
机器学习
Flyforever-Tang
这个作者很懒,什么都没留下…
展开
-
Pytorch:解决使用yolov5时网络不佳导致无法下载数据集的问题
Pytorch:解决使用yolov5时网络不佳导致无法下载数据集的问题原创 2022-10-04 17:43:38 · 1134 阅读 · 0 评论 -
Pytorch:修改模型的特定模块/层
def _set_module(model, submodule_key, module): tokens = submodule_key.split('.') sub_tokens = tokens[:-1] cur_mod = model for s in sub_tokens: cur_mod = getattr(cur_mod, s) setattr(cur_mod, tokens[-1], module)参数如下model:模型sub原创 2022-04-10 15:22:28 · 4687 阅读 · 0 评论 -
Pytorch:在with torch.no_grad()的代码段里临时允许计算梯度
一般情况下,训练神经网络时,pytorch的测试部分语句都会写在 with torch.no_grad() 的代码段下,以关闭tensor的自动求导、计算梯度功能,节省显存和运算时间。但是有时会希望临时允许计算梯度,比如作者是在用pytorch_grad_cam生成神经网络的可解释热图时,它需要通过计算梯度生成,否则会报“element 0 of tensors does not require grad and does not have a grad_fn”的错误。百度没有搜到怎么临时允许计算梯度,原创 2022-03-22 21:37:47 · 1742 阅读 · 0 评论 -
Pytorch:获得模型每一层的名字
以resnet18模型为例,可以打印出模型的结构:from torchvision import modelsmodel = models.resnet18()print(model)'''打印结果'''# ResNet(# (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)# (bn1): BatchNorm2d(64, eps=1e-05, momentum=0原创 2022-03-13 17:00:44 · 5239 阅读 · 0 评论 -
Pytorch:ImageFolder只读取部分类别文件夹
直接继承并且重写ImageFolder类的find_classes方法即可from torchvision.datasets.folder import *from typing import *class FilterableImageFolder(ImageFolder): def __init__( self, root: str, transform: Optional[Callable] = None,原创 2022-03-12 10:20:19 · 3530 阅读 · 1 评论 -
使用ImageFolder读取数据集时忽略特定文件
如果事先知道需要忽略哪些文件,当然直接从数据集里删除就行了。但如果需要在程序运行时动态确认,或者筛选规则比较复杂,人工不好做,就需要让ImageFolder在读取时使用自定义的筛选规则。ImageFolder有一个可选参数为is_valid_file,参数类型为可调用的函数,该函数传入一个str参数,返回一个bool值。当返回值为True时保留该文件,否则忽略。例如,读取时想要忽略所有文件名带‘invalid’的文件,代码如下:class Check(object): def __init__原创 2022-03-11 19:01:45 · 1722 阅读 · 0 评论 -
Pytorch训练GoogLeNet时损失函数报错
报错信息:TypeError: cross_entropy_loss(): argument ‘input’ (position 1) must be Tensor, not GoogLeNetOutputs需要把model输出的GoogLeNetOutputs转化为适用于损失函数的logits形式output = model(x)output = output.logits原创 2021-12-29 15:20:33 · 5762 阅读 · 3 评论 -
keras读取预训练模型统一接口
import numpy as npfrom tensorflow.keras.layers import Flatten, Densefrom keras.models import Modelimport tensorflow.keras.applications as KerasModelsupported_model = np.array([ ['xception', 'Xception'], ['vgg16'原创 2021-08-25 17:00:25 · 223 阅读 · 0 评论 -
Pytorch加载本地CIFAR10数据集
在线下载经常报错,可以预先下载好数据集放到本地。下载数据集(官网页面:http://www.cs.toronto.edu/~kriz/cifar.html 下载地址:http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz),使用IDM很快就下载完了,不要听某些人瞎忽悠去下载还要积分的资源,百度云也大可不必。把cifar-10-python.tar.gz更名为cifar-10-batches-py.tar.gz并放到本地文件夹里,这里放到./data原创 2021-08-24 17:43:46 · 4616 阅读 · 1 评论 -
pytorch判断模型是否处于训练状态
使用self.training,这样就可以让forward函数采用两种执行方式,然后就可以做一些骚操作了import torch.nn as nnclass myNet(nn.Module): def __init__(self): super(myNet, self).__init__() def forward(self): if self.training: print('training') else:原创 2021-08-12 15:58:37 · 3743 阅读 · 0 评论 -
查找numpy.ndarray中的重复元素
先转成list,然后再借助count查找(ndarray没有count方法)list= ndarray.tolist()for i in range(len(list)): if list.count(list[i]) > 1: print(list[i])原创 2021-07-27 15:51:23 · 2646 阅读 · 0 评论 -
保存Shap生成的神经网络解释图(shap.image_plot)
保存Shap生成的神经网络解释图(shap.image_plot)调用shap.image_plot后发现使用plt.savefig保存下来的图像为空白图,经过查资料发现这是因为调用plt.show()后会生成新画板。(参考链接:保存plot_如何解决plt.savefig()保存的图片为空白的问题?)找到了一篇介绍如何保存Shap图的博客(原文地址:shap解释模型特征,多张图保存的实现(要改源码)),但是里面并没有提到image_plot怎么处理。此外,前面那个链接里提到的“先在画图前调用myfig原创 2021-06-23 10:49:23 · 8217 阅读 · 5 评论 -
ImageDataGenerator读取的数据集转Numpy array
ImageDataGenerator读取的数据集转Numpy array常用的数据集类型可以分为两种:(1)一种是网上的经典数据集,比如mnist,一般会给写好的读取方法,比如mnist.load_data(),读取出来的返回值是Numpy array;(2)一种是自己本地的数据集,路径下每个文件夹代表一类图像,目录结构类似于data--type1----img1-1----img1-2--type2----img2-2----img2-2--type3----img3-2----原创 2021-06-18 17:09:01 · 891 阅读 · 1 评论