tinygrad-notes

Beam search explained

One key feature of tinygrad, or any ML framework, is to generate efficient kernel code. For example, a naive implementation of a sum on dimension 1:

Tensor.empty(4, 4).sum(1).realize()

may result in such kernel code:

kernel void r_4_4(device float* data0, device float* data1, uint3 gid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]]) {
  int gidx0 = gid.x; /* 4 */
  float acc0 = 0.0f;
  for (int ridx0 = 0; ridx0 < 4; ridx0++) {
    float val0 = *(data1+((gidx0<<2)+ridx0));
    acc0 = (acc0+val0);
  }
  *(data0+gidx0) = acc0;
}

This can be optimized by taking away the for loop, and handle it inline:

kernel void r_4_4(device float* data0, device float* data1, uint3 gid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]]) {
  int gidx0 = gid.x; /* 4 */
  float4 val0 = *((device float4*)((data1+(gidx0<<2))));
  *(data0+gidx0) = (val0.w+val0.z+val0.x+val0.y);
}

Alternatively, we can launch just one kernel and let it process all four columns:

kernel void r_4_4(device float* data0, device float* data1, uint3 gid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]]) {
  float acc0 = 0.0f;
  float acc1 = 0.0f;
  float acc2 = 0.0f;
  float acc3 = 0.0f;
  for (int ridx0 = 0; ridx0 < 4; ridx0++) {
    float val0 = *(data1+ridx0);
    float val1 = *(data1+(ridx0+4));
    float val2 = *(data1+(ridx0+8));
    float val3 = *(data1+(ridx0+12));
    acc0 = (acc0+val0);
    acc1 = (acc1+val1);
    acc2 = (acc2+val2);
    acc3 = (acc3+val3);
  }
  *((device float4*)((data0+0))) = float4(acc0,acc1,acc2,acc3);
}

How do we know which one is faster? There are two approaches in general, one is hand code some heuristics based on experience. There are lots of articles online that explain how to make your matrix multiplication kernel go faster, such as this one. Another approach is by automatically searching over all the possible optimization techniques and their parameters, compare the speed, and pick the fastest one. The search algorithm currently used is beam search, and in fact is just the specific example of the general search approaches (among others there are monte carlo tree search, or even neural net search, etc.). As pointed out by Sutton’s “The Bitter Lesson”, the search approach will likely to triumph.

That’s actually all there is for the beam search! When a kernel is generated, it goes through all the possible optimization and measure the speed, then pick the quickest one. Quite simple.

The more complicated part are all the options available. In order for search to work, we need to define a set of optimization for it to search over. What I have illustrated in the above examples, are “UNROLL” and “UPCAST”, and there are many more.

General API usage and UNROLL

We start off with the AST that was passed to the kernel, you can use DEBUG=3 flag to find it, in our sum example:

ast = UOp(Ops.SINK, dtypes.void, arg=KernelInfo(local_dims=0, upcasted=0, dont_use_locals=False), src=(
  UOp(Ops.STORE, dtypes.void, arg=None, src=(
    UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
    UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()),
    UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (1,)), src=(
      UOp(Ops.LOAD, dtypes.float, arg=None, src=(
        UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
        UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4, 4), strides=(4, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)),)),))

The opt is constructed by the Opt class, and it takes a few arguments:

unroll_opt = Opt(OptOps.UNROLL, 0, 4)

The opt is applied onto the kernel and you can print out the result:

kernel = Kernel(ast)
kernel.apply_opt(unroll_opt)
p = kernel.to_program()
print(p.src)

You should be able to get identical output as illustrated earlier.

The axis 0 means we are unrolling the last axis, amt 4 means we are unrolling it four times, so the entire iteration is now inline, we can also unroll it just twice: Opt(OptOps.UNROLL, 0, 2):

kernel void r_4_2_2(device float* data0, device float* data1, uint3 gid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]]) {
  int gidx0 = gid.x; /* 4 */
  float acc0 = 0.0f;
  for (int ridx0 = 0; ridx0 < 2; ridx0++) {
    float2 val0 = *((device float2*)((data1+((gidx0<<2)+(ridx0<<1)))));
    acc0 = (acc0+val0.x+val0.y);
  }
  *(data0+gidx0) = acc0;
}

