Transformer架构:输入部分代码实现(基于PyTorch)

部署运行你感兴趣的模型镜像

在Transformer架构中,输入部分是连接原始文本与模型的关键桥梁,其核心任务是将离散的文本序列转换为包含语义和位置信息的连续向量。本文将结合代码实现,详细解析Transformer输入部分的两大核心组件——词嵌入(Embedding)位置编码(Positional Encoding) 的原理与实现细节。


相关文章

Transformer架构:结构介绍网页链接
Transformer架构:核心模块代码实现(基于PyTorch)网页链接
Transformer架构:编码器部分代码实现(基于PyTorch)网页链接
Transformer架构:解码器部分代码实现(基于PyTorch)网页链接
Transformer架构:输出部分代码实现(基于PyTorch)网页链接
Transformer架构:整体实现代码(基于PyTorch)网页链接


一、输入部分的核心目标

输入部分流程图
在这里插入图片描述

Transformer作为无递归结构的模型,无法像RNN那样天然捕捉序列的时序信息。因此,输入部分需要完成两个关键任务:

  1. 语义映射:将离散的词索引(如“我”的索引为2)转换为连续的低维向量,捕捉词的语义特征;
  2. 位置注入:手动加入位置信息,区分“我爱NLP”与“NLP爱我”等同义词序不同的句子。

下面结合代码,逐一解析这两个组件的实现。


二、词嵌入(Embedding):从离散到连续的语义映射

2.1 原理回顾

词嵌入的本质是通过一个可学习的嵌入矩阵,将词表中的每个词(用索引表示)映射为固定维度的向量。例如,词表大小为1024,嵌入维度为512时,嵌入矩阵的形状为[1024, 512],每个行向量即为对应词的语义表示。

为了让词嵌入的量级与后续位置编码匹配(避免位置信息被“淹没”),通常会将嵌入结果乘以√嵌入维度(如√512≈22.6)。

2.2 代码实现

import torch
import torch.nn as nn
import math
from Transformer_config import device, vocabulary_size, embedding_dim  # 导入超参数

class Embedding(nn.Module):
    def __init__(self, vocabulary_size, embedding_dim):
        super().__init__()
        self.vocabulary_size = vocabulary_size  # 词表大小(如1024)
        self.embedding_dim = embedding_dim      # 嵌入维度(如512)
        
        # 定义词嵌入层:输入为词索引,输出为对应嵌入向量
        self.embedding_layer = nn.Embedding(
            num_embeddings=vocabulary_size,  # 词表大小
            embedding_dim=embedding_dim,    # 嵌入维度
            device=device                   # 运行设备(CPU/GPU)
        )

    def forward(self, x):
        # x形状:[batch_size, seq_len](如[2, 4],2个样本,每个样本4个词索引)
        # 1. 通过嵌入层将词索引转为向量
        # 2. 乘以√embedding_dim,平衡与位置编码的量级
        x = self.embedding_layer(x) * math.sqrt(self.embedding_dim)
        return x  # 输出形状:[batch_size, seq_len, embedding_dim](如[2,4,512])

2.3 关键细节解析

  • nn.Embedding:PyTorch内置的嵌入层,会自动初始化一个随机的嵌入矩阵,并在训练中学习优化,使语义相似的词(如“爱”和“喜欢”)的向量距离更近。
  • 缩放操作* math.sqrt(embedding_dim)是为了避免嵌入向量的量级过小(初始随机值通常在[-1,1]),导致后续叠加位置编码时位置信息占比过高。

三、位置编码(Positional Encoding):注入时序信息

3.1 原理回顾

Transformer没有循环或卷积结构,无法天然捕捉词序,因此需要手动加入位置编码。位置编码通过正弦余弦函数生成,公式如下:

PE ( p o s , 2 i ) = sin ⁡ ( p o s 1000 0 2 i / d model ) PE ( p o s , 2 i + 1 ) = cos ⁡ ( p o s 1000 0 2 i / d model ) \text{PE}_{(pos, 2i)} = \sin\left(\frac{pos}{10000^{2i/d_{\text{model}}}}\right) \\ \text{PE}_{(pos, 2i+1)} = \cos\left(\frac{pos}{10000^{2i/d_{\text{model}}}}\right) PE(pos,2i)=sin(100002i/dmodelpos)PE(pos,2i+1)=cos(100002i/dmodelpos)

