summaryrefslogtreecommitdiff
path: root/src/main/java/dev/morling/onebrc/CalculateAverage_gonix.java
diff options
context:
space:
mode:
Diffstat (limited to 'src/main/java/dev/morling/onebrc/CalculateAverage_gonix.java')
-rw-r--r--src/main/java/dev/morling/onebrc/CalculateAverage_gonix.java509
1 files changed, 310 insertions, 199 deletions
diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_gonix.java b/src/main/java/dev/morling/onebrc/CalculateAverage_gonix.java
index 572c272..cbc1127 100644
--- a/src/main/java/dev/morling/onebrc/CalculateAverage_gonix.java
+++ b/src/main/java/dev/morling/onebrc/CalculateAverage_gonix.java
@@ -46,6 +46,7 @@ public class CalculateAverage_gonix {
TreeMap::new));
System.out.println(res);
+ System.out.close();
}
private static List<MappedByteBuffer> buildChunks(RandomAccessFile file) throws IOException {
@@ -75,248 +76,358 @@ public class CalculateAverage_gonix {
}
return chunks;
}
-}
-class Aggregator {
- private static final int MAX_STATIONS = 10_000;
- private static final int MAX_STATION_SIZE = Math.ceilDiv(100, 8) + 5;
- private static final int INDEX_SIZE = 1024 * 1024;
- private static final int INDEX_MASK = INDEX_SIZE - 1;
- private static final int FLD_COUNT = 0;
- private static final int FLD_SUM = 1;
- private static final int FLD_MIN = 2;
- private static final int FLD_MAX = 3;
-
- // Poor man's hash map: hash code to offset in `mem`.
- private final int[] index;
-
- // Contiguous storage of key (station name) and stats fields of all
- // unique stations.
- // The idea here is to improve locality so that stats fields would
- // possibly be already in the CPU cache after we are done comparing
- // the key.
- private final long[] mem;
- private int memUsed;
-
- Aggregator() {
- assert ((INDEX_SIZE & (INDEX_SIZE - 1)) == 0) : "INDEX_SIZE must be power of 2";
- assert (INDEX_SIZE > MAX_STATIONS) : "INDEX_SIZE must be greater than MAX_STATIONS";
-
- index = new int[INDEX_SIZE];
- mem = new long[1 + (MAX_STATIONS * MAX_STATION_SIZE)];
- memUsed = 1;
- }
+ private static class Aggregator {
+ private static final int MAX_STATIONS = 10_000;
+ private static final int MAX_STATION_SIZE = Math.ceilDiv(100, 8) + 5;
+ private static final int INDEX_SIZE = 1024 * 1024;
+ private static final int INDEX_MASK = INDEX_SIZE - 1;
+ private static final int FLD_COUNT = 0;
+ private static final int FLD_SUM = 1;
+ private static final int FLD_MIN = 2;
+ private static final int FLD_MAX = 3;
+
+ // Poor man's hash map: hash code to offset in `mem`.
+ private final int[] index;
+
+ // Contiguous storage of key (station name) and stats fields of all
+ // unique stations.
+ // The idea here is to improve locality so that stats fields would
+ // possibly be already in the CPU cache after we are done comparing
+ // the key.
+ private final long[] mem;
+ private int memUsed;
- Aggregator processChunk(MappedByteBuffer buf) {
- // To avoid checking if it is safe to read a whole long near the
- // end of a chunk, we copy last couple of lines to a padded buffer
- // and process that part separately.
- int limit = buf.limit();
- int pos = Math.max(limit - 16, -1);
- while (pos >= 0 && buf.get(pos) != '\n') {
- pos--;
+ Aggregator() {
+ assert ((INDEX_SIZE & (INDEX_SIZE - 1)) == 0) : "INDEX_SIZE must be power of 2";
+ assert (INDEX_SIZE > MAX_STATIONS) : "INDEX_SIZE must be greater than MAX_STATIONS";
+
+ index = new int[INDEX_SIZE];
+ mem = new long[1 + (MAX_STATIONS * MAX_STATION_SIZE)];
+ memUsed = 1;
}
- pos++;
- if (pos > 0) {
- processChunkLongs(buf, pos);
+
+ Aggregator processChunk(MappedByteBuffer buf) {
+ // To avoid checking if it is safe to read a whole long near the
+ // end of a chunk, we copy last couple of lines to a padded buffer
+ // and process that part separately.
+ int limit = buf.limit();
+ int pos = Math.max(limit - 16, -1);
+ while (pos >= 0 && buf.get(pos) != '\n') {
+ pos--;
+ }
+ pos++;
+ if (pos > 0) {
+ processChunkLongs(buf, pos);
+ }
+ int tailLen = limit - pos;
+ var tailBuf = ByteBuffer.allocate(tailLen + 8).order(ByteOrder.nativeOrder());
+ buf.get(pos, tailBuf.array(), 0, tailLen);
+ processChunkLongs(tailBuf, tailLen);
+ return this;
}
- int tailLen = limit - pos;
- var tailBuf = ByteBuffer.allocate(tailLen + 8).order(ByteOrder.nativeOrder());
- buf.get(pos, tailBuf.array(), 0, tailLen);
- processChunkLongs(tailBuf, tailLen);
- return this;
- }
- Aggregator processChunkLongs(ByteBuffer buf, int limit) {
- int pos = 0;
- while (pos < limit) {
-
- int start = pos;
- int hash = 0;
- long tail = 0;
- while (true) {
- // Seen this trick used in multiple other solutions.
- // Nice breakdown here: https://graphics.stanford.edu/~seander/bithacks.html#ZeroInWord
- long tmpLong = buf.getLong(pos);
- long match = tmpLong ^ 0x3B3B3B3B_3B3B3B3BL; // 3B == ';'
- match = ((match - 0x01010101_01010101L) & (~match & 0x80808080_80808080L));
- if (match == 0) {
- hash = ((33 * hash) ^ (int) (tmpLong & 0xFFFFFFFF)) + (int) ((tmpLong >>> 33) & 0xFFFFFFFF);
- pos += 8;
+ Aggregator processChunkLongs(ByteBuffer buf, int limit) {
+ int pos = 0;
+ while (pos < limit) {
+
+ int start = pos;
+ long keyLong = buf.getLong(pos);
+ long valueSepMark = valueSepMark(keyLong);
+ if (valueSepMark != 0) {
+ int tailBits = tailBits(valueSepMark);
+ pos += valueOffset(tailBits);
+ // assert (UNSAFE.getByte(pos - 1) == ';') : "Expected ';' (1), pos=" + (pos - startAddr);
+ long tailAndLen = tailAndLen(tailBits, keyLong, pos - start - 1);
+
+ long valueLong = buf.getLong(pos);
+ int decimalSepMark = decimalSepMark(valueLong);
+ pos += nextKeyOffset(decimalSepMark);
+ // assert (UNSAFE.getByte(pos - 1) == '\n') : "Expected '\\n' (1), pos=" + (pos - startAddr);
+ int measurement = decimalValue(decimalSepMark, valueLong);
+
+ add1(buf, start, tailAndLen, hash(hash1(tailAndLen)), measurement);
continue;
}
- int tailBits = Long.numberOfTrailingZeros(match >>> 7);
- long tailMask = ~(-1L << tailBits);
- tail = tmpLong & tailMask;
- hash = ((33 * hash) ^ (int) (tail & 0xFFFFFFFF)) + (int) ((tail >>> 33) & 0xFFFFFFFF);
- pos += tailBits >> 3;
- break;
- }
- hash = (33 * hash) ^ (hash >>> 15);
- int lenInLongs = (pos - start) >> 3;
- long tailAndLen = (tail << 8) | (lenInLongs & 0xFF);
- // assert (buf.get(pos) == ';') : "Expected ';'";
- pos++;
+ pos += 8;
+ long keyLong1 = keyLong;
+ keyLong = buf.getLong(pos);
+ valueSepMark = valueSepMark(keyLong);
+ if (valueSepMark != 0) {
+ int tailBits = tailBits(valueSepMark);
+ pos += valueOffset(tailBits);
+ // assert (UNSAFE.getByte(pos - 1) == ';') : "Expected ';' (2), pos=" + (pos - startAddr);
+ long tailAndLen = tailAndLen(tailBits, keyLong, pos - start - 1);
+
+ long valueLong = buf.getLong(pos);
+ int decimalSepMark = decimalSepMark(valueLong);
+ pos += nextKeyOffset(decimalSepMark);
+ // assert (UNSAFE.getByte(pos - 1) == '\n') : "Expected '\\n' (2), pos=" + (pos - startAddr);
+ int measurement = decimalValue(decimalSepMark, valueLong);
+
+ add2(buf, start, keyLong1, tailAndLen, hash(hash(hash1(keyLong1), tailAndLen)), measurement);
+ continue;
+ }
- int measurement;
- {
- // Seen this trick used in multiple other solutions.
- // Looks like the original author is @merykitty.
- long tmpLong = buf.getLong(pos);
-
- // The 4th binary digit of the ascii of a digit is 1 while
- // that of the '.' is 0. This finds the decimal separator
- // The value can be 12, 20, 28
- int decimalSepPos = Long.numberOfTrailingZeros(~tmpLong & 0x10101000);
- int shift = 28 - decimalSepPos;
- // signed is -1 if negative, 0 otherwise
- long signed = (~tmpLong << 59) >> 63;
- long designMask = ~(signed & 0xFF);
- // Align the number to a specific position and transform the ascii code
- // to actual digit value in each byte
- long digits = ((tmpLong & designMask) << shift) & 0x0F000F0F00L;
-
- // Now digits is in the form 0xUU00TTHH00 (UU: units digit, TT: tens digit, HH: hundreds digit)
- // 0xUU00TTHH00 * (100 * 0x1000000 + 10 * 0x10000 + 1) =
- // 0x000000UU00TTHH00 +
- // 0x00UU00TTHH000000 * 10 +
- // 0xUU00TTHH00000000 * 100
- // Now TT * 100 has 2 trailing zeroes and HH * 100 + TT * 10 + UU < 0x400
- // This results in our value lies in the bit 32 to 41 of this product
- // That was close :)
- long absValue = ((digits * 0x640a0001) >>> 32) & 0x3FF;
- measurement = (int) ((absValue ^ signed) - signed);
- pos += (decimalSepPos >>> 3) + 3;
+ long hash = hash1(keyLong1);
+ do {
+ pos += 8;
+ hash = hash(hash, keyLong);
+ keyLong = buf.getLong(pos);
+ valueSepMark = valueSepMark(keyLong);
+ } while (valueSepMark == 0);
+ int tailBits = tailBits(valueSepMark);
+ pos += valueOffset(tailBits);
+ // assert (UNSAFE.getByte(pos - 1) == ';') : "Expected ';' (N), pos=" + (pos - startAddr);
+ long tailAndLen = tailAndLen(tailBits, keyLong, pos - start - 1);
+ hash = hash(hash, tailAndLen);
+
+ long valueLong = buf.getLong(pos);
+ int decimalSepMark = decimalSepMark(valueLong);
+ pos += nextKeyOffset(decimalSepMark);
+ // assert (UNSAFE.getByte(pos - 1) == '\n') : "Expected '\\n' (N), pos=" + (pos - startAddr);
+ int measurement = decimalValue(decimalSepMark, valueLong);
+
+ addN(buf, start, tailAndLen, hash(hash), measurement);
}
- // assert (buf.get(pos - 1) == '\n') : "Expected '\\n'";
- add(buf, start, tailAndLen, hash, measurement);
+ return this;
}
- return this;
- }
+ public Stream<Entry> stream() {
+ return Arrays.stream(index)
+ .filter(offset -> offset != 0)
+ .mapToObj(offset -> new Entry(mem, offset));
+ }
- public Stream<Entry> stream() {
- return Arrays.stream(index)
- .filter(offset -> offset != 0)
- .mapToObj(offset -> new Entry(mem, offset));
- }
+ private static long hash1(long value) {
+ return value;
+ }
- private void add(ByteBuffer buf, int start, long tailAndLen, int hash, int measurement) {
- int idx = hash & INDEX_MASK;
- for (; index[idx] != 0; idx = (idx + 1) & INDEX_MASK) {
- if (update(index[idx], buf, start, tailAndLen, measurement)) {
- return;
- }
+ private static long hash(long hash, long value) {
+ return hash ^ value;
+ }
+
+ private static int hash(long hash) {
+ hash *= 0x9E3779B97F4A7C15L; // Fibonacci hashing multiplier
+ return (int) (hash >>> 39);
}
- index[idx] = create(buf, start, tailAndLen, measurement);
- }
- private int create(ByteBuffer buf, int start, long tailAndLen, int measurement) {
- int offset = memUsed;
+ private static long valueSepMark(long keyLong) {
+ // Seen this trick used in multiple other solutions.
+ // Nice breakdown here: https://graphics.stanford.edu/~seander/bithacks.html#ZeroInWord
+ long match = keyLong ^ 0x3B3B3B3B_3B3B3B3BL; // 3B == ';'
+ match = (match - 0x01010101_01010101L) & (~match & 0x80808080_80808080L);
+ return match;
+ }
- mem[offset] = tailAndLen;
+ private static int tailBits(long valueSepMark) {
+ return Long.numberOfTrailingZeros(valueSepMark >>> 7);
+ }
- int memPos = offset + 1;
- int memEnd = memPos + (int) (tailAndLen & 0xFF);
- int bufPos = start;
- while (memPos < memEnd) {
- mem[memPos] = buf.getLong(bufPos);
- memPos += 1;
- bufPos += 8;
+ private static int valueOffset(int tailBits) {
+ return (int) (tailBits >>> 3) + 1;
}
- mem[memPos + FLD_MIN] = measurement;
- mem[memPos + FLD_MAX] = measurement;
- mem[memPos + FLD_SUM] = measurement;
- mem[memPos + FLD_COUNT] = 1;
- memUsed = memPos + 4;
+ private static long tailAndLen(int tailBits, long keyLong, long keyLen) {
+ long tailMask = ~(-1L << tailBits);
+ long tail = keyLong & tailMask;
+ return (tail << 8) | ((keyLen >> 3) & 0xFF);
+ }
- return offset;
- }
+ private static int decimalSepMark(long value) {
+ // Seen this trick used in multiple other solutions.
+ // Looks like the original author is @merykitty.
- private boolean update(int offset, ByteBuffer buf, int start, long tailAndLen, int measurement) {
- var mem = this.mem;
- if (mem[offset] != tailAndLen) {
- return false;
+ // The 4th binary digit of the ascii of a digit is 1 while
+ // that of the '.' is 0. This finds the decimal separator
+ // The value can be 12, 20, 28
+ return Long.numberOfTrailingZeros(~value & 0x10101000);
}
- int memPos = offset + 1;
- int memEnd = memPos + (int) (tailAndLen & 0xFF);
- int bufPos = start;
- while (memPos < memEnd) {
- if (mem[memPos] != buf.getLong(bufPos)) {
- return false;
+
+ private static int decimalValue(int decimalSepMark, long value) {
+ // Seen this trick used in multiple other solutions.
+ // Looks like the original author is @merykitty.
+
+ int shift = 28 - decimalSepMark;
+ // signed is -1 if negative, 0 otherwise
+ long signed = (~value << 59) >> 63;
+ long designMask = ~(signed & 0xFF);
+ // Align the number to a specific position and transform the ascii code
+ // to actual digit value in each byte
+ long digits = ((value & designMask) << shift) & 0x0F000F0F00L;
+
+ // Now digits is in the form 0xUU00TTHH00 (UU: units digit, TT: tens digit, HH: hundreds digit)
+ // 0xUU00TTHH00 * (100 * 0x1000000 + 10 * 0x10000 + 1) =
+ // 0x000000UU00TTHH00 +
+ // 0x00UU00TTHH000000 * 10 +
+ // 0xUU00TTHH00000000 * 100
+ // Now TT * 100 has 2 trailing zeroes and HH * 100 + TT * 10 + UU < 0x400
+ // This results in our value lies in the bit 32 to 41 of this product
+ // That was close :)
+ long absValue = ((digits * 0x640a0001) >>> 32) & 0x3FF;
+ return (int) ((absValue ^ signed) - signed);
+ }
+
+ private static int nextKeyOffset(int decimalSepMark) {
+ return (decimalSepMark >>> 3) + 3;
+ }
+
+ private void add1(ByteBuffer buf, int start, long tailAndLen, int hash, int measurement) {
+ int idx = hash & INDEX_MASK;
+ for (; index[idx] != 0; idx = (idx + 1) & INDEX_MASK) {
+ if (update1(index[idx], tailAndLen, measurement)) {
+ return;
+ }
}
- memPos += 1;
- bufPos += 8;
+ index[idx] = create(buf, start, tailAndLen, measurement);
}
- mem[memPos + FLD_COUNT] += 1;
- mem[memPos + FLD_SUM] += measurement;
- if (measurement < mem[memPos + FLD_MIN]) {
- mem[memPos + FLD_MIN] = measurement;
+ private void add2(ByteBuffer buf, int start, long keyLong, long tailAndLen, int hash, int measurement) {
+ int idx = hash & INDEX_MASK;
+ for (; index[idx] != 0; idx = (idx + 1) & INDEX_MASK) {
+ if (update2(index[idx], keyLong, tailAndLen, measurement)) {
+ return;
+ }
+ }
+ index[idx] = create(buf, start, tailAndLen, measurement);
}
- if (measurement > mem[memPos + FLD_MAX]) {
- mem[memPos + FLD_MAX] = measurement;
+
+ private void addN(ByteBuffer buf, int start, long tailAndLen, int hash, int measurement) {
+ int idx = hash & INDEX_MASK;
+ for (; index[idx] != 0; idx = (idx + 1) & INDEX_MASK) {
+ if (updateN(index[idx], buf, start, tailAndLen, measurement)) {
+ return;
+ }
+ }
+ index[idx] = create(buf, start, tailAndLen, measurement);
}
- return true;
- }
+ private int create(ByteBuffer buf, int start, long tailAndLen, int measurement) {
+ int offset = memUsed;
- public static class Entry {
- private final long[] mem;
- private final int offset;
- private String key;
+ mem[offset] = tailAndLen;
- Entry(long[] mem, int offset) {
- this.mem = mem;
- this.offset = offset;
+ int memPos = offset + 1;
+ int memEnd = memPos + (int) (tailAndLen & 0xFF);
+ int bufPos = start;
+ while (memPos < memEnd) {
+ mem[memPos] = buf.getLong(bufPos);
+ memPos += 1;
+ bufPos += 8;
+ }
+
+ mem[memPos + FLD_MIN] = measurement;
+ mem[memPos + FLD_MAX] = measurement;
+ mem[memPos + FLD_SUM] = measurement;
+ mem[memPos + FLD_COUNT] = 1;
+ memUsed = memPos + 4;
+
+ return offset;
}
- public String getKey() {
- if (key == null) {
- int pos = this.offset;
- long tailAndLen = mem[pos++];
- int keyLen = (int) (tailAndLen & 0xFF);
- var tmpBuf = ByteBuffer.allocate((keyLen << 3) + 8).order(ByteOrder.nativeOrder());
- for (int i = 0; i < keyLen; i++) {
- tmpBuf.putLong(mem[pos++]);
- }
- long tail = tailAndLen >>> 8;
- tmpBuf.putLong(tail);
- int keyLenBytes = (keyLen << 3) + 8 - (Long.numberOfLeadingZeros(tail) >> 3);
- key = new String(tmpBuf.array(), 0, keyLenBytes, StandardCharsets.UTF_8);
+ private boolean update1(int offset, long tailAndLen, int measurement) {
+ if (mem[offset] != tailAndLen) {
+ return false;
}
- return key;
+ updateStats(offset + 1, measurement);
+ return true;
}
- public Entry add(Entry other) {
- int fldOffset = (int) (mem[offset] & 0xFF) + 1;
- int pos = offset + fldOffset;
- int otherPos = other.offset + fldOffset;
- long[] otherMem = other.mem;
- mem[pos + FLD_MIN] = Math.min((int) mem[pos + FLD_MIN], (int) otherMem[otherPos + FLD_MIN]);
- mem[pos + FLD_MAX] = Math.max((int) mem[pos + FLD_MAX], (int) otherMem[otherPos + FLD_MAX]);
- mem[pos + FLD_SUM] += otherMem[otherPos + FLD_SUM];
- mem[pos + FLD_COUNT] += otherMem[otherPos + FLD_COUNT];
- return this;
+ private boolean update2(int offset, long keyLong, long tailAndLen, int measurement) {
+ if (mem[offset] != tailAndLen || mem[offset + 1] != keyLong) {
+ return false;
+ }
+ updateStats(offset + 2, measurement);
+ return true;
}
- public Entry getValue() {
- return this;
+ private boolean updateN(int offset, ByteBuffer buf, int start, long tailAndLen, int measurement) {
+ var mem = this.mem;
+ if (mem[offset] != tailAndLen) {
+ return false;
+ }
+ int memPos = offset + 1;
+ int memEnd = memPos + (int) (tailAndLen & 0xFF);
+ int bufPos = start;
+ while (memPos < memEnd) {
+ if (mem[memPos] != buf.getLong(bufPos)) {
+ return false;
+ }
+ memPos += 1;
+ bufPos += 8;
+ }
+ updateStats(memPos, measurement);
+ return true;
}
- @Override
- public String toString() {
- int pos = offset + (int) (mem[offset] & 0xFF) + 1;
- return round(mem[pos + FLD_MIN])
- + "/" + round(((double) mem[pos + FLD_SUM]) / mem[pos + FLD_COUNT])
- + "/" + round(mem[pos + FLD_MAX]);
+ private void updateStats(int memPos, int measurement) {
+ mem[memPos + FLD_COUNT] += 1;
+ mem[memPos + FLD_SUM] += measurement;
+ if (measurement < mem[memPos + FLD_MIN]) {
+ mem[memPos + FLD_MIN] = measurement;
+ }
+ if (measurement > mem[memPos + FLD_MAX]) {
+ mem[memPos + FLD_MAX] = measurement;
+ }
}
- private static double round(double value) {
- return Math.round(value) / 10.0;
+ public static class Entry {
+ private final long[] mem;
+ private final int offset;
+ private String key;
+
+ Entry(long[] mem, int offset) {
+ this.mem = mem;
+ this.offset = offset;
+ }
+
+ public String getKey() {
+ if (key == null) {
+ int pos = this.offset;
+ long tailAndLen = mem[pos++];
+ int keyLen = (int) (tailAndLen & 0xFF);
+ var tmpBuf = ByteBuffer.allocate((keyLen << 3) + 8).order(ByteOrder.nativeOrder());
+ for (int i = 0; i < keyLen; i++) {
+ tmpBuf.putLong(mem[pos++]);
+ }
+ long tail = tailAndLen >>> 8;
+ tmpBuf.putLong(tail);
+ int keyLenBytes = (keyLen << 3) + 8 - (Long.numberOfLeadingZeros(tail) >> 3);
+ key = new String(tmpBuf.array(), 0, keyLenBytes, StandardCharsets.UTF_8);
+ }
+ return key;
+ }
+
+ public Entry add(Entry other) {
+ int fldOffset = (int) (mem[offset] & 0xFF) + 1;
+ int pos = offset + fldOffset;
+ int otherPos = other.offset + fldOffset;
+ long[] otherMem = other.mem;
+ mem[pos + FLD_MIN] = Math.min((int) mem[pos + FLD_MIN], (int) otherMem[otherPos + FLD_MIN]);
+ mem[pos + FLD_MAX] = Math.max((int) mem[pos + FLD_MAX], (int) otherMem[otherPos + FLD_MAX]);
+ mem[pos + FLD_SUM] += otherMem[otherPos + FLD_SUM];
+ mem[pos + FLD_COUNT] += otherMem[otherPos + FLD_COUNT];
+ return this;
+ }
+
+ public Entry getValue() {
+ return this;
+ }
+
+ @Override
+ public String toString() {
+ int pos = offset + (int) (mem[offset] & 0xFF) + 1;
+ return round(mem[pos + FLD_MIN])
+ + "/" + round(((double) mem[pos + FLD_SUM]) / mem[pos + FLD_COUNT])
+ + "/" + round(mem[pos + FLD_MAX]);
+ }
+
+ private static double round(double value) {
+ return Math.round(value) / 10.0;
+ }
}
}
+
}