Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/source/api_ref_modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,9 @@ These are utilities that are common to and can be used by all modules.
:nosignatures:

common_utils.reparametrize_as_dtype_state_dict_post_hook
common_utils.setup_use_local_kv_cache
common_utils.use_persistent_kv_cache
common_utils.delete_kv_caches


Vision Transforms
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/llama3_2_vision/evaluation.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ tasks: ["mmmu_val_science"] # Defaulting to science as a good subset
limit: null
batch_size: 1
enable_kv_cache: True
max_seq_length: 8192

# Quantization specific args
# Quantization is not supported in this specific config
Expand Down
108 changes: 49 additions & 59 deletions recipes/eleuther_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
)
from torchtune.generation import generate, sample
from torchtune.modules import TransformerDecoder
from torchtune.modules.common_utils import setup_use_local_kv_cache
from torchtune.modules.model_fusion import DeepFusionModel
from torchtune.modules.tokenizers import ModelTokenizer
from torchtune.modules.transforms import Transform
Expand All @@ -47,7 +48,7 @@
)
sys.exit(1)

from lm_eval.evaluator import evaluate, get_task_list
from lm_eval.evaluator import evaluate

# User doesn't have to have nightlies installed, they just won't be able
# to use the multimodal model
Expand Down Expand Up @@ -253,18 +254,11 @@ def _model_multimodal_generate(
"multimodal generation."
)