其中:

  • pos为词在序列中的位置(从0开始);
  • i为向量维度索引(0 ≤ i < d_model/2);
  • 正弦函数用于偶数维度,余弦函数用于奇数维度。

核心优势

  • 周期性:利用三角函数的周期性(如sin(α+β) = sinα·cosβ + cosα·sinβ),使模型能捕捉相对位置(如第3个词与第5个词的距离);
  • 无长度限制:三角函数对任意长度的序列都能生成有效编码,适用于超长文本。

3.2 代码实现

class PositionalEncoding(nn.Module):
    def __init__(self, embedding_dim, max_len, dropout):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)  # Dropout层,防止过拟合
        
        # 1. 初始化位置编码矩阵PE,形状为[max_len, embedding_dim](如[64, 512])
        pe = torch.zeros(max_len, embedding_dim, device=device)
        
        # 2. 生成位置序列:[0, 1, ..., max_len-1],形状转为[max_len, 1]
        position = torch.arange(0, max_len, device=device).unsqueeze(1)
        
        # 3. 生成衰减因子:10000^(2i/d_model)的倒数(用指数函数实现,更稳定)
        # 维度索引i的范围:[0, 2, 4, ..., embedding_dim-2](共embedding_dim/2个)
        div_term = torch.exp(
            torch.arange(0, embedding_dim, 2, device=device) * 
            (-math.log(10000.0) / embedding_dim)
        )
        
        # 4. 填充PE矩阵:偶数维度用sin,奇数维度用cos
        pe[:, 0::2] = torch.sin(position * div_term)  # 0::2表示步长为2,取偶数索引
        pe[:, 1::2] = torch.cos(position * div_term)  # 1::2表示步长为2,取奇数索引
        
        # 5. 扩展维度为[1, max_len, embedding_dim],适配批量输入(batch_size维度)
        pe = pe.unsqueeze(0)
        
        # 6. 注册为缓冲区(不参与训练的固定参数,随模型保存/加载)
        self.register_buffer('pe', pe)

    def forward(self, x):
        # x形状:[batch_size, seq_len, embedding_dim](如[2,4,512])
        # 1. 叠加位置编码:仅取与输入序列长度匹配的位置编码(x.shape[1]为实际序列长度)
        x = x + self.pe[:, :x.shape[1]]  # self.pe[:, :x.shape[1]]形状:[1, seq_len, 512]
        # 2. 应用Dropout,增强泛化能力
        x = self.dropout(x)
        return x  # 输出形状不变:[batch_size, seq_len, embedding_dim]

3.3 关键细节解析

  • PE矩阵生成:通过position * div_term计算每个位置的编码值,其中div_term随维度索引增大而衰减,确保不同维度捕捉不同尺度的位置信息(低维度捕捉短距离,高维度捕捉长距离)。
  • register_buffer:将PE矩阵注册为缓冲区,使其不参与反向传播(无需训练),但会随模型一起保存,避免每次推理时重新生成。
  • Dropout层:在叠加位置编码后加入Dropout,随机丢弃部分位置信息,防止模型过度依赖特定位置的编码,增强泛化能力。

四、输入部分完整流程与示例

4.1 完整流程

输入部分的整体流程为:

  1. 词索引输入:原始输入为批量词索引(如[[1,2,3,4], [5,6,7,8]]);
  2. 词嵌入:通过Embedding类将索引转为向量并缩放;
  3. 位置编码:通过PositionalEncoding类叠加位置信息,输出最终输入向量。

4.2 代码示例与输出

def test_input():
    # 1. 定义输入:2个样本,每个样本4个词索引(形状[2,4])
    x = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]], device=device)
    
    # 2. 初始化词嵌入和位置编码
    embedding = Embedding(vocabulary_size, embedding_dim)
    positional_encoding = PositionalEncoding(embedding_dim, max_len, dropout)
    
    # 3. 执行词嵌入
    embed_x = embedding(x)
    print(f"词嵌入后形状:{embed_x.shape}")  # 输出:torch.Size([2, 4, 512])
    
    # 4. 执行位置编码
    pe_x = positional_encoding(embed_x)
    print(f"位置编码后形状:{pe_x.shape}")  # 输出:torch.Size([2, 4, 512])

