ProstT5:蛋白质序列与结构的双语语言模型¶
分类:Linux工程化
概述¶
ProstT5 (Protein structure-sequence T5) 是一种能够实现蛋白质序列(Amino Acid, AA)与三维结构表示之间相互翻译的双语蛋白质语言模型(pLM)。该模型基于 ProtT5-XL-U50 进行微调,后者是一个在数十亿蛋白质序列上通过 span corruption 预训练而成的 T5 架构编码器。通过引入来自 AlphaFold 数据库的高质量三维结构预测数据(约 1700 万个蛋白质),ProstT5 学会了将氨基酸序列转换为结构表示,也可反向操作。
蛋白质三维结构在输入模型前由 Foldseek 的 3Di tokens 技术压缩为一维字符串,使得结构信息能够像序列一样被 Transformer 处理。这种独特的“双语”能力使 ProstT5 既能生成富含结构知识的嵌入表示,又能直接进行序列⇌结构的翻译,为蛋白质工程、功能预测、结构预测等下游任务提供强大支撑。
模型与权重已托管在 Hugging Face 平台:Rostlab/ProstT5。
核心知识点¶
1. 模型基础:从 ProtT5-XL-U50 到 ProstT5¶
ProstT5 并非从头训练,而是基于 ProtT5-XL-U50 参数进行微调。ProtT5-XL-U50 是一个标准的 T5 编码器-解码器模型,专门针对蛋白质序列进行了大规模预训练。预训练任务采用 span corruption(遮蔽片段并预测),使模型掌握了丰富的氨基酸上下文语义。
微调阶段,模型接收两个方向的翻译任务: - AA→3Di:氨基酸序列 → 三维结构表示(类比“折叠”) - 3Di→AA:三维结构表示 → 氨基酸序列(类比“反向设计”)
这种双向微调让 ProstT5 成为“双语”模型,既能理解序列语言,也能理解结构语言,并且在潜在空间中将两者对齐,从而产生融合了结构信息的嵌入向量。
2. 3Di 令牌:三维结构的“一维化”¶
传统蛋白质结构通常用 PDB 坐标或接触图表示,不利于直接送入基于序列的 Transformer。Foldseek 提出的 3Di tokens 将每个残基的三维几何构象映射为 20 个离散字母(类似氨基酸的字母表),从而将结构压缩为一串字符。这种表示能够保留局部结构相似性,并且让序列与结构的互译成为可能。
3Di 字符串中的字母均使用小写,以与标准氨基酸的大写表示区分开。例如,一个蛋白质的氨基酸序列为 PRTEINO,其对应的 3Di 结构字符串可能是 strct(示意)。在实际模型中,3Di 字符串会被 token 化为独立词汇,与氨基酸 token 共享相同的序列建模框架。
3. 特殊前缀标记控制翻译方向¶
ProstT5 使用两个特殊标记指示任务方向: - <AA2fold>:表示从氨基酸序列翻译到 3Di 结构(AA → 3Di),或者仅用于提取氨基酸序列的嵌入。 - <fold2AA>:表示从 3Di 结构翻译到氨基酸序列(3Di → AA),或者提取 3Di 结构的嵌入。
在推理时,只需将对应前缀加入输入序列前,模型便会自动执行相应任务。这种方式使得同一个模型权重能够完成序列嵌入、结构嵌入以及双向翻译,无需加载不同的检查点。
4. 嵌入生成与翻译机制¶
ProstT5 可用于两种主要场景: - 生成残基级别嵌入:通过 T5EncoderModel 提取每个 token 的隐藏状态,获得融合结构信息的序列表示,可用于功能位点预测、序列聚类等任务。 - 序列与结构的双向翻译:使用 AutoModelForSeq2SeqLM 并调用 model.generate(),生成目标域的 token 序列。用户可调整束搜索、温度、top-p 等解码参数以控制生成质量。
5. 预处理注意事项¶
- 所有氨基酸序列必须大写,3Di 序列必须小写。
- 稀有或模糊氨基酸(U、Z、O、B)一律替换为
X,以简化词表并避免未知 token 错误。 - 序列中的每个字符后须插入空格(即 tokenizer 的空白分隔),以保证按字符级 tokenize。
- 生成后需要将空格移除,恢复连续字符串。
代码实操¶
环境配置¶
# 安装依赖
pip install torch
pip install transformers
pip install sentencepiece
# 针对最新 transformers 版本可能出现的 protobuf 错误
pip install protobuf
如果你使用的 Transformers 版本已经包含了 PR#24565 的修改,则可能会显示 Legacy 警告。只需在 tokenizer 中显式设置 legacy=True 即可避免,或直接忽略警告(因为默认行为即为 legacy模式)。
场景一:提取蛋白质嵌入¶
from transformers import T5Tokenizer, T5EncoderModel
import torch
import re
# 选择设备
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# 加载分词器与模型(编码器模式)
tokenizer = T5Tokenizer.from_pretrained('Rostlab/ProstT5', do_lower_case=False)
model = T5EncoderModel.from_pretrained("Rostlab/ProstT5").to(device)
# GPU 使用半精度加速,CPU 则保持全精度
if device.type == 'cpu':
model.float()
else:
model.half()
# 准备序列:氨基酸大写,3Di 小写
sequence_examples = ["PRTEINO", "strct"]
# 替换稀有氨基酸为 X,并在字符间插入空格
sequence_examples = [" ".join(list(re.sub(r"[UZOB]", "X", seq))) for seq in sequence_examples]
# 添加方向前缀:大写序列用 <AA2fold>,小写结构用 <fold2AA>
prefixed_sequences = []
for seq in sequence_examples:
if seq.isupper():
prefixed_sequences.append("<AA2fold> " + seq)
else:
prefixed_sequences.append("<fold2AA> " + seq)
# 批编码并补齐
encodings = tokenizer.batch_encode_plus(
prefixed_sequences,
add_special_tokens=True,
padding="longest",
return_tensors='pt'
).to(device)
# 前向传播提取嵌入
with torch.no_grad():
embedding_result = model(
input_ids=encodings.input_ids,
attention_mask=encodings.attention_mask
)
# 获取残基级别嵌入(跳过特殊 token 和前缀 token)
# 第一条序列:"PRTEINO" 长度 7,前缀 "<AA2fold>" 占用 1 个 token,加上 [CLS] 等特殊 token(默认无 CLS,需确认)
# 实际 token 序列:<AA2fold> + 空格分割后的字符,共 8 个 token,索引 0 为 <AA2fold>
emb_0 = embedding_result.last_hidden_state[0, 1:8] # 形状 (7, 1024)
emb_1 = embedding_result.last_hidden_state[1, 1:6] # 形状 (5, 1024)
# 计算蛋白质级别的全局嵌入(对残基取平均)
protein_emb_0 = emb_0.mean(dim=0) # 形状 (1024,)
注释:实际 token 数量取决于前缀和特殊 token 的设置。batch_encode_plus 默认不添加开头/结尾 token,因此只保留前缀 token 和序列 token。需根据具体输出调整切片。
场景二:序列 ⇌ 结构翻译¶
from transformers import T5Tokenizer, AutoModelForSeq2SeqLM
import torch
import re
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# 加载分词器与模型(Seq2Seq 模式)
tokenizer = T5Tokenizer.from_pretrained('Rostlab/ProstT5', do_lower_case=False)
model = AutoModelForSeq2SeqLM.from_pretrained("Rostlab/ProstT5").to(device)
model = model.float() if device.type == 'cpu' else model.half()
# 输入氨基酸序列(大写)
sequence_examples = ["PRTEINO", "SEQWENCE"]
min_len = min(len(s) for s in sequence_examples)
max_len = max(len(s) for s in sequence_examples)
# 预处理:替换稀有氨基酸、插入空格
sequence_examples = [" ".join(list(re.sub(r"[UZOB]", "X", seq))) for seq in sequence_examples]
# 添加 AA -> 3Di 前缀
prefixed_sequences = ["<AA2fold> " + s for s in sequence_examples]
# 编码
encodings = tokenizer.batch_encode_plus(
prefixed_sequences,
add_special_tokens=True,
padding="longest",
return_tensors='pt'
).to(device)
# 生成参数(AA -> 3Di)
gen_kwargs_aa2fold = {
"do_sample": True,
"num_beams": 3,
"top_p": 0.95,
"temperature": 1.2,
"top_k": 6,
"repetition_penalty": 1.2,
}
# 生成翻译结果
with torch.no_grad():
translations = model.generate(
input_ids=encodings.input_ids,
attention_mask=encodings.attention_mask,
max_length=max_len,
min_length=min_len,
early_stopping=True,
num_return_sequences=1,
**gen_kwargs_aa2fold
)
# 解码并去除空格,恢复小写 3Di 字符串
decoded_texts = tokenizer.batch_decode(translations, skip_special_tokens=True)
structure_sequences = ["".join(ts.split()) for ts in decoded_texts]
# ---------- 反向翻译:3Di -> AA ----------
# 准备反向输入(加上 <fold2AA> 前缀)
back_prefixed = ["<fold2AA> " + ts for ts in decoded_texts] # ts 依然保留空格分隔
back_encodings = tokenizer.batch_encode_plus(
back_prefixed,
add_special_tokens=True,
padding="longest",
return_tensors='pt'
).to(device)
# 生成氨基酸序列(也可使用不同的生成参数)
with torch.no_grad():
back_translations = model.generate(
input_ids=back_encodings.input_ids,
attention_mask=back_encodings.attention_mask,
max_length=max_len,
min_length=min_len,
early_stopping=True,
num_return_sequences=1,
do_sample=True,
top_p=0.95,
temperature=1.0,
repetition_penalty=1.2,
)
back_texts = tokenizer.batch_decode(back_translations, skip_special_tokens=True)
aa_sequences = ["".join(ts.split()) for ts in back_texts] # 最终氨基酸序列(大写)
这段代码展示了从氨基酸到结构的“折叠”以及从结构回到氨基酸的“反向设计”两个方向。实际应用中可仅保留一个方向,并调整生成参数以控制序列多样性。
常见问题¶
1. 运行时报错 UnboundLocalError: cannot access local variable 'sentencepiece_model_pb2'¶
原因:较新的 transformers 版本修改了 T5 tokenizer 的导入逻辑,导致 sentencepiece_model_pb2 未正常引入。
解决方案: - 方案一:手动安装 protobuf:pip install protobuf - 方案二:降级 transformers 至更改前的版本,或安装 PR#25684 修复分支。 - 方案三:忽略该错误但功能可能异常,不推荐。
2. 出现 Legacy 警告¶
在 T5 tokenizer 中可能会看到类似警告:“You are using the legacy behaviour of the T5 tokenizer …”。这是因为 Transformers 更新了 tokenizer 行为,但默认仍保持 legacy 模式以兼容旧模型。
处理方式: - 显式设置 tokenizer = T5Tokenizer.from_pretrained(..., legacy=True) 消除警告。 - 直接忽略警告,不影响模型推理结果。
3. 输入序列应使用大写还是小写?¶
- 氨基酸序列:必须全大写,如
PRTEINO。 - 3Di 结构序列:必须全小写,如
strct。 模型通过大小写判断输入类型,并自动将稀有氨基酸替换为X。如果不按要求处理,会导致任务方向判断错误或 unknown token 过多。
4. 为什么要在字符之间添加空格?¶
ProstT5 使用的 tokenizer 是按字符级别进行分词,但内部期望每个 token 由空格分隔。如果不添加空格,tokenizer 可能会将整个字符串视为单个 token,导致无法正确建模。预处理步骤中 " ".join(list(...)) 即完成此操作。解码后需用 "".join(ts.split()) 去掉空格恢复原始字符串。
5. 生成翻译结果长度如何控制?¶
通过 model.generate() 中的 min_length 和 max_length 参数限制输出长度。通常设置为输入序列的长度范围,也可根据具体需求放宽。如果生成序列过短,可调低 repetition_penalty 或开启 early_stopping=False 强制生成长序列。但需注意模型可能在达到 max_length 时截断。
6. 能在 CPU 上运行吗?¶
可以,但速度极慢,不推荐。模型半精度 (model.half()) 仅支持 GPU,CPU 上需使用全精度 (model.float())。即使是嵌入提取,单条蛋白质也会消耗较多内存和计算资源。建议至少配备 8GB 显存的 GPU 进行批量处理。
7. 生成的 3Di 序列是否可以直接用于 Foldseek 搜索?¶
是的。生成的 3Di 字符串就是 Foldseek 能够直接使用的 3Di 描述符,可以代替真实结构进行快速结构比对。由于 3Di 是由 ProstT5 从氨基酸序列推断而来,其准确性可能低于真实实验结构,但对于没有结构信息的蛋白质,这提供了一种高效的近似结构搜索手段。
速查表¶
模型与资源¶
| 项目 | 地址 |
|---|---|
| 模型主页 | https://huggingface.co/Rostlab/ProstT5 |
| 基础模型 ProtT5-XL-U50 | https://huggingface.co/Rostlab/prot_t5_xl_uniref50 |
| 3Di 转换工具 Foldseek | https://github.com/steineggerlab/foldseek |
| 论文预印本 | https://www.biorxiv.org/content/10.1101/2023.07.23.550085v1 |
特殊前缀令牌¶
| 前缀 | 功能 | 嵌入模式 |
|---|---|---|
<AA2fold> | 氨基酸序列 → 3Di 结构 (翻译) 或提取 AA 嵌入 | 氨基酸嵌入 |
<fold2AA> | 3Di 结构 → 氨基酸序列 (翻译) 或提取 3Di 嵌入 | 结构嵌入 |
关键代码片段¶
嵌入提取(编码器模式)
model = T5EncoderModel.from_pretrained("Rostlab/ProstT5")
# 输入: "<AA2fold> P R T E I N O" 或 "<fold2AA> s t r c t"
embeddings = model(**encoded_inputs).last_hidden_state
序列翻译(Seq2Seq 模式)
model = AutoModelForSeq2SeqLM.from_pretrained("Rostlab/ProstT5")
# 前向翻译
input_ids = tokenizer("<AA2fold> P R T E I N O", return_tensors="pt").input_ids
output_ids = model.generate(input_ids, max_length=length)
structure = tokenizer.decode(output_ids[0], skip_special_tokens=True)
# 反向翻译
input_ids = tokenizer("<fold2AA> s t r c t", return_tensors="pt").input_ids
output_ids = model.generate(input_ids, max_length=length)
sequence = tokenizer.decode(output_ids[0], skip_special_tokens=True)
数据预处理管道¶
- 替换模糊氨基酸:
re.sub(r"[UZOB]", "X", sequence) - 字符间插入空格:
" ".join(list(sequence)) - 添加方向前缀:
"<AA2fold> " + spaced_seq(大写) /"<fold2AA> " + spaced_seq(小写) - 编码与补全:
tokenizer.batch_encode_plus(..., padding="longest") - 解码后去除空格:
"".join(decoded.split())
常用生成参数(参考值)¶
| 参数 | AA→3Di 推荐值 | 说明 |
|---|---|---|
do_sample | True | 开启采样 |
num_beams | 3 | 束搜索宽度 |
top_p | 0.95 | 核采样概率 |
temperature | 1.2 | 生成多样性控制 |
top_k | 6 | 限制候选 token 数 |
repetition_penalty | 1.2 | 抑制重复 |
可根据具体应用调整上述参数。反向翻译时温度可适当降低(如 1.0)以获得更确定的氨基酸序列。
通过以上指南,开发者可以快速上手 ProstT5,实现蛋白质序列与结构之间的高效翻译及嵌入提取,为蛋白质工程、计算生物学等领域提供强大的预训练特征支撑。