-
Notifications
You must be signed in to change notification settings - Fork 31.9k
Adding PrefixConstrainedLogitsProcessor #8529
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 11 commits
6b3d5bb
05516af
fd6a815
41b3fad
e706e1d
45cfd93
7d70ed9
1443aff
bb3a228
78cc520
9c98ceb
e413dcc
ce32257
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -13,8 +13,9 @@ | |
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| import math | ||
| from abc import ABC | ||
| from typing import Iterable, List | ||
| from typing import Callable, Iterable, List | ||
|
|
||
| import numpy as np | ||
| import torch | ||
|
|
@@ -372,3 +373,28 @@ def _set_scores_to_inf_for_banned_tokens(self, scores: torch.Tensor, banned_toke | |
| ) | ||
| scores = scores.masked_fill(banned_mask, -float("inf")) | ||
| return scores | ||
|
|
||
|
|
||
| class PrefixConstrainedLogitsProcessor(LogitsProcessor): | ||
| r""" | ||
| :class:`transformers.LogitsProcessor` that enforces that only specified sequences can be generated. | ||
|
|
||
| Args: | ||
| prefix_allowed_tokens_fn (:obj:`Callable[[int, torch.Tensor], List[int]]`): | ||
| a function that has as arguments :obj:`batch_id` and :obj:`inputs_ids`. This function has to return a | ||
| list with the allowed tokens for the next generation step conditioning on the previously generated | ||
| tokens :obj:`inputs_ids` and the batch ID :obj:`batch_id`. | ||
|
|
||
| """ | ||
|
|
||
| def __init__(self, prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]], num_beams: int): | ||
| self._prefix_allowed_tokens_fn = prefix_allowed_tokens_fn | ||
| self._num_beams = num_beams | ||
|
|
||
| def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: | ||
| mask = torch.full_like(scores, -math.inf) | ||
| for batch_id, beam_sent in enumerate(input_ids.view(-1, self._num_beams, input_ids.shape[-1])): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In a future PR we could probably speed this up by just using
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wanted to keep the same signature as in fairseq as if someone has already implemented one it can use the same. |
||
| for beam_id, sent in enumerate(beam_sent): | ||
| mask[batch_id * self._num_beams + beam_id, self._prefix_allowed_tokens_fn(batch_id, sent)] = 0 | ||
|
|
||
| return scores + mask | ||
Uh oh!
There was an error while loading. Please reload this page.