A Rigorous (and Not-So-Rigorous) Look at JAX's Autograd

Date: 2024-01-05 | Author: Jason Eveleth

Table of Contents

Warning!
(Check out my new follow up post: more gradient descent.) This post is correct but contains some misleading statements that I've labeled with an astricks*, and warnings at the beginning of each section. I wrote this post as I was learning differential geometry, thank you for joining me on this journey.

What this post is about

I've started writing this post from scratch several times as I've tried to get the presentation just right. I settled on this two goals: (1) if you have heard the terms "vector space" and "gradient" thrown about before, you should be able to understand and write JAX primitives, and (2) if you're comfortable with vector spaces, dual spaces, and multivariable calculus, I will assuage any fears about rigor.

JAX is a high-performance numerical computing framework for Python that can wrap a normal python function to calculate the derivative. We'll cover JAX-flavored reverse-mode automatic differentiation (autodiff), and focus on the 'differentiation' part. The 'automatic' part basically works by wrapping input ndarray's in a custom class that keeps track of computation. That is out of the scope of this post, but you should check out the Resources section for more information (or my 110-line python implementation).

Resources

I found plenty high-quality resources on autodiff. The first post I read was this one, and I wrote down my takeaways here. This post is also excellent. However, they cover mostly how to directly apply the chain rule. In contrast, the JAX-flavored way of doing things involved "Jacobian-vector products" and other foreign terms. I really liked the JAX approach of providing a higher-order grad function that can wrap any function you like. That approach seemed to have a structured procedure to applying the chain rule.

In my efforts, I stumbled upon many resources specifically related to JAX. One of the main contributors behind JAX made a video about their framework (my notes on it are here). However, the most directly relevant information is the autodiff cookbook for JAX (specifically the section I linked to) and documentation on what you need to do to implement a primitive. Their approach is grounded in the literature, specifically this paper from 2008. And finally, I found a core JAX contributor's phd thesis on the implementation of a precursor library (pages 13-19 and 48-52 are relevant).

How does autodiff work?

First, we'll look at an example. Here is a simple neural net as a function of its input. I've colored each part of the function to correspond to its description.

f(x):=arctan(Wx+b)y2 f(\bm{x}) := {\color{red} \lvert\lvert} {\color{green}\arctan(}{\color{blue} W\bm{x} + \bm{b} }{\color{green})} - {\color{purple}\bm{y} } {\color{red} \rvert\rvert ^{2} }

This is a single-layer neural net with arctan activation and MSE loss w.r.t. labels . This is pretty standard (aside from the maybe uncommon choice of arctan\arctan). Notice that usually, you'd want to take the derivative with respect to the weights (so WW or bb), but I'm keeping this example simpler. We will only take the derivative with respect to the input x\bm{x} This way, we don't need to deal with multiple dimensions.

If all the computations for the autodiff of ff (evaluation and gradient accumulation) happened sequentially, it would look like the following. We will execute each operation of the neural net individually so that we can use these intermediate computations in our backward pass. Note: in numpy, the .T property takes the transpose of a 2D array. Don't worry about the Jt_f(x, gt) functions, they will be explained later.

a = W @ x
b = a + bias # I've replaced `b` from f(x) with `bias`
c = np.arctan(b)
d = c - y
e = d**2
loss = np.sum(e)
# begin backward mode
gt = 1.0
e_gt = Jt_sum(loss, g_t)    # gt * np.ones((1, loss.shape[0]))
d_gt = Jt_square(d, e_gt)   # e_gt * (2 * d.reshape(1,-1))
c_gt = Jt_subtract_const(c, d_gt) # d_gt
b_gt = Jt_arctan(b, c_gt)   # c_gt * (1/(1 + b**2)).reshape(1, -1)
a_gt = Jt_add_const(a, b_gt)# b_gt
delta_xt = Jt_matmul(x, a_g)# a_gt @ W
delta_x = delta_xt.T

There are several things to take away from this code snippet. We see that in backward-mode we are using the values from the original computation (a, b, c, etc.). I've put comments to inline the definitions of the Jt functions. And, at the end of the computation, we get a delta_x: the gradient of ff with respect to the input x.

What are those Jt_f(x, gt) functions?

Warning!
I make three comments (labeled with an astricks*) that make it seem like the Jt_func is an inverse map of the derivative (i.e. a map from the tangent space of the output to the tangent space of the input). The function is just a efficient method of evaluated the transpose of the derivative on gg^{\intercal}, and since the derivative isn't guaranteed to be orthogonal, it is not an inverse map. It just so happens that Jt_func(x,g)\texttt{Jt\_func}(\bm{x}, \bm{g}^{\intercal})^{\intercal} is the steepest ascent when g\bm{g} is the value in our backwards mode differentiation.

