tinygrad-notes

Kernel Fusion part 3: the linear layer UOps

In my high level overview of kernel fusion and the schedule item post, I discussed the three layers of abstraction in generating a kernel: ScheduleItem, Uops and the final kernel code and gave a detailed explanation of how ScheduleItem is generated. Let me discuss the next layer regarding how ScheduleItem is turned into a linear representation (Uops) of all the operations that need to be performed, and their optimzation.

Recall that a dot product operation will output this ScheduleItem to be converted to kernel code:

  0 ━┳ STORE MemBuffer(idx=0, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(1,), strides=(0,), offset=0, mask=None, contiguous=True),)))
  1  ┗━┳ SUM (0,)
  2    ┗━┳ MUL
  3      ┣━━ LOAD MemBuffer(idx=1, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(2,), strides=(1,), offset=0, mask=None, contiguous=True),)))
  4      ┗━━ LOAD MemBuffer(idx=2, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(2,), strides=(1,), offset=0, mask=None, contiguous=True),)))

Remember that the above AST looking tree is the “ScheduleItem” we I discussed in the last post, and it gets converted into the following in the next step:

step  Op_name               type                      input                           arg
   0 UOps.DEFINE_GLOBAL  : ptr.dtypes.int            []                               (0, 'data0', True)
   1 UOps.DEFINE_GLOBAL  : ptr.dtypes.int            []                               (1, 'data1', False)
   2 UOps.DEFINE_GLOBAL  : ptr.dtypes.int            []                               (2, 'data2', False)
   3 UOps.DEFINE_ACC     : dtypes.int                []                               0
   4 UOps.CONST          : dtypes.int                []                               0
   5 UOps.CONST          : dtypes.int                []                               2
   6 UOps.LOOP           : dtypes.int                [4, 5]                           None
   7 UOps.LOAD           : dtypes.int                [1, 6]                           None
   8 UOps.LOAD           : dtypes.int                [2, 6]                           None
   9 UOps.ALU            : dtypes.int                [7, 8]                           BinaryOps.MUL
  10 UOps.ALU            : dtypes.int                [9, 3]                           BinaryOps.ADD
  11 UOps.PHI            : dtypes.int                [3, 10, 6]                       None
  12 UOps.ENDLOOP        :                           [6]                              None
  13 UOps.STORE          :                           [0, 4, 11]                       None

Please refer back to the high level overview for more context.

The journey starts from the corealize method in the Tensor class

  @staticmethod
  def corealize(lst:Iterable[Tensor]):
    run_schedule(create_schedule(flatten([x.lazydata.lbs if isinstance(x.lazydata, MultiLazyBuffer) else [x.lazydata] for x in lst])))

We take the created ScheduleItem (discussed in this post) and call the run_schedule function, which invokes prg = lower_schedule_item(si), within which if the item has a STORE operation at the top level it starts the process of converting ScheduleItem to Uops via the get_runner method.

  if si.ast[0].op is BufferOps.STORE: return Device[si.outputs[0].device].get_runner(*si.ast)

Inside get_runner, it initialize a linearizer and call a to_program method (*ast is our ScheduleItem):

  def get_runner(self, *ast:LazyOp) -> CompiledASTRunner: return self.to_program(self.get_linearizer(*ast))

The Linearizer class is what’s responsible for converting an AST (ScheduleItem) to linear operations, hence its name. Inside it we see a condtion check for the NOOPT env variable, and that’s where things diverge depending on whether you have optimization set up. We will assume the NOOPT is set to true for now so we don’t go in there. Spoiler though, inside the if block, it sets up certain flags for the optimization, rather than performing them directly.

    if not NOOPT:

You will also see a check for the DEBUG variable:

    if DEBUG >= 3:
      from tinygrad.features.graph import print_tree
      for op in ast: print_tree(op)

This is actually what prints out the scheduleitem or AST we saw above.

Overall, it initialize an object with some default settings, and is passed to the to_program method.

to_program calls two important methods, first is the linearize method, this is where the process of converting things actually happens

    k.linearize()

Next thing to-program does is set up the actual code generator and facility to run it:

    ret = CompiledASTRunner(k.name, self.compiler.render(to_function_name(k.name), k.uops), self.dname, k.global_size, k.local_size,
                            k.uops.vars(), min(info.flops, ops * run_count), min(info.mem_estimate, mem * run_count), outcount=len(k.outbufs))

You can see the self.compiler.render() is what generates the actual code, and CompiledASTRunner takes that code and expose an exec method so they can be executed on GPU.

I will focus on the linearize() method in this post:

  def linearize(self):

It’s actually a very long function. An important property to keep an eye on is the uops attribute on the Linearizer instance, This attribute contains a list of, you guessed it, Uops, which is what we saw earlier when I printedd the entire UOps list. So by looking at when and where items are added to the list, we can make sense of how the process works.

