Partilhar via


Treino distribuído do modelo de recomendação de duas torres usando Lightning

Este notebook demonstra a forma de criar um modelo de recomendação de duas torres com a API PyTorch Lightning Trainer realizando treino distribuído em 8 GPUs H100 num único nó.

Lembretes:

  • Para esta demonstração, ligue à GPU 8xH100 para aproveitar o treino distribuído entre várias GPUs.
  • O @distributed decorador da serverless_gpu biblioteca Python irá distribuir a função de treino PyTorch Lightning por 8 GPUs H100.

Para começar, configura o teu portátil para usar GPU serverless:

  1. Clique no menu suspenso Connect na parte superior para abrir o seletor de computação.
  2. Selecione GPU Serverless.
  3. Abra o painel de Ambiente do lado direito.
  4. Seleciona o 8xH100 como acelerador.
  5. Não é necessário configurar dependências de pacotes no painel de ambiente para continuar. Clica em Candidatar-se e depois Confirmar.

O seu portátil está agora ligado a computação de GPU sem servidor. O @distributed decorador vai gerir o início do seu treino em todas as 8 GPUs.

Pré-requisitos

Antes de executar esta demonstração, configure as variáveis widget no topo deste caderno:

  • catalog: O catálogo Unity Catalog onde será registado o modelo treinado.
  • schema: O esquema do Catálogo Unity mencionado acima onde será registado o modelo treinado.

O conjunto de dados é descarregado automaticamente durante a execução do caderno. O modelo será guardado em <catalog>.<schema>.<model_name_in_registry>.

Modelo de recomendação de duas torres

Para mais informações sobre o modelo de recomendação de duas torres, consulte os seguintes recursos:

###Instructions: Abaixo, o código explica-te como:

  1. Instalar dependências
  2. Descarregue e prepare o conjunto de dados
  3. Configurações de treino obrigatórias
  4. Definição do modelo de recomendação de duas torres
  5. Criação da função principal de treino
  6. Treino do modelo de duas torres
  7. Realizar inferência
  8. Registe o seu modelo no MLflow para disponibilização

##1) Instalar dependências Para começar, vamos primeiro instalar todas as bibliotecas necessárias e garantir que o nosso ambiente está pronto para funcionar.

Instalar dependências com pip

Dado que TorchRec e as dependências associadas não foram oficialmente lançadas nem estão instaladas como parte da computação da GPU serverless, estas bibliotecas devem ser instaladas manualmente.

%pip install mlflow==3.7
%pip install -q iopath==0.1.10 pyre_extensions
%pip install -q --upgrade --no-deps --force-reinstall fbgemm-gpu==0.8.0 torchrec==0.8.0 torchmetrics==1.0.3 --index-url https://download.pytorch.org/whl/cu124
%pip install lightning==2.6.1
dbutils.library.restartPython()

####Importar pacotes

Existe um conjunto de importações utilizado ao longo do código. A célula seguinte consolida todas as importações necessárias.

# General Imports
import os
import urllib.request
import zipfile
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Tuple

# Data Processing Imports
import pandas as pd
import numpy as np
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split

# Databricks Specific Imports
import mlflow
from mlflow import MlflowClient
from mlflow.models.signature import infer_signature
from mlflow.pyfunc import PythonModel

# Torch Specific Imports
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset
from torchmetrics.classification import AUROC

# PyTorch Lightning
import lightning.pytorch as pl
from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint, DeviceStatsMonitor
from lightning.pytorch.loggers import MLFlowLogger

# TorchRec Specific Imports
from torchrec.datasets.utils import Batch
from torchrec.modules.embedding_configs import EmbeddingBagConfig
from torchrec.modules.embedding_modules import EmbeddingBagCollection
from torchrec.modules.mlp import MLP
from torchrec.optim.keyed import KeyedOptimizerWrapper
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor

##2) Descarregue e prepare o conjunto de dados Descarregue o conjunto de dados Learning from Sets , pré-processe-o e divida-o em conjuntos de treino/validação/teste.

dbutils.widgets.text("catalog", "main")
dbutils.widgets.text("schema", "default")
dbutils.widgets.text("volume", "recsys")

