Práctica: Crear un Endpoint para Clasificación de Imágenes

Video
25 min~5 min lectura

Reproductor de video

Concepto clave

En producción, un endpoint de clasificación de imágenes no es solo un modelo ML ejecutándose. Es un sistema que recibe imágenes, las procesa, aplica el modelo y devuelve resultados de forma confiable. Piensa en una línea de ensamblaje: la imagen entra como materia prima, pasa por etapas de control de calidad (validación), transformación (preprocesamiento), análisis (modelo) y empaquetado (respuesta).

FastAPI actú como el coordinador de esta línea. Su sistema de tipos y validación garantiza que solo entren imágenes válidas, mientras que su asincronía permite manejar múltiples solicitudes sin bloquear el sistema. La clave está en separar claramente la lógica de la API (FastAPI) de la lógica del modelo (ML), manteniendo interfaces limpias entre ambos.

Cómo funciona en la práctica

Vamos a construir un endpoint para clasificar imágenes de perros y gatos. El flujo completo tiene estos pasos:

  1. Recibir la imagen como archivo en una petición POST
  2. Validar que sea una imagen válida (formato, tamaño)
  3. Preprocesarla para que coincida con lo que espera el modelo
  4. Cargar y ejecutar el modelo preentrenado
  5. Interpretar las predicciones del modelo
  6. Devolver una respuesta estructurada con la clase y confianza

Antes de codificar, definamos la estructura del proyecto: un archivo main.py para la API, un directorio models/ para el código ML, y requirements.txt para dependencias. Esta separación facilita el mantenimiento y testing.

Código en acción

Primero, el endpoint básico sin validaciones robustas:

from fastapi import FastAPI, File, UploadFile
from PIL import Image
import numpy as np
import tensorflow as tf

app = FastAPI()
model = tf.keras.models.load_model('models/cats_dogs.h5')

@app.post("/predict/")
async def predict(image: UploadFile = File(...)):
    img = Image.open(image.file)
    img = img.resize((224, 224))
    img_array = np.array(img) / 255.0
    img_array = np.expand_dims(img_array, axis=0)
    
    prediction = model.predict(img_array)
    class_idx = np.argmax(prediction[0])
    confidence = float(prediction[0][class_idx])
    
    return {"class": "dog" if class_idx == 0 else "cat", "confidence": confidence}

Ahora, la versión mejorada con validaciones y separación de responsabilidades:

from fastapi import FastAPI, File, UploadFile, HTTPException
from pydantic import BaseModel
from typing import Optional
from PIL import Image
import numpy as np
import tensorflow as tf
import io

app = FastAPI(title="Image Classification API")

# Modelo cargado una sola vez al iniciar
MODEL = tf.keras.models.load_model('models/cats_dogs.h5')
CLASS_NAMES = ["dog", "cat"]

class PredictionResponse(BaseModel):
    class_name: str
    confidence: float
    processing_time_ms: Optional[float] = None

def validate_image(file: UploadFile) -> Image.Image:
    """Valida y carga la imagen"""
    if file.content_type not in ["image/jpeg", "image/png"]:
        raise HTTPException(400, "Formato no soportado. Use JPEG o PNG")
    
    try:
        contents = file.file.read()
        if len(contents) > 5_000_000:  # 5MB límite
            raise HTTPException(400, "Imagen demasiado grande")
        
        img = Image.open(io.BytesIO(contents))
        img.verify()  # Verifica integridad
        img = Image.open(io.BytesIO(contents))  # Reabrir después de verify
        return img
    except Exception as e:
        raise HTTPException(400, f"Error procesando imagen: {str(e)}")

def preprocess_image(img: Image.Image) -> np.ndarray:
    """Preprocesa para el modelo"""
    img = img.resize((224, 224))
    img_array = np.array(img)
    
    # Si es PNG con canal alpha, convertir a RGB
    if img_array.shape[-1] == 4:
        img_array = img_array[..., :3]
    
    # Normalización específica del modelo
    img_array = img_array / 255.0
    return np.expand_dims(img_array, axis=0)

