Backpropagation and Reverse-mode Autodiff

Backpropagation is a method for computing gradients of complex (as in complicated) composite functions. It is one of those things that is quite simple once you have figured out how it works (the other being Paxos).

In this document, I will try to explain it in a way that avoids the issues I had with most of the explanations I found in textbooks and on the Internet. Inherently, this means that this document is well-suited to my way of thinking, but my hope is that it may also suit yours or, at the very least, may give you a different perspective as you try to grasp how backpropagation works. In particular, I avoid:

  • using Leibniz notation as much as I can, favouring Lagrange’s notation instead. I specifically avoid talking about derivatives of variables and will always talk about derivatives of functions with respect to variables. I also make the parameters (inputs) of these functions explicit in my derivations whenever I can. This results in more verbose, but hopefully easier-to-understand, derivations;

  • talking about expression graphs too early. Most of the explanations I found start by flashing expression graphs and then explaining the derivation process from those. I found it difficult to reconcile these appeals to intuition to the actual details of the calculus underneath. I therefore take the opposite stance in this document: expression graphs are explained as a consequence of the derivations, and only at the end.

The remainder of this document is organized as follows. In Sec. 1 we talk about something called forward-mode autodiff. Indeed, backpropagation is a special case of reverse-mode autodiff, and I think it comes more naturally if we have a look at forward-mode autodiff, which is simpler, first. Then, in Sec. 2 we delve into reverse-mode autodiff proper: first by showing a very simple example, and then, after doing a recap of the multivariable chain rule in Sec. 2.1, and showing that the “simple” chain rule can be seen as a special case of its multivariable sibling (Sec. 2.2), we work through a more complex example which introduces expression graphs and bridge those to the chain rule in Sec. 2.3. In the process, we compute the gradients in the example by backpropagation and compare those with the result obtained by symbolic derivation, closing off the explanation.

At the end of the document, in the Appendix, you can find a simple reverse-mode autodiff written in Python which leverages sympy.

You will note that I made no mention of Neural Networks here. This is because a Neural Network is just a complex composite function: once you grasp the concepts of reverse-mode autodiff, the same concepts apply to the computation of the gradients of the hypothesis and error (loss) functions of the Neural Network as well.

Table of Contents

