summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorJamal Mulla <jamaldevacc@gmail.com>2024-01-31 21:09:25 +0000
committerGitHub <noreply@github.com>2024-01-31 22:09:25 +0100
commite639e2a045371ab0be51404767a42f22f689cf2c (patch)
tree278ae1e1292f3238838dd0050244d9ad48738605 /src
parentb91c95a498c5959ae391c7ad4fdeb2162e31b73d (diff)
Second attempt with various improvements (#510)
* Initial chunked impl * Bytes instead of chars * Improved number parsing * Custom hashmap * Graal and some tuning * Fix segmenting * Fix casing * Unsafe * Inlining hash calc * Improved loop * Cleanup * Speeding up equals * Simplifying hash * Replace concurrenthashmap with lock * Small changes * Script reorg * Native * Lots of inlining and improvements * Add back length check * Fixes * Small changes --------- Co-authored-by: Jamal Mulla <j.mulla@mwam.com>
Diffstat (limited to 'src')
-rw-r--r--src/main/java/dev/morling/onebrc/CalculateAverage_JamalMulla.java346
1 files changed, 161 insertions, 185 deletions
diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_JamalMulla.java b/src/main/java/dev/morling/onebrc/CalculateAverage_JamalMulla.java
index 7705885..7daf199 100644
--- a/src/main/java/dev/morling/onebrc/CalculateAverage_JamalMulla.java
+++ b/src/main/java/dev/morling/onebrc/CalculateAverage_JamalMulla.java
@@ -21,21 +21,32 @@ import java.io.IOException;
import java.io.RandomAccessFile;
import java.lang.foreign.Arena;
import java.lang.reflect.Field;
-import java.nio.MappedByteBuffer;
import java.nio.channels.FileChannel;
import java.nio.charset.StandardCharsets;
-import java.util.*;
+import java.util.Map;
+import java.util.TreeMap;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
public class CalculateAverage_JamalMulla {
- private static final Map<String, ResultRow> global = new HashMap<>();
+ private static final long ALL_SEMIS = 0x3B3B3B3B3B3B3B3BL;
+ private static final Map<String, ResultRow> global = new TreeMap<>();
private static final String FILE = "./measurements.txt";
private static final Unsafe UNSAFE = initUnsafe();
private static final Lock lock = new ReentrantLock();
- private static final int FNV_32_INIT = 0x811c9dc5;
- private static final int FNV_32_PRIME = 0x01000193;
+ private static final long FXSEED = 0x517cc1b727220a95L;
+
+ private static final long[] masks = {
+ 0x0,
+ 0x00000000000000FFL,
+ 0x000000000000FFFFL,
+ 0x0000000000FFFFFFL,
+ 0x00000000FFFFFFFFL,
+ 0x000000FFFFFFFFFFL,
+ 0x0000FFFFFFFFFFFFL,
+ 0x00FFFFFFFFFFFFFFL
+ };
private static Unsafe initUnsafe() {
try {
@@ -53,12 +64,16 @@ public class CalculateAverage_JamalMulla {
private int max;
private long sum;
private int count;
+ private final long keyStart;
+ private final byte keyLength;
- private ResultRow(int v) {
+ private ResultRow(int v, final long keyStart, final byte keyLength) {
this.min = v;
this.max = v;
this.sum = v;
this.count = 1;
+ this.keyStart = keyStart;
+ this.keyLength = keyLength;
}
public String toString() {
@@ -68,236 +83,197 @@ public class CalculateAverage_JamalMulla {
private double round(double value) {
return Math.round(value) / 10.0;
}
+
}
private record Chunk(Long start, Long length) {
}
- static List<Chunk> getChunks(int numThreads, FileChannel channel) throws IOException {
+ static Chunk[] getChunks(int numThreads, FileChannel channel) throws IOException {
// get all chunk boundaries
final long filebytes = channel.size();
final long roughChunkSize = filebytes / numThreads;
- final List<Chunk> chunks = new ArrayList<>(numThreads);
+ final Chunk[] chunks = new Chunk[numThreads];
final long mappedAddress = channel.map(FileChannel.MapMode.READ_ONLY, 0, filebytes, Arena.global()).address();
long chunkStart = 0;
long chunkLength = Math.min(filebytes - chunkStart - 1, roughChunkSize);
+ int i = 0;
while (chunkStart < filebytes) {
- // unlikely we need to read more than this many bytes to find the next newline
- MappedByteBuffer mbb = channel.map(FileChannel.MapMode.READ_ONLY, chunkStart + chunkLength,
- Math.min(Math.min(filebytes - chunkStart - chunkLength, chunkLength), 100));
-
- while (mbb.get() != 0xA /* \n */) {
+ while (UNSAFE.getByte(mappedAddress + chunkStart + chunkLength) != 0xA /* \n */) {
chunkLength++;
}
- chunks.add(new Chunk(mappedAddress + chunkStart, chunkLength + 1));
+ chunks[i++] = new Chunk(mappedAddress + chunkStart, chunkLength + 1);
// to skip the nl in the next chunk
chunkStart += chunkLength + 1;
chunkLength = Math.min(filebytes - chunkStart - 1, roughChunkSize);
}
+
return chunks;
}
- private static class CalculateTask implements Runnable {
+ private static void run(Chunk chunk) {
- private final SimplerHashMap results;
- private final Chunk chunk;
+ // can't have more than 10000 unique keys but want to match max hash
+ final int MAPSIZE = 65536;
+ final ResultRow[] slots = new ResultRow[MAPSIZE];
- public CalculateTask(Chunk chunk) {
- this.results = new SimplerHashMap();
- this.chunk = chunk;
- }
+ byte nameLength;
+ int temp;
+ long hash;
+
+ long i = chunk.start;
+ final long cl = chunk.start + chunk.length;
+ long word;
+ long hs;
+ long start;
+ byte c;
+ int slot;
+ long n;
+ ResultRow slotValue;
+
+ while (i < cl) {
+ start = i;
+ hash = 0;
+
+ word = UNSAFE.getLong(i);
+
+ while (true) {
+ n = word ^ ALL_SEMIS;
+ hs = (n - 0x0101010101010101L) & (~n & 0x8080808080808080L);
+ if (hs != 0)
+ break;
+ hash = (hash ^ word) * FXSEED;
+ i += 8;
+ word = UNSAFE.getLong(i);
+ }
- @Override
- public void run() {
- // no names bigger than this
- final byte[] nameBytes = new byte[100];
- short nameIndex = 0;
- int ot;
- // fnv hash
- int hash = FNV_32_INIT;
-
- long i = chunk.start;
- final long cl = chunk.start + chunk.length;
- while (i < cl) {
- byte c;
- while ((c = UNSAFE.getByte(i++)) != 0x3B /* semi-colon */) {
- nameBytes[nameIndex++] = c;
- hash ^= c;
- hash *= FNV_32_PRIME;
+ i += Long.numberOfTrailingZeros(hs) >> 3;
+
+ // hash of what's left ((hs >>> 7) - 1) masks off the bytes from word that are before the semicolon
+ hash = (hash ^ word & (hs >>> 7) - 1) * FXSEED;
+ nameLength = (byte) (i++ - start);
+
+ // temperature value follows
+ c = UNSAFE.getByte(i++);
+ // we know the val has to be between -99.9 and 99.8
+ // always with a single fractional digit
+ // represented as a byte array of either 4 or 5 characters
+ if (c != 0x2D /* minus sign */) {
+ // could be either n.x or nn.x
+ if (UNSAFE.getByte(i + 2) == 0xA) {
+ temp = (c - 48) * 10; // char 1
}
-
- // temperature value follows
- c = UNSAFE.getByte(i++);
- // we know the val has to be between -99.9 and 99.8
- // always with a single fractional digit
- // represented as a byte array of either 4 or 5 characters
- if (c == 0x2D /* minus sign */) {
- // could be either n.x or nn.x
- if (UNSAFE.getByte(i + 3) == 0xA) {
- ot = (UNSAFE.getByte(i++) - 48) * 10; // char 1
- }
- else {
- ot = (UNSAFE.getByte(i++) - 48) * 100; // char 1
- ot += (UNSAFE.getByte(i++) - 48) * 10; // char 2
- }
- i++; // skip dot
- ot += (UNSAFE.getByte(i++) - 48); // char 2
- ot = -ot;
+ else {
+ temp = (c - 48) * 100; // char 1
+ temp += (UNSAFE.getByte(i++) - 48) * 10; // char 2
+ }
+ temp += (UNSAFE.getByte(++i) - 48); // char 3
+ }
+ else {
+ // could be either n.x or nn.x
+ if (UNSAFE.getByte(i + 3) == 0xA) {
+ temp = (UNSAFE.getByte(i) - 48) * 10; // char 1
+ i += 2;
}
else {
- // could be either n.x or nn.x
- if (UNSAFE.getByte(i + 2) == 0xA) {
- ot = (c - 48) * 10; // char 1
- }
- else {
- ot = (c - 48) * 100; // char 1
- ot += (UNSAFE.getByte(i++) - 48) * 10; // char 2
- }
- i++; // skip dot
- ot += (UNSAFE.getByte(i++) - 48); // char 3
+ temp = (UNSAFE.getByte(i) - 48) * 100; // char 1
+ temp += (UNSAFE.getByte(i + 1) - 48) * 10; // char 2
+ i += 3;
+ }
+ temp += (UNSAFE.getByte(i) - 48); // char 2
+ temp = -temp;
+ }
+ i += 2;
+
+ // xor folding
+ slot = (int) (hash ^ hash >> 32) & 65535;
+
+ // Linear probe for open slot
+ while ((slotValue = slots[slot]) != null && (slotValue.keyLength != nameLength || !unsafeEquals(slotValue.keyStart, start, nameLength))) {
+ slot = (slot + 1) % MAPSIZE;
+ }
+
+ // existing
+ if (slotValue != null) {
+ slotValue.sum += temp;
+ slotValue.count++;
+ if (temp > slotValue.max) {
+ slotValue.max = temp;
+ continue;
}
+ if (temp < slotValue.min)
+ slotValue.min = temp;
- i++;// nl
- hash &= 65535;
- results.putOrMerge(nameBytes, nameIndex, hash, ot);
- // reset
- nameIndex = 0;
- hash = 0x811c9dc5;
}
+ else {
+ // new value
+ slots[slot] = new ResultRow(temp, start, nameLength);
+ }
+ }
- // merge results with overall results
- List<MapEntry> all = results.getAll();
- lock.lock();
- try {
- for (MapEntry me : all) {
- ResultRow rr;
- ResultRow lr = me.row;
- if ((rr = global.get(me.key)) != null) {
- rr.min = Math.min(rr.min, lr.min);
- rr.max = Math.max(rr.max, lr.max);
- rr.count += lr.count;
- rr.sum += lr.sum;
+ // merge results with overall results
+ ResultRow rr;
+ String key;
+ byte[] bytes;
+ lock.lock();
+ try {
+ for (ResultRow resultRow : slots) {
+ if (resultRow != null) {
+ bytes = new byte[resultRow.keyLength];
+ // copy the name bytes
+ UNSAFE.copyMemory(null, resultRow.keyStart, bytes, Unsafe.ARRAY_BYTE_BASE_OFFSET, resultRow.keyLength);
+ key = new String(bytes, StandardCharsets.UTF_8);
+ if ((rr = global.get(key)) != null) {
+ rr.min = Math.min(rr.min, resultRow.min);
+ rr.max = Math.max(rr.max, resultRow.max);
+ rr.count += resultRow.count;
+ rr.sum += resultRow.sum;
}
else {
- global.put(me.key, lr);
+ global.put(key, resultRow);
}
}
}
- finally {
- lock.unlock();
+ }
+ finally {
+ lock.unlock();
+ }
+
+ }
+
+ static boolean unsafeEquals(final long a_address, final long b_address, final byte b_length) {
+ // byte by byte comparisons are slow, so do as big chunks as possible
+ byte i = 0;
+ for (; i < (b_length & -8); i += 8) {
+ if (UNSAFE.getLong(a_address + i) != UNSAFE.getLong(b_address + i)) {
+ return false;
}
}
+ if (i == b_length)
+ return true;
+ return (UNSAFE.getLong(a_address + i) & masks[b_length - i]) == (UNSAFE.getLong(b_address + i) & masks[b_length - i]);
}
public static void main(String[] args) throws IOException, InterruptedException {
- FileChannel channel = new RandomAccessFile(FILE, "r").getChannel();
int numThreads = 1;
+ FileChannel channel = new RandomAccessFile(FILE, "r").getChannel();
if (channel.size() > 64000) {
numThreads = Runtime.getRuntime().availableProcessors();
}
- List<Chunk> chunks = getChunks(numThreads, channel);
- List<Thread> threads = new ArrayList<>();
- for (Chunk chunk : chunks) {
- Thread thread = new Thread(new CalculateTask(chunk));
+ Chunk[] chunks = getChunks(numThreads, channel);
+ Thread[] threads = new Thread[chunks.length];
+ for (int i = 0; i < chunks.length; i++) {
+ int finalI = i;
+ Thread thread = new Thread(() -> run(chunks[finalI]));
thread.setPriority(Thread.MAX_PRIORITY);
thread.start();
- threads.add(thread);
+ threads[i] = thread;
}
for (Thread t : threads) {
t.join();
}
- // create treemap just to sort
- System.out.println(new TreeMap<>(global));
+ System.out.println(global);
+ channel.close();
}
-
- record MapEntry(String key, ResultRow row) {
- }
-
- static class SimplerHashMap {
- // can't have more than 10000 unique keys but want to match max hash
- final int MAPSIZE = 65536;
- final ResultRow[] slots = new ResultRow[MAPSIZE];
- final byte[][] keys = new byte[MAPSIZE][];
-
- public void putOrMerge(final byte[] key, final short length, final int hash, final int temp) {
- int slot = hash;
- ResultRow slotValue;
-
- // Linear probe for open slot
- while ((slotValue = slots[slot]) != null && (keys[slot].length != length || !unsafeEquals(keys[slot], key, length))) {
- slot++;
- }
-
- // existing
- if (slotValue != null) {
- slotValue.min = Math.min(slotValue.min, temp);
- slotValue.max = Math.max(slotValue.max, temp);
- slotValue.sum += temp;
- slotValue.count++;
- return;
- }
-
- // new value
- slots[slot] = new ResultRow(temp);
- byte[] bytes = new byte[length];
- System.arraycopy(key, 0, bytes, 0, length);
- keys[slot] = bytes;
- }
-
- static boolean unsafeEquals(final byte[] a, final byte[] b, final short length) {
- // byte by byte comparisons are slow, so do as big chunks as possible
- final int baseOffset = Unsafe.ARRAY_BYTE_BASE_OFFSET;
-
- short i = 0;
- // round down to nearest power of 8
- for (; i < (length & -8); i += 8) {
- if (UNSAFE.getLong(a, i + baseOffset) != UNSAFE.getLong(b, i + baseOffset)) {
- return false;
- }
- }
- if (i == length) {
- return true;
- }
- // leftover ints
- for (; i < (length - i & -4); i += 4) {
- if (UNSAFE.getInt(a, i + baseOffset) != UNSAFE.getInt(b, i + baseOffset)) {
- return false;
- }
- }
- if (i == length) {
- return true;
- }
- // leftover shorts
- for (; i < (length - i & -2); i += 2) {
- if (UNSAFE.getShort(a, i + baseOffset) != UNSAFE.getShort(b, i + baseOffset)) {
- return false;
- }
- }
- if (i == length) {
- return true;
- }
- // leftover bytes
- for (; i < (length - i); i++) {
- if (UNSAFE.getByte(a, i + baseOffset) != UNSAFE.getByte(b, i + baseOffset)) {
- return false;
- }
- }
-
- return true;
- }
-
- // Get all pairs
- public List<MapEntry> getAll() {
- final List<MapEntry> result = new ArrayList<>(slots.length);
- for (int i = 0; i < slots.length; i++) {
- ResultRow slotValue = slots[i];
- if (slotValue != null) {
- result.add(new MapEntry(new String(keys[i], StandardCharsets.UTF_8), slotValue));
- }
- }
- return result;
- }
- }
-
}