diff --git a/torchtune/models/clip/_component_builders.py b/torchtune/models/clip/_component_builders.py index 0940d49359..150261fd23 100644 --- a/torchtune/models/clip/_component_builders.py +++ b/torchtune/models/clip/_component_builders.py @@ -9,22 +9,26 @@ import torch from torch import nn - -from torchtune.modules.vision_transformer import VisionTransformer, CLSProjection -from torchtune.models.clip._position_embeddings import TokenPositionalEmbedding, TiledTokenPositionalEmbedding, TilePositionalEmbedding +from torchtune.models.clip._position_embeddings import ( + TiledTokenPositionalEmbedding, + TilePositionalEmbedding, + TokenPositionalEmbedding, +) from torchtune.modules import ( - TransformerSelfAttentionLayer, + FeedForward, + Fp32LayerNorm, MultiHeadAttention, TanhGate, - FeedForward, - Fp32LayerNorm + TransformerSelfAttentionLayer, ) from torchtune.modules.common_utils import reparametrize_as_dtype_state_dict_post_hook from torchtune.modules.peft import DoRALinear, LORA_ATTN_MODULES, LoRALinear +from torchtune.modules.vision_transformer import CLSProjection, VisionTransformer + def clip_vision_encoder( tile_size: int, @@ -43,7 +47,7 @@ def clip_vision_encoder( ) -> VisionTransformer: """ Builds the vision encoder associated with the clip model. This includes: - + - TransformerEncoderLayer - positional embeddings - CLS projection (optional) @@ -82,21 +86,25 @@ def clip_vision_encoder( """ assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads" - cls_projection = CLSProjection(embed_dim=embed_dim, cls_output_dim=cls_output_dim) if output_cls_projection else None + cls_projection = ( + CLSProjection(embed_dim=embed_dim, cls_output_dim=cls_output_dim) + if output_cls_projection + else None + ) # transformer layer self_attn = MultiHeadAttention( - embed_dim=embed_dim, - num_heads=num_heads, - num_kv_heads=num_heads, - head_dim=embed_dim // num_heads, - q_proj=nn.Linear(embed_dim, embed_dim, bias=attn_bias), - k_proj=nn.Linear(embed_dim, embed_dim, bias=attn_bias), - v_proj=nn.Linear(embed_dim, embed_dim, bias=attn_bias), - output_proj=nn.Linear(embed_dim, embed_dim, bias=attn_bias), - pos_embeddings=None, - attn_dropout=0.0, - is_causal=False, + embed_dim=embed_dim, + num_heads=num_heads, + num_kv_heads=num_heads, + head_dim=embed_dim // num_heads, + q_proj=nn.Linear(embed_dim, embed_dim, bias=attn_bias), + k_proj=nn.Linear(embed_dim, embed_dim, bias=attn_bias), + v_proj=nn.Linear(embed_dim, embed_dim, bias=attn_bias), + output_proj=nn.Linear(embed_dim, embed_dim, bias=attn_bias), + pos_embeddings=None, + attn_dropout=0.0, + is_causal=False, ) mlp = clip_mlp( in_dim=embed_dim, @@ -107,8 +115,8 @@ def clip_vision_encoder( transformer_layer = TransformerSelfAttentionLayer( attn=self_attn, mlp=mlp, - sa_norm= Fp32LayerNorm(embed_dim, eps=1e-5), - mlp_norm= Fp32LayerNorm(embed_dim, eps=1e-5), + sa_norm=Fp32LayerNorm(embed_dim, eps=1e-5), + mlp_norm=Fp32LayerNorm(embed_dim, eps=1e-5), sa_scale=None, mlp_scale=None, ) @@ -118,17 +126,21 @@ def clip_vision_encoder( pre_tile_pos_embed = None post_tile_pos_embed = None token_pos_embedding = TokenPositionalEmbedding( - embed_dim=embed_dim, - patch_size=patch_size, - tile_size=tile_size) + embed_dim=embed_dim, patch_size=patch_size, tile_size=tile_size + ) else: - pre_tile_pos_embed = TilePositionalEmbedding(max_num_tiles=max_num_tiles, embed_dim=embed_dim) - post_tile_pos_embed = TilePositionalEmbedding(max_num_tiles=max_num_tiles, embed_dim=embed_dim) + pre_tile_pos_embed = TilePositionalEmbedding( + max_num_tiles=max_num_tiles, embed_dim=embed_dim + ) + post_tile_pos_embed = TilePositionalEmbedding( + max_num_tiles=max_num_tiles, embed_dim=embed_dim + ) token_pos_embedding = TiledTokenPositionalEmbedding( - max_num_tiles=max_num_tiles, - embed_dim=embed_dim, - patch_size=patch_size, - tile_size=tile_size) + max_num_tiles=max_num_tiles, + embed_dim=embed_dim, + patch_size=patch_size, + tile_size=tile_size, + ) return VisionTransformer( num_layers=num_layers, @@ -145,13 +157,29 @@ def clip_vision_encoder( ) -def clip_mlp(in_dim: int, out_dim: int, hidden_dim: int, activation: nn.Module, quantize_base: bool = False) -> FeedForward: +def clip_mlp( + in_dim: int, + out_dim: int, + hidden_dim: int, + activation: nn.Module, + quantize_base: bool = False, +) -> FeedForward: """ Build the MLP layer associated with the clip model. """ - gate_proj = nn.Linear(in_dim, hidden_dim) if not quantize_base else FrozenNF4Linear(in_dim, hidden_dim) - down_proj = nn.Linear(hidden_dim, out_dim) if not quantize_base else FrozenNF4Linear(hidden_dim, out_dim) - return FeedForward(gate_proj=gate_proj, down_proj=down_proj, up_proj=None, activation=activation) + gate_proj = ( + nn.Linear(in_dim, hidden_dim) + if not quantize_base + else FrozenNF4Linear(in_dim, hidden_dim) + ) + down_proj = ( + nn.Linear(hidden_dim, out_dim) + if not quantize_base + else FrozenNF4Linear(hidden_dim, out_dim) + ) + return FeedForward( + gate_proj=gate_proj, down_proj=down_proj, up_proj=None, activation=activation + ) # ------------------ LoRA CLIP ------------------ @@ -222,7 +250,7 @@ def lora_clip_vision_encoder( quantize_base: (bool): Whether to quantize base model weights or not. Only applied to base weights within linear layers LoRA is applied to. The final output linear projection is not supported for quantization currently. - + Returns: VisionTransformer: Instantiation of VisionTransformer model. @@ -230,34 +258,38 @@ def lora_clip_vision_encoder( assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads" # TODO: add support for quantizing and LoRA for the final output projection - cls_projection = CLSProjection(embed_dim=embed_dim, cls_output_dim=cls_output_dim) if output_cls_projection else None + cls_projection = ( + CLSProjection(embed_dim=embed_dim, cls_output_dim=cls_output_dim) + if output_cls_projection + else None + ) # transformer layer self_attn = lora_clip_attention( - lora_modules=lora_modules, - embed_dim=embed_dim, - num_heads=num_heads, - num_kv_heads=num_heads, - head_dim=embed_dim // num_heads, - attn_dropout=0.0, + lora_modules=lora_modules, + embed_dim=embed_dim, + num_heads=num_heads, + num_kv_heads=num_heads, + head_dim=embed_dim // num_heads, + attn_dropout=0.0, + lora_rank=lora_rank, + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + use_dora=use_dora, + quantize_base=quantize_base, + ) + if apply_lora_to_mlp: + mlp = lora_clip_mlp( + in_dim=embed_dim, + hidden_dim=4 * embed_dim, + out_dim=embed_dim, + activation=activation(), lora_rank=lora_rank, lora_alpha=lora_alpha, + quantize_base=quantize_base, lora_dropout=lora_dropout, use_dora=use_dora, - quantize_base=quantize_base, - ) - if apply_lora_to_mlp: - mlp = lora_clip_mlp( - in_dim=embed_dim, - hidden_dim=4 * embed_dim, - out_dim=embed_dim, - activation=activation(), - lora_rank=lora_rank, - lora_alpha=lora_alpha, - quantize_base=quantize_base, - lora_dropout=lora_dropout, - use_dora=use_dora, - ) + ) else: mlp = clip_mlp( in_dim=embed_dim, @@ -269,8 +301,8 @@ def lora_clip_vision_encoder( transformer_layer = TransformerSelfAttentionLayer( attn=self_attn, mlp=mlp, - sa_norm= Fp32LayerNorm(embed_dim, eps=1e-5), - mlp_norm= Fp32LayerNorm(embed_dim, eps=1e-5), + sa_norm=Fp32LayerNorm(embed_dim, eps=1e-5), + mlp_norm=Fp32LayerNorm(embed_dim, eps=1e-5), sa_scale=None, mlp_scale=None, ) @@ -280,17 +312,21 @@ def lora_clip_vision_encoder( pre_tile_pos_embed = None post_tile_pos_embed = None token_pos_embedding = TokenPositionalEmbedding( - embed_dim=embed_dim, - patch_size=patch_size, - tile_size=tile_size) + embed_dim=embed_dim, patch_size=patch_size, tile_size=tile_size + ) else: - pre_tile_pos_embed = TilePositionalEmbedding(max_num_tiles=max_num_tiles, embed_dim=embed_dim) - post_tile_pos_embed = TilePositionalEmbedding(max_num_tiles=max_num_tiles, embed_dim=embed_dim) + pre_tile_pos_embed = TilePositionalEmbedding( + max_num_tiles=max_num_tiles, embed_dim=embed_dim + ) + post_tile_pos_embed = TilePositionalEmbedding( + max_num_tiles=max_num_tiles, embed_dim=embed_dim + ) token_pos_embedding = TiledTokenPositionalEmbedding( - max_num_tiles=max_num_tiles, - embed_dim=embed_dim, - patch_size=patch_size, - tile_size=tile_size) + max_num_tiles=max_num_tiles, + embed_dim=embed_dim, + patch_size=patch_size, + tile_size=tile_size, + ) model = VisionTransformer( num_layers=num_layers, @@ -467,19 +503,23 @@ def lora_clip_mlp( """ adapter_cls = DoRALinear if use_dora else LoRALinear gate_proj = adapter_cls( - in_dim=dim, + in_dim=in_dim, out_dim=hidden_dim, rank=lora_rank, alpha=lora_alpha, dropout=lora_dropout, quantize_base=quantize_base, + use_bias=True, ) down_proj = adapter_cls( in_dim=hidden_dim, - out_dim=dim, + out_dim=out_dim, rank=lora_rank, alpha=lora_alpha, dropout=lora_dropout, quantize_base=quantize_base, + use_bias=True, + ) + return FeedForward( + gate_proj=gate_proj, down_proj=down_proj, up_proj=None, activation=activation ) - return FeedForward(gate_proj=gate_proj, down_proj=down_proj, up_proj=None, activation=activation) diff --git a/torchtune/models/llama3_2_vision/_model_builders.py b/torchtune/models/llama3_2_vision/_model_builders.py index 77f3ed167b..b9f68e67db 100644 --- a/torchtune/models/llama3_2_vision/_model_builders.py +++ b/torchtune/models/llama3_2_vision/_model_builders.py @@ -8,34 +8,34 @@ from typing import List, Optional import torch +from torchtune.data._prompt_templates import _get_prompt_template, _TemplateType from torchtune.models.llama3_2_vision._component_builders import ( # noqa + llama3_2_vision_decoder, + llama3_2_vision_encoder, lora_llama3_2_vision_decoder, lora_llama3_2_vision_encoder, LoRATrainable, - llama3_2_vision_decoder, - llama3_2_vision_encoder, ) from torchtune.models.llama3_2_vision._encoder import Llama3VisionEncoder from torchtune.models.llama3_2_vision._transform import Llama3VisionTransform from torchtune.modules.model_fusion import DeepFusionModel -from torchtune.modules.tokenizers import parse_hf_tokenizer_json -from torchtune.data._prompt_templates import _TemplateType -from torchtune.data._prompt_templates import _get_prompt_template from torchtune.modules.peft import LORA_ATTN_MODULES +from torchtune.modules.tokenizers import parse_hf_tokenizer_json + def llama3_2_vision_11b( decoder_trainable: bool = False, encoder_trainable: bool = True, fusion_trainable: bool = True, - image_size: int = 560 - ) -> DeepFusionModel: - """ Llama 3.2 Vision 11B model + image_size: int = 560, +) -> DeepFusionModel: + """Llama 3.2 Vision 11B model Args: decoder_trainable (bool): Whether to make decoder params trainable. Default is False. encoder_trainable (bool): Whether to make encoder params trainable. Default is True. fusion_trainable (bool): Whether to make fusion params trainable. Default is True. - image_size (int): Base image size that images will be tiled and resized to. + image_size (int): Base image size that images will be tiled and resized to. Default is 560 for Instruct weights, use 448 for pre-trained. Returns: @@ -62,7 +62,7 @@ def llama3_2_vision_11b( num_kv_heads=8, embed_dim=4096, max_seq_len=131_072, - encoder_max_seq_len=128_080, + encoder_max_seq_len=128_080, # 20*6404 rope_base=500000.0, intermediate_dim=14336, ) @@ -76,8 +76,12 @@ def llama3_2_vision_11b( def llama3_2_vision_transform( - path: str, max_seq_len: int = 8192, image_size: int = 560, special_tokens_path: Optional[str] = None, prompt_template: Optional[_TemplateType] = None - ) -> Llama3VisionTransform: + path: str, + max_seq_len: int = 8192, + image_size: int = 560, + special_tokens_path: Optional[str] = None, + prompt_template: Optional[_TemplateType] = None, +) -> Llama3VisionTransform: """ Data Transforms (including Tokenizer) for Llama3 Vision. @@ -85,21 +89,27 @@ def llama3_2_vision_transform( path (str): path to the tokenizer max_seq_len (int): maximum sequence length for tokenizing a single list of messages, after which the input will be truncated. - image_size (int): Base image size that images will be tiled and resized to. + image_size (int): Base image size that images will be tiled and resized to. Default is 560 for Instruct weights, use 448 for pre-trained. special_tokens_path (Optional[str]): Path to ``tokenizer.json`` from Hugging Face - model files that contains all registered special tokens, or a local json file + model files that contains all registered special tokens, or a local json file structured similarly. Default is None to use the canonical Llama3 special tokens. prompt_template (Optional[_TemplateType]): optional specified prompt template. If a string, it is assumed to be the dotpath of a :class:`~torchtune.data.PromptTemplateInterface` class. If a dictionary, it is assumed to be a custom prompt template mapping role to the prepend/append tags. - + Returns: Llama3VisionTransform: Instantiation of the Llama 3.2 vision transform """ - special_tokens = parse_hf_tokenizer_json(special_tokens_path) if special_tokens_path is not None else None - template = _get_prompt_template(prompt_template) if prompt_template is not None else None + special_tokens = ( + parse_hf_tokenizer_json(special_tokens_path) + if special_tokens_path is not None + else None + ) + template = ( + _get_prompt_template(prompt_template) if prompt_template is not None else None + ) return Llama3VisionTransform( path=path, special_tokens=special_tokens, @@ -115,7 +125,7 @@ def llama3_2_vision_transform( def lora_llama3_2_vision_11b( lora_attn_modules: List[LORA_ATTN_MODULES], - decoder_trainable: str = "frozen", + decoder_trainable: str = "frozen", encoder_trainable: str = "lora", fusion_trainable: str = "lora", apply_lora_to_mlp: bool = False, @@ -125,7 +135,7 @@ def lora_llama3_2_vision_11b( lora_dropout: float = 0.0, use_dora: bool = False, quantize_base: bool = False, - image_size: int = 560 + image_size: int = 560, ) -> DeepFusionModel: """ Return a version of Llama3.2 vision (an instance of :func:`~torchtune.modules.model_fusion.DeepFusionModel`) @@ -135,11 +145,11 @@ def lora_llama3_2_vision_11b( lora_attn_modules (List[LORA_ATTN_MODULES]): list of which linear layers LoRA should be applied to in each self-attention block. Options are ``{"q_proj", "k_proj", "v_proj", "output_proj"}``. - decoder_trainable (str): Option to set decoder params as fully trainble (full), lora trainable (lora), + decoder_trainable (str): Option to set decoder params as fully trainble (full), lora trainable (lora), or frozen (frozen). The default is "frozen". - encoder_trainable (str): Option to set encoder params as fully trainble (full), lora trainable (lora), + encoder_trainable (str): Option to set encoder params as fully trainble (full), lora trainable (lora), or frozen (frozen). The default is "lora". - fusion_trainable (str): Option to set fusion params as fully trainble (full), lora trainable (lora), + fusion_trainable (str): Option to set fusion params as fully trainble (full), lora trainable (lora), or frozen (frozen). The default is "lora". apply_lora_to_mlp (bool): whether to apply LoRA to the MLP in each transformer layer. Default: False @@ -151,7 +161,7 @@ def lora_llama3_2_vision_11b( quantize_base: (bool): Whether to quantize base model weights or not. Only applied to base weights within linear layers LoRA is applied to. The final output linear projection is not supported for quantization currently. - image_size (int): Base image size that images will be tiled and resized to. + image_size (int): Base image size that images will be tiled and resized to. Default is 560 for Instruct weights, use 448 for pre-trained. Returns: @@ -197,8 +207,8 @@ def lora_llama3_2_vision_11b( num_heads=32, num_kv_heads=8, embed_dim=4096, - max_seq_len=8192, - encoder_max_seq_len=64040, + max_seq_len=131_072, + encoder_max_seq_len=128_080, # 20*6404 rope_base=500000.0, intermediate_dim=14336, lora_rank=lora_rank,