Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions .github/workflows/self-push.yml
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ jobs:
run: |
source .env/bin/activate
pip install --upgrade pip
pip install .[torch,sklearn,testing,onnxruntime]
pip install .[torch,sklearn,testing,onnxruntime,sentencepiece]
pip install git+https://github.com/huggingface/datasets

- name: Are GPUs recognized by our DL frameworks
Expand Down Expand Up @@ -117,7 +117,7 @@ jobs:
run: |
source .env/bin/activate
pip install --upgrade pip
pip install .[tf,sklearn,testing,onnxruntime]
pip install .[tf,sklearn,testing,onnxruntime,sentencepiece]
pip install git+https://github.com/huggingface/datasets

- name: Are GPUs recognized by our DL frameworks
Expand Down Expand Up @@ -185,7 +185,7 @@ jobs:
run: |
source .env/bin/activate
pip install --upgrade pip
pip install .[torch,sklearn,testing,onnxruntime]
pip install .[torch,sklearn,testing,onnxruntime,sentencepiece]
pip install git+https://github.com/huggingface/datasets

- name: Are GPUs recognized by our DL frameworks
Expand Down Expand Up @@ -244,7 +244,7 @@ jobs:
run: |
source .env/bin/activate
pip install --upgrade pip
pip install .[tf,sklearn,testing,onnxruntime]
pip install .[tf,sklearn,testing,onnxruntime,sentencepiece]
pip install git+https://github.com/huggingface/datasets

- name: Are GPUs recognized by our DL frameworks
Expand Down
16 changes: 8 additions & 8 deletions .github/workflows/self-scheduled.yml
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ jobs:
run: |
source .env/bin/activate
pip install --upgrade pip
pip install .[torch,sklearn,testing,onnxruntime]
pip install .[torch,sklearn,testing,onnxruntime,sentencepiece]
pip install git+https://github.com/huggingface/datasets
pip list

Expand Down Expand Up @@ -144,7 +144,7 @@ jobs:
run: |
source .env/bin/activate
pip install --upgrade pip
pip install .[tf,sklearn,testing,onnxruntime]
pip install .[tf,sklearn,testing,onnxruntime,sentencepiece]
pip install git+https://github.com/huggingface/datasets
pip list

Expand Down Expand Up @@ -223,7 +223,7 @@ jobs:
run: |
source .env/bin/activate
pip install --upgrade pip
pip install .[torch,sklearn,testing,onnxruntime]
pip install .[torch,sklearn,testing,onnxruntime,sentencepiece]
pip install git+https://github.com/huggingface/datasets
pip list

Expand Down Expand Up @@ -251,11 +251,11 @@ jobs:
RUN_SLOW: yes
run: |
source .env/bin/activate
python -m pytest -n 1 --dist=loadfile -s --make-reports=examples_torch_multi_gpu examples
python -m pytest -n 1 --dist=loadfile -s --make-reports=tests_torch_examples_multi_gpu examples

- name: Failure short reports
if: ${{ always() }}
run: cat reports/examples_torch_multi_gpu_failures_short.txt
run: cat reports/tests_torch_examples_multi_gpu_failures_short.txt

- name: Run all pipeline tests on multi-GPU
if: ${{ always() }}
Expand Down Expand Up @@ -314,7 +314,7 @@ jobs:
run: |
source .env/bin/activate
pip install --upgrade pip
pip install .[tf,sklearn,testing,onnxruntime]
pip install .[tf,sklearn,testing,onnxruntime,sentencepiece]
pip install git+https://github.com/huggingface/datasets
pip list

Expand Down Expand Up @@ -345,11 +345,11 @@ jobs:
RUN_PIPELINE_TESTS: yes
run: |
source .env/bin/activate
python -m pytest -n 1 --dist=loadfile -s -m is_pipeline_test --make-reports=tests_tf_pipelines_multi_gpu tests
python -m pytest -n 1 --dist=loadfile -s -m is_pipeline_test --make-reports=tests_tf_pipeline_multi_gpu tests

