Grokking implementations in Jax/Flax and Pytorch

February 23, 2025

https://github.com/atveit/torch_grokking https://github.com/atveit/jax_grokking

What is Grokking in Deep Learning?

Grokking is the phenomena where deep neural networks undergo a phase transition in generalization performance during training. This was first described in the 2022 paper: Grokking: Generalization beyond overfitting on small algorithmic datasets. More recently there has been findings pointing to that there are numerical behavior of the SoftMax layer that likely causes this effect.

Example of grokking - validation curve follows training to approx 100%

Grokking during training

Port of Grokking in MLX to Pytorch and Jax/Flax

I found an MLX implementation of Grokking (using a modern small deep learning architecture) on https://github.com/stockeh/mlx-grokking developed by and ported it to Pytorch and Jax/Flax to make it available for use on other platforms like Microsoft Windows and Nvidia's CUDA GPUs

You can find the code here:

https://github.com/atveit/torch_grokking https://github.com/atveit/jax_grokking

Best regards, Amund


Profile picture

Bio: Principal Product Manager at Microsoft, Past: Google and Youtube, PhD CS, Developer, Only Personal Opinions