chenge to english version

This commit is contained in:
Jiao77
2025-07-22 23:43:35 +08:00
parent 4f81daad3c
commit 9cbfc34436
8 changed files with 166 additions and 166 deletions

140
train.py
View File

@@ -12,14 +12,14 @@ import argparse
import logging
from datetime import datetime
# 导入项目模块
# Import project modules
import config
from models.rord import RoRD
from utils.data_utils import get_transform
# 设置日志记录
# Setup logging
def setup_logging(save_dir):
"""设置训练日志记录"""
"""Setup training logging"""
if not os.path.exists(save_dir):
os.makedirs(save_dir)
@@ -34,14 +34,14 @@ def setup_logging(save_dir):
)
return logging.getLogger(__name__)
# --- (已修改) 训练专用数据集类 ---
# --- (Modified) Training-specific dataset class ---
class ICLayoutTrainingDataset(Dataset):
def __init__(self, image_dir, patch_size=256, transform=None, scale_range=(1.0, 1.0)):
self.image_dir = image_dir
self.image_paths = [os.path.join(image_dir, f) for f in os.listdir(image_dir) if f.endswith('.png')]
self.patch_size = patch_size
self.transform = transform
self.scale_range = scale_range # 新增尺度范围参数
self.scale_range = scale_range # New scale range parameter
def __len__(self):
return len(self.image_paths)
@@ -51,47 +51,47 @@ class ICLayoutTrainingDataset(Dataset):
image = Image.open(img_path).convert('L')
W, H = image.size
# --- 新增:尺度抖动数据增强 ---
# 1. 随机选择一个缩放比例
# --- New: Scale jittering data augmentation ---
# 1. Randomly select a scaling factor
scale = np.random.uniform(self.scale_range[0], self.scale_range[1])
# 2. 根据缩放比例计算需要从原图裁剪的尺寸
# 2. Calculate crop size from original image based on scaling factor
crop_size = int(self.patch_size / scale)
# 确保裁剪尺寸不超过图像边界
if crop_size > min(W, H):
crop_size = min(W, H)
# 3. 随机裁剪
# 3. Random cropping
x = np.random.randint(0, W - crop_size + 1)
y = np.random.randint(0, H - crop_size + 1)
patch = image.crop((x, y, x + crop_size, y + crop_size))
# 4. 将裁剪出的图像块缩放回标准的 patch_size
# 4. Resize cropped patch back to standard patch_size
patch = patch.resize((self.patch_size, self.patch_size), Image.Resampling.LANCZOS)
# --- 尺度抖动结束 ---
# --- Scale jittering end ---
# --- 新增:额外的数据增强 ---
# 亮度调整
# --- New: Additional data augmentation ---
# Brightness adjustment
if np.random.random() < 0.5:
brightness_factor = np.random.uniform(0.8, 1.2)
patch = patch.point(lambda x: int(x * brightness_factor))
# 对比度调整
# Contrast adjustment
if np.random.random() < 0.5:
contrast_factor = np.random.uniform(0.8, 1.2)
patch = patch.point(lambda x: int(((x - 128) * contrast_factor) + 128))
# 添加噪声
# Add noise
if np.random.random() < 0.3:
patch_np = np.array(patch, dtype=np.float32)
noise = np.random.normal(0, 5, patch_np.shape)
patch_np = np.clip(patch_np + noise, 0, 255)
patch = Image.fromarray(patch_np.astype(np.uint8))
# --- 额外数据增强结束 ---
# --- Additional data augmentation end ---
patch_np = np.array(patch)
# 实现8个方向的离散几何变换 (这部分逻辑不变)
# Implement 8-direction discrete geometric transformations (this logic remains unchanged)
theta_deg = np.random.choice([0, 90, 180, 270])
is_mirrored = np.random.choice([True, False])
cx, cy = self.patch_size / 2.0, self.patch_size / 2.0
@@ -117,57 +117,57 @@ class ICLayoutTrainingDataset(Dataset):
H_tensor = torch.from_numpy(H[:2, :]).float()
return patch, transformed_patch, H_tensor
# --- 特征图变换与损失函数 (改进版) ---
# --- (Modified) Feature map transformation and loss functions (improved version) ---
def warp_feature_map(feature_map, H_inv):
B, C, H, W = feature_map.size()
grid = F.affine_grid(H_inv, feature_map.size(), align_corners=False).to(feature_map.device)
return F.grid_sample(feature_map, grid, align_corners=False)
def compute_detection_loss(det_original, det_rotated, H):
"""改进的检测损失使用BCE损失替代MSE"""
"""Improved detection loss: use BCE loss instead of MSE"""
with torch.no_grad():
H_inv = torch.inverse(torch.cat([H, torch.tensor([0.0, 0.0, 1.0]).view(1, 1, 3).repeat(H.shape[0], 1, 1)], dim=1))[:, :2, :]
warped_det_rotated = warp_feature_map(det_rotated, H_inv)
# 使用BCE损失更适合二分类问题
# Use BCE loss, more suitable for binary classification problems
bce_loss = F.binary_cross_entropy(det_original, warped_det_rotated)
# 添加平滑L1损失作为辅助
# Add smooth L1 loss as auxiliary
smooth_l1_loss = F.smooth_l1_loss(det_original, warped_det_rotated)
return bce_loss + 0.1 * smooth_l1_loss
def compute_description_loss(desc_original, desc_rotated, H, margin=1.0):
"""IC版图专用几何感知描述子损失:编码曼哈顿几何特征"""
"""IC layout-specific geometric-aware descriptor loss: encodes Manhattan geometric features"""
B, C, H_feat, W_feat = desc_original.size()
# 曼哈顿几何感知采样:重点采样边缘和角点区域
# Manhattan geometric-aware sampling: focus on edge and corner regions
num_samples = 200
# 生成曼哈顿对齐的采样网格(水平和垂直优先)
# Generate Manhattan-aligned sampling grid (horizontal and vertical priority)
h_coords = torch.linspace(-1, 1, int(np.sqrt(num_samples)), device=desc_original.device)
w_coords = torch.linspace(-1, 1, int(np.sqrt(num_samples)), device=desc_original.device)
# 增加曼哈顿方向的采样密度
# Increase sampling density in Manhattan directions
manhattan_h = torch.cat([h_coords, torch.zeros_like(h_coords)])
manhattan_w = torch.cat([torch.zeros_like(w_coords), w_coords])
manhattan_coords = torch.stack([manhattan_h, manhattan_w], dim=1).unsqueeze(0).repeat(B, 1, 1)
# 采样anchor点
# Sample anchor points
anchor = F.grid_sample(desc_original, manhattan_coords.unsqueeze(1), align_corners=False).squeeze(2).transpose(1, 2)
# 计算对应的正样本点
# Calculate corresponding positive samples
coords_hom = torch.cat([manhattan_coords, torch.ones(B, manhattan_coords.size(1), 1, device=manhattan_coords.device)], dim=2)
M_inv = torch.inverse(torch.cat([H, torch.tensor([0.0, 0.0, 1.0]).view(1, 1, 3).repeat(H.shape[0], 1, 1)], dim=1))
coords_transformed = (coords_hom @ M_inv.transpose(1, 2))[:, :, :2]
positive = F.grid_sample(desc_rotated, coords_transformed.unsqueeze(1), align_corners=False).squeeze(2).transpose(1, 2)
# IC版图专用负样本策略:考虑重复结构
# IC layout-specific negative sample strategy: consider repetitive structures
with torch.no_grad():
# 1. 几何感知的负样本:曼哈顿变换后的不同区域
# 1. Geometric-aware negative samples: different regions after Manhattan transformation
neg_coords = []
for b in range(B):
# 生成曼哈顿变换后的坐标90度旋转等
# Generate coordinates after Manhattan transformation (90-degree rotation, etc.)
angles = [0, 90, 180, 270]
for angle in angles:
if angle != 0:
@@ -181,55 +181,55 @@ def compute_description_loss(desc_original, desc_rotated, H, margin=1.0):
neg_coords = torch.stack(neg_coords[:B*num_samples//2]).reshape(B, -1, 2)
# 2. 特征空间困难负样本
# 2. Feature space hard negative samples
negative_candidates = F.grid_sample(desc_rotated, neg_coords, align_corners=False).squeeze(2).transpose(1, 2)
# 3. 曼哈顿距离约束的困难样本选择
# 3. Manhattan distance constrained hard sample selection
anchor_expanded = anchor.unsqueeze(2).expand(-1, -1, negative_candidates.size(1), -1)
negative_expanded = negative_candidates.unsqueeze(1).expand(-1, anchor.size(1), -1, -1)
# 使用曼哈顿距离而非欧氏距离
# Use Manhattan distance instead of Euclidean distance
manhattan_dist = torch.sum(torch.abs(anchor_expanded - negative_expanded), dim=3)
hard_indices = torch.topk(manhattan_dist, k=anchor.size(1)//2, largest=False)[1]
negative = torch.gather(negative_candidates, 1, hard_indices)
# IC版图专用的几何一致性损失
# 1. 曼哈顿方向一致性损失
# IC layout-specific geometric consistency loss
# 1. Manhattan direction consistency loss
manhattan_loss = 0
for i in range(anchor.size(1)):
# 计算水平和垂直方向的几何一致性
# Calculate geometric consistency in horizontal and vertical directions
anchor_norm = F.normalize(anchor[:, i], p=2, dim=1)
positive_norm = F.normalize(positive[:, i], p=2, dim=1)
# 鼓励描述子对曼哈顿变换不变
# Encourage descriptor invariance to Manhattan transformations
cos_sim = torch.sum(anchor_norm * positive_norm, dim=1)
manhattan_loss += torch.mean(1 - cos_sim)
# 2. 稀疏性正则化IC版图特征稀疏
# 2. Sparsity regularization (IC layout features are sparse)
sparsity_loss = torch.mean(torch.abs(anchor)) + torch.mean(torch.abs(positive))
# 3. 二值化特征距离(处理二值化输入)
# 3. Binary feature distance (handles binary input)
binary_loss = torch.mean(torch.abs(torch.sign(anchor) - torch.sign(positive)))
# 综合损失
triplet_loss = nn.TripletMarginLoss(margin=margin, p=1, reduction='mean') # 使用L1距离
# Combined loss
triplet_loss = nn.TripletMarginLoss(margin=margin, p=1, reduction='mean') # Use L1 distance
geometric_triplet = triplet_loss(anchor, positive, negative)
return geometric_triplet + 0.1 * manhattan_loss + 0.01 * sparsity_loss + 0.05 * binary_loss
# --- (已修改) 主函数与命令行接口 ---
# --- (Modified) Main function and command-line interface ---
def main(args):
# 设置日志记录
# Setup logging
logger = setup_logging(args.save_dir)
logger.info("--- 开始训练 RoRD 模型 ---")
logger.info(f"训练参数: Epochs={args.epochs}, Batch Size={args.batch_size}, LR={args.lr}")
logger.info(f"数据目录: {args.data_dir}")
logger.info(f"保存目录: {args.save_dir}")
logger.info("--- Starting RoRD model training ---")
logger.info(f"Training parameters: Epochs={args.epochs}, Batch Size={args.batch_size}, LR={args.lr}")
logger.info(f"Data directory: {args.data_dir}")
logger.info(f"Save directory: {args.save_dir}")
transform = get_transform()
# 在数据集初始化时传入尺度抖动范围
# Pass scale jittering range during dataset initialization
dataset = ICLayoutTrainingDataset(
args.data_dir,
patch_size=config.PATCH_SIZE,
@@ -237,35 +237,35 @@ def main(args):
scale_range=config.SCALE_JITTER_RANGE
)
logger.info(f"数据集大小: {len(dataset)}")
logger.info(f"Dataset size: {len(dataset)}")
# 分割训练集和验证集
# Split training and validation sets
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])
logger.info(f"训练集大小: {len(train_dataset)}, 验证集大小: {len(val_dataset)}")
logger.info(f"Training set size: {len(train_dataset)}, Validation set size: {len(val_dataset)}")
train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=4)
val_dataloader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4)
model = RoRD().cuda()
logger.info(f"模型参数数量: {sum(p.numel() for p in model.parameters()):,}")
logger.info(f"Model parameter count: {sum(p.numel() for p in model.parameters()):,}")
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-4)
# 添加学习率调度器
# Add learning rate scheduler
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer, mode='min', factor=0.5, patience=5
)
# 早停机制
# Early stopping mechanism
best_val_loss = float('inf')
patience_counter = 0
patience = 10
for epoch in range(args.epochs):
# 训练阶段
# Training phase
model.train()
total_train_loss = 0
total_det_loss = 0
@@ -284,7 +284,7 @@ def main(args):
optimizer.zero_grad()
loss.backward()
# 梯度裁剪,防止梯度爆炸
# Gradient clipping to prevent gradient explosion
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
@@ -300,7 +300,7 @@ def main(args):
avg_det_loss = total_det_loss / len(train_dataloader)
avg_desc_loss = total_desc_loss / len(train_dataloader)
# 验证阶段
# Validation phase
model.eval()
total_val_loss = 0
total_val_det_loss = 0
@@ -325,20 +325,20 @@ def main(args):
avg_val_det_loss = total_val_det_loss / len(val_dataloader)
avg_val_desc_loss = total_val_desc_loss / len(val_dataloader)
# 学习率调度
# Learning rate scheduling
scheduler.step(avg_val_loss)
logger.info(f"--- Epoch {epoch+1} 完成 ---")
logger.info(f"训练 - Total: {avg_train_loss:.4f}, Det: {avg_det_loss:.4f}, Desc: {avg_desc_loss:.4f}")
logger.info(f"验证 - Total: {avg_val_loss:.4f}, Det: {avg_val_det_loss:.4f}, Desc: {avg_val_desc_loss:.4f}")
logger.info(f"学习率: {optimizer.param_groups[0]['lr']:.2e}")
logger.info(f"--- Epoch {epoch+1} completed ---")
logger.info(f"Training - Total: {avg_train_loss:.4f}, Det: {avg_det_loss:.4f}, Desc: {avg_desc_loss:.4f}")
logger.info(f"Validation - Total: {avg_val_loss:.4f}, Det: {avg_val_det_loss:.4f}, Desc: {avg_val_desc_loss:.4f}")
logger.info(f"Learning rate: {optimizer.param_groups[0]['lr']:.2e}")
# 早停检查
# Early stopping check
if avg_val_loss < best_val_loss:
best_val_loss = avg_val_loss
patience_counter = 0
# 保存最佳模型
# Save best model
if not os.path.exists(args.save_dir):
os.makedirs(args.save_dir)
save_path = os.path.join(args.save_dir, 'rord_model_best.pth')
@@ -353,14 +353,14 @@ def main(args):
'epochs': args.epochs
}
}, save_path)
logger.info(f"最佳模型已保存至: {save_path}")
logger.info(f"Best model saved to: {save_path}")
else:
patience_counter += 1
if patience_counter >= patience:
logger.info(f"早停触发!{patience} epoch没有改善")
logger.info(f"Early stopping triggered! No improvement for {patience} epochs")
break
# 保存最终模型
# Save final model
save_path = os.path.join(args.save_dir, 'rord_model_final.pth')
torch.save({
'epoch': args.epochs,
@@ -373,11 +373,11 @@ def main(args):
'epochs': args.epochs
}
}, save_path)
logger.info(f"最终模型已保存至: {save_path}")
logger.info("训练完成!")
logger.info(f"Final model saved to: {save_path}")
logger.info("Training completed!")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="训练 RoRD 模型")
parser = argparse.ArgumentParser(description="Train RoRD model")
parser.add_argument('--data_dir', type=str, default=config.LAYOUT_DIR)
parser.add_argument('--save_dir', type=str, default=config.SAVE_DIR)
parser.add_argument('--epochs', type=int, default=config.NUM_EPOCHS)