Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
185 changes: 149 additions & 36 deletions monai/inferers/inferer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from __future__ import annotations

import inspect
import math
import warnings
from abc import ABC, abstractmethod
Expand Down Expand Up @@ -861,6 +862,96 @@ def __init__(self, scheduler: Scheduler) -> None: # type: ignore[override]

self.scheduler = scheduler

@staticmethod
def _scheduler_step_supports_kwarg(scheduler: Scheduler, kwarg: str) -> bool:
try:
return kwarg in inspect.signature(scheduler.step).parameters
except (TypeError, ValueError):
return False

@staticmethod
def _get_previous_sample_from_step_output(step_output: Any) -> torch.Tensor:
if isinstance(step_output, tuple):
return step_output[0]
if isinstance(step_output, Mapping):
return step_output["prev_sample"]
if hasattr(step_output, "prev_sample"):
return step_output.prev_sample
raise TypeError("Unsupported scheduler.step output. Expected a tuple or an object with `prev_sample`.")

@staticmethod
def _get_scheduler_name(scheduler: Scheduler) -> str:
if hasattr(scheduler, "_get_name"):
return scheduler._get_name()
return scheduler.__class__.__name__

@staticmethod
def _get_scheduler_config_value(scheduler: Scheduler, name: str, default: Any = None) -> Any:
config = getattr(scheduler, "config", None)
if isinstance(config, Mapping):
if name in config:
return config[name]
elif config is not None and hasattr(config, name):
return getattr(config, name)

if hasattr(scheduler, name):
return getattr(scheduler, name)
return default

@staticmethod
def _get_posterior_mean(
scheduler: Scheduler, timestep: int | torch.Tensor, x_0: torch.Tensor, x_t: torch.Tensor
) -> torch.Tensor:
alpha_t = scheduler.alphas[timestep]
alpha_prod_t = scheduler.alphas_cumprod[timestep]
alpha_prod_t_prev = scheduler.alphas_cumprod[timestep - 1] if timestep > 0 else scheduler.one

x_0_coefficient = alpha_prod_t_prev.sqrt() * scheduler.betas[timestep] / (1 - alpha_prod_t)
x_t_coefficient = alpha_t.sqrt() * (1 - alpha_prod_t_prev) / (1 - alpha_prod_t)

return x_0_coefficient * x_0 + x_t_coefficient * x_t

def _get_posterior_variance(
self, scheduler: Scheduler, timestep: int | torch.Tensor, predicted_variance: torch.Tensor | None = None
) -> torch.Tensor:
alpha_prod_t = scheduler.alphas_cumprod[timestep]
alpha_prod_t_prev = scheduler.alphas_cumprod[timestep - 1] if timestep > 0 else scheduler.one
variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * scheduler.betas[timestep]
variance_type = self._get_scheduler_config_value(scheduler, "variance_type")

if variance_type == "fixed_small":
variance = torch.clamp(variance, min=1e-20)
elif variance_type == "fixed_large":
variance = scheduler.betas[timestep]
elif variance_type == "learned" and predicted_variance is not None:
return predicted_variance
elif variance_type == "learned_range" and predicted_variance is not None:
min_log = variance
max_log = scheduler.betas[timestep]
frac = (predicted_variance + 1) / 2
variance = frac * max_log + (1 - frac) * min_log

return variance

def _scheduler_step(
self,
scheduler: Scheduler,
model_output: torch.Tensor,
timestep: int | torch.Tensor,
sample: torch.Tensor,
next_timestep: int | torch.Tensor | None = None,
) -> torch.Tensor:
step_kwargs = {}
if self._scheduler_step_supports_kwarg(scheduler, "return_dict"):
step_kwargs["return_dict"] = False

if isinstance(scheduler, RFlowScheduler):
step_output = scheduler.step(model_output, timestep, sample, next_timestep, **step_kwargs) # type: ignore
else:
step_output = scheduler.step(model_output, timestep, sample, **step_kwargs) # type: ignore

return self._get_previous_sample_from_step_output(step_output)

