diff options
author | Alexey Radul <axch@google.com> | 2023-06-26 12:53:24 -0400 |
---|---|---|
committer | Alexey Radul <axch@google.com> | 2023-07-05 15:16:02 -0400 |
commit | bed0ae8eaf5060e9c3db90bfc4a93fbdd8987262 (patch) | |
tree | 126b4af54ecd2485a37d4ab1710e44ad0e9fe7c4 | |
parent | 17288242a5dbd824ab6e293d23be57aac9ea02f9 (diff) |
Generalize vectorized addition to notice that uniform + contiguous is contiguous, not varying.
This is relevant for preserving vector loads when an index is offset.
-rw-r--r-- | src/lib/Vectorize.hs | 10 | ||||
-rw-r--r-- | tests/opt-tests.dx | 6 |
2 files changed, 10 insertions, 6 deletions
diff --git a/src/lib/Vectorize.hs b/src/lib/Vectorize.hs index fce07270..d9a62728 100644 --- a/src/lib/Vectorize.hs +++ b/src/lib/Vectorize.hs @@ -487,9 +487,13 @@ vectorizePrimOp op = case op of BinOp opk arg1 arg2 -> do sx@(VVal vx x) <- vectorizeAtom arg1 sy@(VVal vy y) <- vectorizeAtom arg2 - let v = case (vx, vy) of (Uniform, Uniform) -> Uniform; _ -> Varying - x' <- if vx /= v then ensureVarying sx else return x - y' <- if vy /= v then ensureVarying sy else return y + let v = case (opk, vx, vy) of + (_, Uniform, Uniform) -> Uniform + (IAdd, Uniform, Contiguous) -> Contiguous + (IAdd, Contiguous, Uniform) -> Contiguous + _ -> Varying + x' <- if v == Varying then ensureVarying sx else return x + y' <- if v == Varying then ensureVarying sy else return y VVal v <$> emitOp (BinOp opk x' y') MiscOp (CastOp tyArg arg) -> do ty <- vectorizeType tyArg diff --git a/tests/opt-tests.dx b/tests/opt-tests.dx index 5a095695..76cecb5a 100644 --- a/tests/opt-tests.dx +++ b/tests/opt-tests.dx @@ -126,13 +126,13 @@ _ = for i:(Fin 20) j:(Fin 4). ordinal j "vectorizing int binary op" -- CHECK-LABEL: vectorizing int binary op %passes vect -_ = for i:(Fin 256). (n_to_i32 (ordinal i)) + 1 +_ = for i:(Fin 256). (n_to_i32 (ordinal i)) * 2 -- CHECK: seq (RawFin 0x10) -- CHECK: [[i0:v#[0-9]+]]:<16xInt32> = vbroadcast -- CHECK: [[i1:v#[0-9]+]]:<16xInt32> = viota -- CHECK: [[i2:v#[0-9]+]]:<16xInt32> = %iadd [[i0]] [[i1]] --- CHECK: [[ones:v#[0-9]+]]:<16xInt32> = vbroadcast 1 --- CHECK: %iadd [[i2]] [[ones]] +-- CHECK: [[twos:v#[0-9]+]]:<16xInt32> = vbroadcast 2 +-- CHECK: %imul [[i2]] [[twos]] "vectorizing float binary op" -- CHECK-LABEL: vectorizing float binary op |