跳转至

图神经网络在生信中的应用


一句话说明

图神经网络(GNN)把生物数据天然的图结构(蛋白质互作网络、代谢通路、分子结构)转化为可学习的表示,用于功能预测、药物发现和疾病分析。


核心知识点

为什么生信适合用图?

生物数据类型节点
蛋白质互作网络 PPI蛋白质互作关系
知识图谱基因/疾病/药物关联关系
分子结构图原子化学键
细胞通讯图细胞类型配体-受体对
基因调控网络 GRN转录因子/基因调控关系

GNN 核心思想

消息传递(Message Passing):每个节点聚合邻居节点的特征,更新自身表示。

h_v^(l+1) = UPDATE( h_v^(l), AGGREGATE({h_u^(l) : u ∈ N(v)}) )

主流 GNN 变体

模型特点生信应用
GCN图卷积,谱域基因功能预测
GAT注意力权重PPI 网络分析
GraphSAGE采样邻居大规模图
GIN图同构,表达能力强分子属性预测
MPNN消息传递框架分子能量预测

代表性生信工具

  • GCN for PPI:STRING 数据库 + GCN 预测功能
  • DeepDTA:GCN 预测药物-靶标亲和力
  • AttentiveFP:图注意力预测分子属性
  • GREIN:GNN 预测基因调控

实战代码

import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, GATConv  # PyTorch Geometric 图神经网络库
from torch_geometric.data import Data

# ===== 1. 构建一个简单的蛋白质互作图 =====
# 假设 5 个蛋白质节点,每个蛋白质有 16 维特征
num_nodes = 5
num_features = 16

# 节点特征矩阵 [节点数, 特征维度]
x = torch.randn(num_nodes, num_features)  # 随机初始化,实际用蛋白质序列特征

# 边索引 [2, 边数],每列是一条边的 (源节点, 目标节点)
edge_index = torch.tensor([
    [0, 1, 1, 2, 3],  # 源节点
    [1, 0, 2, 3, 4]   # 目标节点
], dtype=torch.long)

# 节点标签(0=无功能,1=有某功能)
y = torch.tensor([0, 1, 1, 0, 1], dtype=torch.long)

# 打包成 PyG 的 Data 对象
graph_data = Data(x=x, edge_index=edge_index, y=y)

# ===== 2. 定义 GCN 模型 =====
class GCN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        # 第一层图卷积:输入特征 -> 隐藏层
        self.conv1 = GCNConv(in_channels, hidden_channels)
        # 第二层图卷积:隐藏层 -> 输出类别数
        self.conv2 = GCNConv(hidden_channels, out_channels)

    def forward(self, x, edge_index):
        # 第一层卷积 + ReLU 激活
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, p=0.5, training=self.training)  # Dropout 防过拟合
        # 第二层卷积
        x = self.conv2(x, edge_index)
        return F.log_softmax(x, dim=1)  # 输出各类别的对数概率

# ===== 3. 训练模型 =====
model = GCN(in_channels=16, hidden_channels=32, out_channels=2)  # 二分类
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

model.train()
for epoch in range(100):
    optimizer.zero_grad()  # 清空梯度
    out = model(graph_data.x, graph_data.edge_index)  # 前向传播
    loss = F.nll_loss(out, graph_data.y)  # 负对数似然损失
    loss.backward()  # 反向传播
    optimizer.step()  # 更新参数
    if epoch % 20 == 0:
        print(f'Epoch {epoch}, Loss: {loss.item():.4f}')

# ===== 4. 用 GAT(图注意力网络)替换 =====
class GAT(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, heads=4):
        super().__init__()
        # 多头注意力:heads 个注意力头,自动拼接
        self.conv1 = GATConv(in_channels, hidden_channels, heads=heads)
        # 第二层:拼接后维度 = hidden_channels * heads
        self.conv2 = GATConv(hidden_channels * heads, out_channels, heads=1)

    def forward(self, x, edge_index):
        x = F.elu(self.conv1(x, edge_index))  # ELU 激活(GAT 常用)
        x = self.conv2(x, edge_index)
        return F.log_softmax(x, dim=1)

面试常问点

Q: GCN 和 GAT 的区别? A: GCN 对邻居做平均聚合(权重由度决定),GAT 学习可变注意力权重,对重要邻居赋予更高权重。GAT 更灵活,GCN 更快。

Q: GNN 在蛋白质功能预测中怎么用? A: 把蛋白质互作网络(如 STRING)建成图,节点特征用序列/结构特征,训练 GNN 学习每个蛋白质节点的功能标签。

Q: 过平滑问题是什么? A: GNN 层数太多时,所有节点的表示趋于相同(信息被过度聚合),导致性能下降。一般 2-4 层足够。

Q: 生信中 GNN 面临的挑战? A: 标注数据少(弱监督)、图规模大(百万节点)、异质图(不同类型节点/边)。


速查表

概念解释
节点特征每个节点(如蛋白质)的向量表示
边特征节点间关系的权重/类型
消息传递邻居信息聚合到中心节点
图分类整张图一个标签(如分子毒性)
节点分类每个节点一个标签(如基因功能)
链路预测预测两节点是否有边(如PPI预测)
PyTorch GeometricPyG,最流行的 GNN 库
DGL深度图学习库,另一主流选择