11//! This module provides a specification-friendly bit vector type.
22use super :: bit:: { Bit , MachineInteger } ;
3-
4- // TODO: this module uses `u128/i128` as mathematic integers. We should use `hax_lib::int` or bigint.
3+ use super :: funarr:: * ;
54
65use std:: fmt:: Formatter ;
76
7+ // This is required due to some hax-lib inconsistencies with versus without `cfg(hax)`.
8+ #[ cfg( hax) ]
9+ use hax_lib:: { int, ToInt } ;
10+
11+ // TODO: this module uses `u128/i128` as mathematic integers. We should use `hax_lib::int` or bigint.
12+
813/// A fixed-size bit vector type.
914///
1015/// `BitVec<N>` is a specification-friendly, fixed-length bit vector that internally
@@ -15,12 +20,14 @@ use std::fmt::Formatter;
1520/// The [`Debug`] implementation for `BitVec` pretty-prints the bits in groups of eight,
1621/// making the bit pattern more human-readable. The type also implements indexing,
1722/// allowing for easy access to individual bits.
23+ #[ hax_lib:: fstar:: before( "noeq" ) ]
1824#[ derive( Copy , Clone , Eq , PartialEq ) ]
19- pub struct BitVec < const N : usize > ( [ Bit ; N ] ) ;
25+ pub struct BitVec < const N : u64 > ( FunArray < N , Bit > ) ;
2026
2127/// Pretty prints a bit slice by group of 8
28+ #[ hax_lib:: exclude]
2229fn bit_slice_to_string ( bits : & [ Bit ] ) -> String {
23- bits. into_iter ( )
30+ bits. iter ( )
2431 . map ( |bit| match bit {
2532 Bit :: Zero => '0' ,
2633 Bit :: One => '1' ,
@@ -34,33 +41,38 @@ fn bit_slice_to_string(bits: &[Bit]) -> String {
3441 . into ( )
3542}
3643
37- impl < const N : usize > core:: fmt:: Debug for BitVec < N > {
44+ #[ hax_lib:: exclude]
45+ impl < const N : u64 > core:: fmt:: Debug for BitVec < N > {
3846 fn fmt ( & self , f : & mut Formatter < ' _ > ) -> Result < ( ) , std:: fmt:: Error > {
39- write ! ( f, "{}" , bit_slice_to_string( & self . 0 ) )
47+ write ! ( f, "{}" , bit_slice_to_string( & self . 0 . as_vec ( ) ) )
4048 }
4149}
4250
43- impl < const N : usize > core:: ops:: Index < usize > for BitVec < N > {
51+ #[ hax_lib:: attributes]
52+ impl < const N : u64 > core:: ops:: Index < u64 > for BitVec < N > {
4453 type Output = Bit ;
45- fn index ( & self , index : usize ) -> & Self :: Output {
46- & self . 0 [ index]
54+ #[ requires( index < N ) ]
55+ fn index ( & self , index : u64 ) -> & Self :: Output {
56+ self . 0 . get ( index)
4757 }
4858}
4959
5060/// Convert a bit slice into an unsigned number.
61+ #[ hax_lib:: exclude]
5162fn u64_int_from_bit_slice ( bits : & [ Bit ] ) -> u64 {
52- bits. into_iter ( )
63+ bits. iter ( )
5364 . enumerate ( )
5465 . map ( |( i, bit) | u64:: from ( bit. clone ( ) ) << i)
5566 . sum :: < u64 > ( )
5667}
5768
5869/// Convert a bit slice into a machine integer of type `T`.
70+ #[ hax_lib:: exclude]
5971fn int_from_bit_slice < T : TryFrom < i128 > + MachineInteger + Copy > ( bits : & [ Bit ] ) -> T {
60- debug_assert ! ( bits. len( ) <= T :: BITS as usize ) ;
72+ debug_assert ! ( bits. len( ) <= T :: bits ( ) as usize ) ;
6173 let result = if T :: SIGNED {
62- let is_negative = matches ! ( bits[ T :: BITS as usize - 1 ] , Bit :: One ) ;
63- let s = u64_int_from_bit_slice ( & bits[ 0 ..T :: BITS as usize - 1 ] ) as i128 ;
74+ let is_negative = matches ! ( bits[ T :: bits ( ) as usize - 1 ] , Bit :: One ) ;
75+ let s = u64_int_from_bit_slice ( & bits[ 0 ..T :: bits ( ) as usize - 1 ] ) as i128 ;
6476 if is_negative {
6577 -s
6678 } else {
@@ -76,39 +88,181 @@ fn int_from_bit_slice<T: TryFrom<i128> + MachineInteger + Copy>(bits: &[Bit]) ->
7688 n
7789}
7890
79- impl < const N : usize > BitVec < N > {
80- /// Constructor for BitVec. `BitVec::<N>::from_fn` constructs a bitvector out of a function that takes usizes smaller than `N` and produces bits.
81- pub fn from_fn < F : FnMut ( usize ) -> Bit > ( f : F ) -> Self {
82- Self ( core:: array:: from_fn ( f) )
91+ /// An F* attribute that indiquates a rewritting lemma should be applied
92+ pub const REWRITE_RULE : ( ) = { } ;
93+
94+ #[ hax_lib:: fstar:: replace(
95+ r#"
96+ let ${BitVec::<0>::from_fn::<fn(u64)->Bit>}
97+ (v_N: u64)
98+ (f: (i: u64 {v i < v v_N}) -> $:{Bit})
99+ : t_BitVec v_N =
100+ ${BitVec::<0>}(${FunArray::<0,()>::from_fn::<fn(u64)->()>} v_N f)
101+
102+ open FStar.FunctionalExtensionality
103+ let ${BitVec::<0>::pointwise}
104+ (v_N: u64) (f: t_BitVec v_N)
105+ (#[${_pointwise_apply_mk_term} (v v_N) (fun (i:nat{i < v v_N}) -> f._0 (mk_u64 i))] def: (n: nat {n < v v_N}) -> $:{Bit})
106+ : t_BitVec v_N
107+ = ${BitVec::<0>::from_fn::<fn(u64)->Bit>} v_N (on (i: u64 {v i < v v_N}) (fun i -> def (v i)))
108+
109+ let extensionality' (#a: Type) (#b: Type) (f g: FStar.FunctionalExtensionality.(a ^-> b))
110+ : Lemma (ensures (FStar.FunctionalExtensionality.feq f g <==> f == g))
111+ = ()
112+
113+ open FStar.Tactics.V2
114+ #push-options "--z3rlimit 80 --admit_smt_queries true"
115+ let ${BitVec::<128>::rewrite_pointwise} (x: $:{BitVec<128>})
116+ : Lemma (x == ${BitVec::<128>::pointwise} (${128u64}) x) =
117+ let a = x._0 in
118+ let b = (${BitVec::<128>::pointwise} (${128u64}) x)._0 in
119+ assert_norm (FStar.FunctionalExtensionality.feq a b);
120+ extensionality' a b
121+
122+ let ${BitVec::<256>::rewrite_pointwise} (x: $:{BitVec<256>})
123+ : Lemma (x == ${BitVec::<256>::pointwise} (${256u64}) x) =
124+ let a = x._0 in
125+ let b = (${BitVec::<256>::pointwise} (${256u64}) x)._0 in
126+ assert_norm (FStar.FunctionalExtensionality.feq a b);
127+ extensionality' a b
128+ #pop-options
129+
130+ let postprocess_rewrite_helper (rw_lemma: term) (): Tac unit = with_compat_pre_core 1 (fun () ->
131+ let debug_mode = ext_enabled "debug_bv_postprocess_rewrite" in
132+ let crate = match cur_module () with | crate::_ -> crate | _ -> fail "Empty module name" in
133+ // Remove indirections
134+ norm [primops; iota; delta_namespace [crate; "Libcrux_intrinsics"]; zeta_full];
135+ // Rewrite call chains
136+ let lemmas = FStar.List.Tot.map (fun f -> pack_ln (FStar.Stubs.Reflection.V2.Data.Tv_FVar f)) (lookup_attr (`${REWRITE_RULE}) (top_env ())) in
137+ l_to_r lemmas;
138+ /// Get rid of casts
139+ norm [primops; iota; delta_namespace ["Rust_primitives"; "Prims.pow2"]; zeta_full];
140+ if debug_mode then print ("[postprocess_rewrite_helper] lemmas = " ^ term_to_string (quote lemmas));
141+ if debug_mode then dump "[postprocess_rewrite_helper] After applying lemmas";
142+ // Apply pointwise rw
143+ let done = alloc false in
144+ ctrl_rewrite TopDown (fun _ -> if read done then (false, Skip) else (true, Continue))
145+ (fun _ -> (fun () -> apply_lemma_rw rw_lemma; write done true)
146+ `or_else` trefl);
147+ // Normalize as much as possible
148+ norm [primops; iota; delta_namespace ["Core"; crate; "Minicore"; "Libcrux_intrinsics"; "FStar.FunctionalExtensionality"; "Rust_primitives"]; zeta_full];
149+ // Compute the last bits
150+ compute ();
151+ // Force full normalization
152+ norm [primops; iota; delta; zeta_full];
153+ if debug_mode then dump "[postprocess_rewrite_helper] after full normalization";
154+ // Solves the goal `<normalized body> == ?u`
155+ trefl ()
156+ )
157+
158+ let ${BitVec::<256>::postprocess_rewrite} = postprocess_rewrite_helper (`${BitVec::<256>::rewrite_pointwise})
159+ let ${BitVec::<128>::postprocess_rewrite} = postprocess_rewrite_helper (`${BitVec::<128>::rewrite_pointwise})
160+ "#
161+ ) ]
162+ const _: ( ) = ( ) ;
163+
164+ #[ hax_lib:: fstar:: replace(
165+ r#"
166+ "#
167+ ) ]
168+ pub fn postprocess_normalize_128 ( ) { }
169+
170+ #[ hax_lib:: exclude]
171+ impl BitVec < 128 > {
172+ pub fn rewrite_pointwise ( self ) { }
173+ pub fn postprocess_rewrite ( ) { }
174+ }
175+ #[ hax_lib:: exclude]
176+ impl BitVec < 256 > {
177+ pub fn rewrite_pointwise ( self ) { }
178+ pub fn postprocess_rewrite ( ) { }
179+ }
180+
181+ #[ hax_lib:: exclude]
182+ impl < const N : u64 > BitVec < N > {
183+ pub fn pointwise ( self ) -> Self {
184+ self
83185 }
84186
187+ /// Constructor for BitVec. `BitVec::<N>::from_fn` constructs a bitvector out of a function that takes usizes smaller than `N` and produces bits.
188+ pub fn from_fn < F : Fn ( u64 ) -> Bit > ( f : F ) -> Self {
189+ Self ( FunArray :: from_fn ( f) )
190+ }
85191 /// Convert a slice of machine integers where only the `d` least significant bits are relevant.
86- pub fn from_slice < T : Into < i128 > + MachineInteger + Copy > ( x : & [ T ] , d : usize ) -> Self {
87- Self :: from_fn ( |i| Bit :: of_int ( x[ i / d] , ( i % d) as u32 ) )
192+ pub fn from_slice < T : Into < i128 > + MachineInteger + Copy > ( x : & [ T ] , d : u64 ) -> Self {
193+ Self :: from_fn ( |i| Bit :: of_int ( x[ ( i / d) as usize ] , ( i % d) as u32 ) )
88194 }
89195
90196 /// Construct a BitVec out of a machine integer.
91197 pub fn from_int < T : Into < i128 > + MachineInteger + Copy > ( n : T ) -> Self {
92- Self :: from_slice ( & [ n. into ( ) ] , T :: BITS as usize )
198+ Self :: from_slice ( & [ n. into ( ) ] , T :: bits ( ) as u64 )
93199 }
94200
95201 /// Convert a BitVec into a machine integer of type `T`.
96202 pub fn to_int < T : TryFrom < i128 > + MachineInteger + Copy > ( self ) -> T {
97- int_from_bit_slice ( & self . 0 )
203+ int_from_bit_slice ( & self . 0 . as_vec ( ) )
98204 }
99205
100206 /// Convert a BitVec into a vector of machine integers of type `T`.
101207 pub fn to_vec < T : TryFrom < i128 > + MachineInteger + Copy > ( & self ) -> Vec < T > {
102208 self . 0
103- . chunks ( T :: BITS as usize )
209+ . as_vec ( )
210+ . chunks ( T :: bits ( ) as usize )
104211 . map ( int_from_bit_slice)
105212 . collect ( )
106213 }
107214
108215 /// Generate a random BitVec.
109216 pub fn rand ( ) -> Self {
110217 use rand:: prelude:: * ;
111- let mut rng = rand:: rng ( ) ;
112- Self :: from_fn ( |_| rng. random :: < bool > ( ) . into ( ) )
218+ let random_source: Vec < _ > = {
219+ let mut rng = rand:: rng ( ) ;
220+ ( 0 ..N ) . map ( |_| rng. random :: < bool > ( ) ) . collect ( )
221+ } ;
222+ Self :: from_fn ( |i| random_source[ i as usize ] . into ( ) )
223+ }
224+ }
225+
226+ #[ hax_lib:: attributes]
227+ impl < const N : u64 > BitVec < N > {
228+ #[ hax_lib:: requires( CHUNK > 0 && CHUNK . to_int( ) * SHIFTS . to_int( ) == N . to_int( ) ) ]
229+ pub fn chunked_shift < const CHUNK : u64 , const SHIFTS : u64 > (
230+ self ,
231+ shl : FunArray < SHIFTS , i128 > ,
232+ ) -> BitVec < N > {
233+ // TODO: this inner method is because of https://github.com/cryspen/hax-evit/issues/29
234+ #[ hax_lib:: fstar:: options( "--z3rlimit 50 --split_queries always" ) ]
235+ #[ hax_lib:: requires( CHUNK > 0 && CHUNK . to_int( ) * SHIFTS . to_int( ) == N . to_int( ) ) ]
236+ fn chunked_shift < const N : u64 , const CHUNK : u64 , const SHIFTS : u64 > (
237+ bitvec : BitVec < N > ,
238+ shl : FunArray < SHIFTS , i128 > ,
239+ ) -> BitVec < N > {
240+ BitVec :: from_fn ( |i| {
241+ let nth_bit = i % CHUNK ;
242+ let nth_chunk = i / CHUNK ;
243+ hax_lib:: assert_prop!( nth_chunk. to_int( ) <= SHIFTS . to_int( ) - int!( 1 ) ) ;
244+ hax_lib:: assert_prop!(
245+ nth_chunk. to_int( ) * CHUNK . to_int( )
246+ <= ( SHIFTS . to_int( ) - int!( 1 ) ) * CHUNK . to_int( )
247+ ) ;
248+ let shift: i128 = if nth_chunk < SHIFTS {
249+ shl[ nth_chunk]
250+ } else {
251+ 0
252+ } ;
253+ let local_index = ( nth_bit as i128 ) . wrapping_sub ( shift) ;
254+ if local_index < CHUNK as i128 && local_index >= 0 {
255+ let local_index = local_index as u64 ;
256+ hax_lib:: assert_prop!(
257+ nth_chunk. to_int( ) * CHUNK . to_int( ) + local_index. to_int( )
258+ < SHIFTS . to_int( ) * CHUNK . to_int( )
259+ ) ;
260+ bitvec[ nth_chunk * CHUNK + local_index]
261+ } else {
262+ Bit :: Zero
263+ }
264+ } )
265+ }
266+ chunked_shift :: < N , CHUNK , SHIFTS > ( self , shl)
113267 }
114268}
0 commit comments