-
Notifications
You must be signed in to change notification settings - Fork 17.7k
Expand file tree
/
Copy pathtest_predicted_outputs.py
More file actions
62 lines (54 loc) · 2.02 KB
/
test_predicted_outputs.py
File metadata and controls
62 lines (54 loc) · 2.02 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import pytest
from utils import *
@pytest.fixture(scope="module", autouse=True)
def create_server():
global server
server = ServerPreset.tinyllama2()
server.draft_max = 1024
server.debug = True
def test_with_and_without_prediced_outputs():
global server
server.start()
res = server.make_request("POST", "/v1/chat/completions", data={
"messages": [{"role": "user", "content": "I believe the meaning of life is"}],
"temperature": 0.0,
"top_k": 1,
})
assert res.status_code == 200
assert res.body["usage"]["completion_tokens_details"]["accepted_prediction_tokens"] == 0
content_no_pred = res.body["choices"][0]["message"]["content"]
server.stop()
server.start()
res = server.make_request("POST", "/v1/chat/completions", data={
"messages": [{"role": "user", "content": "I believe the meaning of life is"}],
"temperature": 0.0,
"top_k": 1,
"prediction": {"content": '''"Here?" Annabyed.
"Okay, Annabyes!" Annabyed.
As Annagged, Annap came and said,'''}
})
assert res.status_code == 200
assert res.body["usage"]["completion_tokens_details"]["accepted_prediction_tokens"] == 54
content_pred = res.body["choices"][0]["message"]["content"]
server.stop()
assert content_no_pred == content_pred
@pytest.mark.parametrize("n_slots,n_requests", [
(1, 2),
(2, 2),
])
def test_multi_requests_parallel(n_slots: int, n_requests: int):
global server
server.n_slots = n_slots
server.start()
tasks = []
for _ in range(n_requests):
res = server.make_request("POST", "/v1/chat/completions", data={
"messages": [{"role": "user", "content": "I believe the meaning of life is"}],
"temperature": 0.0,
"top_k": 1,
"prediction": {"content": " believe the meaning of life is"}
})
results = parallel_function_calls(tasks)
for res in results:
assert res.status_code == 200
assert match_regex("(wise|kind|owl|answer)+", res.body["content"])