@app.post("/predict/", response_model=PredictionResponse)
async def predict(image: UploadFile = File(..., description="Imagen JPEG o PNG hasta 5MB")):
    import time
    start_time = time.time()
    
    # 1. Validación
    img = validate_image(image)
    
    # 2. Preprocesamiento
    img_array = preprocess_image(img)
    
    # 3. Predicción
    predictions = MODEL.predict(img_array, verbose=0)
    class_idx = np.argmax(predictions[0])
    confidence = float(predictions[0][class_idx])
    
    # 4. Respuesta
    processing_time = (time.time() - start_time) * 1000
    
    return PredictionResponse(
        class_name=CLASS_NAMES[class_idx],
        confidence=round(confidence, 4),
        processing_time_ms=round(processing_time, 2)
    )

Errores comunes

  • No validar el tipo de archivo: Aceptar cualquier archivo puede causar caídas del servidor. Siempre verifica content_type y usa PIL.Image.verify().
  • Cargar el modelo en cada request: Esto hace la API extremadamente lenta. Carga el modelo una vez al inicio y reúsalo.
  • Ignorar el canal alpha en PNGs: Las imágenes PNG pueden tener 4 canales (RGBA), mientras que muchos modelos esperan 3 (RGB). Convierte explícitamente.
  • No limitar el tamaño de archivo: Un usuario podría subir una imagen de 1GB y saturar la memoria. Establece un límite razonable (ej: 5-10MB).
  • Devolver respuestas inconsistentes: Sin un schema definido (Pydantic), diferentes endpoints pueden devolver formatos distintos. Usa siempre response_model.

Checklist de dominio

  1. El endpoint valida formato y tamaño de imagen antes de procesar
  2. El modelo se carga una sola vez, no por cada petición
  3. El preprocesamiento maneja diferentes formatos (JPEG, PNG) y canales
  4. La respuesta incluye confianza y tiempo de procesamiento para monitoring
  5. Los errores devuelven códigos HTTP apropiados (400 para mal input, 500 para errores internos)
  6. El código está separado en funciones con responsabilidades únicas
  7. Se usa response_model de Pydantic para garantizar consistencia en respuestas

Implementa un endpoint para clasificación de dígitos manuscritos

Extiende la API para clasificar dígitos manuscritos (0-9) usando el dataset MNIST. Sigue estos pasos:

  1. Crea un nuevo endpoint /predict-digit/ que acepte imágenes en escala de grises
  2. Descarga un modelo preentrenado de MNIST (puedes usar TensorFlow o PyTorch) o entrena uno simple
  3. Implementa validación específica: solo imágenes 28x28 píxeles, en escala de grises
  4. El preprocesamiento debe normalizar los valores de píxel a [0,1] y asegurar las dimensiones correctas (1, 28, 28, 1)
  5. Devuelve el dígito predicho (0-9), la confianza, y opcionalmente las probabilidades para todas las clases
  6. Agrega un endpoint /health que devuelva el estado del modelo (cargado o no) y memoria usada
  7. Prueba con imágenes de ejemplo del dataset MNIST

Estructura recomendada:

  • main.py: endpoints de FastAPI
  • models/digit_classifier.py: código para cargar y usar el modelo
  • utils/validation.py: funciones de validación de imágenes
  • requirements.txt: dependencias específicas
Pistas
  • Usa `tf.keras.datasets.mnist.load_data()` para obtener datos de prueba si no tienes imágenes
  • Para escala de grises, verifica que la imagen tenga un solo canal o conviértela con `Image.open(...).convert('L')`
  • Considera agregar un parámetro opcional `return_probabilities: bool = False` para controlar la respuesta detallada

Evalua tu comprension

Completa el quiz interactivo de arriba para ganar XP.