Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 16 additions & 22 deletions src/transformers/modeling_flax_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import flax.linen as nn
import jax
import jax.numpy as jnp
from flax.linen import compact

from .configuration_bert import BertConfig
from .file_utils import add_start_docstrings
Expand Down Expand Up @@ -108,13 +107,15 @@ class FlaxBertLayerNorm(nn.Module):
"""

epsilon: float = 1e-6
dtype: jnp.dtype = jnp.float32
bias: bool = True
scale: bool = True
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
bias: bool = True # If True, bias (beta) is added.
scale: bool = True # If True, multiply by scale (gamma). When the next layer is linear
# (also e.g. nn.relu), this can be disabled since the scaling will be
# done by the next layer.
bias_init: jnp.ndarray = nn.initializers.zeros
scale_init: jnp.ndarray = nn.initializers.ones

@compact
@nn.compact
def __call__(self, x):
"""
Applies layer normalization on the input. It normalizes the activations of the layer for each given example in
Expand All @@ -123,13 +124,6 @@ def __call__(self, x):

Args:
x: the inputs
epsilon: A small float added to variance to avoid dividing by zero.
dtype: the dtype of the computation (default: float32).
bias: If True, bias (beta) is added.
scale: If True, multiply by scale (gamma). When the next layer is linear
(also e.g. nn.relu), this can be disabled since the scaling will be done by the next layer.
bias_init: Initializer for bias, by default, zero.
scale_init: Initializer for scale, by default, one

Returns:
Normalized inputs (the same shape as inputs).
Expand Down Expand Up @@ -157,7 +151,7 @@ class FlaxBertEmbedding(nn.Module):
hidden_size: int
emb_init: Callable[..., np.ndarray] = nn.initializers.normal(stddev=0.1)

@compact
@nn.compact
def __call__(self, inputs):
embedding = self.param("weight", self.emb_init, (self.vocab_size, self.hidden_size))
return jnp.take(embedding, inputs, axis=0)
Expand All @@ -171,7 +165,7 @@ class FlaxBertEmbeddings(nn.Module):
type_vocab_size: int
max_length: int

@compact
@nn.compact
def __call__(self, input_ids, token_type_ids, position_ids, attention_mask):

# Embed
Expand All @@ -198,7 +192,7 @@ class FlaxBertAttention(nn.Module):
num_heads: int
head_size: int

