跳转至

生成式对抗网络 GAN 详解


一句话说明

GAN 让两个神经网络相互博弈:生成器造假数据骗过判别器,判别器不断学习识破假货,最终生成器能造出以假乱真的数据——在生信中用于合成细胞图像、扩充稀缺训练数据等。


核心知识点

GAN 基本原理

对抗博弈: - 生成器 G:输入随机噪声 z,输出假数据 G(z) - 判别器 D:输入数据 x,输出真假概率 D(x) - 目标:G 骗过 D,D 识破 G

损失函数(极小极大博弈):

min_G max_D V(G,D) = E[log D(x)] + E[log(1 - D(G(z)))]
- D 最大化:真实数据高分,假数据低分 - G 最小化:让假数据被判为真

主流 GAN 变体

变体创新点生信应用
DCGAN卷积 GAN,生成图像细胞图像合成
WGANWasserstein 距离,训练稳定分子生成
CGAN条件 GAN,可控生成特定细胞类型生成
CycleGAN无配对图像转换组织染色转换
StyleGAN风格控制显微镜图像增强
scGAN单细胞专用单细胞数据增强

生信应用场景

  1. 单细胞数据增强:稀缺细胞类型数据扩充(scGAN)
  2. 病理图像合成:训练数据不足时合成 H&E 图像
  3. 分子生成:生成具有特定属性的分子结构
  4. 域适应:不同测序平台之间的批次效应去除
  5. 基因表达矩阵补全:填补 dropout 缺失值

实战代码

import torch
import torch.nn as nn
import numpy as np

# ===== 1. 定义生成器 =====
class Generator(nn.Module):
    """生成器:把随机噪声转化为模拟基因表达数据"""
    def __init__(self, noise_dim=100, output_dim=2000):
        super().__init__()
        # 全连接网络,逐层扩张到输出维度
        self.net = nn.Sequential(
            nn.Linear(noise_dim, 256),    # 输入噪声 -> 256 维
            nn.BatchNorm1d(256),           # 批归一化,稳定训练
            nn.ReLU(),                     # ReLU 激活
            nn.Linear(256, 512),           # 256 -> 512 维
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Linear(512, output_dim),    # 512 -> 基因数维度
            nn.ReLU()                      # 基因表达值非负
        )

    def forward(self, z):
        return self.net(z)  # 返回模拟的基因表达谱

# ===== 2. 定义判别器 =====
class Discriminator(nn.Module):
    """判别器:区分真实基因表达数据和生成数据"""
    def __init__(self, input_dim=2000):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, 512),     # 输入基因维度
            nn.LeakyReLU(0.2),             # LeakyReLU(判别器常用,避免梯度消失)
            nn.Dropout(0.3),               # Dropout 防过拟合
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(256, 1),             # 输出一个标量
            nn.Sigmoid()                   # 转化为 [0,1] 概率
        )

    def forward(self, x):
        return self.net(x)  # 输出数据是真实的概率

# ===== 3. 训练 GAN =====
noise_dim = 100         # 噪声维度(潜空间大小)
gene_dim = 2000         # 基因数量
batch_size = 64
lr = 0.0002

G = Generator(noise_dim, gene_dim)       # 实例化生成器
D = Discriminator(gene_dim)              # 实例化判别器

# 两个优化器分别更新 G 和 D
optimizer_G = torch.optim.Adam(G.parameters(), lr=lr, betas=(0.5, 0.999))
optimizer_D = torch.optim.Adam(D.parameters(), lr=lr, betas=(0.5, 0.999))

loss_fn = nn.BCELoss()  # 二元交叉熵损失

def train_step(real_data):
    """单步训练:先训练判别器,再训练生成器"""
    batch_size = real_data.size(0)

    # -- 训练判别器 D --
    optimizer_D.zero_grad()

    # 真实数据标签 = 1
    real_labels = torch.ones(batch_size, 1)
    real_loss = loss_fn(D(real_data), real_labels)  # D 对真实数据的损失

    # 生成假数据,标签 = 0
    z = torch.randn(batch_size, noise_dim)    # 随机噪声
    fake_data = G(z).detach()                 # detach: 不更新 G 的梯度
    fake_labels = torch.zeros(batch_size, 1)
    fake_loss = loss_fn(D(fake_data), fake_labels)  # D 对假数据的损失

    d_loss = real_loss + fake_loss
    d_loss.backward()
    optimizer_D.step()

    # -- 训练生成器 G --
    optimizer_G.zero_grad()

    z = torch.randn(batch_size, noise_dim)
    fake_data = G(z)
    # G 的目标:让 D 把假数据判为真(标签 = 1)
    g_loss = loss_fn(D(fake_data), torch.ones(batch_size, 1))
    g_loss.backward()
    optimizer_G.step()

    return d_loss.item(), g_loss.item()

# 模拟训练(实际替换为真实的 scRNA-seq 数据)
fake_real_data = torch.randn(1000, gene_dim).abs()  # 模拟基因表达(非负)
for epoch in range(50):
    idx = torch.randint(0, 1000, (batch_size,))
    d_loss, g_loss = train_step(fake_real_data[idx])
    if epoch % 10 == 0:
        print(f'Epoch {epoch}: D_loss={d_loss:.4f}, G_loss={g_loss:.4f}')

面试常问点

Q: GAN 训练不稳定的原因? A: 模式崩溃(mode collapse)——生成器只生成有限种类的样本;梯度消失——判别器太强时 G 的梯度接近 0。WGAN 用 Wasserstein 距离解决了这两个问题。

Q: 模式崩溃是什么?怎么解决? A: G 只会生成一种/几种固定样本。解决:WGAN-GP(梯度惩罚)、MiniBatch Discrimination(让 D 看一批数据的多样性)。

Q: GAN 和 VAE 的区别? A: VAE 学习数据分布(显式),生成连续平滑;GAN 隐式建模,生成质量更高但训练不稳定。

Q: 生信中 GAN 的主要挑战? A: 基因表达数据高维稀疏(维度诅咒)、稀缺标注、评估困难(无图像感知质量指标)。


速查表

术语含义
Generator G从噪声生成假数据
Discriminator D区分真假数据
Mode Collapse生成多样性崩溃
WGAN用 Wasserstein 距离替换 JS 散度
Gradient PenaltyWGAN-GP,梯度限制技术
Conditional GAN给定条件(如细胞类型)有条件生成
scGAN单细胞数据专用 GAN
Latent Space噪声 z 所在的潜在空间