Skip to content

Commit b22a494

Browse files
authored
Fix some issues related to Magpie task (#783)
* Fix input columns not included in output * Include `model_name` column as output * Include `model_name` column * Fix unit tests magpie * Fix typo in docstring
1 parent 4fc569d commit b22a494

File tree

4 files changed

+55
-22
lines changed

4 files changed

+55
-22
lines changed

src/distilabel/steps/tasks/magpie/base.py

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,17 @@ def _append_messages_to_conversations(
109109
conversation.append({"role": role, "content": instruction})
110110
return conversations
111111

112+
def _generate_instruction(
113+
self, inputs: List[Dict[str, Any]]
114+
) -> List[Dict[str, Any]]:
115+
prepared_inputs = self._prepare_inputs_for_instruction_generation(inputs)
116+
outputs = self.llm.generate(
117+
inputs=prepared_inputs,
118+
num_generations=1,
119+
**self.llm.generation_kwargs, # type: ignore
120+
)
121+
return [{"instruction": output[0]} for output in outputs]
122+
112123
def _generate_multi_turn_conversation(
113124
self, inputs: List[Dict[str, Any]]
114125
) -> List[Dict[str, Any]]:
@@ -156,17 +167,16 @@ def _generate_with_pre_query_template(
156167
Returns:
157168
The list of generated conversations.
158169
"""
170+
outputs = (
171+
self._generate_instruction(inputs)
172+
if self.only_instruction
173+
else self._generate_multi_turn_conversation(inputs)
174+
)
159175

160-
if self.only_instruction:
161-
prepared_inputs = self._prepare_inputs_for_instruction_generation(inputs)
162-
outputs = self.llm.generate(
163-
inputs=prepared_inputs,
164-
num_generations=1,
165-
**self.llm.generation_kwargs, # type: ignore
166-
)
167-
return [{"instruction": output[0]} for output in outputs]
168-
169-
return self._generate_multi_turn_conversation(inputs)
176+
return [
177+
{**input, **output, "model_name": self.llm.model_name}
178+
for input, output in zip(inputs, outputs)
179+
]
170180

171181

172182
class Magpie(Task, MagpieBase):
@@ -209,8 +219,9 @@ class Magpie(Task, MagpieBase):
209219
210220
Output columns:
211221
- conversation (`ChatType`): the generated conversation which is a list of chat
212-
items with a role and a message. Only if `only_instructions=False`.
222+
items with a role and a message. Only if `only_instruction=False`.
213223
- instruction (`str`): the generated instructions if `only_instruction=True`.
224+
- model_name (`str`): The model name used to generate the `conversation` or `instruction`.
214225
215226
Categories:
216227
- text-generation
@@ -352,8 +363,8 @@ def format_input(self, input: Dict[str, Any]) -> "ChatType":
352363
def outputs(self) -> List[str]:
353364
"""Either a multi-turn conversation or the instruction generated."""
354365
if self.only_instruction:
355-
return ["instruction"]
356-
return ["conversation"]
366+
return ["instruction", "model_name"]
367+
return ["conversation", "model_name"]
357368

358369
def format_output(
359370
self,

src/distilabel/steps/tasks/magpie/generator.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ class MagpieGenerator(GeneratorTask, MagpieBase):
6262
- conversation (`ChatType`): the generated conversation which is a list of chat
6363
items with a role and a message.
6464
- instruction (`str`): the generated instructions if `only_instruction=True`.
65+
- model_name (`str`): The model name used to generate the `conversation` or `instruction`.
6566
6667
Categories:
6768
- text-generation
@@ -215,8 +216,8 @@ def format_output(
215216
def outputs(self) -> List[str]:
216217
"""Either a multi-turn conversation or the instruction generated."""
217218
if self.only_instruction:
218-
return ["instruction"]
219-
return ["conversation"]
219+
return ["instruction", "model_name"]
220+
return ["conversation", "model_name"]
220221

221222
def process(self, offset: int = 0) -> "GeneratorStepOutput":
222223
"""Generates the desired number of instructions or conversations using Magpie.

tests/unit/steps/tasks/magpie/test_base.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,14 +30,14 @@ def test_raise_value_error_llm_no_magpie_mixin(self) -> None:
3030
def test_outputs(self) -> None:
3131
task = Magpie(llm=DummyMagpieLLM(magpie_pre_query_template="llama3"))
3232

33-
assert task.outputs == ["conversation"]
33+
assert task.outputs == ["conversation", "model_name"]
3434

3535
task = Magpie(
3636
llm=DummyMagpieLLM(magpie_pre_query_template="llama3"),
3737
only_instruction=True,
3838
)
3939

40-
assert task.outputs == ["instruction"]
40+
assert task.outputs == ["instruction", "model_name"]
4141

4242
def test_process(self) -> None:
4343
task = Magpie(llm=DummyMagpieLLM(magpie_pre_query_template="llama3"), n_turns=1)
@@ -50,18 +50,21 @@ def test_process(self) -> None:
5050
{"role": "user", "content": "Hello Magpie"},
5151
{"role": "assistant", "content": "Hello Magpie"},
5252
],
53+
"model_name": "test",
5354
},
5455
{
5556
"conversation": [
5657
{"role": "user", "content": "Hello Magpie"},
5758
{"role": "assistant", "content": "Hello Magpie"},
5859
],
60+
"model_name": "test",
5961
},
6062
{
6163
"conversation": [
6264
{"role": "user", "content": "Hello Magpie"},
6365
{"role": "assistant", "content": "Hello Magpie"},
6466
],
67+
"model_name": "test",
6568
},
6669
]
6770

@@ -79,6 +82,7 @@ def test_process_with_n_turns(self) -> None:
7982
{"role": "user", "content": "Hello Magpie"},
8083
{"role": "assistant", "content": "Hello Magpie"},
8184
],
85+
"model_name": "test",
8286
},
8387
{
8488
"conversation": [
@@ -88,6 +92,7 @@ def test_process_with_n_turns(self) -> None:
8892
{"role": "user", "content": "Hello Magpie"},
8993
{"role": "assistant", "content": "Hello Magpie"},
9094
],
95+
"model_name": "test",
9196
},
9297
{
9398
"conversation": [
@@ -97,6 +102,7 @@ def test_process_with_n_turns(self) -> None:
97102
{"role": "user", "content": "Hello Magpie"},
98103
{"role": "assistant", "content": "Hello Magpie"},
99104
],
105+
"model_name": "test",
100106
},
101107
]
102108

@@ -115,31 +121,37 @@ def test_process_with_system_prompt_per_row(self) -> None:
115121
)
116122
) == [
117123
{
124+
"system_prompt": "You're a math expert assistant.",
118125
"conversation": [
119126
{"role": "system", "content": "You're a math expert assistant."},
120127
{"role": "user", "content": "Hello Magpie"},
121128
{"role": "assistant", "content": "Hello Magpie"},
122129
{"role": "user", "content": "Hello Magpie"},
123130
{"role": "assistant", "content": "Hello Magpie"},
124131
],
132+
"model_name": "test",
125133
},
126134
{
135+
"system_prompt": "You're a florist expert assistant.",
127136
"conversation": [
128137
{"role": "system", "content": "You're a florist expert assistant."},
129138
{"role": "user", "content": "Hello Magpie"},
130139
{"role": "assistant", "content": "Hello Magpie"},
131140
{"role": "user", "content": "Hello Magpie"},
132141
{"role": "assistant", "content": "Hello Magpie"},
133142
],
143+
"model_name": "test",
134144
},
135145
{
146+
"system_prompt": "You're a plumber expert assistant.",
136147
"conversation": [
137148
{"role": "system", "content": "You're a plumber expert assistant."},
138149
{"role": "user", "content": "Hello Magpie"},
139150
{"role": "assistant", "content": "Hello Magpie"},
140151
{"role": "user", "content": "Hello Magpie"},
141152
{"role": "assistant", "content": "Hello Magpie"},
142153
],
154+
"model_name": "test",
143155
},
144156
]
145157

@@ -152,9 +164,18 @@ def test_process_only_instruction(self) -> None:
152164
task.load()
153165

154166
assert next(task.process(inputs=[{}, {}, {}])) == [
155-
{"instruction": "Hello Magpie"},
156-
{"instruction": "Hello Magpie"},
157-
{"instruction": "Hello Magpie"},
167+
{
168+
"instruction": "Hello Magpie",
169+
"model_name": "test",
170+
},
171+
{
172+
"instruction": "Hello Magpie",
173+
"model_name": "test",
174+
},
175+
{
176+
"instruction": "Hello Magpie",
177+
"model_name": "test",
178+
},
158179
]
159180

160181
def test_serialization(self) -> None:

tests/unit/steps/tasks/magpie/test_generator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,14 +30,14 @@ def test_raise_value_error_llm_no_magpie_mixin(self) -> None:
3030
def test_outputs(self) -> None:
3131
task = MagpieGenerator(llm=DummyMagpieLLM(magpie_pre_query_template="llama3"))
3232

33-
assert task.outputs == ["conversation"]
33+
assert task.outputs == ["conversation", "model_name"]
3434

3535
task = MagpieGenerator(
3636
llm=DummyMagpieLLM(magpie_pre_query_template="llama3"),
3737
only_instruction=True,
3838
)
3939

40-
assert task.outputs == ["instruction"]
40+
assert task.outputs == ["instruction", "model_name"]
4141

4242
def test_serialization(self) -> None:
4343
task = MagpieGenerator(llm=DummyMagpieLLM(magpie_pre_query_template="llama3"))

0 commit comments

Comments
 (0)