从‘中国-熊猫’到代码:手把手用PyTorch复现Transformer的QKV计算(附可视化)

张开发
2026/5/7 0:17:17 15 分钟阅读

分享文章

从‘中国-熊猫’到代码:手把手用PyTorch复现Transformer的QKV计算(附可视化)
从词向量到注意力机制用PyTorch实现Transformer核心计算在自然语言处理领域Transformer架构已经成为现代深度学习模型的基石。理解其核心的注意力机制特别是Q(Query)、K(Key)、V(Value)计算过程对于掌握当代NLP技术至关重要。本文将通过PyTorch代码实现带您从词向量开始逐步构建完整的注意力计算流程并辅以可视化展示让抽象的概念变得直观可操作。1. 词向量与位置编码构建语义空间词向量是自然语言处理的基础构建块它将离散的单词映射到连续的向量空间。在Transformer中每个单词首先通过嵌入层转换为固定维度的向量表示。import torch import torch.nn as nn # 定义词汇表大小和嵌入维度 vocab_size 10000 embed_dim 512 # 创建嵌入层 embedding nn.Embedding(vocab_size, embed_dim) # 示例将中国和熊猫转换为词向量 china_idx torch.tensor([100]) # 假设中国在词汇表中的索引是100 panda_idx torch.tensor([200]) # 假设熊猫在词汇表中的索引是200 china_embed embedding(china_idx) panda_embed embedding(panda_idx)Transformer还引入了位置编码为序列中的每个位置添加独特的位置信息。以下是实现正弦位置编码的代码def positional_encoding(max_len, d_model): position torch.arange(max_len).unsqueeze(1) div_term torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)) pe torch.zeros(max_len, d_model) pe[:, 0::2] torch.sin(position * div_term) pe[:, 1::2] torch.cos(position * div_term) return pe max_len 50 # 最大序列长度 pos_encoding positional_encoding(max_len, embed_dim)值得注意的是在实际应用中词向量和位置编码会相加形成最终的输入表示。这种组合方式既保留了单词的语义信息又编码了其在序列中的位置关系。2. QKV矩阵注意力机制的核心注意力机制的核心在于Query、Key和Value三个矩阵的计算。这些矩阵通过线性变换从输入向量派生而来各自承担不同的角色Query表示当前关注的问题或查询Key表示可用于回答查询的线索Value包含实际的信息内容class SelfAttention(nn.Module): def __init__(self, embed_dim): super().__init__() self.embed_dim embed_dim # 定义Q、K、V的线性变换层 self.query nn.Linear(embed_dim, embed_dim) self.key nn.Linear(embed_dim, embed_dim) self.value nn.Linear(embed_dim, embed_dim) def forward(self, x): # x的形状: (batch_size, seq_len, embed_dim) Q self.query(x) # 计算Query矩阵 K self.key(x) # 计算Key矩阵 V self.value(x) # 计算Value矩阵 return Q, K, V为了更高效地计算Transformer通常使用多头注意力机制。以下是多头注意力的实现class MultiHeadAttention(nn.Module): def __init__(self, embed_dim, num_heads): super().__init__() self.embed_dim embed_dim self.num_heads num_heads self.head_dim embed_dim // num_heads assert self.head_dim * num_heads embed_dim, embed_dim必须能被num_heads整除 # 定义Q、K、V的线性变换层 self.qkv nn.Linear(embed_dim, embed_dim * 3) self.out nn.Linear(embed_dim, embed_dim) def forward(self, x): batch_size, seq_len, _ x.shape # 通过线性变换同时计算Q、K、V qkv self.qkv(x) qkv qkv.reshape(batch_size, seq_len, 3, self.num_heads, self.head_dim) qkv qkv.permute(2, 0, 3, 1, 4) # [3, batch_size, num_heads, seq_len, head_dim] Q, K, V qkv[0], qkv[1], qkv[2] # 计算注意力分数 scores torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim) attn_weights torch.softmax(scores, dim-1) # 应用注意力权重到Value上 output torch.matmul(attn_weights, V) output output.transpose(1, 2).reshape(batch_size, seq_len, self.embed_dim) # 最后的线性变换 output self.out(output) return output提示在实际应用中通常会加入mask机制来处理变长序列并防止解码器看到未来的信息。3. 注意力分数计算与可视化理解注意力分数的计算过程对于掌握Transformer至关重要。让我们通过具体示例来演示这一过程。假设我们有以下简化的词向量# 示例词向量 (维度简化为3以便可视化) china torch.tensor([3.0, 6.0, 10.0]) panda torch.tensor([2.8, 5.9, 9.8]) australia torch.tensor([-5.0, -8.0, -12.0]) kangaroo torch.tensor([-4.9, -8.1, -12.1]) # 组合成输入矩阵 (batch_size1, seq_len4, embed_dim3) x torch.stack([china, panda, australia, kangaroo]).unsqueeze(0)计算注意力分数# 初始化注意力层 attention SelfAttention(embed_dim3) # 计算Q, K, V Q, K, V attention(x) # 计算注意力分数 scores torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(3) attn_weights torch.softmax(scores, dim-1) print(注意力权重矩阵:) print(attn_weights)使用Matplotlib可视化注意力权重import matplotlib.pyplot as plt import seaborn as sns def plot_attention(weights, tokens): plt.figure(figsize(8, 6)) sns.heatmap(weights.squeeze().detach().numpy(), xticklabelstokens, yticklabelstokens, cmapYlGnBu, annotTrue) plt.title(注意力权重可视化) plt.show() tokens [中国, 熊猫, 澳大利亚, 袋鼠] plot_attention(attn_weights, tokens)关键观察从可视化结果中我们可以看到中国和熊猫之间的注意力权重较高而中国与澳大利亚之间的权重较低这与我们的语义直觉一致。4. 完整注意力层的实现与验证现在我们将前面介绍的各个组件整合成一个完整的注意力层并在实际数据上进行验证。class TransformerBlock(nn.Module): def __init__(self, embed_dim, num_heads): super().__init__() self.attention MultiHeadAttention(embed_dim, num_heads) self.norm1 nn.LayerNorm(embed_dim) self.ffn nn.Sequential( nn.Linear(embed_dim, 4 * embed_dim), nn.ReLU(), nn.Linear(4 * embed_dim, embed_dim) ) self.norm2 nn.LayerNorm(embed_dim) def forward(self, x): # 自注意力部分 attn_output self.attention(x) x self.norm1(x attn_output) # 残差连接和层归一化 # 前馈网络部分 ffn_output self.ffn(x) x self.norm2(x ffn_output) # 残差连接和层归一化 return x为了验证我们的实现我们可以构建一个简单的分类任务# 示例数据集 sentences [ 中国 熊猫 可爱, 澳大利亚 袋鼠 强壮, 中国 长城 雄伟, 澳大利亚 悉尼 歌剧院 ] labels [0, 1, 0, 1] # 0表示中国相关1表示澳大利亚相关 # 构建词汇表 word_to_idx {PAD: 0, UNK: 1} for sent in sentences: for word in sent.split(): if word not in word_to_idx: word_to_idx[word] len(word_to_idx) # 将句子转换为索引序列 def encode_sentence(sent, word_to_idx): return [word_to_idx.get(word, word_to_idx[UNK]) for word in sent.split()] encoded_sentences [encode_sentence(sent, word_to_idx) for sent in sentences] # 填充序列到相同长度 max_len max(len(sent) for sent in encoded_sentences) padded_sentences [sent [word_to_idx[PAD]] * (max_len - len(sent)) for sent in encoded_sentences] # 转换为PyTorch张量 input_ids torch.tensor(padded_sentences) labels torch.tensor(labels) # 定义简单分类模型 class TransformerClassifier(nn.Module): def __init__(self, vocab_size, embed_dim, num_heads): super().__init__() self.embedding nn.Embedding(vocab_size, embed_dim) self.transformer TransformerBlock(embed_dim, num_heads) self.classifier nn.Linear(embed_dim, 2) def forward(self, x): x self.embedding(x) x self.transformer(x) # 取第一个token的输出作为分类依据 x x[:, 0, :] return self.classifier(x) # 训练模型 model TransformerClassifier(len(word_to_idx), 64, 4) criterion nn.CrossEntropyLoss() optimizer torch.optim.Adam(model.parameters(), lr0.001) for epoch in range(10): optimizer.zero_grad() outputs model(input_ids) loss criterion(outputs, labels) loss.backward() optimizer.step() print(fEpoch {epoch1}, Loss: {loss.item():.4f})注意这只是一个极简化的示例实际应用中需要考虑更复杂的预处理、更大的模型规模和更长的训练过程。5. 高级主题与优化技巧在掌握了基本的注意力机制实现后我们可以探讨一些高级主题和优化技巧5.1 注意力掩码在处理变长序列或实现解码器时注意力掩码是必不可少的。以下是两种常见的掩码类型填充掩码忽略填充token的影响前瞻掩码防止解码器看到未来信息def create_masks(src, trgNone, pad_idx0): # 源序列掩码 (填充掩码) src_mask (src ! pad_idx).unsqueeze(1).unsqueeze(2) if trg is not None: # 目标序列填充掩码 trg_pad_mask (trg ! pad_idx).unsqueeze(1).unsqueeze(2) # 目标序列前瞻掩码 trg_len trg.shape[1] trg_sub_mask torch.tril(torch.ones(trg_len, trg_len)).bool() trg_mask trg_pad_mask trg_sub_mask return src_mask, trg_mask return src_mask5.2 注意力头专业化在实践中不同的注意力头可以学习关注不同类型的模式注意力头类型关注模式典型应用场景局部注意力相邻token关系语法结构分析全局注意力长距离依赖语义关联捕捉特定位置注意力固定位置关系序列标记任务5.3 高效注意力计算对于长序列标准的注意力计算可能变得非常昂贵。以下是一些优化方法稀疏注意力只计算部分位置的注意力分数局部注意力限制每个token只能关注其邻近区域线性注意力使用核技巧近似注意力计算# 局部注意力实现示例 class LocalAttention(nn.Module): def __init__(self, embed_dim, num_heads, window_size): super().__init__() self.embed_dim embed_dim self.num_heads num_heads self.window_size window_size self.head_dim embed_dim // num_heads self.qkv nn.Linear(embed_dim, embed_dim * 3) self.out nn.Linear(embed_dim, embed_dim) def forward(self, x): batch_size, seq_len, _ x.shape # 计算Q, K, V qkv self.qkv(x) qkv qkv.reshape(batch_size, seq_len, 3, self.num_heads, self.head_dim) qkv qkv.permute(2, 0, 3, 1, 4) Q, K, V qkv[0], qkv[1], qkv[2] # 计算局部注意力 output torch.zeros_like(Q) for i in range(seq_len): start max(0, i - self.window_size // 2) end min(seq_len, i self.window_size // 2 1) # 计算当前窗口的注意力 scores torch.matmul(Q[:, :, i:i1, :], K[:, :, start:end, :].transpose(-2, -1)) scores scores / math.sqrt(self.head_dim) attn_weights torch.softmax(scores, dim-1) # 应用注意力权重 output[:, :, i:i1, :] torch.matmul(attn_weights, V[:, :, start:end, :]) # 合并多头输出 output output.transpose(1, 2).reshape(batch_size, seq_len, self.embed_dim) output self.out(output) return output在实际项目中我发现合理设置注意力头的数量和维度对模型性能影响很大。通常embed_dim应该是head_dim的整数倍且head_dim不宜过小至少32或64以确保每个注意力头有足够的表达能力。

更多文章