跳转至

156_微生物组机器学习pipeline

一句话概述

微生物组机器学习pipeline整合特征选择(Boruta/mRMR)、分类建模(Random Forest/XGBoost/SVM)、模型解释(SHAP)和严格验证(嵌套交叉验证/外部队列),构建基于菌群特征的疾病诊断或预后分类模型。


核心知识点总览

知识点说明
数据预处理过滤、归一化、CLR转换处理组成性数据
特征选择Boruta/mRMR/LASSO筛选信息性菌种
Random Forest集成学习,适合高维小样本,天然支持特征重要性
XGBoost梯度提升,通常精度更高但需调参
SVM支持向量机,适合线性可分或核变换场景
SHAPSHapley Additive exPlanations模型解释
嵌套交叉验证外层评估泛化性能,内层调参
过拟合防控小样本微生物组数据最大挑战

各步骤详解

第一步:微生物组数据预处理

白话解释: 微生物组数据是"组成性数据"——各菌种的相对丰度加起来是100%。这意味着一个菌种的增加必然导致其他菌种"看起来"减少,即使它们实际没变。做机器学习前必须用CLR转换等方法处理这个问题。

代码示例:

import pandas as pd
import numpy as np
from sklearn.preprocessing import StandardScaler
from skbio.stats.composition import clr, multiplicative_replacement

# 1. 加载OTU/ASV表 (样本×特征)
otu_table = pd.read_csv("otu_table.csv", index_col=0)
metadata = pd.read_csv("metadata.csv", index_col=0)

# 2. 基本过滤
# 去除极低丰度特征 (在<10%样本中出现的)
prevalence = (otu_table > 0).sum(axis=0) / len(otu_table)
otu_filtered = otu_table.loc[:, prevalence >= 0.1]
print(f"过滤后特征数: {otu_filtered.shape[1]} (原始: {otu_table.shape[1]})")

# 去除测序深度极低的样本
depth = otu_filtered.sum(axis=1)
otu_filtered = otu_filtered[depth >= 1000]

# 3. 组成性数据转换 (CLR)
# 先处理零值 (multiplicative replacement)
otu_replaced = multiplicative_replacement(otu_filtered.values)
otu_clr = pd.DataFrame(
    clr(otu_replaced),
    index=otu_filtered.index,
    columns=otu_filtered.columns
)

# 4. 或使用相对丰度 + log转换 (替代方案)
otu_relative = otu_filtered.div(otu_filtered.sum(axis=1), axis=0)
otu_log = np.log10(otu_relative + 1e-6)  # 加伪计数避免log(0)

# 5. 合并metadata
labels = metadata.loc[otu_clr.index, "Disease_status"]  # 0/1分类标签
print(f"类别分布: \n{labels.value_counts()}")


第二步:特征选择

白话解释: 微生物组通常有几百到几千个特征(物种/OTU),但样本量可能只有几十到几百。如果不做特征选择,模型很容易"记住"训练数据中的噪声(过拟合)。需要挑出真正有生物学意义的特征。

代码示例:

from boruta import BorutaPy
from sklearn.ensemble import RandomForestClassifier
from sklearn.feature_selection import mutual_info_classif
import numpy as np

X = otu_clr.values
y = labels.values

# === 方法1: Boruta特征选择 ===
rf = RandomForestClassifier(n_estimators=200, n_jobs=-1, random_state=42)
boruta = BorutaPy(
    estimator=rf,
    n_estimators='auto',
    max_iter=100,
    random_state=42,
    verbose=2
)
boruta.fit(X, y)

# Boruta选择的特征
selected_boruta = otu_clr.columns[boruta.support_].tolist()
tentative_boruta = otu_clr.columns[boruta.support_weak_].tolist()
print(f"Boruta确认特征: {len(selected_boruta)}")
print(f"Boruta待定特征: {len(tentative_boruta)}")

