From f0b2e1b60576abc00bd550c9c3774d5e82d544eb Mon Sep 17 00:00:00 2001 From: Jiao77 Date: Sun, 8 Jun 2025 15:38:56 +0800 Subject: [PATCH] =?UTF-8?q?=E7=AC=AC=E4=BA=8C=E6=AC=A1=E5=A4=A7=E4=BF=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.md | 54 +++------ config.py | 31 +++++ data/__init__.py | 0 evaluate.py | 157 ++++++++++--------------- match.py | 239 ++++++++++++------------------------- models/__init__.py | 0 models/rord.py | 48 ++++---- train.py | 280 +++++++++++++++----------------------------- utils/__init__.py | 0 utils/data_utils.py | 14 +++ 10 files changed, 315 insertions(+), 508 deletions(-) create mode 100644 config.py create mode 100644 data/__init__.py create mode 100644 models/__init__.py create mode 100644 utils/__init__.py create mode 100644 utils/data_utils.py diff --git a/README.md b/README.md index 3c6b114..43bd11e 100644 --- a/README.md +++ b/README.md @@ -58,49 +58,33 @@ ic_layout_recognition/ └── README.md ``` -### 1. 训练模型 +## 🚀 使用方法 -使用以下命令启动模型训练。训练过程采用自监督学习,通过对图像应用随机旋转来生成训练对,从而优化关键点检测和描述子生成。 +### 1. 配置 +首先,请修改 **`config.py`** 文件,设置正确的训练数据、验证数据和模型保存路径。 +### 2. 训练模型 ```bash -python train.py --data_dir path/to/layouts --save_dir path/to/save +python train.py --data_dir /path/to/your/layouts --save_dir /path/to/your/models --epochs 50 +``` +使用 `--help` 查看更多选项。 + +### 3. 模板匹配 +```bash +python match.py --model_path /path/to/your/models/rord_model_final.pth \ + --layout /path/to/layout.png \ + --template /path/to/template.png \ + --output /path/to/result.png ``` -| 参数 | 描述 | -| :--- | :--- | -| `--data_dir` | **[必需]** 包含 PNG 格式 IC 版图图像的目录。 | -| `--save_dir` | **[必需]** 训练好的模型权重保存目录。 | - -### 2. 评估模型 - -使用以下命令在验证集上评估模型的性能。评估脚本会计算基于 IoU 阈值的精确率、召回率和 F1 分数。 - +### 4. 评估模型 ```bash -python evaluate.py --model_path path/to/model.pth --val_dir path/to/val/images --annotations_dir path/to/val/annotations --templates path/to/templates +python evaluate.py --model_path /path/to/your/models/rord_model_final.pth \ + --val_dir /path/to/val/images \ + --annotations_dir /path/to/val/annotations \ + --templates_dir /path/to/templates ``` -| 参数 | 描述 | -| :--- | :--- | -| `--model_path` | **[必需]** 训练好的模型权重 (`.pth`) 文件路径。 | -| `--val_dir` | **[必需]** 验证集图像目录。 | -| `--annotations_dir` | **[必需]** 包含真实标注的 JSON 文件目录。 | -| `--templates` | **[必需]** 模板图像的路径列表。 | - -### 3. 进行模板匹配 - -使用以下命令将模板图像与指定的版图图像进行匹配。匹配过程利用 RoRD 模型提取关键点和描述子,通过互最近邻(MNN)匹配和 RANSAC 几何验证来定位模板。 - -```bash -python match.py --model_path path/to/model.pth --layout_path path/to/layout.png --template_path path/to/template.png --output_path path/to/output.png -``` - -| 参数 | 描述 | -| :--- | :--- | -| `--model_path` | **[必需]** 训练好的模型权重 (`.pth`) 文件路径。 | -| `--layout_path` | **[必需]** 待匹配的版图图像路径。 | -| `--template_path` | **[必需]** 模板图像路径。 | -| `--output_path` | **[可选]** 保存可视化匹配结果的路径。 | - ## 📦 数据准备 ### 训练数据 diff --git a/config.py b/config.py new file mode 100644 index 0000000..ea31e86 --- /dev/null +++ b/config.py @@ -0,0 +1,31 @@ +# config.py + +# --- 训练参数 --- +LEARNING_RATE = 1e-4 +BATCH_SIZE = 4 +NUM_EPOCHS = 20 # 增加了训练轮数 +PATCH_SIZE = 256 + +# --- 匹配与评估参数 --- +# 关键点检测的置信度阈值 +KEYPOINT_THRESHOLD = 0.5 +# RANSAC 重投影误差阈值(像素) +RANSAC_REPROJ_THRESHOLD = 5.0 +# RANSAC 判定为有效匹配所需的最小内点数 +MIN_INLIERS = 15 # 适当提高以增加匹配的可靠性 +# IoU (Intersection over Union) 阈值,用于评估 +IOU_THRESHOLD = 0.5 + +# --- 文件路径 --- +# 训练数据目录 +LAYOUT_DIR = 'path/to/layouts' +# 模型保存目录 +SAVE_DIR = 'path/to/save' +# 验证集图像目录 +VAL_IMG_DIR = 'path/to/val/images' +# 验证集标注目录 +VAL_ANN_DIR = 'path/to/val/annotations' +# 模板图像目录 +TEMPLATE_DIR = 'path/to/templates' +# 默认加载的模型路径 +MODEL_PATH = 'path/to/save/model_final.pth' \ No newline at end of file diff --git a/data/__init__.py b/data/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/evaluate.py b/evaluate.py index 1f1c5c3..51ec5e9 100644 --- a/evaluate.py +++ b/evaluate.py @@ -1,123 +1,84 @@ -from models.rord import RoRD -from data.ic_dataset import ICLayoutDataset -from utils.transforms import SobelTransform -from match import match_template_to_layout +# evaluate.py + import torch -from torchvision import transforms +from PIL import Image import json import os -from PIL import Image +import argparse + +import config +from models.rord import RoRD +from utils.data_utils import get_transform +from data.ic_dataset import ICLayoutDataset +from match import match_template_to_layout def compute_iou(box1, box2): x1, y1, w1, h1 = box1['x'], box1['y'], box1['width'], box1['height'] x2, y2, w2, h2 = box2['x'], box2['y'], box2['width'], box2['height'] - - inter_x1 = max(x1, x2) - inter_y1 = max(y1, y2) - inter_x2 = min(x1 + w1, x2 + w2) - inter_y2 = min(y1 + h1, y2 + h2) - + inter_x1, inter_y1 = max(x1, x2), max(y1, y2) + inter_x2, inter_y2 = min(x1 + w1, x2 + w2), min(y1 + h1, y2 + h2) inter_area = max(0, inter_x2 - inter_x1) * max(0, inter_y2 - inter_y1) - - box1_area = w1 * h1 - box2_area = w2 * h2 - union_area = box1_area + box2_area - inter_area - - iou = inter_area / union_area if union_area > 0 else 0 - return iou + union_area = w1 * h1 + w2 * h2 - inter_area + return inter_area / union_area if union_area > 0 else 0 -def evaluate(model, val_dataset, templates, iou_threshold=0.5): +def evaluate(model, val_dataset, template_dir): model.eval() - all_true_positives = 0 - all_false_positives = 0 - all_false_negatives = 0 + all_tp, all_fp, all_fn = 0, 0, 0 + transform = get_transform() + + template_paths = [os.path.join(template_dir, f) for f in os.listdir(template_dir) if f.endswith('.png')] - for layout_idx in range(len(val_dataset)): - layout_image, annotation = val_dataset[layout_idx] - # layout_image is [3, H, W] - layout_tensor = layout_image.unsqueeze(0).cuda() # [1, 3, H, W] - - # 假设 annotation 是 {"boxes": [{"template": "template1.png", "x": x, "y": y, "width": w, "height": h}, ...]} - gt_boxes_by_template = {} + for layout_tensor, annotation in val_dataset: + layout_tensor = layout_tensor.unsqueeze(0).cuda() + gt_by_template = {box['template']: [] for box in annotation.get('boxes', [])} for box in annotation.get('boxes', []): - template_name = box['template'] - if template_name not in gt_boxes_by_template: - gt_boxes_by_template[template_name] = [] - gt_boxes_by_template[template_name].append(box) + gt_by_template[box['template']].append(box) - for template_path in templates: + for template_path in template_paths: template_name = os.path.basename(template_path) - template_image = Image.open(template_path).convert('L') - template_tensor = transform(template_image).unsqueeze(0).cuda() # [1, 3, H, W] - - # 执行匹配 - detected_bboxes = match_template_to_layout(model, layout_tensor, template_tensor) - - # 获取当前模板的 gt_boxes - gt_boxes = gt_boxes_by_template.get(template_name, []) - - # 初始化已分配的 gt_box 索引 - assigned_gt = set() - - for det_box in detected_bboxes: + template_tensor = transform(Image.open(template_path).convert('L')).unsqueeze(0).cuda() + + detected = match_template_to_layout(model, layout_tensor, template_tensor) + gt_boxes = gt_by_template.get(template_name, []) + + matched_gt = [False] * len(gt_boxes) + tp = 0 + for det_box in detected: best_iou = 0 best_gt_idx = -1 - for idx, gt_box in enumerate(gt_boxes): - if idx in assigned_gt: - continue + for i, gt_box in enumerate(gt_boxes): + if matched_gt[i]: continue iou = compute_iou(det_box, gt_box) if iou > best_iou: - best_iou = iou - best_gt_idx = idx - if best_iou > iou_threshold and best_gt_idx != -1: - all_true_positives += 1 - assigned_gt.add(best_gt_idx) - else: - all_false_positives += 1 + best_iou, best_gt_idx = iou, i + + if best_iou > config.IOU_THRESHOLD: + tp += 1 + matched_gt[best_gt_idx] = True + + all_tp += tp + all_fp += len(detected) - tp + all_fn += len(gt_boxes) - tp - # 计算 FN:未分配的 gt_box - for idx in range(len(gt_boxes)): - if idx not in assigned_gt: - all_false_negatives += 1 - - # 计算评估指标 - precision = all_true_positives / (all_true_positives + all_false_positives) if (all_true_positives + all_false_positives) > 0 else 0 - recall = all_true_positives / (all_true_positives + all_false_negatives) if (all_true_positives + all_false_negatives) > 0 else 0 + precision = all_tp / (all_tp + all_fp) if (all_tp + all_fp) > 0 else 0 + recall = all_tp / (all_tp + all_fn) if (all_tp + all_fn) > 0 else 0 f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0 - - return { - 'precision': precision, - 'recall': recall, - 'f1': f1 - } + return {'precision': precision, 'recall': recall, 'f1': f1} if __name__ == "__main__": - # 设置变换 - transform = transforms.Compose([ - SobelTransform(), - transforms.ToTensor(), - transforms.Lambda(lambda x: x.repeat(3, 1, 1)), # [1, H, W] -> [3, H, W] - transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) - ]) + parser = argparse.ArgumentParser(description="评估 RoRD 模型性能") + parser.add_argument('--model_path', type=str, default=config.MODEL_PATH) + parser.add_argument('--val_dir', type=str, default=config.VAL_IMG_DIR) + parser.add_argument('--annotations_dir', type=str, default=config.VAL_ANN_DIR) + parser.add_argument('--templates_dir', type=str, default=config.TEMPLATE_DIR) + args = parser.parse_args() - # 加载模型 model = RoRD().cuda() - model.load_state_dict(torch.load('path/to/weights.pth')) - model.eval() - - # 定义验证数据集 - val_dataset = ICLayoutDataset( - image_dir='path/to/val/images', - annotation_dir='path/to/val/annotations', - transform=transform - ) - - # 定义模板列表 - templates = ['path/to/templates/template1.png', 'path/to/templates/template2.png'] # 替换为实际模板路径 - - # 评估模型 - results = evaluate(model, val_dataset, templates) + model.load_state_dict(torch.load(args.model_path)) + val_dataset = ICLayoutDataset(args.val_dir, args.annotations_dir, get_transform()) + + results = evaluate(model, val_dataset, args.templates_dir) print("评估结果:") - print(f"精确率: {results['precision']:.4f}") - print(f"召回率: {results['recall']:.4f}") - print(f"F1 分数: {results['f1']:.4f}") \ No newline at end of file + print(f" 精确率 (Precision): {results['precision']:.4f}") + print(f" 召回率 (Recall): {results['recall']:.4f}") + print(f" F1 分数 (F1 Score): {results['f1']:.4f}") \ No newline at end of file diff --git a/match.py b/match.py index 9a99ee3..79cdcd0 100644 --- a/match.py +++ b/match.py @@ -1,199 +1,108 @@ +# match.py + import torch import torch.nn.functional as F -from models.rord import RoRD -from torchvision import transforms -from utils.transforms import SobelTransform import numpy as np import cv2 from PIL import Image +import argparse +import os -def extract_keypoints_and_descriptors(model, image): - """ - 从 RoRD 模型中提取关键点和描述子。 +import config +from models.rord import RoRD +from utils.data_utils import get_transform - 参数: - model (RoRD): RoRD 模型。 - image (torch.Tensor): 输入图像张量,形状为 [1, 1, H, W]。 - - 返回: - tuple: (keypoints_input, descriptors) - - keypoints_input: [N, 2] float tensor,关键点在输入图像中的坐标。 - - descriptors: [N, 128] float tensor,L2 归一化的描述子。 - """ +def extract_keypoints_and_descriptors(model, image, kp_thresh): with torch.no_grad(): - detection_map, _, desc_rord = model(image) - desc = desc_rord # 使用 RoRD 描述子头 + detection_map, desc = model(image) + binary_map = (detection_map > kp_thresh).float() + coords = torch.nonzero(binary_map[0, 0]).float() + keypoints_input = coords[:, [1, 0]] * 8.0 # Stride of descriptor is 8 - # 从检测图中提取关键点 - thresh = 0.5 - binary_map = (detection_map > thresh).float() - coords = torch.nonzero(binary_map[0, 0] > thresh).float() # [N, 2],每个行是 (i_d, j_d) - keypoints_input = coords * 16.0 # 将特征图坐标映射到输入图像坐标(stride=16) - - # 从描述子图中提取描述子 - # detection_map 的形状为 [1, 1, H/16, W/16],desc 的形状为 [1, 128, H/8, W/8] - # 将 detection_map 的坐标映射到 desc 的坐标:(i_d * 2, j_d * 2) - keypoints_desc = (coords * 2).long() # [N, 2],整数坐标 - H_desc, W_desc = desc.shape[2], desc.shape[3] - mask = (keypoints_desc[:, 0] < H_desc) & (keypoints_desc[:, 1] < W_desc) - keypoints_desc = keypoints_desc[mask] - keypoints_input = keypoints_input[mask] - - # 提取描述子 - descriptors = desc[0, :, keypoints_desc[:, 0], keypoints_desc[:, 1]].T # [N, 128] - - # L2 归一化描述子 + descriptors = F.grid_sample(desc, coords.flip(1).view(1, -1, 1, 2) / torch.tensor([(desc.shape[3]-1)/2, (desc.shape[2]-1)/2], device=desc.device) - 1, align_corners=True).squeeze().T descriptors = F.normalize(descriptors, p=2, dim=1) - return keypoints_input, descriptors -def mutual_nearest_neighbor(template_descs, layout_descs): - """ - 使用互最近邻(MNN)找到模板和版图之间的匹配。 - - 参数: - template_descs (torch.Tensor): 模板描述子,形状为 [M, 128]。 - layout_descs (torch.Tensor): 版图描述子,形状为 [N, 128]。 - - 返回: - list: [(i_template, i_layout)],互最近邻匹配对的列表。 - """ - M, N = template_descs.size(0), layout_descs.size(0) - if M == 0 or N == 0: - return [] - similarity_matrix = template_descs @ layout_descs.T # [M, N],点积矩阵 - - # 找到每个模板描述子的最近邻 - nn_template_to_layout = torch.argmax(similarity_matrix, dim=1) # [M] - - # 找到每个版图描述子的最近邻 - nn_layout_to_template = torch.argmax(similarity_matrix, dim=0) # [N] - - # 找到互最近邻 - mutual_matches = [] - for i in range(M): - j = nn_template_to_layout[i] - if nn_layout_to_template[j] == i: - mutual_matches.append((i.item(), j.item())) - - return mutual_matches - -def ransac_filter(matches, template_kps, layout_kps): - """ - 使用 RANSAC 对匹配进行几何验证,并返回内点。 - - 参数: - matches (list): [(i_template, i_layout)],匹配对列表。 - template_kps (torch.Tensor): 模板关键点,形状为 [M, 2]。 - layout_kps (torch.Tensor): 版图关键点,形状为 [N, 2]。 - - 返回: - tuple: (inlier_matches, num_inliers) - - inlier_matches: [(i_template, i_layout)],内点匹配对。 - - num_inliers: int,内点数量。 - """ - src_pts = np.array([template_kps[i].cpu().numpy() for i, _ in matches]) - dst_pts = np.array([layout_kps[j].cpu().numpy() for _, j in matches]) - - if len(src_pts) < 4: - return [], 0 - - try: - H, mask = cv2.findHomography(src_pts, dst_pts, cv2.RANSAC, ransacReprojThreshold=5.0) - if H is None: - return [], 0 - inliers = mask.ravel() > 0 - num_inliers = np.sum(inliers) - inlier_matches = [matches[k] for k in range(len(matches)) if inliers[k]] - return inlier_matches, num_inliers - except cv2.error: - return [], 0 +def mutual_nearest_neighbor(descs1, descs2): + sim = descs1 @ descs2.T + nn12 = torch.max(sim, dim=1) + nn21 = torch.max(sim, dim=0) + ids1 = torch.arange(0, sim.shape[0], device=sim.device) + mask = (ids1 == nn21.indices[nn12.indices]) + matches = torch.stack([ids1[mask], nn12.indices[mask]], dim=1) + return matches.cpu().numpy() def match_template_to_layout(model, layout_image, template_image): - """ - 使用 RoRD 模型执行模板匹配,迭代找到所有匹配并屏蔽已匹配区域。 + layout_kps, layout_descs = extract_keypoints_and_descriptors(model, layout_image, config.KEYPOINT_THRESHOLD) + template_kps, template_descs = extract_keypoints_and_descriptors(model, template_image, config.KEYPOINT_THRESHOLD) - 参数: - model (RoRD): RoRD 模型。 - layout_image (torch.Tensor): 版图图像张量,形状为 [1, 1, H_layout, W_layout]。 - template_image (torch.Tensor): 模板图像张量,形状为 [1, 1, H_template, W_template]。 + active_layout_mask = torch.ones(len(layout_kps), dtype=bool, device=layout_kps.device) + found_instances = [] - 返回: - list: [{'x': x_min, 'y': y_min, 'width': w, 'height': h}],所有检测到的边框。 - """ - # 提取版图和模板的关键点和描述子 - layout_kps, layout_descs = extract_keypoints_and_descriptors(model, layout_image) - template_kps, template_descs = extract_keypoints_and_descriptors(model, template_image) - - # 初始化活动版图关键点掩码 - active_layout = torch.ones(len(layout_kps), dtype=bool) - - bboxes = [] while True: - # 获取当前活动的版图关键点和描述子 - current_layout_kps = layout_kps[active_layout] - current_layout_descs = layout_descs[active_layout] - - if len(current_layout_descs) == 0: + current_indices = torch.nonzero(active_layout_mask).squeeze(1) + if len(current_indices) < config.MIN_INLIERS: break - # MNN 匹配 + current_layout_kps, current_layout_descs = layout_kps[current_indices], layout_descs[current_indices] matches = mutual_nearest_neighbor(template_descs, current_layout_descs) + + if len(matches) < 4: break - if len(matches) == 0: + src_pts = template_kps[matches[:, 0]].cpu().numpy() + dst_pts = current_layout_kps[matches[:, 1]].cpu().numpy() + + H, mask = cv2.findHomography(src_pts, dst_pts, cv2.RANSAC, config.RANSAC_REPROJ_THRESHOLD) + if H is None or mask.sum() < config.MIN_INLIERS: break - # 将当前版图索引映射回原始版图索引 - active_indices = torch.nonzero(active_layout).squeeze(1) - matches_original = [(i_template, active_indices[i_layout].item()) for i_template, i_layout in matches] + inlier_mask = mask.ravel().astype(bool) + + # 区域屏蔽逻辑 + inlier_layout_kps = dst_pts[inlier_mask] + x_min, y_min = inlier_layout_kps.min(axis=0) + x_max, y_max = inlier_layout_kps.max(axis=0) + + instance = {'x': int(x_min), 'y': int(y_min), 'width': int(x_max - x_min), 'height': int(y_max - y_min), 'homography': H} + found_instances.append(instance) - # RANSAC 过滤 - inlier_matches, num_inliers = ransac_filter(matches_original, template_kps, layout_kps) + kp_x, kp_y = layout_kps[:, 0], layout_kps[:, 1] + region_mask = (kp_x >= x_min) & (kp_x <= x_max) & (kp_y >= y_min) & (kp_y <= y_max) + active_layout_mask[region_mask] = False + + print(f"找到实例,内点数: {mask.sum()}。剩余活动关键点: {active_layout_mask.sum()}") + + return found_instances - if num_inliers > 10: # 设置内点阈值 - # 获取内点在版图中的关键点 - inlier_layout_kps = [layout_kps[j].cpu().numpy() for _, j in inlier_matches] - inlier_layout_kps = np.array(inlier_layout_kps) - - # 计算边框 - x_min = int(inlier_layout_kps[:, 0].min()) - y_min = int(inlier_layout_kps[:, 1].min()) - x_max = int(inlier_layout_kps[:, 0].max()) - y_max = int(inlier_layout_kps[:, 1].max()) - bboxes.append({'x': x_min, 'y': y_min, 'width': x_max - x_min, 'height': y_max - y_min}) - - # 屏蔽内点 - for _, j in inlier_matches: - active_layout[j] = False - else: - break - - return bboxes +def visualize_matches(layout_path, template_path, bboxes, output_path): + layout_img = cv2.imread(layout_path) + for i, bbox in enumerate(bboxes): + x, y, w, h = bbox['x'], bbox['y'], bbox['width'], bbox['height'] + cv2.rectangle(layout_img, (x, y), (x + w, y + h), (0, 255, 0), 2) + cv2.putText(layout_img, f"Match {i+1}", (x, y - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2) + cv2.imwrite(output_path, layout_img) + print(f"可视化结果已保存至: {output_path}") if __name__ == "__main__": - # 设置变换 - transform = transforms.Compose([ - SobelTransform(), - transforms.ToTensor(), - transforms.Normalize(mean=[0.5], std=[0.5]) - ]) + parser = argparse.ArgumentParser(description="使用 RoRD 进行模板匹配") + parser.add_argument('--model_path', type=str, default=config.MODEL_PATH) + parser.add_argument('--layout', type=str, required=True) + parser.add_argument('--template', type=str, required=True) + parser.add_argument('--output', type=str) + args = parser.parse_args() - # 加载模型 + transform = get_transform() model = RoRD().cuda() - model.load_state_dict(torch.load('path/to/weights.pth')) + model.load_state_dict(torch.load(args.model_path)) model.eval() - # 加载版图和模板图像 - layout_image = Image.open('path/to/layout.png').convert('L') - layout_tensor = transform(layout_image).unsqueeze(0).cuda() + layout_tensor = transform(Image.open(args.layout).convert('L')).unsqueeze(0).cuda() + template_tensor = transform(Image.open(args.template).convert('L')).unsqueeze(0).cuda() - template_image = Image.open('path/to/template.png').convert('L') - template_tensor = transform(template_image).unsqueeze(0).cuda() - - # 执行匹配 detected_bboxes = match_template_to_layout(model, layout_tensor, template_tensor) - - # 打印检测到的边框 - print("检测到的边框:") + print("\n检测到的边界框:") for bbox in detected_bboxes: - print(bbox) \ No newline at end of file + print(bbox) + + if args.output: + visualize_matches(args.layout, args.template, detected_bboxes, args.output) \ No newline at end of file diff --git a/models/__init__.py b/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/models/rord.py b/models/rord.py index 0e7d61f..17b94fd 100644 --- a/models/rord.py +++ b/models/rord.py @@ -1,16 +1,25 @@ +# models/rord.py + import torch import torch.nn as nn from torchvision import models class RoRD(nn.Module): def __init__(self): + """ + 修复后的 RoRD 模型。 + - 实现了共享骨干网络,以提高计算效率和减少内存占用。 + - 移除了冗余的 descriptor_head_vanilla。 + """ super(RoRD, self).__init__() - # 检测骨干网络:VGG-16 直到 relu5_3(层 0 到 29) - self.backbone_det = models.vgg16(pretrained=True).features[:30] - # 描述骨干网络:VGG-16 直到 relu4_3(层 0 到 22) - self.backbone_desc = models.vgg16(pretrained=True).features[:23] - # 检测头:输出关键点概率图 + vgg16_features = models.vgg16(pretrained=True).features + + # 共享骨干网络 + self.slice1 = vgg16_features[:23] # 到 relu4_3 + self.slice2 = vgg16_features[23:30] # 从 relu4_3 到 relu5_3 + + # 检测头 self.detection_head = nn.Sequential( nn.Conv2d(512, 256, kernel_size=3, padding=1), nn.ReLU(inplace=True), @@ -18,16 +27,8 @@ class RoRD(nn.Module): nn.Sigmoid() ) - # 普通描述子头(D2-Net 风格) - self.descriptor_head_vanilla = nn.Sequential( - nn.Conv2d(512, 256, kernel_size=3, padding=1), - nn.ReLU(inplace=True), - nn.Conv2d(256, 128, kernel_size=1), - nn.InstanceNorm2d(128) - ) - - # RoRD 描述子头(旋转鲁棒) - self.descriptor_head_rord = nn.Sequential( + # 描述子头 + self.descriptor_head = nn.Sequential( nn.Conv2d(512, 256, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.Conv2d(256, 128, kernel_size=1), @@ -35,13 +36,14 @@ class RoRD(nn.Module): ) def forward(self, x): - # 检测分支 - features_det = self.backbone_det(x) - detection = self.detection_head(features_det) + # 共享特征提取 + features_shared = self.slice1(x) - # 描述分支 - features_desc = self.backbone_desc(x) - desc_vanilla = self.descriptor_head_vanilla(features_desc) - desc_rord = self.descriptor_head_rord(features_desc) + # 描述子分支 + descriptors = self.descriptor_head(features_shared) - return detection, desc_vanilla, desc_rord \ No newline at end of file + # 检测器分支 + features_det = self.slice2(features_shared) + detection_map = self.detection_head(features_det) + + return detection_map, descriptors \ No newline at end of file diff --git a/train.py b/train.py index 6a87606..73c3e64 100644 --- a/train.py +++ b/train.py @@ -1,236 +1,142 @@ +# train.py + import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import Dataset, DataLoader -from torchvision import transforms from PIL import Image import numpy as np import cv2 import os -from models.rord import RoRD +import argparse -# 数据集类:生成随机旋转的训练对 +# 导入项目模块 +import config +from models.rord import RoRD +from utils.data_utils import get_transform + +# --- 训练专用数据集类 --- class ICLayoutTrainingDataset(Dataset): def __init__(self, image_dir, patch_size=256, transform=None): - """ - 初始化 IC 版图训练数据集。 - - 参数: - image_dir (str): 存储 PNG 格式 IC 版图图像的目录路径。 - patch_size (int): 裁剪的 patch 大小(默认 256x256)。 - transform (callable, optional): 应用于图像的变换。 - """ self.image_dir = image_dir self.image_paths = [os.path.join(image_dir, f) for f in os.listdir(image_dir) if f.endswith('.png')] self.patch_size = patch_size self.transform = transform def __len__(self): - """ - 返回数据集中的图像数量。 - - 返回: - int: 数据集大小。 - """ return len(self.image_paths) def __getitem__(self, index): - """ - 获取指定索引的训练对(原始 patch、旋转 patch、Homography 矩阵)。 - - 参数: - index (int): 图像索引。 - - 返回: - tuple: (patch, rotated_patch, H_tensor) - - patch: 原始 patch 张量。 - - rotated_patch: 旋转后的 patch 张量。 - - H_tensor: Homography 矩阵张量。 - """ img_path = self.image_paths[index] - image = Image.open(img_path).convert('L') # 灰度图像 + image = Image.open(img_path).convert('L') - # 获取图像大小 W, H = image.size - - # 随机选择裁剪的左上角坐标 x = np.random.randint(0, W - self.patch_size + 1) y = np.random.randint(0, H - self.patch_size + 1) patch = image.crop((x, y, x + self.patch_size, y + self.patch_size)) - - # 转换为 NumPy 数组 patch_np = np.array(patch) + + # 实现8个方向的离散几何变换 + theta_deg = np.random.choice([0, 90, 180, 270]) + is_mirrored = np.random.choice([True, False]) + cx, cy = self.patch_size / 2.0, self.patch_size / 2.0 + M = cv2.getRotationMatrix2D((cx, cy), theta_deg, 1) - # 随机旋转角度(0°~360°) - theta = np.random.uniform(0, 360) - theta_rad = np.deg2rad(theta) - cos_theta = np.cos(theta_rad) - sin_theta = np.sin(theta_rad) + if is_mirrored: + T1 = np.array([[1, 0, -cx], [0, 1, -cy], [0, 0, 1]]) + Flip = np.array([[-1, 0, 0], [0, 1, 0], [0, 0, 1]]) + T2 = np.array([[1, 0, cx], [0, 1, cy], [0, 0, 1]]) + M_mirror_3x3 = T2 @ Flip @ T1 + M_3x3 = np.vstack([M, [0, 0, 1]]) + H = (M_3x3 @ M_mirror_3x3).astype(np.float32) + else: + H = np.vstack([M, [0, 0, 1]]).astype(np.float32) + + transformed_patch_np = cv2.warpPerspective(patch_np, H, (self.patch_size, self.patch_size)) + transformed_patch = Image.fromarray(transformed_patch_np) - # 计算旋转中心(patch 的中心) - cx = self.patch_size / 2.0 - cy = self.patch_size / 2.0 - - # 计算旋转的齐次矩阵(Homography) - H = np.array([ - [cos_theta, -sin_theta, cx * (1 - cos_theta) + cy * sin_theta], - [sin_theta, cos_theta, cy * (1 - cos_theta) - cx * sin_theta], - [0, 0, 1] - ], dtype=np.float32) - - # 应用旋转到 patch - rotated_patch_np = cv2.warpPerspective(patch_np, H, (self.patch_size, self.patch_size)) - - # 转换回 PIL Image - rotated_patch = Image.fromarray(rotated_patch_np) - - # 应用变换 if self.transform: patch = self.transform(patch) - rotated_patch = self.transform(rotated_patch) + transformed_patch = self.transform(transformed_patch) - # 转换 H 为张量 - H_tensor = torch.from_numpy(H).float() + H_tensor = torch.from_numpy(H[:2, :]).float() # 通常损失函数需要2x3的仿射矩阵 + return patch, transformed_patch, H_tensor - return patch, rotated_patch, H_tensor - -# 特征图变换函数 +# --- 特征图变换与损失函数 --- def warp_feature_map(feature_map, H_inv): - """ - 使用逆 Homography 矩阵变换特征图。 - - 参数: - feature_map (torch.Tensor): 输入特征图,形状为 [B, C, H, W]。 - H_inv (torch.Tensor): 逆 Homography 矩阵,形状为 [B, 3, 3]。 - - 返回: - torch.Tensor: 变换后的特征图,形状为 [B, C, H, W]。 - """ B, C, H, W = feature_map.size() - # 生成网格 - grid_y, grid_x = torch.meshgrid( - torch.linspace(-1, 1, H, device=feature_map.device), - torch.linspace(-1, 1, W, device=feature_map.device), - indexing='ij' - ) - grid = torch.stack((grid_x, grid_y, torch.ones_like(grid_x)), dim=-1) # [H, W, 3] - grid = grid.unsqueeze(0).expand(B, H, W, 3) # [B, H, W, 3] + grid = F.affine_grid(H_inv, feature_map.size(), align_corners=False).to(feature_map.device) + return F.grid_sample(feature_map, grid, align_corners=False) - # 将网格转换为齐次坐标并应用 H_inv - grid_flat = grid.view(B, -1, 3) # [B, H*W, 3] - grid_transformed = torch.bmm(grid_flat, H_inv.transpose(1, 2)) # [B, H*W, 3] - grid_transformed = grid_transformed.view(B, H, W, 3) # [B, H, W, 3] - grid_transformed = grid_transformed[..., :2] / (grid_transformed[..., 2:3] + 1e-8) # [B, H, W, 2] - - # 使用 grid_sample 进行变换 - warped_feature = F.grid_sample(feature_map, grid_transformed, align_corners=True) - return warped_feature - -# 检测损失函数 def compute_detection_loss(det_original, det_rotated, H): - """ - 计算检测损失(MSE),比较原始检测图与旋转检测图(逆变换后)。 - - 参数: - det_original (torch.Tensor): 原始图像的检测图,形状为 [B, 1, H, W]。 - det_rotated (torch.Tensor): 旋转图像的检测图,形状为 [B, 1, H, W]。 - H (torch.Tensor): Homography 矩阵,形状为 [B, 3, 3]。 - - 返回: - torch.Tensor: 检测损失。 - """ - H_inv = torch.inverse(H) # 计算逆 Homography + with torch.no_grad(): + H_inv = torch.inverse(torch.cat([H, torch.tensor([0.0, 0.0, 1.0]).view(1, 1, 3).repeat(H.shape[0], 1, 1)], dim=1))[:, :2, :] warped_det_rotated = warp_feature_map(det_rotated, H_inv) return F.mse_loss(det_original, warped_det_rotated) -# 描述子损失函数 def compute_description_loss(desc_original, desc_rotated, H, margin=1.0): - """ - 计算描述子损失(三元组损失),基于对应点的描述子。 + B, C, H_feat, W_feat = desc_original.size() + num_samples = 100 + + # 随机采样锚点坐标 + coords = torch.rand(B, num_samples, 2, device=desc_original.device) * 2 - 1 # [-1, 1] + + # 提取锚点描述子 + anchor = F.grid_sample(desc_original, coords.unsqueeze(1), align_corners=False).squeeze(2).transpose(1, 2) + + # 计算正样本坐标 + coords_hom = torch.cat([coords, torch.ones(B, num_samples, 1, device=coords.device)], dim=2) + M_inv = torch.inverse(torch.cat([H, torch.tensor([0.0, 0.0, 1.0]).view(1, 1, 3).repeat(H.shape[0], 1, 1)], dim=1)) + coords_transformed = (coords_hom @ M_inv.transpose(1, 2))[:, :, :2] + + # 提取正样本描述子 + positive = F.grid_sample(desc_rotated, coords_transformed.unsqueeze(1), align_corners=False).squeeze(2).transpose(1, 2) + + # 随机采样负样本 + neg_coords = torch.rand(B, num_samples, 2, device=desc_original.device) * 2 - 1 + negative = F.grid_sample(desc_rotated, neg_coords.unsqueeze(1), align_corners=False).squeeze(2).transpose(1, 2) - 参数: - desc_original (torch.Tensor): 原始图像的描述子图,形状为 [B, 128, H, W]。 - desc_rotated (torch.Tensor): 旋转图像的描述子图,形状为 [B, 128, H, W]。 - H (torch.Tensor): Homography 矩阵,形状为 [B, 3, 3]。 - margin (float): 三元组损失的边距。 - - 返回: - torch.Tensor: 描述子损失。 - """ - B, C, H, W = desc_original.size() - # 随机选择锚点(anchor) - num_samples = min(100, H * W) # 每张图像采样 100 个点 - idx = torch.randint(0, H * W, (B, num_samples), device=desc_original.device) - idx_y = idx // W - idx_x = idx % W - coords = torch.stack((idx_x.float(), idx_y.float()), dim=-1) # [B, num_samples, 2] - - # 转换为齐次坐标 - coords_hom = torch.cat((coords, torch.ones(B, num_samples, 1, device=coords.device)), dim=-1) # [B, num_samples, 3] - coords_transformed = torch.bmm(coords_hom, H.transpose(1, 2)) # [B, num_samples, 3] - coords_transformed = coords_transformed[..., :2] / (coords_transformed[..., 2:3] + 1e-8) # [B, num_samples, 2] - - # 归一化到 [-1, 1] 用于 grid_sample - coords_transformed = coords_transformed / torch.tensor([W/2, H/2], device=coords.device) - 1 - - # 提取锚点和正样本描述子 - anchor = desc_original.view(B, C, -1)[:, :, idx.view(-1)] # [B, 128, num_samples] - positive = F.grid_sample(desc_rotated, coords_transformed.unsqueeze(2), align_corners=True).squeeze(3) # [B, 128, num_samples] - - # 随机选择负样本 - neg_idx = torch.randint(0, H * W, (B, num_samples), device=desc_original.device) - negative = desc_rotated.view(B, C, -1)[:, :, neg_idx.view(-1)] # [B, 128, num_samples] - - # 三元组损失 triplet_loss = nn.TripletMarginLoss(margin=margin, p=2) - loss = triplet_loss(anchor.transpose(1, 2), positive.transpose(1, 2), negative.transpose(1, 2)) - return loss + return triplet_loss(anchor, positive, negative) -# 定义变换 -transform = transforms.Compose([ - transforms.ToTensor(), # (1, 256, 256) - transforms.Lambda(lambda x: x.repeat(3, 1, 1)), # (3, 256, 256) - transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) -]) +# --- 主函数与命令行接口 --- +def main(args): + print("--- 开始训练 RoRD 模型 ---") + print(f"训练参数: Epochs={args.epochs}, Batch Size={args.batch_size}, LR={args.lr}") + transform = get_transform() + dataset = ICLayoutTrainingDataset(args.data_dir, patch_size=config.PATCH_SIZE, transform=transform) + dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=4) + model = RoRD().cuda() + optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) -# 创建数据集和 DataLoader -dataset = ICLayoutTrainingDataset('path/to/layouts', patch_size=256, transform=transform) -dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=4) + for epoch in range(args.epochs): + model.train() + total_loss_val = 0 + for i, (original, rotated, H) in enumerate(dataloader): + original, rotated, H = original.cuda(), rotated.cuda(), H.cuda() + det_original, desc_original = model(original) + det_rotated, desc_rotated = model(rotated) -# 定义模型 -model = RoRD().cuda() + loss = compute_detection_loss(det_original, det_rotated, H) + compute_description_loss(desc_original, desc_rotated, H) + + optimizer.zero_grad() + loss.backward() + optimizer.step() + total_loss_val += loss.item() -# 定义优化器 -optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) + print(f"--- Epoch {epoch+1} 完成, 平均 Loss: {total_loss_val / len(dataloader):.4f} ---") -# 训练循环 -num_epochs = 10 -for epoch in range(num_epochs): - model.train() - total_loss = 0 - for batch in dataloader: - original, rotated, H = batch - original = original.cuda() - rotated = rotated.cuda() - H = H.cuda() + if not os.path.exists(args.save_dir): + os.makedirs(args.save_dir) + save_path = os.path.join(args.save_dir, 'rord_model_final.pth') + torch.save(model.state_dict(), save_path) + print(f"模型已保存至: {save_path}") - # 前向传播 - det_original, _, desc_rord_original = model(original) - det_rotated, _, desc_rord_rotated = model(rotated) - - # 计算损失 - detection_loss = compute_detection_loss(det_original, det_rotated, H) - description_loss = compute_description_loss(desc_rord_original, desc_rord_rotated, H) - total_loss_batch = detection_loss + description_loss - - # 反向传播 - optimizer.zero_grad() - total_loss_batch.backward() - optimizer.step() - - total_loss += total_loss_batch.item() - - print(f"Epoch {epoch+1}/{num_epochs}, Loss: {total_loss / len(dataloader):.4f}") - -# 保存模型 -torch.save(model.state_dict(), 'path/to/save/model.pth') \ No newline at end of file +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="训练 RoRD 模型") + parser.add_argument('--data_dir', type=str, default=config.LAYOUT_DIR) + parser.add_argument('--save_dir', type=str, default=config.SAVE_DIR) + parser.add_argument('--epochs', type=int, default=config.NUM_EPOCHS) + parser.add_argument('--batch_size', type=int, default=config.BATCH_SIZE) + parser.add_argument('--lr', type=float, default=config.LEARNING_RATE) + main(parser.parse_args()) \ No newline at end of file diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/utils/data_utils.py b/utils/data_utils.py new file mode 100644 index 0000000..5891439 --- /dev/null +++ b/utils/data_utils.py @@ -0,0 +1,14 @@ +from torchvision import transforms +from .transforms import SobelTransform + +def get_transform(): + """ + 获取统一的图像预处理管道。 + 确保训练、评估和推理使用完全相同的预处理。 + """ + return transforms.Compose([ + SobelTransform(), # 应用 Sobel 边缘检测 + transforms.ToTensor(), + transforms.Lambda(lambda x: x.repeat(3, 1, 1)), # 适配 VGG 的三通道输入 + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + ]) \ No newline at end of file