Handling state with monads

The previous post on “Fundamentals of Monads” suggested that monads empower the programmer with abstraction of context (algebraic data type) details, as well as order of evaluation, with the latter bringing us one step closer to the model of imperative languages when we need it, while still maintaining all the beautiful properties of the functional language that Haskell is. If you haven’t read the previous post, you may want to, as some of the terminology and intuition described there is used here.

However, for Haskell to be truly useful, not only do we need to guarantee order of execution (when it’s needed), but we also need a way to maintain state (again, when it’s needed). Too many problems are stateful (i.e. needs side-effects) by nature, and pure functional languages don’t quite handle the notion of state (as least in the primitive language constructs) very intuitively.

A simple solution that is often given to Haskell beginners, is that state can be captured by literally passing the state around functions, whereby each function has access to the state, and can modify the state.

However, that approach is tedious, and has a whole bunch of boilerplate. Again, can we find a way to abstract away that whole notion of state, so that a programmer can ignore the fact that he needs to pass state around, while still having the side-effect’ing notion of a state?

Indeed we can. Enter the State monad.

Note: At this point, because we’re introducing the state monad, we’re going to focus purely on the state monad. Haskell offers other ways of handling state, as well as more powerful monads that do this. This is just the very beginning.

As usual, a monad is characterized by the data type definition, and the definition for >>=. Here’s what the state monad looks like:

data State s a = State s -> (a, s)

Alternatively, to get a “getter” function for free, and using newtype instead, we can define it as follows (this is the formal definition):

newtype State s a = State { runState :: s -> (a, s) }

Now, this looks very different from the Maybe monad, and the List monad. The key difference is that the so-called monadic “value” (or rather, a more general and precise word would be “computation”, and since that’s what it truly is we’ll use that here, properly) is not a “value”, as in the intuitive sense that 4, or 100, is, but rather a function. See why it’s a computation? By the way, a computation can result in a value, but in this case, it’s a function.

In short, remember that monads give context to it’s value (or computation). In this case, our value (or computation) is a function, and we’re going to give it a context. That context is the state. Also, since the function acts on a state, it has to receive the state, logically. Hence, the function takes a state, of type s, and produces a result and a new state, represented in a tuple of types (a, s). Hence, remember, the state monad gives a function a context of state. Since those functions act on state (hey, it’s a state monad), we’ll call them state-transformation functions.

So, we have a state monad, that contextualizes a function with the context of state. This implies that the function we’re talking about manipulates state in some way. We’ll call them state-transformation functions.

runState gives us a “reference” to that function for free. In other words, runState “extracts” the function for us.

Is State a Monad? Not quite.

Looking at the type of State, we notice that it is:

State s a

That’s two free variables. If you were alert while figuring out monads, you’d realize that monads take only one free variable. Look:

class Monad m where
  return     :: a -> m a
  (>>=)      :: m a -> (a -> m b) -> m b
  (>>)       :: m a -> m b -> m b
  fail       :: String -> m a

Everywhere, you see m a and m b, and by type inference, that means that m is a type constructor that takes one parameter. So how can State be a monad?

Indeed, State itself is not, but (State s) is. And that is why the type declaration for State seems to reverse the a and s. It needs to partially apply and get (State s), so that (State s) is a monad, and the parameter is a. What this means is that the type of the state, s, is constant, while the type of the result, a, is not, and indeed by change as we >>= stuff. Put another way, from the Monad typeclass above, m is (State s).

As such, for the State monad, the type definitions looks like:

class Monad (State s) where
  return     :: a -> State s a
  (>>=)      :: State s a -> (a -> State s b) -> State s b
  (>>)       :: State s a -> State s b -> State s b
  fail       :: String -> State s a

This is actually more important than it looks. The monad is (State s), which means that the monadic function expects a type of a (which is not a function, nor a state, but a resultant output value), and acts on that value, and returns a state-transformation function. Yours truly was, for a while, confused because I thought that, like simpler monads, the monadic function takes a state-transformation function (wrapped in the monad), acts on it, and returns a state-transformation function – just like how a monadic function for Maybe takes a value, acts on it, and returns a Maybe type. It does not!

The monadic function for state monads take a resultant output value, acts on it, and returns a state-transformation function.

Getting an Intuitive Sense of State

At this point, it would do us some good to revise the intuitive concept of a monad, and see exactly how it applies to state. To recall, a monad, in general, comprises of the following ideas:

  • A monad gives a context to a value.

  • It also allows us to abstract away details of the context.

  • We also had the concept of a monadic value, which is the value which the monad contextualizes

  • And we also had the concept of a monadic function, which is a function that acts on a monadic value.

How does that specifically apply to the state monad?

  • A state monad gives a context to a function.

  • It also allows us to abstract away details of the context.

  • We also had the concept of a monadic value, which is the function which the state monad contextualizes

  • And we also had the concept of a monadic function, which is a function that acts on a monadic value.

The tricky bit is the last point. For simple monads like Maybe, recall that the monadic function is a function that acts on the inner value of the monadic value (the 4 in Just 4), does something with it, and outputs a monadic value. Hence, this would be right:

monadicFunction :: Int -> Maybe Int
monadicFunction x = Just (x + 1)

Just 4 >>= monadicFunction    ==> Just 5

How about for our State monad? Our monadic value represents a function. Hence, what is our monadic function? We’ll use exactly the same words here, to see the parallel.

Our monadic function is a function that takes a resultant output value, does something with it, and outputs a state-transformation function.

Yes, I’m be-laboring the same point as in the previous section, because it’s important.

Exploring return

Here’s the definition of return, for the state monad:

return a = State (s -> (a, s))

