diff --git a/code_style.sh b/code_style.sh old mode 100644 new mode 100755 diff --git a/src/maxdiffusion/models/ltx_video/transformers_pytorch/attention.py b/src/maxdiffusion/models/ltx_video/transformers_pytorch/attention.py index ceede943..93b16851 100644 --- a/src/maxdiffusion/models/ltx_video/transformers_pytorch/attention.py +++ b/src/maxdiffusion/models/ltx_video/transformers_pytorch/attention.py @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. """ + import inspect from importlib import import_module from typing import Any, Dict, Optional, Tuple diff --git a/src/maxdiffusion/pedagogical_examples/attention_comparison.py b/src/maxdiffusion/pedagogical_examples/attention_comparison.py index 07831550..6981e092 100644 --- a/src/maxdiffusion/pedagogical_examples/attention_comparison.py +++ b/src/maxdiffusion/pedagogical_examples/attention_comparison.py @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. """ + import os import time diff --git a/src/maxdiffusion/pedagogical_examples/dataset_tf_cache_to_tfrecord.py b/src/maxdiffusion/pedagogical_examples/dataset_tf_cache_to_tfrecord.py index 2300d0bd..cc547cc8 100644 --- a/src/maxdiffusion/pedagogical_examples/dataset_tf_cache_to_tfrecord.py +++ b/src/maxdiffusion/pedagogical_examples/dataset_tf_cache_to_tfrecord.py @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. """ + import os import argparse import tensorflow as tf diff --git a/src/maxdiffusion/pedagogical_examples/parameter_count.py b/src/maxdiffusion/pedagogical_examples/parameter_count.py index cf9f8b8d..e9fe4542 100644 --- a/src/maxdiffusion/pedagogical_examples/parameter_count.py +++ b/src/maxdiffusion/pedagogical_examples/parameter_count.py @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. """ + from typing import Sequence from absl import app import jax diff --git a/src/maxdiffusion/pipelines/controlnet/__init__.py b/src/maxdiffusion/pipelines/controlnet/__init__.py index 0cf92cd4..fe7070b0 100644 --- a/src/maxdiffusion/pipelines/controlnet/__init__.py +++ b/src/maxdiffusion/pipelines/controlnet/__init__.py @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. """ + from typing import TYPE_CHECKING from ...utils import ( diff --git a/src/maxdiffusion/pipelines/flux/__init__.py b/src/maxdiffusion/pipelines/flux/__init__.py index c39cc364..39ea05b5 100644 --- a/src/maxdiffusion/pipelines/flux/__init__.py +++ b/src/maxdiffusion/pipelines/flux/__init__.py @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. """ + _import_structure = {"pipeline_jflux": "JfluxPipeline"} from .flux_pipeline import ( diff --git a/src/maxdiffusion/pipelines/stable_diffusion/__init__.py b/src/maxdiffusion/pipelines/stable_diffusion/__init__.py index cbef1d5f..564b0dfa 100644 --- a/src/maxdiffusion/pipelines/stable_diffusion/__init__.py +++ b/src/maxdiffusion/pipelines/stable_diffusion/__init__.py @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. """ + from typing import TYPE_CHECKING from ...utils import ( diff --git a/src/maxdiffusion/pipelines/stable_diffusion_xl/__init__.py b/src/maxdiffusion/pipelines/stable_diffusion_xl/__init__.py index 13e201ae..1ae1b641 100644 --- a/src/maxdiffusion/pipelines/stable_diffusion_xl/__init__.py +++ b/src/maxdiffusion/pipelines/stable_diffusion_xl/__init__.py @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. """ + from typing import TYPE_CHECKING from ...utils import ( diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline.py b/src/maxdiffusion/pipelines/wan/wan_pipeline.py index 415bcfea..f0621c8e 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline.py @@ -543,8 +543,9 @@ def prepare_latents_i2v_base( vae_dtype = getattr(self.vae, "dtype", jnp.float32) video_condition = video_condition.astype(vae_dtype) - with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): + sharding_spec = P(self.config.mesh_axes[0], None, None, None, None) + video_condition = jax.lax.with_sharding_constraint(video_condition, sharding_spec) encoded_output = self.vae.encode(video_condition, self.vae_cache)[0].mode() # Normalize latents diff --git a/src/maxdiffusion/tests/configuration_utils_test.py b/src/maxdiffusion/tests/configuration_utils_test.py index ee761a38..df46ea75 100644 --- a/src/maxdiffusion/tests/configuration_utils_test.py +++ b/src/maxdiffusion/tests/configuration_utils_test.py @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. """ + import json import os diff --git a/src/maxdiffusion/tests/flop_calculations_test.py b/src/maxdiffusion/tests/flop_calculations_test.py index ca0a5020..a58d5dcc 100644 --- a/src/maxdiffusion/tests/flop_calculations_test.py +++ b/src/maxdiffusion/tests/flop_calculations_test.py @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. """ + import os import unittest from unittest.mock import Mock diff --git a/src/maxdiffusion/tests/generate_flux_smoke_test.py b/src/maxdiffusion/tests/generate_flux_smoke_test.py index 12bfe77b..4f174716 100644 --- a/src/maxdiffusion/tests/generate_flux_smoke_test.py +++ b/src/maxdiffusion/tests/generate_flux_smoke_test.py @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. """ + import os import unittest import pytest diff --git a/src/maxdiffusion/tests/generate_sdxl_smoke_test.py b/src/maxdiffusion/tests/generate_sdxl_smoke_test.py index ff823010..e2b4d772 100644 --- a/src/maxdiffusion/tests/generate_sdxl_smoke_test.py +++ b/src/maxdiffusion/tests/generate_sdxl_smoke_test.py @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. """ + import os import unittest import pytest diff --git a/src/maxdiffusion/tests/generate_smoke_test.py b/src/maxdiffusion/tests/generate_smoke_test.py index 2c5b783a..b6722b3a 100644 --- a/src/maxdiffusion/tests/generate_smoke_test.py +++ b/src/maxdiffusion/tests/generate_smoke_test.py @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. """ + import os import unittest import pytest diff --git a/src/maxdiffusion/utils/deprecation_utils.py b/src/maxdiffusion/utils/deprecation_utils.py index 265a60b5..a7077ed7 100644 --- a/src/maxdiffusion/utils/deprecation_utils.py +++ b/src/maxdiffusion/utils/deprecation_utils.py @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. """ + import inspect import warnings from typing import Any, Dict, Optional, Union diff --git a/src/maxdiffusion/utils/export_utils.py b/src/maxdiffusion/utils/export_utils.py index 51b05d30..fa394129 100644 --- a/src/maxdiffusion/utils/export_utils.py +++ b/src/maxdiffusion/utils/export_utils.py @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. """ + import io import random import struct diff --git a/src/maxdiffusion/utils/loading_utils.py b/src/maxdiffusion/utils/loading_utils.py index f2b72cbd..735d2261 100644 --- a/src/maxdiffusion/utils/loading_utils.py +++ b/src/maxdiffusion/utils/loading_utils.py @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. """ + import os from typing import Callable, List, Optional, Union diff --git a/src/maxdiffusion/utils/pil_utils.py b/src/maxdiffusion/utils/pil_utils.py index a05aa47a..86d07c66 100644 --- a/src/maxdiffusion/utils/pil_utils.py +++ b/src/maxdiffusion/utils/pil_utils.py @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. """ + from typing import List import PIL.Image diff --git a/src/maxdiffusion/utils/testing_utils.py b/src/maxdiffusion/utils/testing_utils.py index 6194a03a..55be62ac 100644 --- a/src/maxdiffusion/utils/testing_utils.py +++ b/src/maxdiffusion/utils/testing_utils.py @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. """ + import functools import importlib import inspect