# === 方法2: mRMR (最小冗余最大相关) ===
# pip install mrmr_selection
from mrmr import mrmr_classif

selected_mrmr = mrmr_classif(
    X=pd.DataFrame(X, columns=otu_clr.columns),
    y=pd.Series(y),
    K=20  # 选择top 20特征
)
print(f"mRMR选择的特征: {selected_mrmr}")

# === 方法3: LASSO正则化 ===
from sklearn.linear_model import LassoCV

lasso = LassoCV(cv=5, random_state=42)
lasso.fit(X, y)
lasso_features = otu_clr.columns[lasso.coef_ != 0].tolist()
print(f"LASSO选择特征数: {len(lasso_features)}")

# === 综合: 取多种方法的交集或并集 ===
# 保守策略: 至少2种方法选中的
from collections import Counter
all_selected = selected_boruta + selected_mrmr + lasso_features
feature_counts = Counter(all_selected)
consensus_features = [f for f, c in feature_counts.items() if c >= 2]
print(f"共识特征数: {len(consensus_features)}")


第三步:模型构建——Random Forest

代码示例:

from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import StratifiedKFold, cross_val_score
from sklearn.metrics import (roc_auc_score, classification_report, 
                             roc_curve, precision_recall_curve)
import matplotlib.pyplot as plt

# 使用选择后的特征
X_selected = otu_clr[consensus_features].values
y = labels.values

# === 嵌套交叉验证(防止过拟合的金标准) ===
from sklearn.model_selection import GridSearchCV

# 外层: 评估性能
outer_cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
# 内层: 调参
inner_cv = StratifiedKFold(n_splits=3, shuffle=True, random_state=42)

# 参数网格
param_grid = {
    'n_estimators': [100, 300, 500],
    'max_depth': [5, 10, None],
    'min_samples_leaf': [3, 5, 10],
    'max_features': ['sqrt', 'log2']
}

# 嵌套CV
outer_scores = []
outer_predictions = []

for train_idx, test_idx in outer_cv.split(X_selected, y):
    X_train, X_test = X_selected[train_idx], X_selected[test_idx]
    y_train, y_test = y[train_idx], y[test_idx]

    # 内层CV调参
    grid_search = GridSearchCV(
        RandomForestClassifier(random_state=42),
        param_grid, cv=inner_cv,
        scoring='roc_auc', n_jobs=-1
    )
    grid_search.fit(X_train, y_train)

    # 外层评估
    y_pred_proba = grid_search.predict_proba(X_test)[:, 1]
    auc = roc_auc_score(y_test, y_pred_proba)
    outer_scores.append(auc)
    outer_predictions.extend(zip(test_idx, y_test, y_pred_proba))

print(f"嵌套CV AUC: {np.mean(outer_scores):.3f} ± {np.std(outer_scores):.3f}")

# === 最终模型(全数据训练) ===
final_rf = GridSearchCV(
    RandomForestClassifier(random_state=42),
    param_grid, cv=inner_cv,
    scoring='roc_auc', n_jobs=-1
)
final_rf.fit(X_selected, y)
print(f"最佳参数: {final_rf.best_params_}")


第四步:XGBoost和SVM模型

代码示例:

from xgboost import XGBClassifier
from sklearn.svm import SVC
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler

# === XGBoost ===
xgb_params = {
    'n_estimators': [100, 300],
    'max_depth': [3, 5, 7],
    'learning_rate': [0.01, 0.05, 0.1],
    'subsample': [0.7, 0.8, 0.9],
    'colsample_bytree': [0.7, 0.8],
    'reg_alpha': [0, 0.1, 1],
    'reg_lambda': [1, 5, 10]
}

xgb_model = XGBClassifier(
    random_state=42,
    use_label_encoder=False,
    eval_metric='logloss'
)

xgb_grid = GridSearchCV(xgb_model, xgb_params, cv=inner_cv,
                         scoring='roc_auc', n_jobs=-1)
