
pytorch
文章平均质量分 53
pytorch
Shashank497
我要找到你,不管南北东西
展开
-
pytorch中nn.parameter和require_grad=True的区别
nn.parameter()首先可以把这个函数理解为类型转换函数,将一个不可训练的类型Tensor转换成可以训练的类型parameter并将这个parameter绑定到这个module里面(net.parameter()中就有这个绑定的parameter,所以在参数优化的时候可以进行优化的),所以经过类型转换这个self.v变成了模型的一部分,成为了模型中根据训练可以改动的参数了。torch.tensor([1,2,3],requires_grad=True)这个只是将参数变成可训练的,并没有绑定在m原创 2022-02-10 13:40:37 · 2691 阅读 · 0 评论 -
tensor.narrow()函数
input.narrow(dimension, start, length) → Tensor # 表示取变量input在第dimension维上,从索引start到start+length范围(不包括start+length)的值。example:In [2]: x = torch.Tensor([[1,2,3], [4,5,6], [7,8,9]])In [4]: x.narrow(0,0,3)Out[4]: tensor([[1., 2., 3.], [4., 5., 6原创 2022-01-19 22:01:52 · 1929 阅读 · 0 评论 -
Pytorch view和permute的区别
a = torch.tensor([[[1,2,3],[4,5,6]]])b = a.view(3,2)c = a.permute(0,2,1)print(a.size(),a)print(b.size(),b)print(c.size(),c)结果:a: torch.Size([1, 2, 3]) tensor([[[1, 2, 3], [4, 5, 6]]])b: torch.Size([3, 2]) tensor([[1, 2], [3, 4],原创 2022-01-19 19:12:43 · 969 阅读 · 0 评论 -
pytorch的dataset用法详解
torch.utils.data 里面的dataset使用方法当我们继承了一个 Dataset类之后,我们需要重写 len 方法,该方法提供了dataset的大小; getitem 方法, 该方法支持从 0 到 len(self)的索引from torch.utils.data import Dataset, DataLoaderimport torchclass MyDataset(Dataset): """ 下载数据、初始化数据,都可以在这里完成 """原创 2022-01-15 15:51:21 · 8746 阅读 · 0 评论 -
pytorch中DataLoader详解
import torchimport torch.utils.data as Dataif __name__ == '__main__': torch.manual_seed(1) # reproducible BATCH_SIZE = 5 # 批训练的数据个数 x = torch.linspace(11, 20, 10) # x data: tensor([11., 12., 13., 14., 15., 16., 17., 18., 19., 20.]) .原创 2022-01-15 13:40:17 · 5868 阅读 · 0 评论 -
torch 的dataloader 的核心函数
官方解释:Dataloader 组合了 dataset & sampler,提供在数据上的 iterable主要参数:1、dataset:这个dataset一定要是torch.utils.data.Dataset本身或继承自它的类里面最主要的方法是 getitem(self, index) 用于根据index索引来取数据的2、batch_size:每个batch批次要返回几条数据3、shuffle:是否打乱数据,默认False4、sampler:sample strategy,数据选取策原创 2022-01-14 19:46:24 · 2284 阅读 · 0 评论