你还没受够python原生字典的写法吗?用ml_collections试试吧

20 阅读6分钟

一、ml_collections是什么

ml_collections 是一个python库,初衷设计是专为机器学习配置。而我在实际使用过程中,把它当做来一个原生字典的替代项,感觉很好用。本文主要介绍基础的使用和锁机制。

地址:pypi.org/project/ml-…

核心特性:

  • 可以让字典使用.的方式来访问和修改,而原生方式是[""]=xx;
  • 提供了锁机制,当写完字典后,在一定程度上防止了误修改;
  • 自动带了数据类型检查;

二、基础用法

下文代码所用的版本:1.1.0

1. 创建方式1:ConfigDict()

from ml_collections import ConfigDict

# ========== 方式1:创建空的ConfigDict,最基础
cfg = ConfigDict()

# ========== 方式2:初始化时传入原生字典,一键创建
cfg = ConfigDict({
    "lr": 0.001,
    "batch_size": 32,
    "epochs": 100
})

# ========== 方式3:空配置逐步赋值创建
cfg = ConfigDict()
# 基础单层配置赋值
cfg.lr = 0.001
cfg.batch_size = 32
cfg.epochs = 200
# 配置支持所有Python基础数据类型
cfg.device = "cuda"
cfg.use_amp = True
cfg.weight_decay = 1e-5

2. 创建方式2:create(),可以一次性把配置都写完

from ml_collections import config_dict

# 用 create 一次性写完所有嵌套配置
cfg = config_dict.create(
    # 数据集配置
    data=config_dict.create(
        dataset="cifar10",
        data_dir="./data",
        batch_size=32,
        num_workers=4
    ),
    # 模型配置
    model=config_dict.create(
        backbone="resnet18",
        num_classes=10,
        dropout=0.2,
        pretrained=True
    ),
    # 训练配置 + 嵌套的优化器子配置
    train=config_dict.create(
        epochs=200,
        lr=0.001,
        optimizer=config_dict.create(
            name="adam",
            weight_decay=1e-5
        )
    )
)


# 打印配置(可选)
print("实验配置:", cfg)

3. 嵌套

字典的值也可以是ConfigDict

from ml_collections import ConfigDict

# 创建根配置
cfg = ConfigDict()

# --------------------------
# 第一层:数据集相关配置(嵌套1层)
# --------------------------
cfg.data = ConfigDict()  # 先给data赋值为一个空的ConfigDict
cfg.data.dataset_name = "cifar10"  # 数据集名称
cfg.data.data_path = "./dataset/cifar10"  # 数据存放路径
cfg.data.batch_size = 32  # 批次大小
cfg.data.num_workers = 4  # 加载数据的线程数

# --------------------------
# 第一层:模型相关配置(嵌套1层)
# --------------------------
cfg.model = ConfigDict()
cfg.model.backbone = "resnet18"  # 模型骨干网络
cfg.model.num_classes = 10  # 分类任务的类别数
cfg.model.dropout_rate = 0.2  # dropout正则化概率
cfg.model.use_pretrain = True  # 是否使用预训练权重

# --------------------------
# 第一层:训练相关配置(嵌套2层)
# --------------------------
cfg.train = ConfigDict()
cfg.train.epochs = 200  # 训练总轮数
cfg.train.base_lr = 0.001  # 基础学习率
# 训练配置里再嵌套:优化器的子配置(嵌套第二层)
cfg.train.optimizer = ConfigDict()
cfg.train.optimizer.name = "adam"  # 优化器名称
cfg.train.optimizer.weight_decay = 0.00005  # 权重衰减

# --------------------------
# 取值:链式点语法,一层一层访问
# --------------------------
print("数据集名称:", cfg.data.dataset_name)
print("模型骨干网络:", cfg.model.backbone)
print("优化器名称:", cfg.train.optimizer.name)
print("权重衰减系数:", cfg.train.optimizer.weight_decay)