if __name__ == '__main__':
    from Transformer_config import max_len, dropout  # 导入超参数
    test_input()

五、完整代码

Transformer_config部分:

# 导入必备的工具包
import torch
import copy

# 预定义的网络层torch.nn, 工具开发者已经帮助我们开发好的一些常用层,
import torch.nn as nn
import torch.nn.functional as F

# 数学计算工具包
import math

# torch中变量封装函数Variable.
from torch.autograd import Variable

# 模型超参数及配置项
# device: 指定模型运行设备(CPU/GPU)
# vocabulary_size: 词表总大小
# embedding_dim: 词向量嵌入维度
# d_ff: 前馈神经网络中间层维度
# dropout: dropout层丢弃概率
# max_len: 序列最大处理长度
# heads: 多头注意力机制的头数
# batch_size: 训练批次大小
device = 'cpu'
vocabulary_size = 1024
embedding_dim = 512
d_ff = 1024
dropout = 0.1
max_len = 64
heads = 8
batch_size = 2

Transformer_input部分:

from Transformer_config import *  # 导入配置文件中的模型参数和全局变量

# Embeddings类 实现思路分析
# 1 init函数 (self, d_model, vocab)
    # 设置类属性 定义词嵌入层 self.lut层
# 2 forward(x)函数
    # self.lut(x) * math.sqrt(self.d_model)
class Embedding(nn.Module):
    """
    词嵌入层,将离散的词汇索引转换为连续的向量表示
    
    参数:
        vocabulary_size (int): 词汇表大小,表示可处理的不同词汇数量
        embedding_dim (int): 词嵌入维度,每个词汇对应的特征向量长度
    """
    def __init__(self,vocabulary_size,embedding_dim):
        # 参数vocab   词汇表大小
        # 参数d_model 每个词汇的特征尺寸 词嵌入维度
        super().__init__()  # 调用nn.Module基类的初始化方法
        self.vocabulary_size = vocabulary_size  # 存储词汇表大小
        self.embedding_dim = embedding_dim  # 存储词嵌入维度

        # 定义吃嵌入层
        # 创建PyTorch嵌入层:vocabulary_size表示词汇表大小,embedding_dim表示向量维度
        self.embedding_layer = nn.Embedding(vocabulary_size,embedding_dim,device=device)

    def forward(self,x):
        """
        前向传播过程:将输入索引转换为嵌入向量并缩放
        
        参数:
            x (Tensor): 输入词汇索引张量,形状为[batch_size, seq_len]
            
        返回:
            Tensor: 缩放后的词嵌入向量,形状为[batch_size, seq_len, embedding_dim]
        """
        # 将x传给self.lut并与根号下self.d_model相乘作为结果返回
        # x经过词嵌入后 增大x的值, 词嵌入后的embedding_vector+位置编码信息,值量纲差差不多
        # 通过嵌入层将输入索引转换为稠密向量
        x = self.embedding_layer(x) 
        # 缩放嵌入向量使其与位置编码的量级匹配
        x = x * math.sqrt(self.embedding_dim)
        return x  # 返回缩放后的词嵌入结果

