-
Notifications
You must be signed in to change notification settings - Fork 689
Fix CLIP pos embedding interpolation to work on DTensors #1739
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
dccac07
9d60a91
33a0e05
50878fe
b2d30d9
0fc5221
f98351a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -10,6 +10,7 @@ | |
| import torch | ||
| import torch.nn.functional as F | ||
| from torch import nn | ||
| from torch.distributed._tensor import distribute_tensor, DTensor | ||
|
|
||
|
|
||
| class TokenPositionalEmbedding(nn.Module): | ||
|
|
@@ -137,6 +138,10 @@ def _load_state_dict_hook( | |
| inpt_local_pos_embed = state_dict.get( | ||
| prefix + "local_token_positional_embedding" | ||
| ) | ||
| local_device = inpt_local_pos_embed.device | ||
| if isinstance(inpt_local_pos_embed, DTensor): | ||
| inpt_local_pos_embed = inpt_local_pos_embed.full_tensor() | ||
|
|
||
| if inpt_local_pos_embed is not None: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we should probably add here: if inpt_local_pos_embed is not None and inpt_local_pos_embed.shape != self. local_token_positional_embedding But testing becomes a bit trickier. Maybe for now its better to not add it until testing is completed
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah will leave out for now
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i think we should add before the PR is finalized though. What do you think?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do you say testing becomes trickier?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. After thinking about it more, I don't think we should add it. The DTensor fix resolves the issue and no need to add extra logic on top of this that was not present before |
||
|
|
||
| # sanity check | ||
|
|
@@ -159,6 +164,13 @@ def _load_state_dict_hook( | |
| tgt_patch_grid_size=int(math.sqrt(tgt_n_tokens_per_tile - 1)), | ||
| ) | ||
|
|
||
| if isinstance(inpt_local_pos_embed, DTensor): | ||
|
||
| inpt_local_pos_embed = distribute_tensor( | ||
| inpt_local_pos_embed, | ||
| device_mesh=self.local_token_positional_embedding.device_mesh, | ||
| placements=self.local_token_positional_embedding.placements, | ||
| ) | ||
|
|
||
| # update state dict | ||
| state_dict[ | ||
| prefix + "local_token_positional_embedding" | ||
|
|
@@ -176,6 +188,10 @@ def _load_state_dict_hook( | |
| inpt_global_pos_embed = state_dict.get( | ||
| prefix + "global_token_positional_embedding" | ||
| ) | ||
| global_device = inpt_global_pos_embed.device | ||
| if isinstance(inpt_global_pos_embed, DTensor): | ||
| inpt_global_pos_embed = inpt_global_pos_embed.full_tensor() | ||
|
||
|
|
||
| if inpt_global_pos_embed is not None: | ||
|
|
||
| _, _, inpt_n_tokens_per_tile, _ = inpt_global_pos_embed.shape | ||
|
|
@@ -202,6 +218,13 @@ def _load_state_dict_hook( | |
| tgt_patch_grid_size=int(math.sqrt(tgt_n_tokens_per_tile - 1)), | ||
| ) | ||
|
|
||
| if isinstance(inpt_global_pos_embed, DTensor): | ||
| inpt_global_pos_embed = distribute_tensor( | ||
| inpt_global_pos_embed, | ||
| device_mesh=self.global_token_positional_embedding.device_mesh, | ||
| placements=self.global_token_positional_embedding.placements, | ||
| ) | ||
|
|
||
| # update state dict | ||
| state_dict[ | ||
| prefix + "global_token_positional_embedding" | ||
|
|
@@ -497,7 +520,9 @@ def _load_state_dict_hook( | |
| """ | ||
|
|
||
| embedding = state_dict.get(prefix + "embedding") | ||
|
|
||
| device = embedding.device | ||
| if isinstance(embedding, DTensor): | ||
| embedding = embedding.full_tensor() | ||
|
||
| if embedding is not None: | ||
|
|
||
| # ckpt pos emb | ||
|
|
@@ -534,6 +559,13 @@ def _load_state_dict_hook( | |
| embedding, tgt_max_num_tiles=tgt_max_num_tiles_x | ||
| ) | ||
|
|
||
| if isinstance(embedding_new, DTensor): | ||
| embedding_new = distribute_tensor( | ||
| embedding_new, | ||
| device_mesh=self.embedding.device_mesh, | ||
| placements=self.embedding.placements, | ||
| ) | ||
|
|
||
| # update state dict | ||
| state_dict[prefix + "embedding"] = embedding_new | ||
| if embedding_new.shape != self.embedding.shape: | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
delete "local_device" since its not used
put this inside of the "if inpt_local_pos_embed is not None:"