Concepto clave
La Optimización Directa de Preferencias (DPO) es un método de fine-tuning que entrena modelos de lenguaje directamente sobre preferencias humanas, sin necesidad de un modelo de recompensa explícito. A diferencia de RLHF, que requiere entrenar un modelo de recompensa separado y luego usar aprendizaje por refuerzo, DPO reformula el problema como una clasificación directa entre respuestas preferidas y no preferidas.
Imagina que estás entrenando a un asistente virtual para escribir correos profesionales. En lugar de darle puntuaciones numéricas por cada correo (como en RLHF), simplemente le muestras pares de correos y le dices "este es mejor que este". DPO aprende directamente de estas comparaciones, optimizando la probabilidad de generar respuestas preferidas sobre las no preferidas.
La fórmula clave de DPO transforma el problema de optimización de políticas en RLHF en un problema de clasificación simple:
L_DPO(π_θ) = -E_(x,y_w,y_l)∼D [log σ(β log(π_θ(y_w|x)/π_ref(y_w|x)) - β log(π_θ(y_l|x)/π_ref(y_l|x)))]
Donde π_θ es la política que estamos entrenando, π_ref es el modelo de referencia, y_w es la respuesta preferida, y_l es la respuesta no preferida, β controla la fuerza de la penalización por desviarse del modelo de referencia, y σ es la función sigmoide.
Cómo funciona en la práctica
Veamos un ejemplo paso a paso de cómo implementar DPO:
- Preparación de datos: Recolectas pares de respuestas (preferida vs no preferida) para cada prompt. Por ejemplo:
- Prompt: "Escribe un resumen de 100 palabras sobre machine learning"
- Respuesta preferida: "Machine learning es un campo de la inteligencia artificial que permite a los sistemas aprender y mejorar automáticamente a partir de la experiencia sin ser programados explícitamente. Se basa en algoritmos que identifican patrones en datos para hacer predicciones o decisiones. Sus aplicaciones incluyen reconocimiento de voz, recomendaciones y diagnóstico médico."
- Respuesta no preferida: "ML es algo de computadoras que aprende cosas. Es como cuando un niño aprende pero con números. Se usa en muchas apps."
- Configuración del modelo: Tomas un modelo base (como Llama 2 o Mistral) como π_ref. Inicializas π_θ con los mismos pesos.
- Entrenamiento: Para cada batch de datos:
- Calculas los logits para ambas respuestas usando π_θ
- Calculas los logits para ambas respuestas usando π_ref
- Aplicas la fórmula DPO para calcular la pérdida
- Actualizas los parámetros de π_θ usando backpropagation
- Evaluación: Mides la tasa de preferencia humana en respuestas generadas después del entrenamiento.
Una implementación básica en pseudocódigo:
def dpo_loss(policy_logits, ref_logits, preferred_idx, beta=0.1):
# policy_logits: logits del modelo que entrenamos
# ref_logits: logits del modelo de referencia
# preferred_idx: índice de la respuesta preferida (0 o 1)
# Calculamos log-ratios
log_ratio_policy = policy_logits[preferred_idx] - policy_logits[1-preferred_idx]
log_ratio_ref = ref_logits[preferred_idx] - ref_logits[1-preferred_idx]
# Pérdida DPO
loss = -torch.log(torch.sigmoid(beta * (log_ratio_policy - log_ratio_ref)))
return lossCaso de estudio
Proyecto: Fine-tuning de un modelo para generar respuestas de soporte técnico más útiles
Contexto: Una empresa de SaaS tenía un modelo base que generaba respuestas técnicas correctas pero a menudo demasiado técnicas o poco empáticas. Los usuarios calificaban las respuestas como "útiles" o "no útiles" basándose en si resolvían su problema y eran fáciles de entender.
Implementación:
| Etapa | Acción | Resultado |
|---|---|---|
| Recolección de datos | 5000 pares de respuestas (preferida/no preferida) de tickets reales | Dataset balanceado con diversidad de problemas |
| Modelo base | Mistral 7B | Precisión técnica alta pero empatía baja |
| Entrenamiento DPO | 3 épocas, β=0.1, lr=1e-5 | 16 horas en 4x A100 |
| Evaluación | 100 prompts nuevos con evaluación humana ciega | 85% de respuestas preferidas vs 45% del modelo base |
Lección aprendida: El parámetro β fue crucial. Con β=0.01, el modelo casi no cambiaba. Con β=1.0, el modelo generaba respuestas demasiado simples. β=0.1 dio el mejor balance entre alineamiento y diversidad.
Errores comunes
- Dataset de baja calidad: Usar pares donde la diferencia entre respuestas es mínima o ambigua. Solución: Asegurar que los anotadores tengan criterios claros y entrenamiento adecuado. Filtrar pares donde la preferencia no sea clara (ej., menos del 70% de acuerdo entre anotadores).
- β mal configurado: Un β demasiado alto hace que el modelo colapse a generar siempre la misma respuesta. Un β demasiado bajo no produce cambios significativos. Solución: Hacer una búsqueda en grid: probar [0.01, 0.05, 0.1, 0.5, 1.0] con un conjunto de validación pequeño.
- Olvidar el modelo de referencia: No congelar los pesos del modelo de referencia durante el entrenamiento, causando deriva. Solución: Siempre usar .detach() o .requires_grad_(False) en el modelo de referencia.
- Sobreajuste a preferencias específicas: El modelo aprende a generar respuestas que coinciden exactamente con el estilo del dataset, perdiendo creatividad. Solución: Usar regularización (como weight decay) y limitar épocas de entrenamiento (2-4 épocas típicamente).
- Ignorar el balance de datos: Tener muchos más ejemplos de cierto tipo de preferencia. Solución: Analizar la distribución de características en los datos y balancear mediante muestreo o ponderación de pérdida.
Checklist de dominio
- ✓ Puedo explicar la diferencia matemática entre RLHF y DPO sin consultar referencias
- ✓ He implementado DPO desde cero o usando una librería como TRL
- ✓ Sé cómo seleccionar y ajustar el parámetro β para mi caso de uso específico
- ✓ Puedo diseñar un proceso de recolección de datos de preferencias que minimice sesgos
- ✓ Entiendo cuándo usar DPO vs RLHF (DPO para fine-tuning final, RLHF para alineamiento complejo)
- ✓ Sé cómo evaluar un modelo entrenado con DPO más allá de la pérdida de entrenamiento
- ✓ Puedo debuggear problemas comunes como colapso de modo o sobreajuste
Implementación de DPO para mejorar respuestas de un chatbot
En este ejercicio, implementarás DPO para mejorar un modelo que genera respuestas a preguntas sobre programación. Trabajarás con un dataset simulado basado en Stack Overflow.
- Prepara el entorno:
- Instala las librerías necesarias: transformers, torch, datasets
- Descarga el modelo base "microsoft/DialoGPT-small"
- Carga el dataset de pares de preferencias (disponible en el repositorio del curso)
- Implementa la función de pérdida DPO:
- Crea una función que tome los logits del modelo policy, los logits del modelo reference, y los índices de respuestas preferidas
- Implementa la fórmula L_DPO usando torch operations
- Asegúrate de manejar correctamente el broadcasting para batches
- Configura el entrenamiento:
- Congela los pesos del modelo de referencia
- Configura el optimizer (AdamW con lr=5e-5)
- Define el loop de entrenamiento por 2 épocas
- Evalúa los resultados:
- Genera respuestas para 10 prompts de prueba antes y después del entrenamiento
- Compara cualitativamente la mejora en claridad y precisión
- Calcula la tasa de preferencia usando un criterio simple (longitud adecuada, presencia de ejemplos de código)
- Experimenta con β:
- Entrena con β=0.05, 0.1, y 0.2
- Compara cómo cambia el comportamiento del modelo
- Documenta tus observaciones
- Recuerda que el modelo de referencia no debe actualizar sus gradientes durante el entrenamiento
- Para calcular log-ratios eficientemente, usa torch.gather para seleccionar los logits de las respuestas preferidas y no preferidas
- Si la pérdida no disminuye, verifica que estés usando .detach() en los logits del modelo de referencia
Evalua tu comprension
Completa el quiz interactivo de arriba para ganar XP.