Skip to content
Merged
28 changes: 27 additions & 1 deletion src/transformers/generation_logits_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import math
from abc import ABC
from typing import Iterable, List
from typing import Callable, Iterable, List

import numpy as np
import torch
Expand Down Expand Up @@ -372,3 +373,28 @@ def _set_scores_to_inf_for_banned_tokens(self, scores: torch.Tensor, banned_toke
)
scores = scores.masked_fill(banned_mask, -float("inf"))
return scores


class PrefixConstrainedLogitsProcessor(LogitsProcessor):
r"""
:class:`transformers.LogitsProcessor` that enforces that only specified sequences can be generated.

Args:
prefix_allowed_tokens_fn (:obj:`Callable[[int, torch.Tensor], List[int]]`):
a function that has as arguments :obj:`batch_id` and :obj:`inputs_ids`. This function has to return a
list with the allowed tokens for the next generation step conditioning on the previously generated
tokens :obj:`inputs_ids` and the batch ID :obj:`batch_id`.

"""

def __init__(self, prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]], num_beams: int):
self._prefix_allowed_tokens_fn = prefix_allowed_tokens_fn
self._num_beams = num_beams

def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
mask = torch.full_like(scores, -math.inf)
for batch_id, beam_sent in enumerate(input_ids.view(-1, self._num_beams, input_ids.shape[-1])):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In a future PR we could probably speed this up by just using torch.Tensor operations and not Python loops. Python loops really slow down the computation on GPU apparently (see: #6064). But we can do this in a future PR as well

Copy link
Contributor Author

@nicola-decao nicola-decao Nov 16, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wanted to keep the same signature as in fairseq as if someone has already implemented one it can use the same.

for beam_id, sent in enumerate(beam_sent):
mask[batch_id * self._num_beams + beam_id, self._prefix_allowed_tokens_fn(batch_id, sent)] = 0

return scores + mask
16 changes: 15 additions & 1 deletion src/transformers/generation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Dict, Iterable, List, Optional, Tuple
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple

import torch
from torch.nn import functional as F
Expand All @@ -26,6 +26,7 @@
MinLengthLogitsProcessor,
NoBadWordsLogitsProcessor,
NoRepeatNGramLogitsProcessor,
PrefixConstrainedLogitsProcessor,
RepetitionPenaltyLogitsProcessor,
TemperatureLogitsWarper,
TopKLogitsWarper,
Expand Down Expand Up @@ -249,6 +250,8 @@ def _get_logits_processor(
bad_words_ids: List[List[int]],
min_length: int,
eos_token_id: int,
prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]],
num_beams: int,
) -> LogitsProcessorList:
"""
This class returns a :obj:`~transformers.LogitsProcessorList` list object that contains all relevant
Expand Down Expand Up @@ -276,6 +279,8 @@ def _get_logits_processor(
processors.append(NoBadWordsLogitsProcessor(bad_words_ids, eos_token_id))
if min_length is not None and eos_token_id is not None and min_length > -1:
processors.append(MinLengthLogitsProcessor(min_length, eos_token_id))
if prefix_allowed_tokens_fn is not None:
processors.append(PrefixConstrainedLogitsProcessor(prefix_allowed_tokens_fn, num_beams))
return processors

@torch.no_grad()
Expand All @@ -300,6 +305,7 @@ def generate(
num_return_sequences: Optional[int] = None,
decoder_start_token_id: Optional[int] = None,
use_cache: Optional[bool] = None,
prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]] = None,
**model_kwargs
) -> torch.LongTensor:
r"""
Expand Down Expand Up @@ -366,6 +372,12 @@ def generate(
use_cache: (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not the model should use the past last key/values attentions (if applicable to the model) to
speed up decoding.
prefix_allowed_tokens_fn: (:obj:`Callable[[int, torch.Tensor], List[int]]`, `optional`, defaults to :obj:`None`):
If provided, at each step of Beam Search, this function constraints the search to only allowed tokens.
If not provided no constrain is applied. This function takes 2 arguments :obj:`inputs_ids` and the
batch ID :obj:`batch_id`. It has to return a list with the allowed tokens for the next generation step
conditioning on the previously generated tokens :obj:`inputs_ids` and the batch ID :obj:`batch_id`.
This argument is useful for constrained generation conditioned on the prefix.
model_kwargs:
Additional model specific kwargs will be forwarded to the :obj:`forward` function of the model. If the
model is an Encoder-Decoder model, encoder specific kwargs should not be prefixed and decoder specific
Expand Down Expand Up @@ -485,6 +497,8 @@ def generate(
bad_words_ids=bad_words_ids,
min_length=min_length,
eos_token_id=eos_token_id,
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
num_beams=num_beams,
)

if is_greedy_gen_mode:
Expand Down
11 changes: 10 additions & 1 deletion src/transformers/modeling_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"""RAG model implementation."""

from dataclasses import dataclass
from typing import List, Optional, Tuple
from typing import Callable, List, Optional, Tuple

import torch

Expand Down Expand Up @@ -1234,6 +1234,7 @@ def generate(
num_return_sequences=None,
decoder_start_token_id=None,
n_docs=None,
prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]] = None,
**model_kwargs
):
"""
Expand Down Expand Up @@ -1307,6 +1308,12 @@ def generate(
If an encoder-decoder model starts decoding with a different token than `bos`, the id of that token.
n_docs (:obj:`int`, `optional`, defaults to :obj:`config.n_docs`)
Number of documents to retrieve and/or number of documents for which to generate an answer.
prefix_allowed_tokens_fn: (:obj:`Callable[[int, torch.Tensor], List[int]]`, `optional`, defaults to :obj:`None`):
If provided, at each step of Beam Search, this function constraints the search to only allowed tokens.
If not provided no constrain is applied. This function takes 2 arguments :obj:`inputs_ids` and the
batch ID :obj:`batch_id`. It has to return a list with the allowed tokens for the next generation step
conditioning on the previously generated tokens :obj:`inputs_ids` and the batch ID :obj:`batch_id`.
This argument is useful for constrained generation conditioned on the prefix.

Return:
:obj:`torch.LongTensor` of shape :obj:`(batch_size * num_return_sequences, sequence_length)`: The generated
Expand Down Expand Up @@ -1400,6 +1407,8 @@ def extend_enc_output(tensor, num_beams=None):
bad_words_ids=bad_words_ids,
min_length=min_length,
eos_token_id=eos_token_id,
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
num_beams=num_beams,
)

if num_beams == 1:
Expand Down
21 changes: 21 additions & 0 deletions tests/test_generation_logits_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
MinLengthLogitsProcessor,
NoBadWordsLogitsProcessor,
NoRepeatNGramLogitsProcessor,
PrefixConstrainedLogitsProcessor,
RepetitionPenaltyLogitsProcessor,
TemperatureLogitsWarper,
TopKLogitsWarper,
Expand Down Expand Up @@ -281,3 +282,23 @@ def test_processor_list(self):

# input_ids should never be changed
self.assertListEqual(input_ids.tolist(), input_ids_comp.tolist())

def test_prefix_constrained_logits_processor(self):
vocab_size = 5
batch_size = 2

input_ids = torch.tensor([[0, 1, 3, 1], [0, 1, 0, 1]], device=torch_device, dtype=torch.long)
scores = self._get_uniform_logits(batch_size, vocab_size)

def prefix_allowed_tokens_fn(batch_id, inputs_ids):
return [[0, 1], [2, 3]][batch_id]

prefix_constrained_logits_proc = PrefixConstrainedLogitsProcessor(prefix_allowed_tokens_fn, 1)

filtered_scores = prefix_constrained_logits_proc(input_ids, scores.clone())

# batch 1: 1st, 2nd (0, 1) token are allowed
# batch 2: 3rd, 4th (2, 3) token are allowed
self.assertListEqual(
torch.isinf(filtered_scores).tolist(), [[False, False, True, True, True], [True, True, False, False, True]]
)