跳转至

联邦学习详解


一句话说明

联邦学习让多个医院/机构在不共享原始数据的前提下,共同训练一个机器学习模型——各自在本地训练,只上传模型参数更新,保护患者隐私。


核心知识点

为什么生信需要联邦学习?

  • 医疗数据分散在各医院,隐私法规(HIPAA、GDPR)禁止共享
  • 罕见病样本量少,单中心数据不够
  • 跨机构合作训练更强的诊断模型
  • 多组学数据整合(各中心保留不同数据模态)

联邦学习基本流程

1. 服务器下发全局模型参数
2. 各客户端(医院)用本地数据训练
3. 各客户端上传模型梯度/参数更新
4. 服务器聚合更新(如 FedAvg 加权平均)
5. 更新全局模型,重复

主要类型

类型特点生信应用
横向联邦学习各方特征相同,样本不同各医院相同诊断数据
纵向联邦学习各方样本相同,特征不同同一批患者不同检测
联邦迁移学习特征和样本都不完全重叠跨病种知识迁移

关键算法

算法特点
FedAvg最简单,按数据量加权平均参数
FedProx加近端项约束,适合非 IID 数据
SCAFFOLD控制变分,减少 client drift
FedBN批归一化层本地保留,其他联邦

面临挑战

  • 非 IID 数据:各医院患者群体不同
  • 通信效率:参数量大,上传慢
  • 隐私攻击:梯度逆向工程可能泄露数据
  • 激励机制:怎么让医院参与

实战代码

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import copy
import numpy as np

# ===== 模拟联邦学习:3 家医院共同训练疾病分类模型 =====

# ===== 1. 定义简单的分类模型 =====
class DiagnosisModel(nn.Module):
    """
    简单的疾病分类神经网络
    输入:基因表达特征
    输出:疾病类别概率
    """
    def __init__(self, input_dim=50, hidden_dim=32, num_classes=2):
        super().__init__()
        self.network = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),   # 特征提取层
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(hidden_dim, num_classes)  # 分类输出层
        )

    def forward(self, x):
        return self.network(x)

# ===== 2. 模拟各医院的本地数据(非 IID,各医院患者群体不同)=====
def generate_hospital_data(hospital_id, n_samples=200, n_features=50):
    """
    为每个医院生成模拟数据
    医院 ID 影响数据分布(模拟非 IID 情况)
    """
    np.random.seed(hospital_id * 42)

    # 各医院有不同的基因表达背景(模拟不同患者群体)
    X = np.random.randn(n_samples, n_features).astype(np.float32)
    X += hospital_id * 0.3  # 各医院数据有系统性偏移(批次效应)

    # 标签:后20个特征和相关
    y = (X[:, 40:50].sum(axis=1) > 0).astype(np.int64)

    X_tensor = torch.FloatTensor(X)
    y_tensor = torch.LongTensor(y)
    dataset = TensorDataset(X_tensor, y_tensor)
    return DataLoader(dataset, batch_size=32, shuffle=True)

# 创建 3 个医院的数据
num_hospitals = 3
hospital_dataloaders = [generate_hospital_data(i) for i in range(num_hospitals)]

# ===== 3. FedAvg 算法实现 =====
def local_train(model, dataloader, epochs=5, lr=0.01):
    """
    本地训练:各医院在本地数据上训练几轮
    返回更新后的模型参数
    """
    model.train()
    optimizer = optim.SGD(model.parameters(), lr=lr)  # 随机梯度下降
    loss_fn = nn.CrossEntropyLoss()                    # 交叉熵损失

    for epoch in range(epochs):
        for X_batch, y_batch in dataloader:
            optimizer.zero_grad()
            output = model(X_batch)            # 前向传播
            loss = loss_fn(output, y_batch)    # 计算损失
            loss.backward()                    # 反向传播
            optimizer.step()                   # 更新本地参数

    return model.state_dict()  # 返回更新后的模型参数字典

def fedavg_aggregate(global_model, local_state_dicts, weights=None):
    """
    FedAvg:按权重平均各医院的模型参数
    weights: 各医院的数据量权重(None=均等)
    """
    if weights is None:
        weights = [1.0 / len(local_state_dicts)] * len(local_state_dicts)

    global_state = global_model.state_dict()

    # 对每一层参数做加权平均
    for key in global_state.keys():
        # 初始化为零张量
        aggregated_param = torch.zeros_like(global_state[key], dtype=torch.float)
        for weight, local_state in zip(weights, local_state_dicts):
            # 累积加权参数(注意转换为 float 避免整型问题)
            aggregated_param += weight * local_state[key].float()
        global_state[key] = aggregated_param.to(global_state[key].dtype)

    global_model.load_state_dict(global_state)  # 更新全局模型
    return global_model

# ===== 4. 联邦训练主循环 =====
global_model = DiagnosisModel(input_dim=50)  # 初始化全局模型
num_rounds = 20  # 联邦学习轮数

print('=== 开始联邦学习训练 ===')
for round_num in range(num_rounds):
    local_states = []  # 收集各医院的本地模型参数

    for hospital_id, dataloader in enumerate(hospital_dataloaders):
        # 每家医院拿到全局模型副本,在本地训练
        local_model = copy.deepcopy(global_model)   # 深拷贝全局模型
        local_state = local_train(local_model, dataloader, epochs=3)
        local_states.append(local_state)             # 收集本地参数更新

    # 服务器聚合(等权重 FedAvg)
    global_model = fedavg_aggregate(global_model, local_states)

    if round_num % 5 == 0:
        # 简单评估全局模型(实际应在验证集上)
        print(f'Round {round_num}: 联邦聚合完成')

print('\n联邦学习完成!全局模型已训练完毕。')
print('注意:整个过程中,各医院的原始数据从未离开本地。')

面试常问点

Q: 联邦学习能完全保护隐私吗? A: 不能完全保护。梯度逆向攻击(gradient inversion)可能从梯度中重建部分原始数据。更严格的隐私保护需要结合差分隐私(DP)或安全多方计算(SMC)。

Q: 非 IID 数据对联邦学习有什么影响? A: 各客户端数据分布差异大时,本地训练会导致"客户端漂移",各客户端模型向不同方向更新,聚合后性能差。FedProx、SCAFFOLD 等算法专门解决此问题。

Q: 横向联邦和纵向联邦的区别? A: 横向:同样的特征(如相同基因集),不同的样本(不同患者)。纵向:同样的样本(相同患者群),不同的特征(一家医院有基因组,另一家有影像)。

Q: 生信中联邦学习有哪些实际应用? A: ①MELLODDY 项目:多家药企联合训练分子活性预测模型;②TriNetX:多医院联合临床预测;③Owkin:癌症病理图像联邦分析。


速查表

术语解释
FedAvg最经典联邦算法,加权平均参数
非 IID各客户端数据分布不同(现实常见)
客户端漂移非 IID 导致本地模型偏离全局目标
差分隐私 DP在梯度中加噪声,提供数学隐私保证
聚合服务器汇总各客户端参数,更新全局模型
通信轮次全局模型和客户端交互的轮数
MELLODDY全球首个多药企联邦学习平台
梯度逆向攻击从梯度反推原始训练数据的攻击