How to Serve a Machine Learning Model With FastAPI in Python

A machine learning model that lives inside a Jupyter notebook is not helping anyone. It only becomes useful once other applications can send it data and get predictions back. FastAPI is one of the fastest and simplest ways to wrap a trained Python model in an API endpoint, complete with automatic input validation, interactive documentation, and performance that can handle production traffic. This guide walks through the full process: training a model, saving it, loading it into a FastAPI application at startup, and exposing a prediction endpoint that validates every request with Pydantic.

Why FastAPI for Model Serving

There are several ways to deploy a machine learning model: managed platforms like AWS SageMaker, dedicated serving frameworks like TensorFlow Serving, or general-purpose web frameworks. FastAPI sits in a sweet spot for many teams because it requires no specialized infrastructure, works with any Python ML library (scikit-learn, XGBoost, PyTorch, TensorFlow), and generates interactive API documentation automatically from your code.

FastAPI also provides Pydantic-based request validation out of the box. That means every prediction request is checked against a defined schema before your model ever sees it. Invalid inputs get rejected with clear error messages instead of causing cryptic NumPy errors deep in your inference code. For teams that need to hand an API to frontend developers, mobile engineers, or external partners, that level of self-documentation and validation saves significant back-and-forth.

Project Setup and Dependencies

The project uses a straightforward file layout. The training script runs once to produce a saved model file. The FastAPI application loads that file at startup and serves predictions.

ml-api/
    train_model.py         # Train and save the model
    app/
        __init__.py
        main.py            # FastAPI application
        schemas.py         # Pydantic request/response models
    models/
        classifier.joblib  # Saved model artifact
    requirements.txt

Install the dependencies:

pip install fastapi uvicorn scikit-learn joblib pydantic

This example uses scikit-learn for training because it covers the pattern that applies to any ML library. The serving approach—save the artifact, load it once, expose an endpoint—works the same whether the model behind it is a random forest, an XGBoost classifier, or a PyTorch neural network.

Step 1: Train and Save Your Model

First, train a classifier and save it to disk. This script runs once, outside of the API. The output is a .joblib file that the FastAPI application will load later.

# train_model.py

from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score
import joblib
import os

# Load dataset
iris = load_iris()
X, y = iris.data, iris.target

# Split into training and test sets
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42
)

# Train a random forest classifier
clf = RandomForestClassifier(n_estimators=100, random_state=42)
clf.fit(X_train, y_train)

# Evaluate
y_pred = clf.predict(X_test)
print(f"Accuracy: {accuracy_score(y_test, y_pred):.2f}")

# Save the trained model
os.makedirs("models", exist_ok=True)
joblib.dump(clf, "models/classifier.joblib")
print("Model saved to models/classifier.joblib")
Note

joblib is preferred over pickle for scikit-learn models because it handles large NumPy arrays more efficiently, producing smaller files that load faster. For PyTorch models, use torch.save() and torch.load() instead. For TensorFlow/Keras, use model.save() with the SavedModel format.

Step 2: Define Input and Output Schemas

Pydantic schemas define the exact shape of data your API accepts and returns. Every incoming request gets validated against these schemas before reaching your model. This prevents malformed data from triggering confusing errors during inference.

# app/schemas.py

from pydantic import BaseModel, Field

class PredictionRequest(BaseModel):
    sepal_length: float = Field(..., gt=0, description="Sepal length in cm")
    sepal_width: float = Field(..., gt=0, description="Sepal width in cm")
    petal_length: float = Field(..., gt=0, description="Petal length in cm")
    petal_width: float = Field(..., gt=0, description="Petal width in cm")

    model_config = {
        "json_schema_extra": {
            "examples": [
                {
                    "sepal_length": 5.1,
                    "sepal_width": 3.5,
                    "petal_length": 1.4,
                    "petal_width": 0.2,
                }
            ]
        }
    }


class PredictionResponse(BaseModel):
    predicted_class: int = Field(..., description="Numeric class label (0, 1, or 2)")
    predicted_class_name: str = Field(..., description="Human-readable species name")
    confidence: float = Field(..., ge=0, le=1, description="Prediction confidence score")

