The code implemented in the previous step
[1]:
import numpy as np
class Variable:
def __init__(self, data):
self.data = data
self.grad = None
self.creator = None
def set_creator(self, func):
self.creator = func
def backward(self):
f = self.creator
if f is not None:
x = f.input
x.grad = f.backward(self.grad)
x.backward()
class Function:
def __call__(self, input):
x = input.data
y = self.forward(x)
output = Variable(y)
output.set_creator(self)
self.input = input
self.output = output
return output
def forward(self, x):
raise NotImplementedError()
def backward(self, gy):
raise NotImplementedError()
class Square(Function):
def forward(self, x):
y = x ** 2
return y
def backward(self, gy):
x = self.input.data
gx = 2 * x * gy
return gx
class Exp(Function):
def forward(self, x):
y = np.exp(x)
return y
def backward(self, gy):
x = self.input.data
gx = np.exp(x) * gy
return gx
In the previous step we added the backward method to the Variable class. Here, we change the backward method to a different implementation for improved efficiency and future extensions.
Again, we implemented the backward method of the Variable class as follows
[2]:
class Variable:
def __init__(self, data):
self.data = data
self.grad = None
self.creator = None
def set_creator(self, func):
self.creator = func
def backward(self):
f = self.creator
if f is not None:
x = f.input
x.grad = f.backward(self.grad)
x.backward()
Notice that the backward method of the previous variable is called in the backward method (to the input side). This allows us to say, “The backward method is called in the backward method, then the backward method is called again in the destination backward method, and so on… (until the function self.creator finds a variable that makes it None). This is a recursive structure.
Here, we will rewrite the above “implementation with recursion” to “implementation with loops”. Here’s what the code looks like
[3]:
class Variable:
def __init__(self, data):
self.data = data
self.grad = None
self.creator = None
def set_creator(self, func):
self.creator = func
def backward(self):
funcs = [self.creator]
while funcs:
f = funcs.pop() # Get a function
x, y = f.input, f.output # Get a function input/output
x.grad = f.backward(y.grad) # Call the backward method
if x.creator is not None:
funcs.append(x.creator) # Add one previous function to the list
This is the implementation using loops. The important thing is to add functions to the list of funcs in order to process them. In the while loop, the function to be processed by funcs.pop() is fetched as f and the backward method of the function f is called. The arguments and return values of f.backward() are set correctly by f.input and f.output to get the input and output variables of the function f.
NOTE
The pop method of the list removes the end of the list and retrieves its elements. For example, if funcs = [1, 2, 3] and x = funcs.pop(), then 3 will be taken out and funcs will become [1, 2].
Now let’s use the Variable class above to actually compute the derivative. Let’s run the same code here as in the previous step.
[4]:
A = Square()
B = Exp()
C = Square()
x = Variable(np.array(0.5))
a = A(x)
b = B(a)
y = C(b)
y.grad = np.array(1.0)
y.backward()
print(x.grad)
3.297442541400256
The results are the same as before. You can now switch the implementation method from “recursion” to “loop”. The benefits of implementing this “loop” will be seen in “Step 15”. We’re dealing with complex computational graphs there, but with the current “loop” implementation, it can be extended smoothly. Also, “loop” is a little more efficient.
WARNING
Recursion continues each time a function is called recursively, leaving the results in memory (while stacking them up). Therefore, in general, the loop method is more efficient. However, with modern computers, a little memory usage is not a problem. In some cases, the “tail recursion” process allows the recursion to be performed in the same way as a loop.
This completes the base of the back-propagation implementation. In the future, we will expand the current DeZero to allow more complex calculations. In the next step, I would like to improve on the “ease of use” of DeZero.