Open In Colab

ステップ8 再帰からループへ

前ステップまでに実装したコード

[1]:
import numpy as np


class Variable:
    def __init__(self, data):
        self.data = data
        self.grad = None
        self.creator = None

    def set_creator(self, func):
        self.creator = func

    def backward(self):
        f = self.creator
        if f is not None:
            x = f.input
            x.grad = f.backward(self.grad)
            x.backward()


class Function:
    def __call__(self, input):
        x = input.data
        y = self.forward(x)
        output = Variable(y)
        output.set_creator(self)
        self.input = input
        self.output = output
        return output

    def forward(self, x):
        raise NotImplementedError()

    def backward(self, gy):
        raise NotImplementedError()


class Square(Function):
    def forward(self, x):
        y = x ** 2
        return y

    def backward(self, gy):
        x = self.input.data
        gx = 2 * x * gy
        return gx


class Exp(Function):
    def forward(self, x):
        y = np.exp(x)
        return y

    def backward(self, gy):
        x = self.input.data
        gx = np.exp(x) * gy
        return gx

前ステップで私たちは、Variableクラスにbackwardメソッドを追加しました。ここでは処理効率の改善と今後の拡張を見据えて、backwardメソッドを別の実装方式へと変更します。

8.1 現時点のVariableクラス

再掲になりますが、私たちはVariableクラスのbackwardメソッドを次のように実装しました。

[2]:
class Variable:
    def __init__(self, data):
        self.data = data
        self.grad = None
        self.creator = None

    def set_creator(self, func):
        self.creator = func

    def backward(self):
        f = self.creator
        if f is not None:
            x = f.input
            x.grad = f.backward(self.grad)
            x.backward()


ここで注目したいのは、backwardメソッドの中で、(入力側へ)1つ前の変数のbackwardメソッドが呼ばれている点です。これによって、「backwardメソッドの中でbackwardメソッドが呼ばれ、その呼ばれた先のbackwardメソッドでまたbackwardメソッドが呼ばれ、…」という処理が続きます(関数self.creatorNoneになる変数が見つかるまで続きます)。これは再帰的な構造です。

8.2 ループを使った実装

ここでは、上の「再帰を使った実装」を「ループを使った実装」に書き換えます。そのコードを示すと、次のようになります。

[3]:
class Variable:
    def __init__(self, data):
        self.data = data
        self.grad = None
        self.creator = None

    def set_creator(self, func):
        self.creator = func

    def backward(self):
        funcs = [self.creator]
        while funcs:
            f = funcs.pop()  # 関数を取得
            x, y = f.input, f.output  # 関数の入出力を取得
            x.grad = f.backward(y.grad)  # backwardメソッドを呼ぶ

            if x.creator is not None:
                funcs.append(x.creator)  # 1つ前の関数をリストに追加


これがループを使った実装です。重要な点は、funcsというリストに処理すべき関数を順に追加していくことです。whileループの中では、funcs.pop()によって処理すべき関数がfとして取り出され、その関数fbackwardメソッドが呼ばれます。このとき、f.inputf.outputによって、関数fの入出力の変数を取得することで、f.backward()の引数と戻り値が正しく設定されます。

NOTE

リストのpopメソッドは、リストの末尾が削除され、その要素が取得されます。たとえば、funcs = [1, 2, 3]のときx = funcs.pop()とすれば、3が取り出され、funcs[1, 2]となります。

8.3 動作確認

それでは、上のVariableクラスを使って、実際に微分を求めてみましょう。ここでも前ステップと同じコードを実行してみます。

[4]:
A = Square()
B = Exp()
C = Square()

x = Variable(np.array(0.5))
a = A(x)
b = B(a)
y = C(b)

# 逆伝播
y.grad = np.array(1.0)
y.backward()
print(x.grad)
3.297442541400256


結果は前と同じです。これで、「再帰」から「ループ」へと実装方式の切り替えができました。この「ループ」による実装の恩恵は、「ステップ15」で分かります。そこでは複雑な計算グラフを扱いますが、今の「ループ」による実装であれば、スムーズに拡張できます。また「ループ」の方が少しだけ処理効率も良くなります。

WARNING

再帰は、関数を再帰的に呼ぶたびに途中の結果をメモリに残しながら(スタックに積みながら)処理を続けます。そのため、一般的には、ループ方式の方が処理効率が良くなります。ただし、現代のコンピュータであれば多少のメモリ使用量は問題になりません。また「末尾再帰」という処理によって、再帰をループと同じように実行できる場合があります。

以上で、バックプロパゲーションの実装のベースは完成です。これから、さらに複雑な計算が行えるように、現状のDeZeroを拡張していきます。次のステップでは、DeZeroの「使いやすさ」の点について改善したいと思います。