diff --git a/src/transformers/pipelines/text_generation.py b/src/transformers/pipelines/text_generation.py index 5ace38930ae6..12c1e3b4a4fa 100644 --- a/src/transformers/pipelines/text_generation.py +++ b/src/transformers/pipelines/text_generation.py @@ -44,10 +44,11 @@ class TextGenerationPipeline(Pipeline): "TFCTRLLMHeadModel", ] - def __init__(self, *args, **kwargs): + def __init__(self, *args, return_full_text=True, **kwargs): super().__init__(*args, **kwargs) self.check_model_type(self.ALLOWED_MODELS) + self.return_full_text = return_full_text # overriding _parse_and_tokenize to allow for unusual language-modeling tokenizer arguments def _parse_and_tokenize(self, *args, **kwargs): @@ -65,6 +66,7 @@ def __call__( text_inputs, return_tensors=False, return_text=True, + return_full_text=None, clean_up_tokenization_spaces=False, prefix=None, **generate_kwargs @@ -79,6 +81,9 @@ def __call__( Whether or not to include the tensors of predictions (as token indices) in the outputs. return_text (:obj:`bool`, `optional`, defaults to :obj:`True`): Whether or not to include the decoded texts in the outputs. + return_full_text (:obj:`bool`, `optional`, defaults to :obj:`True`): + If set to :obj:`False` only added text is returned, otherwise the full text is returned Only meaningful + if `return_text` is set to True. clean_up_tokenization_spaces (:obj:`bool`, `optional`, defaults to :obj:`False`): Whether or not to clean up the potential extra spaces in the text output. prefix (:obj:`str`, `optional`): @@ -94,6 +99,8 @@ def __call__( - **generated_token_ids** (:obj:`torch.Tensor` or :obj:`tf.Tensor`, present when ``return_tensors=True``) -- The token ids of the generated text. """ + prefix = prefix if prefix is not None else self.model.config.prefix + return_full_text = return_full_text if return_full_text is not None else self.return_full_text if isinstance(text_inputs, str): text_inputs = [text_inputs] @@ -101,7 +108,6 @@ def __call__( for prompt_text in text_inputs: # Manage correct placement of the tensors with self.device_placement(): - prefix = prefix if prefix is not None else self.model.config.prefix if prefix is None and self.model.__class__.__name__ in [ "XLNetLMHeadModel", "TransfoXLLMHeadModel", @@ -168,7 +174,12 @@ def __call__( ) ) - record["generated_text"] = prompt_text + text[prompt_length:] + if return_full_text: + all_text = prompt_text + text[prompt_length:] + else: + all_text = text[prompt_length:] + + record["generated_text"] = all_text result.append(record) results += [result] diff --git a/tests/test_pipelines_text_generation.py b/tests/test_pipelines_text_generation.py index ee6157f8ef09..24602b6460dd 100644 --- a/tests/test_pipelines_text_generation.py +++ b/tests/test_pipelines_text_generation.py @@ -15,6 +15,7 @@ import unittest from transformers import pipeline +from transformers.testing_utils import require_torch from .test_pipelines_common import MonoInputPipelineCommonMixin @@ -41,3 +42,21 @@ def test_simple_generation(self): self.assertEqual(type(outputs[0][0]["generated_text"]), str) self.assertEqual(list(outputs[1][0].keys()), ["generated_text"]) self.assertEqual(type(outputs[1][0]["generated_text"]), str) + + @require_torch + def test_generation_output_style(self): + text_generator = pipeline(task="text-generation", model=self.small_models[0]) + # text-generation is non-deterministic by nature, we can't fully test the output + + outputs = text_generator("This is a test") + self.assertIn("This is a test", outputs[0]["generated_text"]) + + outputs = text_generator("This is a test", return_full_text=False) + self.assertNotIn("This is a test", outputs[0]["generated_text"]) + + text_generator = pipeline(task="text-generation", model=self.small_models[0], return_full_text=False) + outputs = text_generator("This is a test") + self.assertNotIn("This is a test", outputs[0]["generated_text"]) + + outputs = text_generator("This is a test", return_full_text=True) + self.assertIn("This is a test", outputs[0]["generated_text"])