@@ -388,11 +388,11 @@ function matmul_pack_A_and_B!(
388388 mᵣW = mᵣ * W
389389 # atomicsync = Ref{NTuple{16,UInt}}()
390390 Mbsize, Mrem, Mremfinal, _to_spawn = split_m (M, tospawn, W) # M is guaranteed to be > W because of `W ≥ M` condition for `jmultsplitn!`...
391- atomicsync = allocref (StaticInt {2 } ()* num_cores ()* cache_linesize ())
392- p = reinterpret (Ptr{UInt }, Base. unsafe_convert (Ptr{UInt8}, atomicsync))
391+ atomicsync = allocref (( StaticInt {1 } ()+ num_cores () )* cache_linesize ())
392+ p = align ( reinterpret (Ptr{UInt32 }, Base. unsafe_convert (Ptr{UInt8}, atomicsync) ))
393393 GC. @preserve atomicsync begin
394- for i ∈ CloseOpen (2_ to_spawn )
395- _atomic_store! (p + i* cache_linesize (), zero (UInt) )
394+ for i ∈ CloseOpen (_to_spawn )
395+ _atomic_store! (reinterpret (Ptr{UInt64}, p) + i* cache_linesize (), 0x0000000000000000 )
396396 end
397397 Mblock_Mrem, Mblock_ = promote (Mbsize + W, Mbsize)
398398 u_to_spawn = _to_spawn % UInt
@@ -414,87 +414,81 @@ function matmul_pack_A_and_B!(
414414end
415415
416416function sync_mul! (
417- C:: AbstractStridedPointer{T} , A:: AbstractStridedPointer , B:: AbstractStridedPointer , α, β, M, K, N, atomicp:: Ptr{UInt } , bc:: Ptr , id:: UInt , total_ids:: UInt ,
417+ C:: AbstractStridedPointer{T} , A:: AbstractStridedPointer , B:: AbstractStridedPointer , α, β, M, K, N, atomicp:: Ptr{UInt32 } , bc:: Ptr , id:: UInt , total_ids:: UInt ,
418418 :: StaticFloat64{W₁} , :: StaticFloat64{W₂} , :: StaticFloat64{R₁} , :: StaticFloat64{R₂}
419419) where {T, W₁, W₂, R₁, R₂}
420420
421- (Mblock, Mblock_Mrem, Mremfinal, Mrem, Miter), (Kblock, Kblock_Krem, Krem, Kiter), (Nblock, Nblock_Nrem, Nrem, Niter) =
422- solve_block_sizes (Val (T), M, K, N, StaticFloat64 {W₁} (), StaticFloat64 {W₂} (), StaticFloat64 {R₁} (), StaticFloat64 {R₂} (), One ())
423-
424- # atomics = atomicp + 8sizeof(UInt)
425- sync_iters = zero (UInt)
426- myp = atomicp + id * cache_linesize ()
427- atomicp -= cache_linesize ()
428- atomics = atomicp + total_ids* cache_linesize ()
429- mys = myp + total_ids* (cache_linesize () % UInt)
430- Npackb_r_div, Npackb_r_rem = divrem_fast (Nblock_Nrem, total_ids)
431- Npackb_r_block_rem, Npackb_r_block_ = promote (Npackb_r_div + One (), Npackb_r_div)
432-
433- Npackb___div, Npackb___rem = divrem_fast (Nblock, total_ids)
434- Npackb___block_rem, Npackb___block_ = promote (Npackb___div + One (), Npackb___div)
435-
436- pack_r_offset = Npackb_r_div * id + min (id, Npackb_r_rem)
437- pack___offset = Npackb___div * id + min (id, Npackb___rem)
438-
439- pack_r_len = ifelse (id < Npackb_r_rem, Npackb_r_block_rem, Npackb_r_block_)
440- pack___len = ifelse (id < Npackb___rem, Npackb___block_rem, Npackb___block_)
441-
442- for n in CloseOpen (Niter)
443- # Krem
444- # pack kc x nc block of B
445- nfull = n < Nrem
446- nsize = ifelse (nfull, Nblock_Nrem, Nblock)
447- pack_offset = ifelse (nfull, pack_r_offset, pack___offset)
448- pack_len = ifelse (nfull, pack_r_len, pack___len)
449- let A = A, B = B
450- for k ∈ CloseOpen (Kiter)
451- ksize = ifelse (k < Krem, Kblock_Krem, Kblock)
452- _B = default_zerobased_stridedpointer (bc, (One (), ksize))
453- unsafe_copyto_turbo! (gesp (_B, (Zero (), pack_offset)), gesp (B, (Zero (), pack_offset)), ksize, pack_len)
454- # synchronize before starting the multiplication, to ensure `B` is packed
455- _mv = _atomic_add! (myp, one (UInt))
456- sync_iters += one (UInt)
457- let atomp = atomicp
458- for _ ∈ CloseOpen (total_ids)
459- atomp += cache_linesize ()
460- atomp == myp && continue
461- while _atomic_load (atomp) != sync_iters
462- pause ()
463- end
464- end
465- end
466- # multiply
467- let A = A, B = _B, C = C
468- for m in CloseOpen (Miter)
469- msize = ifelse ((m+ 1 ) == Miter, Mremfinal, ifelse (m < Mrem, Mblock_Mrem, Mblock))
470- if k == 0
471- packaloopmul! (C, A, B, α, β, msize, ksize, nsize)
472- else
473- packaloopmul! (C, A, B, α, One (), msize, ksize, nsize)
474- end
475- A = gesp (A, (msize, Zero ()))
476- C = gesp (C, (msize, Zero ()))
477- end
478- end
479- A = gesp (A, (Zero (), ksize))
480- B = gesp (B, (ksize, Zero ()))
481- # synchronize on completion so we wait until every thread is done with `Bpacked` before beginning to overwrite it
482- _mv = _atomic_add! (mys, one (UInt))
483- let atoms = atomics
484- for _ ∈ CloseOpen (total_ids)
485- atoms += cache_linesize ()
486- atoms == mys && continue
487- while _atomic_load (atoms) != sync_iters
488- pause ()
489- end
490- end
491- end
421+ (Mblock, Mblock_Mrem, Mremfinal, Mrem, Miter), (Kblock, Kblock_Krem, Krem, Kiter), (Nblock, Nblock_Nrem, Nrem, Niter) =
422+ solve_block_sizes (Val (T), M, K, N, StaticFloat64 {W₁} (), StaticFloat64 {W₂} (), StaticFloat64 {R₁} (), StaticFloat64 {R₂} (), One ())
423+
424+ sync_iters = 0x00000000
425+ myp = atomicp + id * cache_linesize ()
426+ Npackb_r_div, Npackb_r_rem = divrem_fast (Nblock_Nrem, total_ids)
427+ Npackb_r_block_rem, Npackb_r_block_ = promote (Npackb_r_div + One (), Npackb_r_div)
428+
429+ Npackb___div, Npackb___rem = divrem_fast (Nblock, total_ids)
430+ Npackb___block_rem, Npackb___block_ = promote (Npackb___div + One (), Npackb___div)
431+
432+ pack_r_offset = Npackb_r_div * id + min (id, Npackb_r_rem)
433+ pack___offset = Npackb___div * id + min (id, Npackb___rem)
434+
435+ pack_r_len = ifelse (id < Npackb_r_rem, Npackb_r_block_rem, Npackb_r_block_)
436+ pack___len = ifelse (id < Npackb___rem, Npackb___block_rem, Npackb___block_)
437+
438+ for n in CloseOpen (Niter)
439+ # Krem
440+ # pack kc x nc block of B
441+ nfull = n < Nrem
442+ nsize = ifelse (nfull, Nblock_Nrem, Nblock)
443+ pack_offset = ifelse (nfull, pack_r_offset, pack___offset)
444+ pack_len = ifelse (nfull, pack_r_len, pack___len)
445+ let A = A, B = B
446+ for k ∈ CloseOpen (Kiter)
447+ ksize = ifelse (k < Krem, Kblock_Krem, Kblock)
448+ _B = default_zerobased_stridedpointer (bc, (One (), ksize))
449+ unsafe_copyto_turbo! (gesp (_B, (Zero (), pack_offset)), gesp (B, (Zero (), pack_offset)), ksize, pack_len)
450+ # synchronize before starting the multiplication, to ensure `B` is packed
451+ _mv = _atomic_add! (myp, 0x00000001 )
452+ sync_iters += 0x00000001
453+ let atomp = atomicp
454+ for _ ∈ CloseOpen (total_ids)
455+ while _atomic_load (atomp) ≠ sync_iters
456+ pause ()
492457 end
458+ atomp += cache_linesize ()
459+ end
493460 end
494- B = gesp (B, (Zero (), nsize))
495- C = gesp (C, (Zero (), nsize))
461+ # multiply
462+ let A = A, B = _B, C = C
463+ for m in CloseOpen (Miter)
464+ msize = ifelse ((m+ 1 ) == Miter, Mremfinal, ifelse (m < Mrem, Mblock_Mrem, Mblock))
465+ if k == 0
466+ packaloopmul! (C, A, B, α, β, msize, ksize, nsize)
467+ else
468+ packaloopmul! (C, A, B, α, One (), msize, ksize, nsize)
469+ end
470+ A = gesp (A, (msize, Zero ()))
471+ C = gesp (C, (msize, Zero ()))
472+ end
473+ end
474+ _mv = _atomic_add! (myp + 4 , 0x00000001 )
475+ A = gesp (A, (Zero (), ksize))
476+ B = gesp (B, (ksize, Zero ()))
477+ # synchronize on completion so we wait until every thread is done with `Bpacked` before beginning to overwrite it
478+ let atomp = atomicp
479+ for _ ∈ CloseOpen (total_ids)
480+ while _atomic_load (atomp+ 4 ) ≠ sync_iters
481+ pause ()
482+ end
483+ atomp += cache_linesize ()
484+ end
485+ end
486+ end
496487 end
497- nothing
488+ B = gesp (B, (Zero (), nsize))
489+ C = gesp (C, (Zero (), nsize))
490+ end
491+ nothing
498492end
499493
500494function _matmul! (y:: AbstractVector{T} , A:: AbstractMatrix , x:: AbstractVector , α, β, MKN, contig_axis) where {T<: Real }
0 commit comments