Merge pull request 'taishi-addtodolist' (#1) from taishi-addtodolist into main

Reviewed-on: #1
This commit is contained in:
2025-09-25 12:31:08 +00:00
13 changed files with 641 additions and 449 deletions

View File

@@ -1,29 +1,34 @@
# config.py
"""Legacy config shim loading values from YAML."""
from __future__ import annotations
from pathlib import Path
from omegaconf import OmegaConf
_BASE_CONFIG_PATH = Path(__file__).resolve().parent / "configs" / "base_config.yaml"
_CFG = OmegaConf.load(_BASE_CONFIG_PATH)
# --- 训练参数 ---
LEARNING_RATE = 5e-5 # 降低学习率,提高训练稳定性
BATCH_SIZE = 8 # 增加批次大小,提高训练效率
NUM_EPOCHS = 50 # 增加训练轮数
PATCH_SIZE = 256
# (优化) 训练时尺度抖动范围 - 缩小范围提高稳定性
SCALE_JITTER_RANGE = (0.8, 1.2)
LEARNING_RATE = float(_CFG.training.learning_rate)
BATCH_SIZE = int(_CFG.training.batch_size)
NUM_EPOCHS = int(_CFG.training.num_epochs)
PATCH_SIZE = int(_CFG.training.patch_size)
SCALE_JITTER_RANGE = tuple(float(x) for x in _CFG.training.scale_jitter_range)
# --- 匹配与评估参数 ---
KEYPOINT_THRESHOLD = 0.5
RANSAC_REPROJ_THRESHOLD = 5.0
MIN_INLIERS = 15
IOU_THRESHOLD = 0.5
# (新增) 推理时模板匹配的图像金字塔尺度
PYRAMID_SCALES = [0.75, 1.0, 1.5]
# (新增) 推理时处理大版图的滑动窗口参数
INFERENCE_WINDOW_SIZE = 1024
INFERENCE_STRIDE = 768 # 小于INFERENCE_WINDOW_SIZE以保证重叠
KEYPOINT_THRESHOLD = float(_CFG.matching.keypoint_threshold)
RANSAC_REPROJ_THRESHOLD = float(_CFG.matching.ransac_reproj_threshold)
MIN_INLIERS = int(_CFG.matching.min_inliers)
PYRAMID_SCALES = [float(s) for s in _CFG.matching.pyramid_scales]
INFERENCE_WINDOW_SIZE = int(_CFG.matching.inference_window_size)
INFERENCE_STRIDE = int(_CFG.matching.inference_stride)
IOU_THRESHOLD = float(_CFG.evaluation.iou_threshold)
# --- 文件路径 ---
# (路径保持不变, 请根据您的环境修改)
LAYOUT_DIR = 'path/to/layouts'
SAVE_DIR = 'path/to/save'
VAL_IMG_DIR = 'path/to/val/images'
VAL_ANN_DIR = 'path/to/val/annotations'
TEMPLATE_DIR = 'path/to/templates'
MODEL_PATH = 'path/to/save/model_final.pth'
LAYOUT_DIR = str((_BASE_CONFIG_PATH.parent / _CFG.paths.layout_dir).resolve()) if not Path(_CFG.paths.layout_dir).is_absolute() else _CFG.paths.layout_dir
SAVE_DIR = str((_BASE_CONFIG_PATH.parent / _CFG.paths.save_dir).resolve()) if not Path(_CFG.paths.save_dir).is_absolute() else _CFG.paths.save_dir
VAL_IMG_DIR = str((_BASE_CONFIG_PATH.parent / _CFG.paths.val_img_dir).resolve()) if not Path(_CFG.paths.val_img_dir).is_absolute() else _CFG.paths.val_img_dir
VAL_ANN_DIR = str((_BASE_CONFIG_PATH.parent / _CFG.paths.val_ann_dir).resolve()) if not Path(_CFG.paths.val_ann_dir).is_absolute() else _CFG.paths.val_ann_dir
TEMPLATE_DIR = str((_BASE_CONFIG_PATH.parent / _CFG.paths.template_dir).resolve()) if not Path(_CFG.paths.template_dir).is_absolute() else _CFG.paths.template_dir
MODEL_PATH = str((_BASE_CONFIG_PATH.parent / _CFG.paths.model_path).resolve()) if not Path(_CFG.paths.model_path).is_absolute() else _CFG.paths.model_path

25
configs/base_config.yaml Normal file
View File

@@ -0,0 +1,25 @@
training:
learning_rate: 5.0e-5
batch_size: 8
num_epochs: 50
patch_size: 256
scale_jitter_range: [0.8, 1.2]
matching:
keypoint_threshold: 0.5
ransac_reproj_threshold: 5.0
min_inliers: 15
pyramid_scales: [0.75, 1.0, 1.5]
inference_window_size: 1024
inference_stride: 768
evaluation:
iou_threshold: 0.5
paths:
layout_dir: "path/to/layouts"
save_dir: "path/to/save"
val_img_dir: "path/to/val/images"
val_ann_dir: "path/to/val/annotations"
template_dir: "path/to/templates"
model_path: "path/to/save/model_final.pth"

View File

@@ -0,0 +1 @@
from .ic_dataset import ICLayoutDataset, ICLayoutTrainingDataset

View File

@@ -1,7 +1,12 @@
import os
import json
from typing import Tuple
import cv2
import numpy as np
import torch
from PIL import Image
from torch.utils.data import Dataset
import json
class ICLayoutDataset(Dataset):
def __init__(self, image_dir, annotation_dir=None, transform=None):
@@ -54,3 +59,91 @@ class ICLayoutDataset(Dataset):
annotation = json.load(f)
return image, annotation
class ICLayoutTrainingDataset(Dataset):
"""自监督训练用的 IC 版图数据集,带数据增强与几何配准标签。"""
def __init__(
self,
image_dir: str,
patch_size: int = 256,
transform=None,
scale_range: Tuple[float, float] = (1.0, 1.0),
) -> None:
self.image_dir = image_dir
self.image_paths = [
os.path.join(image_dir, f)
for f in os.listdir(image_dir)
if f.endswith('.png')
]
self.patch_size = patch_size
self.transform = transform
self.scale_range = scale_range
def __len__(self) -> int:
return len(self.image_paths)
def __getitem__(self, index: int):
img_path = self.image_paths[index]
image = Image.open(img_path).convert('L')
width, height = image.size
# 随机尺度抖动
scale = float(np.random.uniform(self.scale_range[0], self.scale_range[1]))
crop_size = int(self.patch_size / max(scale, 1e-6))
crop_size = min(crop_size, width, height)
if crop_size <= 0:
raise ValueError("crop_size must be positive; check scale_range configuration")
x = np.random.randint(0, max(width - crop_size + 1, 1))
y = np.random.randint(0, max(height - crop_size + 1, 1))
patch = image.crop((x, y, x + crop_size, y + crop_size))
patch = patch.resize((self.patch_size, self.patch_size), Image.Resampling.LANCZOS)
# 亮度/对比度增强
if np.random.random() < 0.5:
brightness_factor = np.random.uniform(0.8, 1.2)
patch = patch.point(lambda px: int(np.clip(px * brightness_factor, 0, 255)))
if np.random.random() < 0.5:
contrast_factor = np.random.uniform(0.8, 1.2)
patch = patch.point(lambda px: int(np.clip(((px - 128) * contrast_factor) + 128, 0, 255)))
if np.random.random() < 0.3:
patch_np = np.array(patch, dtype=np.float32)
noise = np.random.normal(0, 5, patch_np.shape)
patch_np = np.clip(patch_np + noise, 0, 255)
patch = Image.fromarray(patch_np.astype(np.uint8))
patch_np_uint8 = np.array(patch)
# 随机旋转与镜像8个离散变换
theta_deg = int(np.random.choice([0, 90, 180, 270]))
is_mirrored = bool(np.random.choice([True, False]))
center_x, center_y = self.patch_size / 2.0, self.patch_size / 2.0
rotation_matrix = cv2.getRotationMatrix2D((center_x, center_y), theta_deg, 1.0)
if is_mirrored:
translate_to_origin = np.array([[1, 0, -center_x], [0, 1, -center_y], [0, 0, 1]])
mirror = np.array([[-1, 0, 0], [0, 1, 0], [0, 0, 1]])
translate_back = np.array([[1, 0, center_x], [0, 1, center_y], [0, 0, 1]])
mirror_matrix = translate_back @ mirror @ translate_to_origin
rotation_matrix_h = np.vstack([rotation_matrix, [0, 0, 1]])
homography = (rotation_matrix_h @ mirror_matrix).astype(np.float32)
else:
homography = np.vstack([rotation_matrix, [0, 0, 1]]).astype(np.float32)
transformed_patch_np = cv2.warpPerspective(patch_np_uint8, homography, (self.patch_size, self.patch_size))
transformed_patch = Image.fromarray(transformed_patch_np)
if self.transform:
patch_tensor = self.transform(patch)
transformed_tensor = self.transform(transformed_patch)
else:
patch_tensor = torch.from_numpy(np.array(patch)).float().unsqueeze(0) / 255.0
transformed_tensor = torch.from_numpy(np.array(transformed_patch)).float().unsqueeze(0) / 255.0
H_tensor = torch.from_numpy(homography[:2, :]).float()
return patch_tensor, transformed_tensor, H_tensor

166
docs/feature_work.md Normal file
View File

@@ -0,0 +1,166 @@
# 后续工作
本文档整合了 RoRD 项目的优化待办清单和训练需求,用于规划未来的开发和实验工作。
---
## RoRD 项目优化待办清单
本文档旨在为 RoRD (Rotation-Robust Descriptors) 项目提供一系列可行的优化任务。各项任务按优先级和模块划分,您可以根据项目进度和资源情况选择执行。
### 一、 数据策略与增强 (Data Strategy & Augmentation)
> *目标:提升模型的鲁棒性和泛化能力,减少对大量真实数据的依赖。*
- [ ] **引入弹性变形 (Elastic Transformations)**
- **✔️ 价值**: 模拟芯片制造中可能出现的微小物理形变,使模型对非刚性变化更鲁棒。
- **📝 执行方案**:
1. 添加 `albumentations` 库作为项目依赖。
2.`train.py``ICLayoutTrainingDataset` 类中,集成 `A.ElasticTransform` 到数据增强管道中。
- [ ] **创建合成版图数据生成器**
- **✔️ 价值**: 解决真实版图数据获取难、数量少的问题,通过程序化生成大量多样化的训练样本。
- **📝 执行方案**:
1. 创建一个新脚本,例如 `tools/generate_synthetic_layouts.py`
2. 利用 `gdstk` 库 编写函数,程序化地生成包含不同尺寸、密度和类型标准单元的 GDSII 文件。
3. 结合 `tools/layout2png.py` 的逻辑,将生成的版图批量转换为 PNG 图像,用于扩充训练集。
### 二、 模型架构 (Model Architecture)
> *目标:提升模型的特征提取效率和精度,降低计算资源消耗。*
- [ ] **实验更现代的骨干网络 (Backbone)**
- **✔️ 价值**: VGG-16 经典但效率偏低。新架构(如 ResNet, EfficientNet能以更少的参数量和计算量达到更好的性能。
- **📝 执行方案**:
1.`models/rord.py` 中,修改 `RoRD` 类的 `__init__` 方法。
2. 使用 `torchvision.models` 替换 `vgg16`。可尝试 `models.resnet34(pretrained=True)``models.efficientnet_b0(pretrained=True)` 作为替代方案。
3. 相应地调整检测头和描述子头的输入通道数。
- [ ] **集成注意力机制 (Attention Mechanism)**
- **✔️ 价值**: 引导模型自动关注版图中的关键几何结构(如边角、交点),忽略大面积的空白或重复区域,提升特征质量。
- **📝 执行方案**:
1. 寻找一个可靠的注意力模块实现,如 CBAM 或 SE-Net。
2.`models/rord.py` 中,将该模块插入到 `self.backbone` 和两个 `head` 之间。
### 三、 训练与损失函数 (Training & Loss Function)
> *目标:优化训练过程的稳定性,提升模型收敛效果。*
- [ ] **实现损失函数的自动加权**
- **✔️ 价值**: 当前检测损失和描述子损失是等权重相加,手动调参困难。自动加权可以使模型自主地平衡不同任务的优化难度。
- **📝 执行方案**:
1. 参考学术界关于“多任务学习中的不确定性加权” (Uncertainty Weighting) 的论文。
2.`train.py` 中,将损失权重定义为两个可学习的参数 `log_var_a``log_var_b`
3. 将总损失函数修改为 `loss = torch.exp(-log_var_a) * det_loss + log_var_a + torch.exp(-log_var_b) * desc_loss + log_var_b`
4. 将这两个新参数加入到优化器中进行训练。
- [ ] **实现基于关键点响应的困难样本采样**
- **✔️ 价值**: 提升描述子学习的效率。只在模型认为是“关键点”的区域进行采样,能让模型更专注于学习有区分度的特征。
- **📝 执行方案**:
1.`train.py``compute_description_loss` 函数中。
2. 获取 `det_original` 的输出图,进行阈值处理或 Top-K 选择,得到关键点的位置坐标。
3. 使用这些坐标,而不是 `torch.linspace` 生成的网格坐标,作为采样点来提取 `anchor``positive``negative` 描述子。
### 四、 推理与匹配 (Inference & Matching)
> *目标:大幅提升大尺寸版图的匹配速度和多尺度检测能力。*
- [ ] **将模型改造为特征金字塔网络 (FPN) 架构**
- **✔️ 价值**: 当前的多尺度匹配需要多次缩放图像并推理速度慢。FPN 只需一次推理即可获得所有尺度的特征,极大加速匹配过程。
- **📝 执行方案**:
1. 修改 `models/rord.py`,从骨干网络的不同层级(如 VGG 的 `relu2_2`, `relu3_3`, `relu4_3`)提取特征图。
2. 添加上采样和横向连接层来融合这些特征图,构建出特征金字塔。
3. 修改 `match.py`,使其能够直接从 FPN 的不同层级获取特征,替代原有的图像金字塔循环。
- [ ] **在滑动窗口匹配后增加关键点去重**
- **✔️ 价值**: `match.py` 中的滑动窗口在重叠区域会产生大量重复的关键点,增加后续匹配的计算量并可能影响精度。
- **📝 执行方案**:
1.`match.py``extract_features_sliding_window` 函数返回前。
2. 实现一个非极大值抑制 (NMS) 算法。
3. 根据关键点的位置和检测分数(需要模型输出强度图),对 `all_kps``all_descs` 进行过滤,去除冗余点。
### 五、 代码与项目结构 (Code & Project Structure)
> *目标:提升项目的可维护性、可扩展性和易用性。*
- [ ] **迁移配置到 YAML 文件**
- **✔️ 价值**: `config.py` 不利于管理多组实验配置。YAML 文件能让每组实验的参数独立、清晰,便于复现。
- **📝 执行方案**:
1. 创建一个 `configs` 目录,并编写一个 `base_config.yaml` 文件。
2. 引入 `OmegaConf``Hydra` 库。
3. 修改 `train.py``match.py` 等脚本,使其从 YAML 文件加载配置,而不是从 `config.py` 导入。
- [ ] **代码模块解耦**
- **✔️ 价值**: `train.py` 文件过长,职责过多。解耦能使代码结构更清晰,符合单一职责原则。
- **📝 执行方案**:
1.`ICLayoutTrainingDataset` 类从 `train.py` 移动到 `data/ic_dataset.py`
2. 创建一个新文件 `losses.py`,将 `compute_detection_loss``compute_description_loss` 函数移入其中。
### 六、 实验跟踪与评估 (Experiment Tracking & Evaluation)
> *目标:建立科学的实验流程,提供更全面的模型性能度量。*
- [ ] **集成实验跟踪工具 (TensorBoard / W&B)**
- **✔️ 价值**: 日志文件不利于直观对比实验结果。可视化工具可以实时监控、比较多组实验的损失和评估指标。
- **📝 执行方案**:
1.`train.py` 中,导入 `torch.utils.tensorboard.SummaryWriter`
2. 在训练循环中,使用 `writer.add_scalar()` 记录各项损失值。
3. 在验证结束后,记录评估指标和学习率等信息。
- [ ] **增加更全面的评估指标**
- **✔️ 价值**: 当前的评估指标 主要关注检测框的重合度。增加 mAP 和几何误差评估能更全面地衡量模型性能。
- **📝 执行方案**:
1.`evaluate.py` 中,实现 mAP (mean Average Precision) 的计算逻辑。
2. 在计算 IoU 匹配成功后,从 `match_template_multiscale` 返回的单应性矩阵 `H` 中,分解出旋转/平移等几何参数,并与真实变换进行比较,计算误差。
---
## 训练需求
### 1. 数据集类型
* **格式**: 训练数据为PNG格式的集成电路 (IC) 版图图像。这些图像可以是二值化的黑白图,也可以是灰度图。
* **来源**: 可以从 GDSII (.gds) 或 OASIS (.oas) 版图文件通过光栅化生成。
* **内容**: 数据集应包含多种不同区域、不同风格的版图,以确保模型的泛化能力。
* **标注**: **训练阶段无需任何人工标注**。模型采用自监督学习,通过对原图进行旋转、镜像等几何变换来自动生成训练对。
### 2. 数据集大小
* **启动阶段 (功能验证)**: **100 - 200 张** 高分辨率 (例如2048x2048) 的版图图像。这个规模足以验证训练流程是否能跑通,损失函数是否收敛。
* **初步可用模型**: **1,000 - 2,000 张** 版图图像。在这个数量级上,模型能学习到比较鲁棒的几何特征,在与训练数据相似的版图上取得不错的效果。
* **生产级模型**: **5,000 - 10,000+ 张** 版图图像。要让模型在各种不同工艺、设计风格的版图上都具有良好的泛化能力,需要大规模、多样化的数据集。
训练脚本 `train.py` 会将提供的数据集自动按 80/20 的比例划分为训练集和验证集。
### 3. 计算资源
* **硬件**: **一块支持 CUDA 的 NVIDIA GPU 是必需的**。考虑到模型的 VGG-16 骨干网络和复杂的几何感知损失函数,使用中高端 GPU 会显著提升训练效率。
* **推荐型号**:
* **入门级**: NVIDIA RTX 3060 / 4060
* **主流级**: NVIDIA RTX 3080 / 4070 / A4000
* **专业级**: NVIDIA RTX 3090 / 4090 / A6000
* **CPU 与内存**: 建议至少 8 核 CPU 和 32 GB 内存,以确保数据预处理和加载不会成为瓶颈。
### 4. 显存大小 (VRAM)
根据配置文件 `config.py``train.py` 中的参数,可以估算所需显存:
* **模型架构**: 基于 VGG-16。
* **批次大小 (Batch Size)**: 默认为 8。
* **图像块大小 (Patch Size)**: 256x256。
综合以上参数,并考虑到梯度和优化器状态的存储开销,**建议至少需要 12 GB 显存**。如果显存不足,需要将 `BATCH_SIZE` 减小 (例如 4 或 2),但这会牺牲训练速度和稳定性。
### 5. 训练时间估算
假设使用一块 **NVIDIA RTX 3080 (10GB)** 显卡和 **2,000 张** 版图图像的数据集:
* **单个 Epoch 时间**: 约 15 - 25 分钟。
* **总训练时间**: 配置文件中设置的总轮数 (Epochs) 为 50。
* `50 epochs * 20 分钟/epoch ≈ 16.7 小时`
* **收敛时间**: 项目引入了早停机制 (patience=10),如果验证集损失在 10 个 epoch 内没有改善,训练会提前停止。因此,实际训练时间可能在 **10 到 20 小时** 之间。
### 6. 逐步调优时间
调优是一个迭代过程,非常耗时。根据 `TRAINING_STRATEGY_ANALYSIS.md` 文件中提到的优化点 和进一步优化建议,调优阶段可能包括:
* **数据增强策略探索 (1-2周)**: 调整尺度抖动范围、亮度和对比度参数,尝试不同的噪声类型等。
* **损失函数权重平衡 (1-2周)**: `loss_function.md` 中提到了多种损失分量BCE, SmoothL1, Triplet, Manhattan, Sparsity, Binary调整它们之间的权重对模型性能至关重要。
* **超参数搜索 (2-4周)**: 对学习率、批次大小、优化器类型 (Adam, SGD等)、学习率调度策略等进行网格搜索或贝叶斯优化。
* **模型架构微调 (可选2-4周)**: 尝试不同的骨干网络 (如 ResNet)、修改检测头和描述子头的层数或通道数。
**总计,要达到一个稳定、可靠、泛化能力强的生产级模型,从数据准备到最终调优完成,预计需要 1 个半到 3 个月的时间。**

View File

@@ -1,17 +1,17 @@
# evaluate.py
import argparse
import json
import os
from pathlib import Path
import torch
from PIL import Image
import json
import os
import argparse
import config
from models.rord import RoRD
from utils.data_utils import get_transform
from data.ic_dataset import ICLayoutDataset
# (已修改) 导入新的匹配函数
from match import match_template_multiscale
from models.rord import RoRD
from utils.config_loader import load_config, to_absolute_path
from utils.data_utils import get_transform
def compute_iou(box1, box2):
x1, y1, w1, h1 = box1['x'], box1['y'], box1['width'], box1['height']
@@ -23,7 +23,7 @@ def compute_iou(box1, box2):
return inter_area / union_area if union_area > 0 else 0
# --- (已修改) 评估函数 ---
def evaluate(model, val_dataset_dir, val_annotations_dir, template_dir):
def evaluate(model, val_dataset_dir, val_annotations_dir, template_dir, matching_cfg, iou_threshold):
model.eval()
all_tp, all_fp, all_fn = 0, 0, 0
@@ -59,7 +59,7 @@ def evaluate(model, val_dataset_dir, val_annotations_dir, template_dir):
template_image = Image.open(template_path).convert('L')
# (已修改) 调用新的多尺度匹配函数
detected = match_template_multiscale(model, layout_image, template_image, transform)
detected = match_template_multiscale(model, layout_image, template_image, transform, matching_cfg)
gt_boxes = gt_by_template.get(template_name, [])
@@ -76,7 +76,7 @@ def evaluate(model, val_dataset_dir, val_annotations_dir, template_dir):
if iou > best_iou:
best_iou, best_gt_idx = iou, i
if best_iou > config.IOU_THRESHOLD:
if best_iou > iou_threshold:
if not matched_gt[best_gt_idx]:
tp += 1
matched_gt[best_gt_idx] = True
@@ -96,17 +96,29 @@ def evaluate(model, val_dataset_dir, val_annotations_dir, template_dir):
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="评估 RoRD 模型性能")
parser.add_argument('--model_path', type=str, default=config.MODEL_PATH)
parser.add_argument('--val_dir', type=str, default=config.VAL_IMG_DIR)
parser.add_argument('--annotations_dir', type=str, default=config.VAL_ANN_DIR)
parser.add_argument('--templates_dir', type=str, default=config.TEMPLATE_DIR)
parser.add_argument('--config', type=str, default="configs/base_config.yaml", help="YAML 配置文件路径")
parser.add_argument('--model_path', type=str, default=None, help="模型权重路径,若未提供则使用配置文件中的路径")
parser.add_argument('--val_dir', type=str, default=None, help="验证图像目录,若未提供则使用配置文件中的路径")
parser.add_argument('--annotations_dir', type=str, default=None, help="验证标注目录,若未提供则使用配置文件中的路径")
parser.add_argument('--templates_dir', type=str, default=None, help="模板目录,若未提供则使用配置文件中的路径")
args = parser.parse_args()
model = RoRD().cuda()
model.load_state_dict(torch.load(args.model_path))
cfg = load_config(args.config)
config_dir = Path(args.config).resolve().parent
paths_cfg = cfg.paths
matching_cfg = cfg.matching
eval_cfg = cfg.evaluation
# (已修改) 不再需要预加载数据集,直接传入路径
results = evaluate(model, args.val_dir, args.annotations_dir, args.templates_dir)
model_path = args.model_path or str(to_absolute_path(paths_cfg.model_path, config_dir))
val_dir = args.val_dir or str(to_absolute_path(paths_cfg.val_img_dir, config_dir))
annotations_dir = args.annotations_dir or str(to_absolute_path(paths_cfg.val_ann_dir, config_dir))
templates_dir = args.templates_dir or str(to_absolute_path(paths_cfg.template_dir, config_dir))
iou_threshold = float(eval_cfg.iou_threshold)
model = RoRD().cuda()
model.load_state_dict(torch.load(model_path))
results = evaluate(model, val_dir, annotations_dir, templates_dir, matching_cfg, iou_threshold)
print("\n--- 评估结果 ---")
print(f" 精确率 (Precision): {results['precision']:.4f}")

138
losses.py Normal file
View File

@@ -0,0 +1,138 @@
"""Loss utilities for RoRD training."""
from __future__ import annotations
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
def _augment_homography_matrix(h_2x3: torch.Tensor) -> torch.Tensor:
"""Append the third row [0, 0, 1] to build a full 3x3 homography."""
if h_2x3.dim() != 3 or h_2x3.size(1) != 2 or h_2x3.size(2) != 3:
raise ValueError("Expected homography with shape (B, 2, 3)")
batch_size = h_2x3.size(0)
device = h_2x3.device
bottom_row = torch.tensor([0.0, 0.0, 1.0], device=device, dtype=h_2x3.dtype)
bottom_row = bottom_row.view(1, 1, 3).expand(batch_size, -1, -1)
return torch.cat([h_2x3, bottom_row], dim=1)
def warp_feature_map(feature_map: torch.Tensor, h_inv: torch.Tensor) -> torch.Tensor:
"""Warp feature map according to inverse homography."""
return F.grid_sample(
feature_map,
F.affine_grid(h_inv, feature_map.size(), align_corners=False),
align_corners=False,
)
def compute_detection_loss(
det_original: torch.Tensor,
det_rotated: torch.Tensor,
h: torch.Tensor,
) -> torch.Tensor:
"""Binary cross-entropy + smooth L1 detection loss."""
h_full = _augment_homography_matrix(h)
h_inv = torch.inverse(h_full)[:, :2, :]
warped_det = warp_feature_map(det_rotated, h_inv)
bce_loss = F.binary_cross_entropy(det_original, warped_det)
smooth_l1_loss = F.smooth_l1_loss(det_original, warped_det)
return bce_loss + 0.1 * smooth_l1_loss
def compute_description_loss(
desc_original: torch.Tensor,
desc_rotated: torch.Tensor,
h: torch.Tensor,
margin: float = 1.0,
) -> torch.Tensor:
"""Triplet-style descriptor loss with Manhattan-aware sampling."""
batch_size, channels, height, width = desc_original.size()
num_samples = 200
grid_side = int(math.sqrt(num_samples))
h_coords = torch.linspace(-1, 1, grid_side, device=desc_original.device)
w_coords = torch.linspace(-1, 1, grid_side, device=desc_original.device)
manhattan_h = torch.cat([h_coords, torch.zeros_like(h_coords)])
manhattan_w = torch.cat([torch.zeros_like(w_coords), w_coords])
manhattan_coords = torch.stack([manhattan_h, manhattan_w], dim=1)
manhattan_coords = manhattan_coords.unsqueeze(0).repeat(batch_size, 1, 1)
anchor = F.grid_sample(
desc_original,
manhattan_coords.unsqueeze(1),
align_corners=False,
).squeeze(2).transpose(1, 2)
coords_hom = torch.cat(
[manhattan_coords, torch.ones(batch_size, manhattan_coords.size(1), 1, device=desc_original.device)],
dim=2,
)
h_full = _augment_homography_matrix(h)
h_inv = torch.inverse(h_full)
coords_transformed = (coords_hom @ h_inv.transpose(1, 2))[:, :, :2]
positive = F.grid_sample(
desc_rotated,
coords_transformed.unsqueeze(1),
align_corners=False,
).squeeze(2).transpose(1, 2)
negative_list = []
if manhattan_coords.size(1) > 0:
angles = [0, 90, 180, 270]
for angle in angles:
if angle == 0:
continue
theta = torch.tensor(angle * math.pi / 180.0, device=desc_original.device)
cos_t = torch.cos(theta)
sin_t = torch.sin(theta)
rot = torch.stack(
[
torch.stack([cos_t, -sin_t]),
torch.stack([sin_t, cos_t]),
]
)
rotated_coords = manhattan_coords @ rot.T
negative_list.append(rotated_coords)
if negative_list:
neg_coords = torch.stack(negative_list, dim=1).reshape(batch_size, -1, 2)
negative_candidates = F.grid_sample(
desc_rotated,
neg_coords.unsqueeze(1),
align_corners=False,
).squeeze(2).transpose(1, 2)
anchor_expanded = anchor.unsqueeze(2).expand(-1, -1, negative_candidates.size(1), -1)
negative_expanded = negative_candidates.unsqueeze(1).expand(-1, anchor.size(1), -1, -1)
manhattan_dist = torch.sum(torch.abs(anchor_expanded - negative_expanded), dim=3)
k = max(anchor.size(1) // 2, 1)
hard_indices = torch.topk(manhattan_dist, k=k, largest=False)[1]
idx_expand = hard_indices.unsqueeze(-1).expand(-1, -1, -1, negative_candidates.size(2))
negative = torch.gather(negative_candidates.unsqueeze(1).expand(-1, anchor.size(1), -1, -1), 2, idx_expand)
negative = negative.mean(dim=2)
else:
negative = torch.zeros_like(anchor)
triplet_loss = nn.TripletMarginLoss(margin=margin, p=1, reduction='mean')
geometric_triplet = triplet_loss(anchor, positive, negative)
manhattan_loss = 0.0
for i in range(anchor.size(1)):
anchor_norm = F.normalize(anchor[:, i], p=2, dim=1)
positive_norm = F.normalize(positive[:, i], p=2, dim=1)
cos_sim = torch.sum(anchor_norm * positive_norm, dim=1)
manhattan_loss += torch.mean(1 - cos_sim)
manhattan_loss = manhattan_loss / max(anchor.size(1), 1)
sparsity_loss = torch.mean(torch.abs(anchor)) + torch.mean(torch.abs(positive))
binary_loss = torch.mean(torch.abs(torch.sign(anchor) - torch.sign(positive)))
return geometric_triplet + 0.1 * manhattan_loss + 0.01 * sparsity_loss + 0.05 * binary_loss

View File

@@ -1,15 +1,17 @@
# match.py
import torch
import torch.nn.functional as F
import numpy as np
import cv2
from PIL import Image
import argparse
import os
from pathlib import Path
import cv2
import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
import config
from models.rord import RoRD
from utils.config_loader import load_config, to_absolute_path
from utils.data_utils import get_transform
# --- 特征提取函数 (基本无变动) ---
@@ -39,15 +41,16 @@ def extract_keypoints_and_descriptors(model, image_tensor, kp_thresh):
return keypoints, descriptors
# --- (新增) 滑动窗口特征提取函数 ---
def extract_features_sliding_window(model, large_image, transform):
def extract_features_sliding_window(model, large_image, transform, matching_cfg):
"""
使用滑动窗口从大图上提取所有关键点和描述子
"""
print("使用滑动窗口提取大版图特征...")
device = next(model.parameters()).device
W, H = large_image.size
window_size = config.INFERENCE_WINDOW_SIZE
stride = config.INFERENCE_STRIDE
window_size = int(matching_cfg.inference_window_size)
stride = int(matching_cfg.inference_stride)
keypoint_threshold = float(matching_cfg.keypoint_threshold)
all_kps = []
all_descs = []
@@ -65,7 +68,7 @@ def extract_features_sliding_window(model, large_image, transform):
patch_tensor = transform(patch).unsqueeze(0).to(device)
# 提取特征
kps, descs = extract_keypoints_and_descriptors(model, patch_tensor, config.KEYPOINT_THRESHOLD)
kps, descs = extract_keypoints_and_descriptors(model, patch_tensor, keypoint_threshold)
if len(kps) > 0:
# 将局部坐标转换为全局坐标
@@ -94,26 +97,30 @@ def mutual_nearest_neighbor(descs1, descs2):
return matches
# --- (已修改) 多尺度、多实例匹配主函数 ---
def match_template_multiscale(model, layout_image, template_image, transform):
def match_template_multiscale(model, layout_image, template_image, transform, matching_cfg):
"""
在不同尺度下搜索模板,并检测多个实例
"""
# 1. 对大版图使用滑动窗口提取全部特征
layout_kps, layout_descs = extract_features_sliding_window(model, layout_image, transform)
layout_kps, layout_descs = extract_features_sliding_window(model, layout_image, transform, matching_cfg)
if len(layout_kps) < config.MIN_INLIERS:
min_inliers = int(matching_cfg.min_inliers)
if len(layout_kps) < min_inliers:
print("从大版图中提取的关键点过少,无法进行匹配。")
return []
found_instances = []
active_layout_mask = torch.ones(len(layout_kps), dtype=bool, device=layout_kps.device)
pyramid_scales = [float(s) for s in matching_cfg.pyramid_scales]
keypoint_threshold = float(matching_cfg.keypoint_threshold)
ransac_threshold = float(matching_cfg.ransac_reproj_threshold)
# 2. 多实例迭代检测
while True:
current_active_indices = torch.nonzero(active_layout_mask).squeeze(1)
# 如果剩余活动关键点过少,则停止
if len(current_active_indices) < config.MIN_INLIERS:
if len(current_active_indices) < min_inliers:
break
current_layout_kps = layout_kps[current_active_indices]
@@ -123,7 +130,7 @@ def match_template_multiscale(model, layout_image, template_image, transform):
# 3. 图像金字塔:遍历模板的每个尺度
print("在新尺度下搜索模板...")
for scale in config.PYRAMID_SCALES:
for scale in pyramid_scales:
W, H = template_image.size
new_W, new_H = int(W * scale), int(H * scale)
@@ -132,7 +139,7 @@ def match_template_multiscale(model, layout_image, template_image, transform):
template_tensor = transform(scaled_template).unsqueeze(0).to(layout_kps.device)
# 提取缩放后模板的特征
template_kps, template_descs = extract_keypoints_and_descriptors(model, template_tensor, config.KEYPOINT_THRESHOLD)
template_kps, template_descs = extract_keypoints_and_descriptors(model, template_tensor, keypoint_threshold)
if len(template_kps) < 4: continue
@@ -147,13 +154,13 @@ def match_template_multiscale(model, layout_image, template_image, transform):
dst_pts_indices = current_active_indices[matches[:, 1]]
dst_pts = layout_kps[dst_pts_indices].cpu().numpy()
H, mask = cv2.findHomography(src_pts, dst_pts, cv2.RANSAC, config.RANSAC_REPROJ_THRESHOLD)
H, mask = cv2.findHomography(src_pts, dst_pts, cv2.RANSAC, ransac_threshold)
if H is not None and mask.sum() > best_match_info['inliers']:
best_match_info = {'inliers': mask.sum(), 'H': H, 'mask': mask, 'scale': scale, 'dst_pts': dst_pts}
# 4. 如果在所有尺度中找到了最佳匹配,则记录并屏蔽
if best_match_info['inliers'] > config.MIN_INLIERS:
if best_match_info['inliers'] > min_inliers:
print(f"找到一个匹配实例!内点数: {best_match_info['inliers']}, 使用的模板尺度: {best_match_info['scale']:.2f}x")
inlier_mask = best_match_info['mask'].ravel().astype(bool)
@@ -191,21 +198,27 @@ def visualize_matches(layout_path, bboxes, output_path):
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="使用 RoRD 进行多尺度模板匹配")
parser.add_argument('--model_path', type=str, default=config.MODEL_PATH)
parser.add_argument('--config', type=str, default="configs/base_config.yaml", help="YAML 配置文件路径")
parser.add_argument('--model_path', type=str, default=None, help="模型权重路径,若未提供则使用配置文件中的路径")
parser.add_argument('--layout', type=str, required=True)
parser.add_argument('--template', type=str, required=True)
parser.add_argument('--output', type=str)
args = parser.parse_args()
cfg = load_config(args.config)
config_dir = Path(args.config).resolve().parent
matching_cfg = cfg.matching
model_path = args.model_path or str(to_absolute_path(cfg.paths.model_path, config_dir))
transform = get_transform()
model = RoRD().cuda()
model.load_state_dict(torch.load(args.model_path))
model.load_state_dict(torch.load(model_path))
model.eval()
layout_image = Image.open(args.layout).convert('L')
template_image = Image.open(args.template).convert('L')
detected_bboxes = match_template_multiscale(model, layout_image, template_image, transform)
detected_bboxes = match_template_multiscale(model, layout_image, template_image, transform, matching_cfg)
print("\n检测到的边界框:")
for bbox in detected_bboxes:

View File

@@ -14,6 +14,7 @@ dependencies = [
"pillow>=11.2.1",
"torch>=2.7.1",
"torchvision>=0.22.1",
"omegaconf>=2.3.0",
]
[[tool.uv.index]]

View File

@@ -1,159 +0,0 @@
# tools/klayoutconvertor.py
#!/usr/bin/env python3
"""
KLayout GDS to PNG Converter
This script uses KLayout's Python API to convert GDS files to PNG images.
It accepts command-line arguments for input parameters.
Requirements:
pip install klayout
Usage:
python klayoutconvertor.py input.gds output.png [options]
"""
import klayout.db as pya
import klayout.lay as lay
from PIL import Image
import os
import argparse
import sys
Image.MAX_IMAGE_PIXELS = None
def export_gds_as_image(
gds_path: str,
output_path: str,
layers: list = [1, 2],
center_um: tuple = (0, 0),
view_size_um: float = 100.0,
resolution: int = 2048,
binarize: bool = True
) -> None:
"""
Export GDS file as PNG image using KLayout.
Args:
gds_path: Input GDS file path
output_path: Output PNG file path
layers: List of layer numbers to include
center_um: Center coordinates in micrometers (x, y)
view_size_um: View size in micrometers
resolution: Output image resolution
binarize: Whether to convert to black and white
"""
if not os.path.exists(gds_path):
raise FileNotFoundError(f"Input file not found: {gds_path}")
# Ensure output directory exists
output_dir = os.path.dirname(output_path)
if output_dir:
os.makedirs(output_dir, exist_ok=True)
layout = pya.Layout()
layout.read(gds_path)
top = layout.top_cell()
# Create layout view
view = lay.LayoutView()
view.set_config("background-color", "#ffffff")
view.set_config("grid-visible", "false")
# Load layout into view correctly
view.load_layout(gds_path)
# Add all layers
view.add_missing_layers()
# Configure view to show entire layout with reasonable resolution
if view_size_um > 0:
# Use specified view size
box = pya.DBox(
center_um[0] - view_size_um / 2,
center_um[1] - view_size_um / 2,
center_um[0] + view_size_um / 2,
center_um[1] + view_size_um / 2
)
else:
# Use full layout bounds with size limit
bbox = top.bbox()
if bbox:
# Convert to micrometers (KLayout uses database units)
dbu = layout.dbu
box = pya.DBox(
bbox.left * dbu,
bbox.bottom * dbu,
bbox.right * dbu,
bbox.top * dbu
)
else:
# Fallback to 100x100 um if empty layout
box = pya.DBox(-50, -50, 50, 50)
view.max_hier()
view.zoom_box(box)
# Save to temporary file first, then load with PIL
import tempfile
temp_path = tempfile.NamedTemporaryFile(suffix='.png', delete=False).name
try:
view.save_image(temp_path, resolution, resolution)
img = Image.open(temp_path)
if binarize:
# Convert to grayscale and binarize
img = img.convert("L")
img = img.point(lambda x: 255 if x > 128 else 0, '1')
else:
# Convert to grayscale
img = img.convert("L")
img.save(output_path)
finally:
# Clean up temp file
if os.path.exists(temp_path):
os.unlink(temp_path)
def main():
"""Main CLI entry point."""
parser = argparse.ArgumentParser(description='Convert GDS to PNG using KLayout')
parser.add_argument('input', help='Input GDS file')
parser.add_argument('output', help='Output PNG file')
parser.add_argument('--layers', nargs='+', type=int, default=[1, 2],
help='Layers to include (default: 1 2)')
parser.add_argument('--center-x', type=float, default=0,
help='Center X coordinate in micrometers (default: 0)')
parser.add_argument('--center-y', type=float, default=0,
help='Center Y coordinate in micrometers (default: 0)')
parser.add_argument('--size', type=float, default=0,
help='View size in micrometers (default: 0 = full layout)')
parser.add_argument('--resolution', type=int, default=2048,
help='Output image resolution (default: 2048)')
parser.add_argument('--no-binarize', action='store_true',
help='Disable binarization (keep grayscale)')
args = parser.parse_args()
try:
export_gds_as_image(
gds_path=args.input,
output_path=args.output,
layers=args.layers,
center_um=(args.center_x, args.center_y),
view_size_um=args.size,
resolution=args.resolution,
binarize=not args.no_binarize
)
print("Conversion completed successfully!")
except Exception as e:
print(f"Error: {e}")
sys.exit(1)
if __name__ == '__main__':
main()

271
train.py
View File

@@ -1,20 +1,18 @@
# train.py
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import numpy as np
import cv2
import os
import argparse
import logging
import os
from datetime import datetime
from pathlib import Path
# 导入项目模块
import config
import torch
from torch.utils.data import DataLoader
from data.ic_dataset import ICLayoutTrainingDataset
from losses import compute_detection_loss, compute_description_loss
from models.rord import RoRD
from utils.config_loader import load_config, to_absolute_path
from utils.data_utils import get_transform
# 设置日志记录
@@ -34,207 +32,33 @@ def setup_logging(save_dir):
)
return logging.getLogger(__name__)
# --- (已修改) 训练专用数据集类 ---
class ICLayoutTrainingDataset(Dataset):
def __init__(self, image_dir, patch_size=256, transform=None, scale_range=(1.0, 1.0)):
self.image_dir = image_dir
self.image_paths = [os.path.join(image_dir, f) for f in os.listdir(image_dir) if f.endswith('.png')]
self.patch_size = patch_size
self.transform = transform
self.scale_range = scale_range # 新增尺度范围参数
def __len__(self):
return len(self.image_paths)
def __getitem__(self, index):
img_path = self.image_paths[index]
image = Image.open(img_path).convert('L')
W, H = image.size
# --- 新增:尺度抖动数据增强 ---
# 1. 随机选择一个缩放比例
scale = np.random.uniform(self.scale_range[0], self.scale_range[1])
# 2. 根据缩放比例计算需要从原图裁剪的尺寸
crop_size = int(self.patch_size / scale)
# 确保裁剪尺寸不超过图像边界
if crop_size > min(W, H):
crop_size = min(W, H)
# 3. 随机裁剪
x = np.random.randint(0, W - crop_size + 1)
y = np.random.randint(0, H - crop_size + 1)
patch = image.crop((x, y, x + crop_size, y + crop_size))
# 4. 将裁剪出的图像块缩放回标准的 patch_size
patch = patch.resize((self.patch_size, self.patch_size), Image.Resampling.LANCZOS)
# --- 尺度抖动结束 ---
# --- 新增:额外的数据增强 ---
# 亮度调整
if np.random.random() < 0.5:
brightness_factor = np.random.uniform(0.8, 1.2)
patch = patch.point(lambda x: int(x * brightness_factor))
# 对比度调整
if np.random.random() < 0.5:
contrast_factor = np.random.uniform(0.8, 1.2)
patch = patch.point(lambda x: int(((x - 128) * contrast_factor) + 128))
# 添加噪声
if np.random.random() < 0.3:
patch_np = np.array(patch, dtype=np.float32)
noise = np.random.normal(0, 5, patch_np.shape)
patch_np = np.clip(patch_np + noise, 0, 255)
patch = Image.fromarray(patch_np.astype(np.uint8))
# --- 额外数据增强结束 ---
patch_np = np.array(patch)
# 实现8个方向的离散几何变换 (这部分逻辑不变)
theta_deg = np.random.choice([0, 90, 180, 270])
is_mirrored = np.random.choice([True, False])
cx, cy = self.patch_size / 2.0, self.patch_size / 2.0
M = cv2.getRotationMatrix2D((cx, cy), theta_deg, 1)
if is_mirrored:
T1 = np.array([[1, 0, -cx], [0, 1, -cy], [0, 0, 1]])
Flip = np.array([[-1, 0, 0], [0, 1, 0], [0, 0, 1]])
T2 = np.array([[1, 0, cx], [0, 1, cy], [0, 0, 1]])
M_mirror_3x3 = T2 @ Flip @ T1
M_3x3 = np.vstack([M, [0, 0, 1]])
H = (M_3x3 @ M_mirror_3x3).astype(np.float32)
else:
H = np.vstack([M, [0, 0, 1]]).astype(np.float32)
transformed_patch_np = cv2.warpPerspective(patch_np, H, (self.patch_size, self.patch_size))
transformed_patch = Image.fromarray(transformed_patch_np)
if self.transform:
patch = self.transform(patch)
transformed_patch = self.transform(transformed_patch)
H_tensor = torch.from_numpy(H[:2, :]).float()
return patch, transformed_patch, H_tensor
# --- 特征图变换与损失函数 (改进版) ---
def warp_feature_map(feature_map, H_inv):
B, C, H, W = feature_map.size()
grid = F.affine_grid(H_inv, feature_map.size(), align_corners=False).to(feature_map.device)
return F.grid_sample(feature_map, grid, align_corners=False)
def compute_detection_loss(det_original, det_rotated, H):
"""改进的检测损失使用BCE损失替代MSE"""
with torch.no_grad():
H_inv = torch.inverse(torch.cat([H, torch.tensor([0.0, 0.0, 1.0]).view(1, 1, 3).repeat(H.shape[0], 1, 1)], dim=1))[:, :2, :]
warped_det_rotated = warp_feature_map(det_rotated, H_inv)
# 使用BCE损失更适合二分类问题
bce_loss = F.binary_cross_entropy(det_original, warped_det_rotated)
# 添加平滑L1损失作为辅助
smooth_l1_loss = F.smooth_l1_loss(det_original, warped_det_rotated)
return bce_loss + 0.1 * smooth_l1_loss
def compute_description_loss(desc_original, desc_rotated, H, margin=1.0):
"""IC版图专用几何感知描述子损失编码曼哈顿几何特征"""
B, C, H_feat, W_feat = desc_original.size()
# 曼哈顿几何感知采样:重点采样边缘和角点区域
num_samples = 200
# 生成曼哈顿对齐的采样网格(水平和垂直优先)
h_coords = torch.linspace(-1, 1, int(np.sqrt(num_samples)), device=desc_original.device)
w_coords = torch.linspace(-1, 1, int(np.sqrt(num_samples)), device=desc_original.device)
# 增加曼哈顿方向的采样密度
manhattan_h = torch.cat([h_coords, torch.zeros_like(h_coords)])
manhattan_w = torch.cat([torch.zeros_like(w_coords), w_coords])
manhattan_coords = torch.stack([manhattan_h, manhattan_w], dim=1).unsqueeze(0).repeat(B, 1, 1)
# 采样anchor点
anchor = F.grid_sample(desc_original, manhattan_coords.unsqueeze(1), align_corners=False).squeeze(2).transpose(1, 2)
# 计算对应的正样本点
coords_hom = torch.cat([manhattan_coords, torch.ones(B, manhattan_coords.size(1), 1, device=manhattan_coords.device)], dim=2)
M_inv = torch.inverse(torch.cat([H, torch.tensor([0.0, 0.0, 1.0]).view(1, 1, 3).repeat(H.shape[0], 1, 1)], dim=1))
coords_transformed = (coords_hom @ M_inv.transpose(1, 2))[:, :, :2]
positive = F.grid_sample(desc_rotated, coords_transformed.unsqueeze(1), align_corners=False).squeeze(2).transpose(1, 2)
# IC版图专用负样本策略考虑重复结构
with torch.no_grad():
# 1. 几何感知的负样本:曼哈顿变换后的不同区域
neg_coords = []
for b in range(B):
# 生成曼哈顿变换后的坐标90度旋转等
angles = [0, 90, 180, 270]
for angle in angles:
if angle != 0:
theta = torch.tensor([angle * np.pi / 180])
rot_matrix = torch.tensor([
[torch.cos(theta), -torch.sin(theta), 0],
[torch.sin(theta), torch.cos(theta), 0]
])
rotated_coords = manhattan_coords[b] @ rot_matrix[:2, :2].T
neg_coords.append(rotated_coords)
neg_coords = torch.stack(neg_coords[:B*num_samples//2]).reshape(B, -1, 2)
# 2. 特征空间困难负样本
negative_candidates = F.grid_sample(desc_rotated, neg_coords, align_corners=False).squeeze(2).transpose(1, 2)
# 3. 曼哈顿距离约束的困难样本选择
anchor_expanded = anchor.unsqueeze(2).expand(-1, -1, negative_candidates.size(1), -1)
negative_expanded = negative_candidates.unsqueeze(1).expand(-1, anchor.size(1), -1, -1)
# 使用曼哈顿距离而非欧氏距离
manhattan_dist = torch.sum(torch.abs(anchor_expanded - negative_expanded), dim=3)
hard_indices = torch.topk(manhattan_dist, k=anchor.size(1)//2, largest=False)[1]
negative = torch.gather(negative_candidates, 1, hard_indices)
# IC版图专用的几何一致性损失
# 1. 曼哈顿方向一致性损失
manhattan_loss = 0
for i in range(anchor.size(1)):
# 计算水平和垂直方向的几何一致性
anchor_norm = F.normalize(anchor[:, i], p=2, dim=1)
positive_norm = F.normalize(positive[:, i], p=2, dim=1)
# 鼓励描述子对曼哈顿变换不变
cos_sim = torch.sum(anchor_norm * positive_norm, dim=1)
manhattan_loss += torch.mean(1 - cos_sim)
# 2. 稀疏性正则化IC版图特征稀疏
sparsity_loss = torch.mean(torch.abs(anchor)) + torch.mean(torch.abs(positive))
# 3. 二值化特征距离(处理二值化输入)
binary_loss = torch.mean(torch.abs(torch.sign(anchor) - torch.sign(positive)))
# 综合损失
triplet_loss = nn.TripletMarginLoss(margin=margin, p=1, reduction='mean') # 使用L1距离
geometric_triplet = triplet_loss(anchor, positive, negative)
return geometric_triplet + 0.1 * manhattan_loss + 0.01 * sparsity_loss + 0.05 * binary_loss
# --- (已修改) 主函数与命令行接口 ---
def main(args):
# 设置日志记录
logger = setup_logging(args.save_dir)
cfg = load_config(args.config)
config_dir = Path(args.config).resolve().parent
data_dir = args.data_dir or str(to_absolute_path(cfg.paths.layout_dir, config_dir))
save_dir = args.save_dir or str(to_absolute_path(cfg.paths.save_dir, config_dir))
epochs = args.epochs if args.epochs is not None else int(cfg.training.num_epochs)
batch_size = args.batch_size if args.batch_size is not None else int(cfg.training.batch_size)
lr = args.lr if args.lr is not None else float(cfg.training.learning_rate)
patch_size = int(cfg.training.patch_size)
scale_range = tuple(float(x) for x in cfg.training.scale_jitter_range)
logger = setup_logging(save_dir)
logger.info("--- 开始训练 RoRD 模型 ---")
logger.info(f"训练参数: Epochs={args.epochs}, Batch Size={args.batch_size}, LR={args.lr}")
logger.info(f"数据目录: {args.data_dir}")
logger.info(f"保存目录: {args.save_dir}")
logger.info(f"训练参数: Epochs={epochs}, Batch Size={batch_size}, LR={lr}")
logger.info(f"数据目录: {data_dir}")
logger.info(f"保存目录: {save_dir}")
transform = get_transform()
# 在数据集初始化时传入尺度抖动范围
dataset = ICLayoutTrainingDataset(
args.data_dir,
patch_size=config.PATCH_SIZE,
data_dir,
patch_size=patch_size,
transform=transform,
scale_range=config.SCALE_JITTER_RANGE
scale_range=scale_range,
)
logger.info(f"数据集大小: {len(dataset)}")
@@ -246,13 +70,13 @@ def main(args):
logger.info(f"训练集大小: {len(train_dataset)}, 验证集大小: {len(val_dataset)}")
train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=4)
val_dataloader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
model = RoRD().cuda()
logger.info(f"模型参数数量: {sum(p.numel() for p in model.parameters()):,}")
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-4)
optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-4)
# 添加学习率调度器
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
@@ -264,7 +88,7 @@ def main(args):
patience_counter = 0
patience = 10
for epoch in range(args.epochs):
for epoch in range(epochs):
# 训练阶段
model.train()
total_train_loss = 0
@@ -339,18 +163,19 @@ def main(args):
patience_counter = 0
# 保存最佳模型
if not os.path.exists(args.save_dir):
os.makedirs(args.save_dir)
save_path = os.path.join(args.save_dir, 'rord_model_best.pth')
if not os.path.exists(save_dir):
os.makedirs(save_dir)
save_path = os.path.join(save_dir, 'rord_model_best.pth')
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'best_val_loss': best_val_loss,
'config': {
'learning_rate': args.lr,
'batch_size': args.batch_size,
'epochs': args.epochs
'learning_rate': lr,
'batch_size': batch_size,
'epochs': epochs,
'config_path': str(Path(args.config).resolve()),
}
}, save_path)
logger.info(f"最佳模型已保存至: {save_path}")
@@ -361,16 +186,17 @@ def main(args):
break
# 保存最终模型
save_path = os.path.join(args.save_dir, 'rord_model_final.pth')
save_path = os.path.join(save_dir, 'rord_model_final.pth')
torch.save({
'epoch': args.epochs,
'epoch': epochs,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'final_val_loss': avg_val_loss,
'config': {
'learning_rate': args.lr,
'batch_size': args.batch_size,
'epochs': args.epochs
'learning_rate': lr,
'batch_size': batch_size,
'epochs': epochs,
'config_path': str(Path(args.config).resolve()),
}
}, save_path)
logger.info(f"最终模型已保存至: {save_path}")
@@ -378,9 +204,10 @@ def main(args):
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="训练 RoRD 模型")
parser.add_argument('--data_dir', type=str, default=config.LAYOUT_DIR)
parser.add_argument('--save_dir', type=str, default=config.SAVE_DIR)
parser.add_argument('--epochs', type=int, default=config.NUM_EPOCHS)
parser.add_argument('--batch_size', type=int, default=config.BATCH_SIZE)
parser.add_argument('--lr', type=float, default=config.LEARNING_RATE)
parser.add_argument('--config', type=str, default="configs/base_config.yaml", help="YAML 配置文件路径")
parser.add_argument('--data_dir', type=str, default=None, help="训练数据目录,若未提供则使用配置文件中的路径")
parser.add_argument('--save_dir', type=str, default=None, help="模型保存目录,若未提供则使用配置文件中的路径")
parser.add_argument('--epochs', type=int, default=None, help="训练轮数,若未提供则使用配置文件中的值")
parser.add_argument('--batch_size', type=int, default=None, help="批次大小,若未提供则使用配置文件中的值")
parser.add_argument('--lr', type=float, default=None, help="学习率,若未提供则使用配置文件中的值")
main(parser.parse_args())