The Field(..., gt=0) constraint ensures that negative or zero values get rejected automatically with a clear error message. The json_schema_extra block provides an example payload that appears in the auto-generated Swagger UI, making it easy for anyone consuming your API to test it without reading external documentation.

Step 3: Load the Model at Startup With Lifespan

The single most important rule of model serving is: load the model once, not on every request. Reading a model file from disk on every prediction adds latency that compounds under load. FastAPI provides a lifespan context manager that runs code when the application starts and cleans up when it shuts down.

# app/main.py

from contextlib import asynccontextmanager
from fastapi import FastAPI, HTTPException
import joblib
import numpy as np

from app.schemas import PredictionRequest, PredictionResponse

# Store loaded model artifacts
ml_artifacts = {}

@asynccontextmanager
async def lifespan(app: FastAPI):
    # Startup: load model into memory
    try:
        ml_artifacts["model"] = joblib.load("models/classifier.joblib")
        # Warm up the model with a dummy prediction
        dummy = np.zeros((1, 4))
        ml_artifacts["model"].predict(dummy)
        print("Model loaded and warmed up successfully.")
    except FileNotFoundError:
        print("ERROR: Model file not found at models/classifier.joblib")
        ml_artifacts["model"] = None
    yield
    # Shutdown: clean up
    ml_artifacts.clear()

app = FastAPI(
    title="Iris Prediction API",
    version="1.0.0",
    lifespan=lifespan,
)
Pro Tip

The warm-up step matters more than you might expect. The first prediction after loading a scikit-learn model can be significantly slower than subsequent ones due to internal lazy initialization. Running a single dummy prediction at startup ensures that real requests get consistent latency from the start.

Common Mistake

Never call joblib.load() or pickle.load() inside your route handler. Loading the model on every request introduces disk I/O into your hot path and will destroy your throughput under any real traffic.

Step 4: Build the Prediction Endpoint

The prediction endpoint receives a validated request, converts it into a NumPy array, passes it to the model, and returns a structured response.

# Add this to app/main.py below the app instance

IRIS_CLASS_NAMES = ["setosa", "versicolor", "virginica"]

@app.post("/predict", response_model=PredictionResponse)
def predict(request: PredictionRequest):
    model = ml_artifacts.get("model")
    if model is None:
        raise HTTPException(
            status_code=503,
            detail="Model is not available. Check server logs.",
        )

    # Convert input to NumPy array
    features = np.array([[
        request.sepal_length,
        request.sepal_width,
        request.petal_length,
        request.petal_width,
    ]])

    # Get prediction and confidence
    predicted_class = int(model.predict(features)[0])
    confidence = float(model.predict_proba(features)[0].max())

    return PredictionResponse(
        predicted_class=predicted_class,
        predicted_class_name=IRIS_CLASS_NAMES[predicted_class],
        confidence=round(confidence, 4),
    )

Notice that this is a regular def function, not async def. This is intentional. The scikit-learn predict() call is CPU-bound and blocking. When you define a route with def, FastAPI automatically runs it in a thread pool so it does not block the event loop. If you used async def here, the blocking call would prevent other requests from being processed during inference. Use async def only when your inference call is genuinely non-blocking, such as when calling an external ML service over HTTP.

Step 5: Add a Health Check and Model Metadata

A health check endpoint lets monitoring tools and load balancers verify your service is running. A metadata endpoint provides useful context about what model is loaded and when.

# Add these to app/main.py

@app.get("/health")
def health_check():
    model_loaded = ml_artifacts.get("model") is not None
    return {
        "status": "healthy" if model_loaded else "degraded",
        "model_loaded": model_loaded,
    }

@app.get("/model-info")
def model_info():
    model = ml_artifacts.get("model")
    if model is None:
        raise HTTPException(status_code=503, detail="Model not loaded")
    return {
        "model_type": type(model).__name__,
        "n_features": model.n_features_in_,
        "n_classes": len(model.classes_),
        "class_labels": IRIS_CLASS_NAMES,
    }

Testing Your API

Start the server:

uvicorn app.main:app --reload

Open your browser and navigate to http://127.0.0.1:8000/docs to see the interactive Swagger UI. FastAPI generates this documentation automatically from your route definitions and Pydantic schemas. You can send test requests directly from the browser.

