summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorYavuz Tas <12643010+yavuztas@users.noreply.github.com>2024-01-29 21:02:20 +0100
committerGitHub <noreply@github.com>2024-01-29 21:02:20 +0100
commita82cf2ceb7f3832a953f1a68162b03128765612b (patch)
tree1974c11b0cfd346bc230de29ee57a3a23fcb24c9 /src
parent886f0cdb4df4fdbee7c99fb1fdd5c72d9608d635 (diff)
improve speed, thanks to the following improvements: (#550)
* improve speed, thanks to the following improvements: - loop unrolling and eleminating extra calculations - eleminating instance level variable access - quicker equals check, checking long by long chunks instead of bytes - update GraalVM version to the latest * faster equals check * fix equals bug in 10K, more optimizations on equals and calculate hash parts * New solution optimized for Linux/AMD hardware * Optimize solution, try to fix 10K bug on native * Optimize solution, move records to a local field * test timing * revert back accidentally pushed code --------- Co-authored-by: Yavuz Tas <yavuz.tas@ing.com>
Diffstat (limited to 'src')
-rw-r--r--src/main/java/dev/morling/onebrc/CalculateAverage_yavuztas.java507
1 files changed, 319 insertions, 188 deletions
diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_yavuztas.java b/src/main/java/dev/morling/onebrc/CalculateAverage_yavuztas.java
index e33fe7e..0e589a4 100644
--- a/src/main/java/dev/morling/onebrc/CalculateAverage_yavuztas.java
+++ b/src/main/java/dev/morling/onebrc/CalculateAverage_yavuztas.java
@@ -17,15 +17,16 @@ package dev.morling.onebrc;
import sun.misc.Unsafe;
-import java.io.IOException;
import java.lang.foreign.Arena;
import java.lang.reflect.Field;
-import java.nio.ByteBuffer;
import java.nio.channels.FileChannel;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.StandardOpenOption;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
import java.util.TreeMap;
import java.util.function.Consumer;
@@ -35,8 +36,9 @@ public class CalculateAverage_yavuztas {
private static final Unsafe UNSAFE = unsafe();
- // Tried all there: MappedByteBuffer, MemorySegment and Unsafe
- // Accessing the memory using Unsafe is still the fastest in my experience
+ // I compared all three: MappedByteBuffer, MemorySegment and Unsafe.
+ // Accessing the memory using Unsafe is still the fastest in my experience.
+ // However, I would never use it in production, single programming error will crash your app.
private static Unsafe unsafe() {
try {
final Field f = Unsafe.class.getDeclaredField("theUnsafe");
@@ -48,296 +50,419 @@ public class CalculateAverage_yavuztas {
}
}
- // Only one object, both for measurements and keys, less object creation in hotpots is always faster
- static class Record {
-
- // keep memory starting address for each segment
- // since we use Unsafe, this is enough to align and fetch the data
- long segment;
- int start;
- int length;
- int hash;
+ /**
+ * Extract bytes from a long
+ */
+ private static long partial(long word, int length) {
+ final long mask = (~0L) << (length << 3);
+ return word & (~mask);
+ }
- private int min = 1000; // calculations over int is faster than double, we convert to double in the end only once
- private int max = -1000;
+ // Only one object, both for measurements and keys, less object creation in hotpots is always faster
+ private static final class Record {
+
+ private final long start; // memory address of the underlying data
+ private final int length;
+ private final long word1;
+ private final long word2;
+ private final long wordLast;
+ private final int hash;
+ private Record next; // linked list to resolve hash collisions
+
+ private int min; // calculations over int is faster than double, we convert to double in the end only once
+ private int max;
private long sum;
- private long count;
+ private int count;
- public Record(long segment, int start, int length, int hash) {
- this.segment = segment;
+ public Record(long start, int length, long word1, long word2, long wordLast, int hash, int temp) {
this.start = start;
this.length = length;
+ this.word1 = word1;
+ this.word2 = word2;
+ this.wordLast = wordLast;
this.hash = hash;
+ this.min = temp;
+ this.max = temp;
+ this.sum = temp;
+ this.count = 1;
}
@Override
public boolean equals(Object o) {
final Record record = (Record) o;
- return equals(record.segment, record.start, record.length, record.hash);
+ return equals(record.start, record.word1, record.word2, record.wordLast, record.length);
}
- /**
- * Stateless equals, no Record object needed
- */
- public boolean equals(long segment, int start, int length, int hash) {
- if (this.length != length || this.hash != hash)
- return false;
+ private static boolean notEquals(long address1, long address2, int step) {
+ return UNSAFE.getLong(address1 + step) != UNSAFE.getLong(address2 + step);
+ }
- int i = 0; // bytes mismatch check
- while (i < this.length
- && UNSAFE.getByte(this.segment + this.start + i) == UNSAFE.getByte(segment + start + i)) {
- i++;
+ private static boolean equalsComparingLongs(long start1, long start2, int length) {
+ // first shortcuts
+ if (length < 24)
+ return true;
+ if (length < 32)
+ return !notEquals(start1, start2, 16);
+
+ int step = 24; // starting from 3rd long
+ length -= step;
+ while (length >= 8) { // scan longs
+ if (notEquals(start1, start2, step)) {
+ return false;
+ }
+ length -= 8;
+ step += 8; // 8 bytes
}
- return i == this.length;
+ return true;
}
- @Override
- public int hashCode() {
- return this.hash;
+ private boolean equals(long start, long word1, long word2, long last, int length) {
+ if (this.word1 != word1)
+ return false;
+ if (this.word2 != word2)
+ return false;
+
+ // equals check is done by comparing longs instead of byte by byte check, this is faster
+ return equalsComparingLongs(this.start, start, length) && this.wordLast == last;
}
@Override
public String toString() {
final byte[] bytes = new byte[this.length];
- int i = 0;
- while (i < this.length) {
- bytes[i] = UNSAFE.getByte(this.segment + this.start + i++);
- }
-
+ UNSAFE.copyMemory(null, this.start, bytes, Unsafe.ARRAY_BYTE_BASE_OFFSET, this.length);
return new String(bytes, StandardCharsets.UTF_8);
}
- public Record collect(int temp) {
- this.min = Math.min(this.min, temp);
- this.max = Math.max(this.max, temp);
+ private void collect(int temp) {
+ if (temp < this.min)
+ this.min = temp;
+ if (temp > this.max)
+ this.max = temp;
this.sum += temp;
this.count++;
- return this;
}
- public void merge(Record other) {
- this.min = Math.min(this.min, other.min);
- this.max = Math.max(this.max, other.max);
- this.sum += other.sum;
- this.count += other.count;
+ private void merge(Record that) {
+ if (that.min < this.min)
+ this.min = that.min;
+ if (that.max > this.max)
+ this.max = that.max;
+ this.sum += that.sum;
+ this.count += that.count;
}
- public String measurements() {
+ private String measurements() {
// here is only executed once for each unique key, so StringBuilder creation doesn't harm
final StringBuilder sb = new StringBuilder(14);
- sb.append(this.min / 10.0);
- sb.append("/");
- sb.append(round((this.sum / 10.0) / this.count));
- sb.append("/");
- sb.append(this.max / 10.0);
+ sb.append(round(this.min)).append("/");
+ sb.append(round(1.0 * this.sum / this.count)).append("/");
+ sb.append(round(this.max));
return sb.toString();
}
}
// Inspired by @spullara - customized hashmap on purpose
- // The main difference is we hold only one array instead of two
- static class RecordMap {
+ // The main difference is we hold only one array instead of two, fewer objects is faster
+ private static final class RecordMap {
- static final int SIZE = 1 << 15; // 32k - bigger bucket size less collisions
- static final int BITMASK = SIZE - 1;
- Record[] keys = new Record[SIZE];
+ // Bigger bucket size less collisions, but you have to find a sweet spot otherwise it is becoming slower.
+ // Also works good enough for 10K stations
+ private static final int SIZE = 1 << 14; // 16kb - enough for 10K
+ private static final int BITMASK = SIZE - 1;
+ private final Record[] keys = new Record[SIZE];
- static int hashBucket(int hash) {
+ // int collision;
+
+ private boolean hasNoRecord(int index) {
+ return this.keys[index] == null;
+ }
+
+ private Record getRecord(int index) {
+ return this.keys[index];
+ }
+
+ private static int hashBucket(int hash) {
hash = hash ^ (hash >>> 16); // naive bit spreading but surprisingly decreases collision :)
return hash & BITMASK; // fast modulo, to find bucket
}
- void putAndCollect(long segment, int start, int length, int hash, int temp) {
- int bucket = hashBucket(hash);
- Record existing = this.keys[bucket];
- if (existing == null) {
- this.keys[bucket] = new Record(segment, start, length, hash)
- .collect(temp);
+ private void putAndCollect(int hash, int temp, long start, int length, long word1, long word2, long wordLast) {
+ final int bucket = hashBucket(hash);
+ if (hasNoRecord(bucket)) {
+ this.keys[bucket] = new Record(start, length, word1, word2, wordLast, hash, temp);
return;
}
- if (!existing.equals(segment, start, length, hash)) {
- // collision, linear probing to find a slot
- while ((existing = this.keys[++bucket & BITMASK]) != null && !existing.equals(segment, start, length, hash)) {
- // can be stuck here if all the buckets are full :(
- // However, since the data set is max 10K (unique) this shouldn't happen
- // So, I'm happily leave here branchless :)
- }
- if (existing == null) {
- this.keys[bucket & BITMASK] = new Record(segment, start, length, hash)
- .collect(temp);
- return;
- }
+ Record existing = getRecord(bucket);
+ if (existing.equals(start, word1, word2, wordLast, length)) {
existing.collect(temp);
+ return;
}
- else {
- existing.collect(temp);
+
+ // collision++;
+ // find possible slot by scanning the slot linked list
+ while (existing.next != null) {
+ if (existing.next.equals(start, word1, word2, wordLast, length)) {
+ existing.next.collect(temp);
+ return;
+ }
+ existing = existing.next; // go on to next
+ // collision++;
}
+ existing.next = new Record(start, length, word1, word2, wordLast, hash, temp);
}
- void putOrMerge(Record key) {
- int bucket = hashBucket(key.hash);
- Record existing = this.keys[bucket];
- if (existing == null) {
+ private void putOrMerge(Record key) {
+ final int bucket = hashBucket(key.hash);
+ if (hasNoRecord(bucket)) {
+ key.next = null;
this.keys[bucket] = key;
return;
}
- if (!existing.equals(key)) {
- // collision, linear probing to find a slot
- while ((existing = this.keys[++bucket & BITMASK]) != null && !existing.equals(key)) {
- // can be stuck here if all the buckets are full :(
- // However, since the data set is max 10K (unique keys) this shouldn't happen
- // So, I'm happily leave here branchless :)
- }
- if (existing == null) {
- this.keys[bucket & BITMASK] = key;
- return;
- }
+ Record existing = getRecord(bucket);
+ if (existing.equals(key)) {
existing.merge(key);
+ return;
}
- else {
- existing.merge(key);
+
+ // collision++;
+ // find possible slot by scanning the slot linked list
+ while (existing.next != null) {
+ if (existing.next.equals(key)) {
+ existing.next.merge(key);
+ return;
+ }
+ existing = existing.next; // go on to next
+ // collision++;
}
+ key.next = null;
+ existing.next = key;
}
- void forEach(Consumer<Record> consumer) {
+ private void forEach(Consumer<Record> consumer) {
int pos = 0;
Record key;
- while (pos < this.keys.length) {
+ while (pos < SIZE) {
if ((key = this.keys[pos++]) == null) {
continue;
}
+ Record next = key.next;
consumer.accept(key);
+ while (next != null) { // also traverse the records in the collision list
+ final Record tmp = next.next;
+ consumer.accept(next);
+ next = tmp;
+ }
}
}
- void merge(RecordMap other) {
+ private void merge(RecordMap other) {
other.forEach(this::putOrMerge);
}
}
// One actor for one thread, no synchronization
- static class RegionActor {
-
- final FileChannel channel;
- final long startPos;
- final int size;
- final RecordMap map = new RecordMap();
- long segmentAddress;
- int position;
- Thread runner; // each actor has its own thread
-
- public RegionActor(FileChannel channel, long startPos, int size) {
- this.channel = channel;
+ private static final class RegionActor extends Thread {
+
+ private final long startPos; // start of region memory address
+ private final int size;
+
+ private final RecordMap map = new RecordMap();
+
+ public RegionActor(long startPos, int size) {
this.startPos = startPos;
this.size = size;
}
- void accumulate() {
- this.runner = new Thread(() -> {
- try {
- // get the segment memory address, this is the only thing we need for Unsafe
- this.segmentAddress = this.channel.map(FileChannel.MapMode.READ_ONLY, this.startPos, this.size, Arena.global()).address();
- }
- catch (IOException e) {
- // no-op - skip intentionally, no handling for the purpose of this challenge
- }
-
- int start;
- int keyHash;
- int length;
- while (this.position < this.size) {
- byte b;
- start = this.position; // save line start position
- keyHash = UNSAFE.getByte(this.segmentAddress + this.position++); // first byte is guaranteed not to be ';'
- length = 1; // min key length
- while ((b = UNSAFE.getByte(this.segmentAddress + this.position++)) != ';') { // read until semicolon
- keyHash = calculateHash(keyHash, b); // calculate key hash ahead, eleminates one more loop later
- length++;
- }
+ private static long getWord(long address) {
+ return UNSAFE.getLong(address);
+ }
- final int temp = readTemperature();
- this.map.putAndCollect(this.segmentAddress, start, length, keyHash, temp);
+ // hasvalue & haszero
+ // adapted from https://graphics.stanford.edu/~seander/bithacks.html#ZeroInWord
+ private static long hasSemicolon(long word) {
+ // semicolon pattern
+ final long hasVal = word ^ 0x3B3B3B3B3B3B3B3BL; // hasvalue
+ return ((hasVal - 0x0101010101010101L) & ~hasVal & 0x8080808080808080L); // haszero
+ }
- this.position++; // skip linebreak
- }
- });
- this.runner.start();
+ private static int semicolonPos(long hasVal) {
+ return Long.numberOfTrailingZeros(hasVal) >>> 3;
}
- static int calculateHash(int hash, int b) {
- return 31 * hash + b;
+ private static int decimalPos(long numberWord) {
+ return Long.numberOfTrailingZeros(~numberWord & 0x10101000);
}
- // 1. Inspired by @yemreinci - Reading temparature value without Double.parse
- // 2. Inspired by @obourgain - Fetching first 4 bytes ahead, then masking
- int readTemperature() {
- int temp = 0;
- // read 4 bytes ahead
- final int first4 = UNSAFE.getInt(this.segmentAddress + this.position);
- this.position += 3;
+ private static final int MAX_INNER_LOOP_SIZE = 11;
- final byte b1 = (byte) first4; // first byte
- final byte b2 = (byte) ((first4 >> 8) & 0xFF); // second byte
- final byte b3 = (byte) ((first4 >> 16) & 0xFF); // third byte
- if (b1 == '-') {
- if (b3 == '.') {
- temp -= 10 * (b2 - '0') + (byte) ((first4 >> 24) & 0xFF) - '0'; // fourth byte
- this.position++;
- }
- else {
- this.position++; // skip dot
- temp -= 100 * (b2 - '0') + 10 * (b3 - '0') + UNSAFE.getByte(this.segmentAddress + this.position++) - '0'; // fifth byte
- }
- }
- else {
- if (b2 == '.') {
- temp = 10 * (b1 - '0') + b3 - '0';
+ @Override
+ public void run() {
+ long pointer = this.startPos;
+ final long size = pointer + this.size;
+ while (pointer < size) { // line start
+ long hash = 0; // reset hash
+ long s; // semicolon check word
+ final int pos; // semicolon position
+ long word1 = getWord(pointer);
+ if ((s = hasSemicolon(word1)) != 0) {
+ pos = semicolonPos(s);
+ // read temparature
+ final long numberWord = getWord(pointer + pos + 1);
+ final int decimalPos = decimalPos(numberWord);
+ final int temp = convertIntoNumber(decimalPos, numberWord);
+
+ word1 = partial(word1, pos); // last word
+ this.map.putAndCollect(completeHash(hash, word1), temp, pointer, pos, word1, 0, 0);
+
+ pointer += pos + (decimalPos >>> 3) + 4;
}
else {
- temp = 100 * (b1 - '0') + 10 * (b2 - '0') + (byte) ((first4 >> 24) & 0xFF) - '0'; // fourth byte
- this.position++;
+ long word2 = getWord(pointer + 8);
+ if ((s = hasSemicolon(word2)) != 0) {
+ pos = semicolonPos(s);
+ // read temparature
+ final int length = pos + 8;
+ final long numberWord = getWord(pointer + length + 1);
+ final int decimalPos = decimalPos(numberWord);
+ final int temp = convertIntoNumber(decimalPos, numberWord);
+
+ word2 = partial(word2, pos); // last word
+ this.map.putAndCollect(completeHash(hash, word1, word2), temp, pointer, length, word1, word2, 0);
+
+ pointer += length + (decimalPos >>> 3) + 4; // seek to the line end
+ }
+ else {
+ long word = 0;
+ int length = 16;
+ hash = appendHash(hash, word1, word2);
+ // Let the compiler know the loop size ahead
+ // Then it's automatically unrolled
+ // Max key length is 13 longs, 2 we've read before, 11 left
+ for (int i = 0; i < MAX_INNER_LOOP_SIZE; i++) {
+ if ((s = hasSemicolon((word = getWord(pointer + length)))) != 0) {
+ break;
+ }
+ hash = appendHash(hash, word);
+ length += 8;
+ }
+
+ pos = semicolonPos(s);
+ length += pos;
+ // read temparature
+ final long numberWord = getWord(pointer + length + 1);
+ final int decimalPos = decimalPos(numberWord);
+ final int temp = convertIntoNumber(decimalPos, numberWord);
+
+ word = partial(word, pos); // last word
+ this.map.putAndCollect(completeHash(hash, word), temp, pointer, length, word1, word2, word);
+
+ pointer += length + (decimalPos >>> 3) + 4; // seek to the line end
+ }
}
}
+ }
+
+ // Hashes are calculated by a Mersenne Prime (1 << 7) -1
+ // This is faster than multiplication in some machines
+ private static long appendHash(long hash, long word) {
+ return (hash << 7) - hash + word;
+ }
+
+ private static long appendHash(long hash, long word1, long word2) {
+ hash = (hash << 7) - hash + word1;
+ return (hash << 7) - hash + word2;
+ }
+
+ private static int completeHash(long hash, long partial) {
+ hash = (hash << 7) - hash + partial;
+ return (int) (hash ^ (hash >>> 25));
+ }
+
+ private static int completeHash(long hash, long word1, long word2) {
+ hash = (hash << 7) - hash + word1;
+ hash = (hash << 7) - hash + word2;
+ return (int) hash ^ (int) (hash >>> 25);
+ }
- return temp;
+ // Credits to @merrykitty. Magical solution to parse temparature values branchless!
+ // Taken as without modification, comments belong to @merrykitty
+ private static int convertIntoNumber(int decimalSepPos, long numberWord) {
+ final int shift = 28 - decimalSepPos;
+ // signed is -1 if negative, 0 otherwise
+ final long signed = (~numberWord << 59) >> 63;
+ final long designMask = ~(signed & 0xFF);
+ // Align the number to a specific position and transform the ascii code
+ // to actual digit value in each byte
+ final long digits = ((numberWord & designMask) << shift) & 0x0F000F0F00L;
+ // Now digits is in the form 0xUU00TTHH00 (UU: units digit, TT: tens digit, HH: hundreds digit)
+ // 0xUU00TTHH00 * (100 * 0x1000000 + 10 * 0x10000 + 1) =
+ // 0x000000UU00TTHH00 +
+ // 0x00UU00TTHH000000 * 10 +
+ // 0xUU00TTHH00000000 * 100
+ // Now TT * 100 has 2 trailing zeroes and HH * 100 + TT * 10 + UU < 0x400
+ // This results in our value lies in the bit 32 to 41 of this product
+ // That was close :)
+ final long absValue = ((digits * 0x640a0001) >>> 32) & 0x3FF;
+ final long value = (absValue ^ signed) - signed;
+ return (int) value;
}
/**
* blocks until the map is fully collected
*/
- RecordMap get() throws InterruptedException {
- this.runner.join();
+ private RecordMap get() throws InterruptedException {
+ join();
return this.map;
}
}
private static double round(double value) {
- return Math.round(value * 10.0) / 10.0;
+ return Math.round(value) / 10.0;
}
/**
* Scans the given buffer to the left
*/
- private static long findClosestLineEnd(long start, int size, FileChannel channel) throws IOException {
- final long position = start + size;
- final long left = Math.max(position - 101, 0);
- final ByteBuffer buffer = ByteBuffer.allocate(101); // enough size to find at least one '\n'
- if (channel.read(buffer.clear(), left) != -1) {
- int bufferPos = buffer.position() - 1;
- while (buffer.get(bufferPos) != '\n') {
- bufferPos--;
- size--;
- }
+ private static long findClosestLineEnd(long start, int size) {
+ long position = start + size;
+ while (UNSAFE.getByte(--position) != '\n') {
+ // read until a linebreak
+ size--;
}
return size;
}
- public static void main(String[] args) throws IOException, InterruptedException {
+ private static boolean isWorkerProcess(String[] args) {
+ return Arrays.asList(args).contains("--worker");
+ }
+
+ private static void runAsWorker() throws Exception {
+ final ProcessHandle.Info info = ProcessHandle.current().info();
+ final List<String> commands = new ArrayList<>();
+ info.command().ifPresent(commands::add);
+ info.arguments().ifPresent(args -> commands.addAll(Arrays.asList(args)));
+ commands.add("--worker");
+
+ new ProcessBuilder()
+ .command(commands)
+ .start()
+ .getInputStream()
+ .transferTo(System.out);
+ }
+
+ public static void main(String[] args) throws Exception {
- var concurrency = Runtime.getRuntime().availableProcessors();
+ // Dased on @thomaswue's idea, to cut unmapping delay.
+ // Strangely, unmapping delay doesn't occur on macOS/M1 however in Linux/AMD it's substantial - ~200ms
+ if (!isWorkerProcess(args)) {
+ runAsWorker();
+ return;
+ }
+
+ var concurrency = 2 * Runtime.getRuntime().availableProcessors();
final long fileSize = Files.size(FILE);
long regionSize = fileSize / concurrency;
@@ -353,30 +478,36 @@ public class CalculateAverage_yavuztas {
long startPos = 0;
final FileChannel channel = (FileChannel) Files.newByteChannel(FILE, StandardOpenOption.READ);
+ // get the memory address, this is the only thing we need for Unsafe
+ final long memoryAddress = channel.map(FileChannel.MapMode.READ_ONLY, startPos, fileSize, Arena.global()).address();
+
final RegionActor[] actors = new RegionActor[concurrency];
for (int i = 0; i < concurrency; i++) {
// calculate boundaries
long maxSize = (startPos + regionSize > fileSize) ? fileSize - startPos : regionSize;
// shift position to back until we find a linebreak
- maxSize = findClosestLineEnd(startPos, (int) maxSize, channel);
+ maxSize = findClosestLineEnd(memoryAddress + startPos, (int) maxSize);
- final RegionActor region = (actors[i] = new RegionActor(channel, startPos, (int) maxSize));
- region.accumulate();
+ final RegionActor region = (actors[i] = new RegionActor(memoryAddress + startPos, (int) maxSize));
+ region.start(); // start processing
startPos += maxSize;
}
- final RecordMap output = new RecordMap(); // output to merge all regions
+ final RecordMap output = new RecordMap(); // output to merge all records
for (RegionActor actor : actors) {
final RecordMap partial = actor.get(); // blocks until get the result
output.merge(partial);
+ // System.out.println("collisions: " + partial.collision);
}
// sort and print the result
final TreeMap<String, String> sorted = new TreeMap<>();
- output.forEach(key -> sorted.put(key.toString(), key.measurements()));
+ output.forEach(key -> {
+ sorted.put(key.toString(), key.measurements());
+ });
System.out.println(sorted);
-
+ System.out.close(); // closing the stream will trigger the main process to pick up the output early
}
}