Skip to content
2 changes: 1 addition & 1 deletion recipes/configs/llama3_2_vision/11B_lora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ enable_activation_offloading: False
dtype: bf16

# Logging
output_dir: /tmp/full-llama3.2-vision-finetune
output_dir: /tmp/lora-llama3.2-vision-finetune
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

whoops

metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
log_dir: /tmp/Llama-3.2-11B-Vision-Instruct/logs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ enable_activation_offloading: False
dtype: bf16

# Logging
output_dir: /tmp/full-llama3.2-vision-finetune
output_dir: /tmp/lora-llama3.2-vision-finetune
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
log_dir: /tmp/Llama-3.2-11B-Vision-Instruct/logs
Expand Down
88 changes: 88 additions & 0 deletions recipes/configs/llama3_2_vision/11B_qlora.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# Config for multi-device QLoRA finetuning in lora_finetune_distributed.py
# using a Llama3.2 11B Vision Instruct model
#
# This config assumes that you've run the following command before launching:
# tune download meta-llama/Llama-3.2-11B-Vision-Instruct --output-dir /tmp/Llama-3.2-11B-Vision-Instruct
#
# To launch on 2 devices, run the following command from root:
# tune run --nproc_per_node 2 lora_finetune_distributed --config llama3_2_vision/11B_qlora
#
# You can add specific overrides through the command line. For example
# to override the checkpointer directory while launching training:
# tune run --nproc_per_node 2 lora_finetune_distributed --config llama3_2_vision/11B_qlora checkpointer.checkpoint_dir=<YOUR_CHECKPOINT_DIR>
#
# This config works best when the model is being fine-tuned on 2+ GPUs.
# For single device QLoRA finetuning please use 11B_qlora_single_device.yaml

# Model arguments
model:
_component_: torchtune.models.llama3_2_vision.qlora_llama3_2_vision_11b
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you confirm that this is identical to lora except for this line? Whenever you do a merge you should re-check that assumption.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, did this

decoder_trainable: "frozen"
encoder_trainable: "lora"
fusion_trainable: "lora"
lora_attn_modules: ['q_proj', 'v_proj']
apply_lora_to_mlp: False
apply_lora_to_output: False
lora_rank: 8
lora_alpha: 16
lora_dropout: 0.0
image_size: 560 # Make sure this matches the image_size in tokenizer

# Transform
tokenizer:
_component_: torchtune.models.llama3_2_vision.llama3_2_vision_transform
path: /tmp/Llama-3.2-11B-Vision-Instruct/original/tokenizer.model
image_size: 560
max_seq_len: 8192

# Checkpointer
checkpointer:
_component_: torchtune.training.FullModelMetaCheckpointer
checkpoint_dir: /tmp/Llama-3.2-11B-Vision-Instruct/original/
checkpoint_files: [consolidated.pth]
recipe_checkpoint: null
output_dir: /tmp/Llama-3.2-11B-Vision-Instruct/
model_type: LLAMA3_VISION
resume_from_checkpoint: False

# Dataset
dataset:
_component_: torchtune.datasets.multimodal.the_cauldron_dataset
subset: ocrvqa
seed: null
shuffle: True
collate_fn: torchtune.data.padded_collate_tiled_images_and_mask

# Fine-tuning arguments
epochs: 1
max_steps_per_epoch: null
batch_size: 2
gradient_accumulation_steps: 4
optimizer:
_component_: torch.optim.AdamW
fused: True
weight_decay: 0.01
lr: 2e-5
lr_scheduler:
_component_: torchtune.training.lr_schedulers.get_cosine_schedule_with_warmup
num_warmup_steps: 100
loss:
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
clip_grad_norm: 1.0
compile: False # set it to True for better memory and performance

# Training env
device: cuda

# Memory management
enable_activation_checkpointing: True
enable_activation_offloading: False
dtype: bf16

# Logging
output_dir: /tmp/qlora-llama3.2-vision-finetune
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
log_dir: /tmp/Llama-3.2-11B-Vision-Instruct/logs
log_every_n_steps: 1
log_peak_memory_stats: False
113 changes: 113 additions & 0 deletions recipes/configs/llama3_2_vision/11B_qlora_single_device.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
# Config for single device QLoRA finetuning in lora_finetune_single_device.py
# using a Llama3.2 11B Vision Instruct model
#
# This config assumes that you've run the following command before launching:
# tune download meta-llama/Llama-3.2-11B-Vision-Instruct --output-dir /tmp/Llama-3.2-11B-Vision-Instruct
#
# To launch on a single device, run the following command from root:
# tune run lora_finetune_single_device --config llama3_2_vision/11B_qlora_single_device
#
# You can add specific overrides through the command line. For example
# to override the checkpointer directory while launching training:
# tune run lora_finetune_single_device --config llama3_2_vision/11B_qlora_single_device checkpointer.checkpoint_dir=<YOUR_CHECKPOINT_DIR>
#
# This config works only for training on single device.

