第二次大修

This commit is contained in:
Jiao77
2025-06-08 15:38:56 +08:00
parent 53ef1ec99c
commit f0b2e1b605
10 changed files with 315 additions and 508 deletions

View File

@@ -58,49 +58,33 @@ ic_layout_recognition/
└── README.md
```
### 1. 训练模型
## 🚀 使用方法
使用以下命令启动模型训练。训练过程采用自监督学习,通过对图像应用随机旋转来生成训练对,从而优化关键点检测和描述子生成。
### 1. 配置
首先,请修改 **`config.py`** 文件,设置正确的训练数据、验证数据和模型保存路径。
### 2. 训练模型
```bash
python train.py --data_dir path/to/layouts --save_dir path/to/save
python train.py --data_dir /path/to/your/layouts --save_dir /path/to/your/models --epochs 50
```
使用 `--help` 查看更多选项。
### 3. 模板匹配
```bash
python match.py --model_path /path/to/your/models/rord_model_final.pth \
--layout /path/to/layout.png \
--template /path/to/template.png \
--output /path/to/result.png
```
| 参数 | 描述 |
| :--- | :--- |
| `--data_dir` | **[必需]** 包含 PNG 格式 IC 版图图像的目录。 |
| `--save_dir` | **[必需]** 训练好的模型权重保存目录。 |
### 2. 评估模型
使用以下命令在验证集上评估模型的性能。评估脚本会计算基于 IoU 阈值的精确率、召回率和 F1 分数。
### 4. 评估模型
```bash
python evaluate.py --model_path path/to/model.pth --val_dir path/to/val/images --annotations_dir path/to/val/annotations --templates path/to/templates
python evaluate.py --model_path /path/to/your/models/rord_model_final.pth \
--val_dir /path/to/val/images \
--annotations_dir /path/to/val/annotations \
--templates_dir /path/to/templates
```
| 参数 | 描述 |
| :--- | :--- |
| `--model_path` | **[必需]** 训练好的模型权重 (`.pth`) 文件路径。 |
| `--val_dir` | **[必需]** 验证集图像目录。 |
| `--annotations_dir` | **[必需]** 包含真实标注的 JSON 文件目录。 |
| `--templates` | **[必需]** 模板图像的路径列表。 |
### 3. 进行模板匹配
使用以下命令将模板图像与指定的版图图像进行匹配。匹配过程利用 RoRD 模型提取关键点和描述子通过互最近邻MNN匹配和 RANSAC 几何验证来定位模板。
```bash
python match.py --model_path path/to/model.pth --layout_path path/to/layout.png --template_path path/to/template.png --output_path path/to/output.png
```
| 参数 | 描述 |
| :--- | :--- |
| `--model_path` | **[必需]** 训练好的模型权重 (`.pth`) 文件路径。 |
| `--layout_path` | **[必需]** 待匹配的版图图像路径。 |
| `--template_path` | **[必需]** 模板图像路径。 |
| `--output_path` | **[可选]** 保存可视化匹配结果的路径。 |
## 📦 数据准备
### 训练数据

31
config.py Normal file
View File

@@ -0,0 +1,31 @@
# config.py
# --- 训练参数 ---
LEARNING_RATE = 1e-4
BATCH_SIZE = 4
NUM_EPOCHS = 20 # 增加了训练轮数
PATCH_SIZE = 256
# --- 匹配与评估参数 ---
# 关键点检测的置信度阈值
KEYPOINT_THRESHOLD = 0.5
# RANSAC 重投影误差阈值(像素)
RANSAC_REPROJ_THRESHOLD = 5.0
# RANSAC 判定为有效匹配所需的最小内点数
MIN_INLIERS = 15 # 适当提高以增加匹配的可靠性
# IoU (Intersection over Union) 阈值,用于评估
IOU_THRESHOLD = 0.5
# --- 文件路径 ---
# 训练数据目录
LAYOUT_DIR = 'path/to/layouts'
# 模型保存目录
SAVE_DIR = 'path/to/save'
# 验证集图像目录
VAL_IMG_DIR = 'path/to/val/images'
# 验证集标注目录
VAL_ANN_DIR = 'path/to/val/annotations'
# 模板图像目录
TEMPLATE_DIR = 'path/to/templates'
# 默认加载的模型路径
MODEL_PATH = 'path/to/save/model_final.pth'

0
data/__init__.py Normal file
View File

View File

@@ -1,123 +1,84 @@
from models.rord import RoRD
from data.ic_dataset import ICLayoutDataset
from utils.transforms import SobelTransform
from match import match_template_to_layout
# evaluate.py
import torch
from torchvision import transforms
from PIL import Image
import json
import os
from PIL import Image
import argparse
import config
from models.rord import RoRD
from utils.data_utils import get_transform
from data.ic_dataset import ICLayoutDataset
from match import match_template_to_layout
def compute_iou(box1, box2):
x1, y1, w1, h1 = box1['x'], box1['y'], box1['width'], box1['height']
x2, y2, w2, h2 = box2['x'], box2['y'], box2['width'], box2['height']
inter_x1 = max(x1, x2)
inter_y1 = max(y1, y2)
inter_x2 = min(x1 + w1, x2 + w2)
inter_y2 = min(y1 + h1, y2 + h2)
inter_x1, inter_y1 = max(x1, x2), max(y1, y2)
inter_x2, inter_y2 = min(x1 + w1, x2 + w2), min(y1 + h1, y2 + h2)
inter_area = max(0, inter_x2 - inter_x1) * max(0, inter_y2 - inter_y1)
union_area = w1 * h1 + w2 * h2 - inter_area
return inter_area / union_area if union_area > 0 else 0
box1_area = w1 * h1
box2_area = w2 * h2
union_area = box1_area + box2_area - inter_area
iou = inter_area / union_area if union_area > 0 else 0
return iou
def evaluate(model, val_dataset, templates, iou_threshold=0.5):
def evaluate(model, val_dataset, template_dir):
model.eval()
all_true_positives = 0
all_false_positives = 0
all_false_negatives = 0
all_tp, all_fp, all_fn = 0, 0, 0
transform = get_transform()
for layout_idx in range(len(val_dataset)):
layout_image, annotation = val_dataset[layout_idx]
# layout_image is [3, H, W]
layout_tensor = layout_image.unsqueeze(0).cuda() # [1, 3, H, W]
template_paths = [os.path.join(template_dir, f) for f in os.listdir(template_dir) if f.endswith('.png')]
# 假设 annotation 是 {"boxes": [{"template": "template1.png", "x": x, "y": y, "width": w, "height": h}, ...]}
gt_boxes_by_template = {}
for layout_tensor, annotation in val_dataset:
layout_tensor = layout_tensor.unsqueeze(0).cuda()
gt_by_template = {box['template']: [] for box in annotation.get('boxes', [])}
for box in annotation.get('boxes', []):
template_name = box['template']
if template_name not in gt_boxes_by_template:
gt_boxes_by_template[template_name] = []
gt_boxes_by_template[template_name].append(box)
gt_by_template[box['template']].append(box)
for template_path in templates:
for template_path in template_paths:
template_name = os.path.basename(template_path)
template_image = Image.open(template_path).convert('L')
template_tensor = transform(template_image).unsqueeze(0).cuda() # [1, 3, H, W]
template_tensor = transform(Image.open(template_path).convert('L')).unsqueeze(0).cuda()
# 执行匹配
detected_bboxes = match_template_to_layout(model, layout_tensor, template_tensor)
detected = match_template_to_layout(model, layout_tensor, template_tensor)
gt_boxes = gt_by_template.get(template_name, [])
# 获取当前模板的 gt_boxes
gt_boxes = gt_boxes_by_template.get(template_name, [])
# 初始化已分配的 gt_box 索引
assigned_gt = set()
for det_box in detected_bboxes:
matched_gt = [False] * len(gt_boxes)
tp = 0
for det_box in detected:
best_iou = 0
best_gt_idx = -1
for idx, gt_box in enumerate(gt_boxes):
if idx in assigned_gt:
continue
for i, gt_box in enumerate(gt_boxes):
if matched_gt[i]: continue
iou = compute_iou(det_box, gt_box)
if iou > best_iou:
best_iou = iou
best_gt_idx = idx
if best_iou > iou_threshold and best_gt_idx != -1:
all_true_positives += 1
assigned_gt.add(best_gt_idx)
else:
all_false_positives += 1
best_iou, best_gt_idx = iou, i
# 计算 FN未分配的 gt_box
for idx in range(len(gt_boxes)):
if idx not in assigned_gt:
all_false_negatives += 1
if best_iou > config.IOU_THRESHOLD:
tp += 1
matched_gt[best_gt_idx] = True
# 计算评估指标
precision = all_true_positives / (all_true_positives + all_false_positives) if (all_true_positives + all_false_positives) > 0 else 0
recall = all_true_positives / (all_true_positives + all_false_negatives) if (all_true_positives + all_false_negatives) > 0 else 0
all_tp += tp
all_fp += len(detected) - tp
all_fn += len(gt_boxes) - tp
precision = all_tp / (all_tp + all_fp) if (all_tp + all_fp) > 0 else 0
recall = all_tp / (all_tp + all_fn) if (all_tp + all_fn) > 0 else 0
f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
return {
'precision': precision,
'recall': recall,
'f1': f1
}
return {'precision': precision, 'recall': recall, 'f1': f1}
if __name__ == "__main__":
# 设置变换
transform = transforms.Compose([
SobelTransform(),
transforms.ToTensor(),
transforms.Lambda(lambda x: x.repeat(3, 1, 1)), # [1, H, W] -> [3, H, W]
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
parser = argparse.ArgumentParser(description="评估 RoRD 模型性能")
parser.add_argument('--model_path', type=str, default=config.MODEL_PATH)
parser.add_argument('--val_dir', type=str, default=config.VAL_IMG_DIR)
parser.add_argument('--annotations_dir', type=str, default=config.VAL_ANN_DIR)
parser.add_argument('--templates_dir', type=str, default=config.TEMPLATE_DIR)
args = parser.parse_args()
# 加载模型
model = RoRD().cuda()
model.load_state_dict(torch.load('path/to/weights.pth'))
model.eval()
model.load_state_dict(torch.load(args.model_path))
val_dataset = ICLayoutDataset(args.val_dir, args.annotations_dir, get_transform())
# 定义验证数据集
val_dataset = ICLayoutDataset(
image_dir='path/to/val/images',
annotation_dir='path/to/val/annotations',
transform=transform
)
# 定义模板列表
templates = ['path/to/templates/template1.png', 'path/to/templates/template2.png'] # 替换为实际模板路径
# 评估模型
results = evaluate(model, val_dataset, templates)
results = evaluate(model, val_dataset, args.templates_dir)
print("评估结果:")
print(f"精确率: {results['precision']:.4f}")
print(f"召回率: {results['recall']:.4f}")
print(f"F1 分数: {results['f1']:.4f}")
print(f" 精确率 (Precision): {results['precision']:.4f}")
print(f" 召回率 (Recall): {results['recall']:.4f}")
print(f" F1 分数 (F1 Score): {results['f1']:.4f}")

231
match.py
View File

@@ -1,199 +1,108 @@
# match.py
import torch
import torch.nn.functional as F
from models.rord import RoRD
from torchvision import transforms
from utils.transforms import SobelTransform
import numpy as np
import cv2
from PIL import Image
import argparse
import os
def extract_keypoints_and_descriptors(model, image):
"""
从 RoRD 模型中提取关键点和描述子。
import config
from models.rord import RoRD
from utils.data_utils import get_transform
参数:
model (RoRD): RoRD 模型。
image (torch.Tensor): 输入图像张量,形状为 [1, 1, H, W]。
返回:
tuple: (keypoints_input, descriptors)
- keypoints_input: [N, 2] float tensor关键点在输入图像中的坐标。
- descriptors: [N, 128] float tensorL2 归一化的描述子。
"""
def extract_keypoints_and_descriptors(model, image, kp_thresh):
with torch.no_grad():
detection_map, _, desc_rord = model(image)
desc = desc_rord # 使用 RoRD 描述子头
detection_map, desc = model(image)
binary_map = (detection_map > kp_thresh).float()
coords = torch.nonzero(binary_map[0, 0]).float()
keypoints_input = coords[:, [1, 0]] * 8.0 # Stride of descriptor is 8
# 从检测图中提取关键点
thresh = 0.5
binary_map = (detection_map > thresh).float()
coords = torch.nonzero(binary_map[0, 0] > thresh).float() # [N, 2],每个行是 (i_d, j_d)
keypoints_input = coords * 16.0 # 将特征图坐标映射到输入图像坐标stride=16
# 从描述子图中提取描述子
# detection_map 的形状为 [1, 1, H/16, W/16]desc 的形状为 [1, 128, H/8, W/8]
# 将 detection_map 的坐标映射到 desc 的坐标:(i_d * 2, j_d * 2)
keypoints_desc = (coords * 2).long() # [N, 2],整数坐标
H_desc, W_desc = desc.shape[2], desc.shape[3]
mask = (keypoints_desc[:, 0] < H_desc) & (keypoints_desc[:, 1] < W_desc)
keypoints_desc = keypoints_desc[mask]
keypoints_input = keypoints_input[mask]
# 提取描述子
descriptors = desc[0, :, keypoints_desc[:, 0], keypoints_desc[:, 1]].T # [N, 128]
# L2 归一化描述子
descriptors = F.grid_sample(desc, coords.flip(1).view(1, -1, 1, 2) / torch.tensor([(desc.shape[3]-1)/2, (desc.shape[2]-1)/2], device=desc.device) - 1, align_corners=True).squeeze().T
descriptors = F.normalize(descriptors, p=2, dim=1)
return keypoints_input, descriptors
def mutual_nearest_neighbor(template_descs, layout_descs):
"""
使用互最近邻MNN找到模板和版图之间的匹配。
参数:
template_descs (torch.Tensor): 模板描述子,形状为 [M, 128]。
layout_descs (torch.Tensor): 版图描述子,形状为 [N, 128]。
返回:
list: [(i_template, i_layout)],互最近邻匹配对的列表。
"""
M, N = template_descs.size(0), layout_descs.size(0)
if M == 0 or N == 0:
return []
similarity_matrix = template_descs @ layout_descs.T # [M, N],点积矩阵
# 找到每个模板描述子的最近邻
nn_template_to_layout = torch.argmax(similarity_matrix, dim=1) # [M]
# 找到每个版图描述子的最近邻
nn_layout_to_template = torch.argmax(similarity_matrix, dim=0) # [N]
# 找到互最近邻
mutual_matches = []
for i in range(M):
j = nn_template_to_layout[i]
if nn_layout_to_template[j] == i:
mutual_matches.append((i.item(), j.item()))
return mutual_matches
def ransac_filter(matches, template_kps, layout_kps):
"""
使用 RANSAC 对匹配进行几何验证,并返回内点。
参数:
matches (list): [(i_template, i_layout)],匹配对列表。
template_kps (torch.Tensor): 模板关键点,形状为 [M, 2]。
layout_kps (torch.Tensor): 版图关键点,形状为 [N, 2]。
返回:
tuple: (inlier_matches, num_inliers)
- inlier_matches: [(i_template, i_layout)],内点匹配对。
- num_inliers: int内点数量。
"""
src_pts = np.array([template_kps[i].cpu().numpy() for i, _ in matches])
dst_pts = np.array([layout_kps[j].cpu().numpy() for _, j in matches])
if len(src_pts) < 4:
return [], 0
try:
H, mask = cv2.findHomography(src_pts, dst_pts, cv2.RANSAC, ransacReprojThreshold=5.0)
if H is None:
return [], 0
inliers = mask.ravel() > 0
num_inliers = np.sum(inliers)
inlier_matches = [matches[k] for k in range(len(matches)) if inliers[k]]
return inlier_matches, num_inliers
except cv2.error:
return [], 0
def mutual_nearest_neighbor(descs1, descs2):
sim = descs1 @ descs2.T
nn12 = torch.max(sim, dim=1)
nn21 = torch.max(sim, dim=0)
ids1 = torch.arange(0, sim.shape[0], device=sim.device)
mask = (ids1 == nn21.indices[nn12.indices])
matches = torch.stack([ids1[mask], nn12.indices[mask]], dim=1)
return matches.cpu().numpy()
def match_template_to_layout(model, layout_image, template_image):
"""
使用 RoRD 模型执行模板匹配,迭代找到所有匹配并屏蔽已匹配区域。
layout_kps, layout_descs = extract_keypoints_and_descriptors(model, layout_image, config.KEYPOINT_THRESHOLD)
template_kps, template_descs = extract_keypoints_and_descriptors(model, template_image, config.KEYPOINT_THRESHOLD)
参数:
model (RoRD): RoRD 模型。
layout_image (torch.Tensor): 版图图像张量,形状为 [1, 1, H_layout, W_layout]。
template_image (torch.Tensor): 模板图像张量,形状为 [1, 1, H_template, W_template]。
active_layout_mask = torch.ones(len(layout_kps), dtype=bool, device=layout_kps.device)
found_instances = []
返回:
list: [{'x': x_min, 'y': y_min, 'width': w, 'height': h}],所有检测到的边框。
"""
# 提取版图和模板的关键点和描述子
layout_kps, layout_descs = extract_keypoints_and_descriptors(model, layout_image)
template_kps, template_descs = extract_keypoints_and_descriptors(model, template_image)
# 初始化活动版图关键点掩码
active_layout = torch.ones(len(layout_kps), dtype=bool)
bboxes = []
while True:
# 获取当前活动的版图关键点和描述子
current_layout_kps = layout_kps[active_layout]
current_layout_descs = layout_descs[active_layout]
if len(current_layout_descs) == 0:
current_indices = torch.nonzero(active_layout_mask).squeeze(1)
if len(current_indices) < config.MIN_INLIERS:
break
# MNN 匹配
current_layout_kps, current_layout_descs = layout_kps[current_indices], layout_descs[current_indices]
matches = mutual_nearest_neighbor(template_descs, current_layout_descs)
if len(matches) == 0:
if len(matches) < 4: break
src_pts = template_kps[matches[:, 0]].cpu().numpy()
dst_pts = current_layout_kps[matches[:, 1]].cpu().numpy()
H, mask = cv2.findHomography(src_pts, dst_pts, cv2.RANSAC, config.RANSAC_REPROJ_THRESHOLD)
if H is None or mask.sum() < config.MIN_INLIERS:
break
# 将当前版图索引映射回原始版图索引
active_indices = torch.nonzero(active_layout).squeeze(1)
matches_original = [(i_template, active_indices[i_layout].item()) for i_template, i_layout in matches]
inlier_mask = mask.ravel().astype(bool)
# RANSAC 过滤
inlier_matches, num_inliers = ransac_filter(matches_original, template_kps, layout_kps)
# 区域屏蔽逻辑
inlier_layout_kps = dst_pts[inlier_mask]
x_min, y_min = inlier_layout_kps.min(axis=0)
x_max, y_max = inlier_layout_kps.max(axis=0)
if num_inliers > 10: # 设置内点阈值
# 获取内点在版图中的关键点
inlier_layout_kps = [layout_kps[j].cpu().numpy() for _, j in inlier_matches]
inlier_layout_kps = np.array(inlier_layout_kps)
instance = {'x': int(x_min), 'y': int(y_min), 'width': int(x_max - x_min), 'height': int(y_max - y_min), 'homography': H}
found_instances.append(instance)
# 计算边框
x_min = int(inlier_layout_kps[:, 0].min())
y_min = int(inlier_layout_kps[:, 1].min())
x_max = int(inlier_layout_kps[:, 0].max())
y_max = int(inlier_layout_kps[:, 1].max())
bboxes.append({'x': x_min, 'y': y_min, 'width': x_max - x_min, 'height': y_max - y_min})
kp_x, kp_y = layout_kps[:, 0], layout_kps[:, 1]
region_mask = (kp_x >= x_min) & (kp_x <= x_max) & (kp_y >= y_min) & (kp_y <= y_max)
active_layout_mask[region_mask] = False
# 屏蔽内点
for _, j in inlier_matches:
active_layout[j] = False
else:
break
print(f"找到实例,内点数: {mask.sum()}。剩余活动关键点: {active_layout_mask.sum()}")
return bboxes
return found_instances
def visualize_matches(layout_path, template_path, bboxes, output_path):
layout_img = cv2.imread(layout_path)
for i, bbox in enumerate(bboxes):
x, y, w, h = bbox['x'], bbox['y'], bbox['width'], bbox['height']
cv2.rectangle(layout_img, (x, y), (x + w, y + h), (0, 255, 0), 2)
cv2.putText(layout_img, f"Match {i+1}", (x, y - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)
cv2.imwrite(output_path, layout_img)
print(f"可视化结果已保存至: {output_path}")
if __name__ == "__main__":
# 设置变换
transform = transforms.Compose([
SobelTransform(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5], std=[0.5])
])
parser = argparse.ArgumentParser(description="使用 RoRD 进行模板匹配")
parser.add_argument('--model_path', type=str, default=config.MODEL_PATH)
parser.add_argument('--layout', type=str, required=True)
parser.add_argument('--template', type=str, required=True)
parser.add_argument('--output', type=str)
args = parser.parse_args()
# 加载模型
transform = get_transform()
model = RoRD().cuda()
model.load_state_dict(torch.load('path/to/weights.pth'))
model.load_state_dict(torch.load(args.model_path))
model.eval()
# 加载版图和模板图像
layout_image = Image.open('path/to/layout.png').convert('L')
layout_tensor = transform(layout_image).unsqueeze(0).cuda()
layout_tensor = transform(Image.open(args.layout).convert('L')).unsqueeze(0).cuda()
template_tensor = transform(Image.open(args.template).convert('L')).unsqueeze(0).cuda()
template_image = Image.open('path/to/template.png').convert('L')
template_tensor = transform(template_image).unsqueeze(0).cuda()
# 执行匹配
detected_bboxes = match_template_to_layout(model, layout_tensor, template_tensor)
# 打印检测到的边框
print("检测到的边框:")
print("\n检测到的边界框:")
for bbox in detected_bboxes:
print(bbox)
if args.output:
visualize_matches(args.layout, args.template, detected_bboxes, args.output)

0
models/__init__.py Normal file
View File

View File

@@ -1,16 +1,25 @@
# models/rord.py
import torch
import torch.nn as nn
from torchvision import models
class RoRD(nn.Module):
def __init__(self):
"""
修复后的 RoRD 模型。
- 实现了共享骨干网络,以提高计算效率和减少内存占用。
- 移除了冗余的 descriptor_head_vanilla。
"""
super(RoRD, self).__init__()
# 检测骨干网络VGG-16 直到 relu5_3层 0 到 29
self.backbone_det = models.vgg16(pretrained=True).features[:30]
# 描述骨干网络VGG-16 直到 relu4_3层 0 到 22
self.backbone_desc = models.vgg16(pretrained=True).features[:23]
# 检测头:输出关键点概率图
vgg16_features = models.vgg16(pretrained=True).features
# 共享骨干网络
self.slice1 = vgg16_features[:23] # 到 relu4_3
self.slice2 = vgg16_features[23:30] # 从 relu4_3 到 relu5_3
# 检测头
self.detection_head = nn.Sequential(
nn.Conv2d(512, 256, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
@@ -18,16 +27,8 @@ class RoRD(nn.Module):
nn.Sigmoid()
)
# 普通描述子头D2-Net 风格)
self.descriptor_head_vanilla = nn.Sequential(
nn.Conv2d(512, 256, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(256, 128, kernel_size=1),
nn.InstanceNorm2d(128)
)
# RoRD 描述子头(旋转鲁棒)
self.descriptor_head_rord = nn.Sequential(
# 描述子头
self.descriptor_head = nn.Sequential(
nn.Conv2d(512, 256, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(256, 128, kernel_size=1),
@@ -35,13 +36,14 @@ class RoRD(nn.Module):
)
def forward(self, x):
# 检测分支
features_det = self.backbone_det(x)
detection = self.detection_head(features_det)
# 共享特征提取
features_shared = self.slice1(x)
# 描述分支
features_desc = self.backbone_desc(x)
desc_vanilla = self.descriptor_head_vanilla(features_desc)
desc_rord = self.descriptor_head_rord(features_desc)
# 描述分支
descriptors = self.descriptor_head(features_shared)
return detection, desc_vanilla, desc_rord
# 检测器分支
features_det = self.slice2(features_shared)
detection_map = self.detection_head(features_det)
return detection_map, descriptors

258
train.py
View File

@@ -1,236 +1,142 @@
# train.py
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import numpy as np
import cv2
import os
from models.rord import RoRD
import argparse
# 数据集类:生成随机旋转的训练对
# 导入项目模块
import config
from models.rord import RoRD
from utils.data_utils import get_transform
# --- 训练专用数据集类 ---
class ICLayoutTrainingDataset(Dataset):
def __init__(self, image_dir, patch_size=256, transform=None):
"""
初始化 IC 版图训练数据集。
参数:
image_dir (str): 存储 PNG 格式 IC 版图图像的目录路径。
patch_size (int): 裁剪的 patch 大小(默认 256x256
transform (callable, optional): 应用于图像的变换。
"""
self.image_dir = image_dir
self.image_paths = [os.path.join(image_dir, f) for f in os.listdir(image_dir) if f.endswith('.png')]
self.patch_size = patch_size
self.transform = transform
def __len__(self):
"""
返回数据集中的图像数量。
返回:
int: 数据集大小。
"""
return len(self.image_paths)
def __getitem__(self, index):
"""
获取指定索引的训练对(原始 patch、旋转 patch、Homography 矩阵)。
参数:
index (int): 图像索引。
返回:
tuple: (patch, rotated_patch, H_tensor)
- patch: 原始 patch 张量。
- rotated_patch: 旋转后的 patch 张量。
- H_tensor: Homography 矩阵张量。
"""
img_path = self.image_paths[index]
image = Image.open(img_path).convert('L') # 灰度图像
image = Image.open(img_path).convert('L')
# 获取图像大小
W, H = image.size
# 随机选择裁剪的左上角坐标
x = np.random.randint(0, W - self.patch_size + 1)
y = np.random.randint(0, H - self.patch_size + 1)
patch = image.crop((x, y, x + self.patch_size, y + self.patch_size))
# 转换为 NumPy 数组
patch_np = np.array(patch)
# 随机旋转角度0°~360°
theta = np.random.uniform(0, 360)
theta_rad = np.deg2rad(theta)
cos_theta = np.cos(theta_rad)
sin_theta = np.sin(theta_rad)
# 实现8个方向的离散几何变换
theta_deg = np.random.choice([0, 90, 180, 270])
is_mirrored = np.random.choice([True, False])
cx, cy = self.patch_size / 2.0, self.patch_size / 2.0
M = cv2.getRotationMatrix2D((cx, cy), theta_deg, 1)
# 计算旋转中心patch 的中心)
cx = self.patch_size / 2.0
cy = self.patch_size / 2.0
if is_mirrored:
T1 = np.array([[1, 0, -cx], [0, 1, -cy], [0, 0, 1]])
Flip = np.array([[-1, 0, 0], [0, 1, 0], [0, 0, 1]])
T2 = np.array([[1, 0, cx], [0, 1, cy], [0, 0, 1]])
M_mirror_3x3 = T2 @ Flip @ T1
M_3x3 = np.vstack([M, [0, 0, 1]])
H = (M_3x3 @ M_mirror_3x3).astype(np.float32)
else:
H = np.vstack([M, [0, 0, 1]]).astype(np.float32)
# 计算旋转的齐次矩阵Homography
H = np.array([
[cos_theta, -sin_theta, cx * (1 - cos_theta) + cy * sin_theta],
[sin_theta, cos_theta, cy * (1 - cos_theta) - cx * sin_theta],
[0, 0, 1]
], dtype=np.float32)
transformed_patch_np = cv2.warpPerspective(patch_np, H, (self.patch_size, self.patch_size))
transformed_patch = Image.fromarray(transformed_patch_np)
# 应用旋转到 patch
rotated_patch_np = cv2.warpPerspective(patch_np, H, (self.patch_size, self.patch_size))
# 转换回 PIL Image
rotated_patch = Image.fromarray(rotated_patch_np)
# 应用变换
if self.transform:
patch = self.transform(patch)
rotated_patch = self.transform(rotated_patch)
transformed_patch = self.transform(transformed_patch)
# 转换 H 为张量
H_tensor = torch.from_numpy(H).float()
H_tensor = torch.from_numpy(H[:2, :]).float() # 通常损失函数需要2x3的仿射矩阵
return patch, transformed_patch, H_tensor
return patch, rotated_patch, H_tensor
# 特征图变换函数
# --- 特征图变换与损失函数 ---
def warp_feature_map(feature_map, H_inv):
"""
使用逆 Homography 矩阵变换特征图。
参数:
feature_map (torch.Tensor): 输入特征图,形状为 [B, C, H, W]。
H_inv (torch.Tensor): 逆 Homography 矩阵,形状为 [B, 3, 3]。
返回:
torch.Tensor: 变换后的特征图,形状为 [B, C, H, W]。
"""
B, C, H, W = feature_map.size()
# 生成网格
grid_y, grid_x = torch.meshgrid(
torch.linspace(-1, 1, H, device=feature_map.device),
torch.linspace(-1, 1, W, device=feature_map.device),
indexing='ij'
)
grid = torch.stack((grid_x, grid_y, torch.ones_like(grid_x)), dim=-1) # [H, W, 3]
grid = grid.unsqueeze(0).expand(B, H, W, 3) # [B, H, W, 3]
grid = F.affine_grid(H_inv, feature_map.size(), align_corners=False).to(feature_map.device)
return F.grid_sample(feature_map, grid, align_corners=False)
# 将网格转换为齐次坐标并应用 H_inv
grid_flat = grid.view(B, -1, 3) # [B, H*W, 3]
grid_transformed = torch.bmm(grid_flat, H_inv.transpose(1, 2)) # [B, H*W, 3]
grid_transformed = grid_transformed.view(B, H, W, 3) # [B, H, W, 3]
grid_transformed = grid_transformed[..., :2] / (grid_transformed[..., 2:3] + 1e-8) # [B, H, W, 2]
# 使用 grid_sample 进行变换
warped_feature = F.grid_sample(feature_map, grid_transformed, align_corners=True)
return warped_feature
# 检测损失函数
def compute_detection_loss(det_original, det_rotated, H):
"""
计算检测损失MSE比较原始检测图与旋转检测图逆变换后
参数:
det_original (torch.Tensor): 原始图像的检测图,形状为 [B, 1, H, W]。
det_rotated (torch.Tensor): 旋转图像的检测图,形状为 [B, 1, H, W]。
H (torch.Tensor): Homography 矩阵,形状为 [B, 3, 3]。
返回:
torch.Tensor: 检测损失。
"""
H_inv = torch.inverse(H) # 计算逆 Homography
with torch.no_grad():
H_inv = torch.inverse(torch.cat([H, torch.tensor([0.0, 0.0, 1.0]).view(1, 1, 3).repeat(H.shape[0], 1, 1)], dim=1))[:, :2, :]
warped_det_rotated = warp_feature_map(det_rotated, H_inv)
return F.mse_loss(det_original, warped_det_rotated)
# 描述子损失函数
def compute_description_loss(desc_original, desc_rotated, H, margin=1.0):
"""
计算描述子损失(三元组损失),基于对应点的描述子。
B, C, H_feat, W_feat = desc_original.size()
num_samples = 100
参数:
desc_original (torch.Tensor): 原始图像的描述子图,形状为 [B, 128, H, W]。
desc_rotated (torch.Tensor): 旋转图像的描述子图,形状为 [B, 128, H, W]。
H (torch.Tensor): Homography 矩阵,形状为 [B, 3, 3]。
margin (float): 三元组损失的边距。
# 随机采样锚点坐标
coords = torch.rand(B, num_samples, 2, device=desc_original.device) * 2 - 1 # [-1, 1]
返回:
torch.Tensor: 描述子损失。
"""
B, C, H, W = desc_original.size()
# 随机选择锚点anchor
num_samples = min(100, H * W) # 每张图像采样 100 个点
idx = torch.randint(0, H * W, (B, num_samples), device=desc_original.device)
idx_y = idx // W
idx_x = idx % W
coords = torch.stack((idx_x.float(), idx_y.float()), dim=-1) # [B, num_samples, 2]
# 提取锚点描述子
anchor = F.grid_sample(desc_original, coords.unsqueeze(1), align_corners=False).squeeze(2).transpose(1, 2)
# 转换为齐次坐标
coords_hom = torch.cat((coords, torch.ones(B, num_samples, 1, device=coords.device)), dim=-1) # [B, num_samples, 3]
coords_transformed = torch.bmm(coords_hom, H.transpose(1, 2)) # [B, num_samples, 3]
coords_transformed = coords_transformed[..., :2] / (coords_transformed[..., 2:3] + 1e-8) # [B, num_samples, 2]
# 计算正样本坐标
coords_hom = torch.cat([coords, torch.ones(B, num_samples, 1, device=coords.device)], dim=2)
M_inv = torch.inverse(torch.cat([H, torch.tensor([0.0, 0.0, 1.0]).view(1, 1, 3).repeat(H.shape[0], 1, 1)], dim=1))
coords_transformed = (coords_hom @ M_inv.transpose(1, 2))[:, :, :2]
# 归一化到 [-1, 1] 用于 grid_sample
coords_transformed = coords_transformed / torch.tensor([W/2, H/2], device=coords.device) - 1
# 提取正样本描述子
positive = F.grid_sample(desc_rotated, coords_transformed.unsqueeze(1), align_corners=False).squeeze(2).transpose(1, 2)
# 提取锚点和正样本描述子
anchor = desc_original.view(B, C, -1)[:, :, idx.view(-1)] # [B, 128, num_samples]
positive = F.grid_sample(desc_rotated, coords_transformed.unsqueeze(2), align_corners=True).squeeze(3) # [B, 128, num_samples]
# 随机采样负样本
neg_coords = torch.rand(B, num_samples, 2, device=desc_original.device) * 2 - 1
negative = F.grid_sample(desc_rotated, neg_coords.unsqueeze(1), align_corners=False).squeeze(2).transpose(1, 2)
# 随机选择负样本
neg_idx = torch.randint(0, H * W, (B, num_samples), device=desc_original.device)
negative = desc_rotated.view(B, C, -1)[:, :, neg_idx.view(-1)] # [B, 128, num_samples]
# 三元组损失
triplet_loss = nn.TripletMarginLoss(margin=margin, p=2)
loss = triplet_loss(anchor.transpose(1, 2), positive.transpose(1, 2), negative.transpose(1, 2))
return loss
return triplet_loss(anchor, positive, negative)
# 定义变换
transform = transforms.Compose([
transforms.ToTensor(), # (1, 256, 256)
transforms.Lambda(lambda x: x.repeat(3, 1, 1)), # (3, 256, 256)
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
# --- 主函数与命令行接口 ---
def main(args):
print("--- 开始训练 RoRD 模型 ---")
print(f"训练参数: Epochs={args.epochs}, Batch Size={args.batch_size}, LR={args.lr}")
transform = get_transform()
dataset = ICLayoutTrainingDataset(args.data_dir, patch_size=config.PATCH_SIZE, transform=transform)
dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=4)
model = RoRD().cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
# 创建数据集和 DataLoader
dataset = ICLayoutTrainingDataset('path/to/layouts', patch_size=256, transform=transform)
dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=4)
# 定义模型
model = RoRD().cuda()
# 定义优化器
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
# 训练循环
num_epochs = 10
for epoch in range(num_epochs):
for epoch in range(args.epochs):
model.train()
total_loss = 0
for batch in dataloader:
original, rotated, H = batch
original = original.cuda()
rotated = rotated.cuda()
H = H.cuda()
total_loss_val = 0
for i, (original, rotated, H) in enumerate(dataloader):
original, rotated, H = original.cuda(), rotated.cuda(), H.cuda()
det_original, desc_original = model(original)
det_rotated, desc_rotated = model(rotated)
# 前向传播
det_original, _, desc_rord_original = model(original)
det_rotated, _, desc_rord_rotated = model(rotated)
loss = compute_detection_loss(det_original, det_rotated, H) + compute_description_loss(desc_original, desc_rotated, H)
# 计算损失
detection_loss = compute_detection_loss(det_original, det_rotated, H)
description_loss = compute_description_loss(desc_rord_original, desc_rord_rotated, H)
total_loss_batch = detection_loss + description_loss
# 反向传播
optimizer.zero_grad()
total_loss_batch.backward()
loss.backward()
optimizer.step()
total_loss_val += loss.item()
total_loss += total_loss_batch.item()
print(f"--- Epoch {epoch+1} 完成, 平均 Loss: {total_loss_val / len(dataloader):.4f} ---")
print(f"Epoch {epoch+1}/{num_epochs}, Loss: {total_loss / len(dataloader):.4f}")
if not os.path.exists(args.save_dir):
os.makedirs(args.save_dir)
save_path = os.path.join(args.save_dir, 'rord_model_final.pth')
torch.save(model.state_dict(), save_path)
print(f"模型已保存至: {save_path}")
# 保存模型
torch.save(model.state_dict(), 'path/to/save/model.pth')
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="训练 RoRD 模型")
parser.add_argument('--data_dir', type=str, default=config.LAYOUT_DIR)
parser.add_argument('--save_dir', type=str, default=config.SAVE_DIR)
parser.add_argument('--epochs', type=int, default=config.NUM_EPOCHS)
parser.add_argument('--batch_size', type=int, default=config.BATCH_SIZE)
parser.add_argument('--lr', type=float, default=config.LEARNING_RATE)
main(parser.parse_args())

0
utils/__init__.py Normal file
View File

14
utils/data_utils.py Normal file
View File

@@ -0,0 +1,14 @@
from torchvision import transforms
from .transforms import SobelTransform
def get_transform():
"""
获取统一的图像预处理管道。
确保训练、评估和推理使用完全相同的预处理。
"""
return transforms.Compose([
SobelTransform(), # 应用 Sobel 边缘检测
transforms.ToTensor(),
transforms.Lambda(lambda x: x.repeat(3, 1, 1)), # 适配 VGG 的三通道输入
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])