Matrix-matrix multiplication, unconventionally

10 min read Original article ↗

Matrix-matrix multiplication, from less conventional points of view

We consider matrix-matrix multiplication from less frequently seen angles: irregular and arbitrarily nested blocking; indexing by intervals. We will also see that an imperative approach could be insightful and elegant, besides being performant.


Introduction

In the recent paper, Šinkarovs, Koopman and Scholz showed matrix-matrix multiplication from a less-seen angle: as an inductive algorithm similar to merge sort. Impressive is not just elegance: a few lines of code in SaC (a high-performance array language) compile to truly performant code, on par and exceeding OpenBLAS and Intel's MKL -- much faster than the naive matrix-matrix multiplication and hard to achieve. All this without messy and error-prone index and bounds computations, with the guaranteed absence of out-of-bound errors, and assured correctness.

Reading their paper and listening to presentations, one cannot fail to be fascinated -- by the results, and also by some magic. It all appears too simple, too elegant, too good. This article is an attempt to understand the approach at a slower pace, and in conventional terms, without relying on SaC or dependent types -- starting from the standard, low-level definition of matrix-matrix multiplication.

We will see how Šinkarovs, Koopman and Scholz's inductive algorithm comes about -- and see beyond it, into irregular and arbitrarily nested blocking.

