finish Experiment Tracking and Evaluation

This commit is contained in:
Jiao77
2025-09-25 21:24:41 +08:00
parent 05ec32bac1
commit 17d3f419f6
9 changed files with 565 additions and 37 deletions

View File

@@ -9,6 +9,10 @@ import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
try:
from torch.utils.tensorboard import SummaryWriter
except ImportError: # pragma: no cover - fallback for environments without torch tensorboard
from tensorboardX import SummaryWriter # type: ignore
from models.rord import RoRD
from utils.config_loader import load_config, to_absolute_path
@@ -97,16 +101,28 @@ def mutual_nearest_neighbor(descs1, descs2):
return matches
# --- (已修改) 多尺度、多实例匹配主函数 ---
def match_template_multiscale(model, layout_image, template_image, transform, matching_cfg):
def match_template_multiscale(
model,
layout_image,
template_image,
transform,
matching_cfg,
log_writer: SummaryWriter | None = None,
log_step: int = 0,
):
"""
在不同尺度下搜索模板,并检测多个实例
"""
# 1. 对大版图使用滑动窗口提取全部特征
layout_kps, layout_descs = extract_features_sliding_window(model, layout_image, transform, matching_cfg)
if log_writer:
log_writer.add_scalar("match/layout_keypoints", len(layout_kps), log_step)
min_inliers = int(matching_cfg.min_inliers)
if len(layout_kps) < min_inliers:
print("从大版图中提取的关键点过少,无法进行匹配。")
if log_writer:
log_writer.add_scalar("match/instances_found", 0, log_step)
return []
found_instances = []
@@ -162,6 +178,10 @@ def match_template_multiscale(model, layout_image, template_image, transform, ma
# 4. 如果在所有尺度中找到了最佳匹配,则记录并屏蔽
if best_match_info['inliers'] > min_inliers:
print(f"找到一个匹配实例!内点数: {best_match_info['inliers']}, 使用的模板尺度: {best_match_info['scale']:.2f}x")
if log_writer:
instance_index = len(found_instances)
log_writer.add_scalar("match/instance_inliers", int(best_match_info['inliers']), log_step + instance_index)
log_writer.add_scalar("match/instance_scale", float(best_match_info['scale']), log_step + instance_index)
inlier_mask = best_match_info['mask'].ravel().astype(bool)
inlier_layout_kps = best_match_info['dst_pts'][inlier_mask]
@@ -183,6 +203,9 @@ def match_template_multiscale(model, layout_image, template_image, transform, ma
print("在所有尺度下均未找到新的匹配实例,搜索结束。")
break
if log_writer:
log_writer.add_scalar("match/instances_found", len(found_instances), log_step)
return found_instances
@@ -200,6 +223,10 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser(description="使用 RoRD 进行多尺度模板匹配")
parser.add_argument('--config', type=str, default="configs/base_config.yaml", help="YAML 配置文件路径")
parser.add_argument('--model_path', type=str, default=None, help="模型权重路径,若未提供则使用配置文件中的路径")
parser.add_argument('--log_dir', type=str, default=None, help="TensorBoard 日志根目录,覆盖配置文件设置")
parser.add_argument('--experiment_name', type=str, default=None, help="TensorBoard 实验名称,覆盖配置文件设置")
parser.add_argument('--tb_log_matches', action='store_true', help="启用模板匹配过程的 TensorBoard 记录")
parser.add_argument('--disable_tensorboard', action='store_true', help="禁用 TensorBoard 记录")
parser.add_argument('--layout', type=str, required=True)
parser.add_argument('--template', type=str, required=True)
parser.add_argument('--output', type=str)
@@ -208,8 +235,33 @@ if __name__ == "__main__":
cfg = load_config(args.config)
config_dir = Path(args.config).resolve().parent
matching_cfg = cfg.matching
logging_cfg = cfg.get("logging", None)
model_path = args.model_path or str(to_absolute_path(cfg.paths.model_path, config_dir))
use_tensorboard = False
log_dir = None
experiment_name = None
if logging_cfg is not None:
use_tensorboard = bool(logging_cfg.get("use_tensorboard", False))
log_dir = logging_cfg.get("log_dir", "runs")
experiment_name = logging_cfg.get("experiment_name", "default")
if args.disable_tensorboard:
use_tensorboard = False
if args.log_dir is not None:
log_dir = args.log_dir
if args.experiment_name is not None:
experiment_name = args.experiment_name
should_log_matches = args.tb_log_matches and use_tensorboard and log_dir is not None
writer = None
if should_log_matches:
log_root = Path(log_dir).expanduser()
exp_folder = experiment_name or "default"
tb_path = log_root / "match" / exp_folder
tb_path.parent.mkdir(parents=True, exist_ok=True)
writer = SummaryWriter(tb_path.as_posix())
transform = get_transform()
model = RoRD().cuda()
model.load_state_dict(torch.load(model_path))
@@ -218,11 +270,24 @@ if __name__ == "__main__":
layout_image = Image.open(args.layout).convert('L')
template_image = Image.open(args.template).convert('L')
detected_bboxes = match_template_multiscale(model, layout_image, template_image, transform, matching_cfg)
detected_bboxes = match_template_multiscale(
model,
layout_image,
template_image,
transform,
matching_cfg,
log_writer=writer,
log_step=0,
)
print("\n检测到的边界框:")
for bbox in detected_bboxes:
print(bbox)
if args.output:
visualize_matches(args.layout, detected_bboxes, args.output)
visualize_matches(args.layout, detected_bboxes, args.output)
if writer:
writer.add_scalar("match/output_instances", len(detected_bboxes), 0)
writer.add_text("match/layout_path", args.layout, 0)
writer.close()