diff --git a/t5/data/preprocessors.py b/t5/data/preprocessors.py index 1a1aeaf8..420c4fa4 100644 --- a/t5/data/preprocessors.py +++ b/t5/data/preprocessors.py @@ -1615,7 +1615,13 @@ def make_examples(idx, ex): if passthrough_feature_keys is not None: for feature_name in passthrough_feature_keys: - output[feature_name] = [ex[feature_name]] * len(targets) + tiled_shape = tf.concat( + [tf.expand_dims(tf.shape(targets)[0], axis=0), + tf.ones(len(ex[feature_name].shape), dtype=tf.int32)], + axis=0) + output[feature_name] = tf.tile( + tf.expand_dims(ex[feature_name], axis=0), + tiled_shape) if weight_fn is not None: output['weight'] = tf.fill(tf.shape(is_correct), weight_fn(ex)) diff --git a/t5/data/preprocessors_test.py b/t5/data/preprocessors_test.py index b0dd688a..28204f7c 100644 --- a/t5/data/preprocessors_test.py +++ b/t5/data/preprocessors_test.py @@ -1480,23 +1480,46 @@ def test_rank_classification_with_weight(self): ]) def test_rank_classification_with_passthrough_feature_keys(self): - dataset = tf.data.Dataset.from_tensors({ - 'left': 'the sky is blue', - 'right': 'cats are so cute', - 'label_idx': 1, - 'weight': 1.0, - 'starburst_allow_pass': [0.1, 0.2], - 'context_allow_pass': 'the sun is out', - 'starburst_not_allow_pass': [0.9, 0.8] - }) + def dataset_generator(): + yield { + 'left': 'the sky is blue', + 'right': 'cats are so cute', + 'label_idx': 1, + 'weight': 1.0, + 'options': ['class 0', 'class 1'], + 'starburst_allow_pass': [0.1, 0.2], + 'context_allow_pass': 'the sun is out', + 'multicontext_allow_pass': ['the sun is out', 'so i am out'], + 'starburst_not_allow_pass': [0.9, 0.8] + } + + dataset = tf.data.Dataset.from_generator( + dataset_generator, + output_signature={ + 'left': tf.TensorSpec(shape=(), dtype=tf.string), + 'right': tf.TensorSpec(shape=(), dtype=tf.string), + 'label_idx': tf.TensorSpec(shape=(), dtype=tf.int32), + 'weight': tf.TensorSpec(shape=(), dtype=tf.float32), + 'options': tf.TensorSpec(shape=(None,), dtype=tf.string), + 'starburst_allow_pass': tf.TensorSpec(shape=(2,), + dtype=tf.float32), + 'context_allow_pass': tf.TensorSpec(shape=(), dtype=tf.string), + 'multicontext_allow_pass': tf.TensorSpec(shape=(None,), + dtype=tf.string), + 'starburst_not_allow_pass': tf.TensorSpec(shape=(2,), + dtype=tf.float32) + }) + preprocessor = functools.partial( prep.rank_classification, dataset, inputs_fn=lambda features: [features['right'], features['left']], - targets_fn=lambda features: ['class 0', 'class 1'], + targets_fn=lambda features: features['options'], is_correct_fn=lambda features: [False, True], weight_fn=lambda features: features['weight'], - passthrough_feature_keys=['starburst_allow_pass', 'context_allow_pass']) + passthrough_feature_keys=['starburst_allow_pass', + 'context_allow_pass', + 'multicontext_allow_pass']) test_utils.assert_dataset( preprocessor(mode='train'), [{ @@ -1507,6 +1530,7 @@ def test_rank_classification_with_passthrough_feature_keys(self): 'weight': 1.0, 'starburst_allow_pass': [0.1, 0.2], 'context_allow_pass': 'the sun is out', + 'multicontext_allow_pass': ['the sun is out', 'so i am out'] }]) test_utils.assert_dataset( @@ -1518,6 +1542,7 @@ def test_rank_classification_with_passthrough_feature_keys(self): 'weight': 1.0, 'starburst_allow_pass': [0.1, 0.2], 'context_allow_pass': 'the sun is out', + 'multicontext_allow_pass': ['the sun is out', 'so i am out'] }, { 'idx': [0, 1], 'inputs': 'the sky is blue', @@ -1526,6 +1551,7 @@ def test_rank_classification_with_passthrough_feature_keys(self): 'weight': 1.0, 'starburst_allow_pass': [0.1, 0.2], 'context_allow_pass': 'the sun is out', + 'multicontext_allow_pass': ['the sun is out', 'so i am out'] }]) test_utils.assert_dataset( @@ -1537,7 +1563,10 @@ def test_rank_classification_with_passthrough_feature_keys(self): 'is_correct': [False, True], 'weight': [1, 1], 'starburst_allow_pass': [[0.1, 0.2], [0.1, 0.2]], - 'context_allow_pass': ['the sun is out', 'the sun is out'] + 'context_allow_pass': ['the sun is out', 'the sun is out'], + 'multicontext_allow_pass': [ + ['the sun is out', 'so i am out'], + ['the sun is out', 'so i am out']] }, ])