From 6ecc205e07e1ea013e379e53cef722679624e87d Mon Sep 17 00:00:00 2001 From: Prisha Jain Date: Tue, 20 Jan 2026 22:08:22 +0530 Subject: [PATCH 1/5] sharding before call to vae encoder --- src/maxdiffusion/pipelines/wan/wan_pipeline.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline.py b/src/maxdiffusion/pipelines/wan/wan_pipeline.py index 0bc93f0c..efc2a649 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline.py @@ -529,6 +529,11 @@ def prepare_latents_i2v_base( vae_dtype = getattr(self.vae, "dtype", jnp.float32) video_condition = video_condition.astype(vae_dtype) + sharding_spec = P(self.config.mesh_axes[0], None, None, None, None) + video_condition = jax.lax.with_sharding_constraint( + video_condition, + sharding_spec + ) with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): encoded_output = self.vae.encode(video_condition, self.vae_cache)[0].mode() From 33d16f73e7689ff8adb3677d2382ed75e67aa63c Mon Sep 17 00:00:00 2001 From: Prisha Jain Date: Tue, 20 Jan 2026 22:20:52 +0530 Subject: [PATCH 2/5] sharding before call to vae encoder --- src/maxdiffusion/pipelines/wan/wan_pipeline.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline.py b/src/maxdiffusion/pipelines/wan/wan_pipeline.py index efc2a649..4a436bea 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline.py @@ -529,13 +529,12 @@ def prepare_latents_i2v_base( vae_dtype = getattr(self.vae, "dtype", jnp.float32) video_condition = video_condition.astype(vae_dtype) - sharding_spec = P(self.config.mesh_axes[0], None, None, None, None) - video_condition = jax.lax.with_sharding_constraint( - video_condition, - sharding_spec - ) - with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): + sharding_spec = P(self.config.mesh_axes[0], None, None, None, None) + video_condition = jax.lax.with_sharding_constraint( + video_condition, + sharding_spec + ) encoded_output = self.vae.encode(video_condition, self.vae_cache)[0].mode() # Normalize latents From f9e842a11fcd9d7717f3b4bae2943ddd25b2b3ae Mon Sep 17 00:00:00 2001 From: Prisha Jain Date: Wed, 21 Jan 2026 00:40:31 +0530 Subject: [PATCH 3/5] ruff check --- src/maxdiffusion/pipelines/wan/wan_pipeline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline.py b/src/maxdiffusion/pipelines/wan/wan_pipeline.py index 4a436bea..21ec5754 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline.py @@ -532,7 +532,7 @@ def prepare_latents_i2v_base( with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): sharding_spec = P(self.config.mesh_axes[0], None, None, None, None) video_condition = jax.lax.with_sharding_constraint( - video_condition, + video_condition, sharding_spec ) encoded_output = self.vae.encode(video_condition, self.vae_cache)[0].mode() From 04c2320440a4bddc508840762bc1e95916fb383d Mon Sep 17 00:00:00 2001 From: Prisha Jain Date: Wed, 21 Jan 2026 04:16:29 +0000 Subject: [PATCH 4/5] pyink checks --- src/maxdiffusion/__init__.py | 378 +++++++++--------- .../transformers/transformer_flux_flax.py | 88 ++-- .../transformers_pytorch/attention.py | 1 + .../wan/transformers/transformer_wan_vace.py | 12 +- .../attention_comparison.py | 1 + .../dataset_tf_cache_to_tfrecord.py | 1 + .../pedagogical_examples/parameter_count.py | 1 + .../pedagogical_examples/to_tfrecords.py | 14 +- src/maxdiffusion/pipelines/__init__.py | 38 +- .../pipelines/controlnet/__init__.py | 1 + src/maxdiffusion/pipelines/flux/__init__.py | 1 + .../pipelines/stable_diffusion/__init__.py | 13 +- .../pipelines/stable_diffusion_xl/__init__.py | 1 + .../pipelines/wan/wan_pipeline.py | 72 ++-- .../scheduling_dpmsolver_multistep_flax.py | 12 +- .../scheduling_unipc_multistep_flax.py | 12 +- .../schedulers/scheduling_utils_flax.py | 3 +- .../tests/configuration_utils_test.py | 1 + .../tests/flop_calculations_test.py | 1 + .../tests/generate_flux_smoke_test.py | 1 + .../tests/generate_sdxl_smoke_test.py | 1 + src/maxdiffusion/tests/generate_smoke_test.py | 1 + src/maxdiffusion/utils/deprecation_utils.py | 1 + src/maxdiffusion/utils/export_utils.py | 1 + src/maxdiffusion/utils/import_utils.py | 44 +- src/maxdiffusion/utils/loading_utils.py | 1 + src/maxdiffusion/utils/pil_utils.py | 1 + src/maxdiffusion/utils/testing_utils.py | 1 + 28 files changed, 377 insertions(+), 326 deletions(-) diff --git a/src/maxdiffusion/__init__.py b/src/maxdiffusion/__init__.py index a1a2c2f5..2d084be6 100644 --- a/src/maxdiffusion/__init__.py +++ b/src/maxdiffusion/__init__.py @@ -84,23 +84,25 @@ _import_structure["utils.dummy_pt_objects"] = [name for name in dir(dummy_pt_objects) if not name.startswith("_")] else: - _import_structure["models"].extend([ - "AsymmetricAutoencoderKL", - "AutoencoderKL", - "AutoencoderTiny", - "ControlNetModel", - "ModelMixin", - "MultiAdapter", - "PriorTransformer", - "T2IAdapter", - "T5FilmDecoder", - "Transformer2DModel", - "UNet1DModel", - "UNet2DConditionModel", - "UNet2DModel", - "UNet3DConditionModel", - "VQModel", - ]) + _import_structure["models"].extend( + [ + "AsymmetricAutoencoderKL", + "AutoencoderKL", + "AutoencoderTiny", + "ControlNetModel", + "ModelMixin", + "MultiAdapter", + "PriorTransformer", + "T2IAdapter", + "T5FilmDecoder", + "Transformer2DModel", + "UNet1DModel", + "UNet2DConditionModel", + "UNet2DModel", + "UNet3DConditionModel", + "VQModel", + ] + ) _import_structure["optimization"] = [ "get_constant_schedule", "get_constant_schedule_with_warmup", @@ -111,52 +113,56 @@ "get_scheduler", ] - _import_structure["pipelines"].extend([ - "AudioPipelineOutput", - "AutoPipelineForImage2Image", - "AutoPipelineForInpainting", - "AutoPipelineForText2Image", - "ConsistencyModelPipeline", - "DanceDiffusionPipeline", - "DDIMPipeline", - "DDPMPipeline", - "DiffusionPipeline", - "DiTPipeline", - "ImagePipelineOutput", - "KarrasVePipeline", - "LDMPipeline", - "LDMSuperResolutionPipeline", - "PNDMPipeline", - "RePaintPipeline", - "ScoreSdeVePipeline", - ]) - _import_structure["schedulers"].extend([ - "CMStochasticIterativeScheduler", - "DDIMInverseScheduler", - "DDIMParallelScheduler", - "DDIMScheduler", - "DDPMParallelScheduler", - "DDPMScheduler", - "DDPMWuerstchenScheduler", - "DEISMultistepScheduler", - "DPMSolverMultistepInverseScheduler", - "DPMSolverMultistepScheduler", - "DPMSolverSinglestepScheduler", - "EulerAncestralDiscreteScheduler", - "EulerDiscreteScheduler", - "HeunDiscreteScheduler", - "IPNDMScheduler", - "KarrasVeScheduler", - "KDPM2AncestralDiscreteScheduler", - "KDPM2DiscreteScheduler", - "PNDMScheduler", - "RePaintScheduler", - "SchedulerMixin", - "ScoreSdeVeScheduler", - "UnCLIPScheduler", - "UniPCMultistepScheduler", - "VQDiffusionScheduler", - ]) + _import_structure["pipelines"].extend( + [ + "AudioPipelineOutput", + "AutoPipelineForImage2Image", + "AutoPipelineForInpainting", + "AutoPipelineForText2Image", + "ConsistencyModelPipeline", + "DanceDiffusionPipeline", + "DDIMPipeline", + "DDPMPipeline", + "DiffusionPipeline", + "DiTPipeline", + "ImagePipelineOutput", + "KarrasVePipeline", + "LDMPipeline", + "LDMSuperResolutionPipeline", + "PNDMPipeline", + "RePaintPipeline", + "ScoreSdeVePipeline", + ] + ) + _import_structure["schedulers"].extend( + [ + "CMStochasticIterativeScheduler", + "DDIMInverseScheduler", + "DDIMParallelScheduler", + "DDIMScheduler", + "DDPMParallelScheduler", + "DDPMScheduler", + "DDPMWuerstchenScheduler", + "DEISMultistepScheduler", + "DPMSolverMultistepInverseScheduler", + "DPMSolverMultistepScheduler", + "DPMSolverSinglestepScheduler", + "EulerAncestralDiscreteScheduler", + "EulerDiscreteScheduler", + "HeunDiscreteScheduler", + "IPNDMScheduler", + "KarrasVeScheduler", + "KDPM2AncestralDiscreteScheduler", + "KDPM2DiscreteScheduler", + "PNDMScheduler", + "RePaintScheduler", + "SchedulerMixin", + "ScoreSdeVeScheduler", + "UnCLIPScheduler", + "UniPCMultistepScheduler", + "VQDiffusionScheduler", + ] + ) _import_structure["training_utils"] = ["EMAModel"] try: @@ -196,98 +202,100 @@ ] else: - _import_structure["pipelines"].extend([ - "AltDiffusionImg2ImgPipeline", - "AltDiffusionPipeline", - "AudioLDM2Pipeline", - "AudioLDM2ProjectionModel", - "AudioLDM2UNet2DConditionModel", - "AudioLDMPipeline", - "BlipDiffusionControlNetPipeline", - "BlipDiffusionPipeline", - "CLIPImageProjection", - "CycleDiffusionPipeline", - "IFImg2ImgPipeline", - "IFImg2ImgSuperResolutionPipeline", - "IFInpaintingPipeline", - "IFInpaintingSuperResolutionPipeline", - "IFPipeline", - "IFSuperResolutionPipeline", - "ImageTextPipelineOutput", - "KandinskyCombinedPipeline", - "KandinskyImg2ImgCombinedPipeline", - "KandinskyImg2ImgPipeline", - "KandinskyInpaintCombinedPipeline", - "KandinskyInpaintPipeline", - "KandinskyPipeline", - "KandinskyPriorPipeline", - "KandinskyV22CombinedPipeline", - "KandinskyV22ControlnetImg2ImgPipeline", - "KandinskyV22ControlnetPipeline", - "KandinskyV22Img2ImgCombinedPipeline", - "KandinskyV22Img2ImgPipeline", - "KandinskyV22InpaintCombinedPipeline", - "KandinskyV22InpaintPipeline", - "KandinskyV22Pipeline", - "KandinskyV22PriorEmb2EmbPipeline", - "KandinskyV22PriorPipeline", - "LDMTextToImagePipeline", - "MusicLDMPipeline", - "PaintByExamplePipeline", - "SemanticStableDiffusionPipeline", - "ShapEImg2ImgPipeline", - "ShapEPipeline", - "StableDiffusionAdapterPipeline", - "StableDiffusionAttendAndExcitePipeline", - "StableDiffusionControlNetImg2ImgPipeline", - "StableDiffusionControlNetInpaintPipeline", - "StableDiffusionControlNetPipeline", - "StableDiffusionDepth2ImgPipeline", - "StableDiffusionDiffEditPipeline", - "StableDiffusionGLIGENPipeline", - "StableDiffusionGLIGENTextImagePipeline", - "StableDiffusionImageVariationPipeline", - "StableDiffusionImg2ImgPipeline", - "StableDiffusionInpaintPipeline", - "StableDiffusionInpaintPipelineLegacy", - "StableDiffusionInstructPix2PixPipeline", - "StableDiffusionLatentUpscalePipeline", - "StableDiffusionLDM3DPipeline", - "StableDiffusionModelEditingPipeline", - "StableDiffusionPanoramaPipeline", - "StableDiffusionParadigmsPipeline", - "StableDiffusionPipeline", - "StableDiffusionPipelineSafe", - "StableDiffusionPix2PixZeroPipeline", - "StableDiffusionSAGPipeline", - "StableDiffusionUpscalePipeline", - "StableDiffusionXLAdapterPipeline", - "StableDiffusionXLControlNetImg2ImgPipeline", - "StableDiffusionXLControlNetInpaintPipeline", - "StableDiffusionXLControlNetPipeline", - "StableDiffusionXLImg2ImgPipeline", - "StableDiffusionXLInpaintPipeline", - "StableDiffusionXLInstructPix2PixPipeline", - "StableDiffusionXLPipeline", - "StableUnCLIPImg2ImgPipeline", - "StableUnCLIPPipeline", - "TextToVideoSDPipeline", - "TextToVideoZeroPipeline", - "UnCLIPImageVariationPipeline", - "UnCLIPPipeline", - "UniDiffuserModel", - "UniDiffuserPipeline", - "UniDiffuserTextDecoder", - "VersatileDiffusionDualGuidedPipeline", - "VersatileDiffusionImageVariationPipeline", - "VersatileDiffusionPipeline", - "VersatileDiffusionTextToImagePipeline", - "VideoToVideoSDPipeline", - "VQDiffusionPipeline", - "WuerstchenCombinedPipeline", - "WuerstchenDecoderPipeline", - "WuerstchenPriorPipeline", - ]) + _import_structure["pipelines"].extend( + [ + "AltDiffusionImg2ImgPipeline", + "AltDiffusionPipeline", + "AudioLDM2Pipeline", + "AudioLDM2ProjectionModel", + "AudioLDM2UNet2DConditionModel", + "AudioLDMPipeline", + "BlipDiffusionControlNetPipeline", + "BlipDiffusionPipeline", + "CLIPImageProjection", + "CycleDiffusionPipeline", + "IFImg2ImgPipeline", + "IFImg2ImgSuperResolutionPipeline", + "IFInpaintingPipeline", + "IFInpaintingSuperResolutionPipeline", + "IFPipeline", + "IFSuperResolutionPipeline", + "ImageTextPipelineOutput", + "KandinskyCombinedPipeline", + "KandinskyImg2ImgCombinedPipeline", + "KandinskyImg2ImgPipeline", + "KandinskyInpaintCombinedPipeline", + "KandinskyInpaintPipeline", + "KandinskyPipeline", + "KandinskyPriorPipeline", + "KandinskyV22CombinedPipeline", + "KandinskyV22ControlnetImg2ImgPipeline", + "KandinskyV22ControlnetPipeline", + "KandinskyV22Img2ImgCombinedPipeline", + "KandinskyV22Img2ImgPipeline", + "KandinskyV22InpaintCombinedPipeline", + "KandinskyV22InpaintPipeline", + "KandinskyV22Pipeline", + "KandinskyV22PriorEmb2EmbPipeline", + "KandinskyV22PriorPipeline", + "LDMTextToImagePipeline", + "MusicLDMPipeline", + "PaintByExamplePipeline", + "SemanticStableDiffusionPipeline", + "ShapEImg2ImgPipeline", + "ShapEPipeline", + "StableDiffusionAdapterPipeline", + "StableDiffusionAttendAndExcitePipeline", + "StableDiffusionControlNetImg2ImgPipeline", + "StableDiffusionControlNetInpaintPipeline", + "StableDiffusionControlNetPipeline", + "StableDiffusionDepth2ImgPipeline", + "StableDiffusionDiffEditPipeline", + "StableDiffusionGLIGENPipeline", + "StableDiffusionGLIGENTextImagePipeline", + "StableDiffusionImageVariationPipeline", + "StableDiffusionImg2ImgPipeline", + "StableDiffusionInpaintPipeline", + "StableDiffusionInpaintPipelineLegacy", + "StableDiffusionInstructPix2PixPipeline", + "StableDiffusionLatentUpscalePipeline", + "StableDiffusionLDM3DPipeline", + "StableDiffusionModelEditingPipeline", + "StableDiffusionPanoramaPipeline", + "StableDiffusionParadigmsPipeline", + "StableDiffusionPipeline", + "StableDiffusionPipelineSafe", + "StableDiffusionPix2PixZeroPipeline", + "StableDiffusionSAGPipeline", + "StableDiffusionUpscalePipeline", + "StableDiffusionXLAdapterPipeline", + "StableDiffusionXLControlNetImg2ImgPipeline", + "StableDiffusionXLControlNetInpaintPipeline", + "StableDiffusionXLControlNetPipeline", + "StableDiffusionXLImg2ImgPipeline", + "StableDiffusionXLInpaintPipeline", + "StableDiffusionXLInstructPix2PixPipeline", + "StableDiffusionXLPipeline", + "StableUnCLIPImg2ImgPipeline", + "StableUnCLIPPipeline", + "TextToVideoSDPipeline", + "TextToVideoZeroPipeline", + "UnCLIPImageVariationPipeline", + "UnCLIPPipeline", + "UniDiffuserModel", + "UniDiffuserPipeline", + "UniDiffuserTextDecoder", + "VersatileDiffusionDualGuidedPipeline", + "VersatileDiffusionImageVariationPipeline", + "VersatileDiffusionPipeline", + "VersatileDiffusionTextToImagePipeline", + "VideoToVideoSDPipeline", + "VQDiffusionPipeline", + "WuerstchenCombinedPipeline", + "WuerstchenDecoderPipeline", + "WuerstchenPriorPipeline", + ] + ) try: if not (is_torch_available() and is_k_diffusion_available()): @@ -313,14 +321,16 @@ ] else: - _import_structure["pipelines"].extend([ - "OnnxStableDiffusionImg2ImgPipeline", - "OnnxStableDiffusionInpaintPipeline", - "OnnxStableDiffusionInpaintPipelineLegacy", - "OnnxStableDiffusionPipeline", - "OnnxStableDiffusionUpscalePipeline", - "StableDiffusionOnnxPipeline", - ]) + _import_structure["pipelines"].extend( + [ + "OnnxStableDiffusionImg2ImgPipeline", + "OnnxStableDiffusionInpaintPipeline", + "OnnxStableDiffusionInpaintPipelineLegacy", + "OnnxStableDiffusionPipeline", + "OnnxStableDiffusionUpscalePipeline", + "StableDiffusionOnnxPipeline", + ] + ) try: if not (is_torch_available() and is_librosa_available()): @@ -366,17 +376,19 @@ _import_structure["models.vae_flax"] = ["FlaxAutoencoderKL"] _import_structure["models.ltx_video.transformers.transformer3d"] = ["Transformer3DModel"] _import_structure["pipelines"].extend(["FlaxDiffusionPipeline"]) - _import_structure["schedulers"].extend([ - "FlaxDDIMScheduler", - "FlaxDDPMScheduler", - "FlaxDPMSolverMultistepScheduler", - "FlaxEulerDiscreteScheduler", - "FlaxKarrasVeScheduler", - "FlaxLMSDiscreteScheduler", - "FlaxPNDMScheduler", - "FlaxSchedulerMixin", - "FlaxScoreSdeVeScheduler", - ]) + _import_structure["schedulers"].extend( + [ + "FlaxDDIMScheduler", + "FlaxDDPMScheduler", + "FlaxDPMSolverMultistepScheduler", + "FlaxEulerDiscreteScheduler", + "FlaxKarrasVeScheduler", + "FlaxLMSDiscreteScheduler", + "FlaxPNDMScheduler", + "FlaxSchedulerMixin", + "FlaxScoreSdeVeScheduler", + ] + ) try: @@ -391,14 +403,16 @@ else: - _import_structure["pipelines"].extend([ - "FlaxStableDiffusionControlNetPipeline", - "FlaxStableDiffusionXLControlNetPipeline", - "FlaxStableDiffusionImg2ImgPipeline", - "FlaxStableDiffusionInpaintPipeline", - "FlaxStableDiffusionPipeline", - "FlaxStableDiffusionXLPipeline", - ]) + _import_structure["pipelines"].extend( + [ + "FlaxStableDiffusionControlNetPipeline", + "FlaxStableDiffusionXLControlNetPipeline", + "FlaxStableDiffusionImg2ImgPipeline", + "FlaxStableDiffusionInpaintPipeline", + "FlaxStableDiffusionPipeline", + "FlaxStableDiffusionXLPipeline", + ] + ) try: if not (is_note_seq_available()): diff --git a/src/maxdiffusion/models/flux/transformers/transformer_flux_flax.py b/src/maxdiffusion/models/flux/transformers/transformer_flux_flax.py index 814e21ea..a4cfab1b 100644 --- a/src/maxdiffusion/models/flux/transformers/transformer_flux_flax.py +++ b/src/maxdiffusion/models/flux/transformers/transformer_flux_flax.py @@ -202,27 +202,29 @@ def setup(self): dtype=self.dtype, param_dtype=self.weights_dtype, ) - self.img_mlp = nn.Sequential([ - nn.Dense( - int(self.dim * self.mlp_ratio), - use_bias=True, - kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), ("embed", "mlp")), - bias_init=nn.with_logical_partitioning(nn.initializers.zeros, (None,)), - dtype=self.dtype, - param_dtype=self.weights_dtype, - precision=self.precision, - ), - nn.gelu, - nn.Dense( - self.dim, - use_bias=True, - kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), ("mlp", "embed")), - bias_init=nn.with_logical_partitioning(nn.initializers.zeros, (None,)), - dtype=self.dtype, - param_dtype=self.weights_dtype, - precision=self.precision, - ), - ]) + self.img_mlp = nn.Sequential( + [ + nn.Dense( + int(self.dim * self.mlp_ratio), + use_bias=True, + kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), ("embed", "mlp")), + bias_init=nn.with_logical_partitioning(nn.initializers.zeros, (None,)), + dtype=self.dtype, + param_dtype=self.weights_dtype, + precision=self.precision, + ), + nn.gelu, + nn.Dense( + self.dim, + use_bias=True, + kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), ("mlp", "embed")), + bias_init=nn.with_logical_partitioning(nn.initializers.zeros, (None,)), + dtype=self.dtype, + param_dtype=self.weights_dtype, + precision=self.precision, + ), + ] + ) self.txt_norm2 = nn.LayerNorm( use_bias=False, @@ -231,27 +233,29 @@ def setup(self): dtype=self.dtype, param_dtype=self.weights_dtype, ) - self.txt_mlp = nn.Sequential([ - nn.Dense( - int(self.dim * self.mlp_ratio), - use_bias=True, - kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), ("embed", "mlp")), - bias_init=nn.with_logical_partitioning(nn.initializers.zeros, (None,)), - dtype=self.dtype, - param_dtype=self.weights_dtype, - precision=self.precision, - ), - nn.gelu, - nn.Dense( - self.dim, - use_bias=True, - kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), ("mlp", "embed")), - bias_init=nn.with_logical_partitioning(nn.initializers.zeros, (None,)), - dtype=self.dtype, - param_dtype=self.weights_dtype, - precision=self.precision, - ), - ]) + self.txt_mlp = nn.Sequential( + [ + nn.Dense( + int(self.dim * self.mlp_ratio), + use_bias=True, + kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), ("embed", "mlp")), + bias_init=nn.with_logical_partitioning(nn.initializers.zeros, (None,)), + dtype=self.dtype, + param_dtype=self.weights_dtype, + precision=self.precision, + ), + nn.gelu, + nn.Dense( + self.dim, + use_bias=True, + kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), ("mlp", "embed")), + bias_init=nn.with_logical_partitioning(nn.initializers.zeros, (None,)), + dtype=self.dtype, + param_dtype=self.weights_dtype, + precision=self.precision, + ), + ] + ) # let chunk size default to None self._chunk_size = None diff --git a/src/maxdiffusion/models/ltx_video/transformers_pytorch/attention.py b/src/maxdiffusion/models/ltx_video/transformers_pytorch/attention.py index ceede943..93b16851 100644 --- a/src/maxdiffusion/models/ltx_video/transformers_pytorch/attention.py +++ b/src/maxdiffusion/models/ltx_video/transformers_pytorch/attention.py @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. """ + import inspect from importlib import import_module from typing import Any, Dict, Optional, Tuple diff --git a/src/maxdiffusion/models/wan/transformers/transformer_wan_vace.py b/src/maxdiffusion/models/wan/transformers/transformer_wan_vace.py index ce73ac5d..5f10fa68 100644 --- a/src/maxdiffusion/models/wan/transformers/transformer_wan_vace.py +++ b/src/maxdiffusion/models/wan/transformers/transformer_wan_vace.py @@ -460,11 +460,13 @@ def __call__( control_hidden_states = self.vace_patch_embedding(control_hidden_states) control_hidden_states = jax.lax.collapse(control_hidden_states, 1, -1) - control_hidden_states_padding = jnp.zeros(( - batch_size, - control_hidden_states.shape[1], - hidden_states.shape[2] - control_hidden_states.shape[2], - )) + control_hidden_states_padding = jnp.zeros( + ( + batch_size, + control_hidden_states.shape[1], + hidden_states.shape[2] - control_hidden_states.shape[2], + ) + ) control_hidden_states = jnp.concatenate([control_hidden_states, control_hidden_states_padding], axis=2) diff --git a/src/maxdiffusion/pedagogical_examples/attention_comparison.py b/src/maxdiffusion/pedagogical_examples/attention_comparison.py index 07831550..6981e092 100644 --- a/src/maxdiffusion/pedagogical_examples/attention_comparison.py +++ b/src/maxdiffusion/pedagogical_examples/attention_comparison.py @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. """ + import os import time diff --git a/src/maxdiffusion/pedagogical_examples/dataset_tf_cache_to_tfrecord.py b/src/maxdiffusion/pedagogical_examples/dataset_tf_cache_to_tfrecord.py index 2300d0bd..cc547cc8 100644 --- a/src/maxdiffusion/pedagogical_examples/dataset_tf_cache_to_tfrecord.py +++ b/src/maxdiffusion/pedagogical_examples/dataset_tf_cache_to_tfrecord.py @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. """ + import os import argparse import tensorflow as tf diff --git a/src/maxdiffusion/pedagogical_examples/parameter_count.py b/src/maxdiffusion/pedagogical_examples/parameter_count.py index cf9f8b8d..e9fe4542 100644 --- a/src/maxdiffusion/pedagogical_examples/parameter_count.py +++ b/src/maxdiffusion/pedagogical_examples/parameter_count.py @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. """ + from typing import Sequence from absl import app import jax diff --git a/src/maxdiffusion/pedagogical_examples/to_tfrecords.py b/src/maxdiffusion/pedagogical_examples/to_tfrecords.py index a0a38021..67cf6056 100644 --- a/src/maxdiffusion/pedagogical_examples/to_tfrecords.py +++ b/src/maxdiffusion/pedagogical_examples/to_tfrecords.py @@ -54,12 +54,14 @@ dl_manager = tfds.download.DownloadManager(download_dir="/tmp") tmp_dataset = "dataset" -TRANSFORMS = transforms.Compose([ - transforms.ToTensor(), - transforms.Resize(size=512, interpolation=transforms.InterpolationMode.BICUBIC), - transforms.CenterCrop(size=512), - transforms.Normalize([0.5], [0.5]), -]) +TRANSFORMS = transforms.Compose( + [ + transforms.ToTensor(), + transforms.Resize(size=512, interpolation=transforms.InterpolationMode.BICUBIC), + transforms.CenterCrop(size=512), + transforms.Normalize([0.5], [0.5]), + ] +) def delete_files(path): diff --git a/src/maxdiffusion/pipelines/__init__.py b/src/maxdiffusion/pipelines/__init__.py index 019c79a8..e4298c05 100644 --- a/src/maxdiffusion/pipelines/__init__.py +++ b/src/maxdiffusion/pipelines/__init__.py @@ -51,14 +51,16 @@ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_and_onnx_objects)) else: - _import_structure["stable_diffusion"].extend([ - "OnnxStableDiffusionImg2ImgPipeline", - "OnnxStableDiffusionInpaintPipeline", - "OnnxStableDiffusionInpaintPipelineLegacy", - "OnnxStableDiffusionPipeline", - "OnnxStableDiffusionUpscalePipeline", - "StableDiffusionOnnxPipeline", - ]) + _import_structure["stable_diffusion"].extend( + [ + "OnnxStableDiffusionImg2ImgPipeline", + "OnnxStableDiffusionInpaintPipeline", + "OnnxStableDiffusionInpaintPipelineLegacy", + "OnnxStableDiffusionPipeline", + "OnnxStableDiffusionUpscalePipeline", + "StableDiffusionOnnxPipeline", + ] + ) try: if not is_flax_available(): @@ -80,14 +82,18 @@ _import_structure["controlnet"].extend( ["FlaxStableDiffusionControlNetPipeline", "FlaxStableDiffusionXLControlNetPipeline"] ) - _import_structure["stable_diffusion"].extend([ - "FlaxStableDiffusionImg2ImgPipeline", - "FlaxStableDiffusionInpaintPipeline", - "FlaxStableDiffusionPipeline", - ]) - _import_structure["stable_diffusion_xl"].extend([ - "FlaxStableDiffusionXLPipeline", - ]) + _import_structure["stable_diffusion"].extend( + [ + "FlaxStableDiffusionImg2ImgPipeline", + "FlaxStableDiffusionInpaintPipeline", + "FlaxStableDiffusionPipeline", + ] + ) + _import_structure["stable_diffusion_xl"].extend( + [ + "FlaxStableDiffusionXLPipeline", + ] + ) if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: if not is_onnx_available(): diff --git a/src/maxdiffusion/pipelines/controlnet/__init__.py b/src/maxdiffusion/pipelines/controlnet/__init__.py index 0cf92cd4..fe7070b0 100644 --- a/src/maxdiffusion/pipelines/controlnet/__init__.py +++ b/src/maxdiffusion/pipelines/controlnet/__init__.py @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. """ + from typing import TYPE_CHECKING from ...utils import ( diff --git a/src/maxdiffusion/pipelines/flux/__init__.py b/src/maxdiffusion/pipelines/flux/__init__.py index c39cc364..39ea05b5 100644 --- a/src/maxdiffusion/pipelines/flux/__init__.py +++ b/src/maxdiffusion/pipelines/flux/__init__.py @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. """ + _import_structure = {"pipeline_jflux": "JfluxPipeline"} from .flux_pipeline import ( diff --git a/src/maxdiffusion/pipelines/stable_diffusion/__init__.py b/src/maxdiffusion/pipelines/stable_diffusion/__init__.py index cbef1d5f..72ec9aa1 100644 --- a/src/maxdiffusion/pipelines/stable_diffusion/__init__.py +++ b/src/maxdiffusion/pipelines/stable_diffusion/__init__.py @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. """ + from typing import TYPE_CHECKING from ...utils import ( @@ -84,11 +85,13 @@ StableDiffusionPix2PixZeroPipeline, ) - _dummy_objects.update({ - "StableDiffusionDepth2ImgPipeline": StableDiffusionDepth2ImgPipeline, - "StableDiffusionDiffEditPipeline": StableDiffusionDiffEditPipeline, - "StableDiffusionPix2PixZeroPipeline": StableDiffusionPix2PixZeroPipeline, - }) + _dummy_objects.update( + { + "StableDiffusionDepth2ImgPipeline": StableDiffusionDepth2ImgPipeline, + "StableDiffusionDiffEditPipeline": StableDiffusionDiffEditPipeline, + "StableDiffusionPix2PixZeroPipeline": StableDiffusionPix2PixZeroPipeline, + } + ) else: _import_structure["pipeline_stable_diffusion_depth2img"] = ["StableDiffusionDepth2ImgPipeline"] _import_structure["pipeline_stable_diffusion_diffedit"] = ["StableDiffusionDiffEditPipeline"] diff --git a/src/maxdiffusion/pipelines/stable_diffusion_xl/__init__.py b/src/maxdiffusion/pipelines/stable_diffusion_xl/__init__.py index 13e201ae..1ae1b641 100644 --- a/src/maxdiffusion/pipelines/stable_diffusion_xl/__init__.py +++ b/src/maxdiffusion/pipelines/stable_diffusion_xl/__init__.py @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. """ + from typing import TYPE_CHECKING from ...utils import ( diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline.py b/src/maxdiffusion/pipelines/wan/wan_pipeline.py index defe2551..f0621c8e 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline.py @@ -521,43 +521,41 @@ def prepare_latents_i2v_base( dtype: jnp.dtype, last_image: Optional[jax.Array] = None, ) -> Tuple[jax.Array, jax.Array]: - """ - Encodes the initial image(s) into latents to be used as conditioning. - Returns: - latent_condition: The VAE encoded latents of the image(s). - video_condition: The input to the VAE. - """ - height, width = image.shape[-2:] - image = image[:, :, jnp.newaxis, :, :] # [B, C, 1, H, W] - - if last_image is None: - video_condition = jnp.concatenate( - [image, jnp.zeros((image.shape[0], image.shape[1], num_frames - 1, height, width), dtype=image.dtype)], axis=2 - ) - else: - last_image = last_image[:, :, jnp.newaxis, :, :] - video_condition = jnp.concatenate( - [image, jnp.zeros((image.shape[0], image.shape[1], num_frames - 2, height, width), dtype=image.dtype), last_image], axis=2 - ) - - vae_dtype = getattr(self.vae, "dtype", jnp.float32) - video_condition = video_condition.astype(vae_dtype) - with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): - sharding_spec = P(self.config.mesh_axes[0], None, None, None, None) - video_condition = jax.lax.with_sharding_constraint( - video_condition, - sharding_spec - ) - encoded_output = self.vae.encode(video_condition, self.vae_cache)[0].mode() - - # Normalize latents - latents_mean = jnp.array(self.vae.latents_mean).reshape(1, 1, 1, 1, self.vae.z_dim) - latents_std = jnp.array(self.vae.latents_std).reshape(1, 1, 1, 1, self.vae.z_dim) - latent_condition = encoded_output - latent_condition = latent_condition.astype(dtype) - latent_condition = (latent_condition - latents_mean) / latents_std - - return latent_condition, video_condition + """ + Encodes the initial image(s) into latents to be used as conditioning. + Returns: + latent_condition: The VAE encoded latents of the image(s). + video_condition: The input to the VAE. + """ + height, width = image.shape[-2:] + image = image[:, :, jnp.newaxis, :, :] # [B, C, 1, H, W] + + if last_image is None: + video_condition = jnp.concatenate( + [image, jnp.zeros((image.shape[0], image.shape[1], num_frames - 1, height, width), dtype=image.dtype)], axis=2 + ) + else: + last_image = last_image[:, :, jnp.newaxis, :, :] + video_condition = jnp.concatenate( + [image, jnp.zeros((image.shape[0], image.shape[1], num_frames - 2, height, width), dtype=image.dtype), last_image], + axis=2, + ) + + vae_dtype = getattr(self.vae, "dtype", jnp.float32) + video_condition = video_condition.astype(vae_dtype) + with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): + sharding_spec = P(self.config.mesh_axes[0], None, None, None, None) + video_condition = jax.lax.with_sharding_constraint(video_condition, sharding_spec) + encoded_output = self.vae.encode(video_condition, self.vae_cache)[0].mode() + + # Normalize latents + latents_mean = jnp.array(self.vae.latents_mean).reshape(1, 1, 1, 1, self.vae.z_dim) + latents_std = jnp.array(self.vae.latents_std).reshape(1, 1, 1, 1, self.vae.z_dim) + latent_condition = encoded_output + latent_condition = latent_condition.astype(dtype) + latent_condition = (latent_condition - latents_mean) / latents_std + + return latent_condition, video_condition def _denormalize_latents(self, latents: jax.Array) -> jax.Array: """Denormalizes latents using VAE statistics.""" diff --git a/src/maxdiffusion/schedulers/scheduling_dpmsolver_multistep_flax.py b/src/maxdiffusion/schedulers/scheduling_dpmsolver_multistep_flax.py index c55a49c4..218117eb 100644 --- a/src/maxdiffusion/schedulers/scheduling_dpmsolver_multistep_flax.py +++ b/src/maxdiffusion/schedulers/scheduling_dpmsolver_multistep_flax.py @@ -528,11 +528,13 @@ def step_2(state: DPMSolverMultistepSchedulerState) -> jnp.ndarray: ) def step_3(state: DPMSolverMultistepSchedulerState) -> jnp.ndarray: - timestep_list = jnp.array([ - state.timesteps[step_index - 2], - state.timesteps[step_index - 1], - state.timesteps[step_index], - ]) + timestep_list = jnp.array( + [ + state.timesteps[step_index - 2], + state.timesteps[step_index - 1], + state.timesteps[step_index], + ] + ) return self.multistep_dpm_solver_third_order_update( state, state.model_outputs, diff --git a/src/maxdiffusion/schedulers/scheduling_unipc_multistep_flax.py b/src/maxdiffusion/schedulers/scheduling_unipc_multistep_flax.py index b2c7d96a..03a47fd4 100644 --- a/src/maxdiffusion/schedulers/scheduling_unipc_multistep_flax.py +++ b/src/maxdiffusion/schedulers/scheduling_unipc_multistep_flax.py @@ -136,11 +136,13 @@ def __init__( if self.config.use_beta_sigmas and not is_scipy_available(): raise ImportError("Make sure to install scipy if you want to use beta sigmas.") if ( - sum([ - self.config.use_beta_sigmas, - self.config.use_exponential_sigmas, - self.config.use_karras_sigmas, - ]) + sum( + [ + self.config.use_beta_sigmas, + self.config.use_exponential_sigmas, + self.config.use_karras_sigmas, + ] + ) > 1 ): raise ValueError( diff --git a/src/maxdiffusion/schedulers/scheduling_utils_flax.py b/src/maxdiffusion/schedulers/scheduling_utils_flax.py index d38f1446..e1690ba8 100644 --- a/src/maxdiffusion/schedulers/scheduling_utils_flax.py +++ b/src/maxdiffusion/schedulers/scheduling_utils_flax.py @@ -262,8 +262,7 @@ def create(cls, scheduler): elif config.beta_schedule == "scaled_linear": # this schedule is very specific to the latent diffusion model. betas = ( - jnp.linspace(config.beta_start**0.5, config.beta_end**0.5, config.num_train_timesteps, dtype=scheduler.dtype) - ** 2 + jnp.linspace(config.beta_start**0.5, config.beta_end**0.5, config.num_train_timesteps, dtype=scheduler.dtype) ** 2 ) elif config.beta_schedule == "squaredcos_cap_v2": # Glide cosine schedule diff --git a/src/maxdiffusion/tests/configuration_utils_test.py b/src/maxdiffusion/tests/configuration_utils_test.py index ee761a38..df46ea75 100644 --- a/src/maxdiffusion/tests/configuration_utils_test.py +++ b/src/maxdiffusion/tests/configuration_utils_test.py @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. """ + import json import os diff --git a/src/maxdiffusion/tests/flop_calculations_test.py b/src/maxdiffusion/tests/flop_calculations_test.py index ca0a5020..a58d5dcc 100644 --- a/src/maxdiffusion/tests/flop_calculations_test.py +++ b/src/maxdiffusion/tests/flop_calculations_test.py @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. """ + import os import unittest from unittest.mock import Mock diff --git a/src/maxdiffusion/tests/generate_flux_smoke_test.py b/src/maxdiffusion/tests/generate_flux_smoke_test.py index 12bfe77b..4f174716 100644 --- a/src/maxdiffusion/tests/generate_flux_smoke_test.py +++ b/src/maxdiffusion/tests/generate_flux_smoke_test.py @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. """ + import os import unittest import pytest diff --git a/src/maxdiffusion/tests/generate_sdxl_smoke_test.py b/src/maxdiffusion/tests/generate_sdxl_smoke_test.py index ff823010..e2b4d772 100644 --- a/src/maxdiffusion/tests/generate_sdxl_smoke_test.py +++ b/src/maxdiffusion/tests/generate_sdxl_smoke_test.py @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. """ + import os import unittest import pytest diff --git a/src/maxdiffusion/tests/generate_smoke_test.py b/src/maxdiffusion/tests/generate_smoke_test.py index 2c5b783a..b6722b3a 100644 --- a/src/maxdiffusion/tests/generate_smoke_test.py +++ b/src/maxdiffusion/tests/generate_smoke_test.py @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. """ + import os import unittest import pytest diff --git a/src/maxdiffusion/utils/deprecation_utils.py b/src/maxdiffusion/utils/deprecation_utils.py index 265a60b5..a7077ed7 100644 --- a/src/maxdiffusion/utils/deprecation_utils.py +++ b/src/maxdiffusion/utils/deprecation_utils.py @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. """ + import inspect import warnings from typing import Any, Dict, Optional, Union diff --git a/src/maxdiffusion/utils/export_utils.py b/src/maxdiffusion/utils/export_utils.py index 51b05d30..fa394129 100644 --- a/src/maxdiffusion/utils/export_utils.py +++ b/src/maxdiffusion/utils/export_utils.py @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. """ + import io import random import struct diff --git a/src/maxdiffusion/utils/import_utils.py b/src/maxdiffusion/utils/import_utils.py index 05ef72ec..d83596e8 100644 --- a/src/maxdiffusion/utils/import_utils.py +++ b/src/maxdiffusion/utils/import_utils.py @@ -512,27 +512,29 @@ def is_peft_available(): """ -BACKENDS_MAPPING = OrderedDict([ - ("bs4", (is_bs4_available, BS4_IMPORT_ERROR)), - ("flax", (is_flax_available, FLAX_IMPORT_ERROR)), - ("inflect", (is_inflect_available, INFLECT_IMPORT_ERROR)), - ("onnx", (is_onnx_available, ONNX_IMPORT_ERROR)), - ("opencv", (is_opencv_available, OPENCV_IMPORT_ERROR)), - ("scipy", (is_scipy_available, SCIPY_IMPORT_ERROR)), - ("torch", (is_torch_available, PYTORCH_IMPORT_ERROR)), - ("unidecode", (is_unidecode_available, UNIDECODE_IMPORT_ERROR)), - ("librosa", (is_librosa_available, LIBROSA_IMPORT_ERROR)), - ("k_diffusion", (is_k_diffusion_available, K_DIFFUSION_IMPORT_ERROR)), - ("note_seq", (is_note_seq_available, NOTE_SEQ_IMPORT_ERROR)), - ("wandb", (is_wandb_available, WANDB_IMPORT_ERROR)), - ("omegaconf", (is_omegaconf_available, OMEGACONF_IMPORT_ERROR)), - ("tensorboard", (is_tensorboard_available, TENSORBOARD_IMPORT_ERROR)), - ("compel", (is_compel_available, COMPEL_IMPORT_ERROR)), - ("ftfy", (is_ftfy_available, FTFY_IMPORT_ERROR)), - ("torchsde", (is_torchsde_available, TORCHSDE_IMPORT_ERROR)), - ("imageio", (is_imageio_available, IMAGEIO_IMPORT_ERROR)), - ("invisible_watermark", (is_invisible_watermark_available, INVISIBLE_WATERMARK_IMPORT_ERROR)), -]) +BACKENDS_MAPPING = OrderedDict( + [ + ("bs4", (is_bs4_available, BS4_IMPORT_ERROR)), + ("flax", (is_flax_available, FLAX_IMPORT_ERROR)), + ("inflect", (is_inflect_available, INFLECT_IMPORT_ERROR)), + ("onnx", (is_onnx_available, ONNX_IMPORT_ERROR)), + ("opencv", (is_opencv_available, OPENCV_IMPORT_ERROR)), + ("scipy", (is_scipy_available, SCIPY_IMPORT_ERROR)), + ("torch", (is_torch_available, PYTORCH_IMPORT_ERROR)), + ("unidecode", (is_unidecode_available, UNIDECODE_IMPORT_ERROR)), + ("librosa", (is_librosa_available, LIBROSA_IMPORT_ERROR)), + ("k_diffusion", (is_k_diffusion_available, K_DIFFUSION_IMPORT_ERROR)), + ("note_seq", (is_note_seq_available, NOTE_SEQ_IMPORT_ERROR)), + ("wandb", (is_wandb_available, WANDB_IMPORT_ERROR)), + ("omegaconf", (is_omegaconf_available, OMEGACONF_IMPORT_ERROR)), + ("tensorboard", (is_tensorboard_available, TENSORBOARD_IMPORT_ERROR)), + ("compel", (is_compel_available, COMPEL_IMPORT_ERROR)), + ("ftfy", (is_ftfy_available, FTFY_IMPORT_ERROR)), + ("torchsde", (is_torchsde_available, TORCHSDE_IMPORT_ERROR)), + ("imageio", (is_imageio_available, IMAGEIO_IMPORT_ERROR)), + ("invisible_watermark", (is_invisible_watermark_available, INVISIBLE_WATERMARK_IMPORT_ERROR)), + ] +) def requires_backends(obj, backends): diff --git a/src/maxdiffusion/utils/loading_utils.py b/src/maxdiffusion/utils/loading_utils.py index f2b72cbd..735d2261 100644 --- a/src/maxdiffusion/utils/loading_utils.py +++ b/src/maxdiffusion/utils/loading_utils.py @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. """ + import os from typing import Callable, List, Optional, Union diff --git a/src/maxdiffusion/utils/pil_utils.py b/src/maxdiffusion/utils/pil_utils.py index a05aa47a..86d07c66 100644 --- a/src/maxdiffusion/utils/pil_utils.py +++ b/src/maxdiffusion/utils/pil_utils.py @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. """ + from typing import List import PIL.Image diff --git a/src/maxdiffusion/utils/testing_utils.py b/src/maxdiffusion/utils/testing_utils.py index 6194a03a..55be62ac 100644 --- a/src/maxdiffusion/utils/testing_utils.py +++ b/src/maxdiffusion/utils/testing_utils.py @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. """ + import functools import importlib import inspect From 498450467a644d949b48b6c6d86f8f7fb0e37a9f Mon Sep 17 00:00:00 2001 From: Prisha Jain Date: Wed, 21 Jan 2026 04:31:13 +0000 Subject: [PATCH 5/5] pyink check --- code_style.sh | 0 src/maxdiffusion/__init__.py | 378 +++++++++--------- .../transformers/transformer_flux_flax.py | 88 ++-- .../wan/transformers/transformer_wan_vace.py | 12 +- .../pedagogical_examples/to_tfrecords.py | 14 +- src/maxdiffusion/pipelines/__init__.py | 38 +- .../pipelines/stable_diffusion/__init__.py | 12 +- .../scheduling_dpmsolver_multistep_flax.py | 12 +- .../scheduling_unipc_multistep_flax.py | 12 +- .../schedulers/scheduling_utils_flax.py | 3 +- src/maxdiffusion/utils/import_utils.py | 44 +- 11 files changed, 289 insertions(+), 324 deletions(-) mode change 100644 => 100755 code_style.sh diff --git a/code_style.sh b/code_style.sh old mode 100644 new mode 100755 diff --git a/src/maxdiffusion/__init__.py b/src/maxdiffusion/__init__.py index 2d084be6..a1a2c2f5 100644 --- a/src/maxdiffusion/__init__.py +++ b/src/maxdiffusion/__init__.py @@ -84,25 +84,23 @@ _import_structure["utils.dummy_pt_objects"] = [name for name in dir(dummy_pt_objects) if not name.startswith("_")] else: - _import_structure["models"].extend( - [ - "AsymmetricAutoencoderKL", - "AutoencoderKL", - "AutoencoderTiny", - "ControlNetModel", - "ModelMixin", - "MultiAdapter", - "PriorTransformer", - "T2IAdapter", - "T5FilmDecoder", - "Transformer2DModel", - "UNet1DModel", - "UNet2DConditionModel", - "UNet2DModel", - "UNet3DConditionModel", - "VQModel", - ] - ) + _import_structure["models"].extend([ + "AsymmetricAutoencoderKL", + "AutoencoderKL", + "AutoencoderTiny", + "ControlNetModel", + "ModelMixin", + "MultiAdapter", + "PriorTransformer", + "T2IAdapter", + "T5FilmDecoder", + "Transformer2DModel", + "UNet1DModel", + "UNet2DConditionModel", + "UNet2DModel", + "UNet3DConditionModel", + "VQModel", + ]) _import_structure["optimization"] = [ "get_constant_schedule", "get_constant_schedule_with_warmup", @@ -113,56 +111,52 @@ "get_scheduler", ] - _import_structure["pipelines"].extend( - [ - "AudioPipelineOutput", - "AutoPipelineForImage2Image", - "AutoPipelineForInpainting", - "AutoPipelineForText2Image", - "ConsistencyModelPipeline", - "DanceDiffusionPipeline", - "DDIMPipeline", - "DDPMPipeline", - "DiffusionPipeline", - "DiTPipeline", - "ImagePipelineOutput", - "KarrasVePipeline", - "LDMPipeline", - "LDMSuperResolutionPipeline", - "PNDMPipeline", - "RePaintPipeline", - "ScoreSdeVePipeline", - ] - ) - _import_structure["schedulers"].extend( - [ - "CMStochasticIterativeScheduler", - "DDIMInverseScheduler", - "DDIMParallelScheduler", - "DDIMScheduler", - "DDPMParallelScheduler", - "DDPMScheduler", - "DDPMWuerstchenScheduler", - "DEISMultistepScheduler", - "DPMSolverMultistepInverseScheduler", - "DPMSolverMultistepScheduler", - "DPMSolverSinglestepScheduler", - "EulerAncestralDiscreteScheduler", - "EulerDiscreteScheduler", - "HeunDiscreteScheduler", - "IPNDMScheduler", - "KarrasVeScheduler", - "KDPM2AncestralDiscreteScheduler", - "KDPM2DiscreteScheduler", - "PNDMScheduler", - "RePaintScheduler", - "SchedulerMixin", - "ScoreSdeVeScheduler", - "UnCLIPScheduler", - "UniPCMultistepScheduler", - "VQDiffusionScheduler", - ] - ) + _import_structure["pipelines"].extend([ + "AudioPipelineOutput", + "AutoPipelineForImage2Image", + "AutoPipelineForInpainting", + "AutoPipelineForText2Image", + "ConsistencyModelPipeline", + "DanceDiffusionPipeline", + "DDIMPipeline", + "DDPMPipeline", + "DiffusionPipeline", + "DiTPipeline", + "ImagePipelineOutput", + "KarrasVePipeline", + "LDMPipeline", + "LDMSuperResolutionPipeline", + "PNDMPipeline", + "RePaintPipeline", + "ScoreSdeVePipeline", + ]) + _import_structure["schedulers"].extend([ + "CMStochasticIterativeScheduler", + "DDIMInverseScheduler", + "DDIMParallelScheduler", + "DDIMScheduler", + "DDPMParallelScheduler", + "DDPMScheduler", + "DDPMWuerstchenScheduler", + "DEISMultistepScheduler", + "DPMSolverMultistepInverseScheduler", + "DPMSolverMultistepScheduler", + "DPMSolverSinglestepScheduler", + "EulerAncestralDiscreteScheduler", + "EulerDiscreteScheduler", + "HeunDiscreteScheduler", + "IPNDMScheduler", + "KarrasVeScheduler", + "KDPM2AncestralDiscreteScheduler", + "KDPM2DiscreteScheduler", + "PNDMScheduler", + "RePaintScheduler", + "SchedulerMixin", + "ScoreSdeVeScheduler", + "UnCLIPScheduler", + "UniPCMultistepScheduler", + "VQDiffusionScheduler", + ]) _import_structure["training_utils"] = ["EMAModel"] try: @@ -202,100 +196,98 @@ ] else: - _import_structure["pipelines"].extend( - [ - "AltDiffusionImg2ImgPipeline", - "AltDiffusionPipeline", - "AudioLDM2Pipeline", - "AudioLDM2ProjectionModel", - "AudioLDM2UNet2DConditionModel", - "AudioLDMPipeline", - "BlipDiffusionControlNetPipeline", - "BlipDiffusionPipeline", - "CLIPImageProjection", - "CycleDiffusionPipeline", - "IFImg2ImgPipeline", - "IFImg2ImgSuperResolutionPipeline", - "IFInpaintingPipeline", - "IFInpaintingSuperResolutionPipeline", - "IFPipeline", - "IFSuperResolutionPipeline", - "ImageTextPipelineOutput", - "KandinskyCombinedPipeline", - "KandinskyImg2ImgCombinedPipeline", - "KandinskyImg2ImgPipeline", - "KandinskyInpaintCombinedPipeline", - "KandinskyInpaintPipeline", - "KandinskyPipeline", - "KandinskyPriorPipeline", - "KandinskyV22CombinedPipeline", - "KandinskyV22ControlnetImg2ImgPipeline", - "KandinskyV22ControlnetPipeline", - "KandinskyV22Img2ImgCombinedPipeline", - "KandinskyV22Img2ImgPipeline", - "KandinskyV22InpaintCombinedPipeline", - "KandinskyV22InpaintPipeline", - "KandinskyV22Pipeline", - "KandinskyV22PriorEmb2EmbPipeline", - "KandinskyV22PriorPipeline", - "LDMTextToImagePipeline", - "MusicLDMPipeline", - "PaintByExamplePipeline", - "SemanticStableDiffusionPipeline", - "ShapEImg2ImgPipeline", - "ShapEPipeline", - "StableDiffusionAdapterPipeline", - "StableDiffusionAttendAndExcitePipeline", - "StableDiffusionControlNetImg2ImgPipeline", - "StableDiffusionControlNetInpaintPipeline", - "StableDiffusionControlNetPipeline", - "StableDiffusionDepth2ImgPipeline", - "StableDiffusionDiffEditPipeline", - "StableDiffusionGLIGENPipeline", - "StableDiffusionGLIGENTextImagePipeline", - "StableDiffusionImageVariationPipeline", - "StableDiffusionImg2ImgPipeline", - "StableDiffusionInpaintPipeline", - "StableDiffusionInpaintPipelineLegacy", - "StableDiffusionInstructPix2PixPipeline", - "StableDiffusionLatentUpscalePipeline", - "StableDiffusionLDM3DPipeline", - "StableDiffusionModelEditingPipeline", - "StableDiffusionPanoramaPipeline", - "StableDiffusionParadigmsPipeline", - "StableDiffusionPipeline", - "StableDiffusionPipelineSafe", - "StableDiffusionPix2PixZeroPipeline", - "StableDiffusionSAGPipeline", - "StableDiffusionUpscalePipeline", - "StableDiffusionXLAdapterPipeline", - "StableDiffusionXLControlNetImg2ImgPipeline", - "StableDiffusionXLControlNetInpaintPipeline", - "StableDiffusionXLControlNetPipeline", - "StableDiffusionXLImg2ImgPipeline", - "StableDiffusionXLInpaintPipeline", - "StableDiffusionXLInstructPix2PixPipeline", - "StableDiffusionXLPipeline", - "StableUnCLIPImg2ImgPipeline", - "StableUnCLIPPipeline", - "TextToVideoSDPipeline", - "TextToVideoZeroPipeline", - "UnCLIPImageVariationPipeline", - "UnCLIPPipeline", - "UniDiffuserModel", - "UniDiffuserPipeline", - "UniDiffuserTextDecoder", - "VersatileDiffusionDualGuidedPipeline", - "VersatileDiffusionImageVariationPipeline", - "VersatileDiffusionPipeline", - "VersatileDiffusionTextToImagePipeline", - "VideoToVideoSDPipeline", - "VQDiffusionPipeline", - "WuerstchenCombinedPipeline", - "WuerstchenDecoderPipeline", - "WuerstchenPriorPipeline", - ] - ) + _import_structure["pipelines"].extend([ + "AltDiffusionImg2ImgPipeline", + "AltDiffusionPipeline", + "AudioLDM2Pipeline", + "AudioLDM2ProjectionModel", + "AudioLDM2UNet2DConditionModel", + "AudioLDMPipeline", + "BlipDiffusionControlNetPipeline", + "BlipDiffusionPipeline", + "CLIPImageProjection", + "CycleDiffusionPipeline", + "IFImg2ImgPipeline", + "IFImg2ImgSuperResolutionPipeline", + "IFInpaintingPipeline", + "IFInpaintingSuperResolutionPipeline", + "IFPipeline", + "IFSuperResolutionPipeline", + "ImageTextPipelineOutput", + "KandinskyCombinedPipeline", + "KandinskyImg2ImgCombinedPipeline", + "KandinskyImg2ImgPipeline", + "KandinskyInpaintCombinedPipeline", + "KandinskyInpaintPipeline", + "KandinskyPipeline", + "KandinskyPriorPipeline", + "KandinskyV22CombinedPipeline", + "KandinskyV22ControlnetImg2ImgPipeline", + "KandinskyV22ControlnetPipeline", + "KandinskyV22Img2ImgCombinedPipeline", + "KandinskyV22Img2ImgPipeline", + "KandinskyV22InpaintCombinedPipeline", + "KandinskyV22InpaintPipeline", + "KandinskyV22Pipeline", + "KandinskyV22PriorEmb2EmbPipeline", + "KandinskyV22PriorPipeline", + "LDMTextToImagePipeline", + "MusicLDMPipeline", + "PaintByExamplePipeline", + "SemanticStableDiffusionPipeline", + "ShapEImg2ImgPipeline", + "ShapEPipeline", + "StableDiffusionAdapterPipeline", + "StableDiffusionAttendAndExcitePipeline", + "StableDiffusionControlNetImg2ImgPipeline", + "StableDiffusionControlNetInpaintPipeline", + "StableDiffusionControlNetPipeline", + "StableDiffusionDepth2ImgPipeline", + "StableDiffusionDiffEditPipeline", + "StableDiffusionGLIGENPipeline", + "StableDiffusionGLIGENTextImagePipeline", + "StableDiffusionImageVariationPipeline", + "StableDiffusionImg2ImgPipeline", + "StableDiffusionInpaintPipeline", + "StableDiffusionInpaintPipelineLegacy", + "StableDiffusionInstructPix2PixPipeline", + "StableDiffusionLatentUpscalePipeline", + "StableDiffusionLDM3DPipeline", + "StableDiffusionModelEditingPipeline", + "StableDiffusionPanoramaPipeline", + "StableDiffusionParadigmsPipeline", + "StableDiffusionPipeline", + "StableDiffusionPipelineSafe", + "StableDiffusionPix2PixZeroPipeline", + "StableDiffusionSAGPipeline", + "StableDiffusionUpscalePipeline", + "StableDiffusionXLAdapterPipeline", + "StableDiffusionXLControlNetImg2ImgPipeline", + "StableDiffusionXLControlNetInpaintPipeline", + "StableDiffusionXLControlNetPipeline", + "StableDiffusionXLImg2ImgPipeline", + "StableDiffusionXLInpaintPipeline", + "StableDiffusionXLInstructPix2PixPipeline", + "StableDiffusionXLPipeline", + "StableUnCLIPImg2ImgPipeline", + "StableUnCLIPPipeline", + "TextToVideoSDPipeline", + "TextToVideoZeroPipeline", + "UnCLIPImageVariationPipeline", + "UnCLIPPipeline", + "UniDiffuserModel", + "UniDiffuserPipeline", + "UniDiffuserTextDecoder", + "VersatileDiffusionDualGuidedPipeline", + "VersatileDiffusionImageVariationPipeline", + "VersatileDiffusionPipeline", + "VersatileDiffusionTextToImagePipeline", + "VideoToVideoSDPipeline", + "VQDiffusionPipeline", + "WuerstchenCombinedPipeline", + "WuerstchenDecoderPipeline", + "WuerstchenPriorPipeline", + ]) try: if not (is_torch_available() and is_k_diffusion_available()): @@ -321,16 +313,14 @@ ] else: - _import_structure["pipelines"].extend( - [ - "OnnxStableDiffusionImg2ImgPipeline", - "OnnxStableDiffusionInpaintPipeline", - "OnnxStableDiffusionInpaintPipelineLegacy", - "OnnxStableDiffusionPipeline", - "OnnxStableDiffusionUpscalePipeline", - "StableDiffusionOnnxPipeline", - ] - ) + _import_structure["pipelines"].extend([ + "OnnxStableDiffusionImg2ImgPipeline", + "OnnxStableDiffusionInpaintPipeline", + "OnnxStableDiffusionInpaintPipelineLegacy", + "OnnxStableDiffusionPipeline", + "OnnxStableDiffusionUpscalePipeline", + "StableDiffusionOnnxPipeline", + ]) try: if not (is_torch_available() and is_librosa_available()): @@ -376,19 +366,17 @@ _import_structure["models.vae_flax"] = ["FlaxAutoencoderKL"] _import_structure["models.ltx_video.transformers.transformer3d"] = ["Transformer3DModel"] _import_structure["pipelines"].extend(["FlaxDiffusionPipeline"]) - _import_structure["schedulers"].extend( - [ - "FlaxDDIMScheduler", - "FlaxDDPMScheduler", - "FlaxDPMSolverMultistepScheduler", - "FlaxEulerDiscreteScheduler", - "FlaxKarrasVeScheduler", - "FlaxLMSDiscreteScheduler", - "FlaxPNDMScheduler", - "FlaxSchedulerMixin", - "FlaxScoreSdeVeScheduler", - ] - ) + _import_structure["schedulers"].extend([ + "FlaxDDIMScheduler", + "FlaxDDPMScheduler", + "FlaxDPMSolverMultistepScheduler", + "FlaxEulerDiscreteScheduler", + "FlaxKarrasVeScheduler", + "FlaxLMSDiscreteScheduler", + "FlaxPNDMScheduler", + "FlaxSchedulerMixin", + "FlaxScoreSdeVeScheduler", + ]) try: @@ -403,16 +391,14 @@ else: - _import_structure["pipelines"].extend( - [ - "FlaxStableDiffusionControlNetPipeline", - "FlaxStableDiffusionXLControlNetPipeline", - "FlaxStableDiffusionImg2ImgPipeline", - "FlaxStableDiffusionInpaintPipeline", - "FlaxStableDiffusionPipeline", - "FlaxStableDiffusionXLPipeline", - ] - ) + _import_structure["pipelines"].extend([ + "FlaxStableDiffusionControlNetPipeline", + "FlaxStableDiffusionXLControlNetPipeline", + "FlaxStableDiffusionImg2ImgPipeline", + "FlaxStableDiffusionInpaintPipeline", + "FlaxStableDiffusionPipeline", + "FlaxStableDiffusionXLPipeline", + ]) try: if not (is_note_seq_available()): diff --git a/src/maxdiffusion/models/flux/transformers/transformer_flux_flax.py b/src/maxdiffusion/models/flux/transformers/transformer_flux_flax.py index a4cfab1b..814e21ea 100644 --- a/src/maxdiffusion/models/flux/transformers/transformer_flux_flax.py +++ b/src/maxdiffusion/models/flux/transformers/transformer_flux_flax.py @@ -202,29 +202,27 @@ def setup(self): dtype=self.dtype, param_dtype=self.weights_dtype, ) - self.img_mlp = nn.Sequential( - [ - nn.Dense( - int(self.dim * self.mlp_ratio), - use_bias=True, - kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), ("embed", "mlp")), - bias_init=nn.with_logical_partitioning(nn.initializers.zeros, (None,)), - dtype=self.dtype, - param_dtype=self.weights_dtype, - precision=self.precision, - ), - nn.gelu, - nn.Dense( - self.dim, - use_bias=True, - kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), ("mlp", "embed")), - bias_init=nn.with_logical_partitioning(nn.initializers.zeros, (None,)), - dtype=self.dtype, - param_dtype=self.weights_dtype, - precision=self.precision, - ), - ] - ) + self.img_mlp = nn.Sequential([ + nn.Dense( + int(self.dim * self.mlp_ratio), + use_bias=True, + kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), ("embed", "mlp")), + bias_init=nn.with_logical_partitioning(nn.initializers.zeros, (None,)), + dtype=self.dtype, + param_dtype=self.weights_dtype, + precision=self.precision, + ), + nn.gelu, + nn.Dense( + self.dim, + use_bias=True, + kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), ("mlp", "embed")), + bias_init=nn.with_logical_partitioning(nn.initializers.zeros, (None,)), + dtype=self.dtype, + param_dtype=self.weights_dtype, + precision=self.precision, + ), + ]) self.txt_norm2 = nn.LayerNorm( use_bias=False, @@ -233,29 +231,27 @@ def setup(self): dtype=self.dtype, param_dtype=self.weights_dtype, ) - self.txt_mlp = nn.Sequential( - [ - nn.Dense( - int(self.dim * self.mlp_ratio), - use_bias=True, - kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), ("embed", "mlp")), - bias_init=nn.with_logical_partitioning(nn.initializers.zeros, (None,)), - dtype=self.dtype, - param_dtype=self.weights_dtype, - precision=self.precision, - ), - nn.gelu, - nn.Dense( - self.dim, - use_bias=True, - kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), ("mlp", "embed")), - bias_init=nn.with_logical_partitioning(nn.initializers.zeros, (None,)), - dtype=self.dtype, - param_dtype=self.weights_dtype, - precision=self.precision, - ), - ] - ) + self.txt_mlp = nn.Sequential([ + nn.Dense( + int(self.dim * self.mlp_ratio), + use_bias=True, + kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), ("embed", "mlp")), + bias_init=nn.with_logical_partitioning(nn.initializers.zeros, (None,)), + dtype=self.dtype, + param_dtype=self.weights_dtype, + precision=self.precision, + ), + nn.gelu, + nn.Dense( + self.dim, + use_bias=True, + kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), ("mlp", "embed")), + bias_init=nn.with_logical_partitioning(nn.initializers.zeros, (None,)), + dtype=self.dtype, + param_dtype=self.weights_dtype, + precision=self.precision, + ), + ]) # let chunk size default to None self._chunk_size = None diff --git a/src/maxdiffusion/models/wan/transformers/transformer_wan_vace.py b/src/maxdiffusion/models/wan/transformers/transformer_wan_vace.py index 5f10fa68..ce73ac5d 100644 --- a/src/maxdiffusion/models/wan/transformers/transformer_wan_vace.py +++ b/src/maxdiffusion/models/wan/transformers/transformer_wan_vace.py @@ -460,13 +460,11 @@ def __call__( control_hidden_states = self.vace_patch_embedding(control_hidden_states) control_hidden_states = jax.lax.collapse(control_hidden_states, 1, -1) - control_hidden_states_padding = jnp.zeros( - ( - batch_size, - control_hidden_states.shape[1], - hidden_states.shape[2] - control_hidden_states.shape[2], - ) - ) + control_hidden_states_padding = jnp.zeros(( + batch_size, + control_hidden_states.shape[1], + hidden_states.shape[2] - control_hidden_states.shape[2], + )) control_hidden_states = jnp.concatenate([control_hidden_states, control_hidden_states_padding], axis=2) diff --git a/src/maxdiffusion/pedagogical_examples/to_tfrecords.py b/src/maxdiffusion/pedagogical_examples/to_tfrecords.py index 67cf6056..a0a38021 100644 --- a/src/maxdiffusion/pedagogical_examples/to_tfrecords.py +++ b/src/maxdiffusion/pedagogical_examples/to_tfrecords.py @@ -54,14 +54,12 @@ dl_manager = tfds.download.DownloadManager(download_dir="/tmp") tmp_dataset = "dataset" -TRANSFORMS = transforms.Compose( - [ - transforms.ToTensor(), - transforms.Resize(size=512, interpolation=transforms.InterpolationMode.BICUBIC), - transforms.CenterCrop(size=512), - transforms.Normalize([0.5], [0.5]), - ] -) +TRANSFORMS = transforms.Compose([ + transforms.ToTensor(), + transforms.Resize(size=512, interpolation=transforms.InterpolationMode.BICUBIC), + transforms.CenterCrop(size=512), + transforms.Normalize([0.5], [0.5]), +]) def delete_files(path): diff --git a/src/maxdiffusion/pipelines/__init__.py b/src/maxdiffusion/pipelines/__init__.py index e4298c05..019c79a8 100644 --- a/src/maxdiffusion/pipelines/__init__.py +++ b/src/maxdiffusion/pipelines/__init__.py @@ -51,16 +51,14 @@ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_and_onnx_objects)) else: - _import_structure["stable_diffusion"].extend( - [ - "OnnxStableDiffusionImg2ImgPipeline", - "OnnxStableDiffusionInpaintPipeline", - "OnnxStableDiffusionInpaintPipelineLegacy", - "OnnxStableDiffusionPipeline", - "OnnxStableDiffusionUpscalePipeline", - "StableDiffusionOnnxPipeline", - ] - ) + _import_structure["stable_diffusion"].extend([ + "OnnxStableDiffusionImg2ImgPipeline", + "OnnxStableDiffusionInpaintPipeline", + "OnnxStableDiffusionInpaintPipelineLegacy", + "OnnxStableDiffusionPipeline", + "OnnxStableDiffusionUpscalePipeline", + "StableDiffusionOnnxPipeline", + ]) try: if not is_flax_available(): @@ -82,18 +80,14 @@ _import_structure["controlnet"].extend( ["FlaxStableDiffusionControlNetPipeline", "FlaxStableDiffusionXLControlNetPipeline"] ) - _import_structure["stable_diffusion"].extend( - [ - "FlaxStableDiffusionImg2ImgPipeline", - "FlaxStableDiffusionInpaintPipeline", - "FlaxStableDiffusionPipeline", - ] - ) - _import_structure["stable_diffusion_xl"].extend( - [ - "FlaxStableDiffusionXLPipeline", - ] - ) + _import_structure["stable_diffusion"].extend([ + "FlaxStableDiffusionImg2ImgPipeline", + "FlaxStableDiffusionInpaintPipeline", + "FlaxStableDiffusionPipeline", + ]) + _import_structure["stable_diffusion_xl"].extend([ + "FlaxStableDiffusionXLPipeline", + ]) if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: if not is_onnx_available(): diff --git a/src/maxdiffusion/pipelines/stable_diffusion/__init__.py b/src/maxdiffusion/pipelines/stable_diffusion/__init__.py index 72ec9aa1..564b0dfa 100644 --- a/src/maxdiffusion/pipelines/stable_diffusion/__init__.py +++ b/src/maxdiffusion/pipelines/stable_diffusion/__init__.py @@ -85,13 +85,11 @@ StableDiffusionPix2PixZeroPipeline, ) - _dummy_objects.update( - { - "StableDiffusionDepth2ImgPipeline": StableDiffusionDepth2ImgPipeline, - "StableDiffusionDiffEditPipeline": StableDiffusionDiffEditPipeline, - "StableDiffusionPix2PixZeroPipeline": StableDiffusionPix2PixZeroPipeline, - } - ) + _dummy_objects.update({ + "StableDiffusionDepth2ImgPipeline": StableDiffusionDepth2ImgPipeline, + "StableDiffusionDiffEditPipeline": StableDiffusionDiffEditPipeline, + "StableDiffusionPix2PixZeroPipeline": StableDiffusionPix2PixZeroPipeline, + }) else: _import_structure["pipeline_stable_diffusion_depth2img"] = ["StableDiffusionDepth2ImgPipeline"] _import_structure["pipeline_stable_diffusion_diffedit"] = ["StableDiffusionDiffEditPipeline"] diff --git a/src/maxdiffusion/schedulers/scheduling_dpmsolver_multistep_flax.py b/src/maxdiffusion/schedulers/scheduling_dpmsolver_multistep_flax.py index 218117eb..c55a49c4 100644 --- a/src/maxdiffusion/schedulers/scheduling_dpmsolver_multistep_flax.py +++ b/src/maxdiffusion/schedulers/scheduling_dpmsolver_multistep_flax.py @@ -528,13 +528,11 @@ def step_2(state: DPMSolverMultistepSchedulerState) -> jnp.ndarray: ) def step_3(state: DPMSolverMultistepSchedulerState) -> jnp.ndarray: - timestep_list = jnp.array( - [ - state.timesteps[step_index - 2], - state.timesteps[step_index - 1], - state.timesteps[step_index], - ] - ) + timestep_list = jnp.array([ + state.timesteps[step_index - 2], + state.timesteps[step_index - 1], + state.timesteps[step_index], + ]) return self.multistep_dpm_solver_third_order_update( state, state.model_outputs, diff --git a/src/maxdiffusion/schedulers/scheduling_unipc_multistep_flax.py b/src/maxdiffusion/schedulers/scheduling_unipc_multistep_flax.py index 03a47fd4..b2c7d96a 100644 --- a/src/maxdiffusion/schedulers/scheduling_unipc_multistep_flax.py +++ b/src/maxdiffusion/schedulers/scheduling_unipc_multistep_flax.py @@ -136,13 +136,11 @@ def __init__( if self.config.use_beta_sigmas and not is_scipy_available(): raise ImportError("Make sure to install scipy if you want to use beta sigmas.") if ( - sum( - [ - self.config.use_beta_sigmas, - self.config.use_exponential_sigmas, - self.config.use_karras_sigmas, - ] - ) + sum([ + self.config.use_beta_sigmas, + self.config.use_exponential_sigmas, + self.config.use_karras_sigmas, + ]) > 1 ): raise ValueError( diff --git a/src/maxdiffusion/schedulers/scheduling_utils_flax.py b/src/maxdiffusion/schedulers/scheduling_utils_flax.py index e1690ba8..d38f1446 100644 --- a/src/maxdiffusion/schedulers/scheduling_utils_flax.py +++ b/src/maxdiffusion/schedulers/scheduling_utils_flax.py @@ -262,7 +262,8 @@ def create(cls, scheduler): elif config.beta_schedule == "scaled_linear": # this schedule is very specific to the latent diffusion model. betas = ( - jnp.linspace(config.beta_start**0.5, config.beta_end**0.5, config.num_train_timesteps, dtype=scheduler.dtype) ** 2 + jnp.linspace(config.beta_start**0.5, config.beta_end**0.5, config.num_train_timesteps, dtype=scheduler.dtype) + ** 2 ) elif config.beta_schedule == "squaredcos_cap_v2": # Glide cosine schedule diff --git a/src/maxdiffusion/utils/import_utils.py b/src/maxdiffusion/utils/import_utils.py index d83596e8..05ef72ec 100644 --- a/src/maxdiffusion/utils/import_utils.py +++ b/src/maxdiffusion/utils/import_utils.py @@ -512,29 +512,27 @@ def is_peft_available(): """ -BACKENDS_MAPPING = OrderedDict( - [ - ("bs4", (is_bs4_available, BS4_IMPORT_ERROR)), - ("flax", (is_flax_available, FLAX_IMPORT_ERROR)), - ("inflect", (is_inflect_available, INFLECT_IMPORT_ERROR)), - ("onnx", (is_onnx_available, ONNX_IMPORT_ERROR)), - ("opencv", (is_opencv_available, OPENCV_IMPORT_ERROR)), - ("scipy", (is_scipy_available, SCIPY_IMPORT_ERROR)), - ("torch", (is_torch_available, PYTORCH_IMPORT_ERROR)), - ("unidecode", (is_unidecode_available, UNIDECODE_IMPORT_ERROR)), - ("librosa", (is_librosa_available, LIBROSA_IMPORT_ERROR)), - ("k_diffusion", (is_k_diffusion_available, K_DIFFUSION_IMPORT_ERROR)), - ("note_seq", (is_note_seq_available, NOTE_SEQ_IMPORT_ERROR)), - ("wandb", (is_wandb_available, WANDB_IMPORT_ERROR)), - ("omegaconf", (is_omegaconf_available, OMEGACONF_IMPORT_ERROR)), - ("tensorboard", (is_tensorboard_available, TENSORBOARD_IMPORT_ERROR)), - ("compel", (is_compel_available, COMPEL_IMPORT_ERROR)), - ("ftfy", (is_ftfy_available, FTFY_IMPORT_ERROR)), - ("torchsde", (is_torchsde_available, TORCHSDE_IMPORT_ERROR)), - ("imageio", (is_imageio_available, IMAGEIO_IMPORT_ERROR)), - ("invisible_watermark", (is_invisible_watermark_available, INVISIBLE_WATERMARK_IMPORT_ERROR)), - ] -) +BACKENDS_MAPPING = OrderedDict([ + ("bs4", (is_bs4_available, BS4_IMPORT_ERROR)), + ("flax", (is_flax_available, FLAX_IMPORT_ERROR)), + ("inflect", (is_inflect_available, INFLECT_IMPORT_ERROR)), + ("onnx", (is_onnx_available, ONNX_IMPORT_ERROR)), + ("opencv", (is_opencv_available, OPENCV_IMPORT_ERROR)), + ("scipy", (is_scipy_available, SCIPY_IMPORT_ERROR)), + ("torch", (is_torch_available, PYTORCH_IMPORT_ERROR)), + ("unidecode", (is_unidecode_available, UNIDECODE_IMPORT_ERROR)), + ("librosa", (is_librosa_available, LIBROSA_IMPORT_ERROR)), + ("k_diffusion", (is_k_diffusion_available, K_DIFFUSION_IMPORT_ERROR)), + ("note_seq", (is_note_seq_available, NOTE_SEQ_IMPORT_ERROR)), + ("wandb", (is_wandb_available, WANDB_IMPORT_ERROR)), + ("omegaconf", (is_omegaconf_available, OMEGACONF_IMPORT_ERROR)), + ("tensorboard", (is_tensorboard_available, TENSORBOARD_IMPORT_ERROR)), + ("compel", (is_compel_available, COMPEL_IMPORT_ERROR)), + ("ftfy", (is_ftfy_available, FTFY_IMPORT_ERROR)), + ("torchsde", (is_torchsde_available, TORCHSDE_IMPORT_ERROR)), + ("imageio", (is_imageio_available, IMAGEIO_IMPORT_ERROR)), + ("invisible_watermark", (is_invisible_watermark_available, INVISIBLE_WATERMARK_IMPORT_ERROR)), +]) def requires_backends(obj, backends):