Nikola Janjušević
Published 2023-04-12

On using the multi-dimensional chain-rule correctly.

I've recently seem some graduate students have confusion about partial derivatives and applying the chain rule. The following is a quiz I've written for my PhD advisor's Image and Video Processing students. I will use it to illustrate the proper application of the multi-dimensional chain-rule. The quiz question is as follows,

Consider the following loss function, with input x[n]\mathbf{x}[n], 1nN1 \leq n \leq N, and target y[n]\mathbf{y}[n], 1nN1\leq n \leq N,

L(y, x; h)=12yhx22. \mathcal{L}(\mathbf{y}, \, \mathbf{x}; \, \mathbf{h}) = \frac{1}{2}\lVert \mathbf{y} - \mathbf{h} \ast \mathbf{x} \rVert_2^2.

where, h[m], M1M\mathbf{h}[m], ~ -M \leq 1 \leq M is a learnable 1D filter and \ast denotes 1D convolution with zero-padding, i.e.,

hx(hx)[n]=m=MMh[m]x[nm],1nN. \mathbf{h} \ast \mathbf{x} \quad \Leftrightarrow \quad (\mathbf{h} \ast \mathbf{x})[n] = \sum_{m=-M}^M \mathbf{h}[m] \mathbf{x}[n-m], \quad 1 \leq n \leq N.

where x[n<1]0\mathbf{x}[n < 1] \equiv 0 and x[n>N]0\mathbf{x}[n > N] \equiv 0.

Derive the partial derivative,

Lh[m], for MmM. \frac{\partial \mathcal{L}}{\partial \mathbf{h}[m]} , \quad \text{ for } -M \leq m \leq M.

Below is a solution that I would expect students to arrive at. It makes use of the scalar chain rule and doesn't worry about deriving Jacobians, as the course does not emphasize this perspective heavily.

Let y^=hx\hat{\mathbf{y}} = \mathbf{h} \ast \mathbf{x}. Then,

Lh[m]=h[m](12i=1N(y[i]y^[i])2)=i=1Ny^[i]h[m](y^[i]y[i])=i=1Nx[im](y^[i]y[i])=i=1Nx[im]((hx)[i]y[i]),for MmM. \begin{aligned} \frac{\partial \mathcal{L}}{\partial \mathbf{h}[m]} &= \frac{\partial }{\partial \mathbf{h}[m]} \left( \frac{1}{2}\sum_{i=1}^N \left( \mathbf{y}[i] - \hat{\mathbf{y}}[i] \right)^2 \right) \\ &= \sum_{i=1}^N \frac{\partial \hat{\mathbf{y}}[i]}{\partial \mathbf{h}[m]} \left(\hat{\mathbf{y}}[i] - \mathbf{y}[i] \right) \\ &= \sum_{i=1}^N \mathbf{x}[i-m] \left(\hat{\mathbf{y}}[i] - \mathbf{y}[i] \right) \\ &= \sum_{i=1}^N \mathbf{x}[i-m] \left((\mathbf{h} \ast \mathbf{x})[i] - \mathbf{y}[i] \right), \quad \text{for } -M \leq m \leq M. \end{aligned}

Note that this may be interpreted as a convolution of a flipped version of x[n]\mathbf{x}[n] with (y^y)[n](\hat{\mathbf{y}}-\mathbf{y})[n]. To see this, let x[n]=x[n]\vec{\mathbf{x}}[n] = \mathbf{x}[-n]. Then,

Lh[m]=i=1Nx[mi](y^[i]y[i]), for MmM(Lh)[m]=((y^y)x)[m], for MmM \begin{aligned} \frac{\partial \mathcal{L}}{\partial \mathbf{h}[m]} &= \sum_{i=1}^N \vec{\mathbf{x}}[m-i] \left(\hat{\mathbf{y}}[i] - \mathbf{y}[i] \right), \quad \text{ for } -M \leq m \leq M \\ & \Leftrightarrow \\ \left(\frac{\partial \mathcal{L}}{\partial \mathbf{h}}\right)[m] &= \left((\hat{\mathbf{y}} - \mathbf{y}) \ast \vec{\mathbf{x}} \right)[m], \quad \text{ for } -M \leq m \leq M \end{aligned}

Again, we assume zero-padding for y\mathbf{y} and y^\hat{\mathbf{y}}.

On using the multi-dimensional chain-rule

There are two distinct approaches to deriving this partial derivative: the Jacobian way or the scalar way. I see many students tripping themselves up because they're aware of these two methods but not fully aware on their distinction.

From preschool we are all familiar with the standard scalar chain rule of calculus. For differentiable f:RRf: \mathbb{R} \rightarrow \mathbb{R} and g:RRg: \mathbb{R} \rightarrow \mathbb{R},