23
utils/config_loader.py Normal file
View File

@@ -0,0 +1,23 @@
"""Configuration loading utilities using OmegaConf."""
from __future__ import annotations
from pathlib import Path
from typing import Union
from omegaconf import DictConfig, OmegaConf
def load_config(config_path: Union[str, Path]) -> DictConfig:
"""Load a YAML configuration file into a DictConfig."""
path = Path(config_path)
if not path.exists():
raise FileNotFoundError(f"Config file not found: {path}")
return OmegaConf.load(path)
def to_absolute_path(path_str: str, base_dir: Union[str, Path]) -> Path:
"""Resolve a possibly relative path against the configuration file directory."""
path = Path(path_str).expanduser()
if path.is_absolute():
return path.resolve()
return (Path(base_dir) / path).resolve()

47
uv.lock generated
View File

@@ -7,6 +7,12 @@ resolution-markers = [
"(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')",
]
[[package]]
name = "antlr4-python3-runtime"
version = "4.9.3"
source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" }
sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/3e/38/7859ff46355f76f8d19459005ca000b6e7012f2f1ca597746cbcd1fbfe5e/antlr4-python3-runtime-4.9.3.tar.gz", hash = "sha256:f224469b4168294902bb1efa80a8bf7855f24c99aef99cbefc1bcd3cce77881b", size = 117034, upload-time = "2021-11-06T17:52:23.524Z" }
[[package]]
name = "cairocffi"
version = "1.7.1"
@@ -408,6 +414,19 @@ wheels = [
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/9e/4e/0d0c945463719429b7bd21dece907ad0bde437a2ff12b9b12fee94722ab0/nvidia_nvtx_cu12-12.6.77-py3-none-manylinux2014_x86_64.whl", hash = "sha256:6574241a3ec5fdc9334353ab8c479fe75841dbe8f4532a8fc97ce63503330ba1", size = 89265, upload-time = "2024-10-01T17:00:38.172Z" },
]
[[package]]
name = "omegaconf"
version = "2.3.0"
source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" }
dependencies = [
{ name = "antlr4-python3-runtime" },
{ name = "pyyaml" },
]
sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/09/48/6388f1bb9da707110532cb70ec4d2822858ddfb44f1cdf1233c20a80ea4b/omegaconf-2.3.0.tar.gz", hash = "sha256:d5d4b6d29955cc50ad50c46dc269bcd92c6e00f5f90d23ab5fee7bfca4ba4cc7", size = 3298120, upload-time = "2022-12-08T20:59:22.753Z" }
wheels = [
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/e3/94/1843518e420fa3ed6919835845df698c7e27e183cb997394e4a670973a65/omegaconf-2.3.0-py3-none-any.whl", hash = "sha256:7b4df175cdb08ba400f45cae3bdcae7ba8365db4d165fc65fd04b050ab63b46b", size = 79500, upload-time = "2022-12-08T20:59:19.686Z" },
]
[[package]]
name = "opencv-python"
version = "4.11.0.86"
@@ -475,6 +494,32 @@ wheels = [
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/13/a3/a812df4e2dd5696d1f351d58b8fe16a405b234ad2886a0dab9183fb78109/pycparser-2.22-py3-none-any.whl", hash = "sha256:c3702b6d3dd8c7abc1afa565d7e63d53a1d0bd86cdc24edd75470f4de499cfcc", size = 117552, upload-time = "2024-03-30T13:22:20.476Z" },
]
[[package]]
name = "pyyaml"
version = "6.0.2"
source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" }
sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/54/ed/79a089b6be93607fa5cdaedf301d7dfb23af5f25c398d5ead2525b063e17/pyyaml-6.0.2.tar.gz", hash = "sha256:d584d9ec91ad65861cc08d42e834324ef890a082e591037abe114850ff7bbc3e", size = 130631, upload-time = "2024-08-06T20:33:50.674Z" }
wheels = [
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/86/0c/c581167fc46d6d6d7ddcfb8c843a4de25bdd27e4466938109ca68492292c/PyYAML-6.0.2-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:c70c95198c015b85feafc136515252a261a84561b7b1d51e3384e0655ddf25ab", size = 183873, upload-time = "2024-08-06T20:32:25.131Z" },
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/a8/0c/38374f5bb272c051e2a69281d71cba6fdb983413e6758b84482905e29a5d/PyYAML-6.0.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:ce826d6ef20b1bc864f0a68340c8b3287705cae2f8b4b1d932177dcc76721725", size = 173302, upload-time = "2024-08-06T20:32:26.511Z" },
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/c3/93/9916574aa8c00aa06bbac729972eb1071d002b8e158bd0e83a3b9a20a1f7/PyYAML-6.0.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1f71ea527786de97d1a0cc0eacd1defc0985dcf6b3f17bb77dcfc8c34bec4dc5", size = 739154, upload-time = "2024-08-06T20:32:28.363Z" },
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/95/0f/b8938f1cbd09739c6da569d172531567dbcc9789e0029aa070856f123984/PyYAML-6.0.2-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9b22676e8097e9e22e36d6b7bda33190d0d400f345f23d4065d48f4ca7ae0425", size = 766223, upload-time = "2024-08-06T20:32:30.058Z" },
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/b9/2b/614b4752f2e127db5cc206abc23a8c19678e92b23c3db30fc86ab731d3bd/PyYAML-6.0.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:80bab7bfc629882493af4aa31a4cfa43a4c57c83813253626916b8c7ada83476", size = 767542, upload-time = "2024-08-06T20:32:31.881Z" },
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/d4/00/dd137d5bcc7efea1836d6264f049359861cf548469d18da90cd8216cf05f/PyYAML-6.0.2-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:0833f8694549e586547b576dcfaba4a6b55b9e96098b36cdc7ebefe667dfed48", size = 731164, upload-time = "2024-08-06T20:32:37.083Z" },
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/c9/1f/4f998c900485e5c0ef43838363ba4a9723ac0ad73a9dc42068b12aaba4e4/PyYAML-6.0.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8b9c7197f7cb2738065c481a0461e50ad02f18c78cd75775628afb4d7137fb3b", size = 756611, upload-time = "2024-08-06T20:32:38.898Z" },
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/df/d1/f5a275fdb252768b7a11ec63585bc38d0e87c9e05668a139fea92b80634c/PyYAML-6.0.2-cp312-cp312-win32.whl", hash = "sha256:ef6107725bd54b262d6dedcc2af448a266975032bc85ef0172c5f059da6325b4", size = 140591, upload-time = "2024-08-06T20:32:40.241Z" },
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/0c/e8/4f648c598b17c3d06e8753d7d13d57542b30d56e6c2dedf9c331ae56312e/PyYAML-6.0.2-cp312-cp312-win_amd64.whl", hash = "sha256:7e7401d0de89a9a855c839bc697c079a4af81cf878373abd7dc625847d25cbd8", size = 156338, upload-time = "2024-08-06T20:32:41.93Z" },
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/ef/e3/3af305b830494fa85d95f6d95ef7fa73f2ee1cc8ef5b495c7c3269fb835f/PyYAML-6.0.2-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:efdca5630322a10774e8e98e1af481aad470dd62c3170801852d752aa7a783ba", size = 181309, upload-time = "2024-08-06T20:32:43.4Z" },
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/45/9f/3b1c20a0b7a3200524eb0076cc027a970d320bd3a6592873c85c92a08731/PyYAML-6.0.2-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:50187695423ffe49e2deacb8cd10510bc361faac997de9efef88badc3bb9e2d1", size = 171679, upload-time = "2024-08-06T20:32:44.801Z" },
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/7c/9a/337322f27005c33bcb656c655fa78325b730324c78620e8328ae28b64d0c/PyYAML-6.0.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0ffe8360bab4910ef1b9e87fb812d8bc0a308b0d0eef8c8f44e0254ab3b07133", size = 733428, upload-time = "2024-08-06T20:32:46.432Z" },
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/a3/69/864fbe19e6c18ea3cc196cbe5d392175b4cf3d5d0ac1403ec3f2d237ebb5/PyYAML-6.0.2-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:17e311b6c678207928d649faa7cb0d7b4c26a0ba73d41e99c4fff6b6c3276484", size = 763361, upload-time = "2024-08-06T20:32:51.188Z" },
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/04/24/b7721e4845c2f162d26f50521b825fb061bc0a5afcf9a386840f23ea19fa/PyYAML-6.0.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:70b189594dbe54f75ab3a1acec5f1e3faa7e8cf2f1e08d9b561cb41b845f69d5", size = 759523, upload-time = "2024-08-06T20:32:53.019Z" },
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/2b/b2/e3234f59ba06559c6ff63c4e10baea10e5e7df868092bf9ab40e5b9c56b6/PyYAML-6.0.2-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:41e4e3953a79407c794916fa277a82531dd93aad34e29c2a514c2c0c5fe971cc", size = 726660, upload-time = "2024-08-06T20:32:54.708Z" },
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/fe/0f/25911a9f080464c59fab9027482f822b86bf0608957a5fcc6eaac85aa515/PyYAML-6.0.2-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:68ccc6023a3400877818152ad9a1033e3db8625d899c72eacb5a668902e4d652", size = 751597, upload-time = "2024-08-06T20:32:56.985Z" },
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/14/0d/e2c3b43bbce3cf6bd97c840b46088a3031085179e596d4929729d8d68270/PyYAML-6.0.2-cp313-cp313-win32.whl", hash = "sha256:bc2fa7c6b47d6bc618dd7fb02ef6fdedb1090ec036abab80d4681424b84c1183", size = 140527, upload-time = "2024-08-06T20:33:03.001Z" },
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/fa/de/02b54f42487e3d3c6efb3f89428677074ca7bf43aae402517bc7cca949f3/PyYAML-6.0.2-cp313-cp313-win_amd64.whl", hash = "sha256:8388ee1976c416731879ac16da0aff3f63b286ffdd57cdeb95f3f2e085687563", size = 156446, upload-time = "2024-08-06T20:33:04.33Z" },
]
[[package]]
name = "rord-layout-recognation"
version = "0.1.0"
@@ -485,6 +530,7 @@ dependencies = [
{ name = "gdstk" },
{ name = "klayout" },
{ name = "numpy" },
{ name = "omegaconf" },
{ name = "opencv-python" },
{ name = "pillow" },
{ name = "torch" },
@@ -498,6 +544,7 @@ requires-dist = [
{ name = "gdstk", specifier = ">=0.9.60" },
{ name = "klayout", specifier = ">=0.30.2" },
{ name = "numpy", specifier = ">=2.3.0" },
{ name = "omegaconf", specifier = ">=2.3.0" },
{ name = "opencv-python", specifier = ">=4.11.0.86" },
{ name = "pillow", specifier = ">=11.2.1" },
{ name = "torch", specifier = ">=2.7.1" },