Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 35 additions & 39 deletions tests/test_modeling_tf_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,14 +461,14 @@ def test_compile_tf_model(self):
if self.is_encoder_decoder:
input_ids = {
"decoder_input_ids": tf.keras.Input(
batch_shape=(2, 2000), name="decoder_input_ids", dtype="int32"
batch_shape=(2, 512), name="decoder_input_ids", dtype="int32"
),
"input_ids": tf.keras.Input(batch_shape=(2, 2000), name="input_ids", dtype="int32"),
"input_ids": tf.keras.Input(batch_shape=(2, 512), name="input_ids", dtype="int32"),
}
elif model_class in TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING.values():
input_ids = tf.keras.Input(batch_shape=(4, 2, 2000), name="input_ids", dtype="int32")
input_ids = tf.keras.Input(batch_shape=(4, 2, 512), name="input_ids", dtype="int32")
else:
input_ids = tf.keras.Input(batch_shape=(2, 2000), name="input_ids", dtype="int32")
input_ids = tf.keras.Input(batch_shape=(2, 512), name="input_ids", dtype="int32")

Copy link
Contributor Author

@jplu jplu Nov 12, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Replace 2000 by 512 here because all the models but Longformer take at most 512 tokens. Also keeping 2000 here will raise an error in the future embedding layers for these models limited to 512 tokens.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can put it to self.model_tester.max_position_embeddings so that it is model specific?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very good idea!! Just done it!

Copy link
Contributor Author

@jplu jplu Nov 12, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@LysandreJik apparently XLNet, TransfoXL and T5 have not a max_position_embeddings attribute in their config (which is normal). Should I add the parameter manually in the test class?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I did not take this into account. Giving them a max_position_embeddings wouldn't be wise, though, as they actually do not have such a limitation.

I guess here the best approach would be to get a max_position_embeddings, and if they don't have it (they're using relative position embeddings), then set a default of 512 for the sake of this test only.

That means something like getattr(self.model_tester, "max_position_embeddings", 512).

What do you think? This way it stays model-agnostic, but doesn't add the model_tester attribute max_position_embeddings as it doesn't belong here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perfect! I like it 👍

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Like this?

# Prepare our model
model = model_class(config)
Expand Down Expand Up @@ -508,81 +508,77 @@ def test_keyword_and_dict_args(self):
def test_attention_outputs(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.return_dict = True

decoder_seq_length = getattr(self.model_tester, "decoder_seq_length", self.model_tester.seq_length)
encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", self.model_tester.seq_length)
decoder_key_length = getattr(self.model_tester, "key_length", decoder_seq_length)
encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length)

for model_class in self.all_model_classes:
inputs_dict["output_attentions"] = True
inputs_dict["use_cache"] = False
config.output_hidden_states = False
model = model_class(config)
outputs = model(self._prepare_for_class(inputs_dict, model_class))
def check_decoder_attentions_output(outputs):
out_len = len(outputs)
self.assertEqual(out_len % 2, 0)
decoder_attentions = outputs.decoder_attentions
self.assertEqual(len(decoder_attentions), self.model_tester.num_hidden_layers)
self.assertListEqual(
list(decoder_attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads, decoder_seq_length, decoder_key_length],
)

