The general idea of pattern matcher is to express the entire computation into a nested tree, and look for pattern that can be rewritten for optimization. The idea isn’t new, and has been implemented to some extent for a while, although you may not have recognized it.
Suppose we have this script:
from tinygrad import Tensor
a = Tensor.empty(4, 4)
b = a + 1
b.realize()
We want to see the generated kernel code (GPU code) for a non-optimized
case, to do that, run NOOPT=1 DEBUG=5 python script.py
. The NOOPT would
disable certain optimization, which we will compare against later. And DEBUG=5
will output the code as well as the AST tree. Among the output, look for this
C++ looking code:
#include <metal_stdlib>
using namespace metal;
kernel void E_16(device float* data0, device float* data1, uint3 gid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]]) {
int gidx0 = gid.x; /* 16 */
float val0 = *(data1+gidx0);
*(data0+gidx0) = (val0+1.0f);
}
You may see something slightly different if you are using CUDA, but should remain
more or less the same. Look at the last line, where we are storing the a single
element result (val0 + 1.0f
) at the output data (data0
) at the offset gidx0
,
which is its grid id position.
The way this code is generated, is by looking at the AST. The output from the script shows one version of this AST:
UOp(Ops.SINK, dtypes.void, arg=KernelInfo(local_dims=0, upcasted=0, dont_use_locals=False), src=(
UOp(Ops.STORE, dtypes.void, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
x2:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(16,), strides=(1,), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.ADD, dtypes.float, arg=None, src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
x2,)),
UOp(Ops.WHERE, dtypes.float, arg=None, src=(
UOp(Ops.VALID, dtypes.bool, arg=None, src=(
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(16,), strides=(0,), offset=0, mask=None, contiguous=False),)), src=()),)),
UOp(Ops.CONST, dtypes.float, arg=1.0, src=()),
UOp(Ops.CONST, dtypes.float, arg=0.0, src=()),)),)),)),))
You should be able to map some parts of it to the generated code. For example,
Oops.STORE
with src of DEFINE_GLOBAL
maps to data0[
, the dtypes.float.ptr
results in the pointer dereferencing (*
), since it is part of the src of a STORE
operation, an equal sign (=
) should follow, after which we should do code generation
for its siblings, which would be the right hand side. The right hand side is
the result of an ADD (Ops.ADD
) operation, with two operands (two items in its
src). The first is a load operation from the parameter data1
, the second is
a WHERE
op. You get the idea. Let me hand write a simpler version of this,
just for the last line of our code:
UOp(op=STORE, src=(
UOp(op=DEFINE_GLOBAL),
UOp(op=ADD, src=(
UOp(op=LOAD, src=(
UOp(op=DEFINE_GLOBAL)
)),
UOp(op=CONST, arg=1.0)
))
)
This is just to illustrate the idea, if you look at it, it should intuitively translates to:
*(data0+gidx0) = (val0+1.0f);
There are some details missing, but I hope the idea makes sense.
Now, you may naively implement the code generation with a lot of if else statements, for example:
code = []
uop = UOp(op=STORE, src=()) # given the example above
if uop.op is STORE:
code.append("=")
elif uop.op is ADD:
code.append("+")
elif uop.op is CONST:
code.append(uop.arg)
#...
and have it buried in some sort of recursion and clever dictionary mappings. In fact that’s how code generation was previously implemented. But that leads to many repetivie code. We can do better by isolating the pattern into a list of, uh, patterns and writer functions:
patterns = [
(STORE, ), lambda uop: "=",
(CONST, ), lambda uop: f" {uop.arg} ",
(ADD, ), lambda uop: f" + ",
]
and then have a function iterate over it:
def render_code(uop):
code = []
for _uop in uop: # Suppose you already did a DFS/BFS so the tree is flattened
for pattern in patterns:
if _uop.op == pattern[0]:
_code = pattern[1](_uop)
code.append(_code)
Again, this code is overly simplified to illustrate the main idea. But you can see the separation between the matcher, and code generation spec. Our patterns can be further extended to match not just op, but also data types, if you just populate the tuples a bit more:
patterns = [
(STORE, dtypes.long, src=())
]
And the matcher will just have a few more if statements for each category.
Let’s actually give them the proper name, the patterns is class PatternMatcher
,
and it receives the list as initializer arguments. The tuple we haphazardly defined
is called class UPat
, and it has proper arguments like op
dtype
, src
. If
you look into the actual code, you see that it can also match things recursively,
and pass the matched element to the lambda function. For example for our STORE
uop
UOp(Ops.STORE, dtypes.void, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
...
UOp(Ops.ADD, dtypes.float, arg=None, src=(
...
)
You can write the following to match it:
UPat(Ops.STORE, name="x", src=(UPat.var("define_global"), UPat.var("addition"))), lambda x, define_global, addition: ...
The uop for STORE
will be passed to lambda in the argument specified by name
,
in this case x
will be the uop. The matched source elements, will carry the
name we specified in UPat.var
, so you can access the two elements via
define_global
and addition
. There are more interesting and powerful ways
to use it but you do have to read the code to discover them.
Now if you check out the code generation implementation, you can see how it works.
The render function iterates through the linearized uop tree, and match each
against the pattern specified in the PatternMatcher
, for each match it
outputs the string, and the strings were combined at the end for the final rendered
code.
If we can generate code from AST, we can also rewrite AST, in fact that’s the original use case for pattern matcher. All we have to do is just change the implementation of our lambda function, instead of returning strings, it returns another UOp.
Recall in our example script, we ran with NOOPT=1
, this disable the optimization,
so you see that each element is being added in a separate thread (We launched
a total of 16 of them, for the 4 by 4 matrix). You may have learned from elsewhere
that vectorized form is more efficient, so instead of doing 16 threads, we can
launch only 4, and have each process 4 elements in a vectorized form. Now let’s run the
script without NOOPT=1
, just DEBUG=5 python script.py
:
#include <metal_stdlib>
using namespace metal;
kernel void E_4_4(device float* data0, device float* data1, uint3 gid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]]) {
int lidx0 = lid.x; /* 4 */
int alu0 = (lidx0<<2);
float4 val0 = *(device float4*)((data1+alu0));
*(device float4*)((data0+alu0)) = float4((val0.x+1.0f),(val0.y+1.0f),(val0.z+1.0f),(val0.w+1.0f));
}
The pattern matcher that would generate float4 is written as such:
(UPat(Ops.VECTORIZE, name="x"),
lambda ctx,x: f"{ctx.float4.replace('float4', ctx.render_dtype(x.dtype))}" + \
(f"}" if ctx.device == "CLANG" else f"({','.join([ctx[y] for y in x.src])})")),
It may be hard to read because it compacts a lot of other important details, but for our
purpose, focus on the float4
string, that maps directly to the float4
in the
generated code. It then loop through each element of the vector (float4 means there
are four of them), and join them with a comma.
Now, if you compare the two UOp output section between our two run, you should
note that the latter has a VECTORIZE
entry:
0 Ops.CONST : dtypes.float [] 1.0
1 Ops.CONST : dtypes.int [] 2
2 Ops.DEFINE_GLOBAL : dtypes.float.ptr() [] 0
3 Ops.DEFINE_GLOBAL : dtypes.float.ptr() [] 1
4 Ops.SPECIAL : dtypes.int [] ('lidx0', 4)
5 Ops.SHL : dtypes.int [4, '2'] None
6 Ops.INDEX : dtypes.float.ptr() [2, 5] None
7 Ops.CAST : dtypes.float.vec(4).ptr() [6] None
8 Ops.INDEX : dtypes.float.ptr() [3, 5] None
9 Ops.CAST : dtypes.float.vec(4).ptr() [8] None
10 Ops.LOAD : dtypes.float.vec(4) [9] None
11 Ops.GEP : dtypes.float [10] (0,)
12 Ops.GEP : dtypes.float [10] (1,)
13 Ops.GEP : dtypes.float [10] (2,)
14 Ops.GEP : dtypes.float [10] (3,)
15 Ops.ADD : dtypes.float [11, '1.0'] None
16 Ops.ADD : dtypes.float [12, '1.0'] None
17 Ops.ADD : dtypes.float [13, '1.0'] None
18 Ops.ADD : dtypes.float [14, '1.0'] None
19 Ops.VECTORIZE : dtypes.float.vec(4) [15, 16, 17, 18] None
20 Ops.STORE : dtypes.void [7, 19] None
(By the time you read this, the exact output may have changed, so the key here is just
the entry Ops.VECTORIZE
, the rest of the details may go out of date)
How do we end up with vectorize? With rewrite rules. With some additional print statement, you may discover that vectorized originates from this section of the uop:
UOp(Ops.STORE, dtypes.void, arg=None, src=(
UOp(Ops.INDEX, dtypes.float.ptr(), arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
x3:=UOp(Ops.ADD, dtypes.int, arg=None, src=(
After a series of pattern matcher, it becomes
UOp(Ops.STORE, dtypes.void, arg=None, src=(
UOp(Ops.INDEX, dtypes.float.ptr().vec(4), arg=None, src=(
UOp(Ops.VECTORIZE, dtypes.float.ptr().vec(4), arg=None, src=(
x3:=UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
x3,
x3,
x3,)),
And the actual pattern:
(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.GEP, Ops.WMMA, Ops.LOAD, Ops.STORE, Ops.INDEX, Ops.ASSIGN,
Ops.VECTORIZE, Ops.IF), name="root", custom_early_reject=set([Ops.EXPAND])), do_expand),
It matches Ops.STORE
among all the accepted ops (the ops can take a tuple and match if
any one of them is found), and the STORE
becomes root
argument passed to the do_expand
function. This function returns a new version of UOp:
def do_expand(root): #root is the UOp(Ops.STORE, dtypes.void, arg=None, src=(...))
#...
new_srcs.append(UOp(Ops.VECTORIZE,
src.dtype.scalar().vec(expand_sz*src.dtype.count), tuple(src.gep(i) for i in range(src.dtype.count))*expand_sz))
Here’s an actual example you can toy with:
from tinygrad import dtypes
from tinygrad.ops import Ops, UOp
from tinygrad.renderer.cstyle import MetalRenderer
metal_renderer = MetalRenderer()
const = UOp(Ops.CONST, dtypes.float, arg=1.0)
define_global = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0)
special = UOp(Ops.SPECIAL, dtypes.int, arg=('gidx0', 16), src=())
added = UOp(Ops.ADD, dtypes.long, arg=None, src=(define_global, special))
store = UOp(Ops.STORE, dtypes.void, arg=None, src=(added, const))
uops = [const, define_global, special, added, store]
rendered = metal_renderer.render("rendered", uops)
print(rendered)
"""
#include <metal_stdlib>
using namespace metal;
kernel void rendered(device float* data0, uint3 gid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]]) {
int gidx0 = gid.x; /* 16 */
*(data0+gidx0) = 1.0f;
}
"""
It generates the kernel code that would write “1” to a 16 element tensor. Note that the argument to be passed to render must be linearized, hence the hand written list.
Let’s see what pattern rewrite would look like. Suppose we don’t like the fact it writes 1.0f directly, we want it to calculate 0.5f + 0.5f in the output (for no other reason than a demo). We can define a pattern matcher as such:
from tinygrad.ops import PatternMatcher, UPat
const_1 = UOp(Ops.CONST, dtypes.float, arg=0.5)
const_2 = UOp(Ops.CONST, dtypes.float, arg=0.5)
matcher = PatternMatcher([
(UPat(Ops.CONST, dtypes.float, name="x"), lambda ctx, x: UOp(Ops.ADD, dtypes.float, src=(const_1, const_2))),
])
const = UOp(Ops.CONST, dtypes.float, arg=1.0)
const_rewritten = matcher.rewrite(const)
This will rewrite the original const into an addition of two 0.5s. Now, to make it work
with the rest of the script, a few changes are needed. First, the const_1 and
const_2 must be added to the uops
array, because of the linearization requirement.
Normally this is taken care by a dedicated linearize
function, but here you can
see what it actually does. Also, the const_rewritten will be the src of the STORE
.
Here’s the full script:
from tinygrad import dtypes
from tinygrad.ops import Ops, UOp
from tinygrad.renderer.cstyle import MetalRenderer
from tinygrad.ops import PatternMatcher, UPat
const_1 = UOp(Ops.CONST, dtypes.float, arg=0.5)
const_2 = UOp(Ops.CONST, dtypes.float, arg=0.5)
matcher = PatternMatcher([
(UPat(Ops.CONST, dtypes.float, name="x"), lambda ctx, x: UOp(Ops.ADD, dtypes.float, src=(const_1, const_2))),
])
metal_renderer = MetalRenderer()
const = UOp(Ops.CONST, dtypes.float, arg=1.0)
const_rewritten = matcher.rewrite(const)
define_global = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0)
special = UOp(Ops.SPECIAL, dtypes.int, arg=('gidx0', 16), src=())
added = UOp(Ops.ADD, dtypes.long, arg=None, src=(define_global, special))
store = UOp(Ops.STORE, dtypes.void, arg=None, src=(added, const_rewritten))
uops = [const_1, const_2, const_rewritten, define_global, special, added, store]
rendered = metal_renderer.render("rendered", uops)
print(rendered)
"""
#include <metal_stdlib>
using namespace metal;
kernel void rendered(device float* data0, uint3 gid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]]) {
int gidx0 = gid.x; /* 16 */
*(data0+gidx0) = (0.5f+0.5f);
}
"""