Optimización distribuida de Qwen2-0.5B con LoRA

En este cuaderno se muestra cómo ajustar finamente de manera eficiente el modelo de lenguaje grande Qwen2-0.5B mediante técnicas eficientes en el uso de parámetros en computación GPU sin servidor. Aprenderá a:

  • Aplicar LoRA (adaptación de bajo rango) para reducir los parámetros entrenables en un 99 % mientras se mantiene la calidad del modelo
  • Uso de kernels Liger para entrenamiento eficiente en memoria con kernels Triton optimizados
  • Utilizar TRL (Transformer Reinforcement Learning) para el ajuste fino supervisado
  • Registro del modelo optimizado en el Catálogo de Unity para la gobernanza y la implementación

Conceptos clave:

  • LoRA: técnica que inmoviliza el modelo base y entrena capas de adaptadores pequeños, lo que reduce drásticamente los requisitos de memoria y el tiempo de entrenamiento.
  • Kernels de Liger: kernels optimizados para GPU que reducen el uso de memoria hasta 80% a través de operaciones fusionadas
  • TRL: una biblioteca para entrenar modelos de lenguaje con aprendizaje de refuerzo y ajuste fino supervisado
  • Proceso de GPU sin servidor: proceso administrado de Databricks que escala automáticamente los recursos de GPU

Matriz de decisión de LoRA vs ajuste completo

LoRA (Low-Rank Adaptation) congela el modelo base y entrena solo las capas del adaptador pequeñas, reduciendo los parámetros entrenables en un ~99%. Esto hace que el entrenamiento sea más rápido y eficaz para la memoria.

Escenario Recomendación Motivo
Memoria de GPU limitada LoRA Se ajusta a modelos más grandes en memoria entrenando solo 1% de parámetros
Adaptación específica de la tarea LoRA Intercambio de adaptadores diferentes en el mismo modelo base para varias tareas
Cambio de comportamiento del modelo principal Ajuste completo Actualiza todos los parámetros para los cambios fundamentales en el comportamiento del modelo
Implementación en producción LoRA Archivos más pequeños (MB frente a GB), carga más rápida y control de versiones más fácil

Ventajas del kernel de Liger

Los kernels de Liger son operaciones optimizadas para GPU que fusionan varios pasos en kernels únicos, lo que reduce las transferencias de memoria y mejora la eficacia. El documento técnico proporciona pruebas comparativas detalladas que muestran mejoras significativas en el rendimiento.

  • Operaciones fusionadas: combina operaciones (por ejemplo, lineal y pérdida) para reducir la sobrecarga de memoria hasta 80%
  • Núcleos de Triton: Núcleos personalizados para GPU optimizados para operaciones de transformadores (RMSNorm, RoPE, SwiGLU, CrossEntropy)
  • Eficiencia de memoria: permite tamaños de lote o modelos más grandes que no caben en la memoria de GPU.
  • Optimización de GPU única: especialmente eficaz para escenarios de entrenamiento de gpu única A10/A100

Este cuaderno usa la librería TRL para simplificar la configuración del entrenamiento y aplicar automáticamente estas optimizaciones.

Conectar a Cómputo de GPU sin servidor

Este entorno requiere computación GPU sin servidor. Para conectarse:

  1. Haga clic en el selector de proceso del cuaderno en la parte superior derecha y seleccione GPU sin servidor.
  2. En el lado derecho, haga clic en el botón de entorno.
  3. Seleccione 8xH100 como acelerador.
  4. Elección de AI v4 en el entorno base
  5. Haga clic en Aplicar.

La función de entrenamiento aprovisionará automáticamente 8 GPU H100 para el entrenamiento distribuido.

Instalación de bibliotecas necesarias

La celda siguiente instala los paquetes de Python necesarios para la optimización distribuida:

Bibliotecas de entrenamiento principales:

  • trl: Biblioteca de aprendizaje de refuerzo de transformadores para el ajuste fino supervisado y RLHF
  • peft: biblioteca Parameter-Efficient Fine-Tuning que proporciona implementación de LoRA
  • liger-kernel: Kernels de GPU optimizados para el uso de memoria en el entrenamiento eficiente de transformadores

Bibliotecas compatibles:

  • hf_transfer: descargas aceleradas de Hugging Face Hub mediante la transferencia basada en Rust
  • mlflow>=3.0: integración del registro de modelos y seguimiento de experimentos

El comando %restart_python reinicia el intérprete de Python para asegurarse de que los paquetes recién instalados se cargan correctamente.

%pip install --upgrade peft==0.17.1
%pip install --upgrade hf_transfer==0.1.9
%pip install --upgrade transformers==4.56.1
%pip install trl==0.18.1
%pip install liger-kernel
%pip install mlflow==3.7.0
%restart_python

Configuración de la instalación

Integración de Unity Catalog

La celda siguiente configura dónde se almacenará y registrará el modelo ajustado:

  • Catálogo y esquema: organice los modelos dentro del espacio de nombres del catálogo de Unity (valor predeterminado: main.default)
  • Nombre del modelo: el nombre del modelo registrado en el catálogo de Unity para la gobernanza y la implementación
  • Volumen: Volumen del catálogo de Unity para almacenar checkpoints del modelo durante el entrenamiento

Estos widgets permiten personalizar la ubicación de almacenamiento sin editar código. El modelo se registrará como {catalog}.{schema}.{model_name} para facilitar el acceso y el control de versiones.

Hiperparámetros de entrenamiento

La celda también define los parámetros de entrenamiento clave:

  • Modelo y conjunto de datos: Qwen2-0.5B con el conjunto de datos conversacional de Capybara
  • Tamaño del lote (8): número de ejemplos por GPU por paso de entrenamiento
  • Acumulación de gradientes (4): acumula gradientes a lo largo de 4 lotes para un tamaño de lote efectivo de 32
  • Velocidad de aprendizaje (1e-4): tasa conservadora, multiplicada automáticamente por 10 para el entrenamiento de LoRA
  • Épocas (1): paso único por el conjunto de datos para evitar el sobreajuste
  • Registro y punto de comprobación: guarda el progreso cada 100 pasos, registra las métricas cada 25 pasos.
dbutils.widgets.text("uc_catalog", "main")
dbutils.widgets.text("uc_schema", "default")
dbutils.widgets.text("uc_model_name", "qwen2_liger_lora_assistant")
dbutils.widgets.text("uc_volume", "checkpoints")

UC_CATALOG = dbutils.widgets.get("uc_catalog")
UC_SCHEMA = dbutils.widgets.get("uc_schema")
UC_MODEL_NAME = dbutils.widgets.get("uc_model_name")
UC_VOLUME = dbutils.widgets.get("uc_volume")

print(f"UC_CATALOG: {UC_CATALOG}")
print(f"UC_SCHEMA: {UC_SCHEMA}")
print(f"UC_MODEL_NAME: {UC_MODEL_NAME}")
print(f"UC_VOLUME: {UC_VOLUME}")

# MLflow and Unity Catalog configuration

# Model selection - Choose based on your compute constraints
MODEL_NAME = "Qwen/Qwen2-0.5B"
DATASET_NAME = "trl-lib/Capybara"
OUTPUT_DIR = f"/Volumes/{UC_CATALOG}/{UC_SCHEMA}/{UC_VOLUME}/qwen2-0.5b-lora"

# Training hyperparameters
BATCH_SIZE = 8
GRADIENT_ACCUMULATION_STEPS = 4
LEARNING_RATE = 1e-4
NUM_EPOCHS = 1
EVAL_STEPS = 100
LOGGING_STEPS = 25
SAVE_STEPS = 100

Configuración de LoRA

