Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
227 changes: 107 additions & 120 deletions tests/e2e/test_sampling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,19 @@

pytestmark = [pytest.mark.full_model, pytest.mark.other_e2e]

# TODO: REVERT THIS CHANGE!
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we parametrize the test such that they get executed for SB and CB?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or do we already have a similar test for CB?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do not have this test for CB. I think the issue here is increase too much the time of CI. Moreover, they do not repro very well the issue of this PR.



def test_spyre_batch1_temperature(model: ModelInfo, backend, monkeypatch,
use_llm_cache, warmup_shapes):

spyre_model = get_cached_llm(
model=model,
max_model_len=128,
tensor_parallel_size=1,
backend=backend,
monkeypatch=monkeypatch,
warmup_shapes=warmup_shapes,
)
spyre_model = get_cached_llm(model=model,
max_model_len=128,
tensor_parallel_size=1,
backend=backend,
monkeypatch=monkeypatch,
warmup_shapes=None,
use_cb=True)

prompt = "The capital of the United Kingdom is"
params1 = SamplingParams(temperature=0.0, seed=8780, max_tokens=20)
Expand All @@ -36,14 +37,13 @@ def test_spyre_batch1_temperature(model: ModelInfo, backend, monkeypatch,
def test_spyre_batch1_max_tokens(model: ModelInfo, backend, monkeypatch,
use_llm_cache, warmup_shapes):

spyre_model = get_cached_llm(
model=model,
max_model_len=128,
tensor_parallel_size=1,
backend=backend,
monkeypatch=monkeypatch,
warmup_shapes=warmup_shapes,
)
spyre_model = get_cached_llm(model=model,
max_model_len=128,
tensor_parallel_size=1,
backend=backend,
monkeypatch=monkeypatch,
warmup_shapes=None,
use_cb=True)

prompt = "Count to twenty"
params1 = SamplingParams(temperature=0, seed=8780, max_tokens=15)
Expand All @@ -58,14 +58,13 @@ def test_spyre_batch1_max_tokens(model: ModelInfo, backend, monkeypatch,

def test_spyre_batch1_stop_sequence(model: ModelInfo, backend, monkeypatch,
use_llm_cache, warmup_shapes):
spyre_model = get_cached_llm(
model=model,
max_model_len=128,
tensor_parallel_size=1,
backend=backend,
monkeypatch=monkeypatch,
warmup_shapes=warmup_shapes,
)
spyre_model = get_cached_llm(model=model,
max_model_len=128,
tensor_parallel_size=1,
backend=backend,
monkeypatch=monkeypatch,
warmup_shapes=None,
use_cb=True)
stop_str = "train"
prompt = "The best way to travel from Paris to Berlim is by "

Expand All @@ -90,14 +89,13 @@ def max_repetitions(output):

def test_spyre_batch1_presence_penalty(model: ModelInfo, backend, monkeypatch,
use_llm_cache, warmup_shapes):
spyre_model = get_cached_llm(
model=model,
max_model_len=128,
tensor_parallel_size=1,
backend=backend,
monkeypatch=monkeypatch,
warmup_shapes=warmup_shapes,
)
spyre_model = get_cached_llm(model=model,
max_model_len=128,
tensor_parallel_size=1,
backend=backend,
monkeypatch=monkeypatch,
warmup_shapes=None,
use_cb=True)
prompt = "REPEAT OVER AND OVER AGAIN THE MINIMUM "\
"TIMES POSSIBLE: one one one one one"

Expand All @@ -116,14 +114,13 @@ def test_spyre_batch1_presence_penalty(model: ModelInfo, backend, monkeypatch,

def test_spyre_batch1_frequency_penalty(model: ModelInfo, backend, monkeypatch,
use_llm_cache, warmup_shapes):
spyre_model = get_cached_llm(
model=model,
max_model_len=128,
tensor_parallel_size=1,
backend=backend,
monkeypatch=monkeypatch,
warmup_shapes=warmup_shapes,
)
spyre_model = get_cached_llm(model=model,
max_model_len=128,
tensor_parallel_size=1,
backend=backend,
monkeypatch=monkeypatch,
warmup_shapes=None,
use_cb=True)

prompt = 'repeat the word hi ten times:'

Expand All @@ -141,14 +138,13 @@ def test_spyre_batch1_frequency_penalty(model: ModelInfo, backend, monkeypatch,

def test_spyre_batch1_n_generations(model: ModelInfo, backend, monkeypatch,
use_llm_cache, warmup_shapes):
spyre_model = get_cached_llm(
model=model,
max_model_len=128,
tensor_parallel_size=1,
backend=backend,
monkeypatch=monkeypatch,
warmup_shapes=warmup_shapes,
)
spyre_model = get_cached_llm(model=model,
max_model_len=128,
tensor_parallel_size=1,
backend=backend,
monkeypatch=monkeypatch,
warmup_shapes=None,
use_cb=True)
prompt = "The three most popular sports in the world are: "

params = SamplingParams(n=3, seed=8780, max_tokens=20)
Expand All @@ -172,14 +168,13 @@ def token_diversity(spyre_model, prompt, params, n_experiments):

def test_spyre_batch1_top_p(model: ModelInfo, backend, monkeypatch,
use_llm_cache, warmup_shapes):
spyre_model = get_cached_llm(
model=model,
max_model_len=128,
tensor_parallel_size=1,
backend=backend,
monkeypatch=monkeypatch,
warmup_shapes=warmup_shapes,
)
spyre_model = get_cached_llm(model=model,
max_model_len=128,
tensor_parallel_size=1,
backend=backend,
monkeypatch=monkeypatch,
warmup_shapes=None,
use_cb=True)
prompt = "The first three letters of the alphabet are"
params1 = SamplingParams(top_p=0.01, temperature=1, max_tokens=10)
params2 = SamplingParams(temperature=1, max_tokens=10)
Expand All @@ -191,14 +186,13 @@ def test_spyre_batch1_top_p(model: ModelInfo, backend, monkeypatch,

def test_spyre_batch1_top_k(model: ModelInfo, backend, monkeypatch,
use_llm_cache, warmup_shapes):
spyre_model = get_cached_llm(
model=model,
max_model_len=128,
tensor_parallel_size=1,
backend=backend,
monkeypatch=monkeypatch,
warmup_shapes=warmup_shapes,
)
spyre_model = get_cached_llm(model=model,
max_model_len=128,
tensor_parallel_size=1,
backend=backend,
monkeypatch=monkeypatch,
warmup_shapes=None,
use_cb=True)
prompt = "The opposite of hot is"
params1 = SamplingParams(temperature=1, top_k=1, max_tokens=5)
params2 = SamplingParams(temperature=1, max_tokens=5)
Expand All @@ -210,14 +204,13 @@ def test_spyre_batch1_top_k(model: ModelInfo, backend, monkeypatch,

def test_spyre_batch1_logit_bias(model: ModelInfo, backend, monkeypatch,
use_llm_cache, warmup_shapes):
spyre_model = get_cached_llm(
model=model,
max_model_len=128,
tensor_parallel_size=1,
backend=backend,
monkeypatch=monkeypatch,
warmup_shapes=warmup_shapes,
)
spyre_model = get_cached_llm(model=model,
max_model_len=128,
tensor_parallel_size=1,
backend=backend,
monkeypatch=monkeypatch,
warmup_shapes=None,
use_cb=True)
tokenizer = spyre_model.get_tokenizer()
banned_word = "train"
forced_word = "plane"
Expand Down Expand Up @@ -249,14 +242,13 @@ def test_spyre_batch1_logit_bias(model: ModelInfo, backend, monkeypatch,

def test_spyre_batch1_min_tokens(model: ModelInfo, backend, monkeypatch,
use_llm_cache, warmup_shapes):
spyre_model = get_cached_llm(
model=model,
max_model_len=128,
tensor_parallel_size=1,
backend=backend,
monkeypatch=monkeypatch,
warmup_shapes=warmup_shapes,
)
spyre_model = get_cached_llm(model=model,
max_model_len=128,
tensor_parallel_size=1,
backend=backend,
monkeypatch=monkeypatch,
warmup_shapes=None,
use_cb=True)
prompt = "What is the capital of the USA?"
tokenizer = spyre_model.get_tokenizer()
eos_id = tokenizer.eos_token_id
Expand All @@ -276,14 +268,13 @@ def test_spyre_batch1_min_tokens(model: ModelInfo, backend, monkeypatch,

def test_spyre_batch1_ignore_eos(model: ModelInfo, backend, monkeypatch,
use_llm_cache, warmup_shapes):
spyre_model = get_cached_llm(
model=model,
max_model_len=128,
tensor_parallel_size=1,
backend=backend,
monkeypatch=monkeypatch,
warmup_shapes=warmup_shapes,
)
spyre_model = get_cached_llm(model=model,
max_model_len=128,
tensor_parallel_size=1,
backend=backend,
monkeypatch=monkeypatch,
warmup_shapes=None,
use_cb=True)
tokenizer = spyre_model.get_tokenizer()
eos_id = tokenizer.eos_token_id
prompt = "This is the end of the story"
Expand All @@ -310,14 +301,13 @@ def test_spyre_batch1_ignore_eos(model: ModelInfo, backend, monkeypatch,

def test_spyre_batch1_min_p(model: ModelInfo, backend, monkeypatch,
use_llm_cache, warmup_shapes):
spyre_model = get_cached_llm(
model=model,
max_model_len=128,
tensor_parallel_size=1,
backend=backend,
monkeypatch=monkeypatch,
warmup_shapes=warmup_shapes,
)
spyre_model = get_cached_llm(model=model,
max_model_len=128,
tensor_parallel_size=1,
backend=backend,
monkeypatch=monkeypatch,
warmup_shapes=None,
use_cb=True)
prompt = "The opposite of black is"
params1 = SamplingParams(min_p=0.5, temperature=1, max_tokens=5)
params2 = SamplingParams(temperature=1, max_tokens=5)
Expand All @@ -330,14 +320,13 @@ def test_spyre_batch1_min_p(model: ModelInfo, backend, monkeypatch,

def test_spyre_batch1_bad_words(model: ModelInfo, backend, monkeypatch,
use_llm_cache, warmup_shapes):
spyre_model = get_cached_llm(
model=model,
max_model_len=128,
tensor_parallel_size=1,
backend=backend,
monkeypatch=monkeypatch,
warmup_shapes=warmup_shapes,
)
spyre_model = get_cached_llm(model=model,
max_model_len=128,
tensor_parallel_size=1,
backend=backend,
monkeypatch=monkeypatch,
warmup_shapes=None,
use_cb=True)
prompt = "The capital of France is"
params1 = SamplingParams(max_tokens=5,
temperature=0,
Expand All @@ -355,14 +344,13 @@ def test_spyre_batch1_bad_words(model: ModelInfo, backend, monkeypatch,

def test_spyre_batch1_detokenize(model: ModelInfo, backend, monkeypatch,
use_llm_cache, warmup_shapes):
spyre_model = get_cached_llm(
model=model,
max_model_len=128,
tensor_parallel_size=1,
backend=backend,
monkeypatch=monkeypatch,
warmup_shapes=warmup_shapes,
)
spyre_model = get_cached_llm(model=model,
max_model_len=128,
tensor_parallel_size=1,
backend=backend,
monkeypatch=monkeypatch,
warmup_shapes=None,
use_cb=True)
prompt = "Hello, world!"
params = SamplingParams(max_tokens=5,
seed=8780,
Expand All @@ -376,14 +364,13 @@ def test_spyre_batch1_detokenize(model: ModelInfo, backend, monkeypatch,

def test_spyre_batch1_logprobs(model: ModelInfo, backend, monkeypatch,
use_llm_cache, warmup_shapes):
spyre_model = get_cached_llm(
model=model,
max_model_len=128,
tensor_parallel_size=1,
backend=backend,
monkeypatch=monkeypatch,
warmup_shapes=warmup_shapes,
)
spyre_model = get_cached_llm(model=model,
max_model_len=128,
tensor_parallel_size=1,
backend=backend,
monkeypatch=monkeypatch,
warmup_shapes=None,
use_cb=True)
num_logprobs = 5
prompt = "The sky is"
params = SamplingParams(max_tokens=5,
Expand Down
Loading
Loading