Continuation Monad

Preface

This tutorial explains continuations and the continuation monad. It’s pretty long, but the payoff is an understanding of a powerful abstraction. The example code for this tutorial is here.

What’s a continuation anyway?

The continuation monad is probably the most powerful of the standard monads. (See here) But it is very difficult to acquire a mental model of how this monad works and how to use it because it involves two very abstract concepts; that of continuations and that of monads.

The first step is to understand continuations. You’ll come across statements like ‘a continuation represents the future of the computation’ or ‘a continuation reifies the return stack’ which are absolutely true and totally unhelpful for grasping the concept, at least for me. So let’s look at some code instead.

(defn func-a [x]
  (inc x))

(defn func-b [x]
  (* 2 x))

(defn func-c [x]
  (dec x))

(defn fn1 [x]
  (let [a (func-a x)
        b (func-b a)
        c (func-c b)]
    c))

We have 3 simple functions and then fn1 which calls them. fn1 is not idiomatic clojure, but is written in this form to demonstrate a point. The output of func-a is fed into func-b and the output of func-b is fed into func-c. So execution enters fn1, goes through func-a, continues through func-b and then continues through func-c before passing out of fn1. If fn1 was a movie trilogy. We’d say that func-a was the beginning of the story and that func-b was the continuation of the story from func-a. Likewise, func-c would be the continuation of the story from func-b. In the code, it’s the same idea, func-c is the continuation of func-b. And func-b followed by func-c is the continuation of func-a.

Also note that fn1 could have been defined like:

(def fn1 (comp func-c func-b func-a))

This alternative will be important later.

That should give you enough of a grasp of what continuations are to get started. A more complete understanding should develop through the rest of this tutorial.

Passing continuations around

Here’s a key thing to realize, clojure is a functional language which means that functions can be passed as parameters to other functions. So, if we modified func-b above to look like:

(defn func-b [x f]
  (f (* 2 x)))

it would then take 2 arguments. The first would be a number and second would be a function which would be called with the value of x multiplied by 2. The value returned by this function call would then be the value returned by func-b. Nothing complicated, just a function being passed as a parameter. So then, fn1 could be rewritten like:

(defn fn2 [x]
  (let [a (func-a x)
        c (func-b a func-c)]
    c))

What we’ve just done is convert func-b to continuation passing style (CPS). Since func-c is the continuation of func-b, passing it to func-b as a parameter in essence is telling func-b what to do with its result, rather than just blindly returning it to the caller.

Now what would happen if we converted func-a to CPS like this?

(defn func-a [x f]
  (f (inc x)))

We can’t just pass func-b to func-a like this:

(func-a x func-b)

because func-b no longer takes a single argument. Not only that, func-b needs to be given its continuation when it’s called, so what we need is a function that takes a single argument, then calls func-b with that argument as well as the function that func-b is supposed to pass its result to, func-c. Something like:

(defn cont-a [x]
  (func-b x func-c))

This function is named ‘cont-a’ because it’s the continuation of func-a. With this new function, we can write fn1 like:

(defn fn3 [x]
  (let [c (func-a x cont-a)]
    c))

or, more neatly:

(defn fn3 [x]
  (func-a x cont-a))

Taking another step, let’s convert func-c to CPS

(defn func-c [x f]
  (f (dec x)))

Which makes it look like func-a and func-b, but cont-a has to be rewritten because func-c can’t be passed to func-b now. For its second parameter, func-b needs a function that takes a single argument and somehow passes it to func-c. But func-c needs to know what function to pass its result to, in addition to the result of func-b. Remember that func-c is going to call that function with its result and then return the result of that call as its return value to func-b. This value is then going to be returned by func-b as its return value to func-a, which is going to return it to fn3, which is going to return it as its final value. So what we need is a function that does nothing to the result of func-c. In clojure that function is ‘identity’, so cont-b would look like this:

(defn cont-b [x]
  (func-c x identity))

and cont-a would be written like:

(defn cont-a [x]
  (func-b x cont-b))

but fn3 would look the same, except now cont-a is different.

(defn fn4 [x]
  (func-a x cont-a))

But what if we took one final step? (You’re probably thinking I should quit while I’m ahead.) What if we converted fn4 to CPS? Well for one thing, it would take two parameters an integer and a function that would be the continuation of fn4. This function would be called with the result of fn4, like this:

(defn fn5 [x f]
  (f (func-a x cont-a)))

