Google TPUs
Introduction
- CPUs are have one really fast processor and are bound by memory load/save speed.
- GPUs have many tiny CPUs inside, and really large memory load/save speed.
- TPUs don’t have either. Instead, e.g. TPU v3:
- Has two systolic ALUs of size 128 x 128. These are ALUs that are laid out in 128 x 128 grid.
The next animation shows how the network “weights” are laid out inside a TPU.
And this animation shows the “systolic” movement of data inputs into TPUs:
So, as far as I understand, these two above animations show a pair-wise convolution between all weights and all inputs. The only thing that is really important here is that from the animations it seems TPUs don’t re-load the used inputs into HBM. From the animations it seems each ALU passes on the inputs to the next ALU vertically after processing. HBM access is really slow. If ALUs can really pass data upstream like that, that’d make the processing much faster. Interesting indeed.
On the horizontal side, it also seems that after processing a chunk of the weights x inputs, ALU passes the accumulated results to the right-side ALU. It is interesting what exactly passes in ALU-to-ALU transfer operations.
Recommendations:
- Reshape operations should be avoided. Shapes must stay constant. Shapes are actually “compiled”, or baked into model. If you don’t use that extra 8 cells, the compiled will still add those and discard them afterwards.
- Matrices should be as large as possible and at multiples of 8
- Matrix multiplications are easiest to perform. Any other operation (add, sub, reshape) are slower.
Summary
Surprise-surprise, vanilla transformers are really well-suited for google’s TPUs. Especially large ones with large enough hidden dimension. I’d be fair to say that depth of the model won’t matter as much as the hidden size.
Second, it’s best to use JAX for getting your model to work on a TPU. Don’t use PyTorch. JAX has a better support for TPUs. Don’t use TensorFlow under any circumstance.
We know something about the heart of the machine now – what about the rest? What about the TPU host, its disk size, case studies, etc.
System arch
TPU-related Terms
- Batch Inference: Run when available. Slow as heck.
- TPU chip contains TensorCores (TC). Better TPUs have more/faster TCs. TC has many matrix-multiply units (MXUs), a vector unit and scalar unit. MXU is either 256x256 (v6e and later) or 128x128 multiply accumulator. MXUs are the workhorse of the TPUs. TPUs do multiplication in bfloat16, and accumulation in FP32. Vector unit does activations and softmax. Scalar unit does control flow and mem address ops.
- TPU cube. This is a topology of several connected TPUs. TPU is adjacent to 6 other TPUs, I guess. Starts with v4 and later TPUs.
- Multislice is connection between several TPU slides. Of course, TPU slices communicate better than a multislice.
- Cloud TPU ICI resiliency – it’s just a safeguard used when connecting TPU cubes with each other.
- Queued resource – A representation of TPU resources, used to enqueue and manage a request for a single-slice or multi-slice TPU environment. Basically, queued resource is how you stand in queue, when wanting to access a TPU. This has to be connected with “Batch Inference”.
- Single host, multi host, and sub host – TPU “host” is a linux VM that controls a TPU. TPUs can have multiple hosts and a host can have many TPUs. A Sub-host workload would mean a process doesn’t use all TPUs of the host.
- Slice – Pod slice is a collection of chips that are connected with really fast chip interconnects (ICI). Slices can have different ICI connection shapes, this is called chip topology, and different number of chips. A slice can also be described in total number of TensorCores it has too.
- SparseCore – if you have a large recommendation engine that relies on embedding vectors, you likely want to use SparseCore. v5 and later chips have specialized hardware for this, called “SparseCores”.
- TPU Pod – A collection of TPUs that are physically close to each other and connected with a really fast network. TPU Pods are what you will use if you want to train a large model that doesn’t fit on a single TPU HBM.
- TPU VM or “worker” – A virtual machine running Linux that has access to the underlying TPUs. For practical purposes, it’s an ssh terminal that can compile and run TPU based XLA code.
- TensorCores – these do efficient matrix multiplications. See this ACM article for details.
- TPU versions – TPU chip architecture changes a lot between chips. AFAIK, there’s smaller incentive in keeping everything backwards-compatible, so google changes TPUs a lot between versions.
TPU VM images
A default vm for TPUs is tpu-ubuntu2204-base
. There are others too.
What’s inside tpu-ubuntu2204-base
?
TPU versions
All TPUs have bf16 mul, fp32 accumulation.
TPU v6e
Somewhat similar to v5e, but newer. Designed for both training and tuning.
Each v6e has 1 TC and each TC has 2 MXUs.
- Single host can have 8 TPUs at most
- Memory/BW: 32GB, 1640 Gbps
- Peak perf: 918 TFLOPs
- Max Pod size: 8960 TPUs (!!)
- ICI BW: 3584 Gbps
- Max pod size: 256 Chips
- All-reduce bandwidth per Pod: 102 TB/s
TODO how much CPU resources does host have in for v6e?
TPU v5p
For training. This chip is an absolute beast.
- Single host can have 8 TPUs at most
- Memory/BW: 95GB, 2765 Gbps
- Peak perf: 459 TFLOPs
- Max Pod size: 8960 TPUs (!!)
- ICI BW: 4800 Gbps
TPU v5e
For both training and serving.
- Single host can have 8 TPUs at most.
- Memory/BW: 16GB, 819 Gbps
- All reduce BW per pod: 51.2 TB/s
- Peak perf: 197 TFLOPs
- ICI BW: 1600 Gbps Interchip/intrahost is faster than CPU-to-chip.
- Max pod: 256 Chips
- Single CPU can most efficiently access TPUs 0,1,2,3.
Single-host AcceleratorType
-s are:
- v5litepod-1
- v5litepod-4
- v5litepod-8
After this, you need to use Sax
to manage several hosts (remember, 1 host has max 8 TPUs). Then it goes to:
- v5litepod-16, 32, etc …
TPU v4
- Single host can have 8 TPUs at most.
- Memory/BW: 32GB, 1200 Gbps
- All reduce BW per pod: 1100 TB/s
- Peak perf: 275 TFLOPs
- ICI BW: 1600 Gbps Interchip/intrahost is faster than CPU-to-chip.
- Max pod: 4096 Chips
Using v4-8
means having 8 TensorCores (TCs), with one single VM controlling them. This makes it the easiest to work with.
TPU VMs themselves
-
n2d-48-24-v5lite-tpu
forv5litepod-1
has:- 1 v5e TPU
- 24 CPUs
- 48 GB CPU RAM
- 1 NUMA
- Disruption: High. This just means that GKE update could terminate your run possibly. You shouldn’t worry about this, as long as you do periodic model checkpoints. Read up on this here.
-
n2d-192-112-v5lite-tpu
forv5litepod-1
has:- 4 v5e TPUs
- 112 CPUs
- 192 GB CPU RAM (!)
- 1 NUMA
- Fewer disruptions
-
n2d-384-224-v5lite-tpu
forv5litepod-1
has:- 8 v5e TPUs
- 224 CPUs (!)
- 384 GB CPU RAM (!!)
- 2 NUMAs
- Even fewer GKE disruptions
Larger TPUs just repeat use a number of hosts with n2d-384-224-v5lite-tpu
VMs. It’s interesting to know how to implement a connection between inter-host TPUs. I’m assuming
Regions and zones
In EU, TPUs are only in europe-west4
. Latest v6e is only in us-east1-d
, and us-east5-b
.
Supported models
There is this google repository called MaxText
. It’s an “optimization free” TPU-based LLM training codebase that is written in python/jax and that achieves a really high TPU-usage. There are samples for training or tuning Llama2, mistral and gemma.
MaxText
claims to be inspired by MinGPT, which itself is a small self-contained implementation of GPT-2 model in Python. MinGPT primarily uses just two main 300 line scripts, train.py
and model.py
.
Get started
Let’s do a code walkthrough.
Asking Google for TPUs
First, in google cloud, use search bar for “Cloud TPU”. Click “Enable”.
You then need to ask Google to give you TPU access. This is done again by using the above search bar, and typing “Quotas & System Limits”.
Down in the Filter
tab You need to search for “TPU v5 lite pod cores for serving per project per zone”, in the zone of your choice.
Click three dots on the right side and choose “Edit quota”. Follow the approval flow and you will likely get your access soon.
Running TPUs
First download the gcloud
shell command, and then run these commands by pasting them one-by-one in your console.
# TODO: use `gcloud projects list` and paste 'NAME' value below
export PROJECT_ID=
# TODO: use `gcloud iam service-accounts list` and paste 'EMAIL' value below
export SERVICE_ACCOUNT=xyz-compute@developer.gserviceaccount.com
# TODO: Come up with a simple unique name
export RESOURCE_NAME=v5litepod-1-resource
# TODO: Use browser "login with google" window to do auth
gcloud auth login # opens a browser window
# Enable TPUs
gcloud services enable tpu.googleapis.com
# Allow TPUs to access the rest of your gcloud with this account
gcloud beta services identity create \
--service tpu.googleapis.com \
--project $PROJECT_ID
# This prints
# "Service identity created: `service-xyz@cloud-tpu.iam.gserviceaccount.com`"
# This email isn't needed. Ignore it.
# Stand in a queue to access a TPU node
gcloud alpha compute tpus queued-resources create $RESOURCE_NAME \
--node-id v5litepod \
--project $PROJECT_ID \
--zone us-central1-a \
--accelerator-type v5litepod-1 \
--runtime-version v2-alpha-tpuv5-lite \
--valid-until-duration 1d \
--service-account $SERVICE_ACCOUNT
# check if it's your turn in the queue
gcloud alpha compute tpus queued-resources describe $RESOURCE_NAME \
--project $PROJECT_ID \
--zone us-central1-a
# You will see something like
# createTime: '2024-12-20T09:25:02.391983567Z'
# guaranteed: {}
# name: projects/<Project>/locations/us-central1-a/queuedResources/v5litepod-1-resource
# queueingPolicy:
# validUntilTime: '2024-12-21T09:25:02.423775424Z'
# state:
# state: PROVISIONING <----- REPEAT UNTIL THIS IS "ACTIVE"
# tpu:
# nodeSpec:
# - node:
# acceleratorType: v5litepod-1
# networkConfig:
# enableExternalIps: true
# network: default
# queuedResource: projects/<project>/locations/us-central1-a/queuedResources/v5litepod-1-resource
# runtimeVersion: v2-alpha-tpuv5-lite
# schedulingConfig: {}
# serviceAccount:
# email: xyz-compute@developer.gserviceaccount.com
# shieldedInstanceConfig: {}
# nodeId: v5litepod
# parent: projects/<project>/locations/us-central1-a
# ssh into TPU node
gcloud alpha compute tpus tpu-vm ssh v5litepod \
--project $PROJECT_ID \
--zone us-central1-a
## DON'T FORGET TO STOP THE TPU!
gcloud alpha compute tpus queued-resources delete v5litepod-1-resource \
--project $PROJECT_ID \
--zone us-central1-a --force --quiet
## Delete the disk too
gcloud compute disks delete $DISK_NAME \
--zone $ZONE
You will see something like:
SSH key found in project metadata; not updating instance.
Using ssh batch size of 1. Attempting to SSH into 1 nodes with a total of 1 workers.
SSH: Attempting to connect to worker 0...
Warning: Permanently added 'tpu.xyz-0-voj4cz' (ED25519) to the list of known hosts.
Welcome to Ubuntu 22.04.3 LTS (GNU/Linux 6.5.0-1013-gcp x86_64)
* Documentation: https://help.ubuntu.com
* Management: https://landscape.canonical.com
* Support: https://ubuntu.com/pro
System information as of Fri Dec 20 09:27:55 UTC 2024
System load: 1.09521484375 Processes: 353
Usage of /: 8.8% of 96.73GB Users logged in: 0
Memory usage: 2% IPv4 address for docker0: xyz
Swap usage: 0% IPv4 address for ens6: xyz
Expanded Security Maintenance for Applications is not enabled.
145 updates can be applied immediately.
88 of these updates are standard security updates.
To see these additional updates run: apt list --upgradable
7 additional security updates can be applied with ESM Apps.
Learn more about enabling ESM Apps service at https://ubuntu.com/esm
The list of available updates is more than a week old.
To check for new updates run: sudo apt update
The programs included with the Ubuntu system are free software;
the exact distribution terms for each program are described in the
individual files in /usr/share/doc/*/copyright.
Ubuntu comes with ABSOLUTELY NO WARRANTY, to the extent permitted by
applicable law.
tornikeo@t1v-n-5cb9aa1d-w-0:~$
Yay. At this point you can access your TPU. If you are in EU like me, you will get a terminal that is really laggy, since the actual TPU is in US. Latency is really high. You can take a look around.
Look around inside TPU VM
To see how much disk space you have, use df -h
. I’m getting
something like below:
tornikeo@t1v-n-5cb9aa1d-w-0:~$ df -h
Filesystem Size Used Avail Use% Mounted on
/dev/root 97G 14G 84G 15% /
tmpfs 24G 0 24G 0% /dev/shm
tmpfs 9.5G 1.9M 9.5G 1% /run
tmpfs 5.0M 0 5.0M 0% /run/lock
efivarfs 56K 24K 27K 48% /sys/firmware/efi/efivars
/dev/sda15 105M 6.1M 99M 6% /boot/efi
tmpfs 4.8G 4.0K 4.8G 1% /run/user/2001
AFAIK, this means we have around 84GB total, from which we already use 14GB.
To list everything that’s physically connected to VM hardware details, use lspci
:
tornikeo@t1v-n-5cb9aa1d-w-0:~$ lspci
00:00.0 Host bridge: Intel Corporation 440FX - 82441FX PMC [Natoma] (rev 02)
00:01.0 ISA bridge: Intel Corporation 82371AB/EB/MB PIIX4 ISA (rev 03)
00:01.3 Bridge: Intel Corporation 82371AB/EB/MB PIIX4 ACPI (rev 03)
00:03.0 IOMMU: Advanced Micro Devices, Inc. [AMD] Milan IOMMU
00:04.0 Non-VGA unclassified device: Red Hat, Inc. Virtio SCSI
00:05.0 Unassigned class [ff00]: Google, Inc. Device 0063
00:06.0 Ethernet controller: Google, Inc. Compute Engine Virtual Ethernet [gVNIC]
00:07.0 Unclassified device [00ff]: Red Hat, Inc. Virtio RNG
I’m guessing the 00:05.0 Unassigned class [ff00]: Google, Inc. Device 0063
could be our TPU. Who knows?
We can learn more about CPUs, by using hwinfo | less
:
tornikeo@t1v-n-5cb9aa1d-w-0:~$ hwinfo | less
...
----- /proc/cpuinfo -----
processor : 0
vendor_id : AuthenticAMD
cpu family : 25
model : 1
model name : AMD EPYC 7B13
stepping : 0
microcode : 0xffffffff
cpu MHz : 2450.000
cache size : 512 KB
physical id : 0
siblings : 24
core id : 0
cpu cores : 12
apicid : 0
initial apicid : 0
fpu : yes
fpu_exception : yes
cpuid level : 13
wp : yes
flags : fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm cmp_legacy cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw topoext invpcid_single ssbd ibrs ibpb stibp vmmcall fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 clzero xsaveerptr arat npt nrip_save umip vaes vpclmulqdq rdpid fsrm
bugs : sysret_ss_attrs null_seg spectre_v1 spectre_v2 spec_store_bypass srso
bogomips : 4900.00
TLB size : 2560 4K pages
clflush size : 64
cache_alignment : 64
address sizes : 48 bits physical, 48 bits virtual
power management:
We have 24 such AMD CPUs:
tornikeo@t1v-n-5cb9aa1d-w-0:~$ nproc
24
Total RAM we have is 49GB:
tornikeo@t1v-n-5cb9aa1d-w-0:~$ cat /proc/meminfo | grep MemTotal
MemTotal: 49316192 kB
Install packages
We already have python and pip, so let’s install packages inside the TPU VM. This can take 3-4 minutes.
pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html
pip install torch_xla[pallas]
pip install timm
Exit
gcloud alpha compute tpus queued-resources delete v5litepod-1-resource \
--zone us-central1-a --force --quiet \
--project $PROJECT_ID
Making life easier
At this point, you can likely run what you need. Now, of course you don’t need to sit around waiting for pip to finish running. Instead you can specify a startup script that does literally everything.
There’s a neat trick on gcloud, documented here, that allows you to add a startup script to any linux-based instance including TPUs! Here’s how you use it:
For example
gcloud alpha compute tpus queued-resources create $RESOURCE_NAME \
--node-id v5litepod \
--project $PROJECT_ID \
--zone us-central1-a \
--accelerator-type v5litepod-1 \
--runtime-version v2-alpha-tpuv5-lite \
--valid-until-duration 1d \
--service-account $SERVICE_ACCOUNT \
--metadata startup-script='#! /bin/bash
touch tornikeo-was-here.txt # to make sure it all works.
pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html
pip install torch_xla[pallas]
pip install timm
EOF'
Or, if you want to keep the startup file separately, use:
gcloud alpha compute tpus queued-resources create $RESOURCE_NAME \
--node-id v5litepod \
--project $PROJECT_ID \
--zone us-central1-a \
--accelerator-type v5litepod-1 \
--runtime-version v2-alpha-tpuv5-lite \
--valid-until-duration 1d \
--service-account $SERVICE_ACCOUNT \
--metadata-from-file startup-script=FILE_PATH
You can get startup script logs inside ssh, by reading /var/log/syslog
:
tornikeo@t1v-n-6b747f45-w-0:~$ grep startup-scrip /var/log/syslog
This can help you in debugging the startup script in a TPU VM.
Either way, after the instance is ACTIVE, you can check that our startup script was indeed executed in SSH:
tornikeo@t1v-n-6b747f45-w-0:~$ ls /
bin dev home lib32 libx32 media opt root sbin srv tmp usr
boot etc lib lib64 lost+found mnt proc run snap sys tornikeo-was-here.txt var
The startup script can very well be a script that does literally everything:
- Download data
- Download model (if pretrained)
- Authenticate with external services (e.g. github, wandb, huggingface)
- Clone or copy over your full code from github branch or local machine
- Run training, fine-tuning, eval and push model checkpoints to your HF repo
- Shut self down
- hint: you can simply use
gcloud alpha compute tpus queued-resources delete v5litepod-1-resource --force --quiet
inside the ssh script. - If something goes wrong and you don’t reach that line, things might get expensive. Make sure to reserve the TPU instance only as long as you think you will need.
- hint: you can simply use
This way you can rapidly test model hyper-parameters.
Configure your TPU
Storage Options
This section is about the disk of TPU VM. If you TPU instance is small,
say v5litepod-1
then you don’t get a lot of disk space and you might run out as you download the dataset.
Most times if you are performing any serious experiments, you are better off using a FUSE (see below), which basically attaches a gs://
bucket to your system as if it were a regular directory. This can also be used for huggingface datasets.
Add a Persistent Disk (PD)
By default, each TPU VM has 100GB boot persistent disk with OS and all packages. This can be used for storing data, as long as the data is small enough. Usually, you need around 20GB for packages and most of the 80GB can be used.
Most times, this 100GB won’t be enough. You need to add a persistent disk (PD) to a TPU VM. PD lives on after TPU gets disposed of, so you could store data or models there.
Use gs://
bucket
You could use a gs://my-bucket
as a storage. Even if that bucket is fastest and in the same region, it will be much slower than a PD, but also not restricted in size.
Use gs://
like a local folder (gcsfuse)
GCSFUSE
is for people who want to use a Cloud Storage Bucket (i.e. upload/download to gs://
), but do so with regular cp
, mv
, rm
semantics, like the bucket were an actual local disk. Simpler than attaching a persistent disk and simpler than using a Cloud Storage Bucket, but could be slower than a disk if you use it incorrectly.
I have used gcsfuse
both as a way to asynchronously upload model artifacts, and to download really large data. In my experience, simply using a huggingface dataset with streaming=True
is a simpler solution.
gcsfuse can be installed in a startup script, like so:
apt update && \
apt install -y numactl lsb-release gnupg curl net-tools iproute2 procps lsof git ethtool && \
export GCSFUSE_REPO=gcsfuse-`lsb_release -c -s`
echo "deb https://packages.cloud.google.com/apt $GCSFUSE_REPO main" | tee /etc/apt/sources.list.d/gcsfuse.list
curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | apt-key add -
apt update -y && apt -y install gcsfuse
In the same startup script, you can attach a bucket. Note bucket-name
does not contain gs://
prefix, simply provide the bucket name.
mkdir -p /tmp/fuse/datasets/
gcsfuse -o rw --implicit-dirs \
--type-cache-max-size-mb=-1 --stat-cache-max-size-mb=-1 --kernel-list-cache-ttl-secs=-1 --metadata-cache-ttl-secs=-1 \
bucket-name /tmp/fuse/datasets/
I also encountered an error with gcsfuse, which was solved with
pip install proto-plus==1.24.0.dev1
Filestore file share (like FUSE, but faster)
This seems to be an evolved version of FUSE. AFAIK, the smallest size of dataset Filestore supports is 1TB at least.
Training and inference
v6e TPU
As of Dec 2024, this v6e TPU is the latest. It is somewhat similar to v5e, with maximum Pod size of 256 chips.
Due to waiting for the quota increase, I’ll cover this specific chip a bit later on.
v5p TPU
To use this chip, you first need to increase your quota of TPUV5PPerProjectPerZoneForTPUAPI
in the cloud IAM and Admin page on quotas. By default you don’t have access to v5p.
Note the minimum amount of v5p chips you can have is 8. Looking at the pricing page this means that, per hour, the minimum price for a v5p-8
is $33 (!). SO keep this in mind.
While going through the documentation, I found that accessing even the smallest v5p-8
instance is really slow. With queue, I waited for around half an hour before giving up.
My recommendation definitely is to try fitting your problem on v5e
instances first, and only then try to scale up to v5p
.
v5e TPU
This TPU variant is marketed as a the most value-per-dollar data-center chip. Let’s try it out.
The setup is the usual, but notice the --worker=all
. This is needed because we actually will work with a v5litepod-16. This instance will have two VMs. So, while you can connect to each instance individually using SSH, it’s much easier to execute stuff on all at the same time using this argument.
# Install Jax
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--worker=all \
--command='pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'
# Clone imagenet metadata
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--worker=all \
--command='git clone https://github.com/google/flax.git && cd flax/examples/imagenet && pip install -r requirements.txt && pip install flax==0.7.4'
# Make fake data
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--worker=all \
--command='mkdir -p $HOME/flax/.tfds/metadata/imagenet2012/5.1.0 && curl https://raw.githubusercontent.com/tensorflow/datasets/v4.4.0/tensorflow_datasets/testing/metadata/imagenet2012/5.1.0/dataset_info.json --output $HOME/flax/.tfds/metadata/imagenet2012/5.1.0/dataset_info.json'
# Train
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--worker=all \
--command='cd flax/examples/imagenet && JAX_PLATFORMS=tpu python3 imagenet_fake_data_benchmark.py'
If you get an error AttributeError: module 'jax.config' has no attribute 'define_bool_state'
, then you likely need to use jax==0.4.24
during the first step. See this open discussion on flax on the issue.