用文字 + ASCII 图 + 模块化描述,完整绘制出这个 TinyTransformer 模型的网络结构图!
你可以根据下面的描述,轻松用纸笔、PPT、或绘图工具(如 draw.io)画出来。
🎯 模型总览:Encoder-Decoder Transformer
输入序列 (src) 目标序列 (tgt)
│ │
▼ ▼
[Embedding + PosEnc] [Embedding + PosEnc]
│ │
▼ │
┌─────────────┐ │
│ EncoderLayer│×2 │
└─────────────┘ │
│ │
├───────────────────────┤
│ ▼
│ ┌──────────────────┐
│ │ DecoderLayer │×2
│ │ 1. Self-Attention│
│ │ 2. Cross-Attention (←来自Encoder)
│ │ 3. FeedForward │
│ └──────────────────┘
│ │
│ ▼
│ [Linear (vocab_size)]
│ │
│ ▼
└──────────────────▶ 输出 logits
🧱 一、输入处理(共享 Embedding + Positional Encoding)
输入:src (shape: [B, 3]) 和 tgt (shape: [B, 7])
self.embedding = nn.Embedding(vocab_size, d_model) # 29 → 64
self.pos_encoding = PositionalEncoding(d_model, max_len)
流程:
src ──→ Embedding ──→ [B, 3, 64] ──→ + PosEnc ──→ [B, 3, 64]
tgt ──→ Embedding ──→ [B, 7, 64] ──→ + PosEnc ──→ [B, 7, 64]
✅ 位置编码
pe是[1, max_len, 64],通过广播加到输入上
🔄 二、Encoder 部分(2层)
每层 EncoderLayer 包含:
- Multi-Head Self-Attention(4头,每头16维)
- Add + LayerNorm
- FeedForward(64 → 128 → 64)
- Add + LayerNorm
结构图(单层):
x ───────────────┐
│ │
▼ │
[MultiHeadAttention] │
│ │
▼ │
Dropout │
│ │
+ ←──────────────┘ (Residual)
│
LayerNorm
│
▼
[FeedForward]
│
Dropout
│
+ ←──────────────┐ (Residual)
│ │
LayerNorm │
│ │
└────────────────┘
Encoder 整体:
输入: [B, 3, 64]
│
▼
EncoderLayer 1 → [B, 3, 64]
│
▼
EncoderLayer 2 → [B, 3, 64]
│
▼
输出: enc_out [B, 3, 64] + src_mask [B, 1, 1, 3]
🔄 三、Decoder 部分(2层)
每层 DecoderLayer 包含:
- Masked Multi-Head Self-Attention(看自己,屏蔽未来)
- Add + LayerNorm
- Multi-Head Cross-Attention(Q=自己,K,V=Encoder输出)
- Add + LayerNorm
- FeedForward
- Add + LayerNorm
结构图(单层):
x ───────────────────┐
│ │
▼ │
[Masked Self-Attention] │
│ │
▼ │
Dropout │
│ │
+ ←──────────────────┘ (Residual)
│
LayerNorm
│
▼
[Cross-Attention] ←────── enc_out
│
Dropout
│
+ ←──────────────────┐ (Residual)
│ │
LayerNorm │
│ │
▼ │
[FeedForward] │
│ │
Dropout │
│ │
+ ←──────────────────┘ (Residual)
│
LayerNorm
│
▼
Decoder 整体:
输入: tgt [B, 7, 64] + enc_out [B, 3, 64] + masks
│
▼
DecoderLayer 1 → [B, 7, 64]
│
▼
DecoderLayer 2 → [B, 7, 64]
│
▼
Linear(64, 29) → [B, 7, 29] (logits)
🧩 四、关键子模块详解
1. MultiHeadAttention(4头)
输入: q, k, v [B, L, 64]
│
▼
Linear → [B, L, 64]
│
▼
view → [B, L, 4, 16] → transpose → [B, 4, L, 16]
│
▼
Q @ K^T → [B, 4, Lq, Lk] → /√16 → + mask → softmax → dropout
│
▼
@ V → [B, 4, Lq, 16]
│
▼
transpose → [B, Lq, 4, 16] → contiguous → view → [B, Lq, 64]
│
▼
Linear → [B, Lq, 64]
2. FeedForward
输入: [B, L, 64]
│
▼
Linear(64, 128) → ReLU → [B, L, 128]
│
▼
Dropout
│
▼
Linear(128, 64) → [B, L, 64]
3. Masks
- src_mask:
[B, 1, 1, L_src]→ 屏蔽<pad> - tgt_mask:
[B, 1, L_tgt, L_tgt]→ 屏蔽<pad>+ 下三角(防偷看未来)
🎯 五、前向传播完整流程图
训练时:
src [B,3] ──→ Embed+Pos ──→ Encoder (×2) ──→ enc_out [B,3,64]
│
tgt [B,7] ──→ Embed+Pos ──→ Decoder (×2) ←──┘
│
Linear(64,29)
│
logits [B,7,29]
生成时(自回归):
src ──→ Encoder ──→ enc_out
│
<sos> ──→ Decoder ──→ pred1
│
<sos>+pred1 ──→ Decoder ──→ pred2
│
... 直到 <eos> 或 max_len
📊 六、参数规模估算
| 模块 | 参数量 | 计算 |
|---|---|---|
| Embedding | 29 × 64 = 1,856 | vocab_size × d_model |
| Encoder Layer ×2 | 2 × (4×64² + 64×128×2 + ...) ≈ 2 × 20K = 40K | 见下文 |
| Decoder Layer ×2 | 2 × (类似Encoder + 额外Cross-Attn) ≈ 2 × 25K = 50K | |
| Output Linear | 64 × 29 = 1,856 | d_model × vocab_size |
| 总计 | ≈ 93K |
✅ 非常轻量,适合教学!
🖼️ ASCII 完整网络图
TRAINING TIME
┌───────────────────────────────────────────────────────────────┐
│ │
┌─────────┐ ┌───────────┐ │ ┌─────────┐ ┌───────────┐ │
│ src │ │ Embedding │ │ │ tgt │ │ Embedding │ │
│ [B,3] │ ─▶│ + PosEnc │ │ │ [B,7] │ ─▶│ + PosEnc │ │
└─────────┘ └───────────┘ │ └─────────┘ └───────────┘ │
│ [B,3,64] │ │ [B,7,64] │
▼ │ ▼ │
┌─────────────────┐ │ ┌─────────────────┐ │
│ Encoder Layer 1 │ │ │ Decoder Layer 1 │ │
│ - Self-Attn │ │ │ - Masked Self-Attn │
│ - FFN │ │ │ - Cross-Attn (← enc_out) │
└─────────────────┘ │ │ - FFN │ │
│ [B,3,64] │ └─────────────────┘ │
▼ │ │ [B,7,64] │
┌─────────────────┐ │ ▼ │
│ Encoder Layer 2 │ │ ┌─────────────────┐ │
│ - Self-Attn │ │ │ Decoder Layer 2 │ │
│ - FFN │ │ │ - Masked Self-Attn │
└─────────────────┘ │ │ - Cross-Attn (← enc_out) │
│ [B,3,64] ────┼──────▶│ - FFN │ │
▼ │ └─────────────────┘ │
enc_out │ │ [B,7,64] │
│ ▼ │
│ ┌─────────────────┐ │
│ │ Linear(64,29) │ │
│ └─────────────────┘ │
│ │ [B,7,29] │
│ ▼ │
│ logits │
└───────────────────────────────────────────────────────────────┘
GENERATION TIME (自回归)
┌───────────────────────────────────────────────────────────────┐
│ │
│ ┌─────────┐ ┌───────────┐ │
│ │ src │ │ Embedding │ │
│ │ │ ─▶│ + PosEnc │ │
│ └─────────┘ └───────────┘ │
│ │ [B,3,64] │
│ ▼ │
│ ┌─────────────────┐ │
│ │ Encoder Layer 1 │ │
│ └─────────────────┘ │
│ │ │
│ ▼ │
│ ┌─────────────────┐ │
│ │ Encoder Layer 2 │ │
│ └─────────────────┘ │
│ │ │
│ ▼ │
│ enc_out [B,3,64] │
│ │ │
│ ┌───────────────────────────────────────────────────────┐ │
│ │ │ │
│ │ ┌──────┐ │ │
│ │ │<sos> │ ──→ Embed+Pos ──→ Decoder ─→ Linear ─→ pred1 │ │
│ │ └──────┘ ▲ ▲ ▲ │ │
│ │ │ │ │ │ │
│ │ ┌─────────────┐ │ │ │ │ │
│ │ │ <sos>+pred1 │ ─┴──────────┴───────────┘ │ │
│ │ └─────────────┘ │ │
│ │ ▲ │ │
│ │ │ (循环直到<eos>或max_len) │ │
│ └───────────────────────────────────────────────────────┘ │
└───────────────────────────────────────────────────────────────┘
✅ 总结:
你拥有的是一个 完整、标准、可训练、可生成 的微型 Transformer,结构如下:
- 输入/输出:字符级序列
- Embedding:29词表 → 64维
- 位置编码:sin/cos 固定编码
- Encoder:2层,每层 = Self-Attention + FFN
- Decoder:2层,每层 = Masked Self-Attn + Cross-Attn + FFN
- 输出层:64 → 29 的线性层
- 总参数:约 93K