summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAnita SV <anitasvasu@gmail.com>2024-02-01 03:15:23 -0800
committerGitHub <noreply@github.com>2024-02-01 12:15:23 +0100
commit101993f06d1e63e3d56ab57483ff11a3349c47aa (patch)
tree0eab4b2c52362a2720a14067727863205824d136
parentbec0cef2d3cd0c0d5d30b66bc58f351dcc912681 (diff)
CA_vaidhy final changes. (#708)
-rw-r--r--src/main/java/dev/morling/onebrc/CalculateAverage_vaidhy.java367
1 files changed, 272 insertions, 95 deletions
diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_vaidhy.java b/src/main/java/dev/morling/onebrc/CalculateAverage_vaidhy.java
index 5795077..f63374a 100644
--- a/src/main/java/dev/morling/onebrc/CalculateAverage_vaidhy.java
+++ b/src/main/java/dev/morling/onebrc/CalculateAverage_vaidhy.java
@@ -21,6 +21,7 @@ import java.io.IOException;
import java.lang.foreign.Arena;
import java.lang.reflect.Field;
import java.nio.ByteBuffer;
+import java.nio.ByteOrder;
import java.nio.channels.FileChannel;
import java.nio.charset.StandardCharsets;
import java.nio.file.Path;
@@ -37,69 +38,149 @@ public class CalculateAverage_vaidhy<I, T> {
private static final class HashEntry {
private long startAddress;
- private long endAddress;
+ private long keyLength;
private long suffix;
- private int hash;
-
+ private int next;
IntSummaryStatistics value;
}
private static class PrimitiveHashMap {
private final HashEntry[] entries;
+ private final long[] hashes;
+
private final int twoPow;
+ private int next = -1;
PrimitiveHashMap(int twoPow) {
this.twoPow = twoPow;
this.entries = new HashEntry[1 << twoPow];
+ this.hashes = new long[1 << twoPow];
for (int i = 0; i < entries.length; i++) {
this.entries[i] = new HashEntry();
}
}
- public HashEntry find(long startAddress, long endAddress, long suffix, int hash) {
+ public IntSummaryStatistics find(long startAddress, long endAddress, long hash, long suffix) {
int len = entries.length;
- int i = (hash ^ (hash >> twoPow)) & (len - 1);
+ int h = Long.hashCode(hash);
+ int initialIndex = (h ^ (h >> twoPow)) & (len - 1);
+ int i = initialIndex;
+ long lookupLength = endAddress - startAddress;
- do {
+ long hashEntry = hashes[i];
+
+ if (hashEntry == hash) {
HashEntry entry = entries[i];
- if (entry.value == null) {
- return entry;
+ if (lookupLength <= 7) {
+ // This works because
+ // hash = suffix , when simpleHash is just xor.
+ // Since length is not 8, suffix will have a 0 at the end.
+ // Since utf-8 strings can't have 0 in middle of a string this means
+ // we can stop here.
+ return entry.value;
}
- if (entry.hash == hash) {
- long entryLength = entry.endAddress - entry.startAddress;
- long lookupLength = endAddress - startAddress;
- if ((entryLength == lookupLength) && (entry.suffix == suffix)) {
- boolean found = compareEntryKeys(startAddress, endAddress, entry);
-
- if (found) {
- return entry;
- }
+ boolean found = (entry.suffix == suffix &&
+ compareEntryKeys(startAddress, endAddress, entry.startAddress));
+ if (found) {
+ return entry.value;
+ }
+ }
+
+ if (hashEntry == 0) {
+ HashEntry entry = entries[i];
+ entry.startAddress = startAddress;
+ entry.keyLength = lookupLength;
+ hashes[i] = hash;
+ entry.suffix = suffix;
+ entry.next = next;
+ this.next = i;
+ entry.value = new IntSummaryStatistics();
+ return entry.value;
+ }
+
+ i++;
+ if (i == len) {
+ i = 0;
+ }
+
+ if (i == initialIndex) {
+ return null;
+ }
+
+ do {
+ hashEntry = hashes[i];
+ if (hashEntry == hash) {
+ HashEntry entry = entries[i];
+ if (lookupLength <= 7) {
+ return entry.value;
+ }
+ boolean found = (entry.suffix == suffix &&
+ compareEntryKeys(startAddress, endAddress, entry.startAddress));
+ if (found) {
+ return entry.value;
}
}
+ if (hashEntry == 0) {
+ HashEntry entry = entries[i];
+ entry.startAddress = startAddress;
+ entry.keyLength = lookupLength;
+ hashes[i] = hash;
+ entry.suffix = suffix;
+ entry.next = next;
+ this.next = i;
+ entry.value = new IntSummaryStatistics();
+ return entry.value;
+ }
+
i++;
if (i == len) {
i = 0;
}
- } while (i != hash);
+ } while (i != initialIndex);
return null;
}
- private static boolean compareEntryKeys(long startAddress, long endAddress, HashEntry entry) {
- long entryIndex = entry.startAddress;
+ private static boolean compareEntryKeys(long startAddress, long endAddress, long entryStartAddress) {
+ long entryIndex = entryStartAddress;
long lookupIndex = startAddress;
+ long endAddressStop = endAddress - 7;
- for (; (lookupIndex + 7) < endAddress; lookupIndex += 8) {
+ for (; lookupIndex < endAddressStop; lookupIndex += 8) {
if (UNSAFE.getLong(entryIndex) != UNSAFE.getLong(lookupIndex)) {
return false;
}
entryIndex += 8;
}
+
return true;
}
+
+ public Iterable<HashEntry> entrySet() {
+ return () -> new Iterator<>() {
+ int scan = next;
+
+ @Override
+ public boolean hasNext() {
+ return scan != -1;
+ }
+
+ @Override
+ public HashEntry next() {
+ HashEntry entry = entries[scan];
+ scan = entry.next;
+ return entry;
+ }
+ };
+ }
}
private static final String FILE = "./measurements.txt";
+ private static long simpleHash(long hash, long nextData) {
+ return hash ^ nextData;
+ // return (hash ^ Long.rotateLeft((nextData * C1), R1)) * C2;
+ }
+
private static Unsafe initUnsafe() {
try {
Field theUnsafe = Unsafe.class.getDeclaredField("theUnsafe");
@@ -145,7 +226,7 @@ public class CalculateAverage_vaidhy<I, T> {
interface MapReduce<I> {
- void process(long keyStartAddress, long keyEndAddress, int hash, int temperature, long suffix);
+ void process(long keyStartAddress, long keyEndAddress, long hash, long suffix, int temperature);
I result();
}
@@ -173,9 +254,13 @@ public class CalculateAverage_vaidhy<I, T> {
private final long chunkEnd;
private long position;
- private int hash;
+ private long hash;
+
private long suffix;
- byte[] b = new byte[4];
+
+ private final ByteBuffer buf = ByteBuffer
+ .allocate(8)
+ .order(ByteOrder.LITTLE_ENDIAN);
public LineStream(FileService fileService, long offset, long chunkSize) {
long fileStart = fileService.address();
@@ -186,50 +271,38 @@ public class CalculateAverage_vaidhy<I, T> {
}
public boolean hasNext() {
- return position <= chunkEnd && position < fileEnd;
+ return position <= chunkEnd;
}
public long findSemi() {
- int h = 0;
- long s = 0;
- long i = position;
- while ((i + 3) < fileEnd) {
- // Adding 16 as it is the offset for primitive arrays
- ByteBuffer.wrap(b).putInt(UNSAFE.getInt(i));
-
- if (b[3] == 0x3B) {
- break;
- }
- i++;
- h = ((h << 5) - h) ^ b[3];
- s = (s << 8) ^ b[3];
+ long h = 0;
+ buf.rewind();
- if (b[2] == 0x3B) {
- break;
+ for (long i = position; i < fileEnd; i++) {
+ byte ch = UNSAFE.getByte(i);
+ if (ch == ';') {
+ int discard = buf.remaining();
+ buf.rewind();
+ long nextData = (buf.getLong() << discard) >>> discard;
+ this.suffix = nextData;
+ this.hash = simpleHash(h, nextData);
+ position = i + 1;
+ return i;
}
- i++;
- h = ((h << 5) - h) ^ b[2];
- s = (s << 8) ^ b[2];
-
- if (b[1] == 0x3B) {
- break;
+ if (buf.hasRemaining()) {
+ buf.put(ch);
}
- i++;
- h = ((h << 5) - h) ^ b[1];
- s = (s << 8) ^ b[1];
-
- if (b[0] == 0x3B) {
- break;
+ else {
+ buf.flip();
+ long nextData = buf.getLong();
+ h = simpleHash(h, nextData);
+ buf.rewind();
}
- i++;
- h = ((h << 5) - h) ^ b[0];
- s = (s << 8) ^ b[0];
}
-
this.hash = h;
- this.suffix = s;
- position = i + 1;
- return i;
+ this.suffix = buf.getLong();
+ position = fileEnd;
+ return fileEnd;
}
public long skipLine() {
@@ -258,7 +331,94 @@ public class CalculateAverage_vaidhy<I, T> {
}
}
- private void worker(long offset, long chunkSize, MapReduce<I> lineConsumer) {
+ private static final long START_BYTE_INDICATOR = 0x0101_0101_0101_0101L;
+ private static final long END_BYTE_INDICATOR = START_BYTE_INDICATOR << 7;
+
+ private static final long NEW_LINE_DETECTION = START_BYTE_INDICATOR * '\n';
+
+ private static final long SEMI_DETECTION = START_BYTE_INDICATOR * ';';
+
+ private static final long ALL_ONES = 0xffff_ffff_ffff_ffffL;
+
+ private long findByteOctet(long data, long pattern) {
+ long match = data ^ pattern;
+ return (match - START_BYTE_INDICATOR) & ((~match) & END_BYTE_INDICATOR);
+ }
+
+ private void bigWorker(long offset, long chunkSize, MapReduce<I> lineConsumer) {
+ long chunkStart = offset + fileService.address();
+ long chunkEnd = chunkStart + chunkSize;
+ long fileEnd = fileService.address() + fileService.length();
+ long stopPoint = Math.min(chunkEnd + 1, fileEnd);
+
+ boolean skip = offset != 0;
+ for (long position = chunkStart; position < stopPoint;) {
+ if (skip) {
+ long data = UNSAFE.getLong(position);
+ long newLineMask = findByteOctet(data, NEW_LINE_DETECTION);
+ if (newLineMask != 0) {
+ int newLinePosition = Long.numberOfTrailingZeros(newLineMask) >>> 3;
+ skip = false;
+ position = position + newLinePosition + 1;
+ }
+ else {
+ position = position + 8;
+ }
+ continue;
+ }
+
+ long stationStart = position;
+ long stationEnd = -1;
+ long hash = 0;
+ long suffix = 0;
+ do {
+ long data = UNSAFE.getLong(position);
+ long semiMask = findByteOctet(data, SEMI_DETECTION);
+ if (semiMask != 0) {
+ int semiPosition = Long.numberOfTrailingZeros(semiMask) >>> 3;
+ stationEnd = position + semiPosition;
+ position = stationEnd + 1;
+
+ if (semiPosition != 0) {
+ suffix = data & (ALL_ONES >>> (64 - (semiPosition << 3)));
+ }
+ else {
+ suffix = UNSAFE.getLong(position - 8);
+ }
+ hash = simpleHash(hash, suffix);
+ break;
+ }
+ else {
+ hash = simpleHash(hash, data);
+ position = position + 8;
+ }
+ } while (true);
+
+ int temperature = 0;
+ {
+ byte ch = UNSAFE.getByte(position++);
+ boolean negative = false;
+ if (ch == '-') {
+ negative = true;
+ ch = UNSAFE.getByte(position++);
+ }
+ do {
+ if (ch != '.') {
+ temperature *= 10;
+ temperature += (ch ^ '0');
+ }
+ ch = UNSAFE.getByte(position++);
+ } while (ch != '\n');
+ if (negative) {
+ temperature = -temperature;
+ }
+ }
+
+ lineConsumer.process(stationStart, stationEnd, hash, suffix, temperature);
+ }
+ }
+
+ private void smallWorker(long offset, long chunkSize, MapReduce<I> lineConsumer) {
LineStream lineStream = new LineStream(fileService, offset, chunkSize);
if (offset != 0) {
@@ -274,29 +434,58 @@ public class CalculateAverage_vaidhy<I, T> {
while (lineStream.hasNext()) {
long keyStartAddress = lineStream.position;
long keyEndAddress = lineStream.findSemi();
- long keySuffix = lineStream.suffix;
- int keyHash = lineStream.hash;
+ long keyHash = lineStream.hash;
+ long suffix = lineStream.suffix;
long valueStartAddress = lineStream.position;
long valueEndAddress = lineStream.findTemperature();
int temperature = parseDouble(valueStartAddress, valueEndAddress);
- lineConsumer.process(keyStartAddress, keyEndAddress, keyHash, temperature, keySuffix);
+ // System.out.println("Small worker!");
+ lineConsumer.process(keyStartAddress, keyEndAddress, keyHash, suffix, temperature);
}
}
- public T master(long chunkSize, ExecutorService executor) {
- long len = fileService.length();
+ // file size = 7
+ // (0,0) (0,0) small chunk= (0,7)
+ // a;0.1\n
+
+ public T master(int shards, ExecutorService executor) {
List<Future<I>> summaries = new ArrayList<>();
+ long len = fileService.length();
+
+ if (len > 128) {
+ long bigChunk = Math.floorDiv(len, shards);
+ long bigChunkReAlign = bigChunk & 0xffff_ffff_ffff_fff8L;
+
+ long smallChunkStart = bigChunkReAlign * shards;
+ long smallChunkSize = len - smallChunkStart;
+
+ for (long offset = 0; offset < smallChunkStart; offset += bigChunkReAlign) {
+ MapReduce<I> mr = chunkProcessCreator.get();
+ final long transferOffset = offset;
+ Future<I> task = executor.submit(() -> {
+ bigWorker(transferOffset, bigChunkReAlign, mr);
+ return mr.result();
+ });
+ summaries.add(task);
+ }
+
+ MapReduce<I> mrLast = chunkProcessCreator.get();
+ Future<I> lastTask = executor.submit(() -> {
+ smallWorker(smallChunkStart, smallChunkSize - 1, mrLast);
+ return mrLast.result();
+ });
+ summaries.add(lastTask);
+ }
+ else {
- for (long offset = 0; offset < len; offset += chunkSize) {
- long workerLength = Math.min(len, offset + chunkSize) - offset;
- MapReduce<I> mr = chunkProcessCreator.get();
- final long transferOffset = offset;
- Future<I> task = executor.submit(() -> {
- worker(transferOffset, workerLength, mr);
- return mr.result();
+ MapReduce<I> mrLast = chunkProcessCreator.get();
+ Future<I> lastTask = executor.submit(() -> {
+ smallWorker(0, len - 1, mrLast);
+ return mrLast.result();
});
- summaries.add(task);
+ summaries.add(lastTask);
}
+
List<I> summariesDone = summaries.stream()
.map(task -> {
try {
@@ -336,22 +525,12 @@ public class CalculateAverage_vaidhy<I, T> {
private static class ChunkProcessorImpl implements MapReduce<PrimitiveHashMap> {
// 1 << 14 > 10,000 so it works
- private final PrimitiveHashMap statistics = new PrimitiveHashMap(14);
+ private final PrimitiveHashMap statistics = new PrimitiveHashMap(15);
@Override
- public void process(long keyStartAddress, long keyEndAddress, int hash, int temperature, long suffix) {
- HashEntry entry = statistics.find(keyStartAddress, keyEndAddress, suffix, hash);
- if (entry == null) {
- throw new IllegalStateException("Hash table too small :(");
- }
- if (entry.value == null) {
- entry.startAddress = keyStartAddress;
- entry.endAddress = keyEndAddress;
- entry.suffix = suffix;
- entry.hash = hash;
- entry.value = new IntSummaryStatistics();
- }
- entry.value.accept(temperature);
+ public void process(long keyStartAddress, long keyEndAddress, long hash, long suffix, int temperature) {
+ IntSummaryStatistics stats = statistics.find(keyStartAddress, keyEndAddress, hash, suffix);
+ stats.accept(temperature);
}
@Override
@@ -368,13 +547,10 @@ public class CalculateAverage_vaidhy<I, T> {
ChunkProcessorImpl::new,
CalculateAverage_vaidhy::combineOutputs);
- int proc = 2 * Runtime.getRuntime().availableProcessors();
-
- long fileSize = diskFileService.length();
- long chunkSize = Math.ceilDiv(fileSize, proc);
+ int proc = Runtime.getRuntime().availableProcessors();
ExecutorService executor = Executors.newFixedThreadPool(proc);
- Map<String, IntSummaryStatistics> output = calculateAverageVaidhy.master(chunkSize, executor);
+ Map<String, IntSummaryStatistics> output = calculateAverageVaidhy.master(2 * proc, executor);
executor.shutdown();
Map<String, String> outputStr = toPrintMap(output);
@@ -395,11 +571,12 @@ public class CalculateAverage_vaidhy<I, T> {
private static Map<String, IntSummaryStatistics> combineOutputs(
List<PrimitiveHashMap> list) {
- Map<String, IntSummaryStatistics> output = new HashMap<>(10000);
+ Map<String, IntSummaryStatistics> output = HashMap.newHashMap(10000);
for (PrimitiveHashMap map : list) {
- for (HashEntry entry : map.entries) {
+ for (HashEntry entry : map.entrySet()) {
if (entry.value != null) {
- String keyStr = unsafeToString(entry.startAddress, entry.endAddress);
+ String keyStr = unsafeToString(entry.startAddress,
+ entry.startAddress + entry.keyLength);
output.compute(keyStr, (ignore, val) -> {
if (val == null) {