I set aside some time to figure out how to build neural networks from scratch in Clojure, without external libraries.
After a couple dependency-free versions, I ended up adding the neaderthal library to do faster matrix math. The different versions I wrote along the way are on github, in case they’re helpful for anybody else who wants to do this in Clojure.
First impressions and hello world
Neural networks are surprisingly easy to get started with. There’s significantly more “magic” inside a good concurrent queue implementation, for example, than inside a basic neural network to recognize handwritten digits.
For example, here’s the “hello world” of neural networks, a widget to recognize a hand-drawn digit:
See this widget at github.com/matthewdowney/clojure-neural-networks-from-scratch/tree/main/mnist-scittle
And here’s the code for the pixel array -> digit computation1:
(defn sigmoid [n] (/ 1.0 (+ 1.0 (Math/exp (- n)))))
(defn feedforward [inputs weights biases]
(for [[b ws :as _neuron] (map vector biases weights)]
(let [weighted-input (reduce + (map * inputs ws))]
(sigmoid (+ b weighted-input)))))
(defn argmax [numbers]
(let [idx+val (map-indexed vector numbers)]
(first (apply max-key second idx+val))))
(defn digit [pixels]
(-> pixels (feedforward w0 b0) (feedforward w1 b1) argmax))
It’s striking that such a complicated task works without intricate code or underlying black-box libraries.2 I felt kind of dumb for not having known this already!
Resources
The three most helpful resources for me were:
-
3Blue1Brown’s video series on neural networks, with visualizations and intuitive explanations. Good initial context.
-
Michael Nielsen’s neural networks and deep learning tutorial, which uses Python and numpy.
-
Andrej Karpathy’s intro to neural networks and backpropagation, which is pure Python (no numpy), and was kind of a lifesaver for understanding backpropagation.
In retrospect, to get started, I’d recommend reading the first part of Nielsen’s tutorial, skipping to the Andrej Karpathy video, and then solving MNIST from scratch using those two things as references, before coming back to the rest of Nielsen’s material.
I also went through Dragan Djuric’s impressive and erudite deep learning from scratch to GPU tutorial series, but I can’t say I’d recommend it as an introduction to neural networks.3
Approach in retrospect
I’m glad I decided to start from scratch without any external libraries, including ones for matrix math.
I do, however, wish I’d watched Andrej Karpathy’s video before getting so deep into Nielsen’s tutorial, especially because of the backprop calculus4, which I struggled with for a while. Karpathy’s REPL-based, algorithmic explanation was much more intuitive for me than the formal mathematical version.
My approach was to:
- First, build a neural network for the MNIST problem with no matrix math (nn_01.clj),
- Then, create a version with handwritten matrix math,
- Eventually, add the neanderthal library for matrix math in a third version,
- Finally, enhance performance with batch training in the fourth version.
The training time for one epoch of MNIST was 400 seconds in the first two versions, 5 seconds in the third (on par with the Python sample), and down to 1 second in final version.
I’m glad I broke it down like this. Would do again.
Before implementing the backprop algorithm, I built some unit tests for calculating the weight and bias gradients given starting weights and biases and some training data, and this turned out to be enormously helpful. I used Nielsen’s sample Python code to generate the test vectors.
Finally, invoking numpy via libpython-clj at the REPL was useful for figuring out the equivalent neanderthal expressions.
Basic things that I should have already known but didn’t
-
A neuron in a neural network is just a function
[inputs] -> scalar output, where the output is a linear combination of the inputs and the neuron’s weights, summed together with the neuron’s bias, and passed to an activation function. -
Much of the magic inside of neural network libraries has less to do with cleverer algorithms and more to do with vectorized SIMD instructions and/or being parsimonious with GPU memory usage and communication back and forth with main memory.
-
Neural networks can, theoretically, compute any function. And a more readily believable fact: with linear activation functions, no matter how many layers you add to a neural network, it simplifies to a linear transformation.
-
But, the activation function is not necessarily all that squiggly — ReLU is just
max(0, x)and it’s widely used.