La siguiente celda configura los parámetros de LoRA (Adaptación de Bajo Rango) que controlan el ajuste del modelo. LoRA congela los pesos del modelo base y solo entrena pequeñas matrices de adaptadores, reduciendo drásticamente los requisitos de memoria.

Selección de parámetros

  • Rank (r=8): proporciona un buen equilibrio entre el rendimiento y los parámetros.
  • Alfa (32): factor de escalado, normalmente de 2 a 4 veces el rango
  • Dropout (0.1): regularización para evitar el sobreajuste

Módulos objetivo para Qwen2

Este ejemplo tiene como destino todas las capas de transformación clave:

  • Atención: q_proj, k_proj, v_proj, , o_proj
  • MLP: gate_proj, up_proj, down_proj
LORA_R = 8
LORA_ALPHA = 32
LORA_DROPOUT = 0.1
LORA_TARGET_MODULES = [
    "q_proj", "k_proj", "v_proj", "o_proj",
    "gate_proj", "up_proj", "down_proj"
]

Definición de la función de entrenamiento

La celda siguiente crea la función de entrenamiento distribuida que se ejecutará en varias GPU. Esto es lo que hace:

Configuración de entrenamiento distribuido

El @distributed decorador configura la Computación en GPU sin Servidor:

  • 8 GPU: distribuye el entrenamiento entre 8 GPU H100 para un entrenamiento más rápido
  • Orquestación automática: controla el aprovisionamiento de GPU, la distribución de datos y la sincronización.

Flujo de trabajo de entrenamiento

La función ejecuta estos pasos:

  1. Cargar conjunto de datos: descarga y prepara el conjunto de datos conversacional de Capybara
  2. Inicializar modelo: carga el modelo Qwen2-0.5B y el tokenizador con formato de chat
  3. Aplicar LoRA: asocia capas de adaptador para reducir los parámetros entrenables en ~99%
  4. Configuración del entrenamiento: configura el tamaño del lote, la velocidad de aprendizaje y las optimizaciones del kernel de Liger.
  5. Modelo de entrenamiento: ejecuta el bucle de entrenamiento con puntos de comprobación y registro automáticos
  6. Guardar artefactos: almacena adaptadores de LoRA y tokenizador en el volumen del catálogo de Unity
  7. Devolver el identificador de ejecución de MLflow: proporciona el identificador de ejecución para el registro del modelo.

Optimizaciones clave habilitadas

  • Kernels de Liger: las operaciones de GPU fusionadas reducen el uso de memoria hasta 80%
  • Precisión mixta (FP16): cálculo más rápido con una superficie de memoria inferior
  • Control de puntos de gradiente: Intercambia el cálculo por memoria para adaptar lotes más grandes.
  • Acumulación de gradiente: simula tamaños de lote más grandes para un entrenamiento más estable
from serverless_gpu import distributed
from serverless_gpu import runtime as rt

