State: Stateful Computations
Oh imperative programmers rejoice. You will learn now how to bring your dirty
little world of destructive variable updates to the sacred halls of functional
programming. The tool to do this is the State
monad. Whereas the Reader
monad allows us to access a context without modifying it, and the Writer
monad
provides us with a write-only log, the State
monad gives us a state that we
can both read and write at will.
How could we model stateful computations purely functionally? We want to define
a type State s a
that models a computation that produces some value a
and
depends on some state of type s
to do so. In terms of pure functions, a
computation that produces a value of type a
from some state of type s
is a
function of type s -> a
. That's nothing but the function type wrapped by our
Reader
type:
newtype Reader s a = Reader { runReader :: s -> a }
This type models computations that can depend on some read-only context of
type s
. The State
monad should allow us to modify the state. The way a
pure function "modifies" the state it is given as input is to return the updated
state. Thus, a function that can read and write a state of type s
should have
the type s -> (a, s)
. Just as Reader s a
is a wrapper around a function of
type s -> a
, State s a
is a wrapper around a function of type s -> (a, s)
:
newtype State s a = State { runState :: s -> (a, s) }
Let's figure out how to turn State s
into a monad. We need to come up with
appropriate definitions of return
and (>>=)
.
First return
. If x :: a
, then return x
should have the type State s a
.
Let's build it up in pieces again. The function virtually writes itself. We
start with the skeleton
return x = State $ \s -> ...
because a value of type State s a
should wrap a function of type
s -> (a, s)
. The function (\s -> ...)
must return a pair of type (a, s)
.
Given that we know nothing about the types a
and s
, all we have at our
disposal is the state s
given as argument to this function, and the value x
given as argument to return
. Thus, the only pair of type (a, s)
we can
produre is the pair (x, s)
:
return x = State $ \s -> (x, s)
This is exactly the behaviour we would expect from return x
: It's a
computation that returns x
and leaves the state alone.
Now what about x >>= f
. Remember, this should model the idea of applying f
to x
. As far as passing return values along is concerned, this function writes
itself again because we don't really have any choice how to produce a return
value of the correct type for this computation. However, both x
and f
may
modify the state. We need to make sure that these state modifications are
combined correctly. What we want to capture is the idea that x
is a
computation that produces some value of type a
and reads and modifies the
state in the process. Then we pass this value of type a
to f
. f
then runs,
and reads and modifies the state some more. The final result should be whatever
f
returns, and the state should be the state that f
leaves behind. The
important part to focus on is that the state we pass to f
is the state left
behind by x
, because f
runs after x
. This is different from the Reader
monad, where we passed the same context to both x
and f
, because neither can
modify the context. Here's the implementation of x >>= f
that implements this
logic:
x >>= f = State $ \s -> let (y, s') = runState x s
in runState (f y) s'
Given some state s
, this runs x
on this state and collects x
's return
value y
and the updated state s'
. It then passes y
to f
and runs the
resulting function of type s -> (a, s)
on the updated state s'
. The result
is whatever result-state pair f
returns.
Our whole Monad
instance for the State s
monad is
instance Monad (State s) where
return x = State $ \s -> (x, s)
x >>= f = State $ \s -> let (y, s') = runState x s
in runState (f y) s'
Using the State
monad without using do
-notation is quite painful. So, before
we look at some examples, let's introduce how to access the state from within
do
-blocks. The functions are provided by the Control.Monad.State
module.
Accessing the State
The most elementary functions are get
and put
. They respectively allow us to
read the state within a do
-block and to replace the current state with a new
value. If we think about the state as a variable, get
reads it and put
writes it. This allows us to write do
-blocks of the following shape:
do
-- Read the state
s <- get
-- Compute something from s, including an updated state s'
...
-- Make s' the new state
put s'
Here is how get
and put
are implemented:
get :: State s s
get = State $ \s -> (s, s)
put :: s -> State s ()
put s = State $ \_ -> ((), s)
Remember, a computation of type State s a
operates on some state of type s
and returns a value of type a
. We modelled this as a function of type
s -> (a, s)
. get
should not modify the state and should return the current
state as its result. Thus, get
is implemented as the function \s -> (s, s)
.
put s
should replace the current state with s
and do nothing else. In
particular, put s
does not even return anything useful. Thus, the first
component of the pair returned by put
is ()
(void). Since put s
replaces
the old state with s
, the function it wraps ignores its state argument; it
uses a wildcard. The new state is s
, so the state component of the pair
returned by put s
is s
.
Before looking at two more useful helpers, let's see get
and put
in action.
We'll continue our tradition of silly examples that have the benefit of being
simple, so they should be easy to follow. We want to implement a function
rangeList n
that generates the list [1..n]
:
>>> rangeList 10
[1,2,3,4,5,6,7,8,9,10]
>>> rangeList 20
[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20]
As I said, this is a silly example. We would normally simply use the list
comprehension [1..n]
instead of calling rangeList n
. To generate the result
of rangeList
, we will use a counter. That's our state. We will use a function
go
that uses this state to generate the list we want. Specifically, go
will
generate the list [counter..n]
, where counter
is the current counter value.
If counter > n
, this is the empty list. If counter <= n
, we build our list
from the value counter
and the list we obtain by calling go
recursively
after increasing counter
by one:
rangeList :: Int -> [Int]
rangeList n = fst $ runState go 1
where
go = do
counter <- get
if counter > n then
return []
else do
put $ counter + 1
xs <- go
return $ counter : xs
>>> import Control.Monad.State
>>> :{
| rangeList :: Int -> [Int]
| rangeList n = fst $ runState go 1
| where
| go = do
| counter <- get
| if counter > n then
| return []
| else do
| put $ counter + 1
| xs <- go
| return $ counter : xs
| :}
>>> rangeList 10
[1,2,3,4,5,6,7,8,9,10]
>>> rangeList 20
[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20]
The function go
has the type go :: State Int [Int]
. It's a stateful
computation with a state of type Int
, the counter. go
starts by retrieving
the current counter value using counter <- get
. If counter > n
, then the
list [counter..n]
is empty, so we simply return this empty list. Otherwise, we
increment the counter using put $ counter + 1
, then call go
recursively to
generate the list xs = [counter + 1..n]
, and finally return the list
counter : xs = [counter..n]
. To generate the list [1..n]
, we call go
with
initial state 1
. That's what runState go 1
does. The result is the pair
([1..n], n + 1)
composed of the result of calling go
and the state n + 1
left behind once go
is done. We don't care about this final state in this
example, so we extract the list [1..n]
by applying fst
to this result.
Using the fact that every monad is a functor and <$>
is the infix version of
fmap
, we can make this implementation slightly more compact:
rangeList :: Int -> [Int]
rangeList n = fst $ runState go 1
where
go = do
counter <- get
if counter > n then
return []
else do
put $ counter + 1
(counter :) <$> go
Instead of assigning the result of go
to xs
and then returning the list
counter : xs
, we fmap
the function (counter :)
over go
. This works for
any monad. If f :: a -> b
and g :: m a
, then f <$> g
has type m b
. This
result is obtained by applying the pure function f
to the result of g
.
That's simply what fmap
does for every functor.
Just as the Reader
monad provides us with a function asks
to retrieve only
part of the context, we have gets
to read only part of the state:
gets :: (s -> a) -> State s a
gets f = State $ \s -> (f s, s)
We often use gets
when our state is some record storing multiple pieces of
information, say
data MyState = MyState
{ counter :: Int
, list :: [()]
}
We then use the record's accessor function to read only part of this state. For
example, ctr <- gets counter
would read only the counter
and ignore the
list
of the current state. Let me emphasize that while we often use gets
using accessor functions of some record type that we use as the state, the
function we pass to gets
can really be an arbitrary function of type
s -> a
.
The final function to access the state can be seen as a combination of get
and
put
. Instead of reading the state using get
, doing something that depends on
this state, and then writing an updated state using put
, we can also modify
the current state using a pure function of type s -> s
. It computes the new
state value from the old state value:
modify :: (s -> s) -> State s ()
modify f = State $ \s -> ((), f s)
To illustrate modify
, let's look at our rangeList
function one more time. I
promise, I'll stop after this example. Here's an implementation that uses a
record type as state and uses modify
to generate the list. Let me emphasize
that this is not a very good implementation. It's the type of code that an
imperative programmer trying to program in Haskell would write.
data RangeListState = RLS
{ counter :: Int
, list :: [Int]
}
rangeList :: Int -> [Int]
rangeList n = list . snd $ runState go $ RLS n []
where
go = do
ctr <- gets counter
if ctr == 0 then
return ()
else
modify oneStep >> go
oneStep (RLS ctr list) = RLS (ctr - 1) (ctr : list)
>>> :{
| data RangeListState = RLS
| { counter :: Int
| , list :: [Int]
| }
|
| rangeList :: Int -> [Int]
| rangeList n = list . snd $ runState go $ RLS n []
| where
| go = do
| ctr <- gets counter
| if ctr == 0 then
| return ()
| else
| modify oneStep >> go
|
| oneStep (RLS ctr list) = RLS (ctr - 1) (ctr : list)
| :}
>>> rangeList 10
[1,2,3,4,5,6,7,8,9,10]
>>> rangeList 20
[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20]
Let's take this apart. Our new function go
does not take any arguments and
does not produce anything useful. Its return type is ()
. All its input and
output are part of the state. This state consists of the current counter value
counter
and the list list
we are generating. go
maintains the invariant
that list = [counter + 1..n]
. This invariant holds initially, when we call
go
with the state RLS n []
. go
starts by retrieving the current counter
value using gets counter
. If this counter value is 0, then based on our
invariant, we have list = [1..n]
. Thus, go
returns. Note that rangeList n
applies the function list . snd
to the result of runState go $ RLS n []
.
snd
extracts the state component from the pair returned by runState
, and
list
extracts the list from this state. So, once go
returns, rangeList
retrieves the list [1..n]
. If the current counter value is not 0, then go
modifies the current state by calling modify oneStep
and then calls itself
recursively. oneStep
takes the current state RLS ctr list
and replaces it
with the state RLS (ctr - 1) (ctr : list)
. In other words, it adds the current
counter value to the beginning of the current list, and then decreases the
counter by 1. Thus, the invariant that list = [counter + 1..n]
is maintained,
and each application of modify oneStep
gets us one step closer to having a
counter value of 0.1
Running Stateful Computations
Once again, code written using the State
monad isn't really imperative code;
it's purely functional code expressed using a monad. Thus, we should once again
have a method to convert any computation that uses the State
monad into a pure
function. And we can. We already used runState
in some of the examples in this
section. If f
has the type State s a
, then runState f
has the type
s -> (a, s)
. This type signature provides no indication that the underlying
function uses the State
monad in its implementation. It is simply a function
that takes a value of type s
and returns a pair of type (a, s)
.
For the Writer
monad, we also had a function execWriter
that discarded the
return value of the computation and returned only the final log. Similarly,
execState
discards the return value of a computation in the State
monad and
only returns the final state.
execState :: State s a -> s -> s
execState f s = snd $ runState f s
Our fairly contrived last implementation of the rangeList
function was an
example of a computation where we cared only about the state left behind by the
computation. go
does not return anything useful. All we care about is the
list
component of the state left behind by go
. Thus, we could have
implemented rangeList
using execState
:
data RangeListState = RLS
{ counter :: Int
, list :: [Int]
}
rangeList :: Int -> [Int]
rangeList n = list $ execState go $ RLS n []
where
go = do
ctr <- gets counter
if ctr == 0 then
return ()
else
modify oneStep >> go
oneStep (RLS ctr list) = RLS (ctr - 1) (ctr : list)
The only line that has changed is the highlighted one, which uses
list $ execState go
instead of list . snd $ runState go
.
For the Writer
monad, it makes little sense to discard the log. Using this log
after the computation is finished is the whole point of using the Writer
monad
in the first place. For the State
monad, on the other hand, it makes perfect
sense not to care about the final state. The state is only needed while the
computation runs, but what we are interested in once all is done and dusted is
the return value of the computation. For this purpose, we have
evalState :: State s a -> s -> a
evalState f s = fst $ runState f s
Again, our initial implementation of rangeList
was an example of such a
computation. The counter used by go
to construct the list was necessary only
during the construction of the list. Once go
returns the list, we no longer
care about the value of the counter. So we can change the original
implementation of rangeList
to use evalState
instead.
rangeList :: Int -> [Int]
rangeList n = evalState go 1
where
go = do
counter <- get
if counter > n then
return []
else do
put $ counter + 1
xs <- go
return $ counter : xs
The highlighted line uses evalState go 1
instead of fst $ runState go 1
to
implement rangeList n
.
-
GHC is rich in useful language extensions. One such extension is
LambdaCase
, which allows us to write anonymous functions that immediately apply acase
expression to their arguments. Specifically, if we write\case pat1 -> expr1 pat2 -> expr2
after enabling
LambdaCase
, this is the same as if we had written\var -> case var of pat1 -> expr1 pat2 -> expr2
With this, we can make our definition of
rangeList
much prettier:data RangeListState = RLS { counter :: Int , list :: [Int] } rangeList :: Int -> [Int] rangeList n = list . snd $ runState go $ RLS n [] where go = gets counter >>= \case 0 -> return () _ -> modify oneStep >> go oneStep (RLS ctr list) = RLS (ctr - 1) (ctr : list)
GHCi>>> :set -XLambdaCase >>> :{ | data RangeListState = RLS | { counter :: Int | , list :: [Int] | } | | rangeList :: Int -> [Int] | rangeList n = list . snd $ runState go $ RLS n [] | where | go = gets counter >>= \case | 0 -> return () | _ -> modify oneStep >> go | | oneStep (RLS ctr list) = RLS (ctr - 1) (ctr : list) | :} >>> rangeList 10 [1,2,3,4,5,6,7,8,9,10] >>> rangeList 20 [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20]
This is still a terrible use of the
State
monad to do something we can do much more easily and succinctly using pure functions. I merely wanted to demonstrate how to obtain a much more readable version of ourrangeList
implementation that keeps everything in its state. ↩