A Rigorous (and Not-So-Rigorous) Look at JAX's Autograd
What this post is about
I've started writing this post from scratch several times as I've tried to get the presentation just right. I settled on this two goals: (1) if you have heard the terms "vector space" and "gradient" thrown about before, you should be able to understand and write JAX primitives, and (2) if you're comfortable with vector spaces, dual spaces, and multivariable calculus, I will assuage any fears about rigor.
JAX is a high-performance numerical computing framework for Python that can wrap a normal python function to calculate the derivative. We'll cover JAX-flavored reverse-mode automatic differentiation (autodiff), and focus on the 'differentiation' part. The 'automatic' part basically works by wrapping input ndarray
's in a custom class that keeps track of computation. That is out of the scope of this post, but you should check out the Resources section for more information (or my 110-line python implementation).
Resources
I found plenty high-quality resources on autodiff. The first post I read was this one, and I wrote down my takeaways here. This post is also excellent. However, they cover mostly how to directly apply the chain rule. In contrast, the JAX-flavored way of doing things involved "Jacobian-vector products" and other foreign terms. I really liked the JAX approach of providing a higher-order grad
function that can wrap any function you like. That approach seemed to have a structured procedure to applying the chain rule.
In my efforts, I stumbled upon many resources specifically related to JAX. One of the main contributors behind JAX made a video about their framework (my notes on it are here). However, the most directly relevant information is the autodiff cookbook for JAX (specifically the section I linked to) and documentation on what you need to do to implement a primitive. Their approach is grounded in the literature, specifically this paper from 2008. And finally, I found a core JAX contributor's phd thesis on the implementation of a precursor library (pages 13-19 and 48-52 are relevant).
How does autodiff work?
First, we'll look at an example. Here is a simple neural net as a function of its input. I've colored each part of the function to correspond to its description.
This is a single-layer neural net with arctan activation and MSE loss w.r.t. labels . This is pretty standard (aside from the maybe uncommon choice of ). Notice that usually, you'd want to take the derivative with respect to the weights (so or ), but I'm keeping this example simpler. We will only take the derivative with respect to the input This way, we don't need to deal with multiple dimensions.
If all the computations for the autodiff of (evaluation and gradient accumulation) happened sequentially, it would look like the following. We will execute each operation of the neural net individually so that we can use these intermediate computations in our backward pass. Note: in numpy
, the .T
property takes the transpose of a 2D array. Don't worry about the Jt_f(x, gt)
functions, they will be explained later.
a = W @ x
b = a + bias # I've replaced `b` from f(x) with `bias`
c = np.arctan(b)
d = c - y
e = d**2
loss = np.sum(e)
# begin backward mode
gt = 1.0
e_gt = Jt_sum(loss, g_t) # gt * np.ones((1, loss.shape[0]))
d_gt = Jt_square(d, e_gt) # e_gt * (2 * d.reshape(1,-1))
c_gt = Jt_subtract_const(c, d_gt) # d_gt
b_gt = Jt_arctan(b, c_gt) # c_gt * (1/(1 + b**2)).reshape(1, -1)
a_gt = Jt_add_const(a, b_gt)# b_gt
delta_xt = Jt_matmul(x, a_g)# a_gt @ W
delta_x = delta_xt.T
There are several things to take away from this code snippet. We see that in backward-mode we are using the values from the original computation (a
, b
, c
, etc.). I've put comments to inline the definitions of the Jt
functions. And, at the end of the computation, we get a delta_x
: the gradient of with respect to the input x
.
What are those Jt_f(x, gt)
functions?
Jt_func
is an inverse map of the derivative (i.e. a map from the tangent space of the output to the tangent space of the input). The function is just a efficient method of evaluated the transpose of the derivative on , and since the derivative isn't guaranteed to be orthogonal, it is not an inverse map. It just so happens that is the steepest ascent when is the value in our backwards mode differentiation.I'm borrowing the transpose vector Jacobian product (Jt
) from JAX documentation. I keep the notation the same so if you read the JAX documentation, their notation will be familiar. The next few paragraphs are meant to click later in the post; please don't give up just because it doesn't quite make sense now.
When looking at a function call like Jt_f(x, gt)
, the tells us where in the input space of some we are, and, morally[1] , is a direction in the output space of that . Jt_func
turns that direction in the output space, to a direction of the gradient in the input space.* Here's a picture:
Figure 1: On the right we have our output space. I named a vector in this space (later I will explain why I use the transpose notation). For now, we will use it as a row vector. On the left, we see the output of Jt_f(x, gt)
, which is an increment to the input.
You can see the direction in the output being transformed into the input space.* This is how we propagate the gradient through func
. If you have the gradient of at , you can compute
Where is the Jacobian matrix of at . Leibniz's notation is insane, so we're going to adopt the notation from Spivak, and other people agree. Let me clarify the notation a little bit more. In math, we often confuse functions and expressions, so if , we would talk about the function , when we should be talking about the function . Normally the derivative operator applies to expressions, so we write or which means "expand into an expression and then differentiate the expression". In the notation I’m using, the derivative operator applies to functions, we write , which is a new function with same input type. We abbreviate as . Thus, let
be the Jacobian matrix. Note that I'll leave off the parens and assume that always applies to it's function before we apply the function to . Finally, let's rewrite the equation:
Much better. Note that is a matrix, and is any column vector (we'll talk about what really is later). Now, why does composing Jt_func
work? We'll hold off on that for a bit (unless you can't wait, in which case it's here). If you're confused right now, that's entirely fair. We'll go through many examples, and then I'll explain the underlying theory. I've just said this stuff so it will click later on.
The way we figure out how to implement these Jt
functions is we evaluate (which you should think about as taking the direction from output space to a direction in the input space)* for a given . Let's start with the functions I introduced in the simple neural net at the beginning of this article.
Implementing Jt_subtract_const
We will start with subtract_const
. with a constant. We can calculate the derivative by plucking it out of the taylor series:
We collect terms, and the coefficient of the first-order term is the derivative. We can do this for our function :
Thus, the identity matrix. So, If we want to calculate Jt_subtract_const
we need to figure out what is, which is simply . Thus,
def Jt_subtract_const(x, gt):
"""x: ndarray shape(n,)
gt: ndarray shape(1,n)"""
return gt
Let's do a harder one.
Implementing Jt_square
Now element-wise. Remember, to find the derivative, we take
With as element-wise product. Thus, if we name each component of the vector non bold:
then
Since this matrix does the same thing as multiplying element-wise by . Multiplying this with gives . So,
def Jt_square(x, gt):
"""x: ndarray shape(n,)
gt: ndarray shape(1,n)"""
return gt * (2 * x.reshape(1, -1))
Implementing Jt_arctan
Now element-wise. This one is much trickier. There's probably a way to implement Jt_arctan
using the Taylor series, but there is an easier way in this scenario. To understand it, let's name the components of the output:
When we first learn multivariable calculus, you're probably taught that the definition of the Jacobian is the matrix of partial derivatives, i.e.
We're going to exploit the fact that the Jacobian happens to equal this matrix of partial derivatives since we know how to evaluate the scalar . Because our function is element-wise, only the diagonal will have values on it. That's because has no effect on unless (which would put it on the diagonal). This is true in general: element-wise operations yield a diagonal Jacobian. In our scenario, since , this gives us
Putting it into code we get:
def Jt_arctan(x, gt):
"""x: ndarray shape(n,)
gt: ndarray shape(1,n)"""
return gt * (1/(1 + x**2)).reshape(1, -1)
What is going on (bringing back rigor)
You might be thinking: I'm familiar with multivariable calculus, I love a good dual space, but this all feels unjustified. To fix that, we'll look again at what a derivative really is. Given a function between vector spaces, the derivative is the linear map that takes incremental inputs of and will output incremental changes in the value of . It is often represented as a matrix, but to clarify the types, we'll pretend it's just a linear function. That is, , and . That is, we're currying the function when we evaluate it at . It has a nice property that
We're using slightly different notation since we're treating as a function, but remember, the application of it as a function is just left multiplication with the matrix. With our new understanding that the derivative is a linear map (and it just happens to be a matrix). Let's look again at this figure
Figure 2: On the left, we see the input space and the tangent space at which contains . On the right we have our output space . Note that is in the tangent space at
There are several things to notice. First, I've drawn the dashed axes to represent the domain and codomain of . Now, since we're doing reverse mode, we need to go the other direction: take incremental changes in output and turn them into incremental changes to input. Luckily, we can take the transpose of a linear map:
Where . It is important to realize isn't special.* It can be any element of . We're now dealing with the dual space. Elements are often referred to as covectors. They are a linear map from their "normal" space to the underlying field. Thus, . This means the types the expression make sense: , or . Often people think of the dual space as linear functions that are "taking the inner product with a vector". So
This is especially useful because we can then represent as a row vector, which, when combined with a column vector from the nondual space, gives the inner product. This is the reason I've been so careful with keeping the gt
of shape (1,n)
because these s are members of the dual space and are actually row vectors. From now on, I'll use as a reminder that it's a row vector.
Just to check that using the covectors as row vectors will work, let's examine our vector-Jacobian product
Thus, with our newfound correspondence, let's evaluate the vector-Jacobian product (and employ chain rule). Let , then
Recall our definition from earlier:
Then
What we should notice here is that is an element of the dual space of the output of . Thus, we can apply this strategy again, where we incorporate an element into a call to a Jt
function:
This is very busy notation, so to look at it better, let then the final expression becomes
This should hopefully justify the program I provided at the beginning of the post. Just by knowing the Jt
functions of the primitives, we end up with the ability to convert into a covector of the input space of . Now let's reflect on what this means. If we choose . Then we get back a row of the derivative:
(borrowing notation from Sussman, is the partial derivative with respect to the first argument).
This is great because almost all loss functions in ML are scalar, so the Jacobian just has one row, and thus, we can efficiently compute the gradient for our descent.
You may be thinking, hold on a minute, the gradient is actually a linear form (in the case of a scalar loss function): we can't add that to , and you're right. We should think of the gradient as something that can turn small changes of into small changes in . When we're doing gradient descent, we want a small change in that will maximally decrease for a given step size. The small change in that will cause the most increase is in fact the transpose of the gradient (this makes sense since the dot product is maximal when the vectors are parallel). So, it is okay to use the gradient for gradient descent after all.
It is important to note that the components of the Jacobian depend on the basis of the vector spaces of the domain and codomain.
Wrapping up
I hope this post helped you understand JAX backward-mode automatic differentiation. I wrote an implementation that automatically tracks the computations, so you don't have to write out the Jt_f()
functions explicitly. This is one of my most technical posts. Let me know if any part of it needs better or more in depth explanation.
I wrote a follow-up post here, which should clarify what gradient descent actually is and why these Jacobian vector products help.
[1] | morally in this sense |