第二次大修

This commit is contained in:
Jiao77
2025-06-08 15:38:56 +08:00
parent 53ef1ec99c
commit f0b2e1b605
10 changed files with 315 additions and 508 deletions

View File

@@ -58,49 +58,33 @@ ic_layout_recognition/
└── README.md └── README.md
``` ```
### 1. 训练模型 ## 🚀 使用方法
使用以下命令启动模型训练。训练过程采用自监督学习,通过对图像应用随机旋转来生成训练对,从而优化关键点检测和描述子生成。 ### 1. 配置
首先,请修改 **`config.py`** 文件,设置正确的训练数据、验证数据和模型保存路径。
### 2. 训练模型
```bash ```bash
python train.py --data_dir path/to/layouts --save_dir path/to/save python train.py --data_dir /path/to/your/layouts --save_dir /path/to/your/models --epochs 50
```
使用 `--help` 查看更多选项。
### 3. 模板匹配
```bash
python match.py --model_path /path/to/your/models/rord_model_final.pth \
--layout /path/to/layout.png \
--template /path/to/template.png \
--output /path/to/result.png
``` ```
| 参数 | 描述 | ### 4. 评估模型
| :--- | :--- |
| `--data_dir` | **[必需]** 包含 PNG 格式 IC 版图图像的目录。 |
| `--save_dir` | **[必需]** 训练好的模型权重保存目录。 |
### 2. 评估模型
使用以下命令在验证集上评估模型的性能。评估脚本会计算基于 IoU 阈值的精确率、召回率和 F1 分数。
```bash ```bash
python evaluate.py --model_path path/to/model.pth --val_dir path/to/val/images --annotations_dir path/to/val/annotations --templates path/to/templates python evaluate.py --model_path /path/to/your/models/rord_model_final.pth \
--val_dir /path/to/val/images \
--annotations_dir /path/to/val/annotations \
--templates_dir /path/to/templates
``` ```
| 参数 | 描述 |
| :--- | :--- |
| `--model_path` | **[必需]** 训练好的模型权重 (`.pth`) 文件路径。 |
| `--val_dir` | **[必需]** 验证集图像目录。 |
| `--annotations_dir` | **[必需]** 包含真实标注的 JSON 文件目录。 |
| `--templates` | **[必需]** 模板图像的路径列表。 |
### 3. 进行模板匹配
使用以下命令将模板图像与指定的版图图像进行匹配。匹配过程利用 RoRD 模型提取关键点和描述子通过互最近邻MNN匹配和 RANSAC 几何验证来定位模板。
```bash
python match.py --model_path path/to/model.pth --layout_path path/to/layout.png --template_path path/to/template.png --output_path path/to/output.png
```
| 参数 | 描述 |
| :--- | :--- |
| `--model_path` | **[必需]** 训练好的模型权重 (`.pth`) 文件路径。 |
| `--layout_path` | **[必需]** 待匹配的版图图像路径。 |
| `--template_path` | **[必需]** 模板图像路径。 |
| `--output_path` | **[可选]** 保存可视化匹配结果的路径。 |
## 📦 数据准备 ## 📦 数据准备
### 训练数据 ### 训练数据

31
config.py Normal file
View File

@@ -0,0 +1,31 @@
# config.py
# --- 训练参数 ---
LEARNING_RATE = 1e-4
BATCH_SIZE = 4
NUM_EPOCHS = 20 # 增加了训练轮数
PATCH_SIZE = 256
# --- 匹配与评估参数 ---
# 关键点检测的置信度阈值
KEYPOINT_THRESHOLD = 0.5
# RANSAC 重投影误差阈值(像素)
RANSAC_REPROJ_THRESHOLD = 5.0
# RANSAC 判定为有效匹配所需的最小内点数
MIN_INLIERS = 15 # 适当提高以增加匹配的可靠性
# IoU (Intersection over Union) 阈值,用于评估
IOU_THRESHOLD = 0.5
# --- 文件路径 ---
# 训练数据目录
LAYOUT_DIR = 'path/to/layouts'
# 模型保存目录
SAVE_DIR = 'path/to/save'
# 验证集图像目录
VAL_IMG_DIR = 'path/to/val/images'
# 验证集标注目录
VAL_ANN_DIR = 'path/to/val/annotations'
# 模板图像目录
TEMPLATE_DIR = 'path/to/templates'
# 默认加载的模型路径
MODEL_PATH = 'path/to/save/model_final.pth'

0
data/__init__.py Normal file
View File

View File