\[ \newcommand{\vecb}[1]{{\mathbf #1}} \]

1 Forward-Mode Autodiff

Suppose we had the following functions:

  • \(f_1(x,y) = 3 x^{2} + 5 x y + 4 y^{2}\)
  • \(f_2(x) = \sin{\left(x \right)}\)
  • \(f_3(x) = \sqrt{x} + e^{x}\)

and that we wanted to compute the derivative of the following composite function:

\[ f(x) = f_1(f_2(x), f_3(x)) \]

There are a number of ways one could go about doing that. One could, for instance, compute the derivative algebraically and arrive at a closed-form expression by, say, first replacing \(x\) and \(y\) by \(f_2(x)\) and \(f_3(x)\) in \(f\):

\[ \begin{equation} f(x) = 4 \left(\sqrt{x} + e^{x}\right)^{2} + 5 \left(\sqrt{x} + e^{x}\right) \sin{\left(x \right)} + 3 \sin^{2}{\left(x \right)} \end{equation} \]

and then taking the derivative with respect to \(x\):

\[ \begin{equation} f^\prime(x) = 4 \left(\sqrt{x} + e^{x}\right) \left(2 e^{x} + \frac{1}{\sqrt{x}}\right) + \left(5 \sqrt{x} + 5 e^{x}\right) \cos{\left(x \right)} + \left(5 e^{x} + \frac{5}{2 \sqrt{x}}\right) \sin{\left(x \right)} + 6 \sin{\left(x \right)} \cos{\left(x \right)} \tag{1.1} \end{equation} \]

Chain Rule. Or, recalling that in Lagrange’s notation we have \(\frac{\partial h}{\partial x} = h^\prime\) and \(\frac{\partial h}{\partial y} = h_\prime\), one could get to the same result by applying the multivariable chain rule:

\[ \begin{equation} f^\prime(x) = f_1^\prime(f_2(x), f_3(x)).f_2^\prime(x) + f_{1\prime}(f_2(x), f_3(x))\cdot f^\prime_3(x) \tag{1.2} \end{equation} \]

which, if you were to expand the formula, would get you the same result as in Eq. (1.1).

Forward-Mode Autodiff. The essence of autodiff is that, to evaluate the derivative of \(f\) at some value of \(x\), you do not need to compute a complete formula in advance: you can instead try to compute the formulae for a set of smaller, easier-to-compute blocks, which can then be evaluated as a sequence of steps towards the final value. Note that this will not lead us to a closed-form expression of the derivative as in Eq. (1.1), but it will allow us to compute the derivative quite efficiently nevertheless: often more efficiently than with a closed-form representation.

Eq. (1.2) already provides a hint of how this could work. We can break down the constituent parts of \(f\) into set of smaller blocks which are easier to derive symbolically:

\[ \begin{align} \frac{d}{dt} f_1(f_2(t), f_3(t)) &= 6 \operatorname{f_{2}}{\left(t \right)} \frac{d}{d t} \operatorname{f_{2}}{\left(t \right)} + 5 \operatorname{f_{2}}{\left(t \right)} \frac{d}{d t} \operatorname{f_{3}}{\left(t \right)} + 5 \operatorname{f_{3}}{\left(t \right)} \frac{d}{d t} \operatorname{f_{2}}{\left(t \right)} + 8 \operatorname{f_{3}}{\left(t \right)} \frac{d}{d t} \operatorname{f_{3}}{\left(t \right)}\\ f_2^\prime(x) &= \cos{\left(x \right)} \\ f_3^\prime(x) &= e^{x} + \frac{1}{2 \sqrt{x}} \end{align} \]

and then, supposing we wanted to compute the derivative of \(f_1(f_2(x), f_3(x))\) at the point \(x = 1\), we can simply start with the blocks that depend directly on \(x\):

f3_dx_1 = f3_dx.subs({'x': 1.0})
f2_dx_1 = f2_dx.subs({'x': 1.0})
f3_1 = f_3(1.0)
f2_1 = f_2(1.0)

\[ \begin{align} f_3^\prime(1) &=& 3.21828182845905\\ f_2^\prime(1) &=& 0.54030230586814\\ f_3(1) &=& 3.71828182845905 \\ f_2(1) &=& 0.841470984807897 \end{align} \]

And then climb up to the more complex blocks to get our answer:

print(6*f2_1*f2_dx_1 + 5*f2_1*f3_dx_1 + 5*f3_1*f2_dx_1 + 8*f3_1*f3_dx_1)
## 122.045158140265

We can confirm that this is the right answer by comparing with the full symbolic derivation, evaluated at \(x\):

print(f_1(f_2(x), f_3(x)).diff(x).subs({'x': 1.0}))
## 122.045158140265

And, without formalizing things too much at this point, this is the essence of forward-mode autodiff: we start from the derivatives of the innermost functions, and work our way up to the outermost function. Before we move on, it is interesting to reflect about the relationship between functions and subfunctions in a composite function.

Function Breakdown. We wrote \(f = f_1(f_2(t), f_3(t))\) as a composite function of three relatively complex functions. We might just as well have written it as a function of much more fine-grained functions; e.g., we could have expressed \(f_1\) as \(f_1(x,y) = f_{+}(f_{+}(f_{\times}(a, f_{\times}(x,x), f_{\times}(b, f_{\times}(x, y))), f_{\times}(c, f_{\times}(y, y)))\), where \(f_{+}(x,y) = x + y\) and \(f_{\times}(x,y) = x \times y\), and the same principles we have developed above would still apply.

Indeed, this very fine-grained representation is precisely the representation one would get in an algorithmic setting where functions are provided by users in terms of elementary computing operations. Given that our final goal is understanding backpropagation in the context of neural networks, however, we do not need to reason at such a fine-grained level.

2 Reverse-Mode Autodiff

To get started with reverse-mode autodiff, we will look at an even simpler example (borrowed from these great notes) which we then use to explain the base principles, and to separate it from forward-mode autodiff.

Suppose we had a deep composition of single-parameter functions, as follows:

\[ \begin{align} a_1 &= f_1(a)\\ a_2 &= f_2(f_1(a)) = f_2(a_1)\\ a_3 &= f_3(f_2(f_1(a))) = f_3(a_2)\\ &\cdots\\ a_m &= f_m(f_{m-1}(\,\cdots\, f_2(f_1(a)) \,\cdots\,)) = f_m(a_{m-1}) \end{align} \]

Let’s call \(f(a) = f_m(f_{m-1}(\,\cdots\, f_2(f_1(a)) \,\cdots\,))\). The good thing about \(f\) is that it is relatively easy to understand how \(f^\prime\) looks like, even though it is a deep composite function. Indeed, by applying the (simple) chain rule repeatedly, we can expand the final derivative for this function into: \[ f^\prime(a) = f_m^\prime(a_{m-1}) f_{m-1}^\prime(a_{m-2}) \cdots f_2^\prime(a_2)f_1^\prime(a) \]

Now suppose we wanted to compute \(f^{\prime}(a_0)\) for some real-valued \(a_0\). Let us pause and think for a moment about what we already know: we know how to compute the derivatives \(f^{\prime}_{i}\) for each of the composing functions (since those are not composite functions themselves, we assume that their derivative is easy to compute and does not involve the chain rule), and we know the value of \(a_0\). But how do we proceed?

Forward-Mode Autodiff. With forward-mode autodiff, we would compute this expression for a given \(a = a_0\) by starting from the right. We would first evaluate \(a_1 = f_1^{\prime}(a_0)\), then \(a_2 = f_2(f_1(a_0))\), then \(f^{\prime}(a_2)\), then \(a_3 = f_2^\prime(a_2)f_1^\prime(a_0)\), and so on, alternating computations for the various \(f_i\) with computations on their derivatives \(f^\prime_i\) until we complete the evaluation of the full product, which would get us the value of \(f^\prime(a_0)\).

Reverse-Mode Autodiff. With reverse-mode autodiff, we would work our way backwards instead; i.e., the first thing we would compute is \(f^\prime(a_{m-1})\). But we do not know the value of \(a_{m-1}\), so we have to compute it first. This is not hard, however: we have the \(f_i\) and the value of \(a_0\), so we just have to run \(a_0\) through the entire composition and get the value we need. Note that we will, in the process, compute the values of all of the \(a_j\) as well, for \(1 \leq j \leq a_{m-1}\). This is called the “forward pass”. We then move to the left and apply the product to the arguments computed in the forward pass until the end, getting \(f^\prime(a_0)\) as before. This reverse pass is called “backpropagation”.

We now look at more complicated examples in the next sections, but the main idea nevertheless remains the same: we do a forward step to compute the values of all intermediate functions, and then backwards to compute the gradients by following the products prescribed by the chain rule.

2.1 The Multivariable Chain Rule

Let us complicate our deep composite function a little. Suppose we had a function as follows:

\[ f(a) = f_m(f_{1k}(f_{1k-1}\cdots f_{11}(a)\cdots), f_{2k-1}(\cdots f_{21}(a)\cdots)) \]

How can we compute the derivative of something like that? Well, the first thing to note is that the inner composite functions are, from the point of view of \(f_m\), just functions. If we set \(f_{1} = f_{1k-1}(\cdots f_{11}(a)\cdots)\) and \(f_{2} = f_{2k-1}(\cdots f_{21}(a)\cdots)\), we can rewrite this as:

\[ f(a) = f_m(f_{1}(a), f_2(a)) \]

we can then apply the multivariable chain rule to obtain:

\[ \begin{equation} f^{\prime}(a) = f_m^{\prime}(f_1(a), f_2(a))f_1^\prime(a) + f_{m\prime}(f_1(a), f_2(a))f_2^\prime(a) \tag{2.1} \end{equation} \]

Assuming we have already done the forward pass, we know the values for \(f_1(a)\) and \(f_2(a)\), as well as the forms of the derivatives for \(f_m^\prime\) and \(f_{m^\prime}\). The values of \(f_m^{\prime}(f_1(a), f_2(a))\) and \(f_{m\prime}(f_1(a), f_2(a))\) are, therefore, immeditaly available, and we have the first step of reverse-mode autodiff. We do not have the derivatives for \(f_1^\prime(a))\) or \(f_2^\prime(a)\) yet but, as you can imagine, all we will need to do is unpack those with the chain rule too. Before we do that, however, we need to dive in a little deeper in the chain rule.

2.2 Unifying the Rules

So far, we have been dealing with composite functions on one input variable; i.e., our final function \(f(a)\) has always been a function of one single variable \(a\).

Suppose now that we had a two-variable function; i.e.: \[ f(x_1,x_2) = f_m(f_{1}(x_1, x_2)), f_2(x_1)) \]

And we wanted to determine the partial derivatives of \(f\) with respect to \(x_1\) and \(x_2\). Partial derivatives are essentially equivalent to taking a function and interpreting all but one of its variables - the one you are taking the partial derivative with respect to - as constants. This, by construction, means that any partial derivative we take reduces to a single-variable case; i.e., there is nothing special about multiple variables.

For a function \(f(x_1,x_2)\) for which we are computing \(f^\prime(x_1,x_2)\), we will write \(f^\prime(x_1, \cdot)\) to reinforce the idea that we do not care about about \(x_2\) in this case (and, conversely, we will write \(f_\prime(\cdot,x_2)\) when we derive with respect to \(x_2\)).

We now take the partial derivative of \(f\) with respect to \(x_1\), which reduces to the multivariable chain rule:

\[ \begin{equation} f^\prime(x_1,x_2) = f_m^\prime(f_1(x_1,\cdot), f_2(x_1))\times f_1^\prime(x_1,\cdot) + f_{m\prime}(f_2(x_1,\cdot), f_2(x_1)\times f_{2}^\prime(x_1) \tag{2.2} \end{equation} \] and the partial derivative with respect to \(x_2\), which reduces to the simple chain rule:

\[ \begin{equation} f_\prime(x_1,x_2) = f_{m^\prime}(f_1(\cdot,x_2), f_2(\cdot)) \times f_{1^\prime}(\cdot, x_2) \tag{2.3} \end{equation} \]

Now note that the chain rules from Eq. (2.2) and (2.3) share a certain motif. Namely, if the outer function \(f_m(x_1, \cdots, x_n)\) is a function on \(n\) parameters, and the \(k^{th}\) parameter of \(f_m\) contains a function \(f_j(x_i,\cdots)\) that involves an input variable \(x_i\), then the partial derivative of \(f\) (our composite function which involves \(f_m\)) with respect to \(x_j\) will be a summation in which one of the terms will be of the form:

\[ \begin{equation} \frac{\partial}{\partial x_j}f_m(\cdots, f_j(x_i, \cdots), \cdots) \times \frac{\partial}{\partial x_j} f_j(x_i, \cdots) \tag{2.4} \end{equation} \]

Indeed, the partial derivative with respect to \(x_j\) will be a sum containing one such term for every parameter of \(f_m\) that involves \(x_j\), either directly; if, say, \(f_j(x_i) = x_i\), or indirectly; if, say, \(f_j(x_i)\) is itself a composite function like \(f_j(x_i) = f_k(f_s(x_i))\).

More formally, suppose we had \(k\) possible input variables \(\{x_1, \cdots, x_k\}\) and a function \(f\) defined on a subset of these variables, i.e., suppose we had:

\[ f\left(f_1\left(x_{\pi(1)}\right), \cdots, f_n\left(x_{\pi(n)}\right)\right) \]

where \(1 \leq \pi(i) \leq k\), \(1 \leq i \leq n\), and the \(\pi(i)\) can take on the same value for different values of \(i\); i.e., it is possible that \(i \neq j\) but \(\pi(i) = \pi(j)\) for \(1 \leq j \leq n\).

If we wanted to compute the partial derivative of \(f\) with respect to \(x_j\), we would have, by the (multivariable) chain rule:

\[ \frac{\partial}{\partial x_j}f\left(f_1\left(x_{\pi(1)}\right), \cdots, f_n\left(x_{\pi(n)}\right)\right) = \sum_{\{i~\mid~\pi(i) = j\}} \frac{\partial}{\partial x_i}f\left(f_1\left(x_{\pi(1)}\right), \cdots, f_n\left(x_{\pi(n)}\right)\right) \times \frac{\partial}{\partial x_j} f_i\left(x_{\pi(i)}\right) \]

If this is still looking too abstract, we will work through an example from beginning to end in the next section.

2.3 Expression Graphs

In taking the partial derivative of the composite \(f\) with respect to a variable \(x_i\), we had to:

  1. identify which parameters of \(f\) contained functions which referred to \(x_i\);
  2. apply the chain rule and add a term of the form given by Eq. (2.4) for each parameter that did;
  3. sum those terms together to obtain the final result.

In particular, we have so far relied on visually inspecting the function to do step (1), but this is not practical for complex functions. It is much more convenient (and the reason will hopefully become clear as we proceed), instead, to represent our expression as a graph (a DAG, really) where nodes are either: i) intermediate functions or ii) input variables (Fig. 2.1).

Let’s look again at a two-variable example:

example <- quote(f_1(f_2(x_1, f_2(x_2, x_1)), f_3(x_1)))

\[ f(x_1,x_2) = f_1(f_2(x_1, f_2(x_2, x_1)), f_3(x_1)) \]

But this time let us assign concrete functions to them:

  • \(f_1(x, y) = \sin{\left(\pi x \right)} + \cos{\left(\pi y \right)}\)
  • \(f_2(x, y) = x^{y}\)
  • \(f_3(x) = \log{\left(x \right)}\)

Now, supposing we wanted to compute \(\frac{\partial f}{\partial x_1}\), we would get, by the (multivariable) chain rule:

\[ \begin{align} \frac{\partial}{\partial x_1}f(x_1, x_2) =~ &\frac{\partial}{\partial x}f_1(f_2(x_1, f_2(x_2, x_1)), f_3(x_1))\times \frac{\partial}{\partial x_1}f_2(x_1, f_2(x_2, x_1))) + \frac{\partial}{\partial y}f_1(f_2(x_1, f_2(x_2, x_1)), f_3(x_1)) \times \frac{\partial }{\partial x_1}f_3(x_1) = \\ =~&\frac{\partial}{\partial x}f_1(f_2(x_1, f_2(x_2, x_1)), f_3(x_1))\times \left[\frac{\partial}{\partial x}f_2(x_1, f_2(x_2, x_1))) + \frac{\partial}{\partial y}f_2(x_1, f_2(x_2, x_1)))\times \frac{\partial}{\partial x_1}f_2(x_2, x_1)\right] +\\ &+ \frac{\partial}{\partial y}f_1(f_2(x_1, f_2(x_2, x_1)), f_3(x_1)) \times \frac{\partial }{\partial x_1}f_3(x_1) \end{align} \]

