Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 36 additions & 24 deletions python/paddle/fluid/tests/unittests/test_imperative_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ class ModelHyperParams(object):
# to process after each sub-layer
postprocess_cmd = "da" # dropout + residual connection
# random seed used in dropout for CE.
dropout_seed = 1
dropout_seed = None
# the flag indicating whether to share embedding and softmax weights.
# vocabularies in source and target should be same for weight sharing.
weight_sharing = True
Expand Down Expand Up @@ -166,15 +166,21 @@ def create_data(is_static=False):
]
else:
enc_inputs = [
to_variable(src_word_np), to_variable(src_pos_np),
to_variable(src_slf_attn_bias_np)
to_variable(
src_word_np, name='src_word'), to_variable(
src_pos_np, name='src_pos'), to_variable(
src_slf_attn_bias_np, name='src_slf_attn_bias')
]
dec_inputs = [
to_variable(trg_word_np), to_variable(trg_pos_np),
to_variable(trg_slf_attn_bias_np), to_variable(trg_src_attn_bias_np)
to_variable(
trg_word_np, name='trg_word'), to_variable(
trg_pos_np, name='trg_pos'), to_variable(
trg_slf_attn_bias_np, name='trg_slf_attn_bias'),
to_variable(
trg_src_attn_bias_np, name='trg_src_attn_bias')
]
label = to_variable(lbl_word_np)
weight = to_variable(lbl_weight_np)
label = to_variable(lbl_word_np, name='lbl_word')
weight = to_variable(lbl_weight_np, name='lbl_weight')
return enc_inputs, dec_inputs, label, weight


Expand Down Expand Up @@ -211,7 +217,7 @@ def make_all_inputs(input_fields):
# The placeholder for batch_size in compile time. Must be -1 currently to be
# consistent with some ops' infer-shape output in compile time, such as the
# sequence_expand op used in beamsearch decoder.
batch_size = 32
batch_size = -1
# The placeholder for squence length in compile time.
seq_len = ModelHyperParams.max_length
# Here list the data shapes and data types of all inputs.
Expand Down Expand Up @@ -304,35 +310,40 @@ def make_all_inputs(input_fields):

batch_num = 5

np.random.seed = 1
np.random.seed = 90
src_word_np = np.random.randint(
1,
ModelHyperParams.src_vocab_size - 1,
size=(batch_size, seq_len, 1),
size=(TrainTaskConfig.batch_size, seq_len, 1),
dtype='int64')
src_pos_np = np.random.randint(
1, seq_len, size=(batch_size, seq_len, 1), dtype='int64')
src_slf_attn_bias_np = np.random.randn(batch_size, ModelHyperParams.n_head,
seq_len, seq_len).astype('float32')
1, seq_len, size=(TrainTaskConfig.batch_size, seq_len, 1), dtype='int64')
src_slf_attn_bias_np = np.random.randn(TrainTaskConfig.batch_size,
ModelHyperParams.n_head, seq_len,
seq_len).astype('float32')

trg_word_np = np.random.randint(
1,
ModelHyperParams.src_vocab_size - 1,
size=(batch_size, seq_len, 1),
size=(TrainTaskConfig.batch_size, seq_len, 1),
dtype='int64')
trg_pos_np = np.random.randint(
1, seq_len, size=(batch_size, seq_len, 1), dtype='int64')
trg_slf_attn_bias_np = np.random.randn(batch_size, ModelHyperParams.n_head,
seq_len, seq_len).astype('float32')
trg_src_attn_bias_np = np.random.randn(batch_size, ModelHyperParams.n_head,
seq_len, seq_len).astype('float32')
1, seq_len, size=(TrainTaskConfig.batch_size, seq_len, 1), dtype='int64')
trg_slf_attn_bias_np = np.random.randn(TrainTaskConfig.batch_size,
ModelHyperParams.n_head, seq_len,
seq_len).astype('float32')
trg_src_attn_bias_np = np.random.randn(TrainTaskConfig.batch_size,
ModelHyperParams.n_head, seq_len,
seq_len).astype('float32')

lbl_word_np = np.random.randint(
1,
ModelHyperParams.src_vocab_size - 1,
size=(batch_size * seq_len, 1),
size=(TrainTaskConfig.batch_size * seq_len, 1),
dtype='int64')
lbl_weight_np = np.random.randn(batch_size * seq_len, 1).astype('float32')

lbl_weight_np = np.random.randn(TrainTaskConfig.batch_size * seq_len,
1).astype('float32')

pos_inp1 = position_encoding_init(ModelHyperParams.max_length,
ModelHyperParams.d_model)
Expand Down Expand Up @@ -447,7 +458,7 @@ def forward(self, queries, keys, values, attn_bias):
x=v, shape=[0, 0, self._n_head, self._d_value], inplace=False)
transpose_v = fluid.layers.transpose(x=reshaped_v, perm=[0, 2, 1, 3])

#scale dot product attention
# scale dot product attention
product = fluid.layers.matmul(
x=transpose_q,
y=transpose_k,
Expand Down Expand Up @@ -971,13 +982,15 @@ def test_transformer_float32(self):
enc_inputs, dec_inputs, label, weights = create_data()
dy_sum_cost, dy_avg_cost, dy_predict, dy_token_num = transformer(
enc_inputs, dec_inputs, label, weights)

if i == 0:
for param in transformer.parameters():
dy_param_init[param.name] = param._numpy()

dy_avg_cost._backward()
optimizer.minimize(dy_avg_cost)
transformer.clear_gradients()

if i == batch_num - 1:
for param in transformer.parameters():
dy_param_updated[param.name] = param._numpy()
Expand Down Expand Up @@ -1024,7 +1037,6 @@ def test_transformer_float32(self):
static_param_name_list = list()
static_sum_cost, static_avg_cost, static_predict, static_token_num = transformer(
enc_inputs, dec_inputs, label, weights)

optimizer.minimize(static_avg_cost)
for param in transformer.parameters():
static_param_name_list.append(param.name)
Expand All @@ -1042,8 +1054,8 @@ def test_transformer_float32(self):
static_sum_cost, static_avg_cost, static_predict,
static_token_num
]
fetch_list.extend(static_param_name_list)

fetch_list.extend(static_param_name_list)
out = exe.run(fluid.default_main_program(),
feed=feed_dict,
fetch_list=fetch_list)
Expand Down