这是main.py文件的代码:from datetime import datetime
from functools import partial
from PIL import Image
import cv2
import numpy as np
from torch.utils.data import DataLoader
from torch.version import cuda
from torchvision import transforms
from torchvision.datasets import CIFAR10
from torchvision.models import resnet
from tqdm import tqdm
import argparse
import json
import math
import os
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
#数据增强(核心增强部分)
import torch
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
# 设置参数
parser = argparse.ArgumentParser(description='Train MoCo on CIFAR-10')
parser.add_argument('-a', '--arch', default='resnet18')
# lr: 0.06 for batch 512 (or 0.03 for batch 256)
parser.add_argument('--lr', '--learning-rate', default=0.06, type=float, metavar='LR', help='initial learning rate', dest='lr')
parser.add_argument('--epochs', default=300, type=int, metavar='N', help='number of total epochs to run')
parser.add_argument('--schedule', default=[120, 160], nargs='*', type=int, help='learning rate schedule (when to drop lr by 10x); does not take effect if --cos is on')
parser.add_argument('--cos', action='store_true', help='use cosine lr schedule')
parser.add_argument('--batch-size', default=64, type=int, metavar='N', help='mini-batch size')
parser.add_argument('--wd', default=5e-4, type=float, metavar='W', help='weight decay')
# moco specific configs:
parser.add_argument('--moco-dim', default=128, type=int, help='feature dimension')
parser.add_argument('--moco-k', default=4096, type=int, help='queue size; number of negative keys')
parser.add_argument('--moco-m', default=0.99, type=float, help='moco momentum of updating key encoder')
parser.add_argument('--moco-t', default=0.1, type=float, help='softmax temperature')
parser.add_argument('--bn-splits', default=8, type=int, help='simulate multi-gpu behavior of BatchNorm in one gpu; 1 is SyncBatchNorm in multi-gpu')
parser.add_argument('--symmetric', action='store_true', help='use a symmetric loss function that backprops to both crops')
# knn monitor
parser.add_argument('--knn-k', default=20, type=int, help='k in kNN monitor')
parser.add_argument('--knn-t', default=0.1, type=float, help='softmax temperature in kNN monitor; could be different with moco-t')
# utils
parser.add_argument('--resume', default='', type=str, metavar='PATH', help='path to latest checkpoint (default: none)')
parser.add_argument('--results-dir', default='', type=str, metavar='PATH', help='path to cache (default: none)')
'''
args = parser.parse_args() # running in command line
'''
args = parser.parse_args('') # running in ipynb
# set command line arguments here when running in ipynb
args.epochs = 300 # 修改处
args.cos = True
args.schedule = [] # cos in use
args.symmetric = False
if args.results_dir == '':
args.results_dir = "E:\\contrast\\yolov8\\MoCo\\run\\cache-" + datetime.now().strftime("%Y-%m-%d-%H-%M-%S-moco")
moco_args = args
class CIFAR10Pair(CIFAR10):
def __getitem__(self, index):
img = self.data[index]
img = Image.fromarray(img)
# 原始图像增强
im_1 = self.transform(img)
im_2 = self.transform(img)
# 退化增强生成额外视图
degraded_results = image_degradation_and_augmentation(img)
im_3 = self.transform(Image.fromarray(degraded_results['augmented_images'][0])) # 选择第一组退化增强
im_4 = self.transform(Image.fromarray(degraded_results['cutmix_image']))
return im_1, im_2, im_3, im_4 # 返回原始增强+退化增强
# 定义数据加载器
# class CIFAR10Pair(CIFAR10):
# """CIFAR10 Dataset.
# """
# def __getitem__(self, index):
# img = self.data[index]
# img = Image.fromarray(img)
# if self.transform is not None:
# im_1 = self.transform(img)
# im_2 = self.transform(img)
# return im_1, im_2
import cv2
import numpy as np
import random
def apply_interpolation_degradation(img, method):
"""
应用插值退化
参数:
img: 输入图像(numpy数组)
method: 插值方法('nearest', 'bilinear', 'bicubic')
返回:
退化后的图像
"""
# 获取图像尺寸
h, w = img.shape[:2]
# 应用插值方法
if method == 'nearest':
# 最近邻退化: 下采样+上采样
downsampled = cv2.resize(img, (w//2, h//2), interpolation=cv2.INTER_NEAREST)
degraded = cv2.resize(downsampled, (w, h), interpolation=cv2.INTER_NEAREST)
elif method == 'bilinear':
# 双线性退化: 下采样+上采样
downsampled = cv2.resize(img, (w//2, h//2), interpolation=cv2.INTER_LINEAR)
degraded = cv2.resize(downsampled, (w, h), interpolation=cv2.INTER_LINEAR)
elif method == 'bicubic':
# 双三次退化: 下采样+上采样
downsampled = cv2.resize(img, (w//2, h//2), interpolation=cv2.INTER_CUBIC)
degraded = cv2.resize(downsampled, (w, h), interpolation=cv2.INTER_CUBIC)
else:
degraded = img
return degraded
def darken_image(img, intensity=0.3):
"""
应用黑暗处理 - 降低图像亮度并增加暗区对比度
参数:
img: 输入图像(numpy数组)
intensity: 黑暗强度 (0.1-0.9)
返回:
黑暗处理后的图像
"""
# 限制强度范围
intensity = max(0.1, min(0.9, intensity))
# 将图像转换为HSV颜色空间
hsv = cv2.cvtColor(img, cv2.COLOR_RGB2HSV).astype(np.float32)
# 降低亮度(V通道)
hsv[:, :, 2] = hsv[:, :, 2] * intensity
# 增加暗区的对比度 - 使用gamma校正
gamma = 1.0 + (1.0 - intensity) # 黑暗强度越大,gamma值越大
hsv[:, :, 2] = np.power(hsv[:, :, 2]/255.0, gamma) * 255.0
# 限制值在0-255范围内
hsv[:, :, 2] = np.clip(hsv[:, :, 2], 0, 255)
# 转换回RGB
return cv2.cvtColor(hsv.astype(np.uint8), cv2.COLOR_HSV2RGB)
def random_affine(image):
"""
随机仿射变换(缩放和平移)
参数:
image: 输入图像(numpy数组)
返回:
变换后的图像
"""
height, width = image.shape[:2]
# 随机缩放因子 (0.8 to 1.2)
scale = random.uniform(0.8, 1.2)
# 随机平移 (10% of image size)
max_trans = 0.1 * min(width, height)
tx = random.randint(-int(max_trans), int(max_trans))
ty = random.randint(-int(max_trans), int(max_trans))
# 变换矩阵
M = np.array([[scale, 0, tx], [0, scale, ty]], dtype=np.float32)
# 应用仿射变换
transformed = cv2.warpAffine(image, M, (width, height))
return transformed
def augment_hsv(image, h_gain=0.1, s_gain=0.5, v_gain=0.5):
"""
HSV色彩空间增强
参数:
image: 输入图像(numpy数组)
h_gain, s_gain, v_gain: 各通道的增益范围
返回:
增强后的图像
"""
# 限制增益范围
h_gain = max(-0.1, min(0.1, random.uniform(-h_gain, h_gain)))
s_gain = max(0.5, min(1.5, random.uniform(1-s_gain, 1+s_gain)))
v_gain = max(0.5, min(1.5, random.uniform(1-v_gain, 1+v_gain)))
# 转换为HSV
hsv = cv2.cvtColor(image, cv2.COLOR_RGB2HSV).astype(np.float32)
# 应用增益
hsv[:, :, 0] = (hsv[:, :, 0] * (1 + h_gain)) % 180
hsv[:, :, 1] = np.clip(hsv[:, :, 1] * s_gain, 0, 255)
hsv[:, :, 2] = np.clip(hsv[:, :, 2] * v_gain, 0, 255)
# 转换回RGB
return cv2.cvtColor(hsv.astype(np.uint8), cv2.COLOR_HSV2RGB)
# def mixup(img1, img2, alpha=0.6):
# """
# 将两幅图像混合在一起
# 参数:
# img1, img2: 输入图像(numpy数组)
# alpha: Beta分布的参数,控制混合比例
# 返回:
# 混合后的图像
# """
# # 生成混合比例
# lam = random.betavariate(alpha, alpha)
# # 确保图像尺寸相同
# if img1.shape != img2.shape:
# img2 = cv2.resize(img2, (img1.shape[1], img1.shape[0]))
# # 混合图像
# mixed = (lam * img1.astype(np.float32) + (1 - lam) * img2.astype(np.float32)).astype(np.uint8)
# return mixed
# def image_degradation_and_augmentation(image,dark_intensity=0.3):
# """
# 完整的图像退化和增强流程
# 参数:
# image: 输入图像(PIL.Image或numpy数组)
# 返回:
# dict: 包含所有退化组和最终增强结果的字典
# """
# # 确保输入是numpy数组
# if not isinstance(image, np.ndarray):
# image = np.array(image)
# # 确保图像为RGB格式
# if len(image.shape) == 2:
# image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
# elif image.shape[2] == 4:
# image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
# # 原始图像
# original = image.copy()
# # 插值方法列表
# interpolation_methods = ['nearest', 'bilinear', 'bicubic']
# # 第一组退化: 三种插值方法
# group1 = []
# for method in interpolation_methods:
# degraded = apply_interpolation_degradation(original, method)
# group1.append(degraded)
# # 第二组退化: 随机额外退化
# group2 = []
# for img in group1:
# # 随机选择一种退化方法
# method = random.choice(interpolation_methods)
# extra_degraded = apply_interpolation_degradation(img, method)
# group2.append(extra_degraded)
# # 所有退化图像组合
# all_degraded_images = [original] + group1 + group2
# # 应用黑暗处理 (在增强之前)
# darkened_images = [darken_image(img, intensity=dark_intensity) for img in all_degraded_images]
# # 应用数据增强
# # 1. 随机仿射变换
# affine_images = [random_affine(img) for img in darkened_images]
# # 2. HSV增强
# hsv_images = [augment_hsv(img) for img in affine_images]
# # 3. MixUp增强
# # 随机选择两个增强后的图像进行混合
# mixed_image = mixup(
# random.choice(hsv_images),
# random.choice(hsv_images)
# )
# # 返回结果
# results = {
# 'original': original,
# 'degraded_group1': group1, # 第一组退化图像
# 'degraded_group2': group2, # 第二组退化图像
# 'augmented_images': hsv_images, # 所有增强后的图像(原始+六组退化)
# 'mixup_image': mixed_image # MixUp混合图像
# }
# return results
# # def add_gaussian_noise(image, mean=0, sigma=25):
# # """添加高斯噪声"""
# # noise = np.random.normal(mean, sigma, image.shape)
# # noisy = np.clip(image + noise, 0, 255).astype(np.uint8)
# # return noisy
# # def random_cutout(image, max_holes=3, max_height=16, max_width=16):
# # """随机CutOut增强"""
# # h, w = image.shape[:2]
# # for _ in range(random.randint(1, max_holes)):
# # hole_h = random.randint(1, max_height)
# # hole_w = random.randint(1, max_width)
# # y = random.randint(0, h - hole_h)
# # x = random.randint(0, w - hole_w)
# # image[y:y+hole_h, x:x+hole_w] = 0
# # return image
import cv2
import numpy as np
import random
from matplotlib import pyplot as plt
import pywt
def wavelet_degradation(image, level=0.5):
"""小波系数衰减退化"""
# 小波分解
coeffs = pywt.dwt2(image, 'haar')
cA, (cH, cV, cD) = coeffs
# 衰减高频系数
cH = cH * level
cV = cV * level
cD = cD * level
# 重建图像
return pywt.idwt2((cA, (cH, cV, cD)), 'haar')[:image.shape[0], :image.shape[1]]
def adaptive_interpolation_degradation(image):
"""自适应插值退化(随机选择最近邻或双三次插值)"""
if random.choice([True, False]):
method = cv2.INTER_NEAREST # 最近邻插值
else:
method = cv2.INTER_CUBIC # 双三次插值
# 先缩小再放大
scale_factor = random.uniform(0.3, 0.8)
small = cv2.resize(image, None, fx=scale_factor, fy=scale_factor, interpolation=method)
return cv2.resize(small, (image.shape[1], image.shape[0]), interpolation=method)
def bilinear_degradation(image):
"""双线性插值退化"""
# 先缩小再放大
scale_factor = random.uniform(0.3, 0.8)
small = cv2.resize(image, None, fx=scale_factor, fy=scale_factor, interpolation=cv2.INTER_LINEAR)
return cv2.resize(small, (image.shape[1], image.shape[0]), interpolation=cv2.INTER_LINEAR)
def cutmix(img1, img2, bboxes1=None, bboxes2=None, beta=1.0):
"""
参数:
img1: 第一张输入图像(numpy数组)
img2: 第二张输入图像(numpy数组)
bboxes1: 第一张图像的边界框(可选)
bboxes2: 第二张图像的边界框(可选)
beta: Beta分布的参数,控制裁剪区域的大小
返回:
混合后的图像和边界框(如果有)
"""
# 确保图像尺寸相同
if img1.shape != img2.shape:
img2 = cv2.resize(img2, (img1.shape[1], img1.shape[0]))
h, w = img1.shape[:2]
# 生成裁剪区域的lambda值(混合比例)
lam = np.random.beta(beta, beta)
# 计算裁剪区域的宽高
cut_ratio = np.sqrt(1. - lam)
cut_w = int(w * cut_ratio)
cut_h = int(h * cut_ratio)
# 随机确定裁剪区域的中心点
cx = np.random.randint(w)
cy = np.random.randint(h)
# 计算裁剪区域的边界
x1 = np.clip(cx - cut_w // 2, 0, w)
y1 = np.clip(cy - cut_h // 2, 0, h)
x2 = np.clip(cx + cut_w // 2, 0, w)
y2 = np.clip(cy + cut_h // 2, 0, h)
# 执行CutMix操作
mixed_img = img1.copy()
mixed_img[y1:y2, x1:x2] = img2[y1:y2, x1:x2]
# 计算实际的混合比例
lam = 1 - ((x2 - x1) * (y2 - y1) / (w * h))
# 处理边界框(如果有)
mixed_bboxes = None
if bboxes1 is not None and bboxes2 is not None:
mixed_bboxes = []
# 添加第一张图像的边界框
for bbox in bboxes1:
mixed_bboxes.append(bbox + [lam]) # 添加混合权重
# 添加第二张图像的边界框(只添加在裁剪区域内的)
for bbox in bboxes2:
# 检查边界框是否在裁剪区域内
bbox_x_center = (bbox[0] + bbox[2]) / 2
bbox_y_center = (bbox[1] + bbox[3]) / 2
if (x1 <= bbox_x_center <= x2) and (y1 <= bbox_y_center <= y2):
mixed_bboxes.append(bbox + [1 - lam])
return mixed_img, mixed_bboxes
def image_degradation_and_augmentation(image, bboxes=None):
"""
完整的图像退化和增强流程(修改为使用CutMix)
参数:
image: 输入图像(PIL.Image或numpy数组)
bboxes: 边界框(可选)
返回:
dict: 包含所有退化组和最终增强结果的字典
"""
# 确保输入是numpy数组
if not isinstance(image, np.ndarray):
image = np.array(image)
# 确保图像为RGB格式
if len(image.shape) == 2:
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
elif image.shape[2] == 4:
image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
degraded_sets = []
original = image.copy()
# 第一组退化:三种基础退化
degraded_sets.append(wavelet_degradation(original.copy()))
degraded_sets.append(degraded_sets)
degraded_sets.append(adaptive_interpolation_degradation(original.copy()))
degraded_sets.append(degraded_sets)
degraded_sets.append(bilinear_degradation(original.copy()))
degraded_sets.append(degraded_sets)
# # 原始图像
# original = image.copy()
# # 插值方法列表
# interpolation_methods = ['nearest', 'bilinear', 'bicubic']
# # 第一组退化: 三种插值方法
# group1 = []
# for method in interpolation_methods:
# degraded = apply_interpolation_degradation(original, method)
# group1.append(degraded)
# 第二组退化: 随机额外退化
# group2 = []
# for img in group1:
# # 随机选择一种退化方法
# method = random.choice(interpolation_methods)
# extra_degraded = apply_interpolation_degradation(img, method)
# group2.append(extra_degraded)
# 第二组退化:随机选择再退化
methods = [wavelet_degradation, adaptive_interpolation_degradation, bilinear_degradation]
group2=[]
for img in degraded_sets:
selected_method = random.choice(methods)
group2.append(selected_method(img))
group2.append(group2)
# 原始图像
original = image.copy()
all_degraded_images = [original] + degraded_sets + group2
# 应用黑暗处理
dark_original = darken_image(original)
dark_degraded = [darken_image(img) for img in all_degraded_images]
# 合并原始和退化图像
all_images = [dark_original] + dark_degraded
# 应用数据增强
# 1. 随机仿射变换
affine_images = [random_affine(img) for img in all_images]
# 2. HSV增强
hsv_images = [augment_hsv(img) for img in affine_images]
# 3. CutMix增强
# 随机选择两个增强后的图像进行混合
mixed_image, mixed_bboxes = cutmix(
random.choice(hsv_images),
random.choice(hsv_images),
bboxes1=bboxes if bboxes is not None else None,
bboxes2=bboxes if bboxes is not None else None
)
# 返回结果
results = {
'original': original,
'degraded': dark_degraded,
'augmented_images': hsv_images, # 所有增强后的图像(原始+六组退化)
'cutmix_image': mixed_image, # CutMix混合图像
'cutmix_bboxes': mixed_bboxes if bboxes is not None else None # 混合后的边界框
}
return results
train_transform = transforms.Compose([
transforms.RandomResizedCrop(32),
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8),
transforms.RandomGrayscale(p=0.2),
transforms.ToTensor(),
transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])])
test_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])])
# data_processing prepare
train_data = CIFAR10Pair(root="E:/contrast/yolov8/MoCo/data_visdrone2019", train=True, transform=train_transform, download=False)
moco_train_loader = DataLoader(train_data, batch_size=args.batch_size, shuffle=True, num_workers=0, pin_memory=True, drop_last=True)
memory_data = CIFAR10(root="E:/contrast/yolov8/MoCo/data_visdrone2019", train=True, transform=test_transform, download=False)
memory_loader = DataLoader(memory_data, batch_size=args.batch_size, shuffle=False, num_workers=0, pin_memory=True)
test_data = CIFAR10(root="E:/contrast/yolov8/MoCo/data_visdrone2019", train=False, transform=test_transform, download=False)
test_loader = DataLoader(test_data, batch_size=args.batch_size, shuffle=False, num_workers=0, pin_memory=True)
# 定义基本编码器
# SplitBatchNorm: simulate multi-gpu behavior of BatchNorm in one gpu by splitting alone the batch dimension
# implementation adapted from https://github.com/davidcpage/cifar10-fast/blob/master/torch_backend.py
class SplitBatchNorm(nn.BatchNorm2d):
def __init__(self, num_features, num_splits, **kw):
super().__init__(num_features, **kw)
self.num_splits = num_splits
def forward(self, input):
N, C, H, W = input.shape
if self.training or not self.track_running_stats:
running_mean_split = self.running_mean.repeat(self.num_splits)
running_var_split = self.running_var.repeat(self.num_splits)
outcome = nn.functional.batch_norm(
input.view(-1, C * self.num_splits, H, W), running_mean_split, running_var_split,
self.weight.repeat(self.num_splits), self.bias.repeat(self.num_splits),
True, self.momentum, self.eps).view(N, C, H, W)
self.running_mean.data.copy_(running_mean_split.view(self.num_splits, C).mean(dim=0))
self.running_var.data.copy_(running_var_split.view(self.num_splits, C).mean(dim=0))
return outcome
else:
return nn.functional.batch_norm(
input, self.running_mean, self.running_var,
self.weight, self.bias, False, self.momentum, self.eps)
class ModelBase(nn.Module):
"""
Common CIFAR ResNet recipe.
Comparing with ImageNet ResNet recipe, it:
(i) replaces conv1 with kernel=3, str=1
(ii) removes pool1
"""
def __init__(self, feature_dim=128, arch=None, bn_splits=16):
super(ModelBase, self).__init__()
# use split batchnorm
norm_layer = partial(SplitBatchNorm, num_splits=bn_splits) if bn_splits > 1 else nn.BatchNorm2d
resnet_arch = getattr(resnet, arch)
net = resnet_arch(num_classes=feature_dim, norm_layer=norm_layer)
self.net = []
for name, module in net.named_children():
if name == 'conv1':
module = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
if isinstance(module, nn.MaxPool2d):
continue
if isinstance(module, nn.Linear):
self.net.append(nn.Flatten(1))
self.net.append(module)
self.net = nn.Sequential(*self.net)
def forward(self, x):
x = self.net(x)
# note: not normalized here
return x
# 定义MOCO
class ModelMoCo(nn.Module):
def __init__(self, dim=128, K=4096, m=0.99, T=0.1, arch='resnet18', bn_splits=8, symmetric=True):
super(ModelMoCo, self).__init__()
self.K = K
self.m = m
self.T = T
self.symmetric = symmetric
# create the encoders
self.encoder_q = ModelBase(feature_dim=dim, arch=arch, bn_splits=bn_splits)
self.encoder_k = ModelBase(feature_dim=dim, arch=arch, bn_splits=bn_splits)
for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
param_k.data.copy_(param_q.data) # initialize
param_k.requires_grad = False # not update by gradient 不参与训练
# create the queue
self.register_buffer("queue", torch.randn(dim, K))
self.queue = nn.functional.normalize(self.queue, dim=0)
self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))
@torch.no_grad()
def _momentum_update_key_encoder(self): # 动量更新encoder_k
"""
Momentum update of the key encoder
"""
for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
param_k.data = param_k.data * self.m + param_q.data * (1. - self.m)
@torch.no_grad()
def _dequeue_and_enqueue(self, keys): # 出队与入队
batch_size = keys.shape[0]
ptr = int(self.queue_ptr)
assert self.K % batch_size == 0 # for simplicity
# replace the keys at ptr (dequeue and enqueue)
self.queue[:, ptr:ptr + batch_size] = keys.t() # transpose
ptr = (ptr + batch_size) % self.K # move pointer
self.queue_ptr[0] = ptr
@torch.no_grad()
def _batch_shuffle_single_gpu(self, x):
"""
Batch shuffle, for making use of BatchNorm.
"""
# random shuffle index
idx_shuffle = torch.randperm(x.shape[0]).cuda()
# index for restoring
idx_unshuffle = torch.argsort(idx_shuffle)
return x[idx_shuffle], idx_unshuffle
@torch.no_grad()
def _batch_unshuffle_single_gpu(self, x, idx_unshuffle):
"""
Undo batch shuffle.
"""
return x[idx_unshuffle]
def contrastive_loss(self, im_q, im_k):
# compute query features
q = self.encoder_q(im_q) # queries: NxC
q = nn.functional.normalize(q, dim=1) # already normalized
# compute key features
with torch.no_grad(): # no gradient to keys
# shuffle for making use of BN
im_k_, idx_unshuffle = self._batch_shuffle_single_gpu(im_k)
k = self.encoder_k(im_k_) # keys: NxC
k = nn.functional.normalize(k, dim=1) # already normalized
# undo shuffle
k = self._batch_unshuffle_single_gpu(k, idx_unshuffle)
# compute logits
# Einstein sum is more intuitive
# positive logits: Nx1
l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1)
# negative logits: NxK
l_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()])
# logits: Nx(1+K)
logits = torch.cat([l_pos, l_neg], dim=1)
# apply temperature
logits /= self.T
# labels: positive key indicators
labels = torch.zeros(logits.shape[0], dtype=torch.long).cuda()
loss = nn.CrossEntropyLoss().cuda()(logits, labels) # 交叉熵损失
return loss, q, k
def forward(self, im1, im2):
"""
Input:
im_q: a batch of query images
im_k: a batch of key images
Output:
loss
"""
# update the key encoder
with torch.no_grad(): # no gradient to keys
self._momentum_update_key_encoder()
# compute loss
if self.symmetric: # asymmetric loss
loss_12, q1, k2 = self.contrastive_loss(im1, im2)
loss_21, q2, k1 = self.contrastive_loss(im2, im1)
loss = loss_12 + loss_21
k = torch.cat([k1, k2], dim=0)
else: # asymmetric loss
loss, q, k = self.contrastive_loss(im1, im2)
self._dequeue_and_enqueue(k)
return loss
# create model
moco_model = ModelMoCo(
dim=args.moco_dim,
K=args.moco_k,
m=args.moco_m,
T=args.moco_t,
arch=args.arch,
bn_splits=args.bn_splits,
symmetric=args.symmetric,
).cuda()
# print(moco_model.encoder_q)
moco_model_1 = ModelMoCo(
dim=args.moco_dim,
K=args.moco_k,
m=args.moco_m,
T=args.moco_t,
arch=args.arch,
bn_splits=args.bn_splits,
symmetric=args.symmetric,
).cuda()
# print(moco_model_1.encoder_q)
"""
CIFAR10 Dataset.
"""
from torch.cuda import amp
scaler = amp.GradScaler(enabled=cuda)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# train for one epoch
# def moco_train(net, net_1, data_loader, train_optimizer, epoch, args):
# net.train()
# adjust_learning_rate(moco_optimizer, epoch, args)
# total_loss, total_num, train_bar = 0.0, 0, tqdm(data_loader)
# loss_add = 0.0
# for im_1, im_2 in train_bar:
# im_1, im_2 = im_1.cuda(non_blocking=True), im_2.cuda(non_blocking=True)
# loss = net(im_1, im_2) # 原始图像对比损失 梯度清零—>梯度回传—>梯度跟新
# # lossT = loss # 只使用原始对比损失
# # train_optimizer.zero_grad()
# # lossT.backward()
# # train_optimizer.step()
# # loss_add += lossT.item()
# # total_num += data_loader.batch_size
# # total_loss += loss.item() * data_loader.batch_size
# # train_bar.set_description(
# # 'Train Epoch: [{}/{}], lr: {:.6f}, Loss: {:.4f}'.format(
# # epoch, args.epochs,
# # train_optimizer.param_groups[0]['lr'],
# # loss_add / total_num
# # )
# # )
# #傅里叶变换处理流程
# #im_3 = torch.rfft(im_1, 3, onesided=False, normalized=True)[:, :, :, :, 0]
# fft_output = torch.fft.fftn(im_1, dim=(-3, -2, -1), norm="ortho")#转换为频域
# real_imag = torch.view_as_real(fft_output)#分解实部虚部
# im_3 = real_imag[..., 0]#提取频域实部作为新视图
# #该处理实现了频域空间的增强,与空间域增强形成了互补
# #im_4 = torch.rfft(im_2, 3, onesided=False, normalized=True)[:, :, :, :, 0]
# fft_output = torch.fft.fftn(im_2, dim=(-3, -2, -1), norm="ortho")
# real_imag = torch.view_as_real(fft_output)
# im_4 = real_imag[..., 0]
# loss_1 = net_1(im_3, im_4)#频域特征对比损失
# lossT = 0.8*loss + 0.2*loss_1#多模态损失对比融合
# train_optimizer.zero_grad()
# lossT.backward()
# train_optimizer.step()
# loss_add += lossT
# total_num += data_loader.batch_size
# total_loss += loss.item() * data_loader.batch_size
# # train_bar.set_description(
# # 'Train Epoch: [{}/{}], lr: {:.6f}, Loss: {:.4f}'.format(epoch, args.epochs, moco_optimizer.param_groups[0]['lr'],
# # loss_add / total_num))
# return (loss_add / total_num).cpu().item() # yolov5需要的损失
def moco_train(net, net_1, data_loader, train_optimizer, epoch, args):
net.train()
adjust_learning_rate(train_optimizer, epoch, args)
total_loss, total_num = 0.0, 0
train_bar = tqdm(data_loader)
for im_1, im_2, im_3, im_4 in train_bar: # 接收4组视图
im_1, im_2 = im_1.cuda(), im_2.cuda()
im_3, im_4 = im_3.cuda(), im_4.cuda()
# 原始空间域对比损失
loss_orig = net(im_1, im_2)
# 退化增强图像的空间域对比损失
loss_degraded = net(im_3, im_4)
# 频域处理(对退化增强后的图像)
fft_3 = torch.fft.fftn(im_3, dim=(-3, -2, -1), norm="ortho")
fft_3 = torch.view_as_real(fft_3)[..., 0] # 取实部
fft_4 = torch.fft.fftn(im_4, dim=(-3, -2, -1), norm="ortho")
fft_4 = torch.view_as_real(fft_4)[..., 0]
# 频域对比损失
loss_freq = net_1(fft_3, fft_4)
# 多模态损失融合
loss = 0.6 * loss_orig + 0.3 * loss_degraded + 0.1 * loss_freq
# 反向传播
train_optimizer.zero_grad()
loss.backward()
train_optimizer.step()
# 记录损失
total_num += data_loader.batch_size
total_loss += loss.item()
# train_bar.set_description(f'Epoch: [{epoch}/{args.epochs}] Loss: {total_loss/total_num:.4f}')
return total_loss / total_num
# lr scheduler for training
def adjust_learning_rate(optimizer, epoch, args): # 学习率衰减
"""Decay the learning rate based on schedule"""
lr = args.lr
if args.cos: # cosine lr schedule
lr *= 0.5 * (1. + math.cos(math.pi * epoch / args.epochs))
else: # stepwise lr schedule
for milestone in args.schedule:
lr *= 0.1 if epoch >= milestone else 1.
for param_group in optimizer.param_groups:
param_group['lr'] = lr
# test using a knn monitor
def test(net, memory_data_loader, test_data_loader, epoch, args):
net.eval()
classes = len(memory_data_loader.dataset.classes)
total_top1, total_top5, total_num, feature_bank = 0.0, 0.0, 0, []
with torch.no_grad():
# generate feature bank
for data, target in tqdm(memory_data_loader, desc='Feature extracting'):
feature = net(data.cuda(non_blocking=True))
feature = F.normalize(feature, dim=1)
feature_bank.append(feature)
# [D, N]
feature_bank = torch.cat(feature_bank, dim=0).t().contiguous()
# [N]
feature_labels = torch.tensor(memory_data_loader.dataset.targets, device=feature_bank.device)
# loop test data_processing to predict the label by weighted knn search
test_bar = tqdm(test_data_loader)
for data, target in test_bar:
data, target = data.cuda(non_blocking=True), target.cuda(non_blocking=True)
feature = net(data)
feature = F.normalize(feature, dim=1)
pred_labels = knn_predict(feature, feature_bank, feature_labels, classes, args.knn_k, args.knn_t)
total_num += data.size(0)
total_top1 += (pred_labels[:, 0] == target).float().sum().item()
test_bar.set_description(
'Test Epoch: [{}/{}] Acc@1:{:.2f}%'.format(epoch, args.epochs, total_top1 / total_num * 100))
return total_top1 / total_num * 100
# knn monitor as in InstDisc https://arxiv.org/abs/1805.01978
# implementation follows http://github.com/zhirongw/lemniscate.pytorch and https://github.com/leftthomas/SimCLR
def knn_predict(feature, feature_bank, feature_labels, classes, knn_k, knn_t):
# compute cos similarity between each feature vector and feature bank ---> [B, N]
sim_matrix = torch.mm(feature, feature_bank)
# [B, K]
sim_weight, sim_indices = sim_matrix.topk(k=knn_k, dim=-1)
# [B, K]
sim_labels = torch.gather(feature_labels.expand(feature.size(0), -1), dim=-1, index=sim_indices)
sim_weight = (sim_weight / knn_t).exp()
# counts for each class
one_hot_label = torch.zeros(feature.size(0) * knn_k, classes, device=sim_labels.device)
# [B*K, C]
one_hot_label = one_hot_label.scatter(dim=-1, index=sim_labels.view(-1, 1), value=1.0)
# weighted score ---> [B, C]
pred_scores = torch.sum(one_hot_label.view(feature.size(0), -1, classes) * sim_weight.unsqueeze(dim=-1), dim=1)
pred_labels = pred_scores.argsort(dim=-1, descending=True)
return pred_labels
# 开始训练
# define optimizer
moco_optimizer = torch.optim.SGD(moco_model.parameters(), lr=args.lr, weight_decay=args.wd, momentum=0.9)
上述问题怎么修改?
最新发布