次の方法で共有


Lightning を使用した 2 つのタワーレコメンデーション モデルの分散トレーニング

このノートブックでは、PyTorch Lightning Trainer API を使用して 2 つのタワーレコメンデーション モデルを作成し、1 つのノード上の 8 個の H100 GPU に分散トレーニングする方法を示します。

アラーム:

  • このデモでは、 GPU 8xH100 にアタッチして、複数の GPU に分散トレーニングを活用します。
  • @distributed Python ライブラリの serverless_gpu デコレーターは、PyTorch Lightning トレーニング関数を 8 個の H100 GPU に分散します。

開始するには、サーバーレス GPU を使用するようにノートブックを構成します。

  1. 上部にある [接続 ] ドロップダウンをクリックして、コンピューティング セレクターを開きます。
  2. [サーバーレス GPU] を選択します
  3. 右側の [環境 ] パネルを開きます。
  4. アクセラレータとして 8xH100 を選択します。
  5. 続行するために、環境パネルでパッケージの依存関係を構成する必要はありません。 [ 適用] をクリックし、[ 確認] をクリックします。

これで、ノートブックがサーバーレス GPU コンピューティングに接続されました。 @distributedデコレーターは、8 つの GPU すべてでトレーニングの開始を処理します。

前提条件

このデモを実行する前に、このノートブックの上部にあるウィジェット変数を構成します。

  • catalog: トレーニング済みのモデルが登録される Unity カタログ カタログ。
  • schema: トレーニング済みモデルが登録される上記のカタログの Unity カタログ スキーマ。

データセットは、ノートブックの実行中に自動的にダウンロードされます。 モデルは <catalog>.<schema>.<model_name_in_registry>に保存されます。

二重塔レコメンデーションモデル

2 つのタワーレコメンデーション モデルの詳細については、次のリソースを参照してください。

###Instructions: 以下のコードでは、次の方法について説明します。

  1. 依存関係のインストール
  2. データセットのダウンロードと準備
  3. 必要なトレーニング構成
  4. 2 つのタワー推奨モデルの定義
  5. メイン トレーニング関数の作成
  6. 2 つのタワー モデルのトレーニング
  7. 推論を実行する
  8. サービスを提供するためにモデルを MLflow に登録する

##1) 依存関係をインストールする 最初に、必要なすべてのライブラリをインストールし、環境を準備します。

Pip インストールの依存関係

TorchRecおよび関連する依存関係は公式にリリースされておらず、サーバーレス GPU コンピューティングの一部としてインストールされていないため、これらのライブラリを手動でインストールします。

%pip install mlflow==3.7
%pip install -q iopath==0.1.10 pyre_extensions
%pip install -q --upgrade --no-deps --force-reinstall fbgemm-gpu==0.8.0 torchrec==0.8.0 torchmetrics==1.0.3 --index-url https://download.pytorch.org/whl/cu124
%pip install lightning==2.6.1
dbutils.library.restartPython()

####パッケージのインポート

コード全体で使用される一連のインポートがあります。 次のセルは、必要なすべてのインポートを統合します。

# General Imports
import os
import urllib.request
import zipfile
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Tuple

# Data Processing Imports
import pandas as pd
import numpy as np
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split

# Databricks Specific Imports
import mlflow
from mlflow import MlflowClient
from mlflow.models.signature import infer_signature
from mlflow.pyfunc import PythonModel

# Torch Specific Imports
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset
from torchmetrics.classification import AUROC

# PyTorch Lightning
import lightning.pytorch as pl
from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint, DeviceStatsMonitor
from lightning.pytorch.loggers import MLFlowLogger

# TorchRec Specific Imports
from torchrec.datasets.utils import Batch
from torchrec.modules.embedding_configs import EmbeddingBagConfig
from torchrec.modules.embedding_modules import EmbeddingBagCollection
from torchrec.modules.mlp import MLP
from torchrec.optim.keyed import KeyedOptimizerWrapper
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor

##2) データセットをダウンロードして準備 します。データセットから学習 をダウンロードし、それを前処理し、トレーニング/検証/テスト セットに分割します。

dbutils.widgets.text("catalog", "main")
dbutils.widgets.text("schema", "default")
dbutils.widgets.text("volume", "recsys")

catalog = dbutils.widgets.get("catalog")
schema = dbutils.widgets.get("schema")
volume = dbutils.widgets.get("volume")
DATASET_URL = "https://files.grouplens.org/datasets/learning-from-sets-2019/learning-from-sets-2019.zip"

DATASET_PATH = f"/Volumes/{catalog}/{schema}/{volume}/dataset"
ZIP_PATH = f"{DATASET_PATH}/learning-from-sets-2019.zip"
CSV_PATH = f"{DATASET_PATH}/learning-from-sets-2019/item_ratings.csv"

# Download and extract
if not os.path.exists(CSV_PATH):
    os.makedirs(DATASET_PATH, exist_ok=True)
    print("Downloading dataset...")
    urllib.request.urlretrieve(DATASET_URL, ZIP_PATH)
    with zipfile.ZipFile(ZIP_PATH, "r") as zf:
        zf.extractall(DATASET_PATH)
    print("Download complete.")

# Load and preprocess
df = pd.read_csv(CSV_PATH)
df = df.sort_values(["userId", "movieId"]).head(100_000)

# Encode userId to contiguous integers
user_encoder = LabelEncoder()
df["userId"] = user_encoder.fit_transform(df["userId"])

# Binarize ratings: 1 if >= mean, else 0
mean_rating = df["rating"].mean()
df["label"] = (df["rating"] >= mean_rating).astype(np.int64)
df = df[["userId", "movieId", "label"]]

# Compute embedding table sizes from data
num_users = int(df["userId"].nunique())
num_movies = int(df["movieId"].nunique())
print(f"Dataset: {len(df)} rows, {num_users} users, {num_movies} movies")

# Split: 70% train, 21% validation, 9% test
train_df, temp_df = train_test_split(df, test_size=0.3, random_state=42)
val_df, test_df = train_test_split(temp_df, test_size=0.33, random_state=42)
print(f"Train: {len(train_df)}, Validation: {len(val_df)}, Test: {len(test_df)}")
class RecDataset(Dataset):
    """Wraps a DataFrame with columns [userId, movieId, label] as a PyTorch Dataset."""
    def __init__(self, dataframe: pd.DataFrame):
        self.users = dataframe["userId"].values.astype(np.int64)
        self.movies = dataframe["movieId"].values.astype(np.int64)
        self.labels = dataframe["label"].values.astype(np.int64)

    def __len__(self) -> int:
        return len(self.labels)

    def __getitem__(self, idx: int) -> dict:
        return {"userId": self.users[idx], "movieId": self.movies[idx], "label": self.labels[idx]}


def get_dataloader(dataframe: pd.DataFrame, batch_size: int = 1024, shuffle: bool = True) -> DataLoader:
    return DataLoader(RecDataset(dataframe), batch_size=batch_size, shuffle=shuffle, num_workers=2, pin_memory=True)

##3) 必要なトレーニング構成

このトレーニング例に必要なすべての引数と情報は、次のセルに統合されます。 これらはすべて、ユース ケースに合わせて変更できます。

@dataclass
class Args:
    epochs: int = 3
    embedding_dim: int = 128
    layer_sizes: List[int] = field(default_factory=lambda: [128, 64])
    learning_rate: float = 0.01
    batch_size: int = 1024

cat_cols = ["userId", "movieId"]
emb_counts = [num_users, num_movies]  # computed from data in section 2

##4) 2 つのタワーレコメンデーション モデルの定義

このセクションでは、PyTorch Lightning を使用してモデルを定義します。 詳細については、次のドキュメントを参照してください。

class TwoTowerModel(nn.Module):
    def __init__(
        self,
        embedding_bag_collection: EmbeddingBagCollection,
        layer_sizes: List[int],
        device: Optional[torch.device] = None
    ) -> None:
        super().__init__()
        assert len(embedding_bag_collection.embedding_bag_configs()) == 2, "Expected two EmbeddingBags in the two tower model"
        assert embedding_bag_collection.embedding_bag_configs()[0].embedding_dim == embedding_bag_collection.embedding_bag_configs()[1].embedding_dim, "Both EmbeddingBagConfigs must have the same dimension"
        embedding_dim = embedding_bag_collection.embedding_bag_configs()[0].embedding_dim
        self._feature_names_query: List[str] = embedding_bag_collection.embedding_bag_configs()[0].feature_names
        self._candidate_feature_names: List[str] = embedding_bag_collection.embedding_bag_configs()[1].feature_names
        self.ebc = embedding_bag_collection
        self.query_proj = MLP(in_size=embedding_dim, layer_sizes=layer_sizes, device=device)
        self.candidate_proj = MLP(in_size=embedding_dim, layer_sizes=layer_sizes, device=device)

    def forward(self, kjt: KeyedJaggedTensor) -> Tuple[torch.Tensor, torch.Tensor]:
        pooled_embeddings = self.ebc(kjt)
        query_embedding: torch.Tensor = self.query_proj(
            torch.cat(
                [pooled_embeddings[feature] for feature in self._feature_names_query],
                dim=1,
            )
        )
        candidate_embedding: torch.Tensor = self.candidate_proj(
            torch.cat(
                [pooled_embeddings[feature] for feature in self._candidate_feature_names],
                dim=1,
            )
        )
        return query_embedding, candidate_embedding

