add scale robust way
This commit is contained in:
22
config.py
22
config.py
@@ -3,29 +3,27 @@
|
|||||||
# --- 训练参数 ---
|
# --- 训练参数 ---
|
||||||
LEARNING_RATE = 1e-4
|
LEARNING_RATE = 1e-4
|
||||||
BATCH_SIZE = 4
|
BATCH_SIZE = 4
|
||||||
NUM_EPOCHS = 20 # 增加了训练轮数
|
NUM_EPOCHS = 20
|
||||||
PATCH_SIZE = 256
|
PATCH_SIZE = 256
|
||||||
|
# (新增) 训练时尺度抖动范围
|
||||||
|
SCALE_JITTER_RANGE = (0.7, 1.5)
|
||||||
|
|
||||||
# --- 匹配与评估参数 ---
|
# --- 匹配与评估参数 ---
|
||||||
# 关键点检测的置信度阈值
|
|
||||||
KEYPOINT_THRESHOLD = 0.5
|
KEYPOINT_THRESHOLD = 0.5
|
||||||
# RANSAC 重投影误差阈值(像素)
|
|
||||||
RANSAC_REPROJ_THRESHOLD = 5.0
|
RANSAC_REPROJ_THRESHOLD = 5.0
|
||||||
# RANSAC 判定为有效匹配所需的最小内点数
|
MIN_INLIERS = 15
|
||||||
MIN_INLIERS = 15 # 适当提高以增加匹配的可靠性
|
|
||||||
# IoU (Intersection over Union) 阈值,用于评估
|
|
||||||
IOU_THRESHOLD = 0.5
|
IOU_THRESHOLD = 0.5
|
||||||
|
# (新增) 推理时模板匹配的图像金字塔尺度
|
||||||
|
PYRAMID_SCALES = [0.75, 1.0, 1.5]
|
||||||
|
# (新增) 推理时处理大版图的滑动窗口参数
|
||||||
|
INFERENCE_WINDOW_SIZE = 1024
|
||||||
|
INFERENCE_STRIDE = 768 # 小于INFERENCE_WINDOW_SIZE以保证重叠
|
||||||
|
|
||||||
# --- 文件路径 ---
|
# --- 文件路径 ---
|
||||||
# 训练数据目录
|
# (路径保持不变, 请根据您的环境修改)
|
||||||
LAYOUT_DIR = 'path/to/layouts'
|
LAYOUT_DIR = 'path/to/layouts'
|
||||||
# 模型保存目录
|
|
||||||
SAVE_DIR = 'path/to/save'
|
SAVE_DIR = 'path/to/save'
|
||||||
# 验证集图像目录
|
|
||||||
VAL_IMG_DIR = 'path/to/val/images'
|
VAL_IMG_DIR = 'path/to/val/images'
|
||||||
# 验证集标注目录
|
|
||||||
VAL_ANN_DIR = 'path/to/val/annotations'
|
VAL_ANN_DIR = 'path/to/val/annotations'
|
||||||
# 模板图像目录
|
|
||||||
TEMPLATE_DIR = 'path/to/templates'
|
TEMPLATE_DIR = 'path/to/templates'
|
||||||
# 默认加载的模型路径
|
|
||||||
MODEL_PATH = 'path/to/save/model_final.pth'
|
MODEL_PATH = 'path/to/save/model_final.pth'
|
||||||
78
evaluate.py
78
evaluate.py
@@ -10,7 +10,8 @@ import config
|
|||||||
from models.rord import RoRD
|
from models.rord import RoRD
|
||||||
from utils.data_utils import get_transform
|
from utils.data_utils import get_transform
|
||||||
from data.ic_dataset import ICLayoutDataset
|
from data.ic_dataset import ICLayoutDataset
|
||||||
from match import match_template_to_layout
|
# (已修改) 导入新的匹配函数
|
||||||
|
from match import match_template_multiscale
|
||||||
|
|
||||||
def compute_iou(box1, box2):
|
def compute_iou(box1, box2):
|
||||||
x1, y1, w1, h1 = box1['x'], box1['y'], box1['width'], box1['height']
|
x1, y1, w1, h1 = box1['x'], box1['y'], box1['width'], box1['height']
|
||||||
@@ -21,45 +22,73 @@ def compute_iou(box1, box2):
|
|||||||
union_area = w1 * h1 + w2 * h2 - inter_area
|
union_area = w1 * h1 + w2 * h2 - inter_area
|
||||||
return inter_area / union_area if union_area > 0 else 0
|
return inter_area / union_area if union_area > 0 else 0
|
||||||
|
|
||||||
def evaluate(model, val_dataset, template_dir):
|
# --- (已修改) 评估函数 ---
|
||||||
|
def evaluate(model, val_dataset_dir, val_annotations_dir, template_dir):
|
||||||
model.eval()
|
model.eval()
|
||||||
all_tp, all_fp, all_fn = 0, 0, 0
|
all_tp, all_fp, all_fn = 0, 0, 0
|
||||||
|
|
||||||
|
# 只需要一个统一的 transform 给匹配函数内部使用
|
||||||
transform = get_transform()
|
transform = get_transform()
|
||||||
|
|
||||||
template_paths = [os.path.join(template_dir, f) for f in os.listdir(template_dir) if f.endswith('.png')]
|
template_paths = [os.path.join(template_dir, f) for f in os.listdir(template_dir) if f.endswith('.png')]
|
||||||
|
layout_image_names = [f for f in os.listdir(val_dataset_dir) if f.endswith('.png')]
|
||||||
|
|
||||||
for layout_tensor, annotation in val_dataset:
|
# (已修改) 循环遍历验证集中的每个版图文件
|
||||||
layout_tensor = layout_tensor.unsqueeze(0).cuda()
|
for layout_name in layout_image_names:
|
||||||
gt_by_template = {box['template']: [] for box in annotation.get('boxes', [])}
|
print(f"\n正在评估版图: {layout_name}")
|
||||||
|
layout_path = os.path.join(val_dataset_dir, layout_name)
|
||||||
|
annotation_path = os.path.join(val_annotations_dir, layout_name.replace('.png', '.json'))
|
||||||
|
|
||||||
|
# 加载原始PIL图像,以支持滑动窗口
|
||||||
|
layout_image = Image.open(layout_path).convert('L')
|
||||||
|
|
||||||
|
# 加载标注信息
|
||||||
|
if not os.path.exists(annotation_path):
|
||||||
|
continue
|
||||||
|
with open(annotation_path, 'r') as f:
|
||||||
|
annotation = json.load(f)
|
||||||
|
|
||||||
|
# 按模板对真实标注进行分组
|
||||||
|
gt_by_template = {os.path.basename(box['template']): [] for box in annotation.get('boxes', [])}
|
||||||
for box in annotation.get('boxes', []):
|
for box in annotation.get('boxes', []):
|
||||||
gt_by_template[box['template']].append(box)
|
gt_by_template[os.path.basename(box['template'])].append(box)
|
||||||
|
|
||||||
|
# 遍历每个模板,在当前版图上进行匹配
|
||||||
for template_path in template_paths:
|
for template_path in template_paths:
|
||||||
template_name = os.path.basename(template_path)
|
template_name = os.path.basename(template_path)
|
||||||
template_tensor = transform(Image.open(template_path).convert('L')).unsqueeze(0).cuda()
|
template_image = Image.open(template_path).convert('L')
|
||||||
|
|
||||||
|
# (已修改) 调用新的多尺度匹配函数
|
||||||
|
detected = match_template_multiscale(model, layout_image, template_image, transform)
|
||||||
|
|
||||||
detected = match_template_to_layout(model, layout_tensor, template_tensor)
|
|
||||||
gt_boxes = gt_by_template.get(template_name, [])
|
gt_boxes = gt_by_template.get(template_name, [])
|
||||||
|
|
||||||
|
# 计算 TP, FP, FN (这部分逻辑不变)
|
||||||
matched_gt = [False] * len(gt_boxes)
|
matched_gt = [False] * len(gt_boxes)
|
||||||
tp = 0
|
tp = 0
|
||||||
for det_box in detected:
|
if len(detected) > 0:
|
||||||
best_iou = 0
|
for det_box in detected:
|
||||||
best_gt_idx = -1
|
best_iou = 0
|
||||||
for i, gt_box in enumerate(gt_boxes):
|
best_gt_idx = -1
|
||||||
if matched_gt[i]: continue
|
for i, gt_box in enumerate(gt_boxes):
|
||||||
iou = compute_iou(det_box, gt_box)
|
if matched_gt[i]: continue
|
||||||
if iou > best_iou:
|
iou = compute_iou(det_box, gt_box)
|
||||||
best_iou, best_gt_idx = iou, i
|
if iou > best_iou:
|
||||||
|
best_iou, best_gt_idx = iou, i
|
||||||
|
|
||||||
if best_iou > config.IOU_THRESHOLD:
|
if best_iou > config.IOU_THRESHOLD:
|
||||||
tp += 1
|
if not matched_gt[best_gt_idx]:
|
||||||
matched_gt[best_gt_idx] = True
|
tp += 1
|
||||||
|
matched_gt[best_gt_idx] = True
|
||||||
|
|
||||||
|
fp = len(detected) - tp
|
||||||
|
fn = len(gt_boxes) - tp
|
||||||
|
|
||||||
all_tp += tp
|
all_tp += tp
|
||||||
all_fp += len(detected) - tp
|
all_fp += fp
|
||||||
all_fn += len(gt_boxes) - tp
|
all_fn += fn
|
||||||
|
|
||||||
|
# 计算最终指标
|
||||||
precision = all_tp / (all_tp + all_fp) if (all_tp + all_fp) > 0 else 0
|
precision = all_tp / (all_tp + all_fp) if (all_tp + all_fp) > 0 else 0
|
||||||
recall = all_tp / (all_tp + all_fn) if (all_tp + all_fn) > 0 else 0
|
recall = all_tp / (all_tp + all_fn) if (all_tp + all_fn) > 0 else 0
|
||||||
f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
|
f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
|
||||||
@@ -75,10 +104,11 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
model = RoRD().cuda()
|
model = RoRD().cuda()
|
||||||
model.load_state_dict(torch.load(args.model_path))
|
model.load_state_dict(torch.load(args.model_path))
|
||||||
val_dataset = ICLayoutDataset(args.val_dir, args.annotations_dir, get_transform())
|
|
||||||
|
|
||||||
results = evaluate(model, val_dataset, args.templates_dir)
|
# (已修改) 不再需要预加载数据集,直接传入路径
|
||||||
print("评估结果:")
|
results = evaluate(model, args.val_dir, args.annotations_dir, args.templates_dir)
|
||||||
|
|
||||||
|
print("\n--- 评估结果 ---")
|
||||||
print(f" 精确率 (Precision): {results['precision']:.4f}")
|
print(f" 精确率 (Precision): {results['precision']:.4f}")
|
||||||
print(f" 召回率 (Recall): {results['recall']:.4f}")
|
print(f" 召回率 (Recall): {results['recall']:.4f}")
|
||||||
print(f" F1 分数 (F1 Score): {results['f1']:.4f}")
|
print(f" F1 分数 (F1 Score): {results['f1']:.4f}")
|
||||||
195
match.py
195
match.py
@@ -12,69 +12,174 @@ import config
|
|||||||
from models.rord import RoRD
|
from models.rord import RoRD
|
||||||
from utils.data_utils import get_transform
|
from utils.data_utils import get_transform
|
||||||
|
|
||||||
def extract_keypoints_and_descriptors(model, image, kp_thresh):
|
# --- 特征提取函数 (基本无变动) ---
|
||||||
|
def extract_keypoints_and_descriptors(model, image_tensor, kp_thresh):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
detection_map, desc = model(image)
|
detection_map, desc = model(image_tensor)
|
||||||
binary_map = (detection_map > kp_thresh).float()
|
|
||||||
coords = torch.nonzero(binary_map[0, 0]).float()
|
|
||||||
keypoints_input = coords[:, [1, 0]] * 8.0 # Stride of descriptor is 8
|
|
||||||
|
|
||||||
descriptors = F.grid_sample(desc, coords.flip(1).view(1, -1, 1, 2) / torch.tensor([(desc.shape[3]-1)/2, (desc.shape[2]-1)/2], device=desc.device) - 1, align_corners=True).squeeze().T
|
device = detection_map.device
|
||||||
descriptors = F.normalize(descriptors, p=2, dim=1)
|
binary_map = (detection_map > kp_thresh).squeeze(0).squeeze(0)
|
||||||
return keypoints_input, descriptors
|
coords = torch.nonzero(binary_map).float() # y, x
|
||||||
|
|
||||||
|
if len(coords) == 0:
|
||||||
|
return torch.tensor([], device=device), torch.tensor([], device=device)
|
||||||
|
|
||||||
|
# 描述子采样
|
||||||
|
coords_for_grid = coords.flip(1).view(1, -1, 1, 2) # N, 2 -> 1, N, 1, 2 (x,y)
|
||||||
|
# 归一化到 [-1, 1]
|
||||||
|
coords_for_grid = coords_for_grid / torch.tensor([(desc.shape[3]-1)/2, (desc.shape[2]-1)/2], device=device) - 1
|
||||||
|
|
||||||
|
descriptors = F.grid_sample(desc, coords_for_grid, align_corners=True).squeeze().T
|
||||||
|
descriptors = F.normalize(descriptors, p=2, dim=1)
|
||||||
|
|
||||||
|
# 将关键点坐标从特征图尺度转换回图像尺度
|
||||||
|
# VGG到relu4_3的下采样率为8
|
||||||
|
keypoints = coords.flip(1) * 8.0 # x, y
|
||||||
|
|
||||||
|
return keypoints, descriptors
|
||||||
|
|
||||||
|
# --- (新增) 滑动窗口特征提取函数 ---
|
||||||
|
def extract_features_sliding_window(model, large_image, transform):
|
||||||
|
"""
|
||||||
|
使用滑动窗口从大图上提取所有关键点和描述子
|
||||||
|
"""
|
||||||
|
print("使用滑动窗口提取大版图特征...")
|
||||||
|
device = next(model.parameters()).device
|
||||||
|
W, H = large_image.size
|
||||||
|
window_size = config.INFERENCE_WINDOW_SIZE
|
||||||
|
stride = config.INFERENCE_STRIDE
|
||||||
|
|
||||||
|
all_kps = []
|
||||||
|
all_descs = []
|
||||||
|
|
||||||
|
for y in range(0, H, stride):
|
||||||
|
for x in range(0, W, stride):
|
||||||
|
# 确保窗口不越界
|
||||||
|
x_end = min(x + window_size, W)
|
||||||
|
y_end = min(y + window_size, H)
|
||||||
|
|
||||||
|
# 裁剪窗口
|
||||||
|
patch = large_image.crop((x, y, x_end, y_end))
|
||||||
|
|
||||||
|
# 预处理
|
||||||
|
patch_tensor = transform(patch).unsqueeze(0).to(device)
|
||||||
|
|
||||||
|
# 提取特征
|
||||||
|
kps, descs = extract_keypoints_and_descriptors(model, patch_tensor, config.KEYPOINT_THRESHOLD)
|
||||||
|
|
||||||
|
if len(kps) > 0:
|
||||||
|
# 将局部坐标转换为全局坐标
|
||||||
|
kps[:, 0] += x
|
||||||
|
kps[:, 1] += y
|
||||||
|
all_kps.append(kps)
|
||||||
|
all_descs.append(descs)
|
||||||
|
|
||||||
|
if not all_kps:
|
||||||
|
return torch.tensor([], device=device), torch.tensor([], device=device)
|
||||||
|
|
||||||
|
print(f"大版图特征提取完毕,共找到 {sum(len(k) for k in all_kps)} 个关键点。")
|
||||||
|
return torch.cat(all_kps, dim=0), torch.cat(all_descs, dim=0)
|
||||||
|
|
||||||
|
|
||||||
|
# --- 互近邻匹配 (无变动) ---
|
||||||
def mutual_nearest_neighbor(descs1, descs2):
|
def mutual_nearest_neighbor(descs1, descs2):
|
||||||
|
if len(descs1) == 0 or len(descs2) == 0:
|
||||||
|
return torch.empty((0, 2), dtype=torch.int64)
|
||||||
sim = descs1 @ descs2.T
|
sim = descs1 @ descs2.T
|
||||||
nn12 = torch.max(sim, dim=1)
|
nn12 = torch.max(sim, dim=1)
|
||||||
nn21 = torch.max(sim, dim=0)
|
nn21 = torch.max(sim, dim=0)
|
||||||
ids1 = torch.arange(0, sim.shape[0], device=sim.device)
|
ids1 = torch.arange(0, sim.shape[0], device=sim.device)
|
||||||
mask = (ids1 == nn21.indices[nn12.indices])
|
mask = (ids1 == nn21.indices[nn12.indices])
|
||||||
matches = torch.stack([ids1[mask], nn12.indices[mask]], dim=1)
|
matches = torch.stack([ids1[mask], nn12.indices[mask]], dim=1)
|
||||||
return matches.cpu().numpy()
|
return matches
|
||||||
|
|
||||||
def match_template_to_layout(model, layout_image, template_image):
|
# --- (已修改) 多尺度、多实例匹配主函数 ---
|
||||||
layout_kps, layout_descs = extract_keypoints_and_descriptors(model, layout_image, config.KEYPOINT_THRESHOLD)
|
def match_template_multiscale(model, layout_image, template_image, transform):
|
||||||
template_kps, template_descs = extract_keypoints_and_descriptors(model, template_image, config.KEYPOINT_THRESHOLD)
|
"""
|
||||||
|
在不同尺度下搜索模板,并检测多个实例
|
||||||
|
"""
|
||||||
|
# 1. 对大版图使用滑动窗口提取全部特征
|
||||||
|
layout_kps, layout_descs = extract_features_sliding_window(model, layout_image, transform)
|
||||||
|
|
||||||
|
if len(layout_kps) < config.MIN_INLIERS:
|
||||||
|
print("从大版图中提取的关键点过少,无法进行匹配。")
|
||||||
|
return []
|
||||||
|
|
||||||
active_layout_mask = torch.ones(len(layout_kps), dtype=bool, device=layout_kps.device)
|
|
||||||
found_instances = []
|
found_instances = []
|
||||||
|
active_layout_mask = torch.ones(len(layout_kps), dtype=bool, device=layout_kps.device)
|
||||||
|
|
||||||
|
# 2. 多实例迭代检测
|
||||||
while True:
|
while True:
|
||||||
current_indices = torch.nonzero(active_layout_mask).squeeze(1)
|
current_active_indices = torch.nonzero(active_layout_mask).squeeze(1)
|
||||||
if len(current_indices) < config.MIN_INLIERS:
|
|
||||||
|
# 如果剩余活动关键点过少,则停止
|
||||||
|
if len(current_active_indices) < config.MIN_INLIERS:
|
||||||
break
|
break
|
||||||
|
|
||||||
current_layout_kps, current_layout_descs = layout_kps[current_indices], layout_descs[current_indices]
|
current_layout_kps = layout_kps[current_active_indices]
|
||||||
matches = mutual_nearest_neighbor(template_descs, current_layout_descs)
|
current_layout_descs = layout_descs[current_active_indices]
|
||||||
|
|
||||||
if len(matches) < 4: break
|
best_match_info = {'inliers': 0, 'H': None, 'src_pts': None, 'dst_pts': None, 'mask': None}
|
||||||
|
|
||||||
src_pts = template_kps[matches[:, 0]].cpu().numpy()
|
# 3. 图像金字塔:遍历模板的每个尺度
|
||||||
dst_pts = current_layout_kps[matches[:, 1]].cpu().numpy()
|
print("在新尺度下搜索模板...")
|
||||||
|
for scale in config.PYRAMID_SCALES:
|
||||||
|
W, H = template_image.size
|
||||||
|
new_W, new_H = int(W * scale), int(H * scale)
|
||||||
|
|
||||||
H, mask = cv2.findHomography(src_pts, dst_pts, cv2.RANSAC, config.RANSAC_REPROJ_THRESHOLD)
|
# 缩放模板
|
||||||
if H is None or mask.sum() < config.MIN_INLIERS:
|
scaled_template = template_image.resize((new_W, new_H), Image.LANCZOS)
|
||||||
|
template_tensor = transform(scaled_template).unsqueeze(0).to(layout_kps.device)
|
||||||
|
|
||||||
|
# 提取缩放后模板的特征
|
||||||
|
template_kps, template_descs = extract_keypoints_and_descriptors(model, template_tensor, config.KEYPOINT_THRESHOLD)
|
||||||
|
|
||||||
|
if len(template_kps) < 4: continue
|
||||||
|
|
||||||
|
# 匹配当前尺度的模板和活动状态的版图特征
|
||||||
|
matches = mutual_nearest_neighbor(template_descs, current_layout_descs)
|
||||||
|
|
||||||
|
if len(matches) < 4: continue
|
||||||
|
|
||||||
|
# RANSAC
|
||||||
|
# 注意:模板关键点坐标需要还原到原始尺寸,才能计算正确的H
|
||||||
|
src_pts = template_kps[matches[:, 0]].cpu().numpy() / scale
|
||||||
|
dst_pts_indices = current_active_indices[matches[:, 1]]
|
||||||
|
dst_pts = layout_kps[dst_pts_indices].cpu().numpy()
|
||||||
|
|
||||||
|
H, mask = cv2.findHomography(src_pts, dst_pts, cv2.RANSAC, config.RANSAC_REPROJ_THRESHOLD)
|
||||||
|
|
||||||
|
if H is not None and mask.sum() > best_match_info['inliers']:
|
||||||
|
best_match_info = {'inliers': mask.sum(), 'H': H, 'mask': mask, 'scale': scale, 'dst_pts': dst_pts}
|
||||||
|
|
||||||
|
# 4. 如果在所有尺度中找到了最佳匹配,则记录并屏蔽
|
||||||
|
if best_match_info['inliers'] > config.MIN_INLIERS:
|
||||||
|
print(f"找到一个匹配实例!内点数: {best_match_info['inliers']}, 使用的模板尺度: {best_match_info['scale']:.2f}x")
|
||||||
|
|
||||||
|
inlier_mask = best_match_info['mask'].ravel().astype(bool)
|
||||||
|
inlier_layout_kps = best_match_info['dst_pts'][inlier_mask]
|
||||||
|
|
||||||
|
x_min, y_min = inlier_layout_kps.min(axis=0)
|
||||||
|
x_max, y_max = inlier_layout_kps.max(axis=0)
|
||||||
|
|
||||||
|
instance = {'x': int(x_min), 'y': int(y_min), 'width': int(x_max - x_min), 'height': int(y_max - y_min), 'homography': best_match_info['H']}
|
||||||
|
found_instances.append(instance)
|
||||||
|
|
||||||
|
# 屏蔽已匹配区域的关键点,以便检测下一个实例
|
||||||
|
kp_x, kp_y = layout_kps[:, 0], layout_kps[:, 1]
|
||||||
|
region_mask = (kp_x >= x_min) & (kp_x <= x_max) & (kp_y >= y_min) & (kp_y <= y_max)
|
||||||
|
active_layout_mask[region_mask] = False
|
||||||
|
|
||||||
|
print(f"剩余活动关键点: {active_layout_mask.sum()}")
|
||||||
|
else:
|
||||||
|
# 如果在所有尺度下都找不到好的匹配,则结束搜索
|
||||||
|
print("在所有尺度下均未找到新的匹配实例,搜索结束。")
|
||||||
break
|
break
|
||||||
|
|
||||||
inlier_mask = mask.ravel().astype(bool)
|
|
||||||
|
|
||||||
# 区域屏蔽逻辑
|
|
||||||
inlier_layout_kps = dst_pts[inlier_mask]
|
|
||||||
x_min, y_min = inlier_layout_kps.min(axis=0)
|
|
||||||
x_max, y_max = inlier_layout_kps.max(axis=0)
|
|
||||||
|
|
||||||
instance = {'x': int(x_min), 'y': int(y_min), 'width': int(x_max - x_min), 'height': int(y_max - y_min), 'homography': H}
|
|
||||||
found_instances.append(instance)
|
|
||||||
|
|
||||||
kp_x, kp_y = layout_kps[:, 0], layout_kps[:, 1]
|
|
||||||
region_mask = (kp_x >= x_min) & (kp_x <= x_max) & (kp_y >= y_min) & (kp_y <= y_max)
|
|
||||||
active_layout_mask[region_mask] = False
|
|
||||||
|
|
||||||
print(f"找到实例,内点数: {mask.sum()}。剩余活动关键点: {active_layout_mask.sum()}")
|
|
||||||
|
|
||||||
return found_instances
|
return found_instances
|
||||||
|
|
||||||
def visualize_matches(layout_path, template_path, bboxes, output_path):
|
|
||||||
|
def visualize_matches(layout_path, bboxes, output_path):
|
||||||
layout_img = cv2.imread(layout_path)
|
layout_img = cv2.imread(layout_path)
|
||||||
for i, bbox in enumerate(bboxes):
|
for i, bbox in enumerate(bboxes):
|
||||||
x, y, w, h = bbox['x'], bbox['y'], bbox['width'], bbox['height']
|
x, y, w, h = bbox['x'], bbox['y'], bbox['width'], bbox['height']
|
||||||
@@ -83,8 +188,9 @@ def visualize_matches(layout_path, template_path, bboxes, output_path):
|
|||||||
cv2.imwrite(output_path, layout_img)
|
cv2.imwrite(output_path, layout_img)
|
||||||
print(f"可视化结果已保存至: {output_path}")
|
print(f"可视化结果已保存至: {output_path}")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(description="使用 RoRD 进行模板匹配")
|
parser = argparse.ArgumentParser(description="使用 RoRD 进行多尺度模板匹配")
|
||||||
parser.add_argument('--model_path', type=str, default=config.MODEL_PATH)
|
parser.add_argument('--model_path', type=str, default=config.MODEL_PATH)
|
||||||
parser.add_argument('--layout', type=str, required=True)
|
parser.add_argument('--layout', type=str, required=True)
|
||||||
parser.add_argument('--template', type=str, required=True)
|
parser.add_argument('--template', type=str, required=True)
|
||||||
@@ -96,13 +202,14 @@ if __name__ == "__main__":
|
|||||||
model.load_state_dict(torch.load(args.model_path))
|
model.load_state_dict(torch.load(args.model_path))
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
layout_tensor = transform(Image.open(args.layout).convert('L')).unsqueeze(0).cuda()
|
layout_image = Image.open(args.layout).convert('L')
|
||||||
template_tensor = transform(Image.open(args.template).convert('L')).unsqueeze(0).cuda()
|
template_image = Image.open(args.template).convert('L')
|
||||||
|
|
||||||
|
detected_bboxes = match_template_multiscale(model, layout_image, template_image, transform)
|
||||||
|
|
||||||
detected_bboxes = match_template_to_layout(model, layout_tensor, template_tensor)
|
|
||||||
print("\n检测到的边界框:")
|
print("\n检测到的边界框:")
|
||||||
for bbox in detected_bboxes:
|
for bbox in detected_bboxes:
|
||||||
print(bbox)
|
print(bbox)
|
||||||
|
|
||||||
if args.output:
|
if args.output:
|
||||||
visualize_matches(args.layout, args.template, detected_bboxes, args.output)
|
visualize_matches(args.layout, detected_bboxes, args.output)
|
||||||
61
train.py
61
train.py
@@ -15,13 +15,14 @@ import config
|
|||||||
from models.rord import RoRD
|
from models.rord import RoRD
|
||||||
from utils.data_utils import get_transform
|
from utils.data_utils import get_transform
|
||||||
|
|
||||||
# --- 训练专用数据集类 ---
|
# --- (已修改) 训练专用数据集类 ---
|
||||||
class ICLayoutTrainingDataset(Dataset):
|
class ICLayoutTrainingDataset(Dataset):
|
||||||
def __init__(self, image_dir, patch_size=256, transform=None):
|
def __init__(self, image_dir, patch_size=256, transform=None, scale_range=(1.0, 1.0)):
|
||||||
self.image_dir = image_dir
|
self.image_dir = image_dir
|
||||||
self.image_paths = [os.path.join(image_dir, f) for f in os.listdir(image_dir) if f.endswith('.png')]
|
self.image_paths = [os.path.join(image_dir, f) for f in os.listdir(image_dir) if f.endswith('.png')]
|
||||||
self.patch_size = patch_size
|
self.patch_size = patch_size
|
||||||
self.transform = transform
|
self.transform = transform
|
||||||
|
self.scale_range = scale_range # 新增尺度范围参数
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.image_paths)
|
return len(self.image_paths)
|
||||||
@@ -29,14 +30,30 @@ class ICLayoutTrainingDataset(Dataset):
|
|||||||
def __getitem__(self, index):
|
def __getitem__(self, index):
|
||||||
img_path = self.image_paths[index]
|
img_path = self.image_paths[index]
|
||||||
image = Image.open(img_path).convert('L')
|
image = Image.open(img_path).convert('L')
|
||||||
|
|
||||||
W, H = image.size
|
W, H = image.size
|
||||||
x = np.random.randint(0, W - self.patch_size + 1)
|
|
||||||
y = np.random.randint(0, H - self.patch_size + 1)
|
# --- 新增:尺度抖动数据增强 ---
|
||||||
patch = image.crop((x, y, x + self.patch_size, y + self.patch_size))
|
# 1. 随机选择一个缩放比例
|
||||||
|
scale = np.random.uniform(self.scale_range[0], self.scale_range[1])
|
||||||
|
# 2. 根据缩放比例计算需要从原图裁剪的尺寸
|
||||||
|
crop_size = int(self.patch_size / scale)
|
||||||
|
|
||||||
|
# 确保裁剪尺寸不超过图像边界
|
||||||
|
if crop_size > min(W, H):
|
||||||
|
crop_size = min(W, H)
|
||||||
|
|
||||||
|
# 3. 随机裁剪
|
||||||
|
x = np.random.randint(0, W - crop_size + 1)
|
||||||
|
y = np.random.randint(0, H - crop_size + 1)
|
||||||
|
patch = image.crop((x, y, x + crop_size, y + crop_size))
|
||||||
|
|
||||||
|
# 4. 将裁剪出的图像块缩放回标准的 patch_size
|
||||||
|
patch = patch.resize((self.patch_size, self.patch_size), Image.LANCZOS)
|
||||||
|
# --- 尺度抖动结束 ---
|
||||||
|
|
||||||
patch_np = np.array(patch)
|
patch_np = np.array(patch)
|
||||||
|
|
||||||
# 实现8个方向的离散几何变换
|
# 实现8个方向的离散几何变换 (这部分逻辑不变)
|
||||||
theta_deg = np.random.choice([0, 90, 180, 270])
|
theta_deg = np.random.choice([0, 90, 180, 270])
|
||||||
is_mirrored = np.random.choice([True, False])
|
is_mirrored = np.random.choice([True, False])
|
||||||
cx, cy = self.patch_size / 2.0, self.patch_size / 2.0
|
cx, cy = self.patch_size / 2.0, self.patch_size / 2.0
|
||||||
@@ -59,10 +76,10 @@ class ICLayoutTrainingDataset(Dataset):
|
|||||||
patch = self.transform(patch)
|
patch = self.transform(patch)
|
||||||
transformed_patch = self.transform(transformed_patch)
|
transformed_patch = self.transform(transformed_patch)
|
||||||
|
|
||||||
H_tensor = torch.from_numpy(H[:2, :]).float() # 通常损失函数需要2x3的仿射矩阵
|
H_tensor = torch.from_numpy(H[:2, :]).float()
|
||||||
return patch, transformed_patch, H_tensor
|
return patch, transformed_patch, H_tensor
|
||||||
|
|
||||||
# --- 特征图变换与损失函数 ---
|
# --- 特征图变换与损失函数 (无变动) ---
|
||||||
def warp_feature_map(feature_map, H_inv):
|
def warp_feature_map(feature_map, H_inv):
|
||||||
B, C, H, W = feature_map.size()
|
B, C, H, W = feature_map.size()
|
||||||
grid = F.affine_grid(H_inv, feature_map.size(), align_corners=False).to(feature_map.device)
|
grid = F.affine_grid(H_inv, feature_map.size(), align_corners=False).to(feature_map.device)
|
||||||
@@ -77,34 +94,29 @@ def compute_detection_loss(det_original, det_rotated, H):
|
|||||||
def compute_description_loss(desc_original, desc_rotated, H, margin=1.0):
|
def compute_description_loss(desc_original, desc_rotated, H, margin=1.0):
|
||||||
B, C, H_feat, W_feat = desc_original.size()
|
B, C, H_feat, W_feat = desc_original.size()
|
||||||
num_samples = 100
|
num_samples = 100
|
||||||
|
coords = torch.rand(B, num_samples, 2, device=desc_original.device) * 2 - 1
|
||||||
# 随机采样锚点坐标
|
|
||||||
coords = torch.rand(B, num_samples, 2, device=desc_original.device) * 2 - 1 # [-1, 1]
|
|
||||||
|
|
||||||
# 提取锚点描述子
|
|
||||||
anchor = F.grid_sample(desc_original, coords.unsqueeze(1), align_corners=False).squeeze(2).transpose(1, 2)
|
anchor = F.grid_sample(desc_original, coords.unsqueeze(1), align_corners=False).squeeze(2).transpose(1, 2)
|
||||||
|
|
||||||
# 计算正样本坐标
|
|
||||||
coords_hom = torch.cat([coords, torch.ones(B, num_samples, 1, device=coords.device)], dim=2)
|
coords_hom = torch.cat([coords, torch.ones(B, num_samples, 1, device=coords.device)], dim=2)
|
||||||
M_inv = torch.inverse(torch.cat([H, torch.tensor([0.0, 0.0, 1.0]).view(1, 1, 3).repeat(H.shape[0], 1, 1)], dim=1))
|
M_inv = torch.inverse(torch.cat([H, torch.tensor([0.0, 0.0, 1.0]).view(1, 1, 3).repeat(H.shape[0], 1, 1)], dim=1))
|
||||||
coords_transformed = (coords_hom @ M_inv.transpose(1, 2))[:, :, :2]
|
coords_transformed = (coords_hom @ M_inv.transpose(1, 2))[:, :, :2]
|
||||||
|
|
||||||
# 提取正样本描述子
|
|
||||||
positive = F.grid_sample(desc_rotated, coords_transformed.unsqueeze(1), align_corners=False).squeeze(2).transpose(1, 2)
|
positive = F.grid_sample(desc_rotated, coords_transformed.unsqueeze(1), align_corners=False).squeeze(2).transpose(1, 2)
|
||||||
|
|
||||||
# 随机采样负样本
|
|
||||||
neg_coords = torch.rand(B, num_samples, 2, device=desc_original.device) * 2 - 1
|
neg_coords = torch.rand(B, num_samples, 2, device=desc_original.device) * 2 - 1
|
||||||
negative = F.grid_sample(desc_rotated, neg_coords.unsqueeze(1), align_corners=False).squeeze(2).transpose(1, 2)
|
negative = F.grid_sample(desc_rotated, neg_coords.unsqueeze(1), align_corners=False).squeeze(2).transpose(1, 2)
|
||||||
|
|
||||||
triplet_loss = nn.TripletMarginLoss(margin=margin, p=2)
|
triplet_loss = nn.TripletMarginLoss(margin=margin, p=2)
|
||||||
return triplet_loss(anchor, positive, negative)
|
return triplet_loss(anchor, positive, negative)
|
||||||
|
|
||||||
# --- 主函数与命令行接口 ---
|
# --- (已修改) 主函数与命令行接口 ---
|
||||||
def main(args):
|
def main(args):
|
||||||
print("--- 开始训练 RoRD 模型 ---")
|
print("--- 开始训练 RoRD 模型 ---")
|
||||||
print(f"训练参数: Epochs={args.epochs}, Batch Size={args.batch_size}, LR={args.lr}")
|
print(f"训练参数: Epochs={args.epochs}, Batch Size={args.batch_size}, LR={args.lr}")
|
||||||
transform = get_transform()
|
transform = get_transform()
|
||||||
dataset = ICLayoutTrainingDataset(args.data_dir, patch_size=config.PATCH_SIZE, transform=transform)
|
# 在数据集初始化时传入尺度抖动范围
|
||||||
|
dataset = ICLayoutTrainingDataset(
|
||||||
|
args.data_dir,
|
||||||
|
patch_size=config.PATCH_SIZE,
|
||||||
|
transform=transform,
|
||||||
|
scale_range=config.SCALE_JITTER_RANGE
|
||||||
|
)
|
||||||
dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=4)
|
dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=4)
|
||||||
model = RoRD().cuda()
|
model = RoRD().cuda()
|
||||||
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
|
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
|
||||||
@@ -116,14 +128,11 @@ def main(args):
|
|||||||
original, rotated, H = original.cuda(), rotated.cuda(), H.cuda()
|
original, rotated, H = original.cuda(), rotated.cuda(), H.cuda()
|
||||||
det_original, desc_original = model(original)
|
det_original, desc_original = model(original)
|
||||||
det_rotated, desc_rotated = model(rotated)
|
det_rotated, desc_rotated = model(rotated)
|
||||||
|
|
||||||
loss = compute_detection_loss(det_original, det_rotated, H) + compute_description_loss(desc_original, desc_rotated, H)
|
loss = compute_detection_loss(det_original, det_rotated, H) + compute_description_loss(desc_original, desc_rotated, H)
|
||||||
|
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
total_loss_val += loss.item()
|
total_loss_val += loss.item()
|
||||||
|
|
||||||
print(f"--- Epoch {epoch+1} 完成, 平均 Loss: {total_loss_val / len(dataloader):.4f} ---")
|
print(f"--- Epoch {epoch+1} 完成, 平均 Loss: {total_loss_val / len(dataloader):.4f} ---")
|
||||||
|
|
||||||
if not os.path.exists(args.save_dir):
|
if not os.path.exists(args.save_dir):
|
||||||
|
|||||||
Reference in New Issue
Block a user