Skip to content

Conversation

@dhiaEddineRhaiem
Copy link
Contributor

@dhiaEddineRhaiem dhiaEddineRhaiem commented May 29, 2025

This PR introduces support for the FalconH1 model family within the Unsloth library.

It addresses the following issue.

Training integration has been validated using the following command:

python unsloth-cli.py --model_name "tiiuae/Falcon-H1-0.5B-Base" \
  --max_seq_length 2048 --dtype bfloat16 --load_in_4bit \
  --r 64 --lora_alpha 32 --lora_dropout 0.1 --bias "none" \
  --per_device_train_batch_size 1

Note: Inference support is currently under active development and is being debugged.
@danielhanchen , do you think we can first have FalconH1 training supported for the community in unsloth and then raise a separate PR to fix inference?

FalconH1ForCausalLM,
FalconHybridMambaAttentionDynamicCache,
)
except:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@danielhanchen I think we should consolidate all these imports and version checks into one file that gets called at. init and have the flags IS_QWEN_SUPPORTED or IS_FALCON_SUPPORTED passed over to the individual files...
Thoughts?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree - we'll do it for a future release

K = K.view(1, K_M, n_heads, head_dim)
V = V.view(1, V_M, n_heads, head_dim)
pass
else:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are there any efficiency gains in doing this? If not, can we use the same structure for both cases?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually from internal checks, yes - view is faster than reshape, since reshape does a copy possibly

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh actually I meant checking for requires_grad and doing separate operations for both...
If I understand correctly you mean, we can get away with view for inference but its not compatible with training or something?
Otherwise why can't we use view for both

K_M = V_M = bsz * kv_seq_len
Q_M = bsz * q_len

has_swa = isinstance(causal_mask, xformers.attn_bias.BlockDiagonalCausalMask)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Falcon doesn't seem to have SWA (correct me if I'm wrong). We can remove this safely

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True , removing it

past_key_value = (K, V) if use_cache else None

# Attention module
if (not HAS_FLASH_ATTENTION and attention_mask is None):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, @danielhanchen we should move the attention computation logic to single file.
Any pre post processing can exist in the model specific files...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes agreed - was planning to do that!

)
mamba_hidden_states = mamba_hidden_states * self.ssm_out_multiplier

hidden_states = mamba_hidden_states + attention_hidden_states
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we combined L406, L414 and this into single operation if possible and please check if gradients flow properly ? OR
Lets reorder the operations the way they do in transformers to look consistent..

