TinyTransformer 模型的网络结构图

64 阅读3分钟

用文字 + ASCII 图 + 模块化描述,完整绘制出这个 TinyTransformer 模型的网络结构图!

你可以根据下面的描述,轻松用纸笔、PPT、或绘图工具(如 draw.io)画出来。


🎯 模型总览:Encoder-Decoder Transformer

        输入序列 (src)          目标序列 (tgt)
             │                       │
             ▼                       ▼
      [Embedding + PosEnc]    [Embedding + PosEnc]
             │                       │
             ▼                       │
      ┌─────────────┐                │
      │ EncoderLayer│×2             │
      └─────────────┘                │
             │                       │
             ├───────────────────────┤
             │                       ▼
             │             ┌──────────────────┐
             │             │ DecoderLayer     │×2
             │             │ 1. Self-Attention│
             │             │ 2. Cross-Attention (←来自Encoder)
             │             │ 3. FeedForward   │
             │             └──────────────────┘
             │                       │
             │                       ▼
             │                [Linear (vocab_size)]
             │                       │
             │                       ▼
             └──────────────────▶ 输出 logits

🧱 一、输入处理(共享 Embedding + Positional Encoding)

输入:src (shape: [B, 3]) 和 tgt (shape: [B, 7])

self.embedding = nn.Embedding(vocab_size, d_model)  # 29 → 64
self.pos_encoding = PositionalEncoding(d_model, max_len)

流程:

src ──→ Embedding ──→ [B, 3, 64] ──→ + PosEnc ──→ [B, 3, 64]
tgt ──→ Embedding ──→ [B, 7, 64] ──→ + PosEnc ──→ [B, 7, 64]

✅ 位置编码 pe[1, max_len, 64],通过广播加到输入上


🔄 二、Encoder 部分(2层)

每层 EncoderLayer 包含:

  1. Multi-Head Self-Attention(4头,每头16维)
  2. Add + LayerNorm
  3. FeedForward(64 → 128 → 64)
  4. Add + LayerNorm

结构图(单层):

      x ───────────────┐
      │                │
      ▼                │
[MultiHeadAttention]   │
      │                │
      ▼                │
   Dropout             │
      │                │
      + ←──────────────┘  (Residual)
      │
   LayerNorm
      │
      ▼
[FeedForward]
      │
   Dropout
      │
      + ←──────────────┐  (Residual)
      │                │
   LayerNorm           │
      │                │
      └────────────────┘

Encoder 整体:

输入: [B, 3, 64]
  │
  ▼
EncoderLayer 1 → [B, 3, 64]
  │
  ▼
EncoderLayer 2 → [B, 3, 64]
  │
  ▼
输出: enc_out [B, 3, 64] + src_mask [B, 1, 1, 3]

🔄 三、Decoder 部分(2层)

每层 DecoderLayer 包含:

  1. Masked Multi-Head Self-Attention(看自己,屏蔽未来)
  2. Add + LayerNorm
  3. Multi-Head Cross-Attention(Q=自己,K,V=Encoder输出)
  4. Add + LayerNorm
  5. FeedForward
  6. Add + LayerNorm

结构图(单层):

      x ───────────────────┐
      │                    │
      ▼                    │
[Masked Self-Attention]    │
      │                    │
      ▼                    │
   Dropout                 │
      │                    │
      + ←──────────────────┘  (Residual)
      │
   LayerNorm
      │
      ▼
[Cross-Attention] ←────── enc_out
      │
   Dropout
      │
      + ←──────────────────┐  (Residual)
      │                    │
   LayerNorm               │
      │                    │
      ▼                    │
[FeedForward]              │
      │                    │
   Dropout                 │
      │                    │
      + ←──────────────────┘  (Residual)
      │
   LayerNorm
      │
      ▼

Decoder 整体:

输入: tgt [B, 7, 64] + enc_out [B, 3, 64] + masks
  │
  ▼
DecoderLayer 1 → [B, 7, 64]
  │
  ▼
DecoderLayer 2 → [B, 7, 64]
  │
  ▼
Linear(64, 29) → [B, 7, 29] (logits)

🧩 四、关键子模块详解

1. MultiHeadAttention(4头)

输入: q, k, v [B, L, 64]
  │
  ▼
Linear[B, L, 64]
  │
  ▼
view[B, L, 4, 16]transpose[B, 4, L, 16]
  │
  ▼
Q @ K^T[B, 4, Lq, Lk] → /√16 → + masksoftmaxdropout
  │
  ▼
@ V[B, 4, Lq, 16]
  │
  ▼
transpose[B, Lq, 4, 16]contiguousview[B, Lq, 64]
  │
  ▼
Linear[B, Lq, 64]

2. FeedForward

输入: [B, L, 64]
  │
  ▼
Linear(64, 128) → ReLU → [B, L, 128]
  │
  ▼
Dropout
  │
  ▼
Linear(128, 64) → [B, L, 64]

3. Masks

  • src_mask: [B, 1, 1, L_src] → 屏蔽 <pad>
  • tgt_mask: [B, 1, L_tgt, L_tgt] → 屏蔽 <pad> + 下三角(防偷看未来)

🎯 五、前向传播完整流程图