4. 读取方式

同时支持.式和[]中括号方式

from ml_collections import ConfigDict

cfg = ConfigDict({
    "lr":0.001,
    "model": ConfigDict({"backbone":"resnet18"})
})

# ========== 方式1:点语法 读取
print(cfg.lr)               # 输出:0.001
print(cfg.model.backbone)   # 输出:resnet18

# ========== 方式2:字典中括号语法 读取
print(cfg["lr"])            # 输出:0.001
print(cfg["model"]["backbone"]) # 输出:resnet18

5. 删除

from ml_collections import ConfigDict

cfg = ConfigDict({"lr":0.001, "batch_size":32, "epochs":200})

del cfg.epochs # 删除基础字段
del cfg["batch_size"] # 用字典语法删除,等价
print("删除后配置:", cfg) # 输出:lr: 0.001

6. 自动类型校验

int可以赋值给int,int可以赋值给float,int不能赋值给string,string也不能赋值给int

from ml_collections import config_dict

cfg = config_dict.ConfigDict()
cfg.float_field = 12.6 # float类型
cfg.integer_field = 123 # int类型
cfg.another_integer_field = 234 # int类型
cfg.nested = config_dict.ConfigDict()
cfg.nested.string_field = 'tom' # str类型

print(cfg.integer_field)  # Prints 123.
print(cfg['integer_field'])  # Prints 123 as well.

try:
  cfg.integer_field = 'tom'  # 将str赋值给int类型,报错
except TypeError as e:
  print(e)

cfg.float_field = 12  # int可以给float赋值,自动类型转换
cfg.nested.string_field = u'bob'  # str赋值

print(cfg)

三、锁机制

1. 方式1: lock()

  • lock后,不可以新增字段,不可以删除字段
  • lock后,可以修改已有的字段,包括嵌套里的字段
from ml_collections import ConfigDict

# 1. 创建配置并初始化【已有字段】
cfg = ConfigDict()
cfg.lr = 0.001
cfg.batch_size = 32
# 嵌套配置
cfg.model = ConfigDict()
cfg.model.dropout = 0.2

# 2. 核心:执行上锁操作
cfg.lock()

# 上锁后 - 修改【已有字段】:正常生效,无报错 
cfg.lr = 0.0005          # 修改根层级已有字段 ✔️
cfg.batch_size = 64      # 修改根层级已有字段 ✔️
cfg.model.dropout = 0.1  # 修改嵌套层级已有字段 ✔️
print("修改后的值:", cfg.lr, cfg.batch_size, cfg.model.dropout)

#  特性2:上锁后 - 新增【任何字段】:直接报错
try:
    cfg.epochs = 200     # 新增根层级字段 
except AttributeError as e:
    print("❌ 新增根字段报错:", e)

try:
    cfg.model.backbone = "resnet18"  # 新增嵌套层级字段 ❌
except AttributeError as e:
    print("❌ 新增嵌套字段报错:", e)

# 特性3:上锁后 - 删除【任何字段】:直接报错 
try:
    del cfg.lr  # 删除根层级已有字段 ❌
except AttributeError as e:
    print("❌ 删除根字段报错:", e)

try:
    del cfg.model.dropout  # 删除嵌套层级已有字段 ❌
except AttributeError as e:
    print("❌ 删除嵌套字段报错:", e)
  • 判断是否上锁
print(cfg.is_locked) # 上锁后返回 True,解锁后返回 False
  • 彻底解锁
cfg.unlock()
  • 临时解锁
from ml_collections import ConfigDict

# 1. 创建配置 + 赋值已有字段
cfg = ConfigDict()
cfg.lr = 0.001
cfg.batch_size = 32
cfg.model = ConfigDict()

# 2. 上锁(核心前提)
cfg.lock()
print(f"初始上锁状态: {cfg.is_locked}")  # True

