State: Stateful Computations

Oh imperative programmers rejoice. You will learn now how to bring your dirty little world of destructive variable updates to the sacred halls of functional programming. The tool to do this is the State monad. Whereas the Reader monad allows us to access a context without modifying it, and the Writer monad provides us with a write-only log, the State monad gives us a state that we can both read and write at will.

How could we model stateful computations purely functionally? We want to define a type State s a that models a computation that produces some value a and depends on some state of type s to do so. In terms of pure functions, a computation that produces a value of type a from some state of type s is a function of type s -> a. That's nothing but the function type wrapped by our Reader type:

newtype Reader s a = Reader { runReader :: s -> a }

This type models computations that can depend on some read-only context of type s. The State monad should allow us to modify the state. The way a pure function "modifies" the state it is given as input is to return the updated state. Thus, a function that can read and write a state of type s should have the type s -> (a, s). Just as Reader s a is a wrapper around a function of type s -> a, State s a is a wrapper around a function of type s -> (a, s):

newtype State s a = State { runState :: s -> (a, s) }

Let's figure out how to turn State s into a monad. We need to come up with appropriate definitions of return and (>>=).

First return. If x :: a, then return x should have the type State s a. Let's build it up in pieces again. The function virtually writes itself. We start with the skeleton

return x = State $ \s -> ...

because a value of type State s a should wrap a function of type s -> (a, s). The function (\s -> ...) must return a pair of type (a, s). Given that we know nothing about the types a and s, all we have at our disposal is the state s given as argument to this function, and the value x given as argument to return. Thus, the only pair of type (a, s) we can produre is the pair (x, s):

return x = State $ \s -> (x, s)

This is exactly the behaviour we would expect from return x: It's a computation that returns x and leaves the state alone.

Now what about x >>= f. Remember, this should model the idea of applying f to x. As far as passing return values along is concerned, this function writes itself again because we don't really have any choice how to produce a return value of the correct type for this computation. However, both x and f may modify the state. We need to make sure that these state modifications are combined correctly. What we want to capture is the idea that x is a computation that produces some value of type a and reads and modifies the state in the process. Then we pass this value of type a to f. f then runs, and reads and modifies the state some more. The final result should be whatever f returns, and the state should be the state that f leaves behind. The important part to focus on is that the state we pass to f is the state left behind by x, because f runs after x. This is different from the Reader monad, where we passed the same context to both x and f, because neither can modify the context. Here's the implementation of x >>= f that implements this logic:

x >>= f = State $ \s -> let (y, s') = runState x s
                        in  runState (f y) s'

Given some state s, this runs x on this state and collects x's return value y and the updated state s'. It then passes y to f and runs the resulting function of type s -> (a, s) on the updated state s'. The result is whatever result-state pair f returns.

Our whole Monad instance for the State s monad is

instance Monad (State s) where
    return x = State $ \s -> (x, s)
    x >>= f  = State $ \s -> let (y, s') = runState x s
                             in  runState (f y) s'

Using the State monad without using do-notation is quite painful. So, before we look at some examples, let's introduce how to access the state from within do-blocks. The functions are provided by the Control.Monad.State module.

Accessing the State

The most elementary functions are get and put. They respectively allow us to read the state within a do-block and to replace the current state with a new value. If we think about the state as a variable, get reads it and put writes it. This allows us to write do-blocks of the following shape:

    -- Read the state
    s <- get

    -- Compute something from s, including an updated state s'

    -- Make s' the new state
    put s'

Here is how get and put are implemented:

get :: State s s
get = State $ \s -> (s, s)

put :: s -> State s ()
put s = State $ \_ -> ((), s)

Remember, a computation of type State s a operates on some state of type s and returns a value of type a. We modelled this as a function of type s -> (a, s). get should not modify the state and should return the current state as its result. Thus, get is implemented as the function \s -> (s, s).

put s should replace the current state with s and do nothing else. In particular, put s does not even return anything useful. Thus, the first component of the pair returned by put is () (void). Since put s replaces the old state with s, the function it wraps ignores its state argument; it uses a wildcard. The new state is s, so the state component of the pair returned by put s is s.

Before looking at two more useful helpers, let's see get and put in action. We'll continue our tradition of silly examples that have the benefit of being simple, so they should be easy to follow. We want to implement a function rangeList n that generates the list [1..n]:

>>> rangeList 10
>>> rangeList 20

As I said, this is a silly example. We would normally simply use the list comprehension [1..n] instead of calling rangeList n. To generate the result of rangeList, we will use a counter. That's our state. We will use a function go that uses this state to generate the list we want. Specifically, go will generate the list [counter..n], where counter is the current counter value. If counter > n, this is the empty list. If counter <= n, we build our list from the value counter and the list we obtain by calling go recursively after increasing counter by one:

rangeList :: Int -> [Int]
rangeList n = fst $ runState go 1
    go = do
        counter <- get
        if counter > n then
            return []
        else do
            put $ counter + 1
            xs <- go
            return $ counter : xs
>>> import Control.Monad.State
>>> :{
  | rangeList :: Int -> [Int]
  | rangeList n = fst $ runState go 1
  |   where
  |     go = do
  |         counter <- get
  |         if counter > n then
  |             return []
  |         else do
  |             put $ counter + 1
  |             xs <- go
  |             return $ counter : xs
  | :}
>>> rangeList 10
>>> rangeList 20

The function go has the type go :: State Int [Int]. It's a stateful computation with a state of type Int, the counter. go starts by retrieving the current counter value using counter <- get. If counter > n, then the list [counter..n] is empty, so we simply return this empty list. Otherwise, we increment the counter using put $ counter + 1, then call go recursively to generate the list xs = [counter + 1..n], and finally return the list counter : xs = [counter..n]. To generate the list [1..n], we call go with initial state 1. That's what runState go 1 does. The result is the pair ([1..n], n + 1) composed of the result of calling go and the state n + 1 left behind once go is done. We don't care about this final state in this example, so we extract the list [1..n] by applying fst to this result.

Using the fact that every monad is a functor and <$> is the infix version of fmap, we can make this implementation slightly more compact:

rangeList :: Int -> [Int]
rangeList n = fst $ runState go 1
    go = do
        counter <- get
        if counter > n then
            return []
        else do
            put $ counter + 1
            (counter :) <$> go

Instead of assigning the result of go to xs and then returning the list counter : xs, we fmap the function (counter :) over go. This works for any monad. If f :: a -> b and g :: m a, then f <$> g has type m b. This result is obtained by applying the pure function f to the result of g. That's simply what fmap does for every functor.

Just as the Reader monad provides us with a function asks to retrieve only part of the context, we have gets to read only part of the state:

gets :: (s -> a) -> State s a
gets f = State $ \s -> (f s, s)

We often use gets when our state is some record storing multiple pieces of information, say

data MyState = MyState
    { counter :: Int
    , list    :: [()]

We then use the record's accessor function to read only part of this state. For example, ctr <- gets counter would read only the counter and ignore the list of the current state. Let me emphasize that while we often use gets using accessor functions of some record type that we use as the state, the function we pass to gets can really be an arbitrary function of type s -> a.

The final function to access the state can be seen as a combination of get and put. Instead of reading the state using get, doing something that depends on this state, and then writing an updated state using put, we can also modify the current state using a pure function of type s -> s. It computes the new state value from the old state value:

modify :: (s -> s) -> State s ()
modify f = State $ \s -> ((), f s)

To illustrate modify, let's look at our rangeList function one more time. I promise, I'll stop after this example. Here's an implementation that uses a record type as state and uses modify to generate the list. Let me emphasize that this is not a very good implementation. It's the type of code that an imperative programmer trying to program in Haskell would write.

data RangeListState = RLS
    { counter :: Int
    , list    :: [Int]

rangeList :: Int -> [Int]
rangeList n = list . snd $ runState go $ RLS n []
    go = do
        ctr <- gets counter
        if ctr == 0 then
            return ()
            modify oneStep >> go

    oneStep (RLS ctr list) = RLS (ctr - 1) (ctr : list)
>>> :{
  | data RangeListState = RLS
  |     { counter :: Int
  |     , list    :: [Int]
  |     }
  | rangeList :: Int -> [Int]
  | rangeList n = list . snd $ runState go $ RLS n []
  |   where
  |     go = do
  |         ctr <- gets counter
  |         if ctr == 0 then
  |             return ()
  |         else
  |             modify oneStep >> go
  |     oneStep (RLS ctr list) = RLS (ctr - 1) (ctr : list)
  | :}
>>> rangeList 10
>>> rangeList 20

Let's take this apart. Our new function go does not take any arguments and does not produce anything useful. Its return type is (). All its input and output are part of the state. This state consists of the current counter value counter and the list list we are generating. go maintains the invariant that list = [counter + 1..n]. This invariant holds initially, when we call go with the state RLS n []. go starts by retrieving the current counter value using gets counter. If this counter value is 0, then based on our invariant, we have list = [1..n]. Thus, go returns. Note that rangeList n applies the function list . snd to the result of runState go $ RLS n []. snd extracts the state component from the pair returned by runState, and list extracts the list from this state. So, once go returns, rangeList retrieves the list [1..n]. If the current counter value is not 0, then go modifies the current state by calling modify oneStep and then calls itself recursively. oneStep takes the current state RLS ctr list and replaces it with the state RLS (ctr - 1) (ctr : list). In other words, it adds the current counter value to the beginning of the current list, and then decreases the counter by 1. Thus, the invariant that list = [counter + 1..n] is maintained, and each application of modify oneStep gets us one step closer to having a counter value of 0.1

Running Stateful Computations

Once again, code written using the State monad isn't really imperative code; it's purely functional code expressed using a monad. Thus, we should once again have a method to convert any computation that uses the State monad into a pure function. And we can. We already used runState in some of the examples in this section. If f has the type State s a, then runState f has the type s -> (a, s). This type signature provides no indication that the underlying function uses the State monad in its implementation. It is simply a function that takes a value of type s and returns a pair of type (a, s).

For the Writer monad, we also had a function execWriter that discarded the return value of the computation and returned only the final log. Similarly, execState discards the return value of a computation in the State monad and only returns the final state.

execState :: State s a -> s -> s
execState f s = snd $ runState f s

Our fairly contrived last implementation of the rangeList function was an example of a computation where we cared only about the state left behind by the computation. go does not return anything useful. All we care about is the list component of the state left behind by go. Thus, we could have implemented rangeList using execState:

data RangeListState = RLS
    { counter :: Int
    , list    :: [Int]

rangeList :: Int -> [Int]
rangeList n = list $ execState go $ RLS n []
    go = do
        ctr <- gets counter
        if ctr == 0 then
            return ()
            modify oneStep >> go

    oneStep (RLS ctr list) = RLS (ctr - 1) (ctr : list)

The only line that has changed is the highlighted one, which uses list $ execState go instead of list . snd $ runState go.

For the Writer monad, it makes little sense to discard the log. Using this log after the computation is finished is the whole point of using the Writer monad in the first place. For the State monad, on the other hand, it makes perfect sense not to care about the final state. The state is only needed while the computation runs, but what we are interested in once all is done and dusted is the return value of the computation. For this purpose, we have

evalState :: State s a -> s -> a
evalState f s = fst $ runState f s

Again, our initial implementation of rangeList was an example of such a computation. The counter used by go to construct the list was necessary only during the construction of the list. Once go returns the list, we no longer care about the value of the counter. So we can change the original implementation of rangeList to use evalState instead.

rangeList :: Int -> [Int]
rangeList n = evalState go 1
    go = do
        counter <- get
        if counter > n then
            return []
        else do
            put $ counter + 1
            xs <- go
            return $ counter : xs

The highlighted line uses evalState go 1 instead of fst $ runState go 1 to implement rangeList n.

  1. GHC is rich in useful language extensions. One such extension is LambdaCase, which allows us to write anonymous functions that immediately apply a case expression to their arguments. Specifically, if we write

        pat1 -> expr1
        pat2 -> expr2

    after enabling LambdaCase, this is the same as if we had written

    \var -> case var of
        pat1 -> expr1
        pat2 -> expr2

    With this, we can make our definition of rangeList much prettier:

    data RangeListState = RLS
        { counter :: Int
        , list    :: [Int]
    rangeList :: Int -> [Int]
    rangeList n = list . snd $ runState go $ RLS n []
        go = gets counter >>= \case
            0 -> return ()
            _ -> modify oneStep >> go
        oneStep (RLS ctr list) = RLS (ctr - 1) (ctr : list)
    >>> :set -XLambdaCase
    >>> :{
      | data RangeListState = RLS
      |     { counter :: Int
      |     , list    :: [Int]
      |     }
      | rangeList :: Int -> [Int]
      | rangeList n = list . snd $ runState go $ RLS n []
      |   where
      |     go = gets counter >>= \case
      |         0 -> return ()
      |         _ -> modify oneStep >> go
      |     oneStep (RLS ctr list) = RLS (ctr - 1) (ctr : list)
      | :}
    >>> rangeList 10
    >>> rangeList 20

    This is still a terrible use of the State monad to do something we can do much more easily and succinctly using pure functions. I merely wanted to demonstrate how to obtain a much more readable version of our rangeList implementation that keeps everything in its state.