diff --git a/v3-examples/ml-ops-examples/v3-mlflow-train-inference-e2e-example.ipynb b/v3-examples/ml-ops-examples/v3-mlflow-train-inference-e2e-example.ipynb new file mode 100644 index 0000000000..0acd6fdf54 --- /dev/null +++ b/v3-examples/ml-ops-examples/v3-mlflow-train-inference-e2e-example.ipynb @@ -0,0 +1,578 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# SageMaker V3 Train-to-Inference E2E with MLflow Integration\n", + "\n", + "This notebook demonstrates the complete end-to-end workflow from training a custom PyTorch model to deploying it for inference on SageMaker cloud infrastructure, with MLflow 3.x tracking and model registry integration." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Prerequisites\n", + "- SageMaker MLflow App created (tracking server ARN required)\n", + "- IAM permissions for MLflow tracking and model registry\n", + "- AWS credentials configured" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 0: Install Dependencies\n", + "\n", + "**Note:** There are known issues with MLflow model path resolution. Install the latest published SDK from GitHub for the latest fixes." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Install from local SDK for development (includes fixes for MLflow path resolution issues)\n", + "%pip install -e ../../sagemaker-core -e ../../sagemaker-train -e ../../sagemaker-serve -e ../../sagemaker-mlops -e ../../. \"mlflow==3.4.0\" --upgrade" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### NOTE: You must restart your kernel" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 1: Configuration\n", + "\n", + "Set up MLflow tracking server and training configuration." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import uuid\n", + "from sagemaker.core import image_uris\n", + "\n", + "# =============================================================================\n", + "# MLflow Configuration - UPDATE THIS WITH YOUR TRACKING SERVER ARN\n", + "# =============================================================================\n", + "# Eg. \"arn:aws:sagemaker:us-east-1:12345678:mlflow-app/app-ABCDEFGH123\"\n", + "MLFLOW_TRACKING_ARN = \"XXXXX\"\n", + "\n", + "# AWS Configuration\n", + "AWS_REGION = \"us-east-1\"\n", + "\n", + "# Get PyTorch training image dynamically\n", + "PYTORCH_TRAINING_IMAGE = image_uris.retrieve(\n", + " framework=\"pytorch\",\n", + " region=AWS_REGION,\n", + " version=\"2.5\",\n", + " py_version=\"py311\",\n", + " instance_type=\"ml.m5.xlarge\",\n", + " image_scope=\"training\"\n", + ")\n", + "print(f\"Using PyTorch training image: {PYTORCH_TRAINING_IMAGE}\")\n", + "\n", + "# Naming prefixes\n", + "MODEL_NAME_PREFIX = \"mlflow-e2e-model\"\n", + "ENDPOINT_NAME_PREFIX = \"mlflow-e2e-endpoint\"\n", + "TRAINING_JOB_PREFIX = \"mlflow-e2e-pytorch\"\n", + "MLFLOW_EXPERIMENT_NAME = \"sagemaker-v3-e2e-training\"\n", + "MLFLOW_REGISTERED_MODEL_NAME = \"pytorch-simple-classifier\"\n", + "\n", + "# Generate unique identifiers\n", + "unique_id = str(uuid.uuid4())[:8]\n", + "training_job_name = f\"{TRAINING_JOB_PREFIX}-{unique_id}\"\n", + "model_name = f\"{MODEL_NAME_PREFIX}-{unique_id}\"\n", + "endpoint_name = f\"{ENDPOINT_NAME_PREFIX}-{unique_id}\"\n", + "\n", + "print(f\"Training job name: {training_job_name}\")\n", + "print(f\"Model name: {model_name}\")\n", + "print(f\"Endpoint name: {endpoint_name}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 2: Connect to MLflow Tracking Server" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import mlflow\n", + "\n", + "# Connect to SageMaker MLflow tracking server\n", + "mlflow.set_tracking_uri(MLFLOW_TRACKING_ARN)\n", + "\n", + "# Create or get experiment\n", + "mlflow.set_experiment(MLFLOW_EXPERIMENT_NAME)\n", + "\n", + "print(f\"Connected to MLflow tracking server\")\n", + "print(f\"Experiment: {MLFLOW_EXPERIMENT_NAME}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 3: Create Training Code with MLflow Logging\n", + "\n", + "Create a PyTorch training script that logs metrics and registers the model to MLflow." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import tempfile\n", + "import os\n", + "\n", + "def create_pytorch_training_code_with_mlflow(mlflow_tracking_arn, experiment_name, registered_model_name):\n", + " \"\"\"Create PyTorch training script with MLflow integration.\"\"\"\n", + " temp_dir = tempfile.mkdtemp()\n", + " \n", + " train_script = f'''import torch\n", + "import torch.nn as nn\n", + "import torch.optim as optim\n", + "from torch.utils.data import DataLoader, TensorDataset\n", + "import os\n", + "import mlflow\n", + "import mlflow.pytorch\n", + "from mlflow.models import infer_signature\n", + "\n", + "class SimpleModel(nn.Module):\n", + " def __init__(self, input_dim=4, output_dim=2):\n", + " super().__init__()\n", + " self.linear = nn.Linear(input_dim, output_dim)\n", + " \n", + " def forward(self, x):\n", + " return torch.softmax(self.linear(x), dim=1)\n", + "\n", + "def train():\n", + " # MLflow setup\n", + " mlflow.set_tracking_uri(\"{mlflow_tracking_arn}\")\n", + " mlflow.set_experiment(\"{experiment_name}\")\n", + " \n", + " # Hyperparameters\n", + " learning_rate = 0.01\n", + " epochs = 10\n", + " batch_size = 32\n", + " input_dim = 4\n", + " output_dim = 2\n", + " \n", + " with mlflow.start_run() as run:\n", + " # Log hyperparameters\n", + " mlflow.log_params({{\n", + " \"learning_rate\": learning_rate,\n", + " \"epochs\": epochs,\n", + " \"batch_size\": batch_size,\n", + " \"input_dim\": input_dim,\n", + " \"output_dim\": output_dim,\n", + " \"optimizer\": \"Adam\",\n", + " \"loss_function\": \"CrossEntropyLoss\"\n", + " }})\n", + " \n", + " model = SimpleModel(input_dim, output_dim)\n", + " optimizer = optim.Adam(model.parameters(), lr=learning_rate)\n", + " criterion = nn.CrossEntropyLoss()\n", + " \n", + " # Synthetic data\n", + " X = torch.randn(100, input_dim)\n", + " y = torch.randint(0, output_dim, (100,))\n", + " dataset = TensorDataset(X, y)\n", + " dataloader = DataLoader(dataset, batch_size=batch_size)\n", + " \n", + " # Training loop with metric logging\n", + " model.train()\n", + " for epoch in range(epochs):\n", + " epoch_loss = 0.0\n", + " correct = 0\n", + " total = 0\n", + " \n", + " for batch_x, batch_y in dataloader:\n", + " optimizer.zero_grad()\n", + " outputs = model(batch_x)\n", + " loss = criterion(outputs, batch_y)\n", + " loss.backward()\n", + " optimizer.step()\n", + " \n", + " epoch_loss += loss.item()\n", + " _, predicted = torch.max(outputs.data, 1)\n", + " total += batch_y.size(0)\n", + " correct += (predicted == batch_y).sum().item()\n", + " \n", + " avg_loss = epoch_loss / len(dataloader)\n", + " accuracy = correct / total\n", + " \n", + " # Log metrics per epoch\n", + " mlflow.log_metrics({{\n", + " \"train_loss\": avg_loss,\n", + " \"train_accuracy\": accuracy\n", + " }}, step=epoch)\n", + " \n", + " print(f\"Epoch {{epoch+1}}/{{epochs}} - Loss: {{avg_loss:.4f}}, Accuracy: {{accuracy:.4f}}\")\n", + " \n", + " # Log final metrics\n", + " mlflow.log_metrics({{\n", + " \"final_loss\": avg_loss,\n", + " \"final_accuracy\": accuracy\n", + " }})\n", + " \n", + " # Infer signature and register model to MLflow\n", + " model.eval()\n", + " signature = infer_signature(\n", + " X.numpy(),\n", + " model(X).detach().numpy()\n", + " )\n", + " \n", + " # Log and register model in one step\n", + " mlflow.pytorch.log_model(\n", + " model,\n", + " name=\"{registered_model_name}\",\n", + " signature=signature,\n", + " registered_model_name=\"{registered_model_name}\"\n", + " )\n", + " \n", + " print(f\"Model registered to MLflow: {registered_model_name}\")\n", + " print(f\"Run ID: {{run.info.run_id}}\")\n", + " \n", + " print(\"Training completed!\")\n", + "\n", + "if __name__ == \"__main__\":\n", + " train()\n", + "'''\n", + " \n", + " with open(os.path.join(temp_dir, 'train.py'), 'w') as f:\n", + " f.write(train_script)\n", + " \n", + " with open(os.path.join(temp_dir, 'requirements.txt'), 'w') as f:\n", + " f.write('mlflow==3.4.0\\nsagemaker-mlflow==0.2.0\\ncloudpickle==3.1.2\\n')\n", + " \n", + " return temp_dir\n", + "\n", + "# Create training code\n", + "training_code_dir = create_pytorch_training_code_with_mlflow(\n", + " MLFLOW_TRACKING_ARN, \n", + " MLFLOW_EXPERIMENT_NAME,\n", + " MLFLOW_REGISTERED_MODEL_NAME\n", + ")\n", + "print(f\"Training code created in: {training_code_dir}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 4: Create ModelTrainer and Start Training\n", + "\n", + "Use ModelTrainer to run the training script on SageMaker managed infrastructure. The training job will log metrics to MLflow and register the model to the MLflow registry." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from sagemaker.train.model_trainer import ModelTrainer\n", + "from sagemaker.train.configs import SourceCode\n", + "\n", + "# Training on SageMaker managed infrastructure\n", + "model_trainer = ModelTrainer(\n", + " training_image=PYTORCH_TRAINING_IMAGE,\n", + " source_code=SourceCode(\n", + " source_dir=training_code_dir,\n", + " entry_script=\"train.py\",\n", + " requirements=\"requirements.txt\",\n", + " ),\n", + " base_job_name=training_job_name,\n", + ")\n", + "\n", + "# Start training job\n", + "print(f\"Starting training job: {training_job_name}\")\n", + "print(\"Metrics will be logged to MLflow during training...\")\n", + "\n", + "model_trainer.train() \n", + "print(\"Training completed! Check MLflow UI for metrics and registered model.\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 5: Get Registered Model from MLflow\n", + "\n", + "Retrieve the registered model from MLflow to get the model URI (`models://`) needed for deployment with ModelBuilder." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Get the latest version of the registered model\n", + "from mlflow import MlflowClient\n", + "\n", + "client = MlflowClient()\n", + "registered_model = client.get_registered_model(name=MLFLOW_REGISTERED_MODEL_NAME)\n", + "\n", + "latest_version = registered_model.latest_versions[0]\n", + "model_version = latest_version.version\n", + "model_source = latest_version.source\n", + "\n", + "# Get S3 URL of model files (for info only)\n", + "artifact_uri = client.get_model_version_download_uri(MLFLOW_REGISTERED_MODEL_NAME, model_version)\n", + "\n", + "# MLflow model registry path to use with ModelBuilder\n", + "mlflow_model_path = f\"models:/{MLFLOW_REGISTERED_MODEL_NAME}/{model_version}\"\n", + "\n", + "print(f\"Registered Model: {MLFLOW_REGISTERED_MODEL_NAME}\")\n", + "print(f\"Latest Version: {model_version}\")\n", + "print(f\"Source: {model_source}\")\n", + "print(f\"Model artifacts location: {artifact_uri}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 6: Deploy from MLflow Model Registry\n", + "\n", + "Use ModelBuilder to deploy the model directly from MLflow registry to a SageMaker endpoint." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import json\n", + "import torch\n", + "from sagemaker.serve.marshalling.custom_payload_translator import CustomPayloadTranslator\n", + "from sagemaker.serve.builder.schema_builder import SchemaBuilder\n", + "\n", + "# =============================================================================\n", + "# Custom translators for PyTorch tensor conversion\n", + "# \n", + "# PyTorch models expect tensors, but SageMaker endpoints communicate via JSON.\n", + "# These translators handle the conversion between JSON payloads and PyTorch tensors.\n", + "# =============================================================================\n", + "\n", + "class PyTorchInputTranslator(CustomPayloadTranslator):\n", + " \"\"\"Handles input serialization/deserialization for PyTorch models.\"\"\"\n", + " def __init__(self):\n", + " super().__init__(content_type='application/json', accept_type='application/json')\n", + " \n", + " def serialize_payload_to_bytes(self, payload: object) -> bytes:\n", + " if isinstance(payload, torch.Tensor):\n", + " return json.dumps(payload.tolist()).encode('utf-8')\n", + " return json.dumps(payload).encode('utf-8')\n", + " \n", + " def deserialize_payload_from_stream(self, stream) -> object:\n", + " data = json.load(stream)\n", + " return torch.tensor(data, dtype=torch.float32)\n", + "\n", + "class PyTorchOutputTranslator(CustomPayloadTranslator):\n", + " \"\"\"Handles output serialization/deserialization for PyTorch models.\"\"\"\n", + " def __init__(self):\n", + " super().__init__(content_type='application/json', accept_type='application/json')\n", + " \n", + " def serialize_payload_to_bytes(self, payload: object) -> bytes:\n", + " if isinstance(payload, torch.Tensor):\n", + " return json.dumps(payload.tolist()).encode('utf-8')\n", + " return json.dumps(payload).encode('utf-8')\n", + " \n", + " def deserialize_payload_from_stream(self, stream) -> object:\n", + " return json.load(stream)\n", + "\n", + "# Sample input/output for schema inference\n", + "sample_input = [[0.1, 0.2, 0.3, 0.4]]\n", + "sample_output = [[0.8, 0.2]]\n", + "\n", + "schema_builder = SchemaBuilder(\n", + " sample_input=sample_input,\n", + " sample_output=sample_output,\n", + " input_translator=PyTorchInputTranslator(),\n", + " output_translator=PyTorchOutputTranslator()\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from sagemaker.serve.model_builder import ModelBuilder\n", + "from sagemaker.serve.mode.function_pointers import Mode\n", + "\n", + "# Cloud deployment to SageMaker endpoint\n", + "model_builder = ModelBuilder(\n", + " mode=Mode.SAGEMAKER_ENDPOINT,\n", + " schema_builder=schema_builder,\n", + " model_metadata={\n", + " \"MLFLOW_MODEL_PATH\": mlflow_model_path,\n", + " \"MLFLOW_TRACKING_ARN\": MLFLOW_TRACKING_ARN\n", + " },\n", + " dependencies={\"auto\": False, \"custom\": [\"mlflow==3.4.0\", \"sagemaker==3.3.1\", \"numpy==2.4.1\", \"cloudpickle==3.1.2\"]},\n", + ")\n", + "\n", + "print(f\"ModelBuilder configured with MLflow model: {mlflow_model_path}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Build the model\n", + "core_model = model_builder.build(model_name=model_name, region=AWS_REGION)\n", + "print(f\"Model built: {core_model.model_name}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Deploy to SageMaker endpoint\n", + "core_endpoint = model_builder.deploy(\n", + " endpoint_name=endpoint_name,\n", + " initial_instance_count=1\n", + ")\n", + "\n", + "print(f\"Endpoint deployed: {core_endpoint.endpoint_name}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 7: Test the Deployed Model\n", + "\n", + "Invoke the endpoint with a sample input. The model returns class probabilities (2 classes) as a softmax output." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import boto3\n", + "\n", + "# Test with JSON input\n", + "test_data = [[0.1, 0.2, 0.3, 0.4]]\n", + "\n", + "runtime_client = boto3.client('sagemaker-runtime')\n", + "response = runtime_client.invoke_endpoint(\n", + " EndpointName=core_endpoint.endpoint_name,\n", + " Body=json.dumps(test_data),\n", + " ContentType='application/json'\n", + ")\n", + "\n", + "prediction = json.loads(response['Body'].read().decode('utf-8'))\n", + "print(f\"Input: {test_data}\")\n", + "print(f\"Prediction: {prediction}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 8: Clean Up Resources" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import shutil\n", + "from sagemaker.core.resources import EndpointConfig\n", + "\n", + "# Clean up AWS resources\n", + "core_endpoint_config = EndpointConfig.get(endpoint_config_name=core_endpoint.endpoint_name)\n", + "core_model.delete()\n", + "core_endpoint.delete()\n", + "core_endpoint_config.delete()\n", + "print(\"AWS resources cleaned up!\")\n", + "\n", + "# Clean up training code directory\n", + "try:\n", + " shutil.rmtree(training_code_dir)\n", + " print(\"Cleaned up training code directory\")\n", + "except Exception as e:\n", + " print(f\"Could not clean up training code: {e}\")\n", + "\n", + "print(\"Note: MLflow experiment runs and registered models are preserved.\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Summary\n", + "\n", + "This notebook demonstrates cloud deployment of a PyTorch model with MLflow integration:\n", + "\n", + "1. **Training**: Runs on SageMaker managed infrastructure with ModelTrainer\n", + "2. **MLflow Integration**: Logs metrics, parameters, and registers model to MLflow registry\n", + "3. **Deployment**: Uses ModelBuilder to deploy directly from MLflow registry to a SageMaker endpoint\n", + "4. **Inference**: Invokes the endpoint with JSON payloads\n", + "\n", + "Key MLflow integration points:\n", + "- `mlflow.log_params()` - hyperparameters\n", + "- `mlflow.log_metrics()` - training metrics per epoch\n", + "- `mlflow.pytorch.log_model()` - model artifact with registry\n", + "- `ModelBuilder` with `MLFLOW_MODEL_PATH` - deploy from registry\n", + "\n", + "Key patterns:\n", + "- Custom `PayloadTranslator` classes for PyTorch tensor serialization\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.14" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +}