One of the premise of tinygrad is that it is simple to add support for new accelerator. So if you want to try out new hardware, you can implement the support with ease.
If you have read the official doc on adding accelerator you may know that are 21 operations to implement. But it doesn’t explain how things fit together in tinygrad and may be too abstract. I will try to fill in that gap in this post. I also won’t dive into too much details on the code generation part or the details of Metal runtime, but rather give you a good starting ground if you are looking to add implementation yourself.
Prerequisite reading is the tinygrad’s IR and tinygrad backend.
I’ll pretend to implement a secodn my metal accelerator called “mymetal”. To inject it to tinygrad,
we create a file with a specific naming format, that way the importlibs
will pick it up and parse the string, and also that
you can specify MYMETAL=1 in the env variable to instruct tinygrad
to use it. I will have to place it in the runtime
directory and
name it “ops_mymetal.py”:
Within the file, we are required to define a class with suffix
“Device” as its name so the upstream
stuff can find it, via this function in device.py
:
def __get_canonicalized_item(self, ix:str) -> Compiled:
if DEBUG >= 1: print(f"opening device {ix} from pid:{os.getpid()}")
x = ix.split(":")[0].upper()
return [cls for cname, cls in inspect.getmembers(importlib.import_module(f'tinygrad.runtime.ops_{x.lower()}')) if (cname.lower() == x.lower() + "device") and x in self._devices][0](ix) # noqa: E501
In our ops_mymetal.py, let’s define a class to test the integration.
class MyMetalDevice:
def __init__(self, *args, **kwargs):
print("Mymetal device is initialized with args:", args, kwargs)
Throughout the guide, I will again use the dot product example, save the following in a script.py anywhere
from tinygrad.tensor import Tensor
from tinygrad.device import Device
a = Tensor([1.0,2.0])
b = Tensor([3.0,4.0])
c = a.dot(b)
d = c.numpy()
print(d)
Now if you run the script with MYMETAL=1 python script.py
, you will
see errors but more importantly, a line printed like so:
Mymetal device is initialized with args: ('MYMETAL',)
Yay! Our custom class is now integrated into the rest of tinygrad.
Next we need to start implementing methods so it functions properly.
The entrypoint to our device is in lower_schedule_item
function
in realize.py
:
def lower_schedule_item(si:ScheduleItem) -> Optional[JITRunner]:
assert len(set(x.device for x in si.outputs+si.inputs)) == 1 or si.ast[0].op is LoadOps.COPY
if si.ast[0].op is BufferOps.STORE: return Device[si.outputs[0].device].get_runner(*si.ast)
assert len(si.ast) == 1 and len(si.outputs) == 1, "only ASTRunner supports multioutput"
out, ast = si.outputs[0], si.ast[0]
if ast.op is LoadOps.COPY:
if hasattr(Device[out.device].allocator, 'transfer') and out.device.split(":")[0] == si.inputs[0].device.split(":")[0]: return BufferXfer()
return BufferCopy()
if ast.op is LoadOps.CUSTOM: return CustomOp(ast.arg)
return None
The returned value is utilized as such in run_schedule
:
def run_schedule(schedule:List[ScheduleItem]):
while len(schedule):
si = schedule.pop(0)
prg = lower_schedule_item(si)
# run the function (put it in JIT)
if prg:
prg.exec(real_buffers, si.var_vals)
You can see that each schedule item is converted into a program that has an .exec
method defined, and will be called in order. To generalize this, the returned item must be an instance
of a class that inherits
JITRunner
, whether it is used as a kernel or a memory allocation. The base
JITRunner is defined as such:
class JITRunner:
def __init__(self):
self.op_estimate:sint = 0
self.mem_estimate:sint = 0
def exec(self, rawbufs:List[Buffer], var_vals:Optional[Dict[Variable, int]]=None) -> Optional[float]:
var_vals = var_vals if var_vals is not None else {}
from tinygrad.engine.jit import CacheCollector
et = self(rawbufs, var_vals)
if CACHECOLLECTING: CacheCollector.add(self, rawbufs, var_vals)
return et
def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False, jit=False) -> Optional[float]:
raise NotImplementedError("override this")
Since there are shared traits among all the accelerators, we can further abstract two base
classes that inherits JITRunner
. The first is CompiledJITRunner
and the second
is BufferCopy
. Again I pasted only the illustrative part of the implementation:
class CompiledASTRunner(JITRunner):
def __init__(self, name:str, prg:str, dname:str, global_size:Optional[List[int]]=None, local_size:Optional[List[int]]=None,
variables:Optional[List[Variable]]=None, op_estimate:sint=0, mem_estimate:sint=0, precompiled:Optional[bytes]=None, outcount:int=1):
super().__init__()
self.prg = prg
lib:bytes = precompiled if precompiled is not None else self.device.compiler.compile_cached(prg)
self.lib, self.clprg, self.outcount = lib, self.device.runtime(self.name, lib), outcount
if DEBUG >= 4: print(prg)
@property
def device(self): return Device[self.dname]
def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False, jit=False, do_update_stats=True) -> Optional[float]:
et = self.clprg(*[x._buf for x in rawbufs], **lra, vals=tuple(var_vals[k] for k in self.vars), wait=wait or DEBUG>=2)
return et
class BufferCopy(JITRunner):
def copy(self, dest, src):
dest.copyin(src.as_buffer(allow_zero_copy=True)) # may allocate a CPU buffer depending on allow_zero_copy
def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False, jit=False):
dest, src = rawbufs[0:2]
self.copy(dest, src)
So JITRunner will call the actual implementation of __call__
method. If it is
a memory allocator, some data copy is done, otherwise, the program, will be executed.
But note that in the program’s case, self.clprg()
is implemented elsewhere,
and dest.copyin is also implemented elsewhere. Let’s look break down the next
abstraction.
In lower_schedule_item
for a kernel (program), we see this line:
Device[si.outputs[0].device].get_runner(*si.ast)
Remember that Device[si.outputs[0].device]
dynamically resolve to the ops___.py
file and initialize the device class, meaning our MyMetal device must implement
get_runner method that returns a CompiledASTRunner
instance, passing in the
required program that can be executed. This is a common pattern across
all accelerator so tinygrad already has a base class for such functionality,
it is called Compiled
, and in reference implementation for cuda, existing metal,
and other accelerator, you can see that the CudaDevice
, MetalDevice
and so on
all inherits from it:
class Compiled:
def __init__(self, device:str, allocator:Allocator, compiler:Optional[Compiler], runtime, graph=None):
self.dname, self.allocator, self.compiler, self.runtime, self.graph = device, allocator, compiler, runtime, graph
def synchronize(self): pass # override this in your device
def to_program(self, k:Linearizer) -> CompiledASTRunner:
k.linearize()
ret = CompiledASTRunner(k.name, self.compiler.render(to_function_name(k.name), k.uops), self.dname, k.global_size, k.local_size,
k.uops.vars(), min(info.flops, ops * run_count), min(info.mem_estimate, mem * run_count), outcount=len(k.outbufs))
return ret
def get_linearizer(self, *ast:LazyOp) -> Linearizer:
from tinygrad.codegen.linearizer import Linearizer
k = Linearizer(*ast, opts=self.compiler.compiler_opts)
k.required_optimizations()
return k
def get_runner(self, *ast:LazyOp) -> CompiledASTRunner: return self.to_program(self.get_linearizer(*ast))
The actual implementation must pass the memory allocator, runtime, and compiler to the abstract class, Let’s look at the actual implementation of MetalDevice:
class MetalDevice(Compiled):
def __init__(self, device:str):
self.device = Metal.MTLCreateSystemDefaultDevice()
self.mtl_queue = self.device.newCommandQueueWithMaxCommandBufferCount_(1024)
self.mtl_buffers_in_flight: List[Any] = []
self.mv_in_metal: List[memoryview] = []
self.track_cross_buffer: List[Any] = []
from tinygrad.runtime.graph.metal import MetalGraph
super().__init__(device, MetalAllocator(self), MetalCompiler(None if getenv("METAL_XCODE") else self),
functools.partial(MetalProgram, self), functools.partial(MetalGraph, self))
and that’s where the actual connection to the Metal runtime happens!
Memory allocation is done through MetalAllocator(self)
, compiling the
metal code to binary happens in MetalCompiler(None if getenv("METAL_XCODE") else self), functools.partial(MetalProgram, self)
and the runtime is functools.partial(MetalGraph, self)
.
We can see that the MetalCompiler
must inherits Compiler
, which looks like below:
class Compiler:
compiler_opts: ClassVar[CompilerOptions]
def __init__(self, cachekey:Optional[str]=None): self.cachekey = None if getenv("DISABLE_COMPILER_CACHE") else cachekey
def render(self, name:str, uops:UOpGraph) -> str: raise NotImplementedError("need a render function")
def compile(self, src:str) -> bytes: raise NotImplementedError("need a compile function")
def compile_cached(self, src:str) -> bytes:
if self.cachekey is None or (lib := diskcache_get(self.cachekey, src)) is None:
lib = self.compile(src)
if self.cachekey is not None: diskcache_put(self.cachekey, src, lib)
return lib
That’s what executes the self.compiler.render(to_function_name(k.name), k.uops)
in the to_program
method of CompiledASTRunner
.
MetalAllocator
subclasses Allocator
, which is defined like this:
class Allocator:
def alloc(self, size:int, options:Optional[BufferOptions]=None):
assert not isinstance(size, int) or size > 0, f"alloc size must be positve, getting {size}"
return self._alloc(size, options if options is not None else BufferOptions())
def _alloc(self, size:int, options:BufferOptions): raise NotImplementedError("need alloc")
def free(self, opaque, size:int, options:Optional[BufferOptions]=None):
self._free(opaque, options if options is not None else BufferOptions())
def _free(self, opaque, options:BufferOptions): pass # if opaque is a Python object, you don't need a free
def copyin(self, dest, src:memoryview): raise NotImplementedError("need copyin")
def copyout(self, dest:memoryview, src): raise NotImplementedError("need copyout")
and the copyin
and copyout
will handle the methodcall in BufferCopy
we saw
above (note that there’s some intermediate class to optimize things, so between
Allocator and MetalAllocator there’s also the LRUAllocator
, but for simplicity
I’m just showing the base class).
Now let’s start the implementation, and I will mostly just be copying the existing code because they pertain more to the Metal runtime rather than tinygrad. This is where we would start:
class MyMetalDevice:
def __init__(self, *args):
print("Mymetal device is initialized with args:", args)
We know that we have to subclass Compiled
and implement the required methods:
import Metal, libdispatch
class MetalAllocator:
pass
class MetalCompiler:
pass
class MyMetalDevice:
def __init__(self, *args):
print("Mymetal device is initialized with args:", args)
self.device = Metal.MTLCreateSystemDefaultDevice()
self.mtl_queue = self.device.newCommandQueueWithMaxCommandBufferCount_(1024)
self.mtl_buffers_in_flight: List[Any] = []
self.mv_in_metal: List[memoryview] = []
self.track_cross_buffer: List[Any] = []
from tinygrad.runtime.graph.metal import MetalGraph
super().__init__(
device,
MetalAllocator(self),
MetalCompiler(self),
functools.partial(MetalProgram, self),
functools.partial(MetalGraph, self)
)
Some of the syntax are required by Metal API, you can see we imported the metal and libdispatch library, which are part of the objective-c and python bridge, you can refer to the doc here
Next part is to fill in the memory allocator, we will subclass LRUAllocator
which is a more refined version of the base Allocator and implement the required
methods as well:
from tinygrad.device import LRUAllocator
class MetalAllocator(LRUAllocator):
def __init__(self, device:MetalDevice):
self.device:MetalDevice = device
self.track_cross_device: Set[MetalDevice] = set()
super().__init__()
def free_cache(self):
self.device.synchronize()
for x in self.track_cross_device: x.synchronize()
self.track_cross_device.clear()
return super().free_cache()
def _alloc(self, size:int, options) -> Any:
ret = self.device.device.newBufferWithLength_options_(size, Metal.MTLResourceStorageModeShared)
if ret is None: raise MemoryError(f"Metal OOM while allocating {size=}")
return ret
def transfer(self, dest:Any, src:Any, sz:int, src_dev: MetalDevice, **kwargs):
src_dev.synchronize()
command_buffer = self.device.mtl_queue.commandBuffer()
encoder = command_buffer.blitCommandEncoder()
encoder.copyFromBuffer_sourceOffset_toBuffer_destinationOffset_size_(src, 0, dest, 0, sz)
encoder.endEncoding()
command_buffer.commit()
self.device.mtl_buffers_in_flight.append(command_buffer)
def from_buffer(self, src:memoryview) -> Optional[Any]:
ret = self.device.device.newBufferWithBytesNoCopy_length_options_deallocator_(src, len(src), Metal.MTLResourceStorageModeShared, None)
if ret: self.device.mv_in_metal.append(src)
return ret
def _free(self, opaque:Any, options): opaque.release()
def as_buffer(self, src:Any) -> memoryview:
self.device.synchronize()
return src.contents().as_buffer(src.length())
def copyin(self, dest:Any, src:memoryview): self.as_buffer(dest)[:] = src
def copyout(self, dest:memoryview, src:Any): dest[:] = self.as_buffer(src)
Next part is the compiler, let’s look at the render part first:
from tinygrad.device import Compiler
class MetalCompiler(Compiler):
compiler_opts = CompilerOptions("METAL", has_tensor_cores=os.uname().machine == "arm64", shared_max=32768)
def __init__(self, device:Optional[MetalDevice]):
self.device = device
super().__init__("compile_metal")
def render(self, name:str, uops) -> str: return MetalRenderer(name, uops)
What’s MetalRenderer? I won’t go into the details, but it is the code generation. It takes a series of uops, and convert them to the Metal flavoured C++ code, this is the definition:
class MetalLanguage(CStyleLanguage):
kernel_prefix = "kernel "
buffer_prefix = "device "
smem_prefix = "threadgroup "
arg_int_prefix = "constant int&"
barrier = "threadgroup_barrier(mem_flags::mem_threadgroup);"
float4 = "float4"
uses_ptr_arithmetic = True
code_for_workitem = {"g": lambda x: f"gid.{chr(120+x)}", "l": lambda x: f"lid.{chr(120+x)}"}
extra_args = ['uint3 gid [[threadgroup_position_in_grid]]', 'uint3 lid [[thread_position_in_threadgroup]]']
type_map = {dtypes.bfloat16: "bfloat"}
code_for_op = {**CStyleLanguage().code_for_op,
BinaryOps.MAX: lambda a,b,dtype: f"(bfloat)max((float){a},(float){b})" if dtype == dtypes.bfloat16 else f"max({a},{b})",
UnaryOps.SQRT: lambda x,dtype: f"(bfloat)sqrt({x})" if dtype == dtypes.bfloat16 else f"sqrt({x})",
UnaryOps.EXP2: lambda x,dtype: f"(bfloat)exp2({x})" if dtype == dtypes.bfloat16 else f"exp2({x})",
UnaryOps.LOG2: lambda x,dtype: f"(bfloat)log2({x})" if dtype == dtypes.bfloat16 else f"log2({x})",
UnaryOps.SIN: lambda x,dtype: f"(bfloat)sin({x})" if dtype == dtypes.bfloat16 else f"sin({x})",}
def render_cast(self, x: List[str], var_dtype: DType, bitcast=False) -> str:
return f"as_type<{self.render_dtype(var_dtype)}>({x[0]})" if bitcast else super().render_cast(x, var_dtype)
def render_kernel(self, function_name, kernel, bufs, uops, prefix=None):
prefix, wmma_args = ["#include <metal_stdlib>","using namespace metal;"], set([uop.arg for uop in uops if uop.uop is UOps.WMMA])
for arg in wmma_args: prefix.append(f"""{arg[3].name}2 __{arg[0]}({arg[2].name}2 m, {arg[2].name}2 n, {arg[3].name}2 o) {{
simdgroup_{arg[3].name}8x8 a,b,c; a.thread_elements()[0] = m.x; a.thread_elements()[1] = m.y; b.thread_elements()[0] = n.x;
b.thread_elements()[1] = n.y; c.thread_elements()[0] = o.x; c.thread_elements()[1] = o.y; simdgroup_multiply_accumulate(c, a, b, c);
return {arg[3].name}2(c.thread_elements()[0], c.thread_elements()[1]);\n}}""")
return super().render_kernel(function_name, kernel, bufs, uops, prefix)
MetalRenderer = functools.partial(uops_to_cstyle, MetalLanguage())
Remember the official guide on adding acclerator says you have to implement 21 required operations? This is where you would do it. Because metal is C++ code, the implementation is look at the UOPS and convert it to C++ strings.
After we have the rendered C++ code, we need to compile it into actual binary
via the compile
method:
class MetalCompiler(Compiler):
def compile(self, src:str) -> bytes:
options = Metal.MTLCompileOptions.new()
options.setFastMathEnabled_(getenv("METAL_FAST_MATH"))
library = unwrap2(self.device.device.newLibraryWithSource_options_error_(src, options, None))
return library.libraryDataContents().bytes().tobytes()
And last, is the program runtime. It receives the compiled bytecode and do the actual execution, again this is mostly Metal API specific stuff:
class MetalProgram:
def __init__(self, device:MetalDevice, name:str, lib:bytes):
self.device, self.name, self.lib = device, name, lib
data = libdispatch.dispatch_data_create(lib, len(lib), None, None)
self.library = unwrap2(self.device.device.newLibraryWithData_error_(data, None))
self.fxn = self.library.newFunctionWithName_(name)
self.pipeline_state = unwrap2(self.device.device.newComputePipelineStateWithFunction_error_(self.fxn, None))
def __call__(self, *bufs, global_size:Tuple[int,int,int]=(1,1,1), local_size:Tuple[int,int,int]=(1,1,1), vals:Tuple[int, ...]=(), wait=False):
if prod(local_size) > self.pipeline_state.maxTotalThreadsPerThreadgroup(): raise RuntimeError(f"local size {local_size} bigger than {self.pipeline_state.maxTotalThreadsPerThreadgroup()} with exec width {self.pipeline_state.threadExecutionWidth()} memory length {self.pipeline_state.staticThreadgroupMemoryLength()}") # noqa: E501
command_buffer = self.device.mtl_queue.commandBuffer()
encoder = command_buffer.computeCommandEncoder()
encoder.setComputePipelineState_(self.pipeline_state)
for i,a in enumerate(bufs): encoder.setBuffer_offset_atIndex_(a, 0, i)
for i,a in enumerate(vals,start=len(bufs)): encoder.setBytes_length_atIndex_(ctypes.c_int32(a), 4, i)
encoder.dispatchThreadgroups_threadsPerThreadgroup_(Metal.MTLSize(*global_size), Metal.MTLSize(*local_size))
encoder.endEncoding()
command_buffer.commit()
if wait:
wait_check(command_buffer)
return command_buffer.GPUEndTime() - command_buffer.GPUStartTime()
self.device.mtl_buffers_in_flight.append(command_buffer)
putting them all together is what you see in ops_metal.py! If you browse other files like ops_cuda.py, you should see that they follow a similar pattern.