69 lines
		
	
	
		
			2.1 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
		
		
			
		
	
	
			69 lines
		
	
	
		
			2.1 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
|   | #!/usr/bin/env python3 | ||
|  | """
 | ||
|  | Quickly preview training pairs (original, transformed, H) from ICLayoutTrainingDataset. | ||
|  | Saves a grid image for visual inspection. | ||
|  | """
 | ||
|  | from __future__ import annotations | ||
|  | 
 | ||
|  | import argparse | ||
|  | from pathlib import Path | ||
|  | 
 | ||
|  | import numpy as np | ||
|  | import torch | ||
|  | from PIL import Image | ||
|  | from torchvision.utils import make_grid, save_image | ||
|  | 
 | ||
|  | from data.ic_dataset import ICLayoutTrainingDataset | ||
|  | from utils.data_utils import get_transform | ||
|  | 
 | ||
|  | 
 | ||
|  | def to_pil(t: torch.Tensor) -> Image.Image: | ||
|  |     # input normalized to [-1,1] for 3-channels; invert normalization | ||
|  |     x = t.clone() | ||
|  |     if x.dim() == 3 and x.size(0) == 3: | ||
|  |         x = (x * 0.5) + 0.5  # unnormalize | ||
|  |     x = (x * 255.0).clamp(0, 255).byte() | ||
|  |     if x.dim() == 3 and x.size(0) == 3: | ||
|  |         x = x | ||
|  |     elif x.dim() == 3 and x.size(0) == 1: | ||
|  |         x = x.repeat(3, 1, 1) | ||
|  |     else: | ||
|  |         raise ValueError("Unexpected tensor shape") | ||
|  |     np_img = x.permute(1, 2, 0).cpu().numpy() | ||
|  |     return Image.fromarray(np_img) | ||
|  | 
 | ||
|  | 
 | ||
|  | def main(): | ||
|  |     parser = argparse.ArgumentParser(description="Preview dataset samples") | ||
|  |     parser.add_argument("--dir", dest="image_dir", type=str, required=True, help="PNG images directory") | ||
|  |     parser.add_argument("--out", dest="out_path", type=str, default="preview.png") | ||
|  |     parser.add_argument("--n", dest="num", type=int, default=8) | ||
|  |     parser.add_argument("--patch", dest="patch_size", type=int, default=256) | ||
|  |     parser.add_argument("--elastic", dest="use_elastic", action="store_true") | ||
|  |     args = parser.parse_args() | ||
|  | 
 | ||
|  |     transform = get_transform() | ||
|  |     ds = ICLayoutTrainingDataset( | ||
|  |         args.image_dir, | ||
|  |         patch_size=args.patch_size, | ||
|  |         transform=transform, | ||
|  |         scale_range=(1.0, 1.0), | ||
|  |         use_albu=args.use_elastic, | ||
|  |         albu_params={"prob": 0.5}, | ||
|  |     ) | ||
|  | 
 | ||
|  |     images = [] | ||
|  |     for i in range(min(args.num, len(ds))): | ||
|  |         orig, rot, H = ds[i] | ||
|  |         # Stack orig and rot side-by-side for each sample | ||
|  |         images.append(orig) | ||
|  |         images.append(rot) | ||
|  | 
 | ||
|  |     grid = make_grid(torch.stack(images, dim=0), nrow=2, padding=2) | ||
|  |     save_image(grid, args.out_path) | ||
|  |     print(f"Saved preview to {args.out_path}") | ||
|  | 
 | ||
|  | 
 | ||
|  | if __name__ == "__main__": | ||
|  |     main() |