xgb_grid.fit(X_selected, y)
print(f"XGBoost best AUC: {xgb_grid.best_score_:.3f}")

# === SVM (需要标准化) ===
svm_pipeline = Pipeline([
    ('scaler', StandardScaler()),
    ('svm', SVC(probability=True, random_state=42))
])

svm_params = {
    'svm__C': [0.01, 0.1, 1, 10, 100],
    'svm__kernel': ['rbf', 'linear'],
    'svm__gamma': ['scale', 'auto', 0.01, 0.1]
}

svm_grid = GridSearchCV(svm_pipeline, svm_params, cv=inner_cv,
                         scoring='roc_auc', n_jobs=-1)
svm_grid.fit(X_selected, y)
print(f"SVM best AUC: {svm_grid.best_score_:.3f}")

# === 模型对比 ===
models = {
    'Random Forest': final_rf.best_score_,
    'XGBoost': xgb_grid.best_score_,
    'SVM': svm_grid.best_score_
}
for name, score in sorted(models.items(), key=lambda x: x[1], reverse=True):
    print(f"{name}: AUC = {score:.3f}")


第五步:SHAP模型解释

白话解释: SHAP告诉你每个特征对每个样本的预测贡献了多少——正贡献还是负贡献。这不仅让模型可解释,还能发现哪些微生物是疾病的"正向指标"和"负向指标"。

代码示例:

import shap

# === Random Forest SHAP ===
best_rf = final_rf.best_estimator_
explainer = shap.TreeExplainer(best_rf)
shap_values = explainer.shap_values(X_selected)

# 对于二分类, shap_values[1]是正类的SHAP值
# Summary plot (最重要的特征)
feature_names = consensus_features
shap.summary_plot(shap_values[1], X_selected, feature_names=feature_names,
                  max_display=20, show=False)
plt.tight_layout()
plt.savefig("shap_summary.png", dpi=150, bbox_inches='tight')
plt.close()

# Bar plot (特征重要性排序)
shap.summary_plot(shap_values[1], X_selected, feature_names=feature_names,
                  plot_type="bar", max_display=20, show=False)
plt.savefig("shap_bar.png", dpi=150, bbox_inches='tight')
plt.close()

# === XGBoost SHAP ===
best_xgb = xgb_grid.best_estimator_
explainer_xgb = shap.TreeExplainer(best_xgb)
shap_values_xgb = explainer_xgb.shap_values(X_selected)

# 单样本解释 (force plot)
shap.force_plot(
    explainer_xgb.expected_value,
    shap_values_xgb[0, :],
    X_selected[0, :],
    feature_names=feature_names,
    matplotlib=True,
    show=False
)
plt.savefig("shap_force_sample0.png", dpi=150, bbox_inches='tight')

# Dependence plot (特定特征的SHAP值vs特征值)
top_feature_idx = np.abs(shap_values_xgb).mean(axis=0).argmax()
shap.dependence_plot(top_feature_idx, shap_values_xgb, X_selected,
                     feature_names=feature_names, show=False)
plt.savefig("shap_dependence.png", dpi=150)

# === SHAP交互效应 ===
# 看哪些菌种之间有协同效应
shap_interaction = explainer_xgb.shap_interaction_values(X_selected)


第六步:模型验证与报告

代码示例:

from sklearn.metrics import (roc_curve, auc, precision_recall_curve, 
                             confusion_matrix, classification_report)
from sklearn.calibration import calibration_curve
import matplotlib.pyplot as plt
import numpy as np

# === 综合ROC曲线 (多折CV) ===
fig, ax = plt.subplots(figsize=(8, 6))
tprs = []
aucs = []
mean_fpr = np.linspace(0, 1, 100)

