[Datawhale AI 春训营] 分子结构生成赛道程序思路解析

135 阅读12分钟

[Datawhale AI 春训营] 分子结构生成赛道程序思路解析

1. 赛题背景与意义

分子的三维(3D)结构是理解和预测其物理、化学性质以及功能表现的基础。 无论是药物的设计、高效催化剂的开发,还是新型能源材料的探索,都离不开对分子结构的深入研究。然而,在许多前沿领域,获取高质量、高分辨率的分子结构数据往往面临诸多挑战:

  • 实验方法局限性: 传统的实验手段(如X射线晶体学、核磁共振谱等)成本高昂,操作复杂,且难以应用于所有分子体系,尤其对于不稳定或难以结晶的分子。
  • 理论计算资源需求: 基于第一性原理的理论计算(如量子力学计算)虽然能提供高精度的结构信息,但需要巨大的计算资源和时间投入,难以实现大规模、快速的数据生成。

面对分子结构数据匮乏的难题,如何高效、智能地生成具有潜在应用价值的分子结构,成为了材料科学、化学和生物学等领域亟待解决的关键问题。

近年来,人工智能(AI)技术的飞速发展,特别是生成式AI模型(如生成对抗网络GAN、扩散模型Diffusion Models、变分自编码器VAE等),为分子结构的生成带来了新的希望。 这些模型能够从已有的分子数据中学习隐藏的分布规律,进而在硅片上“创造”出全新的、合理的分子结构。 通过与强化学习或条件生成等技术结合,AI甚至可以根据特定的目标性质(如能量、活性等)指导分子的设计过程,极大地加速新材料的发现与优化周期。

2. 比赛任务解析

本次比赛紧密围绕“分子设计”这一核心主题,聚焦生成式AI在3D分子结构生成中的应用。赛事的任务是要求参赛者利用提供的分子数据集,训练生成式AI模型,完成指定目标的分子结构生成。

初赛任务

目标: 生成合理且新颖的3D分子结构。 要求: 基于训练集训练AI生成模型,生成1万个分子的3D结构(包含原子元素种类和三维坐标)。 评测: 后台将评估生成分子的合理性新颖性

比赛原则注意事项

为了确保比赛的公平性并鼓励选手探索纯粹的AI生成方法,比赛设置了严格的原则,核心在于必须通过训练生成式AI模型来生成结果,严禁以下行为:

  • 直接从外部数据库收集或使用现有工具(如SMILES/2D转3D工具)生成结果。
  • 直接提交或简单修改训练集中的分子。
  • 使用任何外部数据或预训练模型。
  • 在训练中引入训练集中未出现的分子,数据增强需保证构象属于同一分子。
  • 对生成结果进行后处理、筛选等。

这些原则确保了比赛的焦点集中在AI生成模型的效能上。

3. 数据集介绍

比赛提供了小分子结构数据集,为.pkl格式。每个样本是一个字典,包含:

  • natoms: 原子数量 (int)
  • elements: 原子元素种类列表 (List[str])
  • coordinates: 原子三维坐标列表 (List[List[float]])
  • properties: 化学性质 (List[float]) - 仅复赛提供,表示绝对能量。

初赛数据特点:

  • 样本数量:4-5万
  • 元素种类:C, H, O, N, F, P, S, Cl, Br
  • 单分子最大构象数量:1
  • 最大分子原子数:60

复赛数据特点:

  • 样本数量:4-5万
  • 元素种类:C, H, O, N, F, P, S, Cl, Br
  • 单分子最大构象数量:1-4 (复赛数据可能包含同一分子的不同构象)

可以根据需要自行划分训练集和验证集。

4. 评价指标

后台首先会将生成的3D分子结构通过价键理论映射为2D图结构,然后进行评估:

  1. 有效性 (Validity): 基于价键理论和化学规则(结合RDKit等库)判断生成分子的合理性和稳定性。计算合理分子占总生成分子数的比例。
  2. 唯一性 (Uniqueness): 在有效的分子中,去除重复分子后计算不重复分子占总生成分子数的比例。
  3. 创新性 (Novelty): 在去重后的分子中,统计训练集中未出现的分子占总生成分子数的比例。
  4. 数据集分布相似性: 评估生成结果与训练集分布的差异,作为反作弊手段。

最终初赛分数是有效性、唯一性、创新性分数的加权组合。复赛则额外评估生成分子实际能量与目标能量的差距。

5. 核心程序思路(基于官方Baseline)

基本思路是基于赛事方提供的等变扩散模型(Equivariant Diffusion Model, EDM)Baseline进行训练和优化。

