|
15 | 15 |
|
16 | 16 |
|
17 | 17 | import copy |
18 | | -import os |
19 | | -import tempfile |
20 | 18 | import unittest |
21 | 19 |
|
22 | 20 | import numpy as np |
23 | 21 |
|
24 | | -import transformers |
25 | 22 | from transformers import LxmertConfig, is_tf_available, is_torch_available |
26 | 23 | from transformers.models.auto import get_values |
27 | | -from transformers.testing_utils import is_pt_tf_cross_test, require_torch, slow, torch_device |
| 24 | +from transformers.testing_utils import require_torch, slow, torch_device |
28 | 25 |
|
29 | 26 | from ..test_configuration_common import ConfigTester |
30 | 27 | from ..test_modeling_common import ModelTesterMixin, ids_tensor |
@@ -527,6 +524,8 @@ def prepare_config_and_inputs_for_common(self, return_obj_labels=False): |
527 | 524 |
|
528 | 525 | if return_obj_labels: |
529 | 526 | inputs_dict["obj_labels"] = obj_labels |
| 527 | + else: |
| 528 | + config.task_obj_predict = False |
530 | 529 |
|
531 | 530 | return config, inputs_dict |
532 | 531 |
|
@@ -740,121 +739,30 @@ def test_retain_grad_hidden_states_attentions(self): |
740 | 739 | self.assertIsNotNone(hidden_states_vision.grad) |
741 | 740 | self.assertIsNotNone(attentions_vision.grad) |
742 | 741 |
|
743 | | - @is_pt_tf_cross_test |
744 | | - def test_pt_tf_model_equivalence(self): |
745 | | - for model_class in self.all_model_classes: |
746 | | - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common( |
747 | | - return_obj_labels="PreTraining" in model_class.__name__ |
748 | | - ) |
749 | | - |
750 | | - tf_model_class_name = "TF" + model_class.__name__ # Add the "TF" at the beginning |
751 | | - |
752 | | - if not hasattr(transformers, tf_model_class_name): |
753 | | - # transformers does not have TF version yet |
754 | | - return |
755 | | - |
756 | | - tf_model_class = getattr(transformers, tf_model_class_name) |
757 | | - |
758 | | - config.output_hidden_states = True |
759 | | - config.task_obj_predict = False |
760 | | - |
761 | | - pt_model = model_class(config) |
762 | | - tf_model = tf_model_class(config) |
763 | | - |
764 | | - # Check we can load pt model in tf and vice-versa with model => model functions |
765 | | - pt_inputs = self._prepare_for_class(inputs_dict, model_class) |
766 | | - |
767 | | - def recursive_numpy_convert(iterable): |
768 | | - return_dict = {} |
769 | | - for key, value in iterable.items(): |
770 | | - if type(value) == bool: |
771 | | - return_dict[key] = value |
772 | | - if isinstance(value, dict): |
773 | | - return_dict[key] = recursive_numpy_convert(value) |
774 | | - else: |
775 | | - if isinstance(value, (list, tuple)): |
776 | | - return_dict[key] = ( |
777 | | - tf.convert_to_tensor(iter_value.cpu().numpy(), dtype=tf.int32) for iter_value in value |
778 | | - ) |
779 | | - else: |
780 | | - return_dict[key] = tf.convert_to_tensor(value.cpu().numpy(), dtype=tf.int32) |
781 | | - return return_dict |
782 | | - |
783 | | - tf_inputs_dict = recursive_numpy_convert(pt_inputs) |
784 | | - |
785 | | - tf_model = transformers.load_pytorch_model_in_tf2_model(tf_model, pt_model, tf_inputs=tf_inputs_dict) |
786 | | - pt_model = transformers.load_tf2_model_in_pytorch_model(pt_model, tf_model).to(torch_device) |
787 | | - |
788 | | - # Check predictions on first output (logits/hidden-states) are close enought given low-level computational differences |
789 | | - pt_model.eval() |
790 | | - |
791 | | - # Delete obj labels as we want to compute the hidden states and not the loss |
792 | | - |
793 | | - if "obj_labels" in inputs_dict: |
794 | | - del inputs_dict["obj_labels"] |
795 | | - |
796 | | - pt_inputs = self._prepare_for_class(inputs_dict, model_class) |
797 | | - tf_inputs_dict = recursive_numpy_convert(pt_inputs) |
798 | | - |
799 | | - with torch.no_grad(): |
800 | | - pto = pt_model(**pt_inputs) |
801 | | - tfo = tf_model(tf_inputs_dict, training=False) |
802 | | - tf_hidden_states = tfo[0].numpy() |
803 | | - pt_hidden_states = pto[0].cpu().numpy() |
804 | | - |
805 | | - tf_nans = np.copy(np.isnan(tf_hidden_states)) |
806 | | - pt_nans = np.copy(np.isnan(pt_hidden_states)) |
807 | | - |
808 | | - pt_hidden_states[tf_nans] = 0 |
809 | | - tf_hidden_states[tf_nans] = 0 |
810 | | - pt_hidden_states[pt_nans] = 0 |
811 | | - tf_hidden_states[pt_nans] = 0 |
812 | | - |
813 | | - max_diff = np.amax(np.abs(tf_hidden_states - pt_hidden_states)) |
814 | | - # Debug info (remove when fixed) |
815 | | - if max_diff >= 2e-2: |
816 | | - print("===") |
817 | | - print(model_class) |
818 | | - print(config) |
819 | | - print(inputs_dict) |
820 | | - print(pt_inputs) |
821 | | - self.assertLessEqual(max_diff, 6e-2) |
822 | | - |
823 | | - # Check we can load pt model in tf and vice-versa with checkpoint => model functions |
824 | | - with tempfile.TemporaryDirectory() as tmpdirname: |
825 | | - pt_checkpoint_path = os.path.join(tmpdirname, "pt_model.bin") |
826 | | - torch.save(pt_model.state_dict(), pt_checkpoint_path) |
827 | | - tf_model = transformers.load_pytorch_checkpoint_in_tf2_model(tf_model, pt_checkpoint_path) |
828 | | - |
829 | | - tf_checkpoint_path = os.path.join(tmpdirname, "tf_model.h5") |
830 | | - tf_model.save_weights(tf_checkpoint_path) |
831 | | - pt_model = transformers.load_tf2_checkpoint_in_pytorch_model(pt_model, tf_checkpoint_path) |
832 | | - |
833 | | - # Check predictions on first output (logits/hidden-states) are close enought given low-level computational differences |
834 | | - pt_model.eval() |
835 | | - |
836 | | - for key, value in pt_inputs.items(): |
837 | | - if key in ("visual_feats", "visual_pos"): |
838 | | - pt_inputs[key] = value.to(torch.float32) |
839 | | - else: |
840 | | - pt_inputs[key] = value.to(torch.long) |
841 | | - |
842 | | - with torch.no_grad(): |
843 | | - pto = pt_model(**pt_inputs) |
844 | | - |
845 | | - tfo = tf_model(tf_inputs_dict) |
846 | | - tfo = tfo[0].numpy() |
847 | | - pto = pto[0].cpu().numpy() |
848 | | - tf_nans = np.copy(np.isnan(tfo)) |
849 | | - pt_nans = np.copy(np.isnan(pto)) |
850 | | - |
851 | | - pto[tf_nans] = 0 |
852 | | - tfo[tf_nans] = 0 |
853 | | - pto[pt_nans] = 0 |
854 | | - tfo[pt_nans] = 0 |
855 | | - |
856 | | - max_diff = np.amax(np.abs(tfo - pto)) |
857 | | - self.assertLessEqual(max_diff, 6e-2) |
| 742 | + def prepare_tf_inputs_from_pt_inputs(self, pt_inputs_dict): |
| 743 | + |
| 744 | + tf_inputs_dict = {} |
| 745 | + for key, value in pt_inputs_dict.items(): |
| 746 | + # skip key that does not exist in tf |
| 747 | + if isinstance(value, dict): |
| 748 | + tf_inputs_dict[key] = self.prepare_pt_inputs_from_tf_inputs(value) |
| 749 | + elif isinstance(value, (list, tuple)): |
| 750 | + tf_inputs_dict[key] = (self.prepare_pt_inputs_from_tf_inputs(iter_value) for iter_value in value) |
| 751 | + elif type(value) == bool: |
| 752 | + tf_inputs_dict[key] = value |
| 753 | + elif key == "input_values": |
| 754 | + tf_inputs_dict[key] = tf.convert_to_tensor(value.cpu().numpy(), dtype=tf.float32) |
| 755 | + elif key == "pixel_values": |
| 756 | + tf_inputs_dict[key] = tf.convert_to_tensor(value.cpu().numpy(), dtype=tf.float32) |
| 757 | + elif key == "input_features": |
| 758 | + tf_inputs_dict[key] = tf.convert_to_tensor(value.cpu().numpy(), dtype=tf.float32) |
| 759 | + # other general float inputs |
| 760 | + elif value.is_floating_point(): |
| 761 | + tf_inputs_dict[key] = tf.convert_to_tensor(value.cpu().numpy(), dtype=tf.float32) |
| 762 | + else: |
| 763 | + tf_inputs_dict[key] = tf.convert_to_tensor(value.cpu().numpy(), dtype=tf.int32) |
| 764 | + |
| 765 | + return tf_inputs_dict |
858 | 766 |
|
859 | 767 |
|
860 | 768 | @require_torch |
|
0 commit comments