前ステップまでに実装したコード
[1]:
import numpy as np
class Variable:
def __init__(self, data):
self.data = data
self.grad = None
self.creator = None
def set_creator(self, func):
self.creator = func
def backward(self):
f = self.creator
if f is not None:
x = f.input
x.grad = f.backward(self.grad)
x.backward()
class Function:
def __call__(self, input):
x = input.data
y = self.forward(x)
output = Variable(y)
output.set_creator(self)
self.input = input
self.output = output
return output
def forward(self, x):
raise NotImplementedError()
def backward(self, gy):
raise NotImplementedError()
class Square(Function):
def forward(self, x):
y = x ** 2
return y
def backward(self, gy):
x = self.input.data
gx = 2 * x * gy
return gx
class Exp(Function):
def forward(self, x):
y = np.exp(x)
return y
def backward(self, gy):
x = self.input.data
gx = np.exp(x) * gy
return gx
前ステップで私たちは、Variableクラスにbackwardメソッドを追加しました。ここでは処理効率の改善と今後の拡張を見据えて、backwardメソッドを別の実装方式へと変更します。
再掲になりますが、私たちはVariableクラスのbackwardメソッドを次のように実装しました。
[2]:
class Variable:
def __init__(self, data):
self.data = data
self.grad = None
self.creator = None
def set_creator(self, func):
self.creator = func
def backward(self):
f = self.creator
if f is not None:
x = f.input
x.grad = f.backward(self.grad)
x.backward()
ここで注目したいのは、backwardメソッドの中で、(入力側へ)1つ前の変数のbackwardメソッドが呼ばれている点です。これによって、「backwardメソッドの中でbackwardメソッドが呼ばれ、その呼ばれた先のbackwardメソッドでまたbackwardメソッドが呼ばれ、…」という処理が続きます(関数self.creatorがNoneになる変数が見つかるまで続きます)。これは再帰的な構造です。
ここでは、上の「再帰を使った実装」を「ループを使った実装」に書き換えます。そのコードを示すと、次のようになります。
[3]:
class Variable:
def __init__(self, data):
self.data = data
self.grad = None
self.creator = None
def set_creator(self, func):
self.creator = func
def backward(self):
funcs = [self.creator]
while funcs:
f = funcs.pop() # 関数を取得
x, y = f.input, f.output # 関数の入出力を取得
x.grad = f.backward(y.grad) # backwardメソッドを呼ぶ
if x.creator is not None:
funcs.append(x.creator) # 1つ前の関数をリストに追加
これがループを使った実装です。重要な点は、funcsというリストに処理すべき関数を順に追加していくことです。whileループの中では、funcs.pop()によって処理すべき関数がfとして取り出され、その関数fのbackwardメソッドが呼ばれます。このとき、f.inputとf.outputによって、関数fの入出力の変数を取得することで、f.backward()の引数と戻り値が正しく設定されます。
NOTE
リストのpopメソッドは、リストの末尾が削除され、その要素が取得されます。たとえば、funcs = [1, 2, 3]のときx = funcs.pop()とすれば、3が取り出され、funcsは[1, 2]となります。
それでは、上のVariableクラスを使って、実際に微分を求めてみましょう。ここでも前ステップと同じコードを実行してみます。
[4]:
A = Square()
B = Exp()
C = Square()
x = Variable(np.array(0.5))
a = A(x)
b = B(a)
y = C(b)
# 逆伝播
y.grad = np.array(1.0)
y.backward()
print(x.grad)
3.297442541400256
結果は前と同じです。これで、「再帰」から「ループ」へと実装方式の切り替えができました。この「ループ」による実装の恩恵は、「ステップ15」で分かります。そこでは複雑な計算グラフを扱いますが、今の「ループ」による実装であれば、スムーズに拡張できます。また「ループ」の方が少しだけ処理効率も良くなります。
WARNING
再帰は、関数を再帰的に呼ぶたびに途中の結果をメモリに残しながら(スタックに積みながら)処理を続けます。そのため、一般的には、ループ方式の方が処理効率が良くなります。ただし、現代のコンピュータであれば多少のメモリ使用量は問題になりません。また「末尾再帰」という処理によって、再帰をループと同じように実行できる場合があります。
以上で、バックプロパゲーションの実装のベースは完成です。これから、さらに複雑な計算が行えるように、現状のDeZeroを拡張していきます。次のステップでは、DeZeroの「使いやすさ」の点について改善したいと思います。