5.1 等变扩散模型 (EDM) 简介

扩散模型 (Diffusion Models) 是一类生成模型,其核心思想是通过一个前向扩散过程(逐渐向数据中添加噪声,直到数据变成完全的噪声分布), 学习一个逆向去噪过程(从噪声中逐渐恢复出原始数据)。训练好的模型可以从随机噪声开始,逐步“去噪”生成新的数据样本。

等变性 (Equivariance) 是指模型对于输入数据的某些变换(如三维空间中的旋转、平移)具有相应的变换性质。 对于分子结构数据,原子坐标是三维空间中的点,其物理化学性质与分子的绝对位置和朝向无关。因此,一个等变的模型在处理分子数据时具有天然的优势, 能够更好地学习和泛化,因为它可以识别出不同朝向的同一个分子,并生成符合物理规律的结构。 EDM结合了扩散模型的强大生成能力和等变网络的结构优势,特别适用于处理3D点云数据,如分子结构。

5.2 训练策略

基于官方提供的EDM代码,主要策略是进行大力出奇迹式的长时间、大参数训练。

  • 模型架构: 采用基于等变图神经网络 (EGNN) 的结构。
    • --nf 256: 设置EGNN每层的特征维度为256,提高模型的表达能力。
    • --n_layers 9: 使用9层EGNN,构建更深的模型以学习复杂的数据分布。
  • 扩散参数:
    • --diffusion_steps 1000: 设置扩散过程(去噪步骤)为1000步。更多的步数通常能带来更精细的生成过程和更高的生成质量,但也意味着更长的生成时间。
    • --diffusion_noise_schedule polynomial_2: 采用二次多项式的噪声调度,控制不同去噪阶段的噪声量。
    • --diffusion_noise_precision 1e-5: 控制最小噪声量。
    • --diffusion_loss_type l2: 使用L2损失函数优化去噪过程。
  • 训练参数:
    • --n_epochs 1600: 进行长达1600个周期的训练,让模型充分学习数据集的分布。
    • --batch_size 96: 每次训练使用96个样本。
    • --lr 1e-4: 初始学习率为10410^{-4}
    • --ema_decay 0.9999: 设置指数移动平均(EMA)的衰减率为0.9999,使模型参数更新更加平滑和稳定,有助于提升生成性能。
    • --normalize_factors '[1,4,10]': 对不同类型的数据(坐标、分类、整数)使用不同的归一化因子。
  • 硬件与结果: 在H20显卡上进行长时间训练。通过这种“堆资源、堆时间”的方式,官方Baseline模型能够达到0.56左右的性能表现(具体指标取决于评测)。

5.3 面临的问题与优化方向

当前的策略虽然能达到一定性能,但也存在明显的不足:

  • 生成速度慢: 设置了1000步的扩散步骤,虽然有助于提高生成质量,但导致生成10000个分子耗时很长。后续重要的优化方向是探索如何减少扩散步骤或加速采样过程,同时保持生成质量。
  • 可能的改进点(来自官方Baseline文档和经验):
    • 扩散参数调整: 尝试增加 diffusion_steps(若硬件允许且需更高质量),或尝试不同的调度策略如 "cosine"
    • 网络架构调整: 谨慎调整 n_layersnf(如在128-512之间),权衡模型容量和训练效率。文档提到短epoch下增加层数效果不佳,需结合训练时长考虑。
    • 训练策略调整: 增加 n_epochs(如100-200)以获得更好收敛(虽然我们已经设置了1600,但具体收敛情况需监控),调整 lr 或使用学习率调度策略,微调 ema_decay

6. 生成结果检查脚本

为了方便检查生成的 output.pkl 文件是否符合提交格式的基本要求,以及辅助调试,这里提供了一个Python脚本。

import argparse
import os
import pickle
import sys

# 允许的元素集合
allowed_elements = {'C', 'H', 'O', 'N', 'F', 'P', 'S', 'Cl', 'Br'}

