生成式对抗网络 GAN 详解¶
一句话说明¶
GAN 让两个神经网络相互博弈:生成器造假数据骗过判别器,判别器不断学习识破假货,最终生成器能造出以假乱真的数据——在生信中用于合成细胞图像、扩充稀缺训练数据等。
核心知识点¶
GAN 基本原理¶
对抗博弈: - 生成器 G:输入随机噪声 z,输出假数据 G(z) - 判别器 D:输入数据 x,输出真假概率 D(x) - 目标:G 骗过 D,D 识破 G
损失函数(极小极大博弈):
- D 最大化:真实数据高分,假数据低分 - G 最小化:让假数据被判为真主流 GAN 变体¶
| 变体 | 创新点 | 生信应用 |
|---|---|---|
| DCGAN | 卷积 GAN,生成图像 | 细胞图像合成 |
| WGAN | Wasserstein 距离,训练稳定 | 分子生成 |
| CGAN | 条件 GAN,可控生成 | 特定细胞类型生成 |
| CycleGAN | 无配对图像转换 | 组织染色转换 |
| StyleGAN | 风格控制 | 显微镜图像增强 |
| scGAN | 单细胞专用 | 单细胞数据增强 |
生信应用场景¶
- 单细胞数据增强:稀缺细胞类型数据扩充(scGAN)
- 病理图像合成:训练数据不足时合成 H&E 图像
- 分子生成:生成具有特定属性的分子结构
- 域适应:不同测序平台之间的批次效应去除
- 基因表达矩阵补全:填补 dropout 缺失值
实战代码¶
import torch
import torch.nn as nn
import numpy as np
# ===== 1. 定义生成器 =====
class Generator(nn.Module):
"""生成器:把随机噪声转化为模拟基因表达数据"""
def __init__(self, noise_dim=100, output_dim=2000):
super().__init__()
# 全连接网络,逐层扩张到输出维度
self.net = nn.Sequential(
nn.Linear(noise_dim, 256), # 输入噪声 -> 256 维
nn.BatchNorm1d(256), # 批归一化,稳定训练
nn.ReLU(), # ReLU 激活
nn.Linear(256, 512), # 256 -> 512 维
nn.BatchNorm1d(512),
nn.ReLU(),
nn.Linear(512, output_dim), # 512 -> 基因数维度
nn.ReLU() # 基因表达值非负
)
def forward(self, z):
return self.net(z) # 返回模拟的基因表达谱
# ===== 2. 定义判别器 =====
class Discriminator(nn.Module):
"""判别器:区分真实基因表达数据和生成数据"""
def __init__(self, input_dim=2000):
super().__init__()
self.net = nn.Sequential(
nn.Linear(input_dim, 512), # 输入基因维度
nn.LeakyReLU(0.2), # LeakyReLU(判别器常用,避免梯度消失)
nn.Dropout(0.3), # Dropout 防过拟合
nn.Linear(512, 256),
nn.LeakyReLU(0.2),
nn.Dropout(0.3),
nn.Linear(256, 1), # 输出一个标量
nn.Sigmoid() # 转化为 [0,1] 概率
)
def forward(self, x):
return self.net(x) # 输出数据是真实的概率
# ===== 3. 训练 GAN =====
noise_dim = 100 # 噪声维度(潜空间大小)
gene_dim = 2000 # 基因数量
batch_size = 64
lr = 0.0002
G = Generator(noise_dim, gene_dim) # 实例化生成器
D = Discriminator(gene_dim) # 实例化判别器
# 两个优化器分别更新 G 和 D
optimizer_G = torch.optim.Adam(G.parameters(), lr=lr, betas=(0.5, 0.999))
optimizer_D = torch.optim.Adam(D.parameters(), lr=lr, betas=(0.5, 0.999))
loss_fn = nn.BCELoss() # 二元交叉熵损失
def train_step(real_data):
"""单步训练:先训练判别器,再训练生成器"""
batch_size = real_data.size(0)
# -- 训练判别器 D --
optimizer_D.zero_grad()
# 真实数据标签 = 1
real_labels = torch.ones(batch_size, 1)
real_loss = loss_fn(D(real_data), real_labels) # D 对真实数据的损失
# 生成假数据,标签 = 0
z = torch.randn(batch_size, noise_dim) # 随机噪声
fake_data = G(z).detach() # detach: 不更新 G 的梯度
fake_labels = torch.zeros(batch_size, 1)
fake_loss = loss_fn(D(fake_data), fake_labels) # D 对假数据的损失
d_loss = real_loss + fake_loss
d_loss.backward()
optimizer_D.step()
# -- 训练生成器 G --
optimizer_G.zero_grad()
z = torch.randn(batch_size, noise_dim)
fake_data = G(z)
# G 的目标:让 D 把假数据判为真(标签 = 1)
g_loss = loss_fn(D(fake_data), torch.ones(batch_size, 1))
g_loss.backward()
optimizer_G.step()
return d_loss.item(), g_loss.item()
# 模拟训练(实际替换为真实的 scRNA-seq 数据)
fake_real_data = torch.randn(1000, gene_dim).abs() # 模拟基因表达(非负)
for epoch in range(50):
idx = torch.randint(0, 1000, (batch_size,))
d_loss, g_loss = train_step(fake_real_data[idx])
if epoch % 10 == 0:
print(f'Epoch {epoch}: D_loss={d_loss:.4f}, G_loss={g_loss:.4f}')
面试常问点¶
Q: GAN 训练不稳定的原因? A: 模式崩溃(mode collapse)——生成器只生成有限种类的样本;梯度消失——判别器太强时 G 的梯度接近 0。WGAN 用 Wasserstein 距离解决了这两个问题。
Q: 模式崩溃是什么?怎么解决? A: G 只会生成一种/几种固定样本。解决:WGAN-GP(梯度惩罚)、MiniBatch Discrimination(让 D 看一批数据的多样性)。
Q: GAN 和 VAE 的区别? A: VAE 学习数据分布(显式),生成连续平滑;GAN 隐式建模,生成质量更高但训练不稳定。
Q: 生信中 GAN 的主要挑战? A: 基因表达数据高维稀疏(维度诅咒)、稀缺标注、评估困难(无图像感知质量指标)。
速查表¶
| 术语 | 含义 |
|---|---|
| Generator G | 从噪声生成假数据 |
| Discriminator D | 区分真假数据 |
| Mode Collapse | 生成多样性崩溃 |
| WGAN | 用 Wasserstein 距离替换 JS 散度 |
| Gradient Penalty | WGAN-GP,梯度限制技术 |
| Conditional GAN | 给定条件(如细胞类型)有条件生成 |
| scGAN | 单细胞数据专用 GAN |
| Latent Space | 噪声 z 所在的潜在空间 |