summaryrefslogtreecommitdiff
path: root/src/utils/bits.c
blob: 525b8fc832d9f1544515bec86a0aa89a92dfca84 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
#include "../core.h"
#include "mut.h"


#if SINGELI_SIMD
  #define SINGELI_FILE bits
  #include "../utils/includeSingeli.h"
  #define bitselFns simd_bitsel
#endif

#if defined(__BMI2__) && !SLOW_PDEP
  #define FAST_PDEP 1
  #include <immintrin.h>
#endif

NOINLINE Arr* allZeroes(usz ia) { u64* rp; Arr* r = m_bitarrp(&rp, ia); for (usz i = 0; i < BIT_N(ia); i++) rp[i] =  0;    return r; }
NOINLINE Arr* allOnes  (usz ia) { u64* rp; Arr* r = m_bitarrp(&rp, ia); for (usz i = 0; i < BIT_N(ia); i++) rp[i] = ~0ULL; return r; }

NOINLINE B bit_sel(B b, B e0, B e1) {
  u8 t0 = selfElType(e0);
  u64* bp = bitarr_ptr(b);
  usz ia = IA(b);
  B r;
  {
    u8 type, width;
    u64 e0i, e1i;
    if (elNum(t0) && isF64(e1)) {
      f64 f0 = o2fG(e0);
      f64 f1 = o2fG(e1);
      switch (t0) { default: UD;
        case el_bit: if (f1==0||f1==1) goto t_bit;
        case el_i8:  if (q_fi8(f1)) goto t_i8; if (q_fi16(f1)) goto t_i16; if (q_fi32(f1)) goto t_i32; goto t_f64; // not using fallthrough to allow deduplicating float→int conversion
        case el_i16:                           if (q_fi16(f1)) goto t_i16; if (q_fi32(f1)) goto t_i32; goto t_f64;
        case el_i32:                                                       if (q_fi32(f1)) goto t_i32; goto t_f64;
        case el_f64: goto t_f64;
      }
      t_bit:
        if (f0) return f1? i64EachDec(1, b) : bit_negate(b);
        else    return f1? b : i64EachDec(0, b);
      t_i8:  type=t_i8arr;  width=0; e0i=( u8)( i8)f0; e1i=( u8)( i8)f1; goto sel;
      t_i16: type=t_i16arr; width=1; e0i=(u16)(i16)f0; e1i=(u16)(i16)f1; goto sel;
      t_i32: type=t_i32arr; width=2; e0i=(u32)(i32)f0; e1i=(u32)(i32)f1; goto sel;
      t_f64: type=t_f64arr; width=3; e0i=     b(f0).u; e1i=     b(f1).u; goto sel;
      
    } else if (elChr(t0) && isC32(e1)) {
      u32 u0 = o2cG(e0); u32 u1 = o2cG(e1);
      switch(t0) { default: UD;
        case el_c8:  if (u1==( u8)u1) { type=t_c8arr;  width=0; e0i=u0; e1i=u1; goto sel; } // else fallthrough
        case el_c16: if (u1==(u16)u1) { type=t_c16arr; width=1; e0i=u0; e1i=u1; goto sel; } // else fallthrough
        case el_c32:                  { type=t_c32arr; width=2; e0i=u0; e1i=u1; goto sel; }
      }
    } else goto slow;
    
    sel:;
    void* rp = m_tyarrlc(&r, width, b, type);
    #if SINGELI_SIMD
      bitselFns[width](rp, bp, e0i, e1i, ia);
    #else
      switch(width) {
        case 0: for (usz i=0; i<ia; i++) (( u8*)rp)[i] = bitp_get(bp,i)? e1i : e0i; break;
        case 1: for (usz i=0; i<ia; i++) ((u16*)rp)[i] = bitp_get(bp,i)? e1i : e0i; break;
        case 2: for (usz i=0; i<ia; i++) ((u32*)rp)[i] = bitp_get(bp,i)? e1i : e0i; break;
        case 3: for (usz i=0; i<ia; i++) ((u64*)rp)[i] = bitp_get(bp,i)? e1i : e0i; break;
      }
    #endif
    goto dec_ret;
  }
  
  slow:;
  HArr_p ra = m_harrUc(b);
  SLOW3("bit_sel", e0, e1, b);
  for (usz i = 0; i < ia; i++) ra.a[i] = bitp_get(bp,i)? e1 : e0;
  NOGC_E;
  
  u64 c1 = bit_sum(bp, ia);
  u64 c0 = ia-c1;
  incBy(e0,c0);
  incBy(e1,c1);
  r = ra.b;
  
  dec_ret:
  decG(b); return r;
}


static inline u64 rbuu64(u64* p, ux off) { // read bit-unaligned u64; aka 64↑off↓p
  ux p0 = off>>6;
  ux m0 = off&63;
  u64 v0 = p[p0];
  u64 v1 = p[p0+1];
  #if HAS_U128
    u128 v = v0 | ((u128)v1)<<64;
    return v>>m0;
  #else
    return m0==0? v0 : v0>>m0 | v1<<(64-m0);
  #endif
}

