Skip to content
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
0360aa9
rand_distr: Add Zipf distribution
vks Jun 13, 2021
1e1e768
Update changelog
vks Jun 13, 2021
a57247d
Zipf: Use `OpenClosed01`
vks Jun 15, 2021
718e71b
Zipf: Add benchmark
vks Jun 15, 2021
c2ecf1b
Fix value stability tests
vks Jun 15, 2021
6c27184
Rename `Zipf` to `Zeta`
vks Jun 15, 2021
b06c2f6
Don't claim `Zeta` follows Zipf's law
vks Jun 15, 2021
a07b321
rand_distr: Add Zipf (not zeta) distribution
vks Jun 15, 2021
6270248
Zipf: Fix `s = 1` special case
vks Jun 16, 2021
4d67af2
Zipf: Mention that rounding may occur
vks Jun 16, 2021
139e898
Zipf: Simplify trait bounds
vks Jun 16, 2021
f514fd6
Zipf: Simplify calculation of ratio
vks Jun 16, 2021
ccaa4de
Zipf: Update benchmarks
vks Jun 16, 2021
3cccc64
Zeta: Inline distribution methods
vks Jun 16, 2021
14d55f8
Group `Zeta` and `Zipf` with rate-related distributions
vks Jun 16, 2021
85f55b2
Zeta and Zipf: Improve docs
vks Jul 27, 2021
2a33433
Zeta: Replace likely impossible if with debug_assert
vks Jul 27, 2021
e19349c
Give credit for implementation details
vks Jul 28, 2021
a746fd2
Zipf: Fix `inv_cdf` for `s = 1`
vks Jul 28, 2021
b053683
Zipf: Correctly calculate rejection ratio
vks Jul 28, 2021
0f9243c
Zipf: Add debug_assert for invariant
vks Jul 28, 2021
e5aff9a
Zipf: Avoid division inside loop
vks Jul 30, 2021
a32cd08
Zeta: Mention algorithm in doc comment
vks Jul 30, 2021
72a6333
Zeta: Avoid division in rejection criterion
vks Jul 30, 2021
cf4b7e4
Zeta: Fix infinite loop for small `a`
vks Aug 2, 2021
fe5a6e1
Zeta: Document cases where infinity is returned
vks Aug 3, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions rand_distr/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/)
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## Unreleased
- New `Zeta` and `Zipf` distributions (#1136)

## [0.4.1] - 2021-06-15
- Empirically test PDF of normal distribution (#1121)
- Correctly document `no_std` support (#1100)
Expand Down
6 changes: 6 additions & 0 deletions rand_distr/benches/src/distributions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,12 @@ fn bench(c: &mut Criterion<CyclesPerByte>) {
distr_float!(g, "poisson", f64, Poisson::new(4.0).unwrap());
}

{
let mut g = c.benchmark_group("zipf");
distr_float!(g, "zipf", f64, Zipf::new(10, 1.5).unwrap());
distr_float!(g, "zeta", f64, Zeta::new(1.5).unwrap());
}

{
let mut g = c.benchmark_group("bernoulli");
distr!(g, "bernoulli", bool, Bernoulli::new(0.18).unwrap());
Expand Down
5 changes: 4 additions & 1 deletion rand_distr/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@
//! - [`Poisson`] distribution
//! - [`Exp`]onential distribution, and [`Exp1`] as a primitive
//! - [`Weibull`] distribution
//! - [`Zeta`] distribution
//! - [`Zipf`] distribution
//! - Gamma and derived distributions:
//! - [`Gamma`] distribution
//! - [`ChiSquared`] distribution
Expand Down Expand Up @@ -115,6 +117,7 @@ pub use self::unit_circle::UnitCircle;
pub use self::unit_disc::UnitDisc;
pub use self::unit_sphere::UnitSphere;
pub use self::weibull::{Error as WeibullError, Weibull};
pub use self::zipf::{ZetaError, Zeta, ZipfError, Zipf};
#[cfg(feature = "alloc")]
#[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))]
pub use rand::distributions::{WeightedError, WeightedIndex};
Expand Down Expand Up @@ -198,4 +201,4 @@ mod unit_sphere;
mod utils;
mod weibull;
mod ziggurat_tables;

mod zipf;
328 changes: 328 additions & 0 deletions rand_distr/src/zipf.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,328 @@
// Copyright 2021 Developers of the Rand project.
//
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
// option. This file may not be copied, modified, or distributed
// except according to those terms.

//! The Zeta and related distributions.
use num_traits::Float;
use crate::{Distribution, Standard};
use rand::{Rng, distributions::OpenClosed01};
use core::fmt;

/// Samples integers according to the [zeta distribution].
///
/// The zeta distribution is a limit of the [`Zipf`] distribution. Sometimes it
/// is called one of the following: discrete Pareto, Riemann-Zeta, Zipf, or
/// Zipf–Estoup distribution.
///
/// It has the density function `f(k) = k^(-a) / C(a)` for `k >= 1`, where `a`
/// is the parameter and `C(a)` is the Riemann zeta function.
///
/// # Example
/// ```
/// use rand::prelude::*;
/// use rand_distr::Zeta;
///
/// let val: f64 = thread_rng().sample(Zeta::new(1.5).unwrap());
/// println!("{}", val);
/// ```
///
/// [zeta distribution]: https://en.wikipedia.org/wiki/Zeta_distribution
#[derive(Clone, Copy, Debug)]
pub struct Zeta<F>
where F: Float, Standard: Distribution<F>, OpenClosed01: Distribution<F>
{
a_minus_1: F,
b: F,
}

/// Error type returned from `Zeta::new`.
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum ZetaError {
/// `a <= 1` or `nan`.
ATooSmall,
}

impl fmt::Display for ZetaError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(match self {
ZetaError::ATooSmall => "a <= 1 or is NaN in Zeta distribution",
})
}
}

#[cfg(feature = "std")]
#[cfg_attr(doc_cfg, doc(cfg(feature = "std")))]
impl std::error::Error for ZetaError {}

impl<F> Zeta<F>
where F: Float, Standard: Distribution<F>, OpenClosed01: Distribution<F>
{
/// Construct a new `Zeta` distribution with given `a` parameter.
#[inline]
pub fn new(a: F) -> Result<Zeta<F>, ZetaError> {
if !(a > F::one()) {
return Err(ZetaError::ATooSmall);
}
let a_minus_1 = a - F::one();
let two = F::one() + F::one();
Ok(Zeta {
a_minus_1,
b: two.powf(a_minus_1),
})
}
}

impl<F> Distribution<F> for Zeta<F>
where F: Float, Standard: Distribution<F>, OpenClosed01: Distribution<F>
{
#[inline]
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
// This is based on https://doi.org/10.1007/978-1-4613-8643-8.
loop {
let u = rng.sample(OpenClosed01);
let x = u.powf(-F::one() / self.a_minus_1).floor();
debug_assert!(x >= F::one());

let t = (F::one() + F::one() / x).powf(self.a_minus_1);

let v = rng.sample(Standard);
if v * x * (t - F::one()) / (self.b - F::one()) <= t / self.b {
return x;
}
}
}
}

/// Samples integers according to the Zipf distribution.
///
/// The samples follow Zipf's law: The frequency of each sample from a finite
/// set of size `n` is inversely proportional to a power of its frequency rank
/// (with exponent `s`).
///
/// For large `n`, this converges to the [`Zeta`] distribution.
///
/// For `s = 0`, this becomes a uniform distribution.
///
/// # Example
/// ```
/// use rand::prelude::*;
/// use rand_distr::Zipf;
///
/// let val: f64 = thread_rng().sample(Zipf::new(10, 1.5).unwrap());
/// println!("{}", val);
/// ```
#[derive(Clone, Copy, Debug)]
pub struct Zipf<F>
where F: Float, Standard: Distribution<F> {
n: F,
s: F,
t: F,
}

/// Error type returned from `Zipf::new`.
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum ZipfError {
/// `s < 0` or `nan`.
STooSmall,
/// `n < 1`.
NTooSmall,
}

impl fmt::Display for ZipfError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(match self {
ZipfError::STooSmall => "s < 0 or is NaN in Zipf distribution",
ZipfError::NTooSmall => "n < 1 in Zipf distribution",
})
}
}

#[cfg(feature = "std")]
#[cfg_attr(doc_cfg, doc(cfg(feature = "std")))]
impl std::error::Error for ZipfError {}

impl<F> Zipf<F>
where F: Float, Standard: Distribution<F> {
/// Construct a new `Zipf` distribution for a set with `n` elements and a
/// frequency rank exponent `s`.
///
/// For large `n`, rounding may occur to fit the number into the float type.
#[inline]
pub fn new(n: u64, s: F) -> Result<Zipf<F>, ZipfError> {
if !(s >= F::zero()) {
return Err(ZipfError::STooSmall);
}
if n < 1 {
return Err(ZipfError::NTooSmall);
}
let n = F::from(n).unwrap(); // This does not fail.
let t = if s != F::one() {
(n.powf(F::one() - s) - s) / (F::one() - s)
} else {
F::one() + n.ln()
};
Ok(Zipf {
n, s, t
})
}

/// Inverse cumulative density function
#[inline]
fn inv_cdf(&self, p: F) -> F {
let one = F::one();
let pt = p * self.t;
if pt <= one {
pt
} else if self.s != F::one() {
(pt * (one - self.s) + self.s).powf(one / (one - self.s))
} else {
pt.exp()
}
}
}

impl<F> Distribution<F> for Zipf<F>
where F: Float, Standard: Distribution<F>
{
#[inline]
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
let one = F::one();
loop {
let inv_b = self.inv_cdf(rng.sample(Standard));
let x = (inv_b + one).floor();
let mut ratio = x.powf(-self.s) * self.t;
if x > one {
ratio = ratio * inv_b.powf(self.s)
};

let y = rng.sample(Standard);
if y < ratio {
return x;
}
}
}
}

