Continuations in Haskell

It is assumed that the reader more or less understands the continuation-passing style. The first thing that the Cont monad does is to abstract away the need to pass continuations explicitly in continuation-passing style.

The Cont monad abstracts away the handling of continuations, by creating a series of “partially applied” (I’m using this quite loosely) functions that are all waiting for their continuation. The add function was originally add x y, and add_cps became add_cps x y k. This meant that add_cps explicitly took a continuation function k.

What happens if we have this?

-- add
add :: Int -> Int -> Int
add x y = x + y

add_cps :: Int -> Int -> (Int -> r) -> r
add_cps x y k = k (add x y)

add_abstracted_cps :: Int -> Int -> ((Int -> r) -> r)
add_abstracted_cps x y = (\k -> k (add x y))

We may see add_abstracted_cps as a “partially applied” function of add_cps, and the missing argument is the continuation function. Hence, we can define a bunch of such “abstracted” functions which are partially applied functions.

Notice that the type signature for the “abstracted” version of the function is exactly the same as that of the cps version, only that due to partial application, it returns a function that takes a continuation and returns a type r. This is made explicit using parentheses, though it’s not necessary (the parentheses part).

Put another way, we are just converting between two equivalent forms using lambda expressions, as follows:

f :: Int -> Int
f x = x+1         -- This is equivalent to this, both
f = \x -> x+1     -- in the function body and in the type signature

The above are equivalent.

Now, back to partial application, suppose we have the following:

add :: Int -> Int -> Int
add x y = x + y

add_cps :: Int -> Int -> (Int -> r) -> r
add_cps x y k = k (add x y)

add_abstracted_cps :: Int -> Int -> ((Int -> r) -> r)
add_abstracted_cps x y = (\k -> k (add x y))

Now, add_abstracted_cps is just add_cps partially applied, and is just missing the continuation before it can be evaluated (i.e. it returns a function that expects k, the continuation).

However, so far the use of the words “partially applied” here is a bit loose, in the sense that you’ll notice that add_abstracted_cps still has the variables x and y, and those have certainly not been applied.

We’ll see why I keep saying partially applied in just a moment. However, for now, let’s do the same for all of the functions and put them all together:

-- add
add :: Int -> Int -> Int
add x y = x + y

add_cps :: Int -> Int -> (Int -> r) -> r
add_cps x y k = k (add x y)

add_abstracted_cps :: Int -> Int -> ((Int -> r) -> r)
add_abstracted_cps x y = (\k -> k (add x y))

-- square
square :: Int -> Int
square x = x * x

square_cps :: Int -> (Int -> r) -> r
square_cps x k = k (square x)

square_abstracted_cps :: Int -> ((Int -> r) -> r)
square_abstracted_cps x = (\k -> k (square x))

-- pythagoras
pythagoras :: Int -> Int -> Int
pythagoras x y = add (square x) (square y)

pythagoras_cps :: Int -> Int -> (Int -> r) -> r
pythagoras_cps x y k = square_cps x (\x_squared ->
                       square_cps y (\y_squared ->
                       add_cps x_squared y_squared (\sum_of_squares ->
                       k sum_of_squares)))

pythagoras_abstracted_cps :: Int -> Int -> ((Int -> r) -> r)
pythagoras_abstracted_cps x y = (k ->
                                  square_cps x (\x_squared ->
                                  square_cps y (\y_squared ->
                                  add_cps x_squared y_squared (\sum_of_squares ->
                                  k sum_of_squares))))

Now, suppose that we fully apply the *_abstracted_cps functions, we are going to get a function that expects k, the continuation. If we apply k to that, then we can “evaluate” the function body. Hence, *_abstracted_cps when fully applied, is just *_cps that has been partially applied just up to (but not including) the continuation parameter, k. Hence my use of the words “partially applied” earlier.

Thank you for tolerating my loose use of those words.

Hence, to reiterate, the *_abstracted_cps functions, after that have been evaluated (by applying all the required parameters), returns a computation (a function) that requires a continuation function in order to fully evaluate.

We’re going to be seeing this kind of function a lot, hence, we’ll give it a more intuitive name. Let’s call them waiting functions. Thus, waiting functions are functions that are waiting on a continuation, as the first and only parameter.

The Cont monad

Clearly, the fact that we have to handle the continuation function explicitly can be abstracted cleanly away using a monad. This is implemented in the Cont monad. Let’s take a look at the definition first. We are going to see some very big similarities.

