Skip to content

Commit 0ad6af2

Browse files
authored
Merge pull request #874 from cryspen/minicore-commitment-proof
avx2/commitment: F* proofs
2 parents 2a67890 + 9c5cab3 commit 0ad6af2

25 files changed

Lines changed: 1373 additions & 561 deletions

File tree

.github/workflows/mldsa-hax.yml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,11 @@ jobs:
3434
- uses: hacspec/hax-actions@main
3535
with:
3636
hax_reference: ${{ github.event.inputs.hax_rev || 'main' }}
37-
fstar: v2025.01.17
37+
fstar: v2025.03.25
3838

3939
- name: 🏃 Extract ML-DSA crate
4040
working-directory: libcrux-ml-dsa
41-
run: ./hax.py extract
41+
run: ./hax.sh extract
4242

4343
- name: ↑ Upload F* extraction
4444
uses: actions/upload-artifact@v4
@@ -58,7 +58,7 @@ jobs:
5858
- uses: hacspec/hax-actions@main
5959
with:
6060
hax_reference: ${{ github.event.inputs.hax_rev || 'main' }}
61-
fstar: v2025.01.17
61+
fstar: v2025.03.25
6262

6363
- uses: actions/download-artifact@v4
6464
with:
@@ -67,7 +67,7 @@ jobs:
6767

6868
- name: 🏃 Lax ML-DSA crate
6969
working-directory: libcrux-ml-dsa
70-
run: ./hax.py prove --admit
70+
run: ./hax.sh prove --admit
7171

7272
mldsa-extract-hax-status:
7373
if: ${{ always() }}

fstar-helpers/minicore/Cargo.toml

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,14 @@
11
[package]
22
name = "minicore"
3-
edition = "2021"
3+
version.workspace = true
4+
authors.workspace = true
5+
license.workspace = true
6+
homepage.workspace = true
7+
edition.workspace = true
8+
repository.workspace = true
9+
readme.workspace = true
410
publish = false
511

612
[dependencies]
713
rand = "0.9"
14+
hax-lib.workspace = true

fstar-helpers/minicore/src/abstractions/bit.rs

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -68,25 +68,40 @@ impl From<bool> for Bit {
6868
}
6969

7070
/// A trait for types that represent machine integers.
71+
#[hax_lib::attributes]
7172
pub trait MachineInteger {
7273
/// The size of this integer type in bits.
73-
const BITS: u32;
74+
#[hax_lib::requires(true)]
75+
#[hax_lib::ensures(|bits| bits >= 8)]
76+
fn bits() -> u32;
7477

7578
/// The signedness of this integer type.
7679
const SIGNED: bool;
7780
}
7881

