Skip to content

Conversation

@lewtun
Copy link
Member

@lewtun lewtun commented May 21, 2021

This PR implements the idea discussed in #2389 to update the labels of the TextClassification template in the DatasetInfo.__post_init__. The main reason for doing so is so avoid duplicating the label definitions in both DatasetInfo.features and DatasetInfo.task_templates.

To avoid storing state in DatasetInfo.__post_init__, the current implementation flushes DatasetInfo.task_templates before the features are cast in Dataset.prepare_for_task (thanks to @mariosasko for this idea!).

Here is an example of the current workflow:

ds1 = load_dataset("./datasets/emotion/")
# cast features and flush templates
ds2 = ds1.prepare_for_task("text-classification")
assert ds2.info.task_templates is None

Note that if users want to pass a TextClassification template to prepare_for_task, we require them to set TextClassification.labels to match the dataset's features corresponding to label_column:

ds1 = load_dataset("./datasets/emotion/")
# TextClassification.labels is None by default => invalid template
task = TextClassification(text_column="text", label_column="label")
# Raises ValueError
ds1.prepare_for_task(task)
# Specifying the labels => valid template
task = TextClassification(text_column="text", label_column="label", labels=['anger', 'fear', 'joy', 'love', 'sadness', 'surprise'])
ds1.prepare_for_task(task)

This PR also adds:

  • New tests + fixed some old tests that weren't testing assertRaises properly
  • A decorator to share docstrings across common functions. This allows us to document DatasetDict.prepare_for_task and Dataset.prepare_for_task in one place.
  • Fixes to avoid side-effects from in-place replacements of DatasetInfo.task_templates in DatasetInfo.__post_init__. Thanks to @lhoestq for figuring this out!
  • Removal of FeaturesWithLazyClassLabel since we now create a new instance of TextClassification in DatasetInfo.__post_init__ and avoid the side-effects first pointed out by @mariosasko

PR Description from original WIP

Hi @yjernite and @lhoestq, here's a first stab at the suggestion discussed in #2389 to update the labels of the TextClassification template in the DatasetInfo.__post_init__.

One problem I've spotted is that my current implementation introduces state into the __post_init__:

  • When we call load_dataset, DatasetInfo.features are the "raw" features without any casting so we can access the column names by the label_column specified in TextClassification
  • When we call Dataset.prepare_for_task we run into a problem because the DatasetInfo.features are first cast into the new schema which triggers a KeyError when we update the infos here.

Here's an explicit example of what I mean with the stack trace appended below:

from datasets import load_dataset

# this works 
ds = load_dataset("emotion")
# we can verify the task template is correctly set
ds["train"].info.task_templates # returns [TextClassification(labels=('sadness', 'joy', 'love', 'anger', 'fear', 'surprise'), text_column='text', label_column='label')]
# but this fails because the _post_init__ is looking for the original column names
ds.prepare_for_task("text-classification")
---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
<ipython-input-4-54a43019b319> in <module>
----> 1 ds.prepare_for_task("text-classification")

~/git/datasets/src/datasets/dataset_dict.py in prepare_for_task(self, task)
    807         """
    808         self._check_values_type()
--> 809         return DatasetDict({k: dataset.prepare_for_task(task=task) for k, dataset in self.items()})

~/git/datasets/src/datasets/dataset_dict.py in <dictcomp>(.0)
    807         """
    808         self._check_values_type()
--> 809         return DatasetDict({k: dataset.prepare_for_task(task=task) for k, dataset in self.items()})

~/git/datasets/src/datasets/arrow_dataset.py in prepare_for_task(self, task)
   1421         dataset = self.remove_columns(columns_to_drop)
   1422         dataset = dataset.rename_columns(column_mapping)
-> 1423         dataset = dataset.cast(features=template.features)
   1424         return dataset
   1425 

~/git/datasets/src/datasets/arrow_dataset.py in cast(self, features, batch_size, keep_in_memory, load_from_cache_file, cache_file_name, writer_batch_size, num_proc)
    970         format = self.format
    971         dataset = self.with_format("arrow")
