Skip to content

Tail Recursion

Let's look at our sumUpTo function again:

Sum.hs
sumUpTo :: Int -> Int
sumUpTo 0 = 0
sumUpTo n = sumUpTo (n - 1) + n

This is an inductive definition. It says that the sum of the first 0 numbers is 0. For \(n > 0\), the sum of the first \(n\) numbers is the sum of the first \(n - 1\) numbers, plus \(n\).

We are interested in the second equation now. To calculate sumUpTo n, we first calculate sumUpTo (n - 1), and then we add n to it. This is the important part: we are not done after sumUpTo (n - 1) returns; we still need to add n to it. To support this, our program needs to remember the state of the computation of sumUpTo n when making the recursive call sumUpTo (n - 1), so it can continue the computation of sumUpTo n when sumUpTo (n - 1) returns. That's recursion.

To compute sumUpTo 10000000, our function makes 10 million recursive calls. None of these calls returns before the recursive call it makes returns. Thus, at the time when sumUpTo 0 is active, we need to remember the state of all 10 million recursive calls. We need 10 million stack frames to do this.

Contrast this with implementing the equivalent of sumUpTo in Python, using a loop:

sum.py
def sum_up_to(n):
    total = 0
    while n > 0:
        total += n
        n -= 1
    return total

There are no recursive calls at all. Thus, this function runs in constant space!

Our next exercise will be to translate the Python version of sum_up_to into a recursive function. It will not immediately be clear why this will lead us to a better implementation of sumUpTo, but bear with me.

The loop of the Python sum_up_to function maintains a state consisting of the contents of the variables n and total. At the beginning, we have total = 0, and n is whatever argument we call sum_up_to with. If n == 0, then the loop exits and the current value of total is what the function returns. If n > 0, then we start another iteration, which increases total by n and decreases n by one.

Let's write a Haskell function loop that mimics the behaviour of the loop of sum_up_to. Its arguments must capture the state of the loop, so it takes two arguments, n and total.

We just said that if n == 0, the loop exits and the current value of total becomes the result returned by sum_up_to. In Haskell,

loop 0 total = total

If n > 0, then we need to start another iteration, with total increased by n, and n decreased by 1. In Haskell,

loop n total = loop (n - 1) (total + n)

To start the loop, we need to call our loop function with the initial values of n and total as arguments. We said that at the beginning, n is whatever argument we pass to sum_up_to, and total = 0. Putting it all together, we get the following Haskell version of sum_up_to:

Sum.hs (Edit)
sumUpTo :: Int -> Int
sumUpTo 0 = 0
sumUpTo n = sumUpTo (n - 1) + n

fastSumUpTo :: Int -> Int
fastSumUpTo n = loop n 0
  where
    loop 0 total = total
    loop n total = loop (n - 1) (total + n)

This translation of a loop into a recursive function always works. To do this, we figure out what the loop variables are, the ones that change (or at least can change) in each iteration of the loop. These variables become the arguments to the loop function (which, of course, doesn't have to be called loop—I just called it loop because it simulates a loop). An invocation of loop returns the final result if its arguments meet the exit condition of the loop. Otherwise, loop calls itself recursively, with its arguments changed in the same way that the body of the loop would update the state of the loop variables.

Illustration of fastSumUpTo

Illustration of fastSumUpTo 4

So why is fastSumUpTo better than sumUpTo? It's not exactly faster (at least if you try it out in GHCi, which does not compile the code). What makes fastSumUpTo better is that the Haskell compiler can translate fastSumUpTo into a loop not unlike our Python function sum_up_to. So the code no longer builds up a massive number of stack frames corresponding to the recursive calls. It runs in constant space, just as the Python sum_up_to function does. How does this work?

fastSumUpTo calls loop once. loop then makes a lot of recursive calls to itself. However, whenever loop makes a recursive call, the result of the recursive call is the result of the current invocation. If you compare to our sumUpTo function, once the recursive call sumUpTo (n - 1) returns, we still have to add n to it to obtain the result of sumUpTo n. In contrast, we define loop n total to be exactly the same as loop (n - 1) (total + n). This means that after the recursive call to loop (n - 1) (total + n) returns, loop n total has no work left to do. It simply passes the result of loop (n - 1) (total + n) through to its caller. Thus, there is no need to remember the state of loop n total when making the recursive call loop (n - 1) (total n). Instead of calling loop (n - 1) (total + n) recursively, the compiler simply replaces the current invocation loop n total with this new invocation loop (n - 1) (total + n)—the new invocation reuses the current invocation's stack frame. This effectively gives us a loop.

The terminology for this type of function is that loop is a tail-recursive function or, equivalently, that the recursive call that loop makes is in tail position.

A function call is in tail position if the result of the function call is the result of the calling function.

Whenever this is the case, the compiler performs tail call optimization, which means that it generates code where the recursive call isn't actually a recursive call. Instead, the "recursive" call reuses the current stack frame and the resulting code looks like a loop.

Let me emphasize that there aren't many compilers that perform tail call optimization. The Haskell compiler and the Scheme interpreter do perform tail call optimization. In these languages, we get the efficiency of loops while programming elegantly using recursion (whether you consider recursion more elegant than a loop is a matter of taste). The reason why Haskell and Scheme support tail recursion is because they have to. Neither Haskell nor Scheme has loops, so without tail call optimization, programs written in these languages would be horribly inefficient. Any language that has built-in loop constructs has no need for tail call optimization because if your function is tail-recursive, it is logically a loop, and then it is often better to make this explicit by expressing the computation as an actual loop.

The lesson to be learnt here is that you should aim to make your functions tail-recursive whenever you can. The resulting implementation is often less pretty, but the gain in efficiency is well worth it. As another example of how to do this, we look at a faster implementation of computing Fibonacci numbers next.