checkpoint utility: shard checkpoint, monitor peak #2974
+190
−113
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Description
Main change:
to_maxtext.pyadd option to shard weights before saving orbax checkpoint.--simulated_cpu_devices_count, default to 16. That is, shard ckpt across 16 simulated cpu array.--simulated_cpu_devices_count=1, skip shardingsave_weights_to_checkpointfrom MaxText.utils.ckpt_scripts.llama_or_mistral_ckptsave_weights_to_checkpoint: add comment, use pop() rather than pop(0), log time for shard and saveWhy?
Additional change:
MemoryMonitorTqdmfrom to_maxtext to utilsTests
model:
qwen3-0.6bauxiliary script to check checkpoint sharding
check_orbax.py: https://paste.googleplex.com/4961902692270080eager mode, simulated_cpu_devices_count=16 (default)
log: https://paste.googleplex.com/5225183650643968
gs://runner-maxtext-logs/2026-01-21-01-22/0/items
INFO:absl:Peak Memory: 6.63 GB
check sharding
https://paste.googleplex.com/5576907204722688
eager mode, simulated_cpu_devices_count=1 (no shard)
log: https://paste.googleplex.com/4754761175924736
gs://runner-maxtext-logs/2026-01-21-01-24/0/items
INFO:absl:Peak Memory: 7.01 GB
check sharding
https://paste.googleplex.com/5627354883948544
lazy mode, simulated_cpu_devices_count=16 (default)
log: https://paste.googleplex.com/6120388708925440
gs://runner-maxtext-logs/2026-01-21-01-15/0/items
INFO:absl:Peak Memory: 3.68 GB
check sharding
https://paste.googleplex.com/6244107053826048
lazy mode, simulated_cpu_devices_count=1 (no shard)
log: https://paste.googleplex.com/5808230754287616
gs://runner-maxtext-logs/2026-01-21-01-17/0/items
INFO:absl:Peak Memory: 3.01 GB
check sharding
https://paste.googleplex.com/6262683827568640
Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.