联邦学习详解¶
一句话说明¶
联邦学习让多个医院/机构在不共享原始数据的前提下,共同训练一个机器学习模型——各自在本地训练,只上传模型参数更新,保护患者隐私。
核心知识点¶
为什么生信需要联邦学习?¶
- 医疗数据分散在各医院,隐私法规(HIPAA、GDPR)禁止共享
- 罕见病样本量少,单中心数据不够
- 跨机构合作训练更强的诊断模型
- 多组学数据整合(各中心保留不同数据模态)
联邦学习基本流程¶
主要类型¶
| 类型 | 特点 | 生信应用 |
|---|---|---|
| 横向联邦学习 | 各方特征相同,样本不同 | 各医院相同诊断数据 |
| 纵向联邦学习 | 各方样本相同,特征不同 | 同一批患者不同检测 |
| 联邦迁移学习 | 特征和样本都不完全重叠 | 跨病种知识迁移 |
关键算法¶
| 算法 | 特点 |
|---|---|
| FedAvg | 最简单,按数据量加权平均参数 |
| FedProx | 加近端项约束,适合非 IID 数据 |
| SCAFFOLD | 控制变分,减少 client drift |
| FedBN | 批归一化层本地保留,其他联邦 |
面临挑战¶
- 非 IID 数据:各医院患者群体不同
- 通信效率:参数量大,上传慢
- 隐私攻击:梯度逆向工程可能泄露数据
- 激励机制:怎么让医院参与
实战代码¶
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import copy
import numpy as np
# ===== 模拟联邦学习:3 家医院共同训练疾病分类模型 =====
# ===== 1. 定义简单的分类模型 =====
class DiagnosisModel(nn.Module):
"""
简单的疾病分类神经网络
输入:基因表达特征
输出:疾病类别概率
"""
def __init__(self, input_dim=50, hidden_dim=32, num_classes=2):
super().__init__()
self.network = nn.Sequential(
nn.Linear(input_dim, hidden_dim), # 特征提取层
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(hidden_dim, num_classes) # 分类输出层
)
def forward(self, x):
return self.network(x)
# ===== 2. 模拟各医院的本地数据(非 IID,各医院患者群体不同)=====
def generate_hospital_data(hospital_id, n_samples=200, n_features=50):
"""
为每个医院生成模拟数据
医院 ID 影响数据分布(模拟非 IID 情况)
"""
np.random.seed(hospital_id * 42)
# 各医院有不同的基因表达背景(模拟不同患者群体)
X = np.random.randn(n_samples, n_features).astype(np.float32)
X += hospital_id * 0.3 # 各医院数据有系统性偏移(批次效应)
# 标签:后20个特征和相关
y = (X[:, 40:50].sum(axis=1) > 0).astype(np.int64)
X_tensor = torch.FloatTensor(X)
y_tensor = torch.LongTensor(y)
dataset = TensorDataset(X_tensor, y_tensor)
return DataLoader(dataset, batch_size=32, shuffle=True)
# 创建 3 个医院的数据
num_hospitals = 3
hospital_dataloaders = [generate_hospital_data(i) for i in range(num_hospitals)]
# ===== 3. FedAvg 算法实现 =====
def local_train(model, dataloader, epochs=5, lr=0.01):
"""
本地训练:各医院在本地数据上训练几轮
返回更新后的模型参数
"""
model.train()
optimizer = optim.SGD(model.parameters(), lr=lr) # 随机梯度下降
loss_fn = nn.CrossEntropyLoss() # 交叉熵损失
for epoch in range(epochs):
for X_batch, y_batch in dataloader:
optimizer.zero_grad()
output = model(X_batch) # 前向传播
loss = loss_fn(output, y_batch) # 计算损失
loss.backward() # 反向传播
optimizer.step() # 更新本地参数
return model.state_dict() # 返回更新后的模型参数字典
def fedavg_aggregate(global_model, local_state_dicts, weights=None):
"""
FedAvg:按权重平均各医院的模型参数
weights: 各医院的数据量权重(None=均等)
"""
if weights is None:
weights = [1.0 / len(local_state_dicts)] * len(local_state_dicts)
global_state = global_model.state_dict()
# 对每一层参数做加权平均
for key in global_state.keys():
# 初始化为零张量
aggregated_param = torch.zeros_like(global_state[key], dtype=torch.float)
for weight, local_state in zip(weights, local_state_dicts):
# 累积加权参数(注意转换为 float 避免整型问题)
aggregated_param += weight * local_state[key].float()
global_state[key] = aggregated_param.to(global_state[key].dtype)
global_model.load_state_dict(global_state) # 更新全局模型
return global_model
# ===== 4. 联邦训练主循环 =====
global_model = DiagnosisModel(input_dim=50) # 初始化全局模型
num_rounds = 20 # 联邦学习轮数
print('=== 开始联邦学习训练 ===')
for round_num in range(num_rounds):
local_states = [] # 收集各医院的本地模型参数
for hospital_id, dataloader in enumerate(hospital_dataloaders):
# 每家医院拿到全局模型副本,在本地训练
local_model = copy.deepcopy(global_model) # 深拷贝全局模型
local_state = local_train(local_model, dataloader, epochs=3)
local_states.append(local_state) # 收集本地参数更新
# 服务器聚合(等权重 FedAvg)
global_model = fedavg_aggregate(global_model, local_states)
if round_num % 5 == 0:
# 简单评估全局模型(实际应在验证集上)
print(f'Round {round_num}: 联邦聚合完成')
print('\n联邦学习完成!全局模型已训练完毕。')
print('注意:整个过程中,各医院的原始数据从未离开本地。')
面试常问点¶
Q: 联邦学习能完全保护隐私吗? A: 不能完全保护。梯度逆向攻击(gradient inversion)可能从梯度中重建部分原始数据。更严格的隐私保护需要结合差分隐私(DP)或安全多方计算(SMC)。
Q: 非 IID 数据对联邦学习有什么影响? A: 各客户端数据分布差异大时,本地训练会导致"客户端漂移",各客户端模型向不同方向更新,聚合后性能差。FedProx、SCAFFOLD 等算法专门解决此问题。
Q: 横向联邦和纵向联邦的区别? A: 横向:同样的特征(如相同基因集),不同的样本(不同患者)。纵向:同样的样本(相同患者群),不同的特征(一家医院有基因组,另一家有影像)。
Q: 生信中联邦学习有哪些实际应用? A: ①MELLODDY 项目:多家药企联合训练分子活性预测模型;②TriNetX:多医院联合临床预测;③Owkin:癌症病理图像联邦分析。
速查表¶
| 术语 | 解释 |
|---|---|
| FedAvg | 最经典联邦算法,加权平均参数 |
| 非 IID | 各客户端数据分布不同(现实常见) |
| 客户端漂移 | 非 IID 导致本地模型偏离全局目标 |
| 差分隐私 DP | 在梯度中加噪声,提供数学隐私保证 |
| 聚合服务器 | 汇总各客户端参数,更新全局模型 |
| 通信轮次 | 全局模型和客户端交互的轮数 |
| MELLODDY | 全球首个多药企联邦学习平台 |
| 梯度逆向攻击 | 从梯度反推原始训练数据的攻击 |