Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ The format is based on [Keep a Changelog], and this project adheres to [Semantic
### Changed (Breaking)

- Missing selector returns an error instead of panicking. #88
- Api `VMContext::current()` simplified to `VM::context()`. #91
- `evm::gas_left` and `evm::ink_left` return maximum values. #90

### Changed
Expand Down
6 changes: 3 additions & 3 deletions crates/motsu/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ execution environment:

#### Chain ID

You can set the Chain ID in tests using the [`VMContext`][vm_context] API:
You can set the Chain ID in tests using the [`VM`][vm] API:

```rust
use motsu::prelude::*;
Expand All @@ -69,7 +69,7 @@ fn test_with_custom_chain_id(
// Default chain ID is 42161 (Arbitrum One)

// Set chain ID to 11155111 (Sepolia testnet)
VMContext::current().set_chain_id(11155111);
VM::context().set_chain_id(11155111);

// Now any contract code that depends on `block::chainid()`
// will use this value
Expand Down Expand Up @@ -233,7 +233,7 @@ Refer to our [Security Policy] for more details.

[address]: stylus_sdk::alloy_primitives::Address

[vm_context]: crate::prelude::VMContext
[vm]: crate::prelude::VM

[contract_sender]: crate::prelude::Contract::sender

Expand Down
112 changes: 54 additions & 58 deletions crates/motsu/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use std::{
collections::HashMap,
ops::{Deref, DerefMut},
ptr, slice,
sync::LazyLock,
thread::ThreadId,
};

Expand All @@ -14,32 +15,26 @@ use alloy_signer_local::PrivateKeySigner;
use alloy_sol_types::{SolEvent, Word};
use dashmap::{mapref::one::RefMut, DashMap};
use k256::ecdsa::SigningKey;
use once_cell::sync::Lazy;
use stylus_sdk::{
host::{WasmVM, VM},
keccak_const::Keccak256,
prelude::StorageType,
types::AddressVM,
ArbResult,
keccak_const::Keccak256, prelude::StorageType, types::AddressVM, ArbResult,
};

use crate::{
router::{VMRouter, VMRouterContext},
router::{Router, VMRouter},
storage_access::AccessStorage,
};

/// Motsu VM Storage.
///
/// A global mutable key-value store that allows concurrent access.
///
/// The key is the test [`VMContext`], an id of the test thread.
/// The key is the test [`VM`], an id of the test thread.
///
/// The value is the [`VMContextStorage`], a storage of the test case.
/// The value is the [`VMStorage`], a storage of the test case.
///
/// NOTE: The [`VMContext::storage`] will panic on lock, when the same key is
/// NOTE: The [`VM::storage`] will panic on lock, when the same key is
/// accessed twice from the same thread.
static MOTSU_VM: Lazy<DashMap<VMContext, VMContextStorage>> =
Lazy::new(DashMap::new);
static MOTSU_VM: LazyLock<DashMap<VM, VMStorage>> = LazyLock::new(DashMap::new);

// TODO: remove this after we can enable the `stylus-test` feature, which should
// happen after we refactor `motsu` to implement a mock
Expand All @@ -52,14 +47,14 @@ pub(crate) const DEFAULT_CHAIN_ID: u64 = 42161;
/// Context of Motsu test VM associated with the current test thread.
#[allow(clippy::module_name_repetitions)]
#[derive(Hash, Eq, PartialEq, Copy, Clone)]
pub struct VMContext {
pub struct VM {
thread_id: ThreadId,
}

impl VMContext {
/// Get test context associated with the current test thread.
impl VM {
/// Get test `VM` associated with the current test thread.
#[must_use]
pub fn current() -> Self {
pub fn context() -> Self {
Self { thread_id: std::thread::current().id() }
}

Expand Down Expand Up @@ -153,7 +148,7 @@ impl VMContext {

/// Initialise contract's storage for the current test thread and
/// `contract_address`.
fn init_storage<ST: StorageType + VMRouter + 'static>(
fn init_storage<ST: StorageType + Router + 'static>(
self,
contract_address: Address,
) {
Expand All @@ -171,10 +166,10 @@ impl VMContext {
self.router(contract_address).init_storage::<ST>();
}

/// Reset storage for the current [`VMContext`] and `contract_address`.
/// Reset storage for the current [`VM`] and `contract_address`.
///
/// If all test contracts are removed, flush storage for the current
/// test [`VMContext`].
/// test [`VM`].
fn reset_storage(self, contract_address: Address) {
let mut storage = self.storage();
storage.persistent.contracts.remove(&contract_address);
Expand Down Expand Up @@ -538,13 +533,13 @@ impl VMContext {
}

/// Get reference to the storage for the current test thread.
fn storage(self) -> RefMut<'static, VMContext, VMContextStorage> {
fn storage(self) -> RefMut<'static, VM, VMStorage> {
MOTSU_VM.access_storage(&self)
}

/// Get router for the contract at `address`.
fn router(self, address: Address) -> VMRouterContext {
VMRouterContext::new(self.thread_id, address)
fn router(self, address: Address) -> VMRouter {
VMRouter::new(self.thread_id, address)
}

/// Get the current chain ID.
Expand Down Expand Up @@ -603,7 +598,7 @@ fn decode_selector(calldata: &[u8]) -> u32 {
}

/// Main storage for Motsu test VM.
struct VMContextStorage {
struct VMStorage {
/// Address of the message sender.
msg_sender: Option<Address>,
/// The ETH value in wei sent to the program.
Expand All @@ -622,7 +617,7 @@ struct VMContextStorage {
chain_id: u64,
}

impl Default for VMContextStorage {
impl Default for VMStorage {
fn default() -> Self {
Self {
msg_sender: None,
Expand Down Expand Up @@ -738,10 +733,9 @@ pub struct ContractCall<'a, ST: StorageType> {
impl<ST: StorageType> ContractCall<'_, ST> {
/// Preset the call parameters.
fn set_call_params(&self) {
_ = VMContext::current().replace_optional_msg_value(self.msg_value);
_ = VMContext::current().replace_msg_sender(self.msg_sender);
_ = VMContext::current()
.replace_contract_address(self.contract_ref.address);
VM::context().replace_optional_msg_value(self.msg_value);
VM::context().replace_msg_sender(self.msg_sender);
VM::context().replace_contract_address(self.contract_ref.address);
}

/// Invalidate the storage cache of stylus contract [`StorageType`], by
Expand All @@ -758,15 +752,15 @@ impl<ST: StorageType> Deref for ContractCall<'_, ST> {

#[inline]
fn deref(&self) -> &Self::Target {
VMContext::current().backup();
VM::context().backup();

// Set parameters for call such as `msg_sender`, `contract_address`,
// `msg_value`.
self.set_call_params();

// Transfer value (if any) from the `msg_sender` to `contract_address`,
// that was set on the previous step.
VMContext::current().transfer_value();
VM::context().transfer_value();

// SAFETY: We don't use `ST` contract type as intended by rust.
// We don't care about any state it has in any property.
Expand All @@ -781,15 +775,15 @@ impl<ST: StorageType> Deref for ContractCall<'_, ST> {
impl<ST: StorageType> DerefMut for ContractCall<'_, ST> {
#[inline]
fn deref_mut(&mut self) -> &mut Self::Target {
VMContext::current().backup();
VM::context().backup();

// Set parameters for call such as `msg_sender`, `contract_address`,
// `msg_value`.
self.set_call_params();

// Transfer value (if any) from the `msg_sender` to `contract_address`,
// that was set on the previous step.
VMContext::current().transfer_value();
VM::context().transfer_value();

self.invalidate_storage_type_cache();
self.storage.get_mut()
Expand All @@ -804,17 +798,17 @@ pub struct Contract<ST: StorageType> {

impl<ST: StorageType> Drop for Contract<ST> {
fn drop(&mut self) {
VMContext::current().reset_storage(self.address);
VM::context().reset_storage(self.address);
}
}

impl<ST: StorageType + VMRouter + 'static> Default for Contract<ST> {
impl<ST: StorageType + Router + 'static> Default for Contract<ST> {
fn default() -> Self {
Contract::new_at(Address::default())
}
}

impl<ST: StorageType + VMRouter + 'static> Contract<ST> {
impl<ST: StorageType + Router + 'static> Contract<ST> {
/// Create a new contract with default storage on the random address.
#[must_use]
pub fn new() -> Self {
Expand All @@ -824,7 +818,7 @@ impl<ST: StorageType + VMRouter + 'static> Contract<ST> {
/// Create a new contract with the given `address`.
#[must_use]
pub fn new_at(address: Address) -> Self {
VMContext::current().init_storage::<ST>(address);
VM::context().init_storage::<ST>(address);

Self { phantom: ::core::marker::PhantomData, address }
}
Expand Down Expand Up @@ -882,7 +876,7 @@ impl<ST: StorageType + VMRouter + 'static> Contract<ST> {

/// Check if the `event` was emitted, by the contract `self`.
pub fn emitted<E: SolEvent>(&self, event: &E) -> bool {
VMContext::current().emitted_for(&self.address, event)
VM::context().emitted_for(&self.address, event)
}

/// Assert that the `event` was emitted, by the contract `self`.
Expand All @@ -895,7 +889,7 @@ impl<ST: StorageType + VMRouter + 'static> Contract<ST> {
/// any) in the error message.
#[track_caller]
pub fn assert_emitted<E: SolEvent + Debug>(&self, event: &E) {
let context = VMContext::current();
let context = VM::context();
if context.emitted_for(&self.address, event) {
return;
}
Expand All @@ -916,7 +910,15 @@ impl<ST: StorageType + VMRouter + 'static> Contract<ST> {
/// Create a default [`StorageType`] `ST` type with at [`U256::ZERO`] slot and
/// `0` offset.
pub(crate) fn create_default_storage_type<ST: StorageType>() -> ST {
unsafe { ST::new(U256::ZERO, 0, VM { host: Box::new(WasmVM {}) }) }
unsafe {
ST::new(
U256::ZERO,
0,
stylus_sdk::host::VM {
host: Box::new(stylus_sdk::host::WasmVM {}),
},
)
}
}

/// Account that can be used to interact with contracts in test environments.
Expand Down Expand Up @@ -1023,7 +1025,7 @@ pub trait Funding {

impl Funding for Address {
fn fund(&self, value: U256) {
VMContext::current().add_assign_balance(*self, value);
VM::context().add_assign_balance(*self, value);
}
}

Expand All @@ -1033,7 +1035,7 @@ impl Funding for Account {
}
}

impl<ST: StorageType + VMRouter + 'static> Funding for Contract<ST> {
impl<ST: StorageType + Router + 'static> Funding for Contract<ST> {
fn fund(&self, value: U256) {
self.address().fund(value);
}
Expand All @@ -1051,7 +1053,7 @@ impl Balance for Account {
}
}

impl<ST: StorageType + VMRouter + 'static> Balance for Contract<ST> {
impl<ST: StorageType + Router + 'static> Balance for Contract<ST> {
fn balance(&self) -> U256 {
self.address().balance()
}
Expand All @@ -1074,7 +1076,7 @@ impl FromTag for Account {
/// Also registers the tag in the test context for debugging purposes.
fn from_tag(tag: &str) -> Self {
let account = Account::from_seed(tag);
VMContext::current().set_tag(account.address(), tag.to_string());
VM::context().set_tag(account.address(), tag.to_string());
account
}
}
Expand All @@ -1086,12 +1088,12 @@ impl FromTag for Address {
fn from_tag(tag: &str) -> Self {
let hash = Keccak256::new().update(tag.as_bytes()).finalize();
let address = Address::from_slice(&hash[..20]);
VMContext::current().set_tag(address, tag.to_string());
VM::context().set_tag(address, tag.to_string());
address
}
}

impl<ST: StorageType + VMRouter + 'static> FromTag for Contract<ST> {
impl<ST: StorageType + Router + 'static> FromTag for Contract<ST> {
/// Creates a contract at an address derived from the tag string.
///
/// This allows deploying contracts to deterministic addresses for testing.
Expand All @@ -1107,7 +1109,7 @@ mod tests {
use stylus_sdk::prelude::*;

use super::{Account, Address, Contract, FromTag};
use crate::context::VMContext;
use crate::context::VM;

mod account {
use std::ops::Deref;
Expand Down Expand Up @@ -1210,10 +1212,7 @@ mod tests {

assert_eq!(expected_private_key, account.private_key);
assert_eq!(expected_address, account.address());
assert_eq!(
Some(tag),
VMContext::current().get_tag(account.address())
);
assert_eq!(Some(tag), VM::context().get_tag(account.address()));
}

#[test]
Expand All @@ -1225,7 +1224,7 @@ mod tests {
let address = Address::from_tag(&tag);

assert_eq!(expected_address, address);
assert_eq!(Some(tag), VMContext::current().get_tag(address));
assert_eq!(Some(tag), VM::context().get_tag(address));
}

#[storage]
Expand All @@ -1245,10 +1244,7 @@ mod tests {
let contract = Contract::<SomeContract>::from_tag(&tag);

assert_eq!(expected_address, contract.address());
assert_eq!(
Some(tag),
VMContext::current().get_tag(contract.address())
);
assert_eq!(Some(tag), VM::context().get_tag(contract.address()));
}

#[test]
Expand All @@ -1265,9 +1261,9 @@ mod tests {

// all addresses still map to the same tag
let tag = Some(tag.to_owned());
assert_eq!(tag, VMContext::current().get_tag(address));
assert_eq!(tag, VMContext::current().get_tag(account.address()));
assert_eq!(tag, VMContext::current().get_tag(contract.address()));
assert_eq!(tag, VM::context().get_tag(address));
assert_eq!(tag, VM::context().get_tag(account.address()));
assert_eq!(tag, VM::context().get_tag(contract.address()));
}
}
}
Loading
Loading