summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/builtins/compare.c7
-rw-r--r--src/builtins/select.c117
-rw-r--r--src/builtins/sfns.c4
-rw-r--r--src/singeli/src/equal.singeli1
-rw-r--r--src/singeli/src/mask.singeli1
-rw-r--r--src/utils/calls.h3
6 files changed, 92 insertions, 41 deletions
diff --git a/src/builtins/compare.c b/src/builtins/compare.c
index a8b5268f..f55831cf 100644
--- a/src/builtins/compare.c
+++ b/src/builtins/compare.c
@@ -107,13 +107,14 @@ u8 const eqFnData[] = { // for the main diagonal, amount to shift length by; oth
#else
#define F(X) equal_##X
bool F(1_1)(void* w, void* x, u64 l, u64 d) {
+ assert(l>0);
u64* wp = w; u64* xp = x;
usz q = l/64;
for (usz i=0; i<q; i++) if (wp[i] != xp[i]) return false;
usz r = (-l)%64; return r==0 || (wp[q]^xp[q])<<r == 0;
}
#define DEF_EQ_U1(N, T) \
- bool F(1_##N)(void* w, void* x, u64 l, u64 d) { \
+ bool F(1_##N)(void* w, void* x, u64 l, u64 d) { assert(l>0); \
if (d!=0) { void* t=w; w=x; x=t; } \
u64* wp = w; T* xp = x; \
for (usz i=0; i<l; i++) if (bitp_get(wp,i)!=xp[i]) return false; \
@@ -127,7 +128,7 @@ u8 const eqFnData[] = { // for the main diagonal, amount to shift length by; oth
#define DEF_EQ_I(NAME, S, T, INIT) \
bool F(NAME)(void* w, void* x, u64 l, u64 d) { \
- INIT \
+ assert(l>0); INIT \
S* wp = w; T* xp = x; \
for (usz i=0; i<l; i++) if (wp[i]!=xp[i]) return false; \
return true; \
@@ -143,7 +144,7 @@ u8 const eqFnData[] = { // for the main diagonal, amount to shift length by; oth
#undef DEF_EQ_I
#undef DEF_EQ
#endif
-bool notEq(void* a, void* b, u64 l, u64 data) { return false; }
+bool notEq(void* a, void* b, u64 l, u64 data) { assert(l>0); return false; }
INIT_GLOBAL EqFn eqFns[] = {
F(1_1), F(1_8), F(1_16), F(1_32), F(1_f64), notEq, notEq, notEq,
F(1_8), F(8_8), F(s8_16), F(s8_32), F(s8_f64), notEq, notEq, notEq,
diff --git a/src/builtins/select.c b/src/builtins/select.c
index 3665dbaa..f8b8e536 100644
--- a/src/builtins/select.c
+++ b/src/builtins/select.c
@@ -372,24 +372,27 @@ B select_c2(B t, B w, B x) {
extern INIT_GLOBAL u8 reuseElType[t_COUNT];
-B select_replace(u32 chr, B w, B x, B rep, usz wia, usz xia) { // rep⌾(w⊏⥊) x, assumes w is a typed (elNum) list of valid indices, only el_f64 if strictly necessary
+B select_replace(u32 chr, B w, B x, B rep, usz wia, usz cam, usz csz) { // consumes all; (⥊rep)⌾(⥊w⊏cam‿csz⥊⊢) x; assumes csz>0, that w is a typed (elNum) list of valid indices (squeeze already attempted on el_f64), and that rep has the proper element count
+ assert(csz > 0);
#if CHECK_VALID
- TALLOC(bool, set, xia);
- bool sparse = wia < xia/64;
- if (!sparse) for (i64 i = 0; i < xia; i++) set[i] = false;
+ TALLOC(bool, set, cam);
+ bool sparse = wia < cam/64;
+ if (!sparse) for (i64 i = 0; i < cam; i++) set[i] = false;
#define SPARSE_INIT(WI) \
- if (sparse) for (usz i = 0; i < wia; i++) { \
- i64 cw = WI; if (RARE(cw<0)) cw+= (i64)xia; set[cw] = false; \
+ if (sparse) for (usz i = 0; i < wia; i++) { \
+ i64 cw = WI; if (RARE(cw<0)) cw+= (i64)cam; set[cw] = false; \
}
- #define EQ(F) if (set[cw] && (F)) thrF("𝔽⌾(a⊸%c): Incompatible result elements", chr); set[cw] = true;
+ #define EQ(ITER,F) if (set[cw]) ITER if (F) thrF("𝔽⌾(a⊸%c): Incompatible result elements", chr); set[cw] = true;
+ #define EQ1(F) EQ(,F)
#define FREE_CHECK TFREE(set)
#else
#define SPARSE_INIT(GET)
- #define EQ(F)
+ #define EQ(ITER,F)
+ #define EQ1(F)
#define FREE_CHECK
#endif
- #define READ_W(N,I) i64 N = (i64)wp[I]; if (RARE(N<0)) N+= (i64)xia
+ #define READ_W(N,I) i64 N = (i64)wp[I]; if (RARE(N<0)) N+= (i64)cam
u8 we = TI(w,elType); assert(elNum(we));
u8 xe = TI(x,elType);
u8 re = el_or(xe, TI(rep,elType));
@@ -399,16 +402,25 @@ B select_replace(u32 chr, B w, B x, B rep, usz wia, usz xia) { // rep⌾(w⊏⥊
f64* wp = f64any_ptr(w);
SPARSE_INIT((i64)wp[i])
- MAKE_MUT(r, xia);
+ MAKE_MUT(r, cam*csz);
mut_init_copy(r, x, re);
NOGC_E;
MUTG_INIT(r); SGet(rep)
- for (usz i = 0; i < wia; i++) {
- READ_W(cw, i);
- B cn = Get(rep, i);
- EQ(!equal(mut_getU(r, cw), cn));
- mut_rm(r, cw);
- mut_setG(r, cw, cn);
+ if (csz==1) {
+ for (usz i = 0; i < wia; i++) {
+ READ_W(cw, i);
+ B cn = Get(rep, i);
+ EQ1(!equal(mut_getU(r, cw), cn));
+ mut_rm(r, cw);
+ mut_setG(r, cw, cn);
+ }
+ } else {
+ for (usz i = 0; i < wia; i++) {
+ READ_W(cw, i);
+ EQ(for (usz j = 0; j < csz; j++), !equal(mut_getU(r, cw*csz + j), Get(rep, i*csz + j)));
+ for (usz j = 0; j < csz; j++) mut_rm(r, cw*csz + j);
+ mut_copyG(r, cw*csz, rep, i*csz, csz);
+ }
}
ra = mut_fp(r);
goto dec_ret_ra;
@@ -419,7 +431,7 @@ B select_replace(u32 chr, B w, B x, B rep, usz wia, usz xia) { // rep⌾(w⊏⥊
i32* wp = i32any_ptr(w);
SPARSE_INIT(wp[i])
bool reuse = reusable(x) && re==reuseElType[TY(x)];
- SLOWIF(!reuse && xia>100 && wia<xia/50) SLOW2("⌾(𝕨⊸⊏)𝕩 or ⌾(𝕨⊸⊑)𝕩 because not reusable", w, x);
+ SLOWIF(!reuse && cam>100 && wia<cam/50) SLOW2("⌾(𝕨⊸⊏)𝕩 or ⌾(𝕨⊸⊑)𝕩 because not reusable", w, x);
switch (re) { default: UD;
case el_i8: rep = toI8Any(rep); ra = reuse? a(REUSE(x)) : cpyI8Arr(x); goto do_u8;
case el_c8: rep = toC8Any(rep); ra = reuse? a(REUSE(x)) : cpyC8Arr(x); goto do_u8;
@@ -432,11 +444,19 @@ B select_replace(u32 chr, B w, B x, B rep, usz wia, usz xia) { // rep⌾(w⊏⥊
TyArr* na = toBitArr(rep); rep = taga(na);
u64* np = bitarrv_ptr(na);
u64* rp = (void*)((TyArr*)ra)->a;
- for (usz i = 0; i < wia; i++) {
- READ_W(cw, i);
- bool cn = bitp_get(np, i);
- EQ(cn != bitp_get(rp, cw));
- bitp_set(rp, cw, cn);
+ if (csz==1) {
+ for (usz i = 0; i < wia; i++) {
+ READ_W(cw, i);
+ bool cn = bitp_get(np, i);
+ EQ1(cn != bitp_get(rp, cw));
+ bitp_set(rp, cw, cn);
+ }
+ } else {
+ for (usz i = 0; i < wia; i++) {
+ READ_W(cw, i);
+ EQ(for (usz j = 0; j < csz; j++), bitp_get(np, i*csz + j) != bitp_get(rp, cw*csz + j));
+ COPY_TO(rp, el_bit, cw*csz, rep, i*csz, csz);
+ }
}
goto dec_ret_ra;
}
@@ -444,27 +464,37 @@ B select_replace(u32 chr, B w, B x, B rep, usz wia, usz xia) { // rep⌾(w⊏⥊
ra = reuse? a(REUSE(x)) : cpyHArr(x);
B* rp = harrP_parts((HArr*)ra).a;
SGet(rep)
- for (usz i = 0; i < wia; i++) {
- READ_W(cw, i);
- B cn = Get(rep, i);
- EQ(!equal(cn,rp[cw]));
- dec(rp[cw]);
- rp[cw] = cn;
+ if (csz==1) {
+ for (usz i = 0; i < wia; i++) {
+ READ_W(cw, i);
+ B cn = Get(rep, i);
+ EQ1(!equal(cn,rp[cw]));
+ dec(rp[cw]);
+ rp[cw] = cn;
+ }
+ } else {
+ for (usz i = 0; i < wia; i++) {
+ READ_W(cw, i);
+ EQ(for (usz j = 0; j < csz; j++), !equal(Get(rep, i*csz + j), rp[cw*csz + j]));
+ for (usz j = 0; j < csz; j++) dec(rp[cw*csz + j]);
+ COPY_TO(rp, el_B, cw*csz, rep, i*csz, csz);
+ }
}
goto dec_ret_ra;
}
}
#define IMPL(T) do { \
+ if (csz!=1) goto do_tycell; \
T* rp = (void*)((TyArr*)ra)->a; \
T* np = tyany_ptr(rep); \
for (usz i = 0; i < wia; i++) { \
READ_W(cw, i); \
T cn = np[i]; \
- EQ(cn != rp[cw]); \
+ EQ1(cn != rp[cw]); \
rp[cw] = cn; \
} \
- goto dec_ret_ra; \
+ goto dec_ret_ra; \
} while(0)
do_u8: IMPL(u8);
@@ -473,6 +503,18 @@ B select_replace(u32 chr, B w, B x, B rep, usz wia, usz xia) { // rep⌾(w⊏⥊
do_u64: IMPL(u64);
#undef IMPL
+ do_tycell:;
+ u8 cwidth = csz * elWidth(re);
+ u8* rp = (u8*) ((TyArr*)ra)->a;
+ u8* np = tyany_ptr(rep);
+ EqFnObj eq = EQFN_GET(re,re);
+ for (usz i = 0; i < wia; i++) {
+ READ_W(cw, i);
+ EQ1(!EQFN_CALL(eq, rp + cw*cwidth, np + i*cwidth, csz));
+ COPY_TO(rp, re, cw*csz, rep, i*csz, csz);
+ }
+ goto dec_ret_ra;
+
dec_ret_ra:;
@@ -482,11 +524,12 @@ B select_replace(u32 chr, B w, B x, B rep, usz wia, usz xia) { // rep⌾(w⊏⥊
#undef SPARSE_INIT
#undef EQ
+ #undef EQ1
#undef FREE_CHECK
}
B select_ucw(B t, B o, B w, B x) {
- if (isAtm(x) || RNK(x)!=1 || isAtm(w)) { def: return def_fn_ucw(t, o, w, x); }
+ if (isAtm(x) || isAtm(w)) { def: return def_fn_ucw(t, o, w, x); }
usz xia = IA(x);
usz wia = IA(w);
u8 we = TI(w,elType);
@@ -495,13 +538,19 @@ B select_ucw(B t, B o, B w, B x) {
if (!elNum(we)) goto def;
}
B rep;
- if (isArr(o)) {
+ if (isArr(o) && RNK(x)>0) {
i64 buf[2];
if (wia!=0 && (!getRange_fns[we](tyany_ptr(w), buf, wia) || buf[0]<-(i64)xia || buf[1]>=xia)) thrF("𝔽⌾(a⊸⊏)𝕩: Indexing out-of-bounds (%l∊a, %H≡≢𝕩)", buf[1]>=xia?buf[1]:buf[0], x);
rep = incG(o);
} else {
rep = c1(o, C2(select, incG(w), incG(x)));
}
- if (isAtm(rep) || !eqShape(w, rep)) thrF("𝔽⌾(a⊸⊏)𝕩: 𝔽 must return an array with the same shape as its input (expected %H, got %H)", w, rep);
- return select_replace(U'⊏', w, x, rep, wia, xia);
+ usz xr = RNK(x);
+ usz wr = RNK(w);
+ usz rr = RNK(rep);
+ bool ok = !isAtm(rep) && xr+wr==rr+1 && eqShPart(SH(w),SH(rep),wr) && eqShPart(SH(x)+1,SH(rep)+wr,xr-1);
+ if (!ok) thrF("𝔽⌾(a⊸⊏)𝕩: 𝔽 must return an array with the same shape as its input (%H ≡ shape of a, %2H ≡ shape of ⊏𝕩, %H ≡ shape of result of 𝔽)", w, xr-1, SH(x)+1, rep);
+ usz csz = arr_csz(x);
+ if (csz == 0) { decG(rep); decG(w); return x; }
+ return select_replace(U'⊏', w, x, rep, wia, *SH(x), csz);
}
diff --git a/src/builtins/sfns.c b/src/builtins/sfns.c
index 0a9109a5..952b27f9 100644
--- a/src/builtins/sfns.c
+++ b/src/builtins/sfns.c
@@ -1302,7 +1302,7 @@ B pick_uc1(B t, B o, B x) {
-B select_replace(u32 chr, B w, B x, B rep, usz wia, usz xia);
+B select_replace(u32 chr, B w, B x, B rep, usz wia, usz xl, usz xcia);
B select_ucw(B t, B o, B w, B x);
B select_c2(B,B,B);
B pick_ucw(B t, B o, B w, B x) {
@@ -1329,7 +1329,7 @@ B pick_ucw(B t, B o, B w, B x) {
w = num_squeeze(mut_fcd(r, w));
B rep = isArr(o)? incG(o) : c1(o, C2(select, incG(w), C1(shape, incG(x))));
if (isAtm(rep) || !eqShape(w, rep)) thrF("𝔽⌾(a⊸⊑)𝕩: 𝔽 must return an array with the same shape as its input (expected %H, got %H)", w, rep);
- return select_replace(U'⊑', w, x, rep, wia, xia);
+ return select_replace(U'⊑', w, x, rep, wia, xia, 1);
}
decG(w);
}
diff --git a/src/singeli/src/equal.singeli b/src/singeli/src/equal.singeli
index b80eb02d..3bdd7ff2 100644
--- a/src/singeli/src/equal.singeli
+++ b/src/singeli/src/equal.singeli
@@ -15,6 +15,7 @@ fn equal{W, X}(w:*void, x:*void, l:u64, d:u64) : u1 = {
def vw = arch_defvw
def bulk = vw / width{X}
if (W!=X) if (d!=0) swap{w,x}
+ assert{l>0}
if (W==u1) {
if (X==u1) { # bitarr ≡ bitarr
diff --git a/src/singeli/src/mask.singeli b/src/singeli/src/mask.singeli
index d9979b7d..adc575e2 100644
--- a/src/singeli/src/mask.singeli
+++ b/src/singeli/src/mask.singeli
@@ -103,6 +103,7 @@ def maskedLoop{bulk} = maskedLoop{bulk,0}
def maskedLoopPositive{bulk}{vars,begin==0,end:L,iter} = {
+ assert{end > 0}
i:L = 0
while(i < (end-1)/bulk) {
mlExec{i, iter, vars, bulk, maskNone}
diff --git a/src/utils/calls.h b/src/utils/calls.h
index d38cf376..137b8810 100644
--- a/src/utils/calls.h
+++ b/src/utils/calls.h
@@ -26,14 +26,13 @@ CMP_DEF(le, AS);
#define CMP_AA_IMM(FN, ELT, WHERE, WP, XP, LEN) CMP_AA_CALL(CMP_AA_FN(FN, ELT), WHERE, WP, XP, LEN)
#define CMP_AS_IMM(FN, ELT, WHERE, WP, X, LEN) CMP_AS_CALL(CMP_AS_FN(FN, ELT), WHERE, WP, X, LEN)
-// Check if the l elements starting at a and b match
typedef bool (*EqFn)(void* a, void* b, u64 l, u64 data);
extern INIT_GLOBAL EqFn eqFns[];
extern u8 const eqFnData[];
#define EQFN_INDEX(W_ELT, X_ELT) ((W_ELT)*8 + (X_ELT))
typedef struct { EqFn fn; u8 data; } EqFnObj;
#define EQFN_GET(W_ELT, X_ELT) ({ u8 eqfn_i_ = EQFN_INDEX(W_ELT, X_ELT); (EqFnObj){.fn=eqFns[eqfn_i_], .data=eqFnData[eqfn_i_]}; })
-#define EQFN_CALL(FN, W, X, L) (FN).fn(W, X, L, (FN).data)
+#define EQFN_CALL(FN, W, X, L) (FN).fn(W, X, L, (FN).data) // check if L elements starting at a and b match; assumes L≥1
typedef bool (*RangeFn)(void* xp, i64* res, u64 len); // writes min,max in res, assumes len≥1; returns 0 and leaves res undefined if either any (floor(x)≠x or abs>2⋆53), or (x≠(i64)x)
extern INIT_GLOBAL RangeFn getRange_fns[el_f64+1]; // limited to ≤el_f64