TensorFlow/Keras 深度学习¶
一句话概述:TensorFlow 是 Google 开发的深度学习框架,Keras 是它的高级 API(已内置),适合快速搭建模型和生产部署,在工业界仍有广泛应用。最新版 TensorFlow 2.21(2026.03)。
核心知识点¶
| 概念 | 白话解释 |
|---|---|
| Keras | 高级 API = TensorFlow 内置的简单接口(几行代码建模型) |
| Sequential | 顺序模型 = 层按顺序堆叠(最简单的模型) |
| Functional API | 函数式 API = 灵活的建模方式(支持多输入/输出) |
| tf.data | 数据管道 = 高效的数据加载和预处理 |
| SavedModel | 模型格式 = TensorFlow 的标准保存格式 |
| TFLite/LiteRT | 移动端推理 = 在手机/嵌入式设备上运行模型 |
安装配置¶
pip install tensorflow # 安装(CPU + GPU 统一包)
python -c "import tensorflow as tf; print(tf.__version__)" # 验证(2.21.x)
基本使用¶
import tensorflow as tf # 导入 TensorFlow
from tensorflow import keras # Keras API
from tensorflow.keras import layers # 层模块
# 方式一:Sequential 顺序模型(最简单)
model = keras.Sequential([
layers.Dense(64, activation='relu', input_shape=(10,)), # 第1层
layers.Dropout(0.3), # Dropout
layers.Dense(32, activation='relu'), # 第2层
layers.Dense(2, activation='softmax') # 输出层(2分类)
])
# 编译模型
model.compile(
optimizer='adam', # 优化器
loss='sparse_categorical_crossentropy', # 损失函数
metrics=['accuracy'] # 评估指标
)
# 查看模型结构
model.summary() # 打印模型概要
# 训练
history = model.fit(
X_train, y_train, # 训练数据
epochs=100, # 训练轮数
batch_size=32, # 批次大小
validation_split=0.2, # 20% 做验证
callbacks=[
keras.callbacks.EarlyStopping(patience=10, # 早停
restore_best_weights=True)
]
)
# 评估
loss, accuracy = model.evaluate(X_test, y_test) # 测试集评估
print(f"测试准确率: {accuracy:.4f}")
# 预测
predictions = model.predict(X_test) # 预测概率
predicted_class = predictions.argmax(axis=1) # 预测类别
方式二:Functional API¶
# 函数式 API(更灵活)
inputs = keras.Input(shape=(10,)) # 输入层
x = layers.Dense(64, activation='relu')(inputs) # 隐藏层
x = layers.Dropout(0.3)(x) # Dropout
x = layers.Dense(32, activation='relu')(x) # 隐藏层
outputs = layers.Dense(2, activation='softmax')(x) # 输出层
model = keras.Model(inputs=inputs, outputs=outputs) # 创建模型
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
训练可视化¶
import matplotlib.pyplot as plt
# 绘制训练曲线
fig, axes = plt.subplots(1, 2, figsize=(12, 4))
axes[0].plot(history.history['loss'], label='Train') # 训练损失
axes[0].plot(history.history['val_loss'], label='Val') # 验证损失
axes[0].set_title('Loss')
axes[0].legend()
axes[1].plot(history.history['accuracy'], label='Train') # 训练准确率
axes[1].plot(history.history['val_accuracy'], label='Val') # 验证准确率
axes[1].set_title('Accuracy')
axes[1].legend()
plt.savefig('training_curves.png', dpi=300)
高级用法¶
保存和加载¶
# 保存整个模型
model.save('my_model.keras') # Keras 格式(推荐)
# 加载模型
loaded_model = keras.models.load_model('my_model.keras')
# 保存为 SavedModel(部署用)
model.export('saved_model_dir') # TF Serving 格式
tf.data 数据管道¶
# 创建高效数据管道
dataset = tf.data.Dataset.from_tensor_slices((X, y)) # 创建数据集
dataset = dataset.shuffle(1000) # 打乱
dataset = dataset.batch(32) # 分批
dataset = dataset.prefetch(tf.data.AUTOTUNE) # 预读取(提速)
model.fit(dataset, epochs=10) # 用数据集训练
常见报错¶
| 报错信息 | 原因 | 解决方法 |
|---|---|---|
ResourceExhaustedError | GPU 内存不足 | 减小 batch_size |
ValueError: shapes not compatible | 输入维度错误 | 检查 input_shape |
ImportError: No module named tensorflow | 未安装 | pip install tensorflow |
速查表¶
# === 核心流程 ===
# 1. 定义模型(Sequential / Functional)
# 2. model.compile(optimizer, loss, metrics)
# 3. model.fit(X, y, epochs, validation_split)
# 4. model.evaluate(X_test, y_test)
# 5. model.predict(X_new)
# === 常用层 ===
layers.Dense(units, activation) # 全连接
layers.Conv2D(filters, kernel) # 卷积
layers.LSTM(units) # LSTM
layers.BatchNormalization() # 批归一化
layers.Dropout(rate) # Dropout
# === PyTorch vs TensorFlow 选择 ===
# PyTorch: 研究首选、动态图、调试方便、占 85% 论文
# TensorFlow: 部署成熟、生态完整、移动端(LiteRT)
# 建议: 研究用 PyTorch,生产部署考虑 TensorFlow
参考:TensorFlow | Keras | 最新版 2.21.0 (2026.03) | 更新于 2026 年