pytorch 入门教程 & 常用知识整理

Tensor

detach()

这个是针对想冻结网络model里面的某些参数不参与训练时使用的。并且它是会返回一个新的tensor。

y = x.detach()

这种情况下,y和x共同一个内存,即一个修改另一个也会跟着改变。此时x还是可以正常求导。所以很多时候用的时候是x = x.detach()。

torch.cat 与 torch.stack

torch.cat不会增加新的维度,原来几个维度,还是几个维度
torch.stack会增加一个新的维度,让n维的tensor变成n+1维

x1_torch = torch.zeros(3,1)
y1_torch = torch.ones(3,1)

xy_1 = torch.cat([x1_torch, y1_torch], dim = 1)
xy_2 = torch.stack([x1_torch, y1_torch], dim=1)

xy_1的shape=(3,2)

xy_1 = tensor(
[[0., 1.],
[0., 1.],
[0., 1.]])

xy_2的shape= (3,2,1)

xy_2 = tensor(
[ [[0.],[1.]],
[[0.],[1.]],
[[0.], [1.]] ])

矩阵运算

x = torch.Tensor([[1,2],[3,4]])
矩阵相乘 按照矩阵运算法则:

y1 = x @ x.T
y2 = x.matmul(x.T)

y1=y2=tensor(
[[ 5., 11.],
[11., 25.]])

元素之间相乘:

y3 = x * x.T
y4 = x.mul(x.T)

y3 = y4 = tensor(
[[ 1., 4.],
[ 9., 16.]])

与numpy数据交换:

Tensor --> numpy

x_numpy = x_torch.numpy()

numpy --> Tensor

x_torch = torch.from_numpy(x_numpy)

但是需要注意的是,这两个之间共享底层存储,所以更改一个会影响另一个。如果要想阻断之间的关系,可以在numpy数据后面加上 copy()。

数据处理

Datasets & DataLoaders

pytorch的数据处理中最关键的就是这两个类

  • torch.utils.data.DataLoader 主要用于训练的迭代器,使用比较简单
  •   train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
    
  • torch.utils.data.Dataset 做自己数据集的关键类,最关键的是继承类中一定要重载__init____len____getitem__这三个函数。常用模板如下:
  •   class Dataset_name(Dataset):
      	def __init__(self, flag='train'):
          	assert flag in ['train', 'test', 'valid']
          	self.flag = flag
         	self.__load_data__()
    
      	def __getitem__(self, index):
          	pass
      	def __len__(self):
          	pass
    
      	def __load_data__(self, csv_paths: list):
          	pass
          	print(
              "train_X.shape:{}\ntrain_Y.shape:{}\nvalid_X.shape:{}\nvalid_Y.shape:{}\n"
              .format(self.train_X.shape, self.train_Y.shape, self.valid_X.shape, self.valid_Y.shape))
    

`

Torchvision.transform

在训练之前,数据基本都需要进行数据增强,各种剪裁等数据变换,因此这个模块是必不可少的。

torchvision.transforms.Compose

这是一个类,主要作用是把所有的数据处理集合成一个操作

  • 输入参数:它接收的元素为Torchvision.transform里面的方法组成的list
  •   T =  transforms.Compose([transforms.CenterCrop(10), transforms.ToTensor()])
    
  • 使用方法:直接对该类的实例传入图片即可,图片的格式需要是PIL或者Tensor。如果是opencv,需要把ToTensor()方法放到第一个。它调用定义的源码为:
  •      def __call__(self, img):
          for t in self.transforms:
              img = t(img)
          return img
    
  • 该类不可用于转移动端的代码,如果转移动端,需要将其替换为如下,并且输出还必须只能是Tensor类型。
   transforms = torch.nn.Sequential(
   transforms.CenterCrop(10),
   transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
)
scripted_transforms = torch.jit.script(transforms)
torchvision.transforms.ToTensor

因为图片的读取一般分为opencv和PIL,但都需要转为tensor,或者互转,opencv中的存储方式为(h,w,c),但是tensor中一般都存储为(c,h,w),而pil也是一种特殊的格式;但是这一切都可以用ToTensor一键搞定。

#img_cv.shape = (448, 800, 3)
#img_pil.size = (800, 448)
cv2tensor =  torchvision.transforms.ToTensor()(img_cv)
#cv2tensor.shape : torch.Size([3, 448, 800])
pil2tensor = torchvision.transforms.ToTensor()(img_pil)
#pil2tensor.shape : torch.Size([3, 448, 800])

搭建模型 build model

搭建模型包括:

  1. 构建模型,其中又包括:
    • 定义单独的网络层,即__init__函数;
    • 把它们拼在一起,确定各层的执行顺序,即forward函数。
  2. 权值初始化。

torch.nn

PyTorch 把与深度学习模型搭建相关的全部类全部在 torch.nn 这个子模块中。
根据类的功能分类,常用的有如下几个部分:

  • 模型运算层:
    • 卷积层:torch.nn.Conv2d
    • 池化层:torch.nn.MaxPool2d
    • 线性层:torch.nn.Linear
    • 等等
  • 容器类,如 torch.nn.Module
  • 工具函数类 Utilities,用的较少,日后遇见可以补充。

而在 torch.nn 下面还有一个子模块 torch.nn.functional,基本上是 torch.nn里对应类的函数,比如torch.nn.ReLU的对应函数是 torch.nn.functional.relu,这两者的功能一样,运行效率也不同,但是也有很大的区别:

  1. 调用方式不同 torch.nn.XXX是类,使用前必须先进行实例化,而torch.nn.functional是函数,使用前必须传入所有参数,如果是卷积层需要传入权重。
inputs = torch.rand(64, 3, 244, 244)
conv = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, padding=
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值