1717from ..kast .inner import KInner
1818from ..kast .manip import extract_subst , flatten_label , free_vars
1919from ..kast .outer import KDefinition , KFlatModule , KFlatModuleList , KImport , KRequire
20- from ..prelude .ml import is_top , mlAnd
20+ from ..prelude .ml import mlAnd
2121from ..utils import gen_file_timestamp , run_process , unique
2222from .kprint import KPrint
2323
@@ -197,7 +197,7 @@ def prove(
197197 depth : int | None = None ,
198198 haskell_log_format : KoreExecLogFormat = KoreExecLogFormat .ONELINE ,
199199 haskell_log_debug_transition : bool = True ,
200- ) -> CTerm :
200+ ) -> list [ CTerm ] :
201201 log_file = spec_file .with_suffix ('.debug-log' ) if log_axioms_file is None else log_axioms_file
202202 if log_file .exists ():
203203 log_file .unlink ()
@@ -245,13 +245,14 @@ def prove(
245245 raise RuntimeError ('kprove failed!' )
246246
247247 if dry_run :
248- return CTerm .bottom ()
248+ return [ CTerm .bottom ()]
249249
250250 debug_log = _get_rule_log (log_file )
251- final_state = CTerm .from_kast (kast_term (json .loads (proc_result .stdout ), KInner )) # type: ignore # https://github.com/python/mypy/issues/4717
252- if final_state .is_top and len (debug_log ) == 0 and not allow_zero_step :
251+ as_kast = kast_term (json .loads (proc_result .stdout ), KInner ) # type: ignore # https://github.com/python/mypy/issues/4717
252+ final_states = [CTerm .from_kast (disjunct ) for disjunct in flatten_label ('#Or' , as_kast )]
253+ if next (state .is_top for state in final_states ) and len (debug_log ) == 0 and not allow_zero_step :
253254 raise ValueError (f'Proof took zero steps, likely the LHS is invalid: { spec_file } ' )
254- return final_state
255+ return final_states
255256
256257 def prove_claim (
257258 self ,
@@ -265,7 +266,7 @@ def prove_claim(
265266 allow_zero_step : bool = False ,
266267 dry_run : bool = False ,
267268 depth : int | None = None ,
268- ) -> CTerm :
269+ ) -> list [ CTerm ] :
269270 with self ._tmp_claim_definition (claim , claim_id , lemmas = lemmas ) as (claim_path , claim_module_name ):
270271 return self .prove (
271272 claim_path ,
@@ -295,7 +296,7 @@ def prove_cterm(
295296 depth : int | None = None ,
296297 ) -> list [CTerm ]:
297298 claim , var_map = build_claim (claim_id , init_cterm , target_cterm , keep_vars = free_vars (init_cterm .kast ))
298- next_state = self .prove_claim (
299+ next_states_cterm = self .prove_claim (
299300 claim ,
300301 claim_id ,
301302 lemmas = lemmas ,
@@ -305,7 +306,8 @@ def prove_cterm(
305306 allow_zero_step = allow_zero_step ,
306307 depth = depth ,
307308 )
308- next_states = list (unique (var_map (ns ) for ns in flatten_label ('#Or' , next_state .kast ) if not is_top (ns )))
309+ # next_states = list(unique(var_map(ns) for ns in flatten_label('#Or', next_state.kast) if not is_top(ns)))
310+ next_states = list (unique (var_map (ns .kast ) for ns in next_states_cterm if not ns .is_top ))
309311 constraint_subst , _ = extract_subst (init_cterm .kast )
310312 next_states_cterm = [
311313 CTerm .from_kast (mlAnd ([constraint_subst .unapply (ns ), constraint_subst .ml_pred ])) for ns in next_states
0 commit comments