466_文本分类任务实战¶
一句话说明¶
文本分类是把一段文字自动归入预定类别的任务,是NLP最基础、应用最广的任务之一(垃圾邮件过滤、情感判断、新闻分类等)。
核心知识点¶
- 任务定义:给定文本 x,预测标签 y ∈ {C1, C2, ..., Cn}
- 单标签 vs 多标签:一条文本只属于一类 vs 可属于多类
- 数据不平衡:实际业务中各类样本数差异极大,需特殊处理
- 特征表示:从 BoW → TF-IDF → Word2Vec → BERT Embedding 演进
- 主流范式:预训练模型(BERT/RoBERTa)+ 下游微调
经典模型/方法¶
| 方法 | 优点 | 缺点 | 适用场景 |
|---|---|---|---|
| TF-IDF + 逻辑回归 | 快、可解释 | 忽略语义 | 小数据、基线 |
| TextCNN | 捕捉局部特征 | 缺长程依赖 | 短文本 |
| LSTM/BiLSTM | 序列建模 | 训练慢、梯度消失 | 中等长度文本 |
| BERT 微调 | 强大语义表示 | 资源消耗大 | 高精度需求 |
| DistilBERT | BERT压缩版,快3倍 | 精度略低 | 生产推理 |
代码示例¶
# 使用 HuggingFace Transformers 做文本分类微调
from transformers import BertTokenizer, BertForSequenceClassification
from torch.utils.data import DataLoader, Dataset
import torch
# ---- 1. 自定义数据集 ----
class TextDataset(Dataset):
def __init__(self, texts, labels, tokenizer, max_len=128):
self.texts = texts
self.labels = labels
self.tokenizer = tokenizer
self.max_len = max_len
def __len__(self):
return len(self.texts)
def __getitem__(self, idx):
# 分词并转换为模型输入格式
enc = self.tokenizer(
self.texts[idx],
max_length=self.max_len,
padding='max_length',
truncation=True,
return_tensors='pt'
)
return {
'input_ids': enc['input_ids'].squeeze(0), # token id序列
'attention_mask': enc['attention_mask'].squeeze(0), # 注意力掩码
'label': torch.tensor(self.labels[idx], dtype=torch.long)
}
# ---- 2. 加载预训练BERT ----
tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
# num_labels 设置为分类类别数
model = BertForSequenceClassification.from_pretrained(
'bert-base-chinese', num_labels=2
)
# ---- 3. 简单训练循环 ----
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5) # BERT标准学习率
# 模拟数据
texts = ["今天天气真好", "这个产品太差了"]
labels = [1, 0] # 1=正面, 0=负面
dataset = TextDataset(texts, labels, tokenizer)
loader = DataLoader(dataset, batch_size=2)
model.train()
for batch in loader:
outputs = model(
input_ids=batch['input_ids'],
attention_mask=batch['attention_mask'],
labels=batch['label'] # 传入labels自动计算交叉熵损失
)
loss = outputs.loss # 模型内部已计算loss
loss.backward() # 反向传播
optimizer.step() # 更新参数
optimizer.zero_grad() # 清零梯度
print(f"Loss: {loss.item():.4f}")
# ---- 4. 推理预测 ----
model.eval()
with torch.no_grad():
enc = tokenizer("这个电影很精彩", return_tensors='pt',
max_length=128, truncation=True, padding='max_length')
logits = model(**enc).logits # 原始输出分数
pred = logits.argmax(dim=-1).item() # 取最大值对应类别
print(f"预测类别: {pred}")
面试常问点¶
- BERT 为什么适合文本分类?
双向Transformer编码器,[CLS] token汇聚全句语义,接全连接层分类
类别不平衡怎么处理?
- 重采样(过采样少数类/欠采样多数类)
- 损失函数加权(class_weight)
Focal Loss(降低易分样本权重)
TextCNN的核心思想?
多种尺寸卷积核(2,3,4)提取不同粒度n-gram特征,max-pooling后拼接
微调BERT时超参怎么设?
- 学习率:2e-5 ~ 5e-5;batch_size:16或32;epoch:3~5
速查表¶
| 场景 | 推荐方案 |
|---|---|
| 快速基线 | TF-IDF + sklearn |
| 中文分类 | bert-base-chinese 微调 |
| 生产部署 | DistilBERT / ONNX导出 |
| 数据极少 | Few-shot + GPT prompt |
| 多标签 | sigmoid输出 + BCELoss |