Hinweis
Für den Zugriff auf diese Seite ist eine Autorisierung erforderlich. Sie können versuchen, sich anzumelden oder das Verzeichnis zu wechseln.
Für den Zugriff auf diese Seite ist eine Autorisierung erforderlich. Sie können versuchen, das Verzeichnis zu wechseln.
Dieses Notizbuch veranschaulicht, wie Sie ein Transformatormodell mithilfe von verteilten Schulungen mit pyTorchs Fully Sharded Data Parallel (FSDP) auf Databricks serverlosen GPU-Compute trainieren. FSDP ist eine Datenparallelitätstechnik, die Modellparameter, Gradienten und Optimiererzustände über mehrere GPUs hinweg aufteilt und das effiziente Training großer Modelle ermöglicht, die nicht auf eine einzelne GPU passen würden.
In diesem Beispiel erfahren Sie, wie Sie:
- Einrichten einer verteilten Schulung mit der verteilten GPU-Schulungs-API ohne Server
- Definieren und Trainieren eines 10M-Parametertransformatorenmodells mithilfe von FSDP
- Speichern verteilter Prüfpunkte während der Schulung
- Nachverfolgen von Experimenten mit MLflow
- Laden von Prüfpunkten für Rückschlüsse oder Weiterbildung
Dieses Notizbuch verwendet synthetische Daten, um es eigenständig zu halten, Sie können es jedoch anpassen, um mit Ihren eigenen Datasets zu arbeiten.
Schlüsselkonzepte:
- FSDP (Fully Sharded Data Parallel): Eine pyTorch verteilte Trainingsstrategie, die Modellparameter über GPUs hinweg verteilt, um die Speicherauslastung zu reduzieren und die Schulung größerer Modelle zu ermöglichen.
- Serverlose GPU-Compute: Vom Databricks verwaltete GPU-Compute, der Ressourcen für Ihre Workloads automatisch skaliert und bereitstellt.
Weitere Informationen finden Sie unter Verteiltes Training mit Multi-GPU und mehreren Knoten.
Abhängigkeiten installieren
Installieren Sie die neueste Version von MLflow zum Nachverfolgen und Modellprotokollieren von Experimenten.
%pip install -U mlflow
%restart_python
Konfigurieren von Unity-Katalog-Speicherorten
Richten Sie die Unity-Katalogspeicherorte ein, an denen das Modell und die Prüfpunkte gespeichert werden. Aktualisieren Sie diese Werte so, dass sie ihrer Arbeitsbereichskonfiguration entsprechen. Sie benötigen USE CATALOG und USE SCHEMA berechtigungen für den angegebenen Katalog und das angegebene Schema.
# You must have `USE CATALOG` privileges on the catalog, and you must have `USE SCHEMA` privileges on the schema.
# If necessary, change the catalog and schema name here.
dbutils.widgets.text("uc_catalog", "main")
dbutils.widgets.text("uc_schema", "default")
dbutils.widgets.text("model_name", "transformer_fsdp")
dbutils.widgets.text("uc_volume", "checkpoints")
UC_CATALOG = dbutils.widgets.get("uc_catalog")
UC_SCHEMA = dbutils.widgets.get("uc_schema")
UC_VOLUME = dbutils.widgets.get("uc_volume")
MODEL_NAME = dbutils.widgets.get("model_name")
UC_MODEL_NAME = f"{UC_CATALOG}.{UC_SCHEMA}.{MODEL_NAME}"
print(f"UC_CATALOG: {UC_CATALOG}")
print(f"UC_SCHEMA: {UC_SCHEMA}")
print(f"UC_VOLUME: {UC_VOLUME}")
print(f"UC_MODEL_NAME: {UC_MODEL_NAME}")
Definieren von Hilfsfunktionen und synthetischem Dataset
In diesem Abschnitt werden Hilfsfunktionen für verteilte Schulungseinrichtung und eine synthetische Datasetklasse für Demonstrationszwecke definiert. Im Produktionsumfeld würden Sie das SyntheticDataset durch Ihre eigene Datenladelogik ersetzen.
Wichtige Komponenten:
-
setup(): Initialisiert die verteilte Schulungsprozessgruppe und konfiguriert GPU-Geräte. -
cleanup(): Bereinigt die verteilte Prozessgruppe nach dem Training -
AppState: Eine Wrapperklasse für prüfpunktmodell und Optimiererstatus, die mit der verteilten Prüfpunkt-API von PyTorch kompatibel ist -
SyntheticDataset: Generiert zufällige Daten für die Schulungsdemonstration
import torch
import torch.nn as nn
import torch.optim as optim
import torch.distributed as dist
import torch.distributed.checkpoint as dcp
from torch.distributed.checkpoint.stateful import Stateful
from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict
from torch.distributed.checkpoint import FileSystemWriter as StorageWriter
import torch.multiprocessing as mp
from torch.distributed.fsdp import fully_shard
from torch.utils.data import Dataset, DataLoader, DistributedSampler
import numpy as np
import os
import time
# Below is an example of distributed checkpoint based on
# https://docs.pytorch.org/tutorials/recipes/distributed_async_checkpoint_recipe.html
class AppState(Stateful):
"""This is a useful wrapper for checkpointing the Application State. Since this object is compliant
with the Stateful protocol, DCP will automatically call state_dict/load_stat_dict as needed in the
dcp.save/load APIs.
Note: We take advantage of this wrapper to hande calling distributed state dict methods on the model
and optimizer.
"""
def __init__(self, model, optimizer=None):
self.model = model
self.optimizer = optimizer
def state_dict(self):
# this line automatically manages FSDP FQN's, as well as sets the default state dict type to FSDP.SHARDED_STATE_DICT
model_state_dict, optimizer_state_dict = get_state_dict(self.model, self.optimizer)
return {
"model": model_state_dict,
"optim": optimizer_state_dict
}
def load_state_dict(self, state_dict):
# sets our state dicts on the model and optimizer, now that we've loaded
set_state_dict(
self.model,
self.optimizer,
model_state_dict=state_dict["model"],
optim_state_dict=state_dict["optim"]
)
def setup():
"""Initialize the distributed training process group"""
# Check if we're in a distributed environment
if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
rank = int(os.environ['RANK'])
world_size = int(os.environ['WORLD_SIZE'])
local_rank = int(os.environ.get('LOCAL_RANK', 0))
else:
# Fallback for single GPU
rank = 0
world_size = 1
local_rank = 0
# Initialize process group
if world_size > 1:
if not dist.is_initialized():
dist.init_process_group(backend='nccl', rank=rank, world_size=world_size)
# Set device
if torch.cuda.is_available():
device = torch.device(f'cuda:{local_rank}')
torch.cuda.set_device(device)
else:
device = torch.device('cpu')
return rank, world_size, device
def cleanup():
"""Clean up the distributed training process group"""
if dist.is_initialized():
dist.destroy_process_group()
class SyntheticDataset(Dataset):
"""Simple synthetic dataset for demo purposes"""
def __init__(self, size=10000, input_dim=512, num_classes=10):
self.size = size
self.input_dim = input_dim
self.num_classes = num_classes
# Generate synthetic data
np.random.seed(42) # For reproducible results
self.data = torch.randn(size, input_dim)
# Create labels with some pattern
self.labels = torch.randint(0, num_classes, (size,))
def __len__(self):
return self.size
def __getitem__(self, idx):
return self.data[idx], self.labels[idx]
Definieren des Transformatormodells mit FSDP
In diesem Abschnitt wird ein einfaches Transformatormodell für die Klassifizierung und die Logik zum Anwenden von FSDP-Sharding definiert. Während FSDP in der Regel für große Sprachmodelle mit 7B+-Parametern verwendet wird, veranschaulicht dieses Beispiel die Technik mit einem kleineren 10M-Parametermodell, das über mehrere H100-GPUs verteilt ist.
Modellarchitektur:
-
TransformerBlock: Eine einzige Transformatorschicht mit Mehrkopf-Aufmerksamkeit und MLP -
SimpleTransformer: Ein Stapel von Transformatorblöcken mit Eingabeprojektion und Klassifizierungskopf -
apply_fsdp(): Verpackt Modellebenen mit FSDP für verteiltes Training
FSDP teilt die Modellparameter, Gradienten und Optimiererzustände auf GPUs auf, reduziert damit den Speicherbedarf pro GPU und ermöglicht das Training größerer Modelle.
class TransformerBlock(nn.Module):
"""Simple transformer block for testing FSDP"""
def __init__(self, dim=512, num_heads=8, mlp_ratio=4):
super().__init__()
self.attention = nn.MultiheadAttention(dim, num_heads, batch_first=True)
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
mlp_dim = int(dim * mlp_ratio)
self.mlp = nn.Sequential(
nn.Linear(dim, mlp_dim),
nn.GELU(),
nn.Linear(mlp_dim, dim),
)
def forward(self, x):
# Self-attention
attn_out, _ = self.attention(x, x, x)
x = self.norm1(x + attn_out)
# MLP
mlp_out = self.mlp(x)
x = self.norm2(x + mlp_out)
return x
class SimpleTransformer(nn.Module):
"""Simple transformer model for classification with FSDP"""
def __init__(self, input_dim=512, num_layers=64, num_classes=10):
super().__init__()
self.input_projection = nn.Linear(input_dim, input_dim)
self.layers = nn.ModuleList([
TransformerBlock(dim=input_dim) for _ in range(num_layers)
])
self.norm = nn.LayerNorm(input_dim)
self.classifier = nn.Linear(input_dim, num_classes)
def forward(self, x):
# Add sequence dimension for transformer
x = x.unsqueeze(1) # [batch, 1, input_dim]
x = self.input_projection(x)
for layer in self.layers:
x = layer(x)
x = self.norm(x)
# Global average pooling
x = x.mean(dim=1) # [batch, input_dim]
return self.classifier(x)
def apply_fsdp(model, world_size):
"""Apply FSDP to the model"""
if world_size > 1:
print("Applying FSDP to model layers...")
# Apply fsdp to each transformer layer
for i, layer in enumerate(model.layers):
fully_shard(layer)
print(f"Applied FSDP to layer {i}")
# Apply FSDP to the entire model
fully_shard(model)
print("Applied FSDP to entire model")
else:
print("Single GPU detected, skipping FSDP setup")
return model
Definieren der verteilten Schulungsfunktion
Die Schulungsfunktion wird mit dem @distributed Dekorierer aus der serverlosen GPU-API umschlossen. Dieser Dekorierer behandelt:
- Bereitstellen der angegebenen Anzahl von GPUs (8 H100 GPUs in diesem Beispiel)
- Einrichten der verteilten Schulungsumgebung
- Verwalten des Lebenszyklus von Remote-Computeressourcen
Die Schulungsfunktion umfasst:
- Modellinitialisierung und FSDP-Verpackung
- Datenladen mit
DistributedSamplerzur parallelen Datenverarbeitung - Trainingsschleife mit Gradientenaktualisierungen
- Regelmäßiges Speichern von Prüfpunkten mithilfe der verteilten Prüfpunkt-API von PyTorch
- MLflow-Logging zur Experimentverfolgung
Prüfpunkte werden in einem Unity-Katalogvolume gespeichert und als MLflow-Artefakte zur Versionsverwaltung und Reproduzierbarkeit protokolliert.
from serverless_gpu import distributed
from serverless_gpu.compute import GPUType
NUM_WORKERS = 8
CHECKPOINT_DIR = f"/Volumes/{UC_CATALOG}/{UC_SCHEMA}/{UC_VOLUME}/{MODEL_NAME}"
@distributed(gpus=NUM_WORKERS, gpu_type=GPUType.H100)
def run_fsdp_training(num_workers=NUM_WORKERS):
"""
Self-contained FSDP training demo using PyTorch 2.0+
Trains a simple neural network on synthetic data using FSDP
"""
import mlflow
mlflow.start_run(run_name='fsdp_example')
def main_training():
"""Main training function"""
print("Starting FSDP Training Demo...")
# Setup distributed training
rank, world_size, device = setup()
print(f"Rank: {rank}, World Size: {world_size}, Device: {device}")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
print(f"CUDA device count: {torch.cuda.device_count()}")
print(f"Current CUDA device: {torch.cuda.current_device()}")
# Create dataset and data loader
dataset = SyntheticDataset(size=10000, input_dim=512, num_classes=10)
# Use DistributedSampler if we have multiple processes
if world_size > 1:
sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
shuffle = False
else:
sampler = None
shuffle = True
dataloader = DataLoader(
dataset,
batch_size=32,
shuffle=shuffle,
sampler=sampler,
num_workers=num_workers,
pin_memory=True
)
# Create model
model = SimpleTransformer(input_dim=512, num_layers=4, num_classes=10).to(device)
# Apply FSDP
model = apply_fsdp(model, world_size)
print(f"Model created and moved to device: {device}")
if rank == 0:
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
# Loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=0.001, weight_decay=0.01)
# Training loop
num_epochs = 5
loss_history = []
print(f"Training for {num_epochs} epochs...")
writer = StorageWriter(cache_staged_state_dict=False, path=CHECKPOINT_DIR)
for epoch in range(num_epochs):
if sampler:
sampler.set_epoch(epoch)
model.train()
total_loss = 0.0
num_batches = 0
epoch_start_time = time.time()
for batch_idx, (data, target) in enumerate(dataloader):
data, target = data.to(device), target.to(device)
# Zero gradients
optimizer.zero_grad()
# Forward pass
output = model(data)
loss = criterion(output, target)
# Backward pass
loss.backward()
mlflow.log_metric(
key='loss',
value=loss.item(),
step=batch_idx,
)
# Update weights
optimizer.step()
total_loss += loss.item()
num_batches += 1
if batch_idx % 10 == 0:
print(f'Saving checkpoint to {CHECKPOINT_DIR}/step{batch_idx}')
state_dict = { 'app': AppState(model, optimizer) }
ckpt_start_time = time.time()
dcp.save(state_dict, storage_writer=writer, checkpoint_id=f"{CHECKPOINT_DIR}/step{batch_idx}")
ckpt_time = time.time() - ckpt_start_time
print(f'Checkpointing took {ckpt_time:.2f}s')
mlflow.log_artifacts(f'{CHECKPOINT_DIR}/step{batch_idx}', artifact_path=f'checkpoints/step{batch_idx}')
if rank == 0:
print(f'Epoch {epoch+1}/{num_epochs}, Batch {batch_idx}, Loss: {loss.item():.6f}')
# Calculate average loss for this epoch
avg_loss = total_loss / num_batches
mlflow.log_metric(key='avg_loss', value=avg_loss)
loss_history.append(avg_loss)
epoch_time = time.time() - epoch_start_time
if rank == 0:
print(f'Epoch {epoch+1}/{num_epochs} with {num_batches} completed in {epoch_time:.2f}s. Average Loss: {avg_loss:.6f}')
# Verify loss is decreasing
if rank == 0:
print("\n=== FSDP Training Results ===")
print("Loss history:")
for i, loss in enumerate(loss_history):
print(f"Epoch {i+1}: {loss:.6f}")
# Check if loss is generally decreasing
initial_loss = loss_history[0]
final_loss = loss_history[-1]
loss_reduction = ((initial_loss - final_loss) / initial_loss) * 100
print(f"\nInitial Loss: {initial_loss:.6f}")
print(f"Final Loss: {final_loss:.6f}")
print(f"Loss Reduction: {loss_reduction:.2f}%")
if final_loss < initial_loss:
print("✅ SUCCESS: FSDP training is working! Loss is decreasing.")
else:
print("❌ WARNING: Loss did not decrease. Check training configuration.")
print(f"\nFSDP training completed successfully on {world_size} GPU(s)")
# Cleanup
cleanup()
mlflow.end_run()
return {
'initial_loss': loss_history[0] if loss_history else None,
'final_loss': loss_history[-1] if loss_history else None,
'loss_history': loss_history,
'world_size': world_size,
'device': str(device),
'fsdp_enabled': world_size > 1
}
# Run the training
return main_training()
Führen Sie das verteilte Training aus
Führen Sie die Schulungsfunktion aus, um die verteilte Schulung über 8 H100 GPUs zu starten. Die .distributed() Methode löst die Remoteausführung auf serverlosem GPU-Compute aus. Schulungsfortschritt, Verlustmetriken und Prüfpunkte werden beim MLflow protokolliert.
Diese Zelle kann mehrere Minuten in Anspruch nehmen, da die GPU-Ressourcen bereitgestellt werden, das Modell für 5 Epochen trainiert wird und Checkpoints gespeichert werden.
print("Starting FSDP Demo on Databricks Serverless GPU...")
result = run_fsdp_training.distributed()
print("FSDP Demo completed!")
print(f"Training Results: {result}")
Laden eines Modellprüfpunkts
In diesem Abschnitt wird veranschaulicht, wie Sie einen gespeicherten Prüfpunkt für Rückschlüsse oder weiterbildungen laden. Der Prüfpunkt enthält die Modellgewichte und den Optimiererzustand, der während der Schulung gespeichert wurde.
Beachten Sie, dass beim Laden von Prüfpunkten außerhalb eines verteilten Schulungskontexts (keine Prozessgruppe initialisiert) die verteilte Prüfpunkt-API von PyTorch automatisch kollektive Vorgänge deaktiviert und den Prüfpunkt auf einem einzelnen Gerät lädt.
def run_checkpoint_load_example():
# create the non FSDP-wrapped toy model
model = SimpleTransformer(input_dim=512, num_layers=4, num_classes=10)
optimizer = optim.AdamW(model.parameters(), lr=0.001, weight_decay=0.01)
state_dict = { 'app': AppState(model, optimizer)}
# print(state_dict)
# since no progress group is initialized, DCP will disable any collectives.
dcp.load(
state_dict=state_dict,
checkpoint_id=f'{CHECKPOINT_DIR}/step0',
)
model.load_state_dict(state_dict['app'].state_dict()['model'])
run_checkpoint_load_example()
Nächste Schritte
Nachdem Sie nun gelernt haben, wie Sie PyTorch FSDP für verteilte Schulungen auf serverlosen GPU-Compute verwenden können, erkunden Sie diese Ressourcen, um mehr zu erfahren:
- Verteiltes Training mit mehreren GPU- und Multiknoten – Erfahren Sie mehr über verschiedene verteilte Schulungsstrategien
- Bewährte Methoden für serverlose GPU-Compute – Optimieren Ihrer GPU-Workloads
- Behandeln von Problemen auf serverlosem GPU-Compute – Häufige Probleme und Lösungen
- PyTorch FSDP-Dokumentation – Deep dive into FSDP features and configuration