Matrix-matrix multiplication, from less conventional points of view
We consider matrix-matrix multiplication from less frequently seen angles: irregular and arbitrarily nested blocking; indexing by intervals. We will also see that an imperative approach could be insightful and elegant, besides being performant.
- Introduction
- Warm-up: conventional blocking, conventionally
- Beyond simple-minded indexing
- Imperative, multiply-accumulate presentation
Introduction
In the recent paper, Šinkarovs, Koopman and Scholz showed matrix-matrix multiplication from a less-seen angle: as an inductive algorithm similar to merge sort. Impressive is not just elegance: a few lines of code in SaC (a high-performance array language) compile to truly performant code, on par and exceeding OpenBLAS and Intel's MKL -- much faster than the naive matrix-matrix multiplication and hard to achieve. All this without messy and error-prone index and bounds computations, with the guaranteed absence of out-of-bound errors, and assured correctness.
Reading their paper and listening to presentations, one cannot fail to be fascinated -- by the results, and also by some magic. It all appears too simple, too elegant, too good. This article is an attempt to understand the approach at a slower pace, and in conventional terms, without relying on SaC or dependent types -- starting from the standard, low-level definition of matrix-matrix multiplication.
We will see how Šinkarovs, Koopman and Scholz's inductive algorithm comes about -- and see beyond it, into irregular and arbitrarily nested blocking.
We also see something unexpected: the BLAS
matrix-matrix multiplication C += A*B is
not a mere generalization. It is a key to a different formulation of
matrix-matrix multiplication and the highest performance. This imperative,
multiply-and-accumulate presentation also turns out elegant -- perhaps
even more so than the algebraic and `pure' approach explained
before -- besides letting us reason about
performance. The algebraic derivation was not a futile exercise,
however: its lessons are applicable, albeit not as directly
as one might have thought.
References
Artjoms Šinkarovs, Thomas Koopman, Sven-Bodo Scholz:
Rank-Polymorphism for Shape-Guided Blocking
FHPNC 2023: Proc. 11th ACM SIGPLAN Intl.
Workshop on Functional High-Performance and Numerical Computing.
doi:10.1145/3609024.3609410
SaC array programming language
<http://www.sac-home.org/>
Niek Janssen and Sven-Bodo Scholz. On mapping n-dimensional data-parallelism efficiently into gpu-thread-spaces. IFL ’21, pp. 54–66, 2022
Artjoms Šinkarovs and Sven-Bodo Scholz. Parallel scan as a multi-
dimensional array problem
Proc. 8th ACM SIGPLAN
Intl. Workshop on Libraries, Languages and Compilers for
Array Programming, ARRAY 2022, pp. 1–11, 2022
Warm-up: conventional blocking, conventionally
Let's recall the conventional matrix-matrix block multiplication -- to remind the notation and to clearly see the places for generalization. The presentation here is deliberately conventional and low-level, with little abstraction: albeit boring, it must be clearly right. We consider regular blocking, which is common and more intuitive and easier to explain. Extensions and abstractions are left for further sections.
The multiplication of the N×K matrix Aik and
the K×M matrix Bkj obtaining the N×M matrix Cij is the very
familiar
Cij = Σk Aik * Bkj where i < N, j < M, k < K
(The indices i, j, etc. are implicitly non-negative.)
Assume N, M and K are composite numbers, which factor as
N = N₁ * N₂ M = M₁ * M₂ K = K₁ * K₂
(where N₁, N₂ are the
factors of N, etc.).
The indices i, j and k can then be represented as
i = i₁*N₁ + i₂ where i₂ < N₁, i₁ < N₂
j = j₁*M₁ + j₂ where j₂ < M₁, j₁ < M₂
k = k₁*K₁ + k₂ where k₂ < K₁, k₁ < K₂
The rows hence get indexed by a pair
i₁i₂ rather than a single i, and so do the columns.
Written in terms of such index
pairs, the matrix multiplication becomes:
Ci₁i₂j₁j₂ = Σk₁ Σk₂ Ai₁i₂k₁k₂ * Bk₁k₂j₁j₂
Here, Ai₁i₂k₁k₂ is an A element at the i₁i₂ row and k₁k₂
column. The ordinary, 2D matrix
turned into a 4D object: tensor of order (sometimes called rank) 4.
Let us write it a bit differently: (Ai₁k₁)i₂k₂. So far, it
is a mere change of notation.
(Ci₁j₁)i₂j₂ = Σk₁ (Σk₂ (Ai₁k₁)i₂k₂ * (Bk₁j₁)k₂j₂)
The inner sum Σk₂ (Ai₁k₁)i₂k₂ * (Bk₁j₁)k₂j₂ should remind of a
matrix product, of what looks like a N₁×K₁ matrix Ai₁k₁ and K₁×M₁ matrix
Bk₁j₁.
When i₂ ranges from 0 to N₁-1 and k₂ from 0
to K₁-1, the elements (Ai₁k₁)i₂k₂ are the elements within the rows
i₁*N₁ through i₁*N₁ + N₁ -1 and columns
k₁*K₁ through k₁*K₁ + K₁ -1 of the original matrix A.
Indeed, Ai₁k₁ looks like a matrix: a rectangular sub-block of A of the size
N₁×K₁ -- or tile. The notational change to
(Ai₁k₁)i₂k₂ gave a different view of the matrix, as built from
N₁×K₁-sized tiles.
The tiles are themselves arranged into a N₂×K₂
grid, and hence correspond to a matrix whose elements are tiles rather
than numbers. In (Ai₁k₁)i₂k₂, the indices i₁k₁ specify the tile;
i₂k₂ index within the tile.
The inner sum
Σk₂ (Ai₁k₁)i₂k₂ * (Bk₁j₁)k₂j₂ gives the i₂j₂-th element
of the matrix product Ai₁k₁ * Bk₁j₁, which we write as
(Ai₁k₁ * Bk₁j₁)i₂j₂. The original matrix-matrix multiplication becomes
(Ci₁j₁)i₂j₂ = Σk₁(Ai₁k₁ * Bk₁j₁)i₂j₂
Here, * now means the tile (i.e., matrix-matrix) multiplication.
Recall, (Ci₁j₁)i₂j₂ is the i₂j₂-th element of the tile (Ci₁j₁)
of the result matrix C. According to the above, it is computed by
obtaining a series of matrix products Ai₁k₁ * Bk₁j₁ and summing up their
i₂j₂-th elements. In other words, the tile Ci₁j₁ is obtained by
computing the matrices Ai₁k₁ * Bk₁j₁ and
summing them up elementwise -- or sum them as matrices. Thus:
Ci₁j₁ = Σk₁ Ai₁k₁ * Bk₁j₁
As before, * means matrix-matrix multiplication; now,
the summation in Σk₁ means the tile (matrix) summation. This equation looks
exactly like the regular matrix multiplication, only * and
summation are interpreted as matrix ones rather than numeric.
This is the familiar block-matrix multiplication: tiled matrices (made of tiles) are multiplied as ordinary matrices, with numeric multiplication and addition replaced with tile (matrix) multiplication and addition.
We have merely played with indices and the re-grouping of summation, without changing the order of summands. We have thus relied only on the associativity of addition.
Generalizations are easy to see: the tiles may also be tiled, in turn. One gets the intimations of a recursive (inductive) algorithm. Second, the tiles do not have to have the same size: tiling can be irregular and arbitrary.
Beyond simple-minded indexing
We now attempt to generalize the ordinary, regular tiling -- with the goal of reproducing Šinkarovs, Koopman and Scholz's algorithm, and going beyond. We see that indices need not be integers. We use neither Agda nor SaC, nor dependent types.
The set-up is the same: multiplying the N×K matrix Aik and
the K×M matrix Bkj obtaining the N×M matrix Cij. However,
N, K and M are no longer integers: they are non-empty
integer intervals. The multiplication looks essentially as before
Cij = Σk Aik * Bkj where i ∈ N, j ∈ M, k ∈ K
only the lower
index bound no longer has to be zero: the intervals are arbitrary.
A very attentive reader may have noted one change, however:
k ∈ K does not imply any order in choosing the indices k from the
interval K. Therefore, we rely now on commutativity of addition.
Let's partition the interval N into non-empty intervals. They do not
have to have the same size; they do have to be non-overlapping and
together cover the whole N. The intervals K and M are
partitioned likewise, not necessarily the same way. We will use Greek
indices to range over the intervals. Repeating the derivation in the
previous section, we obtain the multiplication algorithm
Cνμ = Σκ Aνκ * Bκμ where ν ∈ N, μ ∈ M, κ ∈ K
Here, Aνκ, which is indexed by the intervals ν and κ,
is a tile of A, of the size corresponding to the sizes of ν and κ:
in other words, it is the tile made of Aij where i∈ν and
j∈μ. The multiplication and summation/addition is the matrix (tile)
multiplication and addition.
This tiling process can clearly be repeated. For example: suppose
we want to tile the input matrices into 8×8 tiles. However, their
dimensions (say, K) are not a multiple of 8. Therefore, we first
partition K into two intervals: one whose the size being a multiple
of 8, and the remainder. The former is further partitioned into size-8
intervals. The Šinkarovs, Koopman and Scholz's paper showed more
interesting example of nested tiling, to account for various caches.
Overall, we obtain the inductive matrix-matrix multiplication algorithm, by induction on the interval partitioning of the matrix dimensions (in SaC terms, by induction on shape/rank). For details and detailed derivation, see the accompanying code. The presented algorithm is more general than the one implemented in the Šinkarovs, Koopman and Scholz paper: our interval partitioning does not have to be regular (intervals don't have to have the same sizes).
Partitioning of the intervals changes only the indexing scheme, but does not change the arrangement of the matrix elements in memory. To increase locality, it is common therefore to rearrange the elements so that all elements of a tile are close together: so-called (re)tiling of the matrix (called pre-blocking in the Šinkarovs, Koopman and Scholz paper).
There remains an all important question of the efficient implementation of the inductive tiling algorithm: producing imperative machine code. SaC is very good at turning the notation very similar to the one we have been using into high-performance code. This is a SaC magic.
References
matmul.ml [10K]
Complete OCaml code demonstrating nested tiling, rigorously
Imperative, multiply-accumulate presentation
To dispel the SaC magic, we too should look at the imperative code. And then we notice something strange.
The BLAS matrix-matrix multiplication implements not C = A*B
(that is, multiplying A and B storing the result in C)
but C += A*B: multiplying A and B and adding the result to
C. It looks like a trivial generalization; the original algorithm being
a particular case of the zero C. In our notation:
Cij = Cij + Σk Aik * Bkj where i ∈ N, j ∈ M, k ∈ K
This trifle generalization is consequential: an element Cij now
accumulates the products Aik * Bkj for all k ∈ K.
Written imperatively,
Cij += Aik * Bkj for all i ∈ N, j ∈ M, k ∈ K
it looks elegant: in fact, more elegant than
the `pure' algorithm. To compute C += A*B, we thus execute a series of
multiply-accumulate operations Cij += Aik * Bkj -- in any
order. Thanks to associativity and commutativity of addition, the
order does not matter. So long as we execute all of these
multiply-accumulate operations, we get the right result.
Since the order of executing multiply-accumulate operations does not matter for correctness, we have the flexibility of choosing the order that exploits the locality and caches the most, giving the most performance. The order (traversal) can truly be arbitrary, in zigzag or roundabout.