diff --git a/Dockerfile b/Dockerfile index b88cad5..2a57aa7 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,24 +1,31 @@ -# Use an official Python 3.11 runtime as a parent image -FROM python:3.11-slim +# Use an official Python 3.12 runtime as a parent image +FROM python:3.12-slim # Set the working directory WORKDIR /app -# Copy the current directory contents into the container -COPY . /app - # This tells Python to look in /app for the 'recml' package ENV PYTHONPATH="${PYTHONPATH}:/app" # Install system tools if needed (e.g., git) RUN apt-get update && apt-get install -y git && rm -rf /var/lib/apt/lists/* +# Install the latest jax-tpu-embedding wheel +COPY jax_tpu_embedding-0.1.0.dev20260121-cp312-cp312-manylinux_2_31_x86_64.whl ./ +RUN pip install ./jax_tpu_embedding-0.1.0.dev20260121-cp312-cp312-manylinux_2_31_x86_64.whl + +# Copy requirements.txt to current directory +COPY requirements.txt ./ + # Install dependencies RUN pip install --upgrade pip -RUN pip install -r requirements.txt +RUN pip install -r ./requirements.txt # Force install the specific protobuf version RUN pip install "protobuf>=6.31.1" --no-deps +# Copy the current directory contents into the container +COPY . /app + # Default command to run the training script CMD ["python", "recml/examples/dlrm_experiment_test.py"] diff --git a/jax_tpu_embedding-0.1.0.dev20260121-cp312-cp312-manylinux_2_31_x86_64.whl b/jax_tpu_embedding-0.1.0.dev20260121-cp312-cp312-manylinux_2_31_x86_64.whl new file mode 100644 index 0000000..3167f2c Binary files /dev/null and b/jax_tpu_embedding-0.1.0.dev20260121-cp312-cp312-manylinux_2_31_x86_64.whl differ diff --git a/requirements.txt b/requirements.txt index 998ee15..2eba034 100644 --- a/requirements.txt +++ b/requirements.txt @@ -34,7 +34,6 @@ importlib-resources==6.5.2 iniconfig==2.1.0 isort==6.0.1 jax==0.8.2 -jax-tpu-embedding==0.1.0.dev20251208 jaxlib==0.8.2 jaxtyping==0.3.1 Jinja2==3.1.6 @@ -123,4 +122,4 @@ wadler-lindig==0.1.5 Werkzeug==3.1.3 wheel==0.45.1 wrapt==1.17.2 -zipp==3.21.0 \ No newline at end of file +zipp==3.21.0 diff --git a/training.md b/training.md index bf36dbc..b9bcc31 100644 --- a/training.md +++ b/training.md @@ -7,7 +7,7 @@ This guide explains how to set up the environment and train the HSTU/DLRM models If you are developing on a TPU VM directly, use a virtual environment to avoid conflicts with the system-level Python packages. ### 1. Prerequisites -Ensure you have **Python 3.11+** installed. +Ensure you have **Python 3.12+** installed. ```bash python3 --version ``` @@ -23,6 +23,11 @@ source venv/bin/activate ``` ### 3. Install Dependencies + +Install the latest version of the jax-tpu-embedding library: +```bash +pip install ./jax_tpu_embedding-0.1.0.dev20260121-cp312-cp312-manylinux_2_31_x86_64.whl +``` ```bash pip install -r requirements.txt ``` @@ -41,39 +46,7 @@ python dlrm_experiment_test.py If you prefer not to manage a virtual environment or want to deploy this as a container, you can build a Docker image. -### 1. Create a Dockerfile -Create a file named `Dockerfile` in the root of the repository: - -```dockerfile -# Use an official Python 3.11 runtime as a parent image -FROM python:3.11-slim - -# Set the working directory -WORKDIR /app - -# Copy the current directory contents into the container -COPY . /app - -# This tells Python to look in /app for the 'recml' package -ENV PYTHONPATH="${PYTHONPATH}:/app" - -# Install system tools if needed (e.g., git) -RUN apt-get update && apt-get install -y git && rm -rf /var/lib/apt/lists/* - -# Install dependencies -RUN pip install --upgrade pip -RUN pip install -r requirements.txt - -# Force install the specific protobuf version -RUN pip install "protobuf>=6.31.1" --no-deps - -# Default command to run the training script -CMD ["python", "recml/examples/dlrm_experiment_test.py"] -``` - -You can use this dockerfile to run the DLRM model experiment from this repo in your own environment. - -### 2. Build the Image +### 1. Build the Image Run this command from the root of the repository. It reads the `Dockerfile`, installs all dependencies, and creates a ready-to-run image. @@ -81,7 +54,9 @@ Run this command from the root of the repository. It reads the `Dockerfile`, ins docker build -t recml-training . ``` -### 3. Run the Image +### 2. Run the Image + +This will run the docker image and execute the command specified, which is currently set to run DLRM. ```bash docker run --rm --privileged \ @@ -90,9 +65,3 @@ docker run --rm --privileged \ --name recml-experiment \ recml-training ``` - -### What is happening here? -* **`--rm`**: Automatically deletes the container after the script finishes to keep your disk clean. -* **`--privileged`**: Grants the container direct access to the host's hardware devices, which is required to see the physical TPU chips. -* **`--net=host`**: Removes the container's network isolation, allowing the script to connect to the TPU runtime listening on local ports (e.g., 8353). -* **`--ipc=host`**: Allows the container to use the host's Shared Memory (IPC), which is critical for high-speed data transfer between the CPU and TPU. \ No newline at end of file