Skip to content

Commit 5cb2994

Browse files
mwhittakerGoogle-ML-Automation
authored andcommitted
Warn the user if transparent huge pages aren't enabled.
PiperOrigin-RevId: 735431881
1 parent 14b215f commit 5cb2994

File tree

3 files changed

+54
-17
lines changed

3 files changed

+54
-17
lines changed

jax/_src/cloud_tpu_init.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import datetime
1616
import os
1717
import re
18+
import warnings
1819
from jax import version
1920
from jax._src import config
2021
from jax._src import hardware_utils
@@ -72,7 +73,19 @@ def cloud_tpu_init() -> None:
7273

7374
# Exit early if we're not running on a Cloud TPU VM or libtpu isn't installed.
7475
libtpu_path = get_tpu_library_path()
75-
num_tpu_chips = hardware_utils.num_available_tpu_chips_and_device_id()[0]
76+
num_tpu_chips, tpu_id = hardware_utils.num_available_tpu_chips_and_device_id()
77+
if (
78+
tpu_id is not None
79+
and tpu_id >= hardware_utils.TpuVersion.v5e
80+
and not hardware_utils.transparent_hugepages_enabled()
81+
):
82+
warnings.warn(
83+
'Transparent hugepages are not enabled. TPU runtime startup and'
84+
' shutdown time should be significantly improved on TPU v5e and newer.'
85+
' If not already set, you may need to enable transparent hugepages in'
86+
' your VM image (sudo sh -c "echo always >'
87+
' /sys/kernel/mm/transparent_hugepage/enabled")'
88+
)
7689
if (libtpu_path is None or num_tpu_chips == 0) and not jax_force_tpu_init():
7790
return
7891

jax/_src/hardware_utils.py

Lines changed: 37 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -12,36 +12,49 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import enum
1516
import os
1617
import pathlib
1718
import glob
1819

1920
_GOOGLE_PCI_VENDOR_ID = '0x1ae0'
20-
_TPU_PCI_DEVICE_IDS = [
21-
# TPU v2, v3
22-
'0x0027',
23-
# No public name (plc)
24-
'0x0056',
25-
# TPU v4
26-
'0x005e',
27-
# TPU v5p
28-
'0x0062',
29-
# TPU v5e
30-
'0x0063',
31-
# TPU v6e
32-
'0x006f',
33-
]
3421

3522
_NVIDIA_GPU_DEVICES = [
3623
'/dev/nvidia0',
3724
'/dev/nvidiactl', # Docker/Kubernetes
3825
'/dev/dxg', # WSL2
3926
]
4027

28+
29+
class TpuVersion(enum.IntEnum):
30+
# TPU v2, v3
31+
v2 = 0
32+
v3 = 1
33+
# No public name (plc)
34+
plc = 2
35+
# TPU v4
36+
v4 = 3
37+
# TPU v5p
38+
v5p = 4
39+
# TPU v5e
40+
v5e = 5
41+
# TPU v6e
42+
v6e = 6
43+
44+
45+
_TPU_PCI_DEVICE_IDS = {
46+
'0x0027': TpuVersion.v3,
47+
'0x0056': TpuVersion.plc,
48+
'0x005e': TpuVersion.v4,
49+
'0x0062': TpuVersion.v5p,
50+
'0x0063': TpuVersion.v5e,
51+
'0x006f': TpuVersion.v6e,
52+
}
53+
4154
def num_available_tpu_chips_and_device_id():
4255
"""Returns the device id and number of TPU chips attached through PCI."""
4356
num_chips = 0
44-
device_id = ''
57+
tpu_version = None
4558
for vendor_path in glob.glob('/sys/bus/pci/devices/*/vendor'):
4659
vendor_id = pathlib.Path(vendor_path).read_text().strip()
4760
if vendor_id != _GOOGLE_PCI_VENDOR_ID:
@@ -50,12 +63,20 @@ def num_available_tpu_chips_and_device_id():
5063
device_path = os.path.join(os.path.dirname(vendor_path), 'device')
5164
device_id = pathlib.Path(device_path).read_text().strip()
5265
if device_id in _TPU_PCI_DEVICE_IDS:
66+
tpu_version = _TPU_PCI_DEVICE_IDS[device_id]
5367
num_chips += 1
5468

55-
return num_chips, device_id
69+
return num_chips, tpu_version
5670

5771

5872
def has_visible_nvidia_gpu() -> bool:
5973
"""True if there's a visible nvidia gpu available on device, False otherwise."""
6074

6175
return any(os.path.exists(d) for d in _NVIDIA_GPU_DEVICES)
76+
77+
78+
def transparent_hugepages_enabled() -> bool:
79+
# See https://docs.kernel.org/admin-guide/mm/transhuge.html for more
80+
# information about transparent huge pages.
81+
path = pathlib.Path('/sys/kernel/mm/transparent_hugepage/enabled')
82+
return path.exists() and path.read_text().strip() == '[always] madvise never'

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,9 @@ filterwarnings = [
6666
# https://github.com/protocolbuffers/protobuf/issues/12186#issuecomment-1745679358
6767
"ignore:Type google\\._upb\\._message\\.(Scalar|Message)MapContainer uses PyType_Spec with a metaclass that has custom tp_new\\. This is deprecated and will no longer be allowed in Python 3\\.14\\.:DeprecationWarning",
6868

69+
# TODO(b/401588349): Remove this once transparent hugepages are enabled.
70+
"ignore:Transparent hugepages",
71+
6972
# NOTE: this is probably not where you want to add code to suppress a
7073
# warning. Only pytest tests look at this list, whereas Bazel tests also
7174
# check for warnings and do not check this list. Most likely, you should

0 commit comments

Comments
 (0)