# 位置编码器类PositionalEncoding 实现思路分析
# 1 init函数  (self, d_model, dropout, max_len=5000)
#   super()函数 定义层self.dropout
#   定义位置编码矩阵pe  定义位置列-矩阵position 定义变化矩阵div_term
#   套公式div_term = torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000.0)/d_model))
#   位置列-矩阵 * 变化矩阵 阿达码积my_matmulres
#   给pe矩阵偶数列奇数列赋值 pe[:, 0::2] pe[:, 1::2]
#   pe矩阵注册到模型缓冲区 pe.unsqueeze(0)三维 self.register_buffer('pe', pe)
# 2 forward(self, x) 返回self.dropout(x)
#   给x数据添加位置特征信息 x = x + Variable( self.pe[:,:x.size()[1]], requires_grad=False)
class PositionalEncoding(nn.Module):
    """
    位置编码层,为序列添加位置信息以保留顺序关系
    
    参数:
        embedding_dim (int): 词嵌入维度,需与词嵌入层输出维度一致
        max_len (int): 支持的最大序列长度
        dropout (float): Dropout层的丢弃概率
    """
    def __init__(self,embedding_dim,max_len,dropout):
        # 参数d_model 词嵌入维度 eg: 512个特征
        # 参数max_len 单词token个数 eg: 64个单词
        super().__init__()  # 调用nn.Module基类的初始化方法

        # 定义dropout层
        # 创建Dropout层,p=dropout指定丢弃概率
        self.dropout = nn.Dropout(p=dropout)

        # 思路:位置编码矩阵 + 特征矩阵 相当于给特征增加了位置信息
        # 定义位置编码矩阵PE eg pe[60, 512], 位置编码矩阵和特征矩阵形状是一样的
        # 初始化位置编码矩阵:max_len×embedding_dim的全零张量
        pe = torch.zeros(max_len,embedding_dim,device=device)

        # 定义位置列-矩阵position  数据形状[max_len,1] eg: [0,1,2,3,4...64]^T
        # 创建位置索引序列并增加一个维度变为列向量
        position = torch.arange(0,max_len).unsqueeze(1)

        # 定义变化矩阵div_term [1,256]
        # torch.arange(start=1, end=512, 2)结果并不包含end。在start和end之间做一个等差数组 [0, 2, 4, 6 ... 510]
        # 计算频率调节因子:用于生成不同频率的正弦/余弦波
        div_term = torch.exp(torch.arange(0,embedding_dim,2) * -(math.log(10000.0)/embedding_dim))

        # 位置列-矩阵 @ 变化矩阵 做矩阵运算 [64*1]@ [1*256] ==> 64 *256
        # 矩阵相乘也就是行列对应位置相乘再相加,其含义,给每一个列属性(列特征)增加位置编码信息
        # 计算位置和频率的乘积矩阵
        my_matmulres = position * div_term

        # 给位置编码矩阵奇数列,赋值sin曲线特征
        # 对偶数索引位置应用正弦函数
        pe[:,0::2] = torch.sin(my_matmulres)
        # 给位置编码矩阵偶数列,赋值cos曲线特征
        # 对奇数索引位置应用余弦函数
        pe[:,1::2] = torch.cos(my_matmulres)

        # 形状变化 [60,512]-->[1,64,512]
        # 增加批次维度:从[seq_len, d_model]变为[1, seq_len, d_model]
        pe = pe.unsqueeze(0)

        # 把pe位置编码矩阵 注册成模型的持久缓冲区buffer; 模型保存再加载时,可以根模型参数一样,一同被加载
        # 什么是buffer: 对模型效果有帮助的,但是却不是模型结构中超参数或者参数,不参与模型训练
        # 将位置编码注册为模型缓冲区(不参与训练但会保存)
        self.register_buffer('pe',pe)

    def forward(self,x):
        """
        前向传播过程:为输入序列添加位置编码信息
        
        参数:
            x (Tensor): 输入张量,形状为[batch_size, seq_len, embedding_dim]
            
        返回:
            Tensor: 添加位置编码后的张量,形状不变
        """
        # x--》来自于embedding之后的结果--》[batch_size, seq_len, embed_dim]-->[2, 4, 512]
        # 将x和位置编码的信息进行融合
        # 添加位置编码:截取与输入序列长度匹配的部分
        x = x + self.pe[:,:x.shape[1]]
        # 应用Dropout防止过拟合
        x = self.dropout(x)
        return x  # 返回添加位置编码后的结果


def untest_input(x):
    """
    测试函数:验证词嵌入和位置编码模块的功能
    
    参数:
        x (Tensor): 输入词汇索引张量,形状为[batch_size, seq_len]
        
    返回:
        Tensor: 经过词嵌入和位置编码处理后的张量
    """
    # 创建词嵌入层实例
    my_embedding = Embedding(vocabulary_size,embedding_dim)

    # 应用词嵌入层处理输入
    embed_x = my_embedding(x)

    # 创建位置编码层实例
    my_positiona = PositionalEncoding(embedding_dim,max_len,dropout)
    # 应用位置编码层处理嵌入结果
    pe_result = my_positiona(embed_x)

    # 打印原始输入信息
    print(f'原始输入为:{x}')
    # 打印词嵌入处理结果
    print(f'{"*"*30}词嵌入{"*"*30}')
    print(f'词嵌入后的形状:{embed_x.shape}')
    # 打印位置编码处理结果
    print(f'{"*"*30}位置编码{"*"*30}')
    print(f'位置编码后的形状:{pe_result.shape}')

    # 返回处理后的结果
    return pe_result

