Skip to content

Commit ddb521c

Browse files
Div by zero (#31)
Made some minor optimizations to udivmod4. Speed is not noticeably improved, but panic statements may be more helpful and binaries may be smaller. --------- Co-authored-by: Nicholas Rodrigues Lordello <[email protected]>
1 parent 65d2645 commit ddb521c

File tree

1 file changed

+64
-20
lines changed

1 file changed

+64
-20
lines changed

src/intrinsics/native/divmod.rs

Lines changed: 64 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -10,45 +10,65 @@
1010
//! - unsigned division: <https://github.com/llvm/llvm-project/blob/main/compiler-rt/lib/builtins/udivmodti4.c>
1111
1212
use crate::{int::I256, uint::U256};
13-
use core::mem::MaybeUninit;
13+
use core::{mem::MaybeUninit, num::NonZeroU128};
1414

1515
#[inline(always)]
16-
fn udiv256_by_128_to_128(u1: u128, u0: u128, mut v: u128, r: &mut u128) -> u128 {
16+
fn udiv256_by_128_to_128(u1: u128, u0: u128, mut v: NonZeroU128, r: &mut u128) -> u128 {
1717
const N_UDWORD_BITS: u32 = 128;
18+
19+
#[inline]
20+
unsafe fn shl_nz(x: NonZeroU128, n: u32) -> NonZeroU128 {
21+
debug_assert!(n < N_UDWORD_BITS);
22+
let res: u128 = x.get() << n;
23+
debug_assert_ne!(res, 0);
24+
NonZeroU128::new_unchecked(res)
25+
}
26+
27+
#[inline]
28+
unsafe fn shr_nz(x: NonZeroU128, n: u32) -> NonZeroU128 {
29+
debug_assert!(n < N_UDWORD_BITS);
30+
let res: u128 = x.get() >> n;
31+
debug_assert_ne!(res, 0);
32+
NonZeroU128::new_unchecked(res)
33+
}
34+
1835
const B: u128 = 1 << (N_UDWORD_BITS / 2); // Number base (128 bits)
1936
let (un1, un0): (u128, u128); // Norm. dividend LSD's
20-
let (vn1, vn0): (u128, u128); // Norm. divisor digits
37+
let (vn1, vn0): (NonZeroU128, u128); // Norm. divisor digits
2138
let (mut q1, mut q0): (u128, u128); // Quotient digits
2239
let (un128, un21, un10): (u128, u128, u128); // Dividend digit pairs
2340

41+
debug_assert!(v.get() > u1);
42+
2443
let s = v.leading_zeros();
44+
debug_assert_ne!(s, N_UDWORD_BITS);
2545
if s > 0 {
2646
// Normalize the divisor.
27-
v <<= s;
47+
v = unsafe { shl_nz(v, s) };
2848
un128 = (u1 << s) | (u0 >> (N_UDWORD_BITS - s));
2949
un10 = u0 << s; // Shift dividend left
3050
} else {
31-
// Avoid undefined behavior of (u0 >> 64).
51+
// Avoid undefined behavior of (u0 >> 128).
3252
un128 = u1;
3353
un10 = u0;
3454
}
3555

3656
// Break divisor up into two 64-bit digits.
37-
vn1 = v >> (N_UDWORD_BITS / 2);
38-
vn0 = v & 0xFFFF_FFFF_FFFF_FFFF;
57+
vn1 = unsafe { shr_nz(v, N_UDWORD_BITS / 2) };
58+
vn0 = v.get() & 0xFFFF_FFFF_FFFF_FFFF;
3959

4060
// Break right half of dividend into two digits.
4161
un1 = un10 >> (N_UDWORD_BITS / 2);
4262
un0 = un10 & 0xFFFF_FFFF_FFFF_FFFF;
4363

4464
// Compute the first quotient digit, q1.
4565
q1 = un128 / vn1;
46-
let mut rhat = un128 - q1 * vn1;
66+
let mut rhat = un128 - q1 * vn1.get();
4767

4868
// q1 has at most error 2. No more than 2 iterations.
4969
while q1 >= B || q1 * vn0 > B * rhat + un1 {
5070
q1 -= 1;
51-
rhat += vn1;
71+
rhat += vn1.get();
5272
if rhat >= B {
5373
break;
5474
}
@@ -57,16 +77,16 @@ fn udiv256_by_128_to_128(u1: u128, u0: u128, mut v: u128, r: &mut u128) -> u128
5777
un21 = un128
5878
.wrapping_mul(B)
5979
.wrapping_add(un1)
60-
.wrapping_sub(q1.wrapping_mul(v));
80+
.wrapping_sub(q1.wrapping_mul(v.get()));
6181

6282
// Compute the second quotient digit.
6383
q0 = un21 / vn1;
64-
rhat = un21 - q0 * vn1;
84+
rhat = un21 - q0 * vn1.get();
6585

6686
// q0 has at most error 2. No more than 2 iterations.
6787
while q0 >= B || q0 * vn0 > B * rhat + un0 {
6888
q0 -= 1;
69-
rhat += vn1;
89+
rhat += vn1.get();
7090
if rhat >= B {
7191
break;
7292
}
@@ -75,7 +95,7 @@ fn udiv256_by_128_to_128(u1: u128, u0: u128, mut v: u128, r: &mut u128) -> u128
7595
*r = (un21
7696
.wrapping_mul(B)
7797
.wrapping_add(un0)
78-
.wrapping_sub(q0.wrapping_mul(v)))
98+
.wrapping_sub(q0.wrapping_mul(v.get())))
7999
>> s;
80100
q1 * B + q0
81101
}
@@ -101,10 +121,10 @@ pub fn udivmod4(
101121
// Unfortunately, there is no 256-bit equivalent on x86_64, but we can still
102122
// shortcut if the high and low values of the operands are 0:
103123
if a.high() | b.high() == 0 {
124+
res.write(U256::from_words(0, a.low() / b.low()));
104125
if let Some(rem) = rem {
105126
rem.write(U256::from_words(0, a.low() % b.low()));
106127
}
107-
res.write(U256::from_words(0, a.low() / b.low()));
108128
return;
109129
}
110130

@@ -130,7 +150,8 @@ pub fn udivmod4(
130150
udiv256_by_128_to_128(
131151
*dividend.high(),
132152
*dividend.low(),
133-
*divisor.low(),
153+
// SAFETY: dividend.high() < divisor.low()
154+
unsafe { NonZeroU128::new_unchecked(*divisor.low()) },
134155
remainder.low_mut(),
135156
),
136157
);
@@ -142,7 +163,8 @@ pub fn udivmod4(
142163
udiv256_by_128_to_128(
143164
dividend.high() % divisor.low(),
144165
*dividend.low(),
145-
*divisor.low(),
166+
// SAFETY: dividend.high() / divisor.low()
167+
unsafe { NonZeroU128::new_unchecked(*divisor.low()) },
146168
remainder.low_mut(),
147169
),
148170
);
@@ -154,7 +176,8 @@ pub fn udivmod4(
154176
return;
155177
}
156178

157-
(quotient, remainder) = div_mod_knuth(&dividend, &divisor);
179+
// SAFETY: `*divisor.high() != 0`
180+
(quotient, remainder) = unsafe { div_mod_knuth(&dividend, &divisor) };
158181

159182
if let Some(rem) = rem {
160183
rem.write(remainder);
@@ -164,9 +187,18 @@ pub fn udivmod4(
164187

165188
// See Knuth, TAOCP, Volume 2, section 4.3.1, Algorithm D.
166189
// https://skanthak.homepage.t-online.de/division.html
190+
// SAFETY: The high word of v (the divisor) must be non-zero.
167191
#[inline]
168-
pub fn div_mod_knuth(u: &U256, v: &U256) -> (U256, U256) {
192+
unsafe fn div_mod_knuth(u: &U256, v: &U256) -> (U256, U256) {
169193
const N_UDWORD_BITS: u32 = 128;
194+
debug_assert_ne!(
195+
*u.high(),
196+
0,
197+
"The second operand must be greater than u128::MAX"
198+
);
199+
if *u.high() == 0 {
200+
unsafe { core::hint::unreachable_unchecked() }
201+
}
170202

171203
#[inline]
172204
fn full_shl(a: &U256, shift: u32) -> [u128; 3] {
@@ -266,7 +298,6 @@ pub fn div_mod_knuth(u: &U256, v: &U256) -> (U256, U256) {
266298
let shift = v.high().leading_zeros();
267299
debug_assert!(shift < N_UDWORD_BITS);
268300
let v = v << shift;
269-
debug_assert!(v.high() >> (N_UDWORD_BITS - 1) == 1);
270301
// u will store the remainder (shifted)
271302
let mut u = full_shl(u, shift);
272303

@@ -275,6 +306,14 @@ pub fn div_mod_knuth(u: &U256, v: &U256) -> (U256, U256) {
275306
let v_n_1 = *v.high();
276307
let v_n_2 = *v.low();
277308

309+
if v_n_1 >> (N_UDWORD_BITS - 1) != 1 {
310+
debug_assert!(false);
311+
312+
// SAFETY: `v_n_1` must be normalized because input `v` has
313+
// been checked to be non-zero.
314+
unsafe { core::hint::unreachable_unchecked() }
315+
}
316+
278317
// D2. D7. - unrolled loop j == 0, n == 2, m == 0 (only one possible iteration)
279318
let mut r_hat: u128 = 0;
280319
let u_jn = u[2];
@@ -286,7 +325,12 @@ pub fn div_mod_knuth(u: &U256, v: &U256) -> (U256, U256) {
286325
// Theorem B: q_hat >= q_j >= q_hat - 2
287326
let mut q_hat = if u_jn < v_n_1 {
288327
//let (mut q_hat, mut r_hat) = _div_mod_u128(u_jn, u[j + n - 1], v_n_1);
289-
let mut q_hat = udiv256_by_128_to_128(u_jn, u[1], v_n_1, &mut r_hat);
328+
let mut q_hat = udiv256_by_128_to_128(
329+
u_jn,
330+
u[1],
331+
unsafe { NonZeroU128::new_unchecked(v_n_1) },
332+
&mut r_hat,
333+
);
290334
let mut overflow: bool;
291335
// this loop takes at most 2 iterations
292336
loop {

0 commit comments

Comments
 (0)