-
Notifications
You must be signed in to change notification settings - Fork 1
Open
Description
Summary
Implement shape-dependent dispatch that automatically selects the optimal algorithm based on tensor dimensions, using cost models to choose between different implementations.
Inspiration: JAX jax/_src/dispatch.py:220-228 - uses shape thresholds to select between algorithms (e.g., tupling args on TPU when >2000).
Problem
One-size-fits-all algorithms waste performance:
- Small matrices: BLAS overhead > actual compute
- Large matrices: naive loops thrash cache
- Tall/skinny vs square vs short/wide: different optimal algorithms
- Contiguous vs strided: different memory access patterns
Proposed Solution
1. Shape Classification
/// Classification of tensor shapes for dispatch decisions
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum ShapeClass {
/// Scalar (0 dimensions or all dims = 1)
Scalar,
/// Small vector (< 64 elements)
TinyVector,
/// Medium vector (64 - 4K elements)
SmallVector,
/// Large vector (> 4K elements)
LargeVector,
/// Small matrix (< 64x64)
TinyMatrix,
/// Medium matrix (64x64 - 512x512)
SmallMatrix,
/// Large matrix (> 512x512)
LargeMatrix,
/// Tall and skinny (rows >> cols)
TallSkinny,
/// Short and wide (cols >> rows)
ShortWide,
/// Batched (3+ dimensions)
Batched,
/// Irregular (non-power-of-2, prime dimensions)
Irregular,
}
impl ShapeClass {
pub fn classify(shape: &[usize]) -> Self {
match shape.len() {
0 => ShapeClass::Scalar,
1 => Self::classify_vector(shape[0]),
2 => Self::classify_matrix(shape[0], shape[1]),
_ => ShapeClass::Batched,
}
}
fn classify_vector(len: usize) -> Self {
match len {
0..=1 => ShapeClass::Scalar,
2..=64 => ShapeClass::TinyVector,
65..=4096 => ShapeClass::SmallVector,
_ => ShapeClass::LargeVector,
}
}
fn classify_matrix(rows: usize, cols: usize) -> Self {
let total = rows * cols;
let aspect_ratio = rows as f64 / cols as f64;
// Check aspect ratio first
if aspect_ratio > 8.0 {
return ShapeClass::TallSkinny;
}
if aspect_ratio < 0.125 {
return ShapeClass::ShortWide;
}
// Check size
match total {
0..=1 => ShapeClass::Scalar,
2..=4096 => ShapeClass::TinyMatrix, // Up to 64x64
4097..=262144 => ShapeClass::SmallMatrix, // Up to 512x512
_ => ShapeClass::LargeMatrix,
}
}
/// Check if shape is "nice" (power of 2, good for SIMD)
pub fn is_nice(shape: &[usize]) -> bool {
shape.iter().all(|&d| d > 0 && (d & (d - 1)) == 0)
}
/// Check if dimensions are SIMD-friendly (multiple of 8 for AVX)
pub fn is_simd_friendly(shape: &[usize], simd_width: usize) -> bool {
shape.last().map_or(true, |&d| d % simd_width == 0)
}
}2. Cost Model for Algorithm Selection
/// Cost estimate for an operation
#[derive(Debug, Clone, Copy)]
pub struct OpCost {
/// Estimated FLOPs
pub flops: u64,
/// Estimated memory accesses (bytes)
pub memory_bytes: u64,
/// Estimated overhead (setup, dispatch, etc.)
pub overhead_ns: u64,
}
impl OpCost {
/// Estimate total execution time in nanoseconds
pub fn estimate_time_ns(&self, hardware: &HardwareProfile) -> u64 {
let compute_ns = self.flops / hardware.gflops_per_ns;
let memory_ns = self.memory_bytes / hardware.gbytes_per_ns;
// Assume compute and memory can partially overlap
let overlap = 0.5;
let total = self.overhead_ns +
(compute_ns.max(memory_ns) as f64 * (1.0 + overlap *
(compute_ns.min(memory_ns) as f64 / compute_ns.max(memory_ns) as f64))) as u64;
total
}
}
/// Hardware performance characteristics
#[derive(Debug, Clone)]
pub struct HardwareProfile {
/// Peak GFLOPS (billions of FLOPs per second)
pub gflops_per_ns: u64,
/// Memory bandwidth (GB/s)
pub gbytes_per_ns: u64,
/// L1 cache size (bytes)
pub l1_cache: usize,
/// L2 cache size (bytes)
pub l2_cache: usize,
/// L3 cache size (bytes)
pub l3_cache: usize,
/// SIMD width in floats
pub simd_width: usize,
/// Number of cores
pub num_cores: usize,
}
impl HardwareProfile {
/// Detect hardware profile at runtime
pub fn detect() -> Self {
#[cfg(target_arch = "x86_64")]
{
use raw_cpuid::CpuId;
let cpuid = CpuId::new();
// Get cache info
let cache_info = cpuid.get_cache_parameters();
let mut l1 = 32 * 1024; // Default 32KB
let mut l2 = 256 * 1024; // Default 256KB
let mut l3 = 8 * 1024 * 1024; // Default 8MB
if let Some(caches) = cache_info {
for cache in caches {
let size = cache.associativity() *
cache.physical_line_partitions() *
cache.coherency_line_size() *
cache.sets();
match cache.level() {
1 => l1 = size,
2 => l2 = size,
3 => l3 = size,
_ => {}
}
}
}
// Detect SIMD width
let simd_width = if is_x86_feature_detected!("avx512f") {
16 // 512 bits / 32 bits per float
} else if is_x86_feature_detected!("avx2") {
8 // 256 bits / 32 bits per float
} else {
4 // 128 bits SSE
};
Self {
gflops_per_ns: 100, // ~100 GFLOPS typical
gbytes_per_ns: 50, // ~50 GB/s typical DDR4
l1_cache: l1,
l2_cache: l2,
l3_cache: l3,
simd_width,
num_cores: num_cpus::get(),
}
}
#[cfg(target_arch = "aarch64")]
{
Self {
gflops_per_ns: 50,
gbytes_per_ns: 30,
l1_cache: 64 * 1024,
l2_cache: 512 * 1024,
l3_cache: 4 * 1024 * 1024,
simd_width: 4, // NEON 128-bit
num_cores: num_cpus::get(),
}
}
}
/// Check if operation is compute-bound or memory-bound
pub fn is_compute_bound(&self, cost: &OpCost) -> bool {
let arithmetic_intensity = cost.flops as f64 / cost.memory_bytes as f64;
let machine_balance = self.gflops_per_ns as f64 / self.gbytes_per_ns as f64;
arithmetic_intensity > machine_balance
}
}
// Global hardware profile (detected once)
static HARDWARE: Lazy<HardwareProfile> = Lazy::new(HardwareProfile::detect);3. Algorithm Registry
use std::collections::BTreeMap;
/// A specific algorithm implementation
pub trait Algorithm<Args, Output>: Send + Sync {
/// Name for debugging
fn name(&self) -> &'static str;
/// Estimate cost for given arguments
fn estimate_cost(&self, args: &Args) -> OpCost;
/// Execute the algorithm
fn execute(&self, args: Args) -> Output;
/// Check if this algorithm is applicable
fn is_applicable(&self, args: &Args) -> bool {
true
}
}
/// Registry of algorithms for an operation
pub struct AlgorithmRegistry<Args, Output> {
algorithms: Vec<Box<dyn Algorithm<Args, Output>>>,
/// Cache of (shape_hash -> best_algorithm_index)
cache: RwLock<HashMap<u64, usize>>,
}
impl<Args: Hash, Output> AlgorithmRegistry<Args, Output> {
pub fn new() -> Self {
Self {
algorithms: Vec::new(),
cache: RwLock::new(HashMap::new()),
}
}
pub fn register(&mut self, algo: impl Algorithm<Args, Output> + 'static) {
self.algorithms.push(Box::new(algo));
}
/// Select best algorithm based on cost model
pub fn select(&self, args: &Args) -> &dyn Algorithm<Args, Output> {
// Check cache
let shape_hash = self.hash_args(args);
if let Some(&idx) = self.cache.read().unwrap().get(&shape_hash) {
return self.algorithms[idx].as_ref();
}
// Find best algorithm
let hw = &*HARDWARE;
let mut best_idx = 0;
let mut best_time = u64::MAX;
for (idx, algo) in self.algorithms.iter().enumerate() {
if !algo.is_applicable(args) {
continue;
}
let cost = algo.estimate_cost(args);
let time = cost.estimate_time_ns(hw);
if time < best_time {
best_time = time;
best_idx = idx;
}
}
// Cache result
self.cache.write().unwrap().insert(shape_hash, best_idx);
self.algorithms[best_idx].as_ref()
}
/// Execute with automatic algorithm selection
pub fn execute(&self, args: Args) -> Output {
let algo = self.select(&args);
algo.execute(args)
}
fn hash_args(&self, args: &Args) -> u64 {
let mut hasher = rustc_hash::FxHasher::default();
args.hash(&mut hasher);
hasher.finish()
}
}4. Matmul Dispatch Example
/// Arguments for matrix multiplication
#[derive(Clone, Hash)]
pub struct MatmulArgs {
pub m: usize,
pub k: usize,
pub n: usize,
pub dtype: DType,
pub a_transposed: bool,
pub b_transposed: bool,
}
/// Naive O(n^3) matmul for tiny matrices
struct NaiveMatmul;
impl Algorithm<MatmulArgs, Tensor> for NaiveMatmul {
fn name(&self) -> &'static str { "naive" }
fn estimate_cost(&self, args: &MatmulArgs) -> OpCost {
let flops = 2 * args.m as u64 * args.k as u64 * args.n as u64;
let memory = (args.m * args.k + args.k * args.n + args.m * args.n) as u64 * 4;
OpCost {
flops,
memory_bytes: memory,
overhead_ns: 10, // Very low overhead
}
}
fn execute(&self, args: MatmulArgs) -> Tensor {
// Simple triple loop
naive_matmul_impl(args)
}
fn is_applicable(&self, args: &MatmulArgs) -> bool {
// Only for tiny matrices
args.m * args.n < 4096
}
}
/// SIMD-optimized matmul for medium matrices
struct SimdMatmul;
impl Algorithm<MatmulArgs, Tensor> for SimdMatmul {
fn name(&self) -> &'static str { "simd" }
fn estimate_cost(&self, args: &MatmulArgs) -> OpCost {
let flops = 2 * args.m as u64 * args.k as u64 * args.n as u64;
let memory = (args.m * args.k + args.k * args.n + args.m * args.n) as u64 * 4;
OpCost {
flops,
memory_bytes: memory,
overhead_ns: 100, // Some setup for SIMD
}
}
fn execute(&self, args: MatmulArgs) -> Tensor {
simd_matmul_impl(args)
}
fn is_applicable(&self, args: &MatmulArgs) -> bool {
// Need minimum size for SIMD efficiency
args.n >= HARDWARE.simd_width && args.m >= 4
}
}
/// Tiled matmul for large matrices (cache-friendly)
struct TiledMatmul {
tile_size: usize,
}
impl Algorithm<MatmulArgs, Tensor> for TiledMatmul {
fn name(&self) -> &'static str { "tiled" }
fn estimate_cost(&self, args: &MatmulArgs) -> OpCost {
let flops = 2 * args.m as u64 * args.k as u64 * args.n as u64;
// Better memory access pattern
let memory = (args.m * args.k + args.k * args.n + args.m * args.n) as u64 * 4;
let cache_factor = if args.m * args.k * 4 < HARDWARE.l2_cache { 0.5 } else { 1.0 };
OpCost {
flops,
memory_bytes: (memory as f64 * cache_factor) as u64,
overhead_ns: 500, // Tiling setup
}
}
fn execute(&self, args: MatmulArgs) -> Tensor {
tiled_matmul_impl(args, self.tile_size)
}
fn is_applicable(&self, args: &MatmulArgs) -> bool {
// For large matrices where tiling helps
args.m * args.k * 4 > HARDWARE.l1_cache
}
}
/// Parallel tiled matmul for very large matrices
struct ParallelMatmul;
impl Algorithm<MatmulArgs, Tensor> for ParallelMatmul {
fn name(&self) -> &'static str { "parallel" }
fn estimate_cost(&self, args: &MatmulArgs) -> OpCost {
let flops = 2 * args.m as u64 * args.k as u64 * args.n as u64;
let memory = (args.m * args.k + args.k * args.n + args.m * args.n) as u64 * 4;
let parallelism = HARDWARE.num_cores.min(args.m) as u64;
OpCost {
flops: flops / parallelism, // Divided by parallelism
memory_bytes: memory,
overhead_ns: 5000, // Thread spawn overhead
}
}
fn execute(&self, args: MatmulArgs) -> Tensor {
parallel_matmul_impl(args)
}
fn is_applicable(&self, args: &MatmulArgs) -> bool {
// Only worth parallelizing for large matrices
let total_work = args.m * args.k * args.n;
total_work > 1_000_000 && args.m >= HARDWARE.num_cores
}
}
// Global matmul registry
static MATMUL_REGISTRY: Lazy<AlgorithmRegistry<MatmulArgs, Tensor>> = Lazy::new(|| {
let mut registry = AlgorithmRegistry::new();
registry.register(NaiveMatmul);
registry.register(SimdMatmul);
registry.register(TiledMatmul { tile_size: 64 });
registry.register(ParallelMatmul);
registry
});
/// Public matmul function with automatic dispatch
pub fn matmul(a: &Tensor, b: &Tensor) -> Tensor {
let args = MatmulArgs {
m: a.shape()[0],
k: a.shape()[1],
n: b.shape()[1],
dtype: a.dtype(),
a_transposed: false,
b_transposed: false,
};
MATMUL_REGISTRY.execute(args)
}5. Auto-Tuning with Benchmarking
/// Auto-tune algorithm selection by actually running benchmarks
pub struct AutoTuner<Args, Output> {
registry: AlgorithmRegistry<Args, Output>,
/// Measured timings (shape_hash -> algo_idx -> median_ns)
measurements: RwLock<HashMap<u64, Vec<(usize, u64)>>>,
/// Number of warmup runs
warmup_runs: usize,
/// Number of measurement runs
measurement_runs: usize,
}
impl<Args: Clone + Hash, Output> AutoTuner<Args, Output> {
/// Auto-tune for specific args by benchmarking
pub fn tune(&self, args: &Args) -> usize {
let shape_hash = self.hash_args(args);
// Check if already tuned
if let Some(measurements) = self.measurements.read().unwrap().get(&shape_hash) {
return measurements.iter()
.min_by_key(|(_, time)| time)
.map(|(idx, _)| *idx)
.unwrap_or(0);
}
// Benchmark each applicable algorithm
let mut results = Vec::new();
for (idx, algo) in self.registry.algorithms.iter().enumerate() {
if !algo.is_applicable(args) {
continue;
}
// Warmup
for _ in 0..self.warmup_runs {
let _ = algo.execute(args.clone());
}
// Measure
let mut times = Vec::with_capacity(self.measurement_runs);
for _ in 0..self.measurement_runs {
let start = std::time::Instant::now();
let _ = algo.execute(args.clone());
times.push(start.elapsed().as_nanos() as u64);
}
// Use median
times.sort();
let median = times[times.len() / 2];
results.push((idx, median));
}
// Cache results
let best_idx = results.iter()
.min_by_key(|(_, time)| time)
.map(|(idx, _)| *idx)
.unwrap_or(0);
self.measurements.write().unwrap().insert(shape_hash, results);
best_idx
}
/// Save tuning results to file
pub fn save(&self, path: &Path) -> std::io::Result<()> {
let data = self.measurements.read().unwrap();
let json = serde_json::to_string_pretty(&*data)?;
std::fs::write(path, json)
}
/// Load tuning results from file
pub fn load(&self, path: &Path) -> std::io::Result<()> {
let json = std::fs::read_to_string(path)?;
let data: HashMap<u64, Vec<(usize, u64)>> = serde_json::from_str(&json)?;
*self.measurements.write().unwrap() = data;
Ok(())
}
}Acceptance Criteria
-
ShapeClassenum with classification logic -
OpCostandHardwareProfilefor cost modeling -
AlgorithmRegistrywith cost-based selection - Matmul dispatch with 4+ algorithm variants
- Auto-tuner with runtime benchmarking
- Persistence of tuning results
- Hardware detection (cache sizes, SIMD width)
- Benchmarks comparing dispatch overhead vs speedup
Expected Performance Impact
| Shape | Single Algorithm | Dispatched | Improvement |
|---|---|---|---|
| 8x8 | 50ns (BLAS overhead) | 20ns (naive) | 2.5x |
| 64x64 | 15μs | 12μs (SIMD) | 1.25x |
| 512x512 | 8ms | 6ms (tiled) | 1.33x |
| 4096x4096 | 2s | 500ms (parallel) | 4x |
5-15% average improvement across workloads by avoiding algorithm mismatch.
References
- JAX shape dispatch:
jax/_src/dispatch.py:220-228 - ATLAS auto-tuning: https://math-atlas.sourceforge.net/
- BLIS microkernel selection: https://github.com/flame/blis
Labels
performance, dispatch, optimization, P1-high
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels