Skip to content

Commit 368a765

Browse files
Rollup merge of rust-lang#149663 - quaternic:gather-scatter-bits-opt, r=Mark-Simulacrum
Optimized implementation for uN::{gather,scatter}_bits Feature gate: #![feature(uint_gather_scatter_bits)] Tracking issue: rust-lang#149069 Accepted ACP: rust-lang/libs-team#695 Implements the methods using the parallel suffix strategy mentioned in the ACP discussion. The referenced source material provides C implementations, though this PR makes improvements over those, cutting the instruction count by a third: https://rust.godbolt.org/z/rn5naYnK4 (this PR) https://c.godbolt.org/z/WzYd5WbsY (Hacker's delight) This was initially based on the code for `gather_bits` that `@okaneco` provided in rust-lang/libs-team#695 (comment) . I wanted to understand how it worked, and later on noticed some opportunities for improvement, which eventually led to this PR.
2 parents 4dbbf26 + 79d792f commit 368a765

3 files changed

Lines changed: 176 additions & 40 deletions

File tree

library/core/src/num/int_bits.rs

Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
//! Implementations for `uN::gather_bits` and `uN::scatter_bits`
2+
//!
3+
//! For the purposes of this implementation, the operations can be thought
4+
//! of as operating on the input bits as a list, starting from the least
5+
//! significant bit. Gathering is like `Vec::retain` that deletes bits
6+
//! where the mask has a zero. Scattering is like doing the inverse by
7+
//! inserting the zeros that gathering would delete.
8+
//!
9+
//! Key observation: Each bit that is gathered/scattered needs to be
10+
//! shifted by the count of zeros up to the corresponding mask bit.
11+
//!
12+
//! With that in mind, the general idea is to decompose the operation into
13+
//! a sequence of stages in `0..log2(BITS)`, where each stage shifts some
14+
//! of the bits by `n = 1 << stage`. The masks for each stage are computed
15+
//! via prefix counts of zeros in the mask.
16+
//!
17+
//! # Gathering
18+
//!
19+
//! Consider the input as a sequence of runs of data (bitstrings A,B,C,...),
20+
//! split by fixed-width groups of zeros ('.'), initially at width `n = 1`.
21+
//! Counting the groups of zeros, each stage shifts the odd-indexed runs of
22+
//! data right by `n`, effectively swapping them with the preceding zeros.
23+
//! For the next stage, `n` is doubled as all the zeros are now paired.
24+
//! ```text
25+
//! .A.B.C.D.E.F.G.H
26+
//! ..AB..CD..EF..GH
27+
//! ....ABCD....EFGH
28+
//! ........ABCDEFGH
29+
//! ```
30+
//! What makes this nontrivial is that the lengths of the bitstrings are not
31+
//! the same. Using lowercase for individual bits, the above might look like
32+
//! ```text
33+
//! .a.bbb.ccccc.dd.e..g.hh
34+
//! ..abbb..cccccdd..e..ghh
35+
//! ....abbbcccccdd....eghh
36+
//! ........abbbcccccddeghh
37+
//! ```
38+
//!
39+
//! # Scattering
40+
//!
41+
//! For `scatter_bits`, the stages are reversed. We start with a single run of
42+
//! data in the low bits. Each stage then splits each run of data in two by
43+
//! shifting part of it left by `n`, which is halved each stage.
44+
//! ```text
45+
//! ........ABCDEFGH
46+
//! ....ABCD....EFGH
47+
//! ..AB..CD..EF..GH
48+
//! .A.B.C.D.E.F.G.H
49+
//! ```
50+
//!
51+
//! # Stage masks
52+
//!
53+
//! To facilitate the shifts at each stage, we compute a mask that covers both
54+
//! the bitstrings to shift, and the zeros they shift into.
55+
//! ```text
56+
//! .A.B.C.D.E.F.G.H
57+
//! ## ## ## ##
58+
//! ..AB..CD..EF..GH
59+
//! #### ####
60+
//! ....ABCD....EFGH
61+
//! ########
62+
//! ........ABCDEFGH
63+
//! ```
64+
65+
macro_rules! uint_impl {
66+
($U:ident) => {
67+
pub(super) mod $U {
68+
const STAGES: usize = $U::BITS.ilog2() as usize;
69+
#[inline]
70+
const fn prepare(sparse: $U) -> [$U; STAGES] {
71+
// We'll start with `zeros` as a mask of the bits to be removed,
72+
// and compute into `masks` the parts that shift at each stage.
73+
let mut zeros = !sparse;
74+
let mut masks = [0; STAGES];
75+
let mut stage = 0;
76+
while stage < STAGES {
77+
let n = 1 << stage;
78+
// Suppose `zeros` has bits set at ranges `{ a..a+n, b..b+n, ... }`.
79+
// Then `parity` will be computed as `{ a.. } XOR { b.. } XOR ...`,
80+
// which will be the ranges `{ a..b, c..d, e.. }`.
81+
let mut parity = zeros;
82+
let mut len = n;
83+
while len < $U::BITS {
84+
parity ^= parity << len;
85+
len <<= 1;
86+
}
87+
masks[stage] = parity;
88+
89+
// Toggle off the bits that are shifted into:
90+
// { a..a+n, b..b+n, ... } & !{ a..b, c..d, e.. }
91+
// == { b..b+n, d..d+n, ... }
92+
zeros &= !parity;
93+
// Expand the remaining ranges down to the bits that were
94+
// shifted from: { b-n..b+n, d-n..d+n, ... }
95+
zeros ^= zeros >> n;
96+
97+
stage += 1;
98+
}
99+
masks
100+
}
101+
102+
#[inline(always)]
103+
pub(in super::super) const fn gather_impl(mut x: $U, sparse: $U) -> $U {
104+
let masks = prepare(sparse);
105+
x &= sparse;
106+
let mut stage = 0;
107+
while stage < STAGES {
108+
let n = 1 << stage;
109+
// Consider each two runs of data with their leading
110+
// groups of `n` 0-bits. Suppose that the run that is
111+
// shifted right has length `a`, and the other one has
112+
// length `b`. Assume that only zeros are shifted in.
113+
// ```text
114+
// [0; n], [X; a], [0; n], [Y; b] // x
115+
// [0; n], [X; a], [0; n], [0; b] // q
116+
// [0; n], [0; a + n], [Y; b] // x ^= q
117+
// [0; n + n], [X; a], [0; b] // q >> n
118+
// [0; n], [0; n], [X; a], [Y; b] // x ^= q << n
119+
// ```
120+
// Only zeros are shifted out, satisfying the assumption
121+
// for the next group.
122+
123+
// In effect, the upper run of data is swapped with the
124+
// group of `n` zeros below it.
125+
let q = x & masks[stage];
126+
x ^= q;
127+
x ^= q >> n;
128+
129+
stage += 1;
130+
}
131+
x
132+
}
133+
#[inline(always)]
134+
pub(in super::super) const fn scatter_impl(mut x: $U, sparse: $U) -> $U {
135+
let masks = prepare(sparse);
136+
let mut stage = STAGES;
137+
while stage > 0 {
138+
stage -= 1;
139+
let n = 1 << stage;
140+
// Consider each run of data with the `2 * n` arbitrary bits
141+
// above it. Suppose that the run has length `a + b`, with
142+
// `a` being the length of the part that needs to be
143+
// shifted. Assume that only zeros are shifted in.
144+
// ```text
145+
// [_; n], [_; n], [X; a], [Y; b] // x
146+
// [0; n], [_; n], [X; a], [0; b] // q
147+
// [_; n], [0; n + a], [Y; b] // x ^= q
148+
// [_; n], [X; a], [0; b + n] // q << n
149+
// [_; n], [X; a], [0; n], [Y; b] // x ^= q << n
150+
// ```
151+
// Only zeros are shifted out, satisfying the assumption
152+
// for the next group.
153+
154+
// In effect, `n` 0-bits are inserted somewhere in each run
155+
// of data to spread it, and the two groups of `n` bits
156+
// above are XOR'd together.
157+
let q = x & masks[stage];
158+
x ^= q;
159+
x ^= q << n;
160+
}
161+
x & sparse
162+
}
163+
}
164+
};
165+
}
166+
167+
uint_impl!(u8);
168+
uint_impl!(u16);
169+
uint_impl!(u32);
170+
uint_impl!(u64);
171+
uint_impl!(u128);

