模型权重转换OOM解决方案

195 阅读3分钟

Deepseek2-236B在2T内存的服务器上转全层的权重时会出现内存不够的报错,下面介绍一些可行的改进措施。

以hf2mg为例,下面是对loader_hf.py的一些修改:

def _load_checkpoint(model_provider, queue, args):

Llama-2 requires HF transformers >=4.31.0.

verify_transformers_version()

# Search in directory above this.
sys.path.append(os.path.abspath(
    os.path.join(os.path.dirname(__file__),
                 os.path.pardir,
                 os.path.pardir)))
if args.megatron_path is not None:
    sys.path.insert(0, args.megatron_path)

# 获取huggingface模型
model_hf = get_huggingface_model(args)
# 获取huggingface模型的配置参数
args_hf = model_hf.get_args()
args_hf.moe_grouped_gemm = args.moe_grouped_gemm
args_hf.spec = args.spec

# 获取megatron模型
model_mg = get_megatron_model(model_provider, args_cmd=args)
# 初始化megatron模型
model_mg.initialize_megatron_args(args_hf, queue)

# 配置并行策略
model_mg.set_tensor_model_parallel_world_size(model_mg.args.tensor_model_parallel_size)
model_mg.set_expert_model_parallel_world_size(model_mg.args.expert_model_parallel_size)
model_mg.set_pipeline_model_parallel_world_size(model_mg.args.pipeline_model_parallel_size)
model_mg.set_virtual_pipeline_model_parallel_world_size(model_mg.args.virtual_pipeline_model_parallel_size)

# Get first pipe stage.
model_mg.set_tensor_model_parallel_rank(0)
model_mg.set_pipeline_model_parallel_rank(0)

margs = model_mg.get_args()
md = build_metadata(args, margs)
queue.put(md)
# 给模型参数赋值,model_hf是从HF权重加载模型参数,model_mg是随机加载权重
model_hf.get_modules_from_pretrained()
model_mg.get_modules_from_config()

model_mg.update_module(model_hf)
###############################
# 将model_hf的模型参数赋值给model_mg,执行完后
# model_hf完成了它的使命,可以删除这个变量并释放这部分内存。
###############################
del model_hf
import gc
gc.collect()

# 接下来开始处理每个部分的权重
def queue_put(name, msg):
    logger.info(f"sending {name}")
    msg["name"] = name
    queue.put(msg)

# Send embeddings.
# 处理embeddings
message = get_message_preprocess(model_mg, md)
queue_put("embeddings", message)
for layer_idx in range(margs.num_layers):
    # Grab all parallel tensors for this layer.
    message = {}
    # message = get_message_layer_norm(message, model_mg, layer_idx, md, args)
    # message = get_message_layer_attn(message, model_mg, layer_idx, md, args)
    # message = get_message_layer_mlp(message, model_mg, layer_idx, md)
    # to_detach(message)
    # queue_put(f"transformer layer {layer_idx}", message)
    ###############################
    # 这里model_mg的layer_idx 层处理完成后,对应层的参数在后续不再使用,
    # 这部分内存也可以释放掉,因此我们将这里的layer_idx固定成0,
    # 每次都处理第0层,处理完后释放掉这块内存。
    ###############################
    message = get_message_layer_norm(message, model_mg, 0, md, args)
    message = get_message_layer_attn(message, model_mg, 0, md, args)
    # 这里是因为要判断是不是MOE层,所以传入layer_idx
    message = get_message_layer_mlp(message, model_mg, layer_idx, md)
    to_detach(message)
    queue_put(f"transformer layer {layer_idx}", message)
    del model_mg.module[0][0][0].decoder.layers[0]
    import gc
    gc.collect()
# Send final norm from tp_rank 0.
# 处理final norm
message = get_message_postprocess(model_mg, md)
queue_put("final norm", message)
# 处理output layer
message = get_message_output_layer(model_mg, md)
if message is not None:
    queue_put("output layer", message)

queue.put("done")

def get_message_layer_mlp(message, model, layer_idx, md=None, tp_size=1):
margs = model.get_args()
first_k_dense_replace = model.get_first_k_dense_replace()
moe_layer_freq = model.get_moe_layer_freq()
shared_expert_gate = getattr(margs, 'shared_expert_gate', None)

if layer_idx >= first_k_dense_replace and layer_idx % moe_layer_freq == 0:
    ###############################
    # 这里要和上面逻辑对应,每次都处理第0层。
    ###############################
    layer_idx = 0
    message["mlp_moe"] = {}
    mlp_router_weight = model.get_layers_mlp_router_weight(layer_idx=layer_idx)
    message["mlp_moe"]["mlp router weight"] = mlp_router_weight
    if shared_expert_gate:
        shared_expert_gate = model.get_layers_mlp_shared_expert_gate_weight(layer_idx=layer_idx)
        message["mlp_moe"]["mlp shared_expert_gate weight"] = shared_expert_gate
    if getattr(margs, "n_shared_experts", None) is not None:
        fc1_weight = model.get_layers_mlp_shared_experts_linear_fc1_weight(layer_idx=layer_idx)
        fc2_weight = model.get_layers_mlp_shared_experts_linear_fc2_weight(layer_idx=layer_idx)
        message["mlp_moe"]["mlp shared experts linear fc1 weight"] = fc1_weight
        message["mlp_moe"]["mlp shared experts linear fc2 weight"] = fc2_weight
    if margs.moe_grouped_gemm:
        weight1 = model.get_layers_mlp_experts_weight1_module(layer_idx=layer_idx)
        weight2 = model.get_layers_mlp_experts_weight2_module(layer_idx=layer_idx)
        message["mlp_moe"]["mlp experts weight1 module"] = weight1
        message["mlp_moe"]["mlp experts weight2 module"] = weight2
    else:
        for expert_idx in range(margs.num_experts):
            kwargs = {'expert_idx': expert_idx}
            expert = _get_message_layer_mlp({}, model, layer_idx, md=md, tp_size=tp_size, is_moe_mlp=True, **kwargs)
            message["mlp_moe"][f"expert {expert_idx}"] = expert
    return message
else:
    ###############################
    # 这里要和上面逻辑对应,每次都处理第0层。
    ###############################
    layer_idx = 0
    return _get_message_layer_mlp(message, model, layer_idx, md=md, tp_size=tp_size)


通过这样修改,我们将deepseek2-236B的权重转换的内存降低到了2T以内。
同样在saver.py和loader_mg.py里也可以用同样的思路降低内存。