I'm borrowing the transpose vector Jacobian product (Jt) from JAX documentation. I keep the notation the same so if you read the JAX documentation, their notation will be familiar. The next few paragraphs are meant to click later in the post; please don't give up just because it doesn't quite make sense now.

When looking at a function call like Jt_f(x, gt), the x\bm{x} tells us where in the input space of some f:RnRmf:\mathbb{R}^{n}\to \mathbb{R}^{m} we are, and, morally[1] , g\bm{g}^{\intercal} is a direction in the output space of that ff. Jt_func turns that direction in the output space, to a direction of the gradient in the input space.* Here's a picture:

Figure 1: On the right we have our output space. I named a vector in this space g\bm{g}^{\intercal} (later I will explain why I use the transpose notation). For now, we will use it as a row vector. On the left, we see the output of Jt_f(x, gt), which is an increment to the input.

You can see the direction g\bm{g}^{\intercal} in the output being transformed into the input space.* This is how we propagate the gradient through func. If you have the gradient of ff at x\bm{x}, you can compute

Jt_func(x,g)=g(ddyf(y)y=x) \texttt{Jt\_func}(\bm{x}, \bm{g}^{\intercal}) = \bm{g}^{\intercal} \left(\left.\frac{d}{d \bm{y} }f(\bm{y})\right|_{\bm{y}=\bm{x}}\right)

Where ddyf(y)y=x\left.\frac{d}{d\bm{y} }f(\bm{y})\right|_{\bm{y}=\bm{x}} is the Jacobian matrix of ff at x\bm{x}. Leibniz's notation is insane, so we're going to adopt the notation from Spivak, and other people agree. Let me clarify the notation a little bit more. In math, we often confuse functions and expressions, so if h(x):=2x+5h(x) := 2x + 5, we would talk about the function h(x)h(x), when we should be talking about the function hh. Normally the derivative operator applies to expressions, so we write ddx(2x+5)\frac{d}{dx}(2x + 5) or ddxh(x)\frac{d}{dx}h(x) which means "expand hh into an expression and then differentiate the expression". In the notation I’m using, the derivative operator applies to functions, we write ddxh\frac{d}{dx}h, which is a new function with same input type. We abbreviate ddx\frac{d}{dx} as DD. Thus, let

(Df)(x)=Df(x)=ddyf(y)y=x(Df)(\bm{x}) = Df(\bm{x}) = \left.\frac{d}{d\bm{y} }f(\bm{y})\right|_{\bm{y}=\bm{x}}

be the Jacobian matrix. Note that I'll leave off the parens and assume that DD always applies to it's function before we apply the function to xx. Finally, let's rewrite the equation:

Jt_func(x,g):=gDf(x). \texttt{Jt\_func}(\bm{x}, \bm{g}^{\intercal}) := \bm{g}^{\intercal}Df(\bm{x}).