@@ -1,123 +1,84 @@
from models.rord import RoRD # evaluate.py
from data.ic_dataset import ICLayoutDataset
from utils.transforms import SobelTransform
from match import match_template_to_layout
import torch import torch
from torchvision import transforms from PIL import Image
import json import json
import os import os
from PIL import Image import argparse
import config
from models.rord import RoRD
from utils.data_utils import get_transform
from data.ic_dataset import ICLayoutDataset
from match import match_template_to_layout
def compute_iou(box1, box2): def compute_iou(box1, box2):
x1, y1, w1, h1 = box1['x'], box1['y'], box1['width'], box1['height'] x1, y1, w1, h1 = box1['x'], box1['y'], box1['width'], box1['height']
x2, y2, w2, h2 = box2['x'], box2['y'], box2['width'], box2['height'] x2, y2, w2, h2 = box2['x'], box2['y'], box2['width'], box2['height']
inter_x1, inter_y1 = max(x1, x2), max(y1, y2)
inter_x1 = max(x1, x2) inter_x2, inter_y2 = min(x1 + w1, x2 + w2), min(y1 + h1, y2 + h2)
inter_y1 = max(y1, y2)
inter_x2 = min(x1 + w1, x2 + w2)
inter_y2 = min(y1 + h1, y2 + h2)
inter_area = max(0, inter_x2 - inter_x1) * max(0, inter_y2 - inter_y1) inter_area = max(0, inter_x2 - inter_x1) * max(0, inter_y2 - inter_y1)
union_area = w1 * h1 + w2 * h2 - inter_area
box1_area = w1 * h1 return inter_area / union_area if union_area > 0 else 0
box2_area = w2 * h2
union_area = box1_area + box2_area - inter_area
iou = inter_area / union_area if union_area > 0 else 0
return iou
def evaluate(model, val_dataset, templates, iou_threshold=0.5): def evaluate(model, val_dataset, template_dir):
model.eval() model.eval()
all_true_positives = 0 all_tp, all_fp, all_fn = 0, 0, 0
all_false_positives = 0 transform = get_transform()
all_false_negatives = 0
template_paths = [os.path.join(template_dir, f) for f in os.listdir(template_dir) if f.endswith('.png')]
for layout_idx in range(len(val_dataset)): for layout_tensor, annotation in val_dataset:
layout_image, annotation = val_dataset[layout_idx] layout_tensor = layout_tensor.unsqueeze(0).cuda()
# layout_image is [3, H, W] gt_by_template = {box['template']: [] for box in annotation.get('boxes', [])}
layout_tensor = layout_image.unsqueeze(0).cuda() # [1, 3, H, W]
# 假设 annotation 是 {"boxes": [{"template": "template1.png", "x": x, "y": y, "width": w, "height": h}, ...]}
gt_boxes_by_template = {}
for box in annotation.get('boxes', []): for box in annotation.get('boxes', []):
template_name = box['template'] gt_by_template[box['template']].append(box)
if template_name not in gt_boxes_by_template:
gt_boxes_by_template[template_name] = []
gt_boxes_by_template[template_name].append(box)
for template_path in templates: for template_path in template_paths:
template_name = os.path.basename(template_path) template_name = os.path.basename(template_path)
template_image = Image.open(template_path).convert('L') template_tensor = transform(Image.open(template_path).convert('L')).unsqueeze(0).cuda()
template_tensor = transform(template_image).unsqueeze(0).cuda() # [1, 3, H, W]
detected = match_template_to_layout(model, layout_tensor, template_tensor)
# 执行匹配 gt_boxes = gt_by_template.get(template_name, [])
detected_bboxes = match_template_to_layout(model, layout_tensor, template_tensor)
matched_gt = [False] * len(gt_boxes)
# 获取当前模板的 gt_boxes tp = 0
gt_boxes = gt_boxes_by_template.get(template_name, []) for det_box in detected:
# 初始化已分配的 gt_box 索引
assigned_gt = set()
for det_box in detected_bboxes:
best_iou = 0 best_iou = 0
best_gt_idx = -1 best_gt_idx = -1
for idx, gt_box in enumerate(gt_boxes): for i, gt_box in enumerate(gt_boxes):
if idx in assigned_gt: if matched_gt[i]: continue
continue
iou = compute_iou(det_box, gt_box) iou = compute_iou(det_box, gt_box)
if iou > best_iou: if iou > best_iou:
best_iou = iou best_iou, best_gt_idx = iou, i
best_gt_idx = idx
if best_iou > iou_threshold and best_gt_idx != -1: if best_iou > config.IOU_THRESHOLD:
all_true_positives += 1 tp += 1
assigned_gt.add(best_gt_idx) matched_gt[best_gt_idx] = True
else:
all_false_positives += 1 all_tp += tp
all_fp += len(detected) - tp
all_fn += len(gt_boxes) - tp
# 计算 FN未分配的 gt_box precision = all_tp / (all_tp + all_fp) if (all_tp + all_fp) > 0 else 0
for idx in range(len(gt_boxes)): recall = all_tp / (all_tp + all_fn) if (all_tp + all_fn) > 0 else 0
if idx not in assigned_gt:
all_false_negatives += 1
# 计算评估指标
precision = all_true_positives / (all_true_positives + all_false_positives) if (all_true_positives + all_false_positives) > 0 else 0
recall = all_true_positives / (all_true_positives + all_false_negatives) if (all_true_positives + all_false_negatives) > 0 else 0
f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0 f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
return {'precision': precision, 'recall': recall, 'f1': f1}
return {
'precision': precision,
'recall': recall,
'f1': f1
}
if __name__ == "__main__": if __name__ == "__main__":
# 设置变换 parser = argparse.ArgumentParser(description="评估 RoRD 模型性能")
transform = transforms.Compose([ parser.add_argument('--model_path', type=str, default=config.MODEL_PATH)
SobelTransform(), parser.add_argument('--val_dir', type=str, default=config.VAL_IMG_DIR)
transforms.ToTensor(), parser.add_argument('--annotations_dir', type=str, default=config.VAL_ANN_DIR)
transforms.Lambda(lambda x: x.repeat(3, 1, 1)), # [1, H, W] -> [3, H, W] parser.add_argument('--templates_dir', type=str, default=config.TEMPLATE_DIR)
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) args = parser.parse_args()
])
# 加载模型
model = RoRD().cuda() model = RoRD().cuda()
model.load_state_dict(torch.load('path/to/weights.pth')) model.load_state_dict(torch.load(args.model_path))
model.eval() val_dataset = ICLayoutDataset(args.val_dir, args.annotations_dir, get_transform())
# 定义验证数据集 results = evaluate(model, val_dataset, args.templates_dir)
val_dataset = ICLayoutDataset(
image_dir='path/to/val/images',
annotation_dir='path/to/val/annotations',
transform=transform
)
# 定义模板列表
templates = ['path/to/templates/template1.png', 'path/to/templates/template2.png'] # 替换为实际模板路径
# 评估模型
results = evaluate(model, val_dataset, templates)
print("评估结果:") print("评估结果:")
print(f"精确率: {results['precision']:.4f}") print(f" 精确率 (Precision): {results['precision']:.4f}")
print(f"召回率: {results['recall']:.4f}") print(f" 召回率 (Recall): {results['recall']:.4f}")
print(f"F1 分数: {results['f1']:.4f}") print(f" F1 分数 (F1 Score): {results['f1']:.4f}")

