diff --git a/src/Common.sol b/src/Common.sol index fd3c5c9..0c5896e 100644 --- a/src/Common.sol +++ b/src/Common.sol @@ -330,53 +330,23 @@ function exp2(uint256 x) pure returns (uint256 result) { /// @return result The index of the most significant bit as a uint256. /// @custom:smtchecker abstract-function-nondet function msb(uint256 x) pure returns (uint256 result) { - // 2^128 assembly ("memory-safe") { - let factor := shl(7, gt(x, 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF)) - x := shr(factor, x) - result := or(result, factor) - } - // 2^64 - assembly ("memory-safe") { - let factor := shl(6, gt(x, 0xFFFFFFFFFFFFFFFF)) - x := shr(factor, x) - result := or(result, factor) - } - // 2^32 - assembly ("memory-safe") { - let factor := shl(5, gt(x, 0xFFFFFFFF)) - x := shr(factor, x) - result := or(result, factor) - } - // 2^16 - assembly ("memory-safe") { - let factor := shl(4, gt(x, 0xFFFF)) - x := shr(factor, x) - result := or(result, factor) - } - // 2^8 - assembly ("memory-safe") { - let factor := shl(3, gt(x, 0xFF)) - x := shr(factor, x) - result := or(result, factor) - } - // 2^4 - assembly ("memory-safe") { - let factor := shl(2, gt(x, 0xF)) - x := shr(factor, x) - result := or(result, factor) - } - // 2^2 - assembly ("memory-safe") { - let factor := shl(1, gt(x, 0x3)) - x := shr(factor, x) - result := or(result, factor) - } - // 2^1 - // No need to shift x any more. - assembly ("memory-safe") { - let factor := gt(x, 0x1) - result := or(result, factor) + // 2^128 + result := shl(7, lt(0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF, x)) + // 2^64 + result := or(result, shl(6, lt(0xFFFFFFFFFFFFFFFF, shr(result, x)))) + // 2^32 + result := or(result, shl(5, lt(0xFFFFFFFF, shr(result, x)))) + // 2^16 + result := or(result, shl(4, lt(0xFFFF, shr(result, x)))) + // 2^8 + result := or(result, shl(3, lt(0xFF, shr(result, x)))) + // 2^4 + result := or(result, shl(2, lt(0xF, shr(result, x)))) + // 2^2 + result := or(result, shl(1, lt(0x3, shr(result, x)))) + // 2^1 + result := or(result, lt(0x1, shr(result, x))) } } @@ -614,10 +584,6 @@ function mulDivSigned(int256 x, int256 y, int256 denominator) pure returns (int2 /// @return result The result as a uint256. /// @custom:smtchecker abstract-function-nondet function sqrt(uint256 x) pure returns (uint256 result) { - if (x == 0) { - return 0; - } - // For our first guess, we calculate the biggest power of 2 which is smaller than the square root of x. // // We know that the "msb" (most significant bit) of x is a power of 2 such that we have: @@ -641,53 +607,27 @@ function sqrt(uint256 x) pure returns (uint256 result) { // $$ // // Consequently, $2^{log_2(x) /2} is a good first approximation of sqrt(x) with at least one correct bit. - uint256 xAux = uint256(x); - result = 1; - if (xAux >= 2 ** 128) { - xAux >>= 128; - result <<= 64; - } - if (xAux >= 2 ** 64) { - xAux >>= 64; - result <<= 32; - } - if (xAux >= 2 ** 32) { - xAux >>= 32; - result <<= 16; - } - if (xAux >= 2 ** 16) { - xAux >>= 16; - result <<= 8; - } - if (xAux >= 2 ** 8) { - xAux >>= 8; - result <<= 4; - } - if (xAux >= 2 ** 4) { - xAux >>= 4; - result <<= 2; - } - if (xAux >= 2 ** 2) { - result <<= 1; + unchecked { + // ideally, we should use arithmetic operators, but solc is not smart enough to optimize `2**(msb(x)/2)` + /// forge-lint: disable-next-line(incorrect-shift) + result = 1 << (msb(x) >> 1); } // At this point, `result` is an estimation with at least one bit of precision. We know the true value has at // most 128 bits, since it is the square root of a uint256. Newton's method converges quadratically (precision // doubles at every iteration). We thus need at most 7 iteration to turn our partial result with one bit of // precision into the expected uint128 result. - unchecked { - result = (result + x / result) >> 1; - result = (result + x / result) >> 1; - result = (result + x / result) >> 1; - result = (result + x / result) >> 1; - result = (result + x / result) >> 1; - result = (result + x / result) >> 1; - result = (result + x / result) >> 1; + assembly ("memory-safe") { + // note: division by zero in EVM returns zero + result := shr(1, add(result, div(x, result))) + result := shr(1, add(result, div(x, result))) + result := shr(1, add(result, div(x, result))) + result := shr(1, add(result, div(x, result))) + result := shr(1, add(result, div(x, result))) + result := shr(1, add(result, div(x, result))) + result := shr(1, add(result, div(x, result))) // If x is not a perfect square, round the result toward zero. - uint256 roundedResult = x / result; - if (result >= roundedResult) { - result = roundedResult; - } + result := sub(result, gt(result, div(x, result))) } } diff --git a/test/fuzz/common/msb.t.sol b/test/fuzz/common/msb.t.sol index c2c5509..0aee6bb 100644 --- a/test/fuzz/common/msb.t.sol +++ b/test/fuzz/common/msb.t.sol @@ -5,6 +5,46 @@ import { msb } from "src/Common.sol"; import { Base_Test } from "../../Base.t.sol"; +/// @dev High level implementation of `msb`, for verifying regressions. +/// +/// From https://gist.github.com/PaulRBerg/f932f8693f2733e30c4d479e8e980948 +function msbReferenceImplementation(uint256 x) pure returns (uint256 result) { + unchecked { + if (x >= 2 ** 128) { + x >>= 128; + result += 128; + } + if (x >= 2 ** 64) { + x >>= 64; + result += 64; + } + if (x >= 2 ** 32) { + x >>= 32; + result += 32; + } + if (x >= 2 ** 16) { + x >>= 16; + result += 16; + } + if (x >= 2 ** 8) { + x >>= 8; + result += 8; + } + if (x >= 2 ** 4) { + x >>= 4; + result += 4; + } + if (x >= 2 ** 2) { + x >>= 2; + result += 2; + } + // No need to shift x any more. + if (x >= 2 ** 1) { + result += 1; + } + } +} + /// @dev Collection of tests for the most significant bit function `msb` available in `Common.sol`. contract Common_Msb_Test is Base_Test { function testFuzz_Msb_FitsUint8(uint256 x) external pure { @@ -33,4 +73,8 @@ contract Common_Msb_Test is Base_Test { function testFuzz_Msb_Shifts2ToMoreThanX(uint256 x) external pure whenShiftLeftDoesNotOverflow(x) { assertGt(2 << msb(x), x, "2 ^ {msb(x)+1} not more than x"); } + + function testFuzz_Msb_MatchesReferenceImplementation(uint256 x) external pure { + assertEq(msb(x), msbReferenceImplementation(x), "does not match reference implementation of msb"); + } } diff --git a/test/fuzz/common/sqrt.t.sol b/test/fuzz/common/sqrt.t.sol index 748a04c..655a197 100644 --- a/test/fuzz/common/sqrt.t.sol +++ b/test/fuzz/common/sqrt.t.sol @@ -5,6 +5,60 @@ import { sqrt, MAX_UINT128 } from "src/Common.sol"; import { Base_Test } from "../../Base.t.sol"; +/// @dev Previous implementation of `sqrt`, for verifying regressions. +/// +/// From https://github.com/PaulRBerg/prb-math/blob/v4.1.0/src/Common.sol#L587-L675 +function sqrtReferenceImplemenation(uint256 x) pure returns (uint256 result) { + if (x == 0) { + return 0; + } + + uint256 xAux = uint256(x); + result = 1; + if (xAux >= 2 ** 128) { + xAux >>= 128; + result <<= 64; + } + if (xAux >= 2 ** 64) { + xAux >>= 64; + result <<= 32; + } + if (xAux >= 2 ** 32) { + xAux >>= 32; + result <<= 16; + } + if (xAux >= 2 ** 16) { + xAux >>= 16; + result <<= 8; + } + if (xAux >= 2 ** 8) { + xAux >>= 8; + result <<= 4; + } + if (xAux >= 2 ** 4) { + xAux >>= 4; + result <<= 2; + } + if (xAux >= 2 ** 2) { + result <<= 1; + } + + unchecked { + result = (result + x / result) >> 1; + result = (result + x / result) >> 1; + result = (result + x / result) >> 1; + result = (result + x / result) >> 1; + result = (result + x / result) >> 1; + result = (result + x / result) >> 1; + result = (result + x / result) >> 1; + + uint256 roundedResult = x / result; + if (result >= roundedResult) { + result = roundedResult; + } + } +} + /// @dev Collection of tests for the square root function `sqrt` available in `Common.sol`. contract Common_Sqrt_Test is Base_Test { uint256 internal constant MAX_SQRT = MAX_UINT128; @@ -37,4 +91,8 @@ contract Common_Sqrt_Test is Base_Test { vm.assertLe(sqrt(x) ** 2, x, "incorrect sqrt of very large number"); vm.assertLt(x, (sqrt(x) + 1) ** 2, "incorrect sqrt of very large number"); } + + function testFuzz_Sqrt_MatchesReferenceImplementation(uint256 x) external pure { + assertEq(sqrt(x), sqrtReferenceImplemenation(x), "does not match reference implementation of sqrt"); + } }