And this expands into the sum of three products:

\[ \begin{align} \frac{\partial}{\partial x_1}f(x_1, x_2) =&\frac{\partial}{\partial x}f_1(f_2(x_1, f_2(x_2, x_1)), f_3(x_1))\times\frac{\partial}{\partial x}f_2(x_1, f_2(x_2, x_1)))~+\\ + &\frac{\partial}{\partial x}f_1(f_2(x_1, f_2(x_2, x_1)), f_3(x_1))\times\frac{\partial}{\partial y}f_2(x_1, f_2(x_2, x_1)))\times \frac{\partial}{\partial x_1}f_2(x_2, x_1)~+ \\ + &\frac{\partial}{\partial y}f_1(f_2(x_1, f_2(x_2, x_1)), f_3(x_1)) \times \frac{\partial }{\partial x_1}f_3(x_1) \tag{2.5} \end{align} \]

Let us now bridge this to the expression graph of \(f\), shown in Fig. 2.1.

express(!!example) %>% render_graph()

Figure 2.1: Expression graph for \(f\).

Looking at Fig. 2.1, if we assign, to each outlink in the the DAG, the partial derivative of the function represented by the node from which the edge departs with respect to each of its parameters, it is not hard to see that that the final partial derivative \(\frac{\partial f}{\partial x_1}\) is given by the sum of the products of these edge expressions along each path that connects \(f_1\) to \(x_1\). Indeed, the three products in Eq. (2.5) correspond to the three paths; \((f_1, f_2); (f_1, f_2, f_2); (f_1, f_3)\), that connect \(f_1\) to \(x_1\) in the graph.