and it could be called like:

(fn5 10 identity)

But what would be really cool, is if that identity in the call to fn5 could somehow be given to cont-b to be passed as the continuation of func-c. This would mean defining cont-b inside of fn5 and that would also mean defining cont-a inside of fn5. Which is no big deal in a functional language like clojure.

(defn fn6 [x f]
  (let [cont-b (fn [x]
                 (func-c x f))
        cont-a (fn [x]
                 (func-b x cont-b))]
    (func-a x cont-a)))

Which is all well and good, but it would be really nice to be able to do something similar to

(def fn1 (comp func-c func-b func-a))

where the func’s and fn1 are all in continuation passing style.

Looking at fn6, you’ll notice that cont-b and cont-a are really just plumbing that allow func-a, func-b and func-c to be plugged together. And that both of these functions have the same general structure. If you’ve read the earlier monad tutorial here some bells should be going off. Composing functions that require plumbing to work together is the job of a monad.

Continuation plumbing

If we want to write a monad to handle this CPS plumbing, we have to obey the monad laws

(m-bind (m-result x) f) is equal to (f x)
(m-bind mv m-result) is equal to mv
(m-bind (m-bind mv f) g) is equal to (m-bind mv (fn [x] (m-bind (f x) g)))

where f and g are monadic functions. Looking at the first law, we see that a monadic function has to accept a single value which is not a monadic value. So CPS functions, like func-a for example, are not monadic functions. Looking at the third law, we see that monadic functions have to return monadic values. This means that a monadic function takes a single value of some type and returns a monadic value. But what should the type of that monadic value be?

Consider func-a. What we’d like is some function similar to func-a that accepts a single value instead of a value and a continuation. This is possible if the monadic func-a (mf-a) accepted a single value and returned a function that would in turn accept a continuation and then call func-a with the value and continuation.

(defn mf-a [x]
  (fn [c]
    (func-a x c)))

If we replace the call to func-a with an inline function definition, we get:

(defn mf-a [x]
  (fn [c]
    ((fn [y f]
       (f (inc y)))
       x c)))

And if we re-arrange things to eliminate the inline function, we get:

(defn mf-a [x]
  (fn [c]
    (c (inc x))))

Doing the same kind of transformation to func-b and func-c, we get:

(defn mf-b [x]
  (fn [c]
    (c (* 2 x))))

(defn mf-c [x]
  (fn [c]
    (c (dec x))))

This would mean that a monadic value would be a function, which is similar to the state-m monad. In this case, the monadic value would be a function that would accept as its parameter another function which would then be called.

Now if a monadic value is a function, how can m-result be written so that it will accept a value of any type and wrap it in a function that accepts a continuation does the right thing with it? This is pretty straightforward:

m-result   (fn m-result-cont [v]
              (fn [c]
                (c v)))

m-bind is where things get complicated. Paying attention to the monadic laws we see m-bind must do the following:

Whew! Fortunately, the code to do all this only takes 4 lines.

m-bind     (fn m-bind-cont [mv mf]
             (fn [c]
               (mv (fn [v]
                     ((mf v) c)))))

m-bind is deceptive in its apparent simplicity. The guys who originally discovered this monad were geniuses and Konrad Hinson has done an excellent job translating it elegantly into clojure.

Working with cont-m

Before we turn to reasons to use the continuation monad, let’s work with it a little to get a sense for what it does. Starting with m-result, we can embed a value in an mv and extract it.

((m-result 21) identity)

Remember that m-result returns a monadic value which is a function that accepts a continuation and calls that continuation with the value embedded in it. So this expression first creates such a monadic value with 21 in it, then calls it with identity as its continuation. This causes identity to be called with value 21 so the value of this entire expression is 21.

We can also use m-bind to pass that value to a monadic function:

((m-bind (m-result 21) mf-a) identity)

which returns the value 22. Like m-result, m-bind returns a monadic value, which must be called with identity as its continuation.

((m-bind (m-bind (m-result 21) mf-a) mf-b) identity)

This expression returns a value 44. An equivalant expression is:

((m-bind (m-result 21) (m-chain [mf-a mf-b])) identity)

Finally, we can rewrite fn6 like this:

(def fn7 (m-chain [mf-a mf-b mf-c]))

((fn7 10) identity)

The last expression returns 21.

But why?