--> 972         dataset = dataset.map(
    973             lambda t: t.cast(schema),
    974             batched=True,

~/git/datasets/src/datasets/arrow_dataset.py in map(self, function, with_indices, input_columns, batched, batch_size, drop_last_batch, remove_columns, keep_in_memory, load_from_cache_file, cache_file_name, writer_batch_size, features, disable_nullable, fn_kwargs, num_proc, suffix_template, new_fingerprint)
   1583 
   1584         if num_proc is None or num_proc == 1:
-> 1585             return self._map_single(
   1586                 function=function,
   1587                 with_indices=with_indices,

~/git/datasets/src/datasets/arrow_dataset.py in wrapper(*args, **kwargs)
    173         }
    174         # apply actual function
--> 175         out: Union["Dataset", "DatasetDict"] = func(self, *args, **kwargs)
    176         datasets: List["Dataset"] = list(out.values()) if isinstance(out, dict) else [out]
    177         # re-apply format to the output

~/git/datasets/src/datasets/fingerprint.py in wrapper(*args, **kwargs)
    338             # Call actual function
    339 
--> 340             out = func(self, *args, **kwargs)
    341 
    342             # Update fingerprint of in-place transforms + update in-place history of transforms

~/git/datasets/src/datasets/arrow_dataset.py in _map_single(self, function, with_indices, input_columns, batched, batch_size, drop_last_batch, remove_columns, keep_in_memory, load_from_cache_file, cache_file_name, writer_batch_size, features, disable_nullable, fn_kwargs, new_fingerprint, rank, offset)
   1959         if update_data:
   1960             # Create new Dataset from buffer or file
-> 1961             info = self.info.copy()
   1962             info.features = writer._features
   1963             if buf_writer is None:

~/git/datasets/src/datasets/info.py in copy(self)
    274 
    275     def copy(self) -> "DatasetInfo":
--> 276         return self.__class__(**{k: copy.deepcopy(v) for k, v in self.__dict__.items()})
    277 
    278 

~/git/datasets/src/datasets/info.py in __init__(self, description, citation, homepage, license, features, post_processed, supervised_keys, task_templates, builder_name, config_name, version, splits, download_checksums, download_size, post_processing_size, dataset_size, size_in_bytes)

~/git/datasets/src/datasets/info.py in __post_init__(self)
    174                     # The reason is that Dataset.prepare_for_task calls Dataset.cast which converts the
    175                     # DatasetInfo.features to the new schema and thus template.label_column is no longer a valid key
--> 176                     object.__setattr__(template, "labels", tuple(self.features[template.label_column].names))
    177                     template.label_schema["labels"] = ClassLabel(names=template.labels)
    178                     self.task_templates[idx] = template

KeyError: 'label'

What do you think? I did this a bit quickly, so maybe I'm overlooking something obvious :) One thing would be to only update the labels of the task template on load, but this seems a bit hacky IMO

@mariosasko
Copy link
Collaborator

If I'm not mistaken, one way to fix this would be to drop the task templates when copying the info by inserting dataset.info.task_templates = None before the Dataset.cast call in Dataset.prepare_for_task. Moreover, we should do this change independently of the KeyError being raised because currently the following is possible:

dset = load_dataset("some_dataset") # let's say 'some_dataset' supports text classification and question answering
dset_tc = dset.prepare_for_task("text-classification")
dset_tc.preprare_for_task("question-answering") # this should raise an error because the schema is no longer valid for this task; currently this fails on 'rename_columns'

I see 2 options:

  1. to drop the task templates after the first Dataset.prepare_for_task call
  2. to save only the tasks compatible with the new schema after Dataset.prepare_for_task` (but then we have to update the column names of the compatible tasks to make sure the column mapping is still valid)

@lewtun
Copy link
Member Author

lewtun commented May 23, 2021

If I'm not mistaken, one way to fix this would be to drop the task templates when copying the info by inserting dataset.info.task_templates = None before the Dataset.cast call in Dataset.prepare_for_task. Moreover, we should do this change independently of the KeyError being raised because currently the following is possible:

dset = load_dataset("some_dataset") # let's say 'some_dataset' supports text classification and question answering
dset_tc = dset.prepare_for_task("text-classification")
dset_tc.preprare_for_task("question-answering") # this should raise an error because the schema is no longer valid for this task; currently this fails on 'rename_columns'

I see 2 options:

  1. to drop the task templates after the first Dataset.prepare_for_task call
  2. to save only the tasks compatible with the new schema after Dataset.prepare_for_task` (but then we have to update the column names of the compatible tasks to make sure the column mapping is still valid)

thanks for the great idea @mariosasko and for spotting the problem with sequential task preparation! i am in favour of your option (1) since it is simple and saves us from having to keep track of the column mappings across multiple steps.

i've implemented the change and refactored the tests to account for the new approach (including a new test that the templates are flushed after we call prepare_for_task). perhaps the slightly inelegant aspect here is that if we want to allow the user to set labels in the TextClassification template, then we have two places (DatasetInfo.__post_init__ and TextClassification.__post_init__) where we need to update label_schema.

on the other hand, dropping labels from the TextClassification signature would have the nice effect that users only have to think about column names when defining their tasks.

in any case, i think it would be a good idea to merge #2376 soon as the current PR is touching a lot of the same places in the codebase 😄

@lewtun
Copy link
Member Author

lewtun commented May 25, 2021

cc @SBrandeis who might also be interested in this feature :)

@lhoestq
Copy link
Member

lhoestq commented May 25, 2021

Tests are failing only because the emotion dataset card doesn't pass our dataset card validator (tags are missing), you can ignore this since it's unrelated to this PR.

Copy link
Member

@lhoestq lhoestq left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great ! I added 2 comments:

