diff options
author | Andrea Piseri <andrea.piseri@gmail.com> | 2024-05-18 21:45:18 +0200 |
---|---|---|
committer | Andrea Piseri <andrea.piseri@gmail.com> | 2024-05-18 21:45:18 +0200 |
commit | 06808414da9a99aab0222dd93863786aeae9f322 (patch) | |
tree | 3ec57c6d433f841ba1b2629f4d2ea335ca6706d3 | |
parent | 7f28308e44a4858b6ebe9d4c9944bb65c12cd71e (diff) |
Separate code path for `cxsz=1`, fix out of bounds read for `RNK(x)==0`
-rw-r--r-- | src/builtins/select.c | 123 |
1 files changed, 83 insertions, 40 deletions
diff --git a/src/builtins/select.c b/src/builtins/select.c index 70f0400d..96d4e588 100644 --- a/src/builtins/select.c +++ b/src/builtins/select.c @@ -372,7 +372,7 @@ B select_c2(B t, B w, B x) { extern INIT_GLOBAL u8 reuseElType[t_COUNT]; -B select_replace(u32 chr, B w, B x, B rep, usz wia, usz xl, usz xcia) { // rep⌾(w⊏⥊) x, assumes w is a typed (elNum) list of valid indices, only el_f64 if strictly necessary +B select_replace(u32 chr, B w, B x, B rep, usz wia, usz xl, usz xcsz) { // rep⌾(w⊏⥊) x, assumes w is a typed (elNum) list of valid indices, only el_f64 if strictly necessary #if CHECK_VALID TALLOC(bool, set, xl); bool sparse = wia < xl/64; @@ -401,19 +401,30 @@ B select_replace(u32 chr, B w, B x, B rep, usz wia, usz xl, usz xcia) { // rep f64* wp = f64any_ptr(w); SPARSE_INIT((i64)wp[i]) - MAKE_MUT(r, xl * xcia); + MAKE_MUT(r, xl * xcsz); mut_init_copy(r, x, re); NOGC_E; MUTG_INIT(r); SGet(rep) - for (usz i = 0; i < wia; i++) { - READ_W(cw, i); - for (usz j = 0; j < xcia; j++) { - B cn = Get(rep, i * xcia + j); - EQ(!equal(mut_getU(r, cw * xcia + j), cn)); - mut_rm(r, cw * xcia + j); - mut_setG(r, cw * xcia + j, cn); + if (xcsz==1) { + for (usz i = 0; i < wia; i++) { + READ_W(cw, i); + B cn = Get(rep, i); + EQ(!equal(mut_getU(r, cw), cn)); + mut_rm(r, cw); + mut_setG(r, cw, cn); + DONE_CW; + } + } else { + for (usz i = 0; i < wia; i++) { + READ_W(cw, i); + for (usz j = 0; j < xcsz; j++) { + B cn = Get(rep, i * xcsz + j); + EQ(!equal(mut_getU(r, cw * xcsz + j), cn)); + mut_rm(r, cw * xcsz + j); + mut_setG(r, cw * xcsz + j, cn); + } + DONE_CW; } - DONE_CW; } ra = mut_fp(r); goto dec_ret_ra; @@ -437,14 +448,24 @@ B select_replace(u32 chr, B w, B x, B rep, usz wia, usz xl, usz xcia) { // rep TyArr* na = toBitArr(rep); rep = taga(na); u64* np = bitarrv_ptr(na); u64* rp = (void*)((TyArr*)ra)->a; - for (usz i = 0; i < wia; i++) { - READ_W(cw, i); - for (usz j = 0; j < xcia; j++) { - bool cn = bitp_get(np, i * xcia + j); - EQ(cn != bitp_get(rp, cw * xcia + j)); - bitp_set(rp, cw * xcia + j, cn); + if (xcsz==1) { + for (usz i = 0; i < wia; i++) { + READ_W(cw, i); + bool cn = bitp_get(np, i); + EQ(cn != bitp_get(rp, cw)); + bitp_set(rp, cw, cn); + DONE_CW; + } + } else { + for (usz i = 0; i < wia; i++) { + READ_W(cw, i); + for (usz j = 0; j < xcsz; j++) { + bool cn = bitp_get(np, i * xcsz + j); + EQ(cn != bitp_get(rp, cw * xcsz + j)); + bitp_set(rp, cw * xcsz + j, cn); + } + DONE_CW; } - DONE_CW; } goto dec_ret_ra; } @@ -452,32 +473,54 @@ B select_replace(u32 chr, B w, B x, B rep, usz wia, usz xl, usz xcia) { // rep ra = reuse? a(REUSE(x)) : cpyHArr(x); B* rp = harrP_parts((HArr*)ra).a; SGet(rep) - for (usz i = 0; i < wia; i++) { - READ_W(cw, i); - for (usz j = 0; j < xcia; j++) { - B cn = Get(rep, i * xcia + j); - EQ(!equal(cn,rp[cw * xcia + j])); - dec(rp[cw * xcia + j]); - rp[cw * xcia + j] = cn; + if (xcsz==1) + { + for (usz i = 0; i < wia; i++) { + READ_W(cw, i); + B cn = Get(rep, i); + EQ(!equal(cn,rp[cw])); + dec(rp[cw]); + rp[cw] = cn; + DONE_CW; + } + } else { + for (usz i = 0; i < wia; i++) { + READ_W(cw, i); + for (usz j = 0; j < xcsz; j++) { + B cn = Get(rep, i * xcsz + j); + EQ(!equal(cn,rp[cw * xcsz + j])); + dec(rp[cw * xcsz + j]); + rp[cw * xcsz + j] = cn; + } + DONE_CW; } - DONE_CW; } goto dec_ret_ra; } } - #define IMPL(T) do { \ - T* rp = (void*)((TyArr*)ra)->a; \ - T* np = tyany_ptr(rep); \ - for (usz i = 0; i < wia; i++) { \ - READ_W(cw, i); \ - for (usz j = 0; j < xcia; j++) { \ - T cn = np[i * xcia + j]; \ - EQ(cn != rp[cw * xcia + j]); \ - rp[cw * xcia + j] = cn; \ - } \ - DONE_CW; \ - } \ + #define IMPL(T) do { \ + T* rp = (void*)((TyArr*)ra)->a; \ + T* np = tyany_ptr(rep); \ + if (xcsz==1) { \ + for (usz i = 0; i < wia; i++) { \ + READ_W(cw, i); \ + T cn = np[i]; \ + EQ(cn != rp[cw]); \ + rp[cw] = cn; \ + DONE_CW; \ + } \ + } else { \ + for (usz i = 0; i < wia; i++) { \ + READ_W(cw, i); \ + for (usz j = 0; j < xcsz; j++) { \ + T cn = np[i * xcsz + j]; \ + EQ(cn != rp[cw * xcsz + j]); \ + rp[cw * xcsz + j] = cn; \ + } \ + DONE_CW; \ + } \ + } \ goto dec_ret_ra; \ } while(0) @@ -501,7 +544,7 @@ B select_replace(u32 chr, B w, B x, B rep, usz wia, usz xl, usz xcia) { // rep } B select_ucw(B t, B o, B w, B x) { - if (isAtm(x) || isAtm(w)) { def: return def_fn_ucw(t, o, w, x); } + if (isAtm(x) || RNK(x)==0 || isAtm(w)) { def: return def_fn_ucw(t, o, w, x); } usz xia = IA(x); usz wia = IA(w); u8 we = TI(w,elType); @@ -522,6 +565,6 @@ B select_ucw(B t, B o, B w, B x) { usz rr = RNK(rep); bool ok = !isAtm(rep) && xr+wr==rr+1 && eqShPart(SH(w),SH(rep),wr) && eqShPart(SH(x)+1,SH(rep)+wr,xr-1); if (!ok) thrF("𝔽⌾(a⊸⊏)𝕩: 𝔽 must return an array with the same shape as its input (%H ≡ shape of a, %2H = shape of ⊏𝕩, %H ≡ shape of result of 𝔽)", w, xr-1, SH(x)+1, rep); - usz ia = shProd(SH(x), 1, RNK(x)); - return select_replace(U'⊏', w, x, rep, wia, SH(x)[0], ia); + usz xcsz = arr_csz(x); + return select_replace(U'⊏', w, x, rep, wia, SH(x)[0], xcsz); } |