8282# The segment ID in the Kafka protocol is only the span ID.
8383SegmentKey = bytes
8484
85+ QueueKey = bytes
86+
8587
8688def _segment_key_to_span_id (segment_key : SegmentKey ) -> bytes :
8789 return parse_segment_key (segment_key )[2 ]
@@ -127,6 +129,11 @@ class OutputSpan(NamedTuple):
127129 payload : dict [str , Any ]
128130
129131
132+ class FlushedSegment (NamedTuple ):
133+ queue_key : QueueKey
134+ spans : list [OutputSpan ]
135+
136+
130137class SpansBuffer :
131138 def __init__ (
132139 self ,
@@ -204,13 +211,13 @@ def process_spans(self, spans: Sequence[Span], now: int):
204211 shard = self .assigned_shards [
205212 int (span .trace_id , 16 ) % len (self .assigned_shards )
206213 ]
207- queue_keys .append (f"span-buf:q: { shard } " )
214+ queue_keys .append (self . _get_queue_key ( shard ) )
208215
209216 results = p .execute ()
210217
211218 with metrics .timer ("spans.buffer.process_spans.update_queue" ):
212- queue_deletes : dict [str , set [bytes ]] = {}
213- queue_adds : dict [str , MutableMapping [str | bytes , int ]] = {}
219+ queue_deletes : dict [bytes , set [bytes ]] = {}
220+ queue_adds : dict [bytes , MutableMapping [str | bytes , int ]] = {}
214221
215222 assert len (queue_keys ) == len (results )
216223
@@ -266,6 +273,9 @@ def _ensure_script(self):
266273 self .add_buffer_sha = self .client .script_load (add_buffer_script .script )
267274 return self .add_buffer_sha
268275
276+ def _get_queue_key (self , shard : int ) -> bytes :
277+ return f"span-buf:q:{ shard } " .encode ("ascii" )
278+
269279 def _group_by_parent (self , spans : Sequence [Span ]) -> dict [tuple [str , str ], list [Span ]]:
270280 """
271281 Groups partial trees of spans by their top-most parent span ID in the
@@ -296,32 +306,33 @@ def _group_by_parent(self, spans: Sequence[Span]) -> dict[tuple[str, str], list[
296306
297307 return trees
298308
299- def flush_segments (
300- self , now : int , max_segments : int = 0
301- ) -> tuple [int , dict [SegmentKey , list [OutputSpan ]]]:
309+ def flush_segments (self , now : int , max_segments : int = 0 ) -> dict [SegmentKey , FlushedSegment ]:
302310 cutoff = now
303311
312+ queue_keys = []
313+
304314 with metrics .timer ("spans.buffer.flush_segments.load_segment_ids" ):
305315 with self .client .pipeline (transaction = False ) as p :
306316 for shard in self .assigned_shards :
307- key = f"span-buf:q: { shard } "
317+ key = self . _get_queue_key ( shard )
308318 p .zrangebyscore (
309319 key , 0 , cutoff , start = 0 if max_segments else None , num = max_segments or None
310320 )
311321 p .zcard (key )
322+ queue_keys .append (key )
312323
313324 result = iter (p .execute ())
314325
315- segment_keys = []
326+ segment_keys : list [ tuple [ QueueKey , SegmentKey ]] = []
316327 queue_sizes = []
317328
318329 with metrics .timer ("spans.buffer.flush_segments.load_segment_data" ):
319330 with self .client .pipeline (transaction = False ) as p :
320331 # ZRANGEBYSCORE output
321- for segment_span_ids in result :
332+ for queue_key , segment_span_ids in zip ( queue_keys , result ) :
322333 # process return value of zrevrangebyscore
323334 for segment_key in segment_span_ids :
324- segment_keys .append (segment_key )
335+ segment_keys .append (( queue_key , segment_key ) )
325336 p .smembers (segment_key )
326337
327338 # ZCARD output
@@ -340,10 +351,10 @@ def flush_segments(
340351
341352 num_has_root_spans = 0
342353
343- for segment_key , segment in zip (segment_keys , segments ):
354+ for ( queue_key , segment_key ) , segment in zip (segment_keys , segments ):
344355 segment_span_id = _segment_key_to_span_id (segment_key ).decode ("ascii" )
345356
346- return_segment = []
357+ output_spans = []
347358 has_root_span = False
348359 metrics .timing ("spans.buffer.flush_segments.num_spans_per_segment" , len (segment ))
349360 for payload in segment :
@@ -369,30 +380,30 @@ def flush_segments(
369380 },
370381 )
371382
372- return_segment .append (OutputSpan (payload = val ))
383+ output_spans .append (OutputSpan (payload = val ))
373384
374- return_segments [segment_key ] = return_segment
385+ return_segments [segment_key ] = FlushedSegment ( queue_key = queue_key , spans = output_spans )
375386 num_has_root_spans += int (has_root_span )
387+
376388 metrics .timing ("spans.buffer.flush_segments.num_segments" , len (return_segments ))
377389 metrics .timing ("spans.buffer.flush_segments.has_root_span" , num_has_root_spans )
378390
379- return sum ( queue_sizes ), return_segments
391+ return return_segments
380392
381- def done_flush_segments (self , segment_keys : dict [SegmentKey , list [ OutputSpan ] ]):
393+ def done_flush_segments (self , segment_keys : dict [SegmentKey , FlushedSegment ]):
382394 metrics .timing ("spans.buffer.done_flush_segments.num_segments" , len (segment_keys ))
383395 with metrics .timer ("spans.buffer.done_flush_segments" ):
384396 with self .client .pipeline (transaction = False ) as p :
385- for segment_key , output_spans in segment_keys .items ():
397+ for segment_key , flushed_segment in segment_keys .items ():
386398 hrs_key = b"span-buf:hrs:" + segment_key
387399 p .delete (hrs_key )
388400 p .unlink (segment_key )
389401
390402 project_id , trace_id , _ = parse_segment_key (segment_key )
391403 redirect_map_key = b"span-buf:sr:{%s:%s}" % (project_id , trace_id )
392- shard = self .assigned_shards [int (trace_id , 16 ) % len (self .assigned_shards )]
393- p .zrem (f"span-buf:q:{ shard } " .encode ("ascii" ), segment_key )
404+ p .zrem (flushed_segment .queue_key , segment_key )
394405
395- for span_batch in itertools .batched (output_spans , 100 ):
406+ for span_batch in itertools .batched (flushed_segment . spans , 100 ):
396407 p .hdel (
397408 redirect_map_key ,
398409 * [output_span .payload ["span_id" ] for output_span in span_batch ],
0 commit comments