Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/quick-pianos-press.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
'openzeppelin-solidity': minor
---

`ReentrancyGuard` and `ReentrancyGuardTransient`: Add `nonReentrantView`, a read-only version of the `nonReentrant` modifier.
5 changes: 5 additions & 0 deletions contracts/mocks/ReentrancyAttack.sol
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,9 @@ contract ReentrancyAttack is Context {
(bool success, ) = _msgSender().call(data);
require(success, "ReentrancyAttack: failed call");
}

function staticcallSender(bytes calldata data) public view {
(bool success, ) = _msgSender().staticcall(data);
require(success, "ReentrancyAttack: failed call");
}
}
9 changes: 9 additions & 0 deletions contracts/mocks/ReentrancyMock.sol
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ contract ReentrancyMock is ReentrancyGuard {
_count();
}

function viewCallback() external view nonReentrantView returns (uint256) {
return counter;
}

function countLocalRecursive(uint256 n) public nonReentrant {
if (n > 0) {
_count();
Expand All @@ -36,6 +40,11 @@ contract ReentrancyMock is ReentrancyGuard {
attacker.callSender(abi.encodeCall(this.callback, ()));
}

function countAndCallView(ReentrancyAttack attacker) public nonReentrant {
_count();
attacker.staticcallSender(abi.encodeCall(this.viewCallback, ()));
}

function _count() private {
counter += 1;
}
Expand Down
9 changes: 9 additions & 0 deletions contracts/mocks/ReentrancyTransientMock.sol
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ contract ReentrancyTransientMock is ReentrancyGuardTransient {
_count();
}

function viewCallback() external view nonReentrantView returns (uint256) {
return counter;
}

function countLocalRecursive(uint256 n) public nonReentrant {
if (n > 0) {
_count();
Expand All @@ -36,6 +40,11 @@ contract ReentrancyTransientMock is ReentrancyGuardTransient {
attacker.callSender(abi.encodeCall(this.callback, ()));
}

function countAndCallView(ReentrancyAttack attacker) public nonReentrant {
_count();
attacker.staticcallSender(abi.encodeCall(this.viewCallback, ()));
}

function _count() private {
counter += 1;
}
Expand Down
18 changes: 18 additions & 0 deletions contracts/utils/ReentrancyGuard.sol
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,24 @@ abstract contract ReentrancyGuard {
_nonReentrantAfter();
}

/**
* @dev View variant of the `nonReentrant` modifier. Can be used to prevent view functions from being called
* while the internal state of the contract is inconsistent and invariants do not hold.
*
* This being a "view" version of the modifier, it will not set the reentrancy status. This modifier should only
* be used in view functions. Payable and non-payable function should use the standard "nonReentrant" modifier.
*/
modifier nonReentrantView() {
_nonReentrantBeforeView();
_;
}

function _nonReentrantBeforeView() private view {
if (_status == ENTERED) {
revert ReentrancyGuardReentrantCall();
}
}

function _nonReentrantBefore() private {
// On the first call to nonReentrant, _status will be NOT_ENTERED
if (_status == ENTERED) {
Expand Down
18 changes: 18 additions & 0 deletions contracts/utils/ReentrancyGuardTransient.sol
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,24 @@ abstract contract ReentrancyGuardTransient {
_nonReentrantAfter();
}

/**
* @dev View variant of the `nonReentrant` modifier. Can be used to prevent view functions from being called
* while the internal state of the contract is inconsistent and invariants do not hold.
*
* This being a "view" version of the modifier, it will not set the reentrancy status. This modifier should only
* be used in view functions. Payable and non-payable function should use the standard "nonReentrant" modifier.
*/
modifier nonReentrantView() {
_nonReentrantBeforeView();
_;
}

function _nonReentrantBeforeView() private view {
if (_reentrancyGuardEntered()) {
revert ReentrancyGuardReentrantCall();
}
}

function _nonReentrantBefore() private {
// On the first call to nonReentrant, REENTRANCY_GUARD_STORAGE.asBoolean().tload() will be false
if (_reentrancyGuardEntered()) {
Expand Down
16 changes: 12 additions & 4 deletions test/utils/ReentrancyGuard.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ for (const variant of ['', 'Transient']) {
async function fixture() {
const name = `Reentrancy${variant}Mock`;
const mock = await ethers.deployContract(name);
return { name, mock };
const attacker = await ethers.deployContract('ReentrancyAttack');
return { name, mock, attacker };
}

beforeEach(async function () {
Expand All @@ -20,9 +21,16 @@ for (const variant of ['', 'Transient']) {
expect(await this.mock.counter()).to.equal(1n);
});

it('does not allow remote callback', async function () {
const attacker = await ethers.deployContract('ReentrancyAttack');
await expect(this.mock.countAndCall(attacker)).to.be.revertedWith('ReentrancyAttack: failed call');
it('nonReentrantView function can be called', async function () {
await this.mock.viewCallback();
});

it('does not allow remote callback to nonReentrant function', async function () {
await expect(this.mock.countAndCall(this.attacker)).to.be.revertedWith('ReentrancyAttack: failed call');
});

it('does not allow remote callback to nonReentrantView function', async function () {
await expect(this.mock.countAndCallView(this.attacker)).to.be.revertedWith('ReentrancyAttack: failed call');
});

it('_reentrancyGuardEntered should be true when guarded', async function () {
Expand Down