可以将”import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from mat1 import MyDataset
from network import Network, calc_coeff
from ift import IFT_Module
import numpy as np
class CustomModel(nn.Module):
def __init__(self,n_cls, K, confi_threshold=0.9):
super().__init__()
self.confi = confi_threshold
self.dim = 4
self.n_cls = n_cls
self.K = K
self.source_feat_bank = nn.Parameter(torch.zeros(n_cls * K, self.dim).half(), requires_grad=False)
self.target_feat_bank = nn.Parameter(torch.zeros(n_cls * K, self.dim).half(), requires_grad=False)
#self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.2))
self.source_max_probs_list = [0.0 for i in range(self.n_cls * self.K)]
self.target_max_probs_list = [0.0 for i in range(self.n_cls * self.K)]
self.source_key_dict = {i: i for i in range(self.n_cls * self.K)}
self.target_key_dict = {i: i for i in range(self.n_cls * self.K)}
def forward(self, pseudo_label, construct=True, source=True, label=None):
if construct:
max_probs, label_p = torch.max(pseudo_label, dim=-1)
if source:
for i, l in enumerate(label):
if l == label_p[i]:
index = l.item() * self.K
l_list = self.source_max_probs_list[index: index + self.K]
if max_probs[i] > min(l_list):
min_index = l_list.index(min(l_list))
self.source_max_probs_list[index + min_index] = max_probs[i]
self.source_feat_bank[index + min_index] = pseudo_label[i].clone()
self.source_key_dict[index + min_index] = label_p[i]
else:
for i, l in enumerate(label_p):
index = l.item() * self.K
l_list = self.target_max_probs_list[index: index + self.K]
min_index = l_list.index(min(l_list)) # 确保 min_index 被正确赋值
if max_probs[i] > l_list[min_index]:
self.target_max_probs_list[index + min_index] = max_probs[i]
self.target_feat_bank[index + min_index] = pseudo_label[i].clone()
self.target_key_dict[index+min_index] = label_p[i]
source_bank = torch.mean(self.source_feat_bank.reshape(self.n_cls, self.K, self.dim), dim=1) # 调整形状,torch.mean(..., dim=1) 对每个类别的 K 个特征取平均值,得到一个大小为 (n_cls, dim) 的 source_bank 张量。这表示每个类别都有一个(K, dim)的平均特征向量
target_bank = torch.mean(self.target_feat_bank.reshape(self.n_cls, self.K, self.dim), dim=1)
return source_bank, target_bank
def parse_batch_test(batch):
input_data = batch[1]
labels = batch[2]
return input_data, labels
def construct_bank(model, base_network, dset_loaders, device):
with torch.no_grad():
model.eval()
base_network.eval()
print("Constructing source feature bank...")
for batch in tqdm(dset_loaders["source"]):
signals, labels= parse_batch_test(batch)
signals, labels= signals.to(device).clone(), labels.to(device).clone()
features_source, outputs_source = base_network(signals)
# pseudo_label = torch.softmax(outputs_source, dim=-1)
source_bank, _ = model(outputs_source, construct=True, source=True, label=labels)
if min(model.source_max_probs_list) > 0.99:
break
print("Constructing target feature bank...")
for batch in tqdm(dset_loaders["target"]):
signals, labels = parse_batch_test(batch)
signals, labels = signals.to(device).clone(), labels.to(device).clone()
features_target, outputs_target = base_network(signals)
# pseudo_label = torch.softmax(outputs_target, dim=-1)
_, target_bank = model(outputs_target, construct=True, source=False)
if min(model.target_max_probs_list) > 0.99:
break
print('Feature banks are completed!')
print("Source Feature Bank:")
print(source_bank)
print("Target Feature Bank:")
print(target_bank)
return source_bank, target_bank
# if __name__ == "__main__":
# # 设置需要的值
# n_cls = 4 # 类别数量
# K = 5 # 每个类别存储的特征数量
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#
# dsets = {}
# dset_loaders = {}
# dsets["source"] = MyDataset(ext='source')
# dsets["target"] = MyDataset(ext='target')
# dset_loaders["source"] = DataLoader(dsets["source"], batch_size=64, \
# shuffle=False, drop_last=False)
# dset_loaders["target"] = DataLoader(dsets["target"], batch_size=64, \
# shuffle=False, drop_last=False)
# class_num = 4
#
# base_network = Network()
# base_network = base_network.cuda()
#
# custom_model = CustomModel(n_cls=n_cls, K=K)
# custom_model.to(device)
#
#
# # 构建特征库
# construct_bank(custom_model, base_network, dset_loaders, device)
#
# # 从 custom_model 获取 source_bank 和 target_bank
# source_bank, target_bank = custom_model(pseudo_label=None, construct=False)
#
# # 这里你需要加载或定义你的 CLIP 模型
# ift_module = IFT_Module(custom_model, beta_s=1.0, beta_t=1.0)
# ift_module.to(device)
#
# for batch in dset_loaders["source"]:
# labels, signals = parse_batch_test(batch)
# labels, signals = labels.to(device), signals.to(device)
# Ft = source_bank # 源域图像特征 Ft
# Fv = base_network(signals) # 当前批次的源域图像特征 Fv
# logits = ift_module(Ft, Fv, source_bank, target_bank)
# # logits 即为 IFT_Module 的输出,可以用于进一步处理或计算损失函数等
#
# #for batch in dset_loaders["target"]:
# # signals = batch[1].to(device) # batch[1] 是目标域图像特征
# # Fv = base_network(signals) # 当前批次的目标域图像特征
# # Ft = target_bank(signals)
# # logits = ift_module(Ft, Fv, source_bank, target_bank)
# print(logits)“这段代码写成伪代码吗