Skip to content

Conversation

@SamuelMarks
Copy link
Collaborator

Description

Builder Pattern for TFDS Preprocessing Pipeline: preprocessing_pipeline accepts 17 parameters (e.g., shuffle, tokenize, pack_examples, use_dpo, add_bos, add_eos, shift) and applies transformations to a tf.data.Dataset procedurally ; increase type safety & doc string coverage ; [src/MaxText/tokenizer.py] Introduce TokenizerType.

TL;DR this refactors the monolithic preprocessing_pipeline into a fluent TfdsPipelineBuilder class to decouple transformation logic and improve readability.

Before (Procedural / Monolithic)

def preprocessing_pipeline(dataset, tokenizer_path, ..., pack_examples, use_dpo):
  if not use_dpo:
    dataset = dataset.map(lambda x: _input_pipeline_utils.normalize_features(x, data_column_names[0]))
  
  tokenizer_model = _input_pipeline_utils.get_tokenizer(tokenizer_path, ...)

  if tokenize:
    dataset = dataset.map(lambda x: tokenizer.TokenizeOp(tokenizer=tokenizer_model, ...))

  if max_target_length > 0:
    dataset = dataset.map(lambda x: _input_pipeline_utils.truncate_to_max_allowable_length(x, ...))

  if shuffle:
    dataset = dataset.shuffle(shuffle_buffer_size, seed=data_shuffle_seed)

  dataset = dataset.repeat(num_epochs)

  if shift and not use_dpo:
    dataset = dataset.map(_input_pipeline_utils.shift_data_by_truncation, ...)

  if pack_examples and not use_dpo:
    dataset = sequence_packing.pack_dataset(dataset, max_target_length, pad_id)
    dataset = dataset.batch(global_batch_size // jax.process_count(), drop_remainder=drop_remainder)
  else:
    dataset = dataset.padded_batch(...)
    
  if prefetch_size:
    dataset = dataset.prefetch(prefetch_size)

  return dataset

After (Builder Pattern)

def preprocessing_pipeline(dataset, tokenizer_path, ..., pack_examples, use_dpo):
  builder = TfdsPipelineBuilder(dataset, data_column_names, use_dpo)

  tokenizer_model = _input_pipeline_utils.get_tokenizer(tokenizer_path, ...)

  if tokenize:
    builder.with_tokenization(tokenizer_model)

  builder.with_truncation(max_target_length)

  if shuffle:
    builder.with_shuffling(shuffle_buffer_size, data_shuffle_seed)

  builder.with_repeat(num_epochs)

  if shift:
    builder.with_shift()

  global_per_process_batch_size = global_batch_size // jax.process_count()
  builder.with_batching(
      global_per_process_batch_size, pack_examples, max_target_length, pad_id, drop_remainder
  )

  builder.with_prefetch(prefetch_size)

  return builder.build()

Tests

CI

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

… for TFDS Preprocessing Pipeline: `preprocessing_pipeline` accepts 17 parameters (e.g., `shuffle`, `tokenize`, `pack_examples`, `use_dpo`, `add_bos`, `add_eos`, `shift`) and applies transformations to a `tf.data.Dataset` procedurally ; increase type safety & doc string coverage ; [src/MaxText/tokenizer.py] Introduce `TokenizerType`
@SamuelMarks SamuelMarks changed the title [src/MaxText/input_pipeline/_tfds_data_processing.py] Builder Patternfor TFDS Preprocessing Pipeline: preprocessing_pipeline [src/MaxText/input_pipeline/_tfds_data_processing.py] Builder Pattern for TFDS Preprocessing Pipeline: preprocessing_pipeline Jan 21, 2026
@codecov
Copy link

codecov bot commented Jan 21, 2026

Codecov Report

❌ Patch coverage is 88.88889% with 6 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
...rc/MaxText/input_pipeline/_tfds_data_processing.py 88.67% 1 Missing and 5 partials ⚠️

📢 Thoughts on this report? Let us know!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant