
pytorch
开心就哈哈
这个作者很懒,什么都没留下…
展开
-
如何将标签转为one-hot形式
假设存在如下标签矩阵,包含4个类,0,1,2,3。label 的维度为3行 5列。>>> label = torch.randint(1,4,(3,5))>>> labeltensor([[2, 2, 1, 3, 1], [1, 2, 1, 1, 1], [1, 3, 3, 1, 3]])>>> (1)扩展label维度,因为label含有4个类的标签,因此扩展label的维度 至 3x 5 x 4,即三个原创 2021-08-19 22:00:46 · 1246 阅读 · 0 评论 -
torch.nn.CrossEntropy() 计算方式
torch.nn.CrossEntropy() 主要由torch.nn.LogSoftmax 与torch.nn.NLLLoss 两个计算方式组成。(1)torch.nn.LogSoftmax 是由Softmax() 与对数计算得到。(2)torch.nn.NLLLoss 利用标签target,提取torch.nn.LogSoftmax输出结果中对应特征维度的数值,再取绝对值,然后再所有样本上计算均值得到torch.nn.CrossEntropy() 计算结果。假设输入数据input的维度.原创 2021-08-15 17:10:28 · 669 阅读 · 0 评论 -
多类分割 交叉熵损失函数 pytorch 实现思路
参考文献:https://discuss.pytorch.org/t/multiclass-segmentation/54065原创 2020-10-31 20:46:33 · 1202 阅读 · 0 评论 -
Pytorch 测试阶段,加载模型对输入数据进行推理,程序提示 CUDA of Memory 错误
BUG:在测试的时候,传递prev_state_spatial,出现显存溢出问题。解决方案,将模型得到的中间变量保存至cpu,在使用的时候,将该数据转移到cuda上。参考文献:【1】https://blog.youkuaiyun.com/nlite827109223/article/details/89847044...原创 2020-08-13 20:05:18 · 420 阅读 · 0 评论 -
Pytorch 加权交叉熵实现分析
此文不涉及公式。假设模型对输入图像数据(1(batch size) * 1 (channels) * 3 (height) * 3 (width))的分割输出结果为1 * 2 (只有两类,前景,背景)* 3 (height)* 3 (width),ground truth为 target ( 1 (batch size) * 3 * 3)。假设模型的分割结果为input, ground truth 为target。根据文献[1],计算交叉熵损失有两中方式,一种是用F.nll_loss();.原创 2020-06-09 21:01:10 · 8861 阅读 · 5 评论 -
output = torch.Tensor.scatter_(dim, index, src)
dim = 0, 按照数值方向操作;dim = 1, 按照水平方向操作;a = torch.rand(2, 5)print(a)b = torch.zeros(3, 5).scatter_(0, torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]), a)print(b) output = torch.Tensor.scatter_(...原创 2020-01-13 17:32:26 · 330 阅读 · 0 评论 -
安装torch_geometric 报错 ModuleNotFoundError: No module named 'torch_scatter'
环境:macOS, python=3.5.6 torch=1.3.0在pip install torch-geometric安装好torch-geometric之后,进入python编译器,运行from torch_geometric.data import data,报错:ModuleNotFoundError: No module named 'torch_scatter'...原创 2019-12-10 19:57:47 · 8911 阅读 · 2 评论 -
pytorch gather
yOut[34]: tensor([0, 2])y_hatOut[35]: tensor([[0.1000, 0.3000, 0.6000], [0.3000, 0.2000, 0.5000]])y_hat.gather(1, y.view(-1, 1))聚合方向y_hat的维度1,聚合位置:a[0][y.view(-1,1)[0]] = 0.1a[1][...原创 2019-09-18 21:06:48 · 445 阅读 · 0 评论