# Model arguments
model:
_component_: torchtune.models.llama3_2_vision.qlora_llama3_2_vision_11b
decoder_trainable: "frozen"
encoder_trainable: "lora"
fusion_trainable: "lora"
lora_attn_modules: ['q_proj', 'v_proj']
apply_lora_to_mlp: False
apply_lora_to_output: False
lora_rank: 8
lora_alpha: 16
lora_dropout: 0.0
image_size: 560 # Make sure this matches the image_size in tokenizer

# Transform
tokenizer:
_component_: torchtune.models.llama3_2_vision.llama3_2_vision_transform
path: /tmp/Llama-3.2-11B-Vision-Instruct/original/tokenizer.model
image_size: 560
max_seq_len: 8192

# Checkpointer
checkpointer:
_component_: torchtune.training.FullModelMetaCheckpointer
checkpoint_dir: /tmp/Llama-3.2-11B-Vision-Instruct/original/
checkpoint_files: [consolidated.pth]
recipe_checkpoint: null
output_dir: /tmp/Llama-3.2-11B-Vision-Instruct/
model_type: LLAMA3_VISION
resume_from_checkpoint: False

# Dataset
dataset:
_component_: torchtune.datasets.multimodal.the_cauldron_dataset
subset: ocrvqa
seed: null
shuffle: True
collate_fn: torchtune.data.padded_collate_tiled_images_and_mask

# Fine-tuning arguments
epochs: 1
max_steps_per_epoch: null
batch_size: 2
gradient_accumulation_steps: 16
optimizer:
_component_: torch.optim.AdamW
fused: True
weight_decay: 0.01
lr: 2e-5
optimizer_in_bwd: False
lr_scheduler:
_component_: torchtune.training.lr_schedulers.get_cosine_schedule_with_warmup
num_warmup_steps: 100
loss:
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
clip_grad_norm: 1.0
compile: False # set it to True for better memory and performance

# Training env
device: cuda

# Memory management
enable_activation_checkpointing: True
enable_activation_offloading: False
dtype: bf16

# Logging
output_dir: /tmp/qlora-llama3.2-vision-finetune
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
log_dir: /tmp/Llama-3.2-11B-Vision-Instruct/logs
log_every_n_steps: 1
log_peak_memory_stats: False

# Profiler (disabled)
profiler:
_component_: torchtune.training.setup_torch_profiler
enabled: False

#Output directory of trace artifacts
output_dir: ${output_dir}/profiling_outputs

#`torch.profiler.ProfilerActivity` types to trace
cpu: True
cuda: True

#trace options passed to `torch.profiler.profile`
profile_memory: True
with_stack: False
record_shapes: True
with_flops: False

