@@ -8,6 +8,7 @@ use kvproto::{
88 metapb:: { Region , RegionEpoch } ,
99 pdpb:: CheckPolicy ,
1010 raft_cmdpb:: { ComputeHashRequest , RaftCmdRequest } ,
11+ raft_serverpb:: RaftMessage ,
1112} ;
1213use protobuf:: Message ;
1314use raft:: eraftpb;
@@ -278,6 +279,7 @@ impl_box_observer_g!(
278279 ConsistencyCheckObserver ,
279280 WrappedConsistencyCheckObserver
280281) ;
282+ impl_box_observer ! ( BoxMessageObserver , MessageObserver , WrappedMessageObserver ) ;
281283
282284/// Registry contains all registered coprocessors.
283285#[ derive( Clone ) ]
@@ -296,6 +298,7 @@ where
296298 read_index_observers : Vec < Entry < BoxReadIndexObserver > > ,
297299 pd_task_observers : Vec < Entry < BoxPdTaskObserver > > ,
298300 update_safe_ts_observers : Vec < Entry < BoxUpdateSafeTsObserver > > ,
301+ message_observers : Vec < Entry < BoxMessageObserver > > ,
299302 // TODO: add endpoint
300303}
301304
@@ -313,6 +316,7 @@ impl<E: KvEngine> Default for Registry<E> {
313316 read_index_observers : Default :: default ( ) ,
314317 pd_task_observers : Default :: default ( ) ,
315318 update_safe_ts_observers : Default :: default ( ) ,
319+ message_observers : Default :: default ( ) ,
316320 }
317321 }
318322}
@@ -381,6 +385,10 @@ impl<E: KvEngine> Registry<E> {
381385 pub fn register_update_safe_ts_observer ( & mut self , priority : u32 , qo : BoxUpdateSafeTsObserver ) {
382386 push ! ( priority, qo, self . update_safe_ts_observers) ;
383387 }
388+
389+ pub fn register_message_observer ( & mut self , priority : u32 , qo : BoxMessageObserver ) {
390+ push ! ( priority, qo, self . message_observers) ;
391+ }
384392}
385393
386394/// A macro that loops over all observers and returns early when error is found
@@ -780,6 +788,17 @@ impl<E: KvEngine> CoprocessorHost<E> {
780788 true
781789 }
782790
791+ /// Returns false if the message should not be stepped later.
792+ pub fn on_raft_message ( & self , msg : & RaftMessage ) -> bool {
793+ for observer in & self . registry . message_observers {
794+ let observer = observer. observer . inner ( ) ;
795+ if !observer. on_raft_message ( msg) {
796+ return false ;
797+ }
798+ }
799+ true
800+ }
801+
783802 pub fn on_flush_applied_cmd_batch (
784803 & self ,
785804 max_level : ObserveLevel ,
@@ -890,6 +909,7 @@ mod tests {
890909 OnUpdateSafeTs = 23 ,
891910 PrePersist = 24 ,
892911 PreWriteApplyState = 25 ,
912+ OnRaftMessage = 26 ,
893913 }
894914
895915 impl Coprocessor for TestCoprocessor { }
@@ -1132,6 +1152,14 @@ mod tests {
11321152 }
11331153 }
11341154
1155+ impl MessageObserver for TestCoprocessor {
1156+ fn on_raft_message ( & self , _: & RaftMessage ) -> bool {
1157+ self . called
1158+ . fetch_add ( ObserverIndex :: OnRaftMessage as usize , Ordering :: SeqCst ) ;
1159+ true
1160+ }
1161+ }
1162+
11351163 macro_rules! assert_all {
11361164 ( $target: expr, $expect: expr) => { {
11371165 for ( c, e) in ( $target) . iter( ) . zip( $expect) {
@@ -1168,6 +1196,8 @@ mod tests {
11681196 . register_cmd_observer ( 1 , BoxCmdObserver :: new ( ob. clone ( ) ) ) ;
11691197 host. registry
11701198 . register_update_safe_ts_observer ( 1 , BoxUpdateSafeTsObserver :: new ( ob. clone ( ) ) ) ;
1199+ host. registry
1200+ . register_message_observer ( 1 , BoxMessageObserver :: new ( ob. clone ( ) ) ) ;
11711201
11721202 let mut index: usize = 0 ;
11731203 let region = Region :: default ( ) ;
@@ -1282,6 +1312,11 @@ mod tests {
12821312 host. pre_write_apply_state ( & region) ;
12831313 index += ObserverIndex :: PreWriteApplyState as usize ;
12841314 assert_all ! ( [ & ob. called] , & [ index] ) ;
1315+
1316+ let msg = RaftMessage :: default ( ) ;
1317+ host. on_raft_message ( & msg) ;
1318+ index += ObserverIndex :: OnRaftMessage as usize ;
1319+ assert_all ! ( [ & ob. called] , & [ index] ) ;
12851320 }
12861321
12871322 #[ test]
0 commit comments