To test from the command line:

curl -X POST http://127.0.0.1:8000/predict \
  -H "Content-Type: application/json" \
  -d '{"sepal_length": 5.1, "sepal_width": 3.5, "petal_length": 1.4, "petal_width": 0.2}'

Expected response:

{
    "predicted_class": 0,
    "predicted_class_name": "setosa",
    "confidence": 1.0
}

Try sending invalid data to see the automatic validation in action:

curl -X POST http://127.0.0.1:8000/predict \
  -H "Content-Type: application/json" \
  -d '{"sepal_length": -1, "sepal_width": 3.5, "petal_length": 1.4, "petal_width": 0.2}'

FastAPI will return a 422 response explaining that sepal_length must be greater than 0. Your model never sees invalid data.

Production Considerations

The example above covers the core pattern. Moving to production adds a few more concerns.

Concern Approach
Concurrency Run Uvicorn behind Gunicorn with multiple workers: gunicorn app.main:app -w 4 -k uvicorn.workers.UvicornWorker
Containerization Package the app and model file in a Docker image. Copy the model artifact during the build stage so the container is self-contained.
Model versioning Store models with version-tagged filenames (e.g., classifier_v2.joblib) and expose the version through the /model-info endpoint.
Batch predictions Add a /predict/batch endpoint that accepts a list of inputs, runs inference in a single call, and returns a list of results.
Heavy models For large models (deep learning, LLMs), offload inference to a background worker with a ThreadPoolExecutor or use FastAPI's BackgroundTasks for async processing.
Monitoring Log prediction latency and input distributions. Tools like Prometheus, Evidently, or custom middleware can track model drift over time.

Frequently Asked Questions

How do I load a machine learning model in FastAPI without reloading it on every request?

Use FastAPI's lifespan context manager to load the model once when the server starts. Store it in a dictionary or module-level variable. Every request handler then reads from that in-memory reference instead of hitting the filesystem. The lifespan approach also gives you a clean shutdown hook for releasing resources when the server stops.

Should I use pickle or joblib to save a scikit-learn model for FastAPI?

joblib is the preferred choice for scikit-learn models. It serializes large NumPy arrays more efficiently than pickle, resulting in smaller files and faster load times. For deep learning frameworks, use the library's native save format instead: torch.save() for PyTorch, model.save() for TensorFlow/Keras.

Can FastAPI handle concurrent prediction requests for a machine learning model?

Yes. When you define your prediction route as a regular def function, FastAPI runs it in a thread pool automatically. This allows multiple requests to be processed concurrently without blocking the async event loop. For true parallelism (multi-core), run multiple Uvicorn workers behind Gunicorn.

Should I use async def or def for my prediction endpoint?

Use def for endpoints where the model inference is CPU-bound and blocking (scikit-learn, XGBoost, most traditional ML). FastAPI handles threading for you. Use async def only when the inference itself is non-blocking—for example, when you call an external prediction service over HTTP using an async client like httpx.

Key Takeaways

  1. Load once, serve forever: Use FastAPI's lifespan context manager to load your model into memory at startup. Never load a model file inside a route handler.
  2. Validate inputs with Pydantic: Define request schemas that enforce data types, value ranges, and required fields. Invalid requests get rejected with clear error messages before they reach your model.
  3. Use def, not async def, for blocking inference: scikit-learn's predict() is CPU-bound. A regular def route lets FastAPI run it in a thread pool. Using async def would block the event loop.
  4. Warm up your model: Run a dummy prediction after loading. This eliminates the cold-start penalty that can make the first real request significantly slower.
  5. Separate training from serving: The training script and the serving application are different codebases with different lifecycles. The model file is the contract between them. Save it with joblib, version it, and load it at startup.

FastAPI turns a trained model into a production-ready API with surprisingly little code. Pydantic handles validation. The lifespan system handles loading and cleanup. Swagger UI handles documentation. What you are left with is a clean prediction endpoint that accepts structured input, returns structured output, and stays fast under load. From here you can containerize it with Docker, add authentication, set up monitoring, or scale it out with multiple workers—all without changing the core serving pattern.

back to articles