Skip to the content.

Tail Call Optimization in Clojure: Using Loop/Recur

Recently, I’ve been getting into Clojure. To get started, I’ve begun working through the Clojure Koans. They’ve provided a really helpful introduction to the language’s power and syntax.

One thing that I got stuck on, however, was a problem involving recursion. The objective is to define a function capable of computing the factorial of a really big integer.

Computing a factorial is relatively straightforward:

1
2
3
4
(defn factorial [n]
  (if (zero? n)
    1
    (*' n (factorial (dec n)))))

This will work fine with smaller numbers:

(factorial 5)
;= 120

But it runs into trouble with bigger numbers:

(factorial 10000)
;= StackOverflowError   clojure.lang.Numbers.isZero (Numbers.java:90)

The trouble this method is running into is that each recursive call to factorial is adding a new frame to the stack, as the value of n at every call has to be retained in memory.

A standard approach to optimizing this sort of function in functional programming would be to implement some sort of tail call optimization – have the function store the value that it’s reached up to its current point so that it doesn’t have to retain every value it computed before then in memory:

1
2
3
4
(defn factorial [n acc]
  (if (zero? n)
    acc
    (factorial (dec n) (*' n acc))))

But that will not save us here:

(factorial 10000 1)
;= StackOverflowError   clojure.lang.Numbers.isZero (Numbers.java:90)

Clojure, because it runs on the JVM, does not support tail call optimization.

So what’s a lonely aspiring Clojure programmer to do?

The answer, it turns out, is the recur construct. When placed at the end of a function, recur provides the benefits we’d expect from tail call optimization.

1
2
3
4
(defn factorial [n acc]
  (if (zero? n)
    acc
    (recur (dec n) (*' n acc))))

This method works!

(factorial 10000 1)
;= 2846259680917054518906413212119868890148...(really big number)

One final hangup in the Clojure Koans is to retain the arity of the original function (factorial takes only one parameter). Since we need an accumulator to store the value we’ve computed up to the current call, we’re going to need to bind some new values inside our function.

We can accomplish this with the Clojure loop construct:

1
2
3
4
5
6
(defn factorial [n]
  (loop [new-n-for-looping n
         acc-for-recursion 1]
    (if (zero? new-n-for-looping)
      acc-for-recursion
      (recur (dec new-n-for-looping) (*' new-n-for-looping acc-for-recursion)))))

Here, we bind the passed value of n to our new-n-for-looping and 1 to our acc-for-recursion. The accumulator is multiplied by the current value of new-n-for-looping in each recursive call until the base case is reached (new-n-for-looping reaches zero), and we can compute the factorial of very large numbers without that pesky StackOverflowError.