# 临时解锁 with cfg.unlocked()
# 特性:代码块内 临时解锁,可自由 新增/修改/删除 任何字段;代码块结束 自动重新上锁,无需手动操作
with cfg.unlocked():
    # ✅ 临时修改已有字段
    cfg.lr = 0.0005
    cfg.model.dropout = 0.1
    # ✅ 临时新增字段(根+嵌套都可以)
    cfg.epochs = 200
    cfg.model.backbone = "resnet18"
    # ✅ 临时删除字段
    del cfg.batch_size

# 验证:代码块结束后自动上锁 + 操作全部生效
print(f"临时解锁后状态: {cfg.is_locked}")  # True 自动上锁
print(f"修改后: {cfg.lr}, {cfg.model.dropout}")
print(f"新增后: {cfg.epochs}, {cfg.model.backbone}")
print(f"删除后batch_size是否存在: {'batch_size' in cfg}")  # False

# 上锁状态下 依旧禁止新增/删除(安全兜底)
try:
    cfg.device = "cuda"
except Exception as e:
    print(f"\n上锁禁止新增: {e.args[0]}")

2. 方式2: FrozenConfigDict()

  • 创建,特性:禁止修改、删除、新增
# 正确导入
from ml_collections import ConfigDict
from ml_collections.config_dict import FrozenConfigDict

# ✔️ 方式1:ConfigDict 转 FrozenConfigDict
cfg = ConfigDict()
cfg.lr = 0.001
cfg.batch_size = 32
# 嵌套ConfigDict 正确写法
cfg.model = ConfigDict()
cfg.model.dropout = 0.2
cfg.model.backbone = "resnet18"

f_cfg = FrozenConfigDict(cfg)

# ✔️ 方式2:直接创建 FrozenConfigDict (传字典)
f_cfg2 = FrozenConfigDict({"lr":0.001, "batch_size":32})

print("="*20)
print("FrozenConfigDict只读配置:", f_cfg)
print("是否为冰封配置:", isinstance(f_cfg, FrozenConfigDict)) # True

# 所有【写操作】全部禁止
# 1. ❌ 禁止修改已有字段
try:
    f_cfg.lr = 0.0005
except AttributeError as e:
    print(f"\n❌ 修改报错: {e}")

# 2. ❌ 禁止新增字段
try:
    f_cfg.epochs = 200
except AttributeError as e:
    print(f"❌ 新增报错: {e}")

# 3. ❌ 禁止删除字段
try:
    del f_cfg.batch_size
except AttributeError as e:
    print(f"❌ 删除报错: {e}")

# 4. ❌ 无lock/unlock方法,天生只读,无需上锁
try:
    f_cfg.lock()
except AttributeError as e:
    print(f"❌ 无lock方法: {e}")

# 唯一支持的操作:【读取】 
print("\n" + "="*20)
print("✅ 读取基础字段:", f_cfg.lr)
print("✅ 读取嵌套字段:", f_cfg.model.dropout)
print("✅ 字典方式读取:", f_cfg["batch_size"])
  • FrozenConfigDict与ConfigDict转换
# 正确导入
from ml_collections import ConfigDict
from ml_collections.config_dict import FrozenConfigDict

# ✔️ 转换1: ConfigDict → FrozenConfigDict (冻结,可变→不可变)
cfg = ConfigDict({"lr":0.001, "batch_size":32})
frozen_cfg = FrozenConfigDict(cfg)

# ✔️ 转换2: FrozenConfigDict → ConfigDict (解冻,不可变→可变,恢复所有权限)
unfrozen_cfg = ConfigDict(frozen_cfg)

# 解冻后 ✅ 恢复全部操作权限:增/删/改 都可以
unfrozen_cfg.lr = 1e-4        # 修改 ✔️
unfrozen_cfg.device = "cuda"  # 新增 ✔️
del unfrozen_cfg.batch_size   # 删除 ✔️
print("✅ 解冻后ConfigDict:", unfrozen_cfg)