This page is located in archive. Go to the latest version of this course pages. Go the latest version of this page.

Lab 13: State Monad

This lab is focused on the state monad State. In the lecture, I show you how it is implemented. In this lab, we are going to use the implementation from the lecture State.hs. So include the following lines in your source file:

import Control.Monad
import State
import System.Random
import Data.List

The third import is important as we are going to work with pseudorandom numbers. The last import allows us to use some extra functions to manipulate lists.

If import System.Random doesn't work for you, you need to install the package random as follows:
  1. either locally into the current directory by
    cabal install --lib random --package-env .
  2. or globally by
    cabal install --lib random

If you don't have cabal (e.g. computers in labs), put the file Random.hs into the directory containing your Lab-13 code and replace import System.Ramdom with import Random.

The state monad State s a is a type constructor taking two parameters s and a representing type of states and output respectively. You can imagine this type as

newtype State s a = State { runState :: s -> (a, s) }
The unary type constructor State s is an instance of Monad. The values enclosed in this monadic context are functions taking a state and returning a pair whose first component is an output value and the second one is a new state. Using the bind operator, we can compose such functions in order to create more complex stateful computations out of simpler ones. The function runState :: State s a -> s -> (a, s) is the accessor function extracting the actual function from the value of type State s a.

As State s is a monad, we can use all generic functions working with monads, including the do-notation. Apart from that, the implementation of the state monad comes with several functions allowing to handle the state.

get :: State s s                  -- outputs the state but doesn't change it
put :: s -> State s ()            -- set the state to the given value, outputs empty value
modify :: (s -> s) -> State s ()  -- modifies the state by the given function, outputs empty value
evalState :: State s a -> s -> a  -- computes just the output
evalState p s = fst (runState p s)
execState :: State s a -> s -> s  -- computes just the final state 
execState p s = snd (runState p s)

States in purely functional languages must be handled via accumulators added into function signatures. Using the state monad allows us to abstract away those accumulators.

Exercise 1: Consider the function reverse reversing a given list. We can implement it as a tail recursive function in the same way as in Scheme using an accumulator.

reverseA :: [a] -> [a]
reverseA xs = iter xs [] where
    iter [] acc = acc
    iter (y:ys) acc = iter ys (y:acc) 

Now we try to implement that via the state monad. The above accumulator is a list. So we will use it as our state. We don't have to output anything as the resulting reversed list is stored in the accumulator/state. Thus we are interested in the type State [a] () whose values contain functions of type [a] -> ((), [a]). We will implement a function reverseS :: [a] -> State [a] () which takes a list and returns a stateful computation reversing the given list (i.e. a monadic value enclosing a function of type [a] -> ((), [a]) reversing the given list).

I will show several variants. The first more or less copies the tail recursive function reverseA.


Now we can execute the returned computation as follows:

> runState (reverseS [1,2,3,4]) []
> execState (reverseS [1,2,3,4]) []

The above variant just strips off the first element x and modifies the state by the function (x:). Thus we can rewrite it as follows:


Finally, the above variant is just applying the action modify (x:) for every x in the list. Thus we can use the monadic function mapM_ :: (a -> m b) -> [a] -> m () taking a function creating a monadic action from an argument of type a and a list of values of type a. The resulting action outputs the empty value. Once it is executed, it executes all the actions returned by applying the given function to each element in the given list.


Task 1: Suppose you are given a list of elements. Your task is to return a list of all its pairwise different elements together with the number of their occurrences. E.g. for “abcaacbb” it should return [('a',3),('b',3),('c',2)] in any order. A typical imperative approach to this problem is to keep in memory (as a state) a map from elements to their numbers of occurrences. This state is updated as we iterate through the list. With the state monad, we can implement this approach in Haskell.

First, we need a data structure representing a map from elements to their numbers of occurrences. We can simply represent it as a list of pairs (el, n) where el is an element and n is its number of occurrences. We also define a type representing a stateful computation over the map Map a Int.

type Map a b = [(a,b)]
type Freq a = State (Map a Int) ()

Hint: First implement the following pure function taking an element x and a map m and returning an updated map. If the element x is already in m (i.e., there is a pair (x,n)), then return the updated map which is the same as m except the pair (x,n) is replaced by (x,n+1). If x is not in m, then return the map extending m by (x,1). To check that x is in m, use the function lookup :: Eq a => a -> [(a, b)] -> Maybe b that returns Nothing if x is not in m and otherwise Just n where n is the number of occurrences of x.

update :: Eq a => a -> Map a Int -> Map a Int
update x m = case lookup x m of
    Nothing -> (x,1):m
    Just n -> (x,n+1):[p | p <- m, fst p /= x] 

Once you have that, take the inspiration from Exercise 1 and implement a function freqS taking a list and returning the stateful computation that computes the map of occurrences once executed. E.g.

> execState (freqS "Hello World") [] 
[('d',1),('l',3),('r',1),('o',2),('W',1),(' ',1),('e',1),('H',1)]

freqS :: Eq a => [a] -> State (Map a Int) ()
freqS = mapM_ (modify . update)
-- Alternatively you can do this
freqS [] = return ()
freqS (x:xs) = do m <- get
                  let m' = update x m
                  put m'
                  --modify (update x)  -- or replace the first 3 lines with this
                  freqS xs 

Exercise 2: Recall that pseudorandom numbers from a given interval $(x,y)$ can be generated by the function