Let’s show that this is indeed true by using this method to compute the partial derivatives of \(f\) at the point \((2,3)\), and then comparing the results with the full symbolic derivative obtained by expanding Eq. (2.5).

We first need our building blocks; i.e., the partial derivatives of each of the composing functions:

\[ \frac{\partial}{\partial x} f_1(x, y) = \pi \cos{\left(\pi x \right)}\\ \frac{\partial}{\partial y} f_1(x, y) = - \pi \sin{\left(\pi y \right)}\\ \frac{\partial}{\partial x} f_2(x, y) = \frac{x^{y} y}{x}\\ \frac{\partial}{\partial y} f_2(x, y) = x^{y} \log{\left(x \right)}\\ \frac{\partial}{\partial x} f_3(x) = \frac{1}{x} \]

And we also need the results of the forward propagation step:

\[ \begin{align} x_1 &= 2\\ x_2 &= 3\\ f_2(x_2, x_1) &= 9\\ f_2(x_1, f_2(x_2, x_1)) &= 512\\ f_3(x_1) &= 0.693147180559945 \end{align} \]

Finally, we can put these together by computing the products along the paths from \(f_1\) to \(x_1\):

\[ \begin{align} (f_1, f_2, x_1) &= 7238.22947387088\\ (f_1, f_2, f_2, x_1) &= 11023.8236395328\\ (f_1, f_3) &= -1.29038221381431 \end{align} \]

