tinygrad-notes

Pattern matcher

The general idea of pattern matcher is to express the entire computation into a nested tree, and look for pattern that can be rewritten for optimization. The idea isn’t new, and has been implemented to some extent for a while, although you may not have recognized it.

Code generation

Suppose we have this script:

from tinygrad import Tensor

a = Tensor.empty(4, 4)
b = a + 1
b.realize()

We want to see the generated kernel code (GPU code) for a non-optimized case, to do that, run NOOPT=1 DEBUG=5 python script.py. The NOOPT would disable certain optimization, which we will compare against later. And DEBUG=5 will output the code as well as the AST tree. Among the output, look for this C++ looking code:

#include <metal_stdlib>
using namespace metal;
kernel void E_16(device float* data0, device float* data1, uint3 gid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]]) {
  int gidx0 = gid.x; /* 16 */
  float val0 = *(data1+gidx0);
  *(data0+gidx0) = (val0+1.0f);
}

You may see something slightly different if you are using CUDA, but should remain more or less the same. Look at the last line, where we are storing the a single element result (val0 + 1.0f) at the output data (data0) at the offset gidx0, which is its grid id position.

The way this code is generated, is by looking at the AST. The output from the script shows one version of this AST:

UOp(Ops.SINK, dtypes.void, arg=KernelInfo(local_dims=0, upcasted=0, dont_use_locals=False), src=(
  UOp(Ops.STORE, dtypes.void, arg=None, src=(
    UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
    x2:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(16,), strides=(1,), offset=0, mask=None, contiguous=True),)), src=()),
    UOp(Ops.ADD, dtypes.float, arg=None, src=(
      UOp(Ops.LOAD, dtypes.float, arg=None, src=(
        UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
         x2,)),
      UOp(Ops.WHERE, dtypes.float, arg=None, src=(
        UOp(Ops.VALID, dtypes.bool, arg=None, src=(
          UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(16,), strides=(0,), offset=0, mask=None, contiguous=False),)), src=()),)),
        UOp(Ops.CONST, dtypes.float, arg=1.0, src=()),
        UOp(Ops.CONST, dtypes.float, arg=0.0, src=()),)),)),)),))