@distributed(gpus=8, gpu_type="H100")
def run_train(use_lora=True):
    import logging
    from datasets import load_dataset
    from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
    from peft import LoraConfig, TaskType, get_peft_model
    from trl import (
        SFTConfig,
        SFTTrainer,
        setup_chat_format
    )
    import json
    import os
    import mlflow

    dataset = load_dataset(DATASET_NAME)
    logging.info(f"✓ Dataset loaded: {dataset}")

    if "test" not in dataset:
        logging.info("Creating validation split from training data...")
        dataset = dataset["train"].train_test_split(test_size=0.1, seed=42)
        logging.info("✓ Data split: 90% train, 10% validation")

    # model and tokenizer initialization
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_NAME,
        trust_remote_code=True,
    )

    tokenizer = AutoTokenizer.from_pretrained(
        MODEL_NAME,
        trust_remote_code=True,
        use_fast=True
    )

    # Chat template formatting for conversational fine-tuning
    if tokenizer.chat_template is None:
        logging.info("Adding chat template for proper conversation formatting...")
        model, tokenizer = setup_chat_format(model, tokenizer, format="chatml")
        logging.info("✓ ChatML format applied for structured conversations")

    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
        logging.info("✓ Padding token set to EOS token")

    logging.info("✓ Model and tokenizer loaded successfully")

    # PEFT
    peft_config = None
    if use_lora:
        try:
            logging.info("Configuring LoRA for parameter-efficient fine-tuning...")

            peft_config = LoraConfig(
                task_type=TaskType.CAUSAL_LM,
                inference_mode=False,
                r=LORA_R,
                lora_alpha=LORA_ALPHA,
                lora_dropout=LORA_DROPOUT,
                target_modules=LORA_TARGET_MODULES,
                bias="none",
                use_rslora=False,
                modules_to_save=None,
            )

            logging.info(f"LoRA configuration: rank={LORA_R}, alpha={LORA_ALPHA}, dropout={LORA_DROPOUT}")
            logging.info(f"Target modules: {', '.join(LORA_TARGET_MODULES)}")

            original_params = model.num_parameters()
            model = get_peft_model(model, peft_config)

            trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
            total_params = sum(p.numel() for p in model.parameters())
            efficiency_ratio = 100 * trainable_params / total_params

            logging.info(f"✓ LoRA applied successfully:")
            logging.info(f"  • Original parameters: {original_params:,}")
            logging.info(f"  • Trainable parameters: {trainable_params:,}")
            logging.info(f"  • Training efficiency: {efficiency_ratio:.2f}% of parameters")
            logging.info(f"  • Memory savings: ~{100-efficiency_ratio:.1f}% reduction in gradient memory")

        except Exception as e:
            logging.info(f"✗ LoRA configuration failed: {e}")
            logging.info("Falling back to full fine-tuning...")
            peft_config = None
    else:
        logging.info("Full fine-tuning mode selected")
        trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
        logging.info(f"Trainable parameters: {trainable_params:,} (100% of model)")

    # Learning rate adjustment for LoRA
    adjusted_lr = LEARNING_RATE * 10 if use_lora else LEARNING_RATE
    logging.info(f"Learning rate: {adjusted_lr} ({'LoRA-adjusted' if use_lora else 'standard'})")

    training_args_dict = {
        "output_dir": OUTPUT_DIR,
        "per_device_train_batch_size": BATCH_SIZE,
        "per_device_eval_batch_size": BATCH_SIZE,
        "gradient_accumulation_steps": GRADIENT_ACCUMULATION_STEPS,
        "learning_rate": adjusted_lr,
        "num_train_epochs": NUM_EPOCHS,
        "eval_steps": EVAL_STEPS,
        "logging_steps": LOGGING_STEPS,
        "save_steps": SAVE_STEPS,
        "save_total_limit": 2,
        "report_to": "mlflow",
        "run_name": f"{MODEL_NAME}_fine-tuning",
        "warmup_steps": 50,
        "weight_decay": 0.01,
        "metric_for_best_model": "eval_loss",
        "greater_is_better": False,
        "dataloader_pin_memory": False,
        "remove_unused_columns": False,
        "use_liger_kernel": True,  # Enable Liger kernel optimizations
        "fp16": True,  # Mixed precision training
        "gradient_checkpointing": True,
        "gradient_checkpointing_kwargs": {"use_reentrant": False}, # Required for LORA with DDP
    }

    logging.info("✓ Liger kernel optimizations enabled")

    training_args = SFTConfig(**training_args_dict)

    trainer = SFTTrainer(
        model=model,
        args=training_args,
        train_dataset=dataset["train"],
        eval_dataset=dataset["test"],
        processing_class=tokenizer,
        peft_config=peft_config,
    )

    logging.info("\n" + "="*50)
    logging.info("STARTING TRAINING")
    logging.info("="*50)

    logging.info("🚀 Training with Liger kernels for memory-efficient single GPU training")
    if use_lora:
        logging.info("🎯 Using LoRA for parameter-efficient fine-tuning")

    trainer.train()
    logging.info("\n✓ Training completed successfully!")
    if rt.get_global_rank() == 0:
        logging.info("\nSaving trained model...")

        logging.info("Saving LoRA adapter weights...")
        trainer.save_model(training_args.output_dir)
        logging.info("✓ LoRA adapters saved - use with base model for inference")
        tokenizer.save_pretrained(training_args.output_dir)
        logging.info("✓ Tokenizer saved with model")
        logging.info(f"\n🎉 All artifacts saved to: {training_args.output_dir}")

    mlflow_run_id = None
    if mlflow.last_active_run() is not None:
        mlflow_run_id = mlflow.last_active_run().info.run_id

    return mlflow_run_id

