请根据下面的Python代码生成架构图import matplotlib.pyplot as plt
import numpy as np
from matplotlib.patches import Rectangle, Circle, Polygon, Arrow
from matplotlib.text import TextPath
from matplotlib.font_manager import FontProperties
import matplotlib.patheffects as path_effects
# 设置中文字体支持
plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号
# 创建画布
fig, ax = plt.subplots(figsize=(16, 10))
ax.set_facecolor('#f0f8ff') # 淡蓝色背景
ax.set_xlim(0, 16)
ax.set_ylim(0, 10)
ax.axis('off') # 隐藏坐标轴
# 设置颜色方案
colors = {
'input': '#4e79a7', # 深蓝色
'encoder': '#59a14f', # 绿色
'decoder': '#e15759', # 红色
'attention': '#edc948', # 黄色
'ffn': '#b07aa1', # 紫色
'output': '#ff9da7', # 粉色
'connection': '#9c755f', # 棕色
'text': '#2e2e2e' # 深灰色
}
# ====================== 输入处理层 ======================
def draw_input_layer():
"""绘制输入处理层:输入嵌入 + 位置编码"""
# 输入嵌入
ax.add_patch(Rectangle((1, 7.5), 1.5, 0.8, facecolor=colors['input'], alpha=0.8))
plt.text(1.75, 7.9, '输入嵌入', ha='center', va='center', color='white', fontsize=10)
# 位置编码
ax.add_patch(Rectangle((1, 6.5), 1.5, 0.8, facecolor=colors['input'], alpha=0.8))
plt.text(1.75, 6.9, '位置编码', ha='center', va='center', color='white', fontsize=10)
# 加法操作
ax.add_patch(Circle((2.5, 7.0), radius=0.15, facecolor='#76b7b2'))
plt.text(2.5, 7.0, '+', ha='center', va='center', color='white', fontsize=12)
# 输入箭头
ax.annotate('', xy=(1, 8.0), xytext=(0.5, 8.0),
arrowprops=dict(arrowstyle='->', lw=1.5, color=colors['text']))
plt.text(0.3, 8.2, '输入序列\n["The", "cat", "sat"]', fontsize=9)
# 位置编码公式
eq_text = r'$PE_{(pos,2i)} = \sin\left(\frac{pos}{10000^{2i/d}}\right)$'
plt.text(3.5, 7.0, eq_text, fontsize=12, color=colors['text'])
# ====================== 编码器层 ======================
def draw_encoder_layers():
"""绘制编码器层(6层)"""
# 绘制6个编码器层
for i in range(6):
y_pos = 8.0 - i * 1.0
layer_color = colors['encoder']
# 编码器层框
ax.add_patch(Rectangle((4.5, y_pos-0.4), 3.0, 0.8, facecolor=layer_color, alpha=0.2))
# 多头自注意力
ax.add_patch(Rectangle((4.7, y_pos-0.15), 1.2, 0.3, facecolor=colors['attention'], alpha=0.7))
plt.text(5.3, y_pos, '多头自注意力', ha='center', va='center', fontsize=9)
# Add & Norm
ax.add_patch(Rectangle((6.0, y_pos-0.1), 0.7, 0.2, facecolor='#76b7b2', alpha=0.7))
plt.text(6.35, y_pos, 'Add & Norm', ha='center', va='center', fontsize=8)
# 前馈网络
ax.add_patch(Rectangle((4.7, y_pos-0.4), 1.2, 0.3, facecolor=colors['ffn'], alpha=0.7))
plt.text(5.3, y_pos-0.25, '前馈网络', ha='center', va='center', fontsize=9)
# Add & Norm
ax.add_patch(Rectangle((6.0, y_pos-0.35), 0.7, 0.2, facecolor='#76b7b2', alpha=0.7))
plt.text(6.35, y_pos-0.25, 'Add & Norm', ha='center', va='center', fontsize=8)
# 层编号
plt.text(4.2, y_pos, f'编码器 {i+1}', fontsize=9, rotation=90)
# 连接线
if i > 0:
ax.annotate('', xy=(5.5, y_pos+0.4), xytext=(5.5, y_pos+0.6),
arrowprops=dict(arrowstyle='->', lw=1.0, color=colors['connection']))
# 编码器标题
plt.text(6.0, 8.8, '编码器堆栈 (×6)', fontsize=12, fontweight='bold', color=colors['encoder'])
# 多头注意力公式
attn_text = r'$\text{Attention}(Q,K,V)=\text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V$'
plt.text(8.0, 7.0, attn_text, fontsize=11, color=colors['text'])
# 前馈网络公式
ffn_text = r'$\text{FFN}(x)=\max(0,xW_1+b_1)W_2+b_2$'
plt.text(8.0, 6.2, ffn_text, fontsize=11, color=colors['text'])
# ====================== 解码器层 ======================
def draw_decoder_layers():
"""绘制解码器层(6层)"""
# 绘制6个解码器层
for i in range(6):
y_pos = 8.0 - i * 1.0
layer_color = colors['decoder']
# 解码器层框
ax.add_patch(Rectangle((10.5, y_pos-0.4), 3.0, 0.8, facecolor=layer_color, alpha=0.2))
# 掩码多头自注意力
ax.add_patch(Rectangle((10.7, y_pos-0.15), 1.8, 0.3, facecolor=colors['attention'], alpha=0.7))
plt.text(11.6, y_pos, '掩码多头注意力', ha='center', va='center', fontsize=9)
# Add & Norm
ax.add_patch(Rectangle((12.6, y_pos-0.1), 0.7, 0.2, facecolor='#76b7b2', alpha=0.7))
plt.text(12.95, y_pos, 'Add & Norm', ha='center', va='center', fontsize=8)
# 编码器-解码器注意力
ax.add_patch(Rectangle((10.7, y_pos-0.4), 1.8, 0.3, facecolor=colors['attention'], alpha=0.7))
plt.text(11.6, y_pos-0.25, '编码器-解码器注意力', ha='center', va='center', fontsize=9)
# Add & Norm
ax.add_patch(Rectangle((12.6, y_pos-0.35), 0.7, 0.2, facecolor='#76b7b2', alpha=0.7))
plt.text(12.95, y_pos-0.25, 'Add & Norm', ha='center', va='center', fontsize=8)
# 层编号
plt.text(10.2, y_pos, f'解码器 {i+1}', fontsize=9, rotation=90)
# 连接线
if i > 0:
ax.annotate('', xy=(11.8, y_pos+0.4), xytext=(11.8, y_pos+0.6),
arrowprops=dict(arrowstyle='->', lw=1.0, color=colors['connection']))
# 解码器标题
plt.text(12.0, 8.8, '解码器堆栈 (×6)', fontsize=12, fontweight='bold', color=colors['decoder'])
# 输出嵌入
ax.add_patch(Rectangle((10.5, 3.5), 1.5, 0.8, facecolor=colors['input'], alpha=0.8))
plt.text(11.25, 3.9, '输出嵌入', ha='center', va='center', color='white', fontsize=10)
# 位置编码
ax.add_patch(Rectangle((10.5, 2.5), 1.5, 0.8, facecolor=colors['input'], alpha=0.8))
plt.text(11.25, 2.9, '位置编码', ha='center', va='center', color='white', fontsize=10)
# 加法操作
ax.add_patch(Circle((12.0, 3.0), radius=0.15, facecolor='#76b7b2'))
plt.text(12.0, 3.0, '+', ha='center', va='center', color='white', fontsize=12)
# ====================== 输出层 ======================
def draw_output_layer():
"""绘制输出层"""
# 线性层
ax.add_patch(Rectangle((14.0, 3.0), 1.0, 0.6, facecolor=colors['output'], alpha=0.8))
plt.text(14.5, 3.3, '线性层', ha='center', va='center', fontsize=10)
# Softmax
ax.add_patch(Rectangle((14.0, 2.2), 1.0, 0.6, facecolor=colors['output'], alpha=0.8))
plt.text(14.5, 2.5, 'Softmax', ha='center', va='center', fontsize=10)
# 输出箭头
ax.annotate('', xy=(15.0, 2.5), xytext=(15.5, 2.5),
arrowprops=dict(arrowstyle='->', lw=1.5, color=colors['text']))
plt.text(15.7, 2.7, '输出序列\n["Le", "chat", "s\'assit"]', fontsize=9)
# 输出概率分布图
words = ['Le', 'chat', "s'assit", 'est', 'assis', '...']
probs = [0.45, 0.25, 0.15, 0.07, 0.05, 0.03]
for i, (word, prob) in enumerate(zip(words, probs)):
y_pos = 1.0 - i * 0.3
plt.barh(y_pos, prob * 4, height=0.2, color=colors['output'], alpha=0.6)
plt.text(prob * 4 + 0.1, y_pos, f'{word}: {prob:.2f}', va='center', fontsize=9)
plt.text(0.5, 1.8, '输出概率分布', fontsize=10, fontweight='bold')
# ====================== 连接线 ======================
def draw_connections():
"""绘制各组件之间的连接线"""
# 输入层 → 编码器
ax.annotate('', xy=(4.5, 7.0), xytext=(2.5, 7.0),
arrowprops=dict(arrowstyle='->', lw=1.5, color=colors['connection']))
# 编码器 → 解码器 (跨层连接)
for i in range(6):
y_enc = 7.6 - i * 1.0
y_dec = 7.6 - i * 1.0
ax.annotate('', xy=(10.5, y_dec), xytext=(7.5, y_enc),
arrowprops=dict(arrowstyle='->', lw=1.5, color=colors['connection'], linestyle='-'))
# 解码器内部连接
ax.annotate('', xy=(12.6, 7.0), xytext=(13.5, 7.0),
arrowprops=dict(arrowstyle='->', lw=1.5, color=colors['connection']))
# 输出嵌入 → 解码器
ax.annotate('', xy=(10.5, 3.0), xytext=(12.0, 3.0),
arrowprops=dict(arrowstyle='->', lw=1.5, color=colors['connection']))
# 解码器 → 输出层
ax.annotate('', xy=(13.5, 3.0), xytext=(14.0, 3.0),
arrowprops=dict(arrowstyle='->', lw=1.5, color=colors['connection']))
# 输出层 → 最终输出
ax.annotate('', xy=(15.0, 2.5), xytext=(15.5, 2.5),
arrowprops=dict(arrowstyle='->', lw=1.5, color=colors['connection']))
# ====================== 技术说明 ======================
def add_technical_notes():
"""添加技术说明"""
notes = [
"核心参数:",
"- 嵌入维度: $d_{model}=512$",
"- 注意力头数: $h=8$",
"- 前馈网络维度: $d_{ff}=2048$",
"- 编码器/解码器层数: $N=6$",
"",
"关键技术:",
"• 多头注意力: 并行捕获不同语义信息",
"• 残差连接: 缓解深层网络梯度消失",
"• 层归一化: 稳定训练过程",
"• 位置编码: 注入序列顺序信息",
"",
"中文优化:",
"- RoPE旋转位置编码优化长文本处理",
"- 适应中文分词特性"
]
for i, note in enumerate(notes):
plt.text(0.5, 4.0 - i * 0.4, note, fontsize=9, color=colors['text'])
# ====================== 主标题 ======================
def add_title():
"""添加主标题"""
title = "Transformer 架构详解 (Vaswani et al., 2017)"
plt.text(8, 9.5, title, fontsize=18, fontweight='bold',
ha='center', color=colors['text'])
# 添加效果使标题更突出
title_text = plt.text(8, 9.5, title, fontsize=18, fontweight='bold',
ha='center', color='white')
title_text.set_path_effects([
path_effects.Stroke(linewidth=3, foreground=colors['text']),
path_effects.Normal()
])
# ====================== 执行绘图 ======================
draw_input_layer()
draw_encoder_layers()
draw_decoder_layers()
draw_output_layer()
draw_connections()
add_technical_notes()
add_title()
# 保存高质量图像
plt.tight_layout()
plt.savefig('transformer_architecture.png', dpi=300, bbox_inches='tight')
plt.show()
最新发布