Skip to content

Commit 5d1a9f1

Browse files
Dummy labels no longer on by default in to_tf_dataset (#2951)
* Dummy labels no longer on by default * Style pass
1 parent f7d50b6 commit 5d1a9f1

1 file changed

Lines changed: 3 additions & 8 deletions

File tree

src/datasets/arrow_dataset.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -257,9 +257,9 @@ def to_tf_dataset(
257257
label_cols (:obj:`List[str]` or :obj:`str`, default ``None``): Dataset column(s) to load as
258258
labels. Note that many models compute loss internally rather than letting Keras do it, in which case it is
259259
not necessary to actually pass the labels here, as long as they're in the input `columns`.
260-
dummy_labels (:obj:`bool`, default ``True``): If no `label_cols` are set, output an array of "dummy" labels
261-
with each batch. This setting ensures that Keras `fit()` or `train_on_batch()` does not get confused
262-
by the missing labels.
260+
dummy_labels (:obj:`bool`, default ``False``): If no `label_cols` are set, output an array of "dummy" labels
261+
with each batch. This can avoid problems with `fit()` or `train_on_batch()` that expect labels to be
262+
a Tensor or np.ndarray, but should (hopefully) not be necessary with our standard train_step().
263263
prefetch (:obj:`bool`, default ``True``): Whether to run the dataloader in a separate thread and maintain
264264
a small buffer of batches for training. Improves performance by allowing data to be loaded in the
265265
background while the model is training.
@@ -389,11 +389,6 @@ def split_features_and_labels(input_batch):
389389
tf_dataset = tf_dataset.map(lambda x: list(x.values())[0])
390390

391391
if dummy_labels and not label_cols:
392-
print(
393-
"Warning: No label_cols specified - adding some dummy labels to ensure fit() works correctly. If you "
394-
"only want to use this dataset with predict() or custom training loops, you can disable this "
395-
"behaviour by setting dummy_labels to False."
396-
)
397392

398393
def add_dummy_labels(input_batch):
399394
return input_batch, tf.zeros(tf.shape(input_batch[columns[0]])[0])

0 commit comments

Comments
 (0)