initial commit

This commit is contained in:
Jiao77
2025-06-07 23:45:32 +08:00
parent fc7147eaea
commit 94ba17771e
10 changed files with 1427 additions and 0 deletions

199
match.py Normal file
View File

@@ -0,0 +1,199 @@
import torch
import torch.nn.functional as F
from models.rord import RoRD
from torchvision import transforms
from utils.transforms import SobelTransform
import numpy as np
import cv2
from PIL import Image
def extract_keypoints_and_descriptors(model, image):
"""
从 RoRD 模型中提取关键点和描述子。
参数:
model (RoRD): RoRD 模型。
image (torch.Tensor): 输入图像张量,形状为 [1, 1, H, W]。
返回:
tuple: (keypoints_input, descriptors)
- keypoints_input: [N, 2] float tensor关键点在输入图像中的坐标。
- descriptors: [N, 128] float tensorL2 归一化的描述子。
"""
with torch.no_grad():
detection_map, _, desc_rord = model(image)
desc = desc_rord # 使用 RoRD 描述子头
# 从检测图中提取关键点
thresh = 0.5
binary_map = (detection_map > thresh).float()
coords = torch.nonzero(binary_map[0, 0] > thresh).float() # [N, 2],每个行是 (i_d, j_d)
keypoints_input = coords * 16.0 # 将特征图坐标映射到输入图像坐标stride=16
# 从描述子图中提取描述子
# detection_map 的形状为 [1, 1, H/16, W/16]desc 的形状为 [1, 128, H/8, W/8]
# 将 detection_map 的坐标映射到 desc 的坐标:(i_d * 2, j_d * 2)
keypoints_desc = (coords * 2).long() # [N, 2],整数坐标
H_desc, W_desc = desc.shape[2], desc.shape[3]
mask = (keypoints_desc[:, 0] < H_desc) & (keypoints_desc[:, 1] < W_desc)
keypoints_desc = keypoints_desc[mask]
keypoints_input = keypoints_input[mask]
# 提取描述子
descriptors = desc[0, :, keypoints_desc[:, 0], keypoints_desc[:, 1]].T # [N, 128]
# L2 归一化描述子
descriptors = F.normalize(descriptors, p=2, dim=1)
return keypoints_input, descriptors
def mutual_nearest_neighbor(template_descs, layout_descs):
"""
使用互最近邻MNN找到模板和版图之间的匹配。
参数:
template_descs (torch.Tensor): 模板描述子,形状为 [M, 128]。
layout_descs (torch.Tensor): 版图描述子,形状为 [N, 128]。
返回:
list: [(i_template, i_layout)],互最近邻匹配对的列表。
"""
M, N = template_descs.size(0), layout_descs.size(0)
if M == 0 or N == 0:
return []
similarity_matrix = template_descs @ layout_descs.T # [M, N],点积矩阵
# 找到每个模板描述子的最近邻
nn_template_to_layout = torch.argmax(similarity_matrix, dim=1) # [M]
# 找到每个版图描述子的最近邻
nn_layout_to_template = torch.argmax(similarity_matrix, dim=0) # [N]
# 找到互最近邻
mutual_matches = []
for i in range(M):
j = nn_template_to_layout[i]
if nn_layout_to_template[j] == i:
mutual_matches.append((i.item(), j.item()))
return mutual_matches
def ransac_filter(matches, template_kps, layout_kps):
"""
使用 RANSAC 对匹配进行几何验证,并返回内点。
参数:
matches (list): [(i_template, i_layout)],匹配对列表。
template_kps (torch.Tensor): 模板关键点,形状为 [M, 2]。
layout_kps (torch.Tensor): 版图关键点,形状为 [N, 2]。
返回:
tuple: (inlier_matches, num_inliers)
- inlier_matches: [(i_template, i_layout)],内点匹配对。
- num_inliers: int内点数量。
"""
src_pts = np.array([template_kps[i].cpu().numpy() for i, _ in matches])
dst_pts = np.array([layout_kps[j].cpu().numpy() for _, j in matches])
if len(src_pts) < 4:
return [], 0
try:
H, mask = cv2.findHomography(src_pts, dst_pts, cv2.RANSAC, ransacReprojThreshold=5.0)
if H is None:
return [], 0
inliers = mask.ravel() > 0
num_inliers = np.sum(inliers)
inlier_matches = [matches[k] for k in range(len(matches)) if inliers[k]]
return inlier_matches, num_inliers
except cv2.error:
return [], 0
def match_template_to_layout(model, layout_image, template_image):
"""
使用 RoRD 模型执行模板匹配,迭代找到所有匹配并屏蔽已匹配区域。
参数:
model (RoRD): RoRD 模型。
layout_image (torch.Tensor): 版图图像张量,形状为 [1, 1, H_layout, W_layout]。
template_image (torch.Tensor): 模板图像张量,形状为 [1, 1, H_template, W_template]。
返回:
list: [{'x': x_min, 'y': y_min, 'width': w, 'height': h}],所有检测到的边框。
"""
# 提取版图和模板的关键点和描述子
layout_kps, layout_descs = extract_keypoints_and_descriptors(model, layout_image)
template_kps, template_descs = extract_keypoints_and_descriptors(model, template_image)
# 初始化活动版图关键点掩码
active_layout = torch.ones(len(layout_kps), dtype=bool)
bboxes = []
while True:
# 获取当前活动的版图关键点和描述子
current_layout_kps = layout_kps[active_layout]
current_layout_descs = layout_descs[active_layout]
if len(current_layout_descs) == 0:
break
# MNN 匹配
matches = mutual_nearest_neighbor(template_descs, current_layout_descs)
if len(matches) == 0:
break
# 将当前版图索引映射回原始版图索引
active_indices = torch.nonzero(active_layout).squeeze(1)
matches_original = [(i_template, active_indices[i_layout].item()) for i_template, i_layout in matches]
# RANSAC 过滤
inlier_matches, num_inliers = ransac_filter(matches_original, template_kps, layout_kps)
if num_inliers > 10: # 设置内点阈值
# 获取内点在版图中的关键点
inlier_layout_kps = [layout_kps[j].cpu().numpy() for _, j in inlier_matches]
inlier_layout_kps = np.array(inlier_layout_kps)
# 计算边框
x_min = int(inlier_layout_kps[:, 0].min())
y_min = int(inlier_layout_kps[:, 1].min())
x_max = int(inlier_layout_kps[:, 0].max())
y_max = int(inlier_layout_kps[:, 1].max())
bboxes.append({'x': x_min, 'y': y_min, 'width': x_max - x_min, 'height': y_max - y_min})
# 屏蔽内点
for _, j in inlier_matches:
active_layout[j] = False
else:
break
return bboxes
if __name__ == "__main__":
# 设置变换
transform = transforms.Compose([
SobelTransform(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5], std=[0.5])
])
# 加载模型
model = RoRD().cuda()
model.load_state_dict(torch.load('path/to/weights.pth'))
model.eval()
# 加载版图和模板图像
layout_image = Image.open('path/to/layout.png').convert('L')
layout_tensor = transform(layout_image).unsqueeze(0).cuda()
template_image = Image.open('path/to/template.png').convert('L')
template_tensor = transform(template_image).unsqueeze(0).cuda()
# 执行匹配
detected_bboxes = match_template_to_layout(model, layout_tensor, template_tensor)
# 打印检测到的边框
print("检测到的边框:")
for bbox in detected_bboxes:
print(bbox)