Skip to content

Commit e33446b

Browse files
committed
Implemented a mul_div operation for Uints and reduced overflow risks in inflation computations.
1 parent ab20766 commit e33446b

2 files changed

Lines changed: 315 additions & 21 deletions

File tree

core/src/ledger/storage/masp_conversions.rs

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -126,11 +126,13 @@ where
126126
let noterized_inflation = if total_token_in_masp.is_zero() {
127127
0u128
128128
} else {
129-
crate::types::uint::Uint::try_into(
130-
(inflation * crate::types::uint::Uint::from(precision))
131-
/ total_token_in_masp.raw_amount(),
132-
)
133-
.unwrap()
129+
inflation
130+
.checked_mul_div(
131+
crate::types::uint::Uint::from(precision),
132+
total_token_in_masp.raw_amount(),
133+
)
134+
.0
135+
.map_or(u128::MAX, |x| x.try_into().unwrap_or(u128::MAX))
134136
};
135137

136138
tracing::debug!(
@@ -159,21 +161,17 @@ where
159161
// but we should make sure the return value's ratio matches
160162
// this new inflation rate in 'update_allowed_conversions',
161163
// otherwise we will have an inaccurate view of inflation
162-
wl_storage
163-
.write(
164-
&token::masp_last_inflation_key(addr),
165-
token::Amount::from_uint(
166-
(total_token_in_masp.raw_amount() / precision)
167-
* crate::types::uint::Uint::from(noterized_inflation),
168-
0,
169-
)
170-
.unwrap(),
164+
wl_storage.write(
165+
&token::masp_last_inflation_key(addr),
166+
token::Amount::from_uint(
167+
(total_token_in_masp.raw_amount() / precision)
168+
* crate::types::uint::Uint::from(noterized_inflation),
169+
0,
171170
)
172-
.expect("unable to encode new inflation rate (Decimal)");
171+
.unwrap(),
172+
)?;
173173

174-
wl_storage
175-
.write(&token::masp_last_locked_ratio_key(addr), locked_ratio)
176-
.expect("unable to encode new locked ratio (Decimal)");
174+
wl_storage.write(&token::masp_last_locked_ratio_key(addr), locked_ratio)?;
177175

178176
Ok((noterized_inflation, precision))
179177
}

core/src/types/uint.rs

Lines changed: 299 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,272 @@ pub const ZERO: Uint = Uint::from_u64(0);
2121
pub const ONE: Uint = Uint::from_u64(1);
2222

2323
impl Uint {
24+
const N_WORDS: usize = 4;
25+
2426
/// Convert a [`u64`] to a [`Uint`].
2527
pub const fn from_u64(x: u64) -> Uint {
2628
Uint([x.to_le(), 0, 0, 0])
2729
}
30+
31+
/// Return the least number of bits needed to represent the number
32+
#[inline]
33+
pub fn bits_512(arr: &[u64; 2 * Self::N_WORDS]) -> usize {
34+
for i in 1..arr.len() {
35+
if arr[arr.len() - i] > 0 {
36+
return (0x40 * (arr.len() - i + 1))
37+
- arr[arr.len() - i].leading_zeros() as usize;
38+
}
39+
}
40+
0x40 - arr[0].leading_zeros() as usize
41+
}
42+
43+
fn div_mod_small_512(
44+
mut slf: [u64; 2 * Self::N_WORDS],
45+
other: u64,
46+
) -> ([u64; 2 * Self::N_WORDS], Self) {
47+
let mut rem = 0u64;
48+
slf.iter_mut().rev().for_each(|d| {
49+
let (q, r) = Self::div_mod_word(rem, *d, other);
50+
*d = q;
51+
rem = r;
52+
});
53+
(slf, rem.into())
54+
}
55+
56+
fn shr_512(
57+
original: [u64; 2 * Self::N_WORDS],
58+
shift: u32,
59+
) -> [u64; 2 * Self::N_WORDS] {
60+
let shift = shift as usize;
61+
let mut ret = [0u64; 2 * Self::N_WORDS];
62+
let word_shift = shift / 64;
63+
let bit_shift = shift % 64;
64+
65+
// shift
66+
for i in word_shift..original.len() {
67+
ret[i - word_shift] = original[i] >> bit_shift;
68+
}
69+
70+
// Carry
71+
if bit_shift > 0 {
72+
for i in word_shift + 1..original.len() {
73+
ret[i - word_shift - 1] += original[i] << (64 - bit_shift);
74+
}
75+
}
76+
77+
ret
78+
}
79+
80+
fn full_shl_512(
81+
slf: [u64; 2 * Self::N_WORDS],
82+
shift: u32,
83+
) -> [u64; 2 * Self::N_WORDS + 1] {
84+
debug_assert!(shift < Self::WORD_BITS as u32);
85+
let mut u = [0u64; 2 * Self::N_WORDS + 1];
86+
let u_lo = slf[0] << shift;
87+
let u_hi = Self::shr_512(slf, Self::WORD_BITS as u32 - shift);
88+
u[0] = u_lo;
89+
u[1..].copy_from_slice(&u_hi[..]);
90+
u
91+
}
92+
93+
fn full_shr_512(
94+
u: [u64; 2 * Self::N_WORDS + 1],
95+
shift: u32,
96+
) -> [u64; 2 * Self::N_WORDS] {
97+
debug_assert!(shift < Self::WORD_BITS as u32);
98+
let mut res = [0; 2 * Self::N_WORDS];
99+
for i in 0..res.len() {
100+
res[i] = u[i] >> shift;
101+
}
102+
// carry
103+
if shift > 0 {
104+
for i in 1..=res.len() {
105+
res[i - 1] |= u[i] << (Self::WORD_BITS as u32 - shift);
106+
}
107+
}
108+
res
109+
}
110+
111+
// See Knuth, TAOCP, Volume 2, section 4.3.1, Algorithm D.
112+
fn div_mod_knuth_512(
113+
slf: [u64; 2 * Self::N_WORDS],
114+
mut v: Self,
115+
n: usize,
116+
m: usize,
117+
) -> ([u64; 2 * Self::N_WORDS], Self) {
118+
debug_assert!(Self::bits_512(&slf) >= v.bits() && !v.fits_word());
119+
debug_assert!(n + m <= slf.len());
120+
// D1.
121+
// Make sure 64th bit in v's highest word is set.
122+
// If we shift both self and v, it won't affect the quotient
123+
// and the remainder will only need to be shifted back.
124+
let shift = v.0[n - 1].leading_zeros();
125+
v <<= shift;
126+
// u will store the remainder (shifted)
127+
let mut u = Self::full_shl_512(slf, shift);
128+
129+
// quotient
130+
let mut q = [0; 2 * Self::N_WORDS];
131+
let v_n_1 = v.0[n - 1];
132+
let v_n_2 = v.0[n - 2];
133+
134+
// D2. D7.
135+
// iterate from m downto 0
136+
for j in (0..=m).rev() {
137+
let u_jn = u[j + n];
138+
139+
// D3.
140+
// q_hat is our guess for the j-th quotient digit
141+
// q_hat = min(b - 1, (u_{j+n} * b + u_{j+n-1}) / v_{n-1})
142+
// b = 1 << WORD_BITS
143+
// Theorem B: q_hat >= q_j >= q_hat - 2
144+
let mut q_hat = if u_jn < v_n_1 {
145+
let (mut q_hat, mut r_hat) =
146+
Self::div_mod_word(u_jn, u[j + n - 1], v_n_1);
147+
// this loop takes at most 2 iterations
148+
loop {
149+
// check if q_hat * v_{n-2} > b * r_hat + u_{j+n-2}
150+
let (hi, lo) =
151+
Self::split_u128(u128::from(q_hat) * u128::from(v_n_2));
152+
if (hi, lo) <= (r_hat, u[j + n - 2]) {
153+
break;
154+
}
155+
// then iterate till it doesn't hold
156+
q_hat -= 1;
157+
let (new_r_hat, overflow) = r_hat.overflowing_add(v_n_1);
158+
r_hat = new_r_hat;
159+
// if r_hat overflowed, we're done
160+
if overflow {
161+
break;
162+
}
163+
}
164+
q_hat
165+
} else {
166+
// here q_hat >= q_j >= q_hat - 1
167+
u64::max_value()
168+
};
169+
170+
// ex. 20:
171+
// since q_hat * v_{n-2} <= b * r_hat + u_{j+n-2},
172+
// either q_hat == q_j, or q_hat == q_j + 1
173+
174+
// D4.
175+
// let's assume optimistically q_hat == q_j
176+
// subtract (q_hat * v) from u[j..]
177+
let q_hat_v = v.full_mul_u64(q_hat);
178+
// u[j..] -= q_hat_v;
179+
let c = Self::sub_slice(&mut u[j..], &q_hat_v[..n + 1]);
180+
181+
// D6.
182+
// actually, q_hat == q_j + 1 and u[j..] has overflowed
183+
// highly unlikely ~ (1 / 2^63)
184+
if c {
185+
q_hat -= 1;
186+
// add v to u[j..]
187+
let c = Self::add_slice(&mut u[j..], &v.0[..n]);
188+
u[j + n] = u[j + n].wrapping_add(u64::from(c));
189+
}
190+
191+
// D5.
192+
q[j] = q_hat;
193+
}
194+
195+
// D8.
196+
let remainder = Self::full_shr_512(u, shift);
197+
// The remainder should never exceed the capacity of Self
198+
debug_assert!(
199+
Self::bits_512(&remainder) <= Self::N_WORDS * Self::WORD_BITS
200+
);
201+
(q, Self(remainder[..Self::N_WORDS].try_into().unwrap()))
202+
}
203+
204+
/// Returns a pair `(self / other, self % other)`.
205+
///
206+
/// # Panics
207+
///
208+
/// Panics if `other` is zero.
209+
pub fn div_mod_512(
210+
slf: [u64; 2 * Self::N_WORDS],
211+
other: Self,
212+
) -> ([u64; 2 * Self::N_WORDS], Self) {
213+
let my_bits = Self::bits_512(&slf);
214+
let your_bits = other.bits();
215+
216+
assert!(your_bits != 0, "division by zero");
217+
218+
// Early return in case we are dividing by a larger number than us
219+
if my_bits < your_bits {
220+
return (
221+
[0; 2 * Self::N_WORDS],
222+
Self(slf[..Self::N_WORDS].try_into().unwrap()),
223+
);
224+
}
225+
226+
if your_bits <= Self::WORD_BITS {
227+
return Self::div_mod_small_512(slf, other.low_u64());
228+
}
229+
230+
let (n, m) = {
231+
let my_words = Self::words(my_bits);
232+
let your_words = Self::words(your_bits);
233+
(your_words, my_words - your_words)
234+
};
235+
236+
Self::div_mod_knuth_512(slf, other, n, m)
237+
}
238+
239+
/// Returns a pair `(Some((self * num) / denom), (self * num) % denom)` if
240+
/// the quotient fits into Self. Otherwise `(None, (self * num) % denom)` is
241+
/// returned.
242+
///
243+
/// # Panics
244+
///
245+
/// Panics if `denom` is zero.
246+
pub fn checked_mul_div(
247+
&self,
248+
num: Self,
249+
denom: Self,
250+
) -> (Option<Self>, Self) {
251+
let prod = uint::uint_full_mul_reg!(Uint, 4, self, num);
252+
let (quotient, remainder) = Self::div_mod_512(prod, denom);
253+
// The compiler WILL NOT inline this if you remove this annotation.
254+
#[inline(always)]
255+
fn any_nonzero(arr: &[u64]) -> bool {
256+
use uint::unroll;
257+
unroll! {
258+
for i in 0..4 {
259+
if arr[i] != 0 {
260+
return true;
261+
}
262+
}
263+
}
264+
265+
false
266+
}
267+
(
268+
if any_nonzero(&quotient[Self::N_WORDS..]) {
269+
None
270+
} else {
271+
Some(Self(quotient[0..Self::N_WORDS].try_into().unwrap()))
272+
},
273+
remainder,
274+
)
275+
}
276+
277+
/// Returns a pair `((self * num) / denom, (self * num) % denom)`.
278+
///
279+
/// # Panics
280+
///
281+
/// Panics if `denom` is zero.
282+
pub fn mul_div(&self, num: Self, denom: Self) -> (Self, Self) {
283+
let prod = uint::uint_full_mul_reg!(Uint, 4, self, num);
284+
let (quotient, remainder) = Self::div_mod_512(prod, denom);
285+
(
286+
Self(quotient[0..Self::N_WORDS].try_into().unwrap()),
287+
remainder,
288+
)
289+
}
28290
}
29291

30292
construct_uint! {
@@ -171,10 +433,9 @@ impl Uint {
171433
/// * `self` * 10^(`denom`) overflows 256 bits
172434
/// * `other` is zero (`checked_div` will return `None`).
173435
pub fn fixed_precision_div(&self, rhs: &Self, denom: u8) -> Option<Self> {
174-
let lhs = Uint::from(10)
436+
Uint::from(10)
175437
.checked_pow(Uint::from(denom))
176-
.and_then(|res| res.checked_mul(*self))?;
177-
lhs.checked_div(*rhs)
438+
.and_then(|res| res.checked_mul_div(*self, *rhs).0)
178439
}
179440

180441
/// Compute the two's complement of a number.
@@ -710,4 +971,39 @@ mod test_uint {
710971
let amount: Result<Uint, _> = serde_json::from_str(r#""1000000000.2""#);
711972
assert!(amount.is_err());
712973
}
974+
975+
#[test]
976+
fn test_mul_div() {
977+
use std::str::FromStr;
978+
let a: Uint = Uint::from_str(
979+
"0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF",
980+
).unwrap();
981+
let b: Uint = Uint::from_str(
982+
"0x8000000000000000000000000000000000000000000000000000000000000000",
983+
).unwrap();
984+
let c: Uint = Uint::from_str(
985+
"0x4000000000000000000000000000000000000000000000000000000000000000",
986+
).unwrap();
987+
let d: Uint = Uint::from_str(
988+
"0x7FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF",
989+
).unwrap();
990+
let e: Uint = Uint::from_str(
991+
"0x0000000000000000000000000000000000000000000000000000000000000001",
992+
).unwrap();
993+
let f: Uint = Uint::from_str(
994+
"0x0000000000000000000000000000000000000000000000000000000000000000",
995+
).unwrap();
996+
assert_eq!(a.mul_div(a, a), (a, Uint::zero()));
997+
assert_eq!(b.mul_div(c, b), (c, Uint::zero()));
998+
assert_eq!(a.mul_div(c, b), (d, c));
999+
assert_eq!(a.mul_div(e, e), (a, Uint::zero()));
1000+
assert_eq!(e.mul_div(c, b), (Uint::zero(), c));
1001+
assert_eq!(f.mul_div(a, e), (Uint::zero(), Uint::zero()));
1002+
assert_eq!(a.checked_mul_div(a, a), (Some(a), Uint::zero()));
1003+
assert_eq!(b.checked_mul_div(c, b), (Some(c), Uint::zero()));
1004+
assert_eq!(a.checked_mul_div(c, b), (Some(d), c));
1005+
assert_eq!(a.checked_mul_div(e, e), (Some(a), Uint::zero()));
1006+
assert_eq!(e.checked_mul_div(c, b), (Some(Uint::zero()), c));
1007+
assert_eq!(d.checked_mul_div(a, e), (None, Uint::zero()));
1008+
}
7131009
}

0 commit comments

Comments
 (0)