第二次大修
This commit is contained in:
		
							
								
								
									
										54
									
								
								README.md
									
									
									
									
									
								
							
							
						
						
									
										54
									
								
								README.md
									
									
									
									
									
								
							| @@ -58,49 +58,33 @@ ic_layout_recognition/ | ||||
| └── README.md | ||||
| ``` | ||||
|  | ||||
| ### 1. 训练模型 | ||||
| ## 🚀 使用方法 | ||||
|  | ||||
| 使用以下命令启动模型训练。训练过程采用自监督学习,通过对图像应用随机旋转来生成训练对,从而优化关键点检测和描述子生成。 | ||||
| ### 1. 配置 | ||||
| 首先,请修改 **`config.py`** 文件,设置正确的训练数据、验证数据和模型保存路径。 | ||||
|  | ||||
| ### 2. 训练模型 | ||||
| ```bash | ||||
| python train.py --data_dir path/to/layouts --save_dir path/to/save | ||||
| python train.py --data_dir /path/to/your/layouts --save_dir /path/to/your/models --epochs 50 | ||||
| ``` | ||||
| 使用 `--help` 查看更多选项。 | ||||
|  | ||||
| ### 3. 模板匹配 | ||||
| ```bash | ||||
| python match.py --model_path /path/to/your/models/rord_model_final.pth \ | ||||
|                 --layout /path/to/layout.png \ | ||||
|                 --template /path/to/template.png \ | ||||
|                 --output /path/to/result.png | ||||
| ``` | ||||
|  | ||||
| | 参数 | 描述 | | ||||
| | :--- | :--- | | ||||
| | `--data_dir` | **[必需]** 包含 PNG 格式 IC 版图图像的目录。 | | ||||
| | `--save_dir` | **[必需]** 训练好的模型权重保存目录。 | | ||||
|  | ||||
| ### 2. 评估模型 | ||||
|  | ||||
| 使用以下命令在验证集上评估模型的性能。评估脚本会计算基于 IoU 阈值的精确率、召回率和 F1 分数。 | ||||
|  | ||||
| ### 4. 评估模型 | ||||
| ```bash | ||||
| python evaluate.py --model_path path/to/model.pth --val_dir path/to/val/images --annotations_dir path/to/val/annotations --templates path/to/templates | ||||
| python evaluate.py --model_path /path/to/your/models/rord_model_final.pth \ | ||||
|                    --val_dir /path/to/val/images \ | ||||
|                    --annotations_dir /path/to/val/annotations \ | ||||
|                    --templates_dir /path/to/templates | ||||
| ``` | ||||
|  | ||||
| | 参数 | 描述 | | ||||
| | :--- | :--- | | ||||
| | `--model_path` | **[必需]** 训练好的模型权重 (`.pth`) 文件路径。 | | ||||
| | `--val_dir` | **[必需]** 验证集图像目录。 | | ||||
| | `--annotations_dir` | **[必需]** 包含真实标注的 JSON 文件目录。 | | ||||
| | `--templates` | **[必需]** 模板图像的路径列表。 | | ||||
|  | ||||
| ### 3. 进行模板匹配 | ||||
|  | ||||
| 使用以下命令将模板图像与指定的版图图像进行匹配。匹配过程利用 RoRD 模型提取关键点和描述子,通过互最近邻(MNN)匹配和 RANSAC 几何验证来定位模板。 | ||||
|  | ||||
| ```bash | ||||
| python match.py --model_path path/to/model.pth --layout_path path/to/layout.png --template_path path/to/template.png --output_path path/to/output.png | ||||
| ``` | ||||
|  | ||||
| | 参数 | 描述 | | ||||
| | :--- | :--- | | ||||
| | `--model_path` | **[必需]** 训练好的模型权重 (`.pth`) 文件路径。 | | ||||
| | `--layout_path` | **[必需]** 待匹配的版图图像路径。 | | ||||
| | `--template_path` | **[必需]** 模板图像路径。 | | ||||
| | `--output_path` | **[可选]** 保存可视化匹配结果的路径。 | | ||||
|  | ||||
| ## 📦 数据准备 | ||||
|  | ||||
| ### 训练数据 | ||||
|   | ||||
							
								
								
									
										31
									
								
								config.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										31
									
								
								config.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,31 @@ | ||||
| # config.py | ||||
|  | ||||
| # --- 训练参数 --- | ||||
| LEARNING_RATE = 1e-4 | ||||
| BATCH_SIZE = 4 | ||||
| NUM_EPOCHS = 20  # 增加了训练轮数 | ||||
| PATCH_SIZE = 256 | ||||
|  | ||||
| # --- 匹配与评估参数 --- | ||||
| # 关键点检测的置信度阈值 | ||||
| KEYPOINT_THRESHOLD = 0.5 | ||||
| # RANSAC 重投影误差阈值(像素) | ||||
| RANSAC_REPROJ_THRESHOLD = 5.0 | ||||
| # RANSAC 判定为有效匹配所需的最小内点数 | ||||
| MIN_INLIERS = 15 # 适当提高以增加匹配的可靠性 | ||||
| # IoU (Intersection over Union) 阈值,用于评估 | ||||
| IOU_THRESHOLD = 0.5 | ||||
|  | ||||
| # --- 文件路径 --- | ||||
| # 训练数据目录 | ||||
| LAYOUT_DIR = 'path/to/layouts' | ||||
| # 模型保存目录 | ||||
| SAVE_DIR = 'path/to/save' | ||||
| # 验证集图像目录 | ||||
| VAL_IMG_DIR = 'path/to/val/images' | ||||
| # 验证集标注目录 | ||||
| VAL_ANN_DIR = 'path/to/val/annotations' | ||||
| # 模板图像目录 | ||||
| TEMPLATE_DIR = 'path/to/templates' | ||||
| # 默认加载的模型路径 | ||||
| MODEL_PATH = 'path/to/save/model_final.pth' | ||||
							
								
								
									
										0
									
								
								data/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								data/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
								
								
									
										147
									
								
								evaluate.py
									
									
									
									
									
								
							
							
						
						
									
										147
									
								
								evaluate.py
									
									
									
									
									
								
							| @@ -1,123 +1,84 @@ | ||||
| from models.rord import RoRD | ||||
| from data.ic_dataset import ICLayoutDataset | ||||
| from utils.transforms import SobelTransform | ||||
| from match import match_template_to_layout | ||||
| # evaluate.py | ||||
|  | ||||
| import torch | ||||
| from torchvision import transforms | ||||
| from PIL import Image | ||||
| import json | ||||
| import os | ||||
| from PIL import Image | ||||
| import argparse | ||||
|  | ||||
| import config | ||||
| from models.rord import RoRD | ||||
| from utils.data_utils import get_transform | ||||
| from data.ic_dataset import ICLayoutDataset | ||||
| from match import match_template_to_layout | ||||
|  | ||||
| def compute_iou(box1, box2): | ||||
|     x1, y1, w1, h1 = box1['x'], box1['y'], box1['width'], box1['height'] | ||||
|     x2, y2, w2, h2 = box2['x'], box2['y'], box2['width'], box2['height'] | ||||
|      | ||||
|     inter_x1 = max(x1, x2) | ||||
|     inter_y1 = max(y1, y2) | ||||
|     inter_x2 = min(x1 + w1, x2 + w2) | ||||
|     inter_y2 = min(y1 + h1, y2 + h2) | ||||
|      | ||||
|     inter_x1, inter_y1 = max(x1, x2), max(y1, y2) | ||||
|     inter_x2, inter_y2 = min(x1 + w1, x2 + w2), min(y1 + h1, y2 + h2) | ||||
|     inter_area = max(0, inter_x2 - inter_x1) * max(0, inter_y2 - inter_y1) | ||||
|     union_area = w1 * h1 + w2 * h2 - inter_area | ||||
|     return inter_area / union_area if union_area > 0 else 0 | ||||
|  | ||||
|     box1_area = w1 * h1 | ||||
|     box2_area = w2 * h2 | ||||
|     union_area = box1_area + box2_area - inter_area | ||||
|      | ||||
|     iou = inter_area / union_area if union_area > 0 else 0 | ||||
|     return iou | ||||
|  | ||||
| def evaluate(model, val_dataset, templates, iou_threshold=0.5): | ||||
| def evaluate(model, val_dataset, template_dir): | ||||
|     model.eval() | ||||
|     all_true_positives = 0 | ||||
|     all_false_positives = 0 | ||||
|     all_false_negatives = 0 | ||||
|     all_tp, all_fp, all_fn = 0, 0, 0 | ||||
|     transform = get_transform() | ||||
|      | ||||
|     for layout_idx in range(len(val_dataset)): | ||||
|         layout_image, annotation = val_dataset[layout_idx] | ||||
|         # layout_image is [3, H, W] | ||||
|         layout_tensor = layout_image.unsqueeze(0).cuda()  # [1, 3, H, W] | ||||
|     template_paths = [os.path.join(template_dir, f) for f in os.listdir(template_dir) if f.endswith('.png')] | ||||
|  | ||||
|         # 假设 annotation 是 {"boxes": [{"template": "template1.png", "x": x, "y": y, "width": w, "height": h}, ...]} | ||||
|         gt_boxes_by_template = {} | ||||
|     for layout_tensor, annotation in val_dataset: | ||||
|         layout_tensor = layout_tensor.unsqueeze(0).cuda() | ||||
|         gt_by_template = {box['template']: [] for box in annotation.get('boxes', [])} | ||||
|         for box in annotation.get('boxes', []): | ||||
|             template_name = box['template'] | ||||
|             if template_name not in gt_boxes_by_template: | ||||
|                 gt_boxes_by_template[template_name] = [] | ||||
|             gt_boxes_by_template[template_name].append(box) | ||||
|             gt_by_template[box['template']].append(box) | ||||
|  | ||||
|         for template_path in templates: | ||||
|         for template_path in template_paths: | ||||
|             template_name = os.path.basename(template_path) | ||||
|             template_image = Image.open(template_path).convert('L') | ||||
|             template_tensor = transform(template_image).unsqueeze(0).cuda()  # [1, 3, H, W] | ||||
|             template_tensor = transform(Image.open(template_path).convert('L')).unsqueeze(0).cuda() | ||||
|              | ||||
|             # 执行匹配 | ||||
|             detected_bboxes = match_template_to_layout(model, layout_tensor, template_tensor) | ||||
|             detected = match_template_to_layout(model, layout_tensor, template_tensor) | ||||
|             gt_boxes = gt_by_template.get(template_name, []) | ||||
|              | ||||
|             # 获取当前模板的 gt_boxes | ||||
|             gt_boxes = gt_boxes_by_template.get(template_name, []) | ||||
|  | ||||
|             # 初始化已分配的 gt_box 索引 | ||||
|             assigned_gt = set() | ||||
|  | ||||
|             for det_box in detected_bboxes: | ||||
|             matched_gt = [False] * len(gt_boxes) | ||||
|             tp = 0 | ||||
|             for det_box in detected: | ||||
|                 best_iou = 0 | ||||
|                 best_gt_idx = -1 | ||||
|                 for idx, gt_box in enumerate(gt_boxes): | ||||
|                     if idx in assigned_gt: | ||||
|                         continue | ||||
|                 for i, gt_box in enumerate(gt_boxes): | ||||
|                     if matched_gt[i]: continue | ||||
|                     iou = compute_iou(det_box, gt_box) | ||||
|                     if iou > best_iou: | ||||
|                         best_iou = iou | ||||
|                         best_gt_idx = idx | ||||
|                 if best_iou > iou_threshold and best_gt_idx != -1: | ||||
|                     all_true_positives += 1 | ||||
|                     assigned_gt.add(best_gt_idx) | ||||
|                 else: | ||||
|                     all_false_positives += 1 | ||||
|                         best_iou, best_gt_idx = iou, i | ||||
|                  | ||||
|             # 计算 FN:未分配的 gt_box | ||||
|             for idx in range(len(gt_boxes)): | ||||
|                 if idx not in assigned_gt: | ||||
|                     all_false_negatives += 1 | ||||
|                 if best_iou > config.IOU_THRESHOLD: | ||||
|                     tp += 1 | ||||
|                     matched_gt[best_gt_idx] = True | ||||
|              | ||||
|     # 计算评估指标 | ||||
|     precision = all_true_positives / (all_true_positives + all_false_positives) if (all_true_positives + all_false_positives) > 0 else 0 | ||||
|     recall = all_true_positives / (all_true_positives + all_false_negatives) if (all_true_positives + all_false_negatives) > 0 else 0 | ||||
|             all_tp += tp | ||||
|             all_fp += len(detected) - tp | ||||
|             all_fn += len(gt_boxes) - tp | ||||
|  | ||||
|     precision = all_tp / (all_tp + all_fp) if (all_tp + all_fp) > 0 else 0 | ||||
|     recall = all_tp / (all_tp + all_fn) if (all_tp + all_fn) > 0 else 0 | ||||
|     f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0 | ||||
|  | ||||
|     return { | ||||
|         'precision': precision, | ||||
|         'recall': recall, | ||||
|         'f1': f1 | ||||
|     } | ||||
|     return {'precision': precision, 'recall': recall, 'f1': f1} | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     # 设置变换 | ||||
|     transform = transforms.Compose([ | ||||
|         SobelTransform(), | ||||
|         transforms.ToTensor(), | ||||
|         transforms.Lambda(lambda x: x.repeat(3, 1, 1)),  # [1, H, W] -> [3, H, W] | ||||
|         transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) | ||||
|     ]) | ||||
|     parser = argparse.ArgumentParser(description="评估 RoRD 模型性能") | ||||
|     parser.add_argument('--model_path', type=str, default=config.MODEL_PATH) | ||||
|     parser.add_argument('--val_dir', type=str, default=config.VAL_IMG_DIR) | ||||
|     parser.add_argument('--annotations_dir', type=str, default=config.VAL_ANN_DIR) | ||||
|     parser.add_argument('--templates_dir', type=str, default=config.TEMPLATE_DIR) | ||||
|     args = parser.parse_args() | ||||
|  | ||||
|     # 加载模型 | ||||
|     model = RoRD().cuda() | ||||
|     model.load_state_dict(torch.load('path/to/weights.pth')) | ||||
|     model.eval() | ||||
|     model.load_state_dict(torch.load(args.model_path)) | ||||
|     val_dataset = ICLayoutDataset(args.val_dir, args.annotations_dir, get_transform()) | ||||
|      | ||||
|     # 定义验证数据集 | ||||
|     val_dataset = ICLayoutDataset( | ||||
|         image_dir='path/to/val/images', | ||||
|         annotation_dir='path/to/val/annotations', | ||||
|         transform=transform | ||||
|     ) | ||||
|  | ||||
|     # 定义模板列表 | ||||
|     templates = ['path/to/templates/template1.png', 'path/to/templates/template2.png']  # 替换为实际模板路径 | ||||
|  | ||||
|     # 评估模型 | ||||
|     results = evaluate(model, val_dataset, templates) | ||||
|     results = evaluate(model, val_dataset, args.templates_dir) | ||||
|     print("评估结果:") | ||||
|     print(f"精确率: {results['precision']:.4f}") | ||||
|     print(f"召回率: {results['recall']:.4f}") | ||||
|     print(f"F1 分数: {results['f1']:.4f}") | ||||
|     print(f"  精确率 (Precision): {results['precision']:.4f}") | ||||
|     print(f"  召回率 (Recall):    {results['recall']:.4f}") | ||||
|     print(f"  F1 分数 (F1 Score):  {results['f1']:.4f}") | ||||
							
								
								
									
										231
									
								
								match.py
									
									
									
									
									
								
							
							
						
						
									
										231
									
								
								match.py
									
									
									
									
									
								
							| @@ -1,199 +1,108 @@ | ||||
| # match.py | ||||
|  | ||||
| import torch | ||||
| import torch.nn.functional as F | ||||
| from models.rord import RoRD | ||||
| from torchvision import transforms | ||||
| from utils.transforms import SobelTransform | ||||
| import numpy as np | ||||
| import cv2 | ||||
| from PIL import Image | ||||
| import argparse | ||||
| import os | ||||
|  | ||||
| def extract_keypoints_and_descriptors(model, image): | ||||
|     """ | ||||
|     从 RoRD 模型中提取关键点和描述子。 | ||||
| import config | ||||
| from models.rord import RoRD | ||||
| from utils.data_utils import get_transform | ||||
|  | ||||
|     参数: | ||||
|         model (RoRD): RoRD 模型。 | ||||
|         image (torch.Tensor): 输入图像张量,形状为 [1, 1, H, W]。 | ||||
|  | ||||
|     返回: | ||||
|         tuple: (keypoints_input, descriptors) | ||||
|             - keypoints_input: [N, 2] float tensor,关键点在输入图像中的坐标。 | ||||
|             - descriptors: [N, 128] float tensor,L2 归一化的描述子。 | ||||
|     """ | ||||
| def extract_keypoints_and_descriptors(model, image, kp_thresh): | ||||
|     with torch.no_grad(): | ||||
|         detection_map, _, desc_rord = model(image) | ||||
|         desc = desc_rord  # 使用 RoRD 描述子头 | ||||
|         detection_map, desc = model(image) | ||||
|         binary_map = (detection_map > kp_thresh).float() | ||||
|         coords = torch.nonzero(binary_map[0, 0]).float() | ||||
|         keypoints_input = coords[:, [1, 0]] * 8.0 # Stride of descriptor is 8 | ||||
|  | ||||
|         # 从检测图中提取关键点 | ||||
|         thresh = 0.5 | ||||
|         binary_map = (detection_map > thresh).float() | ||||
|         coords = torch.nonzero(binary_map[0, 0] > thresh).float()  # [N, 2],每个行是 (i_d, j_d) | ||||
|         keypoints_input = coords * 16.0  # 将特征图坐标映射到输入图像坐标(stride=16) | ||||
|  | ||||
|         # 从描述子图中提取描述子 | ||||
|         # detection_map 的形状为 [1, 1, H/16, W/16],desc 的形状为 [1, 128, H/8, W/8] | ||||
|         # 将 detection_map 的坐标映射到 desc 的坐标:(i_d * 2, j_d * 2) | ||||
|         keypoints_desc = (coords * 2).long()  # [N, 2],整数坐标 | ||||
|         H_desc, W_desc = desc.shape[2], desc.shape[3] | ||||
|         mask = (keypoints_desc[:, 0] < H_desc) & (keypoints_desc[:, 1] < W_desc) | ||||
|         keypoints_desc = keypoints_desc[mask] | ||||
|         keypoints_input = keypoints_input[mask] | ||||
|  | ||||
|         # 提取描述子 | ||||
|         descriptors = desc[0, :, keypoints_desc[:, 0], keypoints_desc[:, 1]].T  # [N, 128] | ||||
|  | ||||
|         # L2 归一化描述子 | ||||
|         descriptors = F.grid_sample(desc, coords.flip(1).view(1, -1, 1, 2) / torch.tensor([(desc.shape[3]-1)/2, (desc.shape[2]-1)/2], device=desc.device) - 1, align_corners=True).squeeze().T | ||||
|         descriptors = F.normalize(descriptors, p=2, dim=1) | ||||
|  | ||||
|         return keypoints_input, descriptors | ||||
|  | ||||
| def mutual_nearest_neighbor(template_descs, layout_descs): | ||||
|     """ | ||||
|     使用互最近邻(MNN)找到模板和版图之间的匹配。 | ||||
|  | ||||
|     参数: | ||||
|         template_descs (torch.Tensor): 模板描述子,形状为 [M, 128]。 | ||||
|         layout_descs (torch.Tensor): 版图描述子,形状为 [N, 128]。 | ||||
|  | ||||
|     返回: | ||||
|         list: [(i_template, i_layout)],互最近邻匹配对的列表。 | ||||
|     """ | ||||
|     M, N = template_descs.size(0), layout_descs.size(0) | ||||
|     if M == 0 or N == 0: | ||||
|         return [] | ||||
|     similarity_matrix = template_descs @ layout_descs.T  # [M, N],点积矩阵 | ||||
|  | ||||
|     # 找到每个模板描述子的最近邻 | ||||
|     nn_template_to_layout = torch.argmax(similarity_matrix, dim=1)  # [M] | ||||
|  | ||||
|     # 找到每个版图描述子的最近邻 | ||||
|     nn_layout_to_template = torch.argmax(similarity_matrix, dim=0)  # [N] | ||||
|  | ||||
|     # 找到互最近邻 | ||||
|     mutual_matches = [] | ||||
|     for i in range(M): | ||||
|         j = nn_template_to_layout[i] | ||||
|         if nn_layout_to_template[j] == i: | ||||
|             mutual_matches.append((i.item(), j.item())) | ||||
|  | ||||
|     return mutual_matches | ||||
|  | ||||
| def ransac_filter(matches, template_kps, layout_kps): | ||||
|     """ | ||||
|     使用 RANSAC 对匹配进行几何验证,并返回内点。 | ||||
|  | ||||
|     参数: | ||||
|         matches (list): [(i_template, i_layout)],匹配对列表。 | ||||
|         template_kps (torch.Tensor): 模板关键点,形状为 [M, 2]。 | ||||
|         layout_kps (torch.Tensor): 版图关键点,形状为 [N, 2]。 | ||||
|  | ||||
|     返回: | ||||
|         tuple: (inlier_matches, num_inliers) | ||||
|             - inlier_matches: [(i_template, i_layout)],内点匹配对。 | ||||
|             - num_inliers: int,内点数量。 | ||||
|     """ | ||||
|     src_pts = np.array([template_kps[i].cpu().numpy() for i, _ in matches]) | ||||
|     dst_pts = np.array([layout_kps[j].cpu().numpy() for _, j in matches]) | ||||
|  | ||||
|     if len(src_pts) < 4: | ||||
|         return [], 0 | ||||
|  | ||||
|     try: | ||||
|         H, mask = cv2.findHomography(src_pts, dst_pts, cv2.RANSAC, ransacReprojThreshold=5.0) | ||||
|         if H is None: | ||||
|             return [], 0 | ||||
|         inliers = mask.ravel() > 0 | ||||
|         num_inliers = np.sum(inliers) | ||||
|         inlier_matches = [matches[k] for k in range(len(matches)) if inliers[k]] | ||||
|         return inlier_matches, num_inliers | ||||
|     except cv2.error: | ||||
|         return [], 0 | ||||
| def mutual_nearest_neighbor(descs1, descs2): | ||||
|     sim = descs1 @ descs2.T | ||||
|     nn12 = torch.max(sim, dim=1) | ||||
|     nn21 = torch.max(sim, dim=0) | ||||
|     ids1 = torch.arange(0, sim.shape[0], device=sim.device) | ||||
|     mask = (ids1 == nn21.indices[nn12.indices]) | ||||
|     matches = torch.stack([ids1[mask], nn12.indices[mask]], dim=1) | ||||
|     return matches.cpu().numpy() | ||||
|  | ||||
| def match_template_to_layout(model, layout_image, template_image): | ||||
|     """ | ||||
|     使用 RoRD 模型执行模板匹配,迭代找到所有匹配并屏蔽已匹配区域。 | ||||
|     layout_kps, layout_descs = extract_keypoints_and_descriptors(model, layout_image, config.KEYPOINT_THRESHOLD) | ||||
|     template_kps, template_descs = extract_keypoints_and_descriptors(model, template_image, config.KEYPOINT_THRESHOLD) | ||||
|  | ||||
|     参数: | ||||
|         model (RoRD): RoRD 模型。 | ||||
|         layout_image (torch.Tensor): 版图图像张量,形状为 [1, 1, H_layout, W_layout]。 | ||||
|         template_image (torch.Tensor): 模板图像张量,形状为 [1, 1, H_template, W_template]。 | ||||
|     active_layout_mask = torch.ones(len(layout_kps), dtype=bool, device=layout_kps.device) | ||||
|     found_instances = [] | ||||
|  | ||||
|     返回: | ||||
|         list: [{'x': x_min, 'y': y_min, 'width': w, 'height': h}],所有检测到的边框。 | ||||
|     """ | ||||
|     # 提取版图和模板的关键点和描述子 | ||||
|     layout_kps, layout_descs = extract_keypoints_and_descriptors(model, layout_image) | ||||
|     template_kps, template_descs = extract_keypoints_and_descriptors(model, template_image) | ||||
|  | ||||
|     # 初始化活动版图关键点掩码 | ||||
|     active_layout = torch.ones(len(layout_kps), dtype=bool) | ||||
|  | ||||
|     bboxes = [] | ||||
|     while True: | ||||
|         # 获取当前活动的版图关键点和描述子 | ||||
|         current_layout_kps = layout_kps[active_layout] | ||||
|         current_layout_descs = layout_descs[active_layout] | ||||
|  | ||||
|         if len(current_layout_descs) == 0: | ||||
|         current_indices = torch.nonzero(active_layout_mask).squeeze(1) | ||||
|         if len(current_indices) < config.MIN_INLIERS: | ||||
|             break | ||||
|  | ||||
|         # MNN 匹配 | ||||
|         current_layout_kps, current_layout_descs = layout_kps[current_indices], layout_descs[current_indices] | ||||
|         matches = mutual_nearest_neighbor(template_descs, current_layout_descs) | ||||
|          | ||||
|         if len(matches) == 0: | ||||
|         if len(matches) < 4: break | ||||
|  | ||||
|         src_pts = template_kps[matches[:, 0]].cpu().numpy() | ||||
|         dst_pts = current_layout_kps[matches[:, 1]].cpu().numpy() | ||||
|  | ||||
|         H, mask = cv2.findHomography(src_pts, dst_pts, cv2.RANSAC, config.RANSAC_REPROJ_THRESHOLD) | ||||
|         if H is None or mask.sum() < config.MIN_INLIERS: | ||||
|             break | ||||
|  | ||||
|         # 将当前版图索引映射回原始版图索引 | ||||
|         active_indices = torch.nonzero(active_layout).squeeze(1) | ||||
|         matches_original = [(i_template, active_indices[i_layout].item()) for i_template, i_layout in matches] | ||||
|         inlier_mask = mask.ravel().astype(bool) | ||||
|          | ||||
|         # RANSAC 过滤 | ||||
|         inlier_matches, num_inliers = ransac_filter(matches_original, template_kps, layout_kps) | ||||
|         # 区域屏蔽逻辑 | ||||
|         inlier_layout_kps = dst_pts[inlier_mask] | ||||
|         x_min, y_min = inlier_layout_kps.min(axis=0) | ||||
|         x_max, y_max = inlier_layout_kps.max(axis=0) | ||||
|          | ||||
|         if num_inliers > 10:  # 设置内点阈值 | ||||
|             # 获取内点在版图中的关键点 | ||||
|             inlier_layout_kps = [layout_kps[j].cpu().numpy() for _, j in inlier_matches] | ||||
|             inlier_layout_kps = np.array(inlier_layout_kps) | ||||
|         instance = {'x': int(x_min), 'y': int(y_min), 'width': int(x_max - x_min), 'height': int(y_max - y_min), 'homography': H} | ||||
|         found_instances.append(instance) | ||||
|  | ||||
|             # 计算边框 | ||||
|             x_min = int(inlier_layout_kps[:, 0].min()) | ||||
|             y_min = int(inlier_layout_kps[:, 1].min()) | ||||
|             x_max = int(inlier_layout_kps[:, 0].max()) | ||||
|             y_max = int(inlier_layout_kps[:, 1].max()) | ||||
|             bboxes.append({'x': x_min, 'y': y_min, 'width': x_max - x_min, 'height': y_max - y_min}) | ||||
|         kp_x, kp_y = layout_kps[:, 0], layout_kps[:, 1] | ||||
|         region_mask = (kp_x >= x_min) & (kp_x <= x_max) & (kp_y >= y_min) & (kp_y <= y_max) | ||||
|         active_layout_mask[region_mask] = False | ||||
|          | ||||
|             # 屏蔽内点 | ||||
|             for _, j in inlier_matches: | ||||
|                 active_layout[j] = False | ||||
|         else: | ||||
|             break | ||||
|         print(f"找到实例,内点数: {mask.sum()}。剩余活动关键点: {active_layout_mask.sum()}") | ||||
|              | ||||
|     return bboxes | ||||
|     return found_instances | ||||
|  | ||||
| def visualize_matches(layout_path, template_path, bboxes, output_path): | ||||
|     layout_img = cv2.imread(layout_path) | ||||
|     for i, bbox in enumerate(bboxes): | ||||
|         x, y, w, h = bbox['x'], bbox['y'], bbox['width'], bbox['height'] | ||||
|         cv2.rectangle(layout_img, (x, y), (x + w, y + h), (0, 255, 0), 2) | ||||
|         cv2.putText(layout_img, f"Match {i+1}", (x, y - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2) | ||||
|     cv2.imwrite(output_path, layout_img) | ||||
|     print(f"可视化结果已保存至: {output_path}") | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     # 设置变换 | ||||
|     transform = transforms.Compose([ | ||||
|         SobelTransform(), | ||||
|         transforms.ToTensor(), | ||||
|         transforms.Normalize(mean=[0.5], std=[0.5]) | ||||
|     ]) | ||||
|     parser = argparse.ArgumentParser(description="使用 RoRD 进行模板匹配") | ||||
|     parser.add_argument('--model_path', type=str, default=config.MODEL_PATH) | ||||
|     parser.add_argument('--layout', type=str, required=True) | ||||
|     parser.add_argument('--template', type=str, required=True) | ||||
|     parser.add_argument('--output', type=str) | ||||
|     args = parser.parse_args() | ||||
|  | ||||
|     # 加载模型 | ||||
|     transform = get_transform() | ||||
|     model = RoRD().cuda() | ||||
|     model.load_state_dict(torch.load('path/to/weights.pth')) | ||||
|     model.load_state_dict(torch.load(args.model_path)) | ||||
|     model.eval() | ||||
|  | ||||
|     # 加载版图和模板图像 | ||||
|     layout_image = Image.open('path/to/layout.png').convert('L') | ||||
|     layout_tensor = transform(layout_image).unsqueeze(0).cuda() | ||||
|     layout_tensor = transform(Image.open(args.layout).convert('L')).unsqueeze(0).cuda() | ||||
|     template_tensor = transform(Image.open(args.template).convert('L')).unsqueeze(0).cuda() | ||||
|  | ||||
|     template_image = Image.open('path/to/template.png').convert('L') | ||||
|     template_tensor = transform(template_image).unsqueeze(0).cuda() | ||||
|  | ||||
|     # 执行匹配 | ||||
|     detected_bboxes = match_template_to_layout(model, layout_tensor, template_tensor) | ||||
|  | ||||
|     # 打印检测到的边框 | ||||
|     print("检测到的边框:") | ||||
|     print("\n检测到的边界框:") | ||||
|     for bbox in detected_bboxes: | ||||
|         print(bbox) | ||||
|  | ||||
|     if args.output: | ||||
|         visualize_matches(args.layout, args.template, detected_bboxes, args.output) | ||||
							
								
								
									
										0
									
								
								models/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								models/__init__.py
									
									
									
									
									
										Normal file
									
								
							| @@ -1,16 +1,25 @@ | ||||
| # models/rord.py | ||||
|  | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| from torchvision import models | ||||
|  | ||||
| class RoRD(nn.Module): | ||||
|     def __init__(self): | ||||
|         """ | ||||
|         修复后的 RoRD 模型。 | ||||
|         - 实现了共享骨干网络,以提高计算效率和减少内存占用。 | ||||
|         - 移除了冗余的 descriptor_head_vanilla。 | ||||
|         """ | ||||
|         super(RoRD, self).__init__() | ||||
|         # 检测骨干网络:VGG-16 直到 relu5_3(层 0 到 29) | ||||
|         self.backbone_det = models.vgg16(pretrained=True).features[:30] | ||||
|         # 描述骨干网络:VGG-16 直到 relu4_3(层 0 到 22) | ||||
|         self.backbone_desc = models.vgg16(pretrained=True).features[:23] | ||||
|          | ||||
|         # 检测头:输出关键点概率图 | ||||
|         vgg16_features = models.vgg16(pretrained=True).features | ||||
|          | ||||
|         # 共享骨干网络 | ||||
|         self.slice1 = vgg16_features[:23]  # 到 relu4_3 | ||||
|         self.slice2 = vgg16_features[23:30] # 从 relu4_3 到 relu5_3 | ||||
|  | ||||
|         # 检测头 | ||||
|         self.detection_head = nn.Sequential( | ||||
|             nn.Conv2d(512, 256, kernel_size=3, padding=1), | ||||
|             nn.ReLU(inplace=True), | ||||
| @@ -18,16 +27,8 @@ class RoRD(nn.Module): | ||||
|             nn.Sigmoid() | ||||
|         ) | ||||
|          | ||||
|         # 普通描述子头(D2-Net 风格) | ||||
|         self.descriptor_head_vanilla = nn.Sequential( | ||||
|             nn.Conv2d(512, 256, kernel_size=3, padding=1), | ||||
|             nn.ReLU(inplace=True), | ||||
|             nn.Conv2d(256, 128, kernel_size=1), | ||||
|             nn.InstanceNorm2d(128) | ||||
|         ) | ||||
|          | ||||
|         # RoRD 描述子头(旋转鲁棒) | ||||
|         self.descriptor_head_rord = nn.Sequential( | ||||
|         # 描述子头 | ||||
|         self.descriptor_head = nn.Sequential( | ||||
|             nn.Conv2d(512, 256, kernel_size=3, padding=1), | ||||
|             nn.ReLU(inplace=True), | ||||
|             nn.Conv2d(256, 128, kernel_size=1), | ||||
| @@ -35,13 +36,14 @@ class RoRD(nn.Module): | ||||
|         ) | ||||
|  | ||||
|     def forward(self, x): | ||||
|         # 检测分支 | ||||
|         features_det = self.backbone_det(x) | ||||
|         detection = self.detection_head(features_det) | ||||
|         # 共享特征提取 | ||||
|         features_shared = self.slice1(x) | ||||
|          | ||||
|         # 描述分支 | ||||
|         features_desc = self.backbone_desc(x) | ||||
|         desc_vanilla = self.descriptor_head_vanilla(features_desc) | ||||
|         desc_rord = self.descriptor_head_rord(features_desc) | ||||
|         # 描述子分支 | ||||
|         descriptors = self.descriptor_head(features_shared) | ||||
|          | ||||
|         return detection, desc_vanilla, desc_rord | ||||
|         # 检测器分支 | ||||
|         features_det = self.slice2(features_shared) | ||||
|         detection_map = self.detection_head(features_det) | ||||
|          | ||||
|         return detection_map, descriptors | ||||
							
								
								
									
										258
									
								
								train.py
									
									
									
									
									
								
							
							
						
						
									
										258
									
								
								train.py
									
									
									
									
									
								
							| @@ -1,236 +1,142 @@ | ||||
| # train.py | ||||
|  | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| import torch.nn.functional as F | ||||
| from torch.utils.data import Dataset, DataLoader | ||||
| from torchvision import transforms | ||||
| from PIL import Image | ||||
| import numpy as np | ||||
| import cv2 | ||||
| import os | ||||
| from models.rord import RoRD | ||||
| import argparse | ||||
|  | ||||
| # 数据集类:生成随机旋转的训练对 | ||||
| # 导入项目模块 | ||||
| import config | ||||
| from models.rord import RoRD | ||||
| from utils.data_utils import get_transform | ||||
|  | ||||
| # --- 训练专用数据集类 --- | ||||
| class ICLayoutTrainingDataset(Dataset): | ||||
|     def __init__(self, image_dir, patch_size=256, transform=None): | ||||
|         """ | ||||
|         初始化 IC 版图训练数据集。 | ||||
|  | ||||
|         参数: | ||||
|             image_dir (str): 存储 PNG 格式 IC 版图图像的目录路径。 | ||||
|             patch_size (int): 裁剪的 patch 大小(默认 256x256)。 | ||||
|             transform (callable, optional): 应用于图像的变换。 | ||||
|         """ | ||||
|         self.image_dir = image_dir | ||||
|         self.image_paths = [os.path.join(image_dir, f) for f in os.listdir(image_dir) if f.endswith('.png')] | ||||
|         self.patch_size = patch_size | ||||
|         self.transform = transform | ||||
|  | ||||
|     def __len__(self): | ||||
|         """ | ||||
|         返回数据集中的图像数量。 | ||||
|  | ||||
|         返回: | ||||
|             int: 数据集大小。 | ||||
|         """ | ||||
|         return len(self.image_paths) | ||||
|  | ||||
|     def __getitem__(self, index): | ||||
|         """ | ||||
|         获取指定索引的训练对(原始 patch、旋转 patch、Homography 矩阵)。 | ||||
|  | ||||
|         参数: | ||||
|             index (int): 图像索引。 | ||||
|  | ||||
|         返回: | ||||
|             tuple: (patch, rotated_patch, H_tensor) | ||||
|                 - patch: 原始 patch 张量。 | ||||
|                 - rotated_patch: 旋转后的 patch 张量。 | ||||
|                 - H_tensor: Homography 矩阵张量。 | ||||
|         """ | ||||
|         img_path = self.image_paths[index] | ||||
|         image = Image.open(img_path).convert('L')  # 灰度图像 | ||||
|         image = Image.open(img_path).convert('L') | ||||
|  | ||||
|         # 获取图像大小 | ||||
|         W, H = image.size | ||||
|  | ||||
|         # 随机选择裁剪的左上角坐标 | ||||
|         x = np.random.randint(0, W - self.patch_size + 1) | ||||
|         y = np.random.randint(0, H - self.patch_size + 1) | ||||
|         patch = image.crop((x, y, x + self.patch_size, y + self.patch_size)) | ||||
|  | ||||
|         # 转换为 NumPy 数组 | ||||
|         patch_np = np.array(patch) | ||||
|          | ||||
|         # 随机旋转角度(0°~360°) | ||||
|         theta = np.random.uniform(0, 360) | ||||
|         theta_rad = np.deg2rad(theta) | ||||
|         cos_theta = np.cos(theta_rad) | ||||
|         sin_theta = np.sin(theta_rad) | ||||
|         # 实现8个方向的离散几何变换 | ||||
|         theta_deg = np.random.choice([0, 90, 180, 270]) | ||||
|         is_mirrored = np.random.choice([True, False]) | ||||
|         cx, cy = self.patch_size / 2.0, self.patch_size / 2.0 | ||||
|         M = cv2.getRotationMatrix2D((cx, cy), theta_deg, 1) | ||||
|  | ||||
|         # 计算旋转中心(patch 的中心) | ||||
|         cx = self.patch_size / 2.0 | ||||
|         cy = self.patch_size / 2.0 | ||||
|         if is_mirrored: | ||||
|             T1 = np.array([[1, 0, -cx], [0, 1, -cy], [0, 0, 1]]) | ||||
|             Flip = np.array([[-1, 0, 0], [0, 1, 0], [0, 0, 1]]) | ||||
|             T2 = np.array([[1, 0, cx], [0, 1, cy], [0, 0, 1]]) | ||||
|             M_mirror_3x3 = T2 @ Flip @ T1 | ||||
|             M_3x3 = np.vstack([M, [0, 0, 1]]) | ||||
|             H = (M_3x3 @ M_mirror_3x3).astype(np.float32) | ||||
|         else: | ||||
|             H = np.vstack([M, [0, 0, 1]]).astype(np.float32) | ||||
|          | ||||
|         # 计算旋转的齐次矩阵(Homography) | ||||
|         H = np.array([ | ||||
|             [cos_theta, -sin_theta, cx * (1 - cos_theta) + cy * sin_theta], | ||||
|             [sin_theta,  cos_theta, cy * (1 - cos_theta) - cx * sin_theta], | ||||
|             [0,          0,          1] | ||||
|         ], dtype=np.float32) | ||||
|         transformed_patch_np = cv2.warpPerspective(patch_np, H, (self.patch_size, self.patch_size)) | ||||
|         transformed_patch = Image.fromarray(transformed_patch_np) | ||||
|  | ||||
|         # 应用旋转到 patch | ||||
|         rotated_patch_np = cv2.warpPerspective(patch_np, H, (self.patch_size, self.patch_size)) | ||||
|  | ||||
|         # 转换回 PIL Image | ||||
|         rotated_patch = Image.fromarray(rotated_patch_np) | ||||
|  | ||||
|         # 应用变换 | ||||
|         if self.transform: | ||||
|             patch = self.transform(patch) | ||||
|             rotated_patch = self.transform(rotated_patch) | ||||
|             transformed_patch = self.transform(transformed_patch) | ||||
|  | ||||
|         # 转换 H 为张量 | ||||
|         H_tensor = torch.from_numpy(H).float() | ||||
|         H_tensor = torch.from_numpy(H[:2, :]).float() # 通常损失函数需要2x3的仿射矩阵 | ||||
|         return patch, transformed_patch, H_tensor | ||||
|  | ||||
|         return patch, rotated_patch, H_tensor | ||||
|  | ||||
| # 特征图变换函数 | ||||
| # --- 特征图变换与损失函数 --- | ||||
| def warp_feature_map(feature_map, H_inv): | ||||
|     """ | ||||
|     使用逆 Homography 矩阵变换特征图。 | ||||
|  | ||||
|     参数: | ||||
|         feature_map (torch.Tensor): 输入特征图,形状为 [B, C, H, W]。 | ||||
|         H_inv (torch.Tensor): 逆 Homography 矩阵,形状为 [B, 3, 3]。 | ||||
|  | ||||
|     返回: | ||||
|         torch.Tensor: 变换后的特征图,形状为 [B, C, H, W]。 | ||||
|     """ | ||||
|     B, C, H, W = feature_map.size() | ||||
|     # 生成网格 | ||||
|     grid_y, grid_x = torch.meshgrid( | ||||
|         torch.linspace(-1, 1, H, device=feature_map.device), | ||||
|         torch.linspace(-1, 1, W, device=feature_map.device), | ||||
|         indexing='ij' | ||||
|     ) | ||||
|     grid = torch.stack((grid_x, grid_y, torch.ones_like(grid_x)), dim=-1)  # [H, W, 3] | ||||
|     grid = grid.unsqueeze(0).expand(B, H, W, 3)  # [B, H, W, 3] | ||||
|     grid = F.affine_grid(H_inv, feature_map.size(), align_corners=False).to(feature_map.device) | ||||
|     return F.grid_sample(feature_map, grid, align_corners=False) | ||||
|  | ||||
|     # 将网格转换为齐次坐标并应用 H_inv | ||||
|     grid_flat = grid.view(B, -1, 3)  # [B, H*W, 3] | ||||
|     grid_transformed = torch.bmm(grid_flat, H_inv.transpose(1, 2))  # [B, H*W, 3] | ||||
|     grid_transformed = grid_transformed.view(B, H, W, 3)  # [B, H, W, 3] | ||||
|     grid_transformed = grid_transformed[..., :2] / (grid_transformed[..., 2:3] + 1e-8)  # [B, H, W, 2] | ||||
|  | ||||
|     # 使用 grid_sample 进行变换 | ||||
|     warped_feature = F.grid_sample(feature_map, grid_transformed, align_corners=True) | ||||
|     return warped_feature | ||||
|  | ||||
| # 检测损失函数 | ||||
| def compute_detection_loss(det_original, det_rotated, H): | ||||
|     """ | ||||
|     计算检测损失(MSE),比较原始检测图与旋转检测图(逆变换后)。 | ||||
|  | ||||
|     参数: | ||||
|         det_original (torch.Tensor): 原始图像的检测图,形状为 [B, 1, H, W]。 | ||||
|         det_rotated (torch.Tensor): 旋转图像的检测图,形状为 [B, 1, H, W]。 | ||||
|         H (torch.Tensor): Homography 矩阵,形状为 [B, 3, 3]。 | ||||
|  | ||||
|     返回: | ||||
|         torch.Tensor: 检测损失。 | ||||
|     """ | ||||
|     H_inv = torch.inverse(H)  # 计算逆 Homography | ||||
|     with torch.no_grad(): | ||||
|         H_inv = torch.inverse(torch.cat([H, torch.tensor([0.0, 0.0, 1.0]).view(1, 1, 3).repeat(H.shape[0], 1, 1)], dim=1))[:, :2, :] | ||||
|     warped_det_rotated = warp_feature_map(det_rotated, H_inv) | ||||
|     return F.mse_loss(det_original, warped_det_rotated) | ||||
|  | ||||
| # 描述子损失函数 | ||||
| def compute_description_loss(desc_original, desc_rotated, H, margin=1.0): | ||||
|     """ | ||||
|     计算描述子损失(三元组损失),基于对应点的描述子。 | ||||
|     B, C, H_feat, W_feat = desc_original.size() | ||||
|     num_samples = 100 | ||||
|      | ||||
|     参数: | ||||
|         desc_original (torch.Tensor): 原始图像的描述子图,形状为 [B, 128, H, W]。 | ||||
|         desc_rotated (torch.Tensor): 旋转图像的描述子图,形状为 [B, 128, H, W]。 | ||||
|         H (torch.Tensor): Homography 矩阵,形状为 [B, 3, 3]。 | ||||
|         margin (float): 三元组损失的边距。 | ||||
|     # 随机采样锚点坐标 | ||||
|     coords = torch.rand(B, num_samples, 2, device=desc_original.device) * 2 - 1  # [-1, 1] | ||||
|      | ||||
|     返回: | ||||
|         torch.Tensor: 描述子损失。 | ||||
|     """ | ||||
|     B, C, H, W = desc_original.size() | ||||
|     # 随机选择锚点(anchor) | ||||
|     num_samples = min(100, H * W)  # 每张图像采样 100 个点 | ||||
|     idx = torch.randint(0, H * W, (B, num_samples), device=desc_original.device) | ||||
|     idx_y = idx // W | ||||
|     idx_x = idx % W | ||||
|     coords = torch.stack((idx_x.float(), idx_y.float()), dim=-1)  # [B, num_samples, 2] | ||||
|     # 提取锚点描述子 | ||||
|     anchor = F.grid_sample(desc_original, coords.unsqueeze(1), align_corners=False).squeeze(2).transpose(1, 2) | ||||
|      | ||||
|     # 转换为齐次坐标 | ||||
|     coords_hom = torch.cat((coords, torch.ones(B, num_samples, 1, device=coords.device)), dim=-1)  # [B, num_samples, 3] | ||||
|     coords_transformed = torch.bmm(coords_hom, H.transpose(1, 2))  # [B, num_samples, 3] | ||||
|     coords_transformed = coords_transformed[..., :2] / (coords_transformed[..., 2:3] + 1e-8)  # [B, num_samples, 2] | ||||
|     # 计算正样本坐标 | ||||
|     coords_hom = torch.cat([coords, torch.ones(B, num_samples, 1, device=coords.device)], dim=2) | ||||
|     M_inv = torch.inverse(torch.cat([H, torch.tensor([0.0, 0.0, 1.0]).view(1, 1, 3).repeat(H.shape[0], 1, 1)], dim=1)) | ||||
|     coords_transformed = (coords_hom @ M_inv.transpose(1, 2))[:, :, :2] | ||||
|      | ||||
|     # 归一化到 [-1, 1] 用于 grid_sample | ||||
|     coords_transformed = coords_transformed / torch.tensor([W/2, H/2], device=coords.device) - 1 | ||||
|     # 提取正样本描述子 | ||||
|     positive = F.grid_sample(desc_rotated, coords_transformed.unsqueeze(1), align_corners=False).squeeze(2).transpose(1, 2) | ||||
|      | ||||
|     # 提取锚点和正样本描述子 | ||||
|     anchor = desc_original.view(B, C, -1)[:, :, idx.view(-1)]  # [B, 128, num_samples] | ||||
|     positive = F.grid_sample(desc_rotated, coords_transformed.unsqueeze(2), align_corners=True).squeeze(3)  # [B, 128, num_samples] | ||||
|     # 随机采样负样本 | ||||
|     neg_coords = torch.rand(B, num_samples, 2, device=desc_original.device) * 2 - 1 | ||||
|     negative = F.grid_sample(desc_rotated, neg_coords.unsqueeze(1), align_corners=False).squeeze(2).transpose(1, 2) | ||||
|  | ||||
|     # 随机选择负样本 | ||||
|     neg_idx = torch.randint(0, H * W, (B, num_samples), device=desc_original.device) | ||||
|     negative = desc_rotated.view(B, C, -1)[:, :, neg_idx.view(-1)]  # [B, 128, num_samples] | ||||
|  | ||||
|     # 三元组损失 | ||||
|     triplet_loss = nn.TripletMarginLoss(margin=margin, p=2) | ||||
|     loss = triplet_loss(anchor.transpose(1, 2), positive.transpose(1, 2), negative.transpose(1, 2)) | ||||
|     return loss | ||||
|     return triplet_loss(anchor, positive, negative) | ||||
|  | ||||
| # 定义变换 | ||||
| transform = transforms.Compose([ | ||||
|     transforms.ToTensor(),  # (1, 256, 256) | ||||
|     transforms.Lambda(lambda x: x.repeat(3, 1, 1)),  # (3, 256, 256) | ||||
|     transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) | ||||
| ]) | ||||
| # --- 主函数与命令行接口 --- | ||||
| def main(args): | ||||
|     print("--- 开始训练 RoRD 模型 ---") | ||||
|     print(f"训练参数: Epochs={args.epochs}, Batch Size={args.batch_size}, LR={args.lr}") | ||||
|     transform = get_transform() | ||||
|     dataset = ICLayoutTrainingDataset(args.data_dir, patch_size=config.PATCH_SIZE, transform=transform) | ||||
|     dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=4) | ||||
|     model = RoRD().cuda() | ||||
|     optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) | ||||
|  | ||||
| # 创建数据集和 DataLoader | ||||
| dataset = ICLayoutTrainingDataset('path/to/layouts', patch_size=256, transform=transform) | ||||
| dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=4) | ||||
|  | ||||
| # 定义模型 | ||||
| model = RoRD().cuda() | ||||
|  | ||||
| # 定义优化器 | ||||
| optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) | ||||
|  | ||||
| # 训练循环 | ||||
| num_epochs = 10 | ||||
| for epoch in range(num_epochs): | ||||
|     for epoch in range(args.epochs): | ||||
|         model.train() | ||||
|     total_loss = 0 | ||||
|     for batch in dataloader: | ||||
|         original, rotated, H = batch | ||||
|         original = original.cuda() | ||||
|         rotated = rotated.cuda() | ||||
|         H = H.cuda() | ||||
|         total_loss_val = 0 | ||||
|         for i, (original, rotated, H) in enumerate(dataloader): | ||||
|             original, rotated, H = original.cuda(), rotated.cuda(), H.cuda() | ||||
|             det_original, desc_original = model(original) | ||||
|             det_rotated, desc_rotated = model(rotated) | ||||
|  | ||||
|         # 前向传播 | ||||
|         det_original, _, desc_rord_original = model(original) | ||||
|         det_rotated, _, desc_rord_rotated = model(rotated) | ||||
|             loss = compute_detection_loss(det_original, det_rotated, H) + compute_description_loss(desc_original, desc_rotated, H) | ||||
|              | ||||
|         # 计算损失 | ||||
|         detection_loss = compute_detection_loss(det_original, det_rotated, H) | ||||
|         description_loss = compute_description_loss(desc_rord_original, desc_rord_rotated, H) | ||||
|         total_loss_batch = detection_loss + description_loss | ||||
|  | ||||
|         # 反向传播 | ||||
|             optimizer.zero_grad() | ||||
|         total_loss_batch.backward() | ||||
|             loss.backward() | ||||
|             optimizer.step() | ||||
|             total_loss_val += loss.item() | ||||
|  | ||||
|         total_loss += total_loss_batch.item() | ||||
|         print(f"--- Epoch {epoch+1} 完成, 平均 Loss: {total_loss_val / len(dataloader):.4f} ---") | ||||
|  | ||||
|     print(f"Epoch {epoch+1}/{num_epochs}, Loss: {total_loss / len(dataloader):.4f}") | ||||
|     if not os.path.exists(args.save_dir): | ||||
|         os.makedirs(args.save_dir) | ||||
|     save_path = os.path.join(args.save_dir, 'rord_model_final.pth') | ||||
|     torch.save(model.state_dict(), save_path) | ||||
|     print(f"模型已保存至: {save_path}") | ||||
|  | ||||
| # 保存模型 | ||||
| torch.save(model.state_dict(), 'path/to/save/model.pth') | ||||
| if __name__ == "__main__": | ||||
|     parser = argparse.ArgumentParser(description="训练 RoRD 模型") | ||||
|     parser.add_argument('--data_dir', type=str, default=config.LAYOUT_DIR) | ||||
|     parser.add_argument('--save_dir', type=str, default=config.SAVE_DIR) | ||||
|     parser.add_argument('--epochs', type=int, default=config.NUM_EPOCHS) | ||||
|     parser.add_argument('--batch_size', type=int, default=config.BATCH_SIZE) | ||||
|     parser.add_argument('--lr', type=float, default=config.LEARNING_RATE) | ||||
|     main(parser.parse_args()) | ||||
							
								
								
									
										0
									
								
								utils/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								utils/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
								
								
									
										14
									
								
								utils/data_utils.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										14
									
								
								utils/data_utils.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,14 @@ | ||||
| from torchvision import transforms | ||||
| from .transforms import SobelTransform | ||||
|  | ||||
| def get_transform(): | ||||
|     """ | ||||
|     获取统一的图像预处理管道。 | ||||
|     确保训练、评估和推理使用完全相同的预处理。 | ||||
|     """ | ||||
|     return transforms.Compose([ | ||||
|         SobelTransform(),  # 应用 Sobel 边缘检测 | ||||
|         transforms.ToTensor(), | ||||
|         transforms.Lambda(lambda x: x.repeat(3, 1, 1)), # 适配 VGG 的三通道输入 | ||||
|         transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) | ||||
|     ]) | ||||
		Reference in New Issue
	
	Block a user
	 Jiao77
					Jiao77