Del via


Fine-tune Olmo3 7B with Axolotl on multi-GPU serverless compute

This notebook demonstrates how to fine-tune the Olmo3 7B Instruct model using Axolotl on Databricks serverless GPU compute. Axolotl provides a high-performance framework for LLM post-training with QLoRA (Quantized Low-Rank Adaptation), enabling efficient fine-tuning on multi-GPU infrastructure. The trained model is logged to MLflow and registered in Unity Catalog for deployment.

Install required dependencies

Installs Axolotl with Flash Attention support, MLflow for experiment tracking, and compatible versions of transformers and optimization libraries. The cut-cross-entropy package provides memory-efficient loss computation for large language models.

%pip install -U packaging setuptools wheel ninja
%pip install mlflow>=3.6
%pip install --no-build-isolation axolotl[flash-attn]>=0.12.0
%pip install transformers==4.57.3
%pip uninstall -y awq autoawq
%pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@f4b5712"
dbutils.library.restartPython()

Retrieve HuggingFace token

Retrieves the HuggingFace authentication token from Databricks secrets. This token is required to download the Olmo3 7B base model from the HuggingFace Hub.

HF_TOKEN = dbutils.secrets.get(scope="sgc-nightly-notebook", key="hf_token")

Configure training parameters

Sets up the Axolotl training configuration based on the olmo3-7b-qlora.yaml example. Key modifications include:

  • MLflow integration for experiment tracking
  • Unity Catalog volume path for checkpoint storage
  • SDPA (Scaled Dot Product Attention) instead of Flash Attention for broader GPU compatibility

Define Unity Catalog paths

Creates widgets to specify the Unity Catalog location for storing model checkpoints. The output directory combines the catalog, schema, volume, and model name into a fully qualified path.

dbutils.widgets.text("uc_catalog", "main")
dbutils.widgets.text("uc_schema", "default")
dbutils.widgets.text("uc_volume", "checkpoints")
dbutils.widgets.text("model", "openai/gpt-oss-20b")

UC_CATALOG = dbutils.widgets.get("uc_catalog")
UC_SCHEMA = dbutils.widgets.get("uc_schema")
UC_VOLUME = dbutils.widgets.get("uc_volume")
UC_MODEL_NAME = dbutils.widgets.get("model")

print(f"UC_CATALOG: {UC_CATALOG}")
print(f"UC_SCHEMA: {UC_SCHEMA}")
print(f"UC_VOLUME: {UC_VOLUME}")
print(f"UC_MODEL_NAME: {UC_MODEL_NAME}")

OUTPUT_DIR = f"/Volumes/{UC_CATALOG}/{UC_SCHEMA}/{UC_VOLUME}/{UC_MODEL_NAME}"
print(f"OUTPUT_DIR: {OUTPUT_DIR}")
import os
os.environ['AXOLOTL_DO_NOT_TRACK'] = '1'

Disable telemetry

Disables Axolotl's usage tracking by setting the environment variable.

Create Axolotl configuration

Defines the complete training configuration using Axolotl's DictDefault format. This includes model settings (QLoRA with 4-bit quantization), dataset configuration (Alpaca format), LoRA hyperparameters (rank 32, alpha 16), training parameters (1 epoch, batch size 2, gradient accumulation 4), and MLflow integration for experiment tracking.

from axolotl.cli.config import load_cfg
from axolotl.utils.dict import DictDefault

# Config is based on with some changes to fit GPU types
# https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main/examples/olmo3/olmo3-7b-qlora.yaml

# Axolotl provides full control and transparency over model and training configuration
config = DictDefault(
    base_model="allenai/Olmo-3-7B-Instruct-SFT",
    plugins=[
        "axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin"
    ],
    load_in_8bit=False,
    load_in_4bit=True,
    datasets=[
        {
            "path": "fozziethebeat/alpaca_messages_2k_test",
            "type": "chat_template"
        }
    ],
    dataset_prepared_path="last_run_prepared",
    val_set_size=0.1,
    output_dir=OUTPUT_DIR,
    adapter="qlora",
    lora_model_dir=None,
    sequence_len=2048,
    sample_packing=True,
    lora_r=32,
    lora_alpha=16,
    lora_dropout=0.05,
    lora_target_linear=True,
    lora_target_modules=[
        "gate_proj",
        "down_proj",
        "up_proj",
        "q_proj",
        "v_proj",
        "k_proj",
        "o_proj"
    ],
    wandb_project=None,
    wandb_entity=None,
    wandb_watch=None,
    wandb_name=None,
    wandb_log_model=None,
    gradient_accumulation_steps=4,
    micro_batch_size=2,
    num_epochs=1,
    optimizer="adamw_bnb_8bit",
    lr_scheduler="cosine",
    learning_rate=0.0002,
    bf16="auto",
    tf32=False,
    gradient_checkpointing=True,
    resume_from_checkpoint=None,
    logging_steps=1,
    flash_attention=False,
    warmup_ratio=0.1,
    evals_per_epoch=1,
    saves_per_epoch=1,
    # Eval dataset is too small
    eval_sample_packing=False,
    # Write metrics to MLflow
    use_mlflow=True,
    mlflow_tracking_uri="databricks",
    mlflow_run_name="olmo3-7b-qlora-axolotl",
    hf_mlflow_log_artifacts=False,
    wandb_mode="disabled",
    attn_implementation="sdpa",
    sdpa_attention=True,
    save_first_step=True,
    device_map=None,
)
from axolotl.utils import set_pytorch_cuda_alloc_conf

