Skip to content

Commit 34ee884

Browse files
committed
index_update
1 parent 74178e2 commit 34ee884

3 files changed

Lines changed: 22 additions & 29 deletions

File tree

scripts/pretraining/pretraining_utils.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -493,13 +493,14 @@ def dynamic_masking(self, F, input_ids, valid_lengths):
493493
valid_candidates = valid_candidates.astype(np.float32)
494494
num_masked_position = F.np.maximum(
495495
1, F.np.minimum(N, round(valid_lengths * self._mask_prob)))
496+
496497
# The categorical distribution takes normalized probabilities as input
497498
# softmax is used here instead of log_softmax
498499
sample_probs = F.npx.softmax(
499-
self._proposal_distribution * valid_candidates, axis=-1) # (B, L)
500-
# Top-k Sampling is an alternative solution to avoid duplicates positions
500+
self._proposal_distribution * valid_candidates, axis=-1) # (B, L)
501501
masked_positions = F.npx.random.categorical(
502502
sample_probs, shape=N, dtype=np.int32)
503+
503504
masked_weights = F.npx.sequence_mask(
504505
F.np.ones_like(masked_positions),
505506
sequence_length=num_masked_position,
@@ -508,7 +509,7 @@ def dynamic_masking(self, F, input_ids, valid_lengths):
508509
length_masks = F.npx.sequence_mask(
509510
F.np.ones_like(input_ids, dtype=np.float32),
510511
sequence_length=valid_lengths,
511-
use_sequence_length=True, axis=1, value=0).astype(np.float32)
512+
use_sequence_length=True, axis=1, value=0)
512513
unmasked_tokens = select_vectors_by_position(
513514
F, input_ids, masked_positions) * masked_weights
514515
masked_weights = masked_weights.astype(np.float32)
@@ -518,11 +519,8 @@ def dynamic_masking(self, F, input_ids, valid_lengths):
518519
F.np.zeros_like(masked_positions),
519520
F.np.ones_like(masked_positions)) > self._mask_prob) * masked_positions
520521
# deal with multiple zeros
521-
filled = F.np.where(
522-
replaced_positions,
523-
self.vocab.mask_id,
524-
masked_positions).astype(np.int32)
525-
masked_input_ids, _ = updated_vectors_by_position(F, input_ids, filled, replaced_positions)
522+
filled = F.np.where(replaced_positions, self.vocab.mask_id, masked_positions)
523+
masked_input_ids = updated_vectors_by_position(F, input_ids, filled, replaced_positions)
526524
masked_input = self.MaskedInput(input_ids=masked_input_ids,
527525
masks=length_masks,
528526
unmasked_tokens=unmasked_tokens,

src/gluonnlp/models/electra.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
from mxnet import use_np
3737
from mxnet.gluon import HybridBlock, nn
3838
from ..registry import BACKBONE_REGISTRY
39-
from ..op import gumbel_softmax, select_vectors_by_position, updated_vectors_by_position
39+
from ..op import gumbel_softmax, select_vectors_by_position, add_vectors_by_position, updated_vectors_by_position
4040
from ..base import get_model_zoo_home_dir, get_repo_model_zoo_url, get_model_zoo_checksum_dir
4141
from ..layers import PositionalEmbedding, get_activation
4242
from .transformer import TransformerEncoderLayer
@@ -833,13 +833,14 @@ def get_corrupted_tokens(self, F, inputs, unmasked_tokens, masked_positions, log
833833
use_np_gumbel=False)
834834
corrupted_tokens = F.np.argmax(prob, axis=-1).astype(np.int32)
835835

836-
# Following the Official electra to deal with duplicate positions as
837-
# https://github.com/google-research/electra/issues/41
838-
original_data, updates_mask = updated_vectors_by_position(F,
836+
original_data = updated_vectors_by_position(F,
839837
inputs, unmasked_tokens, masked_positions)
840-
fake_data, _ = updated_vectors_by_position(F,
838+
fake_data = updated_vectors_by_position(F,
841839
inputs, corrupted_tokens, masked_positions)
842-
840+
updates_mask = add_vectors_by_position(F, F.np.zeros_like(inputs),
841+
F.np.ones_like(masked_positions), masked_positions)
842+
# Dealing with duplicate positions
843+
updates_mask = F.np.minimum(updates_mask, 1)
843844
labels = updates_mask * F.np.not_equal(fake_data, original_data)
844845
return corrupted_tokens, fake_data, labels
845846

src/gluonnlp/op.py

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def updated_vectors_by_position(F, base, data, positions):
100100
"""
101101
Update each batch with the given positions. Considered as a reversed process of
102102
"select_vectors_by_position", this is an advanced operator of add_vectors_by_position
103-
that updates the results instead of add and avoids duplicate positions.
103+
that updates the results instead of adding.
104104
Once advanced indexing can be hybridized, we can revise the implementation.
105105
106106
updates[i, positions[i, j], :] = data[i, j, :]
@@ -127,22 +127,16 @@ def updated_vectors_by_position(F, base, data, positions):
127127
out
128128
The updated result.
129129
Shape (batch_size, seq_length)
130-
updates_mask
131-
The state of the updated for the whole sequence
132-
1 -> updated, 0 -> not updated.
133-
Shape (batch_size, seq_length)
134130
"""
135-
# TODO(zheyuye), update when npx.index_update implemented
136-
updates = add_vectors_by_position(F, F.np.zeros_like(base), data, positions)
137-
updates_mask = add_vectors_by_position(F, F.np.zeros_like(base),
138-
F.np.ones_like(positions), positions)
139-
updates = (updates / F.np.maximum(1, updates_mask)).astype(np.int32)
140-
141-
out = F.np.where(updates, updates, base)
142-
updates_mask = F.np.minimum(updates_mask, 1)
143-
144-
return out, updates_mask
131+
positions = positions.astype(np.int32)
132+
# batch_idx.shape = (batch_size, 1) as [[0], [1], [2], ...]
133+
batch_idx = F.np.expand_dims(F.npx.arange_like(positions, axis=0),
134+
axis=1).astype(np.int32)
135+
batch_idx = batch_idx + F.np.zeros_like(positions)
136+
indices = F.np.stack([batch_idx.reshape(-1), positions.reshape(-1)])
145137

138+
out = F.npx.index_update(base, indices, data.reshape(-1))
139+
return out
146140

147141
@use_np
148142
def gumbel_softmax(F, logits, temperature: float = 1.0, eps: float = 1E-10,

0 commit comments

Comments
 (0)