Quite simple. The constructor expects a function that takes a state and returns a result and a modified state, as explained earlier. When we give a “simple” value to return, we wrap it up in a state, and generate a function that, when given some state, just throws back the simple value, with no state modification. Can’t be simpler than that.

However, a common point of confusion here is what is the “state” that the function expects? We’re so used to wrapping values that sometimes we forget this is a function, so I’ll say it here again, this is a function. The state that the function expects, we do not know. It expects that state as a parameter so that it can act on the state! If this is obvious to you, please ignore it.

Exploring >>=

How about the >>= bind operator? At this point, let’s not confuse ourselves with the official definition of >>=. That will come very soon. But let us get some more intuition first: the >>= operator is a mechanism for composing state-transformation functions. That’s the outcome we want out of >>=.

So what, when dealing with state, what do we want the bind operator >>= to mean? Remember, monads are all about giving some context, and >>= is all about abstracting away that context (plus extracting values for sequenced computations). So, again, what does >>= mean?

Intuitively, and very accurately so, it’s just function composition! Yes, like head . tail composes a function that returns the second element of a list. But things are a bit special here: we are composing state-transformation functions, not ordinary (pure) functions like head and tail. Since state is sequential, we need to be careful to produce the correct intermediate state when composing functions.

Once again, we must remember not to get confused between state-transformation functions, which lives inside the state monad as part of the monadic value, and the monadic function, which is the f in m >>= f.

The state-transformation function, f’, looks like this:

f' :: s -> (a, s)

This state-transformation function is, again to repeat, wrapped up in our state monad, as the “value”. So, if we really have such a function, f, as defined above, we could certainly do this:

State f'

and happily wrap f up in a state monad. But that’s beside the point here.

The monadic function, f, is up to what the programmer wants to do with the resultant value. An example may look like this:

f :: a -> State s a        -- which you can think of as f :: a -> (s -> (a, s))
f x = State (s -> (x*2, s))

Now, remember the type of the >>= operator:

m >>= f :: m a -> (a -> m b) -> m b

What the bind operator >>= is doing, is it’s going to take a monadic value (containing our state-tranformation function), strips out the context, get the result of the state-transformation function, feed that as the resultant output value into the monadic function (f), which is going to produce another monadic value (containing a state-transformation function). It’s then going to give the new state-transformation function the new state produced by the first state-transformation function.

If that sounds like a mouthful, it is. Let’s go through this step by step.

** First and foremost, here’s what >>= wants to achieve: **

  • >>= wants to compose state-transformation functions.

  • It wants to take a state-transformation function wrapped in a state monad, and a monadic function.

  • It knows that the monadic function is going to act on the result of the state-transformation function, and produce yet another state-transformation function

  • It knows that at the end, it has two state-transformation functions to handle

  • It wants to join the state-transformation functions, and correctly thread the state through them.

  • It produces a new state-transformation function, which is the composition of the two functions, with properly threaded state.

The >>= bind operator looks like this:

m >>= f  = State (\s ->
                      let (a, s') = (runState m) s
                      in runState (f a) s')

Let’s unpack that. m is the monad that contains a state-transformation function. Remember that runState basically pulls out that function. we can see it as a helper function that just accesses the inner value of the monad). Certainly, we could have patterned matched as well, like this:

State f' >>= f  = State (\s ->
                             let (a, s') = f' s
                             in runState (f a) s')

And here’s how it achieves what it wants to achieve:

  • (runState m) extracts the state-aware function, and applies a state to it. Whatever that state-aware function was supposed to do, it was waiting for a state, and hence was (and is still) just a computation (think of it as a piece of code, a function). We apply some state to it, and hence we can “run” it, but not quite. The state that we are applying to it is still a parameter of the new function that we are composing. So (f' s), or (runState m) s, still remains a computation (a function) that is waiting on some state.

  • Anyway, assuming that at some point that gets evaluated (when we provide the state), we bind that result to a, and its production of a new state, to s, in the tuple (a, s').

  • We then apply the monadic function, f, to the result of the previous computation, a. Remember that the monadic function is expecting the resultant output value. It’s also going to produce, as output, a state-transformation function.

  • Hence, the new function produced by (f a) is yet another state-transformation function, waiting on a state, wrapped in a state monad. Hence we use runState to pull that out. However, going by the way we need to thread states, we give it the new state, s', produced by the previous computation.

  • The final result is a composition of the function wrapped in m, and the new function, f.

We can continue along this path, “composing” functions together which each modify the state (s gets modified to s'), into one huge function that contains all the modifications (think many nested lets), all waiting on intermediate states, whereby the outermost function is waiting on the initial state. We provide that initial state, and the whole gigantic function evaluates, and produces a new state, and a result.

Also, remember that monads provide a guarantee of order of execution, by means of lambda functions (see the post on “Fundamentals of Monads”. Note also that our >>= operator is doing exactly that. Let’s see what happens when we compose more than one function.

m >>= f >>= g

  ==> (State (s1 ->
                 let (a1, s1') = (runState m) s1
                 in runState (f a1) s1')) >>= g

  ==> (State (s2 ->
                 let (a2, s2') = (runState (State (s1 ->
                                                   let (a1, s1') = (runState m) s1
                                                   in runState (f a1) s1'))) s2
                 in runState (g a2) s2'))

Notice how the lambda expressions are forcing the correct order of evaluation, and how the state is correctly threaded through the final expression? s2, in this case, is the initial state that, when provided, triggers the whole evaluation of the composed state-aware functions inside m, of f and of g:

runState (State (s2 ->
                 let (a2, s2') = (runState (State (s1 ->
                                                   let (a1, s1') = (runState m) s1
                                                   in runState (f a1) s1'))) s2
                 in runState (g a2) s2')) initState