Skip to content

Commit a0dff59

Browse files
committed
perf: branchless square root implementation
Estimated gas reduction from 798.1 to 407 gas.
1 parent bde5ec3 commit a0dff59

File tree

3 files changed

+132
-90
lines changed

3 files changed

+132
-90
lines changed

src/Common.sol

Lines changed: 30 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -330,53 +330,23 @@ function exp2(uint256 x) pure returns (uint256 result) {
330330
/// @return result The index of the most significant bit as a uint256.
331331
/// @custom:smtchecker abstract-function-nondet
332332
function msb(uint256 x) pure returns (uint256 result) {
333-
// 2^128
334333
assembly ("memory-safe") {
335-
let factor := shl(7, gt(x, 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF))
336-
x := shr(factor, x)
337-
result := or(result, factor)
338-
}
339-
// 2^64
340-
assembly ("memory-safe") {
341-
let factor := shl(6, gt(x, 0xFFFFFFFFFFFFFFFF))
342-
x := shr(factor, x)
343-
result := or(result, factor)
344-
}
345-
// 2^32
346-
assembly ("memory-safe") {
347-
let factor := shl(5, gt(x, 0xFFFFFFFF))
348-
x := shr(factor, x)
349-
result := or(result, factor)
350-
}
351-
// 2^16
352-
assembly ("memory-safe") {
353-
let factor := shl(4, gt(x, 0xFFFF))
354-
x := shr(factor, x)
355-
result := or(result, factor)
356-
}
357-
// 2^8
358-
assembly ("memory-safe") {
359-
let factor := shl(3, gt(x, 0xFF))
360-
x := shr(factor, x)
361-
result := or(result, factor)
362-
}
363-
// 2^4
364-
assembly ("memory-safe") {
365-
let factor := shl(2, gt(x, 0xF))
366-
x := shr(factor, x)
367-
result := or(result, factor)
368-
}
369-
// 2^2
370-
assembly ("memory-safe") {
371-
let factor := shl(1, gt(x, 0x3))
372-
x := shr(factor, x)
373-
result := or(result, factor)
374-
}
375-
// 2^1
376-
// No need to shift x any more.
377-
assembly ("memory-safe") {
378-
let factor := gt(x, 0x1)
379-
result := or(result, factor)
334+
// 2^128
335+
result := shl(7, lt(0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF, x))
336+
// 2^64
337+
result := or(result, shl(6, lt(0xFFFFFFFFFFFFFFFF, shr(result, x))))
338+
// 2^32
339+
result := or(result, shl(5, lt(0xFFFFFFFF, shr(result, x))))
340+
// 2^16
341+
result := or(result, shl(4, lt(0xFFFF, shr(result, x))))
342+
// 2^8
343+
result := or(result, shl(3, lt(0xFF, shr(result, x))))
344+
// 2^4
345+
result := or(result, shl(2, lt(0xF, shr(result, x))))
346+
// 2^2
347+
result := or(result, shl(1, lt(0x3, shr(result, x))))
348+
// 2^1
349+
result := or(result, lt(0x1, shr(result, x)))
380350
}
381351
}
382352

@@ -614,10 +584,6 @@ function mulDivSigned(int256 x, int256 y, int256 denominator) pure returns (int2
614584
/// @return result The result as a uint256.
615585
/// @custom:smtchecker abstract-function-nondet
616586
function sqrt(uint256 x) pure returns (uint256 result) {
617-
if (x == 0) {
618-
return 0;
619-
}
620-
621587
// For our first guess, we calculate the biggest power of 2 which is smaller than the square root of x.
622588
//
623589
// We know that the "msb" (most significant bit) of x is a power of 2 such that we have:
@@ -641,53 +607,27 @@ function sqrt(uint256 x) pure returns (uint256 result) {
641607
// $$
642608
//
643609
// Consequently, $2^{log_2(x) /2} is a good first approximation of sqrt(x) with at least one correct bit.
644-
uint256 xAux = uint256(x);
645-
result = 1;
646-
if (xAux >= 2 ** 128) {
647-
xAux >>= 128;
648-
result <<= 64;
649-
}
650-
if (xAux >= 2 ** 64) {
651-
xAux >>= 64;
652-
result <<= 32;
653-
}
654-
if (xAux >= 2 ** 32) {
655-
xAux >>= 32;
656-
result <<= 16;
657-
}
658-
if (xAux >= 2 ** 16) {
659-
xAux >>= 16;
660-
result <<= 8;
661-
}
662-
if (xAux >= 2 ** 8) {
663-
xAux >>= 8;
664-
result <<= 4;
665-
}
666-
if (xAux >= 2 ** 4) {
667-
xAux >>= 4;
668-
result <<= 2;
669-
}
670-
if (xAux >= 2 ** 2) {
671-
result <<= 1;
610+
unchecked {
611+
// ideally, we should use arithmetic operators, but solc is not smart enough to optimize `2**(msb(x)/2)`
612+
/// forge-lint: disable-next-line(incorrect-shift)
613+
result = 1 << (msb(x) >> 1);
672614
}
673615

674616
// At this point, `result` is an estimation with at least one bit of precision. We know the true value has at
675617
// most 128 bits, since it is the square root of a uint256. Newton's method converges quadratically (precision
676618
// doubles at every iteration). We thus need at most 7 iteration to turn our partial result with one bit of
677619
// precision into the expected uint128 result.
678-
unchecked {
679-
result = (result + x / result) >> 1;
680-
result = (result + x / result) >> 1;
681-
result = (result + x / result) >> 1;
682-
result = (result + x / result) >> 1;
683-
result = (result + x / result) >> 1;
684-
result = (result + x / result) >> 1;
685-
result = (result + x / result) >> 1;
620+
assembly ("memory-safe") {
621+
// note: division by zero in EVM returns zero
622+
result := shr(1, add(result, div(x, result)))
623+
result := shr(1, add(result, div(x, result)))
624+
result := shr(1, add(result, div(x, result)))
625+
result := shr(1, add(result, div(x, result)))
626+
result := shr(1, add(result, div(x, result)))
627+
result := shr(1, add(result, div(x, result)))
628+
result := shr(1, add(result, div(x, result)))
686629

687630
// If x is not a perfect square, round the result toward zero.
688-
uint256 roundedResult = x / result;
689-
if (result >= roundedResult) {
690-
result = roundedResult;
691-
}
631+
result := sub(result, gt(result, div(x, result)))
692632
}
693633
}

test/fuzz/common/msb.t.sol

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,46 @@ import { msb } from "src/Common.sol";
55

66
import { Base_Test } from "../../Base.t.sol";
77

8+
/// @dev High level implementation of `msb`, for verifying regressions.
9+
///
10+
/// From https://gist.github.com/PaulRBerg/f932f8693f2733e30c4d479e8e980948
11+
function msbReferenceImplementation(uint256 x) pure returns (uint256 result) {
12+
unchecked {
13+
if (x >= 2 ** 128) {
14+
x >>= 128;
15+
result += 128;
16+
}
17+
if (x >= 2 ** 64) {
18+
x >>= 64;
19+
result += 64;
20+
}
21+
if (x >= 2 ** 32) {
22+
x >>= 32;
23+
result += 32;
24+
}
25+
if (x >= 2 ** 16) {
26+
x >>= 16;
27+
result += 16;
28+
}
29+
if (x >= 2 ** 8) {
30+
x >>= 8;
31+
result += 8;
32+
}
33+
if (x >= 2 ** 4) {
34+
x >>= 4;
35+
result += 4;
36+
}
37+
if (x >= 2 ** 2) {
38+
x >>= 2;
39+
result += 2;
40+
}
41+
// No need to shift x any more.
42+
if (x >= 2 ** 1) {
43+
result += 1;
44+
}
45+
}
46+
}
47+
848
/// @dev Collection of tests for the most significant bit function `msb` available in `Common.sol`.
949
contract Common_Msb_Test is Base_Test {
1050
function testFuzz_Msb_FitsUint8(uint256 x) external pure {
@@ -33,4 +73,8 @@ contract Common_Msb_Test is Base_Test {
3373
function testFuzz_Msb_Shifts2ToMoreThanX(uint256 x) external pure whenShiftLeftDoesNotOverflow(x) {
3474
assertGt(2 << msb(x), x, "2 ^ {msb(x)+1} not more than x");
3575
}
76+
77+
function testFuzz_Msb_MatchesReferenceImplementation(uint256 x) external pure {
78+
assertEq(msb(x), msbReferenceImplementation(x), "does not match reference implementation of msb");
79+
}
3680
}

test/fuzz/common/sqrt.t.sol

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,60 @@ import { sqrt, MAX_UINT128 } from "src/Common.sol";
55

66
import { Base_Test } from "../../Base.t.sol";
77

8+
/// @dev Previous implementation of `sqrt`, for verifying regressions.
9+
///
10+
/// From https://github.com/PaulRBerg/prb-math/blob/v4.1.0/src/Common.sol#L587-L675
11+
function sqrtReferenceImplemenation(uint256 x) pure returns (uint256 result) {
12+
if (x == 0) {
13+
return 0;
14+
}
15+
16+
uint256 xAux = uint256(x);
17+
result = 1;
18+
if (xAux >= 2 ** 128) {
19+
xAux >>= 128;
20+
result <<= 64;
21+
}
22+
if (xAux >= 2 ** 64) {
23+
xAux >>= 64;
24+
result <<= 32;
25+
}
26+
if (xAux >= 2 ** 32) {
27+
xAux >>= 32;
28+
result <<= 16;
29+
}
30+
if (xAux >= 2 ** 16) {
31+
xAux >>= 16;
32+
result <<= 8;
33+
}
34+
if (xAux >= 2 ** 8) {
35+
xAux >>= 8;
36+
result <<= 4;
37+
}
38+
if (xAux >= 2 ** 4) {
39+
xAux >>= 4;
40+
result <<= 2;
41+
}
42+
if (xAux >= 2 ** 2) {
43+
result <<= 1;
44+
}
45+
46+
unchecked {
47+
result = (result + x / result) >> 1;
48+
result = (result + x / result) >> 1;
49+
result = (result + x / result) >> 1;
50+
result = (result + x / result) >> 1;
51+
result = (result + x / result) >> 1;
52+
result = (result + x / result) >> 1;
53+
result = (result + x / result) >> 1;
54+
55+
uint256 roundedResult = x / result;
56+
if (result >= roundedResult) {
57+
result = roundedResult;
58+
}
59+
}
60+
}
61+
862
/// @dev Collection of tests for the square root function `sqrt` available in `Common.sol`.
963
contract Common_Sqrt_Test is Base_Test {
1064
uint256 internal constant MAX_SQRT = MAX_UINT128;
@@ -37,4 +91,8 @@ contract Common_Sqrt_Test is Base_Test {
3791
vm.assertLe(sqrt(x) ** 2, x, "incorrect sqrt of very large number");
3892
vm.assertLt(x, (sqrt(x) + 1) ** 2, "incorrect sqrt of very large number");
3993
}
94+
95+
function testFuzz_Sqrt_MatchesReferenceImplementation(uint256 x) external pure {
96+
assertEq(sqrt(x), sqrtReferenceImplemenation(x), "does not match reference implementation of sqrt");
97+
}
4098
}

0 commit comments

Comments
 (0)