randomR :: (RandomGen g, Random a) => (a, a) -> g -> (a, g)
located in the module System.Random. It takes an interval and a generator and returns a random value of type a in the given interval together with a new generator. Random is a type class collecting types whose random values can be generated by randomR. A first generator can be created by mkStdGen :: Int -> StdGen from a given seed.

If we want to generate a sequence of random numbers, we have to use in each step the new generator obtained from the previous step. To abstract the generators away, we use the state monad whose states are generators, i.e., State StdGen a where StdGen is the type of generators. The type a serves as the type of the generated random numbers. To shorten the type annotations, we introduce a new name:

type R a = State StdGen a

Our task is to implement a function that integrates a function $f\colon\mathbb{R}\to\mathbb{R}$ on the given interval $(a,b)$ by the Monte-Carlo method, i.e., we want to compute approximately $\int_a^b f(x)\mathrm{d}x$. For simplicity, we assume that $f(x)\geq 0$ for all $x\in (a,b)$. The Monte-Carlo method is a sampling method. If we know an upper bound $u$ for $f$ on the interval $(a,b)$, we can estimate the area below the graph of $f$ by generating a sequence of random points in the rectangle $(a,b)\times(0,u)$. Then we count how many points were below $f$. The integral is approximately $\frac{k}{n}(b-a)u$ where $k$ is the number of point below the graph of $f$ and $n$ is the number of all generated points (see the picture).

Solution: We first prepare a stateful computation, generating a sequence of random points in a given rectangle. We define two types:

type Range a = (a,a)
type Pair a = (a,a)
The first one represents intervals, and the second one points. Next, we define a function taking an interval and returning a stateful computation generating a single random value in the given interval.


The above function can be simplified using the constructor state :: (s -> (a,s)) -> State s a as was shown in the lecture.

randR:: Random a => Range a -> R a
randR r = state (randomR r)

Since we need to generate points, we define a function taking two intervals and returning a stateful computation generating a random point in their Cartesian product.


Note that we don't have to deal with generators when sequencing randR xr with randR yr. Now if we want to generate a random point, we can execute the stateful computation returned by randPair. E.g. we create an initial state/generator by mkStdGen seed and then use the evalState function because we are interested only in the output not the final generator.

> evalState (randPair (0,1) (4,5)) (mkStdGen 7)
To simplify the above call, we define a function executing any stateful computation of type R a.
runRandom :: R a -> Int -> a
runRandom action seed = evalState action $ mkStdGen seed
Now we need a sequence of random points. We can define a recursive function doing that as follows:


The function randSeq just sequences the actions randPair and collects their results. So we can use sequence allowing to take a list of monadic actions and returning the action, which is the sequence of those actions returning the list of their outputs. To create a list to randPair actions, use the function replicate.

> replicate 5 3
> runRandom (sequence $ replicate 3 (randPair (0,1) (0,1))) 7
In fact, there is a monadic version of replicate. So we can rewrite the last call as follows:
> runRandom (replicateM 3 (randPair (0,1) (0,1))) 7

Now we are ready to finish the Monte-Carlo integration. It takes as arguments a function $f$, an inteval $(a,b)$, an upper bound $u$ and a number of points to be generated.


You can test it on functions you know their integrals.

> integrate id (0,1) 1 10000    -- f(x)=x on (0,1) should be 0.5
> integrate (^2) (0,1) 1 10000  -- f(x)=x^2 on (0,1) should be 1/3
> integrate sin (0,pi) 1 10000  -- f(x)=sin x on (0,pi) should be 2
> integrate exp (0,1) 3 10000   -- f(x)=e^x on (0,1) should e-1

Task 2: Implement a function generating a random binary tree having $n$ many nodes. To be more specific, consider the following type:

data Tree a = Nil | Node a (Tree a) (Tree a) deriving (Eq, Show)
The values of Tree a are binary trees having a value of type a in their nodes together with left and right children. The Nil value indicates that there is no left (resp. right) child. Leaves are of the form Node x Nil Nil.

Your tasks is to implement a function randTree taking a number of nodes n and natural number k and returning a stateful computation generating a random binary tree having n many nodes containing values from $\{0,1,\ldots,k-1\}$.

Hint: To generate random integers in a given interval, use the above function randR. Generating a random binary tree can be done recursively. In each node, you generate a random integer $m$ from $\{0,\ldots,n-1\}$. This is the number of nodes of the left subtree. So you recursively generate the left subtree ltree with m many nodes. Then you recursively generate the right subtree rtree with $n-m-1$ many nodes. Finally you return Node x ltree rtree where x is a randomly generated integer from $\{0,1,\ldots,k-1\}$. The base case for $n=0$ just returns Nil, i.e., no subtree.

randTree :: Int -> Int -> R (Tree Int)
randTree 0 _ = return Nil
randTree n k = do m <- randR (0,n-1)
                  ltree <- randTree m k
                  rtree <- randTree (n-m-1) k
                  x <- randR (0,k-1)
                  return $ Node x ltree rtree

You can use your function to generate a random binary tree with 10 nodes containing integers from $\{0,1,2\}$ as follows:

> runRandom (randTree 10 3) 1
Node 1 (Node 0 (Node 2 Nil Nil) (Node 1 (Node 2 Nil Nil) (Node 1 Nil Nil))) (Node 1 (Node 1 Nil Nil) (Node 2 Nil (Node 2 Nil Nil)))

You can also check that the method does not provide a uniform distribution.

> trees = runRandom (replicateM 10000 (randTree 3 2)) 1
> execState (freqS trees) []

courses/fup/tutorials/lab_13_-_state_monad.txt · Last modified: 2022/05/18 09:35 by xhorcik