State via Cont
State s a
is just a wrapper around a function
f :: s -> (a, s)
f s = (a, s')
What if we encode its result in CPS? That would be the type ((a, s) -> r) -> r.
The type of the whole function becomes
f :: s -> ((a, s) -> r) -> r.
fnow
takes two arguments: the initial state
s, and a continuation
kto which to
pass its result. It calls
kwith argument
(a, s')`:
f :: s -> ((a, s) -> r) -> r
f s k = k (a, s')
Next let's curry the continuation k
. Its type becomes a -> s -> r
. For the
fun of it, let's also swap the arguments of f
:
f :: (a -> s -> r) -> s -> r
f k s = k a s'
And finally, let's desugar this definition of f
:
f :: (a -> s -> r) -> s -> r
f = \k s -> k a s'
Note that we we haven't really changed the behaviour of f
. We still have a
function that maps an initial state s
to a result a
and an updated state
s'
, only instead of returning a
and s'
directly, it passes these values to
a continuation k
.
If we define r' = s -> r
, then f
has the type
f :: (a -> r') -> r'
and that's just a plain old value encoded in CPS. Wrapping f
in Cont
, we
obtain a value of type Cont r' a
. So here is our stateful function f
expressed using the Cont
monad:
f :: Cont (s -> r) a
f = Cont $ \k s -> k a s'
It turns out that this is exactly the correct implementation of the State
monad via the Cont
monad:
type StateC r s a = Cont (s -> r) a
We can easily verify that return
and (>>=)
for this specialization of the
Cont
monad behave exactly the same way as return
and (>>=)
for the State
monad.
First return
. In the State
monad, we have
return x = State $ \s -> (x, s)
So return x
is a function that maps a state s
to a pair consisting of x
and the exact same state s
. According to our discussion above, a function
\s -> (a, s')
turns into a two-argument function with arguments k
and s
that calls k
with arguments a
and s'
. Thus, we should have
return :: a -> StateC r s a
return x = Cont $ \k s -> k x s
The implementation of return
in the Monad
instance of Cont r'
is
return :: a -> Cont r' a
return x = Cont ($ x)
= Cont $ \k -> k x
But here we have 'r' = s -> r, that is, the continuation
khas the type
a ->
r' = a -> s -> r`; it takes two arguments. Using
\(\eta\)-expansion, we
obtain
return :: a -> Cont r' a
return x = Cont $ \k -> \s -> k x s
= Cont $ \k s -> k x s
exactly what we hoped to obtain.
Now let's look at (>>=)
. Consider two functions
f = State $ \s -> (a, s')
g a = State $ \s' -> (b, s'')
For this discussion, you shouldn't read a
, b
, s
, s'
, and s''
as
variables but as concrete values. Given the state s
, f
produces the value
a
and the state s'
. g a
maps s'
to the value b
and the state s''
.
The definition of (>>=)
for the State
monad says that
f >>= g = State $ \s -> let (a, s') = runState f s
in runState (g a) s'
= State $ \s -> (b, s'')
Encoding this in CPS, this becomes the computation
f >>= g :: StateC r s a
f >>= g = Cont $ \k s -> k b s''
f >>= g
takes the initial state s
and the continuation k
as argument and
calls k
with its result value b
and the final state s''
as argument. The
CPS encodings of f
and g a
are
f = Cont $ \k s -> k a s'
g a = Cont $ \k s' -> k b s''
Let's see whether the definition of (>>=)
for the Cont
monad, applied to f
and g
gives us the value f >>= g
we expect. We have
f >>= g
= Cont $ \k -> runCont f (\a -> runCont (g a) k) -- Definition of (>>=)
= Cont $ \k -> (\k s -> k a s') (\a -> (\k s' -> k b s'') k) -- Definitions of f and g a
= Cont $ \k -> (\k s -> k a s') (\a -> \s' -> k b s'') -- Function application
= Cont $ \k -> (\s -> (\a -> \s' -> k b s'') a s') -- Function application
= Cont $ \k -> \s -> k b s'' -- Function application
= Cont $ \k s -> k b s''
Wow, we did it right? This is what we wanted to obtain.
To turn StateC
into a complete state monad, we also need implementations of
get
, put
, and runState
. The definitions of get
for the State
monad were:
get :: State s s
get = State $ \s -> (s, s)
put :: s -> State s ()
put s = State $ \_ -> ((), s)
get
leaves the state s
alone but also returns it as its result. put
returns ()
(void) and does not care what the original state was. The new state
is whatever state s
we passed to put
as its argument. Converted into CPS,
this gives
get :: StateC r s s
get = Cont $ \k s -> k s s
put :: s -> StateC r s ()
put = Cont $ \k _ -> k () s
For the State
monad, we have
runState :: State s r -> s -> (r, s)
runState (State f) s = f s
In particular, this is a computation that returns a result of type (r, s)
.
When using the Cont
monad, a value of type Cont r' a
is a value of type a
in a computation that produces a result of type r'
at the end. Thus, we want
r' = (r, s)
. This gives the type
runState :: StateC (r, s) s r -> s -> (r, s)
The implementation is
runState m = runCont m (,)
This surely requires some unpacking. Once again, if we have the function
f :: State s r
f = State $ \s -> (r, s')
in the standard State
monad, then according to the definition above, we have
runState f s = (r, s')
Encoded in CPS, f
becomes
f :: StateC (r, s) s r
f = Cont $ \k s -> k r s'
With our definition of runState
for the StateC
monad, we obtain
runState f s = runCont f (,) s -- Definition of runState
= (\k s -> k r s') (,) s -- Definition of f
= (\s -> (,) r s') s -- Function application
= (,) r s' -- Function application
= (r, s') -- Function application
and that's what we wanted to obtain.
Let's try it all out, using a fairly silly function as an example:
rangeList :: Int -> Int -> [Int]
rangeList m n = fst $ runState (replicateM (n - m + 1) rangeElem) m
where
rangeElem = do
i <- get
put (i + 1)
return i
We implemented a variation of this function before, when we discussed the real
state monad. The rangeElem function retrieves the current state, an integer, and
updates the state to the next integer. Then it returns the current integer
value. By repeating this n - m + 1
times using replicateM
, we obtain a list
of the results produced by these n - m + 1
invocations of rangeElem
. We
initialize the state to m
, by giving this as the second argument to
runState
. So the list we generate is the list [m..n]
. runState
returns
this list paired with the final state n + 1
. We extract the list from thsi
pair using fst
.
Here's a GHCi session that shows that this does exactly what we want:
>>> type State r s a = Cont (s -> r) a
>>> :{
| get :: State r s s
| get = cont (\k s -> k s s)
|
| put :: s -> State r s ()
| put s = cont (\k _ -> k () s)
|
| runState :: State (r, s) s r -> s -> (r, s)
| runState m = runCont m (,)
| :}
>>> :{
| rangeList :: Int -> Int -> [Int]
| rangeList m n = fst $ runState (replicateM (n - m + 1) rangeElem) m
| where
| rangeElem = do
| i <- get
| put (i + 1)
| return i
| :}
>>> rangeList 5 20
[5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20]