diff options
author | Thomas Wuerthinger <thomas.wuerthinger@oracle.com> | 2024-01-31 09:34:15 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-01-31 09:34:15 +0100 |
commit | a5ce4ba77184d669e67d9633ee20c466ad167742 (patch) | |
tree | 6dcdd8a76da0e7b35d9e8b8a4d6a6526b4e1b5ad /src | |
parent | 7f0e51781190bcfcd4b8624352230a19737b01c8 (diff) |
Added comments to used flags, clean up code, final fine tuning. (#674)
Diffstat (limited to 'src')
-rw-r--r-- | src/main/java/dev/morling/onebrc/CalculateAverage_thomaswue.java | 234 |
1 files changed, 100 insertions, 134 deletions
diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_thomaswue.java b/src/main/java/dev/morling/onebrc/CalculateAverage_thomaswue.java index 9b21f91..dc4df0c 100644 --- a/src/main/java/dev/morling/onebrc/CalculateAverage_thomaswue.java +++ b/src/main/java/dev/morling/onebrc/CalculateAverage_thomaswue.java @@ -27,9 +27,7 @@ import java.util.concurrent.atomic.AtomicLong; * split into 3 parts and cursors for each of those parts are processing the segment simultaneously in the same thread. * Results are accumulated into {@link Result} objects and a tree map is used to sequentially accumulate the results in * the end. - * - * Runs in 0.40s on an Intel i9-13900K. - * + * Runs in 0.39s on an Intel i9-13900K. * Credit: * Quan Anh Mai for branchless number parsing code * Alfonso² Peterssen for suggesting memory mapping with unsafe and the subprocess idea @@ -103,49 +101,111 @@ public class CalculateAverage_thomaswue { return result; } - private static Result findResult(long initialWord, long initialPos, Scanner scanner, Result[] results, List<Result> collectedResults) { + private static void parseLoop(AtomicLong counter, long fileEnd, long fileStart, List<Result> collectedResults) { + Result[] results = new Result[HASH_TABLE_SIZE]; + while (true) { + long current = counter.addAndGet(SEGMENT_SIZE) - SEGMENT_SIZE; + if (current >= fileEnd) { + return; + } + + long segmentEnd = nextNewLine(Math.min(fileEnd - 1, current + SEGMENT_SIZE)); + long segmentStart; + if (current == fileStart) { + segmentStart = current; + } + else { + segmentStart = nextNewLine(current) + 1; + } + + long dist = (segmentEnd - segmentStart) / 3; + long midPoint1 = nextNewLine(segmentStart + dist); + long midPoint2 = nextNewLine(segmentStart + dist + dist); + + Scanner scanner1 = new Scanner(segmentStart, midPoint1); + Scanner scanner2 = new Scanner(midPoint1 + 1, midPoint2); + Scanner scanner3 = new Scanner(midPoint2 + 1, segmentEnd); + while (true) { + if (!scanner1.hasNext()) { + break; + } + if (!scanner2.hasNext()) { + break; + } + if (!scanner3.hasNext()) { + break; + } + long word1 = scanner1.getLong(); + long word2 = scanner2.getLong(); + long word3 = scanner3.getLong(); + long delimiterMask1 = findDelimiter(word1); + long delimiterMask2 = findDelimiter(word2); + long delimiterMask3 = findDelimiter(word3); + Result existingResult1 = findResult(word1, delimiterMask1, scanner1, results, collectedResults); + Result existingResult2 = findResult(word2, delimiterMask2, scanner2, results, collectedResults); + Result existingResult3 = findResult(word3, delimiterMask3, scanner3, results, collectedResults); + long number1 = scanNumber(scanner1); + long number2 = scanNumber(scanner2); + long number3 = scanNumber(scanner3); + record(existingResult1, number1); + record(existingResult2, number2); + record(existingResult3, number3); + } + + while (scanner1.hasNext()) { + long word = scanner1.getLong(); + long pos = findDelimiter(word); + record(findResult(word, pos, scanner1, results, collectedResults), scanNumber(scanner1)); + } + while (scanner2.hasNext()) { + long word = scanner2.getLong(); + long pos = findDelimiter(word); + record(findResult(word, pos, scanner2, results, collectedResults), scanNumber(scanner2)); + } + while (scanner3.hasNext()) { + long word = scanner3.getLong(); + long pos = findDelimiter(word); + record(findResult(word, pos, scanner3, results, collectedResults), scanNumber(scanner3)); + } + } + } + + private static Result findResult(long initialWord, long initialDelimiterMask, Scanner scanner, Result[] results, List<Result> collectedResults) { Result existingResult; long word = initialWord; - long pos = initialPos; + long delimiterMask = initialDelimiterMask; long hash; long nameAddress = scanner.pos(); // Search for ';', one long at a time. There are two common cases that a specially treated: // (b) the ';' is found in the first 16 bytes - if (pos != 0) { + if (delimiterMask != 0) { // Special case for when the ';' is found in the first 8 bytes. - pos = Long.numberOfTrailingZeros(pos) >>> 3; - scanner.add(pos); - word = mask(word, pos); + int trailingZeros = Long.numberOfTrailingZeros(delimiterMask); + word = (word << (63 - trailingZeros)); + scanner.add(trailingZeros >>> 3); hash = word; - - int index = hashToIndex(hash, results); - existingResult = results[index]; - + existingResult = results[hashToIndex(hash, results)]; if (existingResult != null && existingResult.lastNameLong == word) { return existingResult; } - scanner.setPos(nameAddress + pos); } else { // Special case for when the ';' is found in bytes 9-16. - scanner.add(8); hash = word; long prevWord = word; + scanner.add(8); word = scanner.getLong(); - pos = findDelimiter(word); - if (pos != 0) { - pos = Long.numberOfTrailingZeros(pos) >>> 3; - scanner.add(pos); - word = mask(word, pos); + delimiterMask = findDelimiter(word); + if (delimiterMask != 0) { + int trailingZeros = Long.numberOfTrailingZeros(delimiterMask); + word = (word << (63 - trailingZeros)); + scanner.add(trailingZeros >>> 3); hash ^= word; - int index = hashToIndex(hash, results); - existingResult = results[index]; - + existingResult = results[hashToIndex(hash, results)]; if (existingResult != null && existingResult.lastNameLong == word && existingResult.secondLastNameLong == prevWord) { return existingResult; } - scanner.setPos(nameAddress + pos + 8); } else { // Slow-path for when the ';' could not be found in the first 16 bytes. @@ -153,11 +213,11 @@ public class CalculateAverage_thomaswue { hash ^= word; while (true) { word = scanner.getLong(); - pos = findDelimiter(word); - if (pos != 0) { - pos = Long.numberOfTrailingZeros(pos) >>> 3; - scanner.add(pos); - word = mask(word, pos); + delimiterMask = findDelimiter(word); + if (delimiterMask != 0) { + int trailingZeros = Long.numberOfTrailingZeros(delimiterMask); + word = (word << (63 - trailingZeros)); + scanner.add(trailingZeros >>> 3); hash ^= word; break; } @@ -204,7 +264,8 @@ public class CalculateAverage_thomaswue { private static long nextNewLine(long prev) { while (true) { long currentWord = Scanner.UNSAFE.getLong(prev); - long pos = findNewLine(currentWord); + long input = currentWord ^ 0x0A0A0A0A0A0A0A0AL; + long pos = (input - 0x0101010101010101L) & ~input & 0x8080808080808080L; if (pos != 0) { prev += Long.numberOfTrailingZeros(pos) >>> 3; break; @@ -216,87 +277,11 @@ public class CalculateAverage_thomaswue { return prev; } - // Main parse loop. - private static Result[] parseLoop(AtomicLong counter, long fileEnd, long fileStart, List<Result> collectedResults) { - Result[] results = new Result[HASH_TABLE_SIZE]; - - while (true) { - long current = counter.addAndGet(SEGMENT_SIZE) - SEGMENT_SIZE; - - if (current >= fileEnd) { - return results; - } - - long segmentEnd = nextNewLine(Math.min(fileEnd - 1, current + SEGMENT_SIZE)); - long segmentStart; - if (current == fileStart) { - segmentStart = current; - } - else { - segmentStart = nextNewLine(current) + 1; - } - - long dist = (segmentEnd - segmentStart) / 3; - long midPoint1 = nextNewLine(segmentStart + dist); - long midPoint2 = nextNewLine(segmentStart + dist + dist); - - Scanner scanner1 = new Scanner(segmentStart, midPoint1); - Scanner scanner2 = new Scanner(midPoint1 + 1, midPoint2); - Scanner scanner3 = new Scanner(midPoint2 + 1, segmentEnd); - while (true) { - if (!scanner1.hasNext()) { - break; - } - if (!scanner2.hasNext()) { - break; - } - if (!scanner3.hasNext()) { - break; - } - - long word1 = scanner1.getLong(); - long word2 = scanner2.getLong(); - long word3 = scanner3.getLong(); - long pos1 = findDelimiter(word1); - long pos2 = findDelimiter(word2); - long pos3 = findDelimiter(word3); - Result existingResult1 = findResult(word1, pos1, scanner1, results, collectedResults); - Result existingResult2 = findResult(word2, pos2, scanner2, results, collectedResults); - Result existingResult3 = findResult(word3, pos3, scanner3, results, collectedResults); - long number1 = scanNumber(scanner1); - long number2 = scanNumber(scanner2); - long number3 = scanNumber(scanner3); - record(existingResult1, number1); - record(existingResult2, number2); - record(existingResult3, number3); - } - - while (scanner1.hasNext()) { - long word = scanner1.getLong(); - long pos = findDelimiter(word); - record(findResult(word, pos, scanner1, results, collectedResults), scanNumber(scanner1)); - } - - while (scanner2.hasNext()) { - long word = scanner2.getLong(); - long pos = findDelimiter(word); - record(findResult(word, pos, scanner2, results, collectedResults), scanNumber(scanner2)); - } - - while (scanner3.hasNext()) { - long word = scanner3.getLong(); - long pos = findDelimiter(word); - record(findResult(word, pos, scanner3, results, collectedResults), scanNumber(scanner3)); - } - } - } - private static long scanNumber(Scanner scanPtr) { - scanPtr.add(1); - long numberWord = scanPtr.getLong(); - int decimalSepPos = Long.numberOfTrailingZeros(~numberWord & 0x10101000); + long numberWord = scanPtr.getLongAt(scanPtr.pos() + 1); + int decimalSepPos = Long.numberOfTrailingZeros(~numberWord & 0x10101000L); long number = convertIntoNumber(decimalSepPos, numberWord); - scanPtr.add((decimalSepPos >>> 3) + 3); + scanPtr.add((decimalSepPos >>> 3) + 4); return number; } @@ -316,10 +301,6 @@ public class CalculateAverage_thomaswue { return (int) (hashAsInt & (results.length - 1)); } - private static long mask(long word, long pos) { - return (word << ((7 - pos) << 3)); - } - // Special method to convert a number in the ascii number into an int without branches created by Quan Anh Mai. private static long convertIntoNumber(int decimalSepPos, long numberWord) { int shift = 28 - decimalSepPos; @@ -337,14 +318,7 @@ public class CalculateAverage_thomaswue { private static long findDelimiter(long word) { long input = word ^ 0x3B3B3B3B3B3B3B3BL; - long tmp = (input - 0x0101010101010101L) & ~input & 0x8080808080808080L; - return tmp; - } - - private static long findNewLine(long word) { - long input = word ^ 0x0A0A0A0A0A0A0A0AL; - long tmp = (input - 0x0101010101010101L) & ~input & 0x8080808080808080L; - return tmp; + return (input - 0x0101010101010101L) & ~input & 0x8080808080808080L; } private static Result newEntry(Result[] results, long nameAddress, int hash, int nameLength, Scanner scanner, List<Result> collectedResults) { @@ -357,14 +331,13 @@ public class CalculateAverage_thomaswue { r.secondLastNameLong = scanner.getLongAt(nameAddress + i - 8); } int remainingShift = (64 - (nameLength + 1 - i) << 3); - long lastWord = (scanner.getLongAt(nameAddress + i) << remainingShift); - r.lastNameLong = lastWord; + r.lastNameLong = (scanner.getLongAt(nameAddress + i) << remainingShift); r.nameAddress = nameAddress; collectedResults.add(r); return r; } - private static class Result { + private static final class Result { long lastNameLong, secondLastNameLong; short min, max; int count; @@ -409,9 +382,10 @@ public class CalculateAverage_thomaswue { } } - private static class Scanner { + private static final class Scanner { private static final sun.misc.Unsafe UNSAFE = initUnsafe(); - private long pos, end; + private long pos; + private final long end; private static sun.misc.Unsafe initUnsafe() { try { @@ -452,13 +426,5 @@ public class CalculateAverage_thomaswue { byte getByteAt(long pos) { return UNSAFE.getByte(pos); } - - long getLongAt(long pos, long[] array) { - return UNSAFE.getLong(array, pos + sun.misc.Unsafe.ARRAY_LONG_BASE_OFFSET); - } - - void setPos(long l) { - this.pos = l; - } } }
\ No newline at end of file |