Ajuste del modelo GPT-OSS 120B de OpenAI mediante el entrenamiento distribuido

En este cuaderno se muestra el ajuste fino supervisado (SFT) del modelo GPT-OSS de 120 mil millones de parámetros en 8 GPUs H100 utilizando Databricks Serverless GPU Compute. El entrenamiento aprovecha:

  • FSDP (Fully Sharded Data Parallel): divide los parámetros del modelo, los gradientes y los estados del optimizador entre varias GPUs para permitir el entrenamiento de modelos grandes que no caben en una sola GPU.
  • DDP (Distributed Data Parallel): distribuye el entrenamiento entre varias GPU para un entrenamiento más rápido.
  • LoRA (Low-Rank Adaptación):reduce el número de parámetros entrenables agregando capas de adaptadores pequeñas, lo que hace que el ajuste sea más eficaz.
  • TRL (Aprendizaje de refuerzo de transformadores): proporciona el SFTTrainer para el ajuste fino supervisado.

Al establecer remote=False y especificar 16 GPU, esto se puede extender al entrenamiento de varios nodos en 16 GPU.

Instalación de paquetes necesarios

Instale las bibliotecas necesarias para el entrenamiento distribuido y la optimización del modelo:

  • trl: Biblioteca de aprendizaje por refuerzo con transformadores para el entrenamiento de SFT
  • peft: Ajuste fino eficiente en parámetros para adaptadores de LoRA
  • transformers: Biblioteca de transformers de Hugging Face
  • datasets: para cargar conjuntos de datos de entrenamiento
  • accelerate: Para la orquestación de entrenamiento distribuido
  • hf_transfer: para descargas de modelos más rápidas de Hugging Face
%pip install "trl==1.1.0"
%pip install "peft==0.19.1"
%pip install "transformers==5.5.4"
%pip install "fsspec==2024.9.0"
%pip install "huggingface_hub==1.11.0"
%pip install "datasets==3.2.0"
%pip install "accelerate==1.13.0"
%restart_python

Definición de la función de entrenamiento distribuido con FSDP

Esta celda define la función de entrenamiento que se ejecutará mediante el decorador @distributed en 8 GPUs H100. La función incluye:

  • Carga del modelo: carga el parámetro 120B GPT-OSS modelo en precisión bfloat16
  • Configuración de LoRA: se aplica adaptación de bajo rango con el rango 16 para reducir los parámetros que se pueden entrenar
  • Configuración de FSDP: configura datos totalmente particionados paralelos con ajuste automático de capas y puntos de control de activación
  • Configuración de entrenamiento: establece el tamaño del lote, la velocidad de aprendizaje, la acumulación de degradado y otros hiperparámetros.
  • Conjunto de datos: usa el conjunto de datos HuggingFaceH4/Multilingual-Thinking para ajustar correctamente

La función detecta automáticamente las clases de bloques de transformadores para la envoltura FSDP y gestiona la coordinación de entrenamiento distribuido en todas las GPU.

dbutils.widgets.text("uc_catalog", "main")
dbutils.widgets.text("uc_schema", "default")
dbutils.widgets.text("uc_model_name", "gpt-oss-120b-peft")
dbutils.widgets.text("uc_volume", "checkpoints")
dbutils.widgets.text("model", "openai/gpt-oss-120b")
dbutils.widgets.text("dataset_path", "HuggingFaceH4/Multilingual-Thinking")

UC_CATALOG = dbutils.widgets.get("uc_catalog")
UC_SCHEMA = dbutils.widgets.get("uc_schema")
UC_MODEL_NAME = dbutils.widgets.get("uc_model_name")
UC_VOLUME = dbutils.widgets.get("uc_volume")
HF_MODEL_NAME = dbutils.widgets.get("model")
DATASET_PATH = dbutils.widgets.get("dataset_path")

