Skip to content

Conversation

@SamuelMarks
Copy link
Collaborator

Description

Builder Pattern for HF Preprocessing Pipeline.

preprocessing_pipeline takes ~25 arguments (many bools) and constructs a Grain pipeline procedurally. This is brittle. So this PR implements a GrainPipelineBuilder to construct the operations chain fluently; and adds a bunch of type safety & increases doc coverage throughout this file.

Before vs After: preprocessing_pipeline Refactor

Before: A procedural function with dense logic, mixing configuration validation, dataset mapping, and operation appending in a single block.

def preprocessing_pipeline(
    dataloading_host_index, dataloading_host_count, global_mesh, dataset,
    data_column_names, tokenize, tokenizer_path, hf_access_token,
    global_batch_size, max_target_length, shuffle, data_shuffle_seed,
    chat_template_path="", add_bos=True, add_eos=True, packing=True,
    shift=True, num_threads=1, drop_remainder=True,
    generate_padding_batch=False, use_dpo=None, use_sft=None,
    sft_train_on_completion_only=True, grain_worker_count=1,
    max_segments_per_seq=None,
):
    # ... initial assertions ...

    if shuffle:
        dataset = dataset.shuffle(seed=data_shuffle_seed)

    # ... large block configuring tokenizer ... 
    
    if use_sft: # Complex SFT logic mixed in
        # ... select_columns ...
        # ... convert_to_conversational_format ...
        # ... apply_chat_template ...

    if tokenize:
         dataset = dataset.map(...)

    dataset = _input_pipeline_utils.HFDataSource(...)
    
    operations = []
    
    # ... logic branching for dpo vs sft vs default ...
    if use_sft:
        operations.append(_input_pipeline_utils.SFTPromptMasking(...))
    elif use_dpo:
        operations.append(grain.MapOperation(...))
    else:
        operations.append(_input_pipeline_utils.HFNormalizeFeatures(...))

    if packing and not use_dpo:
        operations.append(grain.experimental.PackAndBatchOperation(...))
    else:
        operations.append(grain.Batch(...))

    if shift and not use_dpo:
         operations.append(_input_pipeline_utils.ShiftData(...))

    # ... creation of sampler and dataloader ...
    return multihost_gen

After: A fluent interface where pipeline stages are explicit, reorderable, and encapsulated in the GrainPipelineBuilder.

def preprocessing_pipeline(
    dataloading_host_index, dataloading_host_count, global_mesh, dataset,
    data_column_names, tokenize, tokenizer_path, hf_access_token,
    global_batch_size, max_target_length, shuffle, data_shuffle_seed,
    chat_template_path="", add_bos=True, add_eos=True, packing=True,
    shift=True, num_threads=1, drop_remainder=True,
    generate_padding_batch=False, use_dpo=None, use_sft=None,
    sft_train_on_completion_only=True, grain_worker_count=1,
    max_segments_per_seq=None,
):
    builder = GrainPipelineBuilder(
        dataset=dataset,
        global_mesh=global_mesh,
        dataloading_host_index=dataloading_host_index,
        dataloading_host_count=dataloading_host_count,
    )

    return (
        builder
        .add_shuffling(shuffle, data_shuffle_seed)
        .add_tokenization(
            tokenizer_path=tokenizer_path,
            hf_access_token=hf_access_token,
            tokenize=tokenize,
            data_column_names=data_column_names,
            max_target_length=max_target_length,
            use_sft=use_sft or False,
            chat_template_path=chat_template_path,
            add_bos=add_bos,
            add_eos=add_eos,
        )
        .add_normalization(
            use_sft=use_sft or False,
            use_dpo=use_dpo or False,
            sft_train_on_completion_only=sft_train_on_completion_only,
            max_target_length=max_target_length,
        )
        .add_packing(
            packing=packing,
            global_batch_size=global_batch_size,
            max_target_length=max_target_length,
            max_segments_per_seq=max_segments_per_seq,
            drop_remainder=drop_remainder,
            use_dpo=use_dpo or False,
        )
        .add_shifting(
            shift=shift,
            use_dpo=use_dpo or False,
        )
        .build(
            num_threads=num_threads,
            grain_worker_count=grain_worker_count,
            generate_padding_batch=generate_padding_batch,
        )
    )

Tests

CI

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

…t interface for constructing Grain pipelines (introduce `GrainPipelineBuilder`); increase type-safety
@codecov
Copy link

codecov bot commented Jan 21, 2026

Codecov Report

❌ Patch coverage is 87.35632% with 11 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/MaxText/input_pipeline/_hf_data_processing.py 87.35% 6 Missing and 5 partials ⚠️

📢 Thoughts on this report? Let us know!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant