@@ -167,11 +167,22 @@ def generate(
167167 bootstrap_host : Optional [Union [List [str ], str ]] = None ,
168168 bootstrap_port : Optional [Union [List [int ], int ]] = None ,
169169 bootstrap_room : Optional [Union [List [int ], int ]] = None ,
170+ data_parallel_rank : Optional [int ] = None ,
170171 ) -> Union [Dict , Iterator [Dict ]]:
171172 """
172173 The arguments of this function is the same as `sglang/srt/managers/io_struct.py::GenerateReqInput`.
173174 Please refer to `GenerateReqInput` for the documentation.
174175 """
176+ if self .server_args .enable_dp_attention :
177+ if data_parallel_rank is None :
178+ logger .info ("data_parallel_rank not provided, using default dispatch" )
179+ elif data_parallel_rank < 0 :
180+ raise ValueError ("data_parallel_rank must be non-negative" )
181+ elif data_parallel_rank >= self .server_args .dp_size :
182+ raise ValueError (
183+ f"data_parallel_rank must be less than dp_size: { self .server_args .dp_size } "
184+ )
185+
175186 obj = GenerateReqInput (
176187 text = prompt ,
177188 input_ids = input_ids ,
@@ -188,6 +199,7 @@ def generate(
188199 bootstrap_host = bootstrap_host ,
189200 bootstrap_port = bootstrap_port ,
190201 bootstrap_room = bootstrap_room ,
202+ data_parallel_rank = data_parallel_rank ,
191203 )
192204 loop = asyncio .get_event_loop ()
193205 generator = self .tokenizer_manager .generate_request (obj , None )
@@ -237,11 +249,24 @@ async def async_generate(
237249 bootstrap_host : Optional [Union [List [str ], str ]] = None ,
238250 bootstrap_port : Optional [Union [List [int ], int ]] = None ,
239251 bootstrap_room : Optional [Union [List [int ], int ]] = None ,
252+ data_parallel_rank : Optional [int ] = None ,
240253 ) -> Union [Dict , AsyncIterator [Dict ]]:
241254 """
242255 The arguments of this function is the same as `sglang/srt/managers/io_struct.py::GenerateReqInput`.
243256 Please refer to `GenerateReqInput` for the documentation.
244257 """
258+
259+ if self .server_args .enable_dp_attention :
260+ if data_parallel_rank is None :
261+ logger .info ("data_parallel_rank not provided, using default dispatch" )
262+ elif data_parallel_rank < 0 :
263+ raise ValueError ("data_parallel_rank must be non-negative" )
264+ elif data_parallel_rank >= self .server_args .dp_size :
265+ raise ValueError (
266+ f"data_parallel_rank must be in range [0, { self .server_args .dp_size - 1 } ]"
267+ )
268+
269+ logger .info (f"data_parallel_rank: { data_parallel_rank } " )
245270 obj = GenerateReqInput (
246271 text = prompt ,
247272 input_ids = input_ids ,
@@ -257,6 +282,7 @@ async def async_generate(
257282 bootstrap_host = bootstrap_host ,
258283 bootstrap_port = bootstrap_port ,
259284 bootstrap_room = bootstrap_room ,
285+ data_parallel_rank = data_parallel_rank ,
260286 )
261287 generator = self .tokenizer_manager .generate_request (obj , None )
262288
0 commit comments