Quantcast
Channel: Implementing Gradient Descent Algorithm in Python, bit confused regarding equations - Artificial Intelligence Stack Exchange
Viewing all articles
Browse latest Browse all 2

Answer by primussucks for Implementing Gradient Descent Algorithm in Python, bit confused regarding equations

$
0
0

There seems to be a mismatch between the weights you provide and your network diagram. Since w[0] (the yellow connections) is meant to transform $ x \in \mathbb{R}^2 $ into the layer 0 activations which are $ \mathbb{R}^2 $, w[0] should be a matrix $ \in \mathbb{R}^{2 \times 2} $, not a vector in $\mathbb{R}^2 $ as you have. Likewise, your w[1] (the red connections) should be a vector $ \in \mathbb{R^2} $ and not a scalar. Finally, if you are indeed scaling the output of layer 1 (the blue connection), then you'll need an additional scalar value. However, the blue connection confuses me a bit as usually the activated output is used directly in the loss function, not a scaled version of it. Unless the blue connection stands for the loss function.

In short, I believe if you change the shapes of your weight matrices to actually represent your network diagram, your update equations will work. I'll go through the network below to make sure I illustrate my point.

$ x \in \mathbb{R}^{2} $, an input example

$ W^0 \in \mathbb{R}^{2 \times 2} $, the yellow connections

$ W^1 \in \mathbb{R}^2 $, the red connections

$ z^0 = xW^0 \in \mathbb{R}^{2} $, the weighted inputs to the layer 0 nodes. The dimensions of this should match the number of nodes at layer 0.

$ a^0 = \sigma(z^0) \in \mathbb{R}^{2} $, the output of the layer 0 nodes. The dimensions of this should match the number of nodes at layer 0.

$ z^1 = a^0 W^1 \in \mathbb{R} $, the weighted inputs to the layer 1 nodes. The dimensions of this should match the number of nodes at layer 1.

$ a^1 = \sigma(z^1) \in \mathbb{R} $, the output of the layer 1 nodes and thus the output of the network. The dimensions of this should match the number of nodes at layer 1.

Weight Updates

As you say before your edit, $\delta^1$, as the product of two scalars $\nabla_a C$ and $\sigma'(z^1)$, is also a scalar. Since $a^0$ is a vector in $\mathbb{R}^2$, then$\delta^1(a^0)^T$ is also a vector in $\mathbb{R}^2$. This matches what we expect, as it should match the dimensions of $W^1$ to allow the element-wise subtraction in the weight update equation.

NB. It is not the case, as you say in your edit, that the shape of $\delta^l$ should match the shape of $W^l$. It should instead match the number of nodes, and it is the shape of $\delta^l(a^{l-1})^T$ that should match the shape of $W^l$. You had this right in your original post.

Bias Updates

This brings us to the bias updates.There should be one bias term per node in a given layer, so the shapes of your biases are correct (i.e. $\mathbb{R}^2$ for layer 0 and $\mathbb{R}$ for layer 1). Now, we saw above that the shape of $\delta^l$ also matches the number of nodes in layer $l$, so again the element-wise subtraction in your original bias update equation works.

I also tried using this book to learn backprop, but I had a hard time connecting the variables with the different parts of the network and the corresponding code. I finally understood the algorithm in depth only after deriving all the update equations by hand for a very small network (2 inputs, one output, no hidden layers) and working my way up to larger networks, making sure to keep track of the shapes of the inputs and outputs along the way. If you're having trouble with the update equations I highly recommend this.

A final piece of advice that helped me: drop the $x$ and the summations over input examples from your formulations and just treat everything as matrices (e.g. a scalar becomes a matrix in $\mathbb{R}^{1 \times 1}$, $X$ is a matrix in $\mathbb{R}^{N \times D}$). First, this allows you to better interpret matrix orientations and debug issues such as a missing transpose operation. Second, this is (in my limited understanding) how backprop should actually be implemented in order to take advantage of optimized linalg libraries and GPUs, so it's perhaps a bit more relevant.


Viewing all articles
Browse latest Browse all 2

Trending Articles



<script src="https://jsc.adskeeper.com/r/s/rssing.com.1596347.js" async> </script>