library/core/src/num/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ mod int_macros; // import int_impl!
4444
mod uint_macros; // import uint_impl!
4545

4646
mod error;
47+
mod int_bits;
4748
mod int_log10;
4849
mod int_sqrt;
4950
pub(crate) mod libm;

library/core/src/num/uint_macros.rs

Lines changed: 4 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -492,27 +492,8 @@ macro_rules! uint_impl {
492492
#[must_use = "this returns the result of the operation, \
493493
without modifying the original"]
494494
#[inline]
495-
pub const fn gather_bits(self, mut mask: Self) -> Self {
496-
let mut bit_position = 1;
497-
let mut result = 0;
498-
499-
// Iterate through the mask bits, unsetting the lowest bit after
500-
// each iteration. We fill the bits in the result starting from the
501-
// least significant bit.
502-
while mask != 0 {
503-
// Find the next lowest set bit in the mask
504-
let next_mask_bit = mask.isolate_lowest_one();
505-
506-
// Retrieve the masked bit and if present, set it in the result
507-
let src_bit = (self & next_mask_bit) != 0;
508-
result |= if src_bit { bit_position } else { 0 };
509-
510-
// Unset lowest set bit in the mask, prepare next position to set
511-
mask ^= next_mask_bit;
512-
bit_position <<= 1;
513-
}
514-
515-
result
495+
pub const fn gather_bits(self, mask: Self) -> Self {
496+
crate::num::int_bits::$ActualT::gather_impl(self as $ActualT, mask as $ActualT) as $SelfT
516497
}
517498

518499
/// Returns an integer with the least significant bits of `self`
@@ -528,25 +509,8 @@ macro_rules! uint_impl {
528509
#[must_use = "this returns the result of the operation, \
529510
without modifying the original"]
530511
#[inline]
531-
pub const fn scatter_bits(mut self, mut mask: Self) -> Self {
532-
let mut result = 0;
533-
534-
// Iterate through the mask bits, unsetting the lowest bit after
535-
// each iteration and right-shifting `self` by one to get the next
536-
// bit into the least significant bit position.
537-
while mask != 0 {
538-
// Find the next bit position to potentially set
539-
let next_mask_bit = mask.isolate_lowest_one();
540-
541-
// If bit is set, deposit it at the masked bit position
542-
result |= if (self & 1) != 0 { next_mask_bit } else { 0 };
543-
544-
// Unset lowest set bit in the mask, shift in next `self` bit
545-
mask ^= next_mask_bit;
546-
self >>= 1;
547-
}
548-
549-
result
512+
pub const fn scatter_bits(self, mask: Self) -> Self {
513+
crate::num::int_bits::$ActualT::scatter_impl(self as $ActualT, mask as $ActualT) as $SelfT
550514
}
551515

552516
/// Reverses the order of bits in the integer. The least significant bit becomes the most significant bit,

0 commit comments

Comments
 (0)