for train_idx, test_idx in outer_cv.split(X_selected, y):
    X_train, X_test = X_selected[train_idx], X_selected[test_idx]
    y_train, y_test = y[train_idx], y[test_idx]

    model = RandomForestClassifier(**final_rf.best_params_, random_state=42)
    model.fit(X_train, y_train)
    y_proba = model.predict_proba(X_test)[:, 1]

    fpr, tpr, _ = roc_curve(y_test, y_proba)
    roc_auc = auc(fpr, tpr)
    aucs.append(roc_auc)

    interp_tpr = np.interp(mean_fpr, fpr, tpr)
    interp_tpr[0] = 0.0
    tprs.append(interp_tpr)
    ax.plot(fpr, tpr, alpha=0.3)

mean_tpr = np.mean(tprs, axis=0)
mean_tpr[-1] = 1.0
mean_auc = np.mean(aucs)
std_auc = np.std(aucs)

ax.plot(mean_fpr, mean_tpr, 'b-', 
        label=f'Mean ROC (AUC = {mean_auc:.3f} ± {std_auc:.3f})')
ax.fill_between(mean_fpr, 
                np.mean(tprs, axis=0) - np.std(tprs, axis=0),
                np.mean(tprs, axis=0) + np.std(tprs, axis=0),
                alpha=0.2)
ax.plot([0, 1], [0, 1], 'k--')
ax.set_xlabel('False Positive Rate')
ax.set_ylabel('True Positive Rate')
ax.set_title('Nested Cross-Validation ROC')
ax.legend()
plt.savefig("roc_curve.png", dpi=150)

# === 学习曲线 (检查过拟合) ===
from sklearn.model_selection import learning_curve

train_sizes, train_scores, test_scores = learning_curve(
    final_rf.best_estimator_, X_selected, y,
    cv=5, scoring='roc_auc',
    train_sizes=np.linspace(0.2, 1.0, 10),
    n_jobs=-1
)

plt.figure(figsize=(8, 5))
plt.plot(train_sizes, train_scores.mean(axis=1), label='Training')
plt.plot(train_sizes, test_scores.mean(axis=1), label='Validation')
plt.fill_between(train_sizes, 
                 train_scores.mean(axis=1) - train_scores.std(axis=1),
                 train_scores.mean(axis=1) + train_scores.std(axis=1), alpha=0.1)
plt.xlabel('Training Set Size')
plt.ylabel('AUC')
plt.title('Learning Curve')
plt.legend()
plt.savefig("learning_curve.png", dpi=150)


实战命令

# === 环境安装 ===
pip install scikit-learn xgboost shap boruta mrmr_selection \
  scikit-bio pandas numpy matplotlib seaborn

# === 完整pipeline脚本运行 ===
python microbiome_ml_pipeline.py \
  --otu_table otu_table.csv \
  --metadata metadata.csv \
  --target Disease \
  --output results/ \
  --n_features 20 \
  --outer_cv 5 \
  --inner_cv 3

# === R替代方案 (mikropml) ===
Rscript -e '
install.packages("mikropml")
library(mikropml)
results <- run_ml(otu_data, "rf", outcome_colname="Disease",
                  kfold=5, cv_times=100, seed=42)
'

面试常问点

Q1:微生物组数据做ML为什么需要CLR转换?

A: 微生物组相对丰度是组成性数据(compositional data)——各成分之和受约束为1。这导致:(1) 假相关性:一个物种增加会导致其他物种比例"被动"降低;(2) 子组成不一致性:从全部物种中取子集后结论可能改变。CLR(中心对数比)转换通过对每个物种取log后减去所有物种log值的几何均值,将数据从"单纯形"空间映射到欧氏空间,消除组成性约束。

Q2:为什么微生物组ML特别容易过拟合?

A: 三个主要原因:(1) 高维小样本(p>>n):通常几百种微生物但只有几十到几百个样本,随机模式很容易被"记住";(2) 特征相关性高:共生/竞争的微生物之间有强相关,增加了有效参数空间;(3) 批次效应:不同测序批次的技术差异可能与生物学变量混淆。必须使用嵌套交叉验证和外部验证来诚实评估性能。

