跳转至

TensorFlow/Keras 深度学习

一句话概述:TensorFlow 是 Google 开发的深度学习框架,Keras 是它的高级 API(已内置),适合快速搭建模型和生产部署,在工业界仍有广泛应用。最新版 TensorFlow 2.21(2026.03)。

核心知识点

概念白话解释
Keras高级 API = TensorFlow 内置的简单接口(几行代码建模型)
Sequential顺序模型 = 层按顺序堆叠(最简单的模型)
Functional API函数式 API = 灵活的建模方式(支持多输入/输出)
tf.data数据管道 = 高效的数据加载和预处理
SavedModel模型格式 = TensorFlow 的标准保存格式
TFLite/LiteRT移动端推理 = 在手机/嵌入式设备上运行模型

安装配置

pip install tensorflow                                # 安装(CPU + GPU 统一包)
python -c "import tensorflow as tf; print(tf.__version__)"  # 验证(2.21.x)

基本使用

import tensorflow as tf                               # 导入 TensorFlow
from tensorflow import keras                          # Keras API
from tensorflow.keras import layers                   # 层模块

# 方式一:Sequential 顺序模型(最简单)
model = keras.Sequential([
    layers.Dense(64, activation='relu', input_shape=(10,)),  # 第1层
    layers.Dropout(0.3),                              # Dropout
    layers.Dense(32, activation='relu'),              # 第2层
    layers.Dense(2, activation='softmax')             # 输出层(2分类)
])

# 编译模型
model.compile(
    optimizer='adam',                                  # 优化器
    loss='sparse_categorical_crossentropy',           # 损失函数
    metrics=['accuracy']                              # 评估指标
)

# 查看模型结构
model.summary()                                       # 打印模型概要

# 训练
history = model.fit(
    X_train, y_train,                                 # 训练数据
    epochs=100,                                       # 训练轮数
    batch_size=32,                                    # 批次大小
    validation_split=0.2,                             # 20% 做验证
    callbacks=[
        keras.callbacks.EarlyStopping(patience=10,    # 早停
                                      restore_best_weights=True)
    ]
)

# 评估
loss, accuracy = model.evaluate(X_test, y_test)       # 测试集评估
print(f"测试准确率: {accuracy:.4f}")

# 预测
predictions = model.predict(X_test)                   # 预测概率
predicted_class = predictions.argmax(axis=1)          # 预测类别

方式二:Functional API

# 函数式 API(更灵活)
inputs = keras.Input(shape=(10,))                     # 输入层
x = layers.Dense(64, activation='relu')(inputs)       # 隐藏层
x = layers.Dropout(0.3)(x)                           # Dropout
x = layers.Dense(32, activation='relu')(x)            # 隐藏层
outputs = layers.Dense(2, activation='softmax')(x)    # 输出层

model = keras.Model(inputs=inputs, outputs=outputs)   # 创建模型
model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

训练可视化

import matplotlib.pyplot as plt

# 绘制训练曲线
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

axes[0].plot(history.history['loss'], label='Train')       # 训练损失
axes[0].plot(history.history['val_loss'], label='Val')     # 验证损失
axes[0].set_title('Loss')
axes[0].legend()

axes[1].plot(history.history['accuracy'], label='Train')   # 训练准确率
axes[1].plot(history.history['val_accuracy'], label='Val') # 验证准确率
axes[1].set_title('Accuracy')
axes[1].legend()

plt.savefig('training_curves.png', dpi=300)

高级用法

保存和加载

# 保存整个模型
model.save('my_model.keras')                           # Keras 格式(推荐)

# 加载模型
loaded_model = keras.models.load_model('my_model.keras')

# 保存为 SavedModel(部署用)
model.export('saved_model_dir')                        # TF Serving 格式

tf.data 数据管道

# 创建高效数据管道
dataset = tf.data.Dataset.from_tensor_slices((X, y))  # 创建数据集
dataset = dataset.shuffle(1000)                        # 打乱
dataset = dataset.batch(32)                            # 分批
dataset = dataset.prefetch(tf.data.AUTOTUNE)           # 预读取(提速)

model.fit(dataset, epochs=10)                          # 用数据集训练

常见报错

报错信息原因解决方法
ResourceExhaustedErrorGPU 内存不足减小 batch_size
ValueError: shapes not compatible输入维度错误检查 input_shape
ImportError: No module named tensorflow未安装pip install tensorflow

速查表

# === 核心流程 ===
# 1. 定义模型(Sequential / Functional)
# 2. model.compile(optimizer, loss, metrics)
# 3. model.fit(X, y, epochs, validation_split)
# 4. model.evaluate(X_test, y_test)
# 5. model.predict(X_new)

# === 常用层 ===
layers.Dense(units, activation)     # 全连接
layers.Conv2D(filters, kernel)      # 卷积
layers.LSTM(units)                  # LSTM
layers.BatchNormalization()         # 批归一化
layers.Dropout(rate)                # Dropout

# === PyTorch vs TensorFlow 选择 ===
# PyTorch: 研究首选、动态图、调试方便、占 85% 论文
# TensorFlow: 部署成熟、生态完整、移动端(LiteRT)
# 建议: 研究用 PyTorch,生产部署考虑 TensorFlow

参考:TensorFlow | Keras | 最新版 2.21.0 (2026.03) | 更新于 2026 年