set_pytorch_cuda_alloc_conf()

Configure PyTorch CUDA memory allocation

Optimizes GPU memory management for efficient training on multi-GPU setups.

Run distributed training on serverless GPU compute

Uses the @distributed decorator from the serverless GPU API to distribute the Axolotl training job across 8 H100 GPUs. The decorator handles multi-GPU orchestration, allowing the training function to run in a distributed environment without manual cluster setup.

from serverless_gpu.launcher import distributed
from serverless_gpu.compute import GPUType

@distributed(gpus=8, gpu_type=GPUType.H100)
def run_train(cfg: DictDefault):
    import os
    os.environ['HF_TOKEN'] = HF_TOKEN

    from axolotl.common.datasets import load_datasets

    # Load, parse and tokenize the datasets to be formatted with qwen3 chat template
    # Drop long samples from the dataset that overflow the max sequence length

    # validates the configuration
    cfg = load_cfg(cfg)
    dataset_meta = load_datasets(cfg=cfg)

    from axolotl.train import train

    # just train the first 16 steps for demo.
    # This is sufficient to align the model as we've used packing to maximize the trainable samples per step.
    cfg.max_steps = 16
    model, tokenizer, trainer = train(cfg=cfg, dataset_meta=dataset_meta)

    import mlflow
    mlflow_run_id = None
    if mlflow.last_active_run() is not None:
        mlflow_run_id = mlflow.last_active_run().info.run_id

    return mlflow_run_id
result = run_train.distributed(config)

Execute the training job

Launches the distributed training job. The function loads the dataset, validates the configuration, trains the model for 16 steps, and returns the MLflow run ID for tracking.

run_id = result[0]
print(run_id)

Extract MLflow run ID

Retrieves the MLflow run ID from the training results for model registration and experiment tracking.

Register the fine-tuned model to Unity Catalog

Loads the trained LoRA adapter, merges it with the base model, and registers the combined model to Unity Catalog via MLflow. This makes the model available for deployment and inference.

Note: This step requires H100 GPU compute to load the model checkpoint. Running on smaller GPUs may result in CUDA out-of-memory errors.

from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
try:
    from transformers.activations import NewGELUActivation, PytorchGELUTanh, GELUActivation
except ImportError:
    from transformers.activations import NewGELUActivation, GELUTanh as PytorchGELUTanh, GELUActivation

from peft import PeftModel
import mlflow
from mlflow import transformers as mlflow_transformers
import torch

HF_MODEL_NAME = "allenai/Olmo-3-7B-Instruct-SFT"

torch.cuda.empty_cache()
# Load the trained model for registration
print("Loading LoRA model for registration...")
# For LoRA models, we need both base model and adapter
base_model = AutoModelForCausalLM.from_pretrained(
    HF_MODEL_NAME,
    trust_remote_code=True
)
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(HF_MODEL_NAME)
adapter_dir = OUTPUT_DIR
peft_model = PeftModel.from_pretrained(base_model, adapter_dir)
# Merge LoRA into base and drop PEFT wrappers
merged_model = peft_model.merge_and_unload()

components = {
    "model": merged_model,
    "tokenizer": tokenizer,
}

# Create Unity Catalog model name
full_model_name = f"{UC_CATALOG}.{UC_SCHEMA}.{UC_MODEL_NAME}"

print(f"Registering model as: {full_model_name}")

text_gen_pipe = pipeline(
    task="text-generation",
    model=peft_model,
    tokenizer=tokenizer,
)

input_example = ["Hello, world!"]

with mlflow.start_run(run_id=run_id):
    model_info = mlflow.transformers.log_model(
        transformers_model=text_gen_pipe,   # 🚨 pass the pipeline, not just the model
        artifact_path="model",
        input_example=input_example,
        # optional: save_pretrained=False for reference-only PEFT logging
        # save_pretrained=False,
    )
# Start MLflow run and log model
print(f"✓ Model successfully registered in Unity Catalog: {full_model_name}")
print(f"✓ MLflow model URI: {model_info.model_uri}")
print(f"✓ Model version: {model_info.registered_model_version}")

# Print deployment information
print(f"\n📦 Model Registration Complete!")
print(f"Unity Catalog Path: {full_model_name}")
print(f"Optimization: Liger Kernels + LoRA")

Next steps

Example notebook

Fine-tune Olmo3 7B with Axolotl on multi-GPU serverless compute

Get notebook