Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion t5/data/preprocessors.py
Original file line number Diff line number Diff line change
Expand Up @@ -1615,7 +1615,13 @@ def make_examples(idx, ex):

if passthrough_feature_keys is not None:
for feature_name in passthrough_feature_keys:
output[feature_name] = [ex[feature_name]] * len(targets)
tiled_shape = tf.concat(
[tf.expand_dims(tf.shape(targets)[0], axis=0),
tf.ones(len(ex[feature_name].shape), dtype=tf.int32)],
axis=0)
output[feature_name] = tf.tile(
tf.expand_dims(ex[feature_name], axis=0),
tiled_shape)

if weight_fn is not None:
output['weight'] = tf.fill(tf.shape(is_correct), weight_fn(ex))
Expand Down
53 changes: 41 additions & 12 deletions t5/data/preprocessors_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1480,23 +1480,46 @@ def test_rank_classification_with_weight(self):
])

def test_rank_classification_with_passthrough_feature_keys(self):
dataset = tf.data.Dataset.from_tensors({
'left': 'the sky is blue',
'right': 'cats are so cute',
'label_idx': 1,
'weight': 1.0,
'starburst_allow_pass': [0.1, 0.2],
'context_allow_pass': 'the sun is out',
'starburst_not_allow_pass': [0.9, 0.8]
})
def dataset_generator():
yield {
'left': 'the sky is blue',
'right': 'cats are so cute',
'label_idx': 1,
'weight': 1.0,
'options': ['class 0', 'class 1'],
'starburst_allow_pass': [0.1, 0.2],
'context_allow_pass': 'the sun is out',
'multicontext_allow_pass': ['the sun is out', 'so i am out'],
'starburst_not_allow_pass': [0.9, 0.8]
}

dataset = tf.data.Dataset.from_generator(
dataset_generator,
output_signature={
'left': tf.TensorSpec(shape=(), dtype=tf.string),
'right': tf.TensorSpec(shape=(), dtype=tf.string),
'label_idx': tf.TensorSpec(shape=(), dtype=tf.int32),
'weight': tf.TensorSpec(shape=(), dtype=tf.float32),
'options': tf.TensorSpec(shape=(None,), dtype=tf.string),
'starburst_allow_pass': tf.TensorSpec(shape=(2,),
dtype=tf.float32),
'context_allow_pass': tf.TensorSpec(shape=(), dtype=tf.string),
'multicontext_allow_pass': tf.TensorSpec(shape=(None,),
dtype=tf.string),
'starburst_not_allow_pass': tf.TensorSpec(shape=(2,),
dtype=tf.float32)
})

preprocessor = functools.partial(
prep.rank_classification,
dataset,
inputs_fn=lambda features: [features['right'], features['left']],
targets_fn=lambda features: ['class 0', 'class 1'],
targets_fn=lambda features: features['options'],
is_correct_fn=lambda features: [False, True],
weight_fn=lambda features: features['weight'],
passthrough_feature_keys=['starburst_allow_pass', 'context_allow_pass'])
passthrough_feature_keys=['starburst_allow_pass',
'context_allow_pass',
'multicontext_allow_pass'])

test_utils.assert_dataset(
preprocessor(mode='train'), [{
Expand All @@ -1507,6 +1530,7 @@ def test_rank_classification_with_passthrough_feature_keys(self):
'weight': 1.0,
'starburst_allow_pass': [0.1, 0.2],
'context_allow_pass': 'the sun is out',
'multicontext_allow_pass': ['the sun is out', 'so i am out']
}])

test_utils.assert_dataset(
Expand All @@ -1518,6 +1542,7 @@ def test_rank_classification_with_passthrough_feature_keys(self):
'weight': 1.0,
'starburst_allow_pass': [0.1, 0.2],
'context_allow_pass': 'the sun is out',
'multicontext_allow_pass': ['the sun is out', 'so i am out']
}, {
'idx': [0, 1],
'inputs': 'the sky is blue',
Expand All @@ -1526,6 +1551,7 @@ def test_rank_classification_with_passthrough_feature_keys(self):
'weight': 1.0,
'starburst_allow_pass': [0.1, 0.2],
'context_allow_pass': 'the sun is out',
'multicontext_allow_pass': ['the sun is out', 'so i am out']
}])

test_utils.assert_dataset(
Expand All @@ -1537,7 +1563,10 @@ def test_rank_classification_with_passthrough_feature_keys(self):
'is_correct': [False, True],
'weight': [1, 1],
'starburst_allow_pass': [[0.1, 0.2], [0.1, 0.2]],
'context_allow_pass': ['the sun is out', 'the sun is out']
'context_allow_pass': ['the sun is out', 'the sun is out'],
'multicontext_allow_pass': [
['the sun is out', 'so i am out'],
['the sun is out', 'so i am out']]
},
])

Expand Down