complete code struction update

This commit is contained in:
Jiao77
2025-09-25 20:20:24 +08:00
parent e0b250e77f
commit 8c9926c815
10 changed files with 480 additions and 290 deletions

View File

@@ -1,7 +1,12 @@
import os
import json
from typing import Tuple
import cv2
import numpy as np
import torch
from PIL import Image
from torch.utils.data import Dataset
import json
class ICLayoutDataset(Dataset):
def __init__(self, image_dir, annotation_dir=None, transform=None):
@@ -53,4 +58,92 @@ class ICLayoutDataset(Dataset):
with open(ann_path, 'r') as f:
annotation = json.load(f)
return image, annotation
return image, annotation
class ICLayoutTrainingDataset(Dataset):
"""自监督训练用的 IC 版图数据集,带数据增强与几何配准标签。"""
def __init__(
self,
image_dir: str,
patch_size: int = 256,
transform=None,
scale_range: Tuple[float, float] = (1.0, 1.0),
) -> None:
self.image_dir = image_dir
self.image_paths = [
os.path.join(image_dir, f)
for f in os.listdir(image_dir)
if f.endswith('.png')
]
self.patch_size = patch_size
self.transform = transform
self.scale_range = scale_range
def __len__(self) -> int:
return len(self.image_paths)
def __getitem__(self, index: int):
img_path = self.image_paths[index]
image = Image.open(img_path).convert('L')
width, height = image.size
# 随机尺度抖动
scale = float(np.random.uniform(self.scale_range[0], self.scale_range[1]))
crop_size = int(self.patch_size / max(scale, 1e-6))
crop_size = min(crop_size, width, height)
if crop_size <= 0:
raise ValueError("crop_size must be positive; check scale_range configuration")
x = np.random.randint(0, max(width - crop_size + 1, 1))
y = np.random.randint(0, max(height - crop_size + 1, 1))
patch = image.crop((x, y, x + crop_size, y + crop_size))
patch = patch.resize((self.patch_size, self.patch_size), Image.Resampling.LANCZOS)
# 亮度/对比度增强
if np.random.random() < 0.5:
brightness_factor = np.random.uniform(0.8, 1.2)
patch = patch.point(lambda px: int(np.clip(px * brightness_factor, 0, 255)))
if np.random.random() < 0.5:
contrast_factor = np.random.uniform(0.8, 1.2)
patch = patch.point(lambda px: int(np.clip(((px - 128) * contrast_factor) + 128, 0, 255)))
if np.random.random() < 0.3:
patch_np = np.array(patch, dtype=np.float32)
noise = np.random.normal(0, 5, patch_np.shape)
patch_np = np.clip(patch_np + noise, 0, 255)
patch = Image.fromarray(patch_np.astype(np.uint8))
patch_np_uint8 = np.array(patch)
# 随机旋转与镜像8个离散变换
theta_deg = int(np.random.choice([0, 90, 180, 270]))
is_mirrored = bool(np.random.choice([True, False]))
center_x, center_y = self.patch_size / 2.0, self.patch_size / 2.0
rotation_matrix = cv2.getRotationMatrix2D((center_x, center_y), theta_deg, 1.0)
if is_mirrored:
translate_to_origin = np.array([[1, 0, -center_x], [0, 1, -center_y], [0, 0, 1]])
mirror = np.array([[-1, 0, 0], [0, 1, 0], [0, 0, 1]])
translate_back = np.array([[1, 0, center_x], [0, 1, center_y], [0, 0, 1]])
mirror_matrix = translate_back @ mirror @ translate_to_origin
rotation_matrix_h = np.vstack([rotation_matrix, [0, 0, 1]])
homography = (rotation_matrix_h @ mirror_matrix).astype(np.float32)
else:
homography = np.vstack([rotation_matrix, [0, 0, 1]]).astype(np.float32)
transformed_patch_np = cv2.warpPerspective(patch_np_uint8, homography, (self.patch_size, self.patch_size))
transformed_patch = Image.fromarray(transformed_patch_np)
if self.transform:
patch_tensor = self.transform(patch)
transformed_tensor = self.transform(transformed_patch)
else:
patch_tensor = torch.from_numpy(np.array(patch)).float().unsqueeze(0) / 255.0
transformed_tensor = torch.from_numpy(np.array(transformed_patch)).float().unsqueeze(0) / 255.0
H_tensor = torch.from_numpy(homography[:2, :]).float()
return patch_tensor, transformed_tensor, H_tensor