Skip to content

Commit 84e6221

Browse files
committed
fix: preserve stream candidate order in done events
1 parent 05676a4 commit 84e6221

2 files changed

Lines changed: 129 additions & 1 deletion

File tree

rust-genai/src/models.rs

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,8 @@ impl GenerateContentEventStream {
122122
None => {
123123
self.finished = true;
124124
if self.saw_done.load(Ordering::Relaxed) {
125-
if let Some(response) = self.aggregate_response.take() {
125+
if let Some(mut response) = self.aggregate_response.take() {
126+
normalize_stream_candidate_order(&mut response);
126127
return Ok(Some(GenerateContentStreamEvent::Done(response)));
127128
}
128129
}
@@ -243,6 +244,39 @@ fn late_index_stream_position(
243244
.filter(|&candidate_position| aggregate_candidates[candidate_position].index.is_none())
244245
}
245246

247+
fn normalize_stream_candidate_order(response: &mut GenerateContentResponse) {
248+
if response.candidates.len() < 2 {
249+
return;
250+
}
251+
252+
let mut ordered = vec![None; response.candidates.len()];
253+
let mut unindexed = VecDeque::new();
254+
let mut overflow = VecDeque::new();
255+
256+
for candidate in std::mem::take(&mut response.candidates) {
257+
match candidate
258+
.index
259+
.and_then(|index| usize::try_from(index).ok())
260+
.filter(|&index| index < ordered.len())
261+
{
262+
Some(index) if ordered[index].is_none() => ordered[index] = Some(candidate),
263+
Some(_) => overflow.push_back(candidate),
264+
None if candidate.index.is_none() => unindexed.push_back(candidate),
265+
None => overflow.push_back(candidate),
266+
}
267+
}
268+
269+
for slot in &mut ordered {
270+
if slot.is_none() {
271+
*slot = unindexed.pop_front().or_else(|| overflow.pop_front());
272+
}
273+
}
274+
275+
response.candidates = ordered.into_iter().flatten().collect();
276+
response.candidates.extend(unindexed);
277+
response.candidates.extend(overflow);
278+
}
279+
246280
fn merge_candidate(
247281
existing: &mut rust_genai_types::response::Candidate,
248282
next: &rust_genai_types::response::Candidate,

rust-genai/src/models/tests.rs

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1565,6 +1565,100 @@ fn test_merge_stream_response_merges_late_index_into_sparse_multi_candidate() {
15651565
);
15661566
}
15671567

1568+
#[test]
1569+
fn test_normalize_stream_candidate_order_reorders_indexed_candidates() {
1570+
let mut response = GenerateContentResponse {
1571+
sdk_http_response: None,
1572+
candidates: vec![
1573+
Candidate {
1574+
content: Some(Content::from_parts(vec![Part::text("second")], Role::Model)),
1575+
citation_metadata: None,
1576+
finish_message: None,
1577+
token_count: None,
1578+
finish_reason: None,
1579+
avg_logprobs: None,
1580+
grounding_metadata: None,
1581+
index: Some(1),
1582+
logprobs_result: None,
1583+
safety_ratings: Vec::new(),
1584+
url_context_metadata: None,
1585+
},
1586+
Candidate {
1587+
content: Some(Content::from_parts(vec![Part::text("first")], Role::Model)),
1588+
citation_metadata: None,
1589+
finish_message: None,
1590+
token_count: None,
1591+
finish_reason: None,
1592+
avg_logprobs: None,
1593+
grounding_metadata: None,
1594+
index: Some(0),
1595+
logprobs_result: None,
1596+
safety_ratings: Vec::new(),
1597+
url_context_metadata: None,
1598+
},
1599+
],
1600+
create_time: None,
1601+
automatic_function_calling_history: None,
1602+
prompt_feedback: None,
1603+
usage_metadata: None,
1604+
model_version: None,
1605+
response_id: None,
1606+
};
1607+
1608+
normalize_stream_candidate_order(&mut response);
1609+
1610+
assert_eq!(response.text().as_deref(), Some("first"));
1611+
assert_eq!(response.candidates[0].index, Some(0));
1612+
assert_eq!(response.candidates[1].index, Some(1));
1613+
}
1614+
1615+
#[test]
1616+
fn test_normalize_stream_candidate_order_preserves_unindexed_gap_positions() {
1617+
let mut response = GenerateContentResponse {
1618+
sdk_http_response: None,
1619+
candidates: vec![
1620+
Candidate {
1621+
content: Some(Content::from_parts(vec![Part::text("second")], Role::Model)),
1622+
citation_metadata: None,
1623+
finish_message: None,
1624+
token_count: None,
1625+
finish_reason: None,
1626+
avg_logprobs: None,
1627+
grounding_metadata: None,
1628+
index: Some(1),
1629+
logprobs_result: None,
1630+
safety_ratings: Vec::new(),
1631+
url_context_metadata: None,
1632+
},
1633+
Candidate {
1634+
content: Some(Content::from_parts(vec![Part::text("first")], Role::Model)),
1635+
citation_metadata: None,
1636+
finish_message: None,
1637+
token_count: None,
1638+
finish_reason: None,
1639+
avg_logprobs: None,
1640+
grounding_metadata: None,
1641+
index: None,
1642+
logprobs_result: None,
1643+
safety_ratings: Vec::new(),
1644+
url_context_metadata: None,
1645+
},
1646+
],
1647+
create_time: None,
1648+
automatic_function_calling_history: None,
1649+
prompt_feedback: None,
1650+
usage_metadata: None,
1651+
model_version: None,
1652+
response_id: None,
1653+
};
1654+
1655+
normalize_stream_candidate_order(&mut response);
1656+
1657+
assert_eq!(response.text().as_deref(), Some("first"));
1658+
assert_eq!(response.candidates[0].index, None);
1659+
assert_eq!(response.candidates[1].index, Some(1));
1660+
}
1661+
15681662
#[test]
15691663
fn test_stream_merge_helpers_respect_context_and_targets() {
15701664
let resolution_low = PartMediaResolution {

0 commit comments

Comments
 (0)