239
match.py
View File

@@ -1,199 +1,108 @@
# match.py
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from models.rord import RoRD
from torchvision import transforms
from utils.transforms import SobelTransform
import numpy as np import numpy as np
import cv2 import cv2
from PIL import Image from PIL import Image
import argparse
import os
def extract_keypoints_and_descriptors(model, image): import config
""" from models.rord import RoRD
从 RoRD 模型中提取关键点和描述子。 from utils.data_utils import get_transform
参数: def extract_keypoints_and_descriptors(model, image, kp_thresh):
model (RoRD): RoRD 模型。
image (torch.Tensor): 输入图像张量,形状为 [1, 1, H, W]。
返回:
tuple: (keypoints_input, descriptors)
- keypoints_input: [N, 2] float tensor关键点在输入图像中的坐标。
- descriptors: [N, 128] float tensorL2 归一化的描述子。
"""
with torch.no_grad(): with torch.no_grad():
detection_map, _, desc_rord = model(image) detection_map, desc = model(image)
desc = desc_rord # 使用 RoRD 描述子头 binary_map = (detection_map > kp_thresh).float()
coords = torch.nonzero(binary_map[0, 0]).float()
keypoints_input = coords[:, [1, 0]] * 8.0 # Stride of descriptor is 8
# 从检测图中提取关键点 descriptors = F.grid_sample(desc, coords.flip(1).view(1, -1, 1, 2) / torch.tensor([(desc.shape[3]-1)/2, (desc.shape[2]-1)/2], device=desc.device) - 1, align_corners=True).squeeze().T
thresh = 0.5
binary_map = (detection_map > thresh).float()
coords = torch.nonzero(binary_map[0, 0] > thresh).float() # [N, 2],每个行是 (i_d, j_d)
keypoints_input = coords * 16.0 # 将特征图坐标映射到输入图像坐标stride=16
# 从描述子图中提取描述子
# detection_map 的形状为 [1, 1, H/16, W/16]desc 的形状为 [1, 128, H/8, W/8]
# 将 detection_map 的坐标映射到 desc 的坐标:(i_d * 2, j_d * 2)
keypoints_desc = (coords * 2).long() # [N, 2],整数坐标
H_desc, W_desc = desc.shape[2], desc.shape[3]
mask = (keypoints_desc[:, 0] < H_desc) & (keypoints_desc[:, 1] < W_desc)
keypoints_desc = keypoints_desc[mask]
keypoints_input = keypoints_input[mask]
# 提取描述子
descriptors = desc[0, :, keypoints_desc[:, 0], keypoints_desc[:, 1]].T # [N, 128]
# L2 归一化描述子
descriptors = F.normalize(descriptors, p=2, dim=1) descriptors = F.normalize(descriptors, p=2, dim=1)
return keypoints_input, descriptors return keypoints_input, descriptors
def mutual_nearest_neighbor(template_descs, layout_descs): def mutual_nearest_neighbor(descs1, descs2):
""" sim = descs1 @ descs2.T
使用互最近邻MNN找到模板和版图之间的匹配。 nn12 = torch.max(sim, dim=1)
nn21 = torch.max(sim, dim=0)
参数: ids1 = torch.arange(0, sim.shape[0], device=sim.device)
template_descs (torch.Tensor): 模板描述子,形状为 [M, 128]。 mask = (ids1 == nn21.indices[nn12.indices])
layout_descs (torch.Tensor): 版图描述子,形状为 [N, 128]。 matches = torch.stack([ids1[mask], nn12.indices[mask]], dim=1)
return matches.cpu().numpy()
返回:
list: [(i_template, i_layout)],互最近邻匹配对的列表。
"""
M, N = template_descs.size(0), layout_descs.size(0)
if M == 0 or N == 0:
return []
similarity_matrix = template_descs @ layout_descs.T # [M, N],点积矩阵
# 找到每个模板描述子的最近邻
nn_template_to_layout = torch.argmax(similarity_matrix, dim=1) # [M]
# 找到每个版图描述子的最近邻
nn_layout_to_template = torch.argmax(similarity_matrix, dim=0) # [N]
# 找到互最近邻
mutual_matches = []
for i in range(M):
j = nn_template_to_layout[i]
if nn_layout_to_template[j] == i:
mutual_matches.append((i.item(), j.item()))
return mutual_matches
def ransac_filter(matches, template_kps, layout_kps):
"""
使用 RANSAC 对匹配进行几何验证,并返回内点。
参数:
matches (list): [(i_template, i_layout)],匹配对列表。
template_kps (torch.Tensor): 模板关键点,形状为 [M, 2]。
layout_kps (torch.Tensor): 版图关键点,形状为 [N, 2]。
返回:
tuple: (inlier_matches, num_inliers)
- inlier_matches: [(i_template, i_layout)],内点匹配对。
- num_inliers: int内点数量。
"""
src_pts = np.array([template_kps[i].cpu().numpy() for i, _ in matches])
dst_pts = np.array([layout_kps[j].cpu().numpy() for _, j in matches])
if len(src_pts) < 4:
return [], 0
try:
H, mask = cv2.findHomography(src_pts, dst_pts, cv2.RANSAC, ransacReprojThreshold=5.0)
if H is None:
return [], 0
inliers = mask.ravel() > 0
num_inliers = np.sum(inliers)
inlier_matches = [matches[k] for k in range(len(matches)) if inliers[k]]
return inlier_matches, num_inliers
except cv2.error:
return [], 0
def match_template_to_layout(model, layout_image, template_image): def match_template_to_layout(model, layout_image, template_image):
""" layout_kps, layout_descs = extract_keypoints_and_descriptors(model, layout_image, config.KEYPOINT_THRESHOLD)
使用 RoRD 模型执行模板匹配,迭代找到所有匹配并屏蔽已匹配区域。 template_kps, template_descs = extract_keypoints_and_descriptors(model, template_image, config.KEYPOINT_THRESHOLD)
参数: active_layout_mask = torch.ones(len(layout_kps), dtype=bool, device=layout_kps.device)
model (RoRD): RoRD 模型。 found_instances = []
layout_image (torch.Tensor): 版图图像张量,形状为 [1, 1, H_layout, W_layout]。
template_image (torch.Tensor): 模板图像张量,形状为 [1, 1, H_template, W_template]。
返回:
list: [{'x': x_min, 'y': y_min, 'width': w, 'height': h}],所有检测到的边框。
"""
# 提取版图和模板的关键点和描述子
layout_kps, layout_descs = extract_keypoints_and_descriptors(model, layout_image)
template_kps, template_descs = extract_keypoints_and_descriptors(model, template_image)
# 初始化活动版图关键点掩码
active_layout = torch.ones(len(layout_kps), dtype=bool)
bboxes = []
while True: while True:
# 获取当前活动的版图关键点和描述子 current_indices = torch.nonzero(active_layout_mask).squeeze(1)
current_layout_kps = layout_kps[active_layout] if len(current_indices) < config.MIN_INLIERS:
current_layout_descs = layout_descs[active_layout]
if len(current_layout_descs) == 0:
break break
# MNN 匹配 current_layout_kps, current_layout_descs = layout_kps[current_indices], layout_descs[current_indices]
matches = mutual_nearest_neighbor(template_descs, current_layout_descs) matches = mutual_nearest_neighbor(template_descs, current_layout_descs)
if len(matches) < 4: break
if len(matches) == 0: src_pts = template_kps[matches[:, 0]].cpu().numpy()
dst_pts = current_layout_kps[matches[:, 1]].cpu().numpy()
H, mask = cv2.findHomography(src_pts, dst_pts, cv2.RANSAC, config.RANSAC_REPROJ_THRESHOLD)
if H is None or mask.sum() < config.MIN_INLIERS:
break break
# 将当前版图索引映射回原始版图索引 inlier_mask = mask.ravel().astype(bool)
active_indices = torch.nonzero(active_layout).squeeze(1)
matches_original = [(i_template, active_indices[i_layout].item()) for i_template, i_layout in matches] # 区域屏蔽逻辑
inlier_layout_kps = dst_pts[inlier_mask]
x_min, y_min = inlier_layout_kps.min(axis=0)
x_max, y_max = inlier_layout_kps.max(axis=0)
instance = {'x': int(x_min), 'y': int(y_min), 'width': int(x_max - x_min), 'height': int(y_max - y_min), 'homography': H}
found_instances.append(instance)
# RANSAC 过滤 kp_x, kp_y = layout_kps[:, 0], layout_kps[:, 1]
inlier_matches, num_inliers = ransac_filter(matches_original, template_kps, layout_kps) region_mask = (kp_x >= x_min) & (kp_x <= x_max) & (kp_y >= y_min) & (kp_y <= y_max)
active_layout_mask[region_mask] = False
print(f"找到实例,内点数: {mask.sum()}。剩余活动关键点: {active_layout_mask.sum()}")
return found_instances
if num_inliers > 10: # 设置内点阈值 def visualize_matches(layout_path, template_path, bboxes, output_path):
# 获取内点在版图中的关键点 layout_img = cv2.imread(layout_path)
inlier_layout_kps = [layout_kps[j].cpu().numpy() for _, j in inlier_matches] for i, bbox in enumerate(bboxes):
inlier_layout_kps = np.array(inlier_layout_kps) x, y, w, h = bbox['x'], bbox['y'], bbox['width'], bbox['height']
cv2.rectangle(layout_img, (x, y), (x + w, y + h), (0, 255, 0), 2)
# 计算边框 cv2.putText(layout_img, f"Match {i+1}", (x, y - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)
x_min = int(inlier_layout_kps[:, 0].min()) cv2.imwrite(output_path, layout_img)
y_min = int(inlier_layout_kps[:, 1].min()) print(f"可视化结果已保存至: {output_path}")
x_max = int(inlier_layout_kps[:, 0].max())
y_max = int(inlier_layout_kps[:, 1].max())
bboxes.append({'x': x_min, 'y': y_min, 'width': x_max - x_min, 'height': y_max - y_min})
# 屏蔽内点
for _, j in inlier_matches:
active_layout[j] = False
else:
break
return bboxes
if __name__ == "__main__": if __name__ == "__main__":
# 设置变换 parser = argparse.ArgumentParser(description="使用 RoRD 进行模板匹配")
transform = transforms.Compose([ parser.add_argument('--model_path', type=str, default=config.MODEL_PATH)
SobelTransform(), parser.add_argument('--layout', type=str, required=True)
transforms.ToTensor(), parser.add_argument('--template', type=str, required=True)
transforms.Normalize(mean=[0.5], std=[0.5]) parser.add_argument('--output', type=str)
]) args = parser.parse_args()
# 加载模型 transform = get_transform()
model = RoRD().cuda() model = RoRD().cuda()
model.load_state_dict(torch.load('path/to/weights.pth')) model.load_state_dict(torch.load(args.model_path))
model.eval() model.eval()
# 加载版图和模板图像 layout_tensor = transform(Image.open(args.layout).convert('L')).unsqueeze(0).cuda()
layout_image = Image.open('path/to/layout.png').convert('L') template_tensor = transform(Image.open(args.template).convert('L')).unsqueeze(0).cuda()
layout_tensor = transform(layout_image).unsqueeze(0).cuda()
template_image = Image.open('path/to/template.png').convert('L')
template_tensor = transform(template_image).unsqueeze(0).cuda()
# 执行匹配
detected_bboxes = match_template_to_layout(model, layout_tensor, template_tensor) detected_bboxes = match_template_to_layout(model, layout_tensor, template_tensor)
print("\n检测到的边界框:")
# 打印检测到的边框
print("检测到的边框:")
for bbox in detected_bboxes: for bbox in detected_bboxes:
print(bbox) print(bbox)
if args.output:
visualize_matches(args.layout, args.template, detected_bboxes, args.output)

