Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions docs/source/_static/css/huggingface.css
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,15 @@

/* Colab dropdown */

table.center-aligned-table td {
text-align: center;
}

table.center-aligned-table th {
text-align: center;
vertical-align: middle;
}

.colab-dropdown {
position: relative;
display: inline-block;
Expand Down
95 changes: 93 additions & 2 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ Choose the right framework for every part of a model's lifetime:
- Move a single model between TF2.0/PyTorch frameworks at will
- Seamlessly pick the right framework for training, evaluation, production

Experimental support for Flax with a few models right now, expected to grow in the coming months.

Contents
-----------------------------------------------------------------------------------------------------------------------

Expand All @@ -52,8 +54,8 @@ The documentation is organized in five parts:
- **MODELS** for the classes and functions related to each model implemented in the library.
- **INTERNAL HELPERS** for the classes and functions we use internally.

The library currently contains PyTorch and Tensorflow implementations, pre-trained model weights, usage scripts and
conversion utilities for the following models:
The library currently contains PyTorch, Tensorflow and sometimes Flax implementations, pretrained model weights, usage
scripts and conversion utilities for the following models:

..
This list is updated automatically from the README with `make fix-copies`. Do not update manually!
Expand Down Expand Up @@ -166,6 +168,95 @@ conversion utilities for the following models:
34. `Other community models <https://huggingface.co/models>`__, contributed by the `community
<https://huggingface.co/users>`__.


The table below represents the current support in the library for each of those models, whether they have a Python
tokenizer (called "slow"). A "fast" tokenizer backed by the 🤗 Tokenizers library, whether they have support in PyTorch,
TensorFlow and/or Flax.

..
This table is updated automatically from the auto modules with `make fix-copies`. Do not update manually!

.. rst-class:: center-aligned-table

+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| Model | Tokenizer slow | Tokenizer fast | PyTorch support | TensorFlow support | Flax Support |
+=============================+================+================+=================+====================+==============+
| ALBERT | ✅ | ✅ | ✅ | ✅ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| BART | ✅ | ✅ | ✅ | ✅ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| BERT | ✅ | ✅ | ✅ | ✅ | ✅ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| Bert Generation | ✅ | ❌ | ✅ | ❌ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| Blenderbot | ✅ | ❌ | ✅ | ✅ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| CTRL | ✅ | ❌ | ✅ | ✅ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| CamemBERT | ✅ | ✅ | ✅ | ✅ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| DPR | ✅ | ✅ | ✅ | ✅ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| DeBERTa | ✅ | ❌ | ✅ | ❌ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| DistilBERT | ✅ | ✅ | ✅ | ✅ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| ELECTRA | ✅ | ✅ | ✅ | ✅ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| Encoder decoder | ❌ | ❌ | ✅ | ❌ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| FairSeq Machine-Translation | ✅ | ❌ | ✅ | ❌ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| FlauBERT | ✅ | ❌ | ✅ | ✅ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| Funnel Transformer | ✅ | ✅ | ✅ | ✅ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| LXMERT | ✅ | ✅ | ✅ | ✅ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| LayoutLM | ✅ | ✅ | ✅ | ❌ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| Longformer | ✅ | ✅ | ✅ | ✅ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| Marian | ✅ | ❌ | ✅ | ✅ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| MobileBERT | ✅ | ✅ | ✅ | ✅ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| OpenAI GPT | ✅ | ✅ | ✅ | ✅ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| OpenAI GPT-2 | ✅ | ✅ | ✅ | ✅ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| Pegasus | ✅ | ✅ | ✅ | ✅ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| ProphetNet | ✅ | ❌ | ✅ | ❌ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| RAG | ✅ | ❌ | ✅ | ❌ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| Reformer | ✅ | ✅ | ✅ | ❌ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| RetriBERT | ✅ | ✅ | ✅ | ❌ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| RoBERTa | ✅ | ✅ | ✅ | ✅ | ✅ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| SqueezeBERT | ✅ | ✅ | ✅ | ❌ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| T5 | ✅ | ✅ | ✅ | ✅ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| Transformer-XL | ✅ | ❌ | ✅ | ✅ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| XLM | ✅ | ❌ | ✅ | ✅ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| XLM-RoBERTa | ✅ | ✅ | ✅ | ✅ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| XLMProphetNet | ✅ | ❌ | ✅ | ❌ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| XLNet | ✅ | ✅ | ✅ | ✅ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| mBART | ✅ | ✅ | ✅ | ✅ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| mT5 | ✅ | ✅ | ✅ | ✅ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+


.. toctree::
:maxdepth: 2
:caption: Get started
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@
from .models.auto import (
ALL_PRETRAINED_CONFIG_ARCHIVE_MAP,
CONFIG_MAPPING,
MODEL_NAMES_MAPPING,
TOKENIZER_MAPPING,
AutoConfig,
AutoTokenizer,
Expand Down Expand Up @@ -880,6 +881,7 @@


if is_flax_available():
from .models.auto import FLAX_MODEL_MAPPING, FlaxAutoModel
from .models.bert import FlaxBertModel
from .models.roberta import FlaxRobertaModel
else:
Expand Down
7 changes: 5 additions & 2 deletions src/transformers/models/auto/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.

from ...file_utils import is_tf_available, is_torch_available
from .configuration_auto import ALL_PRETRAINED_CONFIG_ARCHIVE_MAP, CONFIG_MAPPING, AutoConfig
from ...file_utils import is_flax_available, is_tf_available, is_torch_available
from .configuration_auto import ALL_PRETRAINED_CONFIG_ARCHIVE_MAP, CONFIG_MAPPING, MODEL_NAMES_MAPPING, AutoConfig
from .tokenization_auto import TOKENIZER_MAPPING, AutoTokenizer


Expand Down Expand Up @@ -57,3 +57,6 @@
TFAutoModelForTokenClassification,
TFAutoModelWithLMHead,
)

if is_flax_available():
from .modeling_flax_auto import FLAX_MODEL_MAPPING, FlaxAutoModel
10 changes: 5 additions & 5 deletions src/transformers/models/auto/modeling_flax_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
for key, value, in pretrained_map.items()
)

