Building a probabilistic programming interpreter
Rob Zinkov
2015-08-25
Very often interpreters for probabilisitic programming languages (PPLs) can seem a little mysterious. In actuality, if you know how to write an interpreter for a simple language it isn’t that much more work.
Using Haskell as the host language I’ll show how to write a simple PPL which uses importance sampling as the underlying inference method. There is nothing special about using from haskell other than pattern-matching so this example should be pretty easy to port to other languages.
To start let’s import some things and set up some basic types
import Data.List hiding (empty, insert, map)
import Control.Monad
import Data.HashMap.Strict hiding (map)
import System.Random.MWC as MWC
import System.Random.MWC.Distributions as MD
type Name = String
type Env = HashMap String Val
Our language will have as values functions, doubles, bools and pairs of those.
data Val =
D Double |
B Bool |
F (Val -> Val) |
P Val Val
instance Eq Val where
D x == D y = x == y
B x == B y = x == y
P x1 x2 == P y1 y2 = x1 == y1 && x2 == y2
_ == _ = False
instance Ord Val where
D x <= D y = x <= y
B x <= B y = x <= y
P x1 x2 <= P y1 y2 = x1 <= y1 && x2 <= y2
_ <= _ = error "Comparing functions is undefined"
This language will have expressions for these values, conditionals and arithmetic.
data Expr =
Lit Double |
Var Name |
Pair Expr Expr |
Fst Expr |
Snd Expr |
If Expr Expr Expr |
Eql Expr Expr |
Les Expr Expr |
Gre Expr Expr |
And Expr Expr |
Lam Name Expr |
App Expr Expr |
Add Expr Expr |
Sub Expr Expr |
Mul Expr Expr |
Div Expr Expr
deriving (Eq, Show)
We can evalute expressions in this language without doing anything special.
evalT :: Expr -> Env -> Val
evalT (Lit a) _ = D a
evalT (Var x) env = env ! x
evalT (Lam x body) env = F (\ x' -> evalT body (insert x x' env))
evalT (App f x) env = app (evalT f env) (evalT x env)
evalT (Eql x y) env = B $ (evalT x env) == (evalT y env)
evalT (Les x y) env = B $ (evalT x env) <= (evalT y env)
evalT (Gre x y) env = B $ (evalT x env) >= (evalT y env)
evalT (And x y) env = liftB (&&) (evalT x env) (evalT y env)
evalT (Add x y) env = liftOp (+) (evalT x env) (evalT y env)
evalT (Sub x y) env = liftOp (-) (evalT x env) (evalT y env)
evalT (Mul x y) env = liftOp (*) (evalT x env) (evalT y env)
evalT (Div x y) env = liftOp (/) (evalT x env) (evalT y env)
evalT (Pair x y) env = P (evalT x env) (evalT y env)
evalT (Fst x) env = fst_ $ evalT x env
where fst_ (P a b) = a
evalT (Snd x) env = snd_ $ evalT x env
where snd_ (P a b) = b
evalT (If b t f) env = if_ (evalT b env) (evalT t env) (evalT f env)
where if_ (B True) t' f' = t'
if_ (B False) t' f' = f'
app :: Val -> Val -> Val
app (F f') x' = f' x'
liftOp :: (Double -> Double -> Double) ->
Val -> Val -> Val
liftOp op (D e1) (D e2) = D (op e1 e2)
liftB :: (Bool -> Bool -> Bool) ->
Val -> Val -> Val
liftB op (B e1) (B e2) = B (op e1 e2)
Of course this isn’t a probabilisitic programming language. So now we extend our language to include measures.
Let’s take a moment to explain what makes something a measure. Measures can considered un-normalized probability distributions. If you take the sum of the probability of each disjoint outcome from a un-normalized probability distribution, the answer may not be 1.
This is relevant as we will be representing measures as a list of weighted draws from the underlying distribution. Those draws will need to be normalized to be understood as a probability distribution.
We can construct measures in one of three ways. We may simply have the continuous uniform distribution whose bounds are defined as expressions. We may have a weighted distribution which only returns the value of its second argument, with probability of the first argument. This is only a probability distribution when the first argument evaluates to one. We’ll call this case dirac
The final form is what let’s us build measure expressions. What Bind
does is take a measure as input, and a function from draws in that measure to another measure.
Because I don’t feel like defining measurable functions in their own form, Bind
also takes a name to set what variable will hold values forthe draws, so the last argument to bind may just use that variable when it wants to refer to those draws. As an example if I wish to take a draw from a uniform distribution and then square that value.
prog1 = Bind "x" (Uniform (Lit 1) (Lit 5)) -- x <~ uniform(1, 5)
(dirac (Add (Var "x") (Var "x"))) -- return (x + x)
Measures are evaluated by producing a weighted sample from the measure space they represent. This is also called importance sampling.
evalM :: Meas -> Env -> MWC.GenIO -> IO (Val, Double)
evalM (Uniform lo hi) env g = do
let D lo' = evalT lo env
let D hi' = evalT hi env
x <- MWC.uniformR (lo', hi') g
return (D x, 1.0)
evalM (Weight i x) env g = do
let D i' = evalT i env
return (evalT x env, i')
evalM (Bind x m f) env g = do
(x', w) <- evalM m env g
let env' = insert x x' env
(f', w1) <- evalM f env' g
return (f', w*w1)
We may run these programs as follows
test1 :: IO ()
test1 = do
g <- MWC.create
draw <- evalM prog1 empty g
print draw
(7.926912543562406,1.0)
Evaluating this program repeatedly will allow you to produce as many draws from this measure as you need. This is great in that we can represent any unconditioned probability distribution. But how do we represent conditional distributions?
For that we will introduce another datatype
This is just an extension of Meas
expect now we may say, a measure is either unconditioned, or if its conditioned for each case we may specify additionally which value its conditioned on. To draw from a conditioned measure, we convert it into an unconditional measure.
evalC :: Cond -> Meas
evalC (UCond m ) = m
evalC (UniformC lo hi x) = Weight (If (And (Gre x lo)
(Les x hi))
(Div x (Sub hi lo))
(Lit 0)) x
evalC (WeightC i x y) = Weight (If (Eql x y)
i
(Lit 0)) y
evalC (BindC x m f) = Bind x (evalC m) (evalC f)
What evalC
does is determine what weight to assign to a measure given we know it will produce a particular value. This weight is the probability of getting this value from the measure.
And that’s all you need to express probabilisitic programs. Take the following example. Suppose we have two random variables x
and y
where the value of y
depends on x
x <~ uniform(1, 5)
y <~ uniform(x, 7)
What’s the conditional distribution on x
given y
is 3
?
prog2 = BindC "x" (UCond (Uniform (Lit 1) (Lit 5))) -- x <~ uniform(1, 5)
(BindC "_" (UniformC (Var "x") (Lit 7) (Lit 3)) -- y <~ uniform(x, 7)
-- observe y 3
(UCond (dirac (Var "x")))) -- return x
test2 :: IO ()
test2 = do
g <- MWC.create
samples <- replicateM 10 (evalM (evalC prog2) empty g)
print samples
[(1.099241451531848, 0.5084092113511076),
(3.963456271781203, 0.0),
(1.637454187135532, 0.5594357800735532),
(3.781075065891581, 0.0),
(1.908186342514358, 0.5891810269980327),
(2.799366130116895, 0.714177929552209),
(3.091757816253942, 0.0),
(1.486166046469419, 0.5440860253107659),
(3.106369061983323, 0.0),
(1.225163855492708, 0.5194952592470413)]
As you can see, anything above 3
for x
has a weight of 0
because it would be impossible for to observe y
with 3
.
Further reading
This implementation for small problems is actually fairly capable. It can be extended to support more probability distributions in a straightforward way.
If you are interested in more advanced interpreters I suggest reading the following.