catalog = dbutils.widgets.get("catalog")
schema = dbutils.widgets.get("schema")
volume = dbutils.widgets.get("volume")
DATASET_URL = "https://files.grouplens.org/datasets/learning-from-sets-2019/learning-from-sets-2019.zip"

DATASET_PATH = f"/Volumes/{catalog}/{schema}/{volume}/dataset"
ZIP_PATH = f"{DATASET_PATH}/learning-from-sets-2019.zip"
CSV_PATH = f"{DATASET_PATH}/learning-from-sets-2019/item_ratings.csv"

# Download and extract
if not os.path.exists(CSV_PATH):
    os.makedirs(DATASET_PATH, exist_ok=True)
    print("Downloading dataset...")
    urllib.request.urlretrieve(DATASET_URL, ZIP_PATH)
    with zipfile.ZipFile(ZIP_PATH, "r") as zf:
        zf.extractall(DATASET_PATH)
    print("Download complete.")

# Load and preprocess
df = pd.read_csv(CSV_PATH)
df = df.sort_values(["userId", "movieId"]).head(100_000)

# Encode userId to contiguous integers
user_encoder = LabelEncoder()
df["userId"] = user_encoder.fit_transform(df["userId"])

# Binarize ratings: 1 if >= mean, else 0
mean_rating = df["rating"].mean()
df["label"] = (df["rating"] >= mean_rating).astype(np.int64)
df = df[["userId", "movieId", "label"]]

# Compute embedding table sizes from data
num_users = int(df["userId"].nunique())
num_movies = int(df["movieId"].nunique())
print(f"Dataset: {len(df)} rows, {num_users} users, {num_movies} movies")

# Split: 70% train, 21% validation, 9% test
train_df, temp_df = train_test_split(df, test_size=0.3, random_state=42)
val_df, test_df = train_test_split(temp_df, test_size=0.33, random_state=42)
print(f"Train: {len(train_df)}, Validation: {len(val_df)}, Test: {len(test_df)}")
class RecDataset(Dataset):
    """Wraps a DataFrame with columns [userId, movieId, label] as a PyTorch Dataset."""
    def __init__(self, dataframe: pd.DataFrame):
        self.users = dataframe["userId"].values.astype(np.int64)
        self.movies = dataframe["movieId"].values.astype(np.int64)
        self.labels = dataframe["label"].values.astype(np.int64)

    def __len__(self) -> int:
        return len(self.labels)

    def __getitem__(self, idx: int) -> dict:
        return {"userId": self.users[idx], "movieId": self.movies[idx], "label": self.labels[idx]}


def get_dataloader(dataframe: pd.DataFrame, batch_size: int = 1024, shuffle: bool = True) -> DataLoader:
    return DataLoader(RecDataset(dataframe), batch_size=batch_size, shuffle=shuffle, num_workers=2, pin_memory=True)

##3) Configurações de treino obrigatórias

Todos os argumentos e informações necessários para este exemplo de treino são consolidados na célula seguinte. Todos estes podem ser modificados para se adequarem ao seu caso de uso.

@dataclass
class Args:
    epochs: int = 3
    embedding_dim: int = 128
    layer_sizes: List[int] = field(default_factory=lambda: [128, 64])
    learning_rate: float = 0.01
    batch_size: int = 1024

cat_cols = ["userId", "movieId"]
emb_counts = [num_users, num_movies]  # computed from data in section 2

##4) Definição do modelo de recomendação de duas torres

Esta secção define o modelo usando o PyTorch Lightning. Para mais informações, consulte a documentação:

