Skip to content

Conversation

@shuningjin
Copy link
Collaborator

@shuningjin shuningjin commented Jan 20, 2026

Description

Main change: to_maxtext.py add option to shard weights before saving orbax checkpoint.

  • control by --simulated_cpu_devices_count, default to 16. That is, shard ckpt across 16 simulated cpu array.
    • If set --simulated_cpu_devices_count=1, skip sharding
  • reuse save_weights_to_checkpoint from MaxText.utils.ckpt_scripts.llama_or_mistral_ckpt
    • refine save_weights_to_checkpoint: add comment, use pop() rather than pop(0), log time for shard and save

Why?

Additional change:

  • print peak memory
  • move MemoryMonitorTqdm from to_maxtext to utils

Tests

model: qwen3-0.6b

auxiliary script to check checkpoint sharding check_orbax.py: https://paste.googleplex.com/4961902692270080

eager mode, simulated_cpu_devices_count=16 (default)

BASE_OUTPUT_PATH=gs://runner-maxtext-logs/$(date +%Y-%m-%d-%H-%M); \
echo $BASE_OUTPUT_PATH/0/items; \
python -m MaxText.utils.ckpt_conversion.to_maxtext MaxText/configs/base.yml model_name=qwen3-0.6b scan_layers=true \
base_output_directory=$BASE_OUTPUT_PATH hf_access_token=$HF_TOKEN \
hardware=cpu skip_jax_distributed_system=True \
attention=dot_product \
--simulated_cpu_devices_count=16

log: https://paste.googleplex.com/5225183650643968
gs://runner-maxtext-logs/2026-01-21-01-22/0/items
INFO:absl:Peak Memory: 6.63 GB

INFO:absl:shard weights across 16 devices
  0%|                                                                                                                                                         | 0/13 [00:00<?, ?it/s]
INFO:absl:sharding axis 0
INFO:absl:Elapse for checkpoint sharding: 0.15 min

check sharding

python check_orbax.py gs://runner-maxtext-logs/2026-01-21-01-22/0/items

https://paste.googleplex.com/5576907204722688

ArrayMetadata :  name=params.params.token_embedder.embedding,  directory=gs://runner-maxtext-logs/2026-01-21-01-22/0/items,  shape=(151936, 1024),  sharding=NamedShardingMetadata(shape=[16], axis_names=['checkpoint_sharding_axis'], axis_types=(Auto,), partition_spec=('checkpoint_sharding_axis',)) device_mesh=DeviceMetadataMesh(mesh=[DeviceMetadata(id=0), DeviceMetadata(id=1), DeviceMetadata(id=2), DeviceMetadata(id=3), DeviceMetadata(id=4), DeviceMetadata(id=5), DeviceMetadata(id=6), DeviceMetadata(id=7), DeviceMetadata(id=8), DeviceMetadata(id=9), DeviceMetadata(id=10), DeviceMetadata(id=11), DeviceMetadata(id=12), DeviceMetadata(id=13), DeviceMetadata(id=14), DeviceMetadata(id=15)]),  dtype=float32,  storage=StorageMetadata(chunk_shape=(9496, 1024), write_shape=(9496, 1024)),

eager mode, simulated_cpu_devices_count=1 (no shard)

BASE_OUTPUT_PATH=gs://runner-maxtext-logs/$(date +%Y-%m-%d-%H-%M); \
echo $BASE_OUTPUT_PATH/0/items; \
python -m MaxText.utils.ckpt_conversion.to_maxtext MaxText/configs/base.yml model_name=qwen3-0.6b scan_layers=true \
base_output_directory=$BASE_OUTPUT_PATH hf_access_token=$HF_TOKEN \
hardware=cpu skip_jax_distributed_system=True \
attention=dot_product \
--simulated_cpu_devices_count=1

log: https://paste.googleplex.com/4754761175924736

gs://runner-maxtext-logs/2026-01-21-01-24/0/items
INFO:absl:Peak Memory: 7.01 GB

check sharding

python check_orbax.py gs://runner-maxtext-logs/2026-01-21-01-24/0/items

https://paste.googleplex.com/5627354883948544

ArrayMetadata :  name=params.params.token_embedder.embedding,  directory=gs://runner-maxtext-logs/2026-01-21-01-24/0/items,  shape=(151936, 1024),  sharding=None,  dtype=float32,  storage=StorageMetadata(chunk_shape=(151936, 1024), write_shape=None),

