Pytorch读取存储

部署运行你感兴趣的模型镜像

Pytorch读写Tensor

摘自:动手学深度学习PYTORCH版(DEMO)
链接:https://link.zhihu.com/?target=https%3A//github.com/OUCMachineLearning/OUCML/blob/master/BOOK/Dive-into-DL-PyTorch.pdf在这里插入图片描述

import torch
from torch import nn
x = torch.ones(3)
torch.save(x, 'x.pt')

然后将数据从存储的文件读回内存。

x2 = torch.load('x.pt')
x2

输出:

tensor([1., 1., 1.])

还可以存储一个Tensor列表并读回内存

y = torch.zeros(4)
torch.save([x, y], 'xy.pt')
xy_list = torch.load('xy.pt')
xy_list

输出:

[tensor([1., 1., 1.]), tensor([0., 0., 0., 0.])]

存储并读取一个从字符串映射到Tensor的字典

torch.save({
'x': x, 
'y': y
}, 'xy_dict.pt')
xy = torch.load('xy_dict.pt')
xy

输出:

{'x': tensor([1., 1., 1.]), 'y': tensor([0., 0., 0., 0.])}

Pytorch读写模型

state_dict

在这里插入图片描述

class MLP(nn.Module):
	def __init__(self):
		super(MLP, self).__init__()
		self.hidden = nn.Linear(3, 2)
		self.act = nn.ReLU()
		self.output = nn.Linear(2, 1)
		
	def forward(self, x):
		a = self.act(self.hidden(x))
		return self.output(a)
net = MLP()
net.state_dict()

输出:

OrderedDict([('hidden.weight', tensor([[ 0.2448, 0.1856, -0.5678],
[ 0.2030, -0.2073, -0.0104]])),
('hidden.bias', tensor([-0.3117, -0.4232])),
('output.weight', tensor([[-0.4556, 0.4084]])),
('output.bias', tensor([-0.3573]))])

在这里插入图片描述

optimizer = torch.optim.SGD(net.parameters(), lr=0.001,
momentum=0.9)
optimizer.state_dict()

输出:

{'param_groups': [{'dampening': 0,
'lr': 0.001,
'momentum': 0.9,
'nesterov': False,
'params': [4736167728, 4736166648, 4736167368, 4736165352],
'weight_decay': 0}],
'state': {}}
保存和加载模型

在这里插入图片描述
保存

torch.save(model.state_dict(), PATH) # 推荐的文件后缀名是pt或pth

加载

model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))

实验:

X = torch.randn(2, 3)
Y = net(X)
PATH = "./net.pt"
torch.save(net.state_dict(), PATH)
net2 = MLP()
net2.load_state_dict(torch.load(PATH))
Y2 = net2(X)
Y2 == Y

输出:

tensor([[1],
[1]], dtype=torch.uint8)

在这里插入图片描述
官方文档https://pytorch.org/tutorials/beginner/saving_loading_models.html

您可能感兴趣的与本文相关的镜像

PyTorch 2.5

PyTorch 2.5

PyTorch
Cuda

PyTorch 是一个开源的 Python 机器学习库,基于 Torch 库,底层由 C++ 实现,应用于人工智能领域,如计算机视觉和自然语言处理

PyTorch中,直接读取CSV文件并不内置像Pandas那样的功能,因为PyTorch主要用于处理张量数据,而非大数据集的批处理读取。不过,你可以结合Pandas来完成CSV数据的预处理,然后再转换成PyTorch所需的格式。以下是步骤: 1. 首先安装pandas库,如果你还没有安装可以使用以下命令: ``` !pip install pandas ``` 2. 使用Pandas库的`read_csv`函数读取CSV文件: ```python import pandas as pd data = pd.read_csv('your_file.csv') # 替换 'your_file.csv' 为你需要读取的CSV文件路径 ``` 这会返回一个DataFrame对象,DataFrame是一种二维表格型的数据结构,非常适合存储表格数据。 3. 对数据进行预处理,例如选择某些列、处理缺失值、转换数据类型等,这取决于你的任务需求。 4. 将DataFrame转换为PyTorch能处理的张量形式。如果你的数据已经是数值型并且可以直接用作模型输入,你可以这样做: ```python import torch tensor_data = torch.tensor(data.values) ``` 如果有类别特征需要编码,可以先转为one-hot编码或者使用LabelEncoder。 5. 最后,根据需要将张量划分为训练集、验证集和测试集: ```python train_data, val_data, test_data = train_test_split(tensor_data, labels, test_size=0.2, random_state=42) ``` 其中,`labels`是你想要作为目标变量的列名。记住,PyTorch的张量通常是四维的(batch_size, channels, height, width),如果你的数据不需要这样的形状,可能还需要进一步处理。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值