Files
RoRD-Layout-Recognation/evaluate.py

114 lines
4.8 KiB
Python
Raw Normal View History

2025-06-08 15:38:56 +08:00
# evaluate.py
2025-06-07 23:45:32 +08:00
import torch
2025-06-08 15:38:56 +08:00
from PIL import Image
2025-06-07 23:45:32 +08:00
import json
import os
2025-06-08 15:38:56 +08:00
import argparse
import config
from models.rord import RoRD
from utils.data_utils import get_transform
from data.ic_dataset import ICLayoutDataset
2025-07-22 23:43:35 +08:00
# (Modified) Import new matching function
2025-06-09 01:49:13 +08:00
from match import match_template_multiscale
2025-06-07 23:45:32 +08:00
def compute_iou(box1, box2):
x1, y1, w1, h1 = box1['x'], box1['y'], box1['width'], box1['height']
x2, y2, w2, h2 = box2['x'], box2['y'], box2['width'], box2['height']
2025-06-08 15:38:56 +08:00
inter_x1, inter_y1 = max(x1, x2), max(y1, y2)
inter_x2, inter_y2 = min(x1 + w1, x2 + w2), min(y1 + h1, y2 + h2)
2025-06-07 23:45:32 +08:00
inter_area = max(0, inter_x2 - inter_x1) * max(0, inter_y2 - inter_y1)
2025-06-08 15:38:56 +08:00
union_area = w1 * h1 + w2 * h2 - inter_area
return inter_area / union_area if union_area > 0 else 0
2025-06-07 23:45:32 +08:00
2025-07-22 23:43:35 +08:00
# --- (Modified) Evaluation function ---
2025-06-09 01:49:13 +08:00
def evaluate(model, val_dataset_dir, val_annotations_dir, template_dir):
2025-06-07 23:45:32 +08:00
model.eval()
2025-06-08 15:38:56 +08:00
all_tp, all_fp, all_fn = 0, 0, 0
2025-06-09 01:49:13 +08:00
2025-07-22 23:43:35 +08:00
# Only need a unified transform for internal use by matching function
2025-06-08 15:38:56 +08:00
transform = get_transform()
template_paths = [os.path.join(template_dir, f) for f in os.listdir(template_dir) if f.endswith('.png')]
2025-06-09 01:49:13 +08:00
layout_image_names = [f for f in os.listdir(val_dataset_dir) if f.endswith('.png')]
2025-07-22 23:43:35 +08:00
# (Modified) Loop through each layout file in validation set
2025-06-09 01:49:13 +08:00
for layout_name in layout_image_names:
2025-07-22 23:43:35 +08:00
print(f"\nEvaluating layout: {layout_name}")
2025-06-09 01:49:13 +08:00
layout_path = os.path.join(val_dataset_dir, layout_name)
annotation_path = os.path.join(val_annotations_dir, layout_name.replace('.png', '.json'))
2025-07-22 23:43:35 +08:00
# Load original PIL image to support sliding window
2025-06-09 01:49:13 +08:00
layout_image = Image.open(layout_path).convert('L')
2025-06-07 23:45:32 +08:00
2025-07-22 23:43:35 +08:00
# Load annotation information
2025-06-09 01:49:13 +08:00
if not os.path.exists(annotation_path):
continue
with open(annotation_path, 'r') as f:
annotation = json.load(f)
2025-07-22 23:43:35 +08:00
# Group ground truth annotations by template
2025-06-09 01:49:13 +08:00
gt_by_template = {os.path.basename(box['template']): [] for box in annotation.get('boxes', [])}
2025-06-07 23:45:32 +08:00
for box in annotation.get('boxes', []):
2025-06-09 01:49:13 +08:00
gt_by_template[os.path.basename(box['template'])].append(box)
2025-06-07 23:45:32 +08:00
2025-07-22 23:43:35 +08:00
# Iterate through each template and perform matching on current layout
2025-06-08 15:38:56 +08:00
for template_path in template_paths:
2025-06-07 23:45:32 +08:00
template_name = os.path.basename(template_path)
2025-06-09 01:49:13 +08:00
template_image = Image.open(template_path).convert('L')
2025-07-22 23:43:35 +08:00
# (Modified) Call new multi-scale matching function
2025-06-09 01:49:13 +08:00
detected = match_template_multiscale(model, layout_image, template_image, transform)
2025-06-08 15:38:56 +08:00
gt_boxes = gt_by_template.get(template_name, [])
2025-07-22 23:43:35 +08:00
# Calculate TP, FP, FN (this logic remains unchanged)
2025-06-08 15:38:56 +08:00
matched_gt = [False] * len(gt_boxes)
tp = 0
2025-06-09 01:49:13 +08:00
if len(detected) > 0:
for det_box in detected:
best_iou = 0
best_gt_idx = -1
for i, gt_box in enumerate(gt_boxes):
if matched_gt[i]: continue
iou = compute_iou(det_box, gt_box)
if iou > best_iou:
best_iou, best_gt_idx = iou, i
if best_iou > config.IOU_THRESHOLD:
if not matched_gt[best_gt_idx]:
tp += 1
matched_gt[best_gt_idx] = True
2025-06-08 15:38:56 +08:00
2025-06-09 01:49:13 +08:00
fp = len(detected) - tp
fn = len(gt_boxes) - tp
2025-06-08 15:38:56 +08:00
all_tp += tp
2025-06-09 01:49:13 +08:00
all_fp += fp
all_fn += fn
2025-06-08 15:38:56 +08:00
2025-07-22 23:43:35 +08:00
# Calculate final metrics
2025-06-08 15:38:56 +08:00
precision = all_tp / (all_tp + all_fp) if (all_tp + all_fp) > 0 else 0
recall = all_tp / (all_tp + all_fn) if (all_tp + all_fn) > 0 else 0
2025-06-07 23:45:32 +08:00
f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
2025-06-08 15:38:56 +08:00
return {'precision': precision, 'recall': recall, 'f1': f1}
2025-06-07 23:45:32 +08:00
if __name__ == "__main__":
2025-07-22 23:43:35 +08:00
parser = argparse.ArgumentParser(description="Evaluate RoRD model performance")
2025-06-08 15:38:56 +08:00
parser.add_argument('--model_path', type=str, default=config.MODEL_PATH)
parser.add_argument('--val_dir', type=str, default=config.VAL_IMG_DIR)
parser.add_argument('--annotations_dir', type=str, default=config.VAL_ANN_DIR)
parser.add_argument('--templates_dir', type=str, default=config.TEMPLATE_DIR)
args = parser.parse_args()
2025-06-07 23:45:32 +08:00
model = RoRD().cuda()
2025-06-08 15:38:56 +08:00
model.load_state_dict(torch.load(args.model_path))
2025-07-22 23:43:35 +08:00
# (Modified) No longer need to preload dataset, directly pass paths
2025-06-09 01:49:13 +08:00
results = evaluate(model, args.val_dir, args.annotations_dir, args.templates_dir)
2025-07-22 23:43:35 +08:00
print("\n--- Evaluation Results ---")
print(f" Precision: {results['precision']:.4f}")
print(f" Recall: {results['recall']:.4f}")
print(f" F1 Score: {results['f1']:.4f}")