Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 14 additions & 3 deletions src/transformers/pipelines/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,11 @@ class TextGenerationPipeline(Pipeline):
"TFCTRLLMHeadModel",
]

def __init__(self, *args, **kwargs):
def __init__(self, *args, return_full_text=True, **kwargs):
super().__init__(*args, **kwargs)

self.check_model_type(self.ALLOWED_MODELS)
self.return_full_text = return_full_text

# overriding _parse_and_tokenize to allow for unusual language-modeling tokenizer arguments
def _parse_and_tokenize(self, *args, **kwargs):
Expand All @@ -65,6 +66,7 @@ def __call__(
text_inputs,
return_tensors=False,
return_text=True,
return_full_text=None,
clean_up_tokenization_spaces=False,
prefix=None,
**generate_kwargs
Expand All @@ -79,6 +81,9 @@ def __call__(
Whether or not to include the tensors of predictions (as token indices) in the outputs.
return_text (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not to include the decoded texts in the outputs.
return_full_text (:obj:`bool`, `optional`, defaults to :obj:`True`):
If set to :obj:`False` only added text is returned, otherwise the full text is returned Only meaningful
if `return_text` is set to True.
clean_up_tokenization_spaces (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to clean up the potential extra spaces in the text output.
prefix (:obj:`str`, `optional`):
Expand All @@ -94,14 +99,15 @@ def __call__(
- **generated_token_ids** (:obj:`torch.Tensor` or :obj:`tf.Tensor`, present when ``return_tensors=True``)
-- The token ids of the generated text.
"""
prefix = prefix if prefix is not None else self.model.config.prefix
return_full_text = return_full_text if return_full_text is not None else self.return_full_text

if isinstance(text_inputs, str):
text_inputs = [text_inputs]
results = []
for prompt_text in text_inputs:
# Manage correct placement of the tensors
with self.device_placement():
prefix = prefix if prefix is not None else self.model.config.prefix
if prefix is None and self.model.__class__.__name__ in [
"XLNetLMHeadModel",
"TransfoXLLMHeadModel",
Expand Down Expand Up @@ -168,7 +174,12 @@ def __call__(
)
)

record["generated_text"] = prompt_text + text[prompt_length:]
if return_full_text:
all_text = prompt_text + text[prompt_length:]
else:
all_text = text[prompt_length:]

record["generated_text"] = all_text

result.append(record)
results += [result]
Expand Down
19 changes: 19 additions & 0 deletions tests/test_pipelines_text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import unittest

from transformers import pipeline
from transformers.testing_utils import require_torch

from .test_pipelines_common import MonoInputPipelineCommonMixin

Expand All @@ -41,3 +42,21 @@ def test_simple_generation(self):
self.assertEqual(type(outputs[0][0]["generated_text"]), str)
self.assertEqual(list(outputs[1][0].keys()), ["generated_text"])
self.assertEqual(type(outputs[1][0]["generated_text"]), str)

@require_torch
def test_generation_output_style(self):
text_generator = pipeline(task="text-generation", model=self.small_models[0])
# text-generation is non-deterministic by nature, we can't fully test the output

outputs = text_generator("This is a test")
self.assertIn("This is a test", outputs[0]["generated_text"])

outputs = text_generator("This is a test", return_full_text=False)
self.assertNotIn("This is a test", outputs[0]["generated_text"])

text_generator = pipeline(task="text-generation", model=self.small_models[0], return_full_text=False)
outputs = text_generator("This is a test")
self.assertNotIn("This is a test", outputs[0]["generated_text"])

outputs = text_generator("This is a test", return_full_text=True)
self.assertIn("This is a test", outputs[0]["generated_text"])