We’ve almost reached the point where the power of the continuation monad becomes apparent. To show this, we need to make some small modifications to our functions.

(defn mf-a [x]
  (println "starting mf-a")
  (fn [c]
    (println "completing mf-a")
    (c (inc x))))

(defn mf-b [x]
  (println "starting mf-b")
  (fn [c]
    (println "completing mf-b")
    (c (* 2 x))))

(defn mf-c [x]
  (println "starting mf-c")
  (fn [c]
    (println "completing mf-c")
    (c (dec x))))

(def fn8 (m-chain [mf-a mf-b mf-c]))

Calling fn8 with following expression

((fn8 10) identity)

causes the following to be printed

starting mf-a
completing mf-a
starting mf-b
completing mf-b
starting mf-c
completing mf-c

Jumping through all these hoops is pointless, unless it gives us access to something we couldn’t access or allows us to do something we couldn’t do otherwise. Let’s take a close look at mf-a to see what that might be.

(defn mf-a [x]
  (fn [c]
    (c (inc x))))

We’ve been saying that a monadic value is a function that accepts a continuation as its only parameter and does the right thing with it. The ‘right thing’ to this point has always been to call it with a value and then return the result as the mv’s return value. So what else, might be done with the continuation passed to the mv? The simplest would to do nothing with it.

(defn halt [x]
  (fn [c]
    x))

This new monadic function can be used just like any other.

(def fn9 (m-chain [mf-a halt mf-b mf-c]))

((fn9 10) identity)

which produces this output.

starting mf-a
completing mf-a
11

Do you see what happened? When the monadic value produced by halt was called with a continuation function, instead of continuing the computation, it aborted and passed the result of mf-a back out as the result of the call to fn9. This behavior closely resembles throwing an exception.

So any thing returned by a monadic value will eventually be returned by the top level as the result of the top level call. What if we returned a function that, when called, would call the continuation with the value of x?

(defn bounce [x]
  (fn [c]
    (fn []
      (println "bounce")
      (c x))))

(def fn10 (m-chain [mf-a bounce mf-b bounce mf-c]))

If we just call fn10 like we have been, it will execute mf-a and then return the function created by the first bounce. If we then call that function, it will execute mf-b and return the function created by the second bounce. If we then call that function, it will execute mf-c and return its result. Clojure can execute a sequence of functions like that using the function ‘trampoline’.

(trampoline ((fn10 10) identity))

which produces this output.

starting mf-a
completing mf-a
bounce
starting mf-b
completing mf-b
bounce
starting mf-c
completing mf-c
21

The important thing to see is that the value returned by m-chain is itself a monadic function which can be combined with other monadic functions to create yet more monadic functions. This process can continue to any depth. And no matter how deeply the call to halt or bounce is, their return value will be passed back up through all the layers until it gets returned by the top level function. For bounce, the interesting thing about this is that each time it’s called and it returns a function back to the top level, the call stack is cleared. So in a program where the call stack might grow infinitely large, trampoline and bounce let you stop the program, clear the call stack and then pick up where you left off.

One final example. If your monadic function just simply returned the continuation, like this:

(defn mark [x]
  (fn [c]
    c))

(def fn11 (m-chain [mf-a mark mf-b mf-c]))

(def mark-cont ((fn11 10) identity))
(doall (map mark-cont [0 1 2]))

You’re going to get output like this:

starting mf-a
completing mf-a
starting mf-b
completing mf-b
starting mf-c
completing mf-c
starting mf-b
completing mf-b
starting mf-c
completing mf-c
starting mf-b
completing mf-b
starting mf-c
completing mf-c
(-1 1 3)

Reading through that list carefully, you’ll see that mf-a get’s called exactly once. However, mf-b and mf-c get called 3 times each. And each time, they produce a different result depending on what mark-cont is called with. If you squint a little, it looks like a loop. Another thing to notice is that each call to mark-cont totally ignored what mf-a had returned as a result. That information was lost when mark returned the only the continuation.

Closing

These three examples show that with the continuation monad, you can control the flow of execution. But the continuation monad is just a work around for the fact the clojure doesn’t have continuations built in. Which is understandable, since the JVM doesn’t support continuations. In languages where continuations are built in, you have total control over the order of execution. Any control structure can be built out of continuations. In addition to exceptions and loops, you can also build co-routines, threads and the Common Lisp conditional handling system, among others.

Jim Duey 24 March 2012
blog comments powered by Disqus