A useful trick for computing gradients w.r.t. matrix arguments, with some examples

I’ve spent hours this week and last week computing, recomputing, and checking expressions for matrix gradients of functions. It turns out that except in the simplest of cases, the most painfree method for finding such gradients is to use the Frechet derivative (this is one of the few concrete benefits I derived from the differential geometry course I took back in grad school).

Remember that the Frechet derivative of a function \(f : X \rightarrow \mathbb{R}\) at a point \(x\) is defined as the unique linear operator \(d\) that is tangent to \(f\) at \(x\), i.e. that satisfies
\[
f(x+h) = f(x) + d(h) + o(\|h\|).
\]
This definition of differentiability makes sense whenever \(X\) is a normed linear space. If \(f\) has a gradient, then the Frechet derivative exists and the gradient satisfies the relation \(d(h) = \langle \nabla f(x), h \rangle.\)

Simple application

As an example application, lets compute the gradient of the function
\[
f(X) = \langle A, XX^T \rangle := \mathrm{trace}(A^T XX^T) = \sum_{ij} A_{ij} (XX^T)_{ij}
\]
over the linear space of \(m\) by \(n\) real-valued matrices equipped with the Frobenius norm. First we can expand out \(f(X+H)\) as
\[
f(X + H) = \langle A, (X+H)(X+H)^T \rangle = \langle A, XX^T + XH^T + HX^T + HH^T \rangle
\]
Now we observe that the terms which involve more than one power of \(H\) are \(O(\|H\|^2) = o(\|H\|)\) as \(H \rightarrow 0\), so
\[
f(X + H) = f(X) + \langle A, XH^T + HX^T \rangle + o(\|H\|).
\]
It follows that
\[
d(H) = \langle A, XH^T + HX^T \rangle = \mathrm{trace}(A^TXH^T) + \mathrm{trace}(A^THX^T),
\]
which is clearly a linear function of \(H\) as desired. To write this in a way that exposes the gradient, we use the
cyclicity properties of the trace, and exploit its invariance under transposes to see that
\begin{align}
d(H) & = \mathrm{trace}(HX^TA) + \mathrm{trace}(X^TA^T H) \\
& = \mathrm{trace}(X^TAH) + \mathrm{trace}(X^TA^T H) \\
& = \langle AX, H \rangle + \langle A^TX, H \rangle \\
& = \langle (A + A^T)X, H \rangle.
\end{align}
The gradient of \(f\) at \(X\) is evidently \((A + A^T)X\).

More complicated application

If you have the patience to work through a lot of algebra, you could probably calculate the above gradient component by component using the standard rules of differential calculus, then back out the simple matrix expression \((A + A^T)X\). But what if we partitioned \(X\) into \(X = [\begin{matrix}X_1^T & X_2^T \end{matrix}]^T\) and desired the derivative of
\[
f(X_1, X_2) = \mathrm{trace}\left(A \left[\begin{matrix} X_1 \\ X_2 \end{matrix}\right] \left[\begin{matrix}X_1 \\ X_2 \end{matrix} \right]^T\right)
\]
with respect to \(X_2\)? Then the bookkeeping necessary becomes even more tedious if you want to compute component by component derivatives (I imagine, not having attempted it). On the other hand, the Frechet derivative route is not significantly more complicated.

Some basic manipulations allow us to claim
\begin{align}
f(X_1, X_2 + H) & = \mathrm{trace}\left(A \left[\begin{matrix} X_1 \\ X_2 + H \end{matrix}\right] \left[\begin{matrix}X_1 \\ X_2 + H \end{matrix} \right]^T\right) \\
& = f(X_1, X_2) + \mathrm{trace}\left(A \left[\begin{matrix} 0 & X_1 H^T \\
H X_2^T & H X_2^T + X_2 H^T + H H^T \end{matrix} \right]\right)
\end{align}
Once again we drop the \(o(\|H\|)\) terms to see that
\[
d(H) = \mathrm{trace}\left(A \left[\begin{matrix} 0 & X_1 H^T \\
H X_2^T & H X_2^T + X_2 H^T \end{matrix} \right]\right).
\]
To find a simple expression for the gradient, we partition \(A\) (conformally with our partitioning of \(X\) into \(X_1\) and \(X_2\)) as
\[
A = \left[\begin{matrix} A_1 & A_2 \\ A_3 & A_4 \end{matrix} \right].
\]
Given this partitioning,
\begin{align}
d(H) & = \mathrm{trace}\left(\left[\begin{matrix}
A_2 H X_1^T & \\
& A_3 X_1 H^T + A_4 H X_2^T + A_4 X_2 H^T
\end{matrix}\right] \right) \\
& = \langle A_2^TX_1, H \rangle + \langle A_3X_1, H \rangle + \langle A_4^T X_2, H \rangle + \langle A_4X_2, H \rangle \\
& = \langle (A_2^T + A_3)X_1 + (A_4^T + A_4)X_2, H \rangle.
\end{align}
The first equality comes from noting that the trace of a block matrix is simply the trace of its diagonal parts, and the second comes from manipulating the traces using their cyclicity and invariance to transposes.

Thus \(\nabla_{X_2} f(X_1, X_2) = (A_2^T + A_3)X_1 + (A_4^T + A_4)X_2.\)

A masterclass application

Maybe you didn’t find the last example convincing. Here’s a function I needed to compute the matrix gradient for— a task which I defy you to accomplish using standard calculus operations—:
\[
f(V) = \langle 1^T K^T, \log(1^T \mathrm{e}^{VV^T}) \rangle = \log(1^T \mathrm{e}^{VV^T})K1.
\]
Here, \(K\) is an \(n \times n\) matrix (nonsymmetric in general), \(V\) is an \(n \times d\) matrix, and \(1\) is a column vector of ones of length \(n\). The exponential \(\mathrm{e}^{VV^T}\) is computed entrywise, as is the \(\log\).

To motivate why you might want to take the gradient of this function, consider the situation that \(K_{ij}\) measures how similar items \(i\) and \(j\) are in a nonsymmetric manner, and the rows of \(V\) are coordinates for representations of the items in Euclidean space. Then \((1^T K)_j\) measures how similar item \(j\) is to all the items, and
\[
(1^T \mathrm{e}^{VV^T})_j = \sum_{\ell=1}^n \mathrm{e}^{v_\ell^T v_j}
\]
is a measure of how similar the embedding \(v_j\) is to the embeddings of all the items. Thus, if we constrain all the embeddings to have norm 1, maximizing \(f(V)\) with respect to \(V\) ensures that the embeddings capture the item similarities in some sense. (Why do you care about this particular sense? That’s another story altogether.)

Ignoring the constraints (you could use a projected gradient method for the optimization problem), we’re now interested in finding the gradient of \(f\). In the following, I use the notation \(A \odot B\) to indicate the pointwise product of two matrices.
\begin{align}
f(V + H) & = \langle 1^T K, \log(1^T \mathrm{e}^{(V+H)(V+H)^T} \rangle \\
& = \langle 1^T K, \log(1^T [\mathrm{e}^{VV^T} \odot \mathrm{e}^{VH^T} \odot \mathrm{e}^{HV^T} \odot \mathrm{e}^{HH^T} ]) \rangle
\end{align}
One can use the series expansion of the exponential to see that
\begin{align}
\mathrm{e}^{VH^T} & = 11^T + VH^T + o(\|H\|), \\
\mathrm{e}^{HV^T} & = 11^T + HV^T + o(\|H\|), \text{ and}\\
\mathrm{e}^{HH^T} & = 11^T + o(\|H\|).
\end{align}
It follows that
\begin{multline}
f(V + H) = \langle 1^T K, \log(1^T [\mathrm{e}^{VV^T} \odot (11^T + VH^T + o(\|H\|)) \\
\odot (11^T + HV^T + o(\|H\|)) \odot (11^T + o(\|H\|)) ]) \rangle.
\end{multline}
This readily simplifies to
\begin{align}
f(V + H) & = \langle 1^T K, \log(1^T [\mathrm{e}^{VV^T} \odot(11^T + VH^T + HV^T + o(\|H\|) )]) \rangle \\
& = \langle 1^T K, \log(1^T [\mathrm{e}^{VV^T} + e^{VV^T} \odot (VH^T + HV^T) + o(\|H\|) )]) \rangle
\end{align}
Now recall the linear approximation of \(\log\):
\[
\log(x) = \log(x_0) + \frac{1}{x_0} (x-x_0) + o(|x- x_0|^2).
\]
Apply this approximation pointwise to conclude that
\begin{multline}
f(V + H) = \langle 1^T K, \log(1^T \mathrm{e}^{VV^T}) + \\
\{1^T \mathrm{e}^{VV^T}\}^{-1}\odot (1^T [\mathrm{e}^{VV^T} \odot (VH^T + HV^T) + o(\|H\|)]) \rangle,
\end{multline}
where \(\{x\}^{-1}\) denotes the pointwise inverse of a vector.
Take \(D\) to be the diagonal matrix with diagonal entries given by \(1^T \mathrm{e}^{VV^T}\). We have shown that
\[
f(V + H) = f(V) + \langle K^T1, D^{-1} [\mathrm{e}^{VV^T} \odot (VH^T + HV^T)]1 \rangle + o(\|H\|),
\]
so
\begin{align}
d(H) & = \langle K^T1, D^{-1} [\mathrm{e}^{VV^T} \odot (VH^T + HV^T)]1 \rangle \\
& = \langle D^{-1}K^T 11^T, \mathrm{e}^{VV^T} \odot (VH^T + HV^T) \rangle \\
& = \langle \mathrm{e}^{VV^T} \odot D^{-1}K^T 11^T, (VH^T + HV^T) \rangle.
\end{align}
The second inequality follows from the standard properties of inner products and the third from the observation that
\[
\langle A, B\odot C \rangle = \sum_{ij} A_{ij}*B_{ij}*C_{ij} = \langle B \odot A, C \rangle.
\]
Finally, manipulations in the vein of the two preceding examples allow us to claim that
\[
\nabla_V f(V) = [\mathrm{e}^{VV^T} \odot (11^T K D^{-1} + D^{-1} K^T 11^T)] V.
\]

As a caveat, note that if instead \(f(V) = \log(1^T \mathrm{e}^{VV^T} ) K^T 1\), then one should substitute \(K\) for \(K^T\) in the last expression.