11
2- first_effective_cache (:: Type{T} ) where {T} = StaticInt {FIRST__CACHE_SIZE} () ÷ static_sizeof (T)
3- second_effective_cache (:: Type{T} ) where {T} = StaticInt {SECOND_CACHE_SIZE} () ÷ static_sizeof (T)
42
53function block_sizes (:: Type{T} , _α, _β, R₁, R₂) where {T}
6- W = VectorizationBase . pick_vector_width_val (T)
4+ W = pick_vector_width_val (T)
75 α = _α * W
86 β = _β * W
9- L₁ₑ = first_effective_cache (T) * R₁
10- L₂ₑ = second_effective_cache (T) * R₂
7+ L₁ₑ = first_cache_size (T) * R₁
8+ L₂ₑ = second_cache_size (T) * R₂
119 block_sizes (W, α, β, L₁ₑ, L₂ₑ)
1210end
1311function block_sizes (W, α, β, L₁ₑ, L₂ₑ)
14- MᵣW = StaticInt {mᵣ} () * W
12+ mᵣ, nᵣ = matmul_params ()
13+ MᵣW = mᵣ * W
1514
1615 Mc = floortostaticint (√ (L₁ₑ)*√ (L₁ₑ* β + L₂ₑ* α)/√ (L₂ₑ) / MᵣW) * MᵣW
1716 Kc = roundtostaticint (√ (L₁ₑ)*√ (L₂ₑ)/√ (L₁ₑ* β + L₂ₑ* α))
18- Nc = floortostaticint (√ (L₂ₑ)*√ (L₁ₑ* β + L₂ₑ* α)/√ (L₁ₑ) / StaticInt {nᵣ} ()) * StaticInt {nᵣ} ()
17+ Nc = floortostaticint (√ (L₂ₑ)*√ (L₁ₑ* β + L₂ₑ* α)/√ (L₁ₑ) / nᵣ) * nᵣ
1918
2019 Mc, Kc, Nc
2120end
2221function block_sizes (:: Type{T} ) where {T}
23- block_sizes (T, StaticFloat { W₁Default} (), StaticFloat { W₂Default} (), StaticFloat { R₁Default} (), StaticFloat { R₂Default} ())
22+ block_sizes (T, W₁Default (), W₂Default (), R₁Default (), R₂Default ())
2423end
2524
2625"""
@@ -159,11 +158,11 @@ Note that for synchronization on `B`, all threads must have the same values for
159158independently of `M`, this algorithm guarantees all threads are on the same page.
160159"""
161160@inline function solve_block_sizes (:: Type{T} , M, K, N, _α, _β, R₂, R₃, Wfactor) where {T}
162- W = VectorizationBase . pick_vector_width_val (T)
161+ W = pick_vector_width_val (T)
163162 α = _α * W
164163 β = _β * W
165- L₁ₑ = first_effective_cache (T) * R₂
166- L₂ₑ = second_effective_cache (T) * R₃
164+ L₁ₑ = first_cache_size (T) * R₂
165+ L₂ₑ = second_cache_size (T) * R₃
167166
168167 # Nc_init = round(Int, √(L₂ₑ)*√(α * L₂ₑ + β * L₁ₑ)/√(L₁ₑ))
169168 Nc_init⁻¹ = √ (L₁ₑ) / (√ (L₂ₑ)*√ (α * L₂ₑ + β * L₁ₑ))
@@ -178,11 +177,11 @@ independently of `M`, this algorithm guarantees all threads are on the same page
178177end
179178# Takes Nc, calcs Mc and Kc
180179@inline function solve_McKc (:: Type{T} , M, K, Nc, _α, _β, R₂, R₃, Wfactor) where {T}
181- W = VectorizationBase . pick_vector_width_val (T)
180+ W = pick_vector_width_val (T)
182181 α = _α * W
183182 β = _β * W
184- L₁ₑ = first_effective_cache (T) * R₂
185- L₂ₑ = second_effective_cache (T) * R₃
183+ L₁ₑ = first_cache_size (T) * R₂
184+ L₂ₑ = second_cache_size (T) * R₃
186185
187186 Kc_init⁻¹ = Base. FastMath. max_fast (√ (α/ L₁ₑ), Nc* inv (L₂ₑ))
188187 Kiter = cldapproxi (K, Kc_init⁻¹) # approximate `ceil`
@@ -201,27 +200,28 @@ end
201200"""
202201 find_first_acceptable(M, W)
203202
204- Finds first combination of `Miter` and `Niter` that doesn't make `M` too small while producing `Miter * Niter = NUM_CORES `.
203+ Finds first combination of `Miter` and `Niter` that doesn't make `M` too small while producing `Miter * Niter = num_cores() `.
205204This would be awkard if there are computers with prime numbers of cores. I should probably consider that possibility at some point.
206205"""
207206@inline function find_first_acceptable (M, W)
208- Mᵣ = StaticInt {mᵣ} () * W
209- for (miter,niter) ∈ CORE_FACTORS
210- if miter * ((MᵣW_mul_factor - One ()) * Mᵣ) ≤ M + (W + W)
207+ Mᵣ, Nᵣ = matmul_params ()
208+ factors = calc_factors ()
209+ for (miter, niter) ∈ factors
210+ if miter * ((MᵣW_mul_factor () - One ()) * Mᵣ) ≤ M + (W + W)
211211 return miter, niter
212212 end
213213 end
214- last (CORE_FACTORS )
214+ last (factors )
215215end
216216"""
217217 divide_blocks(M, Ntotal, _nspawn, W)
218218
219219Splits both `M` and `N` into blocks when trying to spawn a large number of threads relative to the size of the matrices.
220220"""
221221@inline function divide_blocks (M, Ntotal, _nspawn, W)
222- _nspawn == NUM_CORES && return find_first_acceptable (M, W)
223-
224- Miter = clamp (div_fast (M, W* StaticInt {mᵣ} () * MᵣW_mul_factor), 1 , _nspawn)
222+ _nspawn == num_cores () && return find_first_acceptable (M, W)
223+ mᵣ, nᵣ = matmul_params ()
224+ Miter = clamp (div_fast (M, W* mᵣ * MᵣW_mul_factor () ), 1 , _nspawn)
225225 nspawn = div_fast (_nspawn, Miter)
226226 if (nspawn ≤ 1 ) & (Miter < _nspawn)
227227 # rebalance Miter
0 commit comments