You should be able to map some parts of it to the generated code. For example, Ops.STORE with src of DEFINE_GLOBAL maps to data0[, the dtypes.float.ptr results in the pointer dereferencing (*), since it is part of the src of a STORE operation, an equal sign (=) should follow, after which we should do code generation for its siblings, which would be the right hand side. The right hand side is the result of an ADD (Ops.ADD) operation, with two operands (two items in its src). The first is a load operation from the parameter data1, the second is a WHERE op. You get the idea. Let me hand write a simpler version of this, just for the last line of our code:

UOp(op=STORE, src=(
  UOp(op=DEFINE_GLOBAL),
  UOp(op=ADD, src=(
    UOp(op=LOAD, src=(
      UOp(op=DEFINE_GLOBAL)
    )),
    UOp(op=CONST, arg=1.0)
  ))
)

This is just to illustrate the idea, if you look at it, it should intuitively translates to:

*(data0+gidx0) = (val0+1.0f);

There are some details missing, but I hope the idea makes sense.

Now, you may naively implement the code generation with a lot of if else statements, for example:

code = [] 
uop = UOp(op=STORE, src=()) # given the example above
if uop.op is STORE:
  code.append("=")
elif uop.op is ADD:
  code.append("+")
elif uop.op is CONST:
  code.append(uop.arg)
#...

and have it buried in some sort of recursion and clever dictionary mappings. In fact that’s how code generation was previously implemented. But that leads to many repetitive code. We can do better by isolating the pattern into a list of, uh, patterns and writer functions:

patterns = [
  (STORE, lambda uop: "="),
  (CONST, lambda uop: f" {uop.arg} "),
  (ADD, lambda uop: f" + "),
]

and then have a function iterate over it:

def render_code(uop):
  code = []
  for _uop in uop: # Suppose you already did a DFS/BFS so the tree is flattened
    for pattern in patterns:
      if _uop.op == pattern[0]:
        _code = pattern[1](_uop)
        code.append(_code) 

Again, this code is overly simplified to illustrate the main idea. But you can see the separation between the matcher, and code generation spec. Our patterns can be further extended to match not just op, but also data types, if you just populate the tuples a bit more:

patterns = [
  (STORE, dtypes.long, src=())
]

And the matcher will just have a few more if statements for each category.

Let’s actually give them the proper name, the patterns is class PatternMatcher, and it receives the list as initializer arguments. The tuple we haphazardly defined is called class UPat, and it has proper arguments like op dtype, src. If you look into the actual code, you see that it can also match things recursively, and pass the matched element to the lambda function. For example for our STORE uop

  UOp(Ops.STORE, dtypes.void, arg=None, src=(
    UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
    ...
    UOp(Ops.ADD, dtypes.float, arg=None, src=(
      ...
    )

You can write the following to match it:

UPat(Ops.STORE, name="x", src=(UPat.var("define_global"), UPat.var("addition")), lambda x, define_global, addition: ... )

The uop for STORE will be passed to lambda in the argument specified by name, in this case x will be the uop. The matched source elements, will carry the name we specified in UPat.var, so you can access the two elements via define_global and addition. There are more interesting and powerful ways to use it but you do have to read the code to discover them.

Now if you check out the code generation implementation, you can see how it works. The render function iterates through the linearized uop tree, and match each against the pattern specified in the PatternMatcher, for each match it outputs the string, and the strings were combined at the end for the final rendered code.

Pattern rewrite

If we can generate code from AST, we can also rewrite AST, in fact that’s the original use case for pattern matcher. All we have to do is just change the implementation of our lambda function, instead of returning strings, it returns another UOp.

Recall in our example script, we ran with NOOPT=1, this disable the optimization, so you see that each element is being added in a separate thread (We launched a total of 16 of them, for the 4 by 4 matrix). You may have learned from elsewhere that vectorized form is more efficient, so instead of doing 16 threads, we can launch only 4, and have each process 4 elements in a vectorized form. Now let’s run the script without NOOPT=1, just DEBUG=5 python script.py:

#include <metal_stdlib>
using namespace metal;
kernel void E_4_4(device float* data0, device float* data1, uint3 gid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]]) {
  int lidx0 = lid.x; /* 4 */
  int alu0 = (lidx0<<2);
  float4 val0 = *(device float4*)((data1+alu0));
  *(device float4*)((data0+alu0)) = float4((val0.x+1.0f),(val0.y+1.0f),(val0.z+1.0f),(val0.w+1.0f));
}

The pattern matcher that would generate float4 is written as such:

  (UPat(Ops.VECTORIZE, name="x"),
   lambda ctx,x: f"{ctx.float4.replace('float4', ctx.render_dtype(x.dtype))}" + \
    (f"}" if ctx.device == "CLANG" else f"({','.join([ctx[y] for y in x.src])})")),

It may be hard to read because it compacts a lot of other important details, but for our purpose, focus on the float4 string, that maps directly to the float4 in the generated code. It then loop through each element of the vector (float4 means there are four of them), and join them with a comma.

Now, if you compare the two UOp output section between our two run, you should note that the latter has a VECTORIZE entry:

0 Ops.CONST           : dtypes.float                   []                               1.0
   1 Ops.CONST           : dtypes.int                     []                               2
   2 Ops.DEFINE_GLOBAL   : dtypes.float.ptr()             []                               0
   3 Ops.DEFINE_GLOBAL   : dtypes.float.ptr()             []                               1
   4 Ops.SPECIAL         : dtypes.int                     []                               ('lidx0', 4)
   5 Ops.SHL             : dtypes.int                     [4, '2']                         None
   6 Ops.INDEX           : dtypes.float.ptr()             [2, 5]                           None
   7 Ops.CAST            : dtypes.float.vec(4).ptr()      [6]                              None
   8 Ops.INDEX           : dtypes.float.ptr()             [3, 5]                           None
   9 Ops.CAST            : dtypes.float.vec(4).ptr()      [8]                              None
  10 Ops.LOAD            : dtypes.float.vec(4)            [9]                              None
  11 Ops.GEP             : dtypes.float                   [10]                             (0,)
  12 Ops.GEP             : dtypes.float                   [10]                             (1,)
  13 Ops.GEP             : dtypes.float                   [10]                             (2,)
  14 Ops.GEP             : dtypes.float                   [10]                             (3,)
  15 Ops.ADD             : dtypes.float                   [11, '1.0']                      None
  16 Ops.ADD             : dtypes.float                   [12, '1.0']                      None
  17 Ops.ADD             : dtypes.float                   [13, '1.0']                      None
  18 Ops.ADD             : dtypes.float                   [14, '1.0']                      None
  19 Ops.VECTORIZE       : dtypes.float.vec(4)            [15, 16, 17, 18]                 None
  20 Ops.STORE           : dtypes.void                    [7, 19]                          None

(By the time you read this, the exact output may have changed, so the key here is just the entry Ops.VECTORIZE, the rest of the details may go out of date)

How do we end up with vectorize? With rewrite rules. With some additional print statement, you may discover that vectorized originates from this section of the uop:

UOp(Ops.STORE, dtypes.void, arg=None, src=(
    UOp(Ops.INDEX, dtypes.float.ptr(), arg=None, src=(
      UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
      x3:=UOp(Ops.ADD, dtypes.int, arg=None, src=(

After a series of pattern matcher, it becomes

UOp(Ops.STORE, dtypes.void, arg=None, src=(
    UOp(Ops.INDEX, dtypes.float.ptr().vec(4), arg=None, src=(
      UOp(Ops.VECTORIZE, dtypes.float.ptr().vec(4), arg=None, src=(
        x3:=UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
         x3,
         x3,
         x3,)),

And the actual pattern:

  (UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.GEP, Ops.WMMA, Ops.LOAD, Ops.STORE, Ops.INDEX, Ops.ASSIGN,
         Ops.VECTORIZE, Ops.IF), name="root", custom_early_reject=set([Ops.EXPAND])), do_expand),

It matches Ops.STORE among all the accepted ops (the ops can take a tuple and match if any one of them is found), and the STORE becomes root argument passed to the do_expand function. This function returns a new version of UOp:

def do_expand(root): #root is the UOp(Ops.STORE, dtypes.void, arg=None, src=(...))
  #...
    new_srcs.append(UOp(Ops.VECTORIZE,
                      src.dtype.scalar().vec(expand_sz*src.dtype.count), tuple(src.gep(i) for i in range(src.dtype.count))*expand_sz))

Small example

Here’s an actual example you can toy with:

from tinygrad import dtypes
from tinygrad.ops import Ops, UOp
from tinygrad.renderer.cstyle import MetalRenderer

metal_renderer = MetalRenderer()
const = UOp(Ops.CONST, dtypes.float, arg=1.0)
define_global = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0)
special = UOp(Ops.SPECIAL, dtypes.int, arg=('gidx0', 16), src=())
added = UOp(Ops.ADD, dtypes.long, arg=None, src=(define_global, special))
store = UOp(Ops.STORE, dtypes.void, arg=None, src=(added, const))
uops = [const, define_global, special, added, store]

rendered = metal_renderer.render("rendered", uops)
print(rendered)

"""
#include <metal_stdlib>
using namespace metal;
kernel void rendered(device float* data0, uint3 gid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]]) {
  int gidx0 = gid.x; /* 16 */
  *(data0+gidx0) = 1.0f;
}
"""

It generates the kernel code that would write “1” to a 16 element tensor. Note that the argument to be passed to render must be linearized, hence the hand written list.

Let’s see what pattern rewrite would look like. Suppose we don’t like the fact it writes 1.0f directly, we want it to calculate 0.5f + 0.5f in the output (for no other reason than a demo). We can define a pattern matcher as such:

from tinygrad.ops import PatternMatcher, UPat

const_1 = UOp(Ops.CONST, dtypes.float, arg=0.5)
const_2 = UOp(Ops.CONST, dtypes.float, arg=0.5)

matcher = PatternMatcher([
  (UPat(Ops.CONST, dtypes.float, name="x"), lambda ctx, x: UOp(Ops.ADD, dtypes.float, src=(const_1, const_2))),
])

const = UOp(Ops.CONST, dtypes.float, arg=1.0)
const_rewritten = matcher.rewrite(const)

This will rewrite the original const into an addition of two 0.5s. Now, to make it work with the rest of the script, a few changes are needed. First, the const_1 and const_2 must be added to the uops array, because of the linearization requirement. Normally this is taken care by a dedicated linearize function, but here you can see what it actually does. Also, the const_rewritten will be the src of the STORE.

Here’s the full script:

from tinygrad import dtypes
from tinygrad.ops import Ops, UOp
from tinygrad.renderer.cstyle import MetalRenderer

from tinygrad.ops import PatternMatcher, UPat

const_1 = UOp(Ops.CONST, dtypes.float, arg=0.5)
const_2 = UOp(Ops.CONST, dtypes.float, arg=0.5)

matcher = PatternMatcher([
  (UPat(Ops.CONST, dtypes.float, name="x"), lambda ctx, x: UOp(Ops.ADD, dtypes.float, src=(const_1, const_2))),
])

metal_renderer = MetalRenderer()
const = UOp(Ops.CONST, dtypes.float, arg=1.0)

const_rewritten = matcher.rewrite(const)
define_global = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0)
special = UOp(Ops.SPECIAL, dtypes.int, arg=('gidx0', 16), src=())
added = UOp(Ops.ADD, dtypes.long, arg=None, src=(define_global, special))
store = UOp(Ops.STORE, dtypes.void, arg=None, src=(added, const_rewritten))
uops = [const_1, const_2, const_rewritten, define_global, special, added, store]

rendered = metal_renderer.render("rendered", uops)
print(rendered)
"""
#include <metal_stdlib>
using namespace metal;
kernel void rendered(device float* data0, uint3 gid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]]) {
  int gidx0 = gid.x; /* 16 */
  *(data0+gidx0) = (0.5f+0.5f);
}
"""