Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
204 changes: 202 additions & 2 deletions src/seq/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -292,10 +292,11 @@ pub trait IteratorRandom: Iterator + Sized {
/// available, complexity is `O(n)` where `n` is the iterator length.
/// Partial hints (where `lower > 0`) also improve performance.
///
/// Note that the output values and the the number of RNG samples used
/// Note that the output values and the number of RNG samples used
/// depends on size hints. In particular, `Iterator` combinators that don't
/// change the values yielded but change the size hints may result in
/// `choose` returning different elements.
/// `choose` returning different elements. If you want consistent results
/// and RNG usage consider using [`choose_stable`].
fn choose<R>(mut self, rng: &mut R) -> Option<Self::Item>
where R: Rng + ?Sized {
let (mut lower, mut upper) = self.size_hint();
Expand Down Expand Up @@ -347,6 +348,62 @@ pub trait IteratorRandom: Iterator + Sized {
}
}

/// Choose one element at random from the iterator.
///
/// Returns `None` if and only if the iterator is empty.
///
/// This method is very similar to [`choose`] except that the result
/// only depends on the length of the iterator and the values produced by
/// `rng`. Notably for any iterator of a given length this will make the
/// same requests to `rng` and if the same sequence of values are produced
/// the same index will be selected from `self`. This may be useful if you
/// need consistent results no matter what type of iterator you are working
/// with. If you do not need this stability prefer [`choose`].
///
/// Note that this method still uses [`Iterator::size_hint`] to skip
/// constructing elements where possible, however the selection and `rng`
/// calls are the same in the face of this optimization. If you want to
/// force every element to be created regardless call `.inspect(|e| ())`.
fn choose_stable<R>(mut self, rng: &mut R) -> Option<Self::Item>
where R: Rng + ?Sized {
let mut consumed = 0;
let mut result = None;

loop {
// Currently the only way to skip elements is `nth()`. So we need to
// store what index to access next here.
// This should be replaced by `advance_by()` once it is stable:
// https://github.com/rust-lang/rust/issues/77404
let mut next = 0;

let (lower, _) = self.size_hint();
if lower >= 2 {
let highest_selected = (0..lower)
.filter(|ix| gen_index(rng, consumed+ix+1) == 0)
.last();

consumed += lower;
next = lower;

if let Some(ix) = highest_selected {
result = self.nth(ix);
next -= ix + 1;
debug_assert!(result.is_some(), "iterator shorter than size_hint().0");
}
}

let elem = self.nth(next);
if elem.is_none() {
return result
}

if gen_index(rng, consumed+1) == 0 {
result = elem;
}
consumed += 1;
}
}

/// Collects values at random from the iterator into a supplied buffer
/// until that buffer is filled.
///
Expand Down Expand Up @@ -794,6 +851,103 @@ mod test {
assert_eq!(UnhintedIterator { iter: 0..0 }.choose(r), None);
}

#[test]
#[cfg_attr(miri, ignore)] // Miri is too slow
fn test_iterator_choose_stable() {
let r = &mut crate::test::rng(109);
fn test_iter<R: Rng + ?Sized, Iter: Iterator<Item = usize> + Clone>(r: &mut R, iter: Iter) {
let mut chosen = [0i32; 9];
for _ in 0..1000 {
let picked = iter.clone().choose_stable(r).unwrap();
chosen[picked] += 1;
}
for count in chosen.iter() {
// Samples should follow Binomial(1000, 1/9)
// Octave: binopdf(x, 1000, 1/9) gives the prob of *count == x
// Note: have seen 153, which is unlikely but not impossible.
assert!(
72 < *count && *count < 154,
"count not close to 1000/9: {}",
count
);
}
}

test_iter(r, 0..9);
test_iter(r, [0, 1, 2, 3, 4, 5, 6, 7, 8].iter().cloned());
#[cfg(feature = "alloc")]
test_iter(r, (0..9).collect::<Vec<_>>().into_iter());
test_iter(r, UnhintedIterator { iter: 0..9 });
test_iter(r, ChunkHintedIterator {
iter: 0..9,
chunk_size: 4,
chunk_remaining: 4,
hint_total_size: false,
});
test_iter(r, ChunkHintedIterator {
iter: 0..9,
chunk_size: 4,
chunk_remaining: 4,
hint_total_size: true,
});
test_iter(r, WindowHintedIterator {
iter: 0..9,
window_size: 2,
hint_total_size: false,
});
test_iter(r, WindowHintedIterator {
iter: 0..9,
window_size: 2,
hint_total_size: true,
});

assert_eq!((0..0).choose(r), None);
assert_eq!(UnhintedIterator { iter: 0..0 }.choose(r), None);
}

#[test]
#[cfg_attr(miri, ignore)] // Miri is too slow
fn test_iterator_choose_stable_stability() {
fn test_iter(iter: impl Iterator<Item = usize> + Clone) -> [i32; 9] {
let r = &mut crate::test::rng(109);
let mut chosen = [0i32; 9];
for _ in 0..1000 {
let picked = iter.clone().choose_stable(r).unwrap();
chosen[picked] += 1;
}
chosen
}

let reference = test_iter(0..9);
assert_eq!(test_iter([0, 1, 2, 3, 4, 5, 6, 7, 8].iter().cloned()), reference);

#[cfg(feature = "alloc")]
assert_eq!(test_iter((0..9).collect::<Vec<_>>().into_iter()), reference);
assert_eq!(test_iter(UnhintedIterator { iter: 0..9 }), reference);
assert_eq!(test_iter(ChunkHintedIterator {
iter: 0..9,
chunk_size: 4,
chunk_remaining: 4,
hint_total_size: false,
}), reference);
assert_eq!(test_iter(ChunkHintedIterator {
iter: 0..9,
chunk_size: 4,
chunk_remaining: 4,
hint_total_size: true,
}), reference);
assert_eq!(test_iter(WindowHintedIterator {
iter: 0..9,
window_size: 2,
hint_total_size: false,
}), reference);
assert_eq!(test_iter(WindowHintedIterator {
iter: 0..9,
window_size: 2,
hint_total_size: true,
}), reference);
}

#[test]
#[cfg_attr(miri, ignore)] // Miri is too slow
fn test_shuffle() {
Expand Down Expand Up @@ -999,6 +1153,52 @@ mod test {
);
}

#[test]
fn value_stability_choose_stable() {
fn choose<I: Iterator<Item = u32>>(iter: I) -> Option<u32> {
let mut rng = crate::test::rng(411);
iter.choose_stable(&mut rng)
}

assert_eq!(choose([].iter().cloned()), None);
assert_eq!(choose(0..100), Some(40));
assert_eq!(choose(UnhintedIterator { iter: 0..100 }), Some(40));
assert_eq!(
choose(ChunkHintedIterator {
iter: 0..100,
chunk_size: 32,
chunk_remaining: 32,
hint_total_size: false,
}),
Some(40)
);
assert_eq!(
choose(ChunkHintedIterator {
iter: 0..100,
chunk_size: 32,
chunk_remaining: 32,
hint_total_size: true,
}),
Some(40)
);
assert_eq!(
choose(WindowHintedIterator {
iter: 0..100,
window_size: 32,
hint_total_size: false,
}),
Some(40)
);
assert_eq!(
choose(WindowHintedIterator {
iter: 0..100,
window_size: 32,
hint_total_size: true,
}),
Some(40)
);
}

#[test]
fn value_stability_choose_multiple() {
fn do_test<I: Iterator<Item = u32>>(iter: I, v: &[u32]) {
Expand Down