2025-06-08 15:38:56 +08:00
|
|
|
|
# match.py
|
|
|
|
|
|
|
2025-09-25 20:20:24 +08:00
|
|
|
|
import argparse
|
|
|
|
|
|
import os
|
|
|
|
|
|
from pathlib import Path
|
|
|
|
|
|
|
|
|
|
|
|
import cv2
|
|
|
|
|
|
import numpy as np
|
2025-06-07 23:45:32 +08:00
|
|
|
|
import torch
|
|
|
|
|
|
import torch.nn.functional as F
|
|
|
|
|
|
from PIL import Image
|
2025-09-25 21:24:41 +08:00
|
|
|
|
try:
|
|
|
|
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
|
|
|
|
except ImportError: # pragma: no cover - fallback for environments without torch tensorboard
|
|
|
|
|
|
from tensorboardX import SummaryWriter # type: ignore
|
2025-06-07 23:45:32 +08:00
|
|
|
|
|
2025-06-08 15:38:56 +08:00
|
|
|
|
from models.rord import RoRD
|
2025-09-25 20:20:24 +08:00
|
|
|
|
from utils.config_loader import load_config, to_absolute_path
|
2025-06-08 15:38:56 +08:00
|
|
|
|
from utils.data_utils import get_transform
|
2025-06-07 23:45:32 +08:00
|
|
|
|
|
2025-06-09 01:49:13 +08:00
|
|
|
|
# --- 特征提取函数 (基本无变动) ---
|
|
|
|
|
|
def extract_keypoints_and_descriptors(model, image_tensor, kp_thresh):
|
2025-06-07 23:45:32 +08:00
|
|
|
|
with torch.no_grad():
|
2025-06-09 01:49:13 +08:00
|
|
|
|
detection_map, desc = model(image_tensor)
|
|
|
|
|
|
|
|
|
|
|
|
device = detection_map.device
|
|
|
|
|
|
binary_map = (detection_map > kp_thresh).squeeze(0).squeeze(0)
|
|
|
|
|
|
coords = torch.nonzero(binary_map).float() # y, x
|
|
|
|
|
|
|
|
|
|
|
|
if len(coords) == 0:
|
|
|
|
|
|
return torch.tensor([], device=device), torch.tensor([], device=device)
|
|
|
|
|
|
|
|
|
|
|
|
# 描述子采样
|
|
|
|
|
|
coords_for_grid = coords.flip(1).view(1, -1, 1, 2) # N, 2 -> 1, N, 1, 2 (x,y)
|
|
|
|
|
|
# 归一化到 [-1, 1]
|
|
|
|
|
|
coords_for_grid = coords_for_grid / torch.tensor([(desc.shape[3]-1)/2, (desc.shape[2]-1)/2], device=device) - 1
|
|
|
|
|
|
|
|
|
|
|
|
descriptors = F.grid_sample(desc, coords_for_grid, align_corners=True).squeeze().T
|
|
|
|
|
|
descriptors = F.normalize(descriptors, p=2, dim=1)
|
|
|
|
|
|
|
|
|
|
|
|
# 将关键点坐标从特征图尺度转换回图像尺度
|
|
|
|
|
|
# VGG到relu4_3的下采样率为8
|
|
|
|
|
|
keypoints = coords.flip(1) * 8.0 # x, y
|
|
|
|
|
|
|
|
|
|
|
|
return keypoints, descriptors
|
|
|
|
|
|
|
|
|
|
|
|
# --- (新增) 滑动窗口特征提取函数 ---
|
2025-09-25 20:20:24 +08:00
|
|
|
|
def extract_features_sliding_window(model, large_image, transform, matching_cfg):
|
2025-06-09 01:49:13 +08:00
|
|
|
|
"""
|
|
|
|
|
|
使用滑动窗口从大图上提取所有关键点和描述子
|
|
|
|
|
|
"""
|
|
|
|
|
|
print("使用滑动窗口提取大版图特征...")
|
|
|
|
|
|
device = next(model.parameters()).device
|
|
|
|
|
|
W, H = large_image.size
|
2025-09-25 20:20:24 +08:00
|
|
|
|
window_size = int(matching_cfg.inference_window_size)
|
|
|
|
|
|
stride = int(matching_cfg.inference_stride)
|
|
|
|
|
|
keypoint_threshold = float(matching_cfg.keypoint_threshold)
|
2025-06-09 01:49:13 +08:00
|
|
|
|
|
|
|
|
|
|
all_kps = []
|
|
|
|
|
|
all_descs = []
|
|
|
|
|
|
|
|
|
|
|
|
for y in range(0, H, stride):
|
|
|
|
|
|
for x in range(0, W, stride):
|
|
|
|
|
|
# 确保窗口不越界
|
|
|
|
|
|
x_end = min(x + window_size, W)
|
|
|
|
|
|
y_end = min(y + window_size, H)
|
|
|
|
|
|
|
|
|
|
|
|
# 裁剪窗口
|
|
|
|
|
|
patch = large_image.crop((x, y, x_end, y_end))
|
|
|
|
|
|
|
|
|
|
|
|
# 预处理
|
|
|
|
|
|
patch_tensor = transform(patch).unsqueeze(0).to(device)
|
|
|
|
|
|
|
|
|
|
|
|
# 提取特征
|
2025-09-25 20:20:24 +08:00
|
|
|
|
kps, descs = extract_keypoints_and_descriptors(model, patch_tensor, keypoint_threshold)
|
2025-06-09 01:49:13 +08:00
|
|
|
|
|
|
|
|
|
|
if len(kps) > 0:
|
|
|
|
|
|
# 将局部坐标转换为全局坐标
|
|
|
|
|
|
kps[:, 0] += x
|
|
|
|
|
|
kps[:, 1] += y
|
|
|
|
|
|
all_kps.append(kps)
|
|
|
|
|
|
all_descs.append(descs)
|
|
|
|
|
|
|
|
|
|
|
|
if not all_kps:
|
|
|
|
|
|
return torch.tensor([], device=device), torch.tensor([], device=device)
|
2025-06-07 23:45:32 +08:00
|
|
|
|
|
2025-06-09 01:49:13 +08:00
|
|
|
|
print(f"大版图特征提取完毕,共找到 {sum(len(k) for k in all_kps)} 个关键点。")
|
|
|
|
|
|
return torch.cat(all_kps, dim=0), torch.cat(all_descs, dim=0)
|
2025-06-07 23:45:32 +08:00
|
|
|
|
|
2025-06-09 01:49:13 +08:00
|
|
|
|
|
|
|
|
|
|
# --- 互近邻匹配 (无变动) ---
|
2025-06-08 15:38:56 +08:00
|
|
|
|
def mutual_nearest_neighbor(descs1, descs2):
|
2025-06-09 01:49:13 +08:00
|
|
|
|
if len(descs1) == 0 or len(descs2) == 0:
|
|
|
|
|
|
return torch.empty((0, 2), dtype=torch.int64)
|
2025-06-08 15:38:56 +08:00
|
|
|
|
sim = descs1 @ descs2.T
|
|
|
|
|
|
nn12 = torch.max(sim, dim=1)
|
|
|
|
|
|
nn21 = torch.max(sim, dim=0)
|
|
|
|
|
|
ids1 = torch.arange(0, sim.shape[0], device=sim.device)
|
|
|
|
|
|
mask = (ids1 == nn21.indices[nn12.indices])
|
|
|
|
|
|
matches = torch.stack([ids1[mask], nn12.indices[mask]], dim=1)
|
2025-06-09 01:49:13 +08:00
|
|
|
|
return matches
|
|
|
|
|
|
|
|
|
|
|
|
# --- (已修改) 多尺度、多实例匹配主函数 ---
|
2025-09-25 21:24:41 +08:00
|
|
|
|
def match_template_multiscale(
|
|
|
|
|
|
model,
|
|
|
|
|
|
layout_image,
|
|
|
|
|
|
template_image,
|
|
|
|
|
|
transform,
|
|
|
|
|
|
matching_cfg,
|
|
|
|
|
|
log_writer: SummaryWriter | None = None,
|
|
|
|
|
|
log_step: int = 0,
|
|
|
|
|
|
):
|
2025-06-09 01:49:13 +08:00
|
|
|
|
"""
|
|
|
|
|
|
在不同尺度下搜索模板,并检测多个实例
|
|
|
|
|
|
"""
|
|
|
|
|
|
# 1. 对大版图使用滑动窗口提取全部特征
|
2025-09-25 20:20:24 +08:00
|
|
|
|
layout_kps, layout_descs = extract_features_sliding_window(model, layout_image, transform, matching_cfg)
|
2025-09-25 21:24:41 +08:00
|
|
|
|
if log_writer:
|
|
|
|
|
|
log_writer.add_scalar("match/layout_keypoints", len(layout_kps), log_step)
|
2025-06-09 01:49:13 +08:00
|
|
|
|
|
2025-09-25 20:20:24 +08:00
|
|
|
|
min_inliers = int(matching_cfg.min_inliers)
|
|
|
|
|
|
if len(layout_kps) < min_inliers:
|
2025-06-09 01:49:13 +08:00
|
|
|
|
print("从大版图中提取的关键点过少,无法进行匹配。")
|
2025-09-25 21:24:41 +08:00
|
|
|
|
if log_writer:
|
|
|
|
|
|
log_writer.add_scalar("match/instances_found", 0, log_step)
|
2025-06-09 01:49:13 +08:00
|
|
|
|
return []
|
2025-06-07 23:45:32 +08:00
|
|
|
|
|
2025-06-08 15:38:56 +08:00
|
|
|
|
found_instances = []
|
2025-06-09 01:49:13 +08:00
|
|
|
|
active_layout_mask = torch.ones(len(layout_kps), dtype=bool, device=layout_kps.device)
|
2025-09-25 20:20:24 +08:00
|
|
|
|
pyramid_scales = [float(s) for s in matching_cfg.pyramid_scales]
|
|
|
|
|
|
keypoint_threshold = float(matching_cfg.keypoint_threshold)
|
|
|
|
|
|
ransac_threshold = float(matching_cfg.ransac_reproj_threshold)
|
2025-06-09 01:49:13 +08:00
|
|
|
|
|
|
|
|
|
|
# 2. 多实例迭代检测
|
2025-06-07 23:45:32 +08:00
|
|
|
|
while True:
|
2025-06-09 01:49:13 +08:00
|
|
|
|
current_active_indices = torch.nonzero(active_layout_mask).squeeze(1)
|
|
|
|
|
|
|
|
|
|
|
|
# 如果剩余活动关键点过少,则停止
|
2025-09-25 20:20:24 +08:00
|
|
|
|
if len(current_active_indices) < min_inliers:
|
2025-06-07 23:45:32 +08:00
|
|
|
|
break
|
|
|
|
|
|
|
2025-06-09 01:49:13 +08:00
|
|
|
|
current_layout_kps = layout_kps[current_active_indices]
|
|
|
|
|
|
current_layout_descs = layout_descs[current_active_indices]
|
2025-06-08 15:38:56 +08:00
|
|
|
|
|
2025-06-09 01:49:13 +08:00
|
|
|
|
best_match_info = {'inliers': 0, 'H': None, 'src_pts': None, 'dst_pts': None, 'mask': None}
|
2025-06-07 23:45:32 +08:00
|
|
|
|
|
2025-06-09 01:49:13 +08:00
|
|
|
|
# 3. 图像金字塔:遍历模板的每个尺度
|
|
|
|
|
|
print("在新尺度下搜索模板...")
|
2025-09-25 20:20:24 +08:00
|
|
|
|
for scale in pyramid_scales:
|
2025-06-09 01:49:13 +08:00
|
|
|
|
W, H = template_image.size
|
|
|
|
|
|
new_W, new_H = int(W * scale), int(H * scale)
|
|
|
|
|
|
|
|
|
|
|
|
# 缩放模板
|
|
|
|
|
|
scaled_template = template_image.resize((new_W, new_H), Image.LANCZOS)
|
|
|
|
|
|
template_tensor = transform(scaled_template).unsqueeze(0).to(layout_kps.device)
|
|
|
|
|
|
|
|
|
|
|
|
# 提取缩放后模板的特征
|
2025-09-25 20:20:24 +08:00
|
|
|
|
template_kps, template_descs = extract_keypoints_and_descriptors(model, template_tensor, keypoint_threshold)
|
2025-06-09 01:49:13 +08:00
|
|
|
|
|
|
|
|
|
|
if len(template_kps) < 4: continue
|
2025-06-07 23:45:32 +08:00
|
|
|
|
|
2025-06-09 01:49:13 +08:00
|
|
|
|
# 匹配当前尺度的模板和活动状态的版图特征
|
|
|
|
|
|
matches = mutual_nearest_neighbor(template_descs, current_layout_descs)
|
|
|
|
|
|
|
|
|
|
|
|
if len(matches) < 4: continue
|
2025-06-07 23:45:32 +08:00
|
|
|
|
|
2025-06-09 01:49:13 +08:00
|
|
|
|
# RANSAC
|
|
|
|
|
|
# 注意:模板关键点坐标需要还原到原始尺寸,才能计算正确的H
|
|
|
|
|
|
src_pts = template_kps[matches[:, 0]].cpu().numpy() / scale
|
|
|
|
|
|
dst_pts_indices = current_active_indices[matches[:, 1]]
|
|
|
|
|
|
dst_pts = layout_kps[dst_pts_indices].cpu().numpy()
|
2025-06-08 15:38:56 +08:00
|
|
|
|
|
2025-09-25 20:20:24 +08:00
|
|
|
|
H, mask = cv2.findHomography(src_pts, dst_pts, cv2.RANSAC, ransac_threshold)
|
2025-06-09 01:49:13 +08:00
|
|
|
|
|
|
|
|
|
|
if H is not None and mask.sum() > best_match_info['inliers']:
|
|
|
|
|
|
best_match_info = {'inliers': mask.sum(), 'H': H, 'mask': mask, 'scale': scale, 'dst_pts': dst_pts}
|
|
|
|
|
|
|
|
|
|
|
|
# 4. 如果在所有尺度中找到了最佳匹配,则记录并屏蔽
|
2025-09-25 20:20:24 +08:00
|
|
|
|
if best_match_info['inliers'] > min_inliers:
|
2025-06-09 01:49:13 +08:00
|
|
|
|
print(f"找到一个匹配实例!内点数: {best_match_info['inliers']}, 使用的模板尺度: {best_match_info['scale']:.2f}x")
|
2025-09-25 21:24:41 +08:00
|
|
|
|
if log_writer:
|
|
|
|
|
|
instance_index = len(found_instances)
|
|
|
|
|
|
log_writer.add_scalar("match/instance_inliers", int(best_match_info['inliers']), log_step + instance_index)
|
|
|
|
|
|
log_writer.add_scalar("match/instance_scale", float(best_match_info['scale']), log_step + instance_index)
|
2025-06-09 01:49:13 +08:00
|
|
|
|
|
|
|
|
|
|
inlier_mask = best_match_info['mask'].ravel().astype(bool)
|
|
|
|
|
|
inlier_layout_kps = best_match_info['dst_pts'][inlier_mask]
|
|
|
|
|
|
|
|
|
|
|
|
x_min, y_min = inlier_layout_kps.min(axis=0)
|
|
|
|
|
|
x_max, y_max = inlier_layout_kps.max(axis=0)
|
|
|
|
|
|
|
|
|
|
|
|
instance = {'x': int(x_min), 'y': int(y_min), 'width': int(x_max - x_min), 'height': int(y_max - y_min), 'homography': best_match_info['H']}
|
|
|
|
|
|
found_instances.append(instance)
|
|
|
|
|
|
|
|
|
|
|
|
# 屏蔽已匹配区域的关键点,以便检测下一个实例
|
|
|
|
|
|
kp_x, kp_y = layout_kps[:, 0], layout_kps[:, 1]
|
|
|
|
|
|
region_mask = (kp_x >= x_min) & (kp_x <= x_max) & (kp_y >= y_min) & (kp_y <= y_max)
|
|
|
|
|
|
active_layout_mask[region_mask] = False
|
|
|
|
|
|
|
|
|
|
|
|
print(f"剩余活动关键点: {active_layout_mask.sum()}")
|
|
|
|
|
|
else:
|
|
|
|
|
|
# 如果在所有尺度下都找不到好的匹配,则结束搜索
|
|
|
|
|
|
print("在所有尺度下均未找到新的匹配实例,搜索结束。")
|
|
|
|
|
|
break
|
2025-06-08 15:38:56 +08:00
|
|
|
|
|
2025-09-25 21:24:41 +08:00
|
|
|
|
if log_writer:
|
|
|
|
|
|
log_writer.add_scalar("match/instances_found", len(found_instances), log_step)
|
|
|
|
|
|
|
2025-06-08 15:38:56 +08:00
|
|
|
|
return found_instances
|
|
|
|
|
|
|
2025-06-09 01:49:13 +08:00
|
|
|
|
|
|
|
|
|
|
def visualize_matches(layout_path, bboxes, output_path):
|
2025-06-08 15:38:56 +08:00
|
|
|
|
layout_img = cv2.imread(layout_path)
|
|
|
|
|
|
for i, bbox in enumerate(bboxes):
|
|
|
|
|
|
x, y, w, h = bbox['x'], bbox['y'], bbox['width'], bbox['height']
|
|
|
|
|
|
cv2.rectangle(layout_img, (x, y), (x + w, y + h), (0, 255, 0), 2)
|
|
|
|
|
|
cv2.putText(layout_img, f"Match {i+1}", (x, y - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)
|
|
|
|
|
|
cv2.imwrite(output_path, layout_img)
|
|
|
|
|
|
print(f"可视化结果已保存至: {output_path}")
|
2025-06-07 23:45:32 +08:00
|
|
|
|
|
2025-06-09 01:49:13 +08:00
|
|
|
|
|
2025-06-07 23:45:32 +08:00
|
|
|
|
if __name__ == "__main__":
|
2025-06-09 01:49:13 +08:00
|
|
|
|
parser = argparse.ArgumentParser(description="使用 RoRD 进行多尺度模板匹配")
|
2025-09-25 20:20:24 +08:00
|
|
|
|
parser.add_argument('--config', type=str, default="configs/base_config.yaml", help="YAML 配置文件路径")
|
|
|
|
|
|
parser.add_argument('--model_path', type=str, default=None, help="模型权重路径,若未提供则使用配置文件中的路径")
|
2025-09-25 21:24:41 +08:00
|
|
|
|
parser.add_argument('--log_dir', type=str, default=None, help="TensorBoard 日志根目录,覆盖配置文件设置")
|
|
|
|
|
|
parser.add_argument('--experiment_name', type=str, default=None, help="TensorBoard 实验名称,覆盖配置文件设置")
|
|
|
|
|
|
parser.add_argument('--tb_log_matches', action='store_true', help="启用模板匹配过程的 TensorBoard 记录")
|
|
|
|
|
|
parser.add_argument('--disable_tensorboard', action='store_true', help="禁用 TensorBoard 记录")
|
2025-06-08 15:38:56 +08:00
|
|
|
|
parser.add_argument('--layout', type=str, required=True)
|
|
|
|
|
|
parser.add_argument('--template', type=str, required=True)
|
|
|
|
|
|
parser.add_argument('--output', type=str)
|
|
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
|
2025-09-25 20:20:24 +08:00
|
|
|
|
cfg = load_config(args.config)
|
|
|
|
|
|
config_dir = Path(args.config).resolve().parent
|
|
|
|
|
|
matching_cfg = cfg.matching
|
2025-09-25 21:24:41 +08:00
|
|
|
|
logging_cfg = cfg.get("logging", None)
|
2025-09-25 20:20:24 +08:00
|
|
|
|
model_path = args.model_path or str(to_absolute_path(cfg.paths.model_path, config_dir))
|
|
|
|
|
|
|
2025-09-25 21:24:41 +08:00
|
|
|
|
use_tensorboard = False
|
|
|
|
|
|
log_dir = None
|
|
|
|
|
|
experiment_name = None
|
|
|
|
|
|
if logging_cfg is not None:
|
|
|
|
|
|
use_tensorboard = bool(logging_cfg.get("use_tensorboard", False))
|
|
|
|
|
|
log_dir = logging_cfg.get("log_dir", "runs")
|
|
|
|
|
|
experiment_name = logging_cfg.get("experiment_name", "default")
|
|
|
|
|
|
|
|
|
|
|
|
if args.disable_tensorboard:
|
|
|
|
|
|
use_tensorboard = False
|
|
|
|
|
|
if args.log_dir is not None:
|
|
|
|
|
|
log_dir = args.log_dir
|
|
|
|
|
|
if args.experiment_name is not None:
|
|
|
|
|
|
experiment_name = args.experiment_name
|
|
|
|
|
|
|
|
|
|
|
|
should_log_matches = args.tb_log_matches and use_tensorboard and log_dir is not None
|
|
|
|
|
|
writer = None
|
|
|
|
|
|
if should_log_matches:
|
|
|
|
|
|
log_root = Path(log_dir).expanduser()
|
|
|
|
|
|
exp_folder = experiment_name or "default"
|
|
|
|
|
|
tb_path = log_root / "match" / exp_folder
|
|
|
|
|
|
tb_path.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
|
writer = SummaryWriter(tb_path.as_posix())
|
|
|
|
|
|
|
2025-06-08 15:38:56 +08:00
|
|
|
|
transform = get_transform()
|
2025-06-07 23:45:32 +08:00
|
|
|
|
model = RoRD().cuda()
|
2025-09-25 20:20:24 +08:00
|
|
|
|
model.load_state_dict(torch.load(model_path))
|
2025-06-07 23:45:32 +08:00
|
|
|
|
model.eval()
|
|
|
|
|
|
|
2025-06-09 01:49:13 +08:00
|
|
|
|
layout_image = Image.open(args.layout).convert('L')
|
|
|
|
|
|
template_image = Image.open(args.template).convert('L')
|
|
|
|
|
|
|
2025-09-25 21:24:41 +08:00
|
|
|
|
detected_bboxes = match_template_multiscale(
|
|
|
|
|
|
model,
|
|
|
|
|
|
layout_image,
|
|
|
|
|
|
template_image,
|
|
|
|
|
|
transform,
|
|
|
|
|
|
matching_cfg,
|
|
|
|
|
|
log_writer=writer,
|
|
|
|
|
|
log_step=0,
|
|
|
|
|
|
)
|
2025-06-09 01:49:13 +08:00
|
|
|
|
|
2025-06-08 15:38:56 +08:00
|
|
|
|
print("\n检测到的边界框:")
|
2025-06-07 23:45:32 +08:00
|
|
|
|
for bbox in detected_bboxes:
|
2025-06-08 15:38:56 +08:00
|
|
|
|
print(bbox)
|
|
|
|
|
|
|
|
|
|
|
|
if args.output:
|
2025-09-25 21:24:41 +08:00
|
|
|
|
visualize_matches(args.layout, detected_bboxes, args.output)
|
|
|
|
|
|
|
|
|
|
|
|
if writer:
|
|
|
|
|
|
writer.add_scalar("match/output_instances", len(detected_bboxes), 0)
|
|
|
|
|
|
writer.add_text("match/layout_path", args.layout, 0)
|
|
|
|
|
|
writer.close()
|