Del via


Multi-GPU workload

Important

This feature is in Beta. Workspace admins can control access to this feature from the Previews page. See Manage Azure Databricks previews.

You can launch distributed workloads across multiple GPUs on a single node using the Serverless GPU Python API. The API provides a simple, unified interface that abstracts away the details of GPU provisioning, environment setup, and workload distribution. With minimal code changes, you can seamlessly move from single-GPU training to multi-GPU distributed execution from the same notebook.

Supported frameworks

The @distributed API integrates with major distributed training libraries:

  • PyTorch Distributed Data Parallel (DDP) — Standard multi-GPU data parallelism.
  • Fully Sharded Data Parallel (FSDP) — Memory-efficient training for large models.
  • DeepSpeed — Microsoft's optimization library for large model training.

serverless_gpu API vs. TorchDistributor

The following table compares the serverless_gpu @distributed API with TorchDistributor:

Feature serverless_gpu @distributed API TorchDistributor
Infrastructure Fully serverless, no cluster management Requires a Spark cluster with GPU workers
Setup Single decorator, minimal configuration Requires Spark cluster and TorchDistributor setup
Framework support PyTorch DDP, FSDP, DeepSpeed Primarily PyTorch DDP
Data loading Inside decorator, uses Unity Catalog Volumes Via Spark or filesystem

The serverless_gpu API is the recommended approach for new deep learning workloads on Databricks. TorchDistributor remains available for workloads tightly coupled with Spark clusters.

Quick start

The serverless GPU API for distributed training is preinstalled in Serverless GPU Compute environments for Databricks notebooks. We recommend GPU environment 4 and above. To use it for distributed training, import and use the distributed decorator to distribute your training function.

Wrap the model training code in a function and decorate the function with the @distributed decorator. The decorated function becomes the entrypoint for distributed execution — all training logic, data loading, and model initialization should be defined inside this function.

Warning

The gpu_type parameter in @distributed must match the accelerator type your notebook is connected to. For example, @distributed(gpus=8, gpu_type='H100') requires that your notebook is connected to an H100 accelerator. Using a mismatched accelerator type (such as connecting to A10 while specifying H100) will cause the workload to fail.

The code snippet below shows the basic usage of @distributed:

# Import the distributed decorator
from serverless_gpu import distributed

# Decorate your training function with @distributed and specify the number of GPUs and GPU type
@distributed(gpus=8, gpu_type='H100')
def run_train():
    ...

Below is a full example that trains a multilayer perceptron (MLP) model on 8 H100 GPUs from a notebook:

  1. Set up your model and define utility functions.

    
    # Define the model
    import os
    import torch
    import torch.distributed as dist
    import torch.nn as nn
    
    def setup():
        dist.init_process_group("nccl")
        torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
    
    def cleanup():
        dist.destroy_process_group()
    
    class SimpleMLP(nn.Module):
        def __init__(self, input_dim=10, hidden_dim=64, output_dim=1):
            super().__init__()
            self.net = nn.Sequential(
                nn.Linear(input_dim, hidden_dim),
                nn.ReLU(),
                nn.Dropout(0.2),
                nn.Linear(hidden_dim, hidden_dim),
                nn.ReLU(),
                nn.Dropout(0.2),
                nn.Linear(hidden_dim, output_dim)
            )
    
        def forward(self, x):
            return self.net(x)
    
  2. Import the serverless_gpu library and the distributed module.

    import serverless_gpu
    from serverless_gpu import distributed
    
  3. Wrap the model training code in a function and decorate the function with the @distributed decorator.

    @distributed(gpus=8, gpu_type='H100')
    def run_train(num_epochs: int, batch_size: int) -> None:
        import mlflow
        import torch.optim as optim
        from torch.nn.parallel import DistributedDataParallel as DDP
        from torch.utils.data import DataLoader, DistributedSampler, TensorDataset
    
        # 1. Set up multi-GPU environment
        setup()
        device = torch.device(f"cuda:{int(os.environ['LOCAL_RANK'])}")
    
        # 2. Apply the Torch distributed data parallel (DDP) library for data-parellel training.
        model = SimpleMLP().to(device)
        model = DDP(model, device_ids=[device])
    
        # 3. Create and load dataset.
        x = torch.randn(5000, 10)
        y = torch.randn(5000, 1)
    
        dataset = TensorDataset(x, y)
        sampler = DistributedSampler(dataset)
        dataloader = DataLoader(dataset, sampler=sampler, batch_size=batch_size)
    
        # 4. Define the training loop.
        optimizer = optim.Adam(model.parameters(), lr=0.001)
        loss_fn = nn.MSELoss()
    
        for epoch in range(num_epochs):
            sampler.set_epoch(epoch)
            model.train()
            total_loss = 0.0
            for step, (xb, yb) in enumerate(dataloader):
                xb, yb = xb.to(device), yb.to(device)
                optimizer.zero_grad()
                loss = loss_fn(model(xb), yb)
                # Log loss to MLflow metric
                mlflow.log_metric("loss", loss.item(), step=step)
    
                loss.backward()
                optimizer.step()
                total_loss += loss.item() * xb.size(0)
    
            mlflow.log_metric("total_loss", total_loss)
            print(f"Total loss for epoch {epoch}: {total_loss}")
    
        cleanup()
    
  4. Execute the distributed training by calling the distributed function with user-defined arguments.

    run_train.distributed(num_epochs=3, batch_size=1)
    
  5. When executed, an MLflow run link is be generated in the notebook cell output. Click the MLflow run link or find it in the Experiment panel to see the run results.

    Output in the notebook cell

Distributed execution details

Serverless GPU API consists of several key components:

  • Compute manager: Handles resource allocation and management
  • Runtime environment: Manages Python environments and dependencies
  • Launcher: Orchestrates job execution and monitoring

When running in distributed mode:

  • The function is serialized and distributed across the specified number of GPUs
  • Each GPU runs a copy of the function with the same parameters
  • The environment is synchronized across all GPUs
  • Results are collected and returned from all GPUs

The API supports popular parallel training libraries such as Distributed Data Parallel (DDP), Fully Sharded Data Parallel (FSDP), DeepSpeed.

You can find more real distributed training scenarios using the various libraries in notebook examples.

FAQs

Where should the data loading code be placed?

When using the Serverless GPU API for distributed training, move data loading code inside the @distributed decorator. The dataset size can exceed the maximum size allowed by pickle, so it is recommended to generate the dataset inside the decorator, as shown below:

from serverless_gpu import distributed

# this may cause pickle error
dataset = get_dataset(file_path)
@distributed(gpus=8, gpu_type='H100')
def run_train():
  # good practice
  dataset = get_dataset(file_path)
  ....

Learn more

For the API reference, refer to the Serverless GPU Python API documentation.