(fg)x=f(g(x))g(x)=fgg(x)gxx. \frac{\partial (f\circ g)}{\partial x} = f^\prime(g(x)) g^\prime(x) = \frac{\partial f}{\partial g}\Bigm\lvert_{g(x)} \frac{\partial g}{\partial x} \Bigm\lvert_{x}.

It's this rule above that we directly employ in the above solution, by expanding the loss function in terms of its scalar variables, y^[n]\hat{\mathbf{y}}[n] and y[n]\mathbf{y}[n].

However, students also learn that a similar chain rule exists for vector input/output mappings. Namely, for differentiable mappings f:RKRMf: \mathbb{R}^K \rightarrow \mathbb{R}^M and g:RNRKg: \mathbb{R}^N \rightarrow \mathbb{R}^K,

(fg)x=fgg(x)gxx, \frac{\partial (f\circ g)}{\partial x} = \frac{\partial f}{\partial g}\Bigm\lvert_{g(x)} \frac{\partial g}{\partial x} \Bigm\lvert_{x},

where fgRM×K\frac{\partial f}{\partial g} \in \mathbb{R}^{M \times K} is the Jacobian matrix of ff, defined element-wise as

(fg)ij=figj,1iM,1jK, \left(\frac{\partial f}{\partial g}\right)_{ij} = \frac{\partial f_i}{\partial g_j}, \quad 1 \leq i \leq M, \quad 1 \leq j \leq K,

and an analagous definition for gxRK×N\frac{\partial g}{\partial x} \in \mathbb{R}^{K \times N}. As a sanity check, observe that the shapes of the matrix multiplication in (7) work out, and that the Jacobian of fgf \circ g is an M×NM \times N matrix.

Trouble then arises when students derive the elements of the Jacobian matrices in (7) separately and then forget to compose them using matrix multiplication, i.e.,

(fx)ij=k=1K(fg)ik(gx)kj=k=1Kfigkgkxj. \left(\frac{\partial f}{\partial x}\right)_{ij} = \sum_{k=1}^K \left(\frac{\partial f}{\partial g}\right)_{ik} \left(\frac{\partial g}{\partial x}\right)_{kj} = \sum_{k=1}^K \frac{\partial f_i}{\partial g_k} \frac{\partial g_k}{\partial x_j}.

Students who have trouble will often try to do element-wise multiplication of fg\frac{\partial f}{\partial g} and gx\frac{\partial g}{\partial x}, even when the shapes don't make sense.

Let y^=hx\hat{\mathbf{y}} = \mathbf{h} \ast \mathbf{x}. Then,

Lh=Ly^y^h, \frac{\partial \mathcal{L}}{\partial \mathbf{h}} = \frac{\partial \mathcal{L}}{\partial \hat{\mathbf{y}}}\frac{\partial \hat{\mathbf{y}}}{\partial \mathbf{h}},

where LhR1×(2M+1)\frac{\partial \mathcal{L}}{\partial \mathbf{h}} \in \mathbb{R}^{1 \times (2M+1)}, Ly^R1×N\frac{\partial \mathcal{L}}{\partial \hat{\mathbf{y}}} \in \mathbb{R}^{1 \times N}, and y^hRN×(2M+1)\frac{\partial \hat{\mathbf{y}}}{\partial \mathbf{h}} \in \mathbb{R}^{N \times (2M+1)}. The Jacobian of the loss w.r.t y^\hat{\mathbf{y}} is,

Ly^[j]=y^[j]y[j],for 1jN. \begin{aligned} \frac{\partial \mathcal{L}}{\partial \hat{\mathbf{y}}[j]} &= \hat{\mathbf{y}}[j] - \mathbf{y}[j], \quad \text{for } 1 \leq j \leq N. \end{aligned}

And the Jacobian of convolution w.r.t the kernel is,

y^[i]h[j]=x[ij],for 1iN, MjM. \begin{aligned} \frac{\partial \hat{\mathbf{y}}[i]}{\partial \mathbf{h}[j]} &= \mathbf{x}[i-j], \quad \text{for } 1 \leq i \leq N, ~ -M \leq j \leq M. \end{aligned}

Combining them using matrix multiplication we get,

Lh[m]=i=1NLy^[i]y^[i]h[m]=i=1Nx[im](y^[i]y[i]),for MmM. \begin{aligned} \frac{\partial \mathcal{L}}{\partial \mathbf{h}[m]} &= \sum_{i=1}^N \frac{\partial \mathcal{L}}{\partial \hat{\mathbf{y}}[i]}\frac{\partial \hat{\mathbf{y}}[i]}{\partial \mathbf{h}[m]} \\ &= \sum_{i=1}^N \mathbf{x}[i-m](\hat{\mathbf{y}}[i] - \mathbf{y}[i]), \quad \text{for } -M \leq m \leq M. \end{aligned}