On this page:
2.1 Background
2.1.1 Derivatives
2.1.2 Composition and the chain rule
2.1.3 The reverse transform
2.2 Reverse transform API
2.2.1 Specifying reverse transformations
2.2.2 Interface
grad
grad1
D
D+
lift/  D+
define/  D+
2.2.3 Reverse-transformed procedure results
proc-result
primal
backprop
2.2.4 Handling functions unknown as primitives
current-unknown-transform
error-unknown-transform
error-unknown-proc-transform
error-non-zero-sensitivity-transform
8.3

2 Automatic differentiation

    2.1 Background

      2.1.1 Derivatives

      2.1.2 Composition and the chain rule

      2.1.3 The reverse transform

    2.2 Reverse transform API

      2.2.1 Specifying reverse transformations

      2.2.2 Interface

      2.2.3 Reverse-transformed procedure results

      2.2.4 Handling functions unknown as primitives

This section first give some general background on reverse-mode automatic differentiation, and then describes the interface provided by this library. The implementation is largely based on Pearlmutter and Siskind (2008).

2.1 Background

2.1.1 Derivatives

This section contains a brief recap of differentiation—before we get to automatic differentiation. It largely follows the introductory sections of Elliot (2018).

Suppose A and B are finite-dimensional vector spaces, and that we have a function

f : A \to B.

The derivative of f (if it exists) is a map

df : A \to A \to B

associating a linear map from A to B with each element of A. We say that df(x) is the derivative of f at the point x.

You may be used to thinking of a derivative as a number, perhaps written \frac{df}{dx} or f^\prime(x), or as a matrix J_{ij} = \frac{\partial f_i}{\partial x_j}. This presentation is equivalent, but will have several advantages for our purposes.

For example if f(x) = 5x, the derivative (as defined above) of f is the linear map

df(x) = \Delta x \mapsto 5 \Delta x,

and we understand notation such as \frac{df}{dx} = 5 to indicate the coefficients of this map.

Sometimes it is convenient to choose bases for A and B, and represent this linear map as a matrix (which is then known as the Jacobian), but sometimes another representation is preferable.

We will keep the representation of linear maps as functions. Eventually, just as we think of a (mathematical) function f being implemented (or approximated by) a Racket function, its derivative will be a Racket function too.

It is often necessary, eventually, to turn this function into a numerical representation. Evaluating the linear map gives a directional derivative (in terms of Jacobians, this would be a Jacobian-vector product). Evaluating it for each element of a basis of A allows us to reconstruct the whole Jacobian. Notice that to reconstruct the Jacobian of f at x we would need to make \dim{A} evaluations of df(x), regardless of the dimensionality of B.

Example

f : \mathbf{R}^3 \to \mathbf{R}^2 \\ f(x,y,z) = (z + 1, xy)

then

\begin{split} d&f(x,y,z) =\\ &(\Delta x, \Delta y, \Delta z) \mapsto (\Delta z, y \Delta x + x \Delta y) \end{split}

The directional derivative, in the direction (1,0,0), is

df(x,y,z)(1,0,0) = (0,y)

Evaluating the map at the standard basis vectors (1,0,0), (0,1,0) and (0,0,1) gives the Jacobian matrix:

\begin{pmatrix} 0 & 0 & 1 \\ y & x & 0 \end{pmatrix}

The adjoint map is

\begin{split} D&f(x,y,z) =\\ &(\Delta u, \Delta v) \mapsto (y\Delta v, x\Delta v, \Delta u) \end{split}

where \Delta u and \Delta v are sometimes known as sensitivity variables. The function is said to map output or result sensitivities to argument sensitivities.

Notice that it takes just two evaluations of the adjoint map, at (1,0) and (0,1), to obtain the same Jacobian as above, at a cost of two multiplications per evaluation in each case.

A case that is often useful in practice is when A is very high dimensional, and B is \mathbf{R}. Loss functions in optimization problems have this form, for example.

Handling this case more efficiently is the motivation for reverse-mode AD, which is based on the following idea.

If we further insist that A and B are both equipped with an inner product, we can obtain the adjoint of a linear map L : A \to B, which is another linear map L^* : B \to A. This allows us to define

Df : A \to B \to A\\ Df(x) = df(x)^*.

If df(x) can be represented by the Jacobian matrix J, then the matrix representation of Df(x) is the transpose of the Jacobian, J^T.

Particularly when referring to its implementation in code, we call Df(x) the backpropagator of f at x.

Returning to the case we considered above, of f : A \to \mathbf{R}, it would be possible to reconstruct the Jacobian from a single evaluation of the linear map Df(x) : \mathbf{R} \to A.

The gradient of f is

\nabla f : A \to A\\ \nabla f(x) = Df(x)(1)

assuming the usual inner product on R.

2.1.2 Composition and the chain rule

