from . import BaseActor
from lib.utils.misc import NestedTensor
from lib.utils.box_ops import box_cxcywh_to_xyxy, box_xywh_to_xyxy
import torch
import math
import numpy as np
import numpy
import cv2
import torch.nn.functional as F
import torchvision.transforms.functional as tvisf
import lib.train.data.bounding_box_utils as bbutils
from lib.utils.merge import merge_template_search
from torch.distributions.categorical import Categorical
from ...utils.heapmap_utils import generate_heatmap
from ...utils.ce_utils import generate_mask_cond, adjust_keep_rate
def IoU(rect1, rect2):
""" caculate interection over union
Args:
rect1: (x1, y1, x2, y2)
rect2: (x1, y1, x2, y2)
Returns:
iou
"""
# overlap
x1, y1, x2, y2 = rect1[0], rect1[1], rect1[2], rect1[3]
tx1, ty1, tx2, ty2 = rect2[0], rect2[1], rect2[2], rect2[3]
xx1 = np.maximum(tx1, x1)
yy1 = np.maximum(ty1, y1)
xx2 = np.minimum(tx2, x2)
yy2 = np.minimum(ty2, y2)
ww = np.maximum(0, xx2 - xx1)
hh = np.maximum(0, yy2 - yy1)
area = (x2 - x1) * (y2 - y1)
target_a = (tx2 - tx1) * (ty2 - ty1)
inter = ww * hh
iou = inter / (area + target_a - inter)
return iou
def fp16_clamp(x, min=None, max=None):
if not x.is_cuda and x.dtype == torch.float16:
# clamp for cpu float16, tensor fp16 has no clamp implementation
return x.float().clamp(min, max).half()
return x.clamp(min, max)
def generate_sa_simdr(joints):
'''
:param joints: [num_joints, 3]
:param joints_vis: [num_joints, 3]
:return: target, target_weight(1: visible, 0: invisible)
'''
num_joints = 48
image_size = [256, 256]
simdr_split_ratio = 1.5625
sigma = 6
target_x1 = np.zeros((num_joints,
int(image_size[0] * simdr_split_ratio)),
dtype=np.float32)
target_y1 = np.zeros((num_joints,
int(image_size[1] * simdr_split_ratio)),
dtype=np.float32)
target_x2 = np.zeros((num_joints,
int(image_size[0] * simdr_split_ratio)),
dtype=np.float32)
target_y2 = np.zeros((num_joints,
int(image_size[1] * simdr_split_ratio)),
dtype=np.float32)
zero_4_begin = np.zeros((num_joints, 1), dtype=np.float32)
tmp_size = sigma * 3
for joint_id in range(num_joints):
mu_x1 = joints[joint_id][0]
mu_y1 = joints[joint_id][1]
mu_x2 = joints[joint_id][2]
mu_y2 = joints[joint_id][3]
x1 = np.arange(0, int(image_size[0] * simdr_split_ratio), 1, np.float32)
y1 = np.arange(0, int(image_size[1] * simdr_split_ratio), 1, np.float32)
x2 = np.arange(0, int(image_size[0] * simdr_split_ratio), 1, np.float32)
y2 = np.arange(0, int(image_size[1] * simdr_split_ratio), 1, np.float32)
target_x1[joint_id] = (np.exp(- ((x1 - mu_x1) ** 2) / (2 * sigma ** 2))) / (
sigma * np.sqrt(np.pi * 2))
target_y1[joint_id] = (np.exp(- ((y1 - mu_y1) ** 2) / (2 * sigma ** 2))) / (
sigma * np.sqrt(np.pi * 2))
target_x2[joint_id] = (np.exp(- ((x2 - mu_x2) ** 2) / (2 * sigma ** 2))) / (
sigma * np.sqrt(np.pi * 2))
target_y2[joint_id] = (np.exp(- ((y2 - mu_y2) ** 2) / (2 * sigma ** 2))) / (
sigma * np.sqrt(np.pi * 2))
return target_x1, target_y1, target_x2, target_y2
# angle cost
def SIoU_loss(test1, test2, theta=4):
eps = 1e-7
cx_pred = (test1[:, 0] + test1[:, 2]) / 2
cy_pred = (test1[:, 1] + test1[:, 3]) / 2
cx_gt = (test2[:, 0] + test2[:, 2]) / 2
cy_gt = (test2[:, 1] + test2[:, 3]) / 2
dist = ((cx_pred - cx_gt) ** 2 + (cy_pred - cy_gt) ** 2) ** 0.5
ch = torch.max(cy_gt, cy_pred) - torch.min(cy_gt, cy_pred)
x = ch / (dist + eps)
angle = 1 - 2 * torch.sin(torch.arcsin(x) - torch.pi / 4) ** 2
# distance cost
xmin = torch.min(test1[:, 0], test2[:, 0])
xmax = torch.max(test1[:, 2], test2[:, 2])
ymin = torch.min(test1[:, 1], test2[:, 1])
ymax = torch.max(test1[:, 3], test2[:, 3])
cw = xmax - xmin
ch = ymax - ymin
px = ((cx_gt - cx_pred) / (cw + eps)) ** 2
py = ((cy_gt - cy_pred) / (ch + eps)) ** 2
gama = 2 - angle
dis = (1 - torch.exp(-1 * gama * px)) + (1 - torch.exp(-1 * gama * py))
# shape cost
w_pred = test1[:, 2] - test1[:, 0]
h_pred = test1[:, 3] - test1[:, 1]
w_gt = test2[:, 2] - test2[:, 0]
h_gt = test2[:, 3] - test2[:, 1]
ww = torch.abs(w_pred - w_gt) / (torch.max(w_pred, w_gt) + eps)
wh = torch.abs(h_gt - h_pred) / (torch.max(h_gt, h_pred) + eps)
omega = (1 - torch.exp(-1 * wh)) ** theta + (1 - torch.exp(-1 * ww)) ** theta
# IoU loss
lt = torch.max(test1[..., :2], test2[..., :2]) # [B, rows, 2]
rb = torch.min(test1[..., 2:], test2[..., 2:]) # [B, rows, 2]
wh = fp16_clamp(rb - lt, min=0)
overlap = wh[..., 0] * wh[..., 1]
area1 = (test1[..., 2] - test1[..., 0]) * (
test1[..., 3] - test1[..., 1])
area2 = (test2[..., 2] - test2[..., 0]) * (
test2[..., 3] - test2[..., 1])
iou = overlap / (area1 + area2 - overlap)
SIoU = 1 - iou + (omega + dis) / 2
return SIoU, iou
def ciou(pred, target, eps=1e-7):
# overlap
lt = torch.max(pred[:, :2], target[:, :2])
rb = torch.min(pred[:, 2:], target[:, 2:])
wh = (rb - lt).clamp(min=0)
overlap = wh[:, 0] * wh[:, 1]
# union
ap = (pred[:, 2] - pred[:, 0]) * (pred[:, 3] - pred[:, 1])
ag = (target[:, 2] - target[:, 0]) * (target[:, 3] - target[:, 1])
union = ap + ag - overlap + eps
# IoU
ious = overlap / union
# enclose area
enclose_x1y1 = torch.min(pred[:, :2], target[:, :2])
enclose_x2y2 = torch.max(pred[:, 2:], target[:, 2:])
enclose_wh = (enclose_x2y2 - enclose_x1y1).clamp(min=0)
cw = enclose_wh[:, 0]
ch = enclose_wh[:, 1]
c2 = cw ** 2 + ch ** 2 + eps
b1_x1, b1_y1 = pred[:, 0], pred[:, 1]
b1_x2, b1_y2 = pred[:, 2], pred[:, 3]
b2_x1, b2_y1 = target[:, 0], target[:, 1]
b2_x2, b2_y2 = target[:, 2], target[:, 3]
w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1 + eps
w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + eps
left = ((b2_x1 + b2_x2) - (b1_x1 + b1_x2)) ** 2 / 4
right = ((b2_y1 + b2_y2) - (b1_y1 + b1_y2)) ** 2 / 4
rho2 = left + right
factor = 4 / math.pi ** 2
v = factor * torch.pow(torch.atan(w2 / h2) - torch.atan(w1 / h1), 2)
# CIoU
cious = ious - (rho2 / c2 + v ** 2 / (1 - ious + v))
return cious, ious
class ARTrackV2SeqActor(BaseActor):
""" Actor for training OSTrack models """
def __init__(self, net, objective, loss_weight, settings, bins, search_size, cfg=None):
super().__init__(net, objective)
self.loss_weight = loss_weight
self.settings = settings
self.bs = self.settings.batchsize # batch size
self.cfg = cfg
self.bins = bins
self.search_size = search_size
self.logsoftmax = torch.nn.LogSoftmax(dim=1)
self.focal = None
self.range = cfg.MODEL.RANGE
self.loss_weight['KL'] = 0
self.loss_weight['focal'] = 0
self.pre_num = cfg.MODEL.PRENUM
self.pre_bbox = None
self.x_feat_rem = None
def __call__(self, data):
"""
args:
data - The input data, should contain the fields 'template', 'search', 'gt_bbox'.
template_images: (N_t, batch, 3, H, W)
search_images: (N_s, batch, 3, H, W)
returns:
loss - the training loss
status - dict containing detailed losses
"""
# forward pass
out_dict = self.forward_pass(data)
# compute losses
loss, status = self.compute_losses(out_dict, data)
return loss, status
def _bbox_clip(self, cx, cy, width, height, boundary):
cx = max(0, min(cx, boundary[1]))
cy = max(0, min(cy, boundary[0]))
width = max(10, min(width, boundary[1]))
height = max(10, min(height, boundary[0]))
return cx, cy, width, height
def get_subwindow(self, im, pos, model_sz, original_sz, avg_chans):
"""
args:
im: bgr based image
pos: center position
model_sz: exemplar size
s_z: original size
avg_chans: channel average
"""
if isinstance(pos, float):
pos = [pos, pos]
sz = original_sz
im_sz = im.shape
c = (original_sz + 1) / 2
context_xmin = np.floor(pos[0] - c + 0.5)
context_xmax = context_xmin + sz - 1
context_ymin = np.floor(pos[1] - c + 0.5)
context_ymax = context_ymin + sz - 1
left_pad = int(max(0., -context_xmin))
top_pad = int(max(0., -context_ymin))
right_pad = int(max(0., context_xmax - im_sz[1] + 1))
bottom_pad = int(max(0., context_ymax - im_sz[0] + 1))
context_xmin = context_xmin + left_pad
context_xmax = context_xmax + left_pad
context_ymin = context_ymin + top_pad
context_ymax = context_ymax + top_pad
r, c, k = im.shape
if any([top_pad, bottom_pad, left_pad, right_pad]):
size = (r + top_pad + bottom_pad, c + left_pad + right_pad, k)
te_im = np.zeros(size, np.uint8)
te_im[top_pad:top_pad + r, left_pad:left_pad + c, :] = im
if top_pad:
te_im[0:top_pad, left_pad:left_pad + c, :] = avg_chans
if bottom_pad:
te_im[r + top_pad:, left_pad:left_pad + c, :] = avg_chans
if left_pad:
te_im[:, 0:left_pad, :] = avg_chans
if right_pad:
te_im[:, c + left_pad:, :] = avg_chans
im_patch = te_im[int(context_ymin):int(context_ymax + 1),
int(context_xmin):int(context_xmax + 1), :]
else:
im_patch = im[int(context_ymin):int(context_ymax + 1),
int(context_xmin):int(context_xmax + 1), :]
if not np.array_equal(model_sz, original_sz):
try:
im_patch = cv2.resize(im_patch, (model_sz, model_sz))
except:
return None
im_patch = im_patch.transpose(2, 0, 1)
im_patch = im_patch[np.newaxis, :, :, :]
im_patch = im_patch.astype(np.float32)
im_patch = torch.from_numpy(im_patch)
im_patch = im_patch.cuda()
return im_patch
def batch_init(self, images, template_bbox, initial_bbox) -> dict:
self.frame_num = 1
self.device = 'cuda'
# Convert bbox (x1, y1, w, h) -> (cx, cy, w, h)
template_bbox_1 = template_bbox[:, 0]
# 如果只有单帧,则复制一份作为第二帧
if template_bbox.shape[1] == 1:
template_bbox_2 = template_bbox_1.copy()
else:
template_bbox_2 = template_bbox[:, 1]
# 确保是二维
if template_bbox_1.ndim == 1:
template_bbox_1 = np.expand_dims(template_bbox_1, axis=0)
if template_bbox_2.ndim == 1:
template_bbox_2 = np.expand_dims(template_bbox_2, axis=0)
template_bbox_1 = bbutils.batch_xywh2center2(template_bbox_1)
template_bbox_2 = bbutils.batch_xywh2center2(template_bbox_2)
initial_bbox = bbutils.batch_xywh2center2(initial_bbox)
self.center_pos = initial_bbox[:, :2]
self.size = initial_bbox[:, 2:]
self.pre_bbox = initial_bbox
for i in range(self.pre_num - 1):
self.pre_bbox = numpy.concatenate((self.pre_bbox, initial_bbox), axis=1)
template_factor = self.cfg.DATA.TEMPLATE.FACTOR
w_z_1 = template_bbox_1[:, 2] * template_factor
h_z_1 = template_bbox_1[:, 3] * template_factor
s_z_1 = np.ceil(np.sqrt(w_z_1 * h_z_1))
w_z_2 = template_bbox_2[:, 2] * template_factor
h_z_2 = template_bbox_2[:, 3] * template_factor
s_z_2 = np.ceil(np.sqrt(w_z_2 * h_z_2))
self.channel_average = []
for img_pair in images:
# 提取第一帧 channel avg
avg_1 = np.mean(img_pair[0], axis=(0, 1))
# 如果有第二帧,用它;否则复制第一帧
if len(img_pair) > 1:
avg_2 = np.mean(img_pair[1], axis=(0, 1))
else:
avg_2 = avg_1 # 单帧 → 复制
self.channel_average.append(avg_1)
self.channel_average.append(avg_2)
self.channel_average = np.array(self.channel_average)
z_crop_list = []
z_1_list = []
z_2_list = []
for i in range(len(images)):
# 第一帧必须存在
img1 = images[i][0]
pos1 = template_bbox_1[i, :2]
# 安全获取第二帧图像
if len(images[i]) > 1:
img2 = images[i][1]
else:
img2 = img1 # 单帧 → 复制自己
pos2 = template_bbox_2[i, :2]
# 确保是 RGB 图像
if img1.ndim == 2:
img1 = np.stack([img1, img1, img1], axis=-1) # 灰度转 RGB
if img2.ndim == 2:
img2 = np.stack([img2, img2, img2], axis=-1)
here_crop_1 = self.get_subwindow(img1, pos1, self.cfg.DATA.TEMPLATE.SIZE, s_z_1[i],
self.channel_average[2 * i])
here_crop_2 = self.get_subwindow(img2, pos2, self.cfg.DATA.TEMPLATE.SIZE, s_z_2[i],
self.channel_average[2 * i + 1])
if here_crop_1 is None or here_crop_2 is None:
return None
z_crop_1 = here_crop_1.float().mul(1.0 / 255.0).clamp(0.0, 1.0)
z_crop_2 = here_crop_2.float().mul(1.0 / 255.0).clamp(0.0, 1.0)
self.mean = [0.485, 0.456, 0.406]
self.std = [0.229, 0.224, 0.225]
self.inplace = False
z_crop_1[0] = tvisf.normalize(z_crop_1[0], self.mean, self.std, self.inplace)
z_crop_2[0] = tvisf.normalize(z_crop_2[0], self.mean, self.std, self.inplace)
z_1_list.append(z_crop_1.unsqueeze(1).clone())
z_2_list.append(z_crop_2.unsqueeze(1).clone())
z_crop = torch.cat([z_crop_1.unsqueeze(1), z_crop_2.unsqueeze(1)], dim=1)
z_crop_list.append(z_crop.clone())
z_crop = torch.cat(z_crop_list, dim=0)
z_1_crop = torch.cat(z_1_list, dim=0)
z_2_crop = torch.cat(z_2_list, dim=0)
z_2_crop = z_2_crop.squeeze(1).to(self.net.module.backbone.word_embeddings.weight)
z_2_feat = self.net.module.backbone.patch_embed(z_2_crop)
out = {'template_images': z_crop, "z_1": z_1_crop, "z_2": z_2_crop, "z_2_feat": z_2_feat}
return out
# def batch_init(self, images, template_bbox, initial_bbox) -> dict:
# self.frame_num = 1
# self.device = 'cuda'
# # Convert bbox (x1, y1, w, h) -> (cx, cy, w, h)
# template_bbox_1 = template_bbox[:, 0]
# template_bbox_2 = template_bbox[:, 1]
# # 后添加——确保是二维
# if template_bbox_1.ndim == 1:
# template_bbox_1 = np.expand_dims(template_bbox_1, axis=0) # (4,) -> (1, 4)
# template_bbox_1 = bbutils.batch_xywh2center2(template_bbox_1)
# #template_bbox_1 = bbutils.batch_xywh2center2(template_bbox_1) # ndarray:(2*num_seq,4)
# template_bbox_2 = bbutils.batch_xywh2center2(template_bbox_2) # ndarray:(2*num_seq,4)
# initial_bbox = bbutils.batch_xywh2center2(initial_bbox) # ndarray:(2*num_seq,4)
# self.center_pos = initial_bbox[:, :2] # ndarray:(2*num_seq,2)
# self.size = initial_bbox[:, 2:] # ndarray:(2*num_seq,2)
# self.pre_bbox = initial_bbox
# for i in range(self.pre_num - 1):
# self.pre_bbox = numpy.concatenate((self.pre_bbox, initial_bbox), axis=1)
# template_factor = self.cfg.DATA.TEMPLATE.FACTOR
# w_z_1 = template_bbox_1[:, 2] * template_factor # ndarray:(2*num_seq)
# h_z_1 = template_bbox_1[:, 3] * template_factor # ndarray:(2*num_seq)
# s_z_1 = np.ceil(np.sqrt(w_z_1 * h_z_1)) # ndarray:(2*num_seq)
# w_z_2 = template_bbox_2[:, 2] * template_factor # ndarray:(2*num_seq)
# h_z_2 = template_bbox_2[:, 3] * template_factor # ndarray:(2*num_seq)
# s_z_2 = np.ceil(np.sqrt(w_z_2 * h_z_2)) # ndarray:(2*num_seq)
# self.channel_average = []
# for img in images:
# self.channel_average.append(np.mean(img[0], axis=(0, 1)))
# self.channel_average.append(np.mean(img[1], axis=(0, 1)))
# self.channel_average = np.array(self.channel_average) # ndarray:(2*num_seq,3)
# # get crop
# z_crop_list = []
# z_1_list = []
# z_2_list = []
# for i in range(len(images)):
# here_crop_1 = self.get_subwindow(images[i][0], template_bbox_1[i, :2],
# self.cfg.DATA.TEMPLATE.SIZE, s_z_1[i], self.channel_average[2 * i])
# here_crop_2 = self.get_subwindow(images[i][1], template_bbox_2[i, :2],
# self.cfg.DATA.TEMPLATE.SIZE, s_z_2[i], self.channel_average[2 * i + 1])
# z_crop_1 = here_crop_1.float().mul(1.0 / 255.0).clamp(0.0, 1.0)
# z_crop_2 = here_crop_2.float().mul(1.0 / 255.0).clamp(0.0, 1.0)
# self.mean = [0.485, 0.456, 0.406]
# self.std = [0.229, 0.224, 0.225]
# self.inplace = False
# z_crop_1[0] = tvisf.normalize(z_crop_1[0], self.mean, self.std, self.inplace)
# z_crop_2[0] = tvisf.normalize(z_crop_2[0], self.mean, self.std, self.inplace)
# z_1_list.append(z_crop_1.unsqueeze(1).clone())
# z_2_list.append(z_crop_2.unsqueeze(1).clone())
# z_crop = torch.concat([z_crop_1.unsqueeze(1), z_crop_2.unsqueeze(1)], dim=1)
# z_crop_list.append(z_crop.clone())
# z_crop = torch.cat(z_crop_list, dim=0) # Tensor(2*num_seq,3,128,128)
# z_1_crop = torch.cat(z_1_list, dim=0)
# z_2_crop = torch.cat(z_2_list, dim=0)
# z_2_crop = z_2_crop.squeeze(1).to(self.net.module.backbone.word_embeddings.weight)
# z_2_feat = self.net.module.backbone.patch_embed(z_2_crop)
# out = {'template_images': z_crop, "z_1": z_1_crop, "z_2": z_2_crop, "z_2_feat": z_2_feat}
# return out
def batch_track(self, img, gt_boxes, template, dz_feat, action_mode='max') -> dict:
search_factor = self.cfg.DATA.SEARCH.FACTOR
w_x = self.size[:, 0] * search_factor
h_x = self.size[:, 1] * search_factor
s_x = np.ceil(np.sqrt(w_x * h_x))
gt_boxes_corner = bbutils.batch_xywh2corner(gt_boxes) # ndarray:(2*num_seq,4)
initial_bbox = bbutils.batch_xywh2center2(gt_boxes)
x_crop_list = []
gt_in_crop_list = []
pre_seq_list = []
pre_seq_in_list = []
x_feat_list = []
target_in_search_list = []
update_feat_list = []
for i in range(len(img)):
template_factor = self.cfg.DATA.TEMPLATE.FACTOR
w_z_1 = initial_bbox[:, 2] * template_factor # ndarray:(2*num_seq)
h_z_1 = initial_bbox[:, 3] * template_factor # ndarray:(2*num_seq)
s_z_1 = np.ceil(np.sqrt(w_z_1 * h_z_1)) # ndarray:(2*num_seq)
channel_avg = np.mean(img[i], axis=(0, 1))
target_in_search = self.get_subwindow(img[i], initial_bbox[i, :2], self.cfg.DATA.TEMPLATE.SIZE,
round(s_z_1[i]), channel_avg)
x_crop = self.get_subwindow(img[i], self.center_pos[i], self.cfg.DATA.SEARCH.SIZE,
round(s_x[i]), channel_avg)
if x_crop == None:
return None
if target_in_search == None:
return None
for q in range(self.pre_num):
pre_seq_temp = bbutils.batch_center2corner(self.pre_bbox[:, 0 + 4 * q:4 + 4 * q])
if q == 0:
pre_seq = pre_seq_temp
else:
pre_seq = numpy.concatenate((pre_seq, pre_seq_temp), axis=1)
if gt_boxes_corner is not None and np.sum(np.abs(gt_boxes_corner[i] - np.zeros(4))) > 10:
pre_in = np.zeros(4 * self.pre_num)
for w in range(self.pre_num):
pre_in[0 + w * 4:2 + w * 4] = pre_seq[i, 0 + w * 4:2 + w * 4] - self.center_pos[i]
pre_in[2 + w * 4:4 + w * 4] = pre_seq[i, 2 + w * 4:4 + w * 4] - self.center_pos[i]
pre_in[0 + w * 4:4 + w * 4] = pre_in[0 + w * 4:4 + w * 4] * (
self.cfg.DATA.SEARCH.SIZE / s_x[i]) + self.cfg.DATA.SEARCH.SIZE / 2
pre_in[0 + w * 4:4 + w * 4] = pre_in[0 + w * 4:4 + w * 4] / self.cfg.DATA.SEARCH.SIZE
pre_seq_list.append(pre_in)
gt_in_crop = np.zeros(4)
gt_in_crop[:2] = gt_boxes_corner[i, :2] - self.center_pos[i]
gt_in_crop[2:] = gt_boxes_corner[i, 2:] - self.center_pos[i]
gt_in_crop = gt_in_crop * (self.cfg.DATA.SEARCH.SIZE / s_x[i]) + self.cfg.DATA.SEARCH.SIZE / 2
gt_in_crop[2:] = gt_in_crop[2:] - gt_in_crop[:2] # (x1,y1,x2,y2) to (x1,y1,w,h)
gt_in_crop_list.append(gt_in_crop)
else:
pre_in = np.zeros(4 * self.pre_num)
pre_seq_list.append(pre_in)
gt_in_crop_list.append(np.zeros(4))
pre_seq_input = torch.from_numpy(pre_in).clamp(-0.5 * self.range + 0.5, 0.5 + self.range * 0.5)
pre_seq_input = (pre_seq_input + (0.5 * self.range - 0.5)) * (self.bins - 1)
pre_seq_in_list.append(pre_seq_input.clone())
x_crop = x_crop.float().mul(1.0 / 255.0).clamp(0.0, 1.0)
target_in_search = target_in_search.float().mul(1.0 / 255.0).clamp(0.0, 1.0)
rem_x = x_crop
x_crop[0] = tvisf.normalize(x_crop[0], self.mean, self.std, self.inplace)
target_in_search[0] = tvisf.normalize(target_in_search[0], self.mean, self.std, self.inplace)
x_crop_list.append(x_crop.clone())
target_in_search_list.append(target_in_search.clone())
x_crop = torch.cat(x_crop_list, dim=0)
target_in_search = torch.cat(target_in_search_list, dim=0)
pre_seq_output = torch.cat(pre_seq_in_list, dim=0).reshape(-1, 4 * self.pre_num)
pre = torch.zeros_like(pre_seq_output)
outputs = self.net(template, dz_feat.cuda(), x_crop, seq_input=pre_seq_output, head_type=None,
stage="batch_track",
search_feature=self.x_feat_rem, target_in_search_img=target_in_search,
gt_bboxes=gt_boxes[-1])
selected_indices = outputs['seqs'].detach()
x_feat = outputs['x_feat'].detach().cpu()
self.x_feat_rem = x_feat.clone()
x_feat_list.append(x_feat.clone())
update_feat = outputs['dz_feat'].detach().cpu()
update_feat_list.append(update_feat.clone())
pred_bbox = selected_indices[:, 0:4].data.cpu().numpy()
bbox = (pred_bbox / (self.bins - 1) - (self.range * 0.5 - 0.5)) * s_x.reshape(-1, 1)
cx = bbox[:, 0] + self.center_pos[:, 0] - s_x / 2
cy = bbox[:, 1] + self.center_pos[:, 1] - s_x / 2
width = bbox[:, 2] - bbox[:, 0]
height = bbox[:, 3] - bbox[:, 1]
cx = cx + width / 2
cy = cy + height / 2
for i in range(len(img)):
cx[i], cy[i], width[i], height[i] = self._bbox_clip(cx[i], cy[i], width[i],
height[i], img[i].shape[:2])
self.center_pos = np.stack([cx, cy], 1)
self.size = np.stack([width, height], 1)
for e in range(self.pre_num):
if e != self.pre_num - 1:
self.pre_bbox[:, 0 + e * 4:4 + e * 4] = self.pre_bbox[:, 4 + e * 4:8 + e * 4]
else:
self.pre_bbox[:, 0 + e * 4:4 + e * 4] = numpy.stack([cx, cy, width, height], 1)
bbox = np.stack([cx - width / 2, cy - height / 2, width, height], 1)
out = {
'dz_feat': update_feat,
'search_images': x_crop,
'target_in_search': target_in_search,
'pred_bboxes': bbox,
'selected_indices': selected_indices.cpu(),
'gt_in_crop': torch.tensor(np.stack(gt_in_crop_list, axis=0), dtype=torch.float),
'pre_seq': torch.tensor(np.stack(pre_seq_list, axis=0), dtype=torch.float),
'x_feat': torch.tensor([item.cpu().detach().numpy() for item in x_feat_list], dtype=torch.float),
}
return out
def explore(self, data):
print("\n🔍 [DEBUG] Data keys:", data.keys())
for k, v in data.items():
if hasattr(v, 'shape'):
print(f" {k}: shape{list(v.shape)} dtype={v.dtype}")
elif isinstance(v, list):
shapes = [getattr(x, 'shape', type(x)) for x in v]
print(f" {k}: list[{len(v)}] -> {shapes}")
else:
print(f" {k}: {type(v)}")
results = {}
num_frames = data['num_frames']
images = data['search_images']
gt_bbox = data['search_annos']
template_img_list = data['template_images']
template_anno_list = data['template_annos'] # ← 这个是关键输入
num_seq = len(num_frames)
# ✅ STEP 1: 安全处理 template_anno_list → 统一为 (2, 4)
# 在 explore() 中
processed_template_bboxes = []
for i, anno in enumerate(template_anno_list):
if isinstance(anno, np.ndarray):
if anno.ndim == 1 and anno.size == 4:
anno = anno.reshape(1, 4) # (4,) -> (1,4)
elif anno.ndim == 2:
pass # 正常情况
else:
raise ValueError(f"Unexpected bbox shape: {anno.shape}")
else:
raise TypeError(f"Expected np.ndarray, got {type(anno)}")
# 如果只有一帧,就复制一份
if anno.shape[0] == 1:
anno = np.tile(anno, (2, 1)) # (1,4) -> (2,4)
elif anno.shape[0] >= 2:
anno = anno[:2] # 取前两帧
else:
raise ValueError(f"Bbox has zero frames: {i}")
processed_template_bboxes.append(anno)
template_bbox = np.stack(processed_template_bboxes) # (B, 2, 4)
search_images_list = []
search_anno_list = []
pre_seq_list = []
x_feat_list = []
target_in_search_list = []
template_all_list = []
dz_feat_udpate_list = []
for idx in range(np.max(num_frames)):
here_images = [img[idx] for img in images]
here_gt_bbox = np.stack([gt[idx] for gt in gt_bbox]) # (B, 4)
if idx == 0:
# ✅ 提取两个模板帧 bbox
template_bbox_1 = template_bbox[:, 0] # (B, 4)
template_bbox_2 = template_bbox[:, 1] # (B, 4)
# 构造符合 batch_init 预期的输入格式 (B, 2, 4)
combined_template_bbox = np.stack([template_bbox_1, template_bbox_2], axis=1)
outputs_template = self.batch_init(
template_img_list,
combined_template_bbox,
here_gt_bbox
)
if outputs_template is None:
return None
results['template_images'] = outputs_template['z_1']
self.template_temp = outputs_template['z_1'].clone()
z_all = [outputs_template['z_1'], outputs_template['z_2']]
results['z_all'] = z_all
self.dz_feat_update = outputs_template['z_2_feat']
else:
outputs = self.batch_track(here_images, here_gt_bbox, self.template_temp, self.dz_feat_update,
action_mode='half')
if outputs == None:
return None
template_all_list.append(self.template_temp.clone())
dz_feat_udpate_list.append(self.dz_feat_update.clone().to(outputs['dz_feat']))
x_feat = outputs['x_feat']
self.dz_feat_update = outputs['dz_feat']
pred_bbox = outputs['pred_bboxes']
search_images_list.append(outputs['search_images'])
target_in_search_list.append(outputs['target_in_search'])
search_anno_list.append(outputs['gt_in_crop'])
if len(outputs['pre_seq']) != 8:
print(outputs['pre_seq'])
print(len(outputs['pre_seq']))
print(idx)
print(data['num_frames'])
print(data['search_annos'])
return None
pre_seq_list.append(outputs['pre_seq'])
pred_bbox_corner = bbutils.batch_xywh2corner(pred_bbox)
gt_bbox_corner = bbutils.batch_xywh2corner(here_gt_bbox)
here_iou = []
for i in range(num_seq):
bbox_iou = IoU(pred_bbox_corner[i], gt_bbox_corner[i])
here_iou.append(bbox_iou)
iou_list.append(here_iou)
x_feat_list.append(x_feat.clone())
search_images_reverse_list = []
search_anno_reverse_list = []
action_tensor_reverse_list = []
iou_reverse_list = []
pre_seq_reverse_list = []
x_feat_reverse_list = []
target_in_search_reverse_list = []
dz_feat_update_reverse_list = []
template_all_reverse_list = []
for idx in range(np.max(num_frames)):
real_idx = np.max(num_frames) - 1 - idx
here_images = [img[real_idx] for img in images] # S, N
here_gt_bbox = np.array([gt[real_idx] for gt in gt_bbox])
here_images = here_images
here_gt_bbox = np.concatenate([here_gt_bbox], 0)
if idx == 0:
outputs_template = self.batch_init(template, template_bbox, here_gt_bbox)
results['template_images'] = outputs_template['z_1']
self.template_temp = outputs_template['z_1'].clone()
z_all = [outputs_template['z_1'], outputs_template['z_2']]
results['z_all'] = z_all
self.dz_feat_update = outputs_template['z_2_feat'].clone()
else:
outputs = self.batch_track(here_images, here_gt_bbox, self.template_temp, self.dz_feat_update,
action_mode='half')
if outputs == None:
return None
template_all_reverse_list.append(self.template_temp.clone())
dz_feat_update_reverse_list.append(self.dz_feat_update.clone().to(outputs['dz_feat']))
x_feat = outputs['x_feat']
self.dz_feat_update = outputs['dz_feat']
pred_bbox = outputs['pred_bboxes']
search_images_reverse_list.append(outputs['search_images'])
target_in_search_reverse_list.append(outputs['target_in_search'])
search_anno_reverse_list.append(outputs['gt_in_crop'])
if len(outputs['pre_seq']) != 8:
print(outputs['pre_seq'])
print(len(outputs['pre_seq']))
print(idx)
print(data['num_frames'])
print(data['search_annos'])
return None
pre_seq_reverse_list.append(outputs['pre_seq'])
pred_bbox_corner = bbutils.batch_xywh2corner(pred_bbox)
gt_bbox_corner = bbutils.batch_xywh2corner(here_gt_bbox)
here_iou = []
for i in range(num_seq):
bbox_iou = IoU(pred_bbox_corner[i], gt_bbox_corner[i])
here_iou.append(bbox_iou)
iou_reverse_list.append(here_iou)
x_feat_reverse_list.append(x_feat.clone())
results['x_feat'] = torch.cat([torch.stack(x_feat_list), torch.stack(x_feat_reverse_list)], dim=2)
results['search_images'] = torch.cat([torch.stack(search_images_list), torch.stack(search_images_reverse_list)],
dim=1)
results['template_images_z0'] = torch.cat(
[torch.stack(template_all_list), torch.stack(template_all_reverse_list)], dim=1)
results['dz_feat_update'] = torch.cat(
[torch.stack(dz_feat_udpate_list), torch.stack(dz_feat_update_reverse_list)], dim=1)
results['search_anno'] = torch.cat([torch.stack(search_anno_list), torch.stack(search_anno_reverse_list)],
dim=1)
results['pre_seq'] = torch.cat([torch.stack(pre_seq_list), torch.stack(pre_seq_reverse_list)], dim=1)
results['target_in_search'] = torch.cat(
[torch.stack(target_in_search_list), torch.stack(target_in_search_reverse_list)], dim=1)
iou_tensor = torch.tensor(iou_list, dtype=torch.float)
iou_tensor_reverse = torch.tensor(iou_reverse_list, dtype=torch.float)
results['baseline_iou'] = torch.cat([iou_tensor[:, :num_seq], iou_tensor_reverse[:, :num_seq]], dim=1)
# results['explore_iou'] = iou_tensor[:, num_seq:]
# results['action_tensor'] = torch.stack(action_tensor_list)
return results
def forward_pass(self, data):
# currently only support 1 template and 1 search region
assert len(data['template_images']) == 1
assert len(data['search_images']) == 1
template_list = []
for i in range(self.settings.num_template):
template_img_i = data['template_images'][i].view(-1,
*data['template_images'].shape[2:]) # (batch, 3, 128, 128)
template_list.append(template_img_i)
search_img = data['search_images'][0].view(-1, *data['search_images'].shape[2:]) # (batch, 3, 320, 320)
box_mask_z = None
ce_keep_rate = None
if self.cfg.MODEL.BACKBONE.CE_LOC:
box_mask_z = generate_mask_cond(self.cfg, template_list[0].shape[0], template_list[0].device,
data['template_anno'][0])
ce_start_epoch = self.cfg.TRAIN.CE_START_EPOCH
ce_warm_epoch = self.cfg.TRAIN.CE_WARM_EPOCH
ce_keep_rate = adjust_keep_rate(data['epoch'], warmup_epochs=ce_start_epoch,
total_epochs=ce_start_epoch + ce_warm_epoch,
ITERS_PER_EPOCH=1,
base_keep_rate=self.cfg.MODEL.BACKBONE.CE_KEEP_RATIO[0])
if len(template_list) == 1:
template_list = template_list[0]
gt_bbox = data['search_anno'][-1]
begin = self.bins
end = self.bins + 1
gt_bbox[:, 2] = gt_bbox[:, 0] + gt_bbox[:, 2]
gt_bbox[:, 3] = gt_bbox[:, 1] + gt_bbox[:, 3]
gt_bbox = gt_bbox.clamp(min=0.0, max=1.0)
data['real_bbox'] = gt_bbox
seq_ori = gt_bbox * (self.bins - 1)
seq_ori = seq_ori.int().to(search_img)
B = seq_ori.shape[0]
seq_ori_4_4 = seq_ori[:, 0:3]
seq_input = torch.cat([torch.ones((B, 1)).to(search_img) * begin, seq_ori], dim=1)
seq_output = torch.cat([seq_ori, torch.ones((B, 1)).to(search_img) * end], dim=1)
data['seq_input'] = seq_input
data['seq_output'] = seq_output
out_dict = self.net(template=template_list,
search=search_img,
ce_template_mask=box_mask_z,
ce_keep_rate=ce_keep_rate,
return_last_attn=False,
seq_input=seq_input)
return out_dict
def compute_sequence_losses(self, data):
num_frames = data['search_images'].shape[0]
template_images_for = data['template_images_z0'].reshape(-1, *data['template_images_z0'].size()[2:])
dz_feat = data['dz_feat_update'].reshape(-1, *data['dz_feat_update'].size()[2:])
target_in_search = data['target_in_search'].reshape(-1, *data['target_in_search'].size()[2:])
search_images = data['search_images'].reshape(-1, *data['search_images'].size()[2:])
search_anno = data['search_anno'].reshape(-1, *data['search_anno'].size()[2:])
pre_seq = data['pre_seq'].reshape(-1, 4 * self.pre_num)
x_feat = data['x_feat'].reshape(-1, *data['x_feat'].size()[2:])
epoch = data['epoch']
if epoch < 11:
self.loss_weight['focal'] = 2
self.loss_weight['score_update'] = 1
elif epoch < 31:
self.loss_weight['focal'] = 0
self.loss_weight['score_update'] = 0.1
else:
self.loss_weight['focal'] = 0
self.loss_weight['score_update'] = 0.0
pre_seq = pre_seq.clamp(-0.5 * self.range + 0.5, 0.5 + self.range * 0.5)
pre_seq = (pre_seq + (self.range * 0.5 - 0.5)) * (self.bins - 1)
outputs = self.net(template_images_for, dz_feat, search_images, seq_input=pre_seq, stage="forward_pass",
search_feature=x_feat, target_in_search_img=target_in_search)
score = outputs['score']
renew_loss = outputs['renew_loss']
pred_feat = outputs["feat"]
if self.focal == None:
weight = torch.ones(self.bins * self.range + 6) * 1
weight[self.bins * self.range + 4] = 0.1
weight[self.bins * self.range + 3] = 0.1
weight[self.bins * self.range + 2] = 0.1
weight[self.bins * self.range + 1] = 0.1
weight[self.bins * self.range] = 0.1
weight.to(pred_feat)
self.focal = torch.nn.CrossEntropyLoss(weight=weight, size_average=True).to(pred_feat)
search_anno[:, 2] = search_anno[:, 2] + search_anno[:, 0]
search_anno[:, 3] = search_anno[:, 3] + search_anno[:, 1]
target = (search_anno / self.cfg.DATA.SEARCH.SIZE + (self.range * 0.5 - 0.5)) * (self.bins - 1)
target = target.clamp(min=0.0, max=(self.bins * self.range - 0.0001))
target_iou = target
end_flag = torch.ones((target.shape[0], 1)) * (self.bins * self.range + 1)
end_flag = end_flag.to(target)
target = torch.cat([target], dim=1)
target = target.reshape(-1).to(torch.int64)
pred = pred_feat.permute(1, 0, 2).reshape(-1, self.bins * self.range + 6)
varifocal_loss = self.focal(pred, target)
pred = pred_feat[0:4, :, 0:self.bins * self.range]
target = target_iou[:, 0:4].to(pred_feat) / (self.bins - 1) - (self.range * 0.5 - 0.5)
out = pred.softmax(-1).to(pred)
mul = torch.range((-1 * self.range * 0.5 + 0.5) + 1 / (self.bins * self.range), (self.range * 0.5 + 0.5) - 1 / (self.bins * self.range), 2 / (self.bins * self.range)).to(pred)
ans = out * mul
ans = ans.sum(dim=-1)
ans = ans.permute(1, 0).to(pred)
extra_seq = ans
extra_seq = extra_seq.to(pred)
cious, iou = SIoU_loss(extra_seq, target, 4)
cious = cious.mean()
score_real = score
score_loss = self.objective['l1'](score_real, iou)
giou_loss = cious
l1_loss = self.objective['l1'](extra_seq, target)
loss_bb = (self.loss_weight['giou'] * giou_loss + self.loss_weight['l1'] * l1_loss + self.loss_weight[
'focal'] * varifocal_loss)
total_losses = loss_bb + renew_loss * self.loss_weight['score_update'] + score_loss * self.loss_weight['score_update']
mean_iou = iou.detach().mean()
status = {"Loss/total": total_losses.item() / 2,
"Loss/score": score_loss.item() / 2,
"Loss/giou": giou_loss.item() / 2,
"Loss/l1": l1_loss.item() / 2,
"Loss/location": varifocal_loss.item() / 2,
"Loss/renew": renew_loss.item() / 2,
"IoU": mean_iou.item() / 2}
return total_losses, status
这是lib/train/actors/artrackv2_seq.py
import torch
import numpy as np
def batch_center2corner(boxes):
xmin = boxes[:, 0] - boxes[:, 2] * 0.5
ymin = boxes[:, 1] - boxes[:, 3] * 0.5
xmax = boxes[:, 0] + boxes[:, 2] * 0.5
ymax = boxes[:, 1] + boxes[:, 3] * 0.5
if isinstance(boxes, np.ndarray):
return np.stack([xmin, ymin, xmax, ymax], 1)
else:
return torch.stack([xmin, ymin, xmax, ymax], 1)
def batch_corner2center(boxes):
cx = (boxes[:, 0] + boxes[:, 2]) * 0.5
cy = (boxes[:, 1] + boxes[:, 3]) * 0.5
w = (boxes[:, 2] - boxes[:, 0])
h = (boxes[:, 3] - boxes[:, 1])
if isinstance(boxes, np.ndarray):
return np.stack([cx, cy, w, h], 1)
else:
return torch.stack([cx, cy, w, h], 1)
def batch_xywh2center(boxes):
cx = boxes[:, 0] + (boxes[:, 2] - 1) / 2
cy = boxes[:, 1] + (boxes[:, 3] - 1) / 2
w = boxes[:, 2]
h = boxes[:, 3]
if isinstance(boxes, np.ndarray):
return np.stack([cx, cy, w, h], 1)
else:
return torch.stack([cx, cy, w, h], 1)
# def batch_xywh2center2(boxes):
# cx = boxes[:, 0] + boxes[:, 2] / 2
# cy = boxes[:, 1] + boxes[:, 3] / 2
# w = boxes[:, 2]
# h = boxes[:, 3]
# if isinstance(boxes, np.ndarray):
# return np.stack([cx, cy, w, h], 1)
# else:
# return torch.stack([cx, cy, w, h], 1)
def batch_xywh2center2(boxes):
if not isinstance(boxes, np.ndarray):
boxes = np.array(boxes)
original_shape = boxes.shape
if boxes.ndim == 1:
if boxes.size == 4:
boxes = boxes.reshape(1, -1)
else:
raise ValueError(f"Invalid 1D box size: {boxes.size}, expected 4")
cx = boxes[:, 0] + boxes[:, 2] / 2
cy = boxes[:, 1] + boxes[:, 3] / 2
w = boxes[:, 2]
h = boxes[:, 3]
output = np.stack([cx, cy, w, h], axis=1)
if len(original_shape) == 1:
output = output.squeeze(0)
return output
def batch_xywh2corner(boxes):
xmin = boxes[:, 0]
ymin = boxes[:, 1]
xmax = boxes[:, 0] + boxes[:, 2]
ymax = boxes[:, 1] + boxes[:, 3]
if isinstance(boxes, np.ndarray):
return np.stack([xmin, ymin, xmax, ymax], 1)
else:
return torch.stack([xmin, ymin, xmax, ymax], 1)
def rect_to_rel(bb, sz_norm=None):
"""Convert standard rectangular parametrization of the bounding box [x, y, w, h]
to relative parametrization [cx/sw, cy/sh, log(w), log(h)], where [cx, cy] is the center coordinate.
args:
bb - N x 4 tensor of boxes.
sz_norm - [N] x 2 tensor of value of [sw, sh] (optional). sw=w and sh=h if not given.
"""
c = bb[...,:2] + 0.5 * bb[...,2:]
if sz_norm is None:
c_rel = c / bb[...,2:]
else:
c_rel = c / sz_norm
sz_rel = torch.log(bb[...,2:])
return torch.cat((c_rel, sz_rel), dim=-1)
def rel_to_rect(bb, sz_norm=None):
"""Inverts the effect of rect_to_rel. See above."""
sz = torch.exp(bb[...,2:])
if sz_norm is None:
c = bb[...,:2] * sz
else:
c = bb[...,:2] * sz_norm
tl = c - 0.5 * sz
return torch.cat((tl, sz), dim=-1)
def masks_to_bboxes(mask, fmt='c'):
""" Convert a mask tensor to one or more bounding boxes.
Note: This function is a bit new, make sure it does what it says. /Andreas
:param mask: Tensor of masks, shape = (..., H, W)
:param fmt: bbox layout. 'c' => "center + size" or (x_center, y_center, width, height)
't' => "top left + size" or (x_left, y_top, width, height)
'v' => "vertices" or (x_left, y_top, x_right, y_bottom)
:return: tensor containing a batch of bounding boxes, shape = (..., 4)
"""
batch_shape = mask.shape[:-2]
mask = mask.reshape((-1, *mask.shape[-2:]))
bboxes = []
for m in mask:
mx = m.sum(dim=-2).nonzero()
my = m.sum(dim=-1).nonzero()
bb = [mx.min(), my.min(), mx.max(), my.max()] if (len(mx) > 0 and len(my) > 0) else [0, 0, 0, 0]
bboxes.append(bb)
bboxes = torch.tensor(bboxes, dtype=torch.float32, device=mask.device)
bboxes = bboxes.reshape(batch_shape + (4,))
if fmt == 'v':
return bboxes
x1 = bboxes[..., :2]
s = bboxes[..., 2:] - x1 + 1
if fmt == 'c':
return torch.cat((x1 + 0.5 * s, s), dim=-1)
elif fmt == 't':
return torch.cat((x1, s), dim=-1)
raise ValueError("Undefined bounding box layout '%s'" % fmt)
def masks_to_bboxes_multi(mask, ids, fmt='c'):
assert mask.dim() == 2
bboxes = []
for id in ids:
mx = (mask == id).sum(dim=-2).nonzero()
my = (mask == id).float().sum(dim=-1).nonzero()
bb = [mx.min(), my.min(), mx.max(), my.max()] if (len(mx) > 0 and len(my) > 0) else [0, 0, 0, 0]
bb = torch.tensor(bb, dtype=torch.float32, device=mask.device)
x1 = bb[:2]
s = bb[2:] - x1 + 1
if fmt == 'v':
pass
elif fmt == 'c':
bb = torch.cat((x1 + 0.5 * s, s), dim=-1)
elif fmt == 't':
bb = torch.cat((x1, s), dim=-1)
else:
raise ValueError("Undefined bounding box layout '%s'" % fmt)
bboxes.append(bb)
return bboxes
这是lib/train/data/bounding_box_utils.py
root@train-test08-0:/data1/MyCode_wzt/ARTrackV2-main/ARTrackV2-main# cd /data1/MyCode_wzt/ARTrackV2-main/ARTrackV2-main ; /usr/bin/env /bin/python /root/.vscode-server/extensions/ms-python.python-2022.16.1/pythonFiles/lib/python/debugpy/adapter/../../debugpy/launcher 37657 -- /data1/MyCode_wzt/ARTrackV2-main/ARTrackV2-main/tracking/train.py --script artrackv2_seq --config artrackv2_seq_256_full --save_dir ./output --mode single --use_wandb 0
script_name: artrackv2_seq.py config_name: artrackv2_seq_256_full.yaml
New configuration is shown below.
MODEL configuration: {'PRETRAIN_FILE': 'mae_pretrain_vit_base.pth', 'PRETRAIN_PTH': '', 'EXTRA_MERGER': False, 'RETURN_INTER': False, 'RETURN_STAGES': [2, 5, 8, 11], 'DECODER': {'TYPE': 'mask', 'MASK_RATIO': 0.75, 'EMBEDDIM': 512, 'DEPTH': 8, 'NUMHEADS': 16, 'MLPRATIO': 4}, 'BACKBONE': {'TYPE': 'vit_base_patch16_224', 'STRIDE': 16, 'PATCHSIZE': 16, 'MID_PE': False, 'SEP_SEG': False, 'CAT_MODE': 'direct', 'MERGE_LAYER': 0, 'ADD_CLS_TOKEN': False, 'CLS_TOKEN_USE_MODE': 'ignore', 'CE_LOC': [], 'CE_KEEP_RATIO': [], 'CE_TEMPLATE_RANGE': 'ALL'}, 'BINS': 400, 'RANGE': 2, 'EXTENSION': 3, 'PRENUM': 7, 'ENCODER_LAYER': 3, 'NUM_HEADS': 12, 'MLP_RATIO': 4, 'QKV_BIAS': True, 'DROP_RATE': 0.1, 'ATTN_DROP': 0.0, 'DROP_PATH': 0.0, 'DECODER_LAYER': 6, 'HEAD': {'TYPE': 'PIX', 'NUM_CHANNELS': 768}}
TRAIN configuration: {'LR': 8e-05, 'WEIGHT_DECAY': 0.05, 'EPOCH': 40, 'LR_DROP_EPOCH': 999, 'BATCH_SIZE': 8, 'NUM_WORKER': 0, 'OPTIMIZER': 'ADAMW', 'BACKBONE_MULTIPLIER': 0.1, 'GIOU_WEIGHT': 2.0, 'L1_WEIGHT': 0.0, 'SCORE_WEIGHT': 1.0, 'FREEZE_LAYERS': [0], 'PRINT_INTERVAL': 1, 'VAL_EPOCH_INTERVAL': 10, 'GRAD_CLIP_NORM': 0.1, 'AMP': False, 'CE_START_EPOCH': 20, 'CE_WARM_EPOCH': 80, 'DROP_PATH_RATE': 0.1, 'SCHEDULER': {'TYPE': 'step', 'DECAY_RATE': 0.05}}
DATA configuration: {'MAX_GAP': 300, 'SAMPLER_MODE': 'causal', 'MEAN': [0.485, 0.456, 0.406], 'STD': [0.229, 0.224, 0.225], 'MAX_SAMPLE_INTERVAL': 200, 'MAX_INTERVAL': 5, 'INTERVAL_PROB': 0.0, 'TEMP': 2, 'TRAIN': {'DATASETS_NAME': ['LASOT'], 'DATASETS_RATIO': [1], 'SAMPLE_PER_EPOCH': 1000}, 'VAL': {'DATASETS_NAME': ['LASOT'], 'DATASETS_RATIO': [1], 'SAMPLE_PER_EPOCH': 1000}, 'SEARCH': {'SIZE': 256, 'FACTOR': 4.0, 'CENTER_JITTER': 3, 'SCALE_JITTER': 0.25, 'NUMBER': 24}, 'TEMPLATE': {'NUMBER': 2, 'SIZE': 128, 'FACTOR': 2.0, 'CENTER_JITTER': 0, 'SCALE_JITTER': 0}}
TEST configuration: {'TEMPLATE_FACTOR': 2.0, 'TEMPLATE_SIZE': 128, 'SEARCH_FACTOR': 4.0, 'SEARCH_SIZE': 256, 'EPOCH': 40}
Load pretrained model from: /data1/MyCode_wzt/ARTrackV2-main/ARTrackV2-main/lib/models/artrackv2_seq/../../../pretrained_models/mae_pretrain_vit_base.pth
0.75
[✅] Using global build_dataloaders function
🔥 build_dataloaders called with config:
Train datasets: ['LASOT']
Val datasets: ['LASOT']
Learnable parameters are shown below.
identity
backbone.output_bias
backbone.cls_token
backbone.pos_embed
backbone.pos_embed_z
backbone.pos_embed_z0
backbone.pos_embed_z1
backbone.pos_embed_x
backbone.word_embeddings.weight
backbone.position_embeddings.weight
backbone.prev_position_embeddings.weight
backbone.patch_embed.proj.weight
backbone.patch_embed.proj.bias
backbone.blocks.0.norm1.weight
backbone.blocks.0.norm1.bias
backbone.blocks.0.attn.qkv.weight
backbone.blocks.0.attn.qkv.bias
backbone.blocks.0.attn.proj.weight
backbone.blocks.0.attn.proj.bias
backbone.blocks.0.norm2.weight
backbone.blocks.0.norm2.bias
backbone.blocks.0.mlp.fc1.weight
backbone.blocks.0.mlp.fc1.bias
backbone.blocks.0.mlp.fc2.weight
backbone.blocks.0.mlp.fc2.bias
backbone.blocks.1.norm1.weight
backbone.blocks.1.norm1.bias
backbone.blocks.1.attn.qkv.weight
backbone.blocks.1.attn.qkv.bias
backbone.blocks.1.attn.proj.weight
backbone.blocks.1.attn.proj.bias
backbone.blocks.1.norm2.weight
backbone.blocks.1.norm2.bias
backbone.blocks.1.mlp.fc1.weight
backbone.blocks.1.mlp.fc1.bias
backbone.blocks.1.mlp.fc2.weight
backbone.blocks.1.mlp.fc2.bias
backbone.blocks.2.norm1.weight
backbone.blocks.2.norm1.bias
backbone.blocks.2.attn.qkv.weight
backbone.blocks.2.attn.qkv.bias
backbone.blocks.2.attn.proj.weight
backbone.blocks.2.attn.proj.bias
backbone.blocks.2.norm2.weight
backbone.blocks.2.norm2.bias
backbone.blocks.2.mlp.fc1.weight
backbone.blocks.2.mlp.fc1.bias
backbone.blocks.2.mlp.fc2.weight
backbone.blocks.2.mlp.fc2.bias
backbone.blocks.3.norm1.weight
backbone.blocks.3.norm1.bias
backbone.blocks.3.attn.qkv.weight
backbone.blocks.3.attn.qkv.bias
backbone.blocks.3.attn.proj.weight
backbone.blocks.3.attn.proj.bias
backbone.blocks.3.norm2.weight
backbone.blocks.3.norm2.bias
backbone.blocks.3.mlp.fc1.weight
backbone.blocks.3.mlp.fc1.bias
backbone.blocks.3.mlp.fc2.weight
backbone.blocks.3.mlp.fc2.bias
backbone.blocks.4.norm1.weight
backbone.blocks.4.norm1.bias
backbone.blocks.4.attn.qkv.weight
backbone.blocks.4.attn.qkv.bias
backbone.blocks.4.attn.proj.weight
backbone.blocks.4.attn.proj.bias
backbone.blocks.4.norm2.weight
backbone.blocks.4.norm2.bias
backbone.blocks.4.mlp.fc1.weight
backbone.blocks.4.mlp.fc1.bias
backbone.blocks.4.mlp.fc2.weight
backbone.blocks.4.mlp.fc2.bias
backbone.blocks.5.norm1.weight
backbone.blocks.5.norm1.bias
backbone.blocks.5.attn.qkv.weight
backbone.blocks.5.attn.qkv.bias
backbone.blocks.5.attn.proj.weight
backbone.blocks.5.attn.proj.bias
backbone.blocks.5.norm2.weight
backbone.blocks.5.norm2.bias
backbone.blocks.5.mlp.fc1.weight
backbone.blocks.5.mlp.fc1.bias
backbone.blocks.5.mlp.fc2.weight
backbone.blocks.5.mlp.fc2.bias
backbone.blocks.6.norm1.weight
backbone.blocks.6.norm1.bias
backbone.blocks.6.attn.qkv.weight
backbone.blocks.6.attn.qkv.bias
backbone.blocks.6.attn.proj.weight
backbone.blocks.6.attn.proj.bias
backbone.blocks.6.norm2.weight
backbone.blocks.6.norm2.bias
backbone.blocks.6.mlp.fc1.weight
backbone.blocks.6.mlp.fc1.bias
backbone.blocks.6.mlp.fc2.weight
backbone.blocks.6.mlp.fc2.bias
backbone.blocks.7.norm1.weight
backbone.blocks.7.norm1.bias
backbone.blocks.7.attn.qkv.weight
backbone.blocks.7.attn.qkv.bias
backbone.blocks.7.attn.proj.weight
backbone.blocks.7.attn.proj.bias
backbone.blocks.7.norm2.weight
backbone.blocks.7.norm2.bias
backbone.blocks.7.mlp.fc1.weight
backbone.blocks.7.mlp.fc1.bias
backbone.blocks.7.mlp.fc2.weight
backbone.blocks.7.mlp.fc2.bias
backbone.blocks.8.norm1.weight
backbone.blocks.8.norm1.bias
backbone.blocks.8.attn.qkv.weight
backbone.blocks.8.attn.qkv.bias
backbone.blocks.8.attn.proj.weight
backbone.blocks.8.attn.proj.bias
backbone.blocks.8.norm2.weight
backbone.blocks.8.norm2.bias
backbone.blocks.8.mlp.fc1.weight
backbone.blocks.8.mlp.fc1.bias
backbone.blocks.8.mlp.fc2.weight
backbone.blocks.8.mlp.fc2.bias
backbone.blocks.9.norm1.weight
backbone.blocks.9.norm1.bias
backbone.blocks.9.attn.qkv.weight
backbone.blocks.9.attn.qkv.bias
backbone.blocks.9.attn.proj.weight
backbone.blocks.9.attn.proj.bias
backbone.blocks.9.norm2.weight
backbone.blocks.9.norm2.bias
backbone.blocks.9.mlp.fc1.weight
backbone.blocks.9.mlp.fc1.bias
backbone.blocks.9.mlp.fc2.weight
backbone.blocks.9.mlp.fc2.bias
backbone.blocks.10.norm1.weight
backbone.blocks.10.norm1.bias
backbone.blocks.10.attn.qkv.weight
backbone.blocks.10.attn.qkv.bias
backbone.blocks.10.attn.proj.weight
backbone.blocks.10.attn.proj.bias
backbone.blocks.10.norm2.weight
backbone.blocks.10.norm2.bias
backbone.blocks.10.mlp.fc1.weight
backbone.blocks.10.mlp.fc1.bias
backbone.blocks.10.mlp.fc2.weight
backbone.blocks.10.mlp.fc2.bias
backbone.blocks.11.norm1.weight
backbone.blocks.11.norm1.bias
backbone.blocks.11.attn.qkv.weight
backbone.blocks.11.attn.qkv.bias
backbone.blocks.11.attn.proj.weight
backbone.blocks.11.attn.proj.bias
backbone.blocks.11.norm2.weight
backbone.blocks.11.norm2.bias
backbone.blocks.11.mlp.fc1.weight
backbone.blocks.11.mlp.fc1.bias
backbone.blocks.11.mlp.fc2.weight
backbone.blocks.11.mlp.fc2.bias
backbone.extension.0.norm1.weight
backbone.extension.0.norm1.bias
backbone.extension.0.attn.qkv.weight
backbone.extension.0.attn.qkv.bias
backbone.extension.0.attn.proj.weight
backbone.extension.0.attn.proj.bias
backbone.extension.0.norm2.weight
backbone.extension.0.norm2.bias
backbone.extension.0.mlp.fc1.weight
backbone.extension.0.mlp.fc1.bias
backbone.extension.0.mlp.fc2.weight
backbone.extension.0.mlp.fc2.bias
backbone.extension.1.norm1.weight
backbone.extension.1.norm1.bias
backbone.extension.1.attn.qkv.weight
backbone.extension.1.attn.qkv.bias
backbone.extension.1.attn.proj.weight
backbone.extension.1.attn.proj.bias
backbone.extension.1.norm2.weight
backbone.extension.1.norm2.bias
backbone.extension.1.mlp.fc1.weight
backbone.extension.1.mlp.fc1.bias
backbone.extension.1.mlp.fc2.weight
backbone.extension.1.mlp.fc2.bias
backbone.extension.2.norm1.weight
backbone.extension.2.norm1.bias
backbone.extension.2.attn.qkv.weight
backbone.extension.2.attn.qkv.bias
backbone.extension.2.attn.proj.weight
backbone.extension.2.attn.proj.bias
backbone.extension.2.norm2.weight
backbone.extension.2.norm2.bias
backbone.extension.2.mlp.fc1.weight
backbone.extension.2.mlp.fc1.bias
backbone.extension.2.mlp.fc2.weight
backbone.extension.2.mlp.fc2.bias
backbone.norm.weight
backbone.norm.bias
score_mlp.layers.0.0.weight
score_mlp.layers.0.0.bias
score_mlp.layers.1.weight
score_mlp.layers.1.bias
cross_2_decoder.mask_token
cross_2_decoder.decoder_embed.weight
cross_2_decoder.decoder_embed.bias
cross_2_decoder.decoder_blocks.0.norm1.weight
cross_2_decoder.decoder_blocks.0.norm1.bias
cross_2_decoder.decoder_blocks.0.attn.qkv.weight
cross_2_decoder.decoder_blocks.0.attn.qkv.bias
cross_2_decoder.decoder_blocks.0.attn.proj.weight
cross_2_decoder.decoder_blocks.0.attn.proj.bias
cross_2_decoder.decoder_blocks.0.norm2.weight
cross_2_decoder.decoder_blocks.0.norm2.bias
cross_2_decoder.decoder_blocks.0.mlp.fc1.weight
cross_2_decoder.decoder_blocks.0.mlp.fc1.bias
cross_2_decoder.decoder_blocks.0.mlp.fc2.weight
cross_2_decoder.decoder_blocks.0.mlp.fc2.bias
cross_2_decoder.decoder_blocks.1.norm1.weight
cross_2_decoder.decoder_blocks.1.norm1.bias
cross_2_decoder.decoder_blocks.1.attn.qkv.weight
cross_2_decoder.decoder_blocks.1.attn.qkv.bias
cross_2_decoder.decoder_blocks.1.attn.proj.weight
cross_2_decoder.decoder_blocks.1.attn.proj.bias
cross_2_decoder.decoder_blocks.1.norm2.weight
cross_2_decoder.decoder_blocks.1.norm2.bias
cross_2_decoder.decoder_blocks.1.mlp.fc1.weight
cross_2_decoder.decoder_blocks.1.mlp.fc1.bias
cross_2_decoder.decoder_blocks.1.mlp.fc2.weight
cross_2_decoder.decoder_blocks.1.mlp.fc2.bias
cross_2_decoder.decoder_blocks.2.norm1.weight
cross_2_decoder.decoder_blocks.2.norm1.bias
cross_2_decoder.decoder_blocks.2.attn.qkv.weight
cross_2_decoder.decoder_blocks.2.attn.qkv.bias
cross_2_decoder.decoder_blocks.2.attn.proj.weight
cross_2_decoder.decoder_blocks.2.attn.proj.bias
cross_2_decoder.decoder_blocks.2.norm2.weight
cross_2_decoder.decoder_blocks.2.norm2.bias
cross_2_decoder.decoder_blocks.2.mlp.fc1.weight
cross_2_decoder.decoder_blocks.2.mlp.fc1.bias
cross_2_decoder.decoder_blocks.2.mlp.fc2.weight
cross_2_decoder.decoder_blocks.2.mlp.fc2.bias
cross_2_decoder.decoder_blocks.3.norm1.weight
cross_2_decoder.decoder_blocks.3.norm1.bias
cross_2_decoder.decoder_blocks.3.attn.qkv.weight
cross_2_decoder.decoder_blocks.3.attn.qkv.bias
cross_2_decoder.decoder_blocks.3.attn.proj.weight
cross_2_decoder.decoder_blocks.3.attn.proj.bias
cross_2_decoder.decoder_blocks.3.norm2.weight
cross_2_decoder.decoder_blocks.3.norm2.bias
cross_2_decoder.decoder_blocks.3.mlp.fc1.weight
cross_2_decoder.decoder_blocks.3.mlp.fc1.bias
cross_2_decoder.decoder_blocks.3.mlp.fc2.weight
cross_2_decoder.decoder_blocks.3.mlp.fc2.bias
cross_2_decoder.decoder_blocks.4.norm1.weight
cross_2_decoder.decoder_blocks.4.norm1.bias
cross_2_decoder.decoder_blocks.4.attn.qkv.weight
cross_2_decoder.decoder_blocks.4.attn.qkv.bias
cross_2_decoder.decoder_blocks.4.attn.proj.weight
cross_2_decoder.decoder_blocks.4.attn.proj.bias
cross_2_decoder.decoder_blocks.4.norm2.weight
cross_2_decoder.decoder_blocks.4.norm2.bias
cross_2_decoder.decoder_blocks.4.mlp.fc1.weight
cross_2_decoder.decoder_blocks.4.mlp.fc1.bias
cross_2_decoder.decoder_blocks.4.mlp.fc2.weight
cross_2_decoder.decoder_blocks.4.mlp.fc2.bias
cross_2_decoder.decoder_blocks.5.norm1.weight
cross_2_decoder.decoder_blocks.5.norm1.bias
cross_2_decoder.decoder_blocks.5.attn.qkv.weight
cross_2_decoder.decoder_blocks.5.attn.qkv.bias
cross_2_decoder.decoder_blocks.5.attn.proj.weight
cross_2_decoder.decoder_blocks.5.attn.proj.bias
cross_2_decoder.decoder_blocks.5.norm2.weight
cross_2_decoder.decoder_blocks.5.norm2.bias
cross_2_decoder.decoder_blocks.5.mlp.fc1.weight
cross_2_decoder.decoder_blocks.5.mlp.fc1.bias
cross_2_decoder.decoder_blocks.5.mlp.fc2.weight
cross_2_decoder.decoder_blocks.5.mlp.fc2.bias
cross_2_decoder.decoder_blocks.6.norm1.weight
cross_2_decoder.decoder_blocks.6.norm1.bias
cross_2_decoder.decoder_blocks.6.attn.qkv.weight
cross_2_decoder.decoder_blocks.6.attn.qkv.bias
cross_2_decoder.decoder_blocks.6.attn.proj.weight
cross_2_decoder.decoder_blocks.6.attn.proj.bias
cross_2_decoder.decoder_blocks.6.norm2.weight
cross_2_decoder.decoder_blocks.6.norm2.bias
cross_2_decoder.decoder_blocks.6.mlp.fc1.weight
cross_2_decoder.decoder_blocks.6.mlp.fc1.bias
cross_2_decoder.decoder_blocks.6.mlp.fc2.weight
cross_2_decoder.decoder_blocks.6.mlp.fc2.bias
cross_2_decoder.decoder_blocks.7.norm1.weight
cross_2_decoder.decoder_blocks.7.norm1.bias
cross_2_decoder.decoder_blocks.7.attn.qkv.weight
cross_2_decoder.decoder_blocks.7.attn.qkv.bias
cross_2_decoder.decoder_blocks.7.attn.proj.weight
cross_2_decoder.decoder_blocks.7.attn.proj.bias
cross_2_decoder.decoder_blocks.7.norm2.weight
cross_2_decoder.decoder_blocks.7.norm2.bias
cross_2_decoder.decoder_blocks.7.mlp.fc1.weight
cross_2_decoder.decoder_blocks.7.mlp.fc1.bias
cross_2_decoder.decoder_blocks.7.mlp.fc2.weight
cross_2_decoder.decoder_blocks.7.mlp.fc2.bias
cross_2_decoder.decoder_norm.weight
cross_2_decoder.decoder_norm.bias
cross_2_decoder.decoder_pred.weight
cross_2_decoder.decoder_pred.bias
checkpoints will be saved to /data1/MyCode_wzt/ARTrackV2-main/ARTrackV2-main/output/checkpoints
move_data True
No matching checkpoint file found
🔍 [DEBUG] Data keys: dict_keys(['template_images', 'template_annos', 'search_images', 'search_annos', 'seq_id', 'dataset', 'search_class', 'num_frames'])
template_images: list[8] -> [(360, 480, 3), (720, 1280, 3), (720, 1280, 3), (720, 1280, 3), (240, 320, 3), (240, 320, 3), (360, 480, 3), (720, 1280, 3)]
template_annos: list[8] -> [(4,), (4,), (4,), (4,), (4,), (4,), (4,), (4,)]
search_images: list[8] -> [(24, 360, 480, 3), (24, 720, 1280, 3), (24, 720, 1280, 3), (24, 720, 1280, 3), (24, 240, 320, 3), (24, 240, 320, 3), (24, 360, 480, 3), (24, 720, 1280, 3)]
search_annos: list[8] -> [(24, 4), (24, 4), (24, 4), (24, 4), (24, 4), (24, 4), (24, 4), (24, 4)]
seq_id: list[8] -> [<class 'int'>, <class 'int'>, <class 'int'>, <class 'int'>, <class 'int'>, <class 'int'>, <class 'int'>, <class 'int'>]
dataset: list[8] -> [<class 'str'>, <class 'str'>, <class 'str'>, <class 'str'>, <class 'str'>, <class 'str'>, <class 'str'>, <class 'str'>]
search_class: list[8] -> [<class 'str'>, <class 'str'>, <class 'str'>, <class 'str'>, <class 'str'>, <class 'str'>, <class 'str'>, <class 'str'>]
num_frames: list[8] -> [<class 'int'>, <class 'int'>, <class 'int'>, <class 'int'>, <class 'int'>, <class 'int'>, <class 'int'>, <class 'int'>]
Training crashed at epoch 1
Traceback for the error!
Traceback (most recent call last):
File "/data1/MyCode_wzt/ARTrackV2-main/ARTrackV2-main/lib/train/../../lib/train/trainers/base_trainer.py", line 85, in train
self.train_epoch()
File "/data1/MyCode_wzt/ARTrackV2-main/ARTrackV2-main/lib/train/../../lib/train/trainers/ltr_seq_trainer_v2.py", line 175, in train_epoch
self.cycle_dataset(loader)
File "/data1/MyCode_wzt/ARTrackV2-main/ARTrackV2-main/lib/train/../../lib/train/trainers/ltr_seq_trainer_v2.py", line 83, in cycle_dataset
explore_result = self.actor.explore(data)
File "/data1/MyCode_wzt/ARTrackV2-main/ARTrackV2-main/lib/train/../../lib/train/actors/artrackv2_seq.py", line 665, in explore
outputs_template = self.batch_init(
File "/data1/MyCode_wzt/ARTrackV2-main/ARTrackV2-main/lib/train/../../lib/train/actors/artrackv2_seq.py", line 402, in batch_init
z_2_crop = z_2_crop.squeeze(1).to(self.net.module.backbone.word_embeddings.weight)
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 947, in __getattr__
raise AttributeError("'{}' object has no attribute '{}'".format(
AttributeError: 'ARTrackV2Seq' object has no attribute 'module'
Restarting training from last epoch ...
Finished training!这是报错信息
最新发布