summaryrefslogtreecommitdiff
path: root/examples/mcmc.dx
diff options
context:
space:
mode:
Diffstat (limited to 'examples/mcmc.dx')
-rw-r--r--examples/mcmc.dx24
1 files changed, 13 insertions, 11 deletions
diff --git a/examples/mcmc.dx b/examples/mcmc.dx
index 4644a747..9d7ea5f9 100644
--- a/examples/mcmc.dx
+++ b/examples/mcmc.dx
@@ -12,7 +12,7 @@ def runChain(
numSamples: Nat,
k:Key
) -> Fin numSamples => a given (a|Data) =
- [k1, k2] = split_key k
+ [k1, k2] = split_key(n=2, k)
with_state (initialize k1) \s.
for i:(Fin numSamples).
x = step (ixkey k2 i) (get s)
@@ -24,13 +24,13 @@ def propose(
cur: a,
proposal: a,
k: Key
- ) -> a given (a) =
+ ) -> a given (a:Type) =
accept = logDensity proposal > (logDensity cur + log (rand k))
select accept proposal cur
def meanAndCovariance(xs:n=>d=>Float) -> (d=>Float, d=>d=>Float) given (n|Ix, d|Ix) =
- xsMean : d=>Float = (for i. sum for j. xs[j,i]) / n_to_f (size n)
- xsCov : d=>d=>Float = (for i i'. sum for j.
+ xsMean : d=>Float = (for i:d. sum for j:n. xs[j,i]) / n_to_f (size n)
+ xsCov : d=>d=>Float = (for i:d i':d. sum for j:n.
(xs[j,i'] - xsMean[i']) *
(xs[j,i ] - xsMean[i ]) ) / (n_to_f (size n) - 1)
(xsMean, xsCov)
@@ -45,7 +45,7 @@ def mhStep(
k:Key,
x:d=>Float
) -> d=>Float given (d|Ix) =
- [k1, k2] = split_key k
+ [k1, k2] = split_key(n=2, k)
proposal = x + stepSize .* randn_vec k1
propose logProb x proposal k2
@@ -80,8 +80,8 @@ def hmcStep(
) -> d=>Float given (d|Ix) =
def hamiltonian(s:HMCState (d=>Float)) -> Float =
logProb s.x - 0.5 * vdot s.p s.p
- [k1, k2] = split_key k
- p = randn_vec k1
+ [k1, k2] = split_key(n=2, k)
+ p = randn_vec k1 :: d => Float
proposal = leapfrogIntegrate params logProb HMCState(x, p)
final = propose hamiltonian HMCState(x, p) proposal k2
final.x
@@ -93,6 +93,8 @@ def hmcStep(
def myLogProb(x:(Fin 2)=>Float) -> LogProb =
x' = x - [1.5, 2.5]
neg $ 0.5 * inner x' [[1.,0.],[0.,20.]] x'
+def myInitializer(k:Key) -> Fin 2 => Float =
+ randn_vec(k)
numSamples : Nat =
if dex_test_mode()
@@ -101,21 +103,21 @@ numSamples : Nat =
k0 = new_key 1
mhParams = 0.1
-mhSamples = runChain randn_vec (\k x. mhStep mhParams myLogProb k x) numSamples k0
+mhSamples = runChain myInitializer (\k x. mhStep mhParams myLogProb k x) numSamples k0
:p meanAndCovariance mhSamples
> ([0.5455918, 2.522631], [[0.3552593, 0.05022133], [0.05022133, 0.08734216]])
:html show_plot $ y_plot $
- slice (map head mhSamples) 0 (Fin 1000)
+ slice (each mhSamples head) 0 (Fin 1000)
> <html output>
hmcParams = HMCParams(10, 0.1)
-hmcSamples = runChain randn_vec (\k x. hmcStep hmcParams myLogProb k x) numSamples k0
+hmcSamples = runChain myInitializer (\k x. hmcStep hmcParams myLogProb k x) numSamples k0
:p meanAndCovariance hmcSamples
> ([1.472011, 2.483082], [[1.054705, -0.002082013], [-0.002082013, 0.05058844]])
:html show_plot $ y_plot $
- slice (map head hmcSamples) 0 (Fin 1000)
+ slice (each hmcSamples head) 0 (Fin 1000)
> <html output>