def __call__( # type: ignore[override]
self,
inputs: torch.Tensor,
Expand Down Expand Up @@ -940,7 +1031,12 @@ def sample(
scheduler = self.scheduler
image = input_noise

all_next_timesteps = torch.cat((scheduler.timesteps[1:], torch.tensor([0], dtype=scheduler.timesteps.dtype)))
all_next_timesteps = torch.cat(
(
scheduler.timesteps[1:],
torch.tensor([0], dtype=scheduler.timesteps.dtype, device=scheduler.timesteps.device),
)
)
if verbose and has_tqdm:
progress_bar = tqdm(
zip(scheduler.timesteps, all_next_timesteps),
Expand Down Expand Up @@ -984,10 +1080,9 @@ def sample(
model_output = model_output_uncond + cfg * (model_output_cond - model_output_uncond)

# 2. compute previous image: x_t -> x_t-1
if not isinstance(scheduler, RFlowScheduler):
image, _ = scheduler.step(model_output, t, image) # type: ignore
else:
image, _ = scheduler.step(model_output, t, image, next_t) # type: ignore
image = self._scheduler_step(
scheduler=scheduler, model_output=model_output, timestep=t, sample=image, next_timestep=next_t
)
if save_intermediates and t % intermediate_steps == 0:
intermediates.append(image)

Expand Down Expand Up @@ -1028,10 +1123,10 @@ def get_likelihood(

if not scheduler:
scheduler = self.scheduler
if scheduler._get_name() != "DDPMScheduler":
scheduler_name = self._get_scheduler_name(scheduler)
if scheduler_name != "DDPMScheduler":
raise NotImplementedError(
f"Likelihood computation is only compatible with DDPMScheduler,"
f" you are using {scheduler._get_name()}"
f"Likelihood computation is only compatible with DDPMScheduler," f" you are using {scheduler_name}"
)
if mode not in ["crossattn", "concat"]:
raise NotImplementedError(f"{mode} condition is not supported")
Expand All @@ -1046,7 +1141,7 @@ def get_likelihood(
total_kl = torch.zeros(inputs.shape[0]).to(inputs.device)
for t in progress_bar:
timesteps = torch.full(inputs.shape[:1], t, device=inputs.device).long()
noisy_image = self.scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps)
noisy_image = scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps)
diffusion_model = (
partial(diffusion_model, seg=seg)
if isinstance(diffusion_model, SPADEDiffusionModelUNet)
Expand All @@ -1059,7 +1154,8 @@ def get_likelihood(
model_output = diffusion_model(x=noisy_image, timesteps=timesteps, context=conditioning)

# get the model's predicted mean, and variance if it is predicted
if model_output.shape[1] == inputs.shape[1] * 2 and scheduler.variance_type in ["learned", "learned_range"]:
variance_type = self._get_scheduler_config_value(scheduler, "variance_type")
if model_output.shape[1] == inputs.shape[1] * 2 and variance_type in ["learned", "learned_range"]:
model_output, predicted_variance = torch.split(model_output, inputs.shape[1], dim=1)
else:
predicted_variance = None
Expand All @@ -1072,15 +1168,17 @@ def get_likelihood(

# 2. compute predicted original sample from predicted noise also called
# "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
if scheduler.prediction_type == "epsilon":
prediction_type = self._get_scheduler_config_value(scheduler, "prediction_type")
if prediction_type == "epsilon":
pred_original_sample = (noisy_image - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
elif scheduler.prediction_type == "sample":
elif prediction_type == "sample":
pred_original_sample = model_output
elif scheduler.prediction_type == "v_prediction":
elif prediction_type == "v_prediction":
pred_original_sample = (alpha_prod_t**0.5) * noisy_image - (beta_prod_t**0.5) * model_output
# 3. Clip "predicted x_0"
if scheduler.clip_sample:
pred_original_sample = torch.clamp(pred_original_sample, -1, 1)
if self._get_scheduler_config_value(scheduler, "clip_sample"):
clip_sample_range = self._get_scheduler_config_value(scheduler, "clip_sample_range", 1.0)
pred_original_sample = torch.clamp(pred_original_sample, -clip_sample_range, clip_sample_range)

# 4. Compute coefficients for pred_original_sample x_0 and current sample x_t
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
Expand All @@ -1092,11 +1190,15 @@ def get_likelihood(
predicted_mean = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * noisy_image

# get the posterior mean and variance
posterior_mean = scheduler._get_mean(timestep=t, x_0=inputs, x_t=noisy_image) # type: ignore[operator]
posterior_variance = scheduler._get_variance(timestep=t, predicted_variance=predicted_variance) # type: ignore[operator]
posterior_mean = self._get_posterior_mean(scheduler=scheduler, timestep=t, x_0=inputs, x_t=noisy_image)
posterior_variance = self._get_posterior_variance(
scheduler=scheduler, timestep=t, predicted_variance=predicted_variance
)

log_posterior_variance = torch.log(posterior_variance)
log_predicted_variance = torch.log(predicted_variance) if predicted_variance else log_posterior_variance
log_predicted_variance = (
torch.log(predicted_variance) if predicted_variance is not None else log_posterior_variance
)

if t == 0:
# compute -log p(x_0|x_1)
Expand Down Expand Up @@ -1509,7 +1611,12 @@ def sample( # type: ignore[override]
scheduler = self.scheduler
image = input_noise

all_next_timesteps = torch.cat((scheduler.timesteps[1:], torch.tensor([0], dtype=scheduler.timesteps.dtype)))
all_next_timesteps = torch.cat(
(
scheduler.timesteps[1:],
torch.tensor([0], dtype=scheduler.timesteps.dtype, device=scheduler.timesteps.device),
)
)
if verbose and has_tqdm:
progress_bar = tqdm(
zip(scheduler.timesteps, all_next_timesteps),
Expand Down Expand Up @@ -1583,10 +1690,9 @@ def sample( # type: ignore[override]
model_output = model_output_uncond + cfg * (model_output_cond - model_output_uncond)

# 3. compute previous image: x_t -> x_t-1
if not isinstance(scheduler, RFlowScheduler):
image, _ = scheduler.step(model_output, t, image) # type: ignore
else:
image, _ = scheduler.step(model_output, t, image, next_t) # type: ignore
image = self._scheduler_step(
scheduler=scheduler, model_output=model_output, timestep=t, sample=image, next_timestep=next_t
)

if save_intermediates and t % intermediate_steps == 0:
intermediates.append(image)
Expand Down Expand Up @@ -1631,10 +1737,10 @@ def get_likelihood( # type: ignore[override]

if not scheduler:
scheduler = self.scheduler
if scheduler._get_name() != "DDPMScheduler":
scheduler_name = self._get_scheduler_name(scheduler)
if scheduler_name != "DDPMScheduler":
raise NotImplementedError(
f"Likelihood computation is only compatible with DDPMScheduler,"
f" you are using {scheduler._get_name()}"
f"Likelihood computation is only compatible with DDPMScheduler," f" you are using {scheduler_name}"
)
if mode not in ["crossattn", "concat"]:
raise NotImplementedError(f"{mode} condition is not supported")
Expand All @@ -1647,7 +1753,7 @@ def get_likelihood( # type: ignore[override]
total_kl = torch.zeros(inputs.shape[0]).to(inputs.device)
for t in progress_bar:
timesteps = torch.full(inputs.shape[:1], t, device=inputs.device).long()
noisy_image = self.scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps)
noisy_image = scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps)

diffuse = diffusion_model
if isinstance(diffusion_model, SPADEDiffusionModelUNet):
Expand Down Expand Up @@ -1680,7 +1786,8 @@ def get_likelihood( # type: ignore[override]
mid_block_additional_residual=mid_block_res_sample,
)
# get the model's predicted mean, and variance if it is predicted
if model_output.shape[1] == inputs.shape[1] * 2 and scheduler.variance_type in ["learned", "learned_range"]:
variance_type = self._get_scheduler_config_value(scheduler, "variance_type")
if model_output.shape[1] == inputs.shape[1] * 2 and variance_type in ["learned", "learned_range"]:
model_output, predicted_variance = torch.split(model_output, inputs.shape[1], dim=1)
else:
predicted_variance = None
Expand All @@ -1693,15 +1800,17 @@ def get_likelihood( # type: ignore[override]

# 2. compute predicted original sample from predicted noise also called
# "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
if scheduler.prediction_type == "epsilon":
prediction_type = self._get_scheduler_config_value(scheduler, "prediction_type")
if prediction_type == "epsilon":
pred_original_sample = (noisy_image - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
elif scheduler.prediction_type == "sample":
elif prediction_type == "sample":
pred_original_sample = model_output
elif scheduler.prediction_type == "v_prediction":
elif prediction_type == "v_prediction":
pred_original_sample = (alpha_prod_t**0.5) * noisy_image - (beta_prod_t**0.5) * model_output
# 3. Clip "predicted x_0"
if scheduler.clip_sample:
pred_original_sample = torch.clamp(pred_original_sample, -1, 1)
if self._get_scheduler_config_value(scheduler, "clip_sample"):
clip_sample_range = self._get_scheduler_config_value(scheduler, "clip_sample_range", 1.0)
pred_original_sample = torch.clamp(pred_original_sample, -clip_sample_range, clip_sample_range)

# 4. Compute coefficients for pred_original_sample x_0 and current sample x_t
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
Expand All @@ -1713,11 +1822,15 @@ def get_likelihood( # type: ignore[override]
predicted_mean = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * noisy_image

# get the posterior mean and variance
posterior_mean = scheduler._get_mean(timestep=t, x_0=inputs, x_t=noisy_image) # type: ignore[operator]
posterior_variance = scheduler._get_variance(timestep=t, predicted_variance=predicted_variance) # type: ignore[operator]
posterior_mean = self._get_posterior_mean(scheduler=scheduler, timestep=t, x_0=inputs, x_t=noisy_image)
posterior_variance = self._get_posterior_variance(
scheduler=scheduler, timestep=t, predicted_variance=predicted_variance
)

log_posterior_variance = torch.log(posterior_variance)
log_predicted_variance = torch.log(predicted_variance) if predicted_variance else log_posterior_variance
log_predicted_variance = (
torch.log(predicted_variance) if predicted_variance is not None else log_posterior_variance
)

if t == 0:
# compute -log p(x_0|x_1)
Expand Down
58 changes: 58 additions & 0 deletions tests/inferers/test_diffusion_inferer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

_, has_scipy = optional_import("scipy")
_, has_einops = optional_import("einops")
DiffusersDDPMScheduler, has_diffusers = optional_import("diffusers", name="DDPMScheduler")

TEST_CASES = [
[
Expand Down Expand Up @@ -126,6 +127,63 @@ def test_ddpm_sampler(self, model_params, input_shape):
)
self.assertEqual(len(intermediates), 10)

@skipUnless(has_einops and has_diffusers, "Requires einops and diffusers")
def test_diffusers_ddpm_call(self):
device = "cuda:0" if torch.cuda.is_available() else "cpu"
model = DiffusionModelUNet(
spatial_dims=2,
in_channels=1,
out_channels=1,
channels=[32, 64],
attention_levels=[False, True],
num_res_blocks=1,
num_head_channels=32,
)
model.to(device)
model.eval()
scheduler = DiffusersDDPMScheduler(num_train_timesteps=1000, beta_schedule="linear", prediction_type="epsilon")
scheduler.set_timesteps(num_inference_steps=50)
inferer = DiffusionInferer(scheduler=scheduler)

batch_size = 2
image_size = 32
inputs = torch.randn(batch_size, 1, image_size, image_size).to(device)
noise = torch.randn_like(inputs)
timesteps = torch.randint(0, scheduler.config.num_train_timesteps, (batch_size,)).long().to(device)
with torch.no_grad():
prediction = inferer(inputs=inputs, diffusion_model=model, noise=noise, timesteps=timesteps)

self.assertEqual(prediction.shape, inputs.shape)
Comment thread
coderabbitai[bot] marked this conversation as resolved.
scheduler.set_timesteps(num_inference_steps=2)
sample = inferer.sample(input_noise=noise, diffusion_model=model, scheduler=scheduler, verbose=False)
self.assertEqual(sample.shape, inputs.shape)

@skipUnless(has_einops and has_diffusers, "Requires einops and diffusers")
def test_diffusers_ddpm_get_likelihood(self):
device = "cuda:0" if torch.cuda.is_available() else "cpu"
model = DiffusionModelUNet(
spatial_dims=2,
in_channels=1,
out_channels=1,
channels=[8],
norm_num_groups=8,
attention_levels=[True],
num_res_blocks=1,
num_head_channels=8,
)
model.to(device)
model.eval()
inputs = torch.randn(2, 1, 8, 8).to(device)
scheduler = DiffusersDDPMScheduler(num_train_timesteps=10, beta_schedule="linear", prediction_type="epsilon")
inferer = DiffusionInferer(scheduler=scheduler)
scheduler.set_timesteps(num_inference_steps=10)
likelihood, intermediates = inferer.get_likelihood(
inputs=inputs, diffusion_model=model, scheduler=scheduler, save_intermediates=True
)
self.assertEqual(len(intermediates), 10)
self.assertEqual(intermediates[0].shape, inputs.shape)
self.assertEqual(likelihood.shape[0], inputs.shape[0])

@parameterized.expand(TEST_CASES)
@skipUnless(has_einops, "Requires einops")
def test_ddim_sampler(self, model_params, input_shape):
Expand Down
Loading
Loading