MODEL_MAPPING = OrderedDict(
FLAX_MODEL_MAPPING = OrderedDict(
[
(RobertaConfig, FlaxRobertaModel),
(BertConfig, FlaxBertModel),
Expand Down Expand Up @@ -79,13 +79,13 @@ def from_config(cls, config):
model = FlaxAutoModel.from_config(config)
# E.g. model was saved using `save_pretrained('./test/saved_model/')`
"""
for config_class, model_class in MODEL_MAPPING.items():
for config_class, model_class in FLAX_MODEL_MAPPING.items():
if isinstance(config, config_class):
return model_class(config)
raise ValueError(
f"Unrecognized configuration class {config.__class__} "
f"for this kind of FlaxAutoModel: {cls.__name__}.\n"
f"Model type should be one of {', '.join(c.__name__ for c in MODEL_MAPPING.keys())}."
f"Model type should be one of {', '.join(c.__name__ for c in FLAX_MODEL_MAPPING.keys())}."
)

@classmethod
Expand Down Expand Up @@ -173,11 +173,11 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
if not isinstance(config, PretrainedConfig):
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)

for config_class, model_class in MODEL_MAPPING.items():
for config_class, model_class in FLAX_MODEL_MAPPING.items():
if isinstance(config, config_class):
return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
raise ValueError(
f"Unrecognized configuration class {config.__class__} "
f"for this kind of FlaxAutoModel: {cls.__name__}.\n"
f"Model type should be one of {', '.join(c.__name__ for c in MODEL_MAPPING.keys())}"
f"Model type should be one of {', '.join(c.__name__ for c in FLAX_MODEL_MAPPING.keys())}"
)
12 changes: 12 additions & 0 deletions src/transformers/utils/dummy_flax_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,18 @@
from ..file_utils import requires_flax


FLAX_MODEL_MAPPING = None


class FlaxAutoModel:
def __init__(self, *args, **kwargs):
requires_flax(self)

@classmethod
def from_pretrained(self, *args, **kwargs):
requires_flax(self)


class FlaxBertModel:
def __init__(self, *args, **kwargs):
requires_flax(self)
Expand Down
Loading