大模型训练原理详解:从预训练到推理

0 阅读3分钟

📌 本文首发于掘金,作者:AI探索者 🔗 注明出处

背景

本文是魔塔社区"七天入门大模型"课程的学习笔记,详细讲解大模型的训练原理、推理过程,以及PyTorch代码实战。


训练 vs 推理

概念定义公式
推理使用训练好的模型,根据输入得到输出y = ax + b(已知a,b,求y)
训练使用数据,反推参数a,b的值已知x,y,求a,b

三种数据类型

数据类型例子训练阶段
预训练数据"床前明月光"预训练
微调数据"番茄炒蛋怎么做?答:..."微调
对齐数据"如何杀人?答:通过法律途径"RLHF

训练流程

预训练数据 → 预训练 → Base模型
     ↓
微调数据 → 微调
     ↓
对齐数据 → 对齐 → Chat模型

模型推理过程

输入文本 → 分词 → Embedding → Transformer → 概率计算 → 采样 → 输出

自回归生成

text = "今天天气"
while not ended:
    next_token = model.predict(text)
    text += next_token
    if next_token == "EOS":
        break

PyTorch 入门

Tensor 基本操作

import torch

a = torch.tensor([1.], requires_grad=True)
b = torch.tensor([2.], requires_grad=True)

c = a * b
c.backward()

print(a.grad, b.grad)  # tensor([2.]) tensor([1.])

神经网络

import torch
from torch.nn import Linear

class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = Linear(4, 4)

model = MyModule()
print(model.state_dict())

常见坑:设备问题

错误

RuntimeError: Expected all tensors to be on the same device

解决

# 正确做法
model = model.to('cuda')
data = torch.tensor([1,2,3]).to('cuda')  # 必须接收返回值!
output = model(data)

完整训练代码

import os
import random
import numpy as np
import torch
from torch.optim import AdamW
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import Dataset, DataLoader
from torch.nn import CrossEntropyLoss
import torch.nn.functional as F


# 1. 设置随机种子(保证结果可复现)
seed = 42
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)


# 2. 确定CUDA随机性(防止训练结果不稳定)
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"
torch.use_deterministic_algorithms(True)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False


# 3. 定义模型
class MyModule(torch.nn.Module):
    def __init__(self, n_classes=2):
        super().__init__()
        self.linear = torch.nn.Linear(16, n_classes)
        self.relu = torch.nn.ReLU()


    def forward(self, tensor, label=None):
        output = {'logits': self.relu(self.linear(tensor))}
        if label is not None:
            loss_fct = CrossEntropyLoss()
            output['loss'] = loss_fct(output['logits'], label)
        return output


# 4. 定义数据集
class MyDataset(Dataset):
    def __len__(self):
        return 5

    def __getitem__(self, index):
        return {'tensor': torch.rand(16), 'label': torch.tensor(1)}


# 5. 初始化
model = MyModule()
dataset = MyDataset()
dataloader = DataLoader(dataset, batch_size=4)
optimizer = AdamW(model.parameters(), lr=5e-4)
lr_scheduler = StepLR(optimizer, 2)


# 辅助函数:打印模型预测结果
def print_predictions(model, dataset, title="模型预测结果"):
    print(f"\n{'='*50}")
    print(f"{title}")
    print(f"{'='*50}")
    model.eval()
    with torch.no_grad():
        for i in range(len(dataset)):
            sample = dataset[i]
            tensor = sample['tensor'].unsqueeze(0)
            label = sample['label'].item()

            output = model(tensor)
            logits = output['logits'][0]
            probs = F.softmax(logits, dim=0)
            pred = torch.argmax(logits).item()

            print(f"样本 {i}: 真实标签={label}, 预测标签={pred}, 置信度=[{probs[0]:.4f}, {probs[1]:.4f}]")
    model.train()


# 6. 训练前:输出模型预测结果
print_predictions(model, dataset, "【训练前】模型预测结果")


# 7. 训练循环
print(f"\n{'='*50}")
print("开始训练...")
print(f"{'='*50}")

for epoch in range(3):
    epoch_loss = 0
    batch_count = 0
    for batch in dataloader:
        # 前向传播
        output = model(**batch)

        # 反向传播
        output['loss'].backward()

        # 更新参数
        optimizer.step()

        # 清理梯度
        optimizer.zero_grad()

        # 调整学习率
        lr_scheduler.step()

        epoch_loss += output['loss'].item()
        batch_count += 1

    avg_loss = epoch_loss / batch_count
    print(f"Epoch {epoch}, 平均 Loss: {avg_loss:.4f}")


# 8. 训练后:输出模型预测结果
print_predictions(model, dataset, "【训练后】模型预测结果")


# 9. 输出模型参数变化情况
print(f"\n{'='*50}")
print("训练完成!")
print(f"{'='*50}")

Pasted image 20260227091806.png

总结

概念说明
预训练用海量文本学通用知识 → Base模型
微调用问答数据学专业知识
对齐让回答符合人类期望 → Chat模型
推理用训练好的模型输出结果

参考


如果觉得有帮助,欢迎点赞、收藏、关注!