Skip to content

whycoming/DeFlow-tiny

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

1 Commit
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

DeFlow (轻量实现)

基于 volflow/DeFlow(CVPR 2021 Oral)的轻量重写版本, 面向单通道小数据量非配对场景的退化分布学习。


项目结构

DeFlow/
├── models/
│   ├── flow.py          # 轻量 RealNVP 风格 Flow 网络
│   └── condition.py     # 轻量条件编码器(替代 RRDB)
├── train.py             # 数据集 + 训练器
├── infer.py             # 推理:退化施加 + batch 化 patch 处理
├── main.py              # 主入口 + 配置加载
├── config/
│   ├── light.yaml       # 轻量配置
│   └── standard.yaml    # 标准配置
├── test_roundtrip.py    # encode↔decode 数值一致性测试
└── README.md

与 volflow/DeFlow 的区别

1. 架构层面

volflow/DeFlow 本项目
条件编码器 预训练 RRDB(来自 ESRGAN) 自定义轻量多尺度 CNN,联合训练
Flow 主干 Glow 风格(含 Split、多尺度高斯先验) 精简 RealNVP(ActNorm + 1×1 Conv + Affine Coupling)
输入通道 3(RGB) 1(单通道)
默认深度 L=3, K=16 L=2/3, K=8/12
参数量 数十 M 0.98M(light)/ 4.44M(standard)
依赖 BasicSR / MeasureLib / ESRGAN pretrain 纯 PyTorch,无外部预训练权重

2. 训练流程层面

volflow/DeFlow 本项目
数据归一化 [0, 1] [-1, 1]img/127.5 - 1.0
logdet 数值稳定 log(abs(det(W))) slogdet(W),避免 det→0 时 NaN
耦合层 scale 范围 tanh(s) * 2 tanh(s) * 0.5,防止 exp() 溢出
梯度裁剪 grad_norm ≤ 1.0
NaN/Inf 保护 检测到异常 loss/grad 时跳过该 step
Checkpoint 不含 scheduler state 包含 scheduler state,支持正确续训
续训步数语义 range(n_iter)(会多跑) range(n_iter - global_step)(跑到目标步)

3. 推理层面

volflow/DeFlow 本项目
大图处理 逐 patch 编码/解码 按 batch 批量处理 patch,显著加速 GPU 利用率
输出裁剪 clip(0, 1) clip(-1, 1),与训练域一致

使用方法

训练

# 轻量配置(数据少,显存紧张)
python main.py --config config/light.yaml --mode train \
    --data.clean_dir ./data/clean \
    --data.noisy_dir ./data/noisy

# 标准配置
python main.py --config config/standard.yaml --mode train \
    --data.clean_dir ./data/clean \
    --data.noisy_dir ./data/noisy

# 从 checkpoint 续训(续到 n_iter 总步,而非额外跑 n_iter 步)
python main.py --config config/light.yaml --mode train \
    --resume ./checkpoints/iter_20000.pth

推理(在清晰图像上施加学到的退化)

python main.py --config config/light.yaml --mode infer \
    --data.clean_dir ./data/clean \
    --infer.output_dir ./synthetic_lq \
    --infer.ckpt ./checkpoints/best.pth \
    --infer.temperature 1.0 \
    --infer.n_samples 3

数值自检

# 验证 encode→decode 往返误差(修复 slogdet 前曾出现 1e14 级别错误)
python test_roundtrip.py

训练监控

[  500/30000] loss=2.3421  nll_clean=1.1234  nll_noisy=1.2187
              μ_u=0.0312  σ_u=0.1823  lr=2.00e-04
  • nll_clean / nll_noisy:两域的负对数似然,均应稳步下降
  • μ_u:退化在 latent 空间的均值偏移,会收敛到稳定值
  • σ_u:退化在 latent 空间的方差,收敛值反映退化幅度
    • → 0:退化被压成 0(未学到有效差异),应加大 lr 或缩小 batch
    • 过大(>0.5):两域差异远超经典"退化"范畴,Flow 假设可能不成立

配置说明

两份预设配置均位于 config/

light standard
L (level 数) 2 3
K (每 level step 数) 8 12
hidden 64 96
cond_channels 32 32
batch_size 8 16
n_iter 30000 50000
参数量 0.98M 4.44M

所有字段都可通过 --<section>.<key> 在命令行覆盖。

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages