Settings

Theme

Automatic Differentiation in 38 lines of Haskell

gist.github.com

154 points by ttesmer 3 years ago · 59 comments (58 loaded)

Reader

sordina 3 years ago

One of my favourite Haskell "one-liners" is combining the AD package with Number.Symbolic:

    {-# LANGUAGE ImportQualifiedPost #-}
    
    module Module_1663406024_9206 where
    
    import Numeric.AD qualified as Ad
    import Data.Number.Symbolic qualified as Sym
    
    -- >>> f x = x^2 + 3 * x
    -- >>> Ad.diff f 1
    -- >>> Ad.diff f (Sym.var "a")
    -- 5
    -- 3+a+a

    -- >>> Ad.diff sin pi
    -- >>> Ad.diff sin (Sym.var "a")
    -- -1.0
    -- cos a
The package authors did not need to coordinate to make this possible which is pretty wild.

Saw it first here: https://twitter.com/GabriellaG439/status/647601518871359489 and https://www.reddit.com/r/haskell/comments/3r75hq/comment/cwm...

  • cryptonector 3 years ago

    > The package authors did not need to coordinate to make this possible which is pretty wild.

    It works because `f` is polymorphic. The type of its `x` argument is not constrained in `f`'s definition, so you can plug in any `x` of any type you want provided that `x`'s type implements the methods used in `f`'s definition. With the `Dual` scheme you get to use as `x` a "dual" of `y` (`f x`, for some `f`) and `y'`, and then you get an `f` applied to that `x` where the actual `f` is parameterized by the actual `x`'s type, and so the methods called by `f` are those that apply to `x`'s type. So instead of the traditional numeric addition and multiplication, you'd get the "dual" addition and multiplication, and then everything "chains" through and you end up with `diff f x` being the `y'` in the dual of `y` and `y'` (you don't care about the `y`, just the `y'` because you want the `diff` -- the differential or derivative).

    It's brilliant.

    • magicalhippo 3 years ago

      I still don't get how sin ends up as cos, without any coordination.

      • krastanov 3 years ago

        Presumably the Ad package has a list of known derivatives. The Sym package now "automatically" uses it, without ever having to have known of it.

        The "coordination" is that they both use the "symbol" sin to refer to the idea of sine function.

        • shele 3 years ago

          Interesting that is is this way in Haskell in this example: In Julia, there are community/consensus-based processes to what the "idea" of a symbol is (or rather, the idea of a function with its current methods and its future methods defined for new types) and package interoperability often works in a similar way.

          • leephillips 3 years ago

            There is some confusion about what a symbol is in Julia due to bad tutorial information floating around the web. Here is Stefan Karpinski’s masterful explanation of what a Symbol really is:

            https://stackoverflow.com/a/23482257

          • krastanov 3 years ago

            As the sibling comment mentioned, this thread started using the word "symbol" in two different ways. There is the Julia/Lisp/Ruby use of the word symbol which is related to representing code in code and to homoiconicity. That is discussed in the sibling stack-overflow link.

            Earlier in this thread though "symbol" was used in a less formal, handwavy way as "a named handle to some concept in the language". In that sense, indeed, the way that Julia adds multiple methods to a single named function and supports multiple dispatch permits amazing interoperability (and has little to do with Julia/Lisp/Ruby symbol datastructures).

            • shele 3 years ago

              Right, my point is just: As new methods are added by the owners of new types it requires a common understanding and community consensus how a function should act on new types (i.e. about the "idea" of the function.)

        • magicalhippo 3 years ago

          Right, I was confused because for some reason I imagined "sin" coming from the symbolic library, but I'm assuming it's just built-in so AD knows about.

          • cryptonector 3 years ago

            The `sin` function comes from this bit at the end of TFA:

                instance VectorSpace d => Floating (Dual d) where
                  pi             = D pi zero
                  exp   (D u u') = D (exp u)  (scale (exp u) u')
                  log   (D u u') = D (log u)  (scale (log u) u')
              --->sin   (D u u') = D (sin u)  (scale (cos u) u')
                  cos   (D u u') = D (cos u)  (scale (-sin u) u')
                  sinh  (D u u') = D (sinh u) (scale (cosh u) u')
                  cosh  (D u u') = D (cosh u) (scale (sinh u) u')
            
            and the `sin` function on the right-hand side comes from `Float`, since `Float` is the type of the argument `u` in `sin u` in `D (sin u) (scale (cos u) u')`.
            • krastanov 3 years ago

              Not quite. This subthread is about the extremely short, one-line implementation mentioned here https://news.ycombinator.com/item?id=32882825 (which merges two unrelated modules (autodiff and symbolic) and uses autodiff to implement symbolic differentiation). Your comment is true for the original 38-line implemention of autodiff at the very top of the thread, but not in this subthread. The 38-line implementation is similar to the aforementioned autodiff module though.

              • pfortuny 3 years ago

                However, if the symbolic package knows that

                cos(a+b) == cos(a)cos(b) - sin(a)sin(b)

                then it works.

                Something must be known in advance about the relationship between sin and cos for addition, otherwise you cannot go from one to the other (and the basic

                cos(a)^2+sin(a)^2=1

                is not enough for that.

              • cryptonector 3 years ago

                Yes, indeed, I was referring to TFA, and obviously I was confused.

sterlind 3 years ago

This is an interesting approach. Haskell is not a symbolic language, but you take advantage of the abstractness of type parameters in function definitions to thread your implementation of "D x" through, and pattern match on that.

It's a neat design pattern. I bet it'd work in Julia too.

  • version_five 3 years ago

    For an example in julia, see Mike Innes tutorial: https://github.com/MikeInnes/diff-zoo

    I'm only a beginner in Julia and not and AD expert, but I went through the exercise of porting this to python and found it very enlightening

  • xiphias2 3 years ago

    Yes, but Julia has both forward and backward differention implemented (backwards it's harder).

    • ttesmerOP 3 years ago

      As I wrote in the Markdown file, there's also a usable package for Haskell called `ad` on Hackage. It has both forward and backward autodiff and prevents expression swell, among other things. This gist is just for illustration purposes.

  • bmitc 3 years ago

    > Haskell is not a symbolic language

    I'm not sure I understand what this means. What is a symbolic language that excludes languages like Haskell, F#, OCaml, etc.?

    • chowells 3 years ago

      A symbolic language is something like mathematica. Mathematica does evaluation as term rewriting with rules for how to rewrite all sorts of things - and if it can't find matching rules, it leaves the expression in the same symbolic form as the input.

      • bmitc 3 years ago

        You can do that with symbolic expressions in all the languages I listed.

        • sterlind 3 years ago

          In Mathematica, the equivalent of a "function call" would be something like f[x,y]. I can write and run that, and it won't give an error if f, x or y aren't defined. Fundamentally, all atoms in the language are terms of symbols. Expressions are broken down by patterns until they're reduced as far as they can be, then the language happily stops and returns how far it got.

          It's like Lisp, if everything in Lisp were quoted, and eval consisted of pattern matching and applying substitution rules.

          As an example, say in Haskell I have foo x y = x + 2/y. Can you write a function that takes any function in, and replaces all + with *? in Mathematica you can: foo[x_,y_] := x + 2/y; foo[x,y] /. l_ + r_ -> l*r.

          I'm sure you can do this stuff in other languages if you try hard enough. Some kind of reflection in Haskell or F#. In Lisp, grab and quote the definition of foo and apply macro machinery to it. But that's not the tao of those languages, while it is the essence of Mathematica.

          • chowells 3 years ago

            > Some kind of reflection in Haskell

            In Haskell you can only do that kind of manipulation at compile time. At run time, the source code is long gone. And even at compile time, expressions are a very different type than code that can be run. Haskell just isn't anything like a symbolic language.

            As pointed out way upthread, it provides lots of hooks for programmers to have nice syntax for entry points to symbolic systems - if someone writes one.

    • klipt 3 years ago

      Perhaps they mean not homoiconic like lisp or forth.

chrsig 3 years ago

one of the references is a talk by simon peyton jones on AD[0]

I'm neither a math expert nor a haskell expert, but I happen to enjoy both. It's been a while since I've watched a SPJ lecture, and I'd forgotten how much I want him to explain everything.

[0] https://www.youtube.com/watch?v=EPGqzkEZWyw

MrUssek 3 years ago

To be clear this is a forward mode auto diff implementation, not reverse mode, as might be inferred by the reference to the SPJ talk, correct?

mghwaz 3 years ago

How do you come up with code looking this good? I've been working in Haskell for 6 months professionally: currently - I wouldn't be able to come up with something like this in a day (or more).

  • chrsig 3 years ago

    generally this isn't something that you come up with in a day. it's distilled down and refined over time, you just see it when it reaches maturity.

    can't speak to the OPs process though, maybe they shat it out on a whim :)

tbensky 3 years ago

Fun exercise in Prolog too: https://www.codebymath.com/index.php/welcome/lesson/prolog-d....

bmitc 3 years ago

Forward-mode automatic differentiation is always fun to see because of the power but simplicity of the method. Any language with pattern matching makes it almost trivial to implement.

Although, I'm rarely interested in <such and such> in <n> lines of <language>. The more interesting things are overall conciseness with regards to the problem, the expressiveness, and the clarity that the code produces.

cryptonector 3 years ago

    log   (D u u') = D (log u)  (scale (log u) u')
That doesn't look like the derivative of the natural logarithm!
green_on_black 3 years ago

Curious: I don't see the `^` op defined, or is it translated inti `exp` im guessing?

  • ttesmerOP 3 years ago

    Since the type was made an instance of the Num typeclass, any function that can be used with Num's, can now be used on the type (Dual d). As per the Prelude[1], ^ is part of the Num typeclass. Same thing for * for Floating[2]. The hyperbolic tangent can also be used without being explicitly coded, as it can be derived using cosh and sinh!

    EDIT: As for the differentiation, it works for ^ since it is just multiplication (https://hackage.haskell.org/package/base-4.17.0.0/docs/src/G...) for which the derivative was defined using the product rule.

    [1]: https://hackage.haskell.org/package/base-4.17.0.0/docs/Prelu... [2]: https://hackage.haskell.org/package/base-4.17.0.0/docs/Prelu...

  • tikhonj 3 years ago

    The ^ operator is defined in Haskell's standard library for raising values of any numeric type to non-negative, integral powers. Conceptually a ^ n just expands to a * a * ... * a, but the actual code is a bit more complex[1] for performance reasons.

    The neat thing with this approach is that ^ works for any numeric type, including user-defined types like Dual in this example. Since the Dual type can handle calculating derivatives for *, it gets derivatives for ^ for free.

    [1]: https://hackage.haskell.org/package/base-4.17.0.0/docs/src/G...

    • green_on_black 3 years ago

      Ah, `^` being naturals-only makes more sense in terms of how it could work with no special logic!

      • lalaithion 3 years ago

        Haskell has three exponentiation operators:

            (^) :: (Num a, Integral b) => a -> b -> a
            (^^) :: (Fractional a, Integral b) => a -> b -> a
            (**) :: Floating a => a -> a -> a
        
        The first one has the un-typed requirement that `b` be nonnegative, and basically allows any value with multiplication (implements the Num typeclass) to be raised to a natural power.

        The second one allows any value with multiplication and multiplicative inverses (Fractional) to be raised to an integral power.

        The third one allows any value which supports exponents, trig functions, and logarithms to be raised to a power with the same type. `*` is in fact one of the functions one must implement to implement the Floating typeclass.

  • cryptonector 3 years ago

    What Haskell calls a "class" is what Java calls an "interface". What Haskell calls an "instance" of an interface is what Java would call a class implementing that interface.

    `Num`, then, is a Haskell class (Java interface).

    The `Num` class will have a bunch of what Java would call "default methods".

    Now, the "instance" of `Num` defined here has only a few methods defined, but the other default methods of `Num` will use those. So if `Num` has a `^` defined in terms of ``, and you define an instance of `Num` that defines ``, then you get `^` for free if you don't implement it.

  • ayjtyjtyjgfjc 3 years ago

    It is one of three exponentiation operators in standard Haskell, no translation involved. https://stackoverflow.com/a/6400630/14768587

mountainriver 3 years ago

Being purely functional makes this quite easy, still a beauty to see

  • sterlind 3 years ago

    it's not just the functionalness. you couldn't do it this way in Lisp/Scheme (I think?) because of the lack of multiple dispatch.

    If you did (f 'x) for instance, you'd end up with things like (* 2 'x) which would blow up, since Lisp would try to compute the answer instead giving you '(* 2 x) back.

    • cryptonector 3 years ago

      The Common Lisp Object System has multiple dispatch.

      • patrec 3 years ago

        But math operations are not generic functions. So you'd need to create your own version of all math functions. And that won't compose with independently developed libraries (unlike the Haskell example elsewhere in the thread).

    • jhgb 3 years ago

      I'm sure that systems like ScmUtils wouldn't have problems with this.

naasking 3 years ago

For those less familiar with Haskell, here's a C# implementation of dual numbers comparable to this:

https://github.com/naasking/AutoDiffSharp/blob/master/AutoDi...

quickthrower2 3 years ago

Is the Float' type there to not pollute the Float type with the defined typeclasses? Or is there something I am missing?

  • bradrn 3 years ago

    I have no clue. It looks totally redundant to me too.

    • ttesmerOP 3 years ago

      You're absolutely right, it is totally redundant. You could reduce the entire thing even more in size by doing only `D Float Float`, removing the `TypeSynonymInstances` pragma and removing the `VectorSpace`, instead using normal Float operations for the derivatives in the Num, Floating and Fractional instances. Further, you can remove the `diff` function entirely if you want, since it's doing something very simple; calling `f` using a `D` instead of a normal Float, Int, etc. So you would simply get the function applied to x and it's derivative by doing `(\x -> x^2) (D 2 1)`, setting the derivative to 1 and x=2.

      However, I think the Float' and `diff` etc. is at least a little helpful in understanding it. I got it from SPJ's talk, which I linked to in the file. Also, it makes it easier (e.g. in the case of `diff`) to later add onto the Autodiff, for example by implementing reverse mode, Jacobians, etc.

melony 3 years ago

Where's the vector Jacobian matrix?

  • omnicognate 3 years ago

    You don't need to explicitly form a Jacobian matrix to do AD. The key insight is that the matrix is just a representation of a linear function. In forward mode AD you evaluate the linear function the Jacobian represents (i.e. the derivative) at the same time as evaluating the function you are differentiating, usually without ever explicitly building the matrix. The derivative is a local linear approximation of the original function, so it composes the same way. This allows you to compute the parts of it as you compute the parts of the original function, and pass on the intermediate results alongside the intermediate results of the original calculation.

mhh__ 3 years ago

It's not as pretty but forward mode AD isn't that many lines even in C

Keyboard Shortcuts

j
Next item
k
Previous item
o / Enter
Open selected item
?
Show this help
Esc
Close modal / clear selection