| 
									
										
										
										
											2025-06-08 15:38:56 +08:00
										 |  |  |  | # match.py | 
					
						
							|  |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-09-25 20:20:24 +08:00
										 |  |  |  | import argparse | 
					
						
							|  |  |  |  | import os | 
					
						
							|  |  |  |  | from pathlib import Path | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | import cv2 | 
					
						
							|  |  |  |  | import numpy as np | 
					
						
							| 
									
										
										
										
											2025-06-07 23:45:32 +08:00
										 |  |  |  | import torch | 
					
						
							|  |  |  |  | import torch.nn.functional as F | 
					
						
							|  |  |  |  | from PIL import Image | 
					
						
							| 
									
										
										
										
											2025-09-25 21:24:41 +08:00
										 |  |  |  | try: | 
					
						
							|  |  |  |  |     from torch.utils.tensorboard import SummaryWriter | 
					
						
							|  |  |  |  | except ImportError:  # pragma: no cover - fallback for environments without torch tensorboard | 
					
						
							|  |  |  |  |     from tensorboardX import SummaryWriter  # type: ignore | 
					
						
							| 
									
										
										
										
											2025-06-07 23:45:32 +08:00
										 |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-06-08 15:38:56 +08:00
										 |  |  |  | from models.rord import RoRD | 
					
						
							| 
									
										
										
										
											2025-09-25 20:20:24 +08:00
										 |  |  |  | from utils.config_loader import load_config, to_absolute_path | 
					
						
							| 
									
										
										
										
											2025-06-08 15:38:56 +08:00
										 |  |  |  | from utils.data_utils import get_transform | 
					
						
							| 
									
										
										
										
											2025-06-07 23:45:32 +08:00
										 |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-06-09 01:49:13 +08:00
										 |  |  |  | # --- 特征提取函数 (基本无变动) --- | 
					
						
							|  |  |  |  | def extract_keypoints_and_descriptors(model, image_tensor, kp_thresh): | 
					
						
							| 
									
										
										
										
											2025-06-07 23:45:32 +08:00
										 |  |  |  |     with torch.no_grad(): | 
					
						
							| 
									
										
										
										
											2025-06-09 01:49:13 +08:00
										 |  |  |  |         detection_map, desc = model(image_tensor) | 
					
						
							|  |  |  |  |      | 
					
						
							|  |  |  |  |     device = detection_map.device | 
					
						
							|  |  |  |  |     binary_map = (detection_map > kp_thresh).squeeze(0).squeeze(0) | 
					
						
							|  |  |  |  |     coords = torch.nonzero(binary_map).float() # y, x | 
					
						
							|  |  |  |  |      | 
					
						
							|  |  |  |  |     if len(coords) == 0: | 
					
						
							|  |  |  |  |         return torch.tensor([], device=device), torch.tensor([], device=device) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |     # 描述子采样 | 
					
						
							|  |  |  |  |     coords_for_grid = coords.flip(1).view(1, -1, 1, 2) # N, 2 -> 1, N, 1, 2 (x,y) | 
					
						
							|  |  |  |  |     # 归一化到 [-1, 1] | 
					
						
							|  |  |  |  |     coords_for_grid = coords_for_grid / torch.tensor([(desc.shape[3]-1)/2, (desc.shape[2]-1)/2], device=device) - 1 | 
					
						
							|  |  |  |  |      | 
					
						
							|  |  |  |  |     descriptors = F.grid_sample(desc, coords_for_grid, align_corners=True).squeeze().T | 
					
						
							|  |  |  |  |     descriptors = F.normalize(descriptors, p=2, dim=1) | 
					
						
							|  |  |  |  |      | 
					
						
							|  |  |  |  |     # 将关键点坐标从特征图尺度转换回图像尺度 | 
					
						
							|  |  |  |  |     # VGG到relu4_3的下采样率为8 | 
					
						
							|  |  |  |  |     keypoints = coords.flip(1) * 8.0 # x, y | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |     return keypoints, descriptors | 
					
						
							|  |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-09-25 22:05:39 +08:00
										 |  |  |  | 
 | 
					
						
							|  |  |  |  | # --- (新增) 简单半径 NMS 去重 --- | 
					
						
							|  |  |  |  | def radius_nms(kps: torch.Tensor, scores: torch.Tensor, radius: float) -> torch.Tensor: | 
					
						
							|  |  |  |  |     if kps.numel() == 0: | 
					
						
							|  |  |  |  |         return torch.empty((0,), dtype=torch.long, device=kps.device) | 
					
						
							|  |  |  |  |     idx = torch.argsort(scores, descending=True) | 
					
						
							|  |  |  |  |     keep = [] | 
					
						
							|  |  |  |  |     taken = torch.zeros(len(kps), dtype=torch.bool, device=kps.device) | 
					
						
							|  |  |  |  |     for i in idx: | 
					
						
							|  |  |  |  |         if taken[i]: | 
					
						
							|  |  |  |  |             continue | 
					
						
							|  |  |  |  |         keep.append(i.item()) | 
					
						
							|  |  |  |  |         di = kps - kps[i] | 
					
						
							|  |  |  |  |         dist2 = (di[:, 0]**2 + di[:, 1]**2) | 
					
						
							|  |  |  |  |         taken |= dist2 <= (radius * radius) | 
					
						
							|  |  |  |  |         taken[i] = True | 
					
						
							|  |  |  |  |     return torch.tensor(keep, dtype=torch.long, device=kps.device) | 
					
						
							|  |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-06-09 01:49:13 +08:00
										 |  |  |  | # --- (新增) 滑动窗口特征提取函数 --- | 
					
						
							| 
									
										
										
										
											2025-09-25 20:20:24 +08:00
										 |  |  |  | def extract_features_sliding_window(model, large_image, transform, matching_cfg): | 
					
						
							| 
									
										
										
										
											2025-06-09 01:49:13 +08:00
										 |  |  |  |     """
 | 
					
						
							|  |  |  |  |     使用滑动窗口从大图上提取所有关键点和描述子 | 
					
						
							|  |  |  |  |     """
 | 
					
						
							|  |  |  |  |     print("使用滑动窗口提取大版图特征...") | 
					
						
							|  |  |  |  |     device = next(model.parameters()).device | 
					
						
							|  |  |  |  |     W, H = large_image.size | 
					
						
							| 
									
										
										
										
											2025-09-25 20:20:24 +08:00
										 |  |  |  |     window_size = int(matching_cfg.inference_window_size) | 
					
						
							|  |  |  |  |     stride = int(matching_cfg.inference_stride) | 
					
						
							|  |  |  |  |     keypoint_threshold = float(matching_cfg.keypoint_threshold) | 
					
						
							| 
									
										
										
										
											2025-06-09 01:49:13 +08:00
										 |  |  |  | 
 | 
					
						
							|  |  |  |  |     all_kps = [] | 
					
						
							|  |  |  |  |     all_descs = [] | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |     for y in range(0, H, stride): | 
					
						
							|  |  |  |  |         for x in range(0, W, stride): | 
					
						
							|  |  |  |  |             # 确保窗口不越界 | 
					
						
							|  |  |  |  |             x_end = min(x + window_size, W) | 
					
						
							|  |  |  |  |             y_end = min(y + window_size, H) | 
					
						
							|  |  |  |  |              | 
					
						
							|  |  |  |  |             # 裁剪窗口 | 
					
						
							|  |  |  |  |             patch = large_image.crop((x, y, x_end, y_end)) | 
					
						
							|  |  |  |  |              | 
					
						
							|  |  |  |  |             # 预处理 | 
					
						
							|  |  |  |  |             patch_tensor = transform(patch).unsqueeze(0).to(device) | 
					
						
							|  |  |  |  |              | 
					
						
							|  |  |  |  |             # 提取特征 | 
					
						
							| 
									
										
										
										
											2025-09-25 20:20:24 +08:00
										 |  |  |  |             kps, descs = extract_keypoints_and_descriptors(model, patch_tensor, keypoint_threshold) | 
					
						
							| 
									
										
										
										
											2025-06-09 01:49:13 +08:00
										 |  |  |  |              | 
					
						
							|  |  |  |  |             if len(kps) > 0: | 
					
						
							|  |  |  |  |                 # 将局部坐标转换为全局坐标 | 
					
						
							|  |  |  |  |                 kps[:, 0] += x | 
					
						
							|  |  |  |  |                 kps[:, 1] += y | 
					
						
							|  |  |  |  |                 all_kps.append(kps) | 
					
						
							|  |  |  |  |                 all_descs.append(descs) | 
					
						
							|  |  |  |  |      | 
					
						
							|  |  |  |  |     if not all_kps: | 
					
						
							|  |  |  |  |         return torch.tensor([], device=device), torch.tensor([], device=device) | 
					
						
							| 
									
										
										
										
											2025-06-07 23:45:32 +08:00
										 |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-06-09 01:49:13 +08:00
										 |  |  |  |     print(f"大版图特征提取完毕,共找到 {sum(len(k) for k in all_kps)} 个关键点。") | 
					
						
							|  |  |  |  |     return torch.cat(all_kps, dim=0), torch.cat(all_descs, dim=0) | 
					
						
							| 
									
										
										
										
											2025-06-07 23:45:32 +08:00
										 |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-06-09 01:49:13 +08:00
										 |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-09-25 22:05:39 +08:00
										 |  |  |  | # --- (新增) FPN 路径的关键点与描述子抽取 --- | 
					
						
							|  |  |  |  | def extract_from_pyramid(model, image_tensor, kp_thresh, nms_cfg): | 
					
						
							|  |  |  |  |     with torch.no_grad(): | 
					
						
							|  |  |  |  |         pyramid = model(image_tensor, return_pyramid=True) | 
					
						
							|  |  |  |  |     all_kps = [] | 
					
						
							|  |  |  |  |     all_desc = [] | 
					
						
							|  |  |  |  |     for level_name, (det, desc, stride) in pyramid.items(): | 
					
						
							|  |  |  |  |         binary = (det > kp_thresh).squeeze(0).squeeze(0) | 
					
						
							|  |  |  |  |         coords = torch.nonzero(binary).float()  # y,x | 
					
						
							|  |  |  |  |         if len(coords) == 0: | 
					
						
							|  |  |  |  |             continue | 
					
						
							|  |  |  |  |         scores = det.squeeze()[binary] | 
					
						
							|  |  |  |  |         # 采样描述子 | 
					
						
							|  |  |  |  |         coords_for_grid = coords.flip(1).view(1, -1, 1, 2) | 
					
						
							|  |  |  |  |         coords_for_grid = coords_for_grid / torch.tensor([(desc.shape[3]-1)/2, (desc.shape[2]-1)/2], device=desc.device) - 1 | 
					
						
							|  |  |  |  |         descs = F.grid_sample(desc, coords_for_grid, align_corners=True).squeeze().T | 
					
						
							|  |  |  |  |         descs = F.normalize(descs, p=2, dim=1) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |         # 映射回原图坐标 | 
					
						
							|  |  |  |  |         kps = coords.flip(1) * float(stride) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |         # NMS | 
					
						
							|  |  |  |  |         if nms_cfg and nms_cfg.get('enabled', False): | 
					
						
							|  |  |  |  |             keep = radius_nms(kps, scores, float(nms_cfg.get('radius', 4))) | 
					
						
							|  |  |  |  |             if len(keep) > 0: | 
					
						
							|  |  |  |  |                 kps = kps[keep] | 
					
						
							|  |  |  |  |                 descs = descs[keep] | 
					
						
							|  |  |  |  |         all_kps.append(kps) | 
					
						
							|  |  |  |  |         all_desc.append(descs) | 
					
						
							|  |  |  |  |     if not all_kps: | 
					
						
							|  |  |  |  |         return torch.tensor([], device=image_tensor.device), torch.tensor([], device=image_tensor.device) | 
					
						
							|  |  |  |  |     return torch.cat(all_kps, dim=0), torch.cat(all_desc, dim=0) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-06-09 01:49:13 +08:00
										 |  |  |  | # --- 互近邻匹配 (无变动) --- | 
					
						
							| 
									
										
										
										
											2025-06-08 15:38:56 +08:00
										 |  |  |  | def mutual_nearest_neighbor(descs1, descs2): | 
					
						
							| 
									
										
										
										
											2025-06-09 01:49:13 +08:00
										 |  |  |  |     if len(descs1) == 0 or len(descs2) == 0: | 
					
						
							|  |  |  |  |         return torch.empty((0, 2), dtype=torch.int64) | 
					
						
							| 
									
										
										
										
											2025-06-08 15:38:56 +08:00
										 |  |  |  |     sim = descs1 @ descs2.T | 
					
						
							|  |  |  |  |     nn12 = torch.max(sim, dim=1) | 
					
						
							|  |  |  |  |     nn21 = torch.max(sim, dim=0) | 
					
						
							|  |  |  |  |     ids1 = torch.arange(0, sim.shape[0], device=sim.device) | 
					
						
							|  |  |  |  |     mask = (ids1 == nn21.indices[nn12.indices]) | 
					
						
							|  |  |  |  |     matches = torch.stack([ids1[mask], nn12.indices[mask]], dim=1) | 
					
						
							| 
									
										
										
										
											2025-06-09 01:49:13 +08:00
										 |  |  |  |     return matches | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | # --- (已修改) 多尺度、多实例匹配主函数 --- | 
					
						
							| 
									
										
										
										
											2025-09-25 21:24:41 +08:00
										 |  |  |  | def match_template_multiscale( | 
					
						
							|  |  |  |  |     model, | 
					
						
							|  |  |  |  |     layout_image, | 
					
						
							|  |  |  |  |     template_image, | 
					
						
							|  |  |  |  |     transform, | 
					
						
							|  |  |  |  |     matching_cfg, | 
					
						
							|  |  |  |  |     log_writer: SummaryWriter | None = None, | 
					
						
							|  |  |  |  |     log_step: int = 0, | 
					
						
							|  |  |  |  | ): | 
					
						
							| 
									
										
										
										
											2025-06-09 01:49:13 +08:00
										 |  |  |  |     """
 | 
					
						
							|  |  |  |  |     在不同尺度下搜索模板,并检测多个实例 | 
					
						
							|  |  |  |  |     """
 | 
					
						
							| 
									
										
										
										
											2025-09-25 22:05:39 +08:00
										 |  |  |  |     # 1. 版图特征提取:根据配置选择 FPN 或滑窗 | 
					
						
							|  |  |  |  |     device = next(model.parameters()).device | 
					
						
							|  |  |  |  |     if getattr(matching_cfg, 'use_fpn', False): | 
					
						
							|  |  |  |  |         layout_tensor = transform(layout_image).unsqueeze(0).to(device) | 
					
						
							|  |  |  |  |         layout_kps, layout_descs = extract_from_pyramid(model, layout_tensor, float(matching_cfg.keypoint_threshold), getattr(matching_cfg, 'nms', {})) | 
					
						
							|  |  |  |  |     else: | 
					
						
							|  |  |  |  |         layout_kps, layout_descs = extract_features_sliding_window(model, layout_image, transform, matching_cfg) | 
					
						
							| 
									
										
										
										
											2025-09-25 21:24:41 +08:00
										 |  |  |  |     if log_writer: | 
					
						
							|  |  |  |  |         log_writer.add_scalar("match/layout_keypoints", len(layout_kps), log_step) | 
					
						
							| 
									
										
										
										
											2025-06-09 01:49:13 +08:00
										 |  |  |  |      | 
					
						
							| 
									
										
										
										
											2025-09-25 20:20:24 +08:00
										 |  |  |  |     min_inliers = int(matching_cfg.min_inliers) | 
					
						
							|  |  |  |  |     if len(layout_kps) < min_inliers: | 
					
						
							| 
									
										
										
										
											2025-06-09 01:49:13 +08:00
										 |  |  |  |         print("从大版图中提取的关键点过少,无法进行匹配。") | 
					
						
							| 
									
										
										
										
											2025-09-25 21:24:41 +08:00
										 |  |  |  |         if log_writer: | 
					
						
							|  |  |  |  |             log_writer.add_scalar("match/instances_found", 0, log_step) | 
					
						
							| 
									
										
										
										
											2025-06-09 01:49:13 +08:00
										 |  |  |  |         return [] | 
					
						
							| 
									
										
										
										
											2025-06-07 23:45:32 +08:00
										 |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-06-08 15:38:56 +08:00
										 |  |  |  |     found_instances = [] | 
					
						
							| 
									
										
										
										
											2025-06-09 01:49:13 +08:00
										 |  |  |  |     active_layout_mask = torch.ones(len(layout_kps), dtype=bool, device=layout_kps.device) | 
					
						
							| 
									
										
										
										
											2025-09-25 20:20:24 +08:00
										 |  |  |  |     pyramid_scales = [float(s) for s in matching_cfg.pyramid_scales] | 
					
						
							|  |  |  |  |     keypoint_threshold = float(matching_cfg.keypoint_threshold) | 
					
						
							|  |  |  |  |     ransac_threshold = float(matching_cfg.ransac_reproj_threshold) | 
					
						
							| 
									
										
										
										
											2025-06-09 01:49:13 +08:00
										 |  |  |  |      | 
					
						
							|  |  |  |  |     # 2. 多实例迭代检测 | 
					
						
							| 
									
										
										
										
											2025-06-07 23:45:32 +08:00
										 |  |  |  |     while True: | 
					
						
							| 
									
										
										
										
											2025-06-09 01:49:13 +08:00
										 |  |  |  |         current_active_indices = torch.nonzero(active_layout_mask).squeeze(1) | 
					
						
							|  |  |  |  |          | 
					
						
							|  |  |  |  |         # 如果剩余活动关键点过少,则停止 | 
					
						
							| 
									
										
										
										
											2025-09-25 20:20:24 +08:00
										 |  |  |  |         if len(current_active_indices) < min_inliers: | 
					
						
							| 
									
										
										
										
											2025-06-07 23:45:32 +08:00
										 |  |  |  |             break | 
					
						
							|  |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-06-09 01:49:13 +08:00
										 |  |  |  |         current_layout_kps = layout_kps[current_active_indices] | 
					
						
							|  |  |  |  |         current_layout_descs = layout_descs[current_active_indices] | 
					
						
							| 
									
										
										
										
											2025-06-08 15:38:56 +08:00
										 |  |  |  |          | 
					
						
							| 
									
										
										
										
											2025-06-09 01:49:13 +08:00
										 |  |  |  |         best_match_info = {'inliers': 0, 'H': None, 'src_pts': None, 'dst_pts': None, 'mask': None} | 
					
						
							| 
									
										
										
										
											2025-06-07 23:45:32 +08:00
										 |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-06-09 01:49:13 +08:00
										 |  |  |  |         # 3. 图像金字塔:遍历模板的每个尺度 | 
					
						
							|  |  |  |  |         print("在新尺度下搜索模板...") | 
					
						
							| 
									
										
										
										
											2025-09-25 20:20:24 +08:00
										 |  |  |  |         for scale in pyramid_scales: | 
					
						
							| 
									
										
										
										
											2025-06-09 01:49:13 +08:00
										 |  |  |  |             W, H = template_image.size | 
					
						
							|  |  |  |  |             new_W, new_H = int(W * scale), int(H * scale) | 
					
						
							|  |  |  |  |              | 
					
						
							|  |  |  |  |             # 缩放模板 | 
					
						
							|  |  |  |  |             scaled_template = template_image.resize((new_W, new_H), Image.LANCZOS) | 
					
						
							|  |  |  |  |             template_tensor = transform(scaled_template).unsqueeze(0).to(layout_kps.device) | 
					
						
							|  |  |  |  |              | 
					
						
							| 
									
										
										
										
											2025-09-25 22:05:39 +08:00
										 |  |  |  |             # 提取缩放后模板的特征:FPN 或单尺度 | 
					
						
							|  |  |  |  |             if getattr(matching_cfg, 'use_fpn', False): | 
					
						
							|  |  |  |  |                 template_kps, template_descs = extract_from_pyramid(model, template_tensor, keypoint_threshold, getattr(matching_cfg, 'nms', {})) | 
					
						
							|  |  |  |  |             else: | 
					
						
							|  |  |  |  |                 template_kps, template_descs = extract_keypoints_and_descriptors(model, template_tensor, keypoint_threshold) | 
					
						
							| 
									
										
										
										
											2025-06-09 01:49:13 +08:00
										 |  |  |  |              | 
					
						
							|  |  |  |  |             if len(template_kps) < 4: continue | 
					
						
							| 
									
										
										
										
											2025-06-07 23:45:32 +08:00
										 |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-06-09 01:49:13 +08:00
										 |  |  |  |             # 匹配当前尺度的模板和活动状态的版图特征 | 
					
						
							|  |  |  |  |             matches = mutual_nearest_neighbor(template_descs, current_layout_descs) | 
					
						
							|  |  |  |  |              | 
					
						
							|  |  |  |  |             if len(matches) < 4: continue | 
					
						
							| 
									
										
										
										
											2025-06-07 23:45:32 +08:00
										 |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-06-09 01:49:13 +08:00
										 |  |  |  |             # RANSAC | 
					
						
							|  |  |  |  |             # 注意:模板关键点坐标需要还原到原始尺寸,才能计算正确的H | 
					
						
							|  |  |  |  |             src_pts = template_kps[matches[:, 0]].cpu().numpy() / scale | 
					
						
							|  |  |  |  |             dst_pts_indices = current_active_indices[matches[:, 1]] | 
					
						
							|  |  |  |  |             dst_pts = layout_kps[dst_pts_indices].cpu().numpy() | 
					
						
							| 
									
										
										
										
											2025-06-08 15:38:56 +08:00
										 |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-09-25 20:20:24 +08:00
										 |  |  |  |             H, mask = cv2.findHomography(src_pts, dst_pts, cv2.RANSAC, ransac_threshold) | 
					
						
							| 
									
										
										
										
											2025-06-09 01:49:13 +08:00
										 |  |  |  | 
 | 
					
						
							|  |  |  |  |             if H is not None and mask.sum() > best_match_info['inliers']: | 
					
						
							|  |  |  |  |                 best_match_info = {'inliers': mask.sum(), 'H': H, 'mask': mask, 'scale': scale, 'dst_pts': dst_pts} | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |         # 4. 如果在所有尺度中找到了最佳匹配,则记录并屏蔽 | 
					
						
							| 
									
										
										
										
											2025-09-25 20:20:24 +08:00
										 |  |  |  |         if best_match_info['inliers'] > min_inliers: | 
					
						
							| 
									
										
										
										
											2025-06-09 01:49:13 +08:00
										 |  |  |  |             print(f"找到一个匹配实例!内点数: {best_match_info['inliers']}, 使用的模板尺度: {best_match_info['scale']:.2f}x") | 
					
						
							| 
									
										
										
										
											2025-09-25 21:24:41 +08:00
										 |  |  |  |             if log_writer: | 
					
						
							|  |  |  |  |                 instance_index = len(found_instances) | 
					
						
							|  |  |  |  |                 log_writer.add_scalar("match/instance_inliers", int(best_match_info['inliers']), log_step + instance_index) | 
					
						
							|  |  |  |  |                 log_writer.add_scalar("match/instance_scale", float(best_match_info['scale']), log_step + instance_index) | 
					
						
							| 
									
										
										
										
											2025-06-09 01:49:13 +08:00
										 |  |  |  |              | 
					
						
							|  |  |  |  |             inlier_mask = best_match_info['mask'].ravel().astype(bool) | 
					
						
							|  |  |  |  |             inlier_layout_kps = best_match_info['dst_pts'][inlier_mask] | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |             x_min, y_min = inlier_layout_kps.min(axis=0) | 
					
						
							|  |  |  |  |             x_max, y_max = inlier_layout_kps.max(axis=0) | 
					
						
							|  |  |  |  |              | 
					
						
							|  |  |  |  |             instance = {'x': int(x_min), 'y': int(y_min), 'width': int(x_max - x_min), 'height': int(y_max - y_min), 'homography': best_match_info['H']} | 
					
						
							|  |  |  |  |             found_instances.append(instance) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |             # 屏蔽已匹配区域的关键点,以便检测下一个实例 | 
					
						
							|  |  |  |  |             kp_x, kp_y = layout_kps[:, 0], layout_kps[:, 1] | 
					
						
							|  |  |  |  |             region_mask = (kp_x >= x_min) & (kp_x <= x_max) & (kp_y >= y_min) & (kp_y <= y_max) | 
					
						
							|  |  |  |  |             active_layout_mask[region_mask] = False | 
					
						
							|  |  |  |  |              | 
					
						
							|  |  |  |  |             print(f"剩余活动关键点: {active_layout_mask.sum()}") | 
					
						
							|  |  |  |  |         else: | 
					
						
							|  |  |  |  |             # 如果在所有尺度下都找不到好的匹配,则结束搜索 | 
					
						
							|  |  |  |  |             print("在所有尺度下均未找到新的匹配实例,搜索结束。") | 
					
						
							|  |  |  |  |             break | 
					
						
							| 
									
										
										
										
											2025-06-08 15:38:56 +08:00
										 |  |  |  |              | 
					
						
							| 
									
										
										
										
											2025-09-25 21:24:41 +08:00
										 |  |  |  |     if log_writer: | 
					
						
							|  |  |  |  |         log_writer.add_scalar("match/instances_found", len(found_instances), log_step) | 
					
						
							|  |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-06-08 15:38:56 +08:00
										 |  |  |  |     return found_instances | 
					
						
							|  |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-06-09 01:49:13 +08:00
										 |  |  |  | 
 | 
					
						
							|  |  |  |  | def visualize_matches(layout_path, bboxes, output_path): | 
					
						
							| 
									
										
										
										
											2025-06-08 15:38:56 +08:00
										 |  |  |  |     layout_img = cv2.imread(layout_path) | 
					
						
							|  |  |  |  |     for i, bbox in enumerate(bboxes): | 
					
						
							|  |  |  |  |         x, y, w, h = bbox['x'], bbox['y'], bbox['width'], bbox['height'] | 
					
						
							|  |  |  |  |         cv2.rectangle(layout_img, (x, y), (x + w, y + h), (0, 255, 0), 2) | 
					
						
							|  |  |  |  |         cv2.putText(layout_img, f"Match {i+1}", (x, y - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2) | 
					
						
							|  |  |  |  |     cv2.imwrite(output_path, layout_img) | 
					
						
							|  |  |  |  |     print(f"可视化结果已保存至: {output_path}") | 
					
						
							| 
									
										
										
										
											2025-06-07 23:45:32 +08:00
										 |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-06-09 01:49:13 +08:00
										 |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-06-07 23:45:32 +08:00
										 |  |  |  | if __name__ == "__main__": | 
					
						
							| 
									
										
										
										
											2025-06-09 01:49:13 +08:00
										 |  |  |  |     parser = argparse.ArgumentParser(description="使用 RoRD 进行多尺度模板匹配") | 
					
						
							| 
									
										
										
										
											2025-09-25 20:20:24 +08:00
										 |  |  |  |     parser.add_argument('--config', type=str, default="configs/base_config.yaml", help="YAML 配置文件路径") | 
					
						
							|  |  |  |  |     parser.add_argument('--model_path', type=str, default=None, help="模型权重路径,若未提供则使用配置文件中的路径") | 
					
						
							| 
									
										
										
										
											2025-09-25 21:24:41 +08:00
										 |  |  |  |     parser.add_argument('--log_dir', type=str, default=None, help="TensorBoard 日志根目录,覆盖配置文件设置") | 
					
						
							|  |  |  |  |     parser.add_argument('--experiment_name', type=str, default=None, help="TensorBoard 实验名称,覆盖配置文件设置") | 
					
						
							|  |  |  |  |     parser.add_argument('--tb_log_matches', action='store_true', help="启用模板匹配过程的 TensorBoard 记录") | 
					
						
							|  |  |  |  |     parser.add_argument('--disable_tensorboard', action='store_true', help="禁用 TensorBoard 记录") | 
					
						
							| 
									
										
										
										
											2025-09-25 22:05:39 +08:00
										 |  |  |  |     parser.add_argument('--fpn_off', action='store_true', help="关闭 FPN 匹配路径(等同于 matching.use_fpn=false)") | 
					
						
							|  |  |  |  |     parser.add_argument('--no_nms', action='store_true', help="关闭关键点去重(NMS)") | 
					
						
							| 
									
										
										
										
											2025-06-08 15:38:56 +08:00
										 |  |  |  |     parser.add_argument('--layout', type=str, required=True) | 
					
						
							|  |  |  |  |     parser.add_argument('--template', type=str, required=True) | 
					
						
							|  |  |  |  |     parser.add_argument('--output', type=str) | 
					
						
							|  |  |  |  |     args = parser.parse_args() | 
					
						
							|  |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-09-25 20:20:24 +08:00
										 |  |  |  |     cfg = load_config(args.config) | 
					
						
							|  |  |  |  |     config_dir = Path(args.config).resolve().parent | 
					
						
							|  |  |  |  |     matching_cfg = cfg.matching | 
					
						
							| 
									
										
										
										
											2025-09-25 21:24:41 +08:00
										 |  |  |  |     logging_cfg = cfg.get("logging", None) | 
					
						
							| 
									
										
										
										
											2025-09-25 20:20:24 +08:00
										 |  |  |  |     model_path = args.model_path or str(to_absolute_path(cfg.paths.model_path, config_dir)) | 
					
						
							|  |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-09-25 21:24:41 +08:00
										 |  |  |  |     use_tensorboard = False | 
					
						
							|  |  |  |  |     log_dir = None | 
					
						
							|  |  |  |  |     experiment_name = None | 
					
						
							|  |  |  |  |     if logging_cfg is not None: | 
					
						
							|  |  |  |  |         use_tensorboard = bool(logging_cfg.get("use_tensorboard", False)) | 
					
						
							|  |  |  |  |         log_dir = logging_cfg.get("log_dir", "runs") | 
					
						
							|  |  |  |  |         experiment_name = logging_cfg.get("experiment_name", "default") | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |     if args.disable_tensorboard: | 
					
						
							|  |  |  |  |         use_tensorboard = False | 
					
						
							|  |  |  |  |     if args.log_dir is not None: | 
					
						
							|  |  |  |  |         log_dir = args.log_dir | 
					
						
							|  |  |  |  |     if args.experiment_name is not None: | 
					
						
							|  |  |  |  |         experiment_name = args.experiment_name | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |     should_log_matches = args.tb_log_matches and use_tensorboard and log_dir is not None | 
					
						
							|  |  |  |  |     writer = None | 
					
						
							|  |  |  |  |     if should_log_matches: | 
					
						
							|  |  |  |  |         log_root = Path(log_dir).expanduser() | 
					
						
							|  |  |  |  |         exp_folder = experiment_name or "default" | 
					
						
							|  |  |  |  |         tb_path = log_root / "match" / exp_folder | 
					
						
							|  |  |  |  |         tb_path.parent.mkdir(parents=True, exist_ok=True) | 
					
						
							|  |  |  |  |         writer = SummaryWriter(tb_path.as_posix()) | 
					
						
							|  |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-09-25 22:05:39 +08:00
										 |  |  |  |     # CLI 快捷开关覆盖 YAML 配置 | 
					
						
							|  |  |  |  |     try: | 
					
						
							|  |  |  |  |         if args.fpn_off: | 
					
						
							|  |  |  |  |             matching_cfg.use_fpn = False | 
					
						
							|  |  |  |  |         if args.no_nms and hasattr(matching_cfg, 'nms'): | 
					
						
							|  |  |  |  |             matching_cfg.nms.enabled = False | 
					
						
							|  |  |  |  |     except Exception: | 
					
						
							|  |  |  |  |         # 若 OmegaConf 结构不可写,忽略并在后续逻辑中以 getattr 的方式读取 | 
					
						
							|  |  |  |  |         pass | 
					
						
							|  |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-06-08 15:38:56 +08:00
										 |  |  |  |     transform = get_transform() | 
					
						
							| 
									
										
										
										
											2025-06-07 23:45:32 +08:00
										 |  |  |  |     model = RoRD().cuda() | 
					
						
							| 
									
										
										
										
											2025-09-25 20:20:24 +08:00
										 |  |  |  |     model.load_state_dict(torch.load(model_path)) | 
					
						
							| 
									
										
										
										
											2025-06-07 23:45:32 +08:00
										 |  |  |  |     model.eval() | 
					
						
							|  |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-06-09 01:49:13 +08:00
										 |  |  |  |     layout_image = Image.open(args.layout).convert('L') | 
					
						
							|  |  |  |  |     template_image = Image.open(args.template).convert('L') | 
					
						
							|  |  |  |  |      | 
					
						
							| 
									
										
										
										
											2025-09-25 21:24:41 +08:00
										 |  |  |  |     detected_bboxes = match_template_multiscale( | 
					
						
							|  |  |  |  |         model, | 
					
						
							|  |  |  |  |         layout_image, | 
					
						
							|  |  |  |  |         template_image, | 
					
						
							|  |  |  |  |         transform, | 
					
						
							|  |  |  |  |         matching_cfg, | 
					
						
							|  |  |  |  |         log_writer=writer, | 
					
						
							|  |  |  |  |         log_step=0, | 
					
						
							|  |  |  |  |     ) | 
					
						
							| 
									
										
										
										
											2025-06-09 01:49:13 +08:00
										 |  |  |  |      | 
					
						
							| 
									
										
										
										
											2025-06-08 15:38:56 +08:00
										 |  |  |  |     print("\n检测到的边界框:") | 
					
						
							| 
									
										
										
										
											2025-06-07 23:45:32 +08:00
										 |  |  |  |     for bbox in detected_bboxes: | 
					
						
							| 
									
										
										
										
											2025-06-08 15:38:56 +08:00
										 |  |  |  |         print(bbox) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |     if args.output: | 
					
						
							| 
									
										
										
										
											2025-09-25 21:24:41 +08:00
										 |  |  |  |         visualize_matches(args.layout, detected_bboxes, args.output) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |     if writer: | 
					
						
							|  |  |  |  |         writer.add_scalar("match/output_instances", len(detected_bboxes), 0) | 
					
						
							|  |  |  |  |         writer.add_text("match/layout_path", args.layout, 0) | 
					
						
							|  |  |  |  |         writer.close() |