So that the sum yields 18260.7627311899. As for the symbolic derivative, we get that the formula is:

x_1, x_2 = symbols('x_1 x_2')
df_x1 = f_1(f_2(x_1, f_2(x_2, x_1)), f_3(x_1)).diff(x_1)
print('$$ \\frac{\\partial}{\\partial x_1} f(x_1, x_2) = ' + latex(df_x1) + '$$')

\[ \frac{\partial}{\partial x_1} f(x_1, x_2) = \pi x_{1}^{x_{2}^{x_{1}}} \left(x_{2}^{x_{1}} \log{\left(x_{1} \right)} \log{\left(x_{2} \right)} + \frac{x_{2}^{x_{1}}}{x_{1}}\right) \cos{\left(\pi x_{1}^{x_{2}^{x_{1}}} \right)} - \frac{\pi \sin{\left(\pi \log{\left(x_{1} \right)} \right)}}{x_{1}}\]

Which, computing for \((2, 3)\), yields:

print(df_x1.subs({'x_1': 2, 'x_2': 3}).evalf())
## 18260.7627311899

as expected.

Appendix A: Backpropagation in Python

We will now use the ideas developed in the previous sections to put together a simple reverse-mode autodiff in Python. We will let the user specify relatively complicated functions and, instead of breaking them down into elementary pieces, we will resort to sympy to compute the partial derivatives.

