基于 volflow/DeFlow(CVPR 2021 Oral)的轻量重写版本,
面向单通道、小数据量、非配对场景的退化分布学习。
DeFlow/
├── models/
│ ├── flow.py # 轻量 RealNVP 风格 Flow 网络
│ └── condition.py # 轻量条件编码器(替代 RRDB)
├── train.py # 数据集 + 训练器
├── infer.py # 推理:退化施加 + batch 化 patch 处理
├── main.py # 主入口 + 配置加载
├── config/
│ ├── light.yaml # 轻量配置
│ └── standard.yaml # 标准配置
├── test_roundtrip.py # encode↔decode 数值一致性测试
└── README.md
|
volflow/DeFlow |
本项目 |
| 条件编码器 |
预训练 RRDB(来自 ESRGAN) |
自定义轻量多尺度 CNN,联合训练 |
| Flow 主干 |
Glow 风格(含 Split、多尺度高斯先验) |
精简 RealNVP(ActNorm + 1×1 Conv + Affine Coupling) |
| 输入通道 |
3(RGB) |
1(单通道) |
| 默认深度 |
L=3, K=16 |
L=2/3, K=8/12 |
| 参数量 |
数十 M |
0.98M(light)/ 4.44M(standard) |
| 依赖 |
BasicSR / MeasureLib / ESRGAN pretrain |
纯 PyTorch,无外部预训练权重 |
|
volflow/DeFlow |
本项目 |
| 数据归一化 |
[0, 1] |
[-1, 1](img/127.5 - 1.0) |
| logdet 数值稳定 |
log(abs(det(W))) |
slogdet(W),避免 det→0 时 NaN |
| 耦合层 scale 范围 |
tanh(s) * 2 |
tanh(s) * 0.5,防止 exp() 溢出 |
| 梯度裁剪 |
— |
grad_norm ≤ 1.0 |
| NaN/Inf 保护 |
— |
检测到异常 loss/grad 时跳过该 step |
| Checkpoint |
不含 scheduler state |
包含 scheduler state,支持正确续训 |
| 续训步数语义 |
range(n_iter)(会多跑) |
range(n_iter - global_step)(跑到目标步) |
|
volflow/DeFlow |
本项目 |
| 大图处理 |
逐 patch 编码/解码 |
按 batch 批量处理 patch,显著加速 GPU 利用率 |
| 输出裁剪 |
clip(0, 1) |
clip(-1, 1),与训练域一致 |
# 轻量配置(数据少,显存紧张)
python main.py --config config/light.yaml --mode train \
--data.clean_dir ./data/clean \
--data.noisy_dir ./data/noisy
# 标准配置
python main.py --config config/standard.yaml --mode train \
--data.clean_dir ./data/clean \
--data.noisy_dir ./data/noisy
# 从 checkpoint 续训(续到 n_iter 总步,而非额外跑 n_iter 步)
python main.py --config config/light.yaml --mode train \
--resume ./checkpoints/iter_20000.pth
python main.py --config config/light.yaml --mode infer \
--data.clean_dir ./data/clean \
--infer.output_dir ./synthetic_lq \
--infer.ckpt ./checkpoints/best.pth \
--infer.temperature 1.0 \
--infer.n_samples 3
# 验证 encode→decode 往返误差(修复 slogdet 前曾出现 1e14 级别错误)
python test_roundtrip.py
[ 500/30000] loss=2.3421 nll_clean=1.1234 nll_noisy=1.2187
μ_u=0.0312 σ_u=0.1823 lr=2.00e-04
nll_clean / nll_noisy:两域的负对数似然,均应稳步下降
μ_u:退化在 latent 空间的均值偏移,会收敛到稳定值
σ_u:退化在 latent 空间的方差,收敛值反映退化幅度
- → 0:退化被压成 0(未学到有效差异),应加大 lr 或缩小 batch
- 过大(>0.5):两域差异远超经典"退化"范畴,Flow 假设可能不成立
两份预设配置均位于 config/:
|
light |
standard |
| L (level 数) |
2 |
3 |
| K (每 level step 数) |
8 |
12 |
| hidden |
64 |
96 |
| cond_channels |
32 |
32 |
| batch_size |
8 |
16 |
| n_iter |
30000 |
50000 |
| 参数量 |
0.98M |
4.44M |
所有字段都可通过 --<section>.<key> 在命令行覆盖。