训练时:
src [B,3] ──→ Embed+Pos ──→ Encoder (×2) ──→ enc_out [B,3,64]
                                             │
tgt [B,7] ──→ Embed+Pos ──→ Decoder (×2) ←──┘
                             │
                          Linear(64,29)
                             │
                          logits [B,7,29]

生成时(自回归):
src ──→ Encoder ──→ enc_out
                   │
<sos> ──→ Decoder ──→ pred1
                   │
<sos>+pred1 ──→ Decoder ──→ pred2
                   │
... 直到 <eos> 或 max_len

📊 六、参数规模估算

模块参数量计算
Embedding29 × 64 = 1,856vocab_size × d_model
Encoder Layer ×22 × (4×64² + 64×128×2 + ...) ≈ 2 × 20K = 40K见下文
Decoder Layer ×22 × (类似Encoder + 额外Cross-Attn) ≈ 2 × 25K = 50K
Output Linear64 × 29 = 1,856d_model × vocab_size
总计≈ 93K

✅ 非常轻量,适合教学!


🖼️ ASCII 完整网络图

                              TRAINING TIME
                              ┌───────────────────────────────────────────────────────────────┐
                              │                                                               │
┌─────────┐    ┌───────────┐  │  ┌─────────┐    ┌───────────┐                                │
│  src    │    │ Embedding │  │  │  tgt    │    │ Embedding │                                │
│ [B,3]   │ ─▶│  + PosEnc │  │  │ [B,7]   │ ─▶│  + PosEnc │                                │
└─────────┘    └───────────┘  │  └─────────┘    └───────────┘                                │
               │ [B,3,64]     │                │ [B,7,64]                                   │
               ▼              │                ▼                                            │
       ┌─────────────────┐    │        ┌─────────────────┐                                   │
       │ Encoder Layer 1 │    │        │ Decoder Layer 1 │                                   │
       │  - Self-Attn    │    │        │  - Masked Self-Attn                                │
       │  - FFN          │    │        │  - Cross-Attn (← enc_out)                          │
       └─────────────────┘    │        │  - FFN          │                                   │
               │ [B,3,64]     │        └─────────────────┘                                   │
               ▼              │                │ [B,7,64]                                   │
       ┌─────────────────┐    │                ▼                                            │
       │ Encoder Layer 2 │    │        ┌─────────────────┐                                   │
       │  - Self-Attn    │    │        │ Decoder Layer 2 │                                   │
       │  - FFN          │    │        │  - Masked Self-Attn                                │
       └─────────────────┘    │        │  - Cross-Attn (← enc_out)                          │
               │ [B,3,64] ────┼──────▶│  - FFN          │                                   │
               ▼              │        └─────────────────┘                                   │
           enc_out            │                │ [B,7,64]                                   │
                              │                ▼                                            │
                              │        ┌─────────────────┐                                   │
                              │        │  Linear(64,29)  │                                   │
                              │        └─────────────────┘                                   │
                              │                │ [B,7,29]                                   │
                              │                ▼                                            │
                              │             logits                                          │
                              └───────────────────────────────────────────────────────────────┘

                              GENERATION TIME (自回归)
                              ┌───────────────────────────────────────────────────────────────┐
                              │                                                               │
                              │  ┌─────────┐    ┌───────────┐                                │
                              │  │  src    │    │ Embedding │                                │
                              │  │         │ ─▶│  + PosEnc │                                │
                              │  └─────────┘    └───────────┘                                │
                              │                │ [B,3,64]                                   │
                              │                ▼                                            │
                              │        ┌─────────────────┐                                   │
                              │        │ Encoder Layer 1 │                                   │
                              │        └─────────────────┘                                   │
                              │                │                                            │
                              │                ▼                                            │
                              │        ┌─────────────────┐                                   │
                              │        │ Encoder Layer 2 │                                   │
                              │        └─────────────────┘                                   │
                              │                │                                            │
                              │                ▼                                            │
                              │            enc_out [B,3,64]                                 │
                              │                │                                            │
                              │  ┌───────────────────────────────────────────────────────┐  │
                              │  │                                                       │  │
                              │  │  ┌──────┐                                             │  │
                              │  │  │<sos> │ ──→ Embed+Pos ──→ Decoder ─→ Linear ─→ pred1 │  │
                              │  │  └──────┘         ▲          ▲           ▲             │  │
                              │  │                   │          │           │             │  │
                              │  │  ┌─────────────┐  │          │           │             │  │
                              │  │  │ <sos>+pred1 │ ─┴──────────┴───────────┘             │  │
                              │  │  └─────────────┘                                       │  │
                              │  │        ▲                                               │  │
                              │  │        │ (循环直到<eos>或max_len)                      │  │
                              │  └───────────────────────────────────────────────────────┘  │
                              └───────────────────────────────────────────────────────────────┘

✅ 总结:

你拥有的是一个 完整、标准、可训练、可生成 的微型 Transformer,结构如下:

  • 输入/输出:字符级序列
  • Embedding:29词表 → 64维
  • 位置编码:sin/cos 固定编码
  • Encoder:2层,每层 = Self-Attention + FFN
  • Decoder:2层,每层 = Masked Self-Attn + Cross-Attn + FFN
  • 输出层:64 → 29 的线性层
  • 总参数:约 93K