Skip to content

perf: Shape-Dependent Algorithm Dispatch with Cost Models #95

@noahgift

Description

@noahgift

Summary

Implement shape-dependent dispatch that automatically selects the optimal algorithm based on tensor dimensions, using cost models to choose between different implementations.

Inspiration: JAX jax/_src/dispatch.py:220-228 - uses shape thresholds to select between algorithms (e.g., tupling args on TPU when >2000).

Problem

One-size-fits-all algorithms waste performance:

  • Small matrices: BLAS overhead > actual compute
  • Large matrices: naive loops thrash cache
  • Tall/skinny vs square vs short/wide: different optimal algorithms
  • Contiguous vs strided: different memory access patterns

Proposed Solution

1. Shape Classification

/// Classification of tensor shapes for dispatch decisions
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum ShapeClass {
    /// Scalar (0 dimensions or all dims = 1)
    Scalar,
    /// Small vector (< 64 elements)
    TinyVector,
    /// Medium vector (64 - 4K elements)  
    SmallVector,
    /// Large vector (> 4K elements)
    LargeVector,
    /// Small matrix (< 64x64)
    TinyMatrix,
    /// Medium matrix (64x64 - 512x512)
    SmallMatrix,
    /// Large matrix (> 512x512)
    LargeMatrix,
    /// Tall and skinny (rows >> cols)
    TallSkinny,
    /// Short and wide (cols >> rows)
    ShortWide,
    /// Batched (3+ dimensions)
    Batched,
    /// Irregular (non-power-of-2, prime dimensions)
    Irregular,
}

impl ShapeClass {
    pub fn classify(shape: &[usize]) -> Self {
        match shape.len() {
            0 => ShapeClass::Scalar,
            1 => Self::classify_vector(shape[0]),
            2 => Self::classify_matrix(shape[0], shape[1]),
            _ => ShapeClass::Batched,
        }
    }
    
    fn classify_vector(len: usize) -> Self {
        match len {
            0..=1 => ShapeClass::Scalar,
            2..=64 => ShapeClass::TinyVector,
            65..=4096 => ShapeClass::SmallVector,
            _ => ShapeClass::LargeVector,
        }
    }
    
    fn classify_matrix(rows: usize, cols: usize) -> Self {
        let total = rows * cols;
        let aspect_ratio = rows as f64 / cols as f64;
        
        // Check aspect ratio first
        if aspect_ratio > 8.0 {
            return ShapeClass::TallSkinny;
        }
        if aspect_ratio < 0.125 {
            return ShapeClass::ShortWide;
        }
        
        // Check size
        match total {
            0..=1 => ShapeClass::Scalar,
            2..=4096 => ShapeClass::TinyMatrix,      // Up to 64x64
            4097..=262144 => ShapeClass::SmallMatrix, // Up to 512x512
            _ => ShapeClass::LargeMatrix,
        }
    }
    
    /// Check if shape is "nice" (power of 2, good for SIMD)
    pub fn is_nice(shape: &[usize]) -> bool {
        shape.iter().all(|&d| d > 0 && (d & (d - 1)) == 0)
    }
    
    /// Check if dimensions are SIMD-friendly (multiple of 8 for AVX)
    pub fn is_simd_friendly(shape: &[usize], simd_width: usize) -> bool {
        shape.last().map_or(true, |&d| d % simd_width == 0)
    }
}

2. Cost Model for Algorithm Selection

/// Cost estimate for an operation
#[derive(Debug, Clone, Copy)]
pub struct OpCost {
    /// Estimated FLOPs
    pub flops: u64,
    /// Estimated memory accesses (bytes)
    pub memory_bytes: u64,
    /// Estimated overhead (setup, dispatch, etc.)
    pub overhead_ns: u64,
}

impl OpCost {
    /// Estimate total execution time in nanoseconds
    pub fn estimate_time_ns(&self, hardware: &HardwareProfile) -> u64 {
        let compute_ns = self.flops / hardware.gflops_per_ns;
        let memory_ns = self.memory_bytes / hardware.gbytes_per_ns;
        
        // Assume compute and memory can partially overlap
        let overlap = 0.5;
        let total = self.overhead_ns + 
            (compute_ns.max(memory_ns) as f64 * (1.0 + overlap * 
                (compute_ns.min(memory_ns) as f64 / compute_ns.max(memory_ns) as f64))) as u64;
        
        total
    }
}

