Concepto clave
Integrar un modelo de Machine Learning en FastAPI es el puente entre el desarrollo de modelos y su despliegue en producción. Piensa en FastAPI como el restaurante que sirve platos (predicciones) preparados por el chef (modelo ML). El modelo, entrenado previamente, es cargado una vez al iniciar la aplicación para maximizar eficiencia, evitando recargas costosas por cada solicitud. Esta arquitectura separa claramente la lógica de inferencia de la API, permitiendo escalabilidad y mantenimiento sencillo.
En producción, no solo se trata de hacer predicciones, sino de garantizar confiabilidad y rendimiento. FastAPI maneja las solicitudes HTTP, valida los datos de entrada con Pydantic, y pasa los datos procesados al modelo. El modelo, a su vez, debe ser serializable (usando formatos como pickle o joblib) y optimizado para inferencia rápida. Un error común es tratar el modelo como un componente estático; en realidad, requiere versionado y monitoreo continuo para detectar desviaciones en los datos.
Cómo funciona en la práctica
El proceso sigue un flujo estructurado: primero, se carga el modelo desde un archivo al iniciar FastAPI, típicamente en un evento de startup. Luego, se define un endpoint que recibe datos, los preprocesa si es necesario, y llama al modelo para obtener una predicción. Veamos un ejemplo paso a paso con un modelo de clasificación de texto.
- Prepara tu entorno: instala FastAPI, uvicorn, y scikit-learn.
- Entrena y guarda un modelo simple, por ejemplo, un clasificador de spam, usando joblib.
- Crea un archivo main.py con FastAPI, carga el modelo en una variable global al inicio.
- Define un esquema Pydantic para validar la entrada, como un campo de texto.
- Implementa un endpoint POST que use el modelo para predecir y devuelva el resultado en JSON.
Este enfoque asegura que el modelo esté listo para usarse, reduciendo latencia y permitiendo manejar múltiples solicitudes concurrentes. En producción, añadirías logging, manejo de errores, y tal vez un sistema de caché para respuestas frecuentes.
Codigo en accion
Aquí tienes un ejemplo funcional que carga un modelo de regresión lineal preentrenado. Antes de ejecutar, asegúrate de tener un modelo guardado como 'model.pkl'.
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import joblib
import numpy as np
# Cargar el modelo al iniciar la app
model = joblib.load('model.pkl')
app = FastAPI()
# Esquema para validar entrada
class PredictionInput(BaseModel):
features: list[float]
@app.post("/predict")
async def predict(input_data: PredictionInput):
try:
# Convertir a array numpy y predecir
features_array = np.array(input_data.features).reshape(1, -1)
prediction = model.predict(features_array)
return {"prediction": prediction.tolist()[0]}
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
Ahora, mejoremos este código añadiendo logging y manejo de errores más robusto, clave en producción.
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import joblib
import numpy as np
import logging
# Configurar logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Cargar el modelo con manejo de errores
try:
model = joblib.load('model.pkl')
logger.info("Modelo cargado exitosamente")
except FileNotFoundError:
logger.error("Archivo del modelo no encontrado")
raise
app = FastAPI()
class PredictionInput(BaseModel):
features: list[float]
min_items = 1 # Validación adicional
@app.post("/predict")
async def predict(input_data: PredictionInput):
logger.info(f"Solicitud recibida: {input_data.features}")
try:
if len(input_data.features) != model.n_features_in_:
raise ValueError("Número de características incorrecto")
features_array = np.array(input_data.features).reshape(1, -1)
prediction = model.predict(features_array)
logger.info(f"Predicción realizada: {prediction[0]}")
return {"prediction": prediction.tolist()[0]}
except ValueError as e:
logger.warning(f"Error de validación: {e}")
raise HTTPException(status_code=422, detail=str(e))
except Exception as e:
logger.error(f"Error interno: {e}")
raise HTTPException(status_code=500, detail="Error interno del servidor")
Errores comunes
- Cargar el modelo en cada solicitud: Esto ralentiza la API drásticamente. Solución: Carga el modelo una vez al inicio, usando eventos de startup de FastAPI o variables globales, como se muestra en el código.
- No validar la entrada: Enviar datos mal formados puede crashear el modelo. Usa esquemas Pydantic para validar tipos, rangos y estructuras, evitando errores inesperados.
- Ignorar el versionado del modelo: Desplegar un nuevo modelo sin control puede romper la API. Implementa un sistema de versionado, por ejemplo, con endpoints separados o metadatos en la respuesta.
- Falta de manejo de errores: Un error no manejado devuelve respuestas genéricas al cliente. Incluye bloques try-except y logging para debuggear y proporcionar mensajes claros.
- No optimizar para inferencia: Modelos grandes o lentos afectan el rendimiento. Considera técnicas como cuantización o uso de bibliotecas optimizadas (e.g., ONNX Runtime) en producción.
Checklist de dominio
- ¿Puedes cargar un modelo serializado (pickle/joblib) en FastAPI sin errores?
- ¿Implementas validación de entrada con Pydantic para todos los endpoints?
- ¿Manejas excepciones y registras logs para monitorizar el comportamiento?
- ¿Has probado el endpoint con datos reales y simulados para verificar precisión y rendimiento?
- ¿Consideras el versionado del modelo y su impacto en clientes existentes?
- ¿Optimizas la inferencia, por ejemplo, reduciendo el tamaño del modelo o usando caché?
- ¿Documentas la API, incluyendo ejemplos de solicitud y respuesta para facilitar el uso?
Implementa un endpoint de predicción con validación y logging
Sigue estos pasos para crear una API funcional que cargue un modelo de ML y responda a solicitudes:
- Prepara un modelo: Si no tienes uno, usa scikit-learn para entrenar un modelo simple (e.g., clasificación de iris) y guárdalo con joblib como 'mi_modelo.pkl'.
- Crea un archivo Python (app.py) e importa FastAPI, Pydantic, joblib, numpy y logging.
- En el evento de startup de FastAPI, carga el modelo desde 'mi_modelo.pkl' y configura logging a nivel INFO.
- Define un esquema Pydantic llamado
IrisInputcon campos para las características (sepal_length, sepal_width, petal_length, petal_width) como floats. - Implementa un endpoint POST en '/predict' que:
- Valide la entrada usando
IrisInput. - Registre la solicitud recibida con logging.
- Convierta las características a un array numpy y haga la predicción.
- Maneje errores: si el número de características no coincide, devuelve error 422; otros errores, devuelve 500.
- Devuelva la predicción como JSON, por ejemplo, {'class': 'setosa'}.
- Valide la entrada usando
- Ejecuta la API con uvicorn y prueba con una herramienta como curl o Postman, enviando datos válidos e inválidos.
- Usa el decorador @app.on_event('startup') para cargar el modelo al inicio.
- Asegúrate de que el esquema Pydantic tenga tipos correctos para evitar errores de validación.
- En logging, usa logger.info() para mensajes normales y logger.error() para fallos.
Evalua tu comprension
Completa el quiz interactivo de arriba para ganar XP.