Skip to content

State via Cont

State s a is just a wrapper around a function

f :: s -> (a, s)
f s = (a, s')

What if we encode its result in CPS? That would be the type ((a, s) -> r) -> r. The type of the whole function becomesf :: s -> ((a, s) -> r) -> r.fnow takes two arguments: the initial states, and a continuationkto which to pass its result. It callskwith argument(a, s')`:

f :: s -> ((a, s) -> r) -> r
f s k = k (a, s')

Next let's curry the continuation k. Its type becomes a -> s -> r. For the fun of it, let's also swap the arguments of f:

f :: (a -> s -> r) -> s -> r
f k s = k a s'

And finally, let's desugar this definition of f:

f :: (a -> s -> r) -> s -> r
f = \k s -> k a s'

Note that we we haven't really changed the behaviour of f. We still have a function that maps an initial state s to a result a and an updated state s', only instead of returning a and s' directly, it passes these values to a continuation k.

If we define r' = s -> r, then f has the type

f :: (a -> r') -> r'

and that's just a plain old value encoded in CPS. Wrapping f in Cont, we obtain a value of type Cont r' a. So here is our stateful function f expressed using the Cont monad:

f :: Cont (s -> r) a
f = Cont $ \k s -> k a s'

It turns out that this is exactly the correct implementation of the State monad via the Cont monad:

type StateC r s a = Cont (s -> r) a

We can easily verify that return and (>>=) for this specialization of the Cont monad behave exactly the same way as return and (>>=) for the State monad.

First return. In the State monad, we have

return x = State $ \s -> (x, s)

So return x is a function that maps a state s to a pair consisting of x and the exact same state s. According to our discussion above, a function \s -> (a, s') turns into a two-argument function with arguments k and s that calls k with arguments a and s'. Thus, we should have

return :: a -> StateC r s a
return x = Cont $ \k s -> k x s

The implementation of return in the Monad instance of Cont r' is

return :: a -> Cont r' a
return x = Cont ($ x)
         = Cont $ \k -> k x

But here we have 'r' = s -> r, that is, the continuationkhas the typea -> r' = a -> s -> r`; it takes two arguments. Using \(\eta\)-expansion, we obtain

return :: a -> Cont r' a
return x = Cont $ \k -> \s -> k x s
         = Cont $ \k s -> k x s

exactly what we hoped to obtain.

Now let's look at (>>=). Consider two functions

f   = State $ \s  -> (a, s')
g a = State $ \s' -> (b, s'')

For this discussion, you shouldn't read a, b, s, s', and s'' as variables but as concrete values. Given the state s, f produces the value a and the state s'. g a maps s' to the value b and the state s''. The definition of (>>=) for the State monad says that

f >>= g = State $ \s -> let (a, s') = runState f s
                        in  runState (g a) s'
        = State $ \s -> (b, s'')

Encoding this in CPS, this becomes the computation

f >>= g :: StateC r s a
f >>= g = Cont $ \k s -> k b s''

f >>= g takes the initial state s and the continuation k as argument and calls k with its result value b and the final state s'' as argument. The CPS encodings of f and g a are

f   = Cont $ \k s  -> k a s'
g a = Cont $ \k s' -> k b s''

Let's see whether the definition of (>>=) for the Cont monad, applied to f and g gives us the value f >>= g we expect. We have

f >>= g
    = Cont $ \k -> runCont f (\a -> runCont (g a) k)                -- Definition of (>>=)
    = Cont $ \k -> (\k s -> k a s') (\a -> (\k s' -> k b s'') k)    -- Definitions of f and g a
    = Cont $ \k -> (\k s -> k a s') (\a -> \s' -> k b s'')          -- Function application
    = Cont $ \k -> (\s -> (\a -> \s' -> k b s'') a s')              -- Function application
    = Cont $ \k -> \s -> k b s''                                    -- Function application
    = Cont $ \k s -> k b s''

Wow, we did it right? This is what we wanted to obtain.

To turn StateC into a complete state monad, we also need implementations of get, put, and runState. The definitions of get for the State monad were:

get :: State s s
get = State $ \s -> (s, s)

put :: s -> State s ()
put s = State $ \_ -> ((), s)

get leaves the state s alone but also returns it as its result. put returns () (void) and does not care what the original state was. The new state is whatever state s we passed to put as its argument. Converted into CPS, this gives

get :: StateC r s s
get = Cont $ \k s -> k s s

put :: s -> StateC r s ()
put = Cont $ \k _ -> k () s

For the State monad, we have

runState :: State s r -> s -> (r, s)
runState (State f) s = f s

In particular, this is a computation that returns a result of type (r, s). When using the Cont monad, a value of type Cont r' a is a value of type a in a computation that produces a result of type r' at the end. Thus, we want r' = (r, s). This gives the type

runState :: StateC (r, s) s r -> s -> (r, s)

The implementation is

runState m = runCont m (,)

This surely requires some unpacking. Once again, if we have the function

f :: State s r
f = State $ \s -> (r, s')

in the standard State monad, then according to the definition above, we have

runState f s = (r, s')

Encoded in CPS, f becomes

f :: StateC (r, s) s r
f = Cont $ \k s -> k r s'

With our definition of runState for the StateC monad, we obtain

runState f s = runCont f (,) s           -- Definition of runState
             = (\k s -> k r s') (,) s    -- Definition of f
             = (\s -> (,) r s') s        -- Function application
             = (,) r s'                  -- Function application
             = (r, s')                   -- Function application

and that's what we wanted to obtain.

Let's try it all out, using a fairly silly function as an example:

rangeList :: Int -> Int -> [Int]
rangeList m n = fst $ runState (replicateM (n - m + 1) rangeElem) m
  where
    rangeElem = do
        i <- get
        put (i + 1)
        return i

We implemented a variation of this function before, when we discussed the real state monad. The rangeElem function retrieves the current state, an integer, and updates the state to the next integer. Then it returns the current integer value. By repeating this n - m + 1 times using replicateM, we obtain a list of the results produced by these n - m + 1 invocations of rangeElem. We initialize the state to m, by giving this as the second argument to runState. So the list we generate is the list [m..n]. runState returns this list paired with the final state n + 1. We extract the list from thsi pair using fst.

Here's a GHCi session that shows that this does exactly what we want:

GHCi
>>> type State r s a = Cont (s -> r) a
>>> :{
  | get :: State r s s
  | get = cont (\k s -> k s s)
  |
  | put :: s -> State r s ()
  | put s = cont (\k _ -> k () s)
  |
  | runState :: State (r, s) s r -> s -> (r, s)
  | runState m = runCont m (,)
  | :}
>>> :{
  | rangeList :: Int -> Int -> [Int]
  | rangeList m n = fst $ runState (replicateM (n - m + 1) rangeElem) m
  |   where
  |     rangeElem = do
  |         i <- get
  |         put (i + 1)
  |         return i
  | :}
>>> rangeList 5 20
[5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20]