lazy mode, simulated_cpu_devices_count=16 (default)

BASE_OUTPUT_PATH=gs://runner-maxtext-logs/$(date +%Y-%m-%d-%H-%M); \
echo $BASE_OUTPUT_PATH/0/items; \
python -m MaxText.utils.ckpt_conversion.to_maxtext MaxText/configs/base.yml model_name=qwen3-0.6b scan_layers=true \
base_output_directory=$BASE_OUTPUT_PATH hf_access_token=$HF_TOKEN \
hardware=cpu skip_jax_distributed_system=True \
attention=dot_product \
--lazy_load_tensors=true --simulated_cpu_devices_count=16

log: https://paste.googleplex.com/6120388708925440
gs://runner-maxtext-logs/2026-01-21-01-15/0/items
INFO:absl:Peak Memory: 3.68 GB

check sharding

python check_orbax.py gs://runner-maxtext-logs/2026-01-21-01-15/0/items

https://paste.googleplex.com/6244107053826048

ArrayMetadata :  name=params.params.token_embedder.embedding,  directory=gs://runner-maxtext-logs/2026-01-21-01-15/0/items,  shape=(151936, 1024),  sharding=NamedShardingMetadata(shape=[16], axis_names=['checkpoint_sharding_axis'], axis_types=(Auto,), partition_spec=('checkpoint_sharding_axis',)) device_mesh=DeviceMetadataMesh(mesh=[DeviceMetadata(id=0), DeviceMetadata(id=1), DeviceMetadata(id=2), DeviceMetadata(id=3), DeviceMetadata(id=4), DeviceMetadata(id=5), DeviceMetadata(id=6), DeviceMetadata(id=7), DeviceMetadata(id=8), DeviceMetadata(id=9), DeviceMetadata(id=10), DeviceMetadata(id=11), DeviceMetadata(id=12), DeviceMetadata(id=13), DeviceMetadata(id=14), DeviceMetadata(id=15)]),  dtype=bfloat16,  storage=StorageMetadata(chunk_shape=(9496, 1024), write_shape=(9496, 1024)),

lazy mode, simulated_cpu_devices_count=1 (no shard)

BASE_OUTPUT_PATH=gs://runner-maxtext-logs/$(date +%Y-%m-%d-%H-%M); \
echo $BASE_OUTPUT_PATH/0/items; \
python -m MaxText.utils.ckpt_conversion.to_maxtext MaxText/configs/base.yml model_name=qwen3-0.6b scan_layers=true \
base_output_directory=$BASE_OUTPUT_PATH hf_access_token=$HF_TOKEN \
hardware=cpu skip_jax_distributed_system=True \
attention=dot_product \
--lazy_load_tensors=true --simulated_cpu_devices_count=1

log: https://paste.googleplex.com/5808230754287616
gs://runner-maxtext-logs/2026-01-21-01-17/0/items
INFO:absl:Peak Memory: 3.01 GB

check sharding

python check_orbax.py gs://runner-maxtext-logs/2026-01-21-01-17/0/items

https://paste.googleplex.com/6262683827568640

ArrayMetadata :  name=params.params.token_embedder.embedding,  directory=gs://runner-maxtext-logs/2026-01-21-01-17/0/items,  shape=(151936, 1024),  sharding=None,  dtype=float32,  storage=StorageMetadata(chunk_shape=(151936, 1024), write_shape=None),

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@shuningjin shuningjin changed the title checkpoint utility: save shard checkpoint and improve mem monitor checkpoint utility: shard checkpoint, monitor peak Jan 20, 2026
@codecov
Copy link

codecov bot commented Jan 20, 2026

Codecov Report

❌ Patch coverage is 0% with 38 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/MaxText/utils/ckpt_conversion/utils/utils.py 0.00% 21 Missing ⚠️
src/MaxText/utils/ckpt_conversion/to_maxtext.py 0.00% 15 Missing ⚠️
...rc/MaxText/utils/ckpt_conversion/to_huggingface.py 0.00% 2 Missing ⚠️

📢 Thoughts on this report? Let us know!

@shuningjin shuningjin force-pushed the shuningjin-ckpt-opt2 branch from dd5bad1 to 30418aa Compare January 21, 2026 06:25
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants