JetStream PyTorch inference on v6e TPU VMs
This tutorial shows how to use JetStream to serve PyTorch models on TPU v6e. JetStream is a throughput and memory optimized engine for large language model (LLM) inference on XLA devices (TPUs). In this tutorial, you run the inference benchmark for the Llama2-7B model.
Before you begin
Prepare to provision a TPU v6e with 4 chips:
- Sign in to your Google Account. If you haven't already, sign up for a new account.
- In the Google Cloud console, select or create a Google Cloud project from the project selector page.
- Enable billing for your Google Cloud project. Billing is required for all Google Cloud usage.
- Install the gcloud alpha components.
Run the following command to install the latest version of
gcloud
components.gcloud components update
Enable the TPU API through the following
gcloud
command using Cloud Shell. You can also enable it from the Google Cloud console.gcloud services enable tpu.googleapis.com
Create a service identity for the TPU VM.
gcloud alpha compute tpus tpu-vm service-identity create --zone=ZONE
Create a TPU service account and grant access to Google Cloud services.
Service accounts allow the Google Cloud TPU service to access other Google Cloud services. A user-managed service account is recommended. Follow these guides to create and grant roles. The following roles are necessary:
- TPU Admin: Needed to create a TPU
- Storage Admin: Needed for accessing Cloud Storage
- Logs Writer: Needed for writing logs with the Logging API
- Monitoring Metric Writer: Needed for writing metrics to Cloud Monitoring
Authenticate with Google Cloud and configure the default project and zone for Google Cloud CLI.
gcloud auth login gcloud config set project PROJECT_ID gcloud config set compute/zone ZONE
Secure capacity
Contact your Cloud TPU sales or account team to request TPU quota and to ask any questions about capacity.
Provision the Cloud TPU environment
You can provision v6e TPUs with GKE, with GKE and XPK, or as queued resources.
Prerequisites
- Verify that your project has enough
TPUS_PER_TPU_FAMILY
quota, which specifies the maximum number of chips you can access within your Google Cloud project. - This tutorial was tested with the following configuration:
- Python
3.10 or later
- Nightly software versions:
- nightly JAX
0.4.32.dev20240912
- nightly LibTPU
0.1.dev20240912+nightly
- nightly JAX
- Stable software versions:
- JAX + JAX Lib of
v0.4.35
- JAX + JAX Lib of
- Python
- Verify that your project has enough TPU quota for:
- TPU VM quota
- IP Address quota
- Hyperdisk balanced quota
- User project permissions
- If you are using GKE with XPK, see Cloud Console Permissions on the user or service account for the permissions needed to run XPK.
Create environment variables
In a Cloud Shell, create the following environment variables:
export NODE_ID=TPU_NODE_ID # TPU name export PROJECT_ID=PROJECT_ID export ACCELERATOR_TYPE=v6e-4 export ZONE=us-central2-b export RUNTIME_VERSION=v2-alpha-tpuv6e export SERVICE_ACCOUNT=YOUR_SERVICE_ACCOUNT export QUEUED_RESOURCE_ID=QUEUED_RESOURCE_ID export VALID_DURATION=VALID_DURATION # Additional environment variable needed for Multislice: export NUM_SLICES=NUM_SLICES # Use a custom network for better performance as well as to avoid having the # default network becoming overloaded. export NETWORK_NAME=${PROJECT_ID}-mtu9k export NETWORK_FW_NAME=${NETWORK_NAME}-fw
Command flag descriptions
Variable | Description |
NODE_ID | The user-assigned ID of the TPU that is created when the queued resource request is allocated. |
PROJECT_ID | Google Cloud project name. Use an existing project or create a new one. |
ZONE | See the TPU regions and zones document for the supported zones. |
ACCELERATOR_TYPE | See the Accelerator Typesdocumentation for the supported accelerator types. |
RUNTIME_VERSION | v2-alpha-tpuv6e
|
SERVICE_ACCOUNT | This is the email address for your service account that you can find in
Google Cloud console -> IAM -> Service Accounts
For example: tpu-service-account@<your_project_ID>.iam.gserviceaccount.com.com |
NUM_SLICES | The number of slices to create (needed for Multislice only) |
QUEUED_RESOURCE_ID | The user-assigned text ID of the queued resource request. |
VALID_DURATION | The duration for which the queued resource request is valid. |
NETWORK_NAME | The name of a secondary network to use. |
NETWORK_FW_NAME | The name of a secondary network firewall to use. |
Provision a TPU v6e
gcloud alpha compute tpus queued-resources create QUEUED_RESOURCE_ID \ --node-id TPU_NAME \ --project PROJECT_ID \ --zone ZONE \ --accelerator-type v6e-4 \ --runtime-version v2-alpha-tpuv6e \ --service-account SERVICE_ACCOUNT
Use the list
or describe
commands
to query the status of your queued resource.
gcloud alpha compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \
--project ${PROJECT_ID} --zone ${ZONE}
For a complete list of queued resource request statuses, see the Queued Resources documentation.
Connect to the TPU using SSH
gcloud compute tpus tpu-vm ssh TPU_NAME
Run the JetStream PyTorch Llama2-7B benchmark
To set up JetStream-PyTorch, convert the model checkpoints, and run the inference benchmark, follow the instructions in the GitHub repository.
When the inference benchmark is complete, be sure to clean up the TPU resources.
Clean up
Delete the TPU:
gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
--project ${PROJECT_ID} \
--zone ${ZONE} \
--force \
--async