Note that certain values don’t make sense, for example if you are iterating four times, you can’t unroll it thrice.

UPCAST

Upcasting refers to making the kernel process more elements. One example is shown earlier where we launched 1 kernel to process all four rows. Here’s another example, suppose we are doing an elemenwise add:

a = Tensor.empty(4, 4)
b = Tensor.empty(4, 4)
c = a + b
c.realize()

The resulting ast becomes (DEBUG=3):

ast = UOp(Ops.SINK, dtypes.void, arg=KernelInfo(local_dims=0, upcasted=0, dont_use_locals=False), src=(
  UOp(Ops.STORE, dtypes.void, arg=None, src=(
    UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
    x2:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(16,), strides=(1,), offset=0, mask=None, contiguous=True),)), src=()),
    UOp(Ops.ADD, dtypes.float, arg=None, src=(
      UOp(Ops.LOAD, dtypes.float, arg=None, src=(
        UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
         x2,)),
      UOp(Ops.LOAD, dtypes.float, arg=None, src=(
        UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()),
         x2,)),)),)),))

If we don’t apply any optimization, this would be the kernel:

kernel void E_16(device float* data0, device float* data1, device float* data2, uint3 gid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]]) {
  int gidx0 = gid.x; /* 16 */
  float val0 = *(data1+gidx0);
  float val1 = *(data2+gidx0);
  *(data0+gidx0) = (val0+val1);
}

Let’s apply upcast, since the elements are contiguous, our shape is just one dimensional, so upcasting is on axis 0. Currently each kernel processes 1 element, applying it with an amt of 8 means each kernel will process 8 elements, and total number of kernels is now 2.

opt = Opt(OptOps.UPCAST, 0, 8)
kernel.apply_opt(opt)
kernel void E_2_8(device float* data0, device float* data1, device float* data2, uint3 gid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]]) {
  int gidx0 = gid.x; /* 2 */
  int alu0 = (gidx0<<3);
  int alu1 = (alu0+4);
  float4 val0 = *((device float4*)((data1+alu0)));
  float4 val1 = *((device float4*)((data2+alu0)));
  float4 val2 = *((device float4*)((data1+alu1)));
  float4 val3 = *((device float4*)((data2+alu1)));
  *((device float4*)((data0+alu0))) = float4((val0.x+val1.x),(val0.y+val1.y),(val0.z+val1.z),(val0.w+val1.w));
  *((device float4*)((data0+alu1))) = float4((val2.x+val3.x),(val2.y+val3.y),(val2.z+val3.z),(val2.w+val3.w));
}

Note that the axis specified by upcast must not be a reduce axis.

LOCAL

Our previous examples all launched kernel in the block/global space. Recall that the GPU execution model distinguishes between blocks and threads. This can be signaled by the use of threadgroup_position_in_grid vs thread_position_in_threadgroup in Metal, and use of blockIdx vs threadIdx in CUDA. By default tinygrad’s split work across block level, and if we want to split work in the thread level, we apply a LOCAL opt. Using the same “addition” example, we can do:

opt = Opt(OptOps.LOCAL, 0, 2)
kernel.apply_opt(opt)

This will launch two threads and split the rest of the work accordingly:

kernel void E_8_2(device float* data0, device float* data1, device float* data2, uint3 gid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]]) {
  int gidx0 = gid.x; /* 8 */
  int lidx0 = lid.x; /* 2 */
  int alu0 = (lidx0+(gidx0<<1));
  float val0 = *(data1+alu0);
  float val1 = *(data2+alu0);
  *(data0+alu0) = (val0+val1);
}

If we want to do threads entirely with just one block, we can pass the value 16 to the amount Opt(OptOps.LOCAL, 0, 16):