@lewtun lewtun changed the title [WIP] Update text classification template labels in DatasetInfo __post_init__ Update text classification template labels in DatasetInfo __post_init__ May 26, 2021
@lewtun lewtun requested a review from SBrandeis May 26, 2021 16:25
@lewtun
Copy link
Member Author

lewtun commented May 26, 2021

@lhoestq @SBrandeis i've fixed the tests and think this is now in a good state for another review :)

Copy link
Member

@lhoestq lhoestq left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks all good ! Thanks for the fix :)

Comment on lines 706 to 710
task_template = TextClassification(text_column="text", label_column="labels", labels=labels)
info = DatasetInfo(
features=Features({"text": Value("string"), "labels": ClassLabel(names=labels)}),
task_templates=task_template,
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe add a test to make sure that the DatasetInfo post init does pass the labels to the task templates, so that we can instantiate the task_template without labels as in the demo ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good idea! i've now set up the unit test to cover the following cases:

  1. DatasetInfo defined with a TextClassification template without labels
  2. DatasetInfo defined with a TextClassification template with labels (this one might be overkill, but best to be safe)
  3. Passing a TextClassification template with labels directly to prepare_for_task (the no labels case is covered in test_task_with_incompatible_templates)

@lhoestq
Copy link
Member

lhoestq commented May 27, 2021

Maybe @SBrandeis you can also take a look to make sure you're fine with it ?

Copy link
Collaborator

@mariosasko mariosasko left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Few suggestions.

Comment on lines 176 to 181
for idx, template in enumerate(self.task_templates):
if isinstance(template, TextClassification) and self.features is not None:
labels = self.features[template.label_column].names
self.task_templates[idx] = TextClassification(
text_column=template.text_column, label_column=template.label_column, labels=labels
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A small nit. It's cleaner to have the self.features is not None check outside the loop:

Suggested change
for idx, template in enumerate(self.task_templates):
if isinstance(template, TextClassification) and self.features is not None:
labels = self.features[template.label_column].names
self.task_templates[idx] = TextClassification(
text_column=template.text_column, label_column=template.label_column, labels=labels
)
if self.features is not None:
for idx, template in enumerate(self.task_templates):
if isinstance(template, TextClassification):
labels = self.features[template.label_column].names
self.task_templates[idx] = TextClassification(
text_column=template.text_column, label_column=template.label_column, labels=labels
)

Copy link
Member Author

@lewtun lewtun May 28, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good idea! done.

assert len(self.labels) == len(set(self.labels)), "Labels must be unique"
# Cast labels to tuple to allow hashing
self.__dict__["labels"] = tuple(sorted(self.labels))
self.__dict__["label_schema"] = copy.deepcopy(self.label_schema)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be safe to remove the FeaturesWithLazyClassLabel descriptor after this to reduce the code complexity (and then the deepcopy call can be replaced with self.label_schema.copy()).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes i agree, thanks for the suggestion! since we now create a new instance of TextClassification in DatasetInfo.__post_init__ I think we should be safe from the side-effects you spotted earlier

Copy link
Contributor

@SBrandeis SBrandeis left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks @lewtun !

@lewtun lewtun merged commit caf050c into huggingface:master May 28, 2021
@lewtun lewtun deleted the refactor-text-clf-template branch May 28, 2021 11:37
lhoestq added a commit that referenced this pull request May 28, 2021
* Insert task templates

* Fix style

* Insert templates for datasets with off configs

* Fix style

* Update labels in DatasetInfo __post_init__

* Add emotion example

* Flush task templates before casting

* Add labels to TextClassification __post_init__

* Add comment about casting to tuple

* Fix capitalisation

* Refactor tests to account for label update in `DatasetInfo`, add test

* Update label schema in post_init

* Use __dict__ instead of __setattr__ to update task template labels

* Raise ValueError if TextClassification template has None or incompatible labels

* Remove task templates from emotion demo

* Add decorator to share docstrings across multiple functions

* Update docstring for prepare_for_task

* Reorder TextClassification args for better intuition

* fix missing "task" field in json + edit copy of objects instead of modifying in-place

* style

* Fix failing tests due to new DatasetInfo.__post_init__

* Refactor TextClassification test to cover templates w / w-out labels

* Refactor use of label names in task template concatenation test

* Add separate test for template with labels in DatasetInfo

* Fix log message

* Fix comments

* Remove custom feature with lazy classlabel

No longer needed since we create a new instance of the task template during the `DatasetInfo.__post_init__`

* Move conditional check of features to outer if statement

* Move feature is not None check to inner if-statement

* Revert task template insertion to account for API changes in PR #2392

* Insert task template to allocine dataset

* Revert "Insert task template to allocine dataset"

This reverts commit c577149.

* Simplify args for text classification template insertion

* Add datasets with text classification templates

* Fix style

* Exclude caner dataset from injection

Co-authored-by: Quentin Lhoest <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants