On this page:
Rackpropagator:   Reverse-mode automatic differentiation of Racket programs
8.3

Rackpropagator: Reverse-mode automatic differentiation of Racket programs

Oliver Strickson <o.strickson@gmail.com>

Rackpropagator provides an automatic differentiation facility for a subset of Racket. It uses ‘reverse mode’ differentiation—sometimes known as backpropagation.

The source code of this package is hosted at https://github.com/ots22/rackpropagator.

This library is a work in progress: expect breaking changes, bugs, and a number of limitations.

A particular caution: Currently, performance of the generated code can be very poor in some situations.

Example:
> (define/D+ (square x)
    (* x x))
> (square 10.0)

100.0

> ((grad square) 10.0)

'(20.0)

Reverse-mode automatic differentiation is best suited for computing gradients of functions where the dimension of the domain is large, and larger than the dimension of the codomain. Functions to a single number (from any domain) can have their gradient computed in a single pass.

Example:
> (define/D+ (list-norm2 xs)
    (sqrt (apply + (map * xs xs))))
> (list-norm2 '(1.0 1.0))

1.4142135623730951

; ∇ is an alias for grad
> (( list-norm2) (range 10))

'((0

   0.05923488777590923

   0.11846977555181847

   0.1777046633277277

   0.23693955110363693

   0.29617443887954614

   0.3554093266554554

   0.4146442144313646

   0.47387910220727386

   0.533113989983183))

> (list-norm2
   (car (( list-norm2) (range 10))))

1.0

Reverse-mode automatic differentiation is implemented in this library as a source transformation, at macroexpansion time. This means that the transformed code can be compiled and optimized as with any other function definition, and so—in principle—it can offer similar performance to a hand-coded derivative.

Differentiable functions can be defined and composed using the library, and their derivatives taken to any order (again, in principle, but see Limitations). Derivatives of functions that close over variables behave as expected, as does mutation.

The system is extensible: While a number of differentiable ‘primitives’ are pre-defined, it is possible to register new ones. Reasons for wanting to do this might include supplying derivatives of functions where it is inconvenient or impossible to define these in terms of existing primitives (e.g. to support a new array library, or a function imported via ffi from a numerical library), or for performance: if a particularly efficient implementation of a derivative is available, this can be registered and directly used. Derivatives with respect to numerical arguments can be taken, as well as vectors, lists, nested and improper lists of any shape, and this too can be extended to any type that can be given a linear structure. Support for Arrays is work in progress.

Rackpropagator supports a subset of Racket. For details, see Supported Language.

It is possible to differentiate through recursion and many control structures:

> (define/D+ (pow x n)
    (if (= n 0)
        1.0
        (* x (pow x (- n 1)))))
> (pow 2.0 3)

8.0

> ((grad pow) 2.0 3)

'(12.0 0)

Note the type of the result of the following gradient computation, which is a list containing a single list. The function takes a single argument, which is a list, and the gradient returned is with respect to each element:

> (define/D+ (sum-positives lst)
    (for/sum ([elt (in-list lst)]
              #:when (positive? elt))
      elt))
> (define the-list '(1.0 -2.0 3.0 -4.0 5.0))
> (sum-positives the-list)

9.0

> ((grad sum-positives) the-list)

'((1.0 0.0 1.0 0.0 1.0))

Mutation is also supported:

> (define/D+ (pochhammer3 x0)
    (define next
      (let ([x x0])
        (lambda ()
          (begin0
              x
            (set! x (+ x 1))))))
    ; only binary multiplication allowed currently
    (* (next) (* (next) (next))))
> ((grad pochhammer3) 3.0)

'(47.0)

If the required derivative information is not available for a binding then an error occurs—note that double in the example below is defined with racket’s define, rather than define/D+.

> (define (double x) (+ x x))
> (define/D+ (quadruple x) (double (double x)))
> ((grad quadruple) 4.0)

lift/D+: Backpropagator unknown

  op: 'double

By default, the error only occurs if an attempt is made to call the unknown backpropagator. This permits code paths that do not contribute to the derivative (e.g. error handling, tests in conditionals) without having to register (perhaps meaningless) derivative information for every function that is called. This behaviour can be customised with current-unknown-transform.

The examples above have all demonstrated finding the gradient of functions from several arguments to a real number. Use D to obtain a backpropagator, and supply a sensitivity, to allow other codomains:

> (define <-result ((D cons) 3.0 4.0))
> (<-result '(1.0 . 0.0))

'(1.0 0.0)

> (<-result '(0.0 . 1.0))

'(0.0 1.0)

If both the primal and gradient are required at the same arguments, use D+:

> (define result ((D+ cons) 3.0 4.0))
> (primal result)

'(3.0 . 4.0)

> (define <-result (backprop result))
> (<-result '(1.0 . 0.0))

'(1.0 0.0)

> (<-result '(0.0 . 1.0))

'(0.0 1.0)