/// Hardware performance characteristics
#[derive(Debug, Clone)]
pub struct HardwareProfile {
    /// Peak GFLOPS (billions of FLOPs per second)
    pub gflops_per_ns: u64,
    /// Memory bandwidth (GB/s)
    pub gbytes_per_ns: u64,
    /// L1 cache size (bytes)
    pub l1_cache: usize,
    /// L2 cache size (bytes)
    pub l2_cache: usize,
    /// L3 cache size (bytes)
    pub l3_cache: usize,
    /// SIMD width in floats
    pub simd_width: usize,
    /// Number of cores
    pub num_cores: usize,
}

impl HardwareProfile {
    /// Detect hardware profile at runtime
    pub fn detect() -> Self {
        #[cfg(target_arch = "x86_64")]
        {
            use raw_cpuid::CpuId;
            let cpuid = CpuId::new();
            
            // Get cache info
            let cache_info = cpuid.get_cache_parameters();
            let mut l1 = 32 * 1024;  // Default 32KB
            let mut l2 = 256 * 1024; // Default 256KB
            let mut l3 = 8 * 1024 * 1024; // Default 8MB
            
            if let Some(caches) = cache_info {
                for cache in caches {
                    let size = cache.associativity() * 
                        cache.physical_line_partitions() * 
                        cache.coherency_line_size() * 
                        cache.sets();
                    match cache.level() {
                        1 => l1 = size,
                        2 => l2 = size,
                        3 => l3 = size,
                        _ => {}
                    }
                }
            }
            
            // Detect SIMD width
            let simd_width = if is_x86_feature_detected!("avx512f") {
                16  // 512 bits / 32 bits per float
            } else if is_x86_feature_detected!("avx2") {
                8   // 256 bits / 32 bits per float
            } else {
                4   // 128 bits SSE
            };
            
            Self {
                gflops_per_ns: 100,  // ~100 GFLOPS typical
                gbytes_per_ns: 50,   // ~50 GB/s typical DDR4
                l1_cache: l1,
                l2_cache: l2,
                l3_cache: l3,
                simd_width,
                num_cores: num_cpus::get(),
            }
        }
        
        #[cfg(target_arch = "aarch64")]
        {
            Self {
                gflops_per_ns: 50,
                gbytes_per_ns: 30,
                l1_cache: 64 * 1024,
                l2_cache: 512 * 1024,
                l3_cache: 4 * 1024 * 1024,
                simd_width: 4,  // NEON 128-bit
                num_cores: num_cpus::get(),
            }
        }
    }
    
    /// Check if operation is compute-bound or memory-bound
    pub fn is_compute_bound(&self, cost: &OpCost) -> bool {
        let arithmetic_intensity = cost.flops as f64 / cost.memory_bytes as f64;
        let machine_balance = self.gflops_per_ns as f64 / self.gbytes_per_ns as f64;
        arithmetic_intensity > machine_balance
    }
}

// Global hardware profile (detected once)
static HARDWARE: Lazy<HardwareProfile> = Lazy::new(HardwareProfile::detect);

3. Algorithm Registry

use std::collections::BTreeMap;

/// A specific algorithm implementation
pub trait Algorithm<Args, Output>: Send + Sync {
    /// Name for debugging
    fn name(&self) -> &'static str;
    
    /// Estimate cost for given arguments
    fn estimate_cost(&self, args: &Args) -> OpCost;
    
    /// Execute the algorithm
    fn execute(&self, args: Args) -> Output;
    
    /// Check if this algorithm is applicable
    fn is_applicable(&self, args: &Args) -> bool {
        true
    }
}

/// Registry of algorithms for an operation
pub struct AlgorithmRegistry<Args, Output> {
    algorithms: Vec<Box<dyn Algorithm<Args, Output>>>,
    /// Cache of (shape_hash -> best_algorithm_index)
    cache: RwLock<HashMap<u64, usize>>,
}

impl<Args: Hash, Output> AlgorithmRegistry<Args, Output> {
    pub fn new() -> Self {
        Self {
            algorithms: Vec::new(),
            cache: RwLock::new(HashMap::new()),
        }
    }
    
    pub fn register(&mut self, algo: impl Algorithm<Args, Output> + 'static) {
        self.algorithms.push(Box::new(algo));
    }
    
    /// Select best algorithm based on cost model
    pub fn select(&self, args: &Args) -> &dyn Algorithm<Args, Output> {
        // Check cache
        let shape_hash = self.hash_args(args);
        if let Some(&idx) = self.cache.read().unwrap().get(&shape_hash) {
            return self.algorithms[idx].as_ref();
        }
        
        // Find best algorithm
        let hw = &*HARDWARE;
        let mut best_idx = 0;
        let mut best_time = u64::MAX;
        
        for (idx, algo) in self.algorithms.iter().enumerate() {
            if !algo.is_applicable(args) {
                continue;
            }
            
            let cost = algo.estimate_cost(args);
            let time = cost.estimate_time_ns(hw);
            
            if time < best_time {
                best_time = time;
                best_idx = idx;
            }
        }
        
        // Cache result
        self.cache.write().unwrap().insert(shape_hash, best_idx);
        
        self.algorithms[best_idx].as_ref()
    }
    
