cls_num_list = [0] * args.num_classes # NUM_CLASSES 是你数据集中的类别数量
for label in dset_train.labels:
cls_num_list[label] += 1
print("Class Counts:", cls_num_list)
cls_num_list = [0] * args.num_classes # NUM_CLASSES 是你数据集中的类别数量
for label in dset_train.labels:
cls_num_list[label] += 1
print("Class Counts:", cls_num_list)