Differentiation using JAX#

JAX, amongst other things, is a powerful tool for computing derivatives of native Python and NumPy code. Awkward Array implements support for the jax.jvp() and jax.vjp() JAX functions for computing forward/reverse-mode Jacobian-vector/vector-Jacobian products of functions that operate upon Awkard Arrays. Only a subset of Awkward Array operations can be differentiated through, including:

  • ufunc operations like x + y

  • reducers like ak.sum()

  • slices like x[1:]

How to differentiate Awkward Arrays?#

For this notebook (which is evaluated on a CPU), we need to configure JAX to use only the CPU.

import jax
jax.config.update("jax_platform_name", "cpu")

Next, we must call ak.jax.register_and_check() to register Awkward’s JAX integration.

import awkward as ak
ak.jax.register_and_check()

Let’s define a simple function that accepts an Awkward Array.

def reverse_sum(array):
    return ak.sum(array[::-1], axis=0)

We can then create an array with which to evaluate reverse_sum. The backend argument ensures that we build an Awkward Array that is backed by jaxlib.xla_extension.DeviceArray buffers, which power JAX’s automatic differentiation and JIT compiling features.

array = ak.Array([[1.0, 2.0, 3.0], [], [4.0, 5.0]], backend="jax")
reverse_sum(array)
[5.0,
 7.0,
 3.0]
-----------------
type: 3 * float32

To compute the JVP of reverse_sum requires a tangent vector, which can also be defined as an Awkward Array:

tangent = ak.Array([[0.0, 0.0, 0.0], [], [0.0, 1.0]], backend="jax")
value_jvp, jvp_grad = jax.jvp(reverse_sum, (array,), (tangent,))

jax.jvp() returns both the value of reverse_sum evaluated at array:

value_jvp
[5.0,
 7.0,
 3.0]
-----------------
type: 3 * float32
assert value_jvp.to_list() == reverse_sum(array).to_list()

and the JVP evaluted at array for the given tangent:

jvp_grad
[0.0,
 1.0,
 0.0]
-----------------
type: 3 * float32

JAX’s own documentation encourages the user to use jax.numpy instead of the canonical numpy module when operating upon JAX arrays. However, jax.numpy does not understand Awkward Arrays, so for ak.Arrays you should use the normal ak and numpy functions instead.