Files
2025-09-25 22:05:39 +08:00

365 lines
15 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# match.py
import argparse
import os
from pathlib import Path
import cv2
import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
try:
from torch.utils.tensorboard import SummaryWriter
except ImportError: # pragma: no cover - fallback for environments without torch tensorboard
from tensorboardX import SummaryWriter # type: ignore
from models.rord import RoRD
from utils.config_loader import load_config, to_absolute_path
from utils.data_utils import get_transform
# --- 特征提取函数 (基本无变动) ---
def extract_keypoints_and_descriptors(model, image_tensor, kp_thresh):
with torch.no_grad():
detection_map, desc = model(image_tensor)
device = detection_map.device
binary_map = (detection_map > kp_thresh).squeeze(0).squeeze(0)
coords = torch.nonzero(binary_map).float() # y, x
if len(coords) == 0:
return torch.tensor([], device=device), torch.tensor([], device=device)
# 描述子采样
coords_for_grid = coords.flip(1).view(1, -1, 1, 2) # N, 2 -> 1, N, 1, 2 (x,y)
# 归一化到 [-1, 1]
coords_for_grid = coords_for_grid / torch.tensor([(desc.shape[3]-1)/2, (desc.shape[2]-1)/2], device=device) - 1
descriptors = F.grid_sample(desc, coords_for_grid, align_corners=True).squeeze().T
descriptors = F.normalize(descriptors, p=2, dim=1)
# 将关键点坐标从特征图尺度转换回图像尺度
# VGG到relu4_3的下采样率为8
keypoints = coords.flip(1) * 8.0 # x, y
return keypoints, descriptors
# --- (新增) 简单半径 NMS 去重 ---
def radius_nms(kps: torch.Tensor, scores: torch.Tensor, radius: float) -> torch.Tensor:
if kps.numel() == 0:
return torch.empty((0,), dtype=torch.long, device=kps.device)
idx = torch.argsort(scores, descending=True)
keep = []
taken = torch.zeros(len(kps), dtype=torch.bool, device=kps.device)
for i in idx:
if taken[i]:
continue
keep.append(i.item())
di = kps - kps[i]
dist2 = (di[:, 0]**2 + di[:, 1]**2)
taken |= dist2 <= (radius * radius)
taken[i] = True
return torch.tensor(keep, dtype=torch.long, device=kps.device)
# --- (新增) 滑动窗口特征提取函数 ---
def extract_features_sliding_window(model, large_image, transform, matching_cfg):
"""
使用滑动窗口从大图上提取所有关键点和描述子
"""
print("使用滑动窗口提取大版图特征...")
device = next(model.parameters()).device
W, H = large_image.size
window_size = int(matching_cfg.inference_window_size)
stride = int(matching_cfg.inference_stride)
keypoint_threshold = float(matching_cfg.keypoint_threshold)
all_kps = []
all_descs = []
for y in range(0, H, stride):
for x in range(0, W, stride):
# 确保窗口不越界
x_end = min(x + window_size, W)
y_end = min(y + window_size, H)
# 裁剪窗口
patch = large_image.crop((x, y, x_end, y_end))
# 预处理
patch_tensor = transform(patch).unsqueeze(0).to(device)
# 提取特征
kps, descs = extract_keypoints_and_descriptors(model, patch_tensor, keypoint_threshold)
if len(kps) > 0:
# 将局部坐标转换为全局坐标
kps[:, 0] += x
kps[:, 1] += y
all_kps.append(kps)
all_descs.append(descs)
if not all_kps:
return torch.tensor([], device=device), torch.tensor([], device=device)
print(f"大版图特征提取完毕,共找到 {sum(len(k) for k in all_kps)} 个关键点。")
return torch.cat(all_kps, dim=0), torch.cat(all_descs, dim=0)
# --- (新增) FPN 路径的关键点与描述子抽取 ---
def extract_from_pyramid(model, image_tensor, kp_thresh, nms_cfg):
with torch.no_grad():
pyramid = model(image_tensor, return_pyramid=True)
all_kps = []
all_desc = []
for level_name, (det, desc, stride) in pyramid.items():
binary = (det > kp_thresh).squeeze(0).squeeze(0)
coords = torch.nonzero(binary).float() # y,x
if len(coords) == 0:
continue
scores = det.squeeze()[binary]
# 采样描述子
coords_for_grid = coords.flip(1).view(1, -1, 1, 2)
coords_for_grid = coords_for_grid / torch.tensor([(desc.shape[3]-1)/2, (desc.shape[2]-1)/2], device=desc.device) - 1
descs = F.grid_sample(desc, coords_for_grid, align_corners=True).squeeze().T
descs = F.normalize(descs, p=2, dim=1)
# 映射回原图坐标
kps = coords.flip(1) * float(stride)
# NMS
if nms_cfg and nms_cfg.get('enabled', False):
keep = radius_nms(kps, scores, float(nms_cfg.get('radius', 4)))
if len(keep) > 0:
kps = kps[keep]
descs = descs[keep]
all_kps.append(kps)
all_desc.append(descs)
if not all_kps:
return torch.tensor([], device=image_tensor.device), torch.tensor([], device=image_tensor.device)
return torch.cat(all_kps, dim=0), torch.cat(all_desc, dim=0)
# --- 互近邻匹配 (无变动) ---
def mutual_nearest_neighbor(descs1, descs2):
if len(descs1) == 0 or len(descs2) == 0:
return torch.empty((0, 2), dtype=torch.int64)
sim = descs1 @ descs2.T
nn12 = torch.max(sim, dim=1)
nn21 = torch.max(sim, dim=0)
ids1 = torch.arange(0, sim.shape[0], device=sim.device)
mask = (ids1 == nn21.indices[nn12.indices])
matches = torch.stack([ids1[mask], nn12.indices[mask]], dim=1)
return matches
# --- (已修改) 多尺度、多实例匹配主函数 ---
def match_template_multiscale(
model,
layout_image,
template_image,
transform,
matching_cfg,
log_writer: SummaryWriter | None = None,
log_step: int = 0,
):
"""
在不同尺度下搜索模板,并检测多个实例
"""
# 1. 版图特征提取:根据配置选择 FPN 或滑窗
device = next(model.parameters()).device
if getattr(matching_cfg, 'use_fpn', False):
layout_tensor = transform(layout_image).unsqueeze(0).to(device)
layout_kps, layout_descs = extract_from_pyramid(model, layout_tensor, float(matching_cfg.keypoint_threshold), getattr(matching_cfg, 'nms', {}))
else:
layout_kps, layout_descs = extract_features_sliding_window(model, layout_image, transform, matching_cfg)
if log_writer:
log_writer.add_scalar("match/layout_keypoints", len(layout_kps), log_step)
min_inliers = int(matching_cfg.min_inliers)
if len(layout_kps) < min_inliers:
print("从大版图中提取的关键点过少,无法进行匹配。")
if log_writer:
log_writer.add_scalar("match/instances_found", 0, log_step)
return []
found_instances = []
active_layout_mask = torch.ones(len(layout_kps), dtype=bool, device=layout_kps.device)
pyramid_scales = [float(s) for s in matching_cfg.pyramid_scales]
keypoint_threshold = float(matching_cfg.keypoint_threshold)
ransac_threshold = float(matching_cfg.ransac_reproj_threshold)
# 2. 多实例迭代检测
while True:
current_active_indices = torch.nonzero(active_layout_mask).squeeze(1)
# 如果剩余活动关键点过少,则停止
if len(current_active_indices) < min_inliers:
break
current_layout_kps = layout_kps[current_active_indices]
current_layout_descs = layout_descs[current_active_indices]
best_match_info = {'inliers': 0, 'H': None, 'src_pts': None, 'dst_pts': None, 'mask': None}
# 3. 图像金字塔:遍历模板的每个尺度
print("在新尺度下搜索模板...")
for scale in pyramid_scales:
W, H = template_image.size
new_W, new_H = int(W * scale), int(H * scale)
# 缩放模板
scaled_template = template_image.resize((new_W, new_H), Image.LANCZOS)
template_tensor = transform(scaled_template).unsqueeze(0).to(layout_kps.device)
# 提取缩放后模板的特征FPN 或单尺度
if getattr(matching_cfg, 'use_fpn', False):
template_kps, template_descs = extract_from_pyramid(model, template_tensor, keypoint_threshold, getattr(matching_cfg, 'nms', {}))
else:
template_kps, template_descs = extract_keypoints_and_descriptors(model, template_tensor, keypoint_threshold)
if len(template_kps) < 4: continue
# 匹配当前尺度的模板和活动状态的版图特征
matches = mutual_nearest_neighbor(template_descs, current_layout_descs)
if len(matches) < 4: continue
# RANSAC
# 注意模板关键点坐标需要还原到原始尺寸才能计算正确的H
src_pts = template_kps[matches[:, 0]].cpu().numpy() / scale
dst_pts_indices = current_active_indices[matches[:, 1]]
dst_pts = layout_kps[dst_pts_indices].cpu().numpy()
H, mask = cv2.findHomography(src_pts, dst_pts, cv2.RANSAC, ransac_threshold)
if H is not None and mask.sum() > best_match_info['inliers']:
best_match_info = {'inliers': mask.sum(), 'H': H, 'mask': mask, 'scale': scale, 'dst_pts': dst_pts}
# 4. 如果在所有尺度中找到了最佳匹配,则记录并屏蔽
if best_match_info['inliers'] > min_inliers:
print(f"找到一个匹配实例!内点数: {best_match_info['inliers']}, 使用的模板尺度: {best_match_info['scale']:.2f}x")
if log_writer:
instance_index = len(found_instances)
log_writer.add_scalar("match/instance_inliers", int(best_match_info['inliers']), log_step + instance_index)
log_writer.add_scalar("match/instance_scale", float(best_match_info['scale']), log_step + instance_index)
inlier_mask = best_match_info['mask'].ravel().astype(bool)
inlier_layout_kps = best_match_info['dst_pts'][inlier_mask]
x_min, y_min = inlier_layout_kps.min(axis=0)
x_max, y_max = inlier_layout_kps.max(axis=0)
instance = {'x': int(x_min), 'y': int(y_min), 'width': int(x_max - x_min), 'height': int(y_max - y_min), 'homography': best_match_info['H']}
found_instances.append(instance)
# 屏蔽已匹配区域的关键点,以便检测下一个实例
kp_x, kp_y = layout_kps[:, 0], layout_kps[:, 1]
region_mask = (kp_x >= x_min) & (kp_x <= x_max) & (kp_y >= y_min) & (kp_y <= y_max)
active_layout_mask[region_mask] = False
print(f"剩余活动关键点: {active_layout_mask.sum()}")
else:
# 如果在所有尺度下都找不到好的匹配,则结束搜索
print("在所有尺度下均未找到新的匹配实例,搜索结束。")
break
if log_writer:
log_writer.add_scalar("match/instances_found", len(found_instances), log_step)
return found_instances
def visualize_matches(layout_path, bboxes, output_path):
layout_img = cv2.imread(layout_path)
for i, bbox in enumerate(bboxes):
x, y, w, h = bbox['x'], bbox['y'], bbox['width'], bbox['height']
cv2.rectangle(layout_img, (x, y), (x + w, y + h), (0, 255, 0), 2)
cv2.putText(layout_img, f"Match {i+1}", (x, y - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)
cv2.imwrite(output_path, layout_img)
print(f"可视化结果已保存至: {output_path}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="使用 RoRD 进行多尺度模板匹配")
parser.add_argument('--config', type=str, default="configs/base_config.yaml", help="YAML 配置文件路径")
parser.add_argument('--model_path', type=str, default=None, help="模型权重路径,若未提供则使用配置文件中的路径")
parser.add_argument('--log_dir', type=str, default=None, help="TensorBoard 日志根目录,覆盖配置文件设置")
parser.add_argument('--experiment_name', type=str, default=None, help="TensorBoard 实验名称,覆盖配置文件设置")
parser.add_argument('--tb_log_matches', action='store_true', help="启用模板匹配过程的 TensorBoard 记录")
parser.add_argument('--disable_tensorboard', action='store_true', help="禁用 TensorBoard 记录")
parser.add_argument('--fpn_off', action='store_true', help="关闭 FPN 匹配路径(等同于 matching.use_fpn=false")
parser.add_argument('--no_nms', action='store_true', help="关闭关键点去重NMS")
parser.add_argument('--layout', type=str, required=True)
parser.add_argument('--template', type=str, required=True)
parser.add_argument('--output', type=str)
args = parser.parse_args()
cfg = load_config(args.config)
config_dir = Path(args.config).resolve().parent
matching_cfg = cfg.matching
logging_cfg = cfg.get("logging", None)
model_path = args.model_path or str(to_absolute_path(cfg.paths.model_path, config_dir))
use_tensorboard = False
log_dir = None
experiment_name = None
if logging_cfg is not None:
use_tensorboard = bool(logging_cfg.get("use_tensorboard", False))
log_dir = logging_cfg.get("log_dir", "runs")
experiment_name = logging_cfg.get("experiment_name", "default")
if args.disable_tensorboard:
use_tensorboard = False
if args.log_dir is not None:
log_dir = args.log_dir
if args.experiment_name is not None:
experiment_name = args.experiment_name
should_log_matches = args.tb_log_matches and use_tensorboard and log_dir is not None
writer = None
if should_log_matches:
log_root = Path(log_dir).expanduser()
exp_folder = experiment_name or "default"
tb_path = log_root / "match" / exp_folder
tb_path.parent.mkdir(parents=True, exist_ok=True)
writer = SummaryWriter(tb_path.as_posix())
# CLI 快捷开关覆盖 YAML 配置
try:
if args.fpn_off:
matching_cfg.use_fpn = False
if args.no_nms and hasattr(matching_cfg, 'nms'):
matching_cfg.nms.enabled = False
except Exception:
# 若 OmegaConf 结构不可写,忽略并在后续逻辑中以 getattr 的方式读取
pass
transform = get_transform()
model = RoRD().cuda()
model.load_state_dict(torch.load(model_path))
model.eval()
layout_image = Image.open(args.layout).convert('L')
template_image = Image.open(args.template).convert('L')
detected_bboxes = match_template_multiscale(
model,
layout_image,
template_image,
transform,
matching_cfg,
log_writer=writer,
log_step=0,
)
print("\n检测到的边界框:")
for bbox in detected_bboxes:
print(bbox)
if args.output:
visualize_matches(args.layout, detected_bboxes, args.output)
if writer:
writer.add_scalar("match/output_instances", len(detected_bboxes), 0)
writer.add_text("match/layout_path", args.layout, 0)
writer.close()