initial commit
This commit is contained in:
199
match.py
Normal file
199
match.py
Normal file
@@ -0,0 +1,199 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from models.rord import RoRD
|
||||
from torchvision import transforms
|
||||
from utils.transforms import SobelTransform
|
||||
import numpy as np
|
||||
import cv2
|
||||
from PIL import Image
|
||||
|
||||
def extract_keypoints_and_descriptors(model, image):
|
||||
"""
|
||||
从 RoRD 模型中提取关键点和描述子。
|
||||
|
||||
参数:
|
||||
model (RoRD): RoRD 模型。
|
||||
image (torch.Tensor): 输入图像张量,形状为 [1, 1, H, W]。
|
||||
|
||||
返回:
|
||||
tuple: (keypoints_input, descriptors)
|
||||
- keypoints_input: [N, 2] float tensor,关键点在输入图像中的坐标。
|
||||
- descriptors: [N, 128] float tensor,L2 归一化的描述子。
|
||||
"""
|
||||
with torch.no_grad():
|
||||
detection_map, _, desc_rord = model(image)
|
||||
desc = desc_rord # 使用 RoRD 描述子头
|
||||
|
||||
# 从检测图中提取关键点
|
||||
thresh = 0.5
|
||||
binary_map = (detection_map > thresh).float()
|
||||
coords = torch.nonzero(binary_map[0, 0] > thresh).float() # [N, 2],每个行是 (i_d, j_d)
|
||||
keypoints_input = coords * 16.0 # 将特征图坐标映射到输入图像坐标(stride=16)
|
||||
|
||||
# 从描述子图中提取描述子
|
||||
# detection_map 的形状为 [1, 1, H/16, W/16],desc 的形状为 [1, 128, H/8, W/8]
|
||||
# 将 detection_map 的坐标映射到 desc 的坐标:(i_d * 2, j_d * 2)
|
||||
keypoints_desc = (coords * 2).long() # [N, 2],整数坐标
|
||||
H_desc, W_desc = desc.shape[2], desc.shape[3]
|
||||
mask = (keypoints_desc[:, 0] < H_desc) & (keypoints_desc[:, 1] < W_desc)
|
||||
keypoints_desc = keypoints_desc[mask]
|
||||
keypoints_input = keypoints_input[mask]
|
||||
|
||||
# 提取描述子
|
||||
descriptors = desc[0, :, keypoints_desc[:, 0], keypoints_desc[:, 1]].T # [N, 128]
|
||||
|
||||
# L2 归一化描述子
|
||||
descriptors = F.normalize(descriptors, p=2, dim=1)
|
||||
|
||||
return keypoints_input, descriptors
|
||||
|
||||
def mutual_nearest_neighbor(template_descs, layout_descs):
|
||||
"""
|
||||
使用互最近邻(MNN)找到模板和版图之间的匹配。
|
||||
|
||||
参数:
|
||||
template_descs (torch.Tensor): 模板描述子,形状为 [M, 128]。
|
||||
layout_descs (torch.Tensor): 版图描述子,形状为 [N, 128]。
|
||||
|
||||
返回:
|
||||
list: [(i_template, i_layout)],互最近邻匹配对的列表。
|
||||
"""
|
||||
M, N = template_descs.size(0), layout_descs.size(0)
|
||||
if M == 0 or N == 0:
|
||||
return []
|
||||
similarity_matrix = template_descs @ layout_descs.T # [M, N],点积矩阵
|
||||
|
||||
# 找到每个模板描述子的最近邻
|
||||
nn_template_to_layout = torch.argmax(similarity_matrix, dim=1) # [M]
|
||||
|
||||
# 找到每个版图描述子的最近邻
|
||||
nn_layout_to_template = torch.argmax(similarity_matrix, dim=0) # [N]
|
||||
|
||||
# 找到互最近邻
|
||||
mutual_matches = []
|
||||
for i in range(M):
|
||||
j = nn_template_to_layout[i]
|
||||
if nn_layout_to_template[j] == i:
|
||||
mutual_matches.append((i.item(), j.item()))
|
||||
|
||||
return mutual_matches
|
||||
|
||||
def ransac_filter(matches, template_kps, layout_kps):
|
||||
"""
|
||||
使用 RANSAC 对匹配进行几何验证,并返回内点。
|
||||
|
||||
参数:
|
||||
matches (list): [(i_template, i_layout)],匹配对列表。
|
||||
template_kps (torch.Tensor): 模板关键点,形状为 [M, 2]。
|
||||
layout_kps (torch.Tensor): 版图关键点,形状为 [N, 2]。
|
||||
|
||||
返回:
|
||||
tuple: (inlier_matches, num_inliers)
|
||||
- inlier_matches: [(i_template, i_layout)],内点匹配对。
|
||||
- num_inliers: int,内点数量。
|
||||
"""
|
||||
src_pts = np.array([template_kps[i].cpu().numpy() for i, _ in matches])
|
||||
dst_pts = np.array([layout_kps[j].cpu().numpy() for _, j in matches])
|
||||
|
||||
if len(src_pts) < 4:
|
||||
return [], 0
|
||||
|
||||
try:
|
||||
H, mask = cv2.findHomography(src_pts, dst_pts, cv2.RANSAC, ransacReprojThreshold=5.0)
|
||||
if H is None:
|
||||
return [], 0
|
||||
inliers = mask.ravel() > 0
|
||||
num_inliers = np.sum(inliers)
|
||||
inlier_matches = [matches[k] for k in range(len(matches)) if inliers[k]]
|
||||
return inlier_matches, num_inliers
|
||||
except cv2.error:
|
||||
return [], 0
|
||||
|
||||
def match_template_to_layout(model, layout_image, template_image):
|
||||
"""
|
||||
使用 RoRD 模型执行模板匹配,迭代找到所有匹配并屏蔽已匹配区域。
|
||||
|
||||
参数:
|
||||
model (RoRD): RoRD 模型。
|
||||
layout_image (torch.Tensor): 版图图像张量,形状为 [1, 1, H_layout, W_layout]。
|
||||
template_image (torch.Tensor): 模板图像张量,形状为 [1, 1, H_template, W_template]。
|
||||
|
||||
返回:
|
||||
list: [{'x': x_min, 'y': y_min, 'width': w, 'height': h}],所有检测到的边框。
|
||||
"""
|
||||
# 提取版图和模板的关键点和描述子
|
||||
layout_kps, layout_descs = extract_keypoints_and_descriptors(model, layout_image)
|
||||
template_kps, template_descs = extract_keypoints_and_descriptors(model, template_image)
|
||||
|
||||
# 初始化活动版图关键点掩码
|
||||
active_layout = torch.ones(len(layout_kps), dtype=bool)
|
||||
|
||||
bboxes = []
|
||||
while True:
|
||||
# 获取当前活动的版图关键点和描述子
|
||||
current_layout_kps = layout_kps[active_layout]
|
||||
current_layout_descs = layout_descs[active_layout]
|
||||
|
||||
if len(current_layout_descs) == 0:
|
||||
break
|
||||
|
||||
# MNN 匹配
|
||||
matches = mutual_nearest_neighbor(template_descs, current_layout_descs)
|
||||
|
||||
if len(matches) == 0:
|
||||
break
|
||||
|
||||
# 将当前版图索引映射回原始版图索引
|
||||
active_indices = torch.nonzero(active_layout).squeeze(1)
|
||||
matches_original = [(i_template, active_indices[i_layout].item()) for i_template, i_layout in matches]
|
||||
|
||||
# RANSAC 过滤
|
||||
inlier_matches, num_inliers = ransac_filter(matches_original, template_kps, layout_kps)
|
||||
|
||||
if num_inliers > 10: # 设置内点阈值
|
||||
# 获取内点在版图中的关键点
|
||||
inlier_layout_kps = [layout_kps[j].cpu().numpy() for _, j in inlier_matches]
|
||||
inlier_layout_kps = np.array(inlier_layout_kps)
|
||||
|
||||
# 计算边框
|
||||
x_min = int(inlier_layout_kps[:, 0].min())
|
||||
y_min = int(inlier_layout_kps[:, 1].min())
|
||||
x_max = int(inlier_layout_kps[:, 0].max())
|
||||
y_max = int(inlier_layout_kps[:, 1].max())
|
||||
bboxes.append({'x': x_min, 'y': y_min, 'width': x_max - x_min, 'height': y_max - y_min})
|
||||
|
||||
# 屏蔽内点
|
||||
for _, j in inlier_matches:
|
||||
active_layout[j] = False
|
||||
else:
|
||||
break
|
||||
|
||||
return bboxes
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 设置变换
|
||||
transform = transforms.Compose([
|
||||
SobelTransform(),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=[0.5], std=[0.5])
|
||||
])
|
||||
|
||||
# 加载模型
|
||||
model = RoRD().cuda()
|
||||
model.load_state_dict(torch.load('path/to/weights.pth'))
|
||||
model.eval()
|
||||
|
||||
# 加载版图和模板图像
|
||||
layout_image = Image.open('path/to/layout.png').convert('L')
|
||||
layout_tensor = transform(layout_image).unsqueeze(0).cuda()
|
||||
|
||||
template_image = Image.open('path/to/template.png').convert('L')
|
||||
template_tensor = transform(template_image).unsqueeze(0).cuda()
|
||||
|
||||
# 执行匹配
|
||||
detected_bboxes = match_template_to_layout(model, layout_tensor, template_tensor)
|
||||
|
||||
# 打印检测到的边框
|
||||
print("检测到的边框:")
|
||||
for bbox in detected_bboxes:
|
||||
print(bbox)
|
||||
Reference in New Issue
Block a user