@@ -493,13 +493,14 @@ def dynamic_masking(self, F, input_ids, valid_lengths):
493493 valid_candidates = valid_candidates .astype (np .float32 )
494494 num_masked_position = F .np .maximum (
495495 1 , F .np .minimum (N , round (valid_lengths * self ._mask_prob )))
496+
496497 # The categorical distribution takes normalized probabilities as input
497498 # softmax is used here instead of log_softmax
498499 sample_probs = F .npx .softmax (
499- self ._proposal_distribution * valid_candidates , axis = - 1 ) # (B, L)
500- # Top-k Sampling is an alternative solution to avoid duplicates positions
500+ self ._proposal_distribution * valid_candidates , axis = - 1 ) # (B, L)
501501 masked_positions = F .npx .random .categorical (
502502 sample_probs , shape = N , dtype = np .int32 )
503+
503504 masked_weights = F .npx .sequence_mask (
504505 F .np .ones_like (masked_positions ),
505506 sequence_length = num_masked_position ,
@@ -508,7 +509,7 @@ def dynamic_masking(self, F, input_ids, valid_lengths):
508509 length_masks = F .npx .sequence_mask (
509510 F .np .ones_like (input_ids , dtype = np .float32 ),
510511 sequence_length = valid_lengths ,
511- use_sequence_length = True , axis = 1 , value = 0 ). astype ( np . float32 )
512+ use_sequence_length = True , axis = 1 , value = 0 )
512513 unmasked_tokens = select_vectors_by_position (
513514 F , input_ids , masked_positions ) * masked_weights
514515 masked_weights = masked_weights .astype (np .float32 )
@@ -518,11 +519,8 @@ def dynamic_masking(self, F, input_ids, valid_lengths):
518519 F .np .zeros_like (masked_positions ),
519520 F .np .ones_like (masked_positions )) > self ._mask_prob ) * masked_positions
520521 # deal with multiple zeros
521- filled = F .np .where (
522- replaced_positions ,
523- self .vocab .mask_id ,
524- masked_positions ).astype (np .int32 )
525- masked_input_ids , _ = updated_vectors_by_position (F , input_ids , filled , replaced_positions )
522+ filled = F .np .where (replaced_positions , self .vocab .mask_id , masked_positions )
523+ masked_input_ids = updated_vectors_by_position (F , input_ids , filled , replaced_positions )
526524 masked_input = self .MaskedInput (input_ids = masked_input_ids ,
527525 masks = length_masks ,
528526 unmasked_tokens = unmasked_tokens ,
0 commit comments