Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/transformers/models/dpr/modeling_tf_dpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ class TFDPRContextEncoderOutput(ModelOutput):
heads.
"""

pooler_output: tf.Tensor
pooler_output: tf.Tensor = None
Copy link
Contributor

@patrickvonplaten patrickvonplaten Nov 19, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@LysandreJik @sgugger - for some reason, these outputs need to be initialized with something or else our much-loved tests test_compile_tf_model and test_saved_model_with_... fail. Tbh, I don't know why this is - seems to be some weird tf graph problem with positional arguments...-> @sgugger is this fine for you?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I remember that now. Had to do the same for all tf ModelOutput. This is fine indeed!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great catch! :-)

hidden_states: Optional[Tuple[tf.Tensor]] = None
attentions: Optional[Tuple[tf.Tensor]] = None

Expand Down Expand Up @@ -110,7 +110,7 @@ class TFDPRQuestionEncoderOutput(ModelOutput):
heads.
"""

pooler_output: tf.Tensor
pooler_output: tf.Tensor = None
hidden_states: Optional[Tuple[tf.Tensor]] = None
attentions: Optional[Tuple[tf.Tensor]] = None

Expand Down Expand Up @@ -141,7 +141,7 @@ class TFDPRReaderOutput(ModelOutput):
heads.
"""

start_logits: tf.Tensor
start_logits: tf.Tensor = None
end_logits: tf.Tensor = None
relevance_logits: tf.Tensor = None
hidden_states: Optional[Tuple[tf.Tensor]] = None
Expand Down
10 changes: 4 additions & 6 deletions tests/test_modeling_tf_dpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,6 @@ def prepare_config_and_inputs(self):
type_vocab_size=self.type_vocab_size,
is_decoder=False,
initializer_range=self.initializer_range,
# MODIFY
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Deleted all those #MODIFY statements at the cost of having to initialize the pooler_output = None as explained in the other comment

return_dict=False,
)
config = DPRConfig(projection_dim=self.projection_dim, **config.to_dict())

Expand All @@ -136,7 +134,7 @@ def create_and_check_dpr_context_encoder(
model = TFDPRContextEncoder(config=config)
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
result = model(input_ids, token_type_ids=token_type_ids)
result = model(input_ids, return_dict=True) # MODIFY
result = model(input_ids)
self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.projection_dim or self.hidden_size))

def create_and_check_dpr_question_encoder(
Expand All @@ -145,14 +143,14 @@ def create_and_check_dpr_question_encoder(
model = TFDPRQuestionEncoder(config=config)
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
result = model(input_ids, token_type_ids=token_type_ids)
result = model(input_ids, return_dict=True) # MODIFY
result = model(input_ids)
self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.projection_dim or self.hidden_size))

def create_and_check_dpr_reader(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
model = TFDPRReader(config=config)
result = model(input_ids, attention_mask=input_mask, return_dict=True) # MODIFY
result = model(input_ids, attention_mask=input_mask)

self.parent.assertEqual(result.start_logits.shape, (self.batch_size, self.seq_length))
self.parent.assertEqual(result.end_logits.shape, (self.batch_size, self.seq_length))
Expand Down Expand Up @@ -267,7 +265,7 @@ def test_saved_model_with_attentions_output(self):
class TFDPRModelIntegrationTest(unittest.TestCase):
@slow
def test_inference_no_head(self):
model = TFDPRQuestionEncoder.from_pretrained("facebook/dpr-question_encoder-single-nq-base", return_dict=False)
model = TFDPRQuestionEncoder.from_pretrained("facebook/dpr-question_encoder-single-nq-base")

input_ids = tf.constant(
[[101, 7592, 1010, 2003, 2026, 3899, 10140, 1029, 102]]
Expand Down