summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAlexey Radul <axch@google.com>2023-06-26 12:53:24 -0400
committerAlexey Radul <axch@google.com>2023-07-05 15:16:02 -0400
commitbed0ae8eaf5060e9c3db90bfc4a93fbdd8987262 (patch)
tree126b4af54ecd2485a37d4ab1710e44ad0e9fe7c4
parent17288242a5dbd824ab6e293d23be57aac9ea02f9 (diff)
Generalize vectorized addition to notice that uniform + contiguous is contiguous, not varying.
This is relevant for preserving vector loads when an index is offset.
-rw-r--r--src/lib/Vectorize.hs10
-rw-r--r--tests/opt-tests.dx6
2 files changed, 10 insertions, 6 deletions
diff --git a/src/lib/Vectorize.hs b/src/lib/Vectorize.hs
index fce07270..d9a62728 100644
--- a/src/lib/Vectorize.hs
+++ b/src/lib/Vectorize.hs
@@ -487,9 +487,13 @@ vectorizePrimOp op = case op of
BinOp opk arg1 arg2 -> do
sx@(VVal vx x) <- vectorizeAtom arg1
sy@(VVal vy y) <- vectorizeAtom arg2
- let v = case (vx, vy) of (Uniform, Uniform) -> Uniform; _ -> Varying
- x' <- if vx /= v then ensureVarying sx else return x
- y' <- if vy /= v then ensureVarying sy else return y
+ let v = case (opk, vx, vy) of
+ (_, Uniform, Uniform) -> Uniform
+ (IAdd, Uniform, Contiguous) -> Contiguous
+ (IAdd, Contiguous, Uniform) -> Contiguous
+ _ -> Varying
+ x' <- if v == Varying then ensureVarying sx else return x
+ y' <- if v == Varying then ensureVarying sy else return y
VVal v <$> emitOp (BinOp opk x' y')
MiscOp (CastOp tyArg arg) -> do
ty <- vectorizeType tyArg
diff --git a/tests/opt-tests.dx b/tests/opt-tests.dx
index 5a095695..76cecb5a 100644
--- a/tests/opt-tests.dx
+++ b/tests/opt-tests.dx
@@ -126,13 +126,13 @@ _ = for i:(Fin 20) j:(Fin 4). ordinal j
"vectorizing int binary op"
-- CHECK-LABEL: vectorizing int binary op
%passes vect
-_ = for i:(Fin 256). (n_to_i32 (ordinal i)) + 1
+_ = for i:(Fin 256). (n_to_i32 (ordinal i)) * 2
-- CHECK: seq (RawFin 0x10)
-- CHECK: [[i0:v#[0-9]+]]:<16xInt32> = vbroadcast
-- CHECK: [[i1:v#[0-9]+]]:<16xInt32> = viota
-- CHECK: [[i2:v#[0-9]+]]:<16xInt32> = %iadd [[i0]] [[i1]]
--- CHECK: [[ones:v#[0-9]+]]:<16xInt32> = vbroadcast 1
--- CHECK: %iadd [[i2]] [[ones]]
+-- CHECK: [[twos:v#[0-9]+]]:<16xInt32> = vbroadcast 2
+-- CHECK: %imul [[i2]] [[twos]]
"vectorizing float binary op"
-- CHECK-LABEL: vectorizing float binary op