    /// Execute with automatic algorithm selection
    pub fn execute(&self, args: Args) -> Output {
        let algo = self.select(&args);
        algo.execute(args)
    }
    
    fn hash_args(&self, args: &Args) -> u64 {
        let mut hasher = rustc_hash::FxHasher::default();
        args.hash(&mut hasher);
        hasher.finish()
    }
}

4. Matmul Dispatch Example

/// Arguments for matrix multiplication
#[derive(Clone, Hash)]
pub struct MatmulArgs {
    pub m: usize,
    pub k: usize,
    pub n: usize,
    pub dtype: DType,
    pub a_transposed: bool,
    pub b_transposed: bool,
}

/// Naive O(n^3) matmul for tiny matrices
struct NaiveMatmul;

impl Algorithm<MatmulArgs, Tensor> for NaiveMatmul {
    fn name(&self) -> &'static str { "naive" }
    
    fn estimate_cost(&self, args: &MatmulArgs) -> OpCost {
        let flops = 2 * args.m as u64 * args.k as u64 * args.n as u64;
        let memory = (args.m * args.k + args.k * args.n + args.m * args.n) as u64 * 4;
        OpCost {
            flops,
            memory_bytes: memory,
            overhead_ns: 10,  // Very low overhead
        }
    }
    
    fn execute(&self, args: MatmulArgs) -> Tensor {
        // Simple triple loop
        naive_matmul_impl(args)
    }
    
    fn is_applicable(&self, args: &MatmulArgs) -> bool {
        // Only for tiny matrices
        args.m * args.n < 4096
    }
}

/// SIMD-optimized matmul for medium matrices
struct SimdMatmul;

impl Algorithm<MatmulArgs, Tensor> for SimdMatmul {
    fn name(&self) -> &'static str { "simd" }
    
    fn estimate_cost(&self, args: &MatmulArgs) -> OpCost {
        let flops = 2 * args.m as u64 * args.k as u64 * args.n as u64;
        let memory = (args.m * args.k + args.k * args.n + args.m * args.n) as u64 * 4;
        OpCost {
            flops,
            memory_bytes: memory,
            overhead_ns: 100,  // Some setup for SIMD
        }
    }
    
    fn execute(&self, args: MatmulArgs) -> Tensor {
        simd_matmul_impl(args)
    }
    
    fn is_applicable(&self, args: &MatmulArgs) -> bool {
        // Need minimum size for SIMD efficiency
        args.n >= HARDWARE.simd_width && args.m >= 4
    }
}

/// Tiled matmul for large matrices (cache-friendly)
struct TiledMatmul {
    tile_size: usize,
}

impl Algorithm<MatmulArgs, Tensor> for TiledMatmul {
    fn name(&self) -> &'static str { "tiled" }
    
    fn estimate_cost(&self, args: &MatmulArgs) -> OpCost {
        let flops = 2 * args.m as u64 * args.k as u64 * args.n as u64;
        // Better memory access pattern
        let memory = (args.m * args.k + args.k * args.n + args.m * args.n) as u64 * 4;
        let cache_factor = if args.m * args.k * 4 < HARDWARE.l2_cache { 0.5 } else { 1.0 };
        OpCost {
            flops,
            memory_bytes: (memory as f64 * cache_factor) as u64,
            overhead_ns: 500,  // Tiling setup
        }
    }
    
    fn execute(&self, args: MatmulArgs) -> Tensor {
        tiled_matmul_impl(args, self.tile_size)
    }
    
    fn is_applicable(&self, args: &MatmulArgs) -> bool {
        // For large matrices where tiling helps
        args.m * args.k * 4 > HARDWARE.l1_cache
    }
}

/// Parallel tiled matmul for very large matrices
struct ParallelMatmul;

impl Algorithm<MatmulArgs, Tensor> for ParallelMatmul {
    fn name(&self) -> &'static str { "parallel" }
    
    fn estimate_cost(&self, args: &MatmulArgs) -> OpCost {
        let flops = 2 * args.m as u64 * args.k as u64 * args.n as u64;
        let memory = (args.m * args.k + args.k * args.n + args.m * args.n) as u64 * 4;
        let parallelism = HARDWARE.num_cores.min(args.m) as u64;
        OpCost {
            flops: flops / parallelism,  // Divided by parallelism
            memory_bytes: memory,
            overhead_ns: 5000,  // Thread spawn overhead
        }
    }
    
