Recursion, cache, speed

Recently, a friend sent me a piece of Python trivia, namely how to use cache memory in Python and how using cache memory affects the speed of your code.

A cache in the context of computers is a piece of memory where you can momentarily write something down and then read it very, very quickly. Something like jotting down an order number on a napkin, because we are about to call a supplier and need to have this information at hand.

Writing it down "momentarily" means that the information does not remain in the cache for too long. If we write something else down there, and then again, and so on, at some point we will run out of space on our napkin. Unlike a physical napkin, though, in the case of a computer cache, if space begins to run out, adding more information will erase the oldest ones. As a result, there will always be a portion of data available "at hand".

Please be aware though that the above description of how cache memory works is extremely simplified. Actual implementations are really, really complex. As we all know, caching things is one of the two most difficult problems in computer sciences, together with naming things and off-by-one errors.

A special case of cache use is when a function is executed and its result is read out. If the function is deterministic, i.e. if the same set of input data guarantees the production of the same result each time, subsequent calls to the function with the same input data can be greatly accelerated by storing the number of recent function calls (with their results) in the cache. If it turns out that we have already executed the function before using the same input data, then instead of running the function again and waiting for it to return a result, we can read the result from the cache. The speed gains can be quite remarkable, especially if the function takes a significantly long time to do its job, or when it is called many, many times.

Let's see an example:

import time

def fib(n):
    if n < 2:
        return n
    return fib(n-1) + fib(n-2)

begin = time.process_time_ns()
fib(40)
end = time.process_time_ns()
print(end-begin)

This is probably the dumbest possible implementation of a function that calculates the nth Fibonacci number: we define a fib function with single parameter n, then call it recursively for smaller and smaller numbers until we get down to n=2. Note, however, that this function runs two copies of itself each time it is called: fib(n-1) first and fib(n-2) immediately afterwards. Each of these two copies will do the same thing, so calculating - as in the example above - the fortieth Fibonacci number, the function will call itself about \(2^{40}\) times. \(2^{40}\) is well over a trillion times - it's a wonder my old box executes this script in finite time.

In what time, specifically?

24765625000 nanoseconds, or just under 25 seconds.

And now we'll do a little trick and get the above code to use cache:

from functools import lru_cache
import time

@lru_cache()
def fib(n):
    if n < 2:
        return n
    return fib(n-1) + fib(n-2)

begin = time.process_time_ns()
fib(40)
end = time.process_time_ns()
print(end-begin)

With a slight modification, our code now executes in … zero nanoseconds?

I don't know why, but that's what it is showing to me. Let's try increasing n...

from functools import lru_cache
import time

@lru_cache()
def fib(n):
    if n < 2:
        return n
    return fib(n-1) + fib(n-2)

begin = time.process_time_ns()
fib(400)
end = time.process_time_ns()
print(end-begin)

Still zero. Increasing again…

from functools import lru_cache
import time

@lru_cache()
def fib(n):
    if n < 2:
        return n
    return fib(n-1) + fib(n-2)

begin = time.process_time_ns()
fib(4000)
end = time.process_time_ns()
print(end-begin)

Ah, finally some results! Maybe not quite what I was expecting…

RecursionError: maximum recursion depth exceeded

Python protects against infinite recursion by setting the nesting limit to 1000. Fortunately, this limit can be raised, which is not normally a very sensible operation (1000 is a lot), but for the purposes of today's post…

from functools import lru_cache
import sys
import time

sys.setrecursionlimit(4000)

@lru_cache()
def fib(n):
    if n < 2:
        return n
    return fib(n-1) + fib(n-2)

begin = time.process_time_ns()
fib(4000)
end = time.process_time_ns()
print(end-begin)

This time the script exits with no error message and no results. Probably ran out of stack space or something…

By trial and error, I find the largest number for which the script finally starts working:

from functools import lru_cache
import sys
import time

sys.setrecursionlimit(4000)

@lru_cache()
def fib(n):
    if n < 2:
        return n
    return fib(n-1) + fib(n-2)

begin = time.process_time_ns()
fib(1396)
end = time.process_time_ns()
print(end-begin)

For n = 1396, our little script finally shows a non-zero execution time: 15625000 nanoseconds or 15 milliseconds. Less than two hundredths of a second to calculate the 1396th Fibonacci number using an extremely inefficient algorithm - and all this only thanks to clever use of the cache.

Of course, it is important to remember that this use of the @lru_cache decorator requires careful planning. In our example, it's all nice and easy, but in a larger project it might turn out that the fib function is doing something extra (for example, modifying some external variable or communicating with other processes etc), and at that point we can get ourselves into trouble.

Pure witchcraft.

1 Comment

Leave a Comment

Komentarze mile widziane.

Jeżeli chcesz do komentarza wstawić kod, użyj składni:
[code]
tutaj wstaw swój kod
[/code]