import numbers
from typing import Callable, Mapping
import sympy as sy
import inspect

The code to our minimalistic reverse-mode autodiff is provided next.

class Function(object):
    """
    :class:`Function` represents a mathematical function for which we wish to compute the 
    gradient. :class:`Function`s can be created from :class:`Callable`s, which are often 
    lambda expressions:
  
    >> f = Function(lambda x: x**3)
    >>
  
    Functions can then be regularly called, but will return an :class:`AppliedFunction`, which 
    can then be used to inspect the values and compute the gradients for the declared symbols.
  
    >> call = f(2)
    >> call.value
    8
    >> symbols = call.compute_gradient()
    >> symbols['x'].gradient
    12
  
    Finally, functions can be composed:
  
    >> f1 = Function(lambda x, y: x*y)
    >> f2 = Function(lambda x, y: sympy.cos(sympy.pi * x)/sympy.sin(x))
    >> f3 = Function(lambda x, y: sympy.ln(x)/y)
  
    Note that we use the mathematical functions and constants from :mod:`sympy`. This is necessary 
    as otherwise we are not able to properly compute the derivatives later on. We can then do:
  
    >> x1 = Value(2, 'x1')
    >> x2 = Value(3, 'x2')
    >> call = f1(f2(x1, x2), f3(x1, x2))
    >> call.compute_gradient().gradient()
    {'x1': 0.299580760323820, 'x2': -0.0846987477622259}
    """
  
    def __init__(self, f: Callable):
        self.f = f
        self.args = inspect.signature(f).parameters.keys()
        self.derivatives = self._derivatives(f)
  
    def symbols(self):
        return {
            key: sy.Symbol(key)
            for key in self.args
        }
  
    def _derivatives(self, f: Callable):
        symbols = self.symbols()
        return {
            key: f(**symbols).diff(symbol)
            for key, symbol in symbols.items()
        }
  
    def __call__(self, *args):
        supplied = len(args)
        expected = len(self.derivatives)
        if supplied != expected:
            raise Exception('Expected %s arguments, got %s.' %
                            (expected, supplied))
  
        children = dict(zip(self.args, args))
        children = self._cast_values(children)
  
        return AppliedFunction(
            self,
            ### Forward propagation ###
            value=self.f(**{key: child.value for
                            key, child in children.items()}),
            children=children
        )
  
    def _cast_values(self, kwargs):
        return {
            key: Value(value, key) if isinstance(value, numbers.Number) else value
            for key, value in kwargs.items()
        }

  
