199 lines
7.2 KiB
Python
199 lines
7.2 KiB
Python
|
|
import torch
|
|||
|
|
import torch.nn.functional as F
|
|||
|
|
from models.rord import RoRD
|
|||
|
|
from torchvision import transforms
|
|||
|
|
from utils.transforms import SobelTransform
|
|||
|
|
import numpy as np
|
|||
|
|
import cv2
|
|||
|
|
from PIL import Image
|
|||
|
|
|
|||
|
|
def extract_keypoints_and_descriptors(model, image):
|
|||
|
|
"""
|
|||
|
|
从 RoRD 模型中提取关键点和描述子。
|
|||
|
|
|
|||
|
|
参数:
|
|||
|
|
model (RoRD): RoRD 模型。
|
|||
|
|
image (torch.Tensor): 输入图像张量,形状为 [1, 1, H, W]。
|
|||
|
|
|
|||
|
|
返回:
|
|||
|
|
tuple: (keypoints_input, descriptors)
|
|||
|
|
- keypoints_input: [N, 2] float tensor,关键点在输入图像中的坐标。
|
|||
|
|
- descriptors: [N, 128] float tensor,L2 归一化的描述子。
|
|||
|
|
"""
|
|||
|
|
with torch.no_grad():
|
|||
|
|
detection_map, _, desc_rord = model(image)
|
|||
|
|
desc = desc_rord # 使用 RoRD 描述子头
|
|||
|
|
|
|||
|
|
# 从检测图中提取关键点
|
|||
|
|
thresh = 0.5
|
|||
|
|
binary_map = (detection_map > thresh).float()
|
|||
|
|
coords = torch.nonzero(binary_map[0, 0] > thresh).float() # [N, 2],每个行是 (i_d, j_d)
|
|||
|
|
keypoints_input = coords * 16.0 # 将特征图坐标映射到输入图像坐标(stride=16)
|
|||
|
|
|
|||
|
|
# 从描述子图中提取描述子
|
|||
|
|
# detection_map 的形状为 [1, 1, H/16, W/16],desc 的形状为 [1, 128, H/8, W/8]
|
|||
|
|
# 将 detection_map 的坐标映射到 desc 的坐标:(i_d * 2, j_d * 2)
|
|||
|
|
keypoints_desc = (coords * 2).long() # [N, 2],整数坐标
|
|||
|
|
H_desc, W_desc = desc.shape[2], desc.shape[3]
|
|||
|
|
mask = (keypoints_desc[:, 0] < H_desc) & (keypoints_desc[:, 1] < W_desc)
|
|||
|
|
keypoints_desc = keypoints_desc[mask]
|
|||
|
|
keypoints_input = keypoints_input[mask]
|
|||
|
|
|
|||
|
|
# 提取描述子
|
|||
|
|
descriptors = desc[0, :, keypoints_desc[:, 0], keypoints_desc[:, 1]].T # [N, 128]
|
|||
|
|
|
|||
|
|
# L2 归一化描述子
|
|||
|
|
descriptors = F.normalize(descriptors, p=2, dim=1)
|
|||
|
|
|
|||
|
|
return keypoints_input, descriptors
|
|||
|
|
|
|||
|
|
def mutual_nearest_neighbor(template_descs, layout_descs):
|
|||
|
|
"""
|
|||
|
|
使用互最近邻(MNN)找到模板和版图之间的匹配。
|
|||
|
|
|
|||
|
|
参数:
|
|||
|
|
template_descs (torch.Tensor): 模板描述子,形状为 [M, 128]。
|
|||
|
|
layout_descs (torch.Tensor): 版图描述子,形状为 [N, 128]。
|
|||
|
|
|
|||
|
|
返回:
|
|||
|
|
list: [(i_template, i_layout)],互最近邻匹配对的列表。
|
|||
|
|
"""
|
|||
|
|
M, N = template_descs.size(0), layout_descs.size(0)
|
|||
|
|
if M == 0 or N == 0:
|
|||
|
|
return []
|
|||
|
|
similarity_matrix = template_descs @ layout_descs.T # [M, N],点积矩阵
|
|||
|
|
|
|||
|
|
# 找到每个模板描述子的最近邻
|
|||
|
|
nn_template_to_layout = torch.argmax(similarity_matrix, dim=1) # [M]
|
|||
|
|
|
|||
|
|
# 找到每个版图描述子的最近邻
|
|||
|
|
nn_layout_to_template = torch.argmax(similarity_matrix, dim=0) # [N]
|
|||
|
|
|
|||
|
|
# 找到互最近邻
|
|||
|
|
mutual_matches = []
|
|||
|
|
for i in range(M):
|
|||
|
|
j = nn_template_to_layout[i]
|
|||
|
|
if nn_layout_to_template[j] == i:
|
|||
|
|
mutual_matches.append((i.item(), j.item()))
|
|||
|
|
|
|||
|
|
return mutual_matches
|
|||
|
|
|
|||
|
|
def ransac_filter(matches, template_kps, layout_kps):
|
|||
|
|
"""
|
|||
|
|
使用 RANSAC 对匹配进行几何验证,并返回内点。
|
|||
|
|
|
|||
|
|
参数:
|
|||
|
|
matches (list): [(i_template, i_layout)],匹配对列表。
|
|||
|
|
template_kps (torch.Tensor): 模板关键点,形状为 [M, 2]。
|
|||
|
|
layout_kps (torch.Tensor): 版图关键点,形状为 [N, 2]。
|
|||
|
|
|
|||
|
|
返回:
|
|||
|
|
tuple: (inlier_matches, num_inliers)
|
|||
|
|
- inlier_matches: [(i_template, i_layout)],内点匹配对。
|
|||
|
|
- num_inliers: int,内点数量。
|
|||
|
|
"""
|
|||
|
|
src_pts = np.array([template_kps[i].cpu().numpy() for i, _ in matches])
|
|||
|
|
dst_pts = np.array([layout_kps[j].cpu().numpy() for _, j in matches])
|
|||
|
|
|
|||
|
|
if len(src_pts) < 4:
|
|||
|
|
return [], 0
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
H, mask = cv2.findHomography(src_pts, dst_pts, cv2.RANSAC, ransacReprojThreshold=5.0)
|
|||
|
|
if H is None:
|
|||
|
|
return [], 0
|
|||
|
|
inliers = mask.ravel() > 0
|
|||
|
|
num_inliers = np.sum(inliers)
|
|||
|
|
inlier_matches = [matches[k] for k in range(len(matches)) if inliers[k]]
|
|||
|
|
return inlier_matches, num_inliers
|
|||
|
|
except cv2.error:
|
|||
|
|
return [], 0
|
|||
|
|
|
|||
|
|
def match_template_to_layout(model, layout_image, template_image):
|
|||
|
|
"""
|
|||
|
|
使用 RoRD 模型执行模板匹配,迭代找到所有匹配并屏蔽已匹配区域。
|
|||
|
|
|
|||
|
|
参数:
|
|||
|
|
model (RoRD): RoRD 模型。
|
|||
|
|
layout_image (torch.Tensor): 版图图像张量,形状为 [1, 1, H_layout, W_layout]。
|
|||
|
|
template_image (torch.Tensor): 模板图像张量,形状为 [1, 1, H_template, W_template]。
|
|||
|
|
|
|||
|
|
返回:
|
|||
|
|
list: [{'x': x_min, 'y': y_min, 'width': w, 'height': h}],所有检测到的边框。
|
|||
|
|
"""
|
|||
|
|
# 提取版图和模板的关键点和描述子
|
|||
|
|
layout_kps, layout_descs = extract_keypoints_and_descriptors(model, layout_image)
|
|||
|
|
template_kps, template_descs = extract_keypoints_and_descriptors(model, template_image)
|
|||
|
|
|
|||
|
|
# 初始化活动版图关键点掩码
|
|||
|
|
active_layout = torch.ones(len(layout_kps), dtype=bool)
|
|||
|
|
|
|||
|
|
bboxes = []
|
|||
|
|
while True:
|
|||
|
|
# 获取当前活动的版图关键点和描述子
|
|||
|
|
current_layout_kps = layout_kps[active_layout]
|
|||
|
|
current_layout_descs = layout_descs[active_layout]
|
|||
|
|
|
|||
|
|
if len(current_layout_descs) == 0:
|
|||
|
|
break
|
|||
|
|
|
|||
|
|
# MNN 匹配
|
|||
|
|
matches = mutual_nearest_neighbor(template_descs, current_layout_descs)
|
|||
|
|
|
|||
|
|
if len(matches) == 0:
|
|||
|
|
break
|
|||
|
|
|
|||
|
|
# 将当前版图索引映射回原始版图索引
|
|||
|
|
active_indices = torch.nonzero(active_layout).squeeze(1)
|
|||
|
|
matches_original = [(i_template, active_indices[i_layout].item()) for i_template, i_layout in matches]
|
|||
|
|
|
|||
|
|
# RANSAC 过滤
|
|||
|
|
inlier_matches, num_inliers = ransac_filter(matches_original, template_kps, layout_kps)
|
|||
|
|
|
|||
|
|
if num_inliers > 10: # 设置内点阈值
|
|||
|
|
# 获取内点在版图中的关键点
|
|||
|
|
inlier_layout_kps = [layout_kps[j].cpu().numpy() for _, j in inlier_matches]
|
|||
|
|
inlier_layout_kps = np.array(inlier_layout_kps)
|
|||
|
|
|
|||
|
|
# 计算边框
|
|||
|
|
x_min = int(inlier_layout_kps[:, 0].min())
|
|||
|
|
y_min = int(inlier_layout_kps[:, 1].min())
|
|||
|
|
x_max = int(inlier_layout_kps[:, 0].max())
|
|||
|
|
y_max = int(inlier_layout_kps[:, 1].max())
|
|||
|
|
bboxes.append({'x': x_min, 'y': y_min, 'width': x_max - x_min, 'height': y_max - y_min})
|
|||
|
|
|
|||
|
|
# 屏蔽内点
|
|||
|
|
for _, j in inlier_matches:
|
|||
|
|
active_layout[j] = False
|
|||
|
|
else:
|
|||
|
|
break
|
|||
|
|
|
|||
|
|
return bboxes
|
|||
|
|
|
|||
|
|
if __name__ == "__main__":
|
|||
|
|
# 设置变换
|
|||
|
|
transform = transforms.Compose([
|
|||
|
|
SobelTransform(),
|
|||
|
|
transforms.ToTensor(),
|
|||
|
|
transforms.Normalize(mean=[0.5], std=[0.5])
|
|||
|
|
])
|
|||
|
|
|
|||
|
|
# 加载模型
|
|||
|
|
model = RoRD().cuda()
|
|||
|
|
model.load_state_dict(torch.load('path/to/weights.pth'))
|
|||
|
|
model.eval()
|
|||
|
|
|
|||
|
|
# 加载版图和模板图像
|
|||
|
|
layout_image = Image.open('path/to/layout.png').convert('L')
|
|||
|
|
layout_tensor = transform(layout_image).unsqueeze(0).cuda()
|
|||
|
|
|
|||
|
|
template_image = Image.open('path/to/template.png').convert('L')
|
|||
|
|
template_tensor = transform(template_image).unsqueeze(0).cuda()
|
|||
|
|
|
|||
|
|
# 执行匹配
|
|||
|
|
detected_bboxes = match_template_to_layout(model, layout_tensor, template_tensor)
|
|||
|
|
|
|||
|
|
# 打印检测到的边框
|
|||
|
|
print("检测到的边框:")
|
|||
|
|
for bbox in detected_bboxes:
|
|||
|
|
print(bbox)
|