My guess is that the slight overhead of interacting with mojo led to this speed discrepancy, and if a higher factorial (that was within the overflow limits etc) was run, this overhead would become negligible (as seen by the second example). Also similar to jax code being slower than numpy code for small operations, but being much faster for larger ones on cpus etc.
jax is very much a working (and in my view better, aside from the lack of community) software support. Especially if you use their images (which they do). > > Tensorflow
They have been using jax/flax/etc rather than tensorflow for a while now. They don't really use pytorch from what I see on the outside from their research works. For instance, they released siglip/siglip2 with flax linen: https://github.com/google-research/big_vision
TPUs very much have software support, hence why SSI etc use TPUs.
> They have been using jax/flax/etc rather than tensorflow for a while now
Jax has a harsher learning curve than Pytorch in my experience. Perhaps it's worth it (yay FP!) but it doesn't help adoption.
> They don't really use pytorch from what I see on the outside from their research works
Of course not, there is no outside world at Google - if internal tooling exists for a problem their culture effectively mandates using that before anything else, no matter the difference in quality. This basically explains the whole TF1/TF2 debacle which understandably left a poor taste in people's mouths. In any case while they don't use Pytorch, the rest of us very much do.
Right and in order to use it effectively you basically have to use Jax. Most researchers don't have the advantage of free compute so they are effectively trying to buy mindshare rather than winning on quality. This is fine, but it's worth repeating as it biases the discussion heavily - many proponents of Jax just so happen to be on TRC or have been given credits for TPU's via some other mechanism.
Also - getting access to a TPU on GCP (particularly when you don't have a <fancy_school>.edu email address) has historically been a _fucking nightmare_. Absolute shit show.
I am a high schooler, and easily got a tpuv4-64. No fancy school or edu email address, just a dream of winning geoguessr. They are very receptive to emails, I asked for more and they got more for me.
pytorch xla is barely supported in the pytorch ecosystem (for instance, pytorch lightning still doesn't easily support tpu pods, with only a singular short page about google colab v2-8 tpus that is out of date. Then there are the various libraries/implementations with pytorch that have a .cuda(), etc. More limitations at: https://lightning.ai/docs/pytorch/stable/accelerators/tpu_fa...). I haven't worked with tensorflow, but I've heard it's a pain even when using gpus. JAX is the real deal, and does make my code transferrable between GPUs/TPUs relatively easily (excluding any custom pallas kernels for flash vs splash attention, but this is usually not a massive code change). However, with JAX, there are often not a bunch of pre-existing implementations due to network effects, etc.