Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions tests/cache_artifacts.sh
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ SMALL_MODEL_URLS=(
"https://ossci-datasets.s3.amazonaws.com/torchtune/small-ckpt-hf-03082024.pt"
"https://ossci-datasets.s3.amazonaws.com/torchtune/small-ckpt-tune-llama3-05052024.pt"
"https://ossci-datasets.s3.amazonaws.com/torchtune/small-ckpt-hf-reward-07122024.pt"
"https://ossci-datasets.s3.amazonaws.com/torchtune/small-ckpt-meta-vision-10172024.pt"
"https://ossci-datasets.s3.amazonaws.com/torchtune/small-ckpt-hf-vision-10172024.pt"

)
FULL_MODEL_URL=("s3://pytorch-multimodal/llama2-7b-torchtune.pt")
TOKENIZER_URLS=(
Expand Down
135 changes: 117 additions & 18 deletions tests/recipes/test_eleuther_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,40 @@
import pytest

from tests.common import TUNE_PATH
from tests.recipes.utils import llama2_test_config, write_hf_ckpt_config
from tests.test_utils import CKPT_MODEL_PATHS
from tests.recipes.utils import (
llama2_test_config,
llama3_2_vision_test_config,
write_hf_ckpt_config,
write_hf_vision_ckpt_config,
)
from tests.test_utils import CKPT_MODEL_PATHS, gpu_test


class TestEleutherEval:
@pytest.fixture
def hide_correct_version_number(self, monkeypatch):
import importlib.metadata

import_orig = importlib.metadata.version

def mocked_import(name, *args, **kwargs):
if name == "lm-eval":
return "0.4.4" # Hardcode wrong version number
return import_orig(name, *args, **kwargs)

monkeypatch.setattr(importlib.metadata, "version", mocked_import)

@pytest.fixture
def expected_vision_acc(self):
return {
"Science": 0.35,
"Biology": 0.25,
"Chemistry": 0.25,
"Geography": 0.5,
"Math": 0.0,
"Physics": 0.75,
}

@pytest.mark.parametrize(
"eval_name, expected_acc, bsz",
[
Expand Down Expand Up @@ -74,22 +103,9 @@ def test_torchtune_checkpoint_eval_results(
acc_result = float(search_results.group(1))
assert math.isclose(acc_result, expected_acc, abs_tol=0.05)

@pytest.fixture
def hide_correct_version_number(self, monkeypatch):
import importlib.metadata

import_orig = importlib.metadata.version

def mocked_import(name, *args, **kwargs):
if name == "lm-eval":
return "0.4.4" # Hardcode wrong version number
return import_orig(name, *args, **kwargs)

monkeypatch.setattr(importlib.metadata, "version", mocked_import)

@pytest.mark.integration_test
@pytest.mark.usefixtures("hide_correct_version_number")
def test_eval_recipe_errors_without_lm_eval(self, capsys, monkeypatch, tmpdir):
def test_eval_recipe_errors_without_lm_eval(self, monkeypatch, tmpdir):
ckpt = "llama2_tune"
ckpt_path = Path(CKPT_MODEL_PATHS[ckpt])
ckpt_dir = ckpt_path.parent
Expand Down Expand Up @@ -123,7 +139,7 @@ def test_eval_recipe_errors_without_lm_eval(self, capsys, monkeypatch, tmpdir):

@pytest.mark.integration_test
def test_eval_recipe_errors_with_quantization_hf_checkpointer(
self, capsys, monkeypatch, tmpdir
self, monkeypatch, tmpdir
):
ckpt = "llama2_hf"
ckpt_path = Path(CKPT_MODEL_PATHS[ckpt])
Expand Down Expand Up @@ -162,7 +178,7 @@ def test_eval_recipe_errors_with_quantization_hf_checkpointer(
runpy.run_path(TUNE_PATH, run_name="__main__")

@pytest.mark.integration_test
def test_eval_recipe_errors_with_qat_quantizer(self, capsys, monkeypatch, tmpdir):
def test_eval_recipe_errors_with_qat_quantizer(self, monkeypatch, tmpdir):
ckpt = "llama2_tune"
ckpt_path = Path(CKPT_MODEL_PATHS[ckpt])
ckpt_dir = ckpt_path.parent
Expand Down Expand Up @@ -194,3 +210,86 @@ def test_eval_recipe_errors_with_qat_quantizer(self, capsys, monkeypatch, tmpdir
match="QAT quantizers should only be used during quantization aware training",
):
runpy.run_path(TUNE_PATH, run_name="__main__")

@pytest.mark.integration_test
@gpu_test(gpu_count=1)
def test_meta_eval_vision(self, caplog, monkeypatch, tmpdir, expected_vision_acc):
ckpt = "llama3_2_vision_meta"
ckpt_path = Path(CKPT_MODEL_PATHS[ckpt])
ckpt_dir = ckpt_path.parent

cmd = f"""
tune run eleuther_eval \
--config llama3_2_vision/11B_evaluation \
output_dir={tmpdir} \
checkpointer=torchtune.training.FullModelMetaCheckpointer \
checkpointer.checkpoint_dir='{ckpt_dir}' \
checkpointer.checkpoint_files=[{ckpt_path}] \
~checkpointer.checkpoint_files.filename_format \
~checkpointer.checkpoint_files.max_filename \
checkpointer.output_dir={tmpdir} \
checkpointer.model_type=LLAMA3_VISION \
tokenizer.path=/tmp/test-artifacts/tokenizer_llama3.model \
tokenizer.prompt_template=null \
limit=4 \
dtype=bf16 \
device=cuda \
""".split()

model_config = llama3_2_vision_test_config()
cmd = cmd + model_config

monkeypatch.setattr(sys, "argv", cmd)
with pytest.raises(SystemExit, match=""):
runpy.run_path(TUNE_PATH, run_name="__main__")

out = caplog.text

pattern = r"^\|\s*(?:-\s*)?([^\|]+?)\s*\|\s*(\d+)\s*\|.*?\|.*?\|acc\s*\|\s*↑\s*\|\s*([\d.]+)"

matches = re.findall(pattern, out, re.MULTILINE)
for task_name, _, accuracy in matches:
assert math.isclose(float(accuracy), expected_vision_acc[task_name])

@pytest.mark.integration_test
@gpu_test(gpu_count=1)
def test_hf_eval_vision(self, caplog, monkeypatch, tmpdir, expected_vision_acc):
ckpt = "llama3_2_vision_hf"
ckpt_path = Path(CKPT_MODEL_PATHS[ckpt])
ckpt_dir = ckpt_path.parent

# Config file needed for model conversion.
write_hf_vision_ckpt_config(ckpt_dir)

cmd = f"""
tune run eleuther_eval \
--config llama3_2_vision/11B_evaluation \
output_dir={tmpdir} \
checkpointer=torchtune.training.FullModelHFCheckpointer \
checkpointer.checkpoint_dir='{ckpt_dir}' \
checkpointer.checkpoint_files=[{ckpt_path}]\
~checkpointer.checkpoint_files.filename_format \
~checkpointer.checkpoint_files.max_filename \
checkpointer.output_dir={tmpdir} \
checkpointer.model_type=LLAMA3_VISION \
tokenizer.path=/tmp/test-artifacts/tokenizer_llama3.model \
tokenizer.prompt_template=null \
limit=4 \
dtype=bf16 \
device=cuda \
""".split()

model_config = llama3_2_vision_test_config()
cmd = cmd + model_config

monkeypatch.setattr(sys, "argv", cmd)
with pytest.raises(SystemExit, match=""):
runpy.run_path(TUNE_PATH, run_name="__main__")

out = caplog.text

pattern = r"^\|\s*(?:-\s*)?([^\|]+?)\s*\|\s*(\d+)\s*\|.*?\|.*?\|acc\s*\|\s*↑\s*\|\s*([\d.]+)"

matches = re.findall(pattern, out, re.MULTILINE)
for task_name, _, accuracy in matches:
assert math.isclose(float(accuracy), expected_vision_acc[task_name])
73 changes: 73 additions & 0 deletions tests/recipes/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,58 @@ def llama3_test_config() -> List[str]:
]


def llama3_2_vision_test_config() -> List[str]:
return [
"model=tests.recipes.utils.dummy_vision_model",
"tokenizer._component_=torchtune.models.llama3_2_vision._transform.Llama3VisionTransform",
"tokenizer.patch_size=9",
"tokenizer.max_num_tiles=2",
"tokenizer.tile_size=18",
"tokenizer.max_seq_len=4096",
]


def dummy_vision_model():
from torchtune.models.llama3_2_vision._component_builders import (
llama3_2_vision_decoder,
llama3_2_vision_encoder,
)
from torchtune.modules.model_fusion import DeepFusionModel

vision_encoder = llama3_2_vision_encoder(
clip_embed_dim=128,
clip_num_layers=4,
num_heads=4,
tile_size=18,
patch_size=9,
max_num_tiles=2,
in_channels=3,
clip_hidden_states=[0, 1],
num_layers_projection=2,
decoder_embed_dim=128,
)
vision_decoder = llama3_2_vision_decoder(
vocab_size=128256,
num_layers=4,
fusion_interval=2,
num_special_tokens=2,
num_heads=8,
num_kv_heads=4,
embed_dim=128,
max_seq_len=4096,
encoder_max_seq_len=4096,
)

model = DeepFusionModel(
encoder=vision_encoder,
decoder=vision_decoder,
encoder_trainable=False,
decoder_trainable=False,
fusion_trainable=False,
)
return model


def lora_llama2_test_config(
lora_attn_modules,
apply_lora_to_mlp: bool = False,
Expand Down Expand Up @@ -199,6 +251,27 @@ def write_hf_ckpt_config(ckpt_dir: str):
json.dump(config, f)


def write_hf_vision_ckpt_config(ckpt_dir: str):
config = {
"text_config": {
"num_attention_heads": 8,
"num_key_value_heads": 4,
"hidden_size": 128,
"vocab_size": 128256,
"cross_attention_layers": [1, 4],
},
"vision_config": {
"hidden_size": 128,
"image_size": 18,
"max_num_tiles": 2,
"supported_aspect_ratios": [[1, 1], [1, 2], [2, 1]],
},
}
config_file = Path.joinpath(Path(ckpt_dir), "config.json")
with config_file.open("w") as f:
json.dump(config, f)


MODEL_TEST_CONFIGS = {
"llama2": llama2_test_config(),
"llama3": llama3_test_config(),
Expand Down
2 changes: 2 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
"llama2_reward_hf": "/tmp/test-artifacts/small-ckpt-hf-reward-07122024.pt",
"llama3_tune": "/tmp/test-artifacts/small-ckpt-tune-llama3-05052024.pt",
"llama2_7b": "/tmp/test-artifacts/llama2-7b-torchtune.pt",
"llama3_2_vision_hf": "/tmp/test-artifacts/small-ckpt-hf-vision-10172024.pt",
"llama3_2_vision_meta": "/tmp/test-artifacts/small-ckpt-meta-vision-10172024.pt",
}

TOKENIZER_PATHS = {
Expand Down
39 changes: 38 additions & 1 deletion tests/torchtune/modules/test_transformer_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,11 +183,48 @@ def transformer_layer(
transformer_layer.eval()
return transformer_layer

@mps_ignored_test()
def test_forward_kv_cache(
self,
input: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
transformer_layer: TransformerCrossAttentionLayer,
input_params: Tuple[int, int, int, int],
):

b, _, encoder_seq_len, _ = input_params
transformer_layer.setup_caches(
batch_size=b,
dtype=torch.float32,
encoder_max_seq_len=encoder_seq_len,
decoder_max_seq_len=None,
)
input_x, input_y, mask = input
with torch.no_grad():
# make an initial forward pass which should fill the encoder cache
first_output = transformer_layer(
input_x,
encoder_input=input_y,
encoder_mask=mask,
)
# the second pass should just retrieve from the kv-cache and produce
# identical outputs
output = transformer_layer(
input_x,
encoder_input=None,
encoder_mask=mask,
)

assert_expected(output.mean(), torch.tensor(1.7762), atol=1e-8, rtol=1e-3)
assert_expected(output.shape, input_x.shape)

assert_expected(first_output.shape, output.shape)
assert_expected(first_output.mean(), output.mean())

@mps_ignored_test()
def test_forward(
self,
input: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
transformer_layer: TransformerSelfAttentionLayer,
transformer_layer: TransformerCrossAttentionLayer,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

idek why we're typing tests

) -> None:
input_x, input_y, mask = input
with torch.no_grad():
Expand Down
2 changes: 1 addition & 1 deletion torchtune/models/gemma2/_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ def forward(
q = self.q_norm(q)

if y is None:
if self.kv_cache is None:
if self.kv_cache is None or not self.cache_enabled:
raise ValueError(
"Must provide y input or use kv_cache to enable streaming decoding"
)
Expand Down
26 changes: 13 additions & 13 deletions torchtune/modules/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def forward(
and before the softmax. Either:

A boolean tensor with shape ``[b x s x s]``, ``[b x s x self.encoder_max_cache_seq_len]``,
or ``[b x s x self.encoder_max_cache_seq_len]`` if using KV-cacheing with encoder/decoder layers.
or ``[b x s x self.decoder_max_cache_seq_len]`` if using KV-cacheing with encoder/decoder layers.
A value of True in row ``i`` and column ``j`` means token ``i`` attends to token ``j``. A value of False means
token ``i`` does not attend to token ``j``. If no mask is specified, a causal mask
is used by default.
Expand Down Expand Up @@ -249,7 +249,7 @@ def forward(
q = self.q_norm(q)

if y is None:
if self.kv_cache is None:
if self.kv_cache is None or not self.cache_enabled:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should have been added in #1763

raise ValueError(
"Must provide y input or use kv_cache to enable streaming decoding"
)
Expand All @@ -273,21 +273,21 @@ def forward(
k = k.transpose(1, 2)
v = v.transpose(1, 2)

# Normalize k
if self.k_norm is not None:
k = self.k_norm(k)

# Update key-value cache
if self.kv_cache is not None and self.cache_enabled:
k, v = self.kv_cache.update(k, v)

# If needed, expand the key and value tensors to have the same shape
# as the query tensor by copying values across the relevant dim
# k,v shape: [b, n_h, s, h_d]
if self.num_heads != self.num_kv_heads:
expand_shape = (-1, -1, q_per_kv, -1, -1)
k = k.unsqueeze(2).expand(expand_shape).flatten(1, 2)
v = v.unsqueeze(2).expand(expand_shape).flatten(1, 2)

# Normalize k
if self.k_norm is not None:
k = self.k_norm(k)
# If needed, expand the key and value tensors to have the same shape
# as the query tensor by copying values across the relevant dim
# k,v shape: [b, n_kv, s, h_d] -> [b, n_h, s, h_d]
if self.num_heads != self.num_kv_heads:
expand_shape = (b, self.num_kv_heads, q_per_kv, -1, self.head_dim)
k = k.unsqueeze(2).expand(expand_shape).flatten(1, 2)
v = v.unsqueeze(2).expand(expand_shape).flatten(1, 2)

output = self._attention_call(
q,
Expand Down
2 changes: 1 addition & 1 deletion torchtune/modules/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -781,7 +781,7 @@ def setup_caches(
isinstance(l, TransformerCrossAttentionLayer) for l in self.modules()
)
has_decoder_layers = any(
isinstance(l, TransformerSelfAttentionLayer) for l in self.layers
isinstance(l, TransformerSelfAttentionLayer) for l in self.modules()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

:O

)
if has_encoder_layers:
if encoder_max_seq_len is not None:
Expand Down
Loading