Files
RoRD-Layout-Recognation/data/ic_dataset.py
2025-09-25 20:20:24 +08:00

149 lines
5.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import json
from typing import Tuple
import cv2
import numpy as np
import torch
from PIL import Image
from torch.utils.data import Dataset
class ICLayoutDataset(Dataset):
def __init__(self, image_dir, annotation_dir=None, transform=None):
"""
初始化 IC 版图数据集。
参数:
image_dir (str): 存储 PNG 格式 IC 版图图像的目录路径。
annotation_dir (str, optional): 存储 JSON 格式注释文件的目录路径。
transform (callable, optional): 应用于图像的可选变换(如 Sobel 边缘检测)。
"""
self.image_dir = image_dir
self.annotation_dir = annotation_dir
self.transform = transform
self.images = [f for f in os.listdir(image_dir) if f.endswith('.png')]
if annotation_dir:
self.annotations = [f.replace('.png', '.json') for f in self.images]
else:
self.annotations = [None] * len(self.images)
def __len__(self):
"""
返回数据集中的图像数量。
返回:
int: 数据集大小。
"""
return len(self.images)
def __getitem__(self, idx):
"""
获取指定索引的图像和注释。
参数:
idx (int): 图像索引。
返回:
tuple: (image, annotation)image 为处理后的图像annotation 为注释字典或空字典。
"""
img_path = os.path.join(self.image_dir, self.images[idx])
image = Image.open(img_path).convert('L') # 转换为灰度图
if self.transform:
image = self.transform(image)
annotation = {}
if self.annotation_dir and self.annotations[idx]:
ann_path = os.path.join(self.annotation_dir, self.annotations[idx])
if os.path.exists(ann_path):
with open(ann_path, 'r') as f:
annotation = json.load(f)
return image, annotation
class ICLayoutTrainingDataset(Dataset):
"""自监督训练用的 IC 版图数据集,带数据增强与几何配准标签。"""
def __init__(
self,
image_dir: str,
patch_size: int = 256,
transform=None,
scale_range: Tuple[float, float] = (1.0, 1.0),
) -> None:
self.image_dir = image_dir
self.image_paths = [
os.path.join(image_dir, f)
for f in os.listdir(image_dir)
if f.endswith('.png')
]
self.patch_size = patch_size
self.transform = transform
self.scale_range = scale_range
def __len__(self) -> int:
return len(self.image_paths)
def __getitem__(self, index: int):
img_path = self.image_paths[index]
image = Image.open(img_path).convert('L')
width, height = image.size
# 随机尺度抖动
scale = float(np.random.uniform(self.scale_range[0], self.scale_range[1]))
crop_size = int(self.patch_size / max(scale, 1e-6))
crop_size = min(crop_size, width, height)
if crop_size <= 0:
raise ValueError("crop_size must be positive; check scale_range configuration")
x = np.random.randint(0, max(width - crop_size + 1, 1))
y = np.random.randint(0, max(height - crop_size + 1, 1))
patch = image.crop((x, y, x + crop_size, y + crop_size))
patch = patch.resize((self.patch_size, self.patch_size), Image.Resampling.LANCZOS)
# 亮度/对比度增强
if np.random.random() < 0.5:
brightness_factor = np.random.uniform(0.8, 1.2)
patch = patch.point(lambda px: int(np.clip(px * brightness_factor, 0, 255)))
if np.random.random() < 0.5:
contrast_factor = np.random.uniform(0.8, 1.2)
patch = patch.point(lambda px: int(np.clip(((px - 128) * contrast_factor) + 128, 0, 255)))
if np.random.random() < 0.3:
patch_np = np.array(patch, dtype=np.float32)
noise = np.random.normal(0, 5, patch_np.shape)
patch_np = np.clip(patch_np + noise, 0, 255)
patch = Image.fromarray(patch_np.astype(np.uint8))
patch_np_uint8 = np.array(patch)
# 随机旋转与镜像8个离散变换
theta_deg = int(np.random.choice([0, 90, 180, 270]))
is_mirrored = bool(np.random.choice([True, False]))
center_x, center_y = self.patch_size / 2.0, self.patch_size / 2.0
rotation_matrix = cv2.getRotationMatrix2D((center_x, center_y), theta_deg, 1.0)
if is_mirrored:
translate_to_origin = np.array([[1, 0, -center_x], [0, 1, -center_y], [0, 0, 1]])
mirror = np.array([[-1, 0, 0], [0, 1, 0], [0, 0, 1]])
translate_back = np.array([[1, 0, center_x], [0, 1, center_y], [0, 0, 1]])
mirror_matrix = translate_back @ mirror @ translate_to_origin
rotation_matrix_h = np.vstack([rotation_matrix, [0, 0, 1]])
homography = (rotation_matrix_h @ mirror_matrix).astype(np.float32)
else:
homography = np.vstack([rotation_matrix, [0, 0, 1]]).astype(np.float32)
transformed_patch_np = cv2.warpPerspective(patch_np_uint8, homography, (self.patch_size, self.patch_size))
transformed_patch = Image.fromarray(transformed_patch_np)
if self.transform:
patch_tensor = self.transform(patch)
transformed_tensor = self.transform(transformed_patch)
else:
patch_tensor = torch.from_numpy(np.array(patch)).float().unsqueeze(0) / 255.0
transformed_tensor = torch.from_numpy(np.array(transformed_patch)).float().unsqueeze(0) / 255.0
H_tensor = torch.from_numpy(homography[:2, :]).float()
return patch_tensor, transformed_tensor, H_tensor