|
30 | 30 |
|
31 | 31 | from distilabel.llms.base import LLM |
32 | 32 | from distilabel.llms.chat_templates import CHATML_TEMPLATE |
33 | | -from distilabel.llms.mixins import CudaDevicePlacementMixin |
| 33 | +from distilabel.llms.mixins.cuda_device_placement import CudaDevicePlacementMixin |
| 34 | +from distilabel.llms.mixins.magpie import MagpieChatTemplateMixin |
34 | 35 | from distilabel.llms.typing import GenerateOutput |
35 | 36 | from distilabel.mixins.runtime_parameters import RuntimeParameter |
36 | 37 | from distilabel.steps.tasks.typing import FormattedInput, OutlinesStructuredOutputType |
|
39 | 40 | from transformers import PreTrainedTokenizer |
40 | 41 | from vllm import LLM as _vLLM |
41 | 42 |
|
| 43 | + from distilabel.steps.tasks.typing import StandardInput |
| 44 | + |
42 | 45 |
|
43 | 46 | SamplingParams = None |
44 | 47 |
|
45 | 48 |
|
46 | | -class vLLM(LLM, CudaDevicePlacementMixin): |
| 49 | +class vLLM(LLM, MagpieChatTemplateMixin, CudaDevicePlacementMixin): |
47 | 50 | """`vLLM` library LLM implementation. |
48 | 51 |
|
49 | 52 | Attributes: |
@@ -75,6 +78,12 @@ class vLLM(LLM, CudaDevicePlacementMixin): |
75 | 78 | _tokenizer: the tokenizer instance used to format the prompt before passing it to |
76 | 79 | the `LLM`. This attribute is meant to be used internally and should not be |
77 | 80 | accessed directly. It will be set in the `load` method. |
| 81 | + use_magpie_template: a flag used to enable/disable applying the Magpie pre-query |
| 82 | + template. Defaults to `False`. |
| 83 | + magpie_pre_query_template: the pre-query template to be applied to the prompt or |
| 84 | + sent to the LLM to generate an instruction or a follow up user message. Valid |
| 85 | + values are "llama3", "qwen2" or another pre-query template provided. Defaults |
| 86 | + to `None`. |
78 | 87 |
|
79 | 88 | References: |
80 | 89 | - https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/llm.py |
@@ -213,15 +222,26 @@ def model_name(self) -> str: |
213 | 222 | """Returns the model name used for the LLM.""" |
214 | 223 | return self.model |
215 | 224 |
|
216 | | - def prepare_input(self, input: "FormattedInput") -> str: |
217 | | - """Prepares the input by applying the chat template to the input, which is formatted |
218 | | - as an OpenAI conversation, and adding the generation prompt. |
| 225 | + def prepare_input(self, input: "StandardInput") -> str: |
| 226 | + """Prepares the input (applying the chat template and tokenization) for the provided |
| 227 | + input. |
| 228 | +
|
| 229 | + Args: |
| 230 | + input: the input list containing chat items. |
| 231 | +
|
| 232 | + Returns: |
| 233 | + The prompt to send to the LLM. |
219 | 234 | """ |
220 | | - return self._tokenizer.apply_chat_template( # type: ignore |
221 | | - input, # type: ignore |
222 | | - tokenize=False, |
223 | | - add_generation_prompt=True, # type: ignore |
| 235 | + prompt: str = ( |
| 236 | + self._tokenizer.apply_chat_template( # type: ignore |
| 237 | + input, # type: ignore |
| 238 | + tokenize=False, |
| 239 | + add_generation_prompt=True, # type: ignore |
| 240 | + ) |
| 241 | + if input |
| 242 | + else "" |
224 | 243 | ) |
| 244 | + return super().apply_magpie_pre_query_template(prompt, input) |
225 | 245 |
|
226 | 246 | def _prepare_batches( |
227 | 247 | self, inputs: List[FormattedInput] |
@@ -304,14 +324,13 @@ def generate( # type: ignore |
304 | 324 | if extra_sampling_params is None: |
305 | 325 | extra_sampling_params = {} |
306 | 326 | structured_output = None |
307 | | - needs_sorting = False |
308 | 327 |
|
309 | 328 | if isinstance(inputs[0], tuple): |
310 | 329 | prepared_batches, sorted_indices = self._prepare_batches(inputs) |
311 | | - needs_sorting = True |
312 | 330 | else: |
313 | 331 | # Simulate a batch without the structured output content |
314 | 332 | prepared_batches = [([self.prepare_input(input) for input in inputs], None)] |
| 333 | + sorted_indices = None |
315 | 334 |
|
316 | 335 | # In case we have a single structured output for the dataset, we can |
317 | 336 | logits_processors = None |
@@ -348,7 +367,7 @@ def generate( # type: ignore |
348 | 367 |
|
349 | 368 | # If logits_processor is set, we need to sort the outputs back to the original order |
350 | 369 | # (would be needed only if we have multiple structured outputs in the dataset) |
351 | | - if needs_sorting: |
| 370 | + if sorted_indices is not None: |
352 | 371 | batched_outputs = _sort_batches( |
353 | 372 | batched_outputs, sorted_indices, num_generations=num_generations |
354 | 373 | ) |
|
0 commit comments