Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/source/api_ref_modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,9 @@ These are utilities that are common to and can be used by all modules.
:nosignatures:

common_utils.reparametrize_as_dtype_state_dict_post_hook
common_utils.local_kv_cache
common_utils.disable_kv_cache
common_utils.delete_kv_caches


Vision Transforms
Expand Down
108 changes: 49 additions & 59 deletions recipes/eleuther_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

import torch

from lm_eval.evaluator import evaluate, get_task_list
from lm_eval.evaluator import evaluate
from lm_eval.models.hf_vlms import HFMultimodalLM
from lm_eval.models.huggingface import HFLM
from lm_eval.tasks import get_task_dict, TaskManager
Expand All @@ -29,6 +29,7 @@
)
from torchtune.generation import generate, sample
from torchtune.modules import TransformerDecoder
from torchtune.modules.common_utils import local_kv_cache
from torchtune.modules.model_fusion import DeepFusionModel
from torchtune.modules.tokenizers import ModelTokenizer
from torchtune.modules.transforms import Transform
Expand Down Expand Up @@ -224,18 +225,11 @@ def _model_multimodal_generate(
"multimodal generation."
)

# 2. Setup KV cache and masks for bsz 1
encoder_max_seq_len = (
self.model_transform.image_seq_len * self._max_images_per_sample
)
# Setup masks for bsz 1
with self.device:
if self.model.caches_are_enabled():
self.model.reset_caches()
else:
self.model.setup_caches(
batch_size=1,
dtype=self._dtype,
encoder_max_seq_len=self.model_transform.image_seq_len
* self._max_images_per_sample,
decoder_max_seq_len=self.max_length,
)
causal_mask = torch.tril(
torch.ones(
size=(self.max_length, self.max_length),
Expand All @@ -247,28 +241,37 @@ def _model_multimodal_generate(
batch["input_pos"] = input_pos[None, :seq_len]
batch["mask"] = causal_mask[None, :seq_len]

# 3. Prefill step
generated_tokens = []
logits = self.model(prompt, **batch)[:, -1]
token = sample(logits, temperature=0.0, top_k=None)
generated_tokens.append(token.item())

cache_mask = batch["encoder_mask"][:, -1:]

# 4. Continue generating
for _ in range(max_length):
if token.item() in self.model_transform.stop_tokens:
break
logits = self.model(
token,
mask=causal_mask[None, seq_len, None, :],
encoder_input=None,
encoder_mask=cache_mask,
input_pos=input_pos[None, seq_len],
)[:, -1]
# 2. Setup KV cache
with local_kv_cache(
self.model,
batch_size=self.batch_size,
device=self.device,
dtype=self._dtype,
encoder_max_seq_len=encoder_max_seq_len,
decoder_max_seq_len=self.max_length,
):
# 3. Prefill step
generated_tokens = []
logits = self.model(prompt, **batch)[:, -1]
token = sample(logits, temperature=0.0, top_k=None)
generated_tokens.append(token.item())
seq_len += 1

cache_mask = batch["encoder_mask"][:, -1:]

# 4. Continue generating
for _ in range(max_length):
if token.item() in self.model_transform.stop_tokens:
break
logits = self.model(
token,
mask=causal_mask[None, seq_len, None, :],
encoder_input=None,
encoder_mask=cache_mask,
input_pos=input_pos[None, seq_len],
)[:, -1]
token = sample(logits, temperature=0.0, top_k=None)
generated_tokens.append(token.item())
seq_len += 1

# 5. Return generated tokens
return torch.tensor(generated_tokens, dtype=torch.int32).unsqueeze(0)
Expand Down Expand Up @@ -388,18 +391,6 @@ def _model_generate(
"Any decoding strategy other than greedy is not supported."
)

# Setup KV caches OR reset them if they're already set up
if self.enable_kv_cache:
if self.model.caches_are_enabled():
self.model.reset_caches()
else:
with self.device:
self.model.setup_caches(
batch_size=self.batch_size,
dtype=self._dtype,
decoder_max_seq_len=self.max_length,
)

# if we've recieved fewer than self._batch_size samples in the current
# batch we need to pad the batch out. here we're padding the end of the
# current batch to the correct length. this is because when we use static
Expand All @@ -409,15 +400,21 @@ def _model_generate(
(0, 0, 0, self._batch_size - bsz),
value=self._tokenizer.eos_id, # pad with one of the tokenizer's stop tokens so generation can stop early
)

toks, _ = generate(
with local_kv_cache(
self.model,
maybe_padded_context,
max_generated_tokens=self.max_gen_toks,
temperature=temperature,
top_k=None,
stop_tokens=self._tokenizer.stop_tokens,
)
batch_size=self.batch_size,
device=self.device,
dtype=self._dtype,
decoder_max_seq_len=self.max_length,
):
toks, _ = generate(
self.model,
maybe_padded_context,
max_generated_tokens=self.max_gen_toks,
temperature=temperature,
top_k=None,
stop_tokens=self._tokenizer.stop_tokens,
)
return toks[:bsz]


Expand Down Expand Up @@ -536,13 +533,6 @@ def evaluate(self) -> None:
# Initialize tasks for the harness
task_manager = TaskManager(include_path=self.include_path)
task_dict = get_task_dict(self.tasks, task_manager)
task_types = set([t.task.OUTPUT_TYPE for t in get_task_list(task_dict)])
if len(task_types) > 1 and "generate_until" in task_types:
raise RuntimeError(
"Evaluating on multiple task types where any one task involves "
"generation is currently not supported. See the issue below for more info: "
"https://github.com/pytorch/torchtune/issues/1621"
)

# Run evaluation
t0 = time.time()
Expand Down
9 changes: 0 additions & 9 deletions tests/recipes/test_eleuther_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,12 +194,3 @@ def test_eval_recipe_errors_with_qat_quantizer(self, capsys, monkeypatch, tmpdir
match="QAT quantizers should only be used during quantization aware training",
):
runpy.run_path(TUNE_PATH, run_name="__main__")

@pytest.mark.integration_test
def test_eval_recipe_errors_with_generate_until_and_mc_tasks(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we instead have a test that this now actually works?

self, caplog, capsys, monkeypatch, tmpdir
):
# We can't currently specify both generate_until and mc_tasks in the same run
# b/c the KV cache won't be reset and the result will be different. This test
# catches that error
pass
20 changes: 12 additions & 8 deletions tests/torchtune/modules/model_fusion/test_fusion_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,13 @@ def __init__(self, dim):
self.cache_enabled = False
self.encoder_max_seq_len = None

def setup_cache(self, batch_size, dtype, encoder_max_seq_len, decoder_max_seq_len):
def setup_caches(self, batch_size, dtype, encoder_max_seq_len, decoder_max_seq_len):
self.cache_enabled = True
self.encoder_max_seq_len = encoder_max_seq_len

def caches_are_enabled(self):
return self.cache_enabled

def reset_cache(self):
self.cache_enabled = False

Expand All @@ -43,10 +46,13 @@ def __init__(self, dim):
self.cache_enabled = False
self.decoder_max_seq_len = None

def setup_cache(self, batch_size, dtype, encoder_max_seq_len, decoder_max_seq_len):
def setup_caches(self, batch_size, dtype, encoder_max_seq_len, decoder_max_seq_len):
self.cache_enabled = True
self.decoder_max_seq_len = decoder_max_seq_len

def caches_are_enabled(self):
return self.cache_enabled

def reset_cache(self):
self.cache_enabled = False

Expand Down Expand Up @@ -131,22 +137,20 @@ def test_fusion_params(self, fused_layer):
"fusion_layer.linear.bias",
}

def test_setup_cache(self, fused_layer):
def test_setup_caches(self, fused_layer):
"""
Test that the cache methods works as expected.
"""
fused_layer.setup_cache(
fused_layer.setup_caches(
2, torch.float32, encoder_max_seq_len=10, decoder_max_seq_len=10
)
assert fused_layer.cache_enabled
fused_layer.reset_cache()
assert not fused_layer.cache_enabled
assert fused_layer.caches_are_enabled()

def test_setup_cache_different_cache_seq_len(self, fused_layer):
"""
Test that the cache methods works as expected.
"""
fused_layer.setup_cache(
fused_layer.setup_caches(
2, torch.float32, encoder_max_seq_len=5, decoder_max_seq_len=10
)

Expand Down
6 changes: 3 additions & 3 deletions tests/torchtune/modules/model_fusion/test_fusion_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def __init__(self, dim, vocab_size):
def setup_caches(self, batch_size, dtype, *args, **kwargs):
self.cache_enabled = True

def caches_are_enabled(self):
def caches_are_setup(self):
return self.cache_enabled

def reset_caches(self):
Expand Down Expand Up @@ -144,9 +144,9 @@ def test_setup_cache(self, fused_model):
Test that the cache methods works as expected.
"""
fused_model.setup_caches(2, torch.float32)
assert fused_model.caches_are_enabled()
assert fused_model.caches_are_setup()
fused_model.reset_caches()
assert not fused_model.caches_are_enabled()
assert not fused_model.caches_are_setup()

def test_set_trainable_params(self, fused_model, encoder, decoder):
"""
Expand Down
3 changes: 3 additions & 0 deletions tests/torchtune/modules/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ def gqa_kv_cache(
kv_cache=kv_cache,
max_seq_len=max_seq_len,
)
attn.cache_enabled = True
fixed_init_model(attn)
attn.eval()
return attn
Expand Down Expand Up @@ -195,6 +196,7 @@ def mha_kv_cache(
kv_cache=kv_cache,
max_seq_len=max_seq_len,
)
attn.cache_enabled = True
fixed_init_model(attn)
attn.eval()
return attn
Expand Down Expand Up @@ -249,6 +251,7 @@ def mqa_kv_cache(
kv_cache=kv_cache,
max_seq_len=max_seq_len,
)
attn.cache_enabled = True
fixed_init_model(attn)
attn.eval()
return attn
Expand Down
Loading