# 2. Setup KV cache and masks for bsz 1
encoder_max_seq_len = (
self.model_transform.image_seq_len * self._max_images_per_sample
)
# Setup masks for bsz 1
with self.device:
if self.model.caches_are_enabled():
self.model.reset_caches()
else:
self.model.setup_caches(
batch_size=1,
dtype=self._dtype,
encoder_max_seq_len=self.model_transform.image_seq_len
* self._max_images_per_sample,
decoder_max_seq_len=self.max_length,
)
causal_mask = torch.tril(
torch.ones(
size=(self.max_length, self.max_length),
Expand All @@ -276,28 +270,37 @@ def _model_multimodal_generate(
batch["input_pos"] = input_pos[None, :seq_len]
batch["mask"] = causal_mask[None, :seq_len]

# 3. Prefill step
generated_tokens = []
logits = self.model(prompt, **batch)[:, -1]
token = sample(logits, temperature=0.0, top_k=None)
generated_tokens.append(token.item())

cache_mask = batch["encoder_mask"][:, -1:]

# 4. Continue generating
for _ in range(max_length):
if token.item() in self.model_transform.stop_tokens:
break
logits = self.model(
token,
mask=causal_mask[None, seq_len, None, :],
encoder_input=None,
encoder_mask=cache_mask,
input_pos=input_pos[None, seq_len],
)[:, -1]
# 2. Setup KV cache
with setup_use_local_kv_cache(
self.model,
batch_size=self.batch_size,
device=self.device,
dtype=self._dtype,
encoder_max_seq_len=encoder_max_seq_len,
decoder_max_seq_len=self.max_length,
):
# 3. Prefill step
generated_tokens = []
logits = self.model(prompt, **batch)[:, -1]
token = sample(logits, temperature=0.0, top_k=None)
generated_tokens.append(token.item())
seq_len += 1

cache_mask = batch["encoder_mask"][:, -1:]

# 4. Continue generating
for _ in range(max_length):
if token.item() in self.model_transform.stop_tokens:
break
logits = self.model(
token,
mask=causal_mask[None, seq_len, None, :],
encoder_input=None,
encoder_mask=cache_mask,
input_pos=input_pos[None, seq_len],
)[:, -1]
token = sample(logits, temperature=0.0, top_k=None)
generated_tokens.append(token.item())
seq_len += 1

# 5. Return generated tokens
return torch.tensor(generated_tokens, dtype=torch.int32).unsqueeze(0)
Expand Down Expand Up @@ -417,18 +420,6 @@ def _model_generate(
"Any decoding strategy other than greedy is not supported."
)

# Setup KV caches OR reset them if they're already set up
if self.enable_kv_cache:
if self.model.caches_are_enabled():
self.model.reset_caches()
else:
with self.device:
self.model.setup_caches(
batch_size=self.batch_size,
dtype=self._dtype,
decoder_max_seq_len=self.max_length,
)

# if we've recieved fewer than self._batch_size samples in the current
# batch we need to pad the batch out. here we're padding the end of the
# current batch to the correct length. this is because when we use static
Expand All @@ -438,15 +429,21 @@ def _model_generate(
(0, 0, 0, self._batch_size - bsz),
value=self._tokenizer.eos_id, # pad with one of the tokenizer's stop tokens so generation can stop early
)

toks, _ = generate(
with setup_use_local_kv_cache(
self.model,
maybe_padded_context,
max_generated_tokens=self.max_gen_toks,
temperature=temperature,
top_k=None,
stop_tokens=self._tokenizer.stop_tokens,
)
batch_size=self.batch_size,
device=self.device,
dtype=self._dtype,
decoder_max_seq_len=self.max_length,
):
toks, _ = generate(
self.model,
maybe_padded_context,
max_generated_tokens=self.max_gen_toks,
temperature=temperature,
top_k=None,
stop_tokens=self._tokenizer.stop_tokens,
)
return toks[:bsz]


Expand Down Expand Up @@ -555,13 +552,6 @@ def evaluate(self) -> None:
# Initialize tasks for the harness
task_manager = TaskManager(include_path=self.include_path)
task_dict = get_task_dict(self.tasks, task_manager)
task_types = set([t.task.OUTPUT_TYPE for t in get_task_list(task_dict)])
if len(task_types) > 1 and "generate_until" in task_types:
raise RuntimeError(
"Evaluating on multiple task types where any one task involves "
"generation is currently not supported. See the issue below for more info: "
"https://github.com/pytorch/torchtune/issues/1621"
)

# Run evaluation
t0 = time.time()
Expand Down
9 changes: 0 additions & 9 deletions tests/recipes/test_eleuther_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,12 +199,3 @@ def test_eval_recipe_errors_with_qat_quantizer(self, capsys, monkeypatch, tmpdir
match="QAT quantizers should only be used during quantization aware training",
):
runpy.run_path(TUNE_PATH, run_name="__main__")

@pytest.mark.integration_test
def test_eval_recipe_errors_with_generate_until_and_mc_tasks(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we instead have a test that this now actually works?

self, caplog, capsys, monkeypatch, tmpdir
):
# We can't currently specify both generate_until and mc_tasks in the same run
# b/c the KV cache won't be reset and the result will be different. This test
# catches that error
pass
6 changes: 3 additions & 3 deletions tests/torchtune/modules/model_fusion/test_fusion_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def __init__(self, dim, vocab_size):
def setup_caches(self, batch_size, dtype, *args, **kwargs):
self.cache_enabled = True

def caches_are_enabled(self):
def caches_are_setup(self):
return self.cache_enabled

def reset_caches(self):
Expand Down Expand Up @@ -144,9 +144,9 @@ def test_setup_cache(self, fused_model):
Test that the cache methods works as expected.
"""
fused_model.setup_caches(2, torch.float32)
assert fused_model.caches_are_enabled()
assert fused_model.caches_are_setup()
fused_model.reset_caches()
assert not fused_model.caches_are_enabled()
assert not fused_model.caches_are_setup()

def test_set_trainable_params(self, fused_model, encoder, decoder):
"""
Expand Down
3 changes: 3 additions & 0 deletions tests/torchtune/modules/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ def gqa_kv_cache(
kv_cache=kv_cache,
max_seq_len=max_seq_len,
)
attn.cache_enabled = True
fixed_init_model(attn)
attn.eval()
return attn
Expand Down Expand Up @@ -195,6 +196,7 @@ def mha_kv_cache(
kv_cache=kv_cache,
max_seq_len=max_seq_len,
)
attn.cache_enabled = True
fixed_init_model(attn)
attn.eval()
return attn
Expand Down Expand Up @@ -249,6 +251,7 @@ def mqa_kv_cache(
kv_cache=kv_cache,
max_seq_len=max_seq_len,
)
attn.cache_enabled = True
fixed_init_model(attn)
attn.eval()
return attn
Expand Down
4 changes: 2 additions & 2 deletions torchtune/generation/_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ def generate(
total_response_length = prompt_length + max_generated_tokens

generated_tokens = prompt.clone()
incremental_decoding = model.caches_are_enabled()
incremental_decoding = model.caches_are_setup()

# grab the correct max_seq_len to generate full causal masks/position ids
# this is the model's max cache len if incremental decoding, or the sequence
Expand Down Expand Up @@ -366,7 +366,7 @@ def generate(
tokens, logits = custom_generate_next_token(
model,
input_pos=curr_input_pos,
x=tokens,
x=tokens.clone(),
mask=curr_masks,
temperature=temperature,
top_k=top_k,
Expand Down
2 changes: 1 addition & 1 deletion torchtune/models/clip/_position_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,7 +570,7 @@ def _load_state_dict_hook(
if inpt_num_tokens != tgt_num_tokens or inpt_emb != tgt_emb:
raise ValueError(
"Expected embedding shape to be (..., num_tokens, tgt_emb) to match"
f" but found shapes {self.embedding.shape} and {state_dict[prefix+'embedding'].shape}"
f" but found shapes {self.embedding.shape} and {state_dict[prefix + 'embedding'].shape}"
)

if inpt_max_num_tiles_x != inpt_max_num_tiles_y:
Expand Down
2 changes: 1 addition & 1 deletion torchtune/models/gemma/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def __init__(
self.norm_embeddings = norm_embeddings
self.num_output_chunks = 0

def caches_are_enabled(self) -> bool:
def caches_are_setup(self) -> bool:
"""Check if the key value caches are setup."""
return self.layers[0].cache_enabled

Expand Down
10 changes: 9 additions & 1 deletion torchtune/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,12 @@

from .attention import MultiHeadAttention # noqa
from .attention_utils import create_block_causal_mask, packed_block_causal_mask
from .common_utils import reparametrize_as_dtype_state_dict_post_hook
from .common_utils import (
delete_kv_caches,
reparametrize_as_dtype_state_dict_post_hook,
setup_use_local_kv_cache,
use_persistent_kv_cache,
)
from .feed_forward import FeedForward # noqa
from .kv_cache import KVCache # noqa
from .layer_norm import Fp32LayerNorm # noqa
Expand Down Expand Up @@ -43,4 +48,7 @@
"reparametrize_as_dtype_state_dict_post_hook",
"create_block_causal_mask",
"packed_block_causal_mask",
"setup_use_local_kv_cache",
"delete_kv_caches",
"use_persistent_kv_cache",
]
8 changes: 7 additions & 1 deletion torchtune/modules/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,11 @@ def __init__(
# Use flex attention if supported and we are sample packing
self._attention_call = _sdpa_or_flex_attention()

# this flag indicates whether to update the kv-cache during forward
# passes. when disabled, we can have the cache setup but still
# perform normal forward passes
self.cache_enabled = False

def setup_cache(
self, batch_size: int, dtype: torch.dtype, max_seq_len: int
) -> None:
Expand All @@ -164,6 +169,7 @@ def setup_cache(
head_dim=self.head_dim,
dtype=dtype,
)
self.cache_enabled = True

def reset_cache(self):
"""Reset the key value caches."""
Expand Down Expand Up @@ -291,7 +297,7 @@ def forward(
k = self.k_norm(k)

# Update key-value cache
if self.kv_cache is not None:
if self.kv_cache is not None and self.cache_enabled:
k, v = self.kv_cache.update(k, v)

output = self._attention_call(
Expand Down
Loading