第二次大修
This commit is contained in:
54
README.md
54
README.md
@@ -58,49 +58,33 @@ ic_layout_recognition/
|
|||||||
└── README.md
|
└── README.md
|
||||||
```
|
```
|
||||||
|
|
||||||
### 1. 训练模型
|
## 🚀 使用方法
|
||||||
|
|
||||||
使用以下命令启动模型训练。训练过程采用自监督学习,通过对图像应用随机旋转来生成训练对,从而优化关键点检测和描述子生成。
|
### 1. 配置
|
||||||
|
首先,请修改 **`config.py`** 文件,设置正确的训练数据、验证数据和模型保存路径。
|
||||||
|
|
||||||
|
### 2. 训练模型
|
||||||
```bash
|
```bash
|
||||||
python train.py --data_dir path/to/layouts --save_dir path/to/save
|
python train.py --data_dir /path/to/your/layouts --save_dir /path/to/your/models --epochs 50
|
||||||
|
```
|
||||||
|
使用 `--help` 查看更多选项。
|
||||||
|
|
||||||
|
### 3. 模板匹配
|
||||||
|
```bash
|
||||||
|
python match.py --model_path /path/to/your/models/rord_model_final.pth \
|
||||||
|
--layout /path/to/layout.png \
|
||||||
|
--template /path/to/template.png \
|
||||||
|
--output /path/to/result.png
|
||||||
```
|
```
|
||||||
|
|
||||||
| 参数 | 描述 |
|
### 4. 评估模型
|
||||||
| :--- | :--- |
|
|
||||||
| `--data_dir` | **[必需]** 包含 PNG 格式 IC 版图图像的目录。 |
|
|
||||||
| `--save_dir` | **[必需]** 训练好的模型权重保存目录。 |
|
|
||||||
|
|
||||||
### 2. 评估模型
|
|
||||||
|
|
||||||
使用以下命令在验证集上评估模型的性能。评估脚本会计算基于 IoU 阈值的精确率、召回率和 F1 分数。
|
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python evaluate.py --model_path path/to/model.pth --val_dir path/to/val/images --annotations_dir path/to/val/annotations --templates path/to/templates
|
python evaluate.py --model_path /path/to/your/models/rord_model_final.pth \
|
||||||
|
--val_dir /path/to/val/images \
|
||||||
|
--annotations_dir /path/to/val/annotations \
|
||||||
|
--templates_dir /path/to/templates
|
||||||
```
|
```
|
||||||
|
|
||||||
| 参数 | 描述 |
|
|
||||||
| :--- | :--- |
|
|
||||||
| `--model_path` | **[必需]** 训练好的模型权重 (`.pth`) 文件路径。 |
|
|
||||||
| `--val_dir` | **[必需]** 验证集图像目录。 |
|
|
||||||
| `--annotations_dir` | **[必需]** 包含真实标注的 JSON 文件目录。 |
|
|
||||||
| `--templates` | **[必需]** 模板图像的路径列表。 |
|
|
||||||
|
|
||||||
### 3. 进行模板匹配
|
|
||||||
|
|
||||||
使用以下命令将模板图像与指定的版图图像进行匹配。匹配过程利用 RoRD 模型提取关键点和描述子,通过互最近邻(MNN)匹配和 RANSAC 几何验证来定位模板。
|
|
||||||
|
|
||||||
```bash
|
|
||||||
python match.py --model_path path/to/model.pth --layout_path path/to/layout.png --template_path path/to/template.png --output_path path/to/output.png
|
|
||||||
```
|
|
||||||
|
|
||||||
| 参数 | 描述 |
|
|
||||||
| :--- | :--- |
|
|
||||||
| `--model_path` | **[必需]** 训练好的模型权重 (`.pth`) 文件路径。 |
|
|
||||||
| `--layout_path` | **[必需]** 待匹配的版图图像路径。 |
|
|
||||||
| `--template_path` | **[必需]** 模板图像路径。 |
|
|
||||||
| `--output_path` | **[可选]** 保存可视化匹配结果的路径。 |
|
|
||||||
|
|
||||||
## 📦 数据准备
|
## 📦 数据准备
|
||||||
|
|
||||||
### 训练数据
|
### 训练数据
|
||||||
|
|||||||
31
config.py
Normal file
31
config.py
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
# config.py
|
||||||
|
|
||||||
|
# --- 训练参数 ---
|
||||||
|
LEARNING_RATE = 1e-4
|
||||||
|
BATCH_SIZE = 4
|
||||||
|
NUM_EPOCHS = 20 # 增加了训练轮数
|
||||||
|
PATCH_SIZE = 256
|
||||||
|
|
||||||
|
# --- 匹配与评估参数 ---
|
||||||
|
# 关键点检测的置信度阈值
|
||||||
|
KEYPOINT_THRESHOLD = 0.5
|
||||||
|
# RANSAC 重投影误差阈值(像素)
|
||||||
|
RANSAC_REPROJ_THRESHOLD = 5.0
|
||||||
|
# RANSAC 判定为有效匹配所需的最小内点数
|
||||||
|
MIN_INLIERS = 15 # 适当提高以增加匹配的可靠性
|
||||||
|
# IoU (Intersection over Union) 阈值,用于评估
|
||||||
|
IOU_THRESHOLD = 0.5
|
||||||
|
|
||||||
|
# --- 文件路径 ---
|
||||||
|
# 训练数据目录
|
||||||
|
LAYOUT_DIR = 'path/to/layouts'
|
||||||
|
# 模型保存目录
|
||||||
|
SAVE_DIR = 'path/to/save'
|
||||||
|
# 验证集图像目录
|
||||||
|
VAL_IMG_DIR = 'path/to/val/images'
|
||||||
|
# 验证集标注目录
|
||||||
|
VAL_ANN_DIR = 'path/to/val/annotations'
|
||||||
|
# 模板图像目录
|
||||||
|
TEMPLATE_DIR = 'path/to/templates'
|
||||||
|
# 默认加载的模型路径
|
||||||
|
MODEL_PATH = 'path/to/save/model_final.pth'
|
||||||
0
data/__init__.py
Normal file
0
data/__init__.py
Normal file
147
evaluate.py
147
evaluate.py
@@ -1,123 +1,84 @@
|
|||||||
from models.rord import RoRD
|
# evaluate.py
|
||||||
from data.ic_dataset import ICLayoutDataset
|
|
||||||
from utils.transforms import SobelTransform
|
|
||||||
from match import match_template_to_layout
|
|
||||||
import torch
|
import torch
|
||||||
from torchvision import transforms
|
from PIL import Image
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from PIL import Image
|
import argparse
|
||||||
|
|
||||||
|
import config
|
||||||
|
from models.rord import RoRD
|
||||||
|
from utils.data_utils import get_transform
|
||||||
|
from data.ic_dataset import ICLayoutDataset
|
||||||
|
from match import match_template_to_layout
|
||||||
|
|
||||||
def compute_iou(box1, box2):
|
def compute_iou(box1, box2):
|
||||||
x1, y1, w1, h1 = box1['x'], box1['y'], box1['width'], box1['height']
|
x1, y1, w1, h1 = box1['x'], box1['y'], box1['width'], box1['height']
|
||||||
x2, y2, w2, h2 = box2['x'], box2['y'], box2['width'], box2['height']
|
x2, y2, w2, h2 = box2['x'], box2['y'], box2['width'], box2['height']
|
||||||
|
inter_x1, inter_y1 = max(x1, x2), max(y1, y2)
|
||||||
inter_x1 = max(x1, x2)
|
inter_x2, inter_y2 = min(x1 + w1, x2 + w2), min(y1 + h1, y2 + h2)
|
||||||
inter_y1 = max(y1, y2)
|
|
||||||
inter_x2 = min(x1 + w1, x2 + w2)
|
|
||||||
inter_y2 = min(y1 + h1, y2 + h2)
|
|
||||||
|
|
||||||
inter_area = max(0, inter_x2 - inter_x1) * max(0, inter_y2 - inter_y1)
|
inter_area = max(0, inter_x2 - inter_x1) * max(0, inter_y2 - inter_y1)
|
||||||
|
union_area = w1 * h1 + w2 * h2 - inter_area
|
||||||
|
return inter_area / union_area if union_area > 0 else 0
|
||||||
|
|
||||||
box1_area = w1 * h1
|
def evaluate(model, val_dataset, template_dir):
|
||||||
box2_area = w2 * h2
|
|
||||||
union_area = box1_area + box2_area - inter_area
|
|
||||||
|
|
||||||
iou = inter_area / union_area if union_area > 0 else 0
|
|
||||||
return iou
|
|
||||||
|
|
||||||
def evaluate(model, val_dataset, templates, iou_threshold=0.5):
|
|
||||||
model.eval()
|
model.eval()
|
||||||
all_true_positives = 0
|
all_tp, all_fp, all_fn = 0, 0, 0
|
||||||
all_false_positives = 0
|
transform = get_transform()
|
||||||
all_false_negatives = 0
|
|
||||||
|
|
||||||
for layout_idx in range(len(val_dataset)):
|
template_paths = [os.path.join(template_dir, f) for f in os.listdir(template_dir) if f.endswith('.png')]
|
||||||
layout_image, annotation = val_dataset[layout_idx]
|
|
||||||
# layout_image is [3, H, W]
|
|
||||||
layout_tensor = layout_image.unsqueeze(0).cuda() # [1, 3, H, W]
|
|
||||||
|
|
||||||
# 假设 annotation 是 {"boxes": [{"template": "template1.png", "x": x, "y": y, "width": w, "height": h}, ...]}
|
for layout_tensor, annotation in val_dataset:
|
||||||
gt_boxes_by_template = {}
|
layout_tensor = layout_tensor.unsqueeze(0).cuda()
|
||||||
|
gt_by_template = {box['template']: [] for box in annotation.get('boxes', [])}
|
||||||
for box in annotation.get('boxes', []):
|
for box in annotation.get('boxes', []):
|
||||||
template_name = box['template']
|
gt_by_template[box['template']].append(box)
|
||||||
if template_name not in gt_boxes_by_template:
|
|
||||||
gt_boxes_by_template[template_name] = []
|
|
||||||
gt_boxes_by_template[template_name].append(box)
|
|
||||||
|
|
||||||
for template_path in templates:
|
for template_path in template_paths:
|
||||||
template_name = os.path.basename(template_path)
|
template_name = os.path.basename(template_path)
|
||||||
template_image = Image.open(template_path).convert('L')
|
template_tensor = transform(Image.open(template_path).convert('L')).unsqueeze(0).cuda()
|
||||||
template_tensor = transform(template_image).unsqueeze(0).cuda() # [1, 3, H, W]
|
|
||||||
|
|
||||||
# 执行匹配
|
detected = match_template_to_layout(model, layout_tensor, template_tensor)
|
||||||
detected_bboxes = match_template_to_layout(model, layout_tensor, template_tensor)
|
gt_boxes = gt_by_template.get(template_name, [])
|
||||||
|
|
||||||
# 获取当前模板的 gt_boxes
|
matched_gt = [False] * len(gt_boxes)
|
||||||
gt_boxes = gt_boxes_by_template.get(template_name, [])
|
tp = 0
|
||||||
|
for det_box in detected:
|
||||||
# 初始化已分配的 gt_box 索引
|
|
||||||
assigned_gt = set()
|
|
||||||
|
|
||||||
for det_box in detected_bboxes:
|
|
||||||
best_iou = 0
|
best_iou = 0
|
||||||
best_gt_idx = -1
|
best_gt_idx = -1
|
||||||
for idx, gt_box in enumerate(gt_boxes):
|
for i, gt_box in enumerate(gt_boxes):
|
||||||
if idx in assigned_gt:
|
if matched_gt[i]: continue
|
||||||
continue
|
|
||||||
iou = compute_iou(det_box, gt_box)
|
iou = compute_iou(det_box, gt_box)
|
||||||
if iou > best_iou:
|
if iou > best_iou:
|
||||||
best_iou = iou
|
best_iou, best_gt_idx = iou, i
|
||||||
best_gt_idx = idx
|
|
||||||
if best_iou > iou_threshold and best_gt_idx != -1:
|
|
||||||
all_true_positives += 1
|
|
||||||
assigned_gt.add(best_gt_idx)
|
|
||||||
else:
|
|
||||||
all_false_positives += 1
|
|
||||||
|
|
||||||
# 计算 FN:未分配的 gt_box
|
if best_iou > config.IOU_THRESHOLD:
|
||||||
for idx in range(len(gt_boxes)):
|
tp += 1
|
||||||
if idx not in assigned_gt:
|
matched_gt[best_gt_idx] = True
|
||||||
all_false_negatives += 1
|
|
||||||
|
|
||||||
# 计算评估指标
|
all_tp += tp
|
||||||
precision = all_true_positives / (all_true_positives + all_false_positives) if (all_true_positives + all_false_positives) > 0 else 0
|
all_fp += len(detected) - tp
|
||||||
recall = all_true_positives / (all_true_positives + all_false_negatives) if (all_true_positives + all_false_negatives) > 0 else 0
|
all_fn += len(gt_boxes) - tp
|
||||||
|
|
||||||
|
precision = all_tp / (all_tp + all_fp) if (all_tp + all_fp) > 0 else 0
|
||||||
|
recall = all_tp / (all_tp + all_fn) if (all_tp + all_fn) > 0 else 0
|
||||||
f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
|
f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
|
||||||
|
return {'precision': precision, 'recall': recall, 'f1': f1}
|
||||||
return {
|
|
||||||
'precision': precision,
|
|
||||||
'recall': recall,
|
|
||||||
'f1': f1
|
|
||||||
}
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# 设置变换
|
parser = argparse.ArgumentParser(description="评估 RoRD 模型性能")
|
||||||
transform = transforms.Compose([
|
parser.add_argument('--model_path', type=str, default=config.MODEL_PATH)
|
||||||
SobelTransform(),
|
parser.add_argument('--val_dir', type=str, default=config.VAL_IMG_DIR)
|
||||||
transforms.ToTensor(),
|
parser.add_argument('--annotations_dir', type=str, default=config.VAL_ANN_DIR)
|
||||||
transforms.Lambda(lambda x: x.repeat(3, 1, 1)), # [1, H, W] -> [3, H, W]
|
parser.add_argument('--templates_dir', type=str, default=config.TEMPLATE_DIR)
|
||||||
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
args = parser.parse_args()
|
||||||
])
|
|
||||||
|
|
||||||
# 加载模型
|
|
||||||
model = RoRD().cuda()
|
model = RoRD().cuda()
|
||||||
model.load_state_dict(torch.load('path/to/weights.pth'))
|
model.load_state_dict(torch.load(args.model_path))
|
||||||
model.eval()
|
val_dataset = ICLayoutDataset(args.val_dir, args.annotations_dir, get_transform())
|
||||||
|
|
||||||
# 定义验证数据集
|
results = evaluate(model, val_dataset, args.templates_dir)
|
||||||
val_dataset = ICLayoutDataset(
|
|
||||||
image_dir='path/to/val/images',
|
|
||||||
annotation_dir='path/to/val/annotations',
|
|
||||||
transform=transform
|
|
||||||
)
|
|
||||||
|
|
||||||
# 定义模板列表
|
|
||||||
templates = ['path/to/templates/template1.png', 'path/to/templates/template2.png'] # 替换为实际模板路径
|
|
||||||
|
|
||||||
# 评估模型
|
|
||||||
results = evaluate(model, val_dataset, templates)
|
|
||||||
print("评估结果:")
|
print("评估结果:")
|
||||||
print(f"精确率: {results['precision']:.4f}")
|
print(f" 精确率 (Precision): {results['precision']:.4f}")
|
||||||
print(f"召回率: {results['recall']:.4f}")
|
print(f" 召回率 (Recall): {results['recall']:.4f}")
|
||||||
print(f"F1 分数: {results['f1']:.4f}")
|
print(f" F1 分数 (F1 Score): {results['f1']:.4f}")
|
||||||
231
match.py
231
match.py
@@ -1,199 +1,108 @@
|
|||||||
|
# match.py
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from models.rord import RoRD
|
|
||||||
from torchvision import transforms
|
|
||||||
from utils.transforms import SobelTransform
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import cv2
|
import cv2
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
|
||||||
def extract_keypoints_and_descriptors(model, image):
|
import config
|
||||||
"""
|
from models.rord import RoRD
|
||||||
从 RoRD 模型中提取关键点和描述子。
|
from utils.data_utils import get_transform
|
||||||
|
|
||||||
参数:
|
def extract_keypoints_and_descriptors(model, image, kp_thresh):
|
||||||
model (RoRD): RoRD 模型。
|
|
||||||
image (torch.Tensor): 输入图像张量,形状为 [1, 1, H, W]。
|
|
||||||
|
|
||||||
返回:
|
|
||||||
tuple: (keypoints_input, descriptors)
|
|
||||||
- keypoints_input: [N, 2] float tensor,关键点在输入图像中的坐标。
|
|
||||||
- descriptors: [N, 128] float tensor,L2 归一化的描述子。
|
|
||||||
"""
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
detection_map, _, desc_rord = model(image)
|
detection_map, desc = model(image)
|
||||||
desc = desc_rord # 使用 RoRD 描述子头
|
binary_map = (detection_map > kp_thresh).float()
|
||||||
|
coords = torch.nonzero(binary_map[0, 0]).float()
|
||||||
|
keypoints_input = coords[:, [1, 0]] * 8.0 # Stride of descriptor is 8
|
||||||
|
|
||||||
# 从检测图中提取关键点
|
descriptors = F.grid_sample(desc, coords.flip(1).view(1, -1, 1, 2) / torch.tensor([(desc.shape[3]-1)/2, (desc.shape[2]-1)/2], device=desc.device) - 1, align_corners=True).squeeze().T
|
||||||
thresh = 0.5
|
|
||||||
binary_map = (detection_map > thresh).float()
|
|
||||||
coords = torch.nonzero(binary_map[0, 0] > thresh).float() # [N, 2],每个行是 (i_d, j_d)
|
|
||||||
keypoints_input = coords * 16.0 # 将特征图坐标映射到输入图像坐标(stride=16)
|
|
||||||
|
|
||||||
# 从描述子图中提取描述子
|
|
||||||
# detection_map 的形状为 [1, 1, H/16, W/16],desc 的形状为 [1, 128, H/8, W/8]
|
|
||||||
# 将 detection_map 的坐标映射到 desc 的坐标:(i_d * 2, j_d * 2)
|
|
||||||
keypoints_desc = (coords * 2).long() # [N, 2],整数坐标
|
|
||||||
H_desc, W_desc = desc.shape[2], desc.shape[3]
|
|
||||||
mask = (keypoints_desc[:, 0] < H_desc) & (keypoints_desc[:, 1] < W_desc)
|
|
||||||
keypoints_desc = keypoints_desc[mask]
|
|
||||||
keypoints_input = keypoints_input[mask]
|
|
||||||
|
|
||||||
# 提取描述子
|
|
||||||
descriptors = desc[0, :, keypoints_desc[:, 0], keypoints_desc[:, 1]].T # [N, 128]
|
|
||||||
|
|
||||||
# L2 归一化描述子
|
|
||||||
descriptors = F.normalize(descriptors, p=2, dim=1)
|
descriptors = F.normalize(descriptors, p=2, dim=1)
|
||||||
|
|
||||||
return keypoints_input, descriptors
|
return keypoints_input, descriptors
|
||||||
|
|
||||||
def mutual_nearest_neighbor(template_descs, layout_descs):
|
def mutual_nearest_neighbor(descs1, descs2):
|
||||||
"""
|
sim = descs1 @ descs2.T
|
||||||
使用互最近邻(MNN)找到模板和版图之间的匹配。
|
nn12 = torch.max(sim, dim=1)
|
||||||
|
nn21 = torch.max(sim, dim=0)
|
||||||
参数:
|
ids1 = torch.arange(0, sim.shape[0], device=sim.device)
|
||||||
template_descs (torch.Tensor): 模板描述子,形状为 [M, 128]。
|
mask = (ids1 == nn21.indices[nn12.indices])
|
||||||
layout_descs (torch.Tensor): 版图描述子,形状为 [N, 128]。
|
matches = torch.stack([ids1[mask], nn12.indices[mask]], dim=1)
|
||||||
|
return matches.cpu().numpy()
|
||||||
返回:
|
|
||||||
list: [(i_template, i_layout)],互最近邻匹配对的列表。
|
|
||||||
"""
|
|
||||||
M, N = template_descs.size(0), layout_descs.size(0)
|
|
||||||
if M == 0 or N == 0:
|
|
||||||
return []
|
|
||||||
similarity_matrix = template_descs @ layout_descs.T # [M, N],点积矩阵
|
|
||||||
|
|
||||||
# 找到每个模板描述子的最近邻
|
|
||||||
nn_template_to_layout = torch.argmax(similarity_matrix, dim=1) # [M]
|
|
||||||
|
|
||||||
# 找到每个版图描述子的最近邻
|
|
||||||
nn_layout_to_template = torch.argmax(similarity_matrix, dim=0) # [N]
|
|
||||||
|
|
||||||
# 找到互最近邻
|
|
||||||
mutual_matches = []
|
|
||||||
for i in range(M):
|
|
||||||
j = nn_template_to_layout[i]
|
|
||||||
if nn_layout_to_template[j] == i:
|
|
||||||
mutual_matches.append((i.item(), j.item()))
|
|
||||||
|
|
||||||
return mutual_matches
|
|
||||||
|
|
||||||
def ransac_filter(matches, template_kps, layout_kps):
|
|
||||||
"""
|
|
||||||
使用 RANSAC 对匹配进行几何验证,并返回内点。
|
|
||||||
|
|
||||||
参数:
|
|
||||||
matches (list): [(i_template, i_layout)],匹配对列表。
|
|
||||||
template_kps (torch.Tensor): 模板关键点,形状为 [M, 2]。
|
|
||||||
layout_kps (torch.Tensor): 版图关键点,形状为 [N, 2]。
|
|
||||||
|
|
||||||
返回:
|
|
||||||
tuple: (inlier_matches, num_inliers)
|
|
||||||
- inlier_matches: [(i_template, i_layout)],内点匹配对。
|
|
||||||
- num_inliers: int,内点数量。
|
|
||||||
"""
|
|
||||||
src_pts = np.array([template_kps[i].cpu().numpy() for i, _ in matches])
|
|
||||||
dst_pts = np.array([layout_kps[j].cpu().numpy() for _, j in matches])
|
|
||||||
|
|
||||||
if len(src_pts) < 4:
|
|
||||||
return [], 0
|
|
||||||
|
|
||||||
try:
|
|
||||||
H, mask = cv2.findHomography(src_pts, dst_pts, cv2.RANSAC, ransacReprojThreshold=5.0)
|
|
||||||
if H is None:
|
|
||||||
return [], 0
|
|
||||||
inliers = mask.ravel() > 0
|
|
||||||
num_inliers = np.sum(inliers)
|
|
||||||
inlier_matches = [matches[k] for k in range(len(matches)) if inliers[k]]
|
|
||||||
return inlier_matches, num_inliers
|
|
||||||
except cv2.error:
|
|
||||||
return [], 0
|
|
||||||
|
|
||||||
def match_template_to_layout(model, layout_image, template_image):
|
def match_template_to_layout(model, layout_image, template_image):
|
||||||
"""
|
layout_kps, layout_descs = extract_keypoints_and_descriptors(model, layout_image, config.KEYPOINT_THRESHOLD)
|
||||||
使用 RoRD 模型执行模板匹配,迭代找到所有匹配并屏蔽已匹配区域。
|
template_kps, template_descs = extract_keypoints_and_descriptors(model, template_image, config.KEYPOINT_THRESHOLD)
|
||||||
|
|
||||||
参数:
|
active_layout_mask = torch.ones(len(layout_kps), dtype=bool, device=layout_kps.device)
|
||||||
model (RoRD): RoRD 模型。
|
found_instances = []
|
||||||
layout_image (torch.Tensor): 版图图像张量,形状为 [1, 1, H_layout, W_layout]。
|
|
||||||
template_image (torch.Tensor): 模板图像张量,形状为 [1, 1, H_template, W_template]。
|
|
||||||
|
|
||||||
返回:
|
|
||||||
list: [{'x': x_min, 'y': y_min, 'width': w, 'height': h}],所有检测到的边框。
|
|
||||||
"""
|
|
||||||
# 提取版图和模板的关键点和描述子
|
|
||||||
layout_kps, layout_descs = extract_keypoints_and_descriptors(model, layout_image)
|
|
||||||
template_kps, template_descs = extract_keypoints_and_descriptors(model, template_image)
|
|
||||||
|
|
||||||
# 初始化活动版图关键点掩码
|
|
||||||
active_layout = torch.ones(len(layout_kps), dtype=bool)
|
|
||||||
|
|
||||||
bboxes = []
|
|
||||||
while True:
|
while True:
|
||||||
# 获取当前活动的版图关键点和描述子
|
current_indices = torch.nonzero(active_layout_mask).squeeze(1)
|
||||||
current_layout_kps = layout_kps[active_layout]
|
if len(current_indices) < config.MIN_INLIERS:
|
||||||
current_layout_descs = layout_descs[active_layout]
|
|
||||||
|
|
||||||
if len(current_layout_descs) == 0:
|
|
||||||
break
|
break
|
||||||
|
|
||||||
# MNN 匹配
|
current_layout_kps, current_layout_descs = layout_kps[current_indices], layout_descs[current_indices]
|
||||||
matches = mutual_nearest_neighbor(template_descs, current_layout_descs)
|
matches = mutual_nearest_neighbor(template_descs, current_layout_descs)
|
||||||
|
|
||||||
if len(matches) == 0:
|
if len(matches) < 4: break
|
||||||
|
|
||||||
|
src_pts = template_kps[matches[:, 0]].cpu().numpy()
|
||||||
|
dst_pts = current_layout_kps[matches[:, 1]].cpu().numpy()
|
||||||
|
|
||||||
|
H, mask = cv2.findHomography(src_pts, dst_pts, cv2.RANSAC, config.RANSAC_REPROJ_THRESHOLD)
|
||||||
|
if H is None or mask.sum() < config.MIN_INLIERS:
|
||||||
break
|
break
|
||||||
|
|
||||||
# 将当前版图索引映射回原始版图索引
|
inlier_mask = mask.ravel().astype(bool)
|
||||||
active_indices = torch.nonzero(active_layout).squeeze(1)
|
|
||||||
matches_original = [(i_template, active_indices[i_layout].item()) for i_template, i_layout in matches]
|
|
||||||
|
|
||||||
# RANSAC 过滤
|
# 区域屏蔽逻辑
|
||||||
inlier_matches, num_inliers = ransac_filter(matches_original, template_kps, layout_kps)
|
inlier_layout_kps = dst_pts[inlier_mask]
|
||||||
|
x_min, y_min = inlier_layout_kps.min(axis=0)
|
||||||
|
x_max, y_max = inlier_layout_kps.max(axis=0)
|
||||||
|
|
||||||
if num_inliers > 10: # 设置内点阈值
|
instance = {'x': int(x_min), 'y': int(y_min), 'width': int(x_max - x_min), 'height': int(y_max - y_min), 'homography': H}
|
||||||
# 获取内点在版图中的关键点
|
found_instances.append(instance)
|
||||||
inlier_layout_kps = [layout_kps[j].cpu().numpy() for _, j in inlier_matches]
|
|
||||||
inlier_layout_kps = np.array(inlier_layout_kps)
|
|
||||||
|
|
||||||
# 计算边框
|
kp_x, kp_y = layout_kps[:, 0], layout_kps[:, 1]
|
||||||
x_min = int(inlier_layout_kps[:, 0].min())
|
region_mask = (kp_x >= x_min) & (kp_x <= x_max) & (kp_y >= y_min) & (kp_y <= y_max)
|
||||||
y_min = int(inlier_layout_kps[:, 1].min())
|
active_layout_mask[region_mask] = False
|
||||||
x_max = int(inlier_layout_kps[:, 0].max())
|
|
||||||
y_max = int(inlier_layout_kps[:, 1].max())
|
|
||||||
bboxes.append({'x': x_min, 'y': y_min, 'width': x_max - x_min, 'height': y_max - y_min})
|
|
||||||
|
|
||||||
# 屏蔽内点
|
print(f"找到实例,内点数: {mask.sum()}。剩余活动关键点: {active_layout_mask.sum()}")
|
||||||
for _, j in inlier_matches:
|
|
||||||
active_layout[j] = False
|
|
||||||
else:
|
|
||||||
break
|
|
||||||
|
|
||||||
return bboxes
|
return found_instances
|
||||||
|
|
||||||
|
def visualize_matches(layout_path, template_path, bboxes, output_path):
|
||||||
|
layout_img = cv2.imread(layout_path)
|
||||||
|
for i, bbox in enumerate(bboxes):
|
||||||
|
x, y, w, h = bbox['x'], bbox['y'], bbox['width'], bbox['height']
|
||||||
|
cv2.rectangle(layout_img, (x, y), (x + w, y + h), (0, 255, 0), 2)
|
||||||
|
cv2.putText(layout_img, f"Match {i+1}", (x, y - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)
|
||||||
|
cv2.imwrite(output_path, layout_img)
|
||||||
|
print(f"可视化结果已保存至: {output_path}")
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# 设置变换
|
parser = argparse.ArgumentParser(description="使用 RoRD 进行模板匹配")
|
||||||
transform = transforms.Compose([
|
parser.add_argument('--model_path', type=str, default=config.MODEL_PATH)
|
||||||
SobelTransform(),
|
parser.add_argument('--layout', type=str, required=True)
|
||||||
transforms.ToTensor(),
|
parser.add_argument('--template', type=str, required=True)
|
||||||
transforms.Normalize(mean=[0.5], std=[0.5])
|
parser.add_argument('--output', type=str)
|
||||||
])
|
args = parser.parse_args()
|
||||||
|
|
||||||
# 加载模型
|
transform = get_transform()
|
||||||
model = RoRD().cuda()
|
model = RoRD().cuda()
|
||||||
model.load_state_dict(torch.load('path/to/weights.pth'))
|
model.load_state_dict(torch.load(args.model_path))
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
# 加载版图和模板图像
|
layout_tensor = transform(Image.open(args.layout).convert('L')).unsqueeze(0).cuda()
|
||||||
layout_image = Image.open('path/to/layout.png').convert('L')
|
template_tensor = transform(Image.open(args.template).convert('L')).unsqueeze(0).cuda()
|
||||||
layout_tensor = transform(layout_image).unsqueeze(0).cuda()
|
|
||||||
|
|
||||||
template_image = Image.open('path/to/template.png').convert('L')
|
|
||||||
template_tensor = transform(template_image).unsqueeze(0).cuda()
|
|
||||||
|
|
||||||
# 执行匹配
|
|
||||||
detected_bboxes = match_template_to_layout(model, layout_tensor, template_tensor)
|
detected_bboxes = match_template_to_layout(model, layout_tensor, template_tensor)
|
||||||
|
print("\n检测到的边界框:")
|
||||||
# 打印检测到的边框
|
|
||||||
print("检测到的边框:")
|
|
||||||
for bbox in detected_bboxes:
|
for bbox in detected_bboxes:
|
||||||
print(bbox)
|
print(bbox)
|
||||||
|
|
||||||
|
if args.output:
|
||||||
|
visualize_matches(args.layout, args.template, detected_bboxes, args.output)
|
||||||
0
models/__init__.py
Normal file
0
models/__init__.py
Normal file
@@ -1,16 +1,25 @@
|
|||||||
|
# models/rord.py
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torchvision import models
|
from torchvision import models
|
||||||
|
|
||||||
class RoRD(nn.Module):
|
class RoRD(nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
"""
|
||||||
|
修复后的 RoRD 模型。
|
||||||
|
- 实现了共享骨干网络,以提高计算效率和减少内存占用。
|
||||||
|
- 移除了冗余的 descriptor_head_vanilla。
|
||||||
|
"""
|
||||||
super(RoRD, self).__init__()
|
super(RoRD, self).__init__()
|
||||||
# 检测骨干网络:VGG-16 直到 relu5_3(层 0 到 29)
|
|
||||||
self.backbone_det = models.vgg16(pretrained=True).features[:30]
|
|
||||||
# 描述骨干网络:VGG-16 直到 relu4_3(层 0 到 22)
|
|
||||||
self.backbone_desc = models.vgg16(pretrained=True).features[:23]
|
|
||||||
|
|
||||||
# 检测头:输出关键点概率图
|
vgg16_features = models.vgg16(pretrained=True).features
|
||||||
|
|
||||||
|
# 共享骨干网络
|
||||||
|
self.slice1 = vgg16_features[:23] # 到 relu4_3
|
||||||
|
self.slice2 = vgg16_features[23:30] # 从 relu4_3 到 relu5_3
|
||||||
|
|
||||||
|
# 检测头
|
||||||
self.detection_head = nn.Sequential(
|
self.detection_head = nn.Sequential(
|
||||||
nn.Conv2d(512, 256, kernel_size=3, padding=1),
|
nn.Conv2d(512, 256, kernel_size=3, padding=1),
|
||||||
nn.ReLU(inplace=True),
|
nn.ReLU(inplace=True),
|
||||||
@@ -18,16 +27,8 @@ class RoRD(nn.Module):
|
|||||||
nn.Sigmoid()
|
nn.Sigmoid()
|
||||||
)
|
)
|
||||||
|
|
||||||
# 普通描述子头(D2-Net 风格)
|
# 描述子头
|
||||||
self.descriptor_head_vanilla = nn.Sequential(
|
self.descriptor_head = nn.Sequential(
|
||||||
nn.Conv2d(512, 256, kernel_size=3, padding=1),
|
|
||||||
nn.ReLU(inplace=True),
|
|
||||||
nn.Conv2d(256, 128, kernel_size=1),
|
|
||||||
nn.InstanceNorm2d(128)
|
|
||||||
)
|
|
||||||
|
|
||||||
# RoRD 描述子头(旋转鲁棒)
|
|
||||||
self.descriptor_head_rord = nn.Sequential(
|
|
||||||
nn.Conv2d(512, 256, kernel_size=3, padding=1),
|
nn.Conv2d(512, 256, kernel_size=3, padding=1),
|
||||||
nn.ReLU(inplace=True),
|
nn.ReLU(inplace=True),
|
||||||
nn.Conv2d(256, 128, kernel_size=1),
|
nn.Conv2d(256, 128, kernel_size=1),
|
||||||
@@ -35,13 +36,14 @@ class RoRD(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
# 检测分支
|
# 共享特征提取
|
||||||
features_det = self.backbone_det(x)
|
features_shared = self.slice1(x)
|
||||||
detection = self.detection_head(features_det)
|
|
||||||
|
|
||||||
# 描述分支
|
# 描述子分支
|
||||||
features_desc = self.backbone_desc(x)
|
descriptors = self.descriptor_head(features_shared)
|
||||||
desc_vanilla = self.descriptor_head_vanilla(features_desc)
|
|
||||||
desc_rord = self.descriptor_head_rord(features_desc)
|
|
||||||
|
|
||||||
return detection, desc_vanilla, desc_rord
|
# 检测器分支
|
||||||
|
features_det = self.slice2(features_shared)
|
||||||
|
detection_map = self.detection_head(features_det)
|
||||||
|
|
||||||
|
return detection_map, descriptors
|
||||||
258
train.py
258
train.py
@@ -1,236 +1,142 @@
|
|||||||
|
# train.py
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch.utils.data import Dataset, DataLoader
|
from torch.utils.data import Dataset, DataLoader
|
||||||
from torchvision import transforms
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import cv2
|
import cv2
|
||||||
import os
|
import os
|
||||||
from models.rord import RoRD
|
import argparse
|
||||||
|
|
||||||
# 数据集类:生成随机旋转的训练对
|
# 导入项目模块
|
||||||
|
import config
|
||||||
|
from models.rord import RoRD
|
||||||
|
from utils.data_utils import get_transform
|
||||||
|
|
||||||
|
# --- 训练专用数据集类 ---
|
||||||
class ICLayoutTrainingDataset(Dataset):
|
class ICLayoutTrainingDataset(Dataset):
|
||||||
def __init__(self, image_dir, patch_size=256, transform=None):
|
def __init__(self, image_dir, patch_size=256, transform=None):
|
||||||
"""
|
|
||||||
初始化 IC 版图训练数据集。
|
|
||||||
|
|
||||||
参数:
|
|
||||||
image_dir (str): 存储 PNG 格式 IC 版图图像的目录路径。
|
|
||||||
patch_size (int): 裁剪的 patch 大小(默认 256x256)。
|
|
||||||
transform (callable, optional): 应用于图像的变换。
|
|
||||||
"""
|
|
||||||
self.image_dir = image_dir
|
self.image_dir = image_dir
|
||||||
self.image_paths = [os.path.join(image_dir, f) for f in os.listdir(image_dir) if f.endswith('.png')]
|
self.image_paths = [os.path.join(image_dir, f) for f in os.listdir(image_dir) if f.endswith('.png')]
|
||||||
self.patch_size = patch_size
|
self.patch_size = patch_size
|
||||||
self.transform = transform
|
self.transform = transform
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
"""
|
|
||||||
返回数据集中的图像数量。
|
|
||||||
|
|
||||||
返回:
|
|
||||||
int: 数据集大小。
|
|
||||||
"""
|
|
||||||
return len(self.image_paths)
|
return len(self.image_paths)
|
||||||
|
|
||||||
def __getitem__(self, index):
|
def __getitem__(self, index):
|
||||||
"""
|
|
||||||
获取指定索引的训练对(原始 patch、旋转 patch、Homography 矩阵)。
|
|
||||||
|
|
||||||
参数:
|
|
||||||
index (int): 图像索引。
|
|
||||||
|
|
||||||
返回:
|
|
||||||
tuple: (patch, rotated_patch, H_tensor)
|
|
||||||
- patch: 原始 patch 张量。
|
|
||||||
- rotated_patch: 旋转后的 patch 张量。
|
|
||||||
- H_tensor: Homography 矩阵张量。
|
|
||||||
"""
|
|
||||||
img_path = self.image_paths[index]
|
img_path = self.image_paths[index]
|
||||||
image = Image.open(img_path).convert('L') # 灰度图像
|
image = Image.open(img_path).convert('L')
|
||||||
|
|
||||||
# 获取图像大小
|
|
||||||
W, H = image.size
|
W, H = image.size
|
||||||
|
|
||||||
# 随机选择裁剪的左上角坐标
|
|
||||||
x = np.random.randint(0, W - self.patch_size + 1)
|
x = np.random.randint(0, W - self.patch_size + 1)
|
||||||
y = np.random.randint(0, H - self.patch_size + 1)
|
y = np.random.randint(0, H - self.patch_size + 1)
|
||||||
patch = image.crop((x, y, x + self.patch_size, y + self.patch_size))
|
patch = image.crop((x, y, x + self.patch_size, y + self.patch_size))
|
||||||
|
|
||||||
# 转换为 NumPy 数组
|
|
||||||
patch_np = np.array(patch)
|
patch_np = np.array(patch)
|
||||||
|
|
||||||
# 随机旋转角度(0°~360°)
|
# 实现8个方向的离散几何变换
|
||||||
theta = np.random.uniform(0, 360)
|
theta_deg = np.random.choice([0, 90, 180, 270])
|
||||||
theta_rad = np.deg2rad(theta)
|
is_mirrored = np.random.choice([True, False])
|
||||||
cos_theta = np.cos(theta_rad)
|
cx, cy = self.patch_size / 2.0, self.patch_size / 2.0
|
||||||
sin_theta = np.sin(theta_rad)
|
M = cv2.getRotationMatrix2D((cx, cy), theta_deg, 1)
|
||||||
|
|
||||||
# 计算旋转中心(patch 的中心)
|
if is_mirrored:
|
||||||
cx = self.patch_size / 2.0
|
T1 = np.array([[1, 0, -cx], [0, 1, -cy], [0, 0, 1]])
|
||||||
cy = self.patch_size / 2.0
|
Flip = np.array([[-1, 0, 0], [0, 1, 0], [0, 0, 1]])
|
||||||
|
T2 = np.array([[1, 0, cx], [0, 1, cy], [0, 0, 1]])
|
||||||
|
M_mirror_3x3 = T2 @ Flip @ T1
|
||||||
|
M_3x3 = np.vstack([M, [0, 0, 1]])
|
||||||
|
H = (M_3x3 @ M_mirror_3x3).astype(np.float32)
|
||||||
|
else:
|
||||||
|
H = np.vstack([M, [0, 0, 1]]).astype(np.float32)
|
||||||
|
|
||||||
# 计算旋转的齐次矩阵(Homography)
|
transformed_patch_np = cv2.warpPerspective(patch_np, H, (self.patch_size, self.patch_size))
|
||||||
H = np.array([
|
transformed_patch = Image.fromarray(transformed_patch_np)
|
||||||
[cos_theta, -sin_theta, cx * (1 - cos_theta) + cy * sin_theta],
|
|
||||||
[sin_theta, cos_theta, cy * (1 - cos_theta) - cx * sin_theta],
|
|
||||||
[0, 0, 1]
|
|
||||||
], dtype=np.float32)
|
|
||||||
|
|
||||||
# 应用旋转到 patch
|
|
||||||
rotated_patch_np = cv2.warpPerspective(patch_np, H, (self.patch_size, self.patch_size))
|
|
||||||
|
|
||||||
# 转换回 PIL Image
|
|
||||||
rotated_patch = Image.fromarray(rotated_patch_np)
|
|
||||||
|
|
||||||
# 应用变换
|
|
||||||
if self.transform:
|
if self.transform:
|
||||||
patch = self.transform(patch)
|
patch = self.transform(patch)
|
||||||
rotated_patch = self.transform(rotated_patch)
|
transformed_patch = self.transform(transformed_patch)
|
||||||
|
|
||||||
# 转换 H 为张量
|
H_tensor = torch.from_numpy(H[:2, :]).float() # 通常损失函数需要2x3的仿射矩阵
|
||||||
H_tensor = torch.from_numpy(H).float()
|
return patch, transformed_patch, H_tensor
|
||||||
|
|
||||||
return patch, rotated_patch, H_tensor
|
# --- 特征图变换与损失函数 ---
|
||||||
|
|
||||||
# 特征图变换函数
|
|
||||||
def warp_feature_map(feature_map, H_inv):
|
def warp_feature_map(feature_map, H_inv):
|
||||||
"""
|
|
||||||
使用逆 Homography 矩阵变换特征图。
|
|
||||||
|
|
||||||
参数:
|
|
||||||
feature_map (torch.Tensor): 输入特征图,形状为 [B, C, H, W]。
|
|
||||||
H_inv (torch.Tensor): 逆 Homography 矩阵,形状为 [B, 3, 3]。
|
|
||||||
|
|
||||||
返回:
|
|
||||||
torch.Tensor: 变换后的特征图,形状为 [B, C, H, W]。
|
|
||||||
"""
|
|
||||||
B, C, H, W = feature_map.size()
|
B, C, H, W = feature_map.size()
|
||||||
# 生成网格
|
grid = F.affine_grid(H_inv, feature_map.size(), align_corners=False).to(feature_map.device)
|
||||||
grid_y, grid_x = torch.meshgrid(
|
return F.grid_sample(feature_map, grid, align_corners=False)
|
||||||
torch.linspace(-1, 1, H, device=feature_map.device),
|
|
||||||
torch.linspace(-1, 1, W, device=feature_map.device),
|
|
||||||
indexing='ij'
|
|
||||||
)
|
|
||||||
grid = torch.stack((grid_x, grid_y, torch.ones_like(grid_x)), dim=-1) # [H, W, 3]
|
|
||||||
grid = grid.unsqueeze(0).expand(B, H, W, 3) # [B, H, W, 3]
|
|
||||||
|
|
||||||
# 将网格转换为齐次坐标并应用 H_inv
|
|
||||||
grid_flat = grid.view(B, -1, 3) # [B, H*W, 3]
|
|
||||||
grid_transformed = torch.bmm(grid_flat, H_inv.transpose(1, 2)) # [B, H*W, 3]
|
|
||||||
grid_transformed = grid_transformed.view(B, H, W, 3) # [B, H, W, 3]
|
|
||||||
grid_transformed = grid_transformed[..., :2] / (grid_transformed[..., 2:3] + 1e-8) # [B, H, W, 2]
|
|
||||||
|
|
||||||
# 使用 grid_sample 进行变换
|
|
||||||
warped_feature = F.grid_sample(feature_map, grid_transformed, align_corners=True)
|
|
||||||
return warped_feature
|
|
||||||
|
|
||||||
# 检测损失函数
|
|
||||||
def compute_detection_loss(det_original, det_rotated, H):
|
def compute_detection_loss(det_original, det_rotated, H):
|
||||||
"""
|
with torch.no_grad():
|
||||||
计算检测损失(MSE),比较原始检测图与旋转检测图(逆变换后)。
|
H_inv = torch.inverse(torch.cat([H, torch.tensor([0.0, 0.0, 1.0]).view(1, 1, 3).repeat(H.shape[0], 1, 1)], dim=1))[:, :2, :]
|
||||||
|
|
||||||
参数:
|
|
||||||
det_original (torch.Tensor): 原始图像的检测图,形状为 [B, 1, H, W]。
|
|
||||||
det_rotated (torch.Tensor): 旋转图像的检测图,形状为 [B, 1, H, W]。
|
|
||||||
H (torch.Tensor): Homography 矩阵,形状为 [B, 3, 3]。
|
|
||||||
|
|
||||||
返回:
|
|
||||||
torch.Tensor: 检测损失。
|
|
||||||
"""
|
|
||||||
H_inv = torch.inverse(H) # 计算逆 Homography
|
|
||||||
warped_det_rotated = warp_feature_map(det_rotated, H_inv)
|
warped_det_rotated = warp_feature_map(det_rotated, H_inv)
|
||||||
return F.mse_loss(det_original, warped_det_rotated)
|
return F.mse_loss(det_original, warped_det_rotated)
|
||||||
|
|
||||||
# 描述子损失函数
|
|
||||||
def compute_description_loss(desc_original, desc_rotated, H, margin=1.0):
|
def compute_description_loss(desc_original, desc_rotated, H, margin=1.0):
|
||||||
"""
|
B, C, H_feat, W_feat = desc_original.size()
|
||||||
计算描述子损失(三元组损失),基于对应点的描述子。
|
num_samples = 100
|
||||||
|
|
||||||
参数:
|
# 随机采样锚点坐标
|
||||||
desc_original (torch.Tensor): 原始图像的描述子图,形状为 [B, 128, H, W]。
|
coords = torch.rand(B, num_samples, 2, device=desc_original.device) * 2 - 1 # [-1, 1]
|
||||||
desc_rotated (torch.Tensor): 旋转图像的描述子图,形状为 [B, 128, H, W]。
|
|
||||||
H (torch.Tensor): Homography 矩阵,形状为 [B, 3, 3]。
|
|
||||||
margin (float): 三元组损失的边距。
|
|
||||||
|
|
||||||
返回:
|
# 提取锚点描述子
|
||||||
torch.Tensor: 描述子损失。
|
anchor = F.grid_sample(desc_original, coords.unsqueeze(1), align_corners=False).squeeze(2).transpose(1, 2)
|
||||||
"""
|
|
||||||
B, C, H, W = desc_original.size()
|
|
||||||
# 随机选择锚点(anchor)
|
|
||||||
num_samples = min(100, H * W) # 每张图像采样 100 个点
|
|
||||||
idx = torch.randint(0, H * W, (B, num_samples), device=desc_original.device)
|
|
||||||
idx_y = idx // W
|
|
||||||
idx_x = idx % W
|
|
||||||
coords = torch.stack((idx_x.float(), idx_y.float()), dim=-1) # [B, num_samples, 2]
|
|
||||||
|
|
||||||
# 转换为齐次坐标
|
# 计算正样本坐标
|
||||||
coords_hom = torch.cat((coords, torch.ones(B, num_samples, 1, device=coords.device)), dim=-1) # [B, num_samples, 3]
|
coords_hom = torch.cat([coords, torch.ones(B, num_samples, 1, device=coords.device)], dim=2)
|
||||||
coords_transformed = torch.bmm(coords_hom, H.transpose(1, 2)) # [B, num_samples, 3]
|
M_inv = torch.inverse(torch.cat([H, torch.tensor([0.0, 0.0, 1.0]).view(1, 1, 3).repeat(H.shape[0], 1, 1)], dim=1))
|
||||||
coords_transformed = coords_transformed[..., :2] / (coords_transformed[..., 2:3] + 1e-8) # [B, num_samples, 2]
|
coords_transformed = (coords_hom @ M_inv.transpose(1, 2))[:, :, :2]
|
||||||
|
|
||||||
# 归一化到 [-1, 1] 用于 grid_sample
|
# 提取正样本描述子
|
||||||
coords_transformed = coords_transformed / torch.tensor([W/2, H/2], device=coords.device) - 1
|
positive = F.grid_sample(desc_rotated, coords_transformed.unsqueeze(1), align_corners=False).squeeze(2).transpose(1, 2)
|
||||||
|
|
||||||
# 提取锚点和正样本描述子
|
# 随机采样负样本
|
||||||
anchor = desc_original.view(B, C, -1)[:, :, idx.view(-1)] # [B, 128, num_samples]
|
neg_coords = torch.rand(B, num_samples, 2, device=desc_original.device) * 2 - 1
|
||||||
positive = F.grid_sample(desc_rotated, coords_transformed.unsqueeze(2), align_corners=True).squeeze(3) # [B, 128, num_samples]
|
negative = F.grid_sample(desc_rotated, neg_coords.unsqueeze(1), align_corners=False).squeeze(2).transpose(1, 2)
|
||||||
|
|
||||||
# 随机选择负样本
|
|
||||||
neg_idx = torch.randint(0, H * W, (B, num_samples), device=desc_original.device)
|
|
||||||
negative = desc_rotated.view(B, C, -1)[:, :, neg_idx.view(-1)] # [B, 128, num_samples]
|
|
||||||
|
|
||||||
# 三元组损失
|
|
||||||
triplet_loss = nn.TripletMarginLoss(margin=margin, p=2)
|
triplet_loss = nn.TripletMarginLoss(margin=margin, p=2)
|
||||||
loss = triplet_loss(anchor.transpose(1, 2), positive.transpose(1, 2), negative.transpose(1, 2))
|
return triplet_loss(anchor, positive, negative)
|
||||||
return loss
|
|
||||||
|
|
||||||
# 定义变换
|
# --- 主函数与命令行接口 ---
|
||||||
transform = transforms.Compose([
|
def main(args):
|
||||||
transforms.ToTensor(), # (1, 256, 256)
|
print("--- 开始训练 RoRD 模型 ---")
|
||||||
transforms.Lambda(lambda x: x.repeat(3, 1, 1)), # (3, 256, 256)
|
print(f"训练参数: Epochs={args.epochs}, Batch Size={args.batch_size}, LR={args.lr}")
|
||||||
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
transform = get_transform()
|
||||||
])
|
dataset = ICLayoutTrainingDataset(args.data_dir, patch_size=config.PATCH_SIZE, transform=transform)
|
||||||
|
dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=4)
|
||||||
|
model = RoRD().cuda()
|
||||||
|
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
|
||||||
|
|
||||||
# 创建数据集和 DataLoader
|
for epoch in range(args.epochs):
|
||||||
dataset = ICLayoutTrainingDataset('path/to/layouts', patch_size=256, transform=transform)
|
|
||||||
dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=4)
|
|
||||||
|
|
||||||
# 定义模型
|
|
||||||
model = RoRD().cuda()
|
|
||||||
|
|
||||||
# 定义优化器
|
|
||||||
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
|
|
||||||
|
|
||||||
# 训练循环
|
|
||||||
num_epochs = 10
|
|
||||||
for epoch in range(num_epochs):
|
|
||||||
model.train()
|
model.train()
|
||||||
total_loss = 0
|
total_loss_val = 0
|
||||||
for batch in dataloader:
|
for i, (original, rotated, H) in enumerate(dataloader):
|
||||||
original, rotated, H = batch
|
original, rotated, H = original.cuda(), rotated.cuda(), H.cuda()
|
||||||
original = original.cuda()
|
det_original, desc_original = model(original)
|
||||||
rotated = rotated.cuda()
|
det_rotated, desc_rotated = model(rotated)
|
||||||
H = H.cuda()
|
|
||||||
|
|
||||||
# 前向传播
|
loss = compute_detection_loss(det_original, det_rotated, H) + compute_description_loss(desc_original, desc_rotated, H)
|
||||||
det_original, _, desc_rord_original = model(original)
|
|
||||||
det_rotated, _, desc_rord_rotated = model(rotated)
|
|
||||||
|
|
||||||
# 计算损失
|
|
||||||
detection_loss = compute_detection_loss(det_original, det_rotated, H)
|
|
||||||
description_loss = compute_description_loss(desc_rord_original, desc_rord_rotated, H)
|
|
||||||
total_loss_batch = detection_loss + description_loss
|
|
||||||
|
|
||||||
# 反向传播
|
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
total_loss_batch.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
total_loss_val += loss.item()
|
||||||
|
|
||||||
total_loss += total_loss_batch.item()
|
print(f"--- Epoch {epoch+1} 完成, 平均 Loss: {total_loss_val / len(dataloader):.4f} ---")
|
||||||
|
|
||||||
print(f"Epoch {epoch+1}/{num_epochs}, Loss: {total_loss / len(dataloader):.4f}")
|
if not os.path.exists(args.save_dir):
|
||||||
|
os.makedirs(args.save_dir)
|
||||||
|
save_path = os.path.join(args.save_dir, 'rord_model_final.pth')
|
||||||
|
torch.save(model.state_dict(), save_path)
|
||||||
|
print(f"模型已保存至: {save_path}")
|
||||||
|
|
||||||
# 保存模型
|
if __name__ == "__main__":
|
||||||
torch.save(model.state_dict(), 'path/to/save/model.pth')
|
parser = argparse.ArgumentParser(description="训练 RoRD 模型")
|
||||||
|
parser.add_argument('--data_dir', type=str, default=config.LAYOUT_DIR)
|
||||||
|
parser.add_argument('--save_dir', type=str, default=config.SAVE_DIR)
|
||||||
|
parser.add_argument('--epochs', type=int, default=config.NUM_EPOCHS)
|
||||||
|
parser.add_argument('--batch_size', type=int, default=config.BATCH_SIZE)
|
||||||
|
parser.add_argument('--lr', type=float, default=config.LEARNING_RATE)
|
||||||
|
main(parser.parse_args())
|
||||||
0
utils/__init__.py
Normal file
0
utils/__init__.py
Normal file
14
utils/data_utils.py
Normal file
14
utils/data_utils.py
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
from torchvision import transforms
|
||||||
|
from .transforms import SobelTransform
|
||||||
|
|
||||||
|
def get_transform():
|
||||||
|
"""
|
||||||
|
获取统一的图像预处理管道。
|
||||||
|
确保训练、评估和推理使用完全相同的预处理。
|
||||||
|
"""
|
||||||
|
return transforms.Compose([
|
||||||
|
SobelTransform(), # 应用 Sobel 边缘检测
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Lambda(lambda x: x.repeat(3, 1, 1)), # 适配 VGG 的三通道输入
|
||||||
|
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
||||||
|
])
|
||||||
Reference in New Issue
Block a user