Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 30 additions & 90 deletions src/Common.sol
Original file line number Diff line number Diff line change
Expand Up @@ -330,53 +330,23 @@ function exp2(uint256 x) pure returns (uint256 result) {
/// @return result The index of the most significant bit as a uint256.
/// @custom:smtchecker abstract-function-nondet
function msb(uint256 x) pure returns (uint256 result) {
// 2^128
assembly ("memory-safe") {
let factor := shl(7, gt(x, 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF))
x := shr(factor, x)
result := or(result, factor)
}
// 2^64
assembly ("memory-safe") {
let factor := shl(6, gt(x, 0xFFFFFFFFFFFFFFFF))
x := shr(factor, x)
result := or(result, factor)
}
// 2^32
assembly ("memory-safe") {
let factor := shl(5, gt(x, 0xFFFFFFFF))
x := shr(factor, x)
result := or(result, factor)
}
// 2^16
assembly ("memory-safe") {
let factor := shl(4, gt(x, 0xFFFF))
x := shr(factor, x)
result := or(result, factor)
}
// 2^8
assembly ("memory-safe") {
let factor := shl(3, gt(x, 0xFF))
x := shr(factor, x)
result := or(result, factor)
}
// 2^4
assembly ("memory-safe") {
let factor := shl(2, gt(x, 0xF))
x := shr(factor, x)
result := or(result, factor)
}
// 2^2
assembly ("memory-safe") {
let factor := shl(1, gt(x, 0x3))
x := shr(factor, x)
result := or(result, factor)
}
// 2^1
// No need to shift x any more.
assembly ("memory-safe") {
let factor := gt(x, 0x1)
result := or(result, factor)
// 2^128
result := shl(7, lt(0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF, x))
// 2^64
result := or(result, shl(6, lt(0xFFFFFFFFFFFFFFFF, shr(result, x))))
// 2^32
result := or(result, shl(5, lt(0xFFFFFFFF, shr(result, x))))
// 2^16
result := or(result, shl(4, lt(0xFFFF, shr(result, x))))
// 2^8
result := or(result, shl(3, lt(0xFF, shr(result, x))))
// 2^4
result := or(result, shl(2, lt(0xF, shr(result, x))))
// 2^2
result := or(result, shl(1, lt(0x3, shr(result, x))))
// 2^1
result := or(result, lt(0x1, shr(result, x)))
}
}

Expand Down Expand Up @@ -614,10 +584,6 @@ function mulDivSigned(int256 x, int256 y, int256 denominator) pure returns (int2
/// @return result The result as a uint256.
/// @custom:smtchecker abstract-function-nondet
function sqrt(uint256 x) pure returns (uint256 result) {
if (x == 0) {
return 0;
}

// For our first guess, we calculate the biggest power of 2 which is smaller than the square root of x.
//
// We know that the "msb" (most significant bit) of x is a power of 2 such that we have:
Expand All @@ -641,53 +607,27 @@ function sqrt(uint256 x) pure returns (uint256 result) {
// $$
//
// Consequently, $2^{log_2(x) /2} is a good first approximation of sqrt(x) with at least one correct bit.
uint256 xAux = uint256(x);
result = 1;
if (xAux >= 2 ** 128) {
xAux >>= 128;
result <<= 64;
}
if (xAux >= 2 ** 64) {
xAux >>= 64;
result <<= 32;
}
if (xAux >= 2 ** 32) {
xAux >>= 32;
result <<= 16;
}
if (xAux >= 2 ** 16) {
xAux >>= 16;
result <<= 8;
}
if (xAux >= 2 ** 8) {
xAux >>= 8;
result <<= 4;
}
if (xAux >= 2 ** 4) {
xAux >>= 4;
result <<= 2;
}
if (xAux >= 2 ** 2) {
result <<= 1;
unchecked {
// ideally, we should use arithmetic operators, but solc is not smart enough to optimize `2**(msb(x)/2)`
/// forge-lint: disable-next-line(incorrect-shift)
result = 1 << (msb(x) >> 1);
}

// At this point, `result` is an estimation with at least one bit of precision. We know the true value has at
// most 128 bits, since it is the square root of a uint256. Newton's method converges quadratically (precision
// doubles at every iteration). We thus need at most 7 iteration to turn our partial result with one bit of
// precision into the expected uint128 result.
unchecked {
result = (result + x / result) >> 1;
result = (result + x / result) >> 1;
result = (result + x / result) >> 1;
result = (result + x / result) >> 1;
result = (result + x / result) >> 1;
result = (result + x / result) >> 1;
result = (result + x / result) >> 1;
assembly ("memory-safe") {
// note: division by zero in EVM returns zero
result := shr(1, add(result, div(x, result)))
result := shr(1, add(result, div(x, result)))
result := shr(1, add(result, div(x, result)))
result := shr(1, add(result, div(x, result)))
result := shr(1, add(result, div(x, result)))
result := shr(1, add(result, div(x, result)))
result := shr(1, add(result, div(x, result)))

// If x is not a perfect square, round the result toward zero.
uint256 roundedResult = x / result;
if (result >= roundedResult) {
result = roundedResult;
}
result := sub(result, gt(result, div(x, result)))
}
}
44 changes: 44 additions & 0 deletions test/fuzz/common/msb.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,46 @@ import { msb } from "src/Common.sol";

import { Base_Test } from "../../Base.t.sol";

/// @dev High level implementation of `msb`, for verifying regressions.
///
/// From https://gist.github.com/PaulRBerg/f932f8693f2733e30c4d479e8e980948
function msbReferenceImplementation(uint256 x) pure returns (uint256 result) {
unchecked {
if (x >= 2 ** 128) {
x >>= 128;
result += 128;
}
if (x >= 2 ** 64) {
x >>= 64;
result += 64;
}
if (x >= 2 ** 32) {
x >>= 32;
result += 32;
}
if (x >= 2 ** 16) {
x >>= 16;
result += 16;
}
if (x >= 2 ** 8) {
x >>= 8;
result += 8;
}
if (x >= 2 ** 4) {
x >>= 4;
result += 4;
}
if (x >= 2 ** 2) {
x >>= 2;
result += 2;
}
// No need to shift x any more.
if (x >= 2 ** 1) {
result += 1;
}
}
}

/// @dev Collection of tests for the most significant bit function `msb` available in `Common.sol`.
contract Common_Msb_Test is Base_Test {
function testFuzz_Msb_FitsUint8(uint256 x) external pure {
Expand Down Expand Up @@ -33,4 +73,8 @@ contract Common_Msb_Test is Base_Test {
function testFuzz_Msb_Shifts2ToMoreThanX(uint256 x) external pure whenShiftLeftDoesNotOverflow(x) {
assertGt(2 << msb(x), x, "2 ^ {msb(x)+1} not more than x");
}

function testFuzz_Msb_MatchesReferenceImplementation(uint256 x) external pure {
assertEq(msb(x), msbReferenceImplementation(x), "does not match reference implementation of msb");
}
}
58 changes: 58 additions & 0 deletions test/fuzz/common/sqrt.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,60 @@ import { sqrt, MAX_UINT128 } from "src/Common.sol";

import { Base_Test } from "../../Base.t.sol";

/// @dev Previous implementation of `sqrt`, for verifying regressions.
///
/// From https://github.com/PaulRBerg/prb-math/blob/v4.1.0/src/Common.sol#L587-L675
function sqrtReferenceImplemenation(uint256 x) pure returns (uint256 result) {
if (x == 0) {
return 0;
}

uint256 xAux = uint256(x);
result = 1;
if (xAux >= 2 ** 128) {
xAux >>= 128;
result <<= 64;
}
if (xAux >= 2 ** 64) {
xAux >>= 64;
result <<= 32;
}
if (xAux >= 2 ** 32) {
xAux >>= 32;
result <<= 16;
}
if (xAux >= 2 ** 16) {
xAux >>= 16;
result <<= 8;
}
if (xAux >= 2 ** 8) {
xAux >>= 8;
result <<= 4;
}
if (xAux >= 2 ** 4) {
xAux >>= 4;
result <<= 2;
}
if (xAux >= 2 ** 2) {
result <<= 1;
}

unchecked {
result = (result + x / result) >> 1;
result = (result + x / result) >> 1;
result = (result + x / result) >> 1;
result = (result + x / result) >> 1;
result = (result + x / result) >> 1;
result = (result + x / result) >> 1;
result = (result + x / result) >> 1;

uint256 roundedResult = x / result;
if (result >= roundedResult) {
result = roundedResult;
}
}
}

/// @dev Collection of tests for the square root function `sqrt` available in `Common.sol`.
contract Common_Sqrt_Test is Base_Test {
uint256 internal constant MAX_SQRT = MAX_UINT128;
Expand Down Expand Up @@ -37,4 +91,8 @@ contract Common_Sqrt_Test is Base_Test {
vm.assertLe(sqrt(x) ** 2, x, "incorrect sqrt of very large number");
vm.assertLt(x, (sqrt(x) + 1) ** 2, "incorrect sqrt of very large number");
}

function testFuzz_Sqrt_MatchesReferenceImplementation(uint256 x) external pure {
assertEq(sqrt(x), sqrtReferenceImplemenation(x), "does not match reference implementation of sqrt");
}
}