pytorch中一些常用的模块中的函数(持续更新)

博客围绕PyTorch展开,介绍了常用函数,如拼接张量;还提及快速定义模型的方法,可作为搭建网络模块。此外,涵盖了模型的保存与读取操作,以及创建可训练数据集和常用优化器等内容。
部署运行你感兴趣的模型镜像
import torch # 神经网络用
import torch.nn.functional as F # 激活函数使用
import torch.utils.data as Data # 批处理使用

常用函数

torch.nn.Linear() # 线性层
torch.nn.MSELoss() # 均方差损失函数(回归用)
torch.nn.CrossEntropyLoss() # # 交叉熵损失函数(分类用)
torch.optim.SGD(net.parameters(), lr=0.2) # SGD优化器
torch.nn.functional.relu # 这是一个函数,所以在functional中而不是作为torch.nn.的一个层

拼接张量(例子)

x0 = torch.normal(2*n_data, 1)
y0 = torch.zeros(100)
x1 = torch.normal(-2*n_data, 1)
y1 = torch.ones(100)
x = torch.cat((x0,x1), 0).type(torch.FloatTensor) # 按行拼接两个张量
y = torch.cat((y0,y1), 0).type(torch.LongTensor)

快速定义模型的方法(例子),也可以看成是在搭建网络的一个模块,之后在搭建网络时使用。

net2 = torch.nn.Sequential(
    torch.nn.Linear(2, 10),
    torch.nn.ReLU(),
    torch.nn.Linear(10,2))

模型的保存

torch.save(net1, 'net.pkl') # 保存模型与模型参数全部信息
    torch.save(net1.state_dict(),'net_params.pkl') # 只保存参数

模型的读取

net2 = torch.load('net.pkl')# 读net1的全部信息给net2
							# 这样相当于net2是net的完全复制,从模型结构到训练的参数信息
net3.load_state_dict(torch.load('net_params.pkl')) # 把net1的结构参数信息给net3
												   # net3要和net的结构完全相同,然后把net使用的优化器,损失函数、
												   # 以及参数信息给net3

创建torch可训练的数据集


x = torch.linspace(1, 10, 10) # 生成1到10 按顺序十个点
y = torch.linspace(10, 1, 10)

torch_dataset = Data.TensorDataset(x, y)

loader = Data.DataLoader( # 设置训练数据的各种属性
    dataset = torch_dataset,
    batch_size = BATCH_SIZE,
    shuffle = True, # 是否打乱
    num_workers = 2, # 用几个进程提取数据
)

常用优化器

opt_SGD = torch.optim.SGD(net_SGD.parameters(), lr=LR)
opt_Momentum = torch.optim.SGD(net_Momentum.parameters(), lr=LR, momentum=0.8)
opt_RMSprop = torch.optim.RMSprop(net_RMSprop.parameters(), lr=LR, alpha=0.9)
opt_Adam = torch.optim.Adam(net_Adam.parameters(), lr=LR, betas=(0.9, 0.99))

您可能感兴趣的与本文相关的镜像

PyTorch 2.5

PyTorch 2.5

PyTorch
Cuda

PyTorch 是一个开源的 Python 机器学习库,基于 Torch 库,底层由 C++ 实现,应用于人工智能领域,如计算机视觉和自然语言处理

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值