@@ -120,14 +120,16 @@ class GenerateMachineStep(WorkflowStep):
120120 description = "Generate a P state machine implementation"
121121 max_retries = 3
122122
123- def __init__ (self , generation_service : GenerationService , machine_name : str ):
123+ def __init__ (self , generation_service : GenerationService , machine_name : str , ensemble_size : int = 3 ):
124124 self .service = generation_service
125125 self .machine_name = machine_name
126+ self .ensemble_size = ensemble_size
126127 self .name = f"generate_machine_{ machine_name } "
127128
128129 def execute (self , context : Dict [str , Any ]) -> StepResult :
129130 design_doc = context .get ("design_doc" )
130131 project_path = context .get ("project_path" )
132+ ensemble_size = context .get ("ensemble_size" , self .ensemble_size )
131133
132134 if not design_doc :
133135 return StepResult .failure ("design_doc is required" )
@@ -148,13 +150,23 @@ def execute(self, context: Dict[str, Any]) -> StepResult:
148150 context_files [machine_file ] = value
149151
150152 try :
151- result = self .service .generate_machine (
152- machine_name = self .machine_name ,
153- design_doc = design_doc ,
154- project_path = project_path ,
155- context_files = context_files ,
156- save_to_disk = False # Preview mode
157- )
153+ if ensemble_size > 1 :
154+ result = self .service .generate_machine_ensemble (
155+ machine_name = self .machine_name ,
156+ design_doc = design_doc ,
157+ project_path = project_path ,
158+ context_files = context_files ,
159+ ensemble_size = ensemble_size ,
160+ save_to_disk = False # Preview mode
161+ )
162+ else :
163+ result = self .service .generate_machine (
164+ machine_name = self .machine_name ,
165+ design_doc = design_doc ,
166+ project_path = project_path ,
167+ context_files = context_files ,
168+ save_to_disk = False # Preview mode
169+ )
158170
159171 if result .success :
160172 return StepResult .success (
@@ -185,14 +197,16 @@ class GenerateSpecStep(WorkflowStep):
185197 description = "Generate a P specification/monitor file"
186198 max_retries = 3
187199
188- def __init__ (self , generation_service : GenerationService , spec_name : str = "Safety" ):
200+ def __init__ (self , generation_service : GenerationService , spec_name : str = "Safety" , ensemble_size : int = 3 ):
189201 self .service = generation_service
190202 self .spec_name = spec_name
203+ self .ensemble_size = ensemble_size
191204 self .name = f"generate_spec_{ spec_name } "
192205
193206 def execute (self , context : Dict [str , Any ]) -> StepResult :
194207 design_doc = context .get ("design_doc" )
195208 project_path = context .get ("project_path" )
209+ ensemble_size = context .get ("ensemble_size" , self .ensemble_size )
196210
197211 if not design_doc :
198212 return StepResult .failure ("design_doc is required" )
@@ -203,13 +217,23 @@ def execute(self, context: Dict[str, Any]) -> StepResult:
203217 context_files = self ._collect_context_files (context , project_path )
204218
205219 try :
206- result = self .service .generate_spec (
207- spec_name = self .spec_name ,
208- design_doc = design_doc ,
209- project_path = project_path ,
210- context_files = context_files ,
211- save_to_disk = False
212- )
220+ if ensemble_size > 1 :
221+ result = self .service .generate_spec_ensemble (
222+ spec_name = self .spec_name ,
223+ design_doc = design_doc ,
224+ project_path = project_path ,
225+ context_files = context_files ,
226+ ensemble_size = ensemble_size ,
227+ save_to_disk = False
228+ )
229+ else :
230+ result = self .service .generate_spec (
231+ spec_name = self .spec_name ,
232+ design_doc = design_doc ,
233+ project_path = project_path ,
234+ context_files = context_files ,
235+ save_to_disk = False
236+ )
213237
214238 if result .success :
215239 return StepResult .success (
@@ -265,14 +289,16 @@ class GenerateTestStep(WorkflowStep):
265289 description = "Generate a P test driver file"
266290 max_retries = 3
267291
268- def __init__ (self , generation_service : GenerationService , test_name : str = "TestDriver" ):
292+ def __init__ (self , generation_service : GenerationService , test_name : str = "TestDriver" , ensemble_size : int = 3 ):
269293 self .service = generation_service
270294 self .test_name = test_name
295+ self .ensemble_size = ensemble_size
271296 self .name = f"generate_test_{ test_name } "
272297
273298 def execute (self , context : Dict [str , Any ]) -> StepResult :
274299 design_doc = context .get ("design_doc" )
275300 project_path = context .get ("project_path" )
301+ ensemble_size = context .get ("ensemble_size" , self .ensemble_size )
276302
277303 if not design_doc :
278304 return StepResult .failure ("design_doc is required" )
@@ -283,13 +309,23 @@ def execute(self, context: Dict[str, Any]) -> StepResult:
283309 context_files = self ._collect_all_context (context , project_path )
284310
285311 try :
286- result = self .service .generate_test (
287- test_name = self .test_name ,
288- design_doc = design_doc ,
289- project_path = project_path ,
290- context_files = context_files ,
291- save_to_disk = False
292- )
312+ if ensemble_size > 1 :
313+ result = self .service .generate_test_ensemble (
314+ test_name = self .test_name ,
315+ design_doc = design_doc ,
316+ project_path = project_path ,
317+ context_files = context_files ,
318+ ensemble_size = ensemble_size ,
319+ save_to_disk = False
320+ )
321+ else :
322+ result = self .service .generate_test (
323+ test_name = self .test_name ,
324+ design_doc = design_doc ,
325+ project_path = project_path ,
326+ context_files = context_files ,
327+ save_to_disk = False
328+ )
293329
294330 if result .success :
295331 return StepResult .success (
0 commit comments