- name: Failure short reports
if: ${{ always() }}
run: cat reports/tests_tf_multi_gpu_pipelines_failures_short.txt
run: cat reports/tests_tf_pipeline_multi_gpu_failures_short.txt

- name: Test suite reports artifacts
if: ${{ always() }}
Expand Down
61 changes: 37 additions & 24 deletions src/transformers/models/dpr/modeling_tf_dpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ class TFDPRContextEncoderOutput(ModelOutput):
heads.
"""

pooler_output: tf.Tensor
pooler_output: tf.Tensor = None
Copy link
Contributor

@patrickvonplaten patrickvonplaten Nov 19, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@LysandreJik @sgugger - for some reason, these outputs need to be initialized with something or else our much-loved tests test_compile_tf_model and test_saved_model_with_... fail. Tbh, I don't know why this is - seems to be some weird tf graph problem with positional arguments...-> @sgugger is this fine for you?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I remember that now. Had to do the same for all tf ModelOutput. This is fine indeed!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great catch! :-)

hidden_states: Optional[Tuple[tf.Tensor]] = None
attentions: Optional[Tuple[tf.Tensor]] = None

Expand Down Expand Up @@ -110,7 +110,7 @@ class TFDPRQuestionEncoderOutput(ModelOutput):
heads.
"""

pooler_output: tf.Tensor
pooler_output: tf.Tensor = None
hidden_states: Optional[Tuple[tf.Tensor]] = None
attentions: Optional[Tuple[tf.Tensor]] = None