We also see something unexpected: the BLAS matrix-matrix multiplication C += A*B is not a mere generalization. It is a key to a different formulation of matrix-matrix multiplication and the highest performance. This imperative, multiply-and-accumulate presentation also turns out elegant -- perhaps even more so than the algebraic and `pure' approach explained before -- besides letting us reason about performance. The algebraic derivation was not a futile exercise, however: its lessons are applicable, albeit not as directly as one might have thought.

References

Artjoms Šinkarovs, Thomas Koopman, Sven-Bodo Scholz: Rank-Polymorphism for Shape-Guided Blocking
FHPNC 2023: Proc. 11th ACM SIGPLAN Intl. Workshop on Functional High-Performance and Numerical Computing. doi:10.1145/3609024.3609410

SaC array programming language
<http://www.sac-home.org/>

Niek Janssen and Sven-Bodo Scholz. On mapping n-dimensional data-parallelism efficiently into gpu-thread-spaces. IFL ’21, pp. 54–66, 2022

Artjoms Šinkarovs and Sven-Bodo Scholz. Parallel scan as a multi- dimensional array problem
Proc. 8th ACM SIGPLAN Intl. Workshop on Libraries, Languages and Compilers for Array Programming, ARRAY 2022, pp. 1–11, 2022

Warm-up: conventional blocking, conventionally

Let's recall the conventional matrix-matrix block multiplication -- to remind the notation and to clearly see the places for generalization. The presentation here is deliberately conventional and low-level, with little abstraction: albeit boring, it must be clearly right. We consider regular blocking, which is common and more intuitive and easier to explain. Extensions and abstractions are left for further sections.

The multiplication of the N×K matrix Aik and the K×M matrix Bkj obtaining the N×M matrix Cij is the very familiar

    Cij = Σk Aik * Bkj     where  i < N, j < M, k < K

(The indices i, j, etc. are implicitly non-negative.)

Assume N, M and K are composite numbers, which factor as

    N = N₁ * N₂      M = M₁ * M₂      K = K₁ * K₂

(where N₁, N₂ are the factors of N, etc.). The indices i, j and k can then be represented as

    i = i₁*N₁ + i₂  where   i₂ < N₁, i₁ < N₂
    j = j₁*M₁ + j₂  where   j₂ < M₁, j₁ < M₂
    k = k₁*K₁ + k₂  where   k₂ < K₁, k₁ < K₂

The rows hence get indexed by a pair i₁i₂ rather than a single i, and so do the columns. Written in terms of such index pairs, the matrix multiplication becomes:

    Ci₁i₂j₁j₂ = Σk₁ Σk₂ Ai₁i₂k₁k₂ * Bk₁k₂j₁j₂

Here, Ai₁i₂k₁k₂ is an A element at the i₁i₂ row and k₁k₂ column. The ordinary, 2D matrix turned into a 4D object: tensor of order (sometimes called rank) 4. Let us write it a bit differently: (Ai₁k₁)i₂k₂. So far, it is a mere change of notation.

    (Ci₁j₁)i₂j₂ = Σk₁ (Σk₂ (Ai₁k₁)i₂k₂ * (Bk₁j₁)k₂j₂)

The inner sum Σk₂ (Ai₁k₁)i₂k₂ * (Bk₁j₁)k₂j₂ should remind of a matrix product, of what looks like a N₁×K₁ matrix Ai₁k₁ and K₁×M₁ matrix Bk₁j₁. When i₂ ranges from 0 to N₁-1 and k₂ from 0 to K₁-1, the elements (Ai₁k₁)i₂k₂ are the elements within the rows i₁*N₁ through i₁*N₁ + N₁ -1 and columns k₁*K₁ through k₁*K₁ + K₁ -1 of the original matrix A. Indeed, Ai₁k₁ looks like a matrix: a rectangular sub-block of A of the size N₁×K₁ -- or tile. The notational change to (Ai₁k₁)i₂k₂ gave a different view of the matrix, as built from N₁×K₁-sized tiles. The tiles are themselves arranged into a N₂×K₂ grid, and hence correspond to a matrix whose elements are tiles rather than numbers. In (Ai₁k₁)i₂k₂, the indices i₁k₁ specify the tile; i₂k₂ index within the tile.

The inner sum Σk₂ (Ai₁k₁)i₂k₂ * (Bk₁j₁)k₂j₂ gives the i₂j₂-th element of the matrix product Ai₁k₁ * Bk₁j₁, which we write as (Ai₁k₁ * Bk₁j₁)i₂j₂. The original matrix-matrix multiplication becomes

    (Ci₁j₁)i₂j₂ = Σk₁(Ai₁k₁ * Bk₁j₁)i₂j₂ 

Here, * now means the tile (i.e., matrix-matrix) multiplication.

Recall, (Ci₁j₁)i₂j₂ is the i₂j₂-th element of the tile (Ci₁j₁) of the result matrix C. According to the above, it is computed by obtaining a series of matrix products Ai₁k₁ * Bk₁j₁ and summing up their i₂j₂-th elements. In other words, the tile Ci₁j₁ is obtained by computing the matrices Ai₁k₁ * Bk₁j₁ and summing them up elementwise -- or sum them as matrices. Thus:

    Ci₁j₁ = Σk₁ Ai₁k₁ * Bk₁j₁

As before, * means matrix-matrix multiplication; now, the summation in Σk₁ means the tile (matrix) summation. This equation looks exactly like the regular matrix multiplication, only * and summation are interpreted as matrix ones rather than numeric.

This is the familiar block-matrix multiplication: tiled matrices (made of tiles) are multiplied as ordinary matrices, with numeric multiplication and addition replaced with tile (matrix) multiplication and addition.

We have merely played with indices and the re-grouping of summation, without changing the order of summands. We have thus relied only on the associativity of addition.

Generalizations are easy to see: the tiles may also be tiled, in turn. One gets the intimations of a recursive (inductive) algorithm. Second, the tiles do not have to have the same size: tiling can be irregular and arbitrary.

Beyond simple-minded indexing

We now attempt to generalize the ordinary, regular tiling -- with the goal of reproducing Šinkarovs, Koopman and Scholz's algorithm, and going beyond. We see that indices need not be integers. We use neither Agda nor SaC, nor dependent types.

The set-up is the same: multiplying the N×K matrix Aik and the K×M matrix Bkj obtaining the N×M matrix Cij. However, N, K and M are no longer integers: they are non-empty integer intervals. The multiplication looks essentially as before

    Cij = Σk Aik * Bkj     where  i ∈ N, j ∈ M, k ∈ K

only the lower index bound no longer has to be zero: the intervals are arbitrary. A very attentive reader may have noted one change, however: k ∈ K does not imply any order in choosing the indices k from the interval K. Therefore, we rely now on commutativity of addition.

Let's partition the interval N into non-empty intervals. They do not have to have the same size; they do have to be non-overlapping and together cover the whole N. The intervals K and M are partitioned likewise, not necessarily the same way. We will use Greek indices to range over the intervals. Repeating the derivation in the previous section, we obtain the multiplication algorithm

    Cνμ = Σκ Aνκ * Bκμ     where  ν ∈ N, μ ∈ M, κ ∈ K

Here, Aνκ, which is indexed by the intervals ν and κ, is a tile of A, of the size corresponding to the sizes of ν and κ: in other words, it is the tile made of Aij where i∈ν and j∈μ. The multiplication and summation/addition is the matrix (tile) multiplication and addition.

This tiling process can clearly be repeated. For example: suppose we want to tile the input matrices into 8×8 tiles. However, their dimensions (say, K) are not a multiple of 8. Therefore, we first partition K into two intervals: one whose the size being a multiple of 8, and the remainder. The former is further partitioned into size-8 intervals. The Šinkarovs, Koopman and Scholz's paper showed more interesting example of nested tiling, to account for various caches.

Overall, we obtain the inductive matrix-matrix multiplication algorithm, by induction on the interval partitioning of the matrix dimensions (in SaC terms, by induction on shape/rank). For details and detailed derivation, see the accompanying code. The presented algorithm is more general than the one implemented in the Šinkarovs, Koopman and Scholz paper: our interval partitioning does not have to be regular (intervals don't have to have the same sizes).

Partitioning of the intervals changes only the indexing scheme, but does not change the arrangement of the matrix elements in memory. To increase locality, it is common therefore to rearrange the elements so that all elements of a tile are close together: so-called (re)tiling of the matrix (called pre-blocking in the Šinkarovs, Koopman and Scholz paper).

There remains an all important question of the efficient implementation of the inductive tiling algorithm: producing imperative machine code. SaC is very good at turning the notation very similar to the one we have been using into high-performance code. This is a SaC magic.

References

matmul.ml [10K]
Complete OCaml code demonstrating nested tiling, rigorously

Imperative, multiply-accumulate presentation

To dispel the SaC magic, we too should look at the imperative code. And then we notice something strange.

The BLAS matrix-matrix multiplication implements not C = A*B (that is, multiplying A and B storing the result in C) but C += A*B: multiplying A and B and adding the result to C. It looks like a trivial generalization; the original algorithm being a particular case of the zero C. In our notation:

    Cij = Cij + Σk Aik * Bkj     where  i ∈ N, j ∈ M, k ∈ K

This trifle generalization is consequential: an element Cij now accumulates the products Aik * Bkj for all k ∈ K. Written imperatively,

    Cij += Aik * Bkj  for all i ∈ N, j ∈ M, k ∈ K

it looks elegant: in fact, more elegant than the `pure' algorithm. To compute C += A*B, we thus execute a series of multiply-accumulate operations Cij += Aik * Bkj -- in any order. Thanks to associativity and commutativity of addition, the order does not matter. So long as we execute all of these multiply-accumulate operations, we get the right result.

Since the order of executing multiply-accumulate operations does not matter for correctness, we have the flexibility of choosing the order that exploits the locality and caches the most, giving the most performance. The order (traversal) can truly be arbitrary, in zigzag or roundabout.