Quick Tip: Using Memoization To Speed Up Recursive Functions

in python

A recent visit to hackerrank reminded me of a nice trick to make our recursive functions run faster - and it's called memoization.

The Question

A hackerrank question asked to find the number of ways a child can climb a staircase in a given height, provided that child can climb 1, 2 or 3 stairs at a time.

A recursive solution in Python easy to write:

import sys

def ways(n):
    if n < 0:  return 0
    if n == 0: return 1
    return ways(n-1) + ways(n-2) + ways(n-3)

print(ways(int(sys.argv[1])))

And it even yields the correct result. But for some input this code failed raising a timeout error.

What Went Wrong

Turns out this code starts to get slow for any height above 20. For example searching for height 30 we get:

$ time python ways.py 30
53798080

real	0m29.549s
user	0m29.022s
sys	0m0.159s

30 seconds is too long for hackerrank, and also too long for me.

The main problem with the above recursion is that it calculates the same values many many times. Let's consider ways(5) for example:

ways(5) = ways(4) + ways(3) + ways(2)
ways(4) = ways(3) + ways(2) + ways(1)
ways(2) = ways(1) + ways(0) + ways(-1)

Turns out ways(3) is being calculated 2 times. The larger the numbers the more duplicate calculations are performed, and thus the longer this function is going to take.

Quick Win: Memoize

Python has built-in support to "remember" previous function results (called memoization). All you need to do is add a decorator called lru_cache:

import sys
from functools import lru_cache

@lru_cache(maxsize=None)
def ways(n):
    if n < 0:  return 0
    if n == 0: return 1
    return ways(n-1) + ways(n-2) + ways(n-3)

print(ways(int(sys.argv[1])))

That wasn't hard. And the result:

53798080

real	0m0.048s
user	0m0.034s
sys	0m0.010s

Memoization works great with recursive functions that use multiple recursive calls. Recursive staricase and Fibonacci numbers are two good examples.

Comments