跳转至

481_图像分类经典模型


一句话说明

图像分类是让模型判断一张图片属于哪个类别,从LeNet到ResNet再到ViT,经历了CNN主导到Transformer崛起的演变。


核心知识点

  • 卷积神经网络(CNN):局部感受野+权重共享,专为图像设计
  • 残差连接(ResNet):跳跃连接解决深层网络梯度消失,让网络可以深达1000层
  • 迁移学习:ImageNet预训练模型 + 下游任务微调,极大降低数据需求
  • Vision Transformer(ViT):把图像切成patch,当作序列用Transformer处理
  • 数据增强:随机裁剪、翻转、颜色抖动,防止过拟合的必备手段

经典模型演进

模型年份Top-1精度(ImageNet)参数量特点
AlexNet201257.1%60M深度学习图像分类开山之作
VGG-16201471.5%138M简洁均匀3×3卷积堆叠
InceptionV3201578.0%24MInception多尺度并行分支
ResNet-50201576.1%25M残差连接,训练极深网络
EfficientNet-B7201984.4%66M复合缩放(宽度/深度/分辨率)
ViT-L/16202187.8%307M纯Transformer处理图像
ConvNeXt-L202286.6%198M用Transformer思路改进CNN

代码示例

# ---- 1. 使用预训练 ResNet-50 做迁移学习 ----
import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder

# ---- 加载预训练模型 ----
model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)

# 冻结所有层,只训练最后的分类头
for param in model.parameters():
    param.requires_grad = False  # 冻结参数,不更新梯度

# 替换最后一层全连接层(适配自己的分类任务)
num_classes = 10  # 自定义类别数
in_features = model.fc.in_features  # 2048
model.fc = nn.Sequential(
    nn.Dropout(0.5),                        # Dropout防过拟合
    nn.Linear(in_features, num_classes)     # 新的分类头
)

# ---- 2. 数据增强与加载 ----
train_transforms = transforms.Compose([
    transforms.RandomResizedCrop(224),      # 随机裁剪并resize到224
    transforms.RandomHorizontalFlip(),       # 随机水平翻转
    transforms.ColorJitter(brightness=0.2,  # 随机颜色扰动
                           contrast=0.2,
                           saturation=0.2),
    transforms.ToTensor(),                   # 转为Tensor
    transforms.Normalize(                    # ImageNet均值和标准差归一化
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])

val_transforms = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# ---- 3. 训练循环 ----
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

optimizer = torch.optim.Adam(model.fc.parameters(), lr=1e-3)  # 只优化分类头
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)
criterion = nn.CrossEntropyLoss()  # 多分类交叉熵损失

def train_epoch(model, loader, optimizer, criterion, device):
    model.train()
    total_loss, correct, total = 0, 0, 0
    for images, labels in loader:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)            # 前向传播
        loss = criterion(outputs, labels)  # 计算loss
        loss.backward()                    # 反向传播
        optimizer.step()                   # 更新参数
        total_loss += loss.item()
        pred = outputs.argmax(dim=1)
        correct += (pred == labels).sum().item()
        total += labels.size(0)
    return total_loss / len(loader), correct / total

# ---- 4. ViT 图像分类 ----
from transformers import ViTForImageClassification, ViTFeatureExtractor
from PIL import Image

# 加载预训练ViT(在ImageNet-21k上预训练,ImageNet-1k微调)
vit_model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224')
feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224')

# 推理
image = Image.open("cat.jpg")
inputs = feature_extractor(images=image, return_tensors="pt")
with torch.no_grad():
    logits = vit_model(**inputs).logits
    predicted_class_id = logits.argmax(-1).item()
    predicted_label = vit_model.config.id2label[predicted_class_id]
    print(f"ViT预测类别: {predicted_label}")

# ---- 5. EfficientNet(高效模型)----
# pip install timm
import timm

# timm库提供大量预训练模型
efficientnet = timm.create_model(
    'efficientnet_b4',
    pretrained=True,
    num_classes=10   # 自定义类别数
)
print(f"EfficientNet-B4参数量: {sum(p.numel() for p in efficientnet.parameters())/1e6:.1f}M")

面试常问点

  1. ResNet为什么能训练很深的网络?
  2. 残差连接:output = F(x) + x,梯度可以绕过中间层直接流回早层,避免梯度消失

  3. ViT和CNN的主要区别?

  4. CNN:局部感受野,归纳偏置强(平移不变性)
  5. ViT:全局自注意力,无归纳偏置,需更多数据/预训练

  6. 迁移学习微调的策略?

  7. 冻结全部→只训练分类头(数据少)
  8. 解冻后几层→微调高层特征(数据中等)
  9. 全部解冻小学习率→完整微调(数据充足)

  10. BatchNorm在深度网络中的作用?

  11. 归一化每层输入分布,加速训练,提高稳定性,有轻微正则化效果

速查表

需求选择
轻量边缘部署MobileNetV3 / EfficientNet-B0
平衡精度效率ResNet-50 / EfficientNet-B4
最高精度ViT-L / ConvNeXt-XL
快速原型torchvision pretrained models
大量模型库timm(PyTorch Image Models)