diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline_2_1.py b/src/maxdiffusion/pipelines/wan/wan_pipeline_2_1.py index 8aedddd3..5617e3b7 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline_2_1.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline_2_1.py @@ -62,7 +62,7 @@ def _load_and_init(cls, config, restored_checkpoint=None, vae_only=False, load_t @classmethod def from_pretrained(cls, config: HyperParameters, vae_only=False, load_transformer=True): pipeline , transformer = cls._load_and_init(config, None, vae_only, load_transformer) - transformer = cls.quantize_transformer(config, transformer, pipeline, pipeline.mesh) + pipeline.transformer = cls.quantize_transformer(config, transformer, pipeline, pipeline.mesh) return pipeline @classmethod diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline_2_2.py b/src/maxdiffusion/pipelines/wan/wan_pipeline_2_2.py index 9eb8c3e9..9efccf90 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline_2_2.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline_2_2.py @@ -70,8 +70,8 @@ def _load_and_init(cls, config, restored_checkpoint=None, vae_only=False, load_t @classmethod def from_pretrained(cls, config: HyperParameters, vae_only=False, load_transformer=True): pipeline, low_noise_transformer, high_noise_transformer = cls._load_and_init(config, None, vae_only, load_transformer) - low_noise_transformer = cls.quantize_transformer(config, low_noise_transformer, pipeline, pipeline.mesh) - high_noise_transformer = cls.quantize_transformer(config, high_noise_transformer, pipeline, pipeline.mesh) + pipeline.low_noise_transformer = cls.quantize_transformer(config, low_noise_transformer, pipeline, pipeline.mesh) + pipeline.high_noise_transformer = cls.quantize_transformer(config, high_noise_transformer, pipeline, pipeline.mesh) return pipeline @classmethod