跳转至

变分自编码器 VAE


一句话说明

VAE 是一种生成模型,把高维数据(如基因表达)压缩到低维"潜在空间",并学会从该空间随机采样生成新数据——比普通自编码器多了"随机性",能学到数据的概率分布。


核心知识点

白话理解

普通自编码器:压缩 → 重建(确定性,只会复制) VAE:压缩成一个概率分布(均值+方差)→ 从中采样 → 重建(有随机性,能生成新样本)

核心公式

ELBO(证据下界)= 重建误差 + KL 散度

L = E[log p(x|z)] - KL(q(z|x) || p(z))
      ↑ 重建质量         ↑ 潜空间正则化
  • 重建误差:编码后再解码,数据失真越小越好
  • KL 散度:让潜在空间接近标准正态分布(方便采样)

重参数化技巧

z = μ + ε × σ,其中 ε ~ N(0,1)
把随机性移到 ε,让梯度可以通过 μ 和 σ 反向传播。

生信应用

应用工具/文章
单细胞降维可视化scVI(Lopez et al. 2018)
批次效应校正scVI、SAUCIE
基因表达补全DCA(Deep Count Autoencoder)
药物响应预测DRAGONN
分子生成分子 VAE(Gómez-Bombarelli 2018)
蛋白质设计EVE(变异效应预测)

实战代码

import torch
import torch.nn as nn
import torch.nn.functional as F

# ===== 单细胞基因表达 VAE(参考 scVI 架构)=====

class VAE(nn.Module):
    """
    变分自编码器用于单细胞基因表达数据
    输入:细胞 × 基因 矩阵(归一化后的基因表达值)
    潜空间:低维连续表示,可用于聚类/可视化
    """
    def __init__(self, input_dim=3000, latent_dim=10, hidden_dim=128):
        super().__init__()
        # === 编码器 ===
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),   # 基因维度 -> 隐藏层
            nn.BatchNorm1d(hidden_dim),          # 批归一化
            nn.ReLU()
        )
        # 编码为均值和对数方差(各一个线性层)
        self.fc_mu = nn.Linear(hidden_dim, latent_dim)      # 潜变量均值 μ
        self.fc_logvar = nn.Linear(hidden_dim, latent_dim)  # 对数方差 log(σ²)

        # === 解码器 ===
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),  # 潜变量 -> 隐藏层
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, input_dim),   # 隐藏层 -> 基因维度
            nn.Sigmoid()                         # 输出 [0,1](归一化表达值)
        )

    def encode(self, x):
        """编码:输入基因表达,输出均值和对数方差"""
        h = self.encoder(x)
        mu = self.fc_mu(h)          # 均值向量
        logvar = self.fc_logvar(h)  # 对数方差向量
        return mu, logvar

    def reparameterize(self, mu, logvar):
        """重参数化技巧:z = μ + ε × σ,保证梯度可传播"""
        if self.training:
            std = torch.exp(0.5 * logvar)   # σ = exp(0.5 * logvar)
            eps = torch.randn_like(std)      # ε ~ N(0,1) 标准正态随机数
            return mu + eps * std            # 采样 z
        else:
            return mu  # 推断时直接用均值(确定性)

    def decode(self, z):
        """解码:从潜变量重建基因表达"""
        return self.decoder(z)

    def forward(self, x):
        mu, logvar = self.encode(x)          # 编码
        z = self.reparameterize(mu, logvar)  # 采样潜变量
        x_recon = self.decode(z)             # 解码重建
        return x_recon, mu, logvar

# ===== 损失函数 =====
def vae_loss(x_recon, x, mu, logvar, beta=1.0):
    """
    VAE 损失 = 重建损失 + β × KL 散度
    β=1:标准 VAE;β>1:β-VAE,潜空间更解耦
    """
    # 重建损失(均方误差,适合连续值)
    recon_loss = F.mse_loss(x_recon, x, reduction='sum')

    # KL 散度:让 q(z|x) 接近 N(0,1)
    # 公式:-0.5 * sum(1 + logvar - μ² - e^logvar)
    kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

    return recon_loss + beta * kl_loss

# ===== 训练示例 =====
input_dim = 3000   # 高变基因数量
latent_dim = 10    # 潜空间维度(降维后的表示)
batch_size = 32

model = VAE(input_dim=input_dim, latent_dim=latent_dim)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

# 模拟单细胞数据(实际替换为 scanpy 加载的归一化矩阵)
fake_data = torch.rand(500, input_dim)  # 500 个细胞

for epoch in range(30):
    idx = torch.randint(0, 500, (batch_size,))
    x = fake_data[idx]

    x_recon, mu, logvar = model(x)               # 前向传播
    loss = vae_loss(x_recon, x, mu, logvar)       # 计算损失

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if epoch % 10 == 0:
        print(f'Epoch {epoch}, Loss: {loss.item():.1f}')

# ===== 获取细胞嵌入(用于聚类/UMAP)=====
model.eval()
with torch.no_grad():
    mu, _ = model.encode(fake_data)  # 取均值作为细胞的低维表示
    # mu 形状: [500, 10],可直接用于 UMAP、KMeans 等
    print(f'细胞嵌入维度: {mu.shape}')

面试常问点

Q: VAE 和 AE(普通自编码器)的区别? A: AE 潜空间是确定性点,无法采样新数据;VAE 潜空间是概率分布(均值+方差),可以从中采样生成新数据。VAE 还有 KL 约束让潜空间连续有结构。

Q: 什么是重参数化技巧?为什么需要? A: 采样操作不可导,无法直接反向传播。重参数化把随机性移到外部变量 ε,z=μ+εσ,让梯度能流过 μ 和 σ。

Q: β-VAE 有什么用? A: β>1 加强 KL 约束,让潜变量各维度更独立(解耦),每个维度代表一种可解释的变化因素。

Q: scVI 用什么分布建模单细胞数据? A: 负二项分布(Negative Binomial),因为 scRNA-seq 数据是离散计数且过度分散,比高斯分布更合适。


速查表

术语解释
潜变量 z数据的低维连续表示
μ(mu)潜变量的均值向量
σ(sigma)潜变量的标准差向量
KL 散度两个分布的"距离"
ELBO证据下界,VAE 优化目标
β-VAE增大 KL 权重,强化解耦
scVI单细胞专用 VAE 框架
重参数化让采样可导的数学技巧