Skip to content
Merged
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 33 additions & 1 deletion torchtune/models/clip/_position_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import torch
import torch.nn.functional as F
from torch import nn
from torch.distributed._tensor import distribute_tensor, DTensor


class TokenPositionalEmbedding(nn.Module):
Expand Down Expand Up @@ -137,6 +138,10 @@ def _load_state_dict_hook(
inpt_local_pos_embed = state_dict.get(
prefix + "local_token_positional_embedding"
)
local_device = inpt_local_pos_embed.device
if isinstance(inpt_local_pos_embed, DTensor):
inpt_local_pos_embed = inpt_local_pos_embed.full_tensor()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

delete "local_device" since its not used

put this inside of the "if inpt_local_pos_embed is not None:"


if inpt_local_pos_embed is not None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should probably add here:

if inpt_local_pos_embed is not None and inpt_local_pos_embed.shape != self. local_token_positional_embedding

But testing becomes a bit trickier. Maybe for now its better to not add it until testing is completed

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah will leave out for now

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think we should add before the PR is finalized though. What do you think?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you say testing becomes trickier?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After thinking about it more, I don't think we should add it. The DTensor fix resolves the issue and no need to add extra logic on top of this that was not present before


# sanity check
Expand All @@ -159,6 +164,13 @@ def _load_state_dict_hook(
tgt_patch_grid_size=int(math.sqrt(tgt_n_tokens_per_tile - 1)),
)

if isinstance(inpt_local_pos_embed, DTensor):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

inpt_local_pos_embed is not a DTensor, because its full now

need to chech if self.local_token_positional_embedding is DTensor. We should probably do the same in the previous check (there you check if inpt_local_pos_embed is DTensor)

inpt_local_pos_embed = distribute_tensor(
inpt_local_pos_embed,
device_mesh=self.local_token_positional_embedding.device_mesh,
placements=self.local_token_positional_embedding.placements,
)

# update state dict
state_dict[
prefix + "local_token_positional_embedding"
Expand All @@ -176,6 +188,10 @@ def _load_state_dict_hook(
inpt_global_pos_embed = state_dict.get(
prefix + "global_token_positional_embedding"
)
global_device = inpt_global_pos_embed.device
if isinstance(inpt_global_pos_embed, DTensor):
inpt_global_pos_embed = inpt_global_pos_embed.full_tensor()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comments as above


if inpt_global_pos_embed is not None:

_, _, inpt_n_tokens_per_tile, _ = inpt_global_pos_embed.shape
Expand All @@ -202,6 +218,13 @@ def _load_state_dict_hook(
tgt_patch_grid_size=int(math.sqrt(tgt_n_tokens_per_tile - 1)),
)

if isinstance(inpt_global_pos_embed, DTensor):
inpt_global_pos_embed = distribute_tensor(
inpt_global_pos_embed,
device_mesh=self.global_token_positional_embedding.device_mesh,
placements=self.global_token_positional_embedding.placements,
)

# update state dict
state_dict[
prefix + "global_token_positional_embedding"
Expand Down Expand Up @@ -497,7 +520,9 @@ def _load_state_dict_hook(
"""

embedding = state_dict.get(prefix + "embedding")

device = embedding.device
if isinstance(embedding, DTensor):
embedding = embedding.full_tensor()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comments as above

if embedding is not None:

# ckpt pos emb
Expand Down Expand Up @@ -534,6 +559,13 @@ def _load_state_dict_hook(
embedding, tgt_max_num_tiles=tgt_max_num_tiles_x
)

if isinstance(embedding_new, DTensor):
embedding_new = distribute_tensor(
embedding_new,
device_mesh=self.embedding.device_mesh,
placements=self.embedding.placements,
)

# update state dict
state_dict[prefix + "embedding"] = embedding_new
if embedding_new.shape != self.embedding.shape:
Expand Down