2424 _get_global_group ,
2525 _warn_cur_rank_not_in_group ,
2626)
27+ from paddle .distributed .communication .serialization_utils import (
28+ convert_object_to_tensor ,
29+ convert_tensor_to_object ,
30+ )
2731from paddle .framework .recall_error import check_naninf
2832from paddle .utils import strtobool
2933
@@ -58,10 +62,12 @@ def __init__(self):
5862 def init_or_erase_meta (self ):
5963 self .send_shape_message = None
6064 self .send_dtype_message = None
65+ self .send_key_message = None
6166
6267 self .recv_shape_message = None
6368 self .recv_dtype_message = None
6469 self .recv_stop_gradient = None
70+ self .recv_key_message = None
6571
6672 self .has_send_meta = False
6773 self .has_recv_meta = False
@@ -99,17 +105,31 @@ def recv_meta(self, group, reverse=False, broadcast=False):
99105 shapes = []
100106 dtypes = []
101107 stop_grads = []
108+ keys = []
102109
103110 for _ in range (tensor_num ):
104111 shape_len = data .pop (0 )
105112 shape = data [:shape_len ]
106113 data = data [shape_len :]
107114 dtype_number = data .pop (0 )
108115 stop_gradient = bool (data .pop (0 ))
116+ # ------------------tensor key meta send-------------
117+ key_len = data .pop (0 )
118+ key_data = data [:key_len ]
119+ if key_len > 0 :
120+ key = convert_tensor_to_object (
121+ paddle .to_tensor (key_data ).astype ("uint8" ),
122+ paddle .to_tensor (key_len ),
123+ )
124+ else :
125+ key = None
126+ data = data [key_len :]
127+ # ------------------tensor key meta send-------------
109128
110129 shapes .append (shape )
111130 dtypes .append (dtype_number )
112131 stop_grads .append (stop_gradient )
132+ keys .append (key )
113133
114134 assert (
115135 len (data ) == 0
@@ -119,10 +139,12 @@ def recv_meta(self, group, reverse=False, broadcast=False):
119139 self .recv_shape_message = shapes [0 ]
120140 self .recv_dtype_message = dtypes [0 ]
121141 self .recv_stop_gradient = stop_grads [0 ]
142+ self .recv_key_message = keys [0 ]
122143 else :
123144 self .recv_shape_message = tuple (shapes )
124145 self .recv_dtype_message = tuple (dtypes )
125146 self .recv_stop_gradient = tuple (stop_grads )
147+ self .recv_key_message = tuple (keys )
126148
127149 def send_meta (self , tensor , group , reverse = False , broadcast = False ):
128150 if reverse :
@@ -152,12 +174,24 @@ def send_meta(self, tensor, group, reverse=False, broadcast=False):
152174
153175 for t in tensors_to_send :
154176 assert isinstance (t , paddle .Tensor )
177+ # ------------------tensor key meta send-------------
178+ if hasattr (t , "key" ):
179+ current_tensor_name = t .key
180+ key_data_tensor , _ = convert_object_to_tensor (
181+ current_tensor_name
182+ )
183+ key_data = key_data_tensor .numpy ().tolist ()
184+ else :
185+ key_data = []
186+ # ------------------tensor key meta send-------------
155187 data .extend (
156188 [
157189 len (t .shape ),
158190 * t .shape ,
159191 paddle_2_number (t .dtype ),
160192 int (t .stop_gradient ),
193+ len (key_data ),
194+ * key_data ,
161195 ]
162196 )
163197
@@ -184,35 +218,44 @@ def send_meta(self, tensor, group, reverse=False, broadcast=False):
184218
185219 def _obtain_send_message (self , tensor ):
186220 if isinstance (tensor , paddle .Tensor ):
187- return tensor .shape , paddle_2_number (tensor .dtype )
221+ key = tensor .key if hasattr (tensor , "key" ) else None
222+ return tensor .shape , paddle_2_number (tensor .dtype ), key
188223 else :
189224 shapes = []
190225 dtypes = []
226+ keys = []
191227 for d in tensor :
192228 assert isinstance (d , paddle .Tensor )
193229 if d .stop_gradient :
194230 continue
195- shape , dtype = self ._obtain_send_message (d )
231+ shape , dtype , key = self ._obtain_send_message (d )
196232 shapes .append (shape )
197233 dtypes .append (dtype )
198- return tuple (shapes ), tuple (dtypes )
234+ keys .append (key )
235+ return tuple (shapes ), tuple (dtypes ), tuple (keys )
199236
200237 def set_send_message (self , tensor ):
201238 (
202239 self .send_shape_message ,
203240 self .send_dtype_message ,
241+ self .send_key_message , # (key1_str, key2_str, key3_str ... )
204242 ) = self ._obtain_send_message (tensor )
205243
206244 def check_send_message (self , tensor ):
207245 if self .send_shape_message is None or self .send_dtype_message is None :
208246 return
209- actual_shape , actual_dtype = self ._obtain_send_message (tensor )
247+ actual_shape , actual_dtype , actual_key = self ._obtain_send_message (
248+ tensor
249+ )
210250 assert (
211251 self .send_shape_message == actual_shape
212252 ), f"send_shape_message: { self .send_shape_message } , actual_shape: { actual_shape } "
213253 assert (
214254 self .send_dtype_message == actual_dtype
215255 ), f"send_dtype_message: { self .send_dtype_message } , actual_dtype: { actual_dtype } "
256+ assert (
257+ self .send_key_message == actual_key
258+ ), f"send_key_message: { self .send_key_message } , actual_key: { actual_key } "
216259
217260 def __repr__ (self ):
218261 return f"send_shape_message: { self .send_shape_message } , send_dtype_message: { self .send_dtype_message } , recv_shape_message: { self .recv_shape_message } , recv_dtype_message: { self .recv_dtype_message } , recv_stop_gradient: { self .recv_stop_gradient } "
@@ -619,9 +662,11 @@ def _p2p_helper(
619662 recv_shape_msg = send_recv_meta .recv_shape_message
620663 recv_dtype_msg = send_recv_meta .recv_dtype_message
621664 recv_stop_gradient = send_recv_meta .recv_stop_gradient
665+ recv_key_msg = send_recv_meta .recv_key_message
622666
623667 send_shape_msg = send_recv_meta .send_shape_message
624668 send_dtype_msg = send_recv_meta .send_dtype_message
669+ # backward has no key meta message
625670
626671 # model parallel message
627672 mp_group = _hcg .get_model_parallel_group ()
@@ -636,13 +681,17 @@ def _p2p_helper(
636681 shape = shape , dtype = number_2_dtype (recv_dtype_msg [idx ])
637682 )
638683 tmp .stop_gradient = recv_stop_gradient [idx ]
684+ if recv_key_msg [idx ] is not None :
685+ tmp .key = recv_key_msg [idx ]
639686 tensor_recv_prev .append (tmp )
640687 tensor_recv_prev = tuple (tensor_recv_prev )
641688 else :
642689 tensor_recv_prev = paddle .empty (
643690 shape = recv_shape_msg , dtype = number_2_dtype (recv_dtype_msg )
644691 )
645692 tensor_recv_prev .stop_gradient = recv_stop_gradient
693+ if recv_key_msg is not None :
694+ tensor_recv_prev .key = recv_key_msg
646695
647696 if recv_next :
648697 if dynamic_shape :
0 commit comments