    fn execute(&self, args: MatmulArgs) -> Tensor {
        parallel_matmul_impl(args)
    }
    
    fn is_applicable(&self, args: &MatmulArgs) -> bool {
        // Only worth parallelizing for large matrices
        let total_work = args.m * args.k * args.n;
        total_work > 1_000_000 && args.m >= HARDWARE.num_cores
    }
}

// Global matmul registry
static MATMUL_REGISTRY: Lazy<AlgorithmRegistry<MatmulArgs, Tensor>> = Lazy::new(|| {
    let mut registry = AlgorithmRegistry::new();
    registry.register(NaiveMatmul);
    registry.register(SimdMatmul);
    registry.register(TiledMatmul { tile_size: 64 });
    registry.register(ParallelMatmul);
    registry
});

/// Public matmul function with automatic dispatch
pub fn matmul(a: &Tensor, b: &Tensor) -> Tensor {
    let args = MatmulArgs {
        m: a.shape()[0],
        k: a.shape()[1],
        n: b.shape()[1],
        dtype: a.dtype(),
        a_transposed: false,
        b_transposed: false,
    };
    
    MATMUL_REGISTRY.execute(args)
}

5. Auto-Tuning with Benchmarking

/// Auto-tune algorithm selection by actually running benchmarks
pub struct AutoTuner<Args, Output> {
    registry: AlgorithmRegistry<Args, Output>,
    /// Measured timings (shape_hash -> algo_idx -> median_ns)
    measurements: RwLock<HashMap<u64, Vec<(usize, u64)>>>,
    /// Number of warmup runs
    warmup_runs: usize,
    /// Number of measurement runs
    measurement_runs: usize,
}

impl<Args: Clone + Hash, Output> AutoTuner<Args, Output> {
    /// Auto-tune for specific args by benchmarking
    pub fn tune(&self, args: &Args) -> usize {
        let shape_hash = self.hash_args(args);
        
        // Check if already tuned
        if let Some(measurements) = self.measurements.read().unwrap().get(&shape_hash) {
            return measurements.iter()
                .min_by_key(|(_, time)| time)
                .map(|(idx, _)| *idx)
                .unwrap_or(0);
        }
        
        // Benchmark each applicable algorithm
        let mut results = Vec::new();
        
        for (idx, algo) in self.registry.algorithms.iter().enumerate() {
            if !algo.is_applicable(args) {
                continue;
            }
            
            // Warmup
            for _ in 0..self.warmup_runs {
                let _ = algo.execute(args.clone());
            }
            
            // Measure
            let mut times = Vec::with_capacity(self.measurement_runs);
            for _ in 0..self.measurement_runs {
                let start = std::time::Instant::now();
                let _ = algo.execute(args.clone());
                times.push(start.elapsed().as_nanos() as u64);
            }
            
            // Use median
            times.sort();
            let median = times[times.len() / 2];
            results.push((idx, median));
        }
        
        // Cache results
        let best_idx = results.iter()
            .min_by_key(|(_, time)| time)
            .map(|(idx, _)| *idx)
            .unwrap_or(0);
        
        self.measurements.write().unwrap().insert(shape_hash, results);
        
        best_idx
    }
    
    /// Save tuning results to file
    pub fn save(&self, path: &Path) -> std::io::Result<()> {
        let data = self.measurements.read().unwrap();
        let json = serde_json::to_string_pretty(&*data)?;
        std::fs::write(path, json)
    }
    
    /// Load tuning results from file
    pub fn load(&self, path: &Path) -> std::io::Result<()> {
        let json = std::fs::read_to_string(path)?;
        let data: HashMap<u64, Vec<(usize, u64)>> = serde_json::from_str(&json)?;
        *self.measurements.write().unwrap() = data;
        Ok(())
    }
}

Acceptance Criteria

  • ShapeClass enum with classification logic
  • OpCost and HardwareProfile for cost modeling
  • AlgorithmRegistry with cost-based selection
  • Matmul dispatch with 4+ algorithm variants
  • Auto-tuner with runtime benchmarking
  • Persistence of tuning results
  • Hardware detection (cache sizes, SIMD width)
  • Benchmarks comparing dispatch overhead vs speedup

Expected Performance Impact

Shape Single Algorithm Dispatched Improvement
8x8 50ns (BLAS overhead) 20ns (naive) 2.5x
64x64 15μs 12μs (SIMD) 1.25x
512x512 8ms 6ms (tiled) 1.33x
4096x4096 2s 500ms (parallel) 4x

5-15% average improvement across workloads by avoiding algorithm mismatch.

References

Labels

performance, dispatch, optimization, P1-high

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions