图像匹配的superpoint算法训练
官方没有开源的训练代码,从github上找了野生的基于pytorch实现训练。
1.检测器magicpoint以及人工数据的生成
2.用magicpoint给自己的数据集打上伪标签
我没有训练magicpoint了,直接用已经训练好的预训练权重,在这里:
2.1 main
自己编写一个文件然后运行,内容如下下面
这是官方的说明:
其中
任务:export_detector_homoAdapt
配置文件: “configs/magicpoint_coco_export.yaml”
输出保存伪标签的文件夹:“magicpoint_synth_homoAdapt_uavtianda0002_train”
2.2 接下来我们看setting.py文件
DATA_PATH = 'zz_datasets' # path for datasets 也就是输入图片存放的位置
EXPER_PATH = 'zz_logs' # path for saving checkpoints 也就是输出图片存放的位置,伪标签也在。
SYN_TMPDIR = './zz_temp_datasets/' # path for dumping synthetic data 没啥用
DEBUG = False # true: will make synthetic data only uses draw_checkboard and ignore other classes,
# DEBUG = False
前面提到的"magicpoint_synth_homoAdapt_uavtianda0002_train"也就是在’zz_logs’ 文件夹下,这个你可以自己命名
2.3然后看yaml配置文件
data:
dataset: 'Coco' # 'coco' 'hpatches'
export_folder: 'val' # train, val
preprocessing:
resize: [240, 320]
# resize: [480, 640]
gaussian_label:
enable: false # false
sigma: 1.
augmentation:
photometric:
enable: false
homography_adaptation:
enable: true
num: 100 # 100
aggregation: 'sum'
filter_counts: 0
homographies:
params:
translation: true
rotation: true
scaling: true
perspective: true
scaling_amplitude: 0.2
perspective_amplitude_x: 0.2
perspective_amplitude_y: 0.2
allow_artifacts: true
patch_ratio: 0.85
training:
workers_test: 2
model:
# name: 'SuperPointNet' # 'SuperPointNet_gauss2'
name: 'SuperPointNet_gauss2' # 'SuperPointNet_gauss2'
params: {
}
batch_size: 1
eval_batch_size: 1
detection_threshold: 0.015 # 0.015
nms: 4
top_k: 600
subpixel:
enable: true
# pretrained: 'logs/magicpoint_synth20/checkpoints/superPointNet_200000_checkpoint.pth.tar' # 'SuperPointNet'
pretrained: 'logs/magicpoint_synth_t2/checkpoints/superPointNet_100000_checkpoint.pth.tar'
data:
dataset: 'Coco' # 'coco' 'hpatches'
export_folder: 'val' # train, val
这两行的意思就是,datasets文件夹里的Coco.py文件。
第三行为任务,可以是train也可以是val。生成训练集和验证集在不同文件夹
Coco.py的第57行定义了输入图片的位置
def __init__(self, export=False, transform=None, task='train', **config):
# Update config
self.config = self.default_config
self.config = dict_update(self.config, config)
# dict_update(self.config, config)
self.transforms = transform
self.action = 'train' if task == 'train' else 'val'
# get files
base_path = Path(DATA_PATH, 'uav_tianda/' + task + '2014/')
下面是伪标签的保存地址,yaml文件中的export_folder: 设置为val和train各跑一次代码,得到两个文件夹里的伪标签。
3.使用带为伪标签的数据进行训练得到superpoint
3.1 main函数
定义这样一个函数,然后运行
这是官网的介绍:
superpoint_uavtianda_visdrone是训练好的权重的内容。
3.2看看yaml配置文件
data:
# name: 'coco'
dataset: 'Coco' # 'coco'
# labels: datasets/magicpoint_synth20_homoAdapt100_coco_f1/predictions
labels: zz_logs/magicpoint_synth_homoAdapt_uavtianda_train/predictions
# root: # datasets/COCO
root: # zz_datasets/uav_tianda
root_split_txt: # /datasets/COCO
注意labels一定要到predictions这一层,才能找到标签文件。
用于训练的图片在这: