第二次大修

This commit is contained in:
Jiao77
2025-06-08 15:38:56 +08:00
parent 53ef1ec99c
commit f0b2e1b605
10 changed files with 315 additions and 508 deletions

View File

@@ -1,123 +1,84 @@
from models.rord import RoRD
from data.ic_dataset import ICLayoutDataset
from utils.transforms import SobelTransform
from match import match_template_to_layout
# evaluate.py
import torch
from torchvision import transforms
from PIL import Image
import json
import os
from PIL import Image
import argparse
import config
from models.rord import RoRD
from utils.data_utils import get_transform
from data.ic_dataset import ICLayoutDataset
from match import match_template_to_layout
def compute_iou(box1, box2):
x1, y1, w1, h1 = box1['x'], box1['y'], box1['width'], box1['height']
x2, y2, w2, h2 = box2['x'], box2['y'], box2['width'], box2['height']
inter_x1 = max(x1, x2)
inter_y1 = max(y1, y2)
inter_x2 = min(x1 + w1, x2 + w2)
inter_y2 = min(y1 + h1, y2 + h2)
inter_x1, inter_y1 = max(x1, x2), max(y1, y2)
inter_x2, inter_y2 = min(x1 + w1, x2 + w2), min(y1 + h1, y2 + h2)
inter_area = max(0, inter_x2 - inter_x1) * max(0, inter_y2 - inter_y1)
box1_area = w1 * h1
box2_area = w2 * h2
union_area = box1_area + box2_area - inter_area
iou = inter_area / union_area if union_area > 0 else 0
return iou
union_area = w1 * h1 + w2 * h2 - inter_area
return inter_area / union_area if union_area > 0 else 0
def evaluate(model, val_dataset, templates, iou_threshold=0.5):
def evaluate(model, val_dataset, template_dir):
model.eval()
all_true_positives = 0
all_false_positives = 0
all_false_negatives = 0
all_tp, all_fp, all_fn = 0, 0, 0
transform = get_transform()
template_paths = [os.path.join(template_dir, f) for f in os.listdir(template_dir) if f.endswith('.png')]
for layout_idx in range(len(val_dataset)):
layout_image, annotation = val_dataset[layout_idx]
# layout_image is [3, H, W]
layout_tensor = layout_image.unsqueeze(0).cuda() # [1, 3, H, W]
# 假设 annotation 是 {"boxes": [{"template": "template1.png", "x": x, "y": y, "width": w, "height": h}, ...]}
gt_boxes_by_template = {}
for layout_tensor, annotation in val_dataset:
layout_tensor = layout_tensor.unsqueeze(0).cuda()
gt_by_template = {box['template']: [] for box in annotation.get('boxes', [])}
for box in annotation.get('boxes', []):
template_name = box['template']
if template_name not in gt_boxes_by_template:
gt_boxes_by_template[template_name] = []
gt_boxes_by_template[template_name].append(box)
gt_by_template[box['template']].append(box)
for template_path in templates:
for template_path in template_paths:
template_name = os.path.basename(template_path)
template_image = Image.open(template_path).convert('L')
template_tensor = transform(template_image).unsqueeze(0).cuda() # [1, 3, H, W]
# 执行匹配
detected_bboxes = match_template_to_layout(model, layout_tensor, template_tensor)
# 获取当前模板的 gt_boxes
gt_boxes = gt_boxes_by_template.get(template_name, [])
# 初始化已分配的 gt_box 索引
assigned_gt = set()
for det_box in detected_bboxes:
template_tensor = transform(Image.open(template_path).convert('L')).unsqueeze(0).cuda()
detected = match_template_to_layout(model, layout_tensor, template_tensor)
gt_boxes = gt_by_template.get(template_name, [])
matched_gt = [False] * len(gt_boxes)
tp = 0
for det_box in detected:
best_iou = 0
best_gt_idx = -1
for idx, gt_box in enumerate(gt_boxes):
if idx in assigned_gt:
continue
for i, gt_box in enumerate(gt_boxes):
if matched_gt[i]: continue
iou = compute_iou(det_box, gt_box)
if iou > best_iou:
best_iou = iou
best_gt_idx = idx
if best_iou > iou_threshold and best_gt_idx != -1:
all_true_positives += 1
assigned_gt.add(best_gt_idx)
else:
all_false_positives += 1
best_iou, best_gt_idx = iou, i
if best_iou > config.IOU_THRESHOLD:
tp += 1
matched_gt[best_gt_idx] = True
all_tp += tp
all_fp += len(detected) - tp
all_fn += len(gt_boxes) - tp
# 计算 FN未分配的 gt_box
for idx in range(len(gt_boxes)):
if idx not in assigned_gt:
all_false_negatives += 1
# 计算评估指标
precision = all_true_positives / (all_true_positives + all_false_positives) if (all_true_positives + all_false_positives) > 0 else 0
recall = all_true_positives / (all_true_positives + all_false_negatives) if (all_true_positives + all_false_negatives) > 0 else 0
precision = all_tp / (all_tp + all_fp) if (all_tp + all_fp) > 0 else 0
recall = all_tp / (all_tp + all_fn) if (all_tp + all_fn) > 0 else 0
f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
return {
'precision': precision,
'recall': recall,
'f1': f1
}
return {'precision': precision, 'recall': recall, 'f1': f1}
if __name__ == "__main__":
# 设置变换
transform = transforms.Compose([
SobelTransform(),
transforms.ToTensor(),
transforms.Lambda(lambda x: x.repeat(3, 1, 1)), # [1, H, W] -> [3, H, W]
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
parser = argparse.ArgumentParser(description="评估 RoRD 模型性能")
parser.add_argument('--model_path', type=str, default=config.MODEL_PATH)
parser.add_argument('--val_dir', type=str, default=config.VAL_IMG_DIR)
parser.add_argument('--annotations_dir', type=str, default=config.VAL_ANN_DIR)
parser.add_argument('--templates_dir', type=str, default=config.TEMPLATE_DIR)
args = parser.parse_args()
# 加载模型
model = RoRD().cuda()
model.load_state_dict(torch.load('path/to/weights.pth'))
model.eval()
# 定义验证数据集
val_dataset = ICLayoutDataset(
image_dir='path/to/val/images',
annotation_dir='path/to/val/annotations',
transform=transform
)
# 定义模板列表
templates = ['path/to/templates/template1.png', 'path/to/templates/template2.png'] # 替换为实际模板路径
# 评估模型
results = evaluate(model, val_dataset, templates)
model.load_state_dict(torch.load(args.model_path))
val_dataset = ICLayoutDataset(args.val_dir, args.annotations_dir, get_transform())
results = evaluate(model, val_dataset, args.templates_dir)
print("评估结果:")
print(f"精确率: {results['precision']:.4f}")
print(f"召回率: {results['recall']:.4f}")
print(f"F1 分数: {results['f1']:.4f}")
print(f" 精确率 (Precision): {results['precision']:.4f}")
print(f" 召回率 (Recall): {results['recall']:.4f}")
print(f" F1 分数 (F1 Score): {results['f1']:.4f}")