项目链接:
code:https://github.com/cvg/lightglue
Paper: https://arxiv.org/pdf/2306.13643.pdf
本文模型流程:
- 准备测试数据,至少需要两张,一张作为模板图像(提前用labelimg工具标注好ROI的BBOX),一张是待匹配图像。
- 利用LightGLUE进行关键点匹配
- 利用findHomography查找两个平面之间的透视变换矩阵(也称为单应性矩阵)【所以只支持简单的平面匹配场景】
- 将模板图的BBOX复用在待匹配图像中
- 绘制匹配线,匹配点,前后剪影图
单应性矩阵:
单应性变换是一种广义的平面变换,它包含以下几种基本的二维变换:
- Translation (平移变换)
平移变换用于在平面上平移一个物体,即沿着x轴和y轴移动。该变换保持物体的形状、大小和方向不变。
其中(tx,ty)是平移量 - Euclidean (欧几里得变换)
欧几里得变换包括旋转和平移。它保持物体的形状和大小不变,仅仅改变物体的位置和方向。
- Similarity (相似变换)
相似变换包括缩放、旋转和平移。它保持物体的形状不变,但会改变物体的大小和方向
其中 s 是缩放因子 - Affine (仿射变换)
仿射变换是相对更广义的一种线性变换,可以包括平移、旋转、缩放和剪切。这种变换保持层次间的平行性质不变,即平行线依然是平行的。
- Projective (投影变换)
投影变换或透视变换是最广义的线性平面变换,不仅包括仿射变换,还允许发生透视效果(如透视缩放、透视收缩等),对于平面投影中的线性变换是最完整的描述。
准备工作:
代码下载
git clone https://github.com/cvg/LightGlue.git && cd LightGlue
环境配置
conda create -n lightglue python=3.11
python -m pip install --upgrade pip
python -m pip install -e .
# 下面的根据自己需要下载,有就不用下了~
pip3 install opencv-python
pip3 install kornia
#根据自己的cuda下载,torch下载网址:https://pytorch.org/get-started/locally/
pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
pip3 install matplotlib
pip3 install kornia_moons
测试数据准备:
#数据集目录分布情况
--dataset_name
----A场景图片
---------1.jpg
---------1.xlm #通过labelimg工具对1.jpg中想要匹配的目标标注bbox
---------2.jpg
---------3.jpg
---------等等
----B场景图片
---------1.jpg
---------1.xlm #通过labelimg工具对1.jpg中想要匹配的目标标注bbox
---------2.jpg
---------3.jpg
---------等等
----等等场景图片
代码准备:
新建test.py,复制以下代码:
from pathlib import Path
from lightglue import LightGlue, SuperPoint, DISK
from lightglue.utils import load_image, rbd
from lightglue import viz2d
import torch
import os
import random
import time
import cv2
import numpy as np
import xml.etree.ElementTree as ET
colors=[[0,0,255],[0,255,0],[255,0,0],[0,255,255],[255,0,255],[255,255,0],
[31, 119, 180],[255, 126, 14],[44, 160, 44],[214, 39, 40],
[147, 102, 188],[140, 86, 75],[225, 119, 194],[126, 126, 126],
[188, 188, 34],[23, 188, 206]]
def transform_point(point, mat):
point = np.array([point[0], point[1], 1]).reshape(3, 1) # 3x1
point = np.dot(mat, point) # 3x3 * 3x1
point = point / point[2]
return point[0], point[1]
def draw_bboxes_on_image(image, bbox,id):
"""在图像上绘制多个边界框,每个边界框对应一个掩码。
参数:
image: 输入图像。
bbox: 框。
"""
if bbox is not None:
x_min, y_min, x_max, y_max = bbox
while id>=len(colors):
id=id-len(colors)
# id=id-len(colors) if id>=len(colors) else id
cv2.rectangle(image, (x_min, y_min), (x_max, y_max),colors[id] , 2)
return image
def main(image_root,save_root):
# 关键点提取模型,max_num_keypoints可以设置1024或2048
extractor = SuperPoint(max_num_keypoints=2048).eval().cuda() # load the extractor
# extractor = DISK(max_num_keypoints=2048).eval().cuda() # load the extractor
# 关键点匹配模型,features参数根据关键点提取模型设置
matcher = LightGlue(features='superpoint').eval().cuda() # load the matcher
# 加载测试数据
os.makedirs(save_root,exist_ok=True)
image_files=os.listdir(image_root)
image_files = [
p for p in os.listdir(image_root)
if os.path.splitext(p)[-1] in [".jpg",'.JPG','.png']
]
image_files.sort()
if len(image_files)<=1:
print("只有一张,数据不够")
return
# 第一张图作为模板图,模板图必须有含有bbox的labelimg
fix_name=image_files[0]
# 除第一张图外,其他图像匹配图
for moving_name in image_files[1:]:
start_time=time.time()
# image0源点 moving,image1目标点 fix
image0,image0_cv = load_image(os.path.join(image_root,moving_name))
image1,image1_cv = load_image(os.path.join(image_root,fix_name))
image0=image0.cuda()
image1=image1.cuda()
if image0.size()!=image1.size():
print("图像对尺寸不一致")
# 获取目标点 label
image1_ann_path=os.path.join(image_root,fix_name.split('.')[0]+".xml")
tree = ET.parse(image1_ann_path)
root = tree.getroot()
# 获取尺寸
size = root.find('size')
h,w=int(size.find('height').text),int(size.find('width').text)
# extract local features
feats0 = extractor.extract(image0) # auto-resize the image, disable with resize=None
feats1 = extractor.extract(image1)
# match the features
matches01 = matcher({'image0': feats0, 'image1': feats1})
feats0, feats1, matches01 = [rbd(x) for x in [feats0, feats1, matches01]] # remove batch dimension
kpts0, kpts1, matches = feats0["keypoints"], feats1["keypoints"], matches01["matches"]
m_kpts0, m_kpts1 = kpts0[matches[..., 0]], kpts1[matches[..., 1]]
# 显示匹配线
axes = viz2d.plot_images([image0, image1])
viz2d.plot_matches(m_kpts0, m_kpts1, save_root,'plot_matches_'+moving_name,color="lime", lw=0.2)
viz2d.add_text(0, f'Stop after {matches01["stop"]} layers', fs=20)
# 显示匹配点
kpc0, kpc1 = viz2d.cm_prune(matches01["prune0"]), viz2d.cm_prune(matches01["prune1"])
viz2d.plot_images([image0, image1])
viz2d.plot_keypoints([kpts0, kpts1] ,save_root,'plot_keypoints_'+moving_name,colors=[kpc0, kpc1], ps=10)
# findHomography查找两个平面之间的透视变换矩阵(也称为单应性矩阵)。这个矩阵是一个 3x3 的矩阵,可以将一个平面上的点映射到另一个平面上的对应点。
# H, mask = cv2.findHomography(src_points, dst_points, cv2.RANSAC, 5.0)
# 输出的变换矩阵 H 是将 src_points 映射到 dst_points 的变换矩阵
trans_mat = cv2.findHomography(m_kpts0.cpu().numpy(), m_kpts1.cpu().numpy(), cv2.RANSAC, 5.0)[0] # fix,moving
for id,obj in enumerate(root.findall('object')):
# 形变关键点,将bbox中(x,y)映射到新图中
bndbox = obj.find('bndbox')
xmin = int(bndbox.find('xmin').text)
ymin = int(bndbox.find('ymin').text)
center_min=(xmin,ymin)
new_xmin,new_ymin = transform_point(center_min,trans_mat)
new_xmin,new_ymin=int(new_xmin),int(new_ymin)
xmax = int(bndbox.find('xmax').text)
ymax = int(bndbox.find('ymax').text)
center_max=(xmax,ymax)
new_xmax,new_ymax =transform_point(center_max,trans_mat)
new_xmax,new_ymax=int(new_xmax),int(new_ymax)
if new_xmin<0 or new_ymin<0 or new_xmax>=w or new_ymax>=h:
print(moving_name,' index out !')
continue
image1_cv=draw_bboxes_on_image(image1_cv, (xmin,ymin,xmax,ymax),id=id) #目标源 fix image
image0_cv=draw_bboxes_on_image(image0_cv, (new_xmin,new_ymin,new_xmax,new_ymax),id=id)
cv2.imwrite(os.path.join(save_root,moving_name),image0_cv)
if not os.path.exists(os.path.join(save_root,fix_name)):
cv2.imwrite(os.path.join(save_root,"fix_"+fix_name),image1_cv)
# 对整张匹配图像进行形变
warp_img = cv2.warpPerspective(image0_cv, trans_mat, (w,h))
# 保存前后形变的剪影图,左边为匹配前的剪影图,后边为匹配后剪影图
before_diff=image0_cv-image1_cv
after_diff=warp_img-image1_cv
viz2d.plot_images([before_diff, after_diff],save_root,'diff_'+moving_name)
end_time=time.time()
all_time = end_time-start_time
# print('lightglue 运行时间:',all_time)
if __name__ == "__main__":
save_root='/path/result/' # 结果保存的路径
image_root = '/dataset_name/' # 上一节中测试数据路径
for dir in os.listdir(image_root):
main(os.path.join(image_root,dir),os.path.join(save_root,dir))