ā±ļø 70 min

Building ML REST APIs

Create scalable API endpoints for ML models

Introduction to ML APIs

REST APIs are the standard way to serve ML models to applications. **Why REST APIs?** - **Language Agnostic**: Any language can make HTTP requests - **Scalable**: Easy to scale horizontally - **Standard**: Well-understood protocol - **Testable**: Easy to test with tools like Postman, curl **API Design Principles:** 1. **Versioning**: /api/v1/predict 2. **Clear Endpoints**: /predict, /batch-predict, /health 3. **Error Handling**: Proper HTTP status codes 4. **Documentation**: OpenAPI/Swagger specs 5. **Rate Limiting**: Prevent abuse 6. **Authentication**: API keys, OAuth **Frameworks:** - **FastAPI**: Modern, async, auto-documentation - **Flask**: Simple, flexible, mature - **Django REST**: Full-featured, for complex apps

FastAPI ML Service

Build a production-ready ML API with FastAPI:

python
from fastapi import FastAPI, HTTPException, BackgroundTasks
from pydantic import BaseModel, Field
from typing import List, Optional
import numpy as np
import joblib
import uvicorn
from datetime import datetime
import logging

# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Create FastAPI app
app = FastAPI(
    title="ML Prediction API",
    description="API for serving machine learning predictions",
    version="1.0.0",
    docs_url="/docs",
    redoc_url="/redoc"
)

# Request/Response models
class PredictionRequest(BaseModel):
    """Single prediction request"""
    features: List[float] = Field(..., description="Input features", min_items=10, max_items=10)
    
    class Config:
        schema_extra = {
            "example": {
                "features": [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]
            }
        }

class BatchPredictionRequest(BaseModel):
    """Batch prediction request"""
    instances: List[List[float]] = Field(..., description="List of feature arrays")

class PredictionResponse(BaseModel):
    """Prediction response"""
    prediction: int
    probability: float
    model_version: str
    timestamp: str

class BatchPredictionResponse(BaseModel):
    """Batch prediction response"""
    predictions: List[PredictionResponse]
    batch_size: int

# Global model variable
model = None
model_version = "1.0.0"

# Startup event
@app.on_event("startup")
async def load_model():
    """Load model on startup"""
    global model
    try:
        # In production, load from cloud storage or model registry
        model = DummyModel()  # Replace with: joblib.load("model.joblib")
        logger.info(f"Model v{model_version} loaded successfully")
    except Exception as e:
        logger.error(f"Failed to load model: {e}")
        raise

# Health check endpoint
@app.get("/health", tags=["Health"])
async def health_check():
    """Check if service is healthy"""
    return {
        "status": "healthy",
        "model_loaded": model is not None,
        "version": model_version,
        "timestamp": datetime.now().isoformat()
    }

# Prediction endpoint
@app.post("/predict", response_model=PredictionResponse, tags=["Predictions"])
async def predict(request: PredictionRequest, background_tasks: BackgroundTasks):
    """
    Make a single prediction
    
    - **features**: List of 10 numerical features
    """
    try:
        # Validate model is loaded
        if model is None:
            raise HTTPException(status_code=503, detail="Model not loaded")
        
        # Prepare input
        X = np.array(request.features).reshape(1, -1)
        
        # Make prediction
        prediction = int(model.predict(X)[0])
        probability = float(model.predict_proba(X)[0][prediction])
        
        # Log prediction (async)
        background_tasks.add_task(log_prediction, request.features, prediction)
        
        return PredictionResponse(
            prediction=prediction,
            probability=probability,
            model_version=model_version,
            timestamp=datetime.now().isoformat()
        )
    
    except ValueError as e:
        raise HTTPException(status_code=400, detail=f"Invalid input: {str(e)}")
    except Exception as e:
        logger.error(f"Prediction error: {e}")
        raise HTTPException(status_code=500, detail="Internal server error")

# Batch prediction endpoint
@app.post("/batch-predict", response_model=BatchPredictionResponse, tags=["Predictions"])
async def batch_predict(request: BatchPredictionRequest):
    """Make predictions for multiple instances"""
    try:
        if model is None:
            raise HTTPException(status_code=503, detail="Model not loaded")
        
        predictions = []
        
        for features in request.instances:
            X = np.array(features).reshape(1, -1)
            pred = int(model.predict(X)[0])
            prob = float(model.predict_proba(X)[0][pred])
            
            predictions.append(PredictionResponse(
                prediction=pred,
                probability=prob,
                model_version=model_version,
                timestamp=datetime.now().isoformat()
            ))
        
        return BatchPredictionResponse(
            predictions=predictions,
            batch_size=len(predictions)
        )
    
    except Exception as e:
        logger.error(f"Batch prediction error: {e}")
        raise HTTPException(status_code=500, detail="Internal server error")

# Model info endpoint
@app.get("/model/info", tags=["Model"])
async def model_info():
    """Get model information"""
    return {
        "version": model_version,
        "type": "classification",
        "input_features": 10,
        "output_classes": 2,
        "framework": "scikit-learn"
    }

# Background task
def log_prediction(features, prediction):
    """Log prediction to database or monitoring system"""
    logger.info(f"Prediction logged: {prediction}")

# Dummy model for demonstration
class DummyModel:
    def predict(self, X):
        return np.random.randint(0, 2, size=len(X))
    
    def predict_proba(self, X):
        probs = np.random.rand(len(X), 2)
        return probs / probs.sum(axis=1, keepdims=True)

# Run with: uvicorn main:app --reload
print("FastAPI ML Service")
print("=" * 50)
print("\nEndpoints:")
print("  GET  /health         - Health check")
print("  GET  /docs           - API documentation")
print("  POST /predict        - Single prediction")
print("  POST /batch-predict  - Batch predictions")
print("  GET  /model/info     - Model information")

print("\nStart server:")
print("  uvicorn main:app --host 0.0.0.0 --port 8000")

print("\nExample request:")
print('  curl -X POST "http://localhost:8000/predict" \\')
print('    -H "Content-Type: application/json" \\')
print('    -d '{"features": [1,2,3,4,5,6,7,8,9,10]}'')

print("\nāœ“ FastAPI service configured!")
Output:
FastAPI ML Service
==================================================

Endpoints:
  GET  /health         - Health check
  GET  /docs           - API documentation
  POST /predict        - Single prediction
  POST /batch-predict  - Batch predictions
  GET  /model/info     - Model information

Start server:
  uvicorn main:app --host 0.0.0.0 --port 8000

Example request:
  curl -X POST "http://localhost:8000/predict" \
    -H "Content-Type: application/json" \
    -d '{"features": [1,2,3,4,5,6,7,8,9,10]}'

āœ“ FastAPI service configured!

API Best Practices

**Performance Optimization:** 1. **Async Processing** - Use async endpoints for I/O operations - Background tasks for logging - Queue systems (Celery, RQ) for heavy tasks 2. **Caching** - Cache predictions for identical inputs - Use Redis or Memcached - Set appropriate TTL 3. **Batching** - Batch multiple requests - Improves throughput - Reduces per-request overhead 4. **Model Loading** - Load model once at startup - Use model registry for version management - Implement hot-swapping for updates **Security:** - API key authentication - Rate limiting (per user/IP) - Input validation - HTTPS only in production - CORS configuration **Monitoring:** - Request/response times - Error rates - Model performance drift - Resource usage (CPU, memory) **Testing:** ```python # pytest example def test_predict_endpoint(): response = client.post("/predict", json={"features": [1,2,3,4,5,6,7,8,9,10]}) assert response.status_code == 200 assert "prediction" in response.json() ```

Sharan Initiatives - Making a Difference Together