__author__ = "Sida Wang"
__version__ = "COS 495 NLP Spring 2018"
The prediction of the multiclass linear classifier is based on the score $z_y = w_y \cdot \phi(x)$. While it is easy to learn $w$, the difficulties are hidden in $\phi(x)$, which is fixed after it is designed (by a person). One way to do better is for $\phi(x)$ to have learnable parameters as well. For example, by adding another linear function and another layer of featurization of $x$, like $w \cdot \phi(w' \cdot \phi'(x))$.
The two layer neural network does exactly that, by computing the vector-valued vector-input function $z = W_2 f(W_1 x + b_1) + b_2$. For arbitrary $f$, this says nothing at all. For a neural network, $f$ is a simple but non-linear elementwise function such as $f(x) = \max(0,x)$ or $f(x) = \frac1{1+\exp(-x)}$, which significantly restricts the space of functions considered. More layers can be added, for example, $z = W_3 f_3(W_2 f_2(W_1 x + b_1) + b_2)+b_3$. Generally, an $n$-layer feedforward neural network implements the network function $x \mapsto z$, where
$$ \begin{align} a_{1} & \gets x &\\ z & \gets z_{n} &\\ z_{i} & \gets W_i a_{i} + b_i \ \text{for} \ i = 1,\ldots,n\\ a_{i+1} & \gets f_{i+1}(z_{i}) \ \text{for} \ i = 1,\ldots,n-1.\\ \end{align} $$
The learnable parameters of the network are $\mathcal{W} = \{W_i, b_i\}_{i=1}^n$. Given data $x$, the desired prediction $y$, and the output $z$ computed from $x$, a loss function can be applied. For example, the least squares loss, the hinge loss (svm), and the "softmax loss" (i.e. the negative loglikelihood of the data under softmax) are, respectively,
$$ \begin{align} L_\text{ls}(\mathcal{W}, x, y) &= \left\lVert y-z \right\lVert_2^2,\\ L_\text{svm}(\mathcal{W}, x, y) &= \max(0, 1-(z_{y} - \max_{y'\neq y} z_{y'})),\\ L_\text{nll}(\mathcal{W}, x, y) &= -z_y + \exp(\sum_{y'} z_{y'}). \end{align} $$
While there is no known guarantee, gradient descent has proven to work well for many tasks. Our task here is to find gradients of the loss with respect to all the parameters $\mathcal{W}$. Backprop is an efficient application of the chain rule starting from the gradient of the loss w.r.t the output $z$ and work backwards. For $i = n, n-1, \ldots, 2,$
$$ \begin{align} \frac{d L}{dz_{n}} &= \frac{d L}{dz}\\ \frac{d L}{da_{i}} &= \frac{d L}{dz_{i}} \frac{d z_{i}}{da_{i}} = W_i^T\frac{d L}{dz_{i}}\\ \frac{d L}{dz_{i-1}} &= \frac{d L}{da_{i}} \odot \frac{d a_{i}}{dz_{i-1}} = \frac{d L}{da_{i}} \odot f'_i(z_{i-1}) \\ \end{align} $$
Once these quantities are computed, the gradient with respect to the parameters is easy
$$ \begin{align} \frac{d L}{d W_n} &= \frac{d L}{dz_n} a_n^T,\\ \frac{d L}{d b_n} &= \frac{d L}{dz_n}. \end{align} $$
Exercise: Check the dimensions of $a_i, z_i, \frac{d L}{da_{i}}, \frac{d L}{dz_{i-1}}, W_i$ and make sure they are consistent in backprop.
Is this just chain rule?: It is often claimed that backprop is just the chain rule to trivialized the problem. While a lot of of written steps use the chain rule, there are a few counter points:
Chain rule does not in itself specify the order of application. A straight forward application of the basic scalar chain rule is to expand all the terms and then evaluate, which is too expensive. In backprop, good decisions are made on when expressions are evaluated as opposed to expanded, and what intermediate values need to be kept for efficient computation.
With suitable vector chain rule, we can get to something that looks like backprop quickly. However, the tensor chain rule says $$\frac{d L}{d W_i} = \frac{d L}{dz_i} \frac{d z_i}{d W_i},$$ where $\frac{d z_i}{d W_i}$ is a 3rd order tensor of size $\dim(z_i) \times \dim(W_i)=\dim(z_i)\cdot(\dim(z_i)\cdot\dim(a_{i}))$, and the product is a tensor-vector product that seems to require $\dim(z_i)\cdot\dim(z_i)\cdot\dim(a_{i})$ operations. So you would have to specify how this can be done without actually constructing $\frac{d z_i}{d W_i}$. Backprop requires $\dim(z_i)\cdot\dim(a_{i})$ operations. For just 1000 dimensions, that makes the difference between practical and impractical.
Under our formulation expressing the neural network as a function and our faith in SGD, computing its derivative is an obvious step. However, originally neural networks are developed to model learning by the brain, and it is a stretch that the brain learns by taking derivatives. From that perspective, backprop is not just this simple algorithm which computes the derivative, but also includes the empirical finding that the neural network models can often learn effectively by following the gradient.
It seems that a few more advanced ingredients than the chain rule is needed to get backprop.
The code is no more complex than the math, in vectorized form. However, such code can be hard to debug. For example, it will run without error if I forgot to add the bias, or applied the wrong non-linear function. Gradient check is an essential tool for ensuring the correctness. While many issues and intentional noise does not stop SGD from working, systematically wrong gradient will make a difference.
import numpy as np
from numpy.random import randn
from copy import copy
# backpropagation code for least squares in a 2-layer neural network (single hidden layer)
# y_hat = W_2 f(W_1 x + b) + b_2, and L(y, y_hat) = (y - y_hat)^2
f2 = lambda x: np.maximum(0, x)
f2grad = lambda x: x > 0
# the loss function
lossfunc = lambda ypred, y: (ypred-y)*(ypred-y)
lossgrad = lambda ypred, y: 2*(ypred-y)
def fprop(x, y, params):
W1, b1, W2, b2 = [params[key] for key in ('W1', 'b1', 'W2', 'b2')]
z1 = np.dot(W1, x) + b1
a2 = f2(z1)
z2 = np.dot(W2, a2) + b2
loss = lossfunc(z2, y)
cache = {'x': x, 'y': y, 'z1': z1, 'a1': x, 'z2': z2, 'a2': a2, 'loss': loss}
for key in params:
cache[key] = params[key]
return cache
def bprop(fprop_cache):
x, y, z1, a1, z2, a2, loss = [fprop_cache[key] for key in ('x', 'y', 'z1', 'a1', 'z2', 'a2', 'loss')]
dz2 = lossgrad(z2, y)
dW2 = np.dot(dz2, a2.T)
db2 = dz2
da2 = np.dot(fprop_cache['W2'].T, dz2)
dz1 = da2 * f2grad(z1)
dW1 = np.dot(dz1, x.T)
db1 = dz1
return {'b1': db1, 'W1': dW1, 'b2': db2, 'W2': dW2}
From a software engineering perspective, this implementation exhibits strong coupling between fprop
and bprop
.
This coupling means that the human has to maintain the consistency between fprop
and bprop
when there should only be a dependence on the network architecture. This is unnecessary in principle and can quickly get out of hand for complicated networks. Here is an example backprop code for a network which is quite simple by modern standards. Imagine you made somes change to fprop
such as rearranging some layers, renaming some parameters, add more layers etc., you would have to do the corresponding modifications for bprop
while tracking the dependency structure in lines of code!
To understand the process more generally, consider a directed acyclic graph (DAG) $(V, E)$ that defines the function to be computed. Each node $v$ in the graph represents a function $\operatorname{in}(v) \mapsto \operatorname{out}(v)$, and each directed edge $(v, v')$ represents that $v'$ is a function of $v$ (among others). This means that (some of) the output of $v$ are (some of) the input of $v'$. In this case we say that $v'$ is a child of $v$, and we denote the set of all children of $v$ as $C(v) := \{v' \mid (v, v') \in E\}$ and the set of parents of $v$ as $P(v) := \{v' \mid (v', v) \in E\}$.
More precisely $\operatorname{in}(v) \subseteq \operatorname{out}(P(v))$ and $\operatorname{out}(v) \subseteq \operatorname{in}(C(v))$. Forward prop is evaluating each node from parents to children until the output node. For backprop to work, each node has to compute the gradient of its input given the gradient of of its output $$ \begin{align} \frac{dL}{d\operatorname{in}(v)} &= \frac{dL}{d\operatorname{out}(v)} \frac{d\operatorname{out}(v)}{d\operatorname{in}(v)}. \end{align} $$
Then the rest is up to chain rule. Information based on just the network structure $\frac{d \operatorname{in}(v')}{d\operatorname{out}(v)}$ and the gradients of the children of a node $\frac{dL}{d \operatorname{in}(v')}$ are required before the gradient of the node can be computed $$ \begin{align} \frac{dL}{d\operatorname{out}(v)} &= \frac{d}{d\operatorname{out}(v)} L(\operatorname{in}(C(v)), \operatorname{in}(V \backslash C(v))) \\ &= \sum_{v' \in C(v)} \frac{dL}{d \operatorname{in}(v')} \frac{d \operatorname{in}(v')}{d\operatorname{out}(v)}. \end{align} $$
Now there is some ecapsulation between the inner workings of each node and the overall structure, which means fprop
and bprop
of the whole graph can be decoupled.
Let's call each of these nodes, which might have parameters, and which might themselves be a composition of a bunch of functions a module. Then we need some modules that are easy to differentiate by themselves, which can then be composed to form complex computation graphs.
Here are some examples of simple modules:
$$ \begin{array}{ll} \text{linear:} & \operatorname{in}(v') = W \operatorname{out}(v) + b\\ \text{elementwise:} & \forall j: \operatorname{in}(v')_j = f(\operatorname{out}(v)_j)\\ \text{e.w. product:} & \operatorname{in}(v') = \operatorname{out}(v_1) \odot \operatorname{out}(v_2) \odot \operatorname{out}(v_3) \odot \ldots \\ \text{e.w. sum:} & \operatorname{in}(v') = \operatorname{out}(v_1) + \operatorname{out}(v_2) + \operatorname{out}(v_3) + \ldots\\ \end{array} $$
Given these, all the gradient computations can be done automatically by traversing the DAG from output to input in the any of the orderings implied by the DAG.
More concretely, we would like to write torch
style code like:
nn2layer = Sequential(OrderedDict([
('L1', Linear(params['W1'], params['b1'])),
('Relu1', Elementwise(lambda x: np.maximum(0, x), lambda x: x > 0)),
('L2', Linear(params['W2'], params['b2']))
('Relu2', Elementwise(lambda x: np.maximum(0, x), lambda x: x > 0)),
]))
Then you can swap out units, add more layers, add more structure without having to consider backprop. Here is my minimal implementation of torch style modules that has an explicit backward()
in each module, pytorch did away with them but then the code will have to be much longer.
from collections import OrderedDict
class Module(object):
def __init__(self):
self.params = OrderedDict()
self.grads = OrderedDict()
def forward(self, *input):
raise NotImplementedError
# ideally, one infers backward from forward
def backward(self, *input, gradout):
raise NotImplementedError
class Elementwise(Module):
def __init__(self, f, dfdz):
self.f, self.dfdz = f, dfdz
def forward(self, input):
self.input = input
self.output = self.f(input)
return self.output
def backward(self, gradout):
return gradout * self.dfdz(self.input)
class Linear(Module):
def __init__(self, W, b):
super(Linear, self).__init__()
self.params['W'] = W
self.params['b'] = b
def forward(self, input):
self.input = input
W, b = self.params['W'], self.params['b']
self.output = np.dot(W, input) + b
return self.output
def backward(self, gradout):
self.grads['W'] = np.dot(gradout, self.input.T)
self.grads['b'] = gradout
return np.dot(self.params['W'].T, gradout)
class Sequential(Module):
def __init__(self, children):
self.children = children
def forward(self, input):
for child in self.children.values():
input = child.forward(input)
return input
def backward(self, gradout):
for child in reversed(self.children.values()):
gradout = child.backward(gradout)
return gradout
class Loss(Module):
def __init__(self, netfunc, lossfunc, lossderiv):
self.netfunc = netfunc
self.lossfunc = lossfunc
self.lossderiv = lossderiv
def forward(self, input, target):
self.target = target
self.pred = self.netfunc.forward(input)
return self.lossfunc(target, self.pred)
def backward(self):
deriv = self.lossderiv(self.target, self.pred)
return self.netfunc.backward(deriv)
With these modules, we can define the network.
params = {'L1.W': randn(50,100),
'L1.b': randn(50,1),
'L2.W': randn(1, 50),
'L2.b': randn(1, 1)
}
x = randn(100,1)
y = randn(1)*5
nn2layer = Sequential(OrderedDict([
('L1', Linear(params['L1.W'], params['L1.b'])),
('Relu', Elementwise(lambda x: np.maximum(0, x), lambda x: x > 0)),
('L2', Linear(params['L2.W'], params['L2.b']))
]))
ls_loss = lambda y, yp: np.sum(np.square(y - yp))
ls_grad = lambda y, yp: 2*(yp - y)
final_loss = Loss(nn2layer, ls_loss, ls_grad)
loss = final_loss.forward(x, y)
final_loss.backward()
print('current loss', loss)
print('current pred/target', nn2layer.forward(x), y)
print('diff', (nn2layer.forward(x)-y)**2)
One thing I ignored for the sake of really short code is parameter management. If I am to add a new module, I still have to add new parameters manually. It is not difficult to manage all the parameters in the modules as well.
def flatten(self):
all_modules = OrderedDict()
for key, module in self.children.items():
all_modules[key] = module
if 'children' in module.__dict__:
# recurse on children
raise NotImplementedError
return all_modules
Sequential.flatten = flatten
def collect_params(root):
params = {}
grads = {}
for name, module in root.flatten().items():
if 'params' in module.__dict__:
for key, param in module.params.items():
params[name + '.' + key] = param
grads[name + '.' + key] = module.grads[key]
return params, grads
params, grads = collect_params(nn2layer)
Let us test the basic backprop code by comparing with the numerical gradient $$ \frac{d f(W,x,y)}{d W_{ij}} \approx \frac{f(W + \epsilon_{ij},x,y) - f(W - \epsilon_{ij},x,y)}{2 \epsilon}, $$ where $\epsilon_{ij}$ is the matrix of the same size as $W$, value $\epsilon$ at position $ij$ and 0 everywhere else. We will check both the basic implementation and the module.
def numerical_grad(fprop, x, y, params):
eps = 1e-6
ng_cache = {}
# For every single parameter (W, b)
for key in params:
param = params[key]
# This will be our numerical gradient
ng = np.zeros(param.shape)
for j in range(ng.shape[0]):
for k in range(ng.shape[1]):
# For every element of parameter matrix, compute gradient of loss wrt
# that element numerically using finite differences
add_eps = np.copy(param)
min_eps = np.copy(param)
add_eps[j, k] += eps
min_eps[j, k] -= eps
add_params = copy(params)
min_params = copy(params)
add_params[key] = add_eps
min_params[key] = min_eps
ng[j, k] = (np.sum(fprop(x, y, add_params)['loss']) \
- np.sum(fprop(x, y, min_params)['loss'])) / (2 * eps)
ng_cache[key] = ng
return ng_cache
def check_grad(params, grad1, grad2):
# Compare numerical gradients to those computed using backpropagation algorithm
for key in params:
#print(bprop_grad[key])
#print(num_grad[key])
diff = grad1[key].flatten() - grad2[key].flatten()
sums = grad1[key].flatten() + grad2[key].flatten()
norm = np.max(np.abs(diff / sums))
if norm < 1e-5:
print(key, 'pass', norm)
else:
print(key, 'fail', norm)
import time
timeformat = '{0}: numdata: {1}\t time: {2:.5e}'
# test gradient check on the basic implementation
num_data, dim_data, num_hid = 1, 200, 300
W1 = np.random.rand(num_hid, dim_data)
b1 = np.random.rand(num_hid, 1)
W2 = np.random.rand(1, num_hid)
b2 = np.random.rand(1, 1)
x = np.random.rand(dim_data, num_data)
y = np.random.rand(1, num_data) * 10
params = {'W1': W1, 'b1': b1, 'W2': W2, 'b2': b2}
tic = time.time() ########## BEGIN
fprop_cache = fprop(x, y, params)
bprop_grad = bprop(fprop_cache)
toc = time.time() ########## END
tictoc = toc - tic
print(timeformat.format('bprop', num_data, tictoc))
tic = time.time() ########## BEGIN
num_grad = numerical_grad(fprop, x, y, params)
toc = time.time() ########## END
tictoc = toc - tic
print(timeformat.format('numerical', num_data, tictoc))
# Compare numerical gradients to those computed using backpropagation algorithm
check_grad(params, num_grad, bprop_grad)
print('loss', fprop_cache['loss'])
#############################################
# BEGIN CHECK MODULE
params = {'L1.W': W1, 'L1.b': b1, 'L2.W': W2, 'L2.b': b2}
nn2layer = Sequential(OrderedDict([
('L1', Linear(params['L1.W'], params['L1.b'])),
('Relu', Elementwise(lambda x: np.maximum(0, x), lambda x: x > 0)),
('L2', Linear(params['L2.W'], params['L2.b']))
]))
ls_loss = lambda y, yp: np.sum(np.square(y - yp))
ls_grad = lambda y, yp: 2*(yp - y)
final_loss = Loss(nn2layer, ls_loss, ls_grad)
params = {'L1.W': W1, 'L1.b': b1, 'L2.W': W2, 'L2.b': b2}
def fprop_wrapper(x, y, params):
_children = nn2layer.children
_children['L1'].params['W'] = params['L1.W']
_children['L1'].params['b'] = params['L1.b']
_children['L2'].params['W'] = params['L2.W']
_children['L2'].params['b'] = params['L2.b']
loss = final_loss.forward(x, y)
return {'loss': loss}
tic = time.time() ########## BEGIN
loss_module = final_loss.forward(x, y)
final_loss.backward()
_children = nn2layer.children
module_grad = {'L1.W': _children['L1'].grads['W'], 'L1.b': _children['L1'].grads['b'],
'L2.W': _children['L2'].grads['W'], 'L2.b': _children['L2'].grads['b']}
toc = time.time() ########## END
tictoc = toc - tic
print(timeformat.format('backprop-module', num_data, tictoc))
tic = time.time() ########## BEGIN
num_grad = numerical_grad(fprop_wrapper, x, y, params)
toc = time.time() ########## END
tictoc = toc - tic
print(timeformat.format('numerical-module', num_data, tictoc))
print('loss', loss_module)
check_grad(params, num_grad, module_grad)
Since matrix multiplications are highly optimized, it would be faster to do more of them. In fact, the code already works with minibatches of data points in the form of $X = [x_1, x_2, \ldots, x_n]$, where each data vector occupies a column. On my CPU, I got a speedup of 10 times for a fairly small network. Such speedups are expected to be more significant on GPUs.
num_data, dim_data, num_hid = 200, 100, 200
backprop-batch: numdata: 200 time: 8.87632e-04
backprop-loop: numdata: 200 time: 8.80098e-03
numerical: numdata: 200 time: 1.09668e+00
The last row is the time it takes to compute all the gradients numerically, so do not get any ideas.
Exercise: there is one issue in bprop that prevents vectorized minibatching from being completely correct. Spot it and propose a fix.
num_data, dim_data, num_hid = 100, 200, 300
W1 = np.random.rand(num_hid, dim_data)
b1 = np.random.rand(num_hid, 1)
W2 = np.random.rand(1, num_hid)
b2 = np.random.rand(1, 1)
params = {'W1': W1, 'b1': b1, 'W2': W2, 'b2': b2}
x = np.random.rand(dim_data, num_data)
y = np.random.rand(1, num_data) * 10
tic = time.time()
x = np.random.rand(dim_data, num_data)
y = np.random.rand(1, num_data) * 10
fprop_cache = fprop(x, y, params)
bprop_grad = bprop(fprop_cache)
toc = time.time()
tictoc = toc - tic
print(timeformat.format('backprop-batch', num_data, tictoc))
tic = time.time()
for i in range(num_data):
x = np.random.rand(dim_data, 1)
y = np.random.rand(1, 1) * 10
fprop_cache = fprop(x, y, params)
bprop_grad = bprop(fprop_cache)
toc = time.time()
tictoc = toc - tic
print(timeformat.format('backprop-loop', num_data, tictoc))
tic = time.time()
num_grad = numerical_grad(fprop, x, y, params)
toc = time.time()
tictoc = toc - tic
print(timeformat.format('numerical', num_data, tictoc))
If the basic coupled implementation seemed easy enough to work with, here is the backprop code for a slightly more complex network (but very simple by modern standard). The fprop and a few function called in boths parts are here.
# highly coupled and unreadable code from the transforming autoencoder, around 2011
def backprop(self, target, input, trsf):
numcases = target.shape[1]
biasfac = 1
self.calcoutput(input, trsf)
self.diff = self.output - target
self.wu_ho = g.dot(self.diff, self.h2.T * self.pr.T) / numcases
self.wu_o = self.diff.sum(1)[:,None] / (biasfac * numcases)
dEdH2 = g.dot(self.w_ho.T, self.diff) * self.pr
dEdH2in = dEdH2 * (self.h2 - self.shift) * (1- self.h2 + self.shift)
self.wu_ch = g.dot(dEdH2in, self.ct.T) / numcases
self.wu_h2 = dEdH2in.sum(1)[:,None] / (biasfac * numcases)
dEdCin = g.dot(self.w_ch.T, dEdH2in)
self.bpdEdCin = g.garray(dEdCin)
#print 'max before invert %f %f' % (dEdCin.max(), dEdCin.min())
if not self.justtranslate: dEdCin = g.garray(self.applymatf(trsf,dEdCin,transpose=True))
dEdPout = (self.w_ho[:,:,None] * self.h2[None,:,:] * self.diff[:,None,:])\
.sum(0).reshape(self.groupcoord, self.sizehid2, numcases)\
.sum(1)
dEdPin = self.p * (1-self.p) * dEdPout
self.wu_hp = g.dot(dEdPin, self.h1.T) / numcases
self.wu_p = dEdPin.sum(1)[:,None] / (biasfac * numcases)
#dEdCin *= (self.maskm * self.pr)
self.bpdEdCinm = g.garray(dEdCin)
#print 'max after invert %f %f' % (dEdCin.max(), dEdCin.min())
self.wu_hc = g.dot(dEdCin, self.h1.T) / numcases
self.wu_c = dEdCin.sum(1)[:,None] / (biasfac * numcases)
dEdH1 = g.dot(self.w_hc.T, dEdCin) + g.dot(self.w_hp.T, dEdPin)
dEdH1in = dEdH1 * (self.h1 - self.shift) * (1 - self.h1 + self.shift)
self.wu_vh = g.dot(dEdH1in, input.T) / numcases
self.wu_h1 = dEdH1in.sum(1)[:,None] / (biasfac * numcases)