A recent visit to hackerrank reminded me of a nice trick to make our recursive functions run faster - and it's called memoization.
A hackerrank question asked to find the number of ways a child can climb a staircase in a given height, provided that child can climb 1, 2 or 3 stairs at a time.
A recursive solution in Python easy to write:
import sys def ways(n): if n < 0: return 0 if n == 0: return 1 return ways(n-1) + ways(n-2) + ways(n-3) print(ways(int(sys.argv)))
And it even yields the correct result. But for some input this code failed raising a timeout error.
What Went Wrong
Turns out this code starts to get slow for any height above 20. For example searching for height 30 we get:
$ time python ways.py 30 53798080 real 0m29.549s user 0m29.022s sys 0m0.159s
30 seconds is too long for hackerrank, and also too long for me.
The main problem with the above recursion is that it calculates the same values many many times. Let's consider
ways(5) for example:
ways(5) = ways(4) + ways(3) + ways(2) ways(4) = ways(3) + ways(2) + ways(1) ways(2) = ways(1) + ways(0) + ways(-1)
Turns out ways(3) is being calculated 2 times. The larger the numbers the more duplicate calculations are performed, and thus the longer this function is going to take.
Quick Win: Memoize
Python has built-in support to "remember" previous function results (called memoization). All you need to do is add a decorator called
import sys from functools import lru_cache @lru_cache(maxsize=None) def ways(n): if n < 0: return 0 if n == 0: return 1 return ways(n-1) + ways(n-2) + ways(n-3) print(ways(int(sys.argv)))
That wasn't hard. And the result:
53798080 real 0m0.048s user 0m0.034s sys 0m0.010s
Memoization works great with recursive functions that use multiple recursive calls. Recursive staricase and Fibonacci numbers are two good examples.