
pytorch
znsoft
人工智能博士/教授级高级工程师/博士研究生导师
展开
-
Pytorch模型转换为onnx或ncnn的方法兼谈pytorch模型编写规范
使用torch.jit.trace_module( model, {"entrypoint",[parameters]}) 来记录模型,entrypoint在大多数模型中是 forward, 如果有定制,可以修改。影响成功的主要问题是:在模型的forward函数中不能使用条件语句,否则 会影响trace。此时的办法就是去掉里面的条件语句。pytorch-> torchscript->pnnx, 中间有副产品onnx模型。使用torch.jit.save来保存模型。原创 2023-01-30 20:05:10 · 645 阅读 · 0 评论 -
Windows下用amd显卡训练 : Pytorch-directml 重大升级,改为pytorch插件形式,兼容更好
新的pytorch-directml 不再是独立的pytorch 移植,变成了一个设备插件,更好用,兼容性更好。原来的版本无法跑transformers, 新版变成一个独立的计算设备 dml, 兼容性更好。原创 2022-12-22 07:49:15 · 7197 阅读 · 0 评论 -
pytorch - directml 中查看设备支持情况
print(torch.dml.device_name(0)) # 显示第0个DML设备名称。print(torch.dml.default_device()) #显示缺省DML设备id。print(torch.dml.is_available()) #显示是否有dml设备。原创 2022-11-13 17:10:50 · 2748 阅读 · 0 评论 -
AMD 显卡编译 pytorch 指南 ROCM + pytorch
ROCM + pytorch 快速安装方法需要在干净机器上安装原始参考资料 https://github.com/aieater/rocm_pytorch_informations ,有修改在ubuntu 18.04 及ubuntu 20.04 测试通过以下为安装pytorch 1.6 + rocm 3.5.1 (需要版本匹配)1. 更新系统,安装必要的库sudo apt update sudo apt -y dist-upgrade ...原创 2020-10-18 12:07:34 · 11670 阅读 · 5 评论 -
torch.eq的广播机制兼谈快速生成对角掩码
其实就是一个对角线为true的矩阵,怎么实现的?x和y的维度都不相同,进行广播机制,生成两个 6*6矩阵,这样torch.eq(x,y)后只剩 对角线上是true,其它位置是false了。torch.eq用于判断 两个矩阵是不是逐元素相等,或者和第二个值 相等。输出 x=[0,1,2,3,4,5]原创 2022-08-24 19:12:49 · 512 阅读 · 0 评论 -
解决 NCCL WARN Cuda failure ‘invalid device function‘ , unhandled cuda error, NCCL version 2.4.8
注意最后一行: enqueue.cc:197 NCCL WARN Cuda failure 'invalid device function'运行nvidia-smi 后得到的版本要和pytorch安装 时的版本一样,我的是: CUDA Version: 11.7。原创 2022-08-15 20:33:27 · 3009 阅读 · 0 评论 -
pytorch中的矩阵切片操作完全讲解
我们经常需要从2维或三维tensor中进行切片操作,比如从mask模型中取出mask所在位置的值。原创 2022-08-12 10:53:17 · 2446 阅读 · 0 评论 -
torch.where的新用法(很老但是大家忽略的用法)
(condition) → tuple of LongTensor is identical to .通常 ,我们都 会这样使用torch.where函数:torch.where(condition, x, y)但是实际上torch.where还有如题头如示的用法,返回tuple.tuple中是condition中符合条件的值的index.比如a=[1,2,3,4,5]torch.where(a>3)返回( 3,4)原创 2022-05-31 22:56:36 · 941 阅读 · 1 评论 -
快速获取 pytorch中符合条件的tensor元素个数
import torchx=torch.eye(3) # 生成一个测试用的tensor ,单位阵y=torch.nonzero(x>0) # 找出值大于0的索引位置print(y.shape[0]) #索引行数即是个数原创 2022-04-25 10:01:09 · 5743 阅读 · 0 评论 -
pytorch 交叉熵函数CrossEntropyLoss 使用详解
import torch import torch.nn as NNcriterion = NN.CrossEntropyLoss()X=torch.randn([2,150])Y=torch.randint(0,150,(2,))print(X.shape)print(Y.shape)loss=criterion(X,Y)loss0=criterion(X[0],Y[0])loss1=criterion(X[1],Y[1])loss_all=(loss0+loss1)/.原创 2022-04-19 11:49:48 · 2273 阅读 · 0 评论 -
pytorch中的 target label问题: multi-target not supported
我们在处理多分类问题时,对数据的标签也就是y值 ,通常是按one-hot编码方式处理,这个时候在计算loss 函数时就会出现以下提示:multi-target not supported查阅资料发现,在pytorch的loss计算中,不能使用onehot编码,而是需要使用类别编码,在实际计算时,pytorch会自动转换为one-hot编码,否则 会提示以下错误:multi-target not supported...原创 2021-08-18 07:12:47 · 1215 阅读 · 0 评论