55
66from llama_index .llms import ChatMessage , MessageRole
77from llama_index .llms .llama_utils import (
8- DEFAULT_SYSTEM_PROMPT ,
98 completion_to_prompt ,
109 messages_to_prompt ,
1110)
@@ -29,7 +28,6 @@ class AbstractPromptStyle(abc.ABC):
2928 series of messages into a prompt.
3029 """
3130
32- @abc .abstractmethod
3331 def __init__ (self , * args : Any , ** kwargs : Any ) -> None :
3432 logger .debug ("Initializing prompt_style=%s" , self .__class__ .__name__ )
3533
@@ -52,15 +50,6 @@ def completion_to_prompt(self, completion: str) -> str:
5250 return prompt
5351
5452
55- class AbstractPromptStyleWithSystemPrompt (AbstractPromptStyle , abc .ABC ):
56- _DEFAULT_SYSTEM_PROMPT = DEFAULT_SYSTEM_PROMPT
57-
58- def __init__ (self , default_system_prompt : str | None ) -> None :
59- super ().__init__ ()
60- logger .debug ("Got default_system_prompt='%s'" , default_system_prompt )
61- self .default_system_prompt = default_system_prompt
62-
63-
6453class DefaultPromptStyle (AbstractPromptStyle ):
6554 """Default prompt style that uses the defaults from llama_utils.
6655
@@ -83,7 +72,7 @@ def _completion_to_prompt(self, completion: str) -> str:
8372 return ""
8473
8574
86- class Llama2PromptStyle (AbstractPromptStyleWithSystemPrompt ):
75+ class Llama2PromptStyle (AbstractPromptStyle ):
8776 """Simple prompt style that just uses the default llama_utils functions.
8877
8978 It transforms the sequence of messages into a prompt that should look like:
@@ -94,18 +83,14 @@ class Llama2PromptStyle(AbstractPromptStyleWithSystemPrompt):
9483 ```
9584 """
9685
97- def __init__ (self , default_system_prompt : str | None = None ) -> None :
98- # If no system prompt is given, the default one of the implementation is used.
99- super ().__init__ (default_system_prompt = default_system_prompt )
100-
10186 def _messages_to_prompt (self , messages : Sequence [ChatMessage ]) -> str :
102- return messages_to_prompt (messages , self . default_system_prompt )
87+ return messages_to_prompt (messages )
10388
10489 def _completion_to_prompt (self , completion : str ) -> str :
105- return completion_to_prompt (completion , self . default_system_prompt )
90+ return completion_to_prompt (completion )
10691
10792
108- class TagPromptStyle (AbstractPromptStyleWithSystemPrompt ):
93+ class TagPromptStyle (AbstractPromptStyle ):
10994 """Tag prompt style (used by Vigogne) that uses the prompt style `<|ROLE|>`.
11095
11196 It transforms the sequence of messages into a prompt that should look like:
@@ -119,37 +104,8 @@ class TagPromptStyle(AbstractPromptStyleWithSystemPrompt):
119104 FIXME: should we add surrounding `<s>` and `</s>` tags, like in llama2?
120105 """
121106
122- def __init__ (self , default_system_prompt : str | None = None ) -> None :
123- # We have to define a default system prompt here as the LLM will not
124- # use the default llama_utils functions.
125- default_system_prompt = default_system_prompt or self ._DEFAULT_SYSTEM_PROMPT
126- super ().__init__ (default_system_prompt )
127- self .system_prompt : str = default_system_prompt
128-
129107 def _messages_to_prompt (self , messages : Sequence [ChatMessage ]) -> str :
130- messages = list (messages )
131- if messages [0 ].role != MessageRole .SYSTEM :
132- logger .info (
133- "Adding system_promt='%s' to the given messages as there are none given in the session" ,
134- self .system_prompt ,
135- )
136- messages = [
137- ChatMessage (content = self .system_prompt , role = MessageRole .SYSTEM ),
138- * messages ,
139- ]
140- return self ._format_messages_to_prompt (messages )
141-
142- def _completion_to_prompt (self , completion : str ) -> str :
143- return (
144- f"<|system|>: { self .system_prompt .strip ()} \n "
145- f"<|user|>: { completion .strip ()} \n "
146- "<|assistant|>: "
147- )
148-
149- @staticmethod
150- def _format_messages_to_prompt (messages : list [ChatMessage ]) -> str :
151108 """Format message to prompt with `<|ROLE|>: MSG` style."""
152- assert messages [0 ].role == MessageRole .SYSTEM
153109 prompt = ""
154110 for message in messages :
155111 role = message .role
@@ -161,19 +117,24 @@ def _format_messages_to_prompt(messages: list[ChatMessage]) -> str:
161117 prompt += "<|assistant|>: "
162118 return prompt
163119
120+ def _completion_to_prompt (self , completion : str ) -> str :
121+ return self ._messages_to_prompt (
122+ [ChatMessage (content = completion , role = MessageRole .USER )]
123+ )
124+
164125
165126def get_prompt_style (
166127 prompt_style : Literal ["default" , "llama2" , "tag" ] | None
167- ) -> type [ AbstractPromptStyle ] :
128+ ) -> AbstractPromptStyle :
168129 """Get the prompt style to use from the given string.
169130
170131 :param prompt_style: The prompt style to use.
171132 :return: The prompt style to use.
172133 """
173134 if prompt_style is None or prompt_style == "default" :
174- return DefaultPromptStyle
135+ return DefaultPromptStyle ()
175136 elif prompt_style == "llama2" :
176- return Llama2PromptStyle
137+ return Llama2PromptStyle ()
177138 elif prompt_style == "tag" :
178- return TagPromptStyle
139+ return TagPromptStyle ()
179140 raise ValueError (f"Unknown prompt_style='{ prompt_style } '" )
0 commit comments