diff --git a/pallets/tasks/src/lib.rs b/pallets/tasks/src/lib.rs index baf2d6d91..8a265a8a1 100644 --- a/pallets/tasks/src/lib.rs +++ b/pallets/tasks/src/lib.rs @@ -42,6 +42,7 @@ pub mod pallet { fn set_send_message_task_reward() -> Weight; fn cancel_task() -> Weight; fn reset_tasks() -> Weight; + fn set_shard_task_limit() -> Weight; fn unregister_gateways() -> Weight; } @@ -86,6 +87,10 @@ pub mod pallet { Weight::default() } + fn set_shard_task_limit() -> Weight { + Weight::default() + } + fn unregister_gateways() -> Weight { Weight::default() } @@ -135,6 +140,11 @@ pub mod pallet { pub type UnassignedTasks = StorageDoubleMap<_, Blake2_128Concat, NetworkId, Blake2_128Concat, TaskId, (), OptionQuery>; + #[pallet::storage] + #[pallet::getter(fn shard_task_limit)] + pub type ShardTaskLimit = + StorageMap<_, Blake2_128Concat, NetworkId, u32, OptionQuery>; + #[pallet::storage] pub type ShardTasks = StorageDoubleMap<_, Blake2_128Concat, ShardId, Blake2_128Concat, TaskId, (), OptionQuery>; @@ -253,6 +263,8 @@ pub mod pallet { WriteTaskRewardSet(NetworkId, BalanceOf), /// Send message task reward set for network SendMessageTaskRewardSet(NetworkId, BalanceOf), + /// Set the maximum number of assigned tasks for all shards on the network + ShardTaskLimitSet(NetworkId, u32), } #[pallet::error] @@ -487,7 +499,7 @@ pub mod pallet { Ok(()) } - #[pallet::call_index(10)] + #[pallet::call_index(9)] #[pallet::weight(::WeightInfo::cancel_task())] pub fn sudo_cancel_tasks(origin: OriginFor) -> DispatchResult { ensure_root(origin)?; @@ -502,7 +514,7 @@ pub mod pallet { Ok(()) } - #[pallet::call_index(9)] + #[pallet::call_index(10)] #[pallet::weight(::WeightInfo::reset_tasks())] pub fn reset_tasks(origin: OriginFor) -> DispatchResult { ensure_root(origin)?; @@ -524,6 +536,19 @@ pub mod pallet { } #[pallet::call_index(11)] + #[pallet::weight(::WeightInfo::set_shard_task_limit())] + pub fn set_shard_task_limit( + origin: OriginFor, + network: NetworkId, + limit: u32, + ) -> DispatchResult { + ensure_root(origin)?; + ShardTaskLimit::::insert(network, limit); + Self::deposit_event(Event::ShardTaskLimitSet(network, limit)); + Ok(()) + } + + #[pallet::call_index(12)] #[pallet::weight(::WeightInfo::unregister_gateways())] pub fn unregister_gateways(origin: OriginFor) -> DispatchResult { ensure_root(origin)?; @@ -829,10 +854,17 @@ pub mod pallet { } fn schedule_tasks_shard(network: NetworkId, shard_id: ShardId) { - let tasks = ShardTasks::::iter_prefix(shard_id).count(); + let tasks = ShardTasks::::iter_prefix(shard_id) + .filter(|(t, _)| TaskOutput::::get(t).is_none()) + .count(); let shard_size = T::Shards::shard_members(shard_id).len() as u16; let is_registered = ShardRegistered::::get(shard_id).is_some(); - let capacity = 10.saturating_sub(tasks); + let shard_task_limit = ShardTaskLimit::::get(network).unwrap_or(10) as usize; + let capacity = shard_task_limit.saturating_sub(tasks); + if capacity.is_zero() { + // no new tasks assigned if capacity reached or exceeded + return; + } let tasks = UnassignedTasks::::iter_prefix(network) .filter(|(task_id, _)| { let Some(task) = Tasks::::get(task_id) else { return false }; diff --git a/pallets/tasks/src/tests.rs b/pallets/tasks/src/tests.rs index 2de42a302..ca7abb721 100644 --- a/pallets/tasks/src/tests.rs +++ b/pallets/tasks/src/tests.rs @@ -1,8 +1,9 @@ use crate::mock::*; use crate::{ Error, Event, Gateway, NetworkReadReward, NetworkSendMessageReward, NetworkShards, - NetworkWriteReward, ShardRegistered, ShardTasks, SignerPayout, TaskHash, TaskIdCounter, - TaskOutput, TaskPhaseState, TaskRewardConfig, TaskSignature, TaskSigner, UnassignedTasks, + NetworkWriteReward, ShardRegistered, ShardTaskLimit, ShardTasks, SignerPayout, TaskHash, + TaskIdCounter, TaskOutput, TaskPhaseState, TaskRewardConfig, TaskSignature, TaskSigner, + UnassignedTasks, }; use frame_support::traits::Get; use frame_support::{assert_noop, assert_ok}; @@ -1591,6 +1592,24 @@ fn register_gateway_fails_previous_shard_registration_tasks() { }); } +#[test] +fn set_shard_task_limit_updates_storage_and_emits_event() { + new_test_ext().execute_with(|| { + Shards::create_shard( + ETHEREUM, + [[0u8; 32].into(), [1u8; 32].into(), [2u8; 32].into()].to_vec(), + 1, + ); + assert_eq!(ShardTaskLimit::::get(ETHEREUM), None); + assert_ok!(Tasks::set_shard_task_limit(RawOrigin::Root.into(), ETHEREUM, 5)); + assert_eq!(ShardTaskLimit::::get(ETHEREUM), Some(5)); + System::assert_last_event(Event::::ShardTaskLimitSet(ETHEREUM, 5).into()); + assert_ok!(Tasks::set_shard_task_limit(RawOrigin::Root.into(), ETHEREUM, 50)); + assert_eq!(ShardTaskLimit::::get(ETHEREUM), Some(50)); + System::assert_last_event(Event::::ShardTaskLimitSet(ETHEREUM, 50).into()); + }); +} + #[test] fn cancel_task_sets_task_output_to_err() { new_test_ext().execute_with(|| { @@ -1618,6 +1637,42 @@ fn cancel_task_sets_task_output_to_err() { }); } +#[test] +fn set_shard_task_limit_successfully_limits_task_assignment() { + new_test_ext().execute_with(|| { + Shards::create_shard( + ETHEREUM, + [[0u8; 32].into(), [1u8; 32].into(), [2u8; 32].into()].to_vec(), + 1, + ); + ShardState::::insert(0, ShardStatus::Online); + Tasks::shard_online(0, ETHEREUM); + for _ in 0..5 { + assert_ok!(Tasks::create_task( + RawOrigin::Signed([0; 32].into()).into(), + mock_task(ETHEREUM) + )); + } + assert_eq!(ShardTasks::::iter_prefix(0).count(), 5); + assert_eq!(UnassignedTasks::::iter().collect::>().len(), 0); + assert_ok!(Tasks::set_shard_task_limit(RawOrigin::Root.into(), ETHEREUM, 5)); + assert_ok!(Tasks::create_task( + RawOrigin::Signed([0; 32].into()).into(), + mock_task(ETHEREUM) + )); + assert_eq!(ShardTasks::::iter_prefix(0).count(), 5); + assert_eq!(UnassignedTasks::::iter().collect::>().len(), 1); + assert_ok!(Tasks::set_shard_task_limit(RawOrigin::Root.into(), ETHEREUM, 6)); + assert_ok!(Tasks::create_task( + RawOrigin::Signed([0; 32].into()).into(), + mock_task(ETHEREUM) + )); + assert_eq!(ShardTasks::::iter_prefix(0).count(), 6); + assert_eq!(UnassignedTasks::::iter().collect::>().len(), 1); + assert_ok!(Tasks::set_shard_task_limit(RawOrigin::Root.into(), ETHEREUM, 6)); + }); +} + #[test] fn unregister_gateways_removes_all_gateways_and_shard_registrations() { new_test_ext().execute_with(|| { diff --git a/runtime/src/weights/tasks.rs b/runtime/src/weights/tasks.rs index c5d35db94..9c69793c3 100644 --- a/runtime/src/weights/tasks.rs +++ b/runtime/src/weights/tasks.rs @@ -237,4 +237,8 @@ impl pallet_tasks::WeightInfo for WeightInfo { fn unregister_gateways() -> Weight { Weight::from_parts(0, 0) } + + fn set_shard_task_limit() -> Weight { + Weight::from_parts(0, 0) + } }