class TwoTowerModel(nn.Module):
    def __init__(
        self,
        embedding_bag_collection: EmbeddingBagCollection,
        layer_sizes: List[int],
        device: Optional[torch.device] = None
    ) -> None:
        super().__init__()
        assert len(embedding_bag_collection.embedding_bag_configs()) == 2, "Expected two EmbeddingBags in the two tower model"
        assert embedding_bag_collection.embedding_bag_configs()[0].embedding_dim == embedding_bag_collection.embedding_bag_configs()[1].embedding_dim, "Both EmbeddingBagConfigs must have the same dimension"
        embedding_dim = embedding_bag_collection.embedding_bag_configs()[0].embedding_dim
        self._feature_names_query: List[str] = embedding_bag_collection.embedding_bag_configs()[0].feature_names
        self._candidate_feature_names: List[str] = embedding_bag_collection.embedding_bag_configs()[1].feature_names
        self.ebc = embedding_bag_collection
        self.query_proj = MLP(in_size=embedding_dim, layer_sizes=layer_sizes, device=device)
        self.candidate_proj = MLP(in_size=embedding_dim, layer_sizes=layer_sizes, device=device)

    def forward(self, kjt: KeyedJaggedTensor) -> Tuple[torch.Tensor, torch.Tensor]:
        pooled_embeddings = self.ebc(kjt)
        query_embedding: torch.Tensor = self.query_proj(
            torch.cat(
                [pooled_embeddings[feature] for feature in self._feature_names_query],
                dim=1,
            )
        )
        candidate_embedding: torch.Tensor = self.candidate_proj(
            torch.cat(
                [pooled_embeddings[feature] for feature in self._candidate_feature_names],
                dim=1,
            )
        )
        return query_embedding, candidate_embedding