7982
macro_rules! generate_machine_integer_impls {
8083
($($ty:ident),*) => {
81-
$(impl MachineInteger for $ty {
82-
const BITS: u32 = $ty::BITS;
84+
$(#[hax_lib::exclude]impl MachineInteger for $ty {
85+
fn bits() -> u32 { $ty::BITS }
8386
#[allow(unused_comparisons)]
8487
const SIGNED: bool = $ty::MIN < 0;
8588
})*
8689
};
8790
}
8891
generate_machine_integer_impls!(u8, u16, u32, u64, u128, i8, i16, i32, i64, i128);
8992

93+
#[hax_lib::fstar::replace(
94+
r"
95+
instance impl_MachineInteger_poly (t: inttype): t_MachineInteger (int_t t) =
96+
{ f_bits = (fun () -> mk_u32 (bits t));
97+
f_bits_pre = (fun () -> True);
98+
f_bits_post = (fun () r -> r == mk_u32 (bits t));
99+
f_SIGNED = signed t }
100+
"
101+
)]
102+
const _: () = {};
103+
104+
#[hax_lib::exclude]
90105
impl Bit {
91106
fn of_raw_int(x: u128, nth: u32) -> Self {
92107
if x / 2u128.pow(nth) % 2 == 1 {
@@ -101,7 +116,7 @@ impl Bit {
101116
if x >= 0 {
102117
Self::of_raw_int(x as u128, nth)
103118
} else {
104-
Self::of_raw_int((2i128.pow(T::BITS) + x) as u128, nth)
119+
Self::of_raw_int((2i128.pow(T::bits()) + x) as u128, nth)
105120
}
106121
}
107122
}
Lines changed: 178 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,15 @@
11
//! This module provides a specification-friendly bit vector type.
22
use super::bit::{Bit, MachineInteger};
3-
4-
// TODO: this module uses `u128/i128` as mathematic integers. We should use `hax_lib::int` or bigint.
3+
use super::funarr::*;
54

65
use std::fmt::Formatter;
76

7+
// This is required due to some hax-lib inconsistencies with versus without `cfg(hax)`.
8+
#[cfg(hax)]
9+
use hax_lib::{int, ToInt};
10+
11+
// TODO: this module uses `u128/i128` as mathematic integers. We should use `hax_lib::int` or bigint.
12+
813
/// A fixed-size bit vector type.
914
///
1015
/// `BitVec<N>` is a specification-friendly, fixed-length bit vector that internally
@@ -15,12 +20,14 @@ use std::fmt::Formatter;
1520
/// The [`Debug`] implementation for `BitVec` pretty-prints the bits in groups of eight,
1621
/// making the bit pattern more human-readable. The type also implements indexing,
1722
/// allowing for easy access to individual bits.
23+
#[hax_lib::fstar::before("noeq")]
1824
#[derive(Copy, Clone, Eq, PartialEq)]
19-
pub struct BitVec<const N: usize>([Bit; N]);
25+
pub struct BitVec<const N: u64>(FunArray<N, Bit>);
2026

2127
/// Pretty prints a bit slice by group of 8
28+
#[hax_lib::exclude]
2229
fn bit_slice_to_string(bits: &[Bit]) -> String {
23-
bits.into_iter()
30+
bits.iter()
2431
.map(|bit| match bit {
2532
Bit::Zero => '0',
2633
Bit::One => '1',
@@ -34,33 +41,38 @@ fn bit_slice_to_string(bits: &[Bit]) -> String {
3441
.into()
3542
}
3643

37-
impl<const N: usize> core::fmt::Debug for BitVec<N> {
44+
#[hax_lib::exclude]
45+
impl<const N: u64> core::fmt::Debug for BitVec<N> {
3846
fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), std::fmt::Error> {
39-
write!(f, "{}", bit_slice_to_string(&self.0))
47+
write!(f, "{}", bit_slice_to_string(&self.0.as_vec()))
4048
}
4149
}
4250

43-
impl<const N: usize> core::ops::Index<usize> for BitVec<N> {
51+
#[hax_lib::attributes]
52+
impl<const N: u64> core::ops::Index<u64> for BitVec<N> {
4453
type Output = Bit;
45-
fn index(&self, index: usize) -> &Self::Output {
46-
&self.0[index]
54+
#[requires(index < N)]
55+
fn index(&self, index: u64) -> &Self::Output {
56+
self.0.get(index)
4757
}
4858
}
4959

5060
/// Convert a bit slice into an unsigned number.
61+
#[hax_lib::exclude]
5162
fn u64_int_from_bit_slice(bits: &[Bit]) -> u64 {
52-
bits.into_iter()
63+
bits.iter()
5364
.enumerate()
5465
.map(|(i, bit)| u64::from(bit.clone()) << i)
5566
.sum::<u64>()
5667
}
5768

5869
/// Convert a bit slice into a machine integer of type `T`.
70+
#[hax_lib::exclude]
5971
fn int_from_bit_slice<T: TryFrom<i128> + MachineInteger + Copy>(bits: &[Bit]) -> T {
60-
debug_assert!(bits.len() <= T::BITS as usize);
72+
debug_assert!(bits.len() <= T::bits() as usize);
6173
let result = if T::SIGNED {
62-
let is_negative = matches!(bits[T::BITS as usize - 1], Bit::One);
63-
let s = u64_int_from_bit_slice(&bits[0..T::BITS as usize - 1]) as i128;
74+
let is_negative = matches!(bits[T::bits() as usize - 1], Bit::One);
75+
let s = u64_int_from_bit_slice(&bits[0..T::bits() as usize - 1]) as i128;
6476
if is_negative {
6577
-s
6678
} else {
@@ -76,39 +88,181 @@ fn int_from_bit_slice<T: TryFrom<i128> + MachineInteger + Copy>(bits: &[Bit]) ->
7688
n
7789
}
7890

79-
impl<const N: usize> BitVec<N> {
80-
/// Constructor for BitVec. `BitVec::<N>::from_fn` constructs a bitvector out of a function that takes usizes smaller than `N` and produces bits.
81-
pub fn from_fn<F: FnMut(usize) -> Bit>(f: F) -> Self {
82-
Self(core::array::from_fn(f))
91+
/// An F* attribute that indiquates a rewritting lemma should be applied
92+
pub const REWRITE_RULE: () = {};
93+
94+
#[hax_lib::fstar::replace(
95+
r#"
96+
let ${BitVec::<0>::from_fn::<fn(u64)->Bit>}
97+
(v_N: u64)
98+
(f: (i: u64 {v i < v v_N}) -> $:{Bit})
99+
: t_BitVec v_N =
100+
${BitVec::<0>}(${FunArray::<0,()>::from_fn::<fn(u64)->()>} v_N f)
101+
102+
open FStar.FunctionalExtensionality
103+
let ${BitVec::<0>::pointwise}
104+
(v_N: u64) (f: t_BitVec v_N)
105+
(#[${_pointwise_apply_mk_term} (v v_N) (fun (i:nat{i < v v_N}) -> f._0 (mk_u64 i))] def: (n: nat {n < v v_N}) -> $:{Bit})
106+
: t_BitVec v_N
107+
= ${BitVec::<0>::from_fn::<fn(u64)->Bit>} v_N (on (i: u64 {v i < v v_N}) (fun i -> def (v i)))
108+
109+
let extensionality' (#a: Type) (#b: Type) (f g: FStar.FunctionalExtensionality.(a ^-> b))
110+
: Lemma (ensures (FStar.FunctionalExtensionality.feq f g <==> f == g))
111+
= ()
112+
113+
open FStar.Tactics.V2
114+
#push-options "--z3rlimit 80 --admit_smt_queries true"
115+
let ${BitVec::<128>::rewrite_pointwise} (x: $:{BitVec<128>})
116+
: Lemma (x == ${BitVec::<128>::pointwise} (${128u64}) x) =
117+
let a = x._0 in
118+
let b = (${BitVec::<128>::pointwise} (${128u64}) x)._0 in
119+
assert_norm (FStar.FunctionalExtensionality.feq a b);
120+
extensionality' a b
121+
122+
let ${BitVec::<256>::rewrite_pointwise} (x: $:{BitVec<256>})
123+
: Lemma (x == ${BitVec::<256>::pointwise} (${256u64}) x) =
124+
let a = x._0 in
125+
let b = (${BitVec::<256>::pointwise} (${256u64}) x)._0 in
126+
assert_norm (FStar.FunctionalExtensionality.feq a b);
127+
extensionality' a b
128+
#pop-options
129+
130+
let postprocess_rewrite_helper (rw_lemma: term) (): Tac unit = with_compat_pre_core 1 (fun () ->
131+
let debug_mode = ext_enabled "debug_bv_postprocess_rewrite" in
132+
let crate = match cur_module () with | crate::_ -> crate | _ -> fail "Empty module name" in
133+
// Remove indirections
134+
norm [primops; iota; delta_namespace [crate; "Libcrux_intrinsics"]; zeta_full];
135+
// Rewrite call chains
136+
let lemmas = FStar.List.Tot.map (fun f -> pack_ln (FStar.Stubs.Reflection.V2.Data.Tv_FVar f)) (lookup_attr (`${REWRITE_RULE}) (top_env ())) in
137+
l_to_r lemmas;
138+
/// Get rid of casts
139+
norm [primops; iota; delta_namespace ["Rust_primitives"; "Prims.pow2"]; zeta_full];
140+
if debug_mode then print ("[postprocess_rewrite_helper] lemmas = " ^ term_to_string (quote lemmas));
141+
if debug_mode then dump "[postprocess_rewrite_helper] After applying lemmas";
142+
// Apply pointwise rw
143+
let done = alloc false in
144+
ctrl_rewrite TopDown (fun _ -> if read done then (false, Skip) else (true, Continue))
145+
(fun _ -> (fun () -> apply_lemma_rw rw_lemma; write done true)
146+
`or_else` trefl);
147+
// Normalize as much as possible
148+
norm [primops; iota; delta_namespace ["Core"; crate; "Minicore"; "Libcrux_intrinsics"; "FStar.FunctionalExtensionality"; "Rust_primitives"]; zeta_full];
149+
// Compute the last bits
150+
compute ();
151+
// Force full normalization
152+
norm [primops; iota; delta; zeta_full];
153+
if debug_mode then dump "[postprocess_rewrite_helper] after full normalization";
154+
// Solves the goal `<normalized body> == ?u`
155+
trefl ()
156+
)
157+
158+
let ${BitVec::<256>::postprocess_rewrite} = postprocess_rewrite_helper (`${BitVec::<256>::rewrite_pointwise})
159+
let ${BitVec::<128>::postprocess_rewrite} = postprocess_rewrite_helper (`${BitVec::<128>::rewrite_pointwise})
160+
"#
161+
)]
162+
const _: () = ();
163+
164+
#[hax_lib::fstar::replace(
165+
r#"
166+
"#
167+
)]
168+
pub fn postprocess_normalize_128() {}
169+
170+
#[hax_lib::exclude]
171+
impl BitVec<128> {
172+
pub fn rewrite_pointwise(self) {}
173+
pub fn postprocess_rewrite() {}
174+
}
175+
#[hax_lib::exclude]
176+
impl BitVec<256> {
177+
pub fn rewrite_pointwise(self) {}
178+
pub fn postprocess_rewrite() {}
179+
}
180+
181+
#[hax_lib::exclude]
182+
impl<const N: u64> BitVec<N> {
183+
pub fn pointwise(self) -> Self {
184+
self
83185
}
84186

187+
/// Constructor for BitVec. `BitVec::<N>::from_fn` constructs a bitvector out of a function that takes usizes smaller than `N` and produces bits.
188+
pub fn from_fn<F: Fn(u64) -> Bit>(f: F) -> Self {
189+
Self(FunArray::from_fn(f))
190+
}
85191
/// Convert a slice of machine integers where only the `d` least significant bits are relevant.
86-
pub fn from_slice<T: Into<i128> + MachineInteger + Copy>(x: &[T], d: usize) -> Self {
87-
Self::from_fn(|i| Bit::of_int(x[i / d], (i % d) as u32))
192+
pub fn from_slice<T: Into<i128> + MachineInteger + Copy>(x: &[T], d: u64) -> Self {
193+
Self::from_fn(|i| Bit::of_int(x[(i / d) as usize], (i % d) as u32))
88194
}
89195

90196
/// Construct a BitVec out of a machine integer.
91197
pub fn from_int<T: Into<i128> + MachineInteger + Copy>(n: T) -> Self {
92-
Self::from_slice(&[n.into()], T::BITS as usize)
198+
Self::from_slice(&[n.into()], T::bits() as u64)
93199
}
94200

95201
/// Convert a BitVec into a machine integer of type `T`.
96202
pub fn to_int<T: TryFrom<i128> + MachineInteger + Copy>(self) -> T {
97-
int_from_bit_slice(&self.0)
203+
int_from_bit_slice(&self.0.as_vec())
98204
}
99205

100206
/// Convert a BitVec into a vector of machine integers of type `T`.
101207
pub fn to_vec<T: TryFrom<i128> + MachineInteger + Copy>(&self) -> Vec<T> {
102208
self.0
103-
.chunks(T::BITS as usize)
209+
.as_vec()
210+
.chunks(T::bits() as usize)
104211
.map(int_from_bit_slice)
105212
.collect()
106213
}
107214

108215
/// Generate a random BitVec.
109216
pub fn rand() -> Self {
110217
use rand::prelude::*;
111-
let mut rng = rand::rng();
112-
Self::from_fn(|_| rng.random::<bool>().into())
218+
let random_source: Vec<_> = {
219+
let mut rng = rand::rng();
220+
(0..N).map(|_| rng.random::<bool>()).collect()
221+
};
222+
Self::from_fn(|i| random_source[i as usize].into())
223+
}
224+
}
225+
226+
#[hax_lib::attributes]
227+
impl<const N: u64> BitVec<N> {
228+
#[hax_lib::requires(CHUNK > 0 && CHUNK.to_int() * SHIFTS.to_int() == N.to_int())]
229+
pub fn chunked_shift<const CHUNK: u64, const SHIFTS: u64>(
230+
self,
231+
shl: FunArray<SHIFTS, i128>,
232+
) -> BitVec<N> {
233+
// TODO: this inner method is because of https://github.com/cryspen/hax-evit/issues/29
234+
#[hax_lib::fstar::options("--z3rlimit 50 --split_queries always")]
235+
#[hax_lib::requires(CHUNK > 0 && CHUNK.to_int() * SHIFTS.to_int() == N.to_int())]
236+
fn chunked_shift<const N: u64, const CHUNK: u64, const SHIFTS: u64>(
237+
bitvec: BitVec<N>,
238+
shl: FunArray<SHIFTS, i128>,
239+
) -> BitVec<N> {
240+
BitVec::from_fn(|i| {
241+
let nth_bit = i % CHUNK;
242+
let nth_chunk = i / CHUNK;
243+
hax_lib::assert_prop!(nth_chunk.to_int() <= SHIFTS.to_int() - int!(1));
244+
hax_lib::assert_prop!(
245+
nth_chunk.to_int() * CHUNK.to_int()
246+
<= (SHIFTS.to_int() - int!(1)) * CHUNK.to_int()
247+
);
248+
let shift: i128 = if nth_chunk < SHIFTS {
249+
shl[nth_chunk]
250+
} else {
251+
0
252+
};
253+
let local_index = (nth_bit as i128).wrapping_sub(shift);
254+
if local_index < CHUNK as i128 && local_index >= 0 {
255+
let local_index = local_index as u64;
256+
hax_lib::assert_prop!(
257+
nth_chunk.to_int() * CHUNK.to_int() + local_index.to_int()
258+
< SHIFTS.to_int() * CHUNK.to_int()
259+
);
260+
bitvec[nth_chunk * CHUNK + local_index]
261+
} else {
262+
Bit::Zero
263+
}
264+
})
265+
}
266+
chunked_shift::<N, CHUNK, SHIFTS>(self, shl)
113267
}
114268
}

0 commit comments

Comments
 (0)