tinygrad-notes

Documentation on Tinygrad’s IR

Each uop represent an operation in tinygrad’s intermediate representation, also known as the graph. They are inserted as follows:

from tinygrad.codegen.uops import UOpGraph, UOps
g = UOpGraph()
c0 = g.add(UOps.CONST, dtypes.int, arg=0)

and can be rendered into target platform’s code as such:

s = uops_to_cstyle(MetalLanguage(), 'tester', g)

The main way to modify the graph is by calling the add method with the following arguments:

CONST

Uops.CONST declares a constant variable. There are two required parameters in add: dtype and arg.

Example:

c0 = g.add(UOps.CONST, dtypes.int, arg=10)

c0 can then be used as inputs for other UOp.

DEFINE_GLOBAL

UOps.DEFINE_GLOBAL declares a global variable. It is used as the parameter list for the function.

For example, when declaring a parameter that will be passed to the kernel:

c1 = g.add(UOps.DEFINE_GLOBAL, dtype=dtypes.int, vin=(), arg=(1, "data0", True)

LOOP and ENDLOOP

As its name suggests, they set up loop.

LOOP:

ENDLOOP:

c0 = g.add(UOps.CONST, dtypes.int, arg=0)
c1 = g.add(UOps.CONST, dtypes.int, arg=10)
loop = g.add(UOps.LOOP, dtype=dtypes.int, vin=(c0, c1))
endloop = g.add(UOps.ENDLOOP, vin=(loop,))

The rendered loop looks like this:

for (int ridx0 = 0; ridx0 < 10; ridx0++) {
}

STORE

STORE is for writing value to the output, which comes in the form of a parameter passed to the kernel function.

Example

c1 = g.add(UOps.DEFINE_GLOBAL, dtype=dtypes.int, vin=(), arg=(0, "data0", True))
c2 = g.add(UOps.CONST, dtype=dtypes.int, arg=0)
c3 = g.add(UOps.CONST, dtype=dtypes.int, arg=10)
store = g.add(UOps.STORE, vin=(c1, c2, c3))

and it will render:

kernel void tester(constant int& data0, uint3 gid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]]) {
  *(data0+0) = 10;
}

LOAD

This allows for indexing a value from the input given the offset

Example (see the ALU example for the generated code):

input_value = g.add(UOps.DEFINE_GLOBAL, dtype=dtypes.int, vin=(), arg=(2, "data2", False))
position = g.add(UOps.CONST, dtype=dtypes.int, arg=0)
loaded = g.add(UOps.LOAD, dtype=dtypes.int, vin=(input_value, position))

ALU

ALU is for arithmetic, logical, and bitwise operations.

It is usually used in conjunction with other ops, for example, to load the first element from two input arrays and add them together:

c1 = g.add(UOps.DEFINE_GLOBAL, dtype=dtypes.int, vin=(), arg=(0, "data0", True))
x1 = g.add(UOps.DEFINE_GLOBAL, dtype=dtypes.int, vin=(), arg=(1, "data1", False))
x2 = g.add(UOps.DEFINE_GLOBAL, dtype=dtypes.int, vin=(), arg=(2, "data2", False))
pos_input = g.add(UOps.CONST, dtype=dtypes.int, arg=0)
x1_loaded = g.add(UOps.LOAD, dtype=dtypes.int, vin=(x1, pos_input))
x2_loaded = g.add(UOps.LOAD, dtype=dtypes.int, vin=(x2, pos_input))
c4 = g.add(UOps.ALU, dtype=dtypes.int, vin=(x1_loaded, x2_loaded), arg=BinaryOps.ADD)
pos = g.add(UOps.CONST, dtype=dtypes.int, arg=0)
store = g.add(UOps.STORE, vin=(c1, pos, c4))

Special

GPU kernels are usually executed in SIMT fashion, meaning each thread will need to identify itself among all the other threads, such that it can fetch the correct data. In the ALU example above, we are explicitly fetching the zeroth element via the CONST UOp, but we might want to declare a UOp that fetches element based on the threadID.

Example:

position = g.add(UOps.SPECIAL, dtype=dtypes.int, arg=(0, "gidx0", 10))

This means the thread is launched in a group containing ten threads, and each thread will get the value by iterating from 0 to 10 (exclusive). We can now modify the ALU example:

c1 = g.add(UOps.DEFINE_GLOBAL, dtype=dtypes.int, vin=(), arg=(0, "data0", True))
x1 = g.add(UOps.DEFINE_GLOBAL, dtype=dtypes.int, vin=(), arg=(1, "data1", False))
x2 = g.add(UOps.DEFINE_GLOBAL, dtype=dtypes.int, vin=(), arg=(2, "data2", False))
pos_input = g.add(UOps.SPECIAL, dtype=dtypes.int, arg=(0, "gidx0", 10))
x1_loaded = g.add(UOps.LOAD, dtype=dtypes.int, vin=(x1, pos_input))
x2_loaded = g.add(UOps.LOAD, dtype=dtypes.int, vin=(x2, pos_input))
c4 = g.add(UOps.ALU, dtype=dtypes.int, vin=(x1_loaded, x2_loaded), arg=BinaryOps.ADD)
pos = g.add(UOps.CONST, dtype=dtypes.int, arg=0)
store = g.add(UOps.STORE, vin=(c1, pos, c4))

and the generated code becomes:

kernel void tester(constant int& data0, constant int& data1, constant int& data2, uint3 gid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]]) {
  int gidx0 = gid.x; /* 10 */
  int val0 = *(data1+gidx0);
  int val1 = *(data2+gidx0);
  *(data0+0) = (val0+val1);
}