Create scalable API endpoints for ML models
REST APIs are the standard way to serve ML models to applications. **Why REST APIs?** - **Language Agnostic**: Any language can make HTTP requests - **Scalable**: Easy to scale horizontally - **Standard**: Well-understood protocol - **Testable**: Easy to test with tools like Postman, curl **API Design Principles:** 1. **Versioning**: /api/v1/predict 2. **Clear Endpoints**: /predict, /batch-predict, /health 3. **Error Handling**: Proper HTTP status codes 4. **Documentation**: OpenAPI/Swagger specs 5. **Rate Limiting**: Prevent abuse 6. **Authentication**: API keys, OAuth **Frameworks:** - **FastAPI**: Modern, async, auto-documentation - **Flask**: Simple, flexible, mature - **Django REST**: Full-featured, for complex apps
Build a production-ready ML API with FastAPI:
from fastapi import FastAPI, HTTPException, BackgroundTasks
from pydantic import BaseModel, Field
from typing import List, Optional
import numpy as np
import joblib
import uvicorn
from datetime import datetime
import logging
# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Create FastAPI app
app = FastAPI(
title="ML Prediction API",
description="API for serving machine learning predictions",
version="1.0.0",
docs_url="/docs",
redoc_url="/redoc"
)
# Request/Response models
class PredictionRequest(BaseModel):
"""Single prediction request"""
features: List[float] = Field(..., description="Input features", min_items=10, max_items=10)
class Config:
schema_extra = {
"example": {
"features": [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]
}
}
class BatchPredictionRequest(BaseModel):
"""Batch prediction request"""
instances: List[List[float]] = Field(..., description="List of feature arrays")
class PredictionResponse(BaseModel):
"""Prediction response"""
prediction: int
probability: float
model_version: str
timestamp: str
class BatchPredictionResponse(BaseModel):
"""Batch prediction response"""
predictions: List[PredictionResponse]
batch_size: int
# Global model variable
model = None
model_version = "1.0.0"
# Startup event
@app.on_event("startup")
async def load_model():
"""Load model on startup"""
global model
try:
# In production, load from cloud storage or model registry
model = DummyModel() # Replace with: joblib.load("model.joblib")
logger.info(f"Model v{model_version} loaded successfully")
except Exception as e:
logger.error(f"Failed to load model: {e}")
raise
# Health check endpoint
@app.get("/health", tags=["Health"])
async def health_check():
"""Check if service is healthy"""
return {
"status": "healthy",
"model_loaded": model is not None,
"version": model_version,
"timestamp": datetime.now().isoformat()
}
# Prediction endpoint
@app.post("/predict", response_model=PredictionResponse, tags=["Predictions"])
async def predict(request: PredictionRequest, background_tasks: BackgroundTasks):
"""
Make a single prediction
- **features**: List of 10 numerical features
"""
try:
# Validate model is loaded
if model is None:
raise HTTPException(status_code=503, detail="Model not loaded")
# Prepare input
X = np.array(request.features).reshape(1, -1)
# Make prediction
prediction = int(model.predict(X)[0])
probability = float(model.predict_proba(X)[0][prediction])
# Log prediction (async)
background_tasks.add_task(log_prediction, request.features, prediction)
return PredictionResponse(
prediction=prediction,
probability=probability,
model_version=model_version,
timestamp=datetime.now().isoformat()
)
except ValueError as e:
raise HTTPException(status_code=400, detail=f"Invalid input: {str(e)}")
except Exception as e:
logger.error(f"Prediction error: {e}")
raise HTTPException(status_code=500, detail="Internal server error")
# Batch prediction endpoint
@app.post("/batch-predict", response_model=BatchPredictionResponse, tags=["Predictions"])
async def batch_predict(request: BatchPredictionRequest):
"""Make predictions for multiple instances"""
try:
if model is None:
raise HTTPException(status_code=503, detail="Model not loaded")
predictions = []
for features in request.instances:
X = np.array(features).reshape(1, -1)
pred = int(model.predict(X)[0])
prob = float(model.predict_proba(X)[0][pred])
predictions.append(PredictionResponse(
prediction=pred,
probability=prob,
model_version=model_version,
timestamp=datetime.now().isoformat()
))
return BatchPredictionResponse(
predictions=predictions,
batch_size=len(predictions)
)
except Exception as e:
logger.error(f"Batch prediction error: {e}")
raise HTTPException(status_code=500, detail="Internal server error")
# Model info endpoint
@app.get("/model/info", tags=["Model"])
async def model_info():
"""Get model information"""
return {
"version": model_version,
"type": "classification",
"input_features": 10,
"output_classes": 2,
"framework": "scikit-learn"
}
# Background task
def log_prediction(features, prediction):
"""Log prediction to database or monitoring system"""
logger.info(f"Prediction logged: {prediction}")
# Dummy model for demonstration
class DummyModel:
def predict(self, X):
return np.random.randint(0, 2, size=len(X))
def predict_proba(self, X):
probs = np.random.rand(len(X), 2)
return probs / probs.sum(axis=1, keepdims=True)
# Run with: uvicorn main:app --reload
print("FastAPI ML Service")
print("=" * 50)
print("\nEndpoints:")
print(" GET /health - Health check")
print(" GET /docs - API documentation")
print(" POST /predict - Single prediction")
print(" POST /batch-predict - Batch predictions")
print(" GET /model/info - Model information")
print("\nStart server:")
print(" uvicorn main:app --host 0.0.0.0 --port 8000")
print("\nExample request:")
print(' curl -X POST "http://localhost:8000/predict" \\')
print(' -H "Content-Type: application/json" \\')
print(' -d '{"features": [1,2,3,4,5,6,7,8,9,10]}'')
print("\nā FastAPI service configured!")FastAPI ML Service
==================================================
Endpoints:
GET /health - Health check
GET /docs - API documentation
POST /predict - Single prediction
POST /batch-predict - Batch predictions
GET /model/info - Model information
Start server:
uvicorn main:app --host 0.0.0.0 --port 8000
Example request:
curl -X POST "http://localhost:8000/predict" \
-H "Content-Type: application/json" \
-d '{"features": [1,2,3,4,5,6,7,8,9,10]}'
ā FastAPI service configured!**Performance Optimization:** 1. **Async Processing** - Use async endpoints for I/O operations - Background tasks for logging - Queue systems (Celery, RQ) for heavy tasks 2. **Caching** - Cache predictions for identical inputs - Use Redis or Memcached - Set appropriate TTL 3. **Batching** - Batch multiple requests - Improves throughput - Reduces per-request overhead 4. **Model Loading** - Load model once at startup - Use model registry for version management - Implement hot-swapping for updates **Security:** - API key authentication - Rate limiting (per user/IP) - Input validation - HTTPS only in production - CORS configuration **Monitoring:** - Request/response times - Error rates - Model performance drift - Resource usage (CPU, memory) **Testing:** ```python # pytest example def test_predict_endpoint(): response = client.post("/predict", json={"features": [1,2,3,4,5,6,7,8,9,10]}) assert response.status_code == 200 assert "prediction" in response.json() ```