2323import jax .numpy as jnp
2424from flax .core .frozen_dict import FrozenDict , freeze , unfreeze
2525from flax .linen import combine_masks , make_causal_mask
26+ from flax .linen import partitioning as nn_partitioning
2627from flax .linen .attention import dot_product_attention_weights
2728from flax .traverse_util import flatten_dict , unflatten_dict
2829from jax import lax
5354_CHECKPOINT_FOR_DOC = "openai/whisper-tiny"
5455_CONFIG_FOR_DOC = "WhisperConfig"
5556
57+ remat = nn_partitioning .remat
58+
5659
5760WHISPER_START_DOCSTRING = r"""
5861 This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
@@ -391,12 +394,20 @@ def __call__(
391394class FlaxWhisperEncoderLayerCollection (nn .Module ):
392395 config : WhisperConfig
393396 dtype : jnp .dtype = jnp .float32 # the dtype of the computation
397+ gradient_checkpointing : bool = False
394398
395399 def setup (self ):
396- self .layers = [
397- FlaxWhisperEncoderLayer (self .config , name = str (i ), dtype = self .dtype )
398- for i in range (self .config .encoder_layers )
399- ]
400+ if self .gradient_checkpointing :
401+ FlaxWhisperEncoderCheckpointLayer = remat (FlaxWhisperEncoderLayer , static_argnums = (2 , 3 ))
402+ self .layers = [
403+ FlaxWhisperEncoderCheckpointLayer (self .config , name = str (i ), dtype = self .dtype )
404+ for i in range (self .config .encoder_layers )
405+ ]
406+ else :
407+ self .layers = [
408+ FlaxWhisperEncoderLayer (self .config , name = str (i ), dtype = self .dtype )
409+ for i in range (self .config .encoder_layers )
410+ ]
400411 self .layerdrop = self .config .encoder_layerdrop
401412
402413 def __call__ (
@@ -535,12 +546,20 @@ def __call__(
535546class FlaxWhisperDecoderLayerCollection (nn .Module ):
536547 config : WhisperConfig
537548 dtype : jnp .dtype = jnp .float32 # the dtype of the computation
549+ gradient_checkpointing : bool = False
538550
539551 def setup (self ):
540- self .layers = [
541- FlaxWhisperDecoderLayer (self .config , name = str (i ), dtype = self .dtype )
542- for i in range (self .config .decoder_layers )
543- ]
552+ if self .gradient_checkpointing :
553+ FlaxWhisperDecoderCheckpointLayer = remat (FlaxWhisperDecoderLayer , static_argnums = (4 , 5 , 6 ))
554+ self .layers = [
555+ FlaxWhisperDecoderCheckpointLayer (self .config , name = str (i ), dtype = self .dtype )
556+ for i in range (self .config .encoder_layers )
557+ ]
558+ else :
559+ self .layers = [
560+ FlaxWhisperDecoderLayer (self .config , name = str (i ), dtype = self .dtype )
561+ for i in range (self .config .decoder_layers )
562+ ]
544563 self .layerdrop = self .config .decoder_layerdrop
545564
546565 def __call__ (
@@ -605,6 +624,7 @@ def __call__(
605624class FlaxWhisperEncoder (nn .Module ):
606625 config : WhisperConfig
607626 dtype : jnp .dtype = jnp .float32
627+ gradient_checkpointing : bool = False
608628
609629 def setup (self ) -> None :
610630 self .conv1 = nn .Conv (
@@ -628,6 +648,7 @@ def setup(self) -> None:
628648 self .layers = FlaxWhisperEncoderLayerCollection (
629649 self .config ,
630650 dtype = self .dtype ,
651+ gradient_checkpointing = self .gradient_checkpointing ,
631652 )
632653 self .embed_positions = nn .Embed (self .config .max_source_positions , self .config .d_model , dtype = self .dtype )
633654
@@ -689,12 +710,13 @@ def __call__(
689710class FlaxWhisperDecoder (nn .Module ):
690711 config : WhisperConfig
691712 dtype : jnp .dtype = jnp .float32
713+ gradient_checkpointing : bool = False
692714
693715 def setup (self ) -> None :
694716 self .embed_tokens = nn .Embed (self .config .vocab_size , self .config .d_model , dtype = self .dtype )
695717 self .embed_positions = nn .Embed (self .config .max_target_positions , self .config .d_model , dtype = self .dtype )
696718
697- self .layers = FlaxWhisperDecoderLayerCollection (self .config , dtype = self .dtype )
719+ self .layers = FlaxWhisperDecoderLayerCollection (self .config , dtype = self .dtype , gradient_checkpointing = self . gradient_checkpointing )
698720
699721 self .dropout_layer = nn .Dropout (rate = self .config .dropout )
700722
@@ -753,10 +775,11 @@ def __call__(
753775class FlaxWhisperModule (nn .Module ):
754776 config : WhisperConfig
755777 dtype : jnp .dtype = jnp .float32
778+ gradient_checkpointing : bool = False
756779
757780 def setup (self ) -> None :
758- self .encoder = FlaxWhisperEncoder (self .config , dtype = self .dtype )
759- self .decoder = FlaxWhisperDecoder (self .config , dtype = self .dtype )
781+ self .encoder = FlaxWhisperEncoder (self .config , dtype = self .dtype , gradient_checkpointing = self . gradient_checkpointing )
782+ self .decoder = FlaxWhisperDecoder (self .config , dtype = self .dtype , gradient_checkpointing = self . gradient_checkpointing )
760783
761784 def __call__ (
762785 self ,
@@ -821,11 +844,19 @@ def __init__(
821844 seed : int = 0 ,
822845 dtype : jnp .dtype = jnp .float32 ,
823846 _do_init : bool = True ,
847+ gradient_checkpointing : bool = False ,
824848 ** kwargs ,
825849 ):
826- module = self .module_class (config = config , dtype = dtype , ** kwargs )
850+ module = self .module_class (config = config , dtype = dtype , gradient_checkpointing = gradient_checkpointing , ** kwargs )
827851 super ().__init__ (config , module , input_shape = input_shape , seed = seed , dtype = dtype , _do_init = _do_init )
828852
853+ def enable_gradient_checkpointing (self ):
854+ self ._module = self .module_class (
855+ config = self .config ,
856+ dtype = self .dtype ,
857+ gradient_checkpointing = True ,
858+ )
859+
829860 def init_weights (self , rng : jax .random .PRNGKey , input_shape : Tuple , params : FrozenDict = None ) -> FrozenDict :
830861 # init input tensors
831862 input_features = jnp .zeros (input_shape , dtype = "f4" )
@@ -1137,9 +1168,10 @@ class FlaxWhisperModel(FlaxWhisperPreTrainedModel):
11371168class FlaxWhisperForConditionalGenerationModule (nn .Module ):
11381169 config : WhisperConfig
11391170 dtype : jnp .dtype = jnp .float32
1171+ gradient_checkpointing : bool = False
11401172
11411173 def setup (self ) -> None :
1142- self .model = FlaxWhisperModule (config = self .config , dtype = self .dtype )
1174+ self .model = FlaxWhisperModule (config = self .config , dtype = self .dtype , gradient_checkpointing = self . gradient_checkpointing )
11431175 self .lm_head = nn .Dense (
11441176 self .config .vocab_size ,
11451177 use_bias = False ,
0 commit comments