添加数据增强方案以及扩散生成模型的想法
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
import os
|
||||
import json
|
||||
from typing import Tuple
|
||||
from typing import Tuple, Optional
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
@@ -70,6 +70,8 @@ class ICLayoutTrainingDataset(Dataset):
|
||||
patch_size: int = 256,
|
||||
transform=None,
|
||||
scale_range: Tuple[float, float] = (1.0, 1.0),
|
||||
use_albu: bool = False,
|
||||
albu_params: Optional[dict] = None,
|
||||
) -> None:
|
||||
self.image_dir = image_dir
|
||||
self.image_paths = [
|
||||
@@ -80,6 +82,28 @@ class ICLayoutTrainingDataset(Dataset):
|
||||
self.patch_size = patch_size
|
||||
self.transform = transform
|
||||
self.scale_range = scale_range
|
||||
# 可选的 albumentations 管道
|
||||
self.albu = None
|
||||
if use_albu:
|
||||
try:
|
||||
import albumentations as A # 延迟导入,避免环境未安装时报错
|
||||
p = albu_params or {}
|
||||
elastic_prob = float(p.get("prob", 0.3))
|
||||
alpha = float(p.get("alpha", 40))
|
||||
sigma = float(p.get("sigma", 6))
|
||||
alpha_affine = float(p.get("alpha_affine", 6))
|
||||
use_bc = bool(p.get("brightness_contrast", True))
|
||||
use_noise = bool(p.get("gauss_noise", True))
|
||||
transforms_list = [
|
||||
A.ElasticTransform(alpha=alpha, sigma=sigma, alpha_affine=alpha_affine, p=elastic_prob),
|
||||
]
|
||||
if use_bc:
|
||||
transforms_list.append(A.RandomBrightnessContrast(p=0.5))
|
||||
if use_noise:
|
||||
transforms_list.append(A.GaussNoise(var_limit=(5.0, 20.0), p=0.3))
|
||||
self.albu = A.Compose(transforms_list)
|
||||
except Exception:
|
||||
self.albu = None
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.image_paths)
|
||||
@@ -102,22 +126,27 @@ class ICLayoutTrainingDataset(Dataset):
|
||||
patch = image.crop((x, y, x + crop_size, y + crop_size))
|
||||
patch = patch.resize((self.patch_size, self.patch_size), Image.Resampling.LANCZOS)
|
||||
|
||||
# 亮度/对比度增强
|
||||
if np.random.random() < 0.5:
|
||||
brightness_factor = np.random.uniform(0.8, 1.2)
|
||||
patch = patch.point(lambda px: int(np.clip(px * brightness_factor, 0, 255)))
|
||||
|
||||
if np.random.random() < 0.5:
|
||||
contrast_factor = np.random.uniform(0.8, 1.2)
|
||||
patch = patch.point(lambda px: int(np.clip(((px - 128) * contrast_factor) + 128, 0, 255)))
|
||||
|
||||
if np.random.random() < 0.3:
|
||||
patch_np = np.array(patch, dtype=np.float32)
|
||||
noise = np.random.normal(0, 5, patch_np.shape)
|
||||
patch_np = np.clip(patch_np + noise, 0, 255)
|
||||
patch = Image.fromarray(patch_np.astype(np.uint8))
|
||||
|
||||
# photometric/elastic(在几何 H 之前)
|
||||
patch_np_uint8 = np.array(patch)
|
||||
if self.albu is not None:
|
||||
patch_np_uint8 = self.albu(image=patch_np_uint8)["image"]
|
||||
patch = Image.fromarray(patch_np_uint8)
|
||||
else:
|
||||
# 原有轻量光度增强
|
||||
if np.random.random() < 0.5:
|
||||
brightness_factor = np.random.uniform(0.8, 1.2)
|
||||
patch = patch.point(lambda px: int(np.clip(px * brightness_factor, 0, 255)))
|
||||
|
||||
if np.random.random() < 0.5:
|
||||
contrast_factor = np.random.uniform(0.8, 1.2)
|
||||
patch = patch.point(lambda px: int(np.clip(((px - 128) * contrast_factor) + 128, 0, 255)))
|
||||
|
||||
if np.random.random() < 0.3:
|
||||
patch_np = np.array(patch, dtype=np.float32)
|
||||
noise = np.random.normal(0, 5, patch_np.shape)
|
||||
patch_np = np.clip(patch_np + noise, 0, 255)
|
||||
patch = Image.fromarray(patch_np.astype(np.uint8))
|
||||
patch_np_uint8 = np.array(patch)
|
||||
|
||||
# 随机旋转与镜像(8个离散变换)
|
||||
theta_deg = int(np.random.choice([0, 90, 180, 270]))
|
||||
|
||||
Reference in New Issue
Block a user