kernel void E_16(device float* data0, device float* data1, device float* data2, uint3 gid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]]) {
  int lidx0 = lid.x; /* 16 */
  float val0 = *(data1+lidx0);
  float val1 = *(data2+lidx0);
  *(data0+lidx0) = (val0+val1);
}

GROUP

This optimization refers to the techniques of using shared memory (__shared__ in CUDA and threadgroup in Metal). Here’s a good explanation outlining its usage and advantage.

It’s an optimzation for the reduce axis, so I will use the ast from our sum example, and apply the opt as Opt(OptOps.GROUP, 0, 4), resulting kernel is:

kernel void r_4_4(device float* data0, device float* data1, uint3 gid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]]) {
  threadgroup float temp0[4];
  int gidx0 = gid.x; /* 4 */
  int lidx0 = lid.x; /* 4 */
  float val0 = *(data1+(lidx0+(gidx0<<2)));
  *(temp0+lidx0) = val0;
  threadgroup_barrier(mem_flags::mem_threadgroup);
  if ((((bool)(lidx0))!=1)) {
    float acc0 = 0.0f;
    for (int ridx0 = 0; ridx0 < 4; ridx0++) {
      float val1 = *(temp0+ridx0);
      acc0 = (acc0+val1);
    }
    *(data0+gidx0) = acc0;
  }
}

As you can see in the code, we launch four blocks, each block will handle a row in the matrix and calculate the sum. Within each block, there are four threads. We first have each thread read a value from the input tensor data1, and store the value in the shared variable temp0. We make sure all four threads have done this properly by synchronizing them with threadgroup_barrier function call (in CUDA, this would be __syncthreads). Then, we ask only the zeroth thread to iterate over the value and calculate the sum (essentially aborting all three other threads) and write to the result.

PADTO

It’s often a good idea to have the data be in the shape of power of 2, so if you have are adding two matrices of shape 3 by 3, it would be preferrable to pad the data to 4 by 4, with the additional row and column being empty value. This can be handled on the tensor level, by reshaping them, or alternatively, handled at the kernel level using the PADTO opt.

Here’s the example Tensor.empty(3, 3) + Tensor.empty(3, 3) and the resulting ast:

ast = UOp(Ops.SINK, dtypes.void, arg=KernelInfo(local_dims=1, upcasted=0, dont_use_locals=False), src=(
  UOp(Ops.STORE, dtypes.void, arg=None, src=(
    UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
    x2:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(3, 3), strides=(3, 1), offset=0, mask=None, contiguous=True),)), src=()),
    UOp(Ops.ADD, dtypes.float, arg=None, src=(
      UOp(Ops.LOAD, dtypes.float, arg=None, src=(
        UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
         x2,)),
      UOp(Ops.LOAD, dtypes.float, arg=None, src=(
        UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()),
         x2,)),)),)),))

The naive kernel is actually identical to before:

kernel void E_9(device float* data0, device float* data1, device float* data2, uint3 gid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]]) {
  int gidx0 = gid.x; /* 9 */
  float val0 = *(data1+gidx0);
  float val1 = *(data2+gidx0);
  *(data0+gidx0) = (val0+val1);
}

We can pad it such that we launch 12 kernels, by applying Opt(OptOps.PADTO, 0, 4):

kernel void E_12(device float* data0, device float* data1, device float* data2, uint3 gid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]]) {
  int gidx0 = gid.x; /* 12 */
  bool alu0 = (gidx0<9);
  float val0 = (alu0?*(data1+gidx0):0.0f);
  float val1 = (alu0?*(data2+gidx0):0.0f);
  if (alu0) {
    *(data0+gidx0) = (val0+val1);
  }
}

Here you can see the extra conditional to fill in zeros when the data is the “padded” ones. This does incurs some overhead though.

GROUPTOP

GROUP has a variation called GROUPTOP. The two actually behaves identically when the work can be evenly divided, in our earlier example, should you use Opt(OptOps.GROUPTOP, 0, 4), the result would be the same. However, if we set the amount to 2, things would be different.

Opt(OptOps.GROUP, 0, 2):

