Skip to content

Commit 6db8613

Browse files
danieldkjikanter
authored andcommitted
Update TextCatBOW to use the fixed SparseLinear layer (explosion#13149)
* Update `TextCatBOW` to use the fixed `SparseLinear` layer A while ago, we fixed the `SparseLinear` layer to use all available parameters: explosion/thinc#754 This change updates `TextCatBOW` to `v3` which uses the new `SparseLinear_v2` layer. This results in a sizeable improvement on a text categorization task that was tested. While at it, this `spacy.TextCatBOW.v3` also adds the `length_exponent` option to make it possible to change the hidden size. Ideally, we'd just have an option called `length`. But the way that `TextCatBOW` uses hashes results in a non-uniform distribution of parameters when the length is not a power of two. * Replace TexCatBOW `length_exponent` parameter by `length` We now round up the length to the next power of two if it isn't a power of two. * Remove some tests for TextCatBOW.v2 * Fix missing import
1 parent a2ee6a7 commit 6db8613

File tree

3 files changed

+44
-7
lines changed

3 files changed

+44
-7
lines changed

spacy/errors.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -974,9 +974,6 @@ class Errors(metaclass=ErrorsWithCodes):
974974
E1055 = ("The 'replace_listener' callback expects {num_params} parameters, "
975975
"but only callbacks with one or three parameters are supported")
976976
E1056 = ("The `TextCatBOW` architecture expects a length of at least 1, was {length}.")
977-
E1057 = ("The `TextCatReduce` architecture must be used with at least one "
978-
"reduction. Please enable one of `use_reduce_first`, "
979-
"`use_reduce_last`, `use_reduce_max` or `use_reduce_mean`.")
980977

981978
# v4 error strings
982979
E4000 = ("Expected a Doc as input, but got: '{type}'")

spacy/tests/pipeline/test_textcat.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -499,9 +499,9 @@ def test_resize(name, textcat_config):
499499
("textcat", {"@architectures": "spacy.TextCatBOW.v3", "exclusive_classes": True, "no_output_layer": True, "ngram_size": 3}),
500500
("textcat_multilabel", {"@architectures": "spacy.TextCatBOW.v3", "exclusive_classes": False, "no_output_layer": False, "ngram_size": 3}),
501501
("textcat_multilabel", {"@architectures": "spacy.TextCatBOW.v3", "exclusive_classes": False, "no_output_layer": True, "ngram_size": 3}),
502-
# REDUCE
503-
("textcat", {"@architectures": "spacy.TextCatReduce.v1", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": True, "use_reduce_first": True, "use_reduce_last": True, "use_reduce_max": True, "use_reduce_mean": True}),
504-
("textcat_multilabel", {"@architectures": "spacy.TextCatReduce.v1", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": False, "use_reduce_first": True, "use_reduce_last": True, "use_reduce_max": True, "use_reduce_mean": True}),
502+
# CNN
503+
("textcat", {"@architectures": "spacy.TextCatCNN.v2", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": True}),
504+
("textcat_multilabel", {"@architectures": "spacy.TextCatCNN.v2", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": False}),
505505
],
506506
)
507507
# fmt: on
@@ -749,7 +749,7 @@ def test_overfitting_IO_multi():
749749
# ENSEMBLE V2
750750
("textcat_multilabel", TRAIN_DATA_MULTI_LABEL, {"@architectures": "spacy.TextCatEnsemble.v2", "tok2vec": DEFAULT_TOK2VEC_MODEL, "linear_model": {"@architectures": "spacy.TextCatBOW.v3", "exclusive_classes": False, "ngram_size": 1, "no_output_layer": False}}),
751751
("textcat", TRAIN_DATA_SINGLE_LABEL, {"@architectures": "spacy.TextCatEnsemble.v2", "tok2vec": DEFAULT_TOK2VEC_MODEL, "linear_model": {"@architectures": "spacy.TextCatBOW.v3", "exclusive_classes": True, "ngram_size": 5, "no_output_layer": False}}),
752-
# CNN V2 (legacy)
752+
# CNN V2
753753
("textcat", TRAIN_DATA_SINGLE_LABEL, {"@architectures": "spacy.TextCatCNN.v2", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": True}),
754754
("textcat_multilabel", TRAIN_DATA_MULTI_LABEL, {"@architectures": "spacy.TextCatCNN.v2", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": False}),
755755
# PARAMETRIC ATTENTION V1

website/docs/api/architectures.mdx

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1020,6 +1020,46 @@ but used an internal `tok2vec` instead of taking it as argument:
10201020
10211021
### spacy.TextCatBOW.v3 {id="TextCatBOW"}
10221022
1023+
> #### Example Config
1024+
>
1025+
> ```ini
1026+
> [model]
1027+
> @architectures = "spacy.TextCatCNN.v2"
1028+
> exclusive_classes = false
1029+
> nO = null
1030+
>
1031+
> [model.tok2vec]
1032+
> @architectures = "spacy.HashEmbedCNN.v2"
1033+
> pretrained_vectors = null
1034+
> width = 96
1035+
> depth = 4
1036+
> embed_size = 2000
1037+
> window_size = 1
1038+
> maxout_pieces = 3
1039+
> subword_features = true
1040+
> ```
1041+
1042+
A neural network model where token vectors are calculated using a CNN. The
1043+
vectors are mean pooled and used as features in a feed-forward network. This
1044+
architecture is usually less accurate than the ensemble, but runs faster.
1045+
1046+
| Name | Description |
1047+
| ------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
1048+
| `exclusive_classes` | Whether or not categories are mutually exclusive. ~~bool~~ |
1049+
| `tok2vec` | The [`tok2vec`](#tok2vec) layer of the model. ~~Model~~ |
1050+
| `nO` | Output dimension, determined by the number of different labels. If not set, the [`TextCategorizer`](/api/textcategorizer) component will set it when `initialize` is called. ~~Optional[int]~~ |
1051+
| **CREATES** | The model using the architecture. ~~Model[List[Doc], Floats2d]~~ |
1052+
1053+
<Accordion title="spacy.TextCatCNN.v1 definition" spaced>
1054+
1055+
[TextCatCNN.v1](/api/legacy#TextCatCNN_v1) had the exact same signature, but was
1056+
not yet resizable. Since v2, new labels can be added to this component, even
1057+
after training.
1058+
1059+
</Accordion>
1060+
1061+
### spacy.TextCatBOW.v3 {id="TextCatBOW"}
1062+
10231063
> #### Example Config
10241064
>
10251065
> ```ini

0 commit comments

Comments
 (0)