
pytorch
小鱼儿小于儿
开心做自己
展开
-
torch小知识点
知识点import torchx = torch.Tensor([1, 2, 3, 4]) # torch.Tensor是默认的tensor类型(torch.FlaotTensor)的简称。print('-' * 50)print(x) # tensor([1., 2., 3., 4.])print(x.size()) # torch.Size([4])print(x.dim()) # 1print(x.numpy()) # [1. 2. 3. 4.]print('-' *原创 2021-03-10 15:01:03 · 284 阅读 · 0 评论 -
小程序
小程序import torchimport matplotlib.pyplot as pltx_data = torch.Tensor([[1.0],[2.0],[3.0],[4.0],[5.0]])y_data = torch.Tensor([[2.0],[4.0],[6.0],[8.0],[10.0]])class MODEL(torch.nn.Module): def __init__(self,n_input,n_output): super(MODEL, se原创 2021-03-09 15:11:05 · 156 阅读 · 0 评论 -
pytorch笔记
pytorch#mnist手写识别 CNN训练使用的数据集是MNIST识别手写文字0-9,文字标签的编码方式为one-hot编码。 导入库 (os,torch,nn,Data) -设置epoch、批大小、学习率等使用torchvision.datasets下载数据集并制作批训练分发器,train_data有train_data和train_labels两个子属性,以前2000条数据作为测试集加速测试过程 - 定义CNN类(两层卷积层再全连接到10个节点表示数字,每层卷积后用ReLU激原创 2021-03-09 12:01:22 · 488 阅读 · 0 评论 -
pytorch笔记
pytorchimport torchimport torch.nn.functional as Fimport matplotlib.pyplot as plt# torch.manual_seed(1) # reproduciblex = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1) # x data (tensor), shape=(100, 1)y = x.pow(2) + 0.2*torch.rand(x.size()原创 2021-03-09 11:59:40 · 125 阅读 · 0 评论 -
pytorch RNN
pytorch RNNimport numpy as npimport torchimport torch.nn as nnfrom torch.autograd import Variableimport torchvisionimport matplotlib.pyplot as plt# Hyper ParametersEPOCH = 1BATCH_SIZE = 64TIME_STEP = 28 # rnn time step / image heightIN原创 2021-03-01 18:04:57 · 165 阅读 · 0 评论 -
pytorch
pytorchclass torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=, pin_memory=False, drop_last=False, t原创 2021-03-01 15:13:49 · 109 阅读 · 0 评论