Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
180 changes: 110 additions & 70 deletions torchtune/models/clip/_component_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,22 +9,26 @@

import torch
from torch import nn

from torchtune.modules.vision_transformer import VisionTransformer, CLSProjection
from torchtune.models.clip._position_embeddings import TokenPositionalEmbedding, TiledTokenPositionalEmbedding, TilePositionalEmbedding
from torchtune.models.clip._position_embeddings import (
TiledTokenPositionalEmbedding,
TilePositionalEmbedding,
TokenPositionalEmbedding,
)

from torchtune.modules import (
TransformerSelfAttentionLayer,
FeedForward,
Fp32LayerNorm,
MultiHeadAttention,
TanhGate,
FeedForward,
Fp32LayerNorm
TransformerSelfAttentionLayer,
)

from torchtune.modules.common_utils import reparametrize_as_dtype_state_dict_post_hook

from torchtune.modules.peft import DoRALinear, LORA_ATTN_MODULES, LoRALinear

from torchtune.modules.vision_transformer import CLSProjection, VisionTransformer
Comment on lines +12 to +30
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

precommit hook reordering



def clip_vision_encoder(
tile_size: int,
Expand All @@ -43,7 +47,7 @@ def clip_vision_encoder(
) -> VisionTransformer:
"""
Builds the vision encoder associated with the clip model. This includes:

- TransformerEncoderLayer
- positional embeddings
- CLS projection (optional)
Expand Down Expand Up @@ -82,21 +86,25 @@ def clip_vision_encoder(
"""
assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"

cls_projection = CLSProjection(embed_dim=embed_dim, cls_output_dim=cls_output_dim) if output_cls_projection else None
cls_projection = (
CLSProjection(embed_dim=embed_dim, cls_output_dim=cls_output_dim)
if output_cls_projection
else None
)

# transformer layer
self_attn = MultiHeadAttention(
embed_dim=embed_dim,
num_heads=num_heads,
num_kv_heads=num_heads,
head_dim=embed_dim // num_heads,
q_proj=nn.Linear(embed_dim, embed_dim, bias=attn_bias),
k_proj=nn.Linear(embed_dim, embed_dim, bias=attn_bias),
v_proj=nn.Linear(embed_dim, embed_dim, bias=attn_bias),
output_proj=nn.Linear(embed_dim, embed_dim, bias=attn_bias),
pos_embeddings=None,
attn_dropout=0.0,
is_causal=False,
embed_dim=embed_dim,
num_heads=num_heads,
num_kv_heads=num_heads,
head_dim=embed_dim // num_heads,
q_proj=nn.Linear(embed_dim, embed_dim, bias=attn_bias),
k_proj=nn.Linear(embed_dim, embed_dim, bias=attn_bias),
v_proj=nn.Linear(embed_dim, embed_dim, bias=attn_bias),
output_proj=nn.Linear(embed_dim, embed_dim, bias=attn_bias),
pos_embeddings=None,
attn_dropout=0.0,
is_causal=False,
)
mlp = clip_mlp(
in_dim=embed_dim,
Expand All @@ -107,8 +115,8 @@ def clip_vision_encoder(
transformer_layer = TransformerSelfAttentionLayer(
attn=self_attn,
mlp=mlp,
sa_norm= Fp32LayerNorm(embed_dim, eps=1e-5),
mlp_norm= Fp32LayerNorm(embed_dim, eps=1e-5),
sa_norm=Fp32LayerNorm(embed_dim, eps=1e-5),
mlp_norm=Fp32LayerNorm(embed_dim, eps=1e-5),
sa_scale=None,
mlp_scale=None,
)
Expand All @@ -118,17 +126,21 @@ def clip_vision_encoder(
pre_tile_pos_embed = None
post_tile_pos_embed = None
token_pos_embedding = TokenPositionalEmbedding(
embed_dim=embed_dim,
patch_size=patch_size,
tile_size=tile_size)
embed_dim=embed_dim, patch_size=patch_size, tile_size=tile_size
)
else:
pre_tile_pos_embed = TilePositionalEmbedding(max_num_tiles=max_num_tiles, embed_dim=embed_dim)
post_tile_pos_embed = TilePositionalEmbedding(max_num_tiles=max_num_tiles, embed_dim=embed_dim)
pre_tile_pos_embed = TilePositionalEmbedding(
max_num_tiles=max_num_tiles, embed_dim=embed_dim
)
post_tile_pos_embed = TilePositionalEmbedding(
max_num_tiles=max_num_tiles, embed_dim=embed_dim
)
token_pos_embedding = TiledTokenPositionalEmbedding(
max_num_tiles=max_num_tiles,
embed_dim=embed_dim,
patch_size=patch_size,
tile_size=tile_size)
max_num_tiles=max_num_tiles,
embed_dim=embed_dim,
patch_size=patch_size,
tile_size=tile_size,
)

return VisionTransformer(
num_layers=num_layers,
Expand All @@ -145,13 +157,29 @@ def clip_vision_encoder(
)


def clip_mlp(in_dim: int, out_dim: int, hidden_dim: int, activation: nn.Module, quantize_base: bool = False) -> FeedForward:
def clip_mlp(
in_dim: int,
out_dim: int,
hidden_dim: int,
activation: nn.Module,
quantize_base: bool = False,
) -> FeedForward:
"""
Build the MLP layer associated with the clip model.
"""
gate_proj = nn.Linear(in_dim, hidden_dim) if not quantize_base else FrozenNF4Linear(in_dim, hidden_dim)
down_proj = nn.Linear(hidden_dim, out_dim) if not quantize_base else FrozenNF4Linear(hidden_dim, out_dim)
return FeedForward(gate_proj=gate_proj, down_proj=down_proj, up_proj=None, activation=activation)
gate_proj = (
nn.Linear(in_dim, hidden_dim)
if not quantize_base
else FrozenNF4Linear(in_dim, hidden_dim)
)
down_proj = (
nn.Linear(hidden_dim, out_dim)
if not quantize_base
else FrozenNF4Linear(hidden_dim, out_dim)
)
return FeedForward(
gate_proj=gate_proj, down_proj=down_proj, up_proj=None, activation=activation
)


# ------------------ LoRA CLIP ------------------
Expand Down Expand Up @@ -222,42 +250,46 @@ def lora_clip_vision_encoder(
quantize_base: (bool): Whether to quantize base model weights or not. Only applied to base
weights within linear layers LoRA is applied to. The final output linear projection is not
supported for quantization currently.


Returns:
VisionTransformer: Instantiation of VisionTransformer model.
"""
assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"

# TODO: add support for quantizing and LoRA for the final output projection
cls_projection = CLSProjection(embed_dim=embed_dim, cls_output_dim=cls_output_dim) if output_cls_projection else None
cls_projection = (
CLSProjection(embed_dim=embed_dim, cls_output_dim=cls_output_dim)
if output_cls_projection
else None
)

# transformer layer
self_attn = lora_clip_attention(
lora_modules=lora_modules,
embed_dim=embed_dim,
num_heads=num_heads,
num_kv_heads=num_heads,
head_dim=embed_dim // num_heads,
attn_dropout=0.0,
lora_modules=lora_modules,
embed_dim=embed_dim,
num_heads=num_heads,
num_kv_heads=num_heads,
head_dim=embed_dim // num_heads,
attn_dropout=0.0,
lora_rank=lora_rank,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
use_dora=use_dora,
quantize_base=quantize_base,
)
if apply_lora_to_mlp:
mlp = lora_clip_mlp(
in_dim=embed_dim,
hidden_dim=4 * embed_dim,
out_dim=embed_dim,
activation=activation(),
lora_rank=lora_rank,
lora_alpha=lora_alpha,
quantize_base=quantize_base,
lora_dropout=lora_dropout,
use_dora=use_dora,
quantize_base=quantize_base,
)
if apply_lora_to_mlp:
mlp = lora_clip_mlp(
in_dim=embed_dim,
hidden_dim=4 * embed_dim,
out_dim=embed_dim,
activation=activation(),
lora_rank=lora_rank,
lora_alpha=lora_alpha,
quantize_base=quantize_base,
lora_dropout=lora_dropout,
use_dora=use_dora,
)
)
else:
mlp = clip_mlp(
in_dim=embed_dim,
Expand All @@ -269,8 +301,8 @@ def lora_clip_vision_encoder(
transformer_layer = TransformerSelfAttentionLayer(
attn=self_attn,
mlp=mlp,
sa_norm= Fp32LayerNorm(embed_dim, eps=1e-5),
mlp_norm= Fp32LayerNorm(embed_dim, eps=1e-5),
sa_norm=Fp32LayerNorm(embed_dim, eps=1e-5),
mlp_norm=Fp32LayerNorm(embed_dim, eps=1e-5),
sa_scale=None,
mlp_scale=None,
)
Expand All @@ -280,17 +312,21 @@ def lora_clip_vision_encoder(
pre_tile_pos_embed = None
post_tile_pos_embed = None
token_pos_embedding = TokenPositionalEmbedding(
embed_dim=embed_dim,
patch_size=patch_size,
tile_size=tile_size)
embed_dim=embed_dim, patch_size=patch_size, tile_size=tile_size
)
else:
pre_tile_pos_embed = TilePositionalEmbedding(max_num_tiles=max_num_tiles, embed_dim=embed_dim)
post_tile_pos_embed = TilePositionalEmbedding(max_num_tiles=max_num_tiles, embed_dim=embed_dim)
pre_tile_pos_embed = TilePositionalEmbedding(
max_num_tiles=max_num_tiles, embed_dim=embed_dim
)
post_tile_pos_embed = TilePositionalEmbedding(
max_num_tiles=max_num_tiles, embed_dim=embed_dim
)
token_pos_embedding = TiledTokenPositionalEmbedding(
max_num_tiles=max_num_tiles,
embed_dim=embed_dim,
patch_size=patch_size,
tile_size=tile_size)
max_num_tiles=max_num_tiles,
embed_dim=embed_dim,
patch_size=patch_size,
tile_size=tile_size,
)

model = VisionTransformer(
num_layers=num_layers,
Expand Down Expand Up @@ -467,19 +503,23 @@ def lora_clip_mlp(
"""
adapter_cls = DoRALinear if use_dora else LoRALinear
gate_proj = adapter_cls(
in_dim=dim,
in_dim=in_dim,
out_dim=hidden_dim,
rank=lora_rank,
alpha=lora_alpha,
dropout=lora_dropout,
quantize_base=quantize_base,
use_bias=True,
)
down_proj = adapter_cls(
in_dim=hidden_dim,
out_dim=dim,
out_dim=out_dim,
rank=lora_rank,
alpha=lora_alpha,
dropout=lora_dropout,
quantize_base=quantize_base,
use_bias=True,
Comment on lines +506 to +521
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this changed

)
return FeedForward(
gate_proj=gate_proj, down_proj=down_proj, up_proj=None, activation=activation
)
return FeedForward(gate_proj=gate_proj, down_proj=down_proj, up_proj=None, activation=activation)
Loading
Loading