chenge to english version

This commit is contained in:
Jiao77
2025-07-22 23:43:35 +08:00
parent 4f81daad3c
commit 9cbfc34436
8 changed files with 166 additions and 166 deletions

View File

@@ -12,7 +12,7 @@ import config
from models.rord import RoRD
from utils.data_utils import get_transform
# --- 特征提取函数 (基本无变动) ---
# --- Feature extraction functions (unchanged) ---
def extract_keypoints_and_descriptors(model, image_tensor, kp_thresh):
with torch.no_grad():
detection_map, desc = model(image_tensor)
@@ -24,26 +24,26 @@ def extract_keypoints_and_descriptors(model, image_tensor, kp_thresh):
if len(coords) == 0:
return torch.tensor([], device=device), torch.tensor([], device=device)
# 描述子采样
# Descriptor sampling
coords_for_grid = coords.flip(1).view(1, -1, 1, 2) # N, 2 -> 1, N, 1, 2 (x,y)
# 归一化到 [-1, 1]
# Normalize to [-1, 1]
coords_for_grid = coords_for_grid / torch.tensor([(desc.shape[3]-1)/2, (desc.shape[2]-1)/2], device=device) - 1
descriptors = F.grid_sample(desc, coords_for_grid, align_corners=True).squeeze().T
descriptors = F.normalize(descriptors, p=2, dim=1)
# 将关键点坐标从特征图尺度转换回图像尺度
# VGGrelu4_3的下采样率为8
# Convert keypoint coordinates from feature map scale back to image scale
# VGG downsampling rate to relu4_3 is 8
keypoints = coords.flip(1) * 8.0 # x, y
return keypoints, descriptors
# --- (新增) 滑动窗口特征提取函数 ---
# --- (New) Sliding window feature extraction function ---
def extract_features_sliding_window(model, large_image, transform):
"""
使用滑动窗口从大图上提取所有关键点和描述子
Extract all keypoints and descriptors from large image using sliding window
"""
print("使用滑动窗口提取大版图特征...")
print("Using sliding window to extract features from large layout...")
device = next(model.parameters()).device
W, H = large_image.size
window_size = config.INFERENCE_WINDOW_SIZE
@@ -54,21 +54,21 @@ def extract_features_sliding_window(model, large_image, transform):
for y in range(0, H, stride):
for x in range(0, W, stride):
# 确保窗口不越界
# Ensure window does not exceed boundaries
x_end = min(x + window_size, W)
y_end = min(y + window_size, H)
# 裁剪窗口
# Crop window
patch = large_image.crop((x, y, x_end, y_end))
# 预处理
# Preprocess
patch_tensor = transform(patch).unsqueeze(0).to(device)
# 提取特征
# Extract features
kps, descs = extract_keypoints_and_descriptors(model, patch_tensor, config.KEYPOINT_THRESHOLD)
if len(kps) > 0:
# 将局部坐标转换为全局坐标
# Convert local coordinates to global coordinates
kps[:, 0] += x
kps[:, 1] += y
all_kps.append(kps)
@@ -77,11 +77,11 @@ def extract_features_sliding_window(model, large_image, transform):
if not all_kps:
return torch.tensor([], device=device), torch.tensor([], device=device)
print(f"大版图特征提取完毕,共找到 {sum(len(k) for k in all_kps)} 个关键点。")
print(f"Large layout feature extraction completed, found {sum(len(k) for k in all_kps)} keypoints in total.")
return torch.cat(all_kps, dim=0), torch.cat(all_descs, dim=0)
# --- 互近邻匹配 (无变动) ---
# --- Mutual nearest neighbor matching (unchanged) ---
def mutual_nearest_neighbor(descs1, descs2):
if len(descs1) == 0 or len(descs2) == 0:
return torch.empty((0, 2), dtype=torch.int64)
@@ -93,26 +93,26 @@ def mutual_nearest_neighbor(descs1, descs2):
matches = torch.stack([ids1[mask], nn12.indices[mask]], dim=1)
return matches
# --- (已修改) 多尺度、多实例匹配主函数 ---
# --- (Modified) Multi-scale, multi-instance matching main function ---
def match_template_multiscale(model, layout_image, template_image, transform):
"""
在不同尺度下搜索模板,并检测多个实例
Search for template at different scales and detect multiple instances
"""
# 1. 对大版图使用滑动窗口提取全部特征
# 1. Use sliding window to extract all features from large layout
layout_kps, layout_descs = extract_features_sliding_window(model, layout_image, transform)
if len(layout_kps) < config.MIN_INLIERS:
print("从大版图中提取的关键点过少,无法进行匹配。")
print("Too few keypoints extracted from large layout, cannot perform matching.")
return []
found_instances = []
active_layout_mask = torch.ones(len(layout_kps), dtype=bool, device=layout_kps.device)
# 2. 多实例迭代检测
# 2. Multi-instance iterative detection
while True:
current_active_indices = torch.nonzero(active_layout_mask).squeeze(1)
# 如果剩余活动关键点过少,则停止
# Stop if remaining active keypoints are too few
if len(current_active_indices) < config.MIN_INLIERS:
break
@@ -121,28 +121,28 @@ def match_template_multiscale(model, layout_image, template_image, transform):
best_match_info = {'inliers': 0, 'H': None, 'src_pts': None, 'dst_pts': None, 'mask': None}
# 3. 图像金字塔:遍历模板的每个尺度
print("在新尺度下搜索模板...")
# 3. Image pyramid: iterate through each scale of template
print("Searching for template at new scale...")
for scale in config.PYRAMID_SCALES:
W, H = template_image.size
new_W, new_H = int(W * scale), int(H * scale)
# 缩放模板
# Scale template
scaled_template = template_image.resize((new_W, new_H), Image.LANCZOS)
template_tensor = transform(scaled_template).unsqueeze(0).to(layout_kps.device)
# 提取缩放后模板的特征
# Extract features from scaled template
template_kps, template_descs = extract_keypoints_and_descriptors(model, template_tensor, config.KEYPOINT_THRESHOLD)
if len(template_kps) < 4: continue
# 匹配当前尺度的模板和活动状态的版图特征
# Match current scale template with active layout features
matches = mutual_nearest_neighbor(template_descs, current_layout_descs)
if len(matches) < 4: continue
# RANSAC
# 注意:模板关键点坐标需要还原到原始尺寸,才能计算正确的H
# Note: template keypoint coordinates need to be restored to original size to calculate correct H
src_pts = template_kps[matches[:, 0]].cpu().numpy() / scale
dst_pts_indices = current_active_indices[matches[:, 1]]
dst_pts = layout_kps[dst_pts_indices].cpu().numpy()
@@ -152,9 +152,9 @@ def match_template_multiscale(model, layout_image, template_image, transform):
if H is not None and mask.sum() > best_match_info['inliers']:
best_match_info = {'inliers': mask.sum(), 'H': H, 'mask': mask, 'scale': scale, 'dst_pts': dst_pts}
# 4. 如果在所有尺度中找到了最佳匹配,则记录并屏蔽
# 4. If best match found across all scales, record and mask
if best_match_info['inliers'] > config.MIN_INLIERS:
print(f"找到一个匹配实例!内点数: {best_match_info['inliers']}, 使用的模板尺度: {best_match_info['scale']:.2f}x")
print(f"Found a matching instance! Inliers: {best_match_info['inliers']}, Template scale used: {best_match_info['scale']:.2f}x")
inlier_mask = best_match_info['mask'].ravel().astype(bool)
inlier_layout_kps = best_match_info['dst_pts'][inlier_mask]
@@ -165,15 +165,15 @@ def match_template_multiscale(model, layout_image, template_image, transform):
instance = {'x': int(x_min), 'y': int(y_min), 'width': int(x_max - x_min), 'height': int(y_max - y_min), 'homography': best_match_info['H']}
found_instances.append(instance)
# 屏蔽已匹配区域的关键点,以便检测下一个实例
# Mask keypoints in matched region to detect next instance
kp_x, kp_y = layout_kps[:, 0], layout_kps[:, 1]
region_mask = (kp_x >= x_min) & (kp_x <= x_max) & (kp_y >= y_min) & (kp_y <= y_max)
active_layout_mask[region_mask] = False
print(f"剩余活动关键点: {active_layout_mask.sum()}")
print(f"Remaining active keypoints: {active_layout_mask.sum()}")
else:
# 如果在所有尺度下都找不到好的匹配,则结束搜索
print("在所有尺度下均未找到新的匹配实例,搜索结束。")
# If no good match found across all scales, end search
print("No new matching instances found across all scales, search ended.")
break
return found_instances
@@ -186,11 +186,11 @@ def visualize_matches(layout_path, bboxes, output_path):
cv2.rectangle(layout_img, (x, y), (x + w, y + h), (0, 255, 0), 2)
cv2.putText(layout_img, f"Match {i+1}", (x, y - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)
cv2.imwrite(output_path, layout_img)
print(f"可视化结果已保存至: {output_path}")
print(f"Visualization result saved to: {output_path}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="使用 RoRD 进行多尺度模板匹配")
parser = argparse.ArgumentParser(description="Multi-scale template matching using RoRD")
parser.add_argument('--model_path', type=str, default=config.MODEL_PATH)
parser.add_argument('--layout', type=str, required=True)
parser.add_argument('--template', type=str, required=True)
@@ -207,7 +207,7 @@ if __name__ == "__main__":
detected_bboxes = match_template_multiscale(model, layout_image, template_image, transform)
print("\n检测到的边界框:")
print("\nDetected bounding boxes:")
for bbox in detected_bboxes:
print(bbox)