Rekurencja, pamięć podręczna, szybkość

Znajomy podesłał mi niedawno ciekawostkę pogramistyczną, a mianowicie jak w Pythonie korzystać z pamięci podręcznej i w jaki sposób używanie pamięci podręcznej wpływa na szybkość działania kodu.

Pamięć podręczna w kontekście komputerów to kawałek pamięci, do której można sobie coś na chwilę zapisać, a potem odczytać bardzo, bardzo szybko. Coś jakby zanotować sobie na serwetce numer zamówienia, bo zaraz dzwonimy do dostawcy i potrzebujemy mieć tę informację pod ręką.

Zapisanie "na chwilę" oznacza, że informacja nie pozostaje w pamięci podręcznej zbyt długo. Jeżeli zapiszemy tam sobie coś innego, a potem jeszcze coś innego i tak dalej, w pewnym momencie na naszej serwetce skończy się miejsce. W odróżnieniu jednak od fizycznej serwetki, w przypadku komputerowej pamięci podręcznej jeżeli zaczyna brakować miejsca, dodawanie kolejnych informacji będzie kasować te najstarsze. W efekcie zawsze jakaś tam porcja danych będzie dostępna "pod ręką".

Szczególnym przypadkiem zastosowania takiej pamięci podręcznej jest wykonywanie funkcji i odczytywanie jej wyniku. Jeżeli funkcja jest deterministyczna, a więc jeżeli ten sam zestaw danych wejściowych gwarantuje wyprodukowanie za każdym razem tego samego wyniku, to kolejne wywołania funkcji z takimi samymi danymi wejściowymi można znacznie przyspieszyć zapamiętując w pamięci podręcznej ileś-tam ostatnich wywołań funkcji (wraz z wynikami). Jeżeli okaże się, że wcześniej już uruchamialiśmy tę funkcję z takimi samymi danymi wejściowymi, wówczas zamiast uruchamiać funkcję ponownie i czekać aż zwróci wynik, możemy odczytać gotowy wynik z pamięci podręcznej. Zyski na prędkości działania mogą okazać się całkiem niebanalne, zwłaszcza jeżeli funkcja potrzebuje znacząco dużo czasu na wykonanie pracy, albo kiedy daną funkcję wywołujemy wiele, wiele razy.

Zobaczmy na przykładzie:

import time

def fib(n):
    if n < 2:
        return n
    return fib(n-1) + fib(n-2)

begin = time.process_time_ns()
fib(40)
end = time.process_time_ns()
print(end-begin)

Jest to prawdopodobnie najgłupsza możliwa implementacja funkcji wyliczającej n-tą z kolei liczbę Fibonacciego: definiujemy funkcję fib z jednym parametrem n, następnie wołamy ją rekurencyjnie dla coraz mniejszych liczb aż zejdziemy do dwójki. Proszę jednak zwrócić uwagę, że funkcja ta za każdym wywołaniem uruchamia swoje dwie kopie: najpierw fib(n-1) i zaraz potem fib(n-2). Każda z tych dwóch kopii robi to samo, więc licząc - tak jak w przykładzie powyżej - czterdziestą z kolei liczbę Fibonacciego, funkcja wywoła samą siebie około \(2^{40}\) razy. \(2^{40}\) to dobrze ponad bilion razy (polski bilion, czyli milion milionów, nie angielski billion) - aż cud, że mój stary grat wykonuje ten skrypt w skończonym czasie.

W jakim czasie, konkretnie?

24765625000 nanosekund, czyli niecałe 25 sekund.

A teraz zrobimy mały trick i nakażemy powyższemu kodowi używać pamięci podręcznej:

from functools import lru_cache
import time

@lru_cache()
def fib(n):
    if n < 2:
        return n
    return fib(n-1) + fib(n-2)

begin = time.process_time_ns()
fib(40)
end = time.process_time_ns()
print(end-begin)

Nieznaczna modyfikacja sprawiła, że nasz kod wykonuje się teraz w czasie... zero nanosekund?

