diff options
author | Andrea Piseri <andrea.piseri@gmail.com> | 2024-05-18 12:12:17 +0200 |
---|---|---|
committer | Andrea Piseri <andrea.piseri@gmail.com> | 2024-05-18 12:12:17 +0200 |
commit | 7f28308e44a4858b6ebe9d4c9944bb65c12cd71e (patch) | |
tree | 5bff7cf9b6cefb62ac9b798e4a9028d0c214b6ff | |
parent | 4f898f38d22ebec96965f7293ed4e40bd24c152d (diff) |
native path in `select_ucw` for high rank `𝕩`
`𝔽⌾(a⊸⊏)𝕩` now does not need to go through the self-hosted runtime if
`1<=𝕩`. Instead the `select_replace` helper function is parametrized
over the length of `𝕩` (`xl`) and item amount of the cell of `𝕩` (`xcia`).
- The `EQ` macro is modified to not immediately mark the cell as
populated, so that multiple replacements can be done on the cell
on the first assignment to it.
- The `DONE_CW` macro is invoked to mark the current cell as populated
when every element in it has been assigned.
- A loop over the cell contents is introduced to copy the elements in
`𝕩`. This should be fine as it is an easily predictable jump,
but a performance regression is possible and a separate code path
could be introduced in the future.
The change introduces more extensive checking on the shape of `𝔽`'s
result, as for high rank `𝕩` the requirement should be `(≢𝔽a⊏𝕩)≡(≢a)∾1↓≢𝕩`.
The old behaviour of `select_replace` is recovered by passing `xl=xia`
and `xcia=1` in the implementation of `pick_ucw`.
-rw-r--r-- | src/builtins/select.c | 86 | ||||
-rw-r--r-- | src/builtins/sfns.c | 4 |
2 files changed, 55 insertions, 35 deletions
diff --git a/src/builtins/select.c b/src/builtins/select.c index 3665dbaa..70f0400d 100644 --- a/src/builtins/select.c +++ b/src/builtins/select.c @@ -372,24 +372,26 @@ B select_c2(B t, B w, B x) { extern INIT_GLOBAL u8 reuseElType[t_COUNT]; -B select_replace(u32 chr, B w, B x, B rep, usz wia, usz xia) { // rep⌾(w⊏⥊) x, assumes w is a typed (elNum) list of valid indices, only el_f64 if strictly necessary +B select_replace(u32 chr, B w, B x, B rep, usz wia, usz xl, usz xcia) { // rep⌾(w⊏⥊) x, assumes w is a typed (elNum) list of valid indices, only el_f64 if strictly necessary #if CHECK_VALID - TALLOC(bool, set, xia); - bool sparse = wia < xia/64; - if (!sparse) for (i64 i = 0; i < xia; i++) set[i] = false; + TALLOC(bool, set, xl); + bool sparse = wia < xl/64; + if (!sparse) for (i64 i = 0; i < xl; i++) set[i] = false; #define SPARSE_INIT(WI) \ - if (sparse) for (usz i = 0; i < wia; i++) { \ - i64 cw = WI; if (RARE(cw<0)) cw+= (i64)xia; set[cw] = false; \ + if (sparse) for (usz i = 0; i < wia; i++) { \ + i64 cw = WI; if (RARE(cw<0)) cw+= (i64)xl; set[cw] = false; \ } - #define EQ(F) if (set[cw] && (F)) thrF("𝔽⌾(a⊸%c): Incompatible result elements", chr); set[cw] = true; + #define EQ(F) if (set[cw] && (F)) thrF("𝔽⌾(a⊸%c): Incompatible result elements", chr); + #define DONE_CW set[cw] = true; #define FREE_CHECK TFREE(set) #else #define SPARSE_INIT(GET) #define EQ(F) + #define DONE_CW #define FREE_CHECK #endif - #define READ_W(N,I) i64 N = (i64)wp[I]; if (RARE(N<0)) N+= (i64)xia + #define READ_W(N,I) i64 N = (i64)wp[I]; if (RARE(N<0)) N+= (i64)xl u8 we = TI(w,elType); assert(elNum(we)); u8 xe = TI(x,elType); u8 re = el_or(xe, TI(rep,elType)); @@ -399,16 +401,19 @@ B select_replace(u32 chr, B w, B x, B rep, usz wia, usz xia) { // rep⌾(w⊏⥊ f64* wp = f64any_ptr(w); SPARSE_INIT((i64)wp[i]) - MAKE_MUT(r, xia); + MAKE_MUT(r, xl * xcia); mut_init_copy(r, x, re); NOGC_E; MUTG_INIT(r); SGet(rep) for (usz i = 0; i < wia; i++) { READ_W(cw, i); - B cn = Get(rep, i); - EQ(!equal(mut_getU(r, cw), cn)); - mut_rm(r, cw); - mut_setG(r, cw, cn); + for (usz j = 0; j < xcia; j++) { + B cn = Get(rep, i * xcia + j); + EQ(!equal(mut_getU(r, cw * xcia + j), cn)); + mut_rm(r, cw * xcia + j); + mut_setG(r, cw * xcia + j, cn); + } + DONE_CW; } ra = mut_fp(r); goto dec_ret_ra; @@ -419,7 +424,7 @@ B select_replace(u32 chr, B w, B x, B rep, usz wia, usz xia) { // rep⌾(w⊏⥊ i32* wp = i32any_ptr(w); SPARSE_INIT(wp[i]) bool reuse = reusable(x) && re==reuseElType[TY(x)]; - SLOWIF(!reuse && xia>100 && wia<xia/50) SLOW2("⌾(𝕨⊸⊏)𝕩 or ⌾(𝕨⊸⊑)𝕩 because not reusable", w, x); + SLOWIF(!reuse && xl>100 && wia<xl/50) SLOW2("⌾(𝕨⊸⊏)𝕩 or ⌾(𝕨⊸⊑)𝕩 because not reusable", w, x); switch (re) { default: UD; case el_i8: rep = toI8Any(rep); ra = reuse? a(REUSE(x)) : cpyI8Arr(x); goto do_u8; case el_c8: rep = toC8Any(rep); ra = reuse? a(REUSE(x)) : cpyC8Arr(x); goto do_u8; @@ -434,9 +439,12 @@ B select_replace(u32 chr, B w, B x, B rep, usz wia, usz xia) { // rep⌾(w⊏⥊ u64* rp = (void*)((TyArr*)ra)->a; for (usz i = 0; i < wia; i++) { READ_W(cw, i); - bool cn = bitp_get(np, i); - EQ(cn != bitp_get(rp, cw)); - bitp_set(rp, cw, cn); + for (usz j = 0; j < xcia; j++) { + bool cn = bitp_get(np, i * xcia + j); + EQ(cn != bitp_get(rp, cw * xcia + j)); + bitp_set(rp, cw * xcia + j, cn); + } + DONE_CW; } goto dec_ret_ra; } @@ -446,24 +454,30 @@ B select_replace(u32 chr, B w, B x, B rep, usz wia, usz xia) { // rep⌾(w⊏⥊ SGet(rep) for (usz i = 0; i < wia; i++) { READ_W(cw, i); - B cn = Get(rep, i); - EQ(!equal(cn,rp[cw])); - dec(rp[cw]); - rp[cw] = cn; + for (usz j = 0; j < xcia; j++) { + B cn = Get(rep, i * xcia + j); + EQ(!equal(cn,rp[cw * xcia + j])); + dec(rp[cw * xcia + j]); + rp[cw * xcia + j] = cn; + } + DONE_CW; } goto dec_ret_ra; } } - #define IMPL(T) do { \ - T* rp = (void*)((TyArr*)ra)->a; \ - T* np = tyany_ptr(rep); \ - for (usz i = 0; i < wia; i++) { \ - READ_W(cw, i); \ - T cn = np[i]; \ - EQ(cn != rp[cw]); \ - rp[cw] = cn; \ - } \ + #define IMPL(T) do { \ + T* rp = (void*)((TyArr*)ra)->a; \ + T* np = tyany_ptr(rep); \ + for (usz i = 0; i < wia; i++) { \ + READ_W(cw, i); \ + for (usz j = 0; j < xcia; j++) { \ + T cn = np[i * xcia + j]; \ + EQ(cn != rp[cw * xcia + j]); \ + rp[cw * xcia + j] = cn; \ + } \ + DONE_CW; \ + } \ goto dec_ret_ra; \ } while(0) @@ -482,11 +496,12 @@ B select_replace(u32 chr, B w, B x, B rep, usz wia, usz xia) { // rep⌾(w⊏⥊ #undef SPARSE_INIT #undef EQ + #undef DONE_CW #undef FREE_CHECK } B select_ucw(B t, B o, B w, B x) { - if (isAtm(x) || RNK(x)!=1 || isAtm(w)) { def: return def_fn_ucw(t, o, w, x); } + if (isAtm(x) || isAtm(w)) { def: return def_fn_ucw(t, o, w, x); } usz xia = IA(x); usz wia = IA(w); u8 we = TI(w,elType); @@ -502,6 +517,11 @@ B select_ucw(B t, B o, B w, B x) { } else { rep = c1(o, C2(select, incG(w), incG(x))); } - if (isAtm(rep) || !eqShape(w, rep)) thrF("𝔽⌾(a⊸⊏)𝕩: 𝔽 must return an array with the same shape as its input (expected %H, got %H)", w, rep); - return select_replace(U'⊏', w, x, rep, wia, xia); + usz xr = RNK(x); + usz wr = RNK(w); + usz rr = RNK(rep); + bool ok = !isAtm(rep) && xr+wr==rr+1 && eqShPart(SH(w),SH(rep),wr) && eqShPart(SH(x)+1,SH(rep)+wr,xr-1); + if (!ok) thrF("𝔽⌾(a⊸⊏)𝕩: 𝔽 must return an array with the same shape as its input (%H ≡ shape of a, %2H = shape of ⊏𝕩, %H ≡ shape of result of 𝔽)", w, xr-1, SH(x)+1, rep); + usz ia = shProd(SH(x), 1, RNK(x)); + return select_replace(U'⊏', w, x, rep, wia, SH(x)[0], ia); } diff --git a/src/builtins/sfns.c b/src/builtins/sfns.c index 0a9109a5..952b27f9 100644 --- a/src/builtins/sfns.c +++ b/src/builtins/sfns.c @@ -1302,7 +1302,7 @@ B pick_uc1(B t, B o, B x) { -B select_replace(u32 chr, B w, B x, B rep, usz wia, usz xia); +B select_replace(u32 chr, B w, B x, B rep, usz wia, usz xl, usz xcia); B select_ucw(B t, B o, B w, B x); B select_c2(B,B,B); B pick_ucw(B t, B o, B w, B x) { @@ -1329,7 +1329,7 @@ B pick_ucw(B t, B o, B w, B x) { w = num_squeeze(mut_fcd(r, w)); B rep = isArr(o)? incG(o) : c1(o, C2(select, incG(w), C1(shape, incG(x)))); if (isAtm(rep) || !eqShape(w, rep)) thrF("𝔽⌾(a⊸⊑)𝕩: 𝔽 must return an array with the same shape as its input (expected %H, got %H)", w, rep); - return select_replace(U'⊑', w, x, rep, wia, xia); + return select_replace(U'⊑', w, x, rep, wia, xia, 1); } decG(w); } |