typedef struct {
  u64* ptr;
  u64 tmp;
  ux off;
} ABState;
static ABState ab_new(u64* p) { return (ABState){.ptr=p, .tmp=0, .off=0}; }
static void ab_done(ABState s) { if (s.off) *s.ptr = s.tmp; }
FORCE_INLINE void ab_add(ABState* state, u64 val, ux count) { // assumes bits past count are 0
  assert(count==64 || (val>>count)==0);
  ux off0 = state->off;
  ux off1 = off0 + count;
  state->off = off1&63;
  assert((state->tmp>>off0) == 0);
  state->tmp|= val<<off0;
  if (off1>=64) {
    *state->ptr++ = state->tmp;
    state->tmp = off0==0? 0 : val>>(64-off0);
  }
}


static NOINLINE B zeroPadToCellBits0(B x, usz lr, usz cam, usz pcsz, usz ncsz) { // consumes; for now assumes ncsz is either a multiple of 64, or one of 8,16,32
  assert((ncsz&7) == 0 && RNK(x)>=1 && pcsz<ncsz);
  // printf("zeroPadToCellBits0 ia=%d cam=%d pcsz=%d ncsz=%d\n", IA(x), cam, pcsz, ncsz);
  if (pcsz==1) {
    if (ncsz== 8) return taga(cpyI8Arr(x));
    if (ncsz==16) return taga(cpyI16Arr(x));
    if (ncsz==32) return taga(cpyI32Arr(x));
  }
  
  if (lr==UR_MAX) thrM("Rank too large");
  u64* rp;
  Arr* r = m_bitarrp(&rp, cam*ncsz);
  usz* rsh = arr_shAlloc(r, lr+1);
  shcpy(rsh, SH(x), lr);
  rsh[lr] = ncsz;
  u64* xp = tyany_ptr(x);
  
  // TODO widen 8/16-bit cells to 16/32 via cpyC(16|32)Arr
  if (ncsz<=64 && (ncsz&(ncsz-1)) == 0) {
    u64 tmsk = (1ull<<pcsz)-1;
    #if FAST_PDEP
      if (ncsz<32) {
        assert(ncsz==8 || ncsz==16);
        bool c8 = ncsz==8;
        u64 msk0 = tmsk * (c8? 0x0101010101010101 : 0x0001000100010001);
        ux am = c8? cam/8 : cam/4;
        u32 count = POPC(msk0);
        // printf("widen base %04lx %016lx count=%d am=%zu\n", tmsk, msk0, count, am);
        for (ux i=0; i<am; i++) { *(u64*)rp = _pdep_u64(rbuu64(xp, i*count), msk0); rp++; }
        u32 tb = c8? cam&7 : (cam&3)<<1;
        if (tb) {
          u64 msk1 = msk0 & ((1ull<<tb*8)-1);
          // printf("widen tail %4d %016lx count=%d\n", tb, msk1, POPC(msk1));
          *(u64*)rp = _pdep_u64(rbuu64(xp, am*count), msk1);
        }
      }
      else if (ncsz==32) for (ux i=0; i<cam; i++) ((u32*)rp)[i] = rbuu64(xp, i*pcsz)&tmsk;
      else               for (ux i=0; i<cam; i++) ((u64*)rp)[i] = rbuu64(xp, i*pcsz)&tmsk;
    #else
      switch(ncsz) { default: UD;
        case  8: for (ux i=0; i<cam; i++) ((u8* )rp)[i] = rbuu64(xp, i*pcsz)&tmsk; break;
        case 16: for (ux i=0; i<cam; i++) ((u16*)rp)[i] = rbuu64(xp, i*pcsz)&tmsk; break;
        case 32: for (ux i=0; i<cam; i++) ((u32*)rp)[i] = rbuu64(xp, i*pcsz)&tmsk; break;
        case 64: for (ux i=0; i<cam; i++) ((u64*)rp)[i] = rbuu64(xp, i*pcsz)&tmsk; break;
      }
    #endif
  } else {
    assert((ncsz&63) == 0 && ncsz-pcsz < 64 && (pcsz&63) != 0);
    ux pfu64 = pcsz>>6; // previous full u64 count in cell
    u64 msk = (1ull<<(pcsz&63))-1;
    for (ux i = 0; i < cam; i++) {
      for (ux j = 0; j < pfu64; j++) rp[j] = rbuu64(xp, i*pcsz + j*64);
      rp[pfu64] = rbuu64(xp, i*pcsz + pfu64*64) & msk;
      rp+= ncsz>>6;
    }
  }
  decG(x);
  return taga(r);
}
NOINLINE B widenBitArr(B x, ur axis) {
  assert(isArr(x) && TI(x,elType)!=el_B && axis>=1 && RNK(x)>=axis);
  usz pcsz = shProd(SH(x), axis, RNK(x))<<elwBitLog(TI(x,elType));
  assert(pcsz!=0);
  usz ncsz;
  if (pcsz<=8) ncsz = 8;
  else if (pcsz<=16) ncsz = 16;
  else if (pcsz<=32) ncsz = 32;
  // else if (pcsz<=64) ncsz = 64;
  // else ncsz = (pcsz+7)&~(usz)7;
  else ncsz = (pcsz+63)&~(usz)63;
  if (ncsz==pcsz) return x;
  
  return zeroPadToCellBits0(x, axis, shProd(SH(x), 0, axis), pcsz, ncsz);
}

