Skip to content

Commit 05943a0

Browse files
adapt API and add tests to make sure the generated URIs are correct
1 parent d55a2c3 commit 05943a0

File tree

1 file changed

+104
-0
lines changed

1 file changed

+104
-0
lines changed
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
import pytest
16+
from packaging.version import parse
17+
18+
from sagemaker.huggingface import get_huggingface_llm_image_uri
19+
from tests.unit.sagemaker.image_uris import expected_uris
20+
21+
# Mapping of vLLM versions to expected image tags
22+
VLLM_VERSIONS_MAPPING = {
23+
"inf2": {
24+
"0.10.2": "0.10.2-neuronx-py310-sdk2.26.0-ubuntu22.04",
25+
},
26+
}
27+
28+
29+
@pytest.mark.parametrize("load_config", ["huggingface-vllm-neuronx.json"], indirect=True)
30+
def test_vllm_neuronx_uris(load_config):
31+
"""Test that vLLM NeuronX image URIs are correctly generated."""
32+
VERSIONS = load_config["inference"]["versions"]
33+
device = load_config["inference"]["processors"][0]
34+
35+
# Fail if device is not in mapping
36+
if device not in VLLM_VERSIONS_MAPPING:
37+
raise ValueError(f"Device {device} not found in VLLM_VERSIONS_MAPPING")
38+
39+
# Get highest version for the device
40+
highest_version = max(VLLM_VERSIONS_MAPPING[device].keys(), key=lambda x: parse(x))
41+
42+
for version in VERSIONS:
43+
ACCOUNTS = load_config["inference"]["versions"][version]["registries"]
44+
for region in ACCOUNTS.keys():
45+
uri = get_huggingface_llm_image_uri(
46+
"huggingface-vllm-neuronx",
47+
region=region,
48+
version=version,
49+
)
50+
51+
# Skip only if test version is higher than highest known version
52+
if parse(version) > parse(highest_version):
53+
print(
54+
f"Skipping version check for {version} as it is higher than "
55+
f"the highest known version {highest_version} in VLLM_VERSIONS_MAPPING."
56+
)
57+
continue
58+
59+
expected = expected_uris.huggingface_llm_framework_uri(
60+
"huggingface-vllm-inference-neuronx",
61+
ACCOUNTS[region],
62+
version,
63+
VLLM_VERSIONS_MAPPING[device][version],
64+
region=region,
65+
)
66+
assert expected == uri
67+
68+
69+
@pytest.mark.parametrize("load_config", ["huggingface-vllm-neuronx.json"], indirect=True)
70+
def test_vllm_neuronx_version_aliases(load_config):
71+
"""Test that version aliases work correctly."""
72+
version_aliases = load_config["inference"].get("version_aliases", {})
73+
74+
for alias, full_version in version_aliases.items():
75+
uri_alias = get_huggingface_llm_image_uri(
76+
"huggingface-vllm-neuronx",
77+
region="us-east-1",
78+
version=alias,
79+
)
80+
uri_full = get_huggingface_llm_image_uri(
81+
"huggingface-vllm-neuronx",
82+
region="us-east-1",
83+
version=full_version,
84+
)
85+
# URIs should be identical
86+
assert uri_alias == uri_full
87+
88+
89+
@pytest.mark.parametrize("load_config", ["huggingface-vllm-neuronx.json"], indirect=True)
90+
def test_vllm_neuronx_all_regions(load_config):
91+
"""Test that all regions have valid registry mappings."""
92+
version = "0.10.2"
93+
registries = load_config["inference"]["versions"][version]["registries"]
94+
95+
for region in registries.keys():
96+
uri = get_huggingface_llm_image_uri(
97+
"huggingface-vllm-neuronx",
98+
region=region,
99+
version=version,
100+
)
101+
# Validate URI format
102+
assert uri.startswith(f"{registries[region]}.dkr.ecr.{region}")
103+
assert "huggingface-vllm-inference-neuronx" in uri
104+
assert "0.10.2" in uri

0 commit comments

Comments
 (0)