In a previous post, we explored techniques for visualizing high-dimensional data. Trying to visualize high dimensional data is, by itself, very interesting, but my real goal is something else. I think these techniques form a set of basic building blocks to try and understand machine learning, and specifically to understand the internal operations of deep neural networks.
Deep neural networks are an approach to machine learning that has revolutionized computer vision and speech recognition in the last few years, blowing the previous state of the art results out of the water. They’ve also brought promising results to many other areas, including language understanding and machine translation. Despite this, it remains challenging to understand what, exactly, these networks are doing.
I think that dimensionality reduction, thoughtfully applied, can give us a lot of traction on understanding neural networks.
Understanding neural networks is just scratching the surface, however, because understanding the network is fundamentally tied to understanding the data it operates on. The combination of neural networks and dimensionality reduction turns out to be a very interesting tool for visualizing high-dimensional data – a much more powerful tool than dimensionality reduction on its own.
As we dig into this, we’ll observe what I believe to be an important connection between neural networks, visualization, and user interface.
Not all neural networks are hard to understand. In fact, low-dimensional neural networks – networks which have only two or three neurons in each layer – are quite easy to understand.
Consider the following dataset, consisting of two curves on the plane. Given a point on one of the curves, our network should predict which curve it came from.
A network with just an input layer and an output layer tries to divide the two classes with a straight line.
In the case of this dataset, it is not possible to classify it perfectly by dividing it with a straight line. And so, a network with only an input layer and an output layer can not classify it perfectly.
But, in practice, neural networks have additional layers in the middle, called “hidden” layers. These layers warp and reshape the data to make it easier to classify.
We call the versions of the data corresponding to different layers representations.1 The input layer’s representation is the raw data. The middle “hidden” layer’s representation is a warped, easier to classify, version of the raw data.
Low-dimensional neural networks are really easy to reason about because we can just look at their representations, and at how one representation transforms into another. If we have a question about what it is doing, we can just look. (There’s quite a bit we can learn from low-dimensional neural networks, as explored in my post Neural Networks, Manifolds, and Topology.)
Unfortunately, neural networks are usually not low-dimensional. The strength of neural networks is classifying high-dimensional data, like computer vision data, which often has tens or hundreds of thousands of dimensions. The hidden representations we learn are also of very high dimensionality.
For example, suppose we are trying to classify MNIST. The input representation, MNIST, is a collection of 784-dimensional vectors! And, even for a very simple network, we’ll have a high-dimensional hidden representation. To be concrete, let’s use one hidden layer with a hundred sigmoid neurons.
While we can’t visualize the high-dimensional representations directly, we can visualize them using dimensionality reduction. Below, we look at nearest neighbor graphs of MNIST in its raw form and in a hidden representation from a trained MNIST network.