kernel void r_4_2_2(device float* data0, device float* data1, uint3 gid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]]) {
  threadgroup float temp0[2];
  int gidx0 = gid.x; /* 4 */
  int lidx0 = lid.x; /* 2 */
  float acc0 = 0.0f;
  for (int ridx0 = 0; ridx0 < 2; ridx0++) {
    float val0 = *(data1+(lidx0+(gidx0<<2)+(ridx0<<1)));
    acc0 = (acc0+val0);
  }
  *(temp0+lidx0) = acc0;
  threadgroup_barrier(mem_flags::mem_threadgroup);
  if ((((bool)(lidx0))!=1)) {
    float acc1 = 0.0f;
    for (int ridx1 = 0; ridx1 < 2; ridx1++) {
      float val1 = *(temp0+ridx1);
      acc1 = (acc1+val1);
    }
    *(data0+gidx0) = acc1;
  }
}

in comparison, Opt(OptOps.GROUPTOP, 0, 2) would lead to:

kernel void r_4_2_2(device float* data0, device float* data1, uint3 gid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]]) {
  threadgroup float temp0[2];
  int gidx0 = gid.x; /* 4 */
  int lidx0 = lid.x; /* 2 */
  float acc0 = 0.0f;
  for (int ridx0 = 0; ridx0 < 2; ridx0++) {
    float val0 = *(data1+((gidx0<<2)+(lidx0<<1)+ridx0));
    acc0 = (acc0+val0);
  }
  *(temp0+lidx0) = acc0;
  threadgroup_barrier(mem_flags::mem_threadgroup);
  if ((((bool)(lidx0))!=1)) {
    float acc1 = 0.0f;
    for (int ridx1 = 0; ridx1 < 2; ridx1++) {
      float val1 = *(temp0+ridx1);
      acc1 = (acc1+val1);
    }
    *(data0+gidx0) = acc1;
  }
}

The difference is here:

    float val0 = *(data1+(lidx0+(gidx0<<2)+(ridx0<<1))); // GROUP
    float val0 = *(data1+((gidx0<<2)+(lidx0<<1)+ridx0)); // GROUPTOP

Since there are only two threads to process four elements, there are some choices on how to divide such work. GROUP would let the loop handle two consecutive elements, leading to ridx0<<1+lidx0 access pattern:

lidx0 ridx0 index
    0     0     0  --> this
    0     1     2
    1     0     1  --> and this take place at the same time
    1     1     3

This takes advantage of the fact that threads run in parallel, so the same iteration of the loop also occur in parallel, having them access consecutive elements speed up memory access.

In contrast, GROUPTOP is the oppositve, it lets the thread itself access two elements:

lidx0 ridx0 index
    0     0     0
    0     1     1
    1     0     2
    1     1     3

It may be helpful in other cases.

The search

That covers most of the possible options we can have, here’s how the search takes place. The kernel starts with the AST, and we supply a bunch of possible options, for example:

actions = [Opt(op=OptOps.UPCAST, axis=axis, amt=amt) for amt in [0,2,3,4,5,7] for axis in range(6)]
actions += [Opt(op=OptOps.UNROLL, axis=axis, amt=amt) for amt in [0,4,7] for axis in range(5)]
actions += [Opt(op=OptOps.LOCAL, axis=axis, amt=amt) for amt in [2,3,4,8,13,16,29] for axis in range(6)]
actions += [Opt(op=OptOps.GROUPTOP, axis=axis, amt=amt) for amt in [13,16,28,29,32,49,64,256] for axis in range(3)]
actions += [Opt(op=OptOps.GROUP, axis=axis, amt=amt) for amt in [0,4,8,16] for axis in range(3)]

Then a function will take the AST, and apply all the possible combinations of such actions, and measure the time it takes for each optimization. The strategy on how to try things out are handled by the beam search algorithms and I’ll omit the details. As you can see, the beam search algorithm itself isn’t the key here, and you can easily swap it with some other search algorithms. Rather it’s the possible options and the kernel code they can lead to that’s a bit more difficult to grasp.