a = Tensor([1])
b = Tensor([2])
c = a + b
print(c.numpy())
will generate the following uops
0 UOps.DEFINE_GLOBAL : ptr.dtypes.int [] (0, 'data0', True)
1 UOps.DEFINE_GLOBAL : ptr.dtypes.int [] (1, 'data1', False)
2 UOps.DEFINE_GLOBAL : ptr.dtypes.int [] (2, 'data2', False)
3 UOps.CONST : dtypes.int [] 0
4 UOps.LOAD : dtypes.int [1, 3] None
5 UOps.LOAD : dtypes.int [2, 3] None
6 UOps.ALU : dtypes.int [4, 5] BinaryOps.ADD
7 UOps.STORE : [0, 3, 6] None
which then generates:
#include <metal_stdlib>
using namespace metal;
kernel void E_(device int* data0, const device int* data1, const device int* data2, uint3 gid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]]) {
int val0 = *(data1+0);
int val1 = *(data2+0);
*(data0+0) = (val0+val1);
}
a load operation generates this string: *(data1+0)
. We know that a load
must have an input referrinng to the data, and the other input referring to
the pointer position offset. We have just 1 element, so the offset is zero.
This is the load renderer:
def render_load(self, output_dtype, buf_name, buf_dtype, idx, local=False) -> str:
out_val = f"*({buf_name}+{idx})"
return out_val
idx
will have a value of zero in this case. What if we have more than 1
elements, then idx
will have a value of 1, 2, 3, etc. But the implementation
is done via the [[thread_position_in_threadgroup]], so if the input is:
a = Tensor([1, 2, 3])
b = Tensor([2, 5, 6])
c = a + b
print(c.numpy())
then idx will have a value of lidx0
, and the output code is:
kernel void E_2(device int* data0, const device int* data1, const device int* data2, uint3 gid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]]) {
int lidx0 = lid.x; /* 2 */
int val0 = *(data1+lidx0);
int val1 = *(data2+lidx0);
*(data0+lidx0) = (val0+val1);
}
with the uops list being:
0 UOps.DEFINE_GLOBAL : ptr.dtypes.int [] (0, 'data0', True)
1 UOps.DEFINE_GLOBAL : ptr.dtypes.int [] (1, 'data1', False)
2 UOps.DEFINE_GLOBAL : ptr.dtypes.int [] (2, 'data2', False)
3 UOps.SPECIAL : dtypes.int [] (0, 'lidx0', 2)
4 UOps.LOAD : dtypes.int [1, 3] None
5 UOps.LOAD : dtypes.int [2, 3] None
6 UOps.ALU : dtypes.int [4, 5] BinaryOps.ADD
7 UOps.STORE : [0, 3, 6] None
You see that idx with value zero came from previously the CONST
op, now it is a SPECIAL
op.
We know conceptualy idx takes the value from the second input, but when code gen
encounters the LOAD op, how does it go and grab the value from the input? Inside
uops_to_cstyle
:
elif uop is UOps.LOAD:
print(r[vin[0]]) # Added for illustration
print(r[vin[1]]) # Added for illustration
val = lang.render_load(dtype, r[vin[0]], vin[0].dtype, r[vin[1]], vin[0].uop is UOps.DEFINE_LOCAL)
the datainput is straightforward, the uop that maps to the first input: r[vin[0]]
the idx came from here: r[vin[1]]
. So the “Special” doesn’t have any effect here,
it is literally just grabbing the argument of the op. If it’s a CONST op, value
will be zero, if it’s a SPECIAL, value will be “lidx0”. You can check the print
statement and run it yourself.
You may be confused about the r variable, it is initially an empty dictionary, and gets populated with content as we iterate through the uops list item one by one, among them the CONST and SPECIAL op.
In the single element case, a CONST op is first encountered and gets handled as such:
elif uop is UOps.CONST:
r[u] = lang.render_const(args, dtype) if args >= 0 else f"({lang.render_const(args, dtype)})"
with the render_cosnt function being
def render_const(self, x:ConstType, dtype:DType) -> str:
if math.isnan(x): val = "NAN"
elif math.isinf(x): val = ("-" if x < 0 else "") + "INFINITY"
elif dtype == dtypes.float64: val = f"{x}"
else: val = f"{x}f" if dtypes.is_float(dtype) else f"{x}" if dtypes.is_int(dtype) else f"{x}".lower()
return (self.render_cast([val] * dtype.count, dtype) if dtype.count > 1 or dtype not in [dtypes.float, dtypes.int, dtypes.bool] else val)
which in the most simple case, is just the value being printed either with the ‘f’ suffix or as a plain number. And our dictionary now becomes
{
CONST: 0
}
In the case of the SPECIAL op,
elif uop is UOps.SPECIAL:
kk(f"int {args[1]} = {lang.code_for_workitem[args[1][0]](args[0])}; /* {args[2]} */")
r[u] = args[1]
with args being (0, 'lidx0', 2)
.
You can now see where the value of idx came from in render load, it’s the input value
that’s being added to the r
dictionary. If it’s a multi element, we use a
SPECIAL op to represent the pointer arithmetic, and store the number of thread
within threadgroup to iterate through the list. If it’s a single element,
we store the value of a CONST directly as input and put it in the r
dictionary,
so it renders as *(data1+0)
.
another common operation is type casting. We may start off with some integer but want the output to be a float.
from tinygrad import Tensor, dtypes
a = Tensor([1, 3])
b = Tensor([4, 3])
c = (a + b).cast(dtypes.float32)
print(c.numpy())
the metal code is
int lidx0 = lid.x; /* 2 */
int val0 = *(data1+lidx0);
int val1 = *(data2+lidx0);
*(data0+lidx0) = (float)((val0+val1));
THe only change is just the (float)
part, how is this rendered? Let’s see
uops:
0 UOps.DEFINE_GLOBAL : ptr.dtypes.float [] (0, 'data0', True)
1 UOps.DEFINE_GLOBAL : ptr.dtypes.int [] (1, 'data1', False)
2 UOps.DEFINE_GLOBAL : ptr.dtypes.int [] (2, 'data2', False)
3 UOps.SPECIAL : dtypes.int [] (0, 'lidx0', 2)
4 UOps.LOAD : dtypes.int [1, 3] None
5 UOps.LOAD : dtypes.int [2, 3] None
6 UOps.ALU : dtypes.int [4, 5] BinaryOps.ADD
7 UOps.CAST : dtypes.float [6] dtypes.float
8 UOps.STORE : [0, 3, 7] None
step 7 introduces the cast operation that operates on step 6. Let’s see the code gen on how it handles CAST:
elif uop in {UOps.CAST, UOps.BITCAST}:
if uop is UOps.BITCAST:
assert len(vin) == 1
precast = ssa('precast')
kk(f"{lang.render_dtype(cast(DType, vin[0].dtype))} {precast} = {r[vin[0]]};")
val = lang.render_cast([precast], dtype, bitcast=True)
else:
val = lang.render_cast([r[x] for x in vin], dtype, bitcast=False)
In our case, it’s just calling render_cast
with the input value, which is
stored in the register r
with the value of the addition operation. dtype
would be float as we wanted to cast it to float. Let’s see render_cast
:
def render_cast(self, x:List[str], var_dtype:DType, bitcast=False) -> str:
if bitcast: return f"(*(({self.buffer_prefix}{self.render_dtype(var_dtype)}*)&{x[0]}))"
if len(x) == 1: return f"({self.render_dtype(var_dtype)})({x[0]})"
assert len(x) == var_dtype.count, f"cast is wrong size {len(x)} != {var_dtype.count}"
assert self.float4 is not None, "vectorized cast is not supported on this platform"
return f"{self.float4.replace('float4', self.render_dtype(var_dtype))}({','.join(x)})"
Our scenario hits line 2, so we return (float)((val0+val1))
. Note that
the value of x would be ["(val0+val1)"]
.
Also note that what I show here is the parent method defined by CStyleLanguage
that may not reflect the actual GPU usage. In metal that’s actually the case,
so our MetalLanguage, which inherits from CStyleLanguage, overrides the render_cast
method:
def render_cast(self, x: List[str], var_dtype: DType, bitcast=False) -> str:
return f"as_type<{self.render_dtype(var_dtype)}>({x[0]})" if bitcast else super().render_cast(x, var_dtype)
Although our simple case it falls back to the parent method, it will use the as_type
keyword for something unique to metal in certain scenarios. This is part of how you
would extend tinygrad to custom accelerator.
All the arithmetic logic unit are handled similarly. In our case it is
an addition, rendered as (val0 + val1)
.
elif uop is UOps.ALU:
# remove parens if ALU types are the same. TODO: can do more here
if args in {BinaryOps.ADD,BinaryOps.MUL,BinaryOps.XOR}:
operands = [strip_parens(r[v]) if v.arg == args else r[v]for v in vin]
else:
operands = [r[v] for v in vin]
val = lang.code_for_op[args](*operands, dtype)
assert child_count[u] != 0, f"childless ALU op found {u}"
# TODO: fix index rendering issue. fix clang nested max macro issue
if child_count[u] <= 1 and args is not BinaryOps.MAX and not getenv("EXPAND_SSA"): r[u] = val
else: kk(f"{lang.render_dtype(dtype)} {ssa('alu',u)} = {val};")
In our case, operands = [strip_parens(r[v]) if v.arg == args else r[v]for v in vin]
is executed. Recall that vin is the two load operation, and the operands value
end up being ['val0', 'val1']
given how the two loads stored the value
in register. args
is ‘ADD’.
code_for_op is defined in the shared CStyleLanguage class unless overriden by specific inheritants:
code_for_op: Dict = {
UnaryOps.NEG: lambda x,dtype: f"(!{x})" if dtype is dtypes.bool else f"(-{x})", UnaryOps.SQRT: lambda x,dtype: f"sqrt({x})",
UnaryOps.EXP2: lambda x,dtype: f"exp2({x})", UnaryOps.LOG2: lambda x,dtype: f"log2({x})", UnaryOps.SIN: lambda x,dtype: f"sin({x})",
BinaryOps.ADD: lambda a,b,dtype: f"({a}+{b})", BinaryOps.SUB: lambda a,b,dtype: f"({a}-{b})", BinaryOps.MUL: lambda a,b,dtype: f"({a}*{b})",
BinaryOps.DIV: lambda a,b,dtype: f"({a}/{b})", BinaryOps.MAX: lambda a,b,dtype: f"max({a},{b})", BinaryOps.MOD: lambda a,b,dtype: f"({a}%{b})",
BinaryOps.CMPLT: lambda a,b,dtype: f"({a}<{b})", BinaryOps.CMPEQ: lambda a,b,dtype: f"({a}=={b})", BinaryOps.XOR: lambda a,b,dtype: f"({a}^{b})",
TernaryOps.WHERE: lambda a,b,c,dtype: f"({a}?{b}:{c})"}
and an ADD operation is just adding concatenating the two operands with a plus sign and surround them with parentheses. Finally we store the rendered result in the register dictionary ` r[u] = val` such that the previous CAST operation will pick it up. Similarly you can see how other operations are implemented.