Nie wiem dlaczego, ale tak mi właśnie pokazuje. Spróbujmy zwiększyć n...

from functools import lru_cache
import time

@lru_cache()
def fib(n):
    if n < 2:
        return n
    return fib(n-1) + fib(n-2)

begin = time.process_time_ns()
fib(400)
end = time.process_time_ns()
print(end-begin)

Nadal zero. Zwiększamy ponownie...

from functools import lru_cache
import time

@lru_cache()
def fib(n):
    if n < 2:
        return n
    return fib(n-1) + fib(n-2)

begin = time.process_time_ns()
fib(4000)
end = time.process_time_ns()
print(end-begin)

Ach, w końcu jakieś efekty. Może nie do końca takie, jakich oczekiwałem...

RecursionError: maximum recursion depth exceeded

Python zabezpiecza się przed nieskończoną rekurencją ustawiając limit zagnieżdżeń na 1000. Na szczęście można ten limit nieco podnieść, co w normalnych warunkach nie jest operacją zbyt rozsądną (1000 to dużo), ale na potrzeby dzisiejszego wpisu...

from functools import lru_cache
import sys
import time

sys.setrecursionlimit(4000)

@lru_cache()
def fib(n):
    if n < 2:
        return n
    return fib(n-1) + fib(n-2)

begin = time.process_time_ns()
fib(4000)
end = time.process_time_ns()
print(end-begin)

Tym razem skrypt wykrzacza się bez żadnego komunikatu błędu i bez wyników. Pewnie skończyło się miejsce na stosie czy coś...

Metodą prób i błędów znajduję największą liczbę, dla której skrypt zaczyna wreszcie działać:

from functools import lru_cache
import sys
import time

sys.setrecursionlimit(4000)

@lru_cache()
def fib(n):
    if n < 2:
        return n
    return fib(n-1) + fib(n-2)

begin = time.process_time_ns()
fib(1396)
end = time.process_time_ns()
print(end-begin)

Dla n = 1396 nasz skrypt pokazuje niezerowy czas wykonania: 15625000 nanosekund czyli 15 milisekund. Niecałe dwie setne sekundy, żeby policzyć 1396 z kolei liczbę Fibonacciego przy użyciu wyjątkowo nieefektywnego algorytmu - a wszystko to wyłącznie dzięki sprytnemu użyciu pamięci podręcznej.

Oczywiście należy pamiętać, że takie użycie dekoratora @lru_cache wymaga ostrożnego planowania. W naszym przykładzie wszystko jest łatwe proste i przyjemne, ale w większym projekcie może się okazać, że funkcja fib robi jeszcze coś dodatkowego (na przykład modyfikuje jakąś zmienną zewnętrzną albo komunikuje się z innymi procesami) i w tym momencie możemy sobie napytać biedy.

Czary, panie.

2 komentarze

  1. Fajne i prawdopodobnie kiedyś bardzo by mi się przydało przy pewnym PoCu. Zamiast tego wymyśliłem koło od nowa. Dekorator, przechwycenie wywołania, sha256 z argumentów i użycie memcached.

    Chociaż tam była wolna i powtarzająca się komunikacja z zewnętrzem i chyba bardziej zależało mi na cache pomiędzy uruchomieniami… Tak czy inaczej, dzięki.

    1. Jeżeli chodzi o wolną komunikację, trick też zadziała:

      import functools
      import time
      
      @functools.lru_cache()
      def func(x):
          time.sleep(1)
          print(f"Heavy operation for {x}")
          return x * 10
      
      print("Func returned:", func(1))
      print("Func returned:", func(2))
      print("Func returned:", func(3))
      print("Func returned:", func(3))
      print("Func returned:", func(2))
      print("Func returned:", func(1))
      

Leave a Comment

Komentarze mile widziane.

Jeżeli chcesz do komentarza wstawić kod, użyj składni:
[code]
tutaj wstaw swój kod
[/code]