Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions docs/source/_static/css/huggingface.css
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,15 @@

/* Colab dropdown */

table.center-aligned-table td {
text-align: center;
}

table.center-aligned-table th {
text-align: center;
vertical-align: middle;
}

.colab-dropdown {
position: relative;
display: inline-block;
Expand Down
95 changes: 93 additions & 2 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ Choose the right framework for every part of a model's lifetime:
- Move a single model between TF2.0/PyTorch frameworks at will
- Seamlessly pick the right framework for training, evaluation, production

Experimental support for Flax with a few models right now, expected to grow in the coming months.

Contents
-----------------------------------------------------------------------------------------------------------------------

Expand All @@ -52,8 +54,8 @@ The documentation is organized in five parts:
- **MODELS** for the classes and functions related to each model implemented in the library.
- **INTERNAL HELPERS** for the classes and functions we use internally.

The library currently contains PyTorch and Tensorflow implementations, pre-trained model weights, usage scripts and
conversion utilities for the following models:
The library currently contains PyTorch, Tensorflow and Flax implementations, pretrained model weights, usage scripts
and conversion utilities for the following models:

..
This list is updated automatically from the README with `make fix-copies`. Do not update manually!
Expand Down Expand Up @@ -166,6 +168,95 @@ conversion utilities for the following models:
34. `Other community models <https://huggingface.co/models>`__, contributed by the `community
<https://huggingface.co/users>`__.


The table below represents the current support in the library for each of those models, whether they have a Python
tokenizer (called "slow"). A "fast" tokenizer backed by the 🤗 Tokenizers library, whether they have support in PyTorch,
TensorFlow and/or Flax.

..
This table is updated automatically from the auto modules with `make fix-copies`. Do not update manually!

.. rst-class:: center-aligned-table

+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| Model | Tokenizer slow | Tokenizer fast | PyTorch support | TensorFlow support | Flax Support |
+=============================+================+================+=================+====================+==============+
| ALBERT | ✅ | ✅ | ✅ | ✅ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| BART | ✅ | ✅ | ✅ | ✅ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| BERT | ✅ | ✅ | ✅ | ✅ | ✅ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| Bert Generation | ✅ | ❌ | ✅ | ❌ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| Blenderbot | ✅ | ❌ | ✅ | ✅ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| CTRL | ✅ | ❌ | ✅ | ✅ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| CamemBERT | ✅ | ✅ | ✅ | ✅ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| DPR | ✅ | ✅ | ✅ | ✅ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| DeBERTa | ✅ | ❌ | ✅ | ❌ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| DistilBERT | ✅ | ✅ | ✅ | ✅ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| ELECTRA | ✅ | ✅ | ✅ | ✅ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| Encoder decoder | ❌ | ❌ | ✅ | ❌ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| FairSeq Machine-Translation | ✅ | ❌ | ✅ | ❌ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| FlauBERT | ✅ | ❌ | ✅ | ✅ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| Funnel Transformer | ✅ | ✅ | ✅ | ✅ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| LXMERT | ✅ | ✅ | ✅ | ✅ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| LayoutLM | ✅ | ✅ | ✅ | ❌ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| Longformer | ✅ | ✅ | ✅ | ✅ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| Marian | ✅ | ❌ | ✅ | ✅ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| MobileBERT | ✅ | ✅ | ✅ | ✅ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| OpenAI GPT | ✅ | ✅ | ✅ | ✅ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| OpenAI GPT-2 | ✅ | ✅ | ✅ | ✅ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| Pegasus | ✅ | ✅ | ✅ | ✅ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| ProphetNet | ✅ | ❌ | ✅ | ❌ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| RAG | ✅ | ❌ | ✅ | ❌ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| Reformer | ✅ | ✅ | ✅ | ❌ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| RetriBERT | ✅ | ✅ | ✅ | ❌ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| RoBERTa | ✅ | ✅ | ✅ | ✅ | ✅ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| SqueezeBERT | ✅ | ✅ | ✅ | ❌ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| T5 | ✅ | ✅ | ✅ | ✅ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| Transformer-XL | ✅ | ❌ | ✅ | ✅ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| XLM | ✅ | ❌ | ✅ | ✅ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| XLM-RoBERTa | ✅ | ✅ | ✅ | ✅ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| XLMProphetNet | ✅ | ❌ | ✅ | ❌ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| XLNet | ✅ | ✅ | ✅ | ✅ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| mBART | ✅ | ✅ | ✅ | ✅ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| mT5 | ✅ | ✅ | ✅ | ✅ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+


.. toctree::
:maxdepth: 2
:caption: Get started
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@
from .models.auto import (
ALL_PRETRAINED_CONFIG_ARCHIVE_MAP,
CONFIG_MAPPING,
MODEL_NAMES_MAPPING,
TOKENIZER_MAPPING,
AutoConfig,
AutoTokenizer,
Expand Down Expand Up @@ -880,6 +881,7 @@


if is_flax_available():
from .models.auto import FLAX_MODEL_MAPPING, FlaxAutoModel
from .models.bert import FlaxBertModel
from .models.roberta import FlaxRobertaModel
else:
Expand Down
7 changes: 5 additions & 2 deletions src/transformers/models/auto/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.

from ...file_utils import is_tf_available, is_torch_available
from .configuration_auto import ALL_PRETRAINED_CONFIG_ARCHIVE_MAP, CONFIG_MAPPING, AutoConfig
from ...file_utils import is_flax_available, is_tf_available, is_torch_available
from .configuration_auto import ALL_PRETRAINED_CONFIG_ARCHIVE_MAP, CONFIG_MAPPING, MODEL_NAMES_MAPPING, AutoConfig
from .tokenization_auto import TOKENIZER_MAPPING, AutoTokenizer


Expand Down Expand Up @@ -57,3 +57,6 @@
TFAutoModelForTokenClassification,
TFAutoModelWithLMHead,
)

if is_flax_available():
from .modeling_flax_auto import FLAX_MODEL_MAPPING, FlaxAutoModel
10 changes: 5 additions & 5 deletions src/transformers/models/auto/modeling_flax_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
for key, value, in pretrained_map.items()
)