class LitTwoTower(pl.LightningModule):
    """
    PyTorch Lightning module wrapping a TwoTowerModel.
    Uses torchmetrics AUROC for train/val metrics.
    """
    def __init__(
        self,
        two_tower: nn.Module,
        device: torch.device,
        emb_counts: Optional[List[int]],
        cat_cols: List[str],
        lr: float = 1e-3,
    ) -> None:
        super().__init__()
        self.two_tower = two_tower
        self.loss_fn = nn.BCEWithLogitsLoss()
        self.train_auroc = AUROC(task="binary")
        self.val_auroc = AUROC(task="binary")
        self.lr = lr

        # Store metadata used in batch transform
        self.emb_counts = emb_counts
        self.cat_cols = cat_cols

        self.save_hyperparameters(ignore=["two_tower", "device"])

    def forward(self, batch: Dict[str, Any]) -> torch.Tensor:
        kjt_batch = self._transform_to_torchrec_batch(batch, self.emb_counts)
        query_embedding, candidate_embedding = self.two_tower(kjt_batch.sparse_features)
        logits = (query_embedding * candidate_embedding).sum(dim=1).squeeze()
        return logits

    def _loss(self, outputs: torch.Tensor, batch: Dict[str, Any]) -> torch.Tensor:
        labels = self._get_batch_labels(batch)
        return self.loss_fn(outputs, labels)

    def _update_metric(self, batch: Dict[str, Any], outputs: Optional[torch.Tensor], metric: AUROC) -> None:
        if outputs is None:
            outputs = self.forward(batch)
        preds = torch.sigmoid(outputs)
        labels = self._get_batch_labels(batch)
        metric.update(preds, labels)

    def training_step(self, batch: Dict[str, Any], batch_idx: int):
        logits = self.forward(batch)
        loss = self._loss(logits, batch)

        # Metric update
        self._update_metric(batch, logits, self.train_auroc)

        # Log both step and epoch loss series; enable sync_dist for multi-GPU/DDP
        self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)

        self.log("train_auroc", self.train_auroc, on_step=False, on_epoch=True, prog_bar=True,
             logger=True, sync_dist=True)

        return loss

    def validation_step(self, batch: Dict[str, Any], batch_idx: int):
        logits = self.forward(batch)
        loss = self._loss(logits, batch)

        self._update_metric(batch, logits, self.val_auroc)

        # Typically only epoch-level val metrics are needed for monitoring
        self.log("val_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
        self.log("val_auroc", self.val_auroc, on_step=False, on_epoch=True, prog_bar=True,
             logger=True, sync_dist=True)

    def configure_optimizers(self):
        optimizer = KeyedOptimizerWrapper(
            dict(self.two_tower.named_parameters()),
            lambda params: torch.optim.Adam(params, lr=self.lr),
        )
        return optimizer

    def _get_batch_labels(self, batch: Dict[str, Any]) -> torch.Tensor:
        return batch["label"].to(dtype=torch.float32, device=self.device)

    def _transform_to_torchrec_batch(
        self,
        batch: Dict[str, Any],
        num_embeddings_per_feature: Optional[List[int]],
    ) -> Batch:
        kjt_values_list = []
        kjt_lengths_list = []
        for col_idx, col_name in enumerate(self.cat_cols):
            values = batch[col_name]
            num_emb = num_embeddings_per_feature[col_idx]
            kjt_values_list.append(values % num_emb)
            kjt_lengths_list.append(torch.ones(len(values), dtype=torch.int64))

        values_t = torch.cat(kjt_values_list).to(dtype=torch.int64, device=self.device)
        lengths_t = torch.cat(kjt_lengths_list).to(device=self.device)

        sparse_features = KeyedJaggedTensor.from_lengths_sync(
            self.cat_cols,
            values_t,
            lengths_t,
        )

        labels = batch["label"].to(dtype=torch.int64, device=self.device)

        return Batch(
            dense_features=torch.zeros(1, device=self.device),
            sparse_features=sparse_features,
            labels=labels,
        )

def create_two_tower_model(args, device, cat_cols, emb_counts) -> LitTwoTower:
    eb_configs = [
        EmbeddingBagConfig(
            name=f"t_{feature_name}",
            embedding_dim=args.embedding_dim,
            num_embeddings=emb_counts[feature_idx],
            feature_names=[feature_name],
        )
        for feature_idx, feature_name in enumerate(cat_cols)
    ]
    ebc = EmbeddingBagCollection(tables=eb_configs, device=device)
    base = TwoTowerModel(
        embedding_bag_collection=ebc, layer_sizes=args.layer_sizes, device=device
    )
    lit = LitTwoTower(
        base, cat_cols=cat_cols, emb_counts=emb_counts, device=device, lr=args.learning_rate
    )
    return lit

##5) Criar a função principal de treino

De seguida, use o @distributed decorador da serverless_gpu biblioteca juntamente com as funções auxiliares e a Trainer API do PyTorch Lightning para iniciar o treinamento em múltiplas GPUs.

# setup mlflow experiment
username = spark.sql("SELECT current_user()").first()['current_user()']
experiment_path = f'/Users/{username}/sgc-torchrec-example'
experiment = mlflow.set_experiment(experiment_path)
os.environ["MLFLOW_EXPERIMENT_NAME"] = experiment_path
from serverless_gpu import distributed

# get DB Host and Token
db_host = f"https://{dbutils.notebook.entry_point.getDbutils().notebook().getContext().browserHostName().get()}/"
db_token = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiToken().get()

# setup arguments for training function
args = Args(epochs=1)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
CHECKPOINT_PATH = f"/Volumes/{catalog}/{schema}/{volume}/checkpoints"

@distributed(gpus=8, gpu_type="H100")
def training_function(args=args, cat_cols=cat_cols, emb_counts=emb_counts, device=device,
                      train_data=train_df, val_data=val_df, checkpoint_path=CHECKPOINT_PATH):
    mlflow.pytorch.autolog()
    model = create_two_tower_model(args, device=device, cat_cols=cat_cols, emb_counts=emb_counts)
    train_dataloader = get_dataloader(train_data, batch_size=args.batch_size, shuffle=True)
    eval_dataloader = get_dataloader(val_data, batch_size=args.batch_size, shuffle=False)

    mlflow_logger = MLFlowLogger(
        experiment_name=experiment_path,
        log_model="all",
    )

    ckpt_cb = ModelCheckpoint(
        dirpath=checkpoint_path,
        monitor="val_auroc",
        mode="max",
        save_top_k=1,
        save_last=True,                        # enables last_model_path
        filename="{epoch}-{val_auroc:.4f}",
    )

    callbacks = [
        LearningRateMonitor(logging_interval="step"),
        DeviceStatsMonitor(),
        ckpt_cb,
    ]

    trainer = Trainer(
        max_epochs=args.epochs,
        accelerator="gpu",
        strategy="ddp",
        devices=8,
        log_every_n_steps=20,
        logger=mlflow_logger,
        callbacks=callbacks,
    )
    trainer.fit(
        model,
        train_dataloaders=train_dataloader,
        val_dataloaders=eval_dataloader
    )

    # Return run_id and best checkpoint path
    result = {
        "run_id": trainer.logger.run_id,                   # MLflow run id
        "best_model_checkpoint": ckpt_cb.best_model_path,  # best checkpoint path
        "last_model_checkpoint": ckpt_cb.last_model_path   # last checkpoint path
    }
    return result

##6) Treinar o modelo de duas torres usando a API de treino distribuída por GPU serverless

result = training_function.distributed()

##7) Testar o melhor ponto de controlo modelo

Recupere o melhor checkpoint do modelo e execute um teste para verificar os resultados

print(f"Experiment Name: {experiment.name}")
print(f"Experiment ID: {experiment.experiment_id}")
print(f"Artifact Location: {experiment.artifact_location}")
print(f"Lifecycle_stage: {experiment.lifecycle_stage}")

ranked_checkpoints = mlflow.search_logged_models(
  experiment_ids=[experiment.experiment_id],
  output_format="list",
  order_by=[{"field_name": "metrics.accuracy", "ascending": False}]
)

best_checkpoint: mlflow.entities.LoggedModel = ranked_checkpoints[0]
print(best_checkpoint.metrics[0])
run_id = best_checkpoint.source_run_id
artifact_path = best_checkpoint.artifact_location
model_uri = f"runs:/{run_id}/{artifact_path}"
two_tower_model = mlflow.pytorch.load_model(model_uri)

num_batches = 5 # Number of batches to print out at a time
batch_size = 1 # Print out each individual row

test_dataloader = iter(get_dataloader(test_df, batch_size=batch_size, shuffle=False))

device = torch.device("cuda:0")
two_tower_model.to(device)
two_tower_model.eval()

for _ in range(num_batches):
    next_batch = next(test_dataloader)
    expected_result = next_batch["label"][0]

    actual_result = two_tower_model(next_batch)
    actual_result = torch.sigmoid(actual_result)
    print(f"Expected Result: {expected_result}; Actual Result: {actual_result.round().item()}")

##8) Registe o teu modelo no MLflow para servir

Quando o modelo do passo anterior estiver correto, use o correspondente run_id da última execução para registar o modelo. Para tornar isto fácil de disponibilizar, crie um PyFunc que incorpore o modelo de duas torres para simplificar a entrada: (Dict[str, List] -> List[float]).

class TwoTowerWrapper(PythonModel):
    """
    MLflow PythonModel wrapper for TwoTower model that handles dictionary input and returns list outputs
    """
    def __init__(self, two_tower_model):
        self.two_tower_model = two_tower_model

    def predict(self, model_input: Dict[str, List]) -> List[float]:
        batch = {key: torch.tensor(value) for key, value in model_input.items()}
        if "label" not in batch:
            batch["label"] = torch.zeros(len(next(iter(batch.values()))))
        with torch.no_grad():
            output = self.two_tower_model(batch).cpu()
        output = torch.sigmoid(output)
        return output.tolist()
def preprocess_data(batch):
    # turn the example test dataset from Dict[str, Tensor] to Dict[str, List] and remove the label
    return {key: tensor.tolist() for key, tensor in batch.items() if key != "label"}

def add_and_get_model_signature(two_tower_model, test_dataloader):
    current_batch = preprocess_data(next(test_dataloader))

    pyfunc_two_tower_model = TwoTowerWrapper(two_tower_model)
    current_output = pyfunc_two_tower_model.predict(current_batch)
    signature = infer_signature(current_batch, current_output)
    logged_model = mlflow.pyfunc.log_model(
        artifact_path="two_tower_pyfunc",
        python_model=pyfunc_two_tower_model,
        signature=signature,
        input_example=current_batch
    )
    return signature, logged_model

signature, logged_model = add_and_get_model_signature(two_tower_model, test_dataloader)
model_name = "two_tower_model"
uc_model_version = mlflow.register_model(
    f"models:/{logged_model.model_id}",
    name=f"{catalog}.{schema}.{model_name}"
)

Bloco de notas de exemplo

Treino distribuído do modelo de recomendação de duas torres usando Lightning

Obter caderno