@@ -226,10 +226,15 @@ initSchedule(int maxDist, int stages[SCHED_SIZE], int numStages,
226226 return success ();
227227}
228228
229- void createAndScheduleAsyncCopy (
230- tt::LoadOp loadOp, Value alloc, Value extractIdx, scf::ForOp forOp,
231- tt::CoarseSchedule &schedule, const int stages[SCHED_SIZE],
232- const std::array<tt::CoarseSchedule::Cluster, SCHED_SIZE> &clusters) {
229+ struct AsyncCopyChainOps {
230+ ttg::AsyncCopyGlobalToLocalOp copyOp;
231+ ttg::AsyncCommitGroupOp commitOp;
232+ ttg::AsyncWaitOp waitOp;
233+ ttg::LocalLoadOp localLoadOp;
234+ };
235+
236+ AsyncCopyChainOps createAsyncCopy (tt::LoadOp loadOp, Value alloc,
237+ Value extractIdx, scf::ForOp forOp) {
233238 OpBuilder builder (loadOp);
234239 Location loc = loadOp.getLoc ();
235240
@@ -274,9 +279,15 @@ void createAndScheduleAsyncCopy(
274279 auto sharedLoad =
275280 builder.create <ttg::LocalLoadOp>(loc, loadOp.getType (), viewLoad, waitOp);
276281
282+ return {copyOp, commitOp, waitOp, sharedLoad};
283+ }
284+
285+ void scheduleAsyncCopy (
286+ const AsyncCopyChainOps &asyncOps, tt::LoadOp loadOp,
287+ tt::CoarseSchedule &schedule, const int stages[SCHED_SIZE],
288+ const std::array<tt::CoarseSchedule::Cluster, SCHED_SIZE> &clusters) {
289+ auto [copyOp, commitOp, waitOp, localLoadOp] = asyncOps;
277290 auto [loadStage, loadCluster] = schedule[loadOp];
278- schedule.erase (loadOp);
279- // Schedule new ops
280291 schedule.insert (copyOp, loadStage, loadCluster);
281292 // Place ttg.async_commit_group op following AsyncCopyGlobalToLocal so the
282293 // later UpdateAsyncWaitCount pass can deduce better waitcnts
@@ -292,25 +303,41 @@ void createAndScheduleAsyncCopy(
292303 clusters[SCHED_ASYNC_WAIT]);
293304
294305 if (stages[SCHED_LOCAL_LOAD] != stages[SCHED_COMPUTE])
295- schedule.insert (sharedLoad , stages[SCHED_LOCAL_LOAD],
306+ schedule.insert (localLoadOp , stages[SCHED_LOCAL_LOAD],
296307 clusters[SCHED_LOCAL_LOAD]);
297308
298- loadOp->replaceAllUsesWith (ValueRange{sharedLoad});
299309 if (stages[SCHED_LOCAL_LOAD] != stages[SCHED_COMPUTE] &&
300- sharedLoad ->hasOneUse ()) {
310+ localLoadOp ->hasOneUse ()) {
301311 if (auto cvt =
302- dyn_cast<ttg::ConvertLayoutOp>(*sharedLoad ->getUsers ().begin ()))
312+ dyn_cast<ttg::ConvertLayoutOp>(*localLoadOp ->getUsers ().begin ()))
303313 schedule.insert (cvt, stages[SCHED_LOCAL_LOAD],
304314 clusters[SCHED_LOCAL_LOAD]);
305315 }
306-
307- loadOp.erase ();
308316}
309317
310- void createAndScheduleStreamCopy (
318+ void createAndScheduleAsyncCopy (
311319 tt::LoadOp loadOp, Value alloc, Value extractIdx, scf::ForOp forOp,
312320 tt::CoarseSchedule &schedule, const int stages[SCHED_SIZE],
313321 const std::array<tt::CoarseSchedule::Cluster, SCHED_SIZE> &clusters) {
322+
323+ auto asyncOps = createAsyncCopy (loadOp, alloc, extractIdx, forOp);
324+ loadOp->replaceAllUsesWith (ValueRange{asyncOps.localLoadOp });
325+
326+ scheduleAsyncCopy (asyncOps, loadOp, schedule, stages, clusters);
327+
328+ schedule.erase (loadOp);
329+ loadOp.erase ();
330+ }
331+
332+ struct StreamCopyChainOps {
333+ tt::LoadOp copyOp;
334+ ttg::MemDescSubviewOp subviewOp;
335+ ttg::LocalStoreOp localStoreOp;
336+ ttg::LocalLoadOp localLoadOp;
337+ };
338+
339+ StreamCopyChainOps createStreamCopy (tt::LoadOp loadOp, Value alloc,
340+ Value extractIdx, scf::ForOp forOp) {
314341 OpBuilder builder (forOp);
315342 Value zero = builder.create <arith::ConstantIntOp>(forOp.getLoc (), 0 , 32 );
316343 // Replace the load with insert/extract slice.
@@ -319,11 +346,7 @@ void createAndScheduleStreamCopy(
319346
320347 ttg::MemDescType allocTy = cast<ttg::MemDescType>(alloc.getType ());
321348 SmallVector<Value> copyOffsets (allocTy.getRank (), zero);
322- Operation *copy = builder.clone (*loadOp);
323-
324- auto [stage, cluster] = schedule[loadOp];
325- schedule.erase (loadOp);
326- schedule.insert (copy, stage, cluster);
349+ tt::LoadOp copy = cast<tt::LoadOp>(builder.clone (*loadOp));
327350
328351 // Extract part.
329352 SmallVector<Value> loadOffsets (allocTy.getRank (), zero);
@@ -332,43 +355,66 @@ void createAndScheduleStreamCopy(
332355 auto subviewTy = ttg::MemDescType::get (
333356 allocTy.getShape ().drop_front (), allocTy.getElementType (),
334357 allocTy.getEncoding (), sharedMemorySpace, /* mutableMemory=*/ true );
335- auto viewLoad =
358+ auto subviewOp =
336359 builder.create <ttg::MemDescSubviewOp>(loc, subviewTy, alloc, loadOffsets);
337360 // Clean up old local caches.
338361 SmallVector<ttg::LocalAllocOp> allocsToErase;
339362 for (Operation *user : loadOp->getUsers ()) {
340363 if (auto userAlloc = dyn_cast<ttg::LocalAllocOp>(user)) {
341- tt::replaceUsesAndPropagateType (builder, userAlloc, viewLoad.getResult ());
364+ tt::replaceUsesAndPropagateType (builder, userAlloc,
365+ subviewOp.getResult ());
342366 allocsToErase.push_back (userAlloc);
343367 }
344368 }
345369 for (auto allocToErase : allocsToErase)
346370 allocToErase.erase ();
347371
348372 // Prefetch load ahead of the dot stage if is used by the dot.
349- auto storeOp =
350- builder.create <ttg::LocalStoreOp>(loc, copy->getResult (0 ), viewLoad);
351- schedule.insert (viewLoad, stages[SCHED_LOCAL_STORE],
373+ auto storeOp = builder.create <ttg::LocalStoreOp>(loc, copy, subviewOp);
374+
375+ auto sharedLoad =
376+ builder.create <ttg::LocalLoadOp>(loc, loadOp.getType (), subviewOp);
377+
378+ return {copy, subviewOp, storeOp, sharedLoad};
379+ }
380+
381+ void scheduleStreamCopy (
382+ const StreamCopyChainOps &streamOps, tt::LoadOp loadOp,
383+ tt::CoarseSchedule &schedule, const int stages[SCHED_SIZE],
384+ const std::array<tt::CoarseSchedule::Cluster, SCHED_SIZE> &clusters) {
385+ auto [copyOp, subviewOp, localStoreOp, localLoadOp] = streamOps;
386+ auto [stage, cluster] = schedule[loadOp];
387+ schedule.insert (copyOp, stage, cluster);
388+
389+ schedule.insert (subviewOp, stages[SCHED_LOCAL_STORE],
352390 clusters[SCHED_LOCAL_STORE]);
353- schedule.insert (storeOp , stages[SCHED_LOCAL_STORE],
391+ schedule.insert (localStoreOp , stages[SCHED_LOCAL_STORE],
354392 clusters[SCHED_LOCAL_STORE]);
355393
356- // Create local load
357- auto sharedLoad =
358- builder.create <ttg::LocalLoadOp>(loc, loadOp.getType (), viewLoad);
359- Value result = sharedLoad.getResult ();
360394 if (stages[SCHED_LOCAL_LOAD] != stages[SCHED_COMPUTE])
361- schedule.insert (sharedLoad , stages[SCHED_LOCAL_LOAD],
395+ schedule.insert (localLoadOp , stages[SCHED_LOCAL_LOAD],
362396 clusters[SCHED_LOCAL_LOAD]);
363397
364- loadOp-> replaceAllUsesWith (ValueRange{result});
365-
366- if (stages[SCHED_LOCAL_LOAD] != stages[SCHED_COMPUTE] && result. hasOneUse ()) {
367- if ( auto cvt = dyn_cast<ttg::ConvertLayoutOp>(*result. getUsers ().begin ()))
398+ if (stages[SCHED_LOCAL_LOAD] != stages[SCHED_COMPUTE] &&
399+ localLoadOp-> hasOneUse ()) {
400+ if ( auto cvt =
401+ dyn_cast<ttg::ConvertLayoutOp>(*localLoadOp-> getUsers ().begin ()))
368402 schedule.insert (cvt, stages[SCHED_LOCAL_LOAD],
369403 clusters[SCHED_LOCAL_LOAD]);
370404 }
405+ }
406+
407+ void createAndScheduleStreamCopy (
408+ tt::LoadOp loadOp, Value alloc, Value extractIdx, scf::ForOp forOp,
409+ tt::CoarseSchedule &schedule, const int stages[SCHED_SIZE],
410+ const std::array<tt::CoarseSchedule::Cluster, SCHED_SIZE> &clusters) {
371411
412+ auto streamOps = createStreamCopy (loadOp, alloc, extractIdx, forOp);
413+ loadOp->replaceAllUsesWith (ValueRange{streamOps.localLoadOp });
414+
415+ scheduleStreamCopy (streamOps, loadOp, schedule, stages, clusters);
416+
417+ schedule.erase (loadOp);
372418 loadOp.erase ();
373419}
374420
0 commit comments