# `torch.profiler.schedule` options:
# wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat
wait_steps: 1
warmup_steps: 2
active_steps: 1
num_cycles: 1
9 changes: 3 additions & 6 deletions tests/torchtune/modules/low_precision/test_nf4_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,6 @@ class TestNF4Linear:
Class for testing our NF4Linear implementation.
"""

def test_bias_unsupported(self):
with pytest.raises(RuntimeError, match="does not currently support biases"):
_ = FrozenNF4Linear(1, 1, bias=True)

@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32])
def test_parameters(self, dtype):
nf4_linear = FrozenNF4Linear(512, 512, device="cpu", dtype=dtype)
Expand All @@ -59,9 +55,10 @@ def test_state_dict(self, dtype):
assert isinstance(state_dict["weight"], NF4Tensor)

@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32])
def test_output_dtype(self, dtype):
@pytest.mark.parametrize("bias", [True, False])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the point of adding bias to this test? The dtype isn't changing and you're only checking the dtype?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed it's pretty trivial but I'd like to at least build FrozenNF4Linear with bias somewhere in our unit tests, and the overhead of this unit test is tiny

def test_output_dtype(self, dtype, bias):
# Test to ensure W4 A16 produces A16 / W4A32 produces A32
nf4_linear = FrozenNF4Linear(512, 512, device="cpu", dtype=dtype)
nf4_linear = FrozenNF4Linear(512, 512, device="cpu", dtype=dtype, bias=bias)
inp = torch.randn(2, 512, dtype=dtype, requires_grad=True)
out = nf4_linear(inp)
assert out.dtype == dtype
Expand Down
70 changes: 34 additions & 36 deletions tests/torchtune/modules/peft/test_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,19 +61,22 @@ def lora_linear(self, in_dim, out_dim) -> LoRALinear:
return lora_linear

@pytest.fixture
def qlora_linear(self, in_dim, out_dim) -> LoRALinear:
with training.set_default_dtype(torch.bfloat16):
qlora_linear = LoRALinear(
in_dim=512,
out_dim=512,
rank=RANK,
alpha=ALPHA,
use_bias=False,
quantize_base=True,
)
fixed_init_model(qlora_linear, dtype=torch.bfloat16)
def qlora_linear(self):
def create_qlora_linear(use_bias, dtype):
with training.set_default_dtype(dtype):
qlora_linear = LoRALinear(
in_dim=512,
out_dim=512,
rank=RANK,
alpha=ALPHA,
use_bias=use_bias,
quantize_base=True,
)
# fixed_init_model(qlora_linear, dtype=torch.bfloat16)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you mean to comment this out?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yeah oops, lemme update

return qlora_linear

return create_qlora_linear

@torch.no_grad()
def set_dummy_weights_for_merge(self, lora_module):
lora_module.lora_a.weight = nn.Parameter(
Expand All @@ -97,50 +100,45 @@ def test_forward(self, inputs, lora_linear, out_dim) -> None:
assert actual.shape == (BSZ, SEQ_LEN, out_dim)
torch.testing.assert_close(actual.mean(), expected, atol=1e-4, rtol=1e-6)

def test_lora_weight_nf4_when_quantized(self, qlora_linear):
@pytest.mark.parametrize("use_bias", [True, False])
def test_lora_weight_nf4_when_quantized(self, use_bias, qlora_linear):
qlora_linear = qlora_linear(use_bias=use_bias, dtype=torch.bfloat16)
assert isinstance(qlora_linear.weight, NF4Tensor)

def test_quantize_with_bias_raises(self):
with pytest.raises(NotImplementedError, match="does not support bias"):
LoRALinear(
in_dim=512,
out_dim=512,
rank=RANK,
alpha=ALPHA,
use_bias=True,
quantize_base=True,
)

@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32])
def test_qlora_parity(self, dtype):
if use_bias:
assert not isinstance(qlora_linear.bias, NF4Tensor)
assert qlora_linear.bias.dtype == torch.bfloat16

# Note: with bfloat16 F.linear(x, weight, bias) != F.linear(x, weight) + bias.
# This means we would get different results (irrespective of QLoRA).
# So we leave that test case out
@pytest.mark.parametrize(
"use_bias, dtype",
[(False, torch.bfloat16), (True, torch.float32), (False, torch.float32)],
)
def test_qlora_parity(self, use_bias, dtype, qlora_linear):
qlora_linear = qlora_linear(use_bias=use_bias, dtype=dtype)
with training.set_default_dtype(dtype):
qlora_linear = LoRALinear(
in_dim=512,
out_dim=512,
rank=RANK,
alpha=ALPHA,
use_bias=False,
quantize_base=True,
)
lora_linear = LoRALinear(
in_dim=512,
out_dim=512,
rank=RANK,
alpha=ALPHA,
use_bias=False,
use_bias=use_bias,
quantize_base=False,
)

# set weight of lora_linear to unquantized weight of qlora_linear and check
# parity.
lora_linear.weight.data = qlora_linear.weight.to(dtype)

if use_bias:
lora_linear.bias.data = qlora_linear.bias.detach().clone()
# Ensure forward passes are the same. This is because LoRALinear should use a special
# quantized linear operator that runs compute in higher prec (but only saves the 4 bit quantized tensor)
# for autograd.
inputs = torch.randn(BSZ, SEQ_LEN, 512, dtype=dtype)
lora_linear_out = lora_linear(inputs)
qlora_linear_out = qlora_linear(inputs)

torch.testing.assert_close(lora_linear_out, qlora_linear_out)

@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32])
Expand Down
8 changes: 8 additions & 0 deletions torchtune/_recipe_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,10 @@ class Recipe:
name="llama3_2_vision/11B_lora_single_device",
file_path="llama3_2_vision/11B_lora_single_device.yaml",
),
Config(
name="llama3_2_vision/11B_qlora_single_device",
file_path="llama3_2_vision/11B_qlora_single_device.yaml",
),
],
supports_distributed=False,
),
Expand Down Expand Up @@ -289,6 +293,10 @@ class Recipe:
name="llama3_2_vision/11B_lora",
file_path="llama3_2_vision/11B_lora.yaml",
),
Config(
name="llama3_2_vision/11B_qlora",
file_path="llama3_2_vision/11B_qlora.yaml",
),
],
supports_distributed=True,
),
Expand Down
6 changes: 3 additions & 3 deletions torchtune/models/clip/_component_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
from torchtune.modules import (
FeedForward,
Fp32LayerNorm,
FrozenNF4Linear,
MultiHeadAttention,
TanhGate,
TransformerSelfAttentionLayer,
)

Expand Down Expand Up @@ -170,12 +170,12 @@ def clip_mlp(
gate_proj = (
nn.Linear(in_dim, hidden_dim)
if not quantize_base
else FrozenNF4Linear(in_dim, hidden_dim)
else FrozenNF4Linear(in_dim, hidden_dim, bias=True)
)
down_proj = (
nn.Linear(hidden_dim, out_dim)
if not quantize_base
else FrozenNF4Linear(hidden_dim, out_dim)
else FrozenNF4Linear(hidden_dim, out_dim, bias=True)
)
return FeedForward(
gate_proj=gate_proj, down_proj=down_proj, up_proj=None, activation=activation
Expand Down
Loading