-
-
Notifications
You must be signed in to change notification settings - Fork 3.9k
Add falcon h1 #2650
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add falcon h1 #2650
Conversation
| FalconH1ForCausalLM, | ||
| FalconHybridMambaAttentionDynamicCache, | ||
| ) | ||
| except: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@danielhanchen I think we should consolidate all these imports and version checks into one file that gets called at. init and have the flags IS_QWEN_SUPPORTED or IS_FALCON_SUPPORTED passed over to the individual files...
Thoughts?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree - we'll do it for a future release
| K = K.view(1, K_M, n_heads, head_dim) | ||
| V = V.view(1, V_M, n_heads, head_dim) | ||
| pass | ||
| else: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are there any efficiency gains in doing this? If not, can we use the same structure for both cases?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually from internal checks, yes - view is faster than reshape, since reshape does a copy possibly
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh actually I meant checking for requires_grad and doing separate operations for both...
If I understand correctly you mean, we can get away with view for inference but its not compatible with training or something?
Otherwise why can't we use view for both
unsloth/models/falcon_h1.py
Outdated
| K_M = V_M = bsz * kv_seq_len | ||
| Q_M = bsz * q_len | ||
|
|
||
| has_swa = isinstance(causal_mask, xformers.attn_bias.BlockDiagonalCausalMask) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Falcon doesn't seem to have SWA (correct me if I'm wrong). We can remove this safely
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
True , removing it
| past_key_value = (K, V) if use_cache else None | ||
|
|
||
| # Attention module | ||
| if (not HAS_FLASH_ATTENTION and attention_mask is None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, @danielhanchen we should move the attention computation logic to single file.
Any pre post processing can exist in the model specific files...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes agreed - was planning to do that!
| ) | ||
| mamba_hidden_states = mamba_hidden_states * self.ssm_out_multiplier | ||
|
|
||
| hidden_states = mamba_hidden_states + attention_hidden_states |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we combined L406, L414 and this into single operation if possible and please check if gradients flow properly ? OR
Lets reorder the operations the way they do in transformers to look consistent..
| (see `past_key_values`). | ||
| past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states | ||
| """ | ||
| if use_cache and hasattr(self, "_flag_for_generation"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@danielhanchen we should also refactor these a little.
Instead of having diff code blocks doing entirely same thing (except norm), we can consolidate
if use_cache and hasattr(self, "_flag_for_generation"):
layernorm = fast_rms_layernorm_inference
else:
layernorm = fast_rms_layernorm
...
Rest of the code together
unsloth/models/falcon_h1.py
Outdated
| return outputs | ||
| pass | ||
|
|
||
| def _FalconH1_fast_forward_inference(attention_fast_forward_inference=LlamaAttention_fast_forward_inference, mlp_fast_forward_inference=fast_swiglu_inference): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NIT: Let the default be FalconAttention_fast_forward_inference
| hidden_states: torch.Tensor, | ||
| causal_mask = None, | ||
| attention_mask: Optional[torch.Tensor] = None, | ||
| mamba_attention_mask: Optional[torch.Tensor] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Datta0 how unsloth handles forward kwargs ? Does it gets them directly from HF transformers ?
Context: for Mamba, the mamba_mask is created here: https://github.com/huggingface/transformers/blob/51d732709e5ae424e8fb6c4e58b72057a3e413c2/src/transformers/models/falcon_h1/modeling_falcon_h1.py#L1305 - will we need to add a similar logic in LlamaModel_fast_forward to incorporate the mamba mask?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah if mamba's decoder layer expects mamba_mask, then its the job of FalconModel's forward function to handle that.
If the model's forward functionality is similar to llama with this being the only change, we can use LlamaModel's forward while checking for model type to be falcon
|
@dhiaEddineRhaiem |
|
Many Thanks @Datta0 , |
|
hey @Datta0 ,
got : the small differences may be the result of the numerical differences between the kernels in unsloth vs native kernels on HF + potentially LoRA initializations not being consistent For inference: |
|
@dhiaEddineRhaiem Oh apologies on the delay - is it possible to get a plot in say Excel / Google Sheets of the HF loss vs Unsloth loss over 60 steps if they mainly match, then we can merge this :) Thanks! |
| FalconH1ForCausalLM, | ||
| FalconHybridMambaAttentionDynamicCache, | ||
| ) | ||
| except: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree - we'll do it for a future release
| past_key_value = (K, V) if use_cache else None | ||
|
|
||
| # Attention module | ||
| if (not HAS_FLASH_ATTENTION and attention_mask is None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes agreed - was planning to do that!
| K = K.view(1, K_M, n_heads, head_dim) | ||
| V = V.view(1, V_M, n_heads, head_dim) | ||
| pass | ||
| else: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually from internal checks, yes - view is faster than reshape, since reshape does a copy possibly
Co-authored-by: Daniel Han <[email protected]>
Co-authored-by: Daniel Han <[email protected]>
|
Oh also is inference ok? I'm assuming a new PR will be needed? Also if it's possible to make a notebook and place it in the Unsloth notebooks repo, that'll be also cool - also reminder to add "contributed by the folks at TII" for example and link you website / HF page for extra recognition :) |
|
@danielhanchen, many thx for the review. for the Inference , we think it might be better to raise it in a separate PR. |
|
Ok cool! I think it's cause it's mamba right - so inference is difference since there's no KV cache - a dumb approach is if inference is enabled, don't use the fast path, and use the original mamba forward |
|
Falcon H1 is hybrid ( parallel design with mamba and attention ).
we are still debugging on our side to make it work as expected. |
|
@danielhanchen , @Datta0 , we managed to do 60 steps for FalconH1 with unsloth and hf, also, what is the best way to push the notebook ? (inside hf or maybe inside in this repo ? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for the work. Have a few minor comments. Rest looks good, especially the losses matching with HF :)
| K = K.view(1, K_M, n_heads, head_dim) | ||
| V = V.view(1, V_M, n_heads, head_dim) | ||
| pass | ||
| else: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh actually I meant checking for requires_grad and doing separate operations for both...
If I understand correctly you mean, we can get away with view for inference but its not compatible with training or something?
Otherwise why can't we use view for both
| hidden_states: torch.Tensor, | ||
| causal_mask = None, | ||
| attention_mask: Optional[torch.Tensor] = None, | ||
| mamba_attention_mask: Optional[torch.Tensor] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah if mamba's decoder layer expects mamba_mask, then its the job of FalconModel's forward function to handle that.
If the model's forward functionality is similar to llama with this being the only change, we can use LlamaModel's forward while checking for model type to be falcon
unsloth/models/llama.py
Outdated
| IS_GRANITE = self.config.model_type.startswith("granite") | ||
| IS_FALCON_H1 = self.config.model_type.startswith("falcon_h1") | ||
|
|
||
| if IS_FALCON_H1: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NIT: Either move this to L790/803 or move that here as the code is very similar
maybe we can do
if IS_FALCON_H1 or IS_GRANITE:
inputs_embeds * = self.config.embedding_multiplier
unsloth/models/llama.py
Outdated
| (fast_rms_layernorm_inference_gemma if IS_GEMMA else fast_rms_layernorm_inference)\ | ||
| (self.norm, hidden_states) | ||
| if IS_FALCON_H1: | ||
| hidden_states = fast_rms_layernorm_inference(self.final_layernorm, hidden_states) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NIT: Fix indentation
unsloth/models/llama.py
Outdated
| lm_head_device = lm_head.device | ||
|
|
||
| if self.config.model_type == "falcon_h1": | ||
| lm_head = lm_head * self.config.lm_head_multiplier |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we multiply lm_head or do we multiply the logits?
If the latter, we can look at merging this with L1189/L1205
|
Done addressing the comments @Datta0 ! Let us know if all is good |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Great work :)
|
Many Thanks @Datta0 @danielhanchen @younesbelkada |
|
Great work! |

This PR introduces support for the FalconH1 model family within the Unsloth library.
It addresses the following issue.
Training integration has been validated using the following command:
Note: Inference support is currently under active development and is being debugged.
@danielhanchen , do you think we can first have FalconH1 training supported for the community in unsloth and then raise a separate PR to fix inference?