Our goal is to be able to differentiate as many (Racket) functions as possible. In some cases, we will be content with explicitly providing a function Df that computes the derivative of another function f, and associating them somehow. It would be unsatisfactory if we had to do this for every f, though, so we seek a way of determining the derivative of a function, from its definition in terms of other functions. The ability to do this is the main selling point of of automatic differentiation. The primary way (and in some sense, the only way) that this is achieved is via the chain rule.

The chain rule allows derivatives of compositions of functions to be related to compositions of their derivatives. The chain rule can be expressed in terms of d or D:

d(g \circ f)(x) = dg(f(x)) \circ df(x)

D(g \circ f)(x) = Df(x) \circ Dg(f(x)).

Notice the ‘reverse’ order of composition in the right hand side of the equation immediately above.

We will focus on D for the rest of the section, but similar considerations would apply to d. Notice too that for both rules, we need to know f(x) to express the derivative of the composition (not merely Df). There is often some shared work involved in computing Df(x) and f(x), but this is not apparent from the usual chain rule, and an interface based on this would not let us take advantage of it.

Instead, define

D^+f(x) = (f(x), Df(x))

and now

D^+(g\circ f)(x) = \big(g(f(x)), Df(x) \circ Dg(f(x))\big).

Notice that D^+(g\circ f) can now be expressed in terms of D^+g and D^+f.

2.1.3 The reverse transform

The mapping

f \mapsto D^+f

as implemented in code is the central operation of reverse-mode AD.

Why reverse transform?

Notice the composition rule above: Roughly, whereas data flows ‘forwards’ through the composition g \circ f, the derivatives of f and g are composed in the opposite order, and so data flows ‘in reverse’ through them.

TODO a picture would help here! Since the output of the reverse transform combines the function value and its derivative, data must in fact flow both ways. The idea is to store each function evaluation on the way ‘forward’, to be consumed by the appropriate derivative computation on the way back again.

This description is far from complete. Handling variable assignment (and repeated use of a variable) as well as mutable state have been omitted, as have many technical details needed for a practical implementation.

2.2 Reverse transform API

The previous section defined the reverse transformation as a mapping f \mapsto D^+ f. This section describes how it applies to Racket code. The macros D+ and lift/D+ perform transformations similar to this one; D and grad are provided as a convenience. Of these, lift/D+ is fundamental.

When differentiating an expression, each procedure that is encountered is replaced with one that computes both the primalthe undifferentiated function value, and a backpropagatora linear function taking an output sensitivity and returning the argument sensitivities, called when computing derivative values. The process of replacing the function with its primal and backpropagator is known as reverse transformation.

In the example below, the reverse transformation of * is obtained with lift/D+. The primal and backpropagator are returned in a proc-result struct.

> (define result ((lift/D+ *) 4.0 2.5))
> (primal result)

10.0

> (backprop result)

#<procedure:...ator/primitives.rkt:87:2>

Procedures whose definitions occur within the expression being differentiated can be transformed automatically by the library. Any procedure that is used but not defined within the expression must also be replaced with its reverse transform. Such procedures are known as primitives, and include, for example, arithmetic operations. They must have backpropagators that are known in advance.

In this library, the backpropagator of a function takes two arguments: the result sensitivity, which should conform to the value returned by the function, and the box sensitivities, which will be explained below. The result of evaluating a backpropagator is a list containing

The box sensitivity argument to a backpropagator is the way sensitivities of mutable data structures are handled. This is a hash table (satisfying (and/c hash? hash-eq? (not immutable?))) mapping a mutable data structure to its corresponding sensitivity value. The value in the hash table with a given mutable data structure as its key can be updated by the backpropagator of a function that refers to an element of the data structure.

Continuing the example above,
> ((backprop result) 1.0 (make-hasheq))

'(() 2.5 4.0)

Notice the empty hash table passed as the second argument, and the first element of the resulting list, with a list of closed-over variable sensitivities (in this case there are none, so the list is empty).

Alternatively, use D+ to avoid the empty hash table argument, and to drop the closure sensitivities:

> (define result ((D+ *) 4.0 2.5))
> (primal result)

10.0

> ((backprop result) 1.0)

'(2.5 4.0)

2.2.1 Specifying reverse transformations

When specifying a reverse transform, it should have the form described above (as returned by lift/D+). Here is the reverse transformation of two-argument multiplication:

