Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 57 additions & 0 deletions torchtune/models/clip/_position_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import torch
import torch.nn.functional as F
from torch import nn
from torch.distributed._tensor import distribute_tensor, DTensor


class TokenPositionalEmbedding(nn.Module):
Expand Down Expand Up @@ -137,8 +138,20 @@ def _load_state_dict_hook(
inpt_local_pos_embed = state_dict.get(
prefix + "local_token_positional_embedding"
)

if inpt_local_pos_embed is not None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should probably add here:

if inpt_local_pos_embed is not None and inpt_local_pos_embed.shape != self. local_token_positional_embedding

But testing becomes a bit trickier. Maybe for now its better to not add it until testing is completed

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah will leave out for now

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think we should add before the PR is finalized though. What do you think?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you say testing becomes trickier?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After thinking about it more, I don't think we should add it. The DTensor fix resolves the issue and no need to add extra logic on top of this that was not present before


# We can only apply F.interpolate to vanilla tensors, not DTensors
# If pos embeds are a DTensor, we gather the full tensor, apply
# interpolate, and then reshard after
if isinstance(inpt_local_pos_embed, DTensor):
local_embed_is_sharded = True
local_embed_device_mesh = inpt_local_pos_embed.device_mesh
local_embed_placements = inpt_local_pos_embed.placements
inpt_local_pos_embed = inpt_local_pos_embed.full_tensor()
Comment on lines +149 to +151
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i thought that we had to use the device_mech and placements from self.local_token_positional_embedding

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think they should be the same, which is why it was working before. But in my mind this is the more "correct" thing to do. We apply some operation to a DTensor, then restore it to its original state after the fact

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will this increase chances of OOMs, or is the pos_embed small enough where this is not a concern?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@RdoubleA well as of today it doesn't work so I guess the current chance of OOM is NaN. This is supposed to be a no-op for single device (hence wrapping everything in isinstance(..., DTensor) checks), so no memory implications there.

else:
local_embed_is_sharded = False

# sanity check
inpt_n_tokens_per_tile, inpt_embed_dim = inpt_local_pos_embed.shape
if math.sqrt(inpt_n_tokens_per_tile - 1) % 1 != 0:
Expand All @@ -159,6 +172,13 @@ def _load_state_dict_hook(
tgt_patch_grid_size=int(math.sqrt(tgt_n_tokens_per_tile - 1)),
)

if local_embed_is_sharded:
inpt_local_pos_embed = distribute_tensor(
inpt_local_pos_embed,
device_mesh=local_embed_device_mesh,
placements=local_embed_placements,
)

# update state dict
state_dict[
prefix + "local_token_positional_embedding"
Expand All @@ -176,8 +196,20 @@ def _load_state_dict_hook(
inpt_global_pos_embed = state_dict.get(
prefix + "global_token_positional_embedding"
)

if inpt_global_pos_embed is not None:

# We can only apply F.interpolate to vanilla tensors, not DTensors
# If pos embeds are a DTensor, we gather the full tensor, apply
# interpolate, and then reshard after
if isinstance(inpt_global_pos_embed, DTensor):
global_embed_is_sharded = True
global_embed_device_mesh = inpt_global_pos_embed.device_mesh
global_embed_placements = inpt_global_pos_embed.placements
inpt_global_pos_embed = inpt_global_pos_embed.full_tensor()
else:
global_embed_is_sharded = False

_, _, inpt_n_tokens_per_tile, _ = inpt_global_pos_embed.shape

# sanity check
Expand All @@ -202,6 +234,13 @@ def _load_state_dict_hook(
tgt_patch_grid_size=int(math.sqrt(tgt_n_tokens_per_tile - 1)),
)

if global_embed_is_sharded:
inpt_global_pos_embed = distribute_tensor(
inpt_global_pos_embed,
device_mesh=global_embed_device_mesh,
placements=global_embed_placements,
)

# update state dict
state_dict[
prefix + "global_token_positional_embedding"
Expand Down Expand Up @@ -500,6 +539,17 @@ def _load_state_dict_hook(

if embedding is not None:

# We can only apply F.interpolate to vanilla tensors, not DTensors
# If pos embeds are a DTensor, we gather the full tensor, apply
# interpolate, and then reshard after
if isinstance(embedding, DTensor):
embedding_is_sharded = True
device_mesh = embedding.device_mesh
placements = embedding.placements
embedding = embedding.full_tensor()
else:
embedding_is_sharded = False

# ckpt pos emb
(
tgt_max_num_tiles_x,
Expand Down Expand Up @@ -534,6 +584,13 @@ def _load_state_dict_hook(
embedding, tgt_max_num_tiles=tgt_max_num_tiles_x
)

if embedding_is_sharded:
embedding_new = distribute_tensor(
embedding_new,
device_mesh=device_mesh,
placements=placements,
)

# update state dict
state_dict[prefix + "embedding"] = embedding_new
if embedding_new.shape != self.embedding.shape:
Expand Down