Skip to content

Commit 3182529

Browse files
stas00leandrolvwerrasgugger
authored andcommitted
[WIP] [doc] performance/scalability revamp (huggingface#15723)
* [doc] performance/scalability revamp * link the new docs * no : * mixed precision * work on the first doc * expand the main doc * Trigger CI * style * revamp single GPU training section * work on training performance * remove files not used anymore or will be added later * final touches * fix rebase * Add hardware section to toctree * fix toctree again * Apply suggestions from code review Co-authored-by: Sylvain Gugger <[email protected]> * remove `fast_tokenizers` entry that was copied in rebase * add warning about DP vs DDP * remove todo * Apply suggestions from code review Co-authored-by: Sylvain Gugger <[email protected]> * fix missing closure of codeblock * Update docs/source/en/perf_train_gpu_many.mdx Co-authored-by: Sylvain Gugger <[email protected]> * sync with huggingface#16860 * update toc Co-authored-by: leandro <[email protected]> Co-authored-by: Leandro von Werra <[email protected]> Co-authored-by: Sylvain Gugger <[email protected]>
1 parent d519e22 commit 3182529

File tree

5 files changed

+1043
-1071
lines changed

5 files changed

+1043
-1071
lines changed

docs/source/en/_toctree.yml

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
- sections:
1+
- sections:
22
- local: index
33
title: 🤗 Transformers
44
- local: quicktour
@@ -60,11 +60,9 @@
6060
- local: serialization
6161
title: Export 🤗 Transformers models
6262
- local: performance
63-
title: 'Performance and Scalability: How To Fit a Bigger Model and Train It Faster'
63+
title: Performance and scalability
6464
- local: big_models
6565
title: Instantiating a big model
66-
- local: parallelism
67-
title: Model Parallelism
6866
- local: benchmarks
6967
title: Benchmarks
7068
- local: migration
@@ -83,6 +81,12 @@
8381
title: "How to add a model to 🤗 Transformers?"
8482
- local: add_new_pipeline
8583
title: "How to add a pipeline to 🤗 Transformers?"
84+
- local: perf_train_gpu_one
85+
title: Training on one GPU
86+
- local: perf_train_gpu_many
87+
title: Training on many GPUs
88+
- local: perf_hardware
89+
title: Custom hardware for training
8690
- local: testing
8791
title: Testing
8892
- local: pr_checks

docs/source/en/perf_hardware.mdx

Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
<!---
2+
Copyright 2022 The HuggingFace Team. All rights reserved.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
-->
16+
17+
18+
# Custom hardware for training
19+
20+
The hardware you use to run model training and inference can have a big effect on performance. For a deep dive into GPUs make sure to check out Tim Dettmer's excellent [blog post](https://timdettmers.com/2020/09/07/which-gpu-for-deep-learning/).
21+
22+
Let's have a look at some practical advice for GPU setups.
23+
24+
## GPU
25+
When you train bigger models you have essentially three options:
26+
- bigger GPUs
27+
- more GPUs
28+
- more CPU and NVMe (offloaded to by [DeepSpeed-Infinity](main_classes/deepspeed#nvme-support))
29+
30+
Let's start at the case where you have a single GPU.
31+
32+
### Power and Cooling
33+
34+
If you bought an expensive high end GPU make sure you give it the correct power and sufficient cooling.
35+
36+
**Power**:
37+
38+
Some high end consumer GPU cards have 2 and sometimes 3 PCI-E 8-Pin power sockets. Make sure you have as many independent 12V PCI-E 8-Pin cables plugged into the card as there are sockets. Do not use the 2 splits at one end of the same cable (also known as pigtail cable). That is if you have 2 sockets on the GPU, you want 2 PCI-E 8-Pin cables going from your PSU to the card and not one that has 2 PCI-E 8-Pin connectors at the end! You won't get the full performance out of your card otherwise.
39+
40+
Each PCI-E 8-Pin power cable needs to be plugged into a 12V rail on the PSU side and can supply up to 150W of power.
41+
42+
Some other cards may use a PCI-E 12-Pin connectors, and these can deliver up to 500-600W of power.
43+
44+
Low end cards may use 6-Pin connectors, which supply up to 75W of power.
45+
46+
Additionally you want the high-end PSU that has stable voltage. Some lower quality ones may not give the card the stable voltage it needs to function at its peak.
47+
48+
And of course the PSU needs to have enough unused Watts to power the card.
49+
50+
**Cooling**:
51+
52+
When a GPU gets overheated it will start throttling down and will not deliver full performance and it can even shutdown if it gets too hot.
53+
54+
It's hard to tell the exact best temperature to strive for when a GPU is heavily loaded, but probably anything under +80C is good, but lower is better - perhaps 70-75C is an excellent range to be in. The throttling down is likely to start at around 84-90C. But other than throttling performance a prolonged very high temperature is likely to reduce the lifespan of a GPU.
55+
56+
Next let's have a look at one of the most important aspects when having multiple GPUs: connectivity.
57+
58+
### Multi-GPU Connectivity
59+
60+
If you use multiple GPUs the way cards are inter-connected can have a huge impact on the total training time. If the GPUs are on the same physical node, you can run:
61+
62+
```
63+
nvidia-smi topo -m
64+
```
65+
66+
and it will tell you how the GPUs are inter-connected. On a machine with dual-GPU and which are connected with NVLink, you will most likely see something like:
67+
68+
```
69+
GPU0 GPU1 CPU Affinity NUMA Affinity
70+
GPU0 X NV2 0-23 N/A
71+
GPU1 NV2 X 0-23 N/A
72+
```
73+
74+
on a different machine w/o NVLink we may see:
75+
```
76+
GPU0 GPU1 CPU Affinity NUMA Affinity
77+
GPU0 X PHB 0-11 N/A
78+
GPU1 PHB X 0-11 N/A
79+
```
80+
81+
The report includes this legend:
82+
83+
```
84+
X = Self
85+
SYS = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
86+
NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node
87+
PHB = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
88+
PXB = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
89+
PIX = Connection traversing at most a single PCIe bridge
90+
NV# = Connection traversing a bonded set of # NVLinks
91+
```
92+
93+
So the first report `NV2` tells us the GPUs are interconnected with 2 NVLinks, and the second report `PHB` we have a typical consumer-level PCIe+Bridge setup.
94+
95+
Check what type of connectivity you have on your setup. Some of these will make the communication between cards faster (e.g. NVLink), others slower (e.g. PHB).
96+
97+
Depending on the type of scalability solution used, the connectivity speed could have a major or a minor impact. If the GPUs need to sync rarely, as in DDP, the impact of a slower connection will be less significant. If the GPUs need to send messages to each other often, as in ZeRO-DP, then faster connectivity becomes super important to achieve faster training.
98+
99+
#### NVlink
100+
101+
[NVLink](https://en.wikipedia.org/wiki/NVLink) is a wire-based serial multi-lane near-range communications link developed by Nvidia.
102+
103+
Each new generation provides a faster bandwidth, e.g. here is a quote from [Nvidia Ampere GA102 GPU Architecture](https://www.nvidia.com/content/dam/en-zz/Solutions/geforce/ampere/pdf/NVIDIA-ampere-GA102-GPU-Architecture-Whitepaper-V1.pdf):
104+
105+
> Third-Generation NVLink®
106+
> GA102 GPUs utilize NVIDIA’s third-generation NVLink interface, which includes four x4 links,
107+
> with each link providing 14.0625 GB/sec bandwidth in each direction between two GPUs. Four
108+
> links provide 56.25 GB/sec bandwidth in each direction, and 112.5 GB/sec total bandwidth
109+
> between two GPUs. Two RTX 3090 GPUs can be connected together for SLI using NVLink.
110+
> (Note that 3-Way and 4-Way SLI configurations are not supported.)
111+
112+
So the higher `X` you get in the report of `NVX` in the output of `nvidia-smi topo -m` the better. The generation will depend on your GPU architecture.
113+
114+
Let's compare the execution of a gpt2 language model training over a small sample of wikitext.
115+
116+
The results are:
117+
118+
119+
| NVlink | Time |
120+
| ----- | ---: |
121+
| Y | 101s |
122+
| N | 131s |
123+
124+
125+
You can see that NVLink completes the training ~23% faster. In the second benchmark we use `NCCL_P2P_DISABLE=1` to tell the GPUs not to use NVLink.
126+
127+
Here is the full benchmark code and outputs:
128+
129+
```bash
130+
# DDP w/ NVLink
131+
132+
rm -r /tmp/test-clm; CUDA_VISIBLE_DEVICES=0,1 python -m torch.distributed.launch \
133+
--nproc_per_node 2 examples/pytorch/language-modeling/run_clm.py --model_name_or_path gpt2 \
134+
--dataset_name wikitext --dataset_config_name wikitext-2-raw-v1 --do_train \
135+
--output_dir /tmp/test-clm --per_device_train_batch_size 4 --max_steps 200
136+
137+
{'train_runtime': 101.9003, 'train_samples_per_second': 1.963, 'epoch': 0.69}
138+
139+
# DDP w/o NVLink
140+
141+
rm -r /tmp/test-clm; CUDA_VISIBLE_DEVICES=0,1 NCCL_P2P_DISABLE=1 python -m torch.distributed.launch \
142+
--nproc_per_node 2 examples/pytorch/language-modeling/run_clm.py --model_name_or_path gpt2 \
143+
--dataset_name wikitext --dataset_config_name wikitext-2-raw-v1 --do_train
144+
--output_dir /tmp/test-clm --per_device_train_batch_size 4 --max_steps 200
145+
146+
{'train_runtime': 131.4367, 'train_samples_per_second': 1.522, 'epoch': 0.69}
147+
```
148+
149+
Hardware: 2x TITAN RTX 24GB each + NVlink with 2 NVLinks (`NV2` in `nvidia-smi topo -m`)
150+
Software: `pytorch-1.8-to-be` + `cuda-11.0` / `transformers==4.3.0.dev0`

0 commit comments

Comments
 (0)