Expand Down Expand Up @@ -141,7 +141,7 @@ class TFDPRReaderOutput(ModelOutput):
heads.
"""

start_logits: tf.Tensor
start_logits: tf.Tensor = None
end_logits: tf.Tensor = None
relevance_logits: tf.Tensor = None
hidden_states: Optional[Tuple[tf.Tensor]] = None
Expand Down Expand Up @@ -181,7 +181,7 @@ def call(
return_dict = return_dict if return_dict is not None else self.bert_model.return_dict

outputs = self.bert_model(
inputs=input_ids,
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds,
Expand Down Expand Up @@ -228,7 +228,8 @@ def __init__(self, config: DPRConfig, *args, **kwargs):
def call(
self,
input_ids: Tensor,
attention_mask: Tensor,
attention_mask: Optional[Tensor] = None,
token_type_ids: Optional[Tensor] = None,
inputs_embeds: Optional[Tensor] = None,
output_attentions: bool = False,
output_hidden_states: bool = False,
Expand All @@ -242,6 +243,7 @@ def call(
outputs = self.encoder(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
Expand Down Expand Up @@ -474,19 +476,21 @@ def call(
if isinstance(inputs, (tuple, list)):
input_ids = inputs[0]
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
inputs_embeds = inputs[2] if len(inputs) > 2 else inputs_embeds
output_attentions = inputs[3] if len(inputs) > 3 else output_attentions
output_hidden_states = inputs[4] if len(inputs) > 4 else output_hidden_states
return_dict = inputs[5] if len(inputs) > 5 else return_dict
assert len(inputs) <= 6, "Too many inputs."
token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids
inputs_embeds = inputs[3] if len(inputs) > 3 else inputs_embeds
output_attentions = inputs[4] if len(inputs) > 4 else output_attentions
output_hidden_states = inputs[5] if len(inputs) > 5 else output_hidden_states
return_dict = inputs[6] if len(inputs) > 6 else return_dict
assert len(inputs) <= 7, "Too many inputs."
elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask", attention_mask)
token_type_ids = inputs.get("token_type_ids", token_type_ids)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
output_attentions = inputs.get("output_attentions", output_attentions)
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
return_dict = inputs.get("return_dict", return_dict)
assert len(inputs) <= 6, "Too many inputs."
assert len(inputs) <= 7, "Too many inputs."
else:
input_ids = inputs

Expand Down Expand Up @@ -573,19 +577,21 @@ def call(
if isinstance(inputs, (tuple, list)):
input_ids = inputs[0]
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
inputs_embeds = inputs[2] if len(inputs) > 2 else inputs_embeds
output_attentions = inputs[3] if len(inputs) > 3 else output_attentions
output_hidden_states = inputs[4] if len(inputs) > 4 else output_hidden_states
return_dict = inputs[5] if len(inputs) > 5 else return_dict
assert len(inputs) <= 6, "Too many inputs."
token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids
inputs_embeds = inputs[3] if len(inputs) > 3 else inputs_embeds
output_attentions = inputs[4] if len(inputs) > 4 else output_attentions
output_hidden_states = inputs[5] if len(inputs) > 5 else output_hidden_states
return_dict = inputs[6] if len(inputs) > 6 else return_dict
assert len(inputs) <= 7, "Too many inputs."
elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask", attention_mask)
token_type_ids = inputs.get("token_type_ids", token_type_ids)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
output_attentions = inputs.get("output_attentions", output_attentions)
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
return_dict = inputs.get("return_dict", return_dict)
assert len(inputs) <= 6, "Too many inputs."
assert len(inputs) <= 7, "Too many inputs."
else:
input_ids = inputs

Expand Down Expand Up @@ -650,6 +656,7 @@ def call(
self,
inputs,
attention_mask: Optional[Tensor] = None,
token_type_ids: Optional[Tensor] = None,
inputs_embeds: Optional[Tensor] = None,
output_attentions: bool = None,
output_hidden_states: bool = None,
Expand Down Expand Up @@ -679,19 +686,21 @@ def call(
if isinstance(inputs, (tuple, list)):
input_ids = inputs[0]
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
inputs_embeds = inputs[2] if len(inputs) > 2 else inputs_embeds
output_attentions = inputs[3] if len(inputs) > 3 else output_attentions
output_hidden_states = inputs[4] if len(inputs) > 4 else output_hidden_states
return_dict = inputs[5] if len(inputs) > 5 else return_dict
assert len(inputs) <= 6, "Too many inputs."
token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids
inputs_embeds = inputs[3] if len(inputs) > 3 else inputs_embeds
output_attentions = inputs[4] if len(inputs) > 4 else output_attentions
output_hidden_states = inputs[5] if len(inputs) > 5 else output_hidden_states
return_dict = inputs[6] if len(inputs) > 6 else return_dict
assert len(inputs) <= 7, "Too many inputs."
elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask", attention_mask)
token_type_ids = inputs.get("token_type_ids", token_type_ids)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
output_attentions = inputs.get("output_attentions", output_attentions)
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
return_dict = inputs.get("return_dict", return_dict)
assert len(inputs) <= 6, "Too many inputs."
assert len(inputs) <= 7, "Too many inputs."
else:
input_ids = inputs

Expand All @@ -713,9 +722,13 @@ def call(
if attention_mask is None:
attention_mask = tf.ones(input_shape, dtype=tf.dtypes.int32)

if token_type_ids is None:
token_type_ids = tf.zeros(input_shape, dtype=tf.dtypes.int32)

return self.span_predictor(
input_ids,
attention_mask,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
Expand Down
1 change: 1 addition & 0 deletions tests/test_modeling_tf_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,7 @@ def test_custom_load_tf_weights(self):
self.assertTrue(layer.split("_")[0] in ["dropout", "classifier"])


@require_tf
class TFBertModelIntegrationTest(unittest.TestCase):
@slow
def test_inference_masked_lm(self):
Expand Down
55 changes: 43 additions & 12 deletions tests/test_modeling_tf_dpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import tempfile
import unittest

from transformers import is_tf_available
Expand Down Expand Up @@ -124,8 +123,6 @@ def prepare_config_and_inputs(self):
type_vocab_size=self.type_vocab_size,
is_decoder=False,
initializer_range=self.initializer_range,
# MODIFY
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Deleted all those #MODIFY statements at the cost of having to initialize the pooler_output = None as explained in the other comment

return_dict=False,
)
config = DPRConfig(projection_dim=self.projection_dim, **config.to_dict())

Expand All @@ -137,7 +134,7 @@ def create_and_check_dpr_context_encoder(
model = TFDPRContextEncoder(config=config)
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
result = model(input_ids, token_type_ids=token_type_ids)
result = model(input_ids, return_dict=True) # MODIFY
result = model(input_ids)
self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.projection_dim or self.hidden_size))

def create_and_check_dpr_question_encoder(
Expand All @@ -146,14 +143,14 @@ def create_and_check_dpr_question_encoder(
model = TFDPRQuestionEncoder(config=config)
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
result = model(input_ids, token_type_ids=token_type_ids)
result = model(input_ids, return_dict=True) # MODIFY
result = model(input_ids)
self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.projection_dim or self.hidden_size))

def create_and_check_dpr_reader(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
model = TFDPRReader(config=config)
result = model(input_ids, attention_mask=input_mask, return_dict=True) # MODIFY
result = model(input_ids, attention_mask=input_mask)

self.parent.assertEqual(result.start_logits.shape, (self.batch_size, self.seq_length))
self.parent.assertEqual(result.end_logits.shape, (self.batch_size, self.seq_length))
Expand Down Expand Up @@ -214,27 +211,61 @@ def test_dpr_reader_model(self):
@slow
def test_model_from_pretrained(self):
for model_name in TF_DPR_CONTEXT_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
model = TFDPRContextEncoder.from_pretrained(model_name, from_pt=True)
model = TFDPRContextEncoder.from_pretrained(model_name)
self.assertIsNotNone(model)

for model_name in TF_DPR_CONTEXT_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
model = TFDPRContextEncoder.from_pretrained(model_name, from_pt=True)
model = TFDPRContextEncoder.from_pretrained(model_name)
self.assertIsNotNone(model)

for model_name in TF_DPR_QUESTION_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
model = TFDPRQuestionEncoder.from_pretrained(model_name, from_pt=True)
model = TFDPRQuestionEncoder.from_pretrained(model_name)
self.assertIsNotNone(model)

for model_name in TF_DPR_READER_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
model = TFDPRReader.from_pretrained(model_name, from_pt=True)
model = TFDPRReader.from_pretrained(model_name)
self.assertIsNotNone(model)

@slow
def test_saved_model_with_attentions_output(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.output_attentions = True

encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", self.model_tester.seq_length)
encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length)

for model_class in self.all_model_classes:
print(model_class)
class_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
model = model_class(config)
num_out = len(model(class_inputs_dict))
model._saved_model_inputs_spec = None
model._set_save_spec(class_inputs_dict)

with tempfile.TemporaryDirectory() as tmpdirname:
tf.saved_model.save(model, tmpdirname)
model = tf.keras.models.load_model(tmpdirname)
outputs = model(class_inputs_dict)

if self.is_encoder_decoder:
output = outputs["encoder_attentions"] if isinstance(outputs, dict) else outputs[-1]
else:
output = outputs["attentions"] if isinstance(outputs, dict) else outputs[-1]

attentions = [t.numpy() for t in output]
self.assertEqual(len(outputs), num_out)
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
self.assertListEqual(
list(attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
)


@require_tf
class TFDPRModelIntegrationTest(unittest.TestCase):
@slow
def test_inference_no_head(self):
model = TFDPRQuestionEncoder.from_pretrained("facebook/dpr-question_encoder-single-nq-base", return_dict=False)
model = TFDPRQuestionEncoder.from_pretrained("facebook/dpr-question_encoder-single-nq-base")

input_ids = tf.constant(
[[101, 7592, 1010, 2003, 2026, 3899, 10140, 1029, 102]]
Expand Down
1 change: 1 addition & 0 deletions tests/test_modeling_tf_electra.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,7 @@ def test_model_from_pretrained(self):
self.assertIsNotNone(model)


@require_tf
class TFElectraModelIntegrationTest(unittest.TestCase):
@slow
def test_inference_masked_lm(self):
Expand Down