【AI实战】用LSTM写诗作文本生成,从原理到代码详解(附完整项目)

3 阅读6分钟

💡 本教程是《AI 入门 30 天挑战》系列的项目实战部分


🎯 项目简介

这是一个基于 LSTM 的文本生成项目,可以学习古诗词、小说等文本,然后自动生成新的内容。

你将学到:

  • ✅ LSTM 循环神经网络原理
  • ✅ 字符级语言模型
  • ✅ 文本数据预处理
  • ✅ 温度参数控制创造性
  • ✅ 完整的 NLP 项目流程

项目特点:

  • 🚀 从零实现,代码清晰
  • 📝 支持中文文本生成
  • 🎲 温度参数调节创造性
  • 📊 训练过程可视化

📂 项目结构

text-generation/
├── main.py              # 主程序入口
├── dataset.py           # 数据加载和预处理
├── model.py             # LSTM 模型定义
├── train.py             # 训练脚本
├── generate.py          # 文本生成
├── requirements.txt     # 依赖包
└── README.md            # 详细说明

🚀 快速开始

1. 安装依赖

pip install torch numpy

2. 准备数据

data/ 目录下放置你的文本文件,例如 poems.txt

床前明月光疑是地上霜
举头望明月低头思故乡
春眠不觉晓处处闻啼鸟
夜来风雨声花落知多少

3. 训练模型

cd projects/text-generation
python main.py --mode train --epochs 50

训练完成后会生成:

  • model.pth - 训练好的模型权重
  • 训练过程中会打印生成的示例文本

4. 生成文本

# 使用默认参数
python main.py --mode generate --prompt "床前"

# 自定义参数
python main.py --mode generate --prompt "春眠" --length 100 --temperature 0.8

🔍 核心代码解析

LSTM 模型架构

class LSTMGenerator(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, num_layers, dropout):
        super(LSTMGenerator, self).__init__()
        
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        
        # 嵌入层:将字符索引转换为向量
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        
        # LSTM 层:捕捉序列中的长期依赖
        self.lstm = nn.LSTM(
            embedding_dim,
            hidden_dim,
            num_layers=num_layers,
            batch_first=True,
            dropout=dropout if num_layers > 1 else 0
        )
        
        # 输出层:预测下一个字符
        self.fc = nn.Linear(hidden_dim, vocab_size)
        
        # Dropout 防止过拟合
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x, hidden=None):
        # 嵌入
        embeds = self.embedding(x)
        embeds = self.dropout(embeds)
        
        # LSTM
        if hidden is None:
            lstm_out, hidden = self.lstm(embeds)
        else:
            lstm_out, hidden = self.lstm(embeds, hidden)
        
        lstm_out = self.dropout(lstm_out)
        
        # 全连接
        output = self.fc(lstm_out)
        
        return output, hidden

大白话解释:

  • Embedding 层:把每个字变成一个向量(比如"床"变成 [0.1, 0.5, -0.3...])
  • LSTM 层:记住前面出现过的字,理解上下文
  • 输出层:根据当前状态预测下一个字是什么
  • Dropout:随机忘记一些信息,防止死记硬背

📊 文本生成原理

字符级模型 vs 词级模型

字符级(本项目):

  • ✅ 词汇表小(几百个常用汉字)
  • ✅ 不会出现未登录词
  • ✅ 适合古诗等短文本
  • ❌ 需要更长的序列才能理解语义

词级:

  • ✅ 更容易理解语义
  • ❌ 词汇表大(几万个词)
  • ❌ 需要分词
  • ❌ 可能遇到未登录词

温度参数(Temperature)

# 应用温度
output = output / temperature
probs = torch.softmax(output, dim=0)

# 采样
next_char_idx = torch.multinomial(probs, 1).item()

温度的作用:

  • 低温(0.2-0.5):保守,选择概率最高的字,生成内容稳定但单调
  • 中温(0.7-1.0):平衡,有一定创造性但不离谱
  • 高温(1.2-2.0):激进,随机性强,生成内容有趣但可能不通顺

示例:

Prompt: "床前明月"

Temperature 0.5: 床前明月光疑是地上霜举头望明月低头思故乡
Temperature 1.0: 床前明月光照我床上明月光照我床上明月
Temperature 1.5: 床前明月风吹花落地水声中人静夜深时

💡 优化技巧

1. 增加训练数据

# 合并多个文本文件
texts = []
for file in ['poems.txt', 'novels.txt', 'essays.txt']:
    with open(file, 'r', encoding='utf-8') as f:
        texts.append(f.read())

full_text = '\n'.join(texts)

效果:

  • 模型学习到更多风格
  • 生成内容更丰富
  • 但可能需要更多训练时间

2. 调整序列长度

SEQ_LENGTH = 50  # 默认值

建议:

  • 古诗:30-50(一句诗的长度)
  • 现代文:100-200
  • 小说:200-500

3. 增加 LSTM 层数

NUM_LAYERS = 2  # 可以尝试 3 或 4

注意:

  • 层数越多,模型越强大
  • 但也更容易过拟合
  • 需要更多数据和训练时间

🎨 生成示例

古诗词风格

Prompt: "春眠不觉"
Generated: 春眠不觉晓处处闻啼鸟夜来风雨声花落知多少

Prompt: "白日依山"
Generated: 白日依山尽黄河入海流欲穷千里目更上一层楼

创意写作

Prompt: "在一个遥远的"
Generated: 在一个遥远的星球上有一个神秘的城堡里面住着一个善良的公主她每天都在等待着王子的到来

Prompt: "人工智能的未来"
Generated: 人工智能的未来充满了无限可能它将改变我们的生活方式工作方式甚至思维方式

🤔 常见问题

Q1: 生成的文本不通顺怎么办?

A:

  • 增加训练数据量
  • 增加训练轮数(epochs)
  • 降低温度参数
  • 使用更长的序列长度

Q2: 训练太慢怎么办?

A:

  • 使用 GPU(速度提升 10-50 倍)
  • 减少词汇表大小(只保留常用字)
  • 减少隐藏层维度

Q3: 如何生成英文文本?

A: 修改数据集加载部分,使用英文文本即可。代码不需要改动,因为模型是字符级的。


📚 相关教程

这是《AI 入门 30 天挑战》的项目实战部分,前置知识:

完整 30 天教程:


🎉 总结

通过这个实战项目,你学会了:

  1. ✅ LSTM 循环神经网络原理
  2. ✅ 字符级语言模型实现
  3. ✅ 文本数据预处理技巧
  4. ✅ 温度参数控制创造性

下一步:

  • ⭐ Star GitHub 获取完整代码
  • ➕ 关注专栏查看更多项目
  • 💬 评论区分享你生成的有趣文本

其他项目实战:


🎉 恭喜你完成今天的学习!

📚 学习路径导航

上一篇当前下一篇
项目实战 - CIFAR-10项目实战 - 文本生成项目实战 - 目标检测

🔗 资源汇总

💬 互动时间

思考题:你用这个项目生成了什么有趣的文本?在评论区分享一下吧!

欢迎在评论区分享你的想法或疑问!👇

❤️ 如果有帮助

  • 👍 点赞:让更多人看到这篇教程
  • Star GitHub:获取完整代码和项目
  • 关注专栏:不错过后续更新
  • 🔄 分享给朋友:一起学习进步

明天见!继续下一个项目实战~ 🚀


本文是《AI 入门 30 天挑战》系列的项目实战篇 完整代码已开源,欢迎 Star 支持!