This is where self.uops is initialized:

    self.uops:UOpGraph = UOpGraph()

And this is an example where items are appended:

        self.buf_uops[i] = self.uops.add(UOps.DEFINE_GLOBAL,
                                         buf.dtype if isinstance(buf.dtype, ImageDType) else PtrDType(buf.dtype), (),
                                         (buf.idx, f"data{buf.idx}", any(buf.idx == x.idx for x in self.outbufs)))

The above calls the add method on UOpGraph, and add a DEFINE_GLOBAL uop with dtype set to the buffer’s data type, with zero input, argument set to the index. Let’s briefly look at the implementation of the add method:

  def add(self, uop:UOps, dtype:Optional[DType]=None, vin:Tuple[UOp, ...]=tuple(), arg:Any=None, cachable=True, insert_before=None,
          simplify=True) -> UOp:
    ret = UOp(uop, dtype, vin, arg)
    if simplify and (rewritten:=constant_folder.rewrite(ret)) is not None:
      if rewritten in self.uops: return rewritten  # ignore cachable
      ret = rewritten
    key = (ret.uop, ret.dtype, ret.vin, ret.arg)
    if insert_before is None: insert_before = len(self.uops)
    # check if the cached expr is valid with the given insert place.
    if cachable and (expr:=self.saved_exprs.get(key, None)) is not None and self.uops.index(expr) <= insert_before: return expr
    self.uops.insert(insert_before, ret)
    if cachable: self.saved_exprs[key] = ret
    return ret

You see that it construct the UOp as a plain class instance, and set up some caches, afterwards it just does a plain list insert self.uops.insert(insert_before, ret).

So by adding a breakpoint on the add method, I managed to figure out exactly when a uop is added. Now think about conceptually, how would you turn the AST into linear form?

Step 1 is to initialize variables that will be used in the kernel, we must have an output pointer passed to the argument, and for any input, we also need to have them as arguments passed. So if we have a kernel function like this

(int* data0, int* data1, int* data2) {

}

We need to have three DEFINE_GLOBALS, that’s exactly what happens

    for i,buf in enumerate(self.bufs):
      if isinstance(buf, MemBuffer):
        self.buf_uops[i] = self.uops.add(UOps.DEFINE_GLOBAL,
                                         buf.dtype if isinstance(buf.dtype, ImageDType) else PtrDType(buf.dtype), (),
                                         (buf.idx, f"data{buf.idx}", any(buf.idx == x.idx for x in self.outbufs)))

self.bufs was initialized inside the init method in the Linearizer class:

    self.bufs: List[Union[MemBuffer, ConstBuffer, LocalBuffer]] = self.outbufs + dedup([x.arg for x in self.lazyops if x.op in loadops])

In our case, it has three items:

The first membuffer is the output data pointer, the other two are the two list we are operating the dot product operation on. They came from the scheduleitem that I covered in details in my last post.

Next after lots of code that didn’t take any effect (not in our simple dot product), we encounter a reduce set up:

    if self.reduceop is not None:
      # define indexes
      reduce_idxs = [Variable(f"ridx{i}", 0, self.full_shape[i]-1) for i in range(self.first_reduce+self.group_for_reduces, self.shape_len-self.upcasted)]  # noqa: E501
      fake_reduce_idxs = [x*0 for x in reduce_idxs]

      # define accumulator
      out_buf = -1 if self.group_for_reduces else 0
      acc = self.global_load(out_buf, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, self.get_reduce_acc(self.reduceop))

      # reduce loop
      loop_ctx = render_loop(reduce_idxs)

A reduce op is something that convert multiple elements into one, for example, our SUM operation. reduce_idxs contains one element, a Variable that ranges from 0 to 1. I will cover what a Variable is in a separate post, but for now, think of it as a regular integer but with some special methods. Conceptually, if we want to reduce a list of element into a single one, we need an accumulator and then run a loop on our list, this is in fact what happens at this stage (before any optimization).

We first define an accumulator by calling global_load:

      acc = self.global_load(out_buf, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, self.get_reduce_acc(self.reduceop))

global_load is a complicated function, but in this case, just the following line has effect:

          self.load_cache[key] = self.uops.add(UOps.DEFINE_ACC, localtype, (), dtypes.as_const(this_const, localtype), cachable=False)

A DEFINE_ACC is added to the uops list. Next we enter the render_loop function

      loop_ctx = render_loop(reduce_idxs)

