pytorch 入门指南
1. pytorch 概述
pytorch是facebook 开发的torch(Lua语言)的python版本,于2017年引爆学术界
官方宣传pytorch侧重两类用户:numpy的gpu版、深度学习研究平台
pytorch使用动态图机制,相比于tensorflow最开始的静态图,更为灵活
当前pytorch支持的系统包括:win,linux,macos
2. pytorch基本库
常用的pytorch基本库主要包括:
- torch: 内含一些常用方法,与numpy比较像
- torch.Tensor:内含一些操作tensor的方法,可通过tensor.xx()进行调用
- torch.nn:内含一些常用模型,如rnn,cnn等
- torch.nn.functional:内含一些常用方法,如sigmoid,softmax等
- torch.optim:内含一些优化算法,如sgd,adam等
- torch.utils.data:内含一些数据迭代方法
3. 基本操作
a. tensor操作
# 初始化空向量
torch.empty(3,4)
# 随机初始化数组
torch.rand(4,3)
# 初始化零向量
torch.zeros(4,3, dtype=torch.int)
# 从数据构建数组
x = torch.tensor([3,4],dtype=torch.float)
x = torch.IntTensor([3,4])
# 获取tensor的尺寸,元组
x.shape
x.size()
# _在方法中的意义:表示对自身的改变
x = torch.ones(3,4)
# 以下三个式子 含义相同
x = x + x
x = torch.add(x, x)
x.add_(x)
# 索引,像操作numpy一样