[Python #15] メモ化による計算の効率化

はじめまして。新しくPythonチームに参加しましたT.Kです。

今回が初めての投稿なので、まずは簡単な自己紹介をさせてください。

私は岐阜県出身で、現在は東京都内の大学に通っている大学2年生です。大学では情報工学を専攻していて、プログラミングやアルゴリズム、その他いろいろな情報理論を学んでいます。

私は山登り系のサークルに所属していて、週末は山登りをしていることもあります。一日中家に引きこもる生活を送りがちな私ですが、たまには自然を満喫するのもいいなと思っています。

本インターンには今年の秋から参加しています。プログラミングは大学に入学してから始めた初心者ですが、ある程度慣れてきたところで実務経験を積みたいと思い、インターンに参加しようと思いました。

さて本題に入ろうと思います。

まず、プログラミング言語の分類について簡単に説明したいと思います。プログラミング言語は作成したプログラム(ソースコード)の実行方法によって大きく2つに分類出来ます。

C、C ++、Java等の言語は「コンパイル言語」と呼ばれます。コンパイル言語では、ソースコードを「コンパイル」という処理を経ることによって実行ファイルという新しいファイルが作成されます。プログラムを実行するためにはソースコードを実行するのではなく、その実行ファイルを実行する必要があります。コンパイルとは、プログラムをより高速に実行出来るように最適化を施す処理のことを言います。なのでコンパイル言語は高速に実行出来るという特徴があります。

一方、PythonやJavaScript等の言語は「スクリプト言語」と呼ばれます。スクリプト言語ではコンパイルの必要がなく、ソースコードをそのまま実行することが出来ます。スクリプト言語はコンパイル言語と比べて文法が易しく、プログラミングに慣れていなくても手を出しやすいという利点がありますが、最適化処理であるコンパイルを行わないので実行速度が遅いという特徴があります。

どちらの言語にもそれぞれに利点と欠点があるので一概にどちらがいいとは言えません。

少し難しい話をしてしまいましたが、つまり何が言いたいかというと、PythonはC、C ++言語などのコンパイル言語に比べて処理速度が遅いという特徴があります。そこで少しでも計算速度を上げる方法の一つとしてメモ化を紹介します。

メモ化とは

プログラムの中で同じ関数を使い、同じ計算を繰り返し実行するなんてことがあると思います。同じ計算なのですから繰り返すのは計算の無駄ですよね。

メモ化とは、一度行った計算の答えをキャッシュに記憶しておき、2回目以降に同じ計算式を見つけたときは記憶してある答えを返すというものです。これにより計算の効率が良くなります。

ここで注意点があります。メモ化は関数に対して行い、その関数は参照透過性を持つ(同じ引数に対して同じ戻り値を返す)必要があります。そうでないと同じ計算とは言えません。また、計算結果を記憶しておくためにメモリを消費します。

メモ化の方法

標準モジュールであるfunctoolsモジュールにメモ化を簡単に行うための関数が用意されているのでfunctoolsモジュールを使用します。functoolsモジュールをインポートすることで関数に様々な操作をすることができます。

やることは簡単で、次の2つだけです。

  1. functoolsモジュールからcache関数をインポート
  2. cache関数でメモ化したい関数をデコレート
from functools import cache

@cache
def yourfunction():

2.で「デコレート」という言葉を使っていますが、上記のコードの@cacheの部分のことです。これは関数デコレータと言い、関数を受け取り関数を返す関数のことです。自作の関数に@をつけた関数の処理を追加するということだけ分かってもらえれば十分です。

cache関数では一度行った計算の答えをキャッシュに記憶しておき、2回目以降に同じ計算式を見つけたときは記憶してある答えを返す処理を行います。今回の場合、2.を行うことでメモ化したい関数にcache関数の処理を追加することが出来ます。

デコレータの詳しい説明は今回の主旨からはずれるため 割愛させていただきます。Pythonのデコレータを理解するまで を参考にしてください。

また、functools内のlru_cache関数でも同様のことができます。lruはLeast Recently Usedの頭文字で、キャッシュメモリの方式の一つです。lru_cache関数では

@functools.lru_cache(maxsize=128)

のように引数でキャッシュサイズの制限ができます。cache関数ではキャッシュサイズの制限はありません。なのでlru_cache関数においてmaxsize=Noneとするとcache関数と同じ結果になります。キャッシュサイズを超えるとLRU方式に従い呼び出しの古い順にキャッシュが更新されます。LRU方式について詳しく知りたい方は LRU(Least Recently Used)とは を参照してください。

cache関数では古い値を更新する必要がないため、キャッシュサイズを超える場合には lru_cache() よりもすこし高速に動作します。

検証

同じ計算を繰り返す例として、フィボナッチ数の計算を再帰的に行う再帰関数を挙げたいと思います。

「再帰」とは、関数の中でその関数自身を呼び出す処理のことを言います。イメージで言うと、箱の中に箱があってその箱の中にまた箱があって・・・というのを繰り返す感じです。またその再帰処理を行う関数のことを再帰関数と言います。

フィボナッチ数列とは、数字の列が並んでいて、2つ前の数字と1つ前の数字を足した数が次の数字となる数列のことです。

例えば、数列の初め(Pythonでは最初を0番目とする)の数字が0、1番目が1だとしたら、その次の2番目の数字は0+1=1。3番目の数字は1+1=2。繰り返すと0、1、1、2、3、5、8 、13、21、34...となります。

関数を実装すると以下のようになります。

def fibonacci(n):
    if n < 2:
        return n
    else:
        return fibonacci(n - 1) + fibonacci(n - 2)

では、メモ化を行う前と後でどれくらい速さが違うか検証していきます。

時間の計測にはtimeモジュールを使用します。timeモジュール内のtime関数で現在時刻を取得することが出来ます。時間を計測したい処理の前後で現在時刻を取得し、その差を取ることで時間を計測することが出来ます。time関数で取得する時間の単位は秒です。

メモ化を行わないとき

import time

def fibonacci(n):
    if n < 2:
        return n
    else:
        return fibonacci(n - 1) + fibonacci(n - 2)

n = 40
start = time.time()
print(f'fibonacci({n}) =', fibonacci(n))
end = time.time()
print('time:', end - start, 'seconds')

出力

fibonacci(40) = 102334155
time: 27.186326265335083 seconds   

メモ化を行うとき

from functools import cache
import time

@cache
def fibonacci(n):
    if n < 2:
        return n
    else:
        return fibonacci(n - 1) + fibonacci(n - 2)

n = 40
start = time.perf_counter_ns()
print(f'fibonacci({n}) =', fibonacci(n))
end = time.perf_counter_ns()
print('time:', end - start, 'nanoseconds')

出力

fibonacci(40) = 102334155
time: 327600 nanoseconds

今回はn=40としてフィボナッチ数列の40番目の値を求めています。

結果はメモ化を行わないと27秒、メモ化を行うと327600ナノ秒(0.0003276秒)になりました。計測時間が秒だと細かく表示されないのでtime関数ではなくperf_counter_ns関数というものを使ってナノ秒を出力するようにしました。

メモ化によってかなり高速に計算が行われていることが分かりますね。

nをもっと大きくするとこの差は広がっていきます。

別の方法

上の検証では再帰関数を用いていますが、そもそも再帰的な計算を行うアルゴリズムが計算量を大きくしているのです。(アルゴリズムとは計算等の処理の方法、手順のことです。)なので単に再帰関数を用いない方法を考えるだけで計算は速くなります。

メモ化と少し似たところがありますが、計算した結果を変数に代入することで再帰処理を避けます。

import time

def fibonacci(n):
    if n < 2:
        return n
    else:
        tmp1 = 0
        tmp2 = 1
        i = 1
        while i < n:
            tmp3 = tmp1 + tmp2
            tmp1, tmp2 = tmp2, tmp3
            i += 1
        return tmp3

n = 40
start = time.perf_counter_ns()
print(f'fibonacci({n}) =', fibonacci(n))
end = time.perf_counter_ns()
print('time:', end - start, 'nanoseconds')

出力

fibonacci(40) = 102334155
time: 117700 nanoseconds

結果は117700ナノ秒=0.0001177秒 になりました。

メモ化を使わなくてもアルゴリズムの改良によって計算の高速化を図ることもできます。

上のコードでは繰り返しの処理にfor文ではなくwhile文を使っています。Pythonのfor文は遅いと言われているので代わりにwhile文を使うのも高速化のポイントの一つです。

実はPythonには再帰呼び出しの回数上限が定められており、上限回数を超えて再帰呼び出しを行おうとすると以下のようなエラーが出ます。

RecursionError: maximum recursion depth exceeded

なので再帰関数はどんな場合でも実行出来る訳ではなく、汎用性があまり高くないと言えます。そういった点からも再帰処理はできるだけ避けた方がよいでしょう。

おわりに

今回はメモ化による計算の効率化について紹介しました。

説明の中で難しい用語を所々使ってしまったので少しわかりにくいところがあったかもしれませんがご容赦ください。

そういった難しい話は置いておいて、メモ化は普段のプログラムにたったの2行を追加するだけで簡単に実装できるので知っておくと便利かもしれません。

今回のようなメモ化は実際の業務で使うということはあまりないかと思います。しかし、業務上処理の速さは重視されます。処理時間に加えてメモリ消費、コードの汎用性を意識するのも大切です。

私も今後のプログラミングライフでこれらを意識していきたいと思います。

今回紹介したものの他にも効率化のポイントはたくさんありますので興味がある方は調べてみて下さい。

参考

functools --- 高階関数と呼び出し可能オブジェクトの操作

1行追加するだけの超お手軽な高速化その1 メモ化

pythonのfunctoolsについてまとめてみた

Pythonの再帰回数の上限を確認・変更