第二次大修

This commit is contained in:
Jiao77
2025-06-08 15:38:56 +08:00
parent 53ef1ec99c
commit f0b2e1b605
10 changed files with 315 additions and 508 deletions

280
train.py
View File

@@ -1,236 +1,142 @@
# train.py
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import numpy as np
import cv2
import os
from models.rord import RoRD
import argparse
# 数据集类:生成随机旋转的训练对
# 导入项目模块
import config
from models.rord import RoRD
from utils.data_utils import get_transform
# --- 训练专用数据集类 ---
class ICLayoutTrainingDataset(Dataset):
def __init__(self, image_dir, patch_size=256, transform=None):
"""
初始化 IC 版图训练数据集。
参数:
image_dir (str): 存储 PNG 格式 IC 版图图像的目录路径。
patch_size (int): 裁剪的 patch 大小(默认 256x256
transform (callable, optional): 应用于图像的变换。
"""
self.image_dir = image_dir
self.image_paths = [os.path.join(image_dir, f) for f in os.listdir(image_dir) if f.endswith('.png')]
self.patch_size = patch_size
self.transform = transform
def __len__(self):
"""
返回数据集中的图像数量。
返回:
int: 数据集大小。
"""
return len(self.image_paths)
def __getitem__(self, index):
"""
获取指定索引的训练对(原始 patch、旋转 patch、Homography 矩阵)。
参数:
index (int): 图像索引。
返回:
tuple: (patch, rotated_patch, H_tensor)
- patch: 原始 patch 张量。
- rotated_patch: 旋转后的 patch 张量。
- H_tensor: Homography 矩阵张量。
"""
img_path = self.image_paths[index]
image = Image.open(img_path).convert('L') # 灰度图像
image = Image.open(img_path).convert('L')
# 获取图像大小
W, H = image.size
# 随机选择裁剪的左上角坐标
x = np.random.randint(0, W - self.patch_size + 1)
y = np.random.randint(0, H - self.patch_size + 1)
patch = image.crop((x, y, x + self.patch_size, y + self.patch_size))
# 转换为 NumPy 数组
patch_np = np.array(patch)
# 实现8个方向的离散几何变换
theta_deg = np.random.choice([0, 90, 180, 270])
is_mirrored = np.random.choice([True, False])
cx, cy = self.patch_size / 2.0, self.patch_size / 2.0
M = cv2.getRotationMatrix2D((cx, cy), theta_deg, 1)
# 随机旋转角度0°~360°
theta = np.random.uniform(0, 360)
theta_rad = np.deg2rad(theta)
cos_theta = np.cos(theta_rad)
sin_theta = np.sin(theta_rad)
if is_mirrored:
T1 = np.array([[1, 0, -cx], [0, 1, -cy], [0, 0, 1]])
Flip = np.array([[-1, 0, 0], [0, 1, 0], [0, 0, 1]])
T2 = np.array([[1, 0, cx], [0, 1, cy], [0, 0, 1]])
M_mirror_3x3 = T2 @ Flip @ T1
M_3x3 = np.vstack([M, [0, 0, 1]])
H = (M_3x3 @ M_mirror_3x3).astype(np.float32)
else:
H = np.vstack([M, [0, 0, 1]]).astype(np.float32)
transformed_patch_np = cv2.warpPerspective(patch_np, H, (self.patch_size, self.patch_size))
transformed_patch = Image.fromarray(transformed_patch_np)
# 计算旋转中心patch 的中心)
cx = self.patch_size / 2.0
cy = self.patch_size / 2.0
# 计算旋转的齐次矩阵Homography
H = np.array([
[cos_theta, -sin_theta, cx * (1 - cos_theta) + cy * sin_theta],
[sin_theta, cos_theta, cy * (1 - cos_theta) - cx * sin_theta],
[0, 0, 1]
], dtype=np.float32)
# 应用旋转到 patch
rotated_patch_np = cv2.warpPerspective(patch_np, H, (self.patch_size, self.patch_size))
# 转换回 PIL Image
rotated_patch = Image.fromarray(rotated_patch_np)
# 应用变换
if self.transform:
patch = self.transform(patch)
rotated_patch = self.transform(rotated_patch)
transformed_patch = self.transform(transformed_patch)
# 转换 H 为张量
H_tensor = torch.from_numpy(H).float()
H_tensor = torch.from_numpy(H[:2, :]).float() # 通常损失函数需要2x3的仿射矩阵
return patch, transformed_patch, H_tensor
return patch, rotated_patch, H_tensor
# 特征图变换函数
# --- 特征图变换与损失函数 ---
def warp_feature_map(feature_map, H_inv):
"""
使用逆 Homography 矩阵变换特征图。
参数:
feature_map (torch.Tensor): 输入特征图,形状为 [B, C, H, W]。
H_inv (torch.Tensor): 逆 Homography 矩阵,形状为 [B, 3, 3]。
返回:
torch.Tensor: 变换后的特征图,形状为 [B, C, H, W]。
"""
B, C, H, W = feature_map.size()
# 生成网格
grid_y, grid_x = torch.meshgrid(
torch.linspace(-1, 1, H, device=feature_map.device),
torch.linspace(-1, 1, W, device=feature_map.device),
indexing='ij'
)
grid = torch.stack((grid_x, grid_y, torch.ones_like(grid_x)), dim=-1) # [H, W, 3]
grid = grid.unsqueeze(0).expand(B, H, W, 3) # [B, H, W, 3]
grid = F.affine_grid(H_inv, feature_map.size(), align_corners=False).to(feature_map.device)
return F.grid_sample(feature_map, grid, align_corners=False)
# 将网格转换为齐次坐标并应用 H_inv
grid_flat = grid.view(B, -1, 3) # [B, H*W, 3]
grid_transformed = torch.bmm(grid_flat, H_inv.transpose(1, 2)) # [B, H*W, 3]
grid_transformed = grid_transformed.view(B, H, W, 3) # [B, H, W, 3]
grid_transformed = grid_transformed[..., :2] / (grid_transformed[..., 2:3] + 1e-8) # [B, H, W, 2]
# 使用 grid_sample 进行变换
warped_feature = F.grid_sample(feature_map, grid_transformed, align_corners=True)
return warped_feature
# 检测损失函数
def compute_detection_loss(det_original, det_rotated, H):
"""
计算检测损失MSE比较原始检测图与旋转检测图逆变换后
参数:
det_original (torch.Tensor): 原始图像的检测图,形状为 [B, 1, H, W]。
det_rotated (torch.Tensor): 旋转图像的检测图,形状为 [B, 1, H, W]。
H (torch.Tensor): Homography 矩阵,形状为 [B, 3, 3]。
返回:
torch.Tensor: 检测损失。
"""
H_inv = torch.inverse(H) # 计算逆 Homography
with torch.no_grad():
H_inv = torch.inverse(torch.cat([H, torch.tensor([0.0, 0.0, 1.0]).view(1, 1, 3).repeat(H.shape[0], 1, 1)], dim=1))[:, :2, :]
warped_det_rotated = warp_feature_map(det_rotated, H_inv)
return F.mse_loss(det_original, warped_det_rotated)
# 描述子损失函数
def compute_description_loss(desc_original, desc_rotated, H, margin=1.0):
"""
计算描述子损失(三元组损失),基于对应点的描述子。
B, C, H_feat, W_feat = desc_original.size()
num_samples = 100
# 随机采样锚点坐标
coords = torch.rand(B, num_samples, 2, device=desc_original.device) * 2 - 1 # [-1, 1]
# 提取锚点描述子
anchor = F.grid_sample(desc_original, coords.unsqueeze(1), align_corners=False).squeeze(2).transpose(1, 2)
# 计算正样本坐标
coords_hom = torch.cat([coords, torch.ones(B, num_samples, 1, device=coords.device)], dim=2)
M_inv = torch.inverse(torch.cat([H, torch.tensor([0.0, 0.0, 1.0]).view(1, 1, 3).repeat(H.shape[0], 1, 1)], dim=1))
coords_transformed = (coords_hom @ M_inv.transpose(1, 2))[:, :, :2]
# 提取正样本描述子
positive = F.grid_sample(desc_rotated, coords_transformed.unsqueeze(1), align_corners=False).squeeze(2).transpose(1, 2)
# 随机采样负样本
neg_coords = torch.rand(B, num_samples, 2, device=desc_original.device) * 2 - 1
negative = F.grid_sample(desc_rotated, neg_coords.unsqueeze(1), align_corners=False).squeeze(2).transpose(1, 2)
参数:
desc_original (torch.Tensor): 原始图像的描述子图,形状为 [B, 128, H, W]。
desc_rotated (torch.Tensor): 旋转图像的描述子图,形状为 [B, 128, H, W]。
H (torch.Tensor): Homography 矩阵,形状为 [B, 3, 3]。
margin (float): 三元组损失的边距。
返回:
torch.Tensor: 描述子损失。
"""
B, C, H, W = desc_original.size()
# 随机选择锚点anchor
num_samples = min(100, H * W) # 每张图像采样 100 个点
idx = torch.randint(0, H * W, (B, num_samples), device=desc_original.device)
idx_y = idx // W
idx_x = idx % W
coords = torch.stack((idx_x.float(), idx_y.float()), dim=-1) # [B, num_samples, 2]
# 转换为齐次坐标
coords_hom = torch.cat((coords, torch.ones(B, num_samples, 1, device=coords.device)), dim=-1) # [B, num_samples, 3]
coords_transformed = torch.bmm(coords_hom, H.transpose(1, 2)) # [B, num_samples, 3]
coords_transformed = coords_transformed[..., :2] / (coords_transformed[..., 2:3] + 1e-8) # [B, num_samples, 2]
# 归一化到 [-1, 1] 用于 grid_sample
coords_transformed = coords_transformed / torch.tensor([W/2, H/2], device=coords.device) - 1
# 提取锚点和正样本描述子
anchor = desc_original.view(B, C, -1)[:, :, idx.view(-1)] # [B, 128, num_samples]
positive = F.grid_sample(desc_rotated, coords_transformed.unsqueeze(2), align_corners=True).squeeze(3) # [B, 128, num_samples]
# 随机选择负样本
neg_idx = torch.randint(0, H * W, (B, num_samples), device=desc_original.device)
negative = desc_rotated.view(B, C, -1)[:, :, neg_idx.view(-1)] # [B, 128, num_samples]
# 三元组损失
triplet_loss = nn.TripletMarginLoss(margin=margin, p=2)
loss = triplet_loss(anchor.transpose(1, 2), positive.transpose(1, 2), negative.transpose(1, 2))
return loss
return triplet_loss(anchor, positive, negative)
# 定义变换
transform = transforms.Compose([
transforms.ToTensor(), # (1, 256, 256)
transforms.Lambda(lambda x: x.repeat(3, 1, 1)), # (3, 256, 256)
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
# --- 主函数与命令行接口 ---
def main(args):
print("--- 开始训练 RoRD 模型 ---")
print(f"训练参数: Epochs={args.epochs}, Batch Size={args.batch_size}, LR={args.lr}")
transform = get_transform()
dataset = ICLayoutTrainingDataset(args.data_dir, patch_size=config.PATCH_SIZE, transform=transform)
dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=4)
model = RoRD().cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
# 创建数据集和 DataLoader
dataset = ICLayoutTrainingDataset('path/to/layouts', patch_size=256, transform=transform)
dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=4)
for epoch in range(args.epochs):
model.train()
total_loss_val = 0
for i, (original, rotated, H) in enumerate(dataloader):
original, rotated, H = original.cuda(), rotated.cuda(), H.cuda()
det_original, desc_original = model(original)
det_rotated, desc_rotated = model(rotated)
# 定义模型
model = RoRD().cuda()
loss = compute_detection_loss(det_original, det_rotated, H) + compute_description_loss(desc_original, desc_rotated, H)
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss_val += loss.item()
# 定义优化器
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
print(f"--- Epoch {epoch+1} 完成, 平均 Loss: {total_loss_val / len(dataloader):.4f} ---")
# 训练循环
num_epochs = 10
for epoch in range(num_epochs):
model.train()
total_loss = 0
for batch in dataloader:
original, rotated, H = batch
original = original.cuda()
rotated = rotated.cuda()
H = H.cuda()
if not os.path.exists(args.save_dir):
os.makedirs(args.save_dir)
save_path = os.path.join(args.save_dir, 'rord_model_final.pth')
torch.save(model.state_dict(), save_path)
print(f"模型已保存至: {save_path}")
# 前向传播
det_original, _, desc_rord_original = model(original)
det_rotated, _, desc_rord_rotated = model(rotated)
# 计算损失
detection_loss = compute_detection_loss(det_original, det_rotated, H)
description_loss = compute_description_loss(desc_rord_original, desc_rord_rotated, H)
total_loss_batch = detection_loss + description_loss
# 反向传播
optimizer.zero_grad()
total_loss_batch.backward()
optimizer.step()
total_loss += total_loss_batch.item()
print(f"Epoch {epoch+1}/{num_epochs}, Loss: {total_loss / len(dataloader):.4f}")
# 保存模型
torch.save(model.state_dict(), 'path/to/save/model.pth')
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="训练 RoRD 模型")
parser.add_argument('--data_dir', type=str, default=config.LAYOUT_DIR)
parser.add_argument('--save_dir', type=str, default=config.SAVE_DIR)
parser.add_argument('--epochs', type=int, default=config.NUM_EPOCHS)
parser.add_argument('--batch_size', type=int, default=config.BATCH_SIZE)
parser.add_argument('--lr', type=float, default=config.LEARNING_RATE)
main(parser.parse_args())