Tail Recursion
Let's look at our sumUpTo
function again:
sumUpTo :: Int -> Int
sumUpTo 0 = 0
sumUpTo n = sumUpTo (n - 1) + n
This is an inductive definition. It says that the sum of the first 0 numbers is 0. For \(n > 0\), the sum of the first \(n\) numbers is the sum of the first \(n - 1\) numbers, plus \(n\).
We are interested in the second equation now. To calculate sumUpTo n
, we first
calculate sumUpTo (n - 1)
, and then we add n
to it. This is the important
part: we are not done after sumUpTo (n - 1)
returns; we still need to add
n
to it. To support this, our program needs to remember the state of the
computation of sumUpTo n
when making the recursive call sumUpTo (n - 1)
, so
it can continue the computation of sumUpTo n
when sumUpTo (n - 1)
returns.
That's recursion.
To compute sumUpTo 10000000
, our function makes 10 million recursive calls.
None of these calls returns before the recursive call it makes returns. Thus, at
the time when sumUpTo 0
is active, we need to remember the state of all 10
million recursive calls. We need 10 million stack frames to do this.
Contrast this with implementing the equivalent of sumUpTo
in Python, using a
loop:
def sum_up_to(n):
total = 0
while n > 0:
total += n
n -= 1
return total
There are no recursive calls at all. Thus, this function runs in constant space!
Our next exercise will be to translate the Python version of sum_up_to
into a
recursive function. It will not immediately be clear why this will lead us to
a better implementation of sumUpTo
, but bear with me.
The loop of the Python sum_up_to
function maintains a state consisting of the
contents of the variables n
and total
. At the beginning, we have total =
0
, and n
is whatever argument we call sum_up_to
with. If n == 0
, then
the loop exits and the current value of total
is what the function returns.
If n > 0
, then we start another iteration, which increases total
by n
and
decreases n
by one.
Let's write a Haskell function loop
that mimics the behaviour of the loop of
sum_up_to
. Its arguments must capture the state of the loop, so it takes two
arguments, n
and total
.
We just said that if n == 0
, the loop exits and the current
value of total
becomes the result returned by sum_up_to
. In Haskell,
loop 0 total = total
If n > 0
, then we need to start another iteration, with total
increased by
n
, and n
decreased by 1
. In Haskell,
loop n total = loop (n - 1) (total + n)
To start the loop, we need to call our loop
function with the initial values
of n
and total
as arguments. We said that at the beginning, n
is
whatever argument we pass to sum_up_to
, and total = 0
. Putting it all
together, we get the following Haskell version of sum_up_to
:
sumUpTo :: Int -> Int
sumUpTo 0 = 0
sumUpTo n = sumUpTo (n - 1) + n
fastSumUpTo :: Int -> Int
fastSumUpTo n = loop n 0
where
loop 0 total = total
loop n total = loop (n - 1) (total + n)
This translation of a loop into a recursive function always works. To do this,
we figure out what the loop variables are, the ones that change (or at least
can change) in each iteration of the loop. These variables become the
arguments to the loop
function (which, of course, doesn't have to be called
loop
—I just called it loop
because it simulates a loop). An invocation of
loop
returns the final result if its arguments meet the exit condition of the
loop. Otherwise, loop
calls itself recursively, with its arguments changed in
the same way that the body of the loop would update the state of the loop
variables.
So why is fastSumUpTo
better than sumUpTo
? It's not exactly faster (at least
if you try it out in GHCi, which does not compile the code). What makes
fastSumUpTo
better is that the Haskell compiler can translate fastSumUpTo
into a loop not unlike our Python function sum_up_to
. So the code no longer
builds up a massive number of stack frames corresponding to the recursive calls.
It runs in constant space, just as the Python sum_up_to
function does. How does
this work?
fastSumUpTo
calls loop
once. loop
then makes a lot of recursive calls to
itself. However, whenever loop
makes a recursive call, the result of the
recursive call is the result of the current invocation. If you compare to our
sumUpTo
function, once the recursive call sumUpTo (n - 1)
returns, we still
have to add n
to it to obtain the result of sumUpTo n
. In contrast, we
define loop n total
to be exactly the same as loop (n - 1) (total + n)
. This
means that after the recursive call to loop (n - 1) (total + n)
returns,
loop n total
has no work left to do. It simply passes the result of
loop (n - 1) (total + n)
through to its caller. Thus, there is no need to
remember the state of loop n total
when making the recursive call
loop (n - 1) (total n)
. Instead of calling loop (n - 1) (total + n)
recursively, the compiler simply replaces the current invocation
loop n total
with this new invocation loop (n - 1) (total + n)
—the new
invocation reuses the current invocation's stack frame. This effectively gives
us a loop.
The terminology for this type of function is that loop
is a tail-recursive
function or, equivalently, that the recursive call that loop
makes is in
tail position.
A function call is in tail position if the result of the function call is the result of the calling function.
Whenever this is the case, the compiler performs tail call optimization, which means that it generates code where the recursive call isn't actually a recursive call. Instead, the "recursive" call reuses the current stack frame and the resulting code looks like a loop.
Let me emphasize that there aren't many compilers that perform tail call optimization. The Haskell compiler and the Scheme interpreter do perform tail call optimization. In these languages, we get the efficiency of loops while programming elegantly using recursion (whether you consider recursion more elegant than a loop is a matter of taste). The reason why Haskell and Scheme support tail recursion is because they have to. Neither Haskell nor Scheme has loops, so without tail call optimization, programs written in these languages would be horribly inefficient. Any language that has built-in loop constructs has no need for tail call optimization because if your function is tail-recursive, it is logically a loop, and then it is often better to make this explicit by expressing the computation as an actual loop.
The lesson to be learnt here is that you should aim to make your functions tail-recursive whenever you can. The resulting implementation is often less pretty, but the gain in efficiency is well worth it. As another example of how to do this, we look at a faster implementation of computing Fibonacci numbers next.