render_loop is defined as such:

    def render_loop(xx:List[Variable]) -> Tuple[UOp, ...]:
      new_loops = {x.expr:self.uops.add(UOps.LOOP, dtypes.int32, (
        self.const(x.min) if isinstance(x.min, int) else cast(Node, x.min).render(self.render_ops, self),
        self.const(x.max+1) if isinstance(x.max, int) else cast(Node, x.max+1).render(self.render_ops, self)), cachable=False) for x in xx if not isinstance(x, NumNode) and x.expr is not None}  # noqa: E501
      self.loop_uops.update(new_loops)
      return tuple(new_loops.values())

The key to it is that for the reduce operation we have, we append a LOOP uop to the uops list. The loop will have a termination, which is set to the max value of our reduce op plus 1.

We then iterate through the .earlybufs property to set up a variable that contains the buffers before any reduce operation is performed. The .earlybufs is initialized like this

    self.earlybufs = [x.arg for x in self.reduceop.lazyops if x.op in BufferOps] if self.reduceop else []

Then we parse the entire AST for the reduce part (SUM) and fill in the remaining uops within the loop:

        self.ast_parse(self.reduceop, acc, self.acc_offsets(self.full_buf_index), loaded_buffers, do_reduce=True, loop_ctx=loop_ctx)

Let’s look at what ast_parse does:

  def ast_parse(self, x:LazyOp, acc: List[UOp], offs:Optional[List[int]], loaded_buffers:Dict[Union[MemBuffer, ConstBuffer, LocalBuffer], List[UOp]], do_reduce=False, loop_ctx=tuple(), cache=None) -> List[UOp]:  # noqa: E501

x is our SUM op, acc is the DEFINE_ACC we have set up, offs is zero (which I didn’t cover, but being zero means we operate on the input without any offset). loaded_buffers derives from .earlybufs with some modification, but essentially means the available data our loop can access – put into context, they are the two tensor we are applying dot product operation on and then summing the elements.

We see that it goes into the .src and perform a recursion:

    values = [self.ast_parse(v, acc, offs, loaded_buffers, loop_ctx=loop_ctx, cache=cache) for v in x.src]

Recall our AST looks like this:

  0 ━┳ STORE MemBuffer(idx=0, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(1,), strides=(0,), offset=0, mask=None, contiguous=True),)))
  1  ┗━┳ SUM (0,)
  2    ┗━┳ MUL
  3      ┣━━ LOAD MemBuffer(idx=1, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(2,), strides=(1,), offset=0, mask=None, contiguous=True),)))
  4      ┗━━ LOAD MemBuffer(idx=2, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(2,), strides=(1,), offset=0, mask=None, contiguous=True),)))

The .src of the SUM node is MUL, whose .src are the two LOAD op, so the recursion’s base case terminates upon the loading of the two membuffers, which are our input list:

    if x.op in BufferOps: return loaded_buffers[x.arg]

And the MUL operation is constructed:

      ret = [self.uops.add(UOps.ALU, dtypes.bool if x.op in {BinaryOps.CMPLT, BinaryOps.CMPEQ} else val[-1].dtype, val, x.op) for val in zip(*values)]

which results in:

The above read: append a ALU operation (arithmetic logic unit) to the linear ops list, this op’s type is the same as input type, and the val is each element in the LOAD buffer, with the specific op type being MUL (x.op). In summary, that’s how you would express the following multiplication operation in linear op format:

    int val0 = *(data1+ridx0);
    int val1 = *(data2+ridx0);
    acc = val0 * val1;

And finally we return back to the SUM callstack and execute this line:

    if x.op in ops:
      ret: List[UOp] = []
      input_acc = acc[:]
      for val, off in zip(zip(*values), cast(List[int], offs)):
        acc[off] = self.uops.add(UOps.ALU, acc[off].dtype, vin=val+(acc[off],), arg=ops[cast(ReduceOps, x.op)])
        ret.append(acc[off])
      for off in range(len(acc)):
        if input_acc[off] != acc[off]:
          acc[off] = self.uops.add(UOps.PHI, input_acc[off].dtype, (input_acc[off], acc[off]) + tuple(loop_ctx))

The first loop add all the input from our MUL recursion together by doing self.uops.add(UOps.ALU), and the second loop relates to single static assignment and constructs a PHI op. I don’t understand it 100% so I will just recommend you look up the concept on Google (a good example is LLVM).

The remaining part is more or less the same as before, but for all the remaining operation after the SUM part (which we actually don’t have any)

    loaded_buffers.update({b:self.global_load(i, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs) \
                           for i,b in enumerate(self.bufs) if b not in self.earlybufs and b.__class__ is not LocalBuffer})

    # run late AST (without the store)
    for op in self.ast:
      val = self.ast_parse(op.src[0], acc, None, loaded_buffers)
      self.global_store(op.arg.idx, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, val)

And finally, we enter the optimization step (I will probably cover the optimize part in a separate post)

    self.uops.uoptimize()

and end the uops generation part.