Ejecute el entrenamiento distribuido

Esta celda ejecuta la función de entrenamiento en 8 GPU H100. El distributed() método controla:

  • Aprovisionamiento de recursos de proceso de GPU sin servidor
  • Distribución de la carga de trabajo de entrenamiento entre varias GPU
  • Recopilación del identificador de ejecución de MLflow para el registro de modelos

El entrenamiento suele tardar entre 15 y 30 minutos en función del tamaño del conjunto de datos y la disponibilidad de proceso.

mlflow_run_id = run_train.distributed(use_lora=True)[0]
print(mlflow_run_id)

Registro del catálogo de MLflow y Unity

Estrategia de registro de modelos

  • Seguimiento de MLflow: artefactos y metadatos del modelo de registro
  • Catálogo de Unity: Registro del modelo para la gobernanza y la implementación
  • Control de versiones del modelo: control de versiones automático para la administración del ciclo de vida del modelo
  • Metadatos: información completa del modelo para la reproducibilidad
print("\nRegistering model with MLflow and Unity Catalog...")

from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
import mlflow
from mlflow import transformers as mlflow_transformers

try:
    # Load the trained model for registration
    print("Loading LoRA model for registration...")
    # For LoRA models, we need both base model and adapter
    base_model = AutoModelForCausalLM.from_pretrained(
        MODEL_NAME,
        trust_remote_code=True
    )
    # Load tokenizer
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
    adapter_dir = OUTPUT_DIR
    peft_model = PeftModel.from_pretrained(base_model, adapter_dir)
    # Merge LoRA into base and drop PEFT wrappers
    merged_model = peft_model.merge_and_unload()

    components = {
        "model": merged_model,
        "tokenizer": tokenizer,
    }

    # Create Unity Catalog model name
    full_model_name = f"{UC_CATALOG}.{UC_SCHEMA}.{UC_MODEL_NAME}"

    print(f"Registering model as: {full_model_name}")

    # Start MLflow run and log model
    task = "llm/v1/chat"
    with mlflow.start_run(run_id=mlflow_run_id):
        model_info = mlflow.transformers.log_model(
            transformers_model=components,
            artifact_path="model",
            task=task,
            registered_model_name=full_model_name,
            metadata={
                "task": task,
                "pretrained_model_name": MODEL_NAME,
                "databricks_model_family": "QwenForCausalLM",
            },
        )

    print(f"✓ Model successfully registered in Unity Catalog: {full_model_name}")
    print(f"✓ MLflow model URI: {model_info.model_uri}")

    # Print deployment information
    print(f"\n📦 Model Registration Complete!")
    print(f"Unity Catalog Path: {full_model_name}")
    print(f"Model Type: {model_type}")
    print(f"Optimization: Liger Kernels + LoRA")

except Exception as e:
    print(f"✗ Model registration failed: {e}")
    print("Model is still saved locally and can be registered manually")
    print(f"Local model path: {OUTPUT_DIR}")

Pasos siguientes

Ahora que ha ajustado y registrado el modelo, puede hacer lo siguiente:

Cuaderno de ejemplo

Optimización distribuida de Qwen2-0.5B con LoRA

Obtener el cuaderno