图神经网络在生信中的应用¶
一句话说明¶
图神经网络(GNN)把生物数据天然的图结构(蛋白质互作网络、代谢通路、分子结构)转化为可学习的表示,用于功能预测、药物发现和疾病分析。
核心知识点¶
为什么生信适合用图?¶
| 生物数据类型 | 节点 | 边 |
|---|---|---|
| 蛋白质互作网络 PPI | 蛋白质 | 互作关系 |
| 知识图谱 | 基因/疾病/药物 | 关联关系 |
| 分子结构图 | 原子 | 化学键 |
| 细胞通讯图 | 细胞类型 | 配体-受体对 |
| 基因调控网络 GRN | 转录因子/基因 | 调控关系 |
GNN 核心思想¶
消息传递(Message Passing):每个节点聚合邻居节点的特征,更新自身表示。
主流 GNN 变体¶
| 模型 | 特点 | 生信应用 |
|---|---|---|
| GCN | 图卷积,谱域 | 基因功能预测 |
| GAT | 注意力权重 | PPI 网络分析 |
| GraphSAGE | 采样邻居 | 大规模图 |
| GIN | 图同构,表达能力强 | 分子属性预测 |
| MPNN | 消息传递框架 | 分子能量预测 |
代表性生信工具¶
- GCN for PPI:STRING 数据库 + GCN 预测功能
- DeepDTA:GCN 预测药物-靶标亲和力
- AttentiveFP:图注意力预测分子属性
- GREIN:GNN 预测基因调控
实战代码¶
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, GATConv # PyTorch Geometric 图神经网络库
from torch_geometric.data import Data
# ===== 1. 构建一个简单的蛋白质互作图 =====
# 假设 5 个蛋白质节点,每个蛋白质有 16 维特征
num_nodes = 5
num_features = 16
# 节点特征矩阵 [节点数, 特征维度]
x = torch.randn(num_nodes, num_features) # 随机初始化,实际用蛋白质序列特征
# 边索引 [2, 边数],每列是一条边的 (源节点, 目标节点)
edge_index = torch.tensor([
[0, 1, 1, 2, 3], # 源节点
[1, 0, 2, 3, 4] # 目标节点
], dtype=torch.long)
# 节点标签(0=无功能,1=有某功能)
y = torch.tensor([0, 1, 1, 0, 1], dtype=torch.long)
# 打包成 PyG 的 Data 对象
graph_data = Data(x=x, edge_index=edge_index, y=y)
# ===== 2. 定义 GCN 模型 =====
class GCN(torch.nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels):
super().__init__()
# 第一层图卷积:输入特征 -> 隐藏层
self.conv1 = GCNConv(in_channels, hidden_channels)
# 第二层图卷积:隐藏层 -> 输出类别数
self.conv2 = GCNConv(hidden_channels, out_channels)
def forward(self, x, edge_index):
# 第一层卷积 + ReLU 激活
x = self.conv1(x, edge_index)
x = F.relu(x)
x = F.dropout(x, p=0.5, training=self.training) # Dropout 防过拟合
# 第二层卷积
x = self.conv2(x, edge_index)
return F.log_softmax(x, dim=1) # 输出各类别的对数概率
# ===== 3. 训练模型 =====
model = GCN(in_channels=16, hidden_channels=32, out_channels=2) # 二分类
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
model.train()
for epoch in range(100):
optimizer.zero_grad() # 清空梯度
out = model(graph_data.x, graph_data.edge_index) # 前向传播
loss = F.nll_loss(out, graph_data.y) # 负对数似然损失
loss.backward() # 反向传播
optimizer.step() # 更新参数
if epoch % 20 == 0:
print(f'Epoch {epoch}, Loss: {loss.item():.4f}')
# ===== 4. 用 GAT(图注意力网络)替换 =====
class GAT(torch.nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels, heads=4):
super().__init__()
# 多头注意力:heads 个注意力头,自动拼接
self.conv1 = GATConv(in_channels, hidden_channels, heads=heads)
# 第二层:拼接后维度 = hidden_channels * heads
self.conv2 = GATConv(hidden_channels * heads, out_channels, heads=1)
def forward(self, x, edge_index):
x = F.elu(self.conv1(x, edge_index)) # ELU 激活(GAT 常用)
x = self.conv2(x, edge_index)
return F.log_softmax(x, dim=1)
面试常问点¶
Q: GCN 和 GAT 的区别? A: GCN 对邻居做平均聚合(权重由度决定),GAT 学习可变注意力权重,对重要邻居赋予更高权重。GAT 更灵活,GCN 更快。
Q: GNN 在蛋白质功能预测中怎么用? A: 把蛋白质互作网络(如 STRING)建成图,节点特征用序列/结构特征,训练 GNN 学习每个蛋白质节点的功能标签。
Q: 过平滑问题是什么? A: GNN 层数太多时,所有节点的表示趋于相同(信息被过度聚合),导致性能下降。一般 2-4 层足够。
Q: 生信中 GNN 面临的挑战? A: 标注数据少(弱监督)、图规模大(百万节点)、异质图(不同类型节点/边)。
速查表¶
| 概念 | 解释 |
|---|---|
| 节点特征 | 每个节点(如蛋白质)的向量表示 |
| 边特征 | 节点间关系的权重/类型 |
| 消息传递 | 邻居信息聚合到中心节点 |
| 图分类 | 整张图一个标签(如分子毒性) |
| 节点分类 | 每个节点一个标签(如基因功能) |
| 链路预测 | 预测两节点是否有边(如PPI预测) |
| PyTorch Geometric | PyG,最流行的 GNN 库 |
| DGL | 深度图学习库,另一主流选择 |