3131class DummyTask (Task ):
3232 @property
3333 def inputs (self ) -> List [str ]:
34- return ["instruction" ]
34+ return ["instruction" , "additional_info" ]
3535
3636 def format_input (self , input : Dict [str , Any ]) -> "ChatType" :
3737 return [
3838 {"role" : "system" , "content" : "" },
3939 {"role" : "user" , "content" : input ["instruction" ]},
4040 ]
4141
42- def format_output (self , output : Union [str , None ], input : Dict [str , Any ]) -> dict :
43- return {"output" : output }
42+ @property
43+ def outputs (self ) -> List [str ]:
44+ return ["output" , "info_from_input" ]
45+
46+ def format_output (
47+ self , output : Union [str , None ], input : Union [Dict [str , Any ], None ] = None
48+ ) -> Dict [str , Any ]:
49+ return {"output" : output , "info_from_input" : input ["additional_info" ]} # type: ignore
4450
4551
4652class DummyRuntimeLLM (DummyLLM ):
@@ -85,37 +91,139 @@ def test_with_errors(self, caplog: pytest.LogCaptureFixture) -> None:
8591 Task (name = "task" , llm = DummyLLM ()) # type: ignore
8692
8793 @pytest .mark .parametrize (
88- "group_generations, expected" ,
94+ "input, group_generations, expected" ,
8995 [
9096 (
97+ [
98+ {"instruction" : "test_0" , "additional_info" : "additional_info_0" },
99+ {"instruction" : "test_1" , "additional_info" : "additional_info_1" },
100+ {"instruction" : "test_2" , "additional_info" : "additional_info_2" },
101+ ],
91102 False ,
92103 [
93104 {
94- "instruction" : "test" ,
105+ "instruction" : "test_0" ,
106+ "additional_info" : "additional_info_0" ,
107+ "output" : "output" ,
108+ "info_from_input" : "additional_info_0" ,
109+ "model_name" : "test" ,
110+ "distilabel_metadata" : {"raw_output_task" : "output" },
111+ },
112+ {
113+ "instruction" : "test_0" ,
114+ "additional_info" : "additional_info_0" ,
115+ "output" : "output" ,
116+ "info_from_input" : "additional_info_0" ,
117+ "model_name" : "test" ,
118+ "distilabel_metadata" : {"raw_output_task" : "output" },
119+ },
120+ {
121+ "instruction" : "test_0" ,
122+ "additional_info" : "additional_info_0" ,
123+ "output" : "output" ,
124+ "info_from_input" : "additional_info_0" ,
125+ "model_name" : "test" ,
126+ "distilabel_metadata" : {"raw_output_task" : "output" },
127+ },
128+ {
129+ "instruction" : "test_1" ,
130+ "additional_info" : "additional_info_1" ,
131+ "output" : "output" ,
132+ "info_from_input" : "additional_info_1" ,
133+ "model_name" : "test" ,
134+ "distilabel_metadata" : {"raw_output_task" : "output" },
135+ },
136+ {
137+ "instruction" : "test_1" ,
138+ "additional_info" : "additional_info_1" ,
139+ "output" : "output" ,
140+ "info_from_input" : "additional_info_1" ,
141+ "model_name" : "test" ,
142+ "distilabel_metadata" : {"raw_output_task" : "output" },
143+ },
144+ {
145+ "instruction" : "test_1" ,
146+ "additional_info" : "additional_info_1" ,
95147 "output" : "output" ,
148+ "info_from_input" : "additional_info_1" ,
96149 "model_name" : "test" ,
97150 "distilabel_metadata" : {"raw_output_task" : "output" },
98151 },
99152 {
100- "instruction" : "test" ,
153+ "instruction" : "test_2" ,
154+ "additional_info" : "additional_info_2" ,
101155 "output" : "output" ,
156+ "info_from_input" : "additional_info_2" ,
102157 "model_name" : "test" ,
103158 "distilabel_metadata" : {"raw_output_task" : "output" },
104159 },
105160 {
106- "instruction" : "test" ,
161+ "instruction" : "test_2" ,
162+ "additional_info" : "additional_info_2" ,
107163 "output" : "output" ,
164+ "info_from_input" : "additional_info_2" ,
165+ "model_name" : "test" ,
166+ "distilabel_metadata" : {"raw_output_task" : "output" },
167+ },
168+ {
169+ "instruction" : "test_2" ,
170+ "additional_info" : "additional_info_2" ,
171+ "output" : "output" ,
172+ "info_from_input" : "additional_info_2" ,
108173 "model_name" : "test" ,
109174 "distilabel_metadata" : {"raw_output_task" : "output" },
110175 },
111176 ],
112177 ),
113178 (
179+ [
180+ {"instruction" : "test_0" , "additional_info" : "additional_info_0" },
181+ {"instruction" : "test_1" , "additional_info" : "additional_info_1" },
182+ {"instruction" : "test_2" , "additional_info" : "additional_info_2" },
183+ ],
114184 True ,
115185 [
116186 {
117- "instruction" : "test" ,
187+ "instruction" : "test_0" ,
188+ "additional_info" : "additional_info_0" ,
189+ "output" : ["output" , "output" , "output" ],
190+ "info_from_input" : [
191+ "additional_info_0" ,
192+ "additional_info_0" ,
193+ "additional_info_0" ,
194+ ],
195+ "model_name" : "test" ,
196+ "distilabel_metadata" : [
197+ {"raw_output_task" : "output" },
198+ {"raw_output_task" : "output" },
199+ {"raw_output_task" : "output" },
200+ ],
201+ },
202+ {
203+ "instruction" : "test_1" ,
204+ "additional_info" : "additional_info_1" ,
205+ "output" : ["output" , "output" , "output" ],
206+ "info_from_input" : [
207+ "additional_info_1" ,
208+ "additional_info_1" ,
209+ "additional_info_1" ,
210+ ],
211+ "model_name" : "test" ,
212+ "distilabel_metadata" : [
213+ {"raw_output_task" : "output" },
214+ {"raw_output_task" : "output" },
215+ {"raw_output_task" : "output" },
216+ ],
217+ },
218+ {
219+ "instruction" : "test_2" ,
220+ "additional_info" : "additional_info_2" ,
118221 "output" : ["output" , "output" , "output" ],
222+ "info_from_input" : [
223+ "additional_info_2" ,
224+ "additional_info_2" ,
225+ "additional_info_2" ,
226+ ],
119227 "model_name" : "test" ,
120228 "distilabel_metadata" : [
121229 {"raw_output_task" : "output" },
@@ -128,7 +236,10 @@ def test_with_errors(self, caplog: pytest.LogCaptureFixture) -> None:
128236 ],
129237 )
130238 def test_process (
131- self , group_generations : bool , expected : List [Dict [str , Any ]]
239+ self ,
240+ input : List [Dict [str , str ]],
241+ group_generations : bool ,
242+ expected : List [Dict [str , Any ]],
132243 ) -> None :
133244 pipeline = Pipeline (name = "unit-test-pipeline" )
134245 llm = DummyLLM ()
@@ -139,7 +250,7 @@ def test_process(
139250 group_generations = group_generations ,
140251 num_generations = 3 ,
141252 )
142- result = next (task .process ([{ "instruction" : "test" }] ))
253+ result = next (task .process (input ))
143254 assert result == expected
144255
145256 def test_process_with_runtime_parameters (self ) -> None :
0 commit comments