0
models/__init__.py Normal file
View File

View File

@@ -1,16 +1,25 @@
# models/rord.py
import torch import torch
import torch.nn as nn import torch.nn as nn
from torchvision import models from torchvision import models
class RoRD(nn.Module): class RoRD(nn.Module):
def __init__(self): def __init__(self):
"""
修复后的 RoRD 模型。
- 实现了共享骨干网络,以提高计算效率和减少内存占用。
- 移除了冗余的 descriptor_head_vanilla。
"""
super(RoRD, self).__init__() super(RoRD, self).__init__()
# 检测骨干网络VGG-16 直到 relu5_3层 0 到 29
self.backbone_det = models.vgg16(pretrained=True).features[:30]
# 描述骨干网络VGG-16 直到 relu4_3层 0 到 22
self.backbone_desc = models.vgg16(pretrained=True).features[:23]
# 检测头:输出关键点概率图 vgg16_features = models.vgg16(pretrained=True).features
# 共享骨干网络
self.slice1 = vgg16_features[:23] # 到 relu4_3
self.slice2 = vgg16_features[23:30] # 从 relu4_3 到 relu5_3
# 检测头
self.detection_head = nn.Sequential( self.detection_head = nn.Sequential(
nn.Conv2d(512, 256, kernel_size=3, padding=1), nn.Conv2d(512, 256, kernel_size=3, padding=1),
nn.ReLU(inplace=True), nn.ReLU(inplace=True),
@@ -18,16 +27,8 @@ class RoRD(nn.Module):
nn.Sigmoid() nn.Sigmoid()
) )
# 普通描述子头D2-Net 风格) # 描述子头
self.descriptor_head_vanilla = nn.Sequential( self.descriptor_head = nn.Sequential(
nn.Conv2d(512, 256, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(256, 128, kernel_size=1),
nn.InstanceNorm2d(128)
)
# RoRD 描述子头(旋转鲁棒)
self.descriptor_head_rord = nn.Sequential(
nn.Conv2d(512, 256, kernel_size=3, padding=1), nn.Conv2d(512, 256, kernel_size=3, padding=1),
nn.ReLU(inplace=True), nn.ReLU(inplace=True),
nn.Conv2d(256, 128, kernel_size=1), nn.Conv2d(256, 128, kernel_size=1),
@@ -35,13 +36,14 @@ class RoRD(nn.Module):
) )
def forward(self, x): def forward(self, x):
# 检测分支 # 共享特征提取
features_det = self.backbone_det(x) features_shared = self.slice1(x)
detection = self.detection_head(features_det)
# 描述分支 # 描述分支
features_desc = self.backbone_desc(x) descriptors = self.descriptor_head(features_shared)
desc_vanilla = self.descriptor_head_vanilla(features_desc)
desc_rord = self.descriptor_head_rord(features_desc)
return detection, desc_vanilla, desc_rord # 检测器分支
features_det = self.slice2(features_shared)
detection_map = self.detection_head(features_det)
return detection_map, descriptors

280
train.py
View File

@@ -1,236 +1,142 @@
# train.py
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image from PIL import Image
import numpy as np import numpy as np
import cv2 import cv2
import os import os
from models.rord import RoRD import argparse
# 数据集类:生成随机旋转的训练对 # 导入项目模块
import config
from models.rord import RoRD
from utils.data_utils import get_transform
# --- 训练专用数据集类 ---
class ICLayoutTrainingDataset(Dataset): class ICLayoutTrainingDataset(Dataset):
def __init__(self, image_dir, patch_size=256, transform=None): def __init__(self, image_dir, patch_size=256, transform=None):
"""
初始化 IC 版图训练数据集。
参数:
image_dir (str): 存储 PNG 格式 IC 版图图像的目录路径。
patch_size (int): 裁剪的 patch 大小(默认 256x256
transform (callable, optional): 应用于图像的变换。
"""
self.image_dir = image_dir self.image_dir = image_dir
self.image_paths = [os.path.join(image_dir, f) for f in os.listdir(image_dir) if f.endswith('.png')] self.image_paths = [os.path.join(image_dir, f) for f in os.listdir(image_dir) if f.endswith('.png')]
self.patch_size = patch_size self.patch_size = patch_size
self.transform = transform self.transform = transform
def __len__(self): def __len__(self):
"""
返回数据集中的图像数量。
返回:
int: 数据集大小。
"""
return len(self.image_paths) return len(self.image_paths)
def __getitem__(self, index): def __getitem__(self, index):
"""
获取指定索引的训练对(原始 patch、旋转 patch、Homography 矩阵)。
参数:
index (int): 图像索引。
返回:
tuple: (patch, rotated_patch, H_tensor)
- patch: 原始 patch 张量。
- rotated_patch: 旋转后的 patch 张量。
- H_tensor: Homography 矩阵张量。
"""
img_path = self.image_paths[index] img_path = self.image_paths[index]
image = Image.open(img_path).convert('L') # 灰度图像 image = Image.open(img_path).convert('L')
# 获取图像大小
W, H = image.size W, H = image.size
# 随机选择裁剪的左上角坐标
x = np.random.randint(0, W - self.patch_size + 1) x = np.random.randint(0, W - self.patch_size + 1)
y = np.random.randint(0, H - self.patch_size + 1) y = np.random.randint(0, H - self.patch_size + 1)
patch = image.crop((x, y, x + self.patch_size, y + self.patch_size)) patch = image.crop((x, y, x + self.patch_size, y + self.patch_size))
# 转换为 NumPy 数组
patch_np = np.array(patch) patch_np = np.array(patch)
# 实现8个方向的离散几何变换
theta_deg = np.random.choice([0, 90, 180, 270])
is_mirrored = np.random.choice([True, False])
cx, cy = self.patch_size / 2.0, self.patch_size / 2.0
M = cv2.getRotationMatrix2D((cx, cy), theta_deg, 1)
# 随机旋转角度0°~360° if is_mirrored:
theta = np.random.uniform(0, 360) T1 = np.array([[1, 0, -cx], [0, 1, -cy], [0, 0, 1]])
theta_rad = np.deg2rad(theta) Flip = np.array([[-1, 0, 0], [0, 1, 0], [0, 0, 1]])
cos_theta = np.cos(theta_rad) T2 = np.array([[1, 0, cx], [0, 1, cy], [0, 0, 1]])
sin_theta = np.sin(theta_rad) M_mirror_3x3 = T2 @ Flip @ T1
M_3x3 = np.vstack([M, [0, 0, 1]])
H = (M_3x3 @ M_mirror_3x3).astype(np.float32)
else:
H = np.vstack([M, [0, 0, 1]]).astype(np.float32)
transformed_patch_np = cv2.warpPerspective(patch_np, H, (self.patch_size, self.patch_size))
transformed_patch = Image.fromarray(transformed_patch_np)
# 计算旋转中心patch 的中心)
cx = self.patch_size / 2.0
cy = self.patch_size / 2.0
# 计算旋转的齐次矩阵Homography
H = np.array([
[cos_theta, -sin_theta, cx * (1 - cos_theta) + cy * sin_theta],
[sin_theta, cos_theta, cy * (1 - cos_theta) - cx * sin_theta],
[0, 0, 1]
], dtype=np.float32)
# 应用旋转到 patch
rotated_patch_np = cv2.warpPerspective(patch_np, H, (self.patch_size, self.patch_size))
# 转换回 PIL Image
rotated_patch = Image.fromarray(rotated_patch_np)
# 应用变换
if self.transform: if self.transform:
patch = self.transform(patch) patch = self.transform(patch)
rotated_patch = self.transform(rotated_patch) transformed_patch = self.transform(transformed_patch)
# 转换 H 为张量 H_tensor = torch.from_numpy(H[:2, :]).float() # 通常损失函数需要2x3的仿射矩阵
H_tensor = torch.from_numpy(H).float() return patch, transformed_patch, H_tensor
return patch, rotated_patch, H_tensor # --- 特征图变换与损失函数 ---
# 特征图变换函数
def warp_feature_map(feature_map, H_inv): def warp_feature_map(feature_map, H_inv):
"""
使用逆 Homography 矩阵变换特征图。
参数:
feature_map (torch.Tensor): 输入特征图,形状为 [B, C, H, W]。
H_inv (torch.Tensor): 逆 Homography 矩阵,形状为 [B, 3, 3]。
返回:
torch.Tensor: 变换后的特征图,形状为 [B, C, H, W]。
"""
B, C, H, W = feature_map.size() B, C, H, W = feature_map.size()
# 生成网格 grid = F.affine_grid(H_inv, feature_map.size(), align_corners=False).to(feature_map.device)
grid_y, grid_x = torch.meshgrid( return F.grid_sample(feature_map, grid, align_corners=False)
torch.linspace(-1, 1, H, device=feature_map.device),
torch.linspace(-1, 1, W, device=feature_map.device),
indexing='ij'
)
grid = torch.stack((grid_x, grid_y, torch.ones_like(grid_x)), dim=-1) # [H, W, 3]
grid = grid.unsqueeze(0).expand(B, H, W, 3) # [B, H, W, 3]
# 将网格转换为齐次坐标并应用 H_inv
grid_flat = grid.view(B, -1, 3) # [B, H*W, 3]
grid_transformed = torch.bmm(grid_flat, H_inv.transpose(1, 2)) # [B, H*W, 3]
grid_transformed = grid_transformed.view(B, H, W, 3) # [B, H, W, 3]
grid_transformed = grid_transformed[..., :2] / (grid_transformed[..., 2:3] + 1e-8) # [B, H, W, 2]
# 使用 grid_sample 进行变换
warped_feature = F.grid_sample(feature_map, grid_transformed, align_corners=True)
return warped_feature
# 检测损失函数
def compute_detection_loss(det_original, det_rotated, H): def compute_detection_loss(det_original, det_rotated, H):
""" with torch.no_grad():
计算检测损失MSE比较原始检测图与旋转检测图逆变换后 H_inv = torch.inverse(torch.cat([H, torch.tensor([0.0, 0.0, 1.0]).view(1, 1, 3).repeat(H.shape[0], 1, 1)], dim=1))[:, :2, :]
参数:
det_original (torch.Tensor): 原始图像的检测图,形状为 [B, 1, H, W]。
det_rotated (torch.Tensor): 旋转图像的检测图,形状为 [B, 1, H, W]。
H (torch.Tensor): Homography 矩阵,形状为 [B, 3, 3]。
返回:
torch.Tensor: 检测损失。
"""
H_inv = torch.inverse(H) # 计算逆 Homography
warped_det_rotated = warp_feature_map(det_rotated, H_inv) warped_det_rotated = warp_feature_map(det_rotated, H_inv)
return F.mse_loss(det_original, warped_det_rotated) return F.mse_loss(det_original, warped_det_rotated)
# 描述子损失函数
def compute_description_loss(desc_original, desc_rotated, H, margin=1.0): def compute_description_loss(desc_original, desc_rotated, H, margin=1.0):
""" B, C, H_feat, W_feat = desc_original.size()
计算描述子损失(三元组损失),基于对应点的描述子。 num_samples = 100
# 随机采样锚点坐标
coords = torch.rand(B, num_samples, 2, device=desc_original.device) * 2 - 1 # [-1, 1]
# 提取锚点描述子
anchor = F.grid_sample(desc_original, coords.unsqueeze(1), align_corners=False).squeeze(2).transpose(1, 2)
# 计算正样本坐标
coords_hom = torch.cat([coords, torch.ones(B, num_samples, 1, device=coords.device)], dim=2)
M_inv = torch.inverse(torch.cat([H, torch.tensor([0.0, 0.0, 1.0]).view(1, 1, 3).repeat(H.shape[0], 1, 1)], dim=1))
coords_transformed = (coords_hom @ M_inv.transpose(1, 2))[:, :, :2]
# 提取正样本描述子
positive = F.grid_sample(desc_rotated, coords_transformed.unsqueeze(1), align_corners=False).squeeze(2).transpose(1, 2)
# 随机采样负样本
neg_coords = torch.rand(B, num_samples, 2, device=desc_original.device) * 2 - 1
negative = F.grid_sample(desc_rotated, neg_coords.unsqueeze(1), align_corners=False).squeeze(2).transpose(1, 2)
参数:
desc_original (torch.Tensor): 原始图像的描述子图,形状为 [B, 128, H, W]。
desc_rotated (torch.Tensor): 旋转图像的描述子图,形状为 [B, 128, H, W]。
H (torch.Tensor): Homography 矩阵,形状为 [B, 3, 3]。
margin (float): 三元组损失的边距。
返回:
torch.Tensor: 描述子损失。
"""
B, C, H, W = desc_original.size()
# 随机选择锚点anchor
num_samples = min(100, H * W) # 每张图像采样 100 个点
idx = torch.randint(0, H * W, (B, num_samples), device=desc_original.device)
idx_y = idx // W
idx_x = idx % W
coords = torch.stack((idx_x.float(), idx_y.float()), dim=-1) # [B, num_samples, 2]
# 转换为齐次坐标
coords_hom = torch.cat((coords, torch.ones(B, num_samples, 1, device=coords.device)), dim=-1) # [B, num_samples, 3]
coords_transformed = torch.bmm(coords_hom, H.transpose(1, 2)) # [B, num_samples, 3]
coords_transformed = coords_transformed[..., :2] / (coords_transformed[..., 2:3] + 1e-8) # [B, num_samples, 2]
# 归一化到 [-1, 1] 用于 grid_sample
coords_transformed = coords_transformed / torch.tensor([W/2, H/2], device=coords.device) - 1
# 提取锚点和正样本描述子
anchor = desc_original.view(B, C, -1)[:, :, idx.view(-1)] # [B, 128, num_samples]
positive = F.grid_sample(desc_rotated, coords_transformed.unsqueeze(2), align_corners=True).squeeze(3) # [B, 128, num_samples]
# 随机选择负样本
neg_idx = torch.randint(0, H * W, (B, num_samples), device=desc_original.device)
negative = desc_rotated.view(B, C, -1)[:, :, neg_idx.view(-1)] # [B, 128, num_samples]
# 三元组损失
triplet_loss = nn.TripletMarginLoss(margin=margin, p=2) triplet_loss = nn.TripletMarginLoss(margin=margin, p=2)
loss = triplet_loss(anchor.transpose(1, 2), positive.transpose(1, 2), negative.transpose(1, 2)) return triplet_loss(anchor, positive, negative)
return loss
# 定义变换 # --- 主函数与命令行接口 ---
transform = transforms.Compose([ def main(args):
transforms.ToTensor(), # (1, 256, 256) print("--- 开始训练 RoRD 模型 ---")
transforms.Lambda(lambda x: x.repeat(3, 1, 1)), # (3, 256, 256) print(f"训练参数: Epochs={args.epochs}, Batch Size={args.batch_size}, LR={args.lr}")
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) transform = get_transform()
]) dataset = ICLayoutTrainingDataset(args.data_dir, patch_size=config.PATCH_SIZE, transform=transform)
dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=4)
model = RoRD().cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
# 创建数据集和 DataLoader for epoch in range(args.epochs):
dataset = ICLayoutTrainingDataset('path/to/layouts', patch_size=256, transform=transform) model.train()
dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=4) total_loss_val = 0
for i, (original, rotated, H) in enumerate(dataloader):
original, rotated, H = original.cuda(), rotated.cuda(), H.cuda()
det_original, desc_original = model(original)
det_rotated, desc_rotated = model(rotated)
# 定义模型 loss = compute_detection_loss(det_original, det_rotated, H) + compute_description_loss(desc_original, desc_rotated, H)
model = RoRD().cuda()
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss_val += loss.item()
# 定义优化器 print(f"--- Epoch {epoch+1} 完成, 平均 Loss: {total_loss_val / len(dataloader):.4f} ---")
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
# 训练循环 if not os.path.exists(args.save_dir):
num_epochs = 10 os.makedirs(args.save_dir)
for epoch in range(num_epochs): save_path = os.path.join(args.save_dir, 'rord_model_final.pth')
model.train() torch.save(model.state_dict(), save_path)
total_loss = 0 print(f"模型已保存至: {save_path}")
for batch in dataloader:
original, rotated, H = batch
original = original.cuda()
rotated = rotated.cuda()
H = H.cuda()
# 前向传播 if __name__ == "__main__":
det_original, _, desc_rord_original = model(original) parser = argparse.ArgumentParser(description="训练 RoRD 模型")
det_rotated, _, desc_rord_rotated = model(rotated) parser.add_argument('--data_dir', type=str, default=config.LAYOUT_DIR)
parser.add_argument('--save_dir', type=str, default=config.SAVE_DIR)
# 计算损失 parser.add_argument('--epochs', type=int, default=config.NUM_EPOCHS)
detection_loss = compute_detection_loss(det_original, det_rotated, H) parser.add_argument('--batch_size', type=int, default=config.BATCH_SIZE)
description_loss = compute_description_loss(desc_rord_original, desc_rord_rotated, H) parser.add_argument('--lr', type=float, default=config.LEARNING_RATE)
total_loss_batch = detection_loss + description_loss main(parser.parse_args())
# 反向传播
optimizer.zero_grad()
total_loss_batch.backward()
optimizer.step()
total_loss += total_loss_batch.item()
print(f"Epoch {epoch+1}/{num_epochs}, Loss: {total_loss / len(dataloader):.4f}")
# 保存模型
torch.save(model.state_dict(), 'path/to/save/model.pth')

0
utils/__init__.py Normal file
View File

14
utils/data_utils.py Normal file
View File

@@ -0,0 +1,14 @@
from torchvision import transforms
from .transforms import SobelTransform
def get_transform():
"""
获取统一的图像预处理管道。
确保训练、评估和推理使用完全相同的预处理。
"""
return transforms.Compose([
SobelTransform(), # 应用 Sobel 边缘检测
transforms.ToTensor(),
transforms.Lambda(lambda x: x.repeat(3, 1, 1)), # 适配 VGG 的三通道输入
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])