一、ml_collections是什么
ml_collections 是一个python库,初衷设计是专为机器学习配置。而我在实际使用过程中,把它当做来一个原生字典的替代项,感觉很好用。本文主要介绍基础的使用和锁机制。
核心特性:
- 可以让字典使用.的方式来访问和修改,而原生方式是[""]=xx;
- 提供了锁机制,当写完字典后,在一定程度上防止了误修改;
- 自动带了数据类型检查;
二、基础用法
下文代码所用的版本:1.1.0
1. 创建方式1:ConfigDict()
from ml_collections import ConfigDict
# ========== 方式1:创建空的ConfigDict,最基础
cfg = ConfigDict()
# ========== 方式2:初始化时传入原生字典,一键创建
cfg = ConfigDict({
"lr": 0.001,
"batch_size": 32,
"epochs": 100
})
# ========== 方式3:空配置逐步赋值创建
cfg = ConfigDict()
# 基础单层配置赋值
cfg.lr = 0.001
cfg.batch_size = 32
cfg.epochs = 200
# 配置支持所有Python基础数据类型
cfg.device = "cuda"
cfg.use_amp = True
cfg.weight_decay = 1e-5
2. 创建方式2:create(),可以一次性把配置都写完
from ml_collections import config_dict
# 用 create 一次性写完所有嵌套配置
cfg = config_dict.create(
# 数据集配置
data=config_dict.create(
dataset="cifar10",
data_dir="./data",
batch_size=32,
num_workers=4
),
# 模型配置
model=config_dict.create(
backbone="resnet18",
num_classes=10,
dropout=0.2,
pretrained=True
),
# 训练配置 + 嵌套的优化器子配置
train=config_dict.create(
epochs=200,
lr=0.001,
optimizer=config_dict.create(
name="adam",
weight_decay=1e-5
)
)
)
# 打印配置(可选)
print("实验配置:", cfg)
3. 嵌套
字典的值也可以是ConfigDict
from ml_collections import ConfigDict
# 创建根配置
cfg = ConfigDict()
# --------------------------
# 第一层:数据集相关配置(嵌套1层)
# --------------------------
cfg.data = ConfigDict() # 先给data赋值为一个空的ConfigDict
cfg.data.dataset_name = "cifar10" # 数据集名称
cfg.data.data_path = "./dataset/cifar10" # 数据存放路径
cfg.data.batch_size = 32 # 批次大小
cfg.data.num_workers = 4 # 加载数据的线程数
# --------------------------
# 第一层:模型相关配置(嵌套1层)
# --------------------------
cfg.model = ConfigDict()
cfg.model.backbone = "resnet18" # 模型骨干网络
cfg.model.num_classes = 10 # 分类任务的类别数
cfg.model.dropout_rate = 0.2 # dropout正则化概率
cfg.model.use_pretrain = True # 是否使用预训练权重
# --------------------------
# 第一层:训练相关配置(嵌套2层)
# --------------------------
cfg.train = ConfigDict()
cfg.train.epochs = 200 # 训练总轮数
cfg.train.base_lr = 0.001 # 基础学习率
# 训练配置里再嵌套:优化器的子配置(嵌套第二层)
cfg.train.optimizer = ConfigDict()
cfg.train.optimizer.name = "adam" # 优化器名称
cfg.train.optimizer.weight_decay = 0.00005 # 权重衰减
# --------------------------
# 取值:链式点语法,一层一层访问
# --------------------------
print("数据集名称:", cfg.data.dataset_name)
print("模型骨干网络:", cfg.model.backbone)
print("优化器名称:", cfg.train.optimizer.name)
print("权重衰减系数:", cfg.train.optimizer.weight_decay)
4. 读取方式
同时支持.式和[]中括号方式
from ml_collections import ConfigDict
cfg = ConfigDict({
"lr":0.001,
"model": ConfigDict({"backbone":"resnet18"})
})
# ========== 方式1:点语法 读取
print(cfg.lr) # 输出:0.001
print(cfg.model.backbone) # 输出:resnet18
# ========== 方式2:字典中括号语法 读取
print(cfg["lr"]) # 输出:0.001
print(cfg["model"]["backbone"]) # 输出:resnet18
5. 删除
from ml_collections import ConfigDict
cfg = ConfigDict({"lr":0.001, "batch_size":32, "epochs":200})
del cfg.epochs # 删除基础字段
del cfg["batch_size"] # 用字典语法删除,等价
print("删除后配置:", cfg) # 输出:lr: 0.001
6. 自动类型校验
int可以赋值给int,int可以赋值给float,int不能赋值给string,string也不能赋值给int
from ml_collections import config_dict
cfg = config_dict.ConfigDict()
cfg.float_field = 12.6 # float类型
cfg.integer_field = 123 # int类型
cfg.another_integer_field = 234 # int类型
cfg.nested = config_dict.ConfigDict()
cfg.nested.string_field = 'tom' # str类型
print(cfg.integer_field) # Prints 123.
print(cfg['integer_field']) # Prints 123 as well.
try:
cfg.integer_field = 'tom' # 将str赋值给int类型,报错
except TypeError as e:
print(e)
cfg.float_field = 12 # int可以给float赋值,自动类型转换
cfg.nested.string_field = u'bob' # str赋值
print(cfg)
三、锁机制
1. 方式1: lock()
- lock后,不可以新增字段,不可以删除字段
- lock后,可以修改已有的字段,包括嵌套里的字段
from ml_collections import ConfigDict
# 1. 创建配置并初始化【已有字段】
cfg = ConfigDict()
cfg.lr = 0.001
cfg.batch_size = 32
# 嵌套配置
cfg.model = ConfigDict()
cfg.model.dropout = 0.2
# 2. 核心:执行上锁操作
cfg.lock()
# 上锁后 - 修改【已有字段】:正常生效,无报错
cfg.lr = 0.0005 # 修改根层级已有字段 ✔️
cfg.batch_size = 64 # 修改根层级已有字段 ✔️
cfg.model.dropout = 0.1 # 修改嵌套层级已有字段 ✔️
print("修改后的值:", cfg.lr, cfg.batch_size, cfg.model.dropout)
# 特性2:上锁后 - 新增【任何字段】:直接报错
try:
cfg.epochs = 200 # 新增根层级字段
except AttributeError as e:
print("❌ 新增根字段报错:", e)
try:
cfg.model.backbone = "resnet18" # 新增嵌套层级字段 ❌
except AttributeError as e:
print("❌ 新增嵌套字段报错:", e)
# 特性3:上锁后 - 删除【任何字段】:直接报错
try:
del cfg.lr # 删除根层级已有字段 ❌
except AttributeError as e:
print("❌ 删除根字段报错:", e)
try:
del cfg.model.dropout # 删除嵌套层级已有字段 ❌
except AttributeError as e:
print("❌ 删除嵌套字段报错:", e)
- 判断是否上锁
print(cfg.is_locked) # 上锁后返回 True,解锁后返回 False
- 彻底解锁
cfg.unlock()
- 临时解锁
from ml_collections import ConfigDict
# 1. 创建配置 + 赋值已有字段
cfg = ConfigDict()
cfg.lr = 0.001
cfg.batch_size = 32
cfg.model = ConfigDict()
# 2. 上锁(核心前提)
cfg.lock()
print(f"初始上锁状态: {cfg.is_locked}") # True
# 临时解锁 with cfg.unlocked()
# 特性:代码块内 临时解锁,可自由 新增/修改/删除 任何字段;代码块结束 自动重新上锁,无需手动操作
with cfg.unlocked():
# ✅ 临时修改已有字段
cfg.lr = 0.0005
cfg.model.dropout = 0.1
# ✅ 临时新增字段(根+嵌套都可以)
cfg.epochs = 200
cfg.model.backbone = "resnet18"
# ✅ 临时删除字段
del cfg.batch_size
# 验证:代码块结束后自动上锁 + 操作全部生效
print(f"临时解锁后状态: {cfg.is_locked}") # True 自动上锁
print(f"修改后: {cfg.lr}, {cfg.model.dropout}")
print(f"新增后: {cfg.epochs}, {cfg.model.backbone}")
print(f"删除后batch_size是否存在: {'batch_size' in cfg}") # False
# 上锁状态下 依旧禁止新增/删除(安全兜底)
try:
cfg.device = "cuda"
except Exception as e:
print(f"\n上锁禁止新增: {e.args[0]}")
2. 方式2: FrozenConfigDict()
- 创建,特性:禁止修改、删除、新增
# 正确导入
from ml_collections import ConfigDict
from ml_collections.config_dict import FrozenConfigDict
# ✔️ 方式1:ConfigDict 转 FrozenConfigDict
cfg = ConfigDict()
cfg.lr = 0.001
cfg.batch_size = 32
# 嵌套ConfigDict 正确写法
cfg.model = ConfigDict()
cfg.model.dropout = 0.2
cfg.model.backbone = "resnet18"
f_cfg = FrozenConfigDict(cfg)
# ✔️ 方式2:直接创建 FrozenConfigDict (传字典)
f_cfg2 = FrozenConfigDict({"lr":0.001, "batch_size":32})
print("="*20)
print("FrozenConfigDict只读配置:", f_cfg)
print("是否为冰封配置:", isinstance(f_cfg, FrozenConfigDict)) # True
# 所有【写操作】全部禁止
# 1. ❌ 禁止修改已有字段
try:
f_cfg.lr = 0.0005
except AttributeError as e:
print(f"\n❌ 修改报错: {e}")
# 2. ❌ 禁止新增字段
try:
f_cfg.epochs = 200
except AttributeError as e:
print(f"❌ 新增报错: {e}")
# 3. ❌ 禁止删除字段
try:
del f_cfg.batch_size
except AttributeError as e:
print(f"❌ 删除报错: {e}")
# 4. ❌ 无lock/unlock方法,天生只读,无需上锁
try:
f_cfg.lock()
except AttributeError as e:
print(f"❌ 无lock方法: {e}")
# 唯一支持的操作:【读取】
print("\n" + "="*20)
print("✅ 读取基础字段:", f_cfg.lr)
print("✅ 读取嵌套字段:", f_cfg.model.dropout)
print("✅ 字典方式读取:", f_cfg["batch_size"])
- FrozenConfigDict与ConfigDict转换
# 正确导入
from ml_collections import ConfigDict
from ml_collections.config_dict import FrozenConfigDict
# ✔️ 转换1: ConfigDict → FrozenConfigDict (冻结,可变→不可变)
cfg = ConfigDict({"lr":0.001, "batch_size":32})
frozen_cfg = FrozenConfigDict(cfg)
# ✔️ 转换2: FrozenConfigDict → ConfigDict (解冻,不可变→可变,恢复所有权限)
unfrozen_cfg = ConfigDict(frozen_cfg)
# 解冻后 ✅ 恢复全部操作权限:增/删/改 都可以
unfrozen_cfg.lr = 1e-4 # 修改 ✔️
unfrozen_cfg.device = "cuda" # 新增 ✔️
del unfrozen_cfg.batch_size # 删除 ✔️
print("✅ 解冻后ConfigDict:", unfrozen_cfg)