`import jax.numpy as np`, then we also get a jax implemention after certain modifications: e.g. remove in-place index assignment, replace unsupported functions, etc
JAX requires a bit more work to maintain fixed-size buffers as required by XLA, especially in case of caching and rotary embeddings. But yeah, overall the code can be pretty similar [1].