summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAndrea Piseri <andrea.piseri@gmail.com>2024-05-18 21:45:18 +0200
committerAndrea Piseri <andrea.piseri@gmail.com>2024-05-18 21:45:18 +0200
commit06808414da9a99aab0222dd93863786aeae9f322 (patch)
tree3ec57c6d433f841ba1b2629f4d2ea335ca6706d3
parent7f28308e44a4858b6ebe9d4c9944bb65c12cd71e (diff)
Separate code path for `cxsz=1`, fix out of bounds read for `RNK(x)==0`
-rw-r--r--src/builtins/select.c123
1 files changed, 83 insertions, 40 deletions
diff --git a/src/builtins/select.c b/src/builtins/select.c
index 70f0400d..96d4e588 100644
--- a/src/builtins/select.c
+++ b/src/builtins/select.c
@@ -372,7 +372,7 @@ B select_c2(B t, B w, B x) {
extern INIT_GLOBAL u8 reuseElType[t_COUNT];
-B select_replace(u32 chr, B w, B x, B rep, usz wia, usz xl, usz xcia) { // rep⌾(w⊏⥊) x, assumes w is a typed (elNum) list of valid indices, only el_f64 if strictly necessary
+B select_replace(u32 chr, B w, B x, B rep, usz wia, usz xl, usz xcsz) { // rep⌾(w⊏⥊) x, assumes w is a typed (elNum) list of valid indices, only el_f64 if strictly necessary
#if CHECK_VALID
TALLOC(bool, set, xl);
bool sparse = wia < xl/64;
@@ -401,19 +401,30 @@ B select_replace(u32 chr, B w, B x, B rep, usz wia, usz xl, usz xcia) { // rep
f64* wp = f64any_ptr(w);
SPARSE_INIT((i64)wp[i])
- MAKE_MUT(r, xl * xcia);
+ MAKE_MUT(r, xl * xcsz);
mut_init_copy(r, x, re);
NOGC_E;
MUTG_INIT(r); SGet(rep)
- for (usz i = 0; i < wia; i++) {
- READ_W(cw, i);
- for (usz j = 0; j < xcia; j++) {
- B cn = Get(rep, i * xcia + j);
- EQ(!equal(mut_getU(r, cw * xcia + j), cn));
- mut_rm(r, cw * xcia + j);
- mut_setG(r, cw * xcia + j, cn);
+ if (xcsz==1) {
+ for (usz i = 0; i < wia; i++) {
+ READ_W(cw, i);
+ B cn = Get(rep, i);
+ EQ(!equal(mut_getU(r, cw), cn));
+ mut_rm(r, cw);
+ mut_setG(r, cw, cn);
+ DONE_CW;
+ }
+ } else {
+ for (usz i = 0; i < wia; i++) {
+ READ_W(cw, i);
+ for (usz j = 0; j < xcsz; j++) {
+ B cn = Get(rep, i * xcsz + j);
+ EQ(!equal(mut_getU(r, cw * xcsz + j), cn));
+ mut_rm(r, cw * xcsz + j);
+ mut_setG(r, cw * xcsz + j, cn);
+ }
+ DONE_CW;
}
- DONE_CW;
}
ra = mut_fp(r);
goto dec_ret_ra;
@@ -437,14 +448,24 @@ B select_replace(u32 chr, B w, B x, B rep, usz wia, usz xl, usz xcia) { // rep
TyArr* na = toBitArr(rep); rep = taga(na);
u64* np = bitarrv_ptr(na);
u64* rp = (void*)((TyArr*)ra)->a;
- for (usz i = 0; i < wia; i++) {
- READ_W(cw, i);
- for (usz j = 0; j < xcia; j++) {
- bool cn = bitp_get(np, i * xcia + j);
- EQ(cn != bitp_get(rp, cw * xcia + j));
- bitp_set(rp, cw * xcia + j, cn);
+ if (xcsz==1) {
+ for (usz i = 0; i < wia; i++) {
+ READ_W(cw, i);
+ bool cn = bitp_get(np, i);
+ EQ(cn != bitp_get(rp, cw));
+ bitp_set(rp, cw, cn);
+ DONE_CW;
+ }
+ } else {
+ for (usz i = 0; i < wia; i++) {
+ READ_W(cw, i);
+ for (usz j = 0; j < xcsz; j++) {
+ bool cn = bitp_get(np, i * xcsz + j);
+ EQ(cn != bitp_get(rp, cw * xcsz + j));
+ bitp_set(rp, cw * xcsz + j, cn);
+ }
+ DONE_CW;
}
- DONE_CW;
}
goto dec_ret_ra;
}
@@ -452,32 +473,54 @@ B select_replace(u32 chr, B w, B x, B rep, usz wia, usz xl, usz xcia) { // rep
ra = reuse? a(REUSE(x)) : cpyHArr(x);
B* rp = harrP_parts((HArr*)ra).a;
SGet(rep)
- for (usz i = 0; i < wia; i++) {
- READ_W(cw, i);
- for (usz j = 0; j < xcia; j++) {
- B cn = Get(rep, i * xcia + j);
- EQ(!equal(cn,rp[cw * xcia + j]));
- dec(rp[cw * xcia + j]);
- rp[cw * xcia + j] = cn;
+ if (xcsz==1)
+ {
+ for (usz i = 0; i < wia; i++) {
+ READ_W(cw, i);
+ B cn = Get(rep, i);
+ EQ(!equal(cn,rp[cw]));
+ dec(rp[cw]);
+ rp[cw] = cn;
+ DONE_CW;
+ }
+ } else {
+ for (usz i = 0; i < wia; i++) {
+ READ_W(cw, i);
+ for (usz j = 0; j < xcsz; j++) {
+ B cn = Get(rep, i * xcsz + j);
+ EQ(!equal(cn,rp[cw * xcsz + j]));
+ dec(rp[cw * xcsz + j]);
+ rp[cw * xcsz + j] = cn;
+ }
+ DONE_CW;
}
- DONE_CW;
}
goto dec_ret_ra;
}
}
- #define IMPL(T) do { \
- T* rp = (void*)((TyArr*)ra)->a; \
- T* np = tyany_ptr(rep); \
- for (usz i = 0; i < wia; i++) { \
- READ_W(cw, i); \
- for (usz j = 0; j < xcia; j++) { \
- T cn = np[i * xcia + j]; \
- EQ(cn != rp[cw * xcia + j]); \
- rp[cw * xcia + j] = cn; \
- } \
- DONE_CW; \
- } \
+ #define IMPL(T) do { \
+ T* rp = (void*)((TyArr*)ra)->a; \
+ T* np = tyany_ptr(rep); \
+ if (xcsz==1) { \
+ for (usz i = 0; i < wia; i++) { \
+ READ_W(cw, i); \
+ T cn = np[i]; \
+ EQ(cn != rp[cw]); \
+ rp[cw] = cn; \
+ DONE_CW; \
+ } \
+ } else { \
+ for (usz i = 0; i < wia; i++) { \
+ READ_W(cw, i); \
+ for (usz j = 0; j < xcsz; j++) { \
+ T cn = np[i * xcsz + j]; \
+ EQ(cn != rp[cw * xcsz + j]); \
+ rp[cw * xcsz + j] = cn; \
+ } \
+ DONE_CW; \
+ } \
+ } \
goto dec_ret_ra; \
} while(0)
@@ -501,7 +544,7 @@ B select_replace(u32 chr, B w, B x, B rep, usz wia, usz xl, usz xcia) { // rep
}
B select_ucw(B t, B o, B w, B x) {
- if (isAtm(x) || isAtm(w)) { def: return def_fn_ucw(t, o, w, x); }
+ if (isAtm(x) || RNK(x)==0 || isAtm(w)) { def: return def_fn_ucw(t, o, w, x); }
usz xia = IA(x);
usz wia = IA(w);
u8 we = TI(w,elType);
@@ -522,6 +565,6 @@ B select_ucw(B t, B o, B w, B x) {
usz rr = RNK(rep);
bool ok = !isAtm(rep) && xr+wr==rr+1 && eqShPart(SH(w),SH(rep),wr) && eqShPart(SH(x)+1,SH(rep)+wr,xr-1);
if (!ok) thrF("𝔽⌾(a⊸⊏)𝕩: 𝔽 must return an array with the same shape as its input (%H ≡ shape of a, %2H = shape of ⊏𝕩, %H ≡ shape of result of 𝔽)", w, xr-1, SH(x)+1, rep);
- usz ia = shProd(SH(x), 1, RNK(x));
- return select_replace(U'⊏', w, x, rep, wia, SH(x)[0], ia);
+ usz xcsz = arr_csz(x);
+ return select_replace(U'⊏', w, x, rep, wia, SH(x)[0], xcsz);
}