summaryrefslogtreecommitdiff
path: root/benchmarks/conv_py.py
blob: fa1682b98a8a504a4368d269a3d09db23083932a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import jax
import dex
from dex.interop import jax as djax
import numpy as np

import time
import timeit

def bench_python(f, loops=None):
  """Return average runtime of `f` in seconds and number of iterations used."""
  if loops is None:
    f()
    s = time.perf_counter()
    f()
    e = time.perf_counter()
    duration = e - s
    loops = max(4, int(2 / duration)) # aim for 2s
  return (timeit.timeit(f, number=loops, globals=globals()) / loops, loops)


def main():
  with open('benchmarks/conv.dx', 'r') as f:
    m = dex.Module(f.read())
    dex_conv = djax.primitive(m.conv_spec)
    shp = (int(m.n), int(m.width), int(m.side), int(m.side))
    xs = jax.random.normal(jax.random.PRNGKey(1), shp, dtype=jax.numpy.float32)
    filter_size = int(m.filter_size)
    msg = ("TODO Make dex.interop.primitive return Jax Device Arrays, "
           "and change this assert to a block_until_ready() call.")
    assert isinstance(dex_conv(xs, filter_size), np.ndarray), msg
    time_s, loops = bench_python(lambda : dex_conv(xs, filter_size))
    print(f"> Run time: {time_s} s \t(based on {loops} runs)")


if __name__ == '__main__':
  main()