Concepto clave
En producción, un endpoint de clasificación de imágenes no es solo un modelo ML ejecutándose. Es un sistema que recibe imágenes, las procesa, aplica el modelo y devuelve resultados de forma confiable. Piensa en una línea de ensamblaje: la imagen entra como materia prima, pasa por etapas de control de calidad (validación), transformación (preprocesamiento), análisis (modelo) y empaquetado (respuesta).
FastAPI actú como el coordinador de esta línea. Su sistema de tipos y validación garantiza que solo entren imágenes válidas, mientras que su asincronía permite manejar múltiples solicitudes sin bloquear el sistema. La clave está en separar claramente la lógica de la API (FastAPI) de la lógica del modelo (ML), manteniendo interfaces limpias entre ambos.
Cómo funciona en la práctica
Vamos a construir un endpoint para clasificar imágenes de perros y gatos. El flujo completo tiene estos pasos:
- Recibir la imagen como archivo en una petición POST
- Validar que sea una imagen válida (formato, tamaño)
- Preprocesarla para que coincida con lo que espera el modelo
- Cargar y ejecutar el modelo preentrenado
- Interpretar las predicciones del modelo
- Devolver una respuesta estructurada con la clase y confianza
Antes de codificar, definamos la estructura del proyecto: un archivo main.py para la API, un directorio models/ para el código ML, y requirements.txt para dependencias. Esta separación facilita el mantenimiento y testing.
Código en acción
Primero, el endpoint básico sin validaciones robustas:
from fastapi import FastAPI, File, UploadFile
from PIL import Image
import numpy as np
import tensorflow as tf
app = FastAPI()
model = tf.keras.models.load_model('models/cats_dogs.h5')
@app.post("/predict/")
async def predict(image: UploadFile = File(...)):
img = Image.open(image.file)
img = img.resize((224, 224))
img_array = np.array(img) / 255.0
img_array = np.expand_dims(img_array, axis=0)
prediction = model.predict(img_array)
class_idx = np.argmax(prediction[0])
confidence = float(prediction[0][class_idx])
return {"class": "dog" if class_idx == 0 else "cat", "confidence": confidence}Ahora, la versión mejorada con validaciones y separación de responsabilidades:
from fastapi import FastAPI, File, UploadFile, HTTPException
from pydantic import BaseModel
from typing import Optional
from PIL import Image
import numpy as np
import tensorflow as tf
import io
app = FastAPI(title="Image Classification API")
# Modelo cargado una sola vez al iniciar
MODEL = tf.keras.models.load_model('models/cats_dogs.h5')
CLASS_NAMES = ["dog", "cat"]
class PredictionResponse(BaseModel):
class_name: str
confidence: float
processing_time_ms: Optional[float] = None
def validate_image(file: UploadFile) -> Image.Image:
"""Valida y carga la imagen"""
if file.content_type not in ["image/jpeg", "image/png"]:
raise HTTPException(400, "Formato no soportado. Use JPEG o PNG")
try:
contents = file.file.read()
if len(contents) > 5_000_000: # 5MB límite
raise HTTPException(400, "Imagen demasiado grande")
img = Image.open(io.BytesIO(contents))
img.verify() # Verifica integridad
img = Image.open(io.BytesIO(contents)) # Reabrir después de verify
return img
except Exception as e:
raise HTTPException(400, f"Error procesando imagen: {str(e)}")
def preprocess_image(img: Image.Image) -> np.ndarray:
"""Preprocesa para el modelo"""
img = img.resize((224, 224))
img_array = np.array(img)
# Si es PNG con canal alpha, convertir a RGB
if img_array.shape[-1] == 4:
img_array = img_array[..., :3]
# Normalización específica del modelo
img_array = img_array / 255.0
return np.expand_dims(img_array, axis=0)
@app.post("/predict/", response_model=PredictionResponse)
async def predict(image: UploadFile = File(..., description="Imagen JPEG o PNG hasta 5MB")):
import time
start_time = time.time()
# 1. Validación
img = validate_image(image)
# 2. Preprocesamiento
img_array = preprocess_image(img)
# 3. Predicción
predictions = MODEL.predict(img_array, verbose=0)
class_idx = np.argmax(predictions[0])
confidence = float(predictions[0][class_idx])
# 4. Respuesta
processing_time = (time.time() - start_time) * 1000
return PredictionResponse(
class_name=CLASS_NAMES[class_idx],
confidence=round(confidence, 4),
processing_time_ms=round(processing_time, 2)
)Errores comunes
- No validar el tipo de archivo: Aceptar cualquier archivo puede causar caídas del servidor. Siempre verifica content_type y usa PIL.Image.verify().
- Cargar el modelo en cada request: Esto hace la API extremadamente lenta. Carga el modelo una vez al inicio y reúsalo.
- Ignorar el canal alpha en PNGs: Las imágenes PNG pueden tener 4 canales (RGBA), mientras que muchos modelos esperan 3 (RGB). Convierte explícitamente.
- No limitar el tamaño de archivo: Un usuario podría subir una imagen de 1GB y saturar la memoria. Establece un límite razonable (ej: 5-10MB).
- Devolver respuestas inconsistentes: Sin un schema definido (Pydantic), diferentes endpoints pueden devolver formatos distintos. Usa siempre response_model.
Checklist de dominio
- El endpoint valida formato y tamaño de imagen antes de procesar
- El modelo se carga una sola vez, no por cada petición
- El preprocesamiento maneja diferentes formatos (JPEG, PNG) y canales
- La respuesta incluye confianza y tiempo de procesamiento para monitoring
- Los errores devuelven códigos HTTP apropiados (400 para mal input, 500 para errores internos)
- El código está separado en funciones con responsabilidades únicas
- Se usa response_model de Pydantic para garantizar consistencia en respuestas
Implementa un endpoint para clasificación de dígitos manuscritos
Extiende la API para clasificar dígitos manuscritos (0-9) usando el dataset MNIST. Sigue estos pasos:
- Crea un nuevo endpoint
/predict-digit/que acepte imágenes en escala de grises - Descarga un modelo preentrenado de MNIST (puedes usar TensorFlow o PyTorch) o entrena uno simple
- Implementa validación específica: solo imágenes 28x28 píxeles, en escala de grises
- El preprocesamiento debe normalizar los valores de píxel a [0,1] y asegurar las dimensiones correctas (1, 28, 28, 1)
- Devuelve el dígito predicho (0-9), la confianza, y opcionalmente las probabilidades para todas las clases
- Agrega un endpoint
/healthque devuelva el estado del modelo (cargado o no) y memoria usada - Prueba con imágenes de ejemplo del dataset MNIST
Estructura recomendada:
- main.py: endpoints de FastAPI
- models/digit_classifier.py: código para cargar y usar el modelo
- utils/validation.py: funciones de validación de imágenes
- requirements.txt: dependencias específicas
- Usa `tf.keras.datasets.mnist.load_data()` para obtener datos de prueba si no tienes imágenes
- Para escala de grises, verifica que la imagen tenga un solo canal o conviértela con `Image.open(...).convert('L')`
- Considera agregar un parámetro opcional `return_probabilities: bool = False` para controlar la respuesta detallada
Evalua tu comprension
Completa el quiz interactivo de arriba para ganar XP.