def check_encoder_attentions_output(outputs):
attentions = [
t.numpy() for t in (outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions)
]
self.assertEqual(model.config.output_hidden_states, False)
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
self.assertListEqual(
list(attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
)

for model_class in self.all_model_classes:
inputs_dict["output_attentions"] = True
inputs_dict["use_cache"] = False
config.output_hidden_states = False
model = model_class(config)
outputs = model(self._prepare_for_class(inputs_dict, model_class))
out_len = len(outputs)
self.assertEqual(config.output_hidden_states, False)
check_encoder_attentions_output(outputs)

if self.is_encoder_decoder:
self.assertEqual(out_len % 2, 0)
decoder_attentions = outputs.decoder_attentions
self.assertEqual(model.config.output_hidden_states, False)
self.assertEqual(len(decoder_attentions), self.model_tester.num_hidden_layers)
self.assertListEqual(
list(decoder_attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads, decoder_seq_length, decoder_key_length],
)
model = model_class(config)
outputs = model(self._prepare_for_class(inputs_dict, model_class))
self.assertEqual(config.output_hidden_states, False)
check_decoder_attentions_output(outputs)

# Check that output attentions can also be changed via the config
del inputs_dict["output_attentions"]
config.output_attentions = True
model = model_class(config)
outputs = model(self._prepare_for_class(inputs_dict, model_class))
attentions = [
t.numpy() for t in (outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions)
]
self.assertEqual(model.config.output_hidden_states, False)
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
self.assertListEqual(
list(attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
)
self.assertEqual(config.output_hidden_states, False)
check_encoder_attentions_output(outputs)

# Check attention is always last and order is fine
inputs_dict["output_attentions"] = True
config.output_hidden_states = True
model = model_class(config)
outputs = model(self._prepare_for_class(inputs_dict, model_class))

self.assertEqual(out_len + (2 if self.is_encoder_decoder else 1), len(outputs))
self.assertEqual(model.config.output_hidden_states, True)

attentions = [
t.numpy() for t in (outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions)
]
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
self.assertListEqual(
list(attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
)
check_encoder_attentions_output(outputs)

def test_hidden_states_output(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

def check_hidden_states_output(config, inputs_dict, model_class):
model = model_class(config)
outputs = model(self._prepare_for_class(inputs_dict, model_class))
hidden_states = [t.numpy() for t in outputs[-1]]
expected_num_layers = getattr(
self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1
)

hidden_states = outputs[-1]
self.assertEqual(config.output_attentions, False)
self.assertEqual(len(hidden_states), expected_num_layers)
self.assertListEqual(
list(hidden_states[0].shape[-2:]),
Expand Down
58 changes: 20 additions & 38 deletions tests/test_modeling_tf_longformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,23 +133,21 @@ def create_and_check_attention_mask_determinism(
def create_and_check_longformer_model(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
config.return_dict = True
model = TFLongformerModel(config=config)
sequence_output, pooled_output = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
sequence_output, pooled_output = model(input_ids, token_type_ids=token_type_ids)
sequence_output, pooled_output = model(input_ids)
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
result = model(input_ids, token_type_ids=token_type_ids)
result = model(input_ids)

result = {
"sequence_output": sequence_output,
"pooled_output": pooled_output,
}
self.parent.assertListEqual(
shape_list(result["sequence_output"]), [self.batch_size, self.seq_length, self.hidden_size]
shape_list(result.last_hidden_state), [self.batch_size, self.seq_length, self.hidden_size]
)
self.parent.assertListEqual(shape_list(result["pooled_output"]), [self.batch_size, self.hidden_size])
self.parent.assertListEqual(shape_list(result.pooler_output), [self.batch_size, self.hidden_size])

def create_and_check_longformer_model_with_global_attention_mask(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
config.return_dict = True
model = TFLongformerModel(config=config)
half_input_mask_length = shape_list(input_mask)[-1] // 2
global_attention_mask = tf.concat(
Expand All @@ -160,59 +158,43 @@ def create_and_check_longformer_model_with_global_attention_mask(
axis=-1,
)

sequence_output, pooled_output = model(
result = model(
input_ids,
attention_mask=input_mask,
global_attention_mask=global_attention_mask,
token_type_ids=token_type_ids,
)
sequence_output, pooled_output = model(
input_ids, token_type_ids=token_type_ids, global_attention_mask=global_attention_mask
)
sequence_output, pooled_output = model(input_ids, global_attention_mask=global_attention_mask)
result = model(input_ids, token_type_ids=token_type_ids, global_attention_mask=global_attention_mask)
result = model(input_ids, global_attention_mask=global_attention_mask)

result = {
"sequence_output": sequence_output,
"pooled_output": pooled_output,
}
self.parent.assertListEqual(
shape_list(result["sequence_output"]), [self.batch_size, self.seq_length, self.hidden_size]
shape_list(result.last_hidden_state), [self.batch_size, self.seq_length, self.hidden_size]
)
self.parent.assertListEqual(shape_list(result["pooled_output"]), [self.batch_size, self.hidden_size])
self.parent.assertListEqual(shape_list(result.pooler_output), [self.batch_size, self.hidden_size])

def create_and_check_longformer_for_masked_lm(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
config.return_dict = True
model = TFLongformerForMaskedLM(config=config)
loss, prediction_scores = model(
input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels
)
result = {
"loss": loss,
"prediction_scores": prediction_scores,
}
self.parent.assertListEqual(
shape_list(result["prediction_scores"]), [self.batch_size, self.seq_length, self.vocab_size]
)
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
self.parent.assertListEqual(shape_list(result.logits), [self.batch_size, self.seq_length, self.vocab_size])

def create_and_check_longformer_for_question_answering(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
config.return_dict = True
model = TFLongformerForQuestionAnswering(config=config)
loss, start_logits, end_logits = model(
result = model(
input_ids,
attention_mask=input_mask,
token_type_ids=token_type_ids,
start_positions=sequence_labels,
end_positions=sequence_labels,
)
result = {
"loss": loss,
"start_logits": start_logits,
"end_logits": end_logits,
}
self.parent.assertListEqual(shape_list(result["start_logits"]), [self.batch_size, self.seq_length])
self.parent.assertListEqual(shape_list(result["end_logits"]), [self.batch_size, self.seq_length])

self.parent.assertListEqual(shape_list(result.start_logits), [self.batch_size, self.seq_length])
self.parent.assertListEqual(shape_list(result.end_logits), [self.batch_size, self.seq_length])

def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
Expand Down