diff --git a/src/utils/LibSort.sol b/src/utils/LibSort.sol index 72ae9db754..360bf801fa 100644 --- a/src/utils/LibSort.sol +++ b/src/utils/LibSort.sol @@ -583,43 +583,62 @@ library LibSort { /// @dev Sorts and uniquifies `keys`. Updates `values` with the grouped sums by key. function groupSum(uint256[] memory keys, uint256[] memory values) internal pure { - uint256 m; /// @solidity memory-safe-assembly assembly { - m := mload(0x40) // Cache the free memory pointer, for freeing the memory. - if iszero(eq(mload(keys), mload(values))) { + function mswap(i_, j_) { + let t_ := mload(i_) + mstore(i_, mload(j_)) + mstore(j_, t_) + } + function sortInner(l_, h_, d_) { + let p_ := mload(l_) + let j_ := l_ + for { let i_ := add(l_, 0x20) } 1 {} { + if lt(mload(i_), p_) { + j_ := add(j_, 0x20) + mswap(i_, j_) + mswap(add(i_, d_), add(j_, d_)) + } + i_ := add(0x20, i_) + if iszero(lt(i_, h_)) { break } + } + mswap(l_, j_) + mswap(add(l_, d_), add(j_, d_)) + if iszero(gt(add(0x40, l_), j_)) { sortInner(l_, j_, d_) } + if iszero(gt(add(0x60, j_), h_)) { sortInner(add(j_, 0x20), h_, d_) } + } + let n := mload(values) + if iszero(eq(mload(keys), n)) { mstore(0x00, 0x4e487b71) mstore(0x20, 0x32) // Array out of bounds panic if the arrays lengths differ. revert(0x1c, 0x24) } - } - if (keys.length == uint256(0)) return; - (uint256[] memory oriKeys, uint256[] memory oriValues) = (copy(keys), copy(values)); - insertionSort(keys); // Optimize for small `n` and bytecode size. - uniquifySorted(keys); - /// @solidity memory-safe-assembly - assembly { - let d := sub(values, keys) - let w := not(0x1f) - let s := add(keys, 0x20) // Location of `keys[0]`. - mstore(values, mload(keys)) // Truncate. - calldatacopy(add(s, d), calldatasize(), shl(5, mload(keys))) // Zeroize. - for { let i := shl(5, mload(oriKeys)) } 1 {} { - let k := mload(add(oriKeys, i)) - let v := mload(add(oriValues, i)) - let j := s // Just do a linear scan to optimize for small `n` and bytecode size. - for {} iszero(eq(mload(j), k)) {} { j := add(j, 0x20) } - j := add(j, d) // Convert `j` to point into `values`. - mstore(j, add(mload(j), v)) - if lt(mload(j), v) { - mstore(0x00, 0x4e487b71) - mstore(0x20, 0x11) // Overflow panic if the addition overflows. - revert(0x1c, 0x24) + if iszero(lt(n, 2)) { + let d := sub(values, keys) + let x := add(keys, 0x20) + let end := add(x, shl(5, n)) + sortInner(x, end, d) + let s := mload(add(x, d)) + for { let y := add(keys, 0x40) } 1 {} { + if iszero(eq(mload(x), mload(y))) { + mstore(add(x, d), s) // Write sum. + s := 0 + x := add(x, 0x20) + mstore(x, mload(y)) + } + s := add(s, mload(add(y, d))) + if lt(s, mload(add(y, d))) { + mstore(0x00, 0x4e487b71) + mstore(0x20, 0x11) // Overflow panic if the addition overflows. + revert(0x1c, 0x24) + } + y := add(y, 0x20) + if eq(y, end) { break } } - i := add(i, w) // `sub(i, 0x20)`. - if iszero(i) { break } + mstore(add(x, d), s) // Write sum. + mstore(keys, shr(5, sub(x, keys))) // Truncate. + mstore(values, mload(keys)) // Truncate. } - mstore(0x40, m) // Frees the memory allocated for the temporary copies. } }