Q3:嵌套交叉验证和普通CV有什么区别?为什么要用嵌套CV?

A: 普通CV在同一数据上同时做了调参和评估——使用CV选出最佳参数后报告的CV分数是有偏的(乐观的)。嵌套CV用外层CV评估泛化性能,内层CV调参——外层测试集从未参与任何决策(包括参数选择)。对于小样本微生物组数据,这种区别可以导致报告AUC差异5-15%。论文中报告的性能必须来自外层CV。

Q4:SHAP解释在微生物组研究中的意义是什么?

A: SHAP提供了从"黑盒分类器"到"生物学发现"的桥梁:(1) 识别biomarker:SHAP值最大的微生物是最佳诊断标志物候选;(2) 方向性:SHAP值的正负指示该微生物丰度增加是促进还是抑制疾病预测;(3) 非线性效应:dependence plot可以揭示某些微生物只在特定丰度阈值以上才有影响;(4) 交互效应:SHAP interaction values可以发现协同作用的菌群组合。

Q5:如何处理类别不平衡问题?

A: 微生物组研究中常见case:control = 1:3甚至更极端。策略包括:(1) 采样方法:SMOTE过采样(慎用,组成性数据的SMOTE可能不合理)、随机欠采样、加权采样;(2) 算法级:设置class_weight='balanced'(RF/SVM)或scale_pos_weight(XGBoost);(3) 评价指标:不用accuracy,用AUC-ROC、AUC-PR、F1。推荐优先使用class_weight参数和AUC-PR指标。


易错点

1. 数据泄露(Data Leakage)

错误: 在全数据上做特征选择后再做CV。 正确做法: 特征选择必须在CV的训练折内进行。否则测试折的信息已通过特征选择"泄露"到模型中,导致性能高估。嵌套CV中特征选择应放在内层。

2. 不处理组成性数据就建模

错误: 直接用相对丰度值做RF/XGBoost。 正确做法: 至少做CLR转换。虽然RF对单调变换不敏感,但组成性约束导致的假相关仍会影响特征选择和解释。

3. 报告训练集性能

错误: 报告训练AUC=0.99作为模型性能。 正确做法: 必须报告holdout或嵌套CV的性能。训练AUC接近1而测试AUC显著低是过拟合的标志。

4. 忽略批次效应

错误: 训练集和测试集来自同一批次,外部验证集来自不同批次时性能骤降。 正确做法: (1) CV中按批次分层(同一批次的样本要在同一折中);(2) 使用ComBat等方法校正批次效应(但要在CV训练折内做);(3) 多中心数据做LOCO(leave-one-center-out)验证。

5. 特征过多不做选择

错误: 500个OTU全部输入模型,样本只有80个。 正确做法: 经验法则:特征数不超过样本数的1/5-1/10。80个样本建议最终模型使用8-15个特征。先用filter方法(方差过滤、prevalence)大幅缩减,再用wrapper方法精选。


补充知识

微生物组ML专用工具

  • mikropml(R): 微生物组ML一站式R包
  • SIAMCAT(R): Bioconductor包,专为微生物组分类设计
  • q2-sample-classifier: QIIME2插件,集成多种ML方法
  • MetAML: 微生物组ML benchmark框架

发表要求

论文中报告ML结果需要包含: 1. 完整的方法描述(预处理、特征选择、模型、验证) 2. 嵌套CV或独立验证集的性能(非训练集) 3. 参数选择的方式和搜索空间 4. 学习曲线(证明样本量是否足够) 5. 可解释性分析(SHAP/特征重要性) 6. 代码和数据可获取性

前沿方向

  • 深度学习: 图神经网络利用菌种间生态关系
  • 迁移学习: 从大队列(如AGP)预训练,fine-tune到特定疾病
  • 多组学整合: 菌群+代谢组+宿主转录组联合建模
  • 因果推断: 从关联性到因果性(MR + 干预实验)