def validate_output(file_path):
    """
    验证生成的pkl文件是否符合提交格式要求。
    """
    all_elements_in_file = set()  # 用于收集文件中实际出现的元素

    if not os.path.exists(file_path):
        return False, f"错误:文件 '{file_path}' 不存在。"

    try:
        with open(file_path, 'rb') as f:
            data = pickle.load(f)
    except Exception as e:
        return False, f"错误:加载文件失败 - {str(e)}"

    # 检查数据是否为包含10000个分子的列表
    if not isinstance(data, list) or len(data) != 10000:
        return False, f"错误:数据必须是包含 exactly 10000 个分子的列表,当前数量为 {len(data)}"

    for idx, mol in enumerate(data):
        # 检查每个分子是否为字典且包含所有必要键
        if not isinstance(mol, dict):
            return False, f"错误:第{idx}个分子不是字典类型"
        required_keys = {'natoms', 'elements', 'coordinates'}
        if set(mol.keys()) != required_keys:
            return False, f"错误:第{idx}个分子缺少必要键,应包含 {required_keys},实际包含 {set(mol.keys())}"

        # 检查 'natoms' 是否为整数且符合范围
        natoms = mol.get('natoms')
        if not isinstance(natoms, int) or natoms < 1 or natoms > 60:
            return False, f"错误:第{idx}个分子的原子数 natoms ({natoms}) 不合法(需为整数且在1-60之间)"

        # 检查 'elements' 是否为字符串列表且长度等于 natoms
        elements = mol.get('elements')
        if (not isinstance(elements, list)
            or not all(isinstance(e, str) for e in elements)
            or len(elements) != natoms):
            return False, f"错误:第{idx}个分子的元素列表 elements 不合法(需为字符串列表,长度应为 {natoms},实际长度 {len(elements)})"

        # 检查元素是否在允许的集合中,并收集文件中出现的元素
        invalid_elements = [e for e in elements if e not in allowed_elements]
        if invalid_elements:
            return False, f"错误:第{idx}个分子包含无效元素:{invalid_elements}"
        all_elements_in_file.update(elements) # 收集出现的元素

        # 检查 'coordinates' 是否为3D坐标列表且长度等于 natoms
        coords = mol.get('coordinates')
        if (not isinstance(coords, list)
            or len(coords) != natoms):
            return False, f"错误:第{idx}个分子的坐标列表 coordinates 长度不合法(需为列表,长度应为 {natoms},实际长度 {len(coords)})"
        for coord_idx, coord in enumerate(coords):
            if (not isinstance(coord, list)
                or len(coord) != 3
                or not all(isinstance(c, (int, float)) for c in coord)):
                return False, f"错误:第{idx}个分子的第{coord_idx}个坐标格式不合法(需为包含3个数字的列表)"

    # 检查是否所有允许的元素都在生成文件中至少出现过一次
    missing_allowed_elements = allowed_elements - all_elements_in_file
    if missing_allowed_elements:
         print(f"警告:生成结果中未包含以下允许的元素:{', '.join(sorted(missing_allowed_elements))}. 这可能影响分布相似性评估。", file=sys.stderr)

    return True, "验证通过!所有分子符合基本的提交格式要求。"


def main():
    parser = argparse.ArgumentParser(description='Validate generated molecular structures PKL file.')
    parser.add_argument('file_path', type=str, help='Path to the output PKL file (e.g., outputs/.../eval/data.pkl)')

    args = parser.parse_args()

    result, message = validate_output(args.file_path)
    print(message)

if __name__ == "__main__":
    main()

脚本功能说明:

  • 该脚本定义了一个 validate_output 函数,用于检查指定的 .pkl 文件是否满足初赛提交的基本要求:
    • 文件是否存在且可加载。
    • 数据结构是否为包含10000个字典的列表。
    • 每个字典是否包含 natoms, elements, coordinates 三个键。
    • natoms 是否为整数且在1到60之间。
    • elements 是否为字符串列表,长度等于 natoms,且所有元素都在允许的集合内。
    • coordinates 是否为列表,长度等于 natoms,且每个元素是包含3个数字的列表。
    • (额外检查)提醒如果文件中未能生成所有允许的元素。
  • main 函数解析命令行参数,接收文件路径并调用 validate_output 进行检查,打印结果信息。
  • 使用方法:在命令行中运行 python your_script_name.py /path/to/your/output.pkl

重要提示: 这个脚本仅检查文件格式和基本结构约束(如原子数量、元素类型是否合法、坐标格式等)。 它不能替代赛事后台对分子合理性、唯一性和创新性的评估。 一个通过此脚本检查的文件,仍可能因化学价不合理、结构不稳定等原因在后台评估中得分很低。

7. 总结与展望

当前的程序思路是利用官方提供的等变扩散模型Baseline,并通过长时间、大参数的训练来充分挖掘模型的生成能力。 尽管这种策略在一定程度上是“暴力”的,并在生成速度上存在瓶颈,但它提供了一个扎实的基础。

未来的工作将围绕如何提高生成效率(优化扩散过程)、探索不同的模型架构或训练策略(如条件生成以更好地控制分子性质)、以及如何更有效地利用复赛提供的能量数据来指导生成过程展开。