I personally end up relying on TensorFlow for most of my GPU needs at work still, but that training session was incredibly helpful for me to understand what was going on under the hood and helped demystify the CUDA kernels in general.
what do you think of something like https://github.com/google/jax that i guess claims to be a jack of all trades, in the sense that it'll transform programs (through XLA???) to various architectures?
JAX is absolutely what I would recommend if you were trying to implement something usable in the long term and wanted to go beyond just learning about CUDA kernels!
Looking at the JAX documentation it's _much_ better than the last time I saw it. Their tutorials seem fairly solid in fact. I do want to point out the difference though. You're conceptually operating at a very different point in JAX than in NUMBA.
For example, consider multiplying 2 matrices in JAX on your GPU. That's a simple example with just a few lines of code in the JAX tutorial[1].
On the other hand in the NUMBA tutorial from GTC I mentioned earlier, you have notebook 4, "Writing CUDA Kernels"[2], which teaches you about the programming model used to write computation for GPU's.
I'm sorry I was unclear. My recommendation of NUMBA is not so much in advocating for its use in a project, but more so in using it and its tutorials as an easy way of learning and experimenting with CUDA kernels without jumping into the deep end with C/C++. If you actually want to write some usable CUDA code for a project, keeping in mind JAX is still experimental, I would fully advocate for JAX over NUMBA.
I personally end up relying on TensorFlow for most of my GPU needs at work still, but that training session was incredibly helpful for me to understand what was going on under the hood and helped demystify the CUDA kernels in general.