
PyTorch
ZZ58
这个作者很懒,什么都没留下…
展开
-
Pytorch实现手写数字识别
''' 用PyTorch完成手写数字识别 1.准备数据 2.构建模型 3.模型的训练 4.模型的保存 5.模型的评估 ''' import torch import os import numpy as np from torch import nn from torch import optim import torch.nn.functional as F from torch.utils.data import DataLoader from torchvision.datasets import原创 2021-05-09 14:41:39 · 249 阅读 · 0 评论 -
PyTorch实现线性回归
import torch import torch.nn as nn from torch.optim import SGD import matplotlib.pyplot as plt #0.准备数据 x = torch.rand([500,1]) y_true = 3*x + 0.8 #1.定义模型 class MyLinear(nn.Module): def __init__(self): super(MyLinear, self).__init__() se原创 2021-05-09 10:53:47 · 121 阅读 · 0 评论 -
Pytorch--常用API
1.nn.Module a.__init__, b.farward nn.Module定义了__call__方法,及Lr的实例,能够直接被传入参数调用,实际上用的是forward方法并传入参数 from torch import nn class Lr(nn.Module): def __init__(self): super(Lr,self).__init_() self.linear = nn.Linear(1,1) def forward(self,x): out = se原创 2021-05-07 14:37:31 · 304 阅读 · 0 评论 -
安装PyTorch以及GPU上运行
Pytorch:torch,torchvision,cudatoolkit 对应的版本 https://pytorch.org/get-started/previous-versions/ 下载torch以及torchvision,toolkit网址 http://download.pytorch.org/whl/torch_stable.html 下载下来的whl文件 在这个目录下 pip install xxxxxx.whl 在GPU上执行程序 a.自定义的参数需要转化为cuda支持的tensor b.原创 2021-05-07 15:15:11 · 100 阅读 · 0 评论