if __name__ == '__main__':
    """
    主程序入口:执行测试函数
    """
    # 创建测试输入张量:2个样本,每个样本4个token
    test_input = torch.tensor([[1,2,3,4],[5,6,7,8]])
    # 执行测试函数
    untest_input(test_input)

您可能感兴趣的与本文相关的镜像

PyTorch 2.5

PyTorch 2.5

PyTorch
Cuda

PyTorch 是一个开源的 Python 机器学习库,基于 Torch 库,底层由 C++ 实现,应用于人工智能领域,如计算机视觉和自然语言处理

<think>我们之前讨论的是3D-DCGAN的TensorFlow实现,现在用户要求获取基于PyTorch代码实现。根据引用内容,用户提供了基于PyTorch的文本分类和Transformer实现的例子,但我们的主题是3D-DCGAN。因此,我们需要将之前的TensorFlow代码转换为PyTorch实现。 核心要点: 1. 生成器(Generator)和判别器(Discriminator)的架构需要转换为PyTorch的nn.Module形式。 2. 训练循环需要按照PyTorch的风格编写(包括梯度清零、反向传播等)。 3. 使用PyTorch的优化器和损失函数。 以下是基于PyTorch的3D-DCGAN实现: 步骤1:导入必要的库 步骤2:定义生成器和判别器 步骤3:定义训练过程 注意:我们依然使用64x64x64的体素网格,生成器的输入是潜在向量z,输出是64x64x64x1的体素;判别器输入64x64x64x1的体素,输出一个标量(真/假)。 代码实现: </think>以下是根据您需求提供的基于PyTorch的3D-DCGAN代码实现方案: ### 一、核心代码实现PyTorch版) ```python import torch import torch.nn as nn # 生成器架构 class Generator(nn.Module): def __init__(self, latent_dim): super(Generator, self).__init__() self.main = nn.Sequential( # 输入: 潜在空间向量 (z) nn.Linear(latent_dim, 4*4*4*256), nn.BatchNorm1d(4*4*4*256), nn.LeakyReLU(0.2), # 重构为3D张量 nn.Unflatten(1, (256, 4, 4, 4)), # 3D转置卷积上采样 nn.ConvTranspose3d(256, 128, kernel_size=5, stride=2, padding=2, bias=False), nn.BatchNorm3d(128), nn.LeakyReLU(0.2), # 输出: (128, 8, 8, 8) nn.ConvTranspose3d(128, 64, kernel_size=5, stride=2, padding=2, bias=False), nn.BatchNorm3d(64), nn.LeakyReLU(0.2), # 输出: (64, 16, 16, 16) nn.ConvTranspose3d(64, 32, kernel_size=5, stride=2, padding=2, bias=False), nn.BatchNorm3d(32), nn.LeakyReLU(0.2), # 输出: (32, 32, 32, 32) # 最终输出层 nn.ConvTranspose3d(32, 1, kernel_size=5, stride=2, padding=2, bias=False), nn.Tanh() # 输出: (1, 64, 64, 64) ) def forward(self, input): return self.main(input) # 判别器架构(带自注意力机制) class Discriminator(nn.Module): def __init__(self): super(Discriminator, self).__init__() self.conv_layers = nn.Sequential( # 输入: 1x64x64x64 nn.Conv3d(1, 32, kernel_size=5, stride=2, padding=2), nn.LeakyReLU(0.2), nn.Dropout3d(0.3), # 输出: 32x32x32x32 nn.Conv3d(32, 64, kernel_size=5, stride=2, padding=2), nn.LeakyReLU(0.2), nn.Dropout3d(0.3), # 输出: 64x16x16x16 nn.Conv3d(64, 128, kernel_size=5, stride=2, padding=2), nn.LeakyReLU(0.2), nn.Dropout3d(0.3) # 输出: 128x8x8x8 ) # 自注意力层(参考引用[2]的实现) self.attention = SelfAttention(128) self.final_layers = nn.Sequential( nn.Conv3d(128, 256, kernel_size=5, stride=2, padding=2), nn.LeakyReLU(0.2), nn.Dropout3d(0.3), # 输出: 256x4x4x4 nn.Flatten(), nn.Linear(256*4*4*4, 1), nn.Sigmoid() ) def forward(self, input): features = self.conv_layers(input) attended = self.attention(features) return self.final_layers(attended) # 自注意力模块(基于引用[2]实现) class SelfAttention(nn.Module): def __init__(self, in_channels): super(SelfAttention, self).__init__() self.query = nn.Conv3d(in_channels, in_channels//8, kernel_size=1) self.key = nn.Conv3d(in_channels, in_channels//8, kernel_size=1) self.value = nn.Conv3d(in_channels, in_channels, kernel_size=1) self.gamma = nn.Parameter(torch.zeros(1)) def forward(self, x): batch_size, C, D, H, W = x.size() query = self.query(x).view(batch_size, -1, D*H*W).permute(0, 2, 1) key = self.key(x).view(batch_size, -1, D*H*W) energy = torch.bmm(query, key) attention = F.softmax(energy, dim=-1) value = self.value(x).view(batch_size, -1, D*H*W) out = torch.bmm(value, attention.permute(0, 2, 1)) out = out.view(batch_size, C, D, H, W) return self.gamma * out + x ``` ### 二、训练配置(PyTorch版) ```python # 初始化模型 latent_dim = 100 generator = Generator(latent_dim) discriminator = Discriminator() # 损失函数与优化器 criterion = nn.BCELoss() g_optimizer = torch.optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999)) d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999)) # 训练循环 def train(real_volumes): batch_size = real_volumes.size(0) # 训练判别器 d_optimizer.zero_grad() # 真实样本 real_labels = torch.ones(batch_size, 1) real_output = discriminator(real_volumes) d_loss_real = criterion(real_output, real_labels) # 生成样本 noise = torch.randn(batch_size, latent_dim) fake_volumes = generator(noise) fake_labels = torch.zeros(batch_size, 1) fake_output = discriminator(fake_volumes.detach()) d_loss_fake = criterion(fake_output, fake_labels) d_loss = d_loss_real + d_loss_fake d_loss.backward() d_optimizer.step() # 训练生成器 g_optimizer.zero_grad() fake_output = discriminator(fake_volumes) g_loss = criterion(fake_output, real_labels) # 欺骗判别器 g_loss.backward() g_optimizer.step() return d_loss.item(), g_loss.item() ``` ### 三、关键改进点 1. **自注意力机制** 在判别器中引入自注意力模块[^2],增强模型对长距离空间依赖关系的捕捉能力: $$ \text{Attention}(Q,K,V) = \text{softmax}\left(\frac{QK^\top}{\sqrt{d_k}}\right)V $$ 其中$Q$, $K$, $V$分别表示查询、键和值矩阵。 2. **梯度惩罚优化** ```python # Wasserstein GAN引入梯度惩罚 alpha = torch.rand(batch_size, 1, 1, 1, 1) interpolates = (alpha * real_volumes + (1 - alpha) * fake_volumes).requires_grad_(True) d_interpolates = discriminator(interpolates) gradients = torch.autograd.grad( outputs=d_interpolates, inputs=interpolates, grad_outputs=torch.ones_like(d_interpolates), create_graph=True, retain_graph=True)[0] # 梯度惩罚项 gradient_penalty = ((gradients.norm(2, dim=(1,2,3,4)) - 1)**2).mean() * 10 d_loss += gradient_penalty ``` 3. **Transformer编码器增强**(参考引用[3]) ```python class TransformerEncoder(nn.Module): def __init__(self, embed_dim): super().__init__() self.encoder_layer = nn.TransformerEncoderLayer( d_model=embed_dim, nhead=8, dim_feedforward=2048) self.transformer_encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=6) def forward(self, x): # 将3D体素展开为序列 x = x.flatten(2).permute(2, 0, 1) return self.transformer_encoder(x) ``` ### 四、性能优化技巧 1. **混合精度训练** ```python from torch.cuda.amp import autocast, GradScaler scaler = GradScaler() with autocast(): fake_volumes = generator(noise) # ... 后续计算 ... scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() ``` 2. **分布式训练** ```python # 初始化分布式训练 torch.distributed.init_process_group(backend='nccl') generator = nn.parallel.DistributedDataParallel(generator) discriminator = nn.parallel.DistributedDataParallel(discriminator) ```
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值