You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
@@ -26,7 +27,7 @@ objectives. This section is intended to provide more detail.
26
27
27
28
## <aname="S-patterns-sequencing-nesting"></a> Sequencing and Nesting of Collective Primitives
28
29
29
-
CUTLASS embodies a design paradigm exemplified by the [CUB library](https://nvlabs.github.io/cub/) for expressing collective operations. Objects expose an interface for a problem that is then decomposed into concurrent subtasks executed by cooperating threadblocks, warps, and threads. For example, a grid-level object may be constructed with base pointers to the start of a GEMM operation, add a threadblock-dependent offset to partition the problem, and then compute a per-threadblock GEMM. This in turn performs some operations as a collection of cooperating threads, while it may partition other parts of the task into warp-level subtasks.
30
+
CUTLASS embodies a design paradigm exemplified by the [CUB library](https://nvlabs.github.io/cub/) for expressing collective operations. Objects expose an interface for a problem that is then decomposed into concurrent subtasks executed by cooperating threadblocks, warps, and threads. For example, a grid-level object may be constructed with base pointers to the start of a GEMM operation, add a threadblock-dependent offset to partition the problem, and then compute a per-threadblock GEMM. This in turn performs some operations as a collection of cooperating threads, while it may partition other parts of the task into warp-level subtasks.
30
31
31
32
## <aname="S-patterns-tiles-iterators"></a> Tiles and Iterators
32
33
@@ -48,7 +49,7 @@ CUTLASS can take advantage of this CUDA grid-invariant property by constructing
48
49
49
50
The design pattern in CUTLASS is for classes with nontrivial constructors to define `struct Params` as an inner class which contains grid-invariant state. These should define a constructor and an `initialize()` method. The `Params` structure should also include a data member corresponding to each data member in the parent class, so these too can be properly constructed in host code. The parent class should define a constructor which accepts `Params const &` as its first argument.
50
51
51
-
For example, `cutlass::gemm::Gemm<>` should define `struct cutlass::gemm::Gemm::Params`. The latter should define data members for each data member in `cutlass::gemm::Gemm<>`.
52
+
For example, `cutlass::gemm::Gemm<>` should define `struct cutlass::gemm::Gemm::Params`. The latter should define data members for each data member in `cutlass::gemm::Gemm<>`.
@@ -94,7 +95,7 @@ multiply operation performed by each iteration of the mainloop is referred to as
94
95
95
96
The threadblock loads a sequence of tiles from global memory and stores this data to shared memory. The iterative
96
97
access and traversal of tiles in global memory are performed by a _TileLoadIterator_, and storing to a circular
97
-
buffer in shared memory is performed by a _GlobalLoadIterator_.
98
+
buffer in shared memory is performed by a _GlobalLoadIterator_.
98
99
99
100
**[Global Load Stream](cutlass/gemm/gemm_global_stream.h)** manages loading of the threadblock-scope multiplicands to the GEMM kernel. It owns an iterator into global memory for loading tiles of data, a TensorAllocation in shared memory to hold the resulting tile, and an iterator for writing the tile into this allocation. A transformer exists to optionally transform the data as it is loaded which may of use to perform type conversion or, in the case of int8 GEMM, transpose 4x4 tiles held in registers.
100
101
@@ -109,24 +110,24 @@ The Global Load Stream template contains members defined by the following templa
109
110
The threadblock's _OutputTile_ is partitioned among the warps, and each computes a warp-level matrix product.
110
111
Data is loaded from shared memory into registers, and math instructions are dispatched to CUDA Cores or Tensor Cores.
111
112
112
-
[**Shared Load Stream**](cutlass/gemm/gemm_shared_stream.h) manages loading of warp-level multiplicands from shared memory into registers. This owns an iterator for fetching data and the destination fragments for holding the results.
113
+
[**Shared Load Stream**](cutlass/gemm/gemm_shared_stream.h) manages loading of warp-level multiplicands from shared memory into registers. This owns an iterator for fetching data and the destination fragments for holding the results.
**Matrix Multiply** computes a matrix product operation on data held in registers. Specializations exist for thread-level instructions such as single-precision fused multiply-add as well as warp-level matrix operations targeting TensorCores.
117
+
**Matrix Multiply** computes a matrix product operation on data held in registers. Specializations exist for thread-level instructions such as single-precision fused multiply-add as well as warp-level matrix operations targeting TensorCores.
The [**epilogue**](cutlass/gemm/gemm_epilogue.h) iteratively selects a subset of accumulator elements held by a warp, writes them to shared memory, and loads them by different threads such that a threadblock-scoped tile store operation will make contiguous, striped accesses to global memory. Thus, the flow of data utilizes the following components:
132
133
@@ -227,7 +228,7 @@ must specify compile-time constant tile sizes.
SIMT architectures utilize predicated execution in place of control flow when conditional code sequences are fairly short, on the order of a few machine instructions. While CUDA C++ does not include constructs at the language level for predication, PTX makes this explicit, and compilation to SASS is assumed to aggressively utilize predication. Typical applications are to initialize a sequence of bits used to mask memory operations and use these bits as predicates guarding memory load and store instructions.
300
+
SIMT architectures utilize predicated execution in place of control flow when conditional code sequences are fairly short, on the order of a few machine instructions. While CUDA C++ does not include constructs at the language level for predication, PTX makes this explicit, and compilation to SASS is assumed to aggressively utilize predication. Typical applications are to initialize a sequence of bits used to mask memory operations and use these bits as predicates guarding memory load and store instructions.
300
301
301
-
CUTLASS provides `PredicateVector` defined in [cutlass/predicate_vector.h](cutlass/predicate_vector.h) to manage a statically-sized bit vector, store them into general purpose registers, and efficiently access them in sequence. By storing four predicates per byte in hardware registers, the CUDA compiler is able to issue specialized instructions to achieve very efficient unpacking.
302
+
CUTLASS provides `PredicateVector` defined in [cutlass/predicate_vector.h](cutlass/predicate_vector.h) to manage a statically-sized bit vector, store them into general purpose registers, and efficiently access them in sequence. By storing four predicates per byte in hardware registers, the CUDA compiler is able to issue specialized instructions to achieve very efficient unpacking.
302
303
303
304
304
305
# <aname="S-utilities"></a> 4. Utilities
@@ -310,6 +311,46 @@ framework offering features such as:
310
311
* Components for allocating and initializing [host-side and device-side tensors](tools/util/host_tensor.h) usable by CUTLASS
311
312
* Reference implementations of [GEMM](tools/util/reference/host/gemm.h) and [element-wise operations](tools/util/reference/host/tensor_elementwise.h)
This section describes several strategies taken to increase performance beyond what is achievable with
318
+
a basic implementation of the hierarchical GEMM structure.
319
+
320
+
321
+
## Threadblock Rasterization
322
+
323
+
To maximize reuse of data held in the last level cache, CUTLASS defines several functions to
324
+
affect the mapping of threadblocks to logical partitions of the GEMM problem. These map
325
+
consecutively launched threadblocks to packed two-dimensional regions of the partitioned GEMM
326
+
problem to increase the probability that these will access the same tiles of global memory at
327
+
approximately the same time.
328
+
329
+
Several functions are defined in [cutlass/gemm/threadblock_swizzle.h](cutlass/gemm/threadblock_swizzle.h).
330
+
331
+
332
+
## Parallel Reductions across GEMM _K_
333
+
334
+
Matrix product computations expose parallelism among _O(MN)_ independent inner product
335
+
computations. For sufficiently large problem sizes, a GEMM kernel in CUTLASS may approach
336
+
the theoretical maximum computational throughput. For small problems, however, there are
337
+
too few threadblocks to efficiently occupy the entire GPU.
338
+
339
+
As a recourse, parallelizing the reduction performed during the inner product computation
340
+
enables more threadblocks to execute concurrently while still taking advantage of the throughput
341
+
benefits of large threadblock-level GEMM tiles.
342
+
343
+
CUTLASS implements parallel reductions across threadblocks by partitioning the GEMM _K_ dimension
344
+
and launching an additional set of threadblocks for each partition. Consequently, we refer to
345
+
this strategy within CUTLASS as "parallel reduction splitK." The "parallel reduction splitK" in cutlass requires the execution of 2 kernels. The first one is called partitionedK GEMM. The second one is called batched reduction.
346
+
347
+
The partitionedK GEMM is very similar to one flavor of batched strided GEMM. Instead of requiring users to specify the problem size of each batch, partitionedK GEMM asks for the overall problem size and the number of partition that will be applied along K dimension for operand A and B. For example, parameters of m=128, n=128, k=4096 and partition=16 will result in 16 batched strided GEMMs with each batch of m=128, n=128, k=256. PartitionedK also allows scenario where k is not divisible by partition count. For example, parameters of m=128, n=128, k=4096 and partition=20 will result in 20 batched strided GEMMs with the first 19 batches of m=128, n=128, k=4096/20=204 and the last batch of m=128, n=128, k=220.
348
+
349
+
The batched reduction kernel will further perform reduction along the K-dimension. Thus, the input of the batched reduction kernel is the output (C) of partitionedK GEMM. An workspace memory is managed by the users to store this intermediate results.
350
+
351
+
An example of splitK usage can be found [here](examples/06_splitK_gemm/splitK_gemm.cu).
352
+
353
+
313
354
# Copyright
314
355
315
356
Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
@@ -335,4 +376,3 @@ Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
335
376
STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
336
377
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
0 commit comments