#[cfg(test)]
mod tests {
use super::*;

fn test_samples<F: Float + core::fmt::Debug, D: Distribution<F>>(
distr: D, zero: F, expected: &[F],
) {
let mut rng = crate::test::rng(213);
let mut buf = [zero; 4];
for x in &mut buf {
*x = rng.sample(&distr);
}
assert_eq!(buf, expected);
}

#[test]
#[should_panic]
fn zeta_invalid() {
Zeta::new(1.).unwrap();
}

#[test]
#[should_panic]
fn zeta_nan() {
Zeta::new(core::f64::NAN).unwrap();
}

#[test]
fn zeta_sample() {
let a = 2.0;
let d = Zeta::new(a).unwrap();
let mut rng = crate::test::rng(1);
for _ in 0..1000 {
let r = d.sample(&mut rng);
assert!(r >= 1.);
}
}

#[test]
fn zeta_value_stability() {
test_samples(Zeta::new(1.5).unwrap(), 0f32, &[
1.0, 2.0, 1.0, 1.0,
]);
test_samples(Zeta::new(2.0).unwrap(), 0f64, &[
2.0, 1.0, 1.0, 1.0,
]);
}

#[test]
#[should_panic]
fn zipf_s_too_small() {
Zipf::new(10, -1.).unwrap();
}

#[test]
#[should_panic]
fn zipf_n_too_small() {
Zipf::new(0, 1.).unwrap();
}

#[test]
#[should_panic]
fn zipf_nan() {
Zipf::new(10, core::f64::NAN).unwrap();
}

#[test]
fn zipf_sample() {
let d = Zipf::new(10, 0.5).unwrap();
let mut rng = crate::test::rng(2);
for _ in 0..1000 {
let r = d.sample(&mut rng);
assert!(r >= 1.);
}
}

#[test]
fn zipf_sample_s_1() {
let d = Zipf::new(10, 1.).unwrap();
let mut rng = crate::test::rng(2);
for _ in 0..1000 {
let r = d.sample(&mut rng);
assert!(r >= 1.);
}
}

#[test]
fn zipf_sample_s_0() {
let d = Zipf::new(10, 0.).unwrap();
let mut rng = crate::test::rng(2);
for _ in 0..1000 {
let r = d.sample(&mut rng);
assert!(r >= 1.);
}
// TODO: verify that this is a uniform distribution
}

#[test]
fn zipf_sample_large_n() {
let d = Zipf::new(core::u64::MAX, 1.5).unwrap();
let mut rng = crate::test::rng(2);
for _ in 0..1000 {
let r = d.sample(&mut rng);
assert!(r >= 1.);
}
// TODO: verify that this is a zeta distribution
}

#[test]
fn zipf_value_stability() {
test_samples(Zipf::new(10, 0.5).unwrap(), 0f32, &[
10.0, 2.0, 6.0, 7.0
]);
test_samples(Zipf::new(10, 2.0).unwrap(), 0f64, &[
1.0, 2.0, 3.0, 2.0
]);
}
}