2021年7月7日水曜日

functools.cache とフィボナッチ

フィボナッチ数を計算する関数を再帰で素直に書く。

def fib(n):
    if n < 0:
        return n
    return fib(n-1) + fib(n-2)

このままではものすごく遅いので、メモ化すると良い。 というのは一般常識の内だろう。

Python には 3.9 からデコレーター functools.cache があるので、簡単だ。

from functools import cache

@cache
def fib(n):
    if n < 0:
        return n
    return fib(n-1) + fib(n-2)

もう少し前のバージョンを使わなければならない人も安心して欲しい。 デコレーター functools.lru_cache が 3.2 から存在していて、lru_cache(maxsize=None) と書けば実質的に cache と同じことをしてくれる。 実は上のフィボナッチ数を計算するプログラムは functools 公式ドキュメントの lru_cache 使用例から持ってきたものだ。

話はちょっと飛ぶのだが、昔やっていた NZMATH というプロジェクトをもう一度動かそうという動きがあるとかないとか。 いや動きがあるとは聞いているのだが、実際に動き出したかどうかはまだ見えていない。 そんな NZMATH にもフィボナッチ数を計算する関数 nzmath.sequence.fibonacci があった。

FIBONACCI = {0:0, 1:1}
def fibonacci(n):
    """
    param non-negative integer n
    return the n-th term of the Fibonacci
    effect FIBONACCI[n] = fibonacci(n)
    """
    if n < 0:
        raise ValueError("fibonacci(n)  0 <= n  ?")

    if n in FIBONACCI:
        return FIBONACCI[n]

    m = n >> 1
    if n & 1 == 0:
        f1 = fibonacci(m - 1)
        f2 = fibonacci(m)
        FIBONACCI[n] = (f1 + f1 + f2) * f2
    else:
        f1 = fibonacci(m)
        f2 = fibonacci(m + 1)
        FIBONACCI[n] = f1 ** 2 + f2 ** 2

    return FIBONACCI[n]

注目したいのは2点、計算方法とメモ化だ。 計算方法は、大体半分ぐらいのところの値を使って計算するという仕組になっている。 何を参考に書いたのか判らないが、線形再帰数列の一般論から出てくるはずだ。 あるいはフィボナッチ数に特化した話としてリベンボイムの「素数の話」に書いてあったのから取ったのかもしれない。 定義通りの計算はメモ化しても O(n) 項計算するのを避けられないが、この方法は O(log n) 項の計算で済む。 一方のメモ化だが、キャッシュに使う辞書 FIBONACCI を自前で準備している。 (当時対象バージョンは Python 2.5 だったので、まだ functools.lru_cache は登場していなかった)

今ならばこの自前のキャッシュを止めて、functools.cache を使ってより簡潔に書ける。

@cache
def fibonacci(n):
    if n < 0:
        raise ValueError("fibonacci(n)  0 <= n  ?")
    if n < 2:
        return n

    m = n >> 1
    if n & 1 == 0:
        f1 = fibonacci(m - 1)
        f2 = fibonacci(m)
        return (f1 + f1 + f2) * f2
    else:
        f1 = fibonacci(m)
        f2 = fibonacci(m + 1)
        return f1 ** 2 + f2 ** 2

速さに違いがあるか timeit を仕掛けてみた、が、考えてみたら2回目以降の呼び出しはただの辞書からの読み出しだから大して意味は無いかも。 一応書いておくと、fibonacci(1000) を計算させて、元のバージョンは "2000000 loops, best of 5: 145 nsec per loop" なのに対し、 デコレーター cache を使った方は "5000000 loops, best of 5: 59.8 nsec per loop" だった。 ちなみに定義通りに書いた方はもっと手前で再帰が最大回数まで使い果たしてエラーが出る。 再帰させる行の足し算を逆転するとちょっと限界が伸びるけどね。

@cache
def fib(n):
    if n < 0:
        return n
    return fib(n-2) + fib(n-1)