(λ (x y)
  (proc-result
   (* x y)
   (λ (Aw Abox)
     (list '() (scale Aw y) (scale Aw x)))))

and of exp:

(λ (x)
 (let ([exp-x (exp x)])
   (proc-result
     exp-x
     (λ (Aw Abox)
       (list '() (scale Aw exp-x))))))

Backpropagator definitions should allow for the fact that the result sensitivity may be passed a value of gen-zero (hence the use of scale). See Linear generic interface.

The reverse transform of a binding must be provided when registering a binding as a new primitive with register-primitive!, or by the require/primal+backprop mechanism. It can subsequently be used in functions defined with define/D+.

2.2.2 Interface

syntax

(grad expr)

A function that evaluates to the gradient of expr at the given arguments.

The result of evaluating expr must be a procedure.

This form evaluates to a function of the same arity, that when called returns the gradient (represented as described above)—in general, it will be a list whose length is the number of arguments passed, and whose elements conform with the corresponding arguments.

The first form is equivalent to the second with 1.0 passed as the value of result-sensitivity.

Example:
> ((grad (lambda (x y) (+ (* x x) y))) 2.0 3.0)

'(4.0 1.0)

syntax

( expr)

The same as grad.

syntax

(grad1 expr)

Like grad, but for functions of arity one. When grad would evaluate to a list holding a single element, this form evaluates to the element without the list wrapper, which may be more convenient.

Example:
> ((grad1 (grad1 cos)) 0.0)

-1.0

syntax

(D expr)

Like D+, except the resulting function returns only the backpropagator.

Examples:
> (define/D+ (f x y)
    (vector->immutable-vector
     (vector (* x x) y)))
> (((D f) 2.0 3.0) #(1.0 0.0))

'(4.0 0.0)

; An error: sensitivity does not conform with the result:
> (((D f) 2.0 3.0) '(1.0 0.0))

raise-argument-error: contract violation

  expected: exact-nonnegative-integer?

  given: '(1.0 0.0)

TODO Fix unhelpful error message

syntax

(D+ expr)

Like lift/D+, except that:

syntax

(lift/D+ expr)

Reverse transform the expression expr. The result is a function, that when called returns a proc-result struct containing the primal at the given arguments, and a backpropagator for the same arguments.

The backpropagator is the two argument form described above. The first argument is the result sensitivity, and the second must be a mutable hash table (satisfying (and/c hash? hash-eq? (not/c immutable?))).

The resulting function is of the correct form to pass to derivatives of higher-order functions.

Example:
> ((grad foldl) (lift/D+ *) 1 '(2 3 4))

'(() 24.0 (12.0 8.0 6.0))

Using it directly:

> (define D+f (lift/D+ (lambda (x) (set! x (* x x)) x)))
> (define result (D+f 2.0))
> (primal result)

4.0

> ((backprop result) 1.0 (make-hasheq))

'(() 4.0)

syntax

(define/D+ id expr)

(define/D+ (id args ... . rest-args)
   body ...+)
Define a new primitive in terms of others.

Similarly to define define, bind id to the result of evaluating expr in the first case, or to a procedure in the second case—note that this form does not support the curried function definition shorthand of define.

In addition, the reverse transform of expr or body is determined, and id registered as a primitive. Recursive definitions are allowed.

2.2.3 Reverse-transformed procedure results

struct

(struct proc-result (primal backprop)
    #:transparent)
  primal : any/c
  backprop : procedure?

procedure

(primal r)  any/c

  r : proc-result?

procedure

(backprop r)  procedure?

  r : proc-result?
A two member structure for holding the return value of a reverse-transformed procedure.

proc-result-primal and proc-result-backprop are also provided under the shorter aliases primal and backprop.

2.2.4 Handling functions unknown as primitives

During reverse transformation, an identifier may be encountered that is not registered as a primitive. In this case, it is transformed to the result of calling the procedure stored as the value of the parameter current-unknown-transform. In general, the job of this procedure is to raise an error, but a few other cases where the result is known may also be handled.

The default value of current-unknown-transform is error-non-zero-sensitivity-transform, that raises an error only if an attempt is made to call the unknown backpropagator with a non-zero sensitivity argument. This permits code paths that do not contribute to the derivative (e.g. error handling, tests in conditionals) without having to register (perhaps meaningless) derivative information for every function that is called.

error-unknown-transform and error-unknown-proc-transform can be very useful for debugging.

The reverse transform to apply to a value not registered as a primitive.

Example:
> (define/D+ (car-or-void v)
    (when (pair? v)
      (car v)))
> ((grad car-or-void) '(1.0 2.0 3.0))

'((1.0 0.0 0.0))

> ((grad car-or-void) 123.0)

'(0.0)

> (parameterize ([current-unknown-transform error-unknown-transform])
    ((grad car-or-void) '(1.0 2.0 3.0)))

lift/D+: Backpropagator unknown

  op: 'pair?

procedure

(error-unknown-transform op op-name)  any

  op : any/c
  op-name : any/c
Use as a value for current-unknown-transform.

The resulting procedure will be used as the reverse transform of op. It will unconditionally raise an error when called.

procedure

(error-unknown-proc-transform op op-name)  procedure?

  op : any/c
  op-name : any/c
Use as a value for current-unknown-transform.

The resulting procedure will be used as the reverse transform of op. It will raise an error when op is a procedure, otherwise op is returned.

procedure

(error-non-zero-sensitivity-transform op    
  op-name)  procedure?
  op : any/c
  op-name : any/c
Use as a value for current-unknown-transform.

The resulting procedure will be used as the reverse transform of op.

When op is non-procedure value, op is returned.

When op is a procedure, attempt to construct a reverse transform for it, whose primal is the result of evaluating the procedure, and whose backpropagator raises an error when called, unless (gen-zero) is passed (the result of the backpropagator is then also (gen-zero)).