class AppliedFunction(object):
    def __init__(self, ftype, value, children: Mapping[str, 'AppliedFunction']):
        self.ftype = ftype
        self.value = value
        self.children = children
  
    def compute_gradient(self, parent_gradient=1):
        pars = {key: child.value for key, child in self.children.items()}
        ### Backpropagation ###
        for key, child in self.children.items():
            child.compute_gradient(
                parent_gradient * self.ftype.derivatives[key].subs(pars).evalf()
            )
        return self
  
    def symbols(self):
        values = {}
        for child in self.children.values():
            values.update(child.symbols())
  
        return values
  
    def gradient(self):
        return {
            s.symbol: s.gradient for s in self.symbols().values()
        }
  
    def sympy(self):
        return self.ftype.f(
            **{key: child.sympy() for key, child in self.children.items()}
        )
  
  
class Value(object):
    def __init__(self, value, symbol):
        self.symbol = symbol
        self.value = float(value)
        self.gradient = 0
  
    def compute_gradient(self, parent_gradient):
        self.gradient += parent_gradient
  
    def symbols(self):
        return {self.symbol: self}
  
    def sympy(self):
        return sy.Symbol(self.symbol)
  
    def __repr__(self):
        return 'Value(%f, %s)' % (self.value, self.symbol)

To see that it works, we give it a go with the example we have already tried.

x1 = Value(2, 'x1')
x2 = Value(3, 'x2')
  
f_1 = Function(lambda x, y: sy.sin(sy.pi * x) + sy.cos(sy.pi * y))
f_2 = Function(lambda x, y: x**y)
f_3 = Function(lambda x: sy.ln(x))
  
f = f_1(f_2(x1, f_2(x2, x1)), f_3(x1))
f.compute_gradient()
  

<main.AppliedFunction object at 0x7fda3e6e0580>

print('Gradients are: $$\\begin{align}\\frac{\\partial}{\\partial x_1} f(2,3) &= %f\\\\' 
      '\\frac{\\partial}{\\partial x_2} f(2,3) &= %f\\end{align}$$' % (x1.gradient, x2.gradient))

Gradients are: \[\begin{align}\frac{\partial}{\partial x_1} f(2,3) &= 18260.762731\\\frac{\partial}{\partial x_2} f(2,3) &= 6689.544469\end{align}\]

And the results match.

comments powered by Disqus