The code implemented in the previous step
[1]:
import numpy as np
class Variable:
def __init__(self, data):
self.data = data
self.grad = None
self.creator = None
def set_creator(self, func):
self.creator = func
def backward(self):
f = self.creator
if f is not None:
x = f.input
x.grad = f.backward(self.grad)
x.backward()
class Function:
def __call__(self, input):
x = input.data
y = self.forward(x)
output = Variable(y)
output.set_creator(self)
self.input = input
self.output = output
return output
def forward(self, x):
raise NotImplementedError()
def backward(self, gy):
raise NotImplementedError()
class Square(Function):
def forward(self, x):
y = x ** 2
return y
def backward(self, gy):
x = self.input.data
gx = 2 * x * gy
return gx
class Exp(Function):
def forward(self, x):
y = np.exp(x)
return y
def backward(self, gy):
x = self.input.data
gx = np.exp(x) * gy
return gx
In the previous step we added the backward
method to the Variable
class. Here, we change the backward
method to a different implementation for improved efficiency and future extensions.
Again, we implemented the backward
method of the Variable
class as follows
[2]:
class Variable:
def __init__(self, data):
self.data = data
self.grad = None
self.creator = None
def set_creator(self, func):
self.creator = func
def backward(self):
f = self.creator
if f is not None:
x = f.input
x.grad = f.backward(self.grad)
x.backward()
Notice that the backward
method of the previous variable is called in the backward
method (to the input side). This allows us to say, “The backward
method is called in the backward
method, then the backward
method is called again in the destination backward
method, and so on… (until the function self.creator
finds a variable that makes it None
). This is a recursive structure.
Here, we will rewrite the above “implementation with recursion” to “implementation with loops”. Here’s what the code looks like
[3]:
class Variable:
def __init__(self, data):
self.data = data
self.grad = None
self.creator = None
def set_creator(self, func):
self.creator = func
def backward(self):
funcs = [self.creator]
while funcs:
f = funcs.pop() # Get a function
x, y = f.input, f.output # Get a function input/output
x.grad = f.backward(y.grad) # Call the backward method
if x.creator is not None:
funcs.append(x.creator) # Add one previous function to the list
This is the implementation using loops. The important thing is to add functions to the list of funcs
in order to process them. In the while
loop, the function to be processed by funcs.pop()
is fetched as f
and the backward
method of the function f
is called. The arguments and return values of f.backward()
are set correctly by f.input
and f.output
to get the input and output variables of the function f
.
NOTE
The pop
method of the list removes the end of the list and retrieves its elements. For example, if funcs = [1, 2, 3]
and x = funcs.pop()
, then 3
will be taken out and funcs
will become [1, 2]
.
Now let’s use the Variable
class above to actually compute the derivative. Let’s run the same code here as in the previous step.
[4]:
A = Square()
B = Exp()
C = Square()
x = Variable(np.array(0.5))
a = A(x)
b = B(a)
y = C(b)
y.grad = np.array(1.0)
y.backward()
print(x.grad)
3.297442541400256
The results are the same as before. You can now switch the implementation method from “recursion” to “loop”. The benefits of implementing this “loop” will be seen in “Step 15”. We’re dealing with complex computational graphs there, but with the current “loop” implementation, it can be extended smoothly. Also, “loop” is a little more efficient.
WARNING
Recursion continues each time a function is called recursively, leaving the results in memory (while stacking them up). Therefore, in general, the loop method is more efficient. However, with modern computers, a little memory usage is not a problem. In some cases, the “tail recursion” process allows the recursion to be performed in the same way as a loop.
This completes the base of the back-propagation implementation. In the future, we will expand the current DeZero to allow more complex calculations. In the next step, I would like to improve on the “ease of use” of DeZero.