B narrowWidenedBitArr(B x, ur axis, ur cr, usz* csh) { // for now assumes the bits to be dropped are zero, origCellBits is a multiple of 8, and that there's at most 63 padding bits
  if (TI(x,elType)!=el_bit) return taga(cpyBitArr(x));
  if (axis+cr>UR_MAX) thrM("Rank too large");
  
  usz xcsz = shProd(SH(x), axis, RNK(x));
  usz ocsz = shProd(csh, 0, cr);
  // printf("narrowWidenedBitArr ia=%d axis=%d cr=%d ocsz=%d xcsz=%d\n", IA(x), axis, cr, ocsz, xcsz);
  assert((xcsz&7) == 0 && ocsz<xcsz && ocsz!=0);
  if (xcsz==ocsz) {
    if (RNK(x)-axis == cr && eqShPart(SH(x)+axis, csh, cr)) return x;
    Arr* r = cpyWithShape(x);
    ShArr* rsh = m_shArr(axis+cr);
    shcpy(rsh->a, SH(x), axis);
    shcpy(rsh->a+axis, csh, cr);
    arr_shReplace(r, axis+cr, rsh);
    return taga(r);
  }
  
  
  usz cam = shProd(SH(x), 0, axis);
  u64* rp;
  Arr* r = m_bitarrp(&rp, cam*ocsz);
  usz* rsh = arr_shAlloc(r, axis+cr);
  shcpy(rsh, SH(x), axis);
  shcpy(rsh+axis, csh, cr);
  
  u8* xp = tyany_ptr(x);
  // FILL_TO(rp, el_bit, 0, m_f64(1), PIA(r));
  ABState ab = ab_new(rp);
  if (xcsz<=64 && (xcsz&(xcsz-1)) == 0) {
    #if FAST_PDEP
      if (xcsz<32) {
        assert(xcsz==8 || xcsz==16);
        bool c8 = xcsz==8;
        u64 tmsk = (1ull<<ocsz)-1;
        u64 msk0 = tmsk * (c8? 0x0101010101010101 : 0x0001000100010001);
        ux am = c8? cam/8 : cam/4;
        u32 count = POPC(msk0);
        // printf("narrow base %04lx %016lx count=%d am=%zu\n", tmsk, msk0, count, am);
        for (ux i=0; i<am; i++) { ab_add(&ab, _pext_u64(*(u64*)xp, msk0), count); xp+= 8; }
        u32 tb = c8? cam&7 : (cam&3)<<1;
        if (tb) {
          u64 msk1 = msk0 & ((1ull<<tb*8)-1);
          // printf("narrow tail %4d %016lx count=%d\n", tb, msk1, POPC(msk1));
          ab_add(&ab, _pext_u64(*(u64*)xp, msk1), POPC(msk1));
        }
      }
      else if (xcsz==32) for (ux i=0; i<cam; i++) ab_add(&ab, ((u32*)xp)[i], ocsz);
      else               for (ux i=0; i<cam; i++) ab_add(&ab, ((u64*)xp)[i], ocsz);
    #else
      switch(xcsz) { default: UD;
        case  8: for (ux i=0; i<cam; i++) ab_add(&ab, ((u8* )xp)[i], ocsz); break; // all assume zero padding
        case 16: for (ux i=0; i<cam; i++) ab_add(&ab, ((u16*)xp)[i], ocsz); break;
        case 32: for (ux i=0; i<cam; i++) ab_add(&ab, ((u32*)xp)[i], ocsz); break;
        case 64: for (ux i=0; i<cam; i++) ab_add(&ab, ((u64*)xp)[i], ocsz); break;
      }
    #endif
  } else {
    assert(xcsz-ocsz<64);
    ux rfu64 = ocsz>>6; // full u64 count per cell in x
    u64 msk = (1ull<<(ocsz&63))-1;
    for (ux i = 0; i < cam; i++) {
      for (ux j = 0; j < rfu64; j++) ab_add(&ab, loadu_u64(j + (u64*)xp), 64);
      ab_add(&ab, loadu_u64(rfu64 + (u64*)xp)&msk, ocsz&63);
      rp+= ocsz>>6;
      xp+= xcsz>>3;
    }
  }
  ab_done(ab);
  decG(x);
  return taga(r);
}