print(f"UC_CATALOG: {UC_CATALOG}")
print(f"UC_SCHEMA: {UC_SCHEMA}")
print(f"UC_MODEL_NAME: {UC_MODEL_NAME}")
print(f"UC_VOLUME: {UC_VOLUME}")
print(f"HF_MODEL_NAME: {HF_MODEL_NAME}")
print(f"DATASET_PATH: {DATASET_PATH}")

OUTPUT_DIR = f"/Volumes/{UC_CATALOG}/{UC_SCHEMA}/{UC_VOLUME}/{UC_MODEL_NAME}"
print(f"OUTPUT_DIR: {OUTPUT_DIR}")
from serverless_gpu import distributed

@distributed(gpus=8, gpu_type='H100')
def train_gpt_oss_fsdp_120b():
    """
    Fine-tune a 120B-class model with TRL SFTTrainer + FSDP2 on H100s.
    Uses LoRA + activation ckpt + full_shard auto_wrap.
    """

    # --- imports inside for pickle safety ---
    import os, torch, torch.distributed as dist
    from transformers import AutoModelForCausalLM, AutoTokenizer, Mxfp4Config
    from trl import SFTTrainer, SFTConfig
    from datasets import load_dataset
    from peft import LoraConfig, get_peft_model

    # ---------- DDP / CUDA binding ----------
    local_rank = int(os.environ.get("LOCAL_RANK", "0"))
    torch.cuda.set_device(local_rank)
    os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
    os.environ.setdefault("NCCL_DEBUG", "WARN")
    os.environ.setdefault("CUDA_LAUNCH_BLOCKING", "0")

    os.environ.setdefault("TORCH_NCCL_ASYNC_ERROR_HANDLING", "1")  # replaces NCCL_ASYNC_ERROR_HANDLING

    # ---------- Config ----------
    MAX_LENGTH = 2048
    PER_DEVICE_BATCH = 1                 # start conservative for 120B
    GRAD_ACCUM = 4                       # tune for throughput
    LR = 1.5e-4
    EPOCHS = 1

    is_main  = int(os.environ.get("RANK", "0")) == 0
    world_size = int(os.environ.get("WORLD_SIZE", "1"))

    if is_main:
        print("=" * 60)
        print("FSDP (full_shard) launch for 120B")
        print(f"WORLD_SIZE={world_size} | LOCAL_RANK={local_rank}")
        print("=" * 60)

    # ---------- Tokenizer ----------
    tokenizer = AutoTokenizer.from_pretrained(HF_MODEL_NAME)
    if tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None:
        tokenizer.pad_token = tokenizer.eos_token
    tokenizer.model_max_length = MAX_LENGTH
    tokenizer.truncation_side = "right"

    # ---------- Model ----------
    # IMPORTANT: no device_map, no .to(device) — let Trainer/Accelerate+FSDP handle placement
    # low_cpu_mem_usage helps with massive checkpoints (still needs decent host RAM)
    quantization_config = Mxfp4Config(dequantize=True)
    model = AutoModelForCausalLM.from_pretrained(
        HF_MODEL_NAME,
        dtype=torch.bfloat16,
        attn_implementation="eager",
        quantization_config=quantization_config,
        use_cache=False,                  # needed for grad ckpt
        low_cpu_mem_usage=True,
    )

    # ---------- LoRA ----------
    # the following config works
    # include MoE layers as well.
    peft_config = LoraConfig(
        r=32,
        lora_alpha=32,
        target_modules="all-linear",
        rank_pattern={
            "mlp.experts.gate_up_proj": 8,
            "mlp.experts.down_proj": 8
        },
        target_parameters=["mlp.experts.gate_up_proj", "mlp.experts.down_proj"],
        lora_dropout=0.0,
        bias="none",
        task_type="CAUSAL_LM",
    )

    model = get_peft_model(model, peft_config)

    # Cast all parameters to bfloat16 so FSDP sees a uniform dtype
    # (LoRA adapters are initialized in float32 by default)
    model = model.to(torch.bfloat16)

    if is_main:
        model.print_trainable_parameters()

    # ---------- Data ----------
    dataset = load_dataset("HuggingFaceH4/Multilingual-Thinking", split="train")
    if is_main:
        print(f"Dataset size: {len(dataset)}")

    # ---------- FSDP settings ----------
    def infer_transformer_blocks_for_fsdp(model):
        COMMON = {
            "LlamaDecoderLayer", "MistralDecoderLayer", "MixtralDecoderLayer",
            "Qwen2DecoderLayer", "Gemma2DecoderLayer", "Phi3DecoderLayer",
            "GPTNeoXLayer", "MPTBlock", "BloomBlock", "FalconDecoderLayer",
            "DecoderLayer", "GPTJBlock", "OPTDecoderLayer"
        }
        hits = set()
        for _, m in model.named_modules():
            name = m.__class__.__name__
            if name in COMMON:
                hits.add(name)
        # Fallback: grab anything that *looks* like a decoder block
        if not hits:
            for _, m in model.named_modules():
                name = m.__class__.__name__
                if any(s in name for s in ["Block", "DecoderLayer", "Layer"]) and "Embedding" not in name:
                    hits.add(name)
        return sorted(hits)


    fsdp_wrap_classes = infer_transformer_blocks_for_fsdp(model)
    if not fsdp_wrap_classes:
        raise RuntimeError("Could not infer transformer block classes for FSDP wrapping; "
                       "print(model) and add the block class explicitly.")


    training_args = SFTConfig(
        output_dir=OUTPUT_DIR,
        num_train_epochs=EPOCHS,
        per_device_train_batch_size=PER_DEVICE_BATCH,
        gradient_accumulation_steps=GRAD_ACCUM,
        learning_rate=LR,
        warmup_ratio=0.03,
        lr_scheduler_type="cosine",
        bf16=True,
        logging_steps=5,
        logging_strategy="steps",
        save_strategy="no",
        report_to="none",
        ddp_find_unused_parameters=False,
        dataloader_pin_memory=True,
        max_length=MAX_LENGTH,
        gradient_checkpointing=False,

        # ---- FSDP2 knobs ----
        fsdp="full_shard auto_wrap",
        fsdp_config={
            "version": 2,
            "fsdp_transformer_layer_cls_to_wrap": fsdp_wrap_classes,
            "reshard_after_forward": True,
            "activation_checkpointing": True,    # <- use activation ckpt (not gradient)
            "xla": False,
            "limit_all_gathers": True,
        },
    )

    # ---------- Trainer ----------
    trainer = SFTTrainer(
        model=model,
        args=training_args,
        train_dataset=dataset,
        processing_class=tokenizer,
    )

    # verify distributed init & FSDP
    rank = int(os.getenv("RANK", "0"))
    print(f"[rank {rank}] dist.is_initialized() -> {dist.is_initialized()}")
    acc = getattr(trainer, "accelerator", None)
    print(f"[rank {rank}] accelerator.distributed_type = {getattr(getattr(acc,'state',None),'distributed_type','n/a')}")
    print(f"[rank {rank}] accelerator.num_processes = {getattr(acc, 'num_processes', 'n/a')}")

    # ---------- Train ----------
    result = trainer.train()

    if is_main:
        print("\nTraining complete (FSDP).")
        print(result.metrics)

Ejecuta el trabajo de entrenamiento distribuido

Ejecute la función de entrenamiento en 8 GPU H100. El @distributed decorador gestiona la orquestación del inicio del entrenamiento en todas las GPU con una configuración distribuida adecuada.

train_gpt_oss_fsdp_120b.distributed()

Pasos siguientes

Cuaderno de ejemplo

Ajuste del modelo GPT-OSS 120B de OpenAI mediante el entrenamiento distribuido

Obtener el cuaderno