Much better. Note that Df(x)Df(\bm{x}) is a matrix, and g\bm{g}^{\intercal} is any column vector (we'll talk about what g\bm{g} really is later). Now, why does composing Jt_func work? We'll hold off on that for a bit (unless you can't wait, in which case it's here). If you're confused right now, that's entirely fair. We'll go through many examples, and then I'll explain the underlying theory. I've just said this stuff so it will click later on.

The way we figure out how to implement these Jt functions is we evaluate gDf(x)\bm{g}^{\intercal}Df(\bm{x}) (which you should think about as taking the direction g\bm{g}^{\intercal} from output space to a direction in the input space)* for a given ff. Let's start with the functions I introduced in the simple neural net at the beginning of this article.

Implementing Jt_subtract_const

We will start with subtract_const. f(x):=xbf(\bm{x}) := \bm{x} - \bm{b} with b\bm{b} a constant. We can calculate the derivative by plucking it out of the taylor series:

f(x+Δx)=f(x)+Df(x)Δx+O(Δx2). f(\bm{x}+ \Delta \bm{x}) = f(\bm{x}) + Df(\bm{x})\Delta \bm{x} + O(\Delta \bm{x}^2).

We collect terms, and the coefficient of the first-order term is the derivative. We can do this for our function ff :

f(x+Δx)=(x+Δx)b=xb+Δx=f(x)+IΔx \begin{align*} f(\bm{x}+\Delta \bm{x}) &= (\bm{x} + \Delta \bm{x}) - \bm{b}\\ &= \bm{x} - \bm{b} + \Delta \bm{x} \\ &= f(\bm{x}) + I\Delta \bm{x} \\ \end{align*}

Thus, Df(x)=IDf(\bm{x}) = I the identity matrix. So, If we want to calculate Jt_subtract_const we need to figure out what gI\bm{g}^{\intercal}I is, which is simply g\bm{g}^{\intercal}. Thus,

def Jt_subtract_const(x, gt):
    """x: ndarray shape(n,)
       gt: ndarray shape(1,n)"""
    return gt

Let's do a harder one.

Implementing Jt_square

Now f(x):=x2f(\bm{x}) := \bm{x}^{2} element-wise. Remember, to find the derivative, we take

f(x+Δx)=(x+Δx)2=(x+Δx)(x+Δx)=x2+2xΔx+(Δx)2=f(x)+2xΔx+O((Δx)2). \begin{align*} f(\bm{x} + \Delta\bm{x}) &=(\bm{x} + \Delta\bm{x})^{2}\\ &=(\bm{x} + \Delta\bm{x})\odot(\bm{x} + \Delta\bm{x})\\ &=\bm{x}^{2}+2\bm{x} \odot \Delta\bm{x} + (\Delta\bm{x})^{2}\\ &=f(\bm{x})+2\bm{x} \odot\Delta\bm{x} + O((\Delta\bm{x})^{2}).\\ \end{align*}

With \odot as element-wise product. Thus, if we name each component of the vector non bold:

x=(x1x2xn), \bm{x} = \begin{pmatrix} x_{1} \\ x_{2} \\ \vdots\\x_{n} \end{pmatrix},

then

Df(x)=[2x1000002xn]. Df(\bm{x}) = \begin{bmatrix} 2x_{1}& 0 & \dots& 0 \\ 0& \ddots & & \\ & & \ddots & 0\\ & & 0& 2x_{n}\\ \end{bmatrix}.

Since this matrix does the same thing as multiplying element-wise by 2x2\bm{x} . Multiplying this with g\bm{g}^{\intercal} gives [2x1g12xngn][2x_{1}g_{1} \dots 2x_{n}g_{n}]. So,

def Jt_square(x, gt):
    """x: ndarray shape(n,)
       gt: ndarray shape(1,n)"""
    return gt * (2 * x.reshape(1, -1))

Implementing Jt_arctan

Now f(x):=arctan(x)f(\bm{x}) := \arctan(\bm{x}) element-wise. This one is much trickier. There's probably a way to implement Jt_arctan using the Taylor series, but there is an easier way in this scenario. To understand it, let's name the components of the output:

f(x)=(f1(x)fn(x))=(arctan(x1)arctan(xn)).f(\bm{x})=\begin{pmatrix} f_{1}(\bm{x})\\ \vdots\\ f_{n}(\bm{x}) \end{pmatrix}=\begin{pmatrix} \arctan(x_{1})\\ \vdots\\ \arctan(x_{n}) \end{pmatrix}.

When we first learn multivariable calculus, you're probably taught that the definition of the Jacobian is the matrix of partial derivatives, i.e.

Df(x)=[1f1(x)2f1(x)nf1(x)1f2(x)nfn1(x)n1fn(x)nfn(x)]. Df(\bm{x}) = \begin{bmatrix} \partial_{1}f_{1}(\bm{x})& \partial_{2}f_{1}(\bm{x}) & \dots& \partial_{n}f_{1}(\bm{x}) \\ \partial_{1}f_{2}(\bm{x})& \ddots & & \\ \dots& & \ddots & \partial_{n}f_{n-1}(\bm{x})\\ & & \partial_{n-1}f_{n}(\bm{x})& \partial_{n}f_{n}(\bm{x})\\ \end{bmatrix}.

We're going to exploit the fact that the Jacobian happens to equal this matrix of partial derivatives since we know how to evaluate the scalar arctan\arctan. Because our function ff is element-wise, only the diagonal will have values on it. That's because xmx_{m} has no effect on fk(x)=arctan(xk)f_{k}(\bm{x}) = \arctan(x_{k}) unless k=mk=m (which would put it on the diagonal). This is true in general: element-wise operations yield a diagonal Jacobian. In our scenario, since D(arctan)(x)=1/(1+x2)D(\arctan)(x)=1/(1 + x^{2}), this gives us

Df(x)=[1/(1+x12)000001/(1+xn2)] Df(\bm{x}) = \begin{bmatrix} 1 /(1 + x_{1}^{2})& 0 & \dots& 0 \\ 0& \ddots & & \\ & & \ddots & 0\\ & & 0& 1 / (1 + x_{n}^{2})\\ \end{bmatrix}

Putting it into code we get:

def Jt_arctan(x, gt):
    """x: ndarray shape(n,)
       gt: ndarray shape(1,n)"""
    return gt * (1/(1 + x**2)).reshape(1, -1)

What is going on (bringing back rigor)

Warning!
I claim "g\bm{g} isn't special" when it is. Otherwise the Df(x)gDf(x)^{\intercal}\bm{g} is just a random value in a vector space with the same dimension as the tangent space in the input. Also, while it is correct, the side quest about the dual space just makes the article a little less clear.

You might be thinking: I'm familiar with multivariable calculus, I love a good dual space, but this all feels unjustified. To fix that, we'll look again at what a derivative really is. Given a function f:VWf:V\to W between vector spaces, the derivative is the linear map that takes incremental inputs of ff and will output incremental changes in the value of ff. It is often represented as a matrix, but to clarify the types, we'll pretend it's just a linear function. That is, Df:V(VW)Df: V \to (V \to W), and Df(x):VWDf(\bm{x}): V\to W. That is, we're currying the function when we evaluate it at x\bm{x}. It has a nice property that

f(x)+Df(x)(h)f(x+h). f(\bm{x}) + Df(\bm{x})(\bm{h})\approx f(\bm{x}+\bm{h}).

We're using slightly different notation since we're treating Df(x)Df(\bm{x}) as a function, but remember, the application of it as a function is just left multiplication with the matrix. With our new understanding that the derivative is a linear map (and it just happens to be a matrix). Let's look again at this figure

Figure 2: On the left, we see the input space VV and the tangent space at x\bm{x} which contains h\bm{h} . On the right we have our output space WW. Note that Df(x)hDf(\bm{x})\bm{h} is in the tangent space at f(x)f(\bm{x})

There are several things to notice. First, I've drawn the dashed axes to represent the domain and codomain of Df(x)Df(\bm{x}). Now, since we're doing reverse mode, we need to go the other direction: take incremental changes in output and turn them into incremental changes to input. Luckily, we can take the transpose of a linear map:

Df(x):WVDf(x)(g)=gDf(x) \begin{align*} Df(\bm{x})^{\intercal}:W^*\to V^*\\ Df(\bm{x})^{\intercal}(\bm{g})=\bm{g}\circ Df(\bm{x}) \end{align*}

Where gW\bm{g}\in W^*. It is important to realize g\bm{g} isn't special.* It can be any element of WW^{*} . We're now dealing with the dual space. Elements are often referred to as covectors. They are a linear map from their "normal" space to the underlying field. Thus, g:WR\bm{g}:W\to \mathbb{R}. This means the types the expression make sense: gDf(x):VR\bm{g}\circ Df(\bm{x}):V\to \mathbb{R}, or gDf(x):V\bm{g}\circ Df(\bm{x}):V^* . Often people think of the dual space as linear functions that are "taking the inner product with a vector". So

g(x)=g,x. \bm{g}(\bm{x}) = \langle \bm{g},\bm{x} \rangle.

This is especially useful because we can then represent g\bm{g} as a row vector, which, when combined with a column vector from the nondual space, gives the inner product. This is the reason I've been so careful with keeping the gt of shape (1,n) because these g\bm{g}s are members of the dual space and are actually row vectors. From now on, I'll use g\bm{g}^{\intercal} as a reminder that it's a row vector.

Just to check that using the covectors as row vectors will work, let's examine our vector-Jacobian product

(gDf(x))(v)=(gDf(x))v=gDf(x)v (\bm{g}\circ Df(\bm{x}))(\bm{v})=(\bm{g}^{\intercal} Df(\bm{x}))\bm{v}=\bm{g}^{\intercal} Df(\bm{x})\bm{v}

Thus, with our newfound correspondence, let's evaluate the vector-Jacobian product (and employ chain rule). Let f=CBAf = C \circ B \circ A, then

Df(x)(g)=gDf(x)=g((DCBA)(x)(DBA)(x)DA(x))=gDC((BA)(x))DB(A(x))DA(x) \begin{align*} Df(\bm{x})^{\intercal}(\bm{g})&=\bm{g}^{\intercal} Df(\bm{x})\\ &=\bm{g}^{\intercal} \,((DC \circ B \circ A)(\bm{x})\,\,(DB \circ A)(\bm{x})\, DA(\bm{x}))\\ &=\bm{g}^{\intercal}\, DC((B \circ A)(\bm{x}))\,DB (A(\bm{x})) \, DA(\bm{x})\\ \end{align*}

Recall our definition from earlier:

Jt_f(x,g)=gDf(x). \texttt{Jt\_f}(\bm{x}, \bm{g}^{\intercal}) = \bm{g}^{\intercal}Df(\bm{x}).

Then

=gDC((BA)(x))DB(A(x))DA(x)=Jt_C((BA)(x),g)DB(A(x))DA(x) \begin{align*} &=\bm{g}^{\intercal} DC((B \circ A)(\bm{x}))\,\,DB (A(\bm{x}))\, DA(\bm{x})\\ &=\texttt{Jt\_C}((B\circ A)(\bm{x}), \bm{g}^{\intercal})\,DB (A(\bm{x}))\, DA(\bm{x})\\ \end{align*}

What we should notice here is that Jt_C((BA)(x),g)=gDC((BA)(x))\texttt{Jt\_C}((B\circ A)(\bm{x}), \bm{g}^{\intercal}) = \bm{g}^{\intercal}DC((B\circ A)(\bm{x})) is an element of the dual space of the output of BB. Thus, we can apply this strategy again, where we incorporate an element into a call to a Jt function:

=Jt_C((BA)(x),g)DB(A(x))DA(x)=Jt_B(A(x),Jt_C((BA)(x),g))DA(x)=Jt_A(x,Jt_B(A(x),Jt_C((BA)(x),g))) \begin{align*} &=\texttt{Jt\_C}((B\circ A)(\bm{x}), \bm{g}^{\intercal})\,DB (A(\bm{x}))\, DA(\bm{x})\\ &= \texttt{Jt\_B}(A(\bm{x}),\,\, \texttt{Jt\_C}((B\circ A)(\bm{x}), \bm{g}^{\intercal})) \,\, DA(\bm{x})\\ &= \texttt{Jt\_A}(\bm{x},\,\, \texttt{Jt\_B}(A(\bm{x}), \texttt{Jt\_C}((B\circ A)(\bm{x}), \bm{g}^{\intercal})))\\ \end{align*}

This is very busy notation, so to look at it better, let a=A(x),b=B(A(x))\bm{a}=A(\bm{x}), \bm{b}=B(A(\bm{x})) then the final expression becomes

Jt_A(x,Jt_B(a,Jt_C(b,g))). \texttt{Jt\_A}(\bm{x},\, \texttt{Jt\_B}(\bm{a},\, \texttt{Jt\_C}(\bm{b},\, \bm{g}^{\intercal}))).

This should hopefully justify the program I provided at the beginning of the post. Just by knowing the Jt functions of the primitives, we end up with the ability to convert g\bm{g}^{\intercal} into a covector of the input space of ff. Now let's reflect on what this means. If we choose g=[100]\bm{g}^{\intercal}=\begin{bmatrix} 1 & 0& \dots&0 \end{bmatrix}. Then we get back a row of the derivative:

gDf(x)=[100]Df(x)=[1f(x)2f(x)nf(x)] \begin{align*} \bm{g}^{\intercal}Df(\bm{x}) &=\begin{bmatrix} 1 & 0& \dots&0 \end{bmatrix}\, Df(\bm{x})\\ &=\begin{bmatrix} \partial_{1}f(\bm{x}) & \partial_{2}f(\bm{x})& \dots&\partial_{n}f(\bm{x}) \end{bmatrix} \end{align*}

(borrowing notation from Sussman, 1f\partial_{1}f is the partial derivative with respect to the first argument).

This is great because almost all loss functions in ML are scalar, so the Jacobian just has one row, and thus, we can efficiently compute the gradient for our descent.

You may be thinking, hold on a minute, the gradient is actually a linear form (in the case of a scalar loss function): we can't add that to x\bm{x}, and you're right. We should think of the gradient as something that can turn small changes of x\bm{x} into small changes in f(x)f(\bm{x}). When we're doing gradient descent, we want a small change in x\bm{x} that will maximally decrease f(x)f(\bm{x}) for a given step size. The small change in x\bm{x} that will cause the most increase is in fact the transpose of the gradient (this makes sense since the dot product is maximal when the vectors are parallel). So, it is okay to use the gradient for gradient descent after all.

It is important to note that the components of the Jacobian depend on the basis of the vector spaces of the domain and codomain.

Wrapping up

I hope this post helped you understand JAX backward-mode automatic differentiation. I wrote an implementation that automatically tracks the computations, so you don't have to write out the Jt_f() functions explicitly. This is one of my most technical posts. Let me know if any part of it needs better or more in depth explanation.

I wrote a follow-up post here, which should clarify what gradient descent actually is and why these Jacobian vector products help.

[1] morally in this sense

© Jason Eveleth 2023 · Powered by Franklin.jl · Last modified: December 31, 2024 Page Source