
API_Net
「已注销」
这个作者很懒,什么都没留下…
展开
-
API_Net代码的train.txt与val.txt
import os if __name__=="__main__": path = "/home/yjys/datasets/CUB_200_2011" train_test_file = "train_test_split.txt" image_class_label = "image_class_labels.txt" image = "images.txt" image_path = "/home/yjys/datasets/CUB_200_2011/imag原创 2021-02-03 10:03:12 · 351 阅读 · 1 评论 -
API_Net官方代码之训练网络
导入包: import argparse import os import time import torch import torch.nn as nn import torch.backends.cudnn as cudnn import torch.optim import torch.utils.data import torchvision.transforms as transforms import numpy as np from models import API_Net from dat原创 2021-02-03 10:01:23 · 469 阅读 · 0 评论 -
API_Net官方代码之utils工具
导入包 import torch import shutil 二、模块 1)保存模型参数,保存模型状态,状态中可以有模型参数,优化器参数,epoch等。如果是在验证集上表现比之前好,那么就是is_best=True,使用shutil.copyfile(src, des)将src文件直接拷贝到des,如果已经存在,就直接覆盖掉。 def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):#state是一个字典,包含优化器、网络等参数原创 2021-02-02 10:17:29 · 248 阅读 · 2 评论 -
API_Net官方代码之创建模型
导入包: import torch from torch import nn from torchvision import models import numpy as np from skimage import io 设置device: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 1)计算批量样本通过resnet101后输出的批量向量的距离: 计算方法是两个向量之间对应元素相减,再平方,再求和,使用完全原创 2021-02-01 22:50:54 · 319 阅读 · 2 评论 -
API_Net官方代码之数据处理
一、数据准备 总结: RandomDataset :用于验证 (val) BatchDataset:用于训练 (train) BalancedBatchSampler:决定如何采样样本,不是简单的在Dataloader中设置一个batch_size了 1)导入的包类: import torch from PIL import Image import numpy as np from torchvision import transforms from torch.utils.data import .原创 2021-02-01 21:36:02 · 518 阅读 · 2 评论