Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ The format is based on [Keep a Changelog], and this project adheres to [Semantic
### Changed (Breaking)

- Missing selector returns an error instead of panicking. #88
- Api `VMContext::current()` simplified to `VM::context()`. #91
- `evm::gas_left` and `evm::ink_left` return maximum values. #90

### Changed
Expand Down
6 changes: 3 additions & 3 deletions crates/motsu/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ execution environment:

#### Chain ID

You can set the Chain ID in tests using the [`VMContext`][vm_context] API:
You can set the Chain ID in tests using the [`VM`][vm] API:

```rust
use motsu::prelude::*;
Expand All @@ -69,7 +69,7 @@ fn test_with_custom_chain_id(
// Default chain ID is 42161 (Arbitrum One)

// Set chain ID to 11155111 (Sepolia testnet)
VMContext::current().set_chain_id(11155111);
VM::context().set_chain_id(11155111);

// Now any contract code that depends on `block::chainid()`
// will use this value
Expand Down Expand Up @@ -233,7 +233,7 @@ Refer to our [Security Policy] for more details.

[address]: stylus_sdk::alloy_primitives::Address

[vm_context]: crate::prelude::VMContext
[vm]: crate::prelude::VM

[contract_sender]: crate::prelude::Contract::sender

Expand Down
112 changes: 54 additions & 58 deletions crates/motsu/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use std::{
collections::HashMap,
ops::{Deref, DerefMut},
ptr, slice,
sync::LazyLock,
thread::ThreadId,
};

Expand All @@ -14,32 +15,26 @@ use alloy_signer_local::PrivateKeySigner;
use alloy_sol_types::{SolEvent, Word};
use dashmap::{mapref::one::RefMut, DashMap};
use k256::ecdsa::SigningKey;
use once_cell::sync::Lazy;
use stylus_sdk::{
host::{WasmVM, VM},
keccak_const::Keccak256,
prelude::StorageType,
types::AddressVM,
ArbResult,
keccak_const::Keccak256, prelude::StorageType, types::AddressVM, ArbResult,
};

use crate::{
router::{VMRouter, VMRouterContext},
router::{Router, VMRouter},
storage_access::AccessStorage,
};

/// Motsu VM Storage.
///
/// A global mutable key-value store that allows concurrent access.
///
/// The key is the test [`VMContext`], an id of the test thread.
/// The key is the test [`VM`], an id of the test thread.
///
/// The value is the [`VMContextStorage`], a storage of the test case.
/// The value is the [`VMStorage`], a storage of the test case.
///
/// NOTE: The [`VMContext::storage`] will panic on lock, when the same key is
/// NOTE: The [`VM::storage`] will panic on lock, when the same key is
/// accessed twice from the same thread.
static MOTSU_VM: Lazy<DashMap<VMContext, VMContextStorage>> =
Lazy::new(DashMap::new);
static MOTSU_VM: LazyLock<DashMap<VM, VMStorage>> = LazyLock::new(DashMap::new);

// TODO: remove this after we can enable the `stylus-test` feature, which should
// happen after we refactor `motsu` to implement a mock
Expand All @@ -52,14 +47,14 @@ pub(crate) const DEFAULT_CHAIN_ID: u64 = 42161;
/// Context of Motsu test VM associated with the current test thread.
#[allow(clippy::module_name_repetitions)]
#[derive(Hash, Eq, PartialEq, Copy, Clone)]
pub struct VMContext {
pub struct VM {
thread_id: ThreadId,
}

impl VMContext {
/// Get test context associated with the current test thread.
impl VM {
/// Get test `VM` associated with the current test thread.
#[must_use]
pub fn current() -> Self {
pub fn context() -> Self {
Self { thread_id: std::thread::current().id() }
}

Expand Down Expand Up @@ -153,7 +148,7 @@ impl VMContext {

/// Initialise contract's storage for the current test thread and
/// `contract_address`.
fn init_storage<ST: StorageType + VMRouter + 'static>(
fn init_storage<ST: StorageType + Router + 'static>(
self,
contract_address: Address,
) {
Expand All @@ -171,10 +166,10 @@ impl VMContext {
self.router(contract_address).init_storage::<ST>();
}

/// Reset storage for the current [`VMContext`] and `contract_address`.
/// Reset storage for the current [`VM`] and `contract_address`.
///
/// If all test contracts are removed, flush storage for the current
/// test [`VMContext`].
/// test [`VM`].
fn reset_storage(self, contract_address: Address) {
let mut storage = self.storage();
storage.persistent.contracts.remove(&contract_address);
Expand Down Expand Up @@ -538,13 +533,13 @@ impl VMContext {
}

/// Get reference to the storage for the current test thread.
fn storage(self) -> RefMut<'static, VMContext, VMContextStorage> {
fn storage(self) -> RefMut<'static, VM, VMStorage> {
MOTSU_VM.access_storage(&self)
}

/// Get router for the contract at `address`.
fn router(self, address: Address) -> VMRouterContext {
VMRouterContext::new(self.thread_id, address)
fn router(self, address: Address) -> VMRouter {
VMRouter::new(self.thread_id, address)
}

/// Get the current chain ID.
Expand Down Expand Up @@ -603,7 +598,7 @@ fn decode_selector(calldata: &[u8]) -> u32 {
}

/// Main storage for Motsu test VM.
struct VMContextStorage {
struct VMStorage {
/// Address of the message sender.
msg_sender: Option<Address>,
/// The ETH value in wei sent to the program.
Expand All @@ -622,7 +617,7 @@ struct VMContextStorage {
chain_id: u64,
}

impl Default for VMContextStorage {
impl Default for VMStorage {
fn default() -> Self {
Self {
msg_sender: None,
Expand Down Expand Up @@ -738,10 +733,9 @@ pub struct ContractCall<'a, ST: StorageType> {
impl<ST: StorageType> ContractCall<'_, ST> {
/// Preset the call parameters.
fn set_call_params(&self) {
_ = VMContext::current().replace_optional_msg_value(self.msg_value);
_ = VMContext::current().replace_msg_sender(self.msg_sender);
_ = VMContext::current()
.replace_contract_address(self.contract_ref.address);
_ = VM::context().replace_optional_msg_value(self.msg_value);
_ = VM::context().replace_msg_sender(self.msg_sender);
_ = VM::context().replace_contract_address(self.contract_ref.address);
}

/// Invalidate the storage cache of stylus contract [`StorageType`], by
Expand All @@ -758,15 +752,15 @@ impl<ST: StorageType> Deref for ContractCall<'_, ST> {

#[inline]
fn deref(&self) -> &Self::Target {
VMContext::current().backup();
VM::context().backup();

// Set parameters for call such as `msg_sender`, `contract_address`,
// `msg_value`.
self.set_call_params();

// Transfer value (if any) from the `msg_sender` to `contract_address`,
// that was set on the previous step.
VMContext::current().transfer_value();
VM::context().transfer_value();

// SAFETY: We don't use `ST` contract type as intended by rust.
// We don't care about any state it has in any property.
Expand All @@ -781,15 +775,15 @@ impl<ST: StorageType> Deref for ContractCall<'_, ST> {
impl<ST: StorageType> DerefMut for ContractCall<'_, ST> {
#[inline]
fn deref_mut(&mut self) -> &mut Self::Target {
VMContext::current().backup();
VM::context().backup();

// Set parameters for call such as `msg_sender`, `contract_address`,
// `msg_value`.
self.set_call_params();

// Transfer value (if any) from the `msg_sender` to `contract_address`,
// that was set on the previous step.
VMContext::current().transfer_value();
VM::context().transfer_value();

self.invalidate_storage_type_cache();
self.storage.get_mut()
Expand All @@ -804,17 +798,17 @@ pub struct Contract<ST: StorageType> {

impl<ST: StorageType> Drop for Contract<ST> {
fn drop(&mut self) {
VMContext::current().reset_storage(self.address);
VM::context().reset_storage(self.address);
}
}

impl<ST: StorageType + VMRouter + 'static> Default for Contract<ST> {
impl<ST: StorageType + Router + 'static> Default for Contract<ST> {
fn default() -> Self {
Contract::new_at(Address::default())
}
}

impl<ST: StorageType + VMRouter + 'static> Contract<ST> {
impl<ST: StorageType + Router + 'static> Contract<ST> {
/// Create a new contract with default storage on the random address.
#[must_use]
pub fn new() -> Self {
Expand All @@ -824,7 +818,7 @@ impl<ST: StorageType + VMRouter + 'static> Contract<ST> {
/// Create a new contract with the given `address`.
#[must_use]
pub fn new_at(address: Address) -> Self {
VMContext::current().init_storage::<ST>(address);
VM::context().init_storage::<ST>(address);

Self { phantom: ::core::marker::PhantomData, address }
}
Expand Down Expand Up @@ -882,7 +876,7 @@ impl<ST: StorageType + VMRouter + 'static> Contract<ST> {

/// Check if the `event` was emitted, by the contract `self`.
pub fn emitted<E: SolEvent>(&self, event: &E) -> bool {
VMContext::current().emitted_for(&self.address, event)
VM::context().emitted_for(&self.address, event)
}

/// Assert that the `event` was emitted, by the contract `self`.
Expand All @@ -895,7 +889,7 @@ impl<ST: StorageType + VMRouter + 'static> Contract<ST> {
/// any) in the error message.
#[track_caller]
pub fn assert_emitted<E: SolEvent + Debug>(&self, event: &E) {
let context = VMContext::current();
let context = VM::context();
if context.emitted_for(&self.address, event) {
return;
}
Expand All @@ -916,7 +910,15 @@ impl<ST: StorageType + VMRouter + 'static> Contract<ST> {
/// Create a default [`StorageType`] `ST` type with at [`U256::ZERO`] slot and
/// `0` offset.
pub(crate) fn create_default_storage_type<ST: StorageType>() -> ST {
unsafe { ST::new(U256::ZERO, 0, VM { host: Box::new(WasmVM {}) }) }
unsafe {
ST::new(
U256::ZERO,
0,
stylus_sdk::host::VM {
host: Box::new(stylus_sdk::host::WasmVM {}),
},
)
}
}

/// Account that can be used to interact with contracts in test environments.
Expand Down Expand Up @@ -1023,7 +1025,7 @@ pub trait Funding {

impl Funding for Address {
fn fund(&self, value: U256) {
VMContext::current().add_assign_balance(*self, value);
VM::context().add_assign_balance(*self, value);
}
}

Expand All @@ -1033,7 +1035,7 @@ impl Funding for Account {
}
}

impl<ST: StorageType + VMRouter + 'static> Funding for Contract<ST> {
impl<ST: StorageType + Router + 'static> Funding for Contract<ST> {
fn fund(&self, value: U256) {
self.address().fund(value);
}
Expand All @@ -1051,7 +1053,7 @@ impl Balance for Account {
}
}

impl<ST: StorageType + VMRouter + 'static> Balance for Contract<ST> {
impl<ST: StorageType + Router + 'static> Balance for Contract<ST> {
fn balance(&self) -> U256 {
self.address().balance()
}
Expand All @@ -1074,7 +1076,7 @@ impl FromTag for Account {
/// Also registers the tag in the test context for debugging purposes.
fn from_tag(tag: &str) -> Self {
let account = Account::from_seed(tag);
VMContext::current().set_tag(account.address(), tag.to_string());
VM::context().set_tag(account.address(), tag.to_string());
account
}
}
Expand All @@ -1086,12 +1088,12 @@ impl FromTag for Address {
fn from_tag(tag: &str) -> Self {
let hash = Keccak256::new().update(tag.as_bytes()).finalize();
let address = Address::from_slice(&hash[..20]);
VMContext::current().set_tag(address, tag.to_string());
VM::context().set_tag(address, tag.to_string());
address
}
}

impl<ST: StorageType + VMRouter + 'static> FromTag for Contract<ST> {
impl<ST: StorageType + Router + 'static> FromTag for Contract<ST> {
/// Creates a contract at an address derived from the tag string.
///
/// This allows deploying contracts to deterministic addresses for testing.
Expand All @@ -1107,7 +1109,7 @@ mod tests {
use stylus_sdk::prelude::*;

use super::{Account, Address, Contract, FromTag};
use crate::context::VMContext;
use crate::context::VM;

mod account {
use std::ops::Deref;
Expand Down Expand Up @@ -1210,10 +1212,7 @@ mod tests {

assert_eq!(expected_private_key, account.private_key);
assert_eq!(expected_address, account.address());
assert_eq!(
Some(tag),
VMContext::current().get_tag(account.address())
);
assert_eq!(Some(tag), VM::context().get_tag(account.address()));
}

#[test]
Expand All @@ -1225,7 +1224,7 @@ mod tests {
let address = Address::from_tag(&tag);

assert_eq!(expected_address, address);
assert_eq!(Some(tag), VMContext::current().get_tag(address));
assert_eq!(Some(tag), VM::context().get_tag(address));
}

#[storage]
Expand All @@ -1245,10 +1244,7 @@ mod tests {
let contract = Contract::<SomeContract>::from_tag(&tag);

assert_eq!(expected_address, contract.address());
assert_eq!(
Some(tag),
VMContext::current().get_tag(contract.address())
);
assert_eq!(Some(tag), VM::context().get_tag(contract.address()));
}

#[test]
Expand All @@ -1265,9 +1261,9 @@ mod tests {

// all addresses still map to the same tag
let tag = Some(tag.to_owned());
assert_eq!(tag, VMContext::current().get_tag(address));
assert_eq!(tag, VMContext::current().get_tag(account.address()));
assert_eq!(tag, VMContext::current().get_tag(contract.address()));
assert_eq!(tag, VM::context().get_tag(address));
assert_eq!(tag, VM::context().get_tag(account.address()));
assert_eq!(tag, VM::context().get_tag(contract.address()));
}
}
}
Loading