State: Stateful Computations
Oh imperative programmers rejoice. You will learn now how to bring your dirty
little world of destructive variable updates to the sacred halls of functional
programming. The tool to do this is the State monad. Whereas the Reader
monad allows us to access a context without modifying it, and the Writer monad
provides us with a write-only log, the State monad gives us a state that we
can both read and write at will.
How could we model stateful computations purely functionally? We want to define
a type State s a that models a computation that produces some value a and
depends on some state of type s to do so. In terms of pure functions, a
computation that produces a value of type a from some state of type s is a
function of type s -> a. That's nothing but the function type wrapped by our
Reader type:
newtype Reader s a = Reader { runReader :: s -> a }
This type models computations that can depend on some read-only context of
type s. The State monad should allow us to modify the state. The way a
pure function "modifies" the state it is given as input is to return the updated
state. Thus, a function that can read and write a state of type s should have
the type s -> (a, s). Just as Reader s a is a wrapper around a function of
type s -> a, State s a is a wrapper around a function of type s -> (a, s):
newtype State s a = State { runState :: s -> (a, s) }
Let's figure out how to turn State s into a monad. We need to come up with
appropriate definitions of return and (>>=).
First return. If x :: a, then return x should have the type State s a.
Let's build it up in pieces again. The function virtually writes itself. We
start with the skeleton
return x = State $ \s -> ...
because a value of type State s a should wrap a function of type
s -> (a, s). The function (\s -> ...) must return a pair of type (a, s).
Given that we know nothing about the types a and s, all we have at our
disposal is the state s given as argument to this function, and the value x
given as argument to return. Thus, the only pair of type (a, s) we can
produre is the pair (x, s):
return x = State $ \s -> (x, s)
This is exactly the behaviour we would expect from return x: It's a
computation that returns x and leaves the state alone.
Now what about x >>= f. Remember, this should model the idea of applying f
to x. As far as passing return values along is concerned, this function writes
itself again because we don't really have any choice how to produce a return
value of the correct type for this computation. However, both x and f may
modify the state. We need to make sure that these state modifications are
combined correctly. What we want to capture is the idea that x is a
computation that produces some value of type a and reads and modifies the
state in the process. Then we pass this value of type a to f. f then runs,
and reads and modifies the state some more. The final result should be whatever
f returns, and the state should be the state that f leaves behind. The
important part to focus on is that the state we pass to f is the state left
behind by x, because f runs after x. This is different from the Reader
monad, where we passed the same context to both x and f, because neither can
modify the context. Here's the implementation of x >>= f that implements this
logic:
x >>= f = State $ \s -> let (y, s') = runState x s
in runState (f y) s'
Given some state s, this runs x on this state and collects x's return
value y and the updated state s'. It then passes y to f and runs the
resulting function of type s -> (a, s) on the updated state s'. The result
is whatever result-state pair f returns.
Our whole Monad instance for the State s monad is
instance Monad (State s) where
return x = State $ \s -> (x, s)
x >>= f = State $ \s -> let (y, s') = runState x s
in runState (f y) s'
Using the State monad without using do-notation is quite painful. So, before
we look at some examples, let's introduce how to access the state from within
do-blocks. The functions are provided by the Control.Monad.State module.
Accessing the State
The most elementary functions are get and put. They respectively allow us to
read the state within a do-block and to replace the current state with a new
value. If we think about the state as a variable, get reads it and put
writes it. This allows us to write do-blocks of the following shape:
do
-- Read the state
s <- get
-- Compute something from s, including an updated state s'
...
-- Make s' the new state
put s'
Here is how get and put are implemented:
get :: State s s
get = State $ \s -> (s, s)
put :: s -> State s ()
put s = State $ \_ -> ((), s)
Remember, a computation of type State s a operates on some state of type s
and returns a value of type a. We modelled this as a function of type
s -> (a, s). get should not modify the state and should return the current
state as its result. Thus, get is implemented as the function \s -> (s, s).
put s should replace the current state with s and do nothing else. In
particular, put s does not even return anything useful. Thus, the first
component of the pair returned by put is () (void). Since put s replaces
the old state with s, the function it wraps ignores its state argument; it
uses a wildcard. The new state is s, so the state component of the pair
returned by put s is s.
Before looking at two more useful helpers, let's see get and put in action.
We'll continue our tradition of silly examples that have the benefit of being
simple, so they should be easy to follow. We want to implement a function
rangeList n that generates the list [1..n]:
>>> rangeList 10
[1,2,3,4,5,6,7,8,9,10]
>>> rangeList 20
[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20]
As I said, this is a silly example. We would normally simply use the list
comprehension [1..n] instead of calling rangeList n. To generate the result
of rangeList, we will use a counter. That's our state. We will use a function
go that uses this state to generate the list we want. Specifically, go will
generate the list [counter..n], where counter is the current counter value.
If counter > n, this is the empty list. If counter <= n, we build our list
from the value counter and the list we obtain by calling go recursively
after increasing counter by one:
rangeList :: Int -> [Int]
rangeList n = fst $ runState go 1
where
go = do
counter <- get
if counter > n then
return []
else do
put $ counter + 1
xs <- go
return $ counter : xs
>>> import Control.Monad.State
>>> :{
| rangeList :: Int -> [Int]
| rangeList n = fst $ runState go 1
| where
| go = do
| counter <- get
| if counter > n then
| return []
| else do
| put $ counter + 1
| xs <- go
| return $ counter : xs
| :}
>>> rangeList 10
[1,2,3,4,5,6,7,8,9,10]
>>> rangeList 20
[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20]
The function go has the type go :: State Int [Int]. It's a stateful
computation with a state of type Int, the counter. go starts by retrieving
the current counter value using counter <- get. If counter > n, then the
list [counter..n] is empty, so we simply return this empty list. Otherwise, we
increment the counter using put $ counter + 1, then call go recursively to
generate the list xs = [counter + 1..n], and finally return the list
counter : xs = [counter..n]. To generate the list [1..n], we call go with
initial state 1. That's what runState go 1 does. The result is the pair
([1..n], n + 1) composed of the result of calling go and the state n + 1
left behind once go is done. We don't care about this final state in this
example, so we extract the list [1..n] by applying fst to this result.
Using the fact that every monad is a functor and <$> is the infix version of
fmap, we can make this implementation slightly more compact:
rangeList :: Int -> [Int]
rangeList n = fst $ runState go 1
where
go = do
counter <- get
if counter > n then
return []
else do
put $ counter + 1
(counter :) <$> go
Instead of assigning the result of go to xs and then returning the list
counter : xs, we fmap the function (counter :) over go. This works for
any monad. If f :: a -> b and g :: m a, then f <$> g has type m b. This
result is obtained by applying the pure function f to the result of g.
That's simply what fmap does for every functor.
Just as the Reader monad provides us with a function asks to retrieve only
part of the context, we have gets to read only part of the state:
gets :: (s -> a) -> State s a
gets f = State $ \s -> (f s, s)
We often use gets when our state is some record storing multiple pieces of
information, say
data MyState = MyState
{ counter :: Int
, list :: [()]
}
We then use the record's accessor function to read only part of this state. For
example, ctr <- gets counter would read only the counter and ignore the
list of the current state. Let me emphasize that while we often use gets
using accessor functions of some record type that we use as the state, the
function we pass to gets can really be an arbitrary function of type
s -> a.
The final function to access the state can be seen as a combination of get and
put. Instead of reading the state using get, doing something that depends on
this state, and then writing an updated state using put, we can also modify
the current state using a pure function of type s -> s. It computes the new
state value from the old state value:
modify :: (s -> s) -> State s ()
modify f = State $ \s -> ((), f s)
To illustrate modify, let's look at our rangeList function one more time. I
promise, I'll stop after this example. Here's an implementation that uses a
record type as state and uses modify to generate the list. Let me emphasize
that this is not a very good implementation. It's the type of code that an
imperative programmer trying to program in Haskell would write.
data RangeListState = RLS
{ counter :: Int
, list :: [Int]
}
rangeList :: Int -> [Int]
rangeList n = list . snd $ runState go $ RLS n []
where
go = do
ctr <- gets counter
if ctr == 0 then
return ()
else
modify oneStep >> go
oneStep (RLS ctr list) = RLS (ctr - 1) (ctr : list)
>>> :{
| data RangeListState = RLS
| { counter :: Int
| , list :: [Int]
| }
|
| rangeList :: Int -> [Int]
| rangeList n = list . snd $ runState go $ RLS n []
| where
| go = do
| ctr <- gets counter
| if ctr == 0 then
| return ()
| else
| modify oneStep >> go
|
| oneStep (RLS ctr list) = RLS (ctr - 1) (ctr : list)
| :}
>>> rangeList 10
[1,2,3,4,5,6,7,8,9,10]
>>> rangeList 20
[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20]
Let's take this apart. Our new function go does not take any arguments and
does not produce anything useful. Its return type is (). All its input and
output are part of the state. This state consists of the current counter value
counter and the list list we are generating. go maintains the invariant
that list = [counter + 1..n]. This invariant holds initially, when we call
go with the state RLS n []. go starts by retrieving the current counter
value using gets counter. If this counter value is 0, then based on our
invariant, we have list = [1..n]. Thus, go returns. Note that rangeList n
applies the function list . snd to the result of runState go $ RLS n [].
snd extracts the state component from the pair returned by runState, and
list extracts the list from this state. So, once go returns, rangeList
retrieves the list [1..n]. If the current counter value is not 0, then go
modifies the current state by calling modify oneStep and then calls itself
recursively. oneStep takes the current state RLS ctr list and replaces it
with the state RLS (ctr - 1) (ctr : list). In other words, it adds the current
counter value to the beginning of the current list, and then decreases the
counter by 1. Thus, the invariant that list = [counter + 1..n] is maintained,
and each application of modify oneStep gets us one step closer to having a
counter value of 0.1
Running Stateful Computations
Once again, code written using the State monad isn't really imperative code;
it's purely functional code expressed using a monad. Thus, we should once again
have a method to convert any computation that uses the State monad into a pure
function. And we can. We already used runState in some of the examples in this
section. If f has the type State s a, then runState f has the type
s -> (a, s). This type signature provides no indication that the underlying
function uses the State monad in its implementation. It is simply a function
that takes a value of type s and returns a pair of type (a, s).
For the Writer monad, we also had a function execWriter that discarded the
return value of the computation and returned only the final log. Similarly,
execState discards the return value of a computation in the State monad and
only returns the final state.
execState :: State s a -> s -> s
execState f s = snd $ runState f s
Our fairly contrived last implementation of the rangeList function was an
example of a computation where we cared only about the state left behind by the
computation. go does not return anything useful. All we care about is the
list component of the state left behind by go. Thus, we could have
implemented rangeList using execState:
data RangeListState = RLS
{ counter :: Int
, list :: [Int]
}
rangeList :: Int -> [Int]
rangeList n = list $ execState go $ RLS n []
where
go = do
ctr <- gets counter
if ctr == 0 then
return ()
else
modify oneStep >> go
oneStep (RLS ctr list) = RLS (ctr - 1) (ctr : list)
The only line that has changed is the highlighted one, which uses
list $ execState go instead of list . snd $ runState go.
For the Writer monad, it makes little sense to discard the log. Using this log
after the computation is finished is the whole point of using the Writer monad
in the first place. For the State monad, on the other hand, it makes perfect
sense not to care about the final state. The state is only needed while the
computation runs, but what we are interested in once all is done and dusted is
the return value of the computation. For this purpose, we have
evalState :: State s a -> s -> a
evalState f s = fst $ runState f s
Again, our initial implementation of rangeList was an example of such a
computation. The counter used by go to construct the list was necessary only
during the construction of the list. Once go returns the list, we no longer
care about the value of the counter. So we can change the original
implementation of rangeList to use evalState instead.
rangeList :: Int -> [Int]
rangeList n = evalState go 1
where
go = do
counter <- get
if counter > n then
return []
else do
put $ counter + 1
xs <- go
return $ counter : xs
The highlighted line uses evalState go 1 instead of fst $ runState go 1 to
implement rangeList n.
-
GHC is rich in useful language extensions. One such extension is
LambdaCase, which allows us to write anonymous functions that immediately apply acaseexpression to their arguments. Specifically, if we write\case pat1 -> expr1 pat2 -> expr2after enabling
LambdaCase, this is the same as if we had written\var -> case var of pat1 -> expr1 pat2 -> expr2With this, we can make our definition of
rangeListmuch prettier:data RangeListState = RLS { counter :: Int , list :: [Int] } rangeList :: Int -> [Int] rangeList n = list . snd $ runState go $ RLS n [] where go = gets counter >>= \case 0 -> return () _ -> modify oneStep >> go oneStep (RLS ctr list) = RLS (ctr - 1) (ctr : list)GHCi>>> :set -XLambdaCase >>> :{ | data RangeListState = RLS | { counter :: Int | , list :: [Int] | } | | rangeList :: Int -> [Int] | rangeList n = list . snd $ runState go $ RLS n [] | where | go = gets counter >>= \case | 0 -> return () | _ -> modify oneStep >> go | | oneStep (RLS ctr list) = RLS (ctr - 1) (ctr : list) | :} >>> rangeList 10 [1,2,3,4,5,6,7,8,9,10] >>> rangeList 20 [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20]This is still a terrible use of the
Statemonad to do something we can do much more easily and succinctly using pure functions. I merely wanted to demonstrate how to obtain a much more readable version of ourrangeListimplementation that keeps everything in its state. ↩