Similar to the trick in matrix multiplication, convolution is also implemented with some clever shape movement.
Let’s create our input tensor as a 1 batch, 1 channel, 4 by 4 tensor, I will give them unique numbers (0 to 15) and reshape it, so we can visualize the intermediate steps better.
from tinygrad import Tensor
a = Tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]).reshape((1, 1, 4, 4))
print(a.numpy())
"""
[[[[ 0 1 2 3]
[ 4 5 6 7]
[ 8 9 10 11]
[12 13 14 15]]]]
"""
If we run convolution on this tensor, with stride 3, and keeping the rest as 1:
weight = Tensor.ones(1, 1, 3, 3)
out = a.conv2d(weight)
print(out.shape) # 1, 1, 2, 2
print(out.numpy())
"""
[[[[45. 54.]
[81. 90.]]]
"""
Nothing too surprising. Here’s how it works. The most basic case is without any padding or special handling, just a weight tensor sliding across the input. There’s a method called “_pool” to generate the required shape.
pooled = a._pool(k_=(3, 3), stride=1, dilation=1)
print(pooled.numpy())
"""
[[[[[[ 0 1 2]
[ 4 5 6]
[ 8 9 10]]
[[ 1 2 3]
[ 5 6 7]
[ 9 10 11]]]
[[[ 4 5 6]
[ 8 9 10]
[12 13 14]]
[[ 5 6 7]
[ 9 10 11]
[13 14 15]]]]]]
"""
Compare it to the input where number 0 - 15 is arranged in 4 by 4, you can see that this is the exact input to a convolution kernel. For example, the first block, when summed, is value 45, corresponding to the top left corner of the output.
The actual implementation of _pool
is just a series of expand, permute and shrink with the right sizes, but it’s best understood
by visually checking its output.
Stride refers to the step it slides each time, if we supply the kernel size to be (2, 2), with a stride of 2:
pooled = a._pool(k_=(2, 2), stride=2, dilation=1)
"""
[[[[[[ 0 1]
[ 4 5]]
[[ 2 3]
[ 6 7]]]
[[[ 8 9]
[12 13]]
[[10 11]
[14 15]]]]]]
"""
Dilation refers to how much to expand the kernel to make it “grab” onto more elements, while skipping those in the middle, again it’s best explained by visualizing it:
pooled = a._pool(k_=(2, 2), stride=1, dilation=2)
"""
[[[[[[ 0 2]
[ 8 10]]
[[ 1 3]
[ 9 11]]]
[[[ 4 6]
[12 14]]
[[ 5 7]
[13 15]]]]]]
"""
After getting the pooled tensor, the rest is simply multiplying elementwise and doing a sum across the height and width axis.
We could have created our input tensor with just Tensor.arange(0, 15).reshape()
, instead of typing the elements manually.
It may surprise you that arange
uses the same trick as convolution. This is what allows arange to be efficient. It would
have been very slow if generating 16 numbers requires 16 iteration of a for loop. Let’s find the generated kernel code with
DEBUG=5 and NOOPT=1 (I’m turning off the optimization to make output easier to reason about)
# Run with NOOPT=1 DEBUG=5 python script.py
Tensor.arange(15).realize()
"""
kernel void r_15_15(device int* data0, uint3 gid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]]) {
int gidx0 = gid.x; /* 15 */
*(data0+gidx0) = gidx0;
}
"""
So instead of running a loop, it generates 16 kernel threads, each one just writes the value to the buffer. This example is a bit trivial because the index happens to map onto the value, but inside the kernel code we can add offsets and steps so it works for more complex range arguments.
The trick is on how this kernel is generated.
The first step is create a tensor filled with just values 1, the size will be 1 minus the output length, so 15:
[1 1 1 1 1 1 1 1 1 1 1 1 1 1 1] # 15 ones here
We then pad with zeros to the front, with 1 minus the size being the number of elements to pad:
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1] # 14 zeros, 15 ones
Lastly, we do the same pool operation, with kernel size being 15, stride is 1, dilation is one:
[[0 0 0 0 0 0 0 0 0 0 0 0 0 0 1]
[0 0 0 0 0 0 0 0 0 0 0 0 0 1 1]
[0 0 0 0 0 0 0 0 0 0 0 0 1 1 1]
[0 0 0 0 0 0 0 0 0 0 0 1 1 1 1]
[0 0 0 0 0 0 0 0 0 0 1 1 1 1 1]
[0 0 0 0 0 0 0 0 0 1 1 1 1 1 1]
[0 0 0 0 0 0 0 0 1 1 1 1 1 1 1]
[0 0 0 0 0 0 0 1 1 1 1 1 1 1 1]
[0 0 0 0 0 0 1 1 1 1 1 1 1 1 1]
[0 0 0 0 0 1 1 1 1 1 1 1 1 1 1]
[0 0 0 0 1 1 1 1 1 1 1 1 1 1 1]
[0 0 0 1 1 1 1 1 1 1 1 1 1 1 1]
[0 0 1 1 1 1 1 1 1 1 1 1 1 1 1]
[0 1 1 1 1 1 1 1 1 1 1 1 1 1 1]
[1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]]
The values seem arbitrary, and they were in fact derived from the desired output, what we want is a trinagular shape like above, where each row has one more “1” than the previous. Now if we add each row together, we get a number range 1 to 16. Lastly we just have to subtract 1 from it to get the desired 0 - 15 output.
The interesting part of this approach that it is parallelizable. Each thread will handle one row of data, and the data is determined entirely based on its position in the output. So it takes the same time to generates value from 0 - 15, and the same time for value 0 to 150.