481_图像分类经典模型¶
一句话说明¶
图像分类是让模型判断一张图片属于哪个类别,从LeNet到ResNet再到ViT,经历了CNN主导到Transformer崛起的演变。
核心知识点¶
- 卷积神经网络(CNN):局部感受野+权重共享,专为图像设计
- 残差连接(ResNet):跳跃连接解决深层网络梯度消失,让网络可以深达1000层
- 迁移学习:ImageNet预训练模型 + 下游任务微调,极大降低数据需求
- Vision Transformer(ViT):把图像切成patch,当作序列用Transformer处理
- 数据增强:随机裁剪、翻转、颜色抖动,防止过拟合的必备手段
经典模型演进¶
| 模型 | 年份 | Top-1精度(ImageNet) | 参数量 | 特点 |
|---|---|---|---|---|
| AlexNet | 2012 | 57.1% | 60M | 深度学习图像分类开山之作 |
| VGG-16 | 2014 | 71.5% | 138M | 简洁均匀3×3卷积堆叠 |
| InceptionV3 | 2015 | 78.0% | 24M | Inception多尺度并行分支 |
| ResNet-50 | 2015 | 76.1% | 25M | 残差连接,训练极深网络 |
| EfficientNet-B7 | 2019 | 84.4% | 66M | 复合缩放(宽度/深度/分辨率) |
| ViT-L/16 | 2021 | 87.8% | 307M | 纯Transformer处理图像 |
| ConvNeXt-L | 2022 | 86.6% | 198M | 用Transformer思路改进CNN |
代码示例¶
# ---- 1. 使用预训练 ResNet-50 做迁移学习 ----
import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
# ---- 加载预训练模型 ----
model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)
# 冻结所有层,只训练最后的分类头
for param in model.parameters():
param.requires_grad = False # 冻结参数,不更新梯度
# 替换最后一层全连接层(适配自己的分类任务)
num_classes = 10 # 自定义类别数
in_features = model.fc.in_features # 2048
model.fc = nn.Sequential(
nn.Dropout(0.5), # Dropout防过拟合
nn.Linear(in_features, num_classes) # 新的分类头
)
# ---- 2. 数据增强与加载 ----
train_transforms = transforms.Compose([
transforms.RandomResizedCrop(224), # 随机裁剪并resize到224
transforms.RandomHorizontalFlip(), # 随机水平翻转
transforms.ColorJitter(brightness=0.2, # 随机颜色扰动
contrast=0.2,
saturation=0.2),
transforms.ToTensor(), # 转为Tensor
transforms.Normalize( # ImageNet均值和标准差归一化
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
val_transforms = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# ---- 3. 训练循环 ----
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
optimizer = torch.optim.Adam(model.fc.parameters(), lr=1e-3) # 只优化分类头
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)
criterion = nn.CrossEntropyLoss() # 多分类交叉熵损失
def train_epoch(model, loader, optimizer, criterion, device):
model.train()
total_loss, correct, total = 0, 0, 0
for images, labels in loader:
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(images) # 前向传播
loss = criterion(outputs, labels) # 计算loss
loss.backward() # 反向传播
optimizer.step() # 更新参数
total_loss += loss.item()
pred = outputs.argmax(dim=1)
correct += (pred == labels).sum().item()
total += labels.size(0)
return total_loss / len(loader), correct / total
# ---- 4. ViT 图像分类 ----
from transformers import ViTForImageClassification, ViTFeatureExtractor
from PIL import Image
# 加载预训练ViT(在ImageNet-21k上预训练,ImageNet-1k微调)
vit_model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224')
feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224')
# 推理
image = Image.open("cat.jpg")
inputs = feature_extractor(images=image, return_tensors="pt")
with torch.no_grad():
logits = vit_model(**inputs).logits
predicted_class_id = logits.argmax(-1).item()
predicted_label = vit_model.config.id2label[predicted_class_id]
print(f"ViT预测类别: {predicted_label}")
# ---- 5. EfficientNet(高效模型)----
# pip install timm
import timm
# timm库提供大量预训练模型
efficientnet = timm.create_model(
'efficientnet_b4',
pretrained=True,
num_classes=10 # 自定义类别数
)
print(f"EfficientNet-B4参数量: {sum(p.numel() for p in efficientnet.parameters())/1e6:.1f}M")
面试常问点¶
- ResNet为什么能训练很深的网络?
残差连接:
output = F(x) + x,梯度可以绕过中间层直接流回早层,避免梯度消失ViT和CNN的主要区别?
- CNN:局部感受野,归纳偏置强(平移不变性)
ViT:全局自注意力,无归纳偏置,需更多数据/预训练
迁移学习微调的策略?
- 冻结全部→只训练分类头(数据少)
- 解冻后几层→微调高层特征(数据中等)
全部解冻小学习率→完整微调(数据充足)
BatchNorm在深度网络中的作用?
- 归一化每层输入分布,加速训练,提高稳定性,有轻微正则化效果
速查表¶
| 需求 | 选择 |
|---|---|
| 轻量边缘部署 | MobileNetV3 / EfficientNet-B0 |
| 平衡精度效率 | ResNet-50 / EfficientNet-B4 |
| 最高精度 | ViT-L / ConvNeXt-XL |
| 快速原型 | torchvision pretrained models |
| 大量模型库 | timm(PyTorch Image Models) |