前ステップまでに実装したコード
[1]:
import numpy as np
class Variable:
def __init__(self, data):
self.data = data
self.grad = None
self.creator = None
def set_creator(self, func):
self.creator = func
def backward(self):
f = self.creator
if f is not None:
x = f.input
x.grad = f.backward(self.grad)
x.backward()
class Function:
def __call__(self, input):
x = input.data
y = self.forward(x)
output = Variable(y)
output.set_creator(self)
self.input = input
self.output = output
return output
def forward(self, x):
raise NotImplementedError()
def backward(self, gy):
raise NotImplementedError()
class Square(Function):
def forward(self, x):
y = x ** 2
return y
def backward(self, gy):
x = self.input.data
gx = 2 * x * gy
return gx
class Exp(Function):
def forward(self, x):
y = np.exp(x)
return y
def backward(self, gy):
x = self.input.data
gx = np.exp(x) * gy
return gx
前ステップで私たちは、Variable
クラスにbackward
メソッドを追加しました。ここでは処理効率の改善と今後の拡張を見据えて、backward
メソッドを別の実装方式へと変更します。
再掲になりますが、私たちはVariable
クラスのbackward
メソッドを次のように実装しました。
[2]:
class Variable:
def __init__(self, data):
self.data = data
self.grad = None
self.creator = None
def set_creator(self, func):
self.creator = func
def backward(self):
f = self.creator
if f is not None:
x = f.input
x.grad = f.backward(self.grad)
x.backward()
ここで注目したいのは、backward
メソッドの中で、(入力側へ)1つ前の変数のbackward
メソッドが呼ばれている点です。これによって、「backward
メソッドの中でbackward
メソッドが呼ばれ、その呼ばれた先のbackward
メソッドでまたbackward
メソッドが呼ばれ、…」という処理が続きます(関数self.creator
がNone
になる変数が見つかるまで続きます)。これは再帰的な構造です。
ここでは、上の「再帰を使った実装」を「ループを使った実装」に書き換えます。そのコードを示すと、次のようになります。
[3]:
class Variable:
def __init__(self, data):
self.data = data
self.grad = None
self.creator = None
def set_creator(self, func):
self.creator = func
def backward(self):
funcs = [self.creator]
while funcs:
f = funcs.pop() # 関数を取得
x, y = f.input, f.output # 関数の入出力を取得
x.grad = f.backward(y.grad) # backwardメソッドを呼ぶ
if x.creator is not None:
funcs.append(x.creator) # 1つ前の関数をリストに追加
これがループを使った実装です。重要な点は、funcs
というリストに処理すべき関数を順に追加していくことです。while
ループの中では、funcs.pop()
によって処理すべき関数がf
として取り出され、その関数f
のbackward
メソッドが呼ばれます。このとき、f.input
とf.output
によって、関数f
の入出力の変数を取得することで、f.backward()
の引数と戻り値が正しく設定されます。
NOTE
リストのpop
メソッドは、リストの末尾が削除され、その要素が取得されます。たとえば、funcs = [1, 2, 3]
のときx = funcs.pop()
とすれば、3
が取り出され、funcs
は[1, 2]
となります。
それでは、上のVariable
クラスを使って、実際に微分を求めてみましょう。ここでも前ステップと同じコードを実行してみます。
[4]:
A = Square()
B = Exp()
C = Square()
x = Variable(np.array(0.5))
a = A(x)
b = B(a)
y = C(b)
# 逆伝播
y.grad = np.array(1.0)
y.backward()
print(x.grad)
3.297442541400256
結果は前と同じです。これで、「再帰」から「ループ」へと実装方式の切り替えができました。この「ループ」による実装の恩恵は、「ステップ15」で分かります。そこでは複雑な計算グラフを扱いますが、今の「ループ」による実装であれば、スムーズに拡張できます。また「ループ」の方が少しだけ処理効率も良くなります。
WARNING
再帰は、関数を再帰的に呼ぶたびに途中の結果をメモリに残しながら(スタックに積みながら)処理を続けます。そのため、一般的には、ループ方式の方が処理効率が良くなります。ただし、現代のコンピュータであれば多少のメモリ使用量は問題になりません。また「末尾再帰」という処理によって、再帰をループと同じように実行できる場合があります。
以上で、バックプロパゲーションの実装のベースは完成です。これから、さらに複雑な計算が行えるように、現状のDeZeroを拡張していきます。次のステップでは、DeZeroの「使いやすさ」の点について改善したいと思います。