第二次大修

This commit is contained in:
Jiao77
2025-06-08 15:38:56 +08:00
parent 53ef1ec99c
commit f0b2e1b605
10 changed files with 315 additions and 508 deletions

239
match.py
View File

@@ -1,199 +1,108 @@
# match.py
import torch
import torch.nn.functional as F
from models.rord import RoRD
from torchvision import transforms
from utils.transforms import SobelTransform
import numpy as np
import cv2
from PIL import Image
import argparse
import os
def extract_keypoints_and_descriptors(model, image):
"""
从 RoRD 模型中提取关键点和描述子。
import config
from models.rord import RoRD
from utils.data_utils import get_transform
参数:
model (RoRD): RoRD 模型。
image (torch.Tensor): 输入图像张量,形状为 [1, 1, H, W]。
返回:
tuple: (keypoints_input, descriptors)
- keypoints_input: [N, 2] float tensor关键点在输入图像中的坐标。
- descriptors: [N, 128] float tensorL2 归一化的描述子。
"""
def extract_keypoints_and_descriptors(model, image, kp_thresh):
with torch.no_grad():
detection_map, _, desc_rord = model(image)
desc = desc_rord # 使用 RoRD 描述子头
detection_map, desc = model(image)
binary_map = (detection_map > kp_thresh).float()
coords = torch.nonzero(binary_map[0, 0]).float()
keypoints_input = coords[:, [1, 0]] * 8.0 # Stride of descriptor is 8
# 从检测图中提取关键点
thresh = 0.5
binary_map = (detection_map > thresh).float()
coords = torch.nonzero(binary_map[0, 0] > thresh).float() # [N, 2],每个行是 (i_d, j_d)
keypoints_input = coords * 16.0 # 将特征图坐标映射到输入图像坐标stride=16
# 从描述子图中提取描述子
# detection_map 的形状为 [1, 1, H/16, W/16]desc 的形状为 [1, 128, H/8, W/8]
# 将 detection_map 的坐标映射到 desc 的坐标:(i_d * 2, j_d * 2)
keypoints_desc = (coords * 2).long() # [N, 2],整数坐标
H_desc, W_desc = desc.shape[2], desc.shape[3]
mask = (keypoints_desc[:, 0] < H_desc) & (keypoints_desc[:, 1] < W_desc)
keypoints_desc = keypoints_desc[mask]
keypoints_input = keypoints_input[mask]
# 提取描述子
descriptors = desc[0, :, keypoints_desc[:, 0], keypoints_desc[:, 1]].T # [N, 128]
# L2 归一化描述子
descriptors = F.grid_sample(desc, coords.flip(1).view(1, -1, 1, 2) / torch.tensor([(desc.shape[3]-1)/2, (desc.shape[2]-1)/2], device=desc.device) - 1, align_corners=True).squeeze().T
descriptors = F.normalize(descriptors, p=2, dim=1)
return keypoints_input, descriptors
def mutual_nearest_neighbor(template_descs, layout_descs):
"""
使用互最近邻MNN找到模板和版图之间的匹配。
参数:
template_descs (torch.Tensor): 模板描述子,形状为 [M, 128]。
layout_descs (torch.Tensor): 版图描述子,形状为 [N, 128]。
返回:
list: [(i_template, i_layout)],互最近邻匹配对的列表。
"""
M, N = template_descs.size(0), layout_descs.size(0)
if M == 0 or N == 0:
return []
similarity_matrix = template_descs @ layout_descs.T # [M, N],点积矩阵
# 找到每个模板描述子的最近邻
nn_template_to_layout = torch.argmax(similarity_matrix, dim=1) # [M]
# 找到每个版图描述子的最近邻
nn_layout_to_template = torch.argmax(similarity_matrix, dim=0) # [N]
# 找到互最近邻
mutual_matches = []
for i in range(M):
j = nn_template_to_layout[i]
if nn_layout_to_template[j] == i:
mutual_matches.append((i.item(), j.item()))
return mutual_matches
def ransac_filter(matches, template_kps, layout_kps):
"""
使用 RANSAC 对匹配进行几何验证,并返回内点。
参数:
matches (list): [(i_template, i_layout)],匹配对列表。
template_kps (torch.Tensor): 模板关键点,形状为 [M, 2]。
layout_kps (torch.Tensor): 版图关键点,形状为 [N, 2]。
返回:
tuple: (inlier_matches, num_inliers)
- inlier_matches: [(i_template, i_layout)],内点匹配对。
- num_inliers: int内点数量。
"""
src_pts = np.array([template_kps[i].cpu().numpy() for i, _ in matches])
dst_pts = np.array([layout_kps[j].cpu().numpy() for _, j in matches])
if len(src_pts) < 4:
return [], 0
try:
H, mask = cv2.findHomography(src_pts, dst_pts, cv2.RANSAC, ransacReprojThreshold=5.0)
if H is None:
return [], 0
inliers = mask.ravel() > 0
num_inliers = np.sum(inliers)
inlier_matches = [matches[k] for k in range(len(matches)) if inliers[k]]
return inlier_matches, num_inliers
except cv2.error:
return [], 0
def mutual_nearest_neighbor(descs1, descs2):
sim = descs1 @ descs2.T
nn12 = torch.max(sim, dim=1)
nn21 = torch.max(sim, dim=0)
ids1 = torch.arange(0, sim.shape[0], device=sim.device)
mask = (ids1 == nn21.indices[nn12.indices])
matches = torch.stack([ids1[mask], nn12.indices[mask]], dim=1)
return matches.cpu().numpy()
def match_template_to_layout(model, layout_image, template_image):
"""
使用 RoRD 模型执行模板匹配,迭代找到所有匹配并屏蔽已匹配区域。
layout_kps, layout_descs = extract_keypoints_and_descriptors(model, layout_image, config.KEYPOINT_THRESHOLD)
template_kps, template_descs = extract_keypoints_and_descriptors(model, template_image, config.KEYPOINT_THRESHOLD)
参数:
model (RoRD): RoRD 模型。
layout_image (torch.Tensor): 版图图像张量,形状为 [1, 1, H_layout, W_layout]。
template_image (torch.Tensor): 模板图像张量,形状为 [1, 1, H_template, W_template]。
active_layout_mask = torch.ones(len(layout_kps), dtype=bool, device=layout_kps.device)
found_instances = []
返回:
list: [{'x': x_min, 'y': y_min, 'width': w, 'height': h}],所有检测到的边框。
"""
# 提取版图和模板的关键点和描述子
layout_kps, layout_descs = extract_keypoints_and_descriptors(model, layout_image)
template_kps, template_descs = extract_keypoints_and_descriptors(model, template_image)
# 初始化活动版图关键点掩码
active_layout = torch.ones(len(layout_kps), dtype=bool)
bboxes = []
while True:
# 获取当前活动的版图关键点和描述子
current_layout_kps = layout_kps[active_layout]
current_layout_descs = layout_descs[active_layout]
if len(current_layout_descs) == 0:
current_indices = torch.nonzero(active_layout_mask).squeeze(1)
if len(current_indices) < config.MIN_INLIERS:
break
# MNN 匹配
current_layout_kps, current_layout_descs = layout_kps[current_indices], layout_descs[current_indices]
matches = mutual_nearest_neighbor(template_descs, current_layout_descs)
if len(matches) < 4: break
if len(matches) == 0:
src_pts = template_kps[matches[:, 0]].cpu().numpy()
dst_pts = current_layout_kps[matches[:, 1]].cpu().numpy()
H, mask = cv2.findHomography(src_pts, dst_pts, cv2.RANSAC, config.RANSAC_REPROJ_THRESHOLD)
if H is None or mask.sum() < config.MIN_INLIERS:
break
# 将当前版图索引映射回原始版图索引
active_indices = torch.nonzero(active_layout).squeeze(1)
matches_original = [(i_template, active_indices[i_layout].item()) for i_template, i_layout in matches]
inlier_mask = mask.ravel().astype(bool)
# 区域屏蔽逻辑
inlier_layout_kps = dst_pts[inlier_mask]
x_min, y_min = inlier_layout_kps.min(axis=0)
x_max, y_max = inlier_layout_kps.max(axis=0)
instance = {'x': int(x_min), 'y': int(y_min), 'width': int(x_max - x_min), 'height': int(y_max - y_min), 'homography': H}
found_instances.append(instance)
# RANSAC 过滤
inlier_matches, num_inliers = ransac_filter(matches_original, template_kps, layout_kps)
kp_x, kp_y = layout_kps[:, 0], layout_kps[:, 1]
region_mask = (kp_x >= x_min) & (kp_x <= x_max) & (kp_y >= y_min) & (kp_y <= y_max)
active_layout_mask[region_mask] = False
print(f"找到实例,内点数: {mask.sum()}。剩余活动关键点: {active_layout_mask.sum()}")
return found_instances
if num_inliers > 10: # 设置内点阈值
# 获取内点在版图中的关键点
inlier_layout_kps = [layout_kps[j].cpu().numpy() for _, j in inlier_matches]
inlier_layout_kps = np.array(inlier_layout_kps)
# 计算边框
x_min = int(inlier_layout_kps[:, 0].min())
y_min = int(inlier_layout_kps[:, 1].min())
x_max = int(inlier_layout_kps[:, 0].max())
y_max = int(inlier_layout_kps[:, 1].max())
bboxes.append({'x': x_min, 'y': y_min, 'width': x_max - x_min, 'height': y_max - y_min})
# 屏蔽内点
for _, j in inlier_matches:
active_layout[j] = False
else:
break
return bboxes
def visualize_matches(layout_path, template_path, bboxes, output_path):
layout_img = cv2.imread(layout_path)
for i, bbox in enumerate(bboxes):
x, y, w, h = bbox['x'], bbox['y'], bbox['width'], bbox['height']
cv2.rectangle(layout_img, (x, y), (x + w, y + h), (0, 255, 0), 2)
cv2.putText(layout_img, f"Match {i+1}", (x, y - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)
cv2.imwrite(output_path, layout_img)
print(f"可视化结果已保存至: {output_path}")
if __name__ == "__main__":
# 设置变换
transform = transforms.Compose([
SobelTransform(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5], std=[0.5])
])
parser = argparse.ArgumentParser(description="使用 RoRD 进行模板匹配")
parser.add_argument('--model_path', type=str, default=config.MODEL_PATH)
parser.add_argument('--layout', type=str, required=True)
parser.add_argument('--template', type=str, required=True)
parser.add_argument('--output', type=str)
args = parser.parse_args()
# 加载模型
transform = get_transform()
model = RoRD().cuda()
model.load_state_dict(torch.load('path/to/weights.pth'))
model.load_state_dict(torch.load(args.model_path))
model.eval()
# 加载版图和模板图像
layout_image = Image.open('path/to/layout.png').convert('L')
layout_tensor = transform(layout_image).unsqueeze(0).cuda()
layout_tensor = transform(Image.open(args.layout).convert('L')).unsqueeze(0).cuda()
template_tensor = transform(Image.open(args.template).convert('L')).unsqueeze(0).cuda()
template_image = Image.open('path/to/template.png').convert('L')
template_tensor = transform(template_image).unsqueeze(0).cuda()
# 执行匹配
detected_bboxes = match_template_to_layout(model, layout_tensor, template_tensor)
# 打印检测到的边框
print("检测到的边框:")
print("\n检测到的边界框:")
for bbox in detected_bboxes:
print(bbox)
print(bbox)
if args.output:
visualize_matches(args.layout, args.template, detected_bboxes, args.output)