However, the first thing that we need to constantly keep in mind, is what exactly the monad represents. Here goes: the Cont monad “stores”, as the inner value of the monadic value, the fully applied *_abstracted_cps functions. That’s our waiting function. Here is it again:

The monadic value “contains” the computation that is just waiting for its continuation (the waiting function).

The Cont monad is defined in Control.Monad.Cont, hence the following line is necessary:

import Control.Monad.Cont

Here’s the data definition of Cont:

newtype Cont r a = Cont { runCont :: ((\a -> r) -> r) }

First off, the type of the monadic value of the Cont monad represents ((a -> r) -> r), which looks very familiar, since it’s the type signature of the waiting function. It’s pretty much the rear end of all our cps function type signatures.

The continuation itself is given by the type (a -> r), and as you can see, represents the first and only parameter of the waiting function.

The waiting function, with a continuation as its argument evaluates to a value of type r, since the continuation returns a result of type r. Thus, a Cont “contains” a waiting function.

Hence, this would work:

Cont (pythagoras_abstracted_cps 100 200)

return and >>=

Now let’s take a look at how the Cont monad defines return and >>=.

instance Monad (Cont r) where
  return a  = Cont (\k -> k a)
  mv >>= f  = Cont (\k -> runCont mv (\a -> runCont (f a) k))

return is simply the continuation version of identity. Consider the following:

id :: Int -> Int
id x = x

id_cps :: Int -> (Int -> r) -> r
id_cps x k = k x

id_abstracted_cps :: Int -> (Int -> r) -> r
id_abstracted_cps x = (\k -> k x)

That’s exactly return in the Cont monad, with the addition of applying the monad’s constructor to the result, of course. It creates a computation that is going to return just plain x (its argument), once it has its continuation supplied. Sounds familiar? :)

The >>= operator is just a little bit more complicated. Remember that the monadic values contains the computations just waiting for their continuation. Now suppose we have more than one of these functions. For instance, we want to combine the continuation versions of both add and square together.

square_cont 4 >>= add_cont 200

Before examining how >>= works in the Cont monad, remember that mv >>= f embodies the concept of extracting the inner value of mv, and somehow relate f to the inner value. The inner value is the computation waiting for its continuation (call this c), and f is a function that is going to take the result of the computation of c (which is bound to a, by the definition of continuations), and produce a new computation that is waiting for its continuation (call this c').

Okay that was a mouthful. Let’s go through this step by step. First of all, remember that the mechanism, or the convention if you will, of functions written in continuation-passing style, is that we have a function that contains some code, and that function takes, as an “extra” argument, a continuation, to which it passes the result.

This means that the function’s code itself, intuitively, will not contain any references to the continuation function passed into it, with the exception of the last “line” of the function which calls that continuation function. Look at the following function yet again:

add :: Int -> Int -> Int
add x y = x + y

add_cps :: Int -> Int -> (Int -> r) -> r
add_cps x y k = k (add x y)

The only use of k is to call the continuation, and this is the last action. Continuations are defined as such, since the idea is to pass control to the continuation function. We can thus think of the function body being completely independent of the continuation, except for the final pass.

Again, by convention, the continuation is going to take the “result” of the function body as its argument. In this case, k is going to take a single argument, with the result of add x y.

Now we are ready to look back at the definition for >>=:

newtype Cont r a = Cont { runCont :: ((\a -> r) -> r) }
instance Monad (Cont r) where
  return a  = Cont (\k -> k a)
  mv >>= f  = Cont (\k -> runCont mv (\a -> runCont (f a) k))

mv contains the computation (waiting for its continuation), which we extract using the runCont helper function. Hence, runCont mv produces that computation function (as we just said). It’s waiting for its continuation, so we pass it the continuation (a -> runCont (f a) k).

Remember that the definition of continuations is such that body of the computation can be “evaluated” to produce a “result”, independent of the continuation? That “result”, again by definition, is going to be passed into the continuation, which takes it in as a single argument. In other words, the “result” of the body of the computation is going to be bound to a in the continuation. This represents the notion of “evaluating” mv to produce its “result”.

Then we have the monadic function f. We recall that monadic functions are going to take a value somehow extracted (and perhaps tweaked a little) from the monadic value mv, and produce a monadic value. In our case of the Cont monad, what f is expecting is the result of the computation in mv.

We hence give it exactly that, and that result was bound to a. The output, and therefore the result of (f a) is a monadic value, which is a new computation waiting for its continuation, again. We extract that computation using runCont (again), give it a continuation, k (which won’t evaluate because it’s waiting for k, see the lambda), and that completes the process. This represents the notion of “evaluating” f to produce the final “result”.

Note that we’ve been using the word “evaluate” and “result” again quite (okay, very) loosely. Nothing gets “evaluated” in mv >>= f, because we’re dealing with computations, or the notion of computations, depending on how you want to see it.

Remember, Haskell is lazy. All that’s there is a promise to evaluate, when it’s required. Hence, the final “result” is nothing more than a large computation that is still waiting on a continuation – the top level continuation. Once we supply this does the entire computation have what is required to finally and truly evaluate, and when it does, it actually produces a result.

Let’s finally use the Cont monad:

-- add
add_cont :: Int -> Int -> Cont r Int
add_cont x y = return (add x y)

{- This creates the computation waiting for its continuation:
     Cont (\k -> k (add x y)) -}

-- square
square_cont :: Int -> Cont r Int
square_cont x = return (square x)

{- This creates the computation waiting for its continuation:
     Cont (\k -> k (square x)) -}

-- pythagoras
pythagoras_cont :: Int -> Int -> Cont r Int
pythagoras_cont x y =
    square_cont x >>= (\x_squared ->
      square_cont y >>= (\y_squared ->
        add_cont x_squared y_squared >>= (\sum_of_squares ->
          return sum_of_squares)))

-- pythagoras (with do-notation)
pythagoras_cont':: Int -> Int -> Cont r Int
pythagoras_cont' x y = do
  x_squared <- square_cont x
  y_squared <- square_cont y
  sum_of_squares <- add_cont x_squared y_squared
  return sum_of_squares

test1 = runCont (pythagoras_cont 3 4) print
test2 = runCont (pythagoras_cont' 3 4) (+100)

So far, we’ve understood how to write in continuation-passing style, and how the Cont monad takes care of turning functions into continuation-passing style by turning all functions in the Cont monad into computations that expect a continuation.

Hence, all the functions that we define in the Cont monad are such expect-a-continuation functions. It also defines the >>= operator to compose such computations, and the >>= operator takes care of passing the right continuations through the composed functions, making everything nice and implicit.

A note on the “right” continuations

Observe that the “right” continuations change (obviously) as the computations “progress”. Because we are composing computations using the monadic >>= bind operator, and because of how the Cont monad defines >>=, the “right” continuation is always going to contain the computations on the right hand side of >>=. In other words, given mv >>= f, when “evaluating” mv, the continuation to mv is going to contain the computation represented by f. Given mv >>= f >>= g, then the continuation to mv contains the computation of f (and not g), and the continuation to f contains the computation of g. What then is contained in the continuation of g? It’s going to be the top-level continuation, the one that we pass in when we say runCont (...) print. In this example, the top-level continuation is print.

Note that if we write our code slightly differently (though equivalently), we get something a little bit different (but equivalent). Consider the code:

mv >>= (\x -> (f x) >>= (\y -> (g y)))

Now, the continuation to mv is going to contain the computation of

(\x -> (f x) >>= (\y -> (g y)))

and the continuation to (f x) is going to contain the computation of

(\y -> (g y))

In this way, we can see things in a slightly different light. The continuation at any point is always the current-continuation, where the current-continuation represents “the rest of the program”. In the case of Haskell, because we’re trapped in the Cont monad, it represents “the rest of the monadic computation”. However, the two forms are no different in practice.

Gaining even more control

Now, can we obtain even more control over our continuations, “jumping” to continuations whenever we want (and not just, as we’ve seen, at the end of a function under the Cont monad), and pass whatever we want into the continuation (and not just, as we’ve again seen, the currently computed value)?

Indeed so, after all, what’s stopping us from writing code that calls the continuation, k, at any point that we like? This is trivially simple in the continuation-passing style, for we have direct access to the continuations itself. For instance, we could call k under a condition, and do something else when the condition is false, as follows:

pythagoras_cps' :: Int -> Int -> (Int -> r) -> r
pythagoras_cps' x y k = square_cps x (\x_squared ->
                        square_cps y (\y_squared ->
                        add_cps x_squared y_squared (\sum_of_squares ->
                        if (sum_of_squares > 100) then k (\sum_of_squares-10)
                        else k sum_of_squares)))

test3 = pythagoras_cps' 10 10 print
test4 = pythagoras_cps' 1 1 print

Hence, a close analogy in imperative languages is the return statement (not the monadic return function), whereby we use return to “exit” a function, and pass return an argument to, well, return.

Similarly, in continuations, since we are “returning” to the continuation, we can simply call the continuation to exit, and pass the continuation an argument to, well, continue with.

However, how about when using the Cont monad? Remember that the Cont monad “hides” the continuation, by turning the computation (e.g. the pythagoras function) into what we called a waiting function. That means that we don’t have direct access to the continuation, which is the whole purpose of the monad in the first place. How then can we access the continuation so that we have the power and flexibility we just saw?

We do that using a helper function called callCC (okay, whether you see it as a helper function or not is your choice).

What is callCC? It’s the commonly called call-with-current-continuation. This is a queer little fella, and we should get some intuitive idea about callCC first. So, what is it?

callCC takes a waiting function, and it is going to create a Cont (which contains that waiting function), and it’s going to run it.

But with what? It’s waiting for a continuation. The obvious answer is that it’s going to run it with the continuation it’s given, as it’s being composed via the >>= operator. See the section on “A note on the ‘right’ continuations” (above) for why. And the “right” continuation is – the current-continuation.

So what’s the big deal? Well, the big deal is that without callCC, all the “right” continuations were implicit. Your code could not access the implicitly created continuations (all the a -> ... stuff in the definition of >>=). With callCC, it exposes the continuation, by passing the current-continuation explicitly into the waiting function, and that being a waiting function, has an explicit argument that is expecting a continuation. First consider this code:

addOne :: Int -> Cont r Int
addOne n = return (n+1)

doSomething :: Int -> Cont r Int
doSomething n = return n >>= (\x -> return (x-2))

test5 = runCont (doSomething 5) print

Nothing special above. Now let’s put in a callCC, but not invoke the current-continuation passed to callCC.

doSomething' :: Int -> Cont r Int
doSomething' n = return n >>= (\x -> callCC (\breakOutOfHere -> return (x-2)))

test6 = runCont (doSomething' 5) print

Looks good, but doesn’t do anything cool. Now let’s invoke it conditionally.

doSomething'' :: Int -> Cont r Int
doSomething'' n = return n >>= (\x -> callCC (breakOutOfHere ->
                                                if (x>5) then breakOutOfHere 100000
                                                else return (x-2)))

test7 = runCont (doSomething'' 5) print
test8 = runCont (doSomething'' 10) print

Now, we literally get to break out when we want, and pass (or throw) out whatever value we want when we break out. Now we have a fully hidden and implicit implementation of continuations, with the option to expose the continuation whenever we want, via callCC.

Thus, you saw that invoking the current-continuation (it has a name now) at any point in time inside the code of the waiting function, breaks out of the waiting function. Break out to where, you may ask.

Again, the continuation is the current-continuation from the point of the callCC call. Hence, it breaks just straight out of the waiting function, to the point right after.

Type and definition of callCC

Now let’s take a look at the type and definition of callCC:

           arg = waiting function        result
                         |                      |
           v---------------------------v    v------v
callCC :: ((\a -> Cont r b) -> Cont r a) -> Cont r a

callCC f = Cont $ k -> runCont (f (\a -> Cont $ _ -> k a)) k

So, how does callCC do what it claims to do? First, we see that it takes f, a waiting function. Like all waiting functions, it needs to take in a continuation. We already figured that out.

It then says that it wants to “run” f, so it has to call runCont on f. Again, that is quite simple. However, what does it pass to f? Normally (not in callCC), we would pass in the “right” continuation. However, for callCC, it doesn’t want to do that. It wants to let the programmer break out of the waiting function, and where is that point exactly? It’s the in “outer” continuation, which is in k. Now the function f is a waiting function that, when “evaluated”, generates a result.

As per our normal way of passing results into continuations, the a in our example here is going to take the result of f. But this time, callCC explicitly passes that result into k, which is the “outer” continuation. It breaks out with the result of f.

Now, the type. The argument to callCC is of type ((\a -> Cont r b) -> Cont r a), which we can see is a function, f.

We recall that this function is a waiting function, and that it takes a continuation. The continuation therefore has the type (a -> Cont r b). Why is the type different from the type of the previous continuations that we saw? That’s because callCC is going to pass the current continuation, and the current continuation is going to represent the rest of the program.

Remember also that we’re in the Cont monad. So what does the “rest of the program” return? Well, it has to return something in the Cont monad! Hence the type of ((\a -> Cont r b) -> Cont r a) for the function f.