Exact numeric nth derivatives
Automatic differentiation is a well-studied technique for computing exact numeric derivatives. Dan Piponi has a great introduction on the subject, but to give you an overview, the idea is that you introduce an algebraic symbol such that but . Formulas involving are called dual numbers (e.g., , ) in much the same way as complex numbers are formulas involving the algebraic symbol , which has the property .
Then you teach the computer how to add, subtract, multiply and divide with dual numbers. So for example, (since ). The computer keeps everything in “normal form,” i.e., , as you go along.
In order to find the derivative of some function at a point , all you have to do is compute . The answer you get (in normal form) is
So for example, let and let’s find . To do this, we compute :
We expected . Equating these two results, we can conclude that and . Which is true!
This works not just for simple polynomials, but for any compound or nested formula of any kind. Somehow dual numbers keep track of the derivative during the evaluation of the formula, respecting the chain rule, the product rule and all.
The reason this works becomes clearer when you consider the Taylor series expansion of a function:
When you evaluate , all the higher-order terms drop out (because ) and all you’re left with is .
Implementing dual numbers
One interesting way to implement dual numbers is with matrices. The number can be encoded as
This encoding has the properties and :
Furthermore, they add, multiply and divide like dual numbers.
This is nice because if you already have a library for doing matrix math, implementing automatic differentiation is trivial.
Higher-order derivatives
Many of the papers on automatic differentiation point out that this technique generalizes to second derivatives or arbitrary nth derivatives, but I haven’t found a good explanation of how that works. So, here’s my attempt.
To compute second derivatives, we need to carry out the Taylor series expansion one step further. So instead of , we need and . Then we get numbers of the form . When we want to find the first and second derivatives of some function at , we compute as before, but now we get
since the and higher terms drop out.
Now all we need to do is find a mathematical object that behaves this way. It turns out there’s a matrix for this too:
It encodes the properties we want:
Pretty neat!
If you want 3rd derivatives, you need a 4 x 4 matrix:
So we have
You can extend this technique to any order derivative you want.
Code
OK, enough math, let’s code it up. The Dual
class will represent a dual number. The underlying representation is a
square upper triangular
diagonal-constant matrix. Unlike typical matrix libraries, I’m not
going to store the cell values as a List[List[Double]]
or anything. Instead I’m just going to have a method
get(r: Int, c: Int): Double
that returns the value in the specified cell. Also I’ll memoize it, so yeah, I guess I
am storing the cell values somewhere. But this technique makes all the matrix operations easier to write.
Now the usual matrix operations.
Division is implemented as multiplication by the inverse: . Matrix inverses are annoying to find and may not exist in the general case, but thankfully all we’re dealing with are square upper triangular matrices, which are a lot easier to invert. Actually, we have something even better: all of our matrices are of the form , where is a nilpotent matrix, meaning for some .
It turns out that for any nilpotent matrix ,
à la the familiar algebraic identity
So now we can find the inverse as follows:
where . Here’s the code:
Finally, some utility methods:
Now we need concrete classes representing 1 and :
Let’s try it out. Suppose we want to find the first 5 derivatives of at .
scala> val one = new I(6)
i: I = 1.0 0.0 0.0 0.0 0.0 0.0
scala> val e = new E(6)
e: D = 0.0 1.0 0.0 0.0 0.0 0.0
scala> def f(x: Dual): Dual = x.pow(4)
f: (x: Dual)Dual
scala> f(one*2 + e)
res0: Dual = 16.0 32.0 24.0 8.0 1.0 0.0
This is . The coefficient of will be . And it checks out:
How about the first 8 derivatives of at ?
scala> val one = new I(9)
i: I = 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
scala> val e = new E(9)
e: D = 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
scala> def g(x: Dual): Dual = x.pow(2) * 4 / (one - x).pow(3)
g: (x: Dual)Dual
scala> g(one*3 + e)
res1: Dual = -4.5 3.75 -2.75 1.875 -1.21875 0.765625 -0.46875 0.28125 -0.166015625
OK, let’s just check . I’m gonna use Wolfram Alpha for this, because… yeah.
Neat!
Conclusion
This is a pretty amazing technique. Instead of computing a difference quotient with tiny values of , which is prone to all sorts of floating-point rounding errors, you get exact numerical derivatives. In fact you get as many higher-order derivatives as you want, simultaneously. So, you almost never need to do symbolic differentiation.
Of course, for this to be really useful, I’d have to implement more than just the standard arithmetic operations on dual numbers. I’ll also want be able to compute or or . There are certainly ways to do this! But maybe it’s a topic for another post.
All of the code in this post is available in this gist.
blog comments powered by Disqus