MODEL_MAPPING = OrderedDict(
FLAX_MODEL_MAPPING = OrderedDict(
[
(RobertaConfig, FlaxRobertaModel),
(BertConfig, FlaxBertModel),
Expand Down Expand Up @@ -79,13 +79,13 @@ def from_config(cls, config):
model = FlaxAutoModel.from_config(config)
# E.g. model was saved using `save_pretrained('./test/saved_model/')`
"""
for config_class, model_class in MODEL_MAPPING.items():
for config_class, model_class in FLAX_MODEL_MAPPING.items():
if isinstance(config, config_class):
return model_class(config)
raise ValueError(
f"Unrecognized configuration class {config.__class__} "
f"for this kind of FlaxAutoModel: {cls.__name__}.\n"
f"Model type should be one of {', '.join(c.__name__ for c in MODEL_MAPPING.keys())}."
f"Model type should be one of {', '.join(c.__name__ for c in FLAX_MODEL_MAPPING.keys())}."
)

@classmethod
Expand Down Expand Up @@ -173,11 +173,11 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
if not isinstance(config, PretrainedConfig):
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)

for config_class, model_class in MODEL_MAPPING.items():
for config_class, model_class in FLAX_MODEL_MAPPING.items():
if isinstance(config, config_class):
return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
raise ValueError(
f"Unrecognized configuration class {config.__class__} "
f"for this kind of FlaxAutoModel: {cls.__name__}.\n"
f"Model type should be one of {', '.join(c.__name__ for c in MODEL_MAPPING.keys())}"
f"Model type should be one of {', '.join(c.__name__ for c in FLAX_MODEL_MAPPING.keys())}"
)
12 changes: 12 additions & 0 deletions src/transformers/utils/dummy_flax_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,18 @@
from ..file_utils import requires_flax


FLAX_MODEL_MAPPING = None


class FlaxAutoModel:
def __init__(self, *args, **kwargs):
requires_flax(self)

@classmethod
def from_pretrained(self, *args, **kwargs):
requires_flax(self)


class FlaxBertModel:
def __init__(self, *args, **kwargs):
requires_flax(self)
Expand Down
Loading