(see `past_key_values`).
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
"""
if use_cache and hasattr(self, "_flag_for_generation"):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@danielhanchen we should also refactor these a little.
Instead of having diff code blocks doing entirely same thing (except norm), we can consolidate

if use_cache and hasattr(self, "_flag_for_generation"):
	layernorm = fast_rms_layernorm_inference
else:
	layernorm = fast_rms_layernorm
...
Rest of the code together 

return outputs
pass

def _FalconH1_fast_forward_inference(attention_fast_forward_inference=LlamaAttention_fast_forward_inference, mlp_fast_forward_inference=fast_swiglu_inference):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIT: Let the default be FalconAttention_fast_forward_inference

hidden_states: torch.Tensor,
causal_mask = None,
attention_mask: Optional[torch.Tensor] = None,
mamba_attention_mask: Optional[torch.Tensor] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Datta0 how unsloth handles forward kwargs ? Does it gets them directly from HF transformers ?
Context: for Mamba, the mamba_mask is created here: https://github.com/huggingface/transformers/blob/51d732709e5ae424e8fb6c4e58b72057a3e413c2/src/transformers/models/falcon_h1/modeling_falcon_h1.py#L1305 - will we need to add a similar logic in LlamaModel_fast_forward to incorporate the mamba mask?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah if mamba's decoder layer expects mamba_mask, then its the job of FalconModel's forward function to handle that.
If the model's forward functionality is similar to llama with this being the only change, we can use LlamaModel's forward while checking for model type to be falcon

@dhiaEddineRhaiem dhiaEddineRhaiem requested a review from Datta0 June 8, 2025 18:36
@Datta0
Copy link
Collaborator

Datta0 commented Jun 9, 2025

@dhiaEddineRhaiem
Thanks for the changes. everything looks good to me. If you can provide a small sample script to verify that our outputs match HF's (for training a few steps and inference), that'd be great and we can get this merged

@dhiaEddineRhaiem
Copy link
Contributor Author

Many Thanks @Datta0 ,
thanks also to @younesbelkada for contributing to this.
i will be providing that shortly.

@dhiaEddineRhaiem
Copy link
Contributor Author

dhiaEddineRhaiem commented Jun 18, 2025

hey @Datta0 ,
for training:

  1. with unsloth training, we got this loss for first 8 steps using following command :

python unsloth-cli.py --model_name "tiiuae/Falcon-H1-0.5B-Base" --max_seq_length 2048 --dtype bfloat16 --load_in_4bit --r 64 --lora_alpha 32 --lora_dropout 0.1 --bias "none" --per_device_train_batch_size 1

{'loss': 1.765, 'grad_norm': 0.7072669863700867, 'learning_rate': 0.0, 'epoch': 0.0}                                                                                     
{'loss': 1.6431, 'grad_norm': 0.7720947265625, 'learning_rate': 4e-05, 'epoch': 0.0}                                                                                     
{'loss': 2.8533, 'grad_norm': 1.7554939985275269, 'learning_rate': 8e-05, 'epoch': 0.0}                                                                                  
{'loss': 1.4332, 'grad_norm': 0.6112934947013855, 'learning_rate': 0.00012, 'epoch': 0.0}                                                                                
{'loss': 1.751, 'grad_norm': 0.4668850898742676, 'learning_rate': 0.00016, 'epoch': 0.0}                                                                                 
{'loss': 1.5516, 'grad_norm': 0.3460712432861328, 'learning_rate': 0.0002, 'epoch': 0.0}                                                                                 
{'loss': 1.8194, 'grad_norm': 0.49015215039253235, 'learning_rate': 0.00019949367088607596, 'epoch': 0.0}                                                                
{'loss': 1.8264, 'grad_norm': 0.3050038516521454, 'learning_rate': 0.0001989873417721519, 'epoch': 0.0}       
  1. with following hf script ( using same settings), we ,using following command,

python hf-cli.py --model_name "tiiuae/Falcon-H1-0.5B-Base" --max_seq_length 2048 --load_in_4bit --r 64 --lora_alpha 32 --lora_dropout 0.1 --bias "none" --per_device_train_batch_size 1

got :

{'loss': 1.7969, 'grad_norm': 0.7279049158096313, 'learning_rate': 0.0, 'epoch': 0.0}                                                                                                             
{'loss': 1.6988, 'grad_norm': 0.7694766521453857, 'learning_rate': 4e-05, 'epoch': 0.0}                                                                                                           
{'loss': 2.9951, 'grad_norm': 1.805290937423706, 'learning_rate': 8e-05, 'epoch': 0.0}                                                                                                            
{'loss': 1.4775, 'grad_norm': 0.6704585552215576, 'learning_rate': 0.00012, 'epoch': 0.0}                                                                                                         
{'loss': 1.7497, 'grad_norm': 0.4692195653915405, 'learning_rate': 0.00016, 'epoch': 0.0}                                                                                                         
{'loss': 1.587, 'grad_norm': 0.37044161558151245, 'learning_rate': 0.0002, 'epoch': 0.0}                                                                                                          
{'loss': 1.8655, 'grad_norm': 0.520369291305542, 'learning_rate': 0.00019949367088607596, 'epoch': 0.0}                                                                                           
{'loss': 1.879, 'grad_norm': 0.3259218633174896, 'learning_rate': 0.0001989873417721519, 'epoch': 0.0}

the small differences may be the result of the numerical differences between the kernels in unsloth vs native kernels on HF + potentially LoRA initializations not being consistent

For inference:
We can address the inference part in a follow up PR as it will require some time to make it work (still working on our end on that)

@danielhanchen
Copy link
Contributor

@dhiaEddineRhaiem Oh apologies on the delay - is it possible to get a plot in say Excel / Google Sheets of the HF loss vs Unsloth loss over 60 steps

if they mainly match, then we can merge this :) Thanks!

FalconH1ForCausalLM,
FalconHybridMambaAttentionDynamicCache,
)
except:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree - we'll do it for a future release

past_key_value = (K, V) if use_cache else None

# Attention module
if (not HAS_FLASH_ATTENTION and attention_mask is None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes agreed - was planning to do that!

K = K.view(1, K_M, n_heads, head_dim)
V = V.view(1, V_M, n_heads, head_dim)
pass
else:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually from internal checks, yes - view is faster than reshape, since reshape does a copy possibly

@danielhanchen
Copy link
Contributor

Oh also is inference ok? I'm assuming a new PR will be needed?

Also if it's possible to make a notebook and place it in the Unsloth notebooks repo, that'll be also cool - also reminder to add "contributed by the folks at TII" for example and link you website / HF page for extra recognition :)

@dhiaEddineRhaiem
Copy link
Contributor Author

@danielhanchen, many thx for the review.
on our side , we will be preparing the excel and the notebook quickly.

for the Inference , we think it might be better to raise it in a separate PR.

@danielhanchen
Copy link
Contributor

Ok cool! I think it's cause it's mamba right - so inference is difference since there's no KV cache - a dumb approach is if inference is enabled, don't use the fast path, and use the original mamba forward

@dhiaEddineRhaiem
Copy link
Contributor Author

Falcon H1 is hybrid ( parallel design with mamba and attention ).
For that, we designed a new FalconHybridMambaAttentionDynamicCache , a dynamic cache manager that tracks:

  1. key_cache / value_cache for attention layers

  2. conv_states / ssm_states for Mamba layers

we are still debugging on our side to make it work as expected.

@dhiaEddineRhaiem
Copy link
Contributor Author

@danielhanchen , @Datta0 , we managed to do 60 steps for FalconH1 with unsloth and hf,
this is a plot for the losses obtained with a dataframe containing all:
loss_comparison.csv

image

also, what is the best way to push the notebook ? (inside hf or maybe inside in this repo ?

Copy link
Collaborator

@Datta0 Datta0 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for the work. Have a few minor comments. Rest looks good, especially the losses matching with HF :)

K = K.view(1, K_M, n_heads, head_dim)
V = V.view(1, V_M, n_heads, head_dim)
pass
else:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh actually I meant checking for requires_grad and doing separate operations for both...
If I understand correctly you mean, we can get away with view for inference but its not compatible with training or something?
Otherwise why can't we use view for both

hidden_states: torch.Tensor,
causal_mask = None,
attention_mask: Optional[torch.Tensor] = None,
mamba_attention_mask: Optional[torch.Tensor] = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah if mamba's decoder layer expects mamba_mask, then its the job of FalconModel's forward function to handle that.
If the model's forward functionality is similar to llama with this being the only change, we can use LlamaModel's forward while checking for model type to be falcon

IS_GRANITE = self.config.model_type.startswith("granite")
IS_FALCON_H1 = self.config.model_type.startswith("falcon_h1")

if IS_FALCON_H1:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIT: Either move this to L790/803 or move that here as the code is very similar
maybe we can do

if IS_FALCON_H1 or IS_GRANITE:
    inputs_embeds * = self.config.embedding_multiplier

(fast_rms_layernorm_inference_gemma if IS_GEMMA else fast_rms_layernorm_inference)\
(self.norm, hidden_states)
if IS_FALCON_H1:
hidden_states = fast_rms_layernorm_inference(self.final_layernorm, hidden_states)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIT: Fix indentation

lm_head_device = lm_head.device

if self.config.model_type == "falcon_h1":
lm_head = lm_head * self.config.lm_head_multiplier
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we multiply lm_head or do we multiply the logits?
If the latter, we can look at merging this with L1189/L1205

@younesbelkada
Copy link
Contributor

Done addressing the comments @Datta0 ! Let us know if all is good

Copy link
Collaborator

@Datta0 Datta0 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Great work :)

@dhiaEddineRhaiem
Copy link
Contributor Author

Many Thanks @Datta0 @danielhanchen @younesbelkada

@danielhanchen
Copy link
Contributor

Great work!

@danielhanchen danielhanchen merged commit a8f3d69 into unslothai:main Jun 28, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants