Introduction

Here, I explore some geometric intuition of layer normalization, that are found in transformers. I’ve been in general trying to “imagine” operations on transformers. My “visualization” of multi-head attention largely follows the insights from A Mathematical Framework for Transformer Circuits, while I imagine the feedforward neural network part as “warping the space” of embedding vectors. I struggled a bit with layer normalization until I realized that it had a quite nice interpretation. I’m sharing this here in case anyone is interested.

Visualizing layer normalization

In a neural network, the layer norm operation on a vector \(x \in \mathbb{R}^n\) is the operation defined by

\[\operatorname{lnorm}(x;a,b) = a \odot \frac{x - \operatorname{mean}(x)}{\operatorname{std}(x) + \epsilon} + b \\ \quad \\ \operatorname{mean}(x) = \frac{1}{n} \sum_{i=1}^n x_i \\ \quad \\ \operatorname{std}(x) = \sqrt{\frac{1}{n} \sum_{i=1}^n (x_i - \operatorname{mean}(x))^2}\]

where \(a, b \in \mathbb{R}^n\) are parameters of the layer norm, and \(\epsilon\) is an small stability factor (for instance \(\epsilon = 10^{-12}\) is valid). Here, \(\odot\) denotes the element-wise product. Layer normalization is more interpretable if it can be seen as the composition of three functions

\[\operatorname{lnorm}(x;a,b) = h(g(f(x));a, b) \\ \quad\\ f(x) = \left(I_{n} - \frac{1}{n} \mathbf{1}_{n \times n} \right) x \\ \quad \\ g(x) = \sqrt{n} \frac{x}{\|x\|_2 + \epsilon \sqrt{n}} \\ \quad \\ h(x;a,b) = \operatorname{diag}(a)x+b ,\]

where \(I_{n}\) is the \(n\)-dimensional identity matrix, and \(\mathbf{1}_{n \times n}\) is the \(n\)-dimensional matrix of ones. We can interpret each of these in turn.

We begin with \(f\). First, consider the matrix \(P = \frac{1}{n} \mathbf{1}_{n \times n}\). It is not hard to see both the vector of ones \(\mathbf{1}_n\) is an eigenvector of \(P\) with eigenvalue \(\lambda_1 = 1\). Moreover, the null space of \(P\) must be of dimension \(n-1\), so it forms an eigenspace with eigenvalue being 0, that is orthogonal to \(\mathbf{1}_n\) (since \(P\) is symmetric). Therefore, \(P\) is an orthogonal on the 1-dimensional subspace \(\{x; x_1=x_2=\ldots=x_n\}\), generated by \(\mathbf{1}_n\). It follows that \(I - P\) is an orthogonal projection on the \((n-1)\)-dimensional subspace orthogonal to \(\mathbf{1}_n\), which we denote by \(U\). Therefore, \(f\) is an orthogonal projection into \(U\).

Going on with \(g\), we have that, if \(\|x\|_2 \gg \epsilon\), which will be true except for vectors very close to the origin, \(g = \sqrt{n} x/\|x\|_2\) is just a projection into a \((n-1)\)-dimensional sphere of radius \(\sqrt{n}\), which we denote \(\sqrt{n}S_{n-1} \in \mathbb{R}^n\). Since \(f(x)\) already projects \(x\) to \(U\), it means that \(g(f(x))\) actually lives on an \((n-2)\)-dimensional sphere contained in \(U\). We denote this sphere by \(\sqrt{n} S_{n-2, U} = S_{n-1} \cap U\).

Finally, \(h(x; a, b)\) is an scaling by \(A = \operatorname{diag}{a}\) and change of location by \(b\). Notice that, since our vectors \(x\) are living in \(\sqrt{n} S_{n-2, U}\) after the composition \(g \circ f\), \(A\) not only distorts \(\sqrt{n} S_{n-2, U}\) into an elipsis, but actually takes \(U\) into a different subspace generated by applying \(A\) on \(U\), and then leads it to an **affine** space, in an ellipsis centered in \(b\).

In summary, \(\operatorname{lnorm}\) project an embedding vectors \(x \in \mathbb{R}^n\) into a \((n-2)\) dimensional ellipsis, whose shape and location is determined by \(a\) and \(b\).

What about RMS Layer Normalization

This transformation in an ellipsis seems needlessly complicated. What if we instead projected the embedding vectors directly in a sphere, and manipulated that sphere if necessary? This is the intuition of RMSNorm, which sheds the operation \(f\) above, and lets \(b=0\). This way, \(g\) projects the embedding vectors \(x\) into the sphere \(\sqrt{n} S_{n-1}\), and \(a\) just warps the principal axes of this sphere (on the standard basis) to the size \(\sqrt{n} a_i\). That makes interpretation at least way easier. Unfortunately, in the wild, we have mostly layer normalization, so the above interpretation is still the most important one.