class LitTwoTower(pl.LightningModule):
    """
    PyTorch Lightning module wrapping a TwoTowerModel.
    Uses torchmetrics AUROC for train/val metrics.
    """
    def __init__(
        self,
        two_tower: nn.Module,
        device: torch.device,
        emb_counts: Optional[List[int]],
        cat_cols: List[str],
        lr: float = 1e-3,
    ) -> None:
        super().__init__()
        self.two_tower = two_tower
        self.loss_fn = nn.BCEWithLogitsLoss()
        self.train_auroc = AUROC(task="binary")
        self.val_auroc = AUROC(task="binary")
        self.lr = lr

        # Store metadata used in batch transform
        self.emb_counts = emb_counts
        self.cat_cols = cat_cols

        self.save_hyperparameters(ignore=["two_tower", "device"])

    def forward(self, batch: Dict[str, Any]) -> torch.Tensor:
        kjt_batch = self._transform_to_torchrec_batch(batch, self.emb_counts)
        query_embedding, candidate_embedding = self.two_tower(kjt_batch.sparse_features)
        logits = (query_embedding * candidate_embedding).sum(dim=1).squeeze()
        return logits

    def _loss(self, outputs: torch.Tensor, batch: Dict[str, Any]) -> torch.Tensor:
        labels = self._get_batch_labels(batch)
        return self.loss_fn(outputs, labels)

    def _update_metric(self, batch: Dict[str, Any], outputs: Optional[torch.Tensor], metric: AUROC) -> None:
        if outputs is None:
            outputs = self.forward(batch)
        preds = torch.sigmoid(outputs)
        labels = self._get_batch_labels(batch)
        metric.update(preds, labels)

    def training_step(self, batch: Dict[str, Any], batch_idx: int):
        logits = self.forward(batch)
        loss = self._loss(logits, batch)

        # Metric update
        self._update_metric(batch, logits, self.train_auroc)

        # Log both step and epoch loss series; enable sync_dist for multi-GPU/DDP
        self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)

        self.log("train_auroc", self.train_auroc, on_step=False, on_epoch=True, prog_bar=True,
             logger=True, sync_dist=True)

        return loss

    def validation_step(self, batch: Dict[str, Any], batch_idx: int):
        logits = self.forward(batch)
        loss = self._loss(logits, batch)

        self._update_metric(batch, logits, self.val_auroc)

        # Typically only epoch-level val metrics are needed for monitoring
        self.log("val_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
        self.log("val_auroc", self.val_auroc, on_step=False, on_epoch=True, prog_bar=True,
             logger=True, sync_dist=True)

    def configure_optimizers(self):
        optimizer = KeyedOptimizerWrapper(
            dict(self.two_tower.named_parameters()),
            lambda params: torch.optim.Adam(params, lr=self.lr),
        )
        return optimizer

    def _get_batch_labels(self, batch: Dict[str, Any]) -> torch.Tensor:
        return batch["label"].to(dtype=torch.float32, device=self.device)

    def _transform_to_torchrec_batch(
        self,
        batch: Dict[str, Any],
        num_embeddings_per_feature: Optional[List[int]],
    ) -> Batch:
        kjt_values_list = []
        kjt_lengths_list = []
        for col_idx, col_name in enumerate(self.cat_cols):
            values = batch[col_name]
            num_emb = num_embeddings_per_feature[col_idx]
            kjt_values_list.append(values % num_emb)
            kjt_lengths_list.append(torch.ones(len(values), dtype=torch.int64))

        values_t = torch.cat(kjt_values_list).to(dtype=torch.int64, device=self.device)
        lengths_t = torch.cat(kjt_lengths_list).to(device=self.device)

        sparse_features = KeyedJaggedTensor.from_lengths_sync(
            self.cat_cols,
            values_t,
            lengths_t,
        )

        labels = batch["label"].to(dtype=torch.int64, device=self.device)

        return Batch(
            dense_features=torch.zeros(1, device=self.device),
            sparse_features=sparse_features,
            labels=labels,
        )

def create_two_tower_model(args, device, cat_cols, emb_counts) -> LitTwoTower:
    eb_configs = [
        EmbeddingBagConfig(
            name=f"t_{feature_name}",
            embedding_dim=args.embedding_dim,
            num_embeddings=emb_counts[feature_idx],
            feature_names=[feature_name],
        )
        for feature_idx, feature_name in enumerate(cat_cols)
    ]
    ebc = EmbeddingBagCollection(tables=eb_configs, device=device)
    base = TwoTowerModel(
        embedding_bag_collection=ebc, layer_sizes=args.layer_sizes, device=device
    )
    lit = LitTwoTower(
        base, cat_cols=cat_cols, emb_counts=emb_counts, device=device, lr=args.learning_rate
    )
    return lit

##5) メイン トレーニング関数を作成する

次に、@distributed ライブラリの serverless_gpu デコレーターと、ヘルパー関数と PyTorch Lightning による Trainer API を使用して、複数の GPU でトレーニングを開始します。

# setup mlflow experiment
username = spark.sql("SELECT current_user()").first()['current_user()']
experiment_path = f'/Users/{username}/sgc-torchrec-example'
experiment = mlflow.set_experiment(experiment_path)
os.environ["MLFLOW_EXPERIMENT_NAME"] = experiment_path
from serverless_gpu import distributed

# get DB Host and Token
db_host = f"https://{dbutils.notebook.entry_point.getDbutils().notebook().getContext().browserHostName().get()}/"
db_token = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiToken().get()

# setup arguments for training function
args = Args(epochs=1)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
CHECKPOINT_PATH = f"/Volumes/{catalog}/{schema}/{volume}/checkpoints"

@distributed(gpus=8, gpu_type="H100")
def training_function(args=args, cat_cols=cat_cols, emb_counts=emb_counts, device=device,
                      train_data=train_df, val_data=val_df, checkpoint_path=CHECKPOINT_PATH):
    mlflow.pytorch.autolog()
    model = create_two_tower_model(args, device=device, cat_cols=cat_cols, emb_counts=emb_counts)
    train_dataloader = get_dataloader(train_data, batch_size=args.batch_size, shuffle=True)
    eval_dataloader = get_dataloader(val_data, batch_size=args.batch_size, shuffle=False)

    mlflow_logger = MLFlowLogger(
        experiment_name=experiment_path,
        log_model="all",
    )

    ckpt_cb = ModelCheckpoint(
        dirpath=checkpoint_path,
        monitor="val_auroc",
        mode="max",
        save_top_k=1,
        save_last=True,                        # enables last_model_path
        filename="{epoch}-{val_auroc:.4f}",
    )

    callbacks = [
        LearningRateMonitor(logging_interval="step"),
        DeviceStatsMonitor(),
        ckpt_cb,
    ]

    trainer = Trainer(
        max_epochs=args.epochs,
        accelerator="gpu",
        strategy="ddp",
        devices=8,
        log_every_n_steps=20,
        logger=mlflow_logger,
        callbacks=callbacks,
    )
    trainer.fit(
        model,
        train_dataloaders=train_dataloader,
        val_dataloaders=eval_dataloader
    )

    # Return run_id and best checkpoint path
    result = {
        "run_id": trainer.logger.run_id,                   # MLflow run id
        "best_model_checkpoint": ckpt_cb.best_model_path,  # best checkpoint path
        "last_model_checkpoint": ckpt_cb.last_model_path   # last checkpoint path
    }
    return result

##6) サーバーレス GPU 分散トレーニング API を使用して 2 つのタワー モデルをトレーニングする

result = training_function.distributed()

##7) 最適なモデル チェックポイントをテストする

最適なモデル チェックポイントを取得し、テストを実行して結果を確認する

print(f"Experiment Name: {experiment.name}")
print(f"Experiment ID: {experiment.experiment_id}")
print(f"Artifact Location: {experiment.artifact_location}")
print(f"Lifecycle_stage: {experiment.lifecycle_stage}")

ranked_checkpoints = mlflow.search_logged_models(
  experiment_ids=[experiment.experiment_id],
  output_format="list",
  order_by=[{"field_name": "metrics.accuracy", "ascending": False}]
)

best_checkpoint: mlflow.entities.LoggedModel = ranked_checkpoints[0]
print(best_checkpoint.metrics[0])
run_id = best_checkpoint.source_run_id
artifact_path = best_checkpoint.artifact_location
model_uri = f"runs:/{run_id}/{artifact_path}"
two_tower_model = mlflow.pytorch.load_model(model_uri)

num_batches = 5 # Number of batches to print out at a time
batch_size = 1 # Print out each individual row

test_dataloader = iter(get_dataloader(test_df, batch_size=batch_size, shuffle=False))

device = torch.device("cuda:0")
two_tower_model.to(device)
two_tower_model.eval()

for _ in range(num_batches):
    next_batch = next(test_dataloader)
    expected_result = next_batch["label"][0]

    actual_result = two_tower_model(next_batch)
    actual_result = torch.sigmoid(actual_result)
    print(f"Expected Result: {expected_result}; Actual Result: {actual_result.round().item()}")

##8) サービスを提供するためにモデルを MLflow に登録する

前の手順でモデルが正しく表示されたら、最新の実行の対応する run_id を使用してモデルを登録します。 これを簡単に提供するには、2 つのタワー モデルをラップして単純な入力 ( (Dict[str, List] -> List[float])) を取り込む PyFunc を作成します。

class TwoTowerWrapper(PythonModel):
    """
    MLflow PythonModel wrapper for TwoTower model that handles dictionary input and returns list outputs
    """
    def __init__(self, two_tower_model):
        self.two_tower_model = two_tower_model

    def predict(self, model_input: Dict[str, List]) -> List[float]:
        batch = {key: torch.tensor(value) for key, value in model_input.items()}
        if "label" not in batch:
            batch["label"] = torch.zeros(len(next(iter(batch.values()))))
        with torch.no_grad():
            output = self.two_tower_model(batch).cpu()
        output = torch.sigmoid(output)
        return output.tolist()
def preprocess_data(batch):
    # turn the example test dataset from Dict[str, Tensor] to Dict[str, List] and remove the label
    return {key: tensor.tolist() for key, tensor in batch.items() if key != "label"}

def add_and_get_model_signature(two_tower_model, test_dataloader):
    current_batch = preprocess_data(next(test_dataloader))

    pyfunc_two_tower_model = TwoTowerWrapper(two_tower_model)
    current_output = pyfunc_two_tower_model.predict(current_batch)
    signature = infer_signature(current_batch, current_output)
    logged_model = mlflow.pyfunc.log_model(
        artifact_path="two_tower_pyfunc",
        python_model=pyfunc_two_tower_model,
        signature=signature,
        input_example=current_batch
    )
    return signature, logged_model

signature, logged_model = add_and_get_model_signature(two_tower_model, test_dataloader)
model_name = "two_tower_model"
uc_model_version = mlflow.register_model(
    f"models:/{logged_model.model_id}",
    name=f"{catalog}.{schema}.{model_name}"
)

ノートブックの例

Lightning を使用した 2 つのタワーレコメンデーション モデルの分散トレーニング

ノートブックを入手