跳转至

466_文本分类任务实战


一句话说明

文本分类是把一段文字自动归入预定类别的任务,是NLP最基础、应用最广的任务之一(垃圾邮件过滤、情感判断、新闻分类等)。


核心知识点

  • 任务定义:给定文本 x,预测标签 y ∈ {C1, C2, ..., Cn}
  • 单标签 vs 多标签:一条文本只属于一类 vs 可属于多类
  • 数据不平衡:实际业务中各类样本数差异极大,需特殊处理
  • 特征表示:从 BoW → TF-IDF → Word2Vec → BERT Embedding 演进
  • 主流范式:预训练模型(BERT/RoBERTa)+ 下游微调

经典模型/方法

方法优点缺点适用场景
TF-IDF + 逻辑回归快、可解释忽略语义小数据、基线
TextCNN捕捉局部特征缺长程依赖短文本
LSTM/BiLSTM序列建模训练慢、梯度消失中等长度文本
BERT 微调强大语义表示资源消耗大高精度需求
DistilBERTBERT压缩版,快3倍精度略低生产推理

代码示例

# 使用 HuggingFace Transformers 做文本分类微调

from transformers import BertTokenizer, BertForSequenceClassification
from torch.utils.data import DataLoader, Dataset
import torch

# ---- 1. 自定义数据集 ----
class TextDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_len=128):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        # 分词并转换为模型输入格式
        enc = self.tokenizer(
            self.texts[idx],
            max_length=self.max_len,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        return {
            'input_ids': enc['input_ids'].squeeze(0),          # token id序列
            'attention_mask': enc['attention_mask'].squeeze(0), # 注意力掩码
            'label': torch.tensor(self.labels[idx], dtype=torch.long)
        }

# ---- 2. 加载预训练BERT ----
tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
# num_labels 设置为分类类别数
model = BertForSequenceClassification.from_pretrained(
    'bert-base-chinese', num_labels=2
)

# ---- 3. 简单训练循环 ----
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)  # BERT标准学习率

# 模拟数据
texts = ["今天天气真好", "这个产品太差了"]
labels = [1, 0]  # 1=正面, 0=负面

dataset = TextDataset(texts, labels, tokenizer)
loader = DataLoader(dataset, batch_size=2)

model.train()
for batch in loader:
    outputs = model(
        input_ids=batch['input_ids'],
        attention_mask=batch['attention_mask'],
        labels=batch['label']  # 传入labels自动计算交叉熵损失
    )
    loss = outputs.loss  # 模型内部已计算loss
    loss.backward()      # 反向传播
    optimizer.step()     # 更新参数
    optimizer.zero_grad()  # 清零梯度
    print(f"Loss: {loss.item():.4f}")

# ---- 4. 推理预测 ----
model.eval()
with torch.no_grad():
    enc = tokenizer("这个电影很精彩", return_tensors='pt',
                    max_length=128, truncation=True, padding='max_length')
    logits = model(**enc).logits          # 原始输出分数
    pred = logits.argmax(dim=-1).item()   # 取最大值对应类别
    print(f"预测类别: {pred}")

面试常问点

  1. BERT 为什么适合文本分类?
  2. 双向Transformer编码器,[CLS] token汇聚全句语义,接全连接层分类

  3. 类别不平衡怎么处理?

  4. 重采样(过采样少数类/欠采样多数类)
  5. 损失函数加权(class_weight)
  6. Focal Loss(降低易分样本权重)

  7. TextCNN的核心思想?

  8. 多种尺寸卷积核(2,3,4)提取不同粒度n-gram特征,max-pooling后拼接

  9. 微调BERT时超参怎么设?

  10. 学习率:2e-5 ~ 5e-5;batch_size:16或32;epoch:3~5

速查表

场景推荐方案
快速基线TF-IDF + sklearn
中文分类bert-base-chinese 微调
生产部署DistilBERT / ONNX导出
数据极少Few-shot + GPT prompt
多标签sigmoid输出 + BCELoss