@compact
@nn.compact
def __call__(self, hidden_state, attention_mask):
self_att = nn.attention.SelfAttention(num_heads=self.num_heads, qkv_features=self.head_size, name="self")(
hidden_state, attention_mask
Expand All @@ -211,15 +205,15 @@ def __call__(self, hidden_state, attention_mask):
class FlaxBertIntermediate(nn.Module):
output_size: int

@compact
@nn.compact
def __call__(self, hidden_state):
# TODO: Add ACT2FN reference to change activation function
dense = nn.Dense(features=self.output_size, name="dense")(hidden_state)
return gelu(dense)


class FlaxBertOutput(nn.Module):
@compact
@nn.compact
def __call__(self, intermediate_output, attention_output):
hidden_state = nn.Dense(attention_output.shape[-1], name="dense")(intermediate_output)
hidden_state = FlaxBertLayerNorm(name="layer_norm")(hidden_state + attention_output)
Expand All @@ -231,7 +225,7 @@ class FlaxBertLayer(nn.Module):
head_size: int
intermediate_size: int

@compact
@nn.compact
def __call__(self, hidden_state, attention_mask):
attention = FlaxBertAttention(self.num_heads, self.head_size, name="attention")(hidden_state, attention_mask)
intermediate = FlaxBertIntermediate(self.intermediate_size, name="intermediate")(attention)
Expand All @@ -250,7 +244,7 @@ class FlaxBertLayerCollection(nn.Module):
head_size: int
intermediate_size: int

@compact
@nn.compact
def __call__(self, inputs, attention_mask):
assert self.num_layers > 0, f"num_layers should be >= 1, got ({self.num_layers})"

Expand All @@ -270,7 +264,7 @@ class FlaxBertEncoder(nn.Module):
head_size: int
intermediate_size: int

@compact
@nn.compact
def __call__(self, hidden_state, attention_mask):
layer = FlaxBertLayerCollection(
self.num_layers, self.num_heads, self.head_size, self.intermediate_size, name="layer"
Expand All @@ -279,7 +273,7 @@ def __call__(self, hidden_state, attention_mask):


class FlaxBertPooler(nn.Module):
@compact
@nn.compact
def __call__(self, hidden_state):
cls_token = hidden_state[:, 0]
out = nn.Dense(hidden_state.shape[-1], name="dense")(cls_token)
Expand All @@ -296,7 +290,7 @@ class FlaxBertModule(nn.Module):
head_size: int
intermediate_size: int

@compact
@nn.compact
def __call__(self, input_ids, token_type_ids, position_ids, attention_mask):

# Embedding
Expand Down
38 changes: 16 additions & 22 deletions src/transformers/modeling_flax_roberta.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import flax.linen as nn
import jax
import jax.numpy as jnp
from flax.linen import compact

from .configuration_roberta import RobertaConfig
from .file_utils import add_start_docstrings
Expand Down Expand Up @@ -108,13 +107,15 @@ class FlaxRobertaLayerNorm(nn.Module):
"""

epsilon: float = 1e-6
dtype: jnp.dtype = jnp.float32
bias: bool = True
scale: bool = True
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
bias: bool = True # If True, bias (beta) is added.
scale: bool = True # If True, multiply by scale (gamma). When the next layer is linear
# (also e.g. nn.relu), this can be disabled since the scaling will be
# done by the next layer.
bias_init: jnp.ndarray = nn.initializers.zeros
scale_init: jnp.ndarray = nn.initializers.ones

@compact
@nn.compact
def __call__(self, x):
"""
Applies layer normalization on the input. It normalizes the activations of the layer for each given example in
Expand All @@ -123,13 +124,6 @@ def __call__(self, x):

Args:
x: the inputs
epsilon: A small float added to variance to avoid dividing by zero.
dtype: the dtype of the computation (default: float32).
bias: If True, bias (beta) is added.
scale: If True, multiply by scale (gamma). When the next layer is linear
(also e.g. nn.relu), this can be disabled since the scaling will be done by the next layer.
bias_init: Initializer for bias, by default, zero.
scale_init: Initializer for scale, by default, one

Returns:
Normalized inputs (the same shape as inputs).
Expand Down Expand Up @@ -158,7 +152,7 @@ class FlaxRobertaEmbedding(nn.Module):
hidden_size: int
emb_init: Callable[..., np.ndarray] = nn.initializers.normal(stddev=0.1)

@compact
@nn.compact
def __call__(self, inputs):
embedding = self.param("weight", self.emb_init, (self.vocab_size, self.hidden_size))
return jnp.take(embedding, inputs, axis=0)
Expand All @@ -173,7 +167,7 @@ class FlaxRobertaEmbeddings(nn.Module):
type_vocab_size: int
max_length: int

@compact
@nn.compact
def __call__(self, input_ids, token_type_ids, position_ids, attention_mask):

# Embed
Expand Down Expand Up @@ -201,7 +195,7 @@ class FlaxRobertaAttention(nn.Module):
num_heads: int
head_size: int

@compact
@nn.compact
def __call__(self, hidden_state, attention_mask):
self_att = nn.attention.SelfAttention(num_heads=self.num_heads, qkv_features=self.head_size, name="self")(
hidden_state, attention_mask
Expand All @@ -215,7 +209,7 @@ def __call__(self, hidden_state, attention_mask):
class FlaxRobertaIntermediate(nn.Module):
output_size: int

@compact
@nn.compact
def __call__(self, hidden_state):
# TODO: Add ACT2FN reference to change activation function
dense = nn.Dense(features=self.output_size, name="dense")(hidden_state)
Expand All @@ -224,7 +218,7 @@ def __call__(self, hidden_state):

# Copied from transformers.modeling_flax_bert.FlaxBertOutput with Bert->Roberta
class FlaxRobertaOutput(nn.Module):
@compact
@nn.compact
def __call__(self, intermediate_output, attention_output):
hidden_state = nn.Dense(attention_output.shape[-1], name="dense")(intermediate_output)
hidden_state = FlaxRobertaLayerNorm(name="layer_norm")(hidden_state + attention_output)
Expand All @@ -236,7 +230,7 @@ class FlaxRobertaLayer(nn.Module):
head_size: int
intermediate_size: int

@compact
@nn.compact
def __call__(self, hidden_state, attention_mask):
attention = FlaxRobertaAttention(self.num_heads, self.head_size, name="attention")(
hidden_state, attention_mask
Expand All @@ -258,7 +252,7 @@ class FlaxRobertaLayerCollection(nn.Module):
head_size: int
intermediate_size: int

@compact
@nn.compact
def __call__(self, inputs, attention_mask):
assert self.num_layers > 0, f"num_layers should be >= 1, got ({self.num_layers})"

Expand All @@ -279,7 +273,7 @@ class FlaxRobertaEncoder(nn.Module):
head_size: int
intermediate_size: int

@compact
@nn.compact
def __call__(self, hidden_state, attention_mask):
layer = FlaxRobertaLayerCollection(
self.num_layers, self.num_heads, self.head_size, self.intermediate_size, name="layer"
Expand All @@ -289,7 +283,7 @@ def __call__(self, hidden_state, attention_mask):

# Copied from transformers.modeling_flax_bert.FlaxBertPooler with Bert->Roberta
class FlaxRobertaPooler(nn.Module):
@compact
@nn.compact
def __call__(self, hidden_state):
cls_token = hidden_state[:, 0]
out = nn.Dense(hidden_state.shape[-1], name="dense")(cls_token)
Expand All @@ -307,7 +301,7 @@ class FlaxRobertaModule(nn.Module):
head_size: int
intermediate_size: int

@compact
@nn.compact
def __call__(self, input_ids, token_type_ids, position_ids, attention_mask):

# Embedding
Expand Down