添加数据增强方案以及扩散生成模型的想法
This commit is contained in:
@@ -91,7 +91,12 @@ class RoRD(nn.Module):
|
||||
# 默认各层通道(VGG 对齐)
|
||||
c2_ch, c3_ch, c4_ch = 128, 256, 512
|
||||
if backbone_name == "resnet34":
|
||||
res = models.resnet34(weights=models.ResNet34_Weights.DEFAULT if pretrained else None)
|
||||
# 构建骨干并按需手动加载权重,便于打印加载摘要
|
||||
if pretrained:
|
||||
res = models.resnet34(weights=None)
|
||||
self._summarize_pretrained_load(res, models.ResNet34_Weights.DEFAULT, "resnet34")
|
||||
else:
|
||||
res = models.resnet34(weights=None)
|
||||
self.backbone = nn.Sequential(
|
||||
res.conv1, res.bn1, res.relu, res.maxpool,
|
||||
res.layer1, res.layer2, res.layer3, res.layer4,
|
||||
@@ -102,14 +107,23 @@ class RoRD(nn.Module):
|
||||
# 选择 layer2/layer3/layer4 作为 C2/C3/C4
|
||||
c2_ch, c3_ch, c4_ch = 128, 256, 512
|
||||
elif backbone_name == "efficientnet_b0":
|
||||
eff = models.efficientnet_b0(weights=models.EfficientNet_B0_Weights.DEFAULT if pretrained else None)
|
||||
if pretrained:
|
||||
eff = models.efficientnet_b0(weights=None)
|
||||
self._summarize_pretrained_load(eff, models.EfficientNet_B0_Weights.DEFAULT, "efficientnet_b0")
|
||||
else:
|
||||
eff = models.efficientnet_b0(weights=None)
|
||||
self.backbone = eff.features
|
||||
self._backbone_raw = eff
|
||||
out_channels_backbone = 1280
|
||||
# 选择 features[2]/[3]/[6] 作为 C2/C3/C4(约 24/40/192)
|
||||
c2_ch, c3_ch, c4_ch = 24, 40, 192
|
||||
else:
|
||||
vgg16_features = models.vgg16(weights=models.VGG16_Weights.DEFAULT if pretrained else None).features
|
||||
if pretrained:
|
||||
vgg = models.vgg16(weights=None)
|
||||
self._summarize_pretrained_load(vgg, models.VGG16_Weights.DEFAULT, "vgg16")
|
||||
else:
|
||||
vgg = models.vgg16(weights=None)
|
||||
vgg16_features = vgg.features
|
||||
# VGG16 特征各阶段索引(conv & relu 层序列)
|
||||
# relu2_2 索引 8,relu3_3 索引 15,relu4_3 索引 22
|
||||
self.features = vgg16_features
|
||||
@@ -263,4 +277,33 @@ class RoRD(nn.Module):
|
||||
x = feats[6](x); c4 = x
|
||||
return c2, c3, c4
|
||||
|
||||
raise RuntimeError(f"Unsupported backbone for FPN: {self.backbone_name}")
|
||||
raise RuntimeError(f"Unsupported backbone for FPN: {self.backbone_name}")
|
||||
|
||||
# --- Utils ---
|
||||
def _summarize_pretrained_load(self, torch_model: nn.Module, weights_enum, arch_name: str) -> None:
|
||||
"""手动加载 torchvision 预训练权重并打印加载摘要。
|
||||
- 使用 strict=False 以兼容可能的键差异,打印 missing/unexpected keys。
|
||||
- 输出参数量统计,便于快速核对加载情况。
|
||||
"""
|
||||
try:
|
||||
state_dict = weights_enum.get_state_dict(progress=False)
|
||||
except Exception:
|
||||
# 回退:若权重枚举不支持 get_state_dict,则跳过摘要(通常已在构造器中加载)
|
||||
print(f"[Pretrained] {arch_name}: skip summary (weights enum lacks get_state_dict)")
|
||||
return
|
||||
incompatible = torch_model.load_state_dict(state_dict, strict=False)
|
||||
total_params = sum(p.numel() for p in torch_model.parameters())
|
||||
trainable_params = sum(p.numel() for p in torch_model.parameters() if p.requires_grad)
|
||||
missing = list(getattr(incompatible, 'missing_keys', []))
|
||||
unexpected = list(getattr(incompatible, 'unexpected_keys', []))
|
||||
try:
|
||||
matched = len(state_dict) - len(unexpected)
|
||||
except Exception:
|
||||
matched = 0
|
||||
print(f"[Pretrained] {arch_name}: ImageNet weights loaded (strict=False)")
|
||||
print(f" params: total={total_params/1e6:.2f}M, trainable={trainable_params/1e6:.2f}M")
|
||||
print(f" keys: matched≈{matched} | missing={len(missing)} | unexpected={len(unexpected)}")
|
||||
if missing and len(missing) <= 10:
|
||||
print(f" missing: {missing}")
|
||||
if unexpected and len(unexpected) <= 10:
|
||||
print(f" unexpected: {unexpected}")
|
||||
Reference in New Issue
Block a user