summaryrefslogtreecommitdiff
path: root/memtable
diff options
context:
space:
mode:
authorYi Wu <yiwu@fb.com>2017-04-06 13:59:31 -0700
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2017-04-06 14:09:13 -0700
commitdf6f5a377287c3f4f2b38a46a8871aabf53dbffb (patch)
tree80dcc63e3793005ad416346ea8f936aaf56f3b87 /memtable
parent107c5f6a60e94c3cf478ee25e50164fd31b014b0 (diff)
Move memtable related files into memtable directory
Summary: Move memtable related files into memtable directory. Closes https://github.com/facebook/rocksdb/pull/2087 Differential Revision: D4829242 Pulled By: yiwu-arbug fbshipit-source-id: ca70ab6
Diffstat (limited to 'memtable')
-rw-r--r--memtable/hash_cuckoo_rep.cc2
-rw-r--r--memtable/hash_linklist_rep.cc2
-rw-r--r--memtable/hash_skiplist_rep.cc2
-rw-r--r--memtable/inlineskiplist.h900
-rw-r--r--memtable/inlineskiplist_test.cc625
-rw-r--r--memtable/memtable_allocator.cc59
-rw-r--r--memtable/memtable_allocator.h50
-rw-r--r--memtable/memtablerep_bench.cc697
-rw-r--r--memtable/skiplist.h495
-rw-r--r--memtable/skiplist_test.cc387
-rw-r--r--memtable/skiplistrep.cc2
11 files changed, 3217 insertions, 4 deletions
diff --git a/memtable/hash_cuckoo_rep.cc b/memtable/hash_cuckoo_rep.cc
index a8799f1ca..30cb41b7d 100644
--- a/memtable/hash_cuckoo_rep.cc
+++ b/memtable/hash_cuckoo_rep.cc
@@ -16,7 +16,7 @@
#include <vector>
#include "db/memtable.h"
-#include "db/skiplist.h"
+#include "memtable/skiplist.h"
#include "memtable/stl_wrappers.h"
#include "port/port.h"
#include "rocksdb/memtablerep.h"
diff --git a/memtable/hash_linklist_rep.cc b/memtable/hash_linklist_rep.cc
index 48a6d9897..59559d5b7 100644
--- a/memtable/hash_linklist_rep.cc
+++ b/memtable/hash_linklist_rep.cc
@@ -10,7 +10,7 @@
#include <algorithm>
#include <atomic>
#include "db/memtable.h"
-#include "db/skiplist.h"
+#include "memtable/skiplist.h"
#include "monitoring/histogram.h"
#include "port/port.h"
#include "rocksdb/memtablerep.h"
diff --git a/memtable/hash_skiplist_rep.cc b/memtable/hash_skiplist_rep.cc
index 12b47950d..307fad650 100644
--- a/memtable/hash_skiplist_rep.cc
+++ b/memtable/hash_skiplist_rep.cc
@@ -16,7 +16,7 @@
#include "port/port.h"
#include "util/murmurhash.h"
#include "db/memtable.h"
-#include "db/skiplist.h"
+#include "memtable/skiplist.h"
namespace rocksdb {
namespace {
diff --git a/memtable/inlineskiplist.h b/memtable/inlineskiplist.h
new file mode 100644
index 000000000..43bb09ac8
--- /dev/null
+++ b/memtable/inlineskiplist.h
@@ -0,0 +1,900 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree. An additional
+// grant of patent rights can be found in the PATENTS file in the same
+// directory.
+//
+// Copyright (c) 2011 The LevelDB Authors. All rights reserved. Use of
+// this source code is governed by a BSD-style license that can be found
+// in the LICENSE file. See the AUTHORS file for names of contributors.
+//
+// InlineSkipList is derived from SkipList (skiplist.h), but it optimizes
+// the memory layout by requiring that the key storage be allocated through
+// the skip list instance. For the common case of SkipList<const char*,
+// Cmp> this saves 1 pointer per skip list node and gives better cache
+// locality, at the expense of wasted padding from using AllocateAligned
+// instead of Allocate for the keys. The unused padding will be from
+// 0 to sizeof(void*)-1 bytes, and the space savings are sizeof(void*)
+// bytes, so despite the padding the space used is always less than
+// SkipList<const char*, ..>.
+//
+// Thread safety -------------
+//
+// Writes via Insert require external synchronization, most likely a mutex.
+// InsertConcurrently can be safely called concurrently with reads and
+// with other concurrent inserts. Reads require a guarantee that the
+// InlineSkipList will not be destroyed while the read is in progress.
+// Apart from that, reads progress without any internal locking or
+// synchronization.
+//
+// Invariants:
+//
+// (1) Allocated nodes are never deleted until the InlineSkipList is
+// destroyed. This is trivially guaranteed by the code since we never
+// delete any skip list nodes.
+//
+// (2) The contents of a Node except for the next/prev pointers are
+// immutable after the Node has been linked into the InlineSkipList.
+// Only Insert() modifies the list, and it is careful to initialize a
+// node and use release-stores to publish the nodes in one or more lists.
+//
+// ... prev vs. next pointer ordering ...
+//
+
+#pragma once
+#include <assert.h>
+#include <stdlib.h>
+#include <algorithm>
+#include <atomic>
+#include "port/port.h"
+#include "util/allocator.h"
+#include "util/random.h"
+
+namespace rocksdb {
+
+template <class Comparator>
+class InlineSkipList {
+ private:
+ struct Node;
+ struct Splice;
+
+ public:
+ static const uint16_t kMaxPossibleHeight = 32;
+
+ // Create a new InlineSkipList object that will use "cmp" for comparing
+ // keys, and will allocate memory using "*allocator". Objects allocated
+ // in the allocator must remain allocated for the lifetime of the
+ // skiplist object.
+ explicit InlineSkipList(Comparator cmp, Allocator* allocator,
+ int32_t max_height = 12,
+ int32_t branching_factor = 4);
+
+ // Allocates a key and a skip-list node, returning a pointer to the key
+ // portion of the node. This method is thread-safe if the allocator
+ // is thread-safe.
+ char* AllocateKey(size_t key_size);
+
+ // Allocate a splice using allocator.
+ Splice* AllocateSplice();
+
+ // Inserts a key allocated by AllocateKey, after the actual key value
+ // has been filled in.
+ //
+ // REQUIRES: nothing that compares equal to key is currently in the list.
+ // REQUIRES: no concurrent calls to any of inserts.
+ void Insert(const char* key);
+
+ // Inserts a key allocated by AllocateKey with a hint of last insert
+ // position in the skip-list. If hint points to nullptr, a new hint will be
+ // populated, which can be used in subsequent calls.
+ //
+ // It can be used to optimize the workload where there are multiple groups
+ // of keys, and each key is likely to insert to a location close to the last
+ // inserted key in the same group. One example is sequential inserts.
+ //
+ // REQUIRES: nothing that compares equal to key is currently in the list.
+ // REQUIRES: no concurrent calls to any of inserts.
+ void InsertWithHint(const char* key, void** hint);
+
+ // Like Insert, but external synchronization is not required.
+ void InsertConcurrently(const char* key);
+
+ // Inserts a node into the skip list. key must have been allocated by
+ // AllocateKey and then filled in by the caller. If UseCAS is true,
+ // then external synchronization is not required, otherwise this method
+ // may not be called concurrently with any other insertions.
+ //
+ // Regardless of whether UseCAS is true, the splice must be owned
+ // exclusively by the current thread. If allow_partial_splice_fix is
+ // true, then the cost of insertion is amortized O(log D), where D is
+ // the distance from the splice to the inserted key (measured as the
+ // number of intervening nodes). Note that this bound is very good for
+ // sequential insertions! If allow_partial_splice_fix is false then
+ // the existing splice will be ignored unless the current key is being
+ // inserted immediately after the splice. allow_partial_splice_fix ==
+ // false has worse running time for the non-sequential case O(log N),
+ // but a better constant factor.
+ template <bool UseCAS>
+ void Insert(const char* key, Splice* splice, bool allow_partial_splice_fix);
+
+ // Returns true iff an entry that compares equal to key is in the list.
+ bool Contains(const char* key) const;
+
+ // Return estimated number of entries smaller than `key`.
+ uint64_t EstimateCount(const char* key) const;
+
+ // Validate correctness of the skip-list.
+ void TEST_Validate() const;
+
+ // Iteration over the contents of a skip list
+ class Iterator {
+ public:
+ // Initialize an iterator over the specified list.
+ // The returned iterator is not valid.
+ explicit Iterator(const InlineSkipList* list);
+
+ // Change the underlying skiplist used for this iterator
+ // This enables us not changing the iterator without deallocating
+ // an old one and then allocating a new one
+ void SetList(const InlineSkipList* list);
+
+ // Returns true iff the iterator is positioned at a valid node.
+ bool Valid() const;
+
+ // Returns the key at the current position.
+ // REQUIRES: Valid()
+ const char* key() const;
+
+ // Advances to the next position.
+ // REQUIRES: Valid()
+ void Next();
+
+ // Advances to the previous position.
+ // REQUIRES: Valid()
+ void Prev();
+
+ // Advance to the first entry with a key >= target
+ void Seek(const char* target);
+
+ // Retreat to the last entry with a key <= target
+ void SeekForPrev(const char* target);
+
+ // Position at the first entry in list.
+ // Final state of iterator is Valid() iff list is not empty.
+ void SeekToFirst();
+
+ // Position at the last entry in list.
+ // Final state of iterator is Valid() iff list is not empty.
+ void SeekToLast();
+
+ private:
+ const InlineSkipList* list_;
+ Node* node_;
+ // Intentionally copyable
+ };
+
+ private:
+ const uint16_t kMaxHeight_;
+ const uint16_t kBranching_;
+ const uint32_t kScaledInverseBranching_;
+
+ // Immutable after construction
+ Comparator const compare_;
+ Allocator* const allocator_; // Allocator used for allocations of nodes
+
+ Node* const head_;
+
+ // Modified only by Insert(). Read racily by readers, but stale
+ // values are ok.
+ std::atomic<int> max_height_; // Height of the entire list
+
+ // seq_splice_ is a Splice used for insertions in the non-concurrent
+ // case. It caches the prev and next found during the most recent
+ // non-concurrent insertion.
+ Splice* seq_splice_;
+
+ inline int GetMaxHeight() const {
+ return max_height_.load(std::memory_order_relaxed);
+ }
+
+ int RandomHeight();
+
+ Node* AllocateNode(size_t key_size, int height);
+
+ bool Equal(const char* a, const char* b) const {
+ return (compare_(a, b) == 0);
+ }
+
+ bool LessThan(const char* a, const char* b) const {
+ return (compare_(a, b) < 0);
+ }
+
+ // Return true if key is greater than the data stored in "n". Null n
+ // is considered infinite. n should not be head_.
+ bool KeyIsAfterNode(const char* key, Node* n) const;
+
+ // Returns the earliest node with a key >= key.
+ // Return nullptr if there is no such node.
+ Node* FindGreaterOrEqual(const char* key) const;
+
+ // Return the latest node with a key < key.
+ // Return head_ if there is no such node.
+ // Fills prev[level] with pointer to previous node at "level" for every
+ // level in [0..max_height_-1], if prev is non-null.
+ Node* FindLessThan(const char* key, Node** prev = nullptr) const;
+
+ // Return the latest node with a key < key on bottom_level. Start searching
+ // from root node on the level below top_level.
+ // Fills prev[level] with pointer to previous node at "level" for every
+ // level in [bottom_level..top_level-1], if prev is non-null.
+ Node* FindLessThan(const char* key, Node** prev, Node* root, int top_level,
+ int bottom_level) const;
+
+ // Return the last node in the list.
+ // Return head_ if list is empty.
+ Node* FindLast() const;
+
+ // Traverses a single level of the list, setting *out_prev to the last
+ // node before the key and *out_next to the first node after. Assumes
+ // that the key is not present in the skip list. On entry, before should
+ // point to a node that is before the key, and after should point to
+ // a node that is after the key. after should be nullptr if a good after
+ // node isn't conveniently available.
+ void FindSpliceForLevel(const char* key, Node* before, Node* after, int level,
+ Node** out_prev, Node** out_next);
+
+ // Recomputes Splice levels from highest_level (inclusive) down to
+ // lowest_level (inclusive).
+ void RecomputeSpliceLevels(const char* key, Splice* splice,
+ int recompute_level);
+
+ // No copying allowed
+ InlineSkipList(const InlineSkipList&);
+ InlineSkipList& operator=(const InlineSkipList&);
+};
+
+// Implementation details follow
+
+template <class Comparator>
+struct InlineSkipList<Comparator>::Splice {
+ // The invariant of a Splice is that prev_[i+1].key <= prev_[i].key <
+ // next_[i].key <= next_[i+1].key for all i. That means that if a
+ // key is bracketed by prev_[i] and next_[i] then it is bracketed by
+ // all higher levels. It is _not_ required that prev_[i]->Next(i) ==
+ // next_[i] (it probably did at some point in the past, but intervening
+ // or concurrent operations might have inserted nodes in between).
+ int height_ = 0;
+ Node** prev_;
+ Node** next_;
+};
+
+// The Node data type is more of a pointer into custom-managed memory than
+// a traditional C++ struct. The key is stored in the bytes immediately
+// after the struct, and the next_ pointers for nodes with height > 1 are
+// stored immediately _before_ the struct. This avoids the need to include
+// any pointer or sizing data, which reduces per-node memory overheads.
+template <class Comparator>
+struct InlineSkipList<Comparator>::Node {
+ // Stores the height of the node in the memory location normally used for
+ // next_[0]. This is used for passing data from AllocateKey to Insert.
+ void StashHeight(const int height) {
+ assert(sizeof(int) <= sizeof(next_[0]));
+ memcpy(&next_[0], &height, sizeof(int));
+ }
+
+ // Retrieves the value passed to StashHeight. Undefined after a call
+ // to SetNext or NoBarrier_SetNext.
+ int UnstashHeight() const {
+ int rv;
+ memcpy(&rv, &next_[0], sizeof(int));
+ return rv;
+ }
+
+ const char* Key() const { return reinterpret_cast<const char*>(&next_[1]); }
+
+ // Accessors/mutators for links. Wrapped in methods so we can add
+ // the appropriate barriers as necessary, and perform the necessary
+ // addressing trickery for storing links below the Node in memory.
+ Node* Next(int n) {
+ assert(n >= 0);
+ // Use an 'acquire load' so that we observe a fully initialized
+ // version of the returned Node.
+ return (next_[-n].load(std::memory_order_acquire));
+ }
+
+ void SetNext(int n, Node* x) {
+ assert(n >= 0);
+ // Use a 'release store' so that anybody who reads through this
+ // pointer observes a fully initialized version of the inserted node.
+ next_[-n].store(x, std::memory_order_release);
+ }
+
+ bool CASNext(int n, Node* expected, Node* x) {
+ assert(n >= 0);
+ return next_[-n].compare_exchange_strong(expected, x);
+ }
+
+ // No-barrier variants that can be safely used in a few locations.
+ Node* NoBarrier_Next(int n) {
+ assert(n >= 0);
+ return next_[-n].load(std::memory_order_relaxed);
+ }
+
+ void NoBarrier_SetNext(int n, Node* x) {
+ assert(n >= 0);
+ next_[-n].store(x, std::memory_order_relaxed);
+ }
+
+ // Insert node after prev on specific level.
+ void InsertAfter(Node* prev, int level) {
+ // NoBarrier_SetNext() suffices since we will add a barrier when
+ // we publish a pointer to "this" in prev.
+ NoBarrier_SetNext(level, prev->NoBarrier_Next(level));
+ prev->SetNext(level, this);
+ }
+
+ private:
+ // next_[0] is the lowest level link (level 0). Higher levels are
+ // stored _earlier_, so level 1 is at next_[-1].
+ std::atomic<Node*> next_[1];
+};
+
+template <class Comparator>
+inline InlineSkipList<Comparator>::Iterator::Iterator(
+ const InlineSkipList* list) {
+ SetList(list);
+}
+
+template <class Comparator>
+inline void InlineSkipList<Comparator>::Iterator::SetList(
+ const InlineSkipList* list) {
+ list_ = list;
+ node_ = nullptr;
+}
+
+template <class Comparator>
+inline bool InlineSkipList<Comparator>::Iterator::Valid() const {
+ return node_ != nullptr;
+}
+
+template <class Comparator>
+inline const char* InlineSkipList<Comparator>::Iterator::key() const {
+ assert(Valid());
+ return node_->Key();
+}
+
+template <class Comparator>
+inline void InlineSkipList<Comparator>::Iterator::Next() {
+ assert(Valid());
+ node_ = node_->Next(0);
+}
+
+template <class Comparator>
+inline void InlineSkipList<Comparator>::Iterator::Prev() {
+ // Instead of using explicit "prev" links, we just search for the
+ // last node that falls before key.
+ assert(Valid());
+ node_ = list_->FindLessThan(node_->Key());
+ if (node_ == list_->head_) {
+ node_ = nullptr;
+ }
+}
+
+template <class Comparator>
+inline void InlineSkipList<Comparator>::Iterator::Seek(const char* target) {
+ node_ = list_->FindGreaterOrEqual(target);
+}
+
+template <class Comparator>
+inline void InlineSkipList<Comparator>::Iterator::SeekForPrev(
+ const char* target) {
+ Seek(target);
+ if (!Valid()) {
+ SeekToLast();
+ }
+ while (Valid() && list_->LessThan(target, key())) {
+ Prev();
+ }
+}
+
+template <class Comparator>
+inline void InlineSkipList<Comparator>::Iterator::SeekToFirst() {
+ node_ = list_->head_->Next(0);
+}
+
+template <class Comparator>
+inline void InlineSkipList<Comparator>::Iterator::SeekToLast() {
+ node_ = list_->FindLast();
+ if (node_ == list_->head_) {
+ node_ = nullptr;
+ }
+}
+
+template <class Comparator>
+int InlineSkipList<Comparator>::RandomHeight() {
+ auto rnd = Random::GetTLSInstance();
+
+ // Increase height with probability 1 in kBranching
+ int height = 1;
+ while (height < kMaxHeight_ && height < kMaxPossibleHeight &&
+ rnd->Next() < kScaledInverseBranching_) {
+ height++;
+ }
+ assert(height > 0);
+ assert(height <= kMaxHeight_);
+ assert(height <= kMaxPossibleHeight);
+ return height;
+}
+
+template <class Comparator>
+bool InlineSkipList<Comparator>::KeyIsAfterNode(const char* key,
+ Node* n) const {
+ // nullptr n is considered infinite
+ assert(n != head_);
+ return (n != nullptr) && (compare_(n->Key(), key) < 0);
+}
+
+template <class Comparator>
+typename InlineSkipList<Comparator>::Node*
+InlineSkipList<Comparator>::FindGreaterOrEqual(const char* key) const {
+ // Note: It looks like we could reduce duplication by implementing
+ // this function as FindLessThan(key)->Next(0), but we wouldn't be able
+ // to exit early on equality and the result wouldn't even be correct.
+ // A concurrent insert might occur after FindLessThan(key) but before
+ // we get a chance to call Next(0).
+ Node* x = head_;
+ int level = GetMaxHeight() - 1;
+ Node* last_bigger = nullptr;
+ while (true) {
+ Node* next = x->Next(level);
+ // Make sure the lists are sorted
+ assert(x == head_ || next == nullptr || KeyIsAfterNode(next->Key(), x));
+ // Make sure we haven't overshot during our search
+ assert(x == head_ || KeyIsAfterNode(key, x));
+ int cmp = (next == nullptr || next == last_bigger)
+ ? 1
+ : compare_(next->Key(), key);
+ if (cmp == 0 || (cmp > 0 && level == 0)) {
+ return next;
+ } else if (cmp < 0) {
+ // Keep searching in this list
+ x = next;
+ } else {
+ // Switch to next list, reuse compare_() result
+ last_bigger = next;
+ level--;
+ }
+ }
+}
+
+template <class Comparator>
+typename InlineSkipList<Comparator>::Node*
+InlineSkipList<Comparator>::FindLessThan(const char* key, Node** prev) const {
+ return FindLessThan(key, prev, head_, GetMaxHeight(), 0);
+}
+
+template <class Comparator>
+typename InlineSkipList<Comparator>::Node*
+InlineSkipList<Comparator>::FindLessThan(const char* key, Node** prev,
+ Node* root, int top_level,
+ int bottom_level) const {
+ assert(top_level > bottom_level);
+ int level = top_level - 1;
+ Node* x = root;
+ // KeyIsAfter(key, last_not_after) is definitely false
+ Node* last_not_after = nullptr;
+ while (true) {
+ Node* next = x->Next(level);
+ assert(x == head_ || next == nullptr || KeyIsAfterNode(next->Key(), x));
+ assert(x == head_ || KeyIsAfterNode(key, x));
+ if (next != last_not_after && KeyIsAfterNode(key, next)) {
+ // Keep searching in this list
+ x = next;
+ } else {
+ if (prev != nullptr) {
+ prev[level] = x;
+ }
+ if (level == bottom_level) {
+ return x;
+ } else {
+ // Switch to next list, reuse KeyIsAfterNode() result
+ last_not_after = next;
+ level--;
+ }
+ }
+ }
+}
+
+template <class Comparator>
+typename InlineSkipList<Comparator>::Node*
+InlineSkipList<Comparator>::FindLast() const {
+ Node* x = head_;
+ int level = GetMaxHeight() - 1;
+ while (true) {
+ Node* next = x->Next(level);
+ if (next == nullptr) {
+ if (level == 0) {
+ return x;
+ } else {
+ // Switch to next list
+ level--;
+ }
+ } else {
+ x = next;
+ }
+ }
+}
+
+template <class Comparator>
+uint64_t InlineSkipList<Comparator>::EstimateCount(const char* key) const {
+ uint64_t count = 0;
+
+ Node* x = head_;
+ int level = GetMaxHeight() - 1;
+ while (true) {
+ assert(x == head_ || compare_(x->Key(), key) < 0);
+ Node* next = x->Next(level);
+ if (next == nullptr || compare_(next->Key(), key) >= 0) {
+ if (level == 0) {
+ return count;
+ } else {
+ // Switch to next list
+ count *= kBranching_;
+ level--;
+ }
+ } else {
+ x = next;
+ count++;
+ }
+ }
+}
+
+template <class Comparator>
+InlineSkipList<Comparator>::InlineSkipList(const Comparator cmp,
+ Allocator* allocator,
+ int32_t max_height,
+ int32_t branching_factor)
+ : kMaxHeight_(max_height),
+ kBranching_(branching_factor),
+ kScaledInverseBranching_((Random::kMaxNext + 1) / kBranching_),
+ compare_(cmp),
+ allocator_(allocator),
+ head_(AllocateNode(0, max_height)),
+ max_height_(1),
+ seq_splice_(AllocateSplice()) {
+ assert(max_height > 0 && kMaxHeight_ == static_cast<uint32_t>(max_height));
+ assert(branching_factor > 1 &&
+ kBranching_ == static_cast<uint32_t>(branching_factor));
+ assert(kScaledInverseBranching_ > 0);
+
+ for (int i = 0; i < kMaxHeight_; ++i) {
+ head_->SetNext(i, nullptr);
+ }
+}
+
+template <class Comparator>
+char* InlineSkipList<Comparator>::AllocateKey(size_t key_size) {
+ return const_cast<char*>(AllocateNode(key_size, RandomHeight())->Key());
+}
+
+template <class Comparator>
+typename InlineSkipList<Comparator>::Node*
+InlineSkipList<Comparator>::AllocateNode(size_t key_size, int height) {
+ auto prefix = sizeof(std::atomic<Node*>) * (height - 1);
+
+ // prefix is space for the height - 1 pointers that we store before
+ // the Node instance (next_[-(height - 1) .. -1]). Node starts at
+ // raw + prefix, and holds the bottom-mode (level 0) skip list pointer
+ // next_[0]. key_size is the bytes for the key, which comes just after
+ // the Node.
+ char* raw = allocator_->AllocateAligned(prefix + sizeof(Node) + key_size);
+ Node* x = reinterpret_cast<Node*>(raw + prefix);
+
+ // Once we've linked the node into the skip list we don't actually need
+ // to know its height, because we can implicitly use the fact that we
+ // traversed into a node at level h to known that h is a valid level
+ // for that node. We need to convey the height to the Insert step,
+ // however, so that it can perform the proper links. Since we're not
+ // using the pointers at the moment, StashHeight temporarily borrow
+ // storage from next_[0] for that purpose.
+ x->StashHeight(height);
+ return x;
+}
+
+template <class Comparator>
+typename InlineSkipList<Comparator>::Splice*
+InlineSkipList<Comparator>::AllocateSplice() {
+ // size of prev_ and next_
+ size_t array_size = sizeof(Node*) * (kMaxHeight_ + 1);
+ char* raw = allocator_->AllocateAligned(sizeof(Splice) + array_size * 2);
+ Splice* splice = reinterpret_cast<Splice*>(raw);
+ splice->height_ = 0;
+ splice->prev_ = reinterpret_cast<Node**>(raw + sizeof(Splice));
+ splice->next_ = reinterpret_cast<Node**>(raw + sizeof(Splice) + array_size);
+ return splice;
+}
+
+template <class Comparator>
+void InlineSkipList<Comparator>::Insert(const char* key) {
+ Insert<false>(key, seq_splice_, false);
+}
+
+template <class Comparator>
+void InlineSkipList<Comparator>::InsertConcurrently(const char* key) {
+ Node* prev[kMaxPossibleHeight];
+ Node* next[kMaxPossibleHeight];
+ Splice splice;
+ splice.prev_ = prev;
+ splice.next_ = next;
+ Insert<true>(key, &splice, false);
+}
+
+template <class Comparator>
+void InlineSkipList<Comparator>::InsertWithHint(const char* key, void** hint) {
+ assert(hint != nullptr);
+ Splice* splice = reinterpret_cast<Splice*>(*hint);
+ if (splice == nullptr) {
+ splice = AllocateSplice();
+ *hint = reinterpret_cast<void*>(splice);
+ }
+ Insert<false>(key, splice, true);
+}
+
+template <class Comparator>
+void InlineSkipList<Comparator>::FindSpliceForLevel(const char* key,
+ Node* before, Node* after,
+ int level, Node** out_prev,
+ Node** out_next) {
+ while (true) {
+ Node* next = before->Next(level);
+ assert(before == head_ || next == nullptr ||
+ KeyIsAfterNode(next->Key(), before));
+ assert(before == head_ || KeyIsAfterNode(key, before));
+ if (next == after || !KeyIsAfterNode(key, next)) {
+ // found it
+ *out_prev = before;
+ *out_next = next;
+ return;
+ }
+ before = next;
+ }
+}
+
+template <class Comparator>
+void InlineSkipList<Comparator>::RecomputeSpliceLevels(const char* key,
+ Splice* splice,
+ int recompute_level) {
+ assert(recompute_level > 0);
+ assert(recompute_level <= splice->height_);
+ for (int i = recompute_level - 1; i >= 0; --i) {
+ FindSpliceForLevel(key, splice->prev_[i + 1], splice->next_[i + 1], i,
+ &splice->prev_[i], &splice->next_[i]);
+ }
+}
+
+template <class Comparator>
+template <bool UseCAS>
+void InlineSkipList<Comparator>::Insert(const char* key, Splice* splice,
+ bool allow_partial_splice_fix) {
+ Node* x = reinterpret_cast<Node*>(const_cast<char*>(key)) - 1;
+ int height = x->UnstashHeight();
+ assert(height >= 1 && height <= kMaxHeight_);
+
+ int max_height = max_height_.load(std::memory_order_relaxed);
+ while (height > max_height) {
+ if (max_height_.compare_exchange_weak(max_height, height)) {
+ // successfully updated it
+ max_height = height;
+ break;
+ }
+ // else retry, possibly exiting the loop because somebody else
+ // increased it
+ }
+ assert(max_height <= kMaxPossibleHeight);
+
+ int recompute_height = 0;
+ if (splice->height_ < max_height) {
+ // Either splice has never been used or max_height has grown since
+ // last use. We could potentially fix it in the latter case, but
+ // that is tricky.
+ splice->prev_[max_height] = head_;
+ splice->next_[max_height] = nullptr;
+ splice->height_ = max_height;
+ recompute_height = max_height;
+ } else {
+ // Splice is a valid proper-height splice that brackets some
+ // key, but does it bracket this one? We need to validate it and
+ // recompute a portion of the splice (levels 0..recompute_height-1)
+ // that is a superset of all levels that don't bracket the new key.
+ // Several choices are reasonable, because we have to balance the work
+ // saved against the extra comparisons required to validate the Splice.
+ //
+ // One strategy is just to recompute all of orig_splice_height if the
+ // bottom level isn't bracketing. This pessimistically assumes that
+ // we will either get a perfect Splice hit (increasing sequential
+ // inserts) or have no locality.
+ //
+ // Another strategy is to walk up the Splice's levels until we find
+ // a level that brackets the key. This strategy lets the Splice
+ // hint help for other cases: it turns insertion from O(log N) into
+ // O(log D), where D is the number of nodes in between the key that
+ // produced the Splice and the current insert (insertion is aided
+ // whether the new key is before or after the splice). If you have
+ // a way of using a prefix of the key to map directly to the closest
+ // Splice out of O(sqrt(N)) Splices and we make it so that splices
+ // can also be used as hints during read, then we end up with Oshman's
+ // and Shavit's SkipTrie, which has O(log log N) lookup and insertion
+ // (compare to O(log N) for skip list).
+ //
+ // We control the pessimistic strategy with allow_partial_splice_fix.
+ // A good strategy is probably to be pessimistic for seq_splice_,
+ // optimistic if the caller actually went to the work of providing
+ // a Splice.
+ while (recompute_height < max_height) {
+ if (splice->prev_[recompute_height]->Next(recompute_height) !=
+ splice->next_[recompute_height]) {
+ // splice isn't tight at this level, there must have been some inserts
+ // to this
+ // location that didn't update the splice. We might only be a little
+ // stale, but if
+ // the splice is very stale it would be O(N) to fix it. We haven't used
+ // up any of
+ // our budget of comparisons, so always move up even if we are
+ // pessimistic about
+ // our chances of success.
+ ++recompute_height;
+ } else if (splice->prev_[recompute_height] != head_ &&
+ !KeyIsAfterNode(key, splice->prev_[recompute_height])) {
+ // key is from before splice
+ if (allow_partial_splice_fix) {
+ // skip all levels with the same node without more comparisons
+ Node* bad = splice->prev_[recompute_height];
+ while (splice->prev_[recompute_height] == bad) {
+ ++recompute_height;
+ }
+ } else {
+ // we're pessimistic, recompute everything
+ recompute_height = max_height;
+ }
+ } else if (KeyIsAfterNode(key, splice->next_[recompute_height])) {
+ // key is from after splice
+ if (allow_partial_splice_fix) {
+ Node* bad = splice->next_[recompute_height];
+ while (splice->next_[recompute_height] == bad) {
+ ++recompute_height;
+ }
+ } else {
+ recompute_height = max_height;
+ }
+ } else {
+ // this level brackets the key, we won!
+ break;
+ }
+ }
+ }
+ assert(recompute_height <= max_height);
+ if (recompute_height > 0) {
+ RecomputeSpliceLevels(key, splice, recompute_height);
+ }
+
+ bool splice_is_valid = true;
+ if (UseCAS) {
+ for (int i = 0; i < height; ++i) {
+ while (true) {
+ assert(splice->next_[i] == nullptr ||
+ compare_(x->Key(), splice->next_[i]->Key()) < 0);
+ assert(splice->prev_[i] == head_ ||
+ compare_(splice->prev_[i]->Key(), x->Key()) < 0);
+ x->NoBarrier_SetNext(i, splice->next_[i]);
+ if (splice->prev_[i]->CASNext(i, splice->next_[i], x)) {
+ // success
+ break;
+ }
+ // CAS failed, we need to recompute prev and next. It is unlikely
+ // to be helpful to try to use a different level as we redo the
+ // search, because it should be unlikely that lots of nodes have
+ // been inserted between prev[i] and next[i]. No point in using
+ // next[i] as the after hint, because we know it is stale.
+ FindSpliceForLevel(key, splice->prev_[i], nullptr, i, &splice->prev_[i],
+ &splice->next_[i]);
+
+ // Since we've narrowed the bracket for level i, we might have
+ // violated the Splice constraint between i and i-1. Make sure
+ // we recompute the whole thing next time.
+ if (i > 0) {
+ splice_is_valid = false;
+ }
+ }
+ }
+ } else {
+ for (int i = 0; i < height; ++i) {
+ if (i >= recompute_height &&
+ splice->prev_[i]->Next(i) != splice->next_[i]) {
+ FindSpliceForLevel(key, splice->prev_[i], nullptr, i, &splice->prev_[i],
+ &splice->next_[i]);
+ }
+ assert(splice->next_[i] == nullptr ||
+ compare_(x->Key(), splice->next_[i]->Key()) < 0);
+ assert(splice->prev_[i] == head_ ||
+ compare_(splice->prev_[i]->Key(), x->Key()) < 0);
+ assert(splice->prev_[i]->Next(i) == splice->next_[i]);
+ x->NoBarrier_SetNext(i, splice->next_[i]);
+ splice->prev_[i]->SetNext(i, x);
+ }
+ }
+ if (splice_is_valid) {
+ for (int i = 0; i < height; ++i) {
+ splice->prev_[i] = x;
+ }
+ assert(splice->prev_[splice->height_] == head_);
+ assert(splice->next_[splice->height_] == nullptr);
+ for (int i = 0; i < splice->height_; ++i) {
+ assert(splice->next_[i] == nullptr ||
+ compare_(key, splice->next_[i]->Key()) < 0);
+ assert(splice->prev_[i] == head_ ||
+ compare_(splice->prev_[i]->Key(), key) <= 0);
+ assert(splice->prev_[i + 1] == splice->prev_[i] ||
+ splice->prev_[i + 1] == head_ ||
+ compare_(splice->prev_[i + 1]->Key(), splice->prev_[i]->Key()) <
+ 0);
+ assert(splice->next_[i + 1] == splice->next_[i] ||
+ splice->next_[i + 1] == nullptr ||
+ compare_(splice->next_[i]->Key(), splice->next_[i + 1]->Key()) <
+ 0);
+ }
+ } else {
+ splice->height_ = 0;
+ }
+}
+
+template <class Comparator>
+bool InlineSkipList<Comparator>::Contains(const char* key) const {
+ Node* x = FindGreaterOrEqual(key);
+ if (x != nullptr && Equal(key, x->Key())) {
+ return true;
+ } else {
+ return false;
+ }
+}
+
+template <class Comparator>
+void InlineSkipList<Comparator>::TEST_Validate() const {
+ // Interate over all levels at the same time, and verify nodes appear in
+ // the right order, and nodes appear in upper level also appear in lower
+ // levels.
+ Node* nodes[kMaxPossibleHeight];
+ int max_height = GetMaxHeight();
+ for (int i = 0; i < max_height; i++) {
+ nodes[i] = head_;
+ }
+ while (nodes[0] != nullptr) {
+ Node* l0_next = nodes[0]->Next(0);
+ if (l0_next == nullptr) {
+ break;
+ }
+ assert(nodes[0] == head_ || compare_(nodes[0]->Key(), l0_next->Key()) < 0);
+ nodes[0] = l0_next;
+
+ int i = 1;
+ while (i < max_height) {
+ Node* next = nodes[i]->Next(i);
+ if (next == nullptr) {
+ break;
+ }
+ auto cmp = compare_(nodes[0]->Key(), next->Key());
+ assert(cmp <= 0);
+ if (cmp == 0) {
+ assert(next == nodes[0]);
+ nodes[i] = next;
+ } else {
+ break;
+ }
+ i++;
+ }
+ }
+ for (int i = 1; i < max_height; i++) {
+ assert(nodes[i]->Next(i) == nullptr);
+ }
+}
+
+} // namespace rocksdb
diff --git a/memtable/inlineskiplist_test.cc b/memtable/inlineskiplist_test.cc
new file mode 100644
index 000000000..56ab49b81
--- /dev/null
+++ b/memtable/inlineskiplist_test.cc
@@ -0,0 +1,625 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree. An additional grant
+// of patent rights can be found in the PATENTS file in the same directory.
+//
+// Copyright (c) 2011 The LevelDB Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file. See the AUTHORS file for names of contributors.
+
+#include "memtable/inlineskiplist.h"
+#include <set>
+#include <unordered_set>
+#include "rocksdb/env.h"
+#include "util/concurrent_arena.h"
+#include "util/hash.h"
+#include "util/random.h"
+#include "util/testharness.h"
+
+namespace rocksdb {
+
+// Our test skip list stores 8-byte unsigned integers
+typedef uint64_t Key;
+
+static const char* Encode(const uint64_t* key) {
+ return reinterpret_cast<const char*>(key);
+}
+
+static Key Decode(const char* key) {
+ Key rv;
+ memcpy(&rv, key, sizeof(Key));
+ return rv;
+}
+
+struct TestComparator {
+ int operator()(const char* a, const char* b) const {
+ if (Decode(a) < Decode(b)) {
+ return -1;
+ } else if (Decode(a) > Decode(b)) {
+ return +1;
+ } else {
+ return 0;
+ }
+ }
+};
+
+typedef InlineSkipList<TestComparator> TestInlineSkipList;
+
+class InlineSkipTest : public testing::Test {
+ public:
+ void Insert(TestInlineSkipList* list, Key key) {
+ char* buf = list->AllocateKey(sizeof(Key));
+ memcpy(buf, &key, sizeof(Key));
+ list->Insert(buf);
+ keys_.insert(key);
+ }
+
+ void InsertWithHint(TestInlineSkipList* list, Key key, void** hint) {
+ char* buf = list->AllocateKey(sizeof(Key));
+ memcpy(buf, &key, sizeof(Key));
+ list->InsertWithHint(buf, hint);
+ keys_.insert(key);
+ }
+
+ void Validate(TestInlineSkipList* list) {
+ // Check keys exist.
+ for (Key key : keys_) {
+ ASSERT_TRUE(list->Contains(Encode(&key)));
+ }
+ // Iterate over the list, make sure keys appears in order and no extra
+ // keys exist.
+ TestInlineSkipList::Iterator iter(list);
+ ASSERT_FALSE(iter.Valid());
+ Key zero = 0;
+ iter.Seek(Encode(&zero));
+ for (Key key : keys_) {
+ ASSERT_TRUE(iter.Valid());
+ ASSERT_EQ(key, Decode(iter.key()));
+ iter.Next();
+ }
+ ASSERT_FALSE(iter.Valid());
+ // Validate the list is well-formed.
+ list->TEST_Validate();
+ }
+
+ private:
+ std::set<Key> keys_;
+};
+
+TEST_F(InlineSkipTest, Empty) {
+ Arena arena;
+ TestComparator cmp;
+ InlineSkipList<TestComparator> list(cmp, &arena);
+ Key key = 10;
+ ASSERT_TRUE(!list.Contains(Encode(&key)));
+
+ InlineSkipList<TestComparator>::Iterator iter(&list);
+ ASSERT_TRUE(!iter.Valid());
+ iter.SeekToFirst();
+ ASSERT_TRUE(!iter.Valid());
+ key = 100;
+ iter.Seek(Encode(&key));
+ ASSERT_TRUE(!iter.Valid());
+ iter.SeekForPrev(Encode(&key));
+ ASSERT_TRUE(!iter.Valid());
+ iter.SeekToLast();
+ ASSERT_TRUE(!iter.Valid());
+}
+
+TEST_F(InlineSkipTest, InsertAndLookup) {
+ const int N = 2000;
+ const int R = 5000;
+ Random rnd(1000);
+ std::set<Key> keys;
+ ConcurrentArena arena;
+ TestComparator cmp;
+ InlineSkipList<TestComparator> list(cmp, &arena);
+ for (int i = 0; i < N; i++) {
+ Key key = rnd.Next() % R;
+ if (keys.insert(key).second) {
+ char* buf = list.AllocateKey(sizeof(Key));
+ memcpy(buf, &key, sizeof(Key));
+ list.Insert(buf);
+ }
+ }
+
+ for (Key i = 0; i < R; i++) {
+ if (list.Contains(Encode(&i))) {
+ ASSERT_EQ(keys.count(i), 1U);
+ } else {
+ ASSERT_EQ(keys.count(i), 0U);
+ }
+ }
+
+ // Simple iterator tests
+ {
+ InlineSkipList<TestComparator>::Iterator iter(&list);
+ ASSERT_TRUE(!iter.Valid());
+
+ uint64_t zero = 0;
+ iter.Seek(Encode(&zero));
+ ASSERT_TRUE(iter.Valid());
+ ASSERT_EQ(*(keys.begin()), Decode(iter.key()));
+
+ uint64_t max_key = R - 1;
+ iter.SeekForPrev(Encode(&max_key));
+ ASSERT_TRUE(iter.Valid());
+ ASSERT_EQ(*(keys.rbegin()), Decode(iter.key()));
+
+ iter.SeekToFirst();
+ ASSERT_TRUE(iter.Valid());
+ ASSERT_EQ(*(keys.begin()), Decode(iter.key()));
+
+ iter.SeekToLast();
+ ASSERT_TRUE(iter.Valid());
+ ASSERT_EQ(*(keys.rbegin()), Decode(iter.key()));
+ }
+
+ // Forward iteration test
+ for (Key i = 0; i < R; i++) {
+ InlineSkipList<TestComparator>::Iterator iter(&list);
+ iter.Seek(Encode(&i));
+
+ // Compare against model iterator
+ std::set<Key>::iterator model_iter = keys.lower_bound(i);
+ for (int j = 0; j < 3; j++) {
+ if (model_iter == keys.end()) {
+ ASSERT_TRUE(!iter.Valid());
+ break;
+ } else {
+ ASSERT_TRUE(iter.Valid());
+ ASSERT_EQ(*model_iter, Decode(iter.key()));
+ ++model_iter;
+ iter.Next();
+ }
+ }
+ }
+
+ // Backward iteration test
+ for (Key i = 0; i < R; i++) {
+ InlineSkipList<TestComparator>::Iterator iter(&list);
+ iter.SeekForPrev(Encode(&i));
+
+ // Compare against model iterator
+ std::set<Key>::iterator model_iter = keys.upper_bound(i);
+ for (int j = 0; j < 3; j++) {
+ if (model_iter == keys.begin()) {
+ ASSERT_TRUE(!iter.Valid());
+ break;
+ } else {
+ ASSERT_TRUE(iter.Valid());
+ ASSERT_EQ(*--model_iter, Decode(iter.key()));
+ iter.Prev();
+ }
+ }
+ }
+}
+
+TEST_F(InlineSkipTest, InsertWithHint_Sequential) {
+ const int N = 100000;
+ Arena arena;
+ TestComparator cmp;
+ TestInlineSkipList list(cmp, &arena);
+ void* hint = nullptr;
+ for (int i = 0; i < N; i++) {
+ Key key = i;
+ InsertWithHint(&list, key, &hint);
+ }
+ Validate(&list);
+}
+
+TEST_F(InlineSkipTest, InsertWithHint_MultipleHints) {
+ const int N = 100000;
+ const int S = 100;
+ Random rnd(534);
+ Arena arena;
+ TestComparator cmp;
+ TestInlineSkipList list(cmp, &arena);
+ void* hints[S];
+ Key last_key[S];
+ for (int i = 0; i < S; i++) {
+ hints[i] = nullptr;
+ last_key[i] = 0;
+ }
+ for (int i = 0; i < N; i++) {
+ Key s = rnd.Uniform(S);
+ Key key = (s << 32) + (++last_key[s]);
+ InsertWithHint(&list, key, &hints[s]);
+ }
+ Validate(&list);
+}
+
+TEST_F(InlineSkipTest, InsertWithHint_MultipleHintsRandom) {
+ const int N = 100000;
+ const int S = 100;
+ Random rnd(534);
+ Arena arena;
+ TestComparator cmp;
+ TestInlineSkipList list(cmp, &arena);
+ void* hints[S];
+ for (int i = 0; i < S; i++) {
+ hints[i] = nullptr;
+ }
+ for (int i = 0; i < N; i++) {
+ Key s = rnd.Uniform(S);
+ Key key = (s << 32) + rnd.Next();
+ InsertWithHint(&list, key, &hints[s]);
+ }
+ Validate(&list);
+}
+
+TEST_F(InlineSkipTest, InsertWithHint_CompatibleWithInsertWithoutHint) {
+ const int N = 100000;
+ const int S1 = 100;
+ const int S2 = 100;
+ Random rnd(534);
+ Arena arena;
+ TestComparator cmp;
+ TestInlineSkipList list(cmp, &arena);
+ std::unordered_set<Key> used;
+ Key with_hint[S1];
+ Key without_hint[S2];
+ void* hints[S1];
+ for (int i = 0; i < S1; i++) {
+ hints[i] = nullptr;
+ while (true) {
+ Key s = rnd.Next();
+ if (used.insert(s).second) {
+ with_hint[i] = s;
+ break;
+ }
+ }
+ }
+ for (int i = 0; i < S2; i++) {
+ while (true) {
+ Key s = rnd.Next();
+ if (used.insert(s).second) {
+ without_hint[i] = s;
+ break;
+ }
+ }
+ }
+ for (int i = 0; i < N; i++) {
+ Key s = rnd.Uniform(S1 + S2);
+ if (s < S1) {
+ Key key = (with_hint[s] << 32) + rnd.Next();
+ InsertWithHint(&list, key, &hints[s]);
+ } else {
+ Key key = (without_hint[s - S1] << 32) + rnd.Next();
+ Insert(&list, key);
+ }
+ }
+ Validate(&list);
+}
+
+// We want to make sure that with a single writer and multiple
+// concurrent readers (with no synchronization other than when a
+// reader's iterator is created), the reader always observes all the
+// data that was present in the skip list when the iterator was
+// constructor. Because insertions are happening concurrently, we may
+// also observe new values that were inserted since the iterator was
+// constructed, but we should never miss any values that were present
+// at iterator construction time.
+//
+// We generate multi-part keys:
+// <key,gen,hash>
+// where:
+// key is in range [0..K-1]
+// gen is a generation number for key
+// hash is hash(key,gen)
+//
+// The insertion code picks a random key, sets gen to be 1 + the last
+// generation number inserted for that key, and sets hash to Hash(key,gen).
+//
+// At the beginning of a read, we snapshot the last inserted
+// generation number for each key. We then iterate, including random
+// calls to Next() and Seek(). For every key we encounter, we
+// check that it is either expected given the initial snapshot or has
+// been concurrently added since the iterator started.
+class ConcurrentTest {
+ public:
+ static const uint32_t K = 8;
+
+ private:
+ static uint64_t key(Key key) { return (key >> 40); }
+ static uint64_t gen(Key key) { return (key >> 8) & 0xffffffffu; }
+ static uint64_t hash(Key key) { return key & 0xff; }
+
+ static uint64_t HashNumbers(uint64_t k, uint64_t g) {
+ uint64_t data[2] = {k, g};
+ return Hash(reinterpret_cast<char*>(data), sizeof(data), 0);
+ }
+
+ static Key MakeKey(uint64_t k, uint64_t g) {
+ assert(sizeof(Key) == sizeof(uint64_t));
+ assert(k <= K); // We sometimes pass K to seek to the end of the skiplist
+ assert(g <= 0xffffffffu);
+ return ((k << 40) | (g << 8) | (HashNumbers(k, g) & 0xff));
+ }
+
+ static bool IsValidKey(Key k) {
+ return hash(k) == (HashNumbers(key(k), gen(k)) & 0xff);
+ }
+
+ static Key RandomTarget(Random* rnd) {
+ switch (rnd->Next() % 10) {
+ case 0:
+ // Seek to beginning
+ return MakeKey(0, 0);
+ case 1:
+ // Seek to end
+ return MakeKey(K, 0);
+ default:
+ // Seek to middle
+ return MakeKey(rnd->Next() % K, 0);
+ }
+ }
+
+ // Per-key generation
+ struct State {
+ std::atomic<int> generation[K];
+ void Set(int k, int v) {
+ generation[k].store(v, std::memory_order_release);
+ }
+ int Get(int k) { return generation[k].load(std::memory_order_acquire); }
+
+ State() {
+ for (unsigned int k = 0; k < K; k++) {
+ Set(k, 0);
+ }
+ }
+ };
+
+ // Current state of the test
+ State current_;
+
+ ConcurrentArena arena_;
+
+ // InlineSkipList is not protected by mu_. We just use a single writer
+ // thread to modify it.
+ InlineSkipList<TestComparator> list_;
+
+ public:
+ ConcurrentTest() : list_(TestComparator(), &arena_) {}
+
+ // REQUIRES: No concurrent calls to WriteStep or ConcurrentWriteStep
+ void WriteStep(Random* rnd) {
+ const uint32_t k = rnd->Next() % K;
+ const int g = current_.Get(k) + 1;
+ const Key new_key = MakeKey(k, g);
+ char* buf = list_.AllocateKey(sizeof(Key));
+ memcpy(buf, &new_key, sizeof(Key));
+ list_.Insert(buf);
+ current_.Set(k, g);
+ }
+
+ // REQUIRES: No concurrent calls for the same k
+ void ConcurrentWriteStep(uint32_t k) {
+ const int g = current_.Get(k) + 1;
+ const Key new_key = MakeKey(k, g);
+ char* buf = list_.AllocateKey(sizeof(Key));
+ memcpy(buf, &new_key, sizeof(Key));
+ list_.InsertConcurrently(buf);
+ ASSERT_EQ(g, current_.Get(k) + 1);
+ current_.Set(k, g);
+ }
+
+ void ReadStep(Random* rnd) {
+ // Remember the initial committed state of the skiplist.
+ State initial_state;
+ for (unsigned int k = 0; k < K; k++) {
+ initial_state.Set(k, current_.Get(k));
+ }
+
+ Key pos = RandomTarget(rnd);
+ InlineSkipList<TestComparator>::Iterator iter(&list_);
+ iter.Seek(Encode(&pos));
+ while (true) {
+ Key current;
+ if (!iter.Valid()) {
+ current = MakeKey(K, 0);
+ } else {
+ current = Decode(iter.key());
+ ASSERT_TRUE(IsValidKey(current)) << current;
+ }
+ ASSERT_LE(pos, current) << "should not go backwards";
+
+ // Verify that everything in [pos,current) was not present in
+ // initial_state.
+ while (pos < current) {
+ ASSERT_LT(key(pos), K) << pos;
+
+ // Note that generation 0 is never inserted, so it is ok if
+ // <*,0,*> is missing.
+ ASSERT_TRUE((gen(pos) == 0U) ||
+ (gen(pos) > static_cast<uint64_t>(initial_state.Get(
+ static_cast<int>(key(pos))))))
+ << "key: " << key(pos) << "; gen: " << gen(pos)
+ << "; initgen: " << initial_state.Get(static_cast<int>(key(pos)));
+
+ // Advance to next key in the valid key space
+ if (key(pos) < key(current)) {
+ pos = MakeKey(key(pos) + 1, 0);
+ } else {
+ pos = MakeKey(key(pos), gen(pos) + 1);
+ }
+ }
+
+ if (!iter.Valid()) {
+ break;
+ }
+
+ if (rnd->Next() % 2) {
+ iter.Next();
+ pos = MakeKey(key(pos), gen(pos) + 1);
+ } else {
+ Key new_target = RandomTarget(rnd);
+ if (new_target > pos) {
+ pos = new_target;
+ iter.Seek(Encode(&new_target));
+ }
+ }
+ }
+ }
+};
+const uint32_t ConcurrentTest::K;
+
+// Simple test that does single-threaded testing of the ConcurrentTest
+// scaffolding.
+TEST_F(InlineSkipTest, ConcurrentReadWithoutThreads) {
+ ConcurrentTest test;
+ Random rnd(test::RandomSeed());
+ for (int i = 0; i < 10000; i++) {
+ test.ReadStep(&rnd);
+ test.WriteStep(&rnd);
+ }
+}
+
+TEST_F(InlineSkipTest, ConcurrentInsertWithoutThreads) {
+ ConcurrentTest test;
+ Random rnd(test::RandomSeed());
+ for (int i = 0; i < 10000; i++) {
+ test.ReadStep(&rnd);
+ uint32_t base = rnd.Next();
+ for (int j = 0; j < 4; ++j) {
+ test.ConcurrentWriteStep((base + j) % ConcurrentTest::K);
+ }
+ }
+}
+
+class TestState {
+ public:
+ ConcurrentTest t_;
+ int seed_;
+ std::atomic<bool> quit_flag_;
+ std::atomic<uint32_t> next_writer_;
+
+ enum ReaderState { STARTING, RUNNING, DONE };
+
+ explicit TestState(int s)
+ : seed_(s),
+ quit_flag_(false),
+ state_(STARTING),
+ pending_writers_(0),
+ state_cv_(&mu_) {}
+
+ void Wait(ReaderState s) {
+ mu_.Lock();
+ while (state_ != s) {
+ state_cv_.Wait();
+ }
+ mu_.Unlock();
+ }
+
+ void Change(ReaderState s) {
+ mu_.Lock();
+ state_ = s;
+ state_cv_.Signal();
+ mu_.Unlock();
+ }
+
+ void AdjustPendingWriters(int delta) {
+ mu_.Lock();
+ pending_writers_ += delta;
+ if (pending_writers_ == 0) {
+ state_cv_.Signal();
+ }
+ mu_.Unlock();
+ }
+
+ void WaitForPendingWriters() {
+ mu_.Lock();
+ while (pending_writers_ != 0) {
+ state_cv_.Wait();
+ }
+ mu_.Unlock();
+ }
+
+ private:
+ port::Mutex mu_;
+ ReaderState state_;
+ int pending_writers_;
+ port::CondVar state_cv_;
+};
+
+static void ConcurrentReader(void* arg) {
+ TestState* state = reinterpret_cast<TestState*>(arg);
+ Random rnd(state->seed_);
+ int64_t reads = 0;
+ state->Change(TestState::RUNNING);
+ while (!state->quit_flag_.load(std::memory_order_acquire)) {
+ state->t_.ReadStep(&rnd);
+ ++reads;
+ }
+ state->Change(TestState::DONE);
+}
+
+static void ConcurrentWriter(void* arg) {
+ TestState* state = reinterpret_cast<TestState*>(arg);
+ uint32_t k = state->next_writer_++ % ConcurrentTest::K;
+ state->t_.ConcurrentWriteStep(k);
+ state->AdjustPendingWriters(-1);
+}
+
+static void RunConcurrentRead(int run) {
+ const int seed = test::RandomSeed() + (run * 100);
+ Random rnd(seed);
+ const int N = 1000;
+ const int kSize = 1000;
+ for (int i = 0; i < N; i++) {
+ if ((i % 100) == 0) {
+ fprintf(stderr, "Run %d of %d\n", i, N);
+ }
+ TestState state(seed + 1);
+ Env::Default()->Schedule(ConcurrentReader, &state);
+ state.Wait(TestState::RUNNING);
+ for (int k = 0; k < kSize; ++k) {
+ state.t_.WriteStep(&rnd);
+ }
+ state.quit_flag_.store(true, std::memory_order_release);
+ state.Wait(TestState::DONE);
+ }
+}
+
+static void RunConcurrentInsert(int run, int write_parallelism = 4) {
+ Env::Default()->SetBackgroundThreads(1 + write_parallelism,
+ Env::Priority::LOW);
+ const int seed = test::RandomSeed() + (run * 100);
+ Random rnd(seed);
+ const int N = 1000;
+ const int kSize = 1000;
+ for (int i = 0; i < N; i++) {
+ if ((i % 100) == 0) {
+ fprintf(stderr, "Run %d of %d\n", i, N);
+ }
+ TestState state(seed + 1);
+ Env::Default()->Schedule(ConcurrentReader, &state);
+ state.Wait(TestState::RUNNING);
+ for (int k = 0; k < kSize; k += write_parallelism) {
+ state.next_writer_ = rnd.Next();
+ state.AdjustPendingWriters(write_parallelism);
+ for (int p = 0; p < write_parallelism; ++p) {
+ Env::Default()->Schedule(ConcurrentWriter, &state);
+ }
+ state.WaitForPendingWriters();
+ }
+ state.quit_flag_.store(true, std::memory_order_release);
+ state.Wait(TestState::DONE);
+ }
+}
+
+TEST_F(InlineSkipTest, ConcurrentRead1) { RunConcurrentRead(1); }
+TEST_F(InlineSkipTest, ConcurrentRead2) { RunConcurrentRead(2); }
+TEST_F(InlineSkipTest, ConcurrentRead3) { RunConcurrentRead(3); }
+TEST_F(InlineSkipTest, ConcurrentRead4) { RunConcurrentRead(4); }
+TEST_F(InlineSkipTest, ConcurrentRead5) { RunConcurrentRead(5); }
+TEST_F(InlineSkipTest, ConcurrentInsert1) { RunConcurrentInsert(1); }
+TEST_F(InlineSkipTest, ConcurrentInsert2) { RunConcurrentInsert(2); }
+TEST_F(InlineSkipTest, ConcurrentInsert3) { RunConcurrentInsert(3); }
+
+} // namespace rocksdb
+
+int main(int argc, char** argv) {
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/memtable/memtable_allocator.cc b/memtable/memtable_allocator.cc
new file mode 100644
index 000000000..08a7dbf74
--- /dev/null
+++ b/memtable/memtable_allocator.cc
@@ -0,0 +1,59 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree. An additional grant
+// of patent rights can be found in the PATENTS file in the same directory.
+//
+// Copyright (c) 2011 The LevelDB Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file. See the AUTHORS file for names of contributors.
+
+#include "memtable/memtable_allocator.h"
+
+#include <assert.h>
+#include "rocksdb/write_buffer_manager.h"
+#include "util/arena.h"
+
+namespace rocksdb {
+
+MemTableAllocator::MemTableAllocator(Allocator* allocator,
+ WriteBufferManager* write_buffer_manager)
+ : allocator_(allocator),
+ write_buffer_manager_(write_buffer_manager),
+ bytes_allocated_(0) {}
+
+MemTableAllocator::~MemTableAllocator() { DoneAllocating(); }
+
+char* MemTableAllocator::Allocate(size_t bytes) {
+ assert(write_buffer_manager_ != nullptr);
+ if (write_buffer_manager_->enabled()) {
+ bytes_allocated_.fetch_add(bytes, std::memory_order_relaxed);
+ write_buffer_manager_->ReserveMem(bytes);
+ }
+ return allocator_->Allocate(bytes);
+}
+
+char* MemTableAllocator::AllocateAligned(size_t bytes, size_t huge_page_size,
+ Logger* logger) {
+ assert(write_buffer_manager_ != nullptr);
+ if (write_buffer_manager_->enabled()) {
+ bytes_allocated_.fetch_add(bytes, std::memory_order_relaxed);
+ write_buffer_manager_->ReserveMem(bytes);
+ }
+ return allocator_->AllocateAligned(bytes, huge_page_size, logger);
+}
+
+void MemTableAllocator::DoneAllocating() {
+ if (write_buffer_manager_ != nullptr) {
+ if (write_buffer_manager_->enabled()) {
+ write_buffer_manager_->FreeMem(
+ bytes_allocated_.load(std::memory_order_relaxed));
+ } else {
+ assert(bytes_allocated_.load(std::memory_order_relaxed) == 0);
+ }
+ write_buffer_manager_ = nullptr;
+ }
+}
+
+size_t MemTableAllocator::BlockSize() const { return allocator_->BlockSize(); }
+
+} // namespace rocksdb
diff --git a/memtable/memtable_allocator.h b/memtable/memtable_allocator.h
new file mode 100644
index 000000000..050e13b36
--- /dev/null
+++ b/memtable/memtable_allocator.h
@@ -0,0 +1,50 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree. An additional grant
+// of patent rights can be found in the PATENTS file in the same directory.
+//
+// Copyright (c) 2011 The LevelDB Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file. See the AUTHORS file for names of contributors.
+//
+// This is used by the MemTable to allocate write buffer memory. It connects
+// to WriteBufferManager so we can track and enforce overall write buffer
+// limits.
+
+#pragma once
+
+#include <atomic>
+#include "rocksdb/write_buffer_manager.h"
+#include "util/allocator.h"
+
+namespace rocksdb {
+
+class Logger;
+
+class MemTableAllocator : public Allocator {
+ public:
+ explicit MemTableAllocator(Allocator* allocator,
+ WriteBufferManager* write_buffer_manager);
+ ~MemTableAllocator();
+
+ // Allocator interface
+ char* Allocate(size_t bytes) override;
+ char* AllocateAligned(size_t bytes, size_t huge_page_size = 0,
+ Logger* logger = nullptr) override;
+ size_t BlockSize() const override;
+
+ // Call when we're finished allocating memory so we can free it from
+ // the write buffer's limit.
+ void DoneAllocating();
+
+ private:
+ Allocator* allocator_;
+ WriteBufferManager* write_buffer_manager_;
+ std::atomic<size_t> bytes_allocated_;
+
+ // No copying allowed
+ MemTableAllocator(const MemTableAllocator&);
+ void operator=(const MemTableAllocator&);
+};
+
+} // namespace rocksdb
diff --git a/memtable/memtablerep_bench.cc b/memtable/memtablerep_bench.cc
new file mode 100644
index 000000000..7eaffab91
--- /dev/null
+++ b/memtable/memtablerep_bench.cc
@@ -0,0 +1,697 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree. An additional grant
+// of patent rights can be found in the PATENTS file in the same directory.
+//
+// Copyright (c) 2011 The LevelDB Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file. See the AUTHORS file for names of contributors.
+
+#define __STDC_FORMAT_MACROS
+
+#ifndef GFLAGS
+#include <cstdio>
+int main() {
+ fprintf(stderr, "Please install gflags to run rocksdb tools\n");
+ return 1;
+}
+#else
+
+#include <gflags/gflags.h>
+
+#include <atomic>
+#include <iostream>
+#include <memory>
+#include <thread>
+#include <type_traits>
+#include <vector>
+
+#include "db/dbformat.h"
+#include "db/memtable.h"
+#include "port/port.h"
+#include "port/stack_trace.h"
+#include "rocksdb/comparator.h"
+#include "rocksdb/memtablerep.h"
+#include "rocksdb/options.h"
+#include "rocksdb/slice_transform.h"
+#include "rocksdb/write_buffer_manager.h"
+#include "util/arena.h"
+#include "util/mutexlock.h"
+#include "util/stop_watch.h"
+#include "util/testutil.h"
+
+using GFLAGS::ParseCommandLineFlags;
+using GFLAGS::RegisterFlagValidator;
+using GFLAGS::SetUsageMessage;
+
+DEFINE_string(benchmarks, "fillrandom",
+ "Comma-separated list of benchmarks to run. Options:\n"
+ "\tfillrandom -- write N random values\n"
+ "\tfillseq -- write N values in sequential order\n"
+ "\treadrandom -- read N values in random order\n"
+ "\treadseq -- scan the DB\n"
+ "\treadwrite -- 1 thread writes while N - 1 threads "
+ "do random\n"
+ "\t reads\n"
+ "\tseqreadwrite -- 1 thread writes while N - 1 threads "
+ "do scans\n");
+
+DEFINE_string(memtablerep, "skiplist",
+ "Which implementation of memtablerep to use. See "
+ "include/memtablerep.h for\n"
+ " more details. Options:\n"
+ "\tskiplist -- backed by a skiplist\n"
+ "\tvector -- backed by an std::vector\n"
+ "\thashskiplist -- backed by a hash skip list\n"
+ "\thashlinklist -- backed by a hash linked list\n"
+ "\tcuckoo -- backed by a cuckoo hash table");
+
+DEFINE_int64(bucket_count, 1000000,
+ "bucket_count parameter to pass into NewHashSkiplistRepFactory or "
+ "NewHashLinkListRepFactory");
+
+DEFINE_int32(
+ hashskiplist_height, 4,
+ "skiplist_height parameter to pass into NewHashSkiplistRepFactory");
+
+DEFINE_int32(
+ hashskiplist_branching_factor, 4,
+ "branching_factor parameter to pass into NewHashSkiplistRepFactory");
+
+DEFINE_int32(
+ huge_page_tlb_size, 0,
+ "huge_page_tlb_size parameter to pass into NewHashLinkListRepFactory");
+
+DEFINE_int32(bucket_entries_logging_threshold, 4096,
+ "bucket_entries_logging_threshold parameter to pass into "
+ "NewHashLinkListRepFactory");
+
+DEFINE_bool(if_log_bucket_dist_when_flash, true,
+ "if_log_bucket_dist_when_flash parameter to pass into "
+ "NewHashLinkListRepFactory");
+
+DEFINE_int32(
+ threshold_use_skiplist, 256,
+ "threshold_use_skiplist parameter to pass into NewHashLinkListRepFactory");
+
+DEFINE_int64(
+ write_buffer_size, 256,
+ "write_buffer_size parameter to pass into NewHashCuckooRepFactory");
+
+DEFINE_int64(
+ average_data_size, 64,
+ "average_data_size parameter to pass into NewHashCuckooRepFactory");
+
+DEFINE_int64(
+ hash_function_count, 4,
+ "hash_function_count parameter to pass into NewHashCuckooRepFactory");
+
+DEFINE_int32(
+ num_threads, 1,
+ "Number of concurrent threads to run. If the benchmark includes writes,\n"
+ "then at most one thread will be a writer");
+
+DEFINE_int32(num_operations, 1000000,
+ "Number of operations to do for write and random read benchmarks");
+
+DEFINE_int32(num_scans, 10,
+ "Number of times for each thread to scan the memtablerep for "
+ "sequential read "
+ "benchmarks");
+
+DEFINE_int32(item_size, 100, "Number of bytes each item should be");
+
+DEFINE_int32(prefix_length, 8,
+ "Prefix length to pass into NewFixedPrefixTransform");
+
+/* VectorRep settings */
+DEFINE_int64(vectorrep_count, 0,
+ "Number of entries to reserve on VectorRep initialization");
+
+DEFINE_int64(seed, 0,
+ "Seed base for random number generators. "
+ "When 0 it is deterministic.");
+
+namespace rocksdb {
+
+namespace {
+struct CallbackVerifyArgs {
+ bool found;
+ LookupKey* key;
+ MemTableRep* table;
+ InternalKeyComparator* comparator;
+};
+} // namespace
+
+// Helper for quickly generating random data.
+class RandomGenerator {
+ private:
+ std::string data_;
+ unsigned int pos_;
+
+ public:
+ RandomGenerator() {
+ Random rnd(301);
+ auto size = (unsigned)std::max(1048576, FLAGS_item_size);
+ test::RandomString(&rnd, size, &data_);
+ pos_ = 0;
+ }
+
+ Slice Generate(unsigned int len) {
+ assert(len <= data_.size());
+ if (pos_ + len > data_.size()) {
+ pos_ = 0;
+ }
+ pos_ += len;
+ return Slice(data_.data() + pos_ - len, len);
+ }
+};
+
+enum WriteMode { SEQUENTIAL, RANDOM, UNIQUE_RANDOM };
+
+class KeyGenerator {
+ public:
+ KeyGenerator(Random64* rand, WriteMode mode, uint64_t num)
+ : rand_(rand), mode_(mode), num_(num), next_(0) {
+ if (mode_ == UNIQUE_RANDOM) {
+ // NOTE: if memory consumption of this approach becomes a concern,
+ // we can either break it into pieces and only random shuffle a section
+ // each time. Alternatively, use a bit map implementation
+ // (https://reviews.facebook.net/differential/diff/54627/)
+ values_.resize(num_);
+ for (uint64_t i = 0; i < num_; ++i) {
+ values_[i] = i;
+ }
+ std::shuffle(
+ values_.begin(), values_.end(),
+ std::default_random_engine(static_cast<unsigned int>(FLAGS_seed)));
+ }
+ }
+
+ uint64_t Next() {
+ switch (mode_) {
+ case SEQUENTIAL:
+ return next_++;
+ case RANDOM:
+ return rand_->Next() % num_;
+ case UNIQUE_RANDOM:
+ return values_[next_++];
+ }
+ assert(false);
+ return std::numeric_limits<uint64_t>::max();
+ }
+
+ private:
+ Random64* rand_;
+ WriteMode mode_;
+ const uint64_t num_;
+ uint64_t next_;
+ std::vector<uint64_t> values_;
+};
+
+class BenchmarkThread {
+ public:
+ explicit BenchmarkThread(MemTableRep* table, KeyGenerator* key_gen,
+ uint64_t* bytes_written, uint64_t* bytes_read,
+ uint64_t* sequence, uint64_t num_ops,
+ uint64_t* read_hits)
+ : table_(table),
+ key_gen_(key_gen),
+ bytes_written_(bytes_written),
+ bytes_read_(bytes_read),
+ sequence_(sequence),
+ num_ops_(num_ops),
+ read_hits_(read_hits) {}
+
+ virtual void operator()() = 0;
+ virtual ~BenchmarkThread() {}
+
+ protected:
+ MemTableRep* table_;
+ KeyGenerator* key_gen_;
+ uint64_t* bytes_written_;
+ uint64_t* bytes_read_;
+ uint64_t* sequence_;
+ uint64_t num_ops_;
+ uint64_t* read_hits_;
+ RandomGenerator generator_;
+};
+
+class FillBenchmarkThread : public BenchmarkThread {
+ public:
+ FillBenchmarkThread(MemTableRep* table, KeyGenerator* key_gen,
+ uint64_t* bytes_written, uint64_t* bytes_read,
+ uint64_t* sequence, uint64_t num_ops, uint64_t* read_hits)
+ : BenchmarkThread(table, key_gen, bytes_written, bytes_read, sequence,
+ num_ops, read_hits) {}
+
+ void FillOne() {
+ char* buf = nullptr;
+ auto internal_key_size = 16;
+ auto encoded_len =
+ FLAGS_item_size + VarintLength(internal_key_size) + internal_key_size;
+ KeyHandle handle = table_->Allocate(encoded_len, &buf);
+ assert(buf != nullptr);
+ char* p = EncodeVarint32(buf, internal_key_size);
+ auto key = key_gen_->Next();
+ EncodeFixed64(p, key);
+ p += 8;
+ EncodeFixed64(p, ++(*sequence_));
+ p += 8;
+ Slice bytes = generator_.Generate(FLAGS_item_size);
+ memcpy(p, bytes.data(), FLAGS_item_size);
+ p += FLAGS_item_size;
+ assert(p == buf + encoded_len);
+ table_->Insert(handle);
+ *bytes_written_ += encoded_len;
+ }
+
+ void operator()() override {
+ for (unsigned int i = 0; i < num_ops_; ++i) {
+ FillOne();
+ }
+ }
+};
+
+class ConcurrentFillBenchmarkThread : public FillBenchmarkThread {
+ public:
+ ConcurrentFillBenchmarkThread(MemTableRep* table, KeyGenerator* key_gen,
+ uint64_t* bytes_written, uint64_t* bytes_read,
+ uint64_t* sequence, uint64_t num_ops,
+ uint64_t* read_hits,
+ std::atomic_int* threads_done)
+ : FillBenchmarkThread(table, key_gen, bytes_written, bytes_read, sequence,
+ num_ops, read_hits) {
+ threads_done_ = threads_done;
+ }
+
+ void operator()() override {
+ // # of read threads will be total threads - write threads (always 1). Loop
+ // while all reads complete.
+ while ((*threads_done_).load() < (FLAGS_num_threads - 1)) {
+ FillOne();
+ }
+ }
+
+ private:
+ std::atomic_int* threads_done_;
+};
+
+class ReadBenchmarkThread : public BenchmarkThread {
+ public:
+ ReadBenchmarkThread(MemTableRep* table, KeyGenerator* key_gen,
+ uint64_t* bytes_written, uint64_t* bytes_read,
+ uint64_t* sequence, uint64_t num_ops, uint64_t* read_hits)
+ : BenchmarkThread(table, key_gen, bytes_written, bytes_read, sequence,
+ num_ops, read_hits) {}
+
+ static bool callback(void* arg, const char* entry) {
+ CallbackVerifyArgs* callback_args = static_cast<CallbackVerifyArgs*>(arg);
+ assert(callback_args != nullptr);
+ uint32_t key_length;
+ const char* key_ptr = GetVarint32Ptr(entry, entry + 5, &key_length);
+ if ((callback_args->comparator)
+ ->user_comparator()
+ ->Equal(Slice(key_ptr, key_length - 8),
+ callback_args->key->user_key())) {
+ callback_args->found = true;
+ }
+ return false;
+ }
+
+ void ReadOne() {
+ std::string user_key;
+ auto key = key_gen_->Next();
+ PutFixed64(&user_key, key);
+ LookupKey lookup_key(user_key, *sequence_);
+ InternalKeyComparator internal_key_comp(BytewiseComparator());
+ CallbackVerifyArgs verify_args;
+ verify_args.found = false;
+ verify_args.key = &lookup_key;
+ verify_args.table = table_;
+ verify_args.comparator = &internal_key_comp;
+ table_->Get(lookup_key, &verify_args, callback);
+ if (verify_args.found) {
+ *bytes_read_ += VarintLength(16) + 16 + FLAGS_item_size;
+ ++*read_hits_;
+ }
+ }
+ void operator()() override {
+ for (unsigned int i = 0; i < num_ops_; ++i) {
+ ReadOne();
+ }
+ }
+};
+
+class SeqReadBenchmarkThread : public BenchmarkThread {
+ public:
+ SeqReadBenchmarkThread(MemTableRep* table, KeyGenerator* key_gen,
+ uint64_t* bytes_written, uint64_t* bytes_read,
+ uint64_t* sequence, uint64_t num_ops,
+ uint64_t* read_hits)
+ : BenchmarkThread(table, key_gen, bytes_written, bytes_read, sequence,
+ num_ops, read_hits) {}
+
+ void ReadOneSeq() {
+ std::unique_ptr<MemTableRep::Iterator> iter(table_->GetIterator());
+ for (iter->SeekToFirst(); iter->Valid(); iter->Next()) {
+ // pretend to read the value
+ *bytes_read_ += VarintLength(16) + 16 + FLAGS_item_size;
+ }
+ ++*read_hits_;
+ }
+
+ void operator()() override {
+ for (unsigned int i = 0; i < num_ops_; ++i) {
+ { ReadOneSeq(); }
+ }
+ }
+};
+
+class ConcurrentReadBenchmarkThread : public ReadBenchmarkThread {
+ public:
+ ConcurrentReadBenchmarkThread(MemTableRep* table, KeyGenerator* key_gen,
+ uint64_t* bytes_written, uint64_t* bytes_read,
+ uint64_t* sequence, uint64_t num_ops,
+ uint64_t* read_hits,
+ std::atomic_int* threads_done)
+ : ReadBenchmarkThread(table, key_gen, bytes_written, bytes_read, sequence,
+ num_ops, read_hits) {
+ threads_done_ = threads_done;
+ }
+
+ void operator()() override {
+ for (unsigned int i = 0; i < num_ops_; ++i) {
+ ReadOne();
+ }
+ ++*threads_done_;
+ }
+
+ private:
+ std::atomic_int* threads_done_;
+};
+
+class SeqConcurrentReadBenchmarkThread : public SeqReadBenchmarkThread {
+ public:
+ SeqConcurrentReadBenchmarkThread(MemTableRep* table, KeyGenerator* key_gen,
+ uint64_t* bytes_written,
+ uint64_t* bytes_read, uint64_t* sequence,
+ uint64_t num_ops, uint64_t* read_hits,
+ std::atomic_int* threads_done)
+ : SeqReadBenchmarkThread(table, key_gen, bytes_written, bytes_read,
+ sequence, num_ops, read_hits) {
+ threads_done_ = threads_done;
+ }
+
+ void operator()() override {
+ for (unsigned int i = 0; i < num_ops_; ++i) {
+ ReadOneSeq();
+ }
+ ++*threads_done_;
+ }
+
+ private:
+ std::atomic_int* threads_done_;
+};
+
+class Benchmark {
+ public:
+ explicit Benchmark(MemTableRep* table, KeyGenerator* key_gen,
+ uint64_t* sequence, uint32_t num_threads)
+ : table_(table),
+ key_gen_(key_gen),
+ sequence_(sequence),
+ num_threads_(num_threads) {}
+
+ virtual ~Benchmark() {}
+ virtual void Run() {
+ std::cout << "Number of threads: " << num_threads_ << std::endl;
+ std::vector<port::Thread> threads;
+ uint64_t bytes_written = 0;
+ uint64_t bytes_read = 0;
+ uint64_t read_hits = 0;
+ StopWatchNano timer(Env::Default(), true);
+ RunThreads(&threads, &bytes_written, &bytes_read, true, &read_hits);
+ auto elapsed_time = static_cast<double>(timer.ElapsedNanos() / 1000);
+ std::cout << "Elapsed time: " << static_cast<int>(elapsed_time) << " us"
+ << std::endl;
+
+ if (bytes_written > 0) {
+ auto MiB_written = static_cast<double>(bytes_written) / (1 << 20);
+ auto write_throughput = MiB_written / (elapsed_time / 1000000);
+ std::cout << "Total bytes written: " << MiB_written << " MiB"
+ << std::endl;
+ std::cout << "Write throughput: " << write_throughput << " MiB/s"
+ << std::endl;
+ auto us_per_op = elapsed_time / num_write_ops_per_thread_;
+ std::cout << "write us/op: " << us_per_op << std::endl;
+ }
+ if (bytes_read > 0) {
+ auto MiB_read = static_cast<double>(bytes_read) / (1 << 20);
+ auto read_throughput = MiB_read / (elapsed_time / 1000000);
+ std::cout << "Total bytes read: " << MiB_read << " MiB" << std::endl;
+ std::cout << "Read throughput: " << read_throughput << " MiB/s"
+ << std::endl;
+ auto us_per_op = elapsed_time / num_read_ops_per_thread_;
+ std::cout << "read us/op: " << us_per_op << std::endl;
+ }
+ }
+
+ virtual void RunThreads(std::vector<port::Thread>* threads,
+ uint64_t* bytes_written, uint64_t* bytes_read,
+ bool write, uint64_t* read_hits) = 0;
+
+ protected:
+ MemTableRep* table_;
+ KeyGenerator* key_gen_;
+ uint64_t* sequence_;
+ uint64_t num_write_ops_per_thread_;
+ uint64_t num_read_ops_per_thread_;
+ const uint32_t num_threads_;
+};
+
+class FillBenchmark : public Benchmark {
+ public:
+ explicit FillBenchmark(MemTableRep* table, KeyGenerator* key_gen,
+ uint64_t* sequence)
+ : Benchmark(table, key_gen, sequence, 1) {
+ num_write_ops_per_thread_ = FLAGS_num_operations;
+ }
+
+ void RunThreads(std::vector<port::Thread>* threads, uint64_t* bytes_written,
+ uint64_t* bytes_read, bool write,
+ uint64_t* read_hits) override {
+ FillBenchmarkThread(table_, key_gen_, bytes_written, bytes_read, sequence_,
+ num_write_ops_per_thread_, read_hits)();
+ }
+};
+
+class ReadBenchmark : public Benchmark {
+ public:
+ explicit ReadBenchmark(MemTableRep* table, KeyGenerator* key_gen,
+ uint64_t* sequence)
+ : Benchmark(table, key_gen, sequence, FLAGS_num_threads) {
+ num_read_ops_per_thread_ = FLAGS_num_operations / FLAGS_num_threads;
+ }
+
+ void RunThreads(std::vector<port::Thread>* threads, uint64_t* bytes_written,
+ uint64_t* bytes_read, bool write,
+ uint64_t* read_hits) override {
+ for (int i = 0; i < FLAGS_num_threads; ++i) {
+ threads->emplace_back(
+ ReadBenchmarkThread(table_, key_gen_, bytes_written, bytes_read,
+ sequence_, num_read_ops_per_thread_, read_hits));
+ }
+ for (auto& thread : *threads) {
+ thread.join();
+ }
+ std::cout << "read hit%: "
+ << (static_cast<double>(*read_hits) / FLAGS_num_operations) * 100
+ << std::endl;
+ }
+};
+
+class SeqReadBenchmark : public Benchmark {
+ public:
+ explicit SeqReadBenchmark(MemTableRep* table, uint64_t* sequence)
+ : Benchmark(table, nullptr, sequence, FLAGS_num_threads) {
+ num_read_ops_per_thread_ = FLAGS_num_scans;
+ }
+
+ void RunThreads(std::vector<port::Thread>* threads, uint64_t* bytes_written,
+ uint64_t* bytes_read, bool write,
+ uint64_t* read_hits) override {
+ for (int i = 0; i < FLAGS_num_threads; ++i) {
+ threads->emplace_back(SeqReadBenchmarkThread(
+ table_, key_gen_, bytes_written, bytes_read, sequence_,
+ num_read_ops_per_thread_, read_hits));
+ }
+ for (auto& thread : *threads) {
+ thread.join();
+ }
+ }
+};
+
+template <class ReadThreadType>
+class ReadWriteBenchmark : public Benchmark {
+ public:
+ explicit ReadWriteBenchmark(MemTableRep* table, KeyGenerator* key_gen,
+ uint64_t* sequence)
+ : Benchmark(table, key_gen, sequence, FLAGS_num_threads) {
+ num_read_ops_per_thread_ =
+ FLAGS_num_threads <= 1
+ ? 0
+ : (FLAGS_num_operations / (FLAGS_num_threads - 1));
+ num_write_ops_per_thread_ = FLAGS_num_operations;
+ }
+
+ void RunThreads(std::vector<port::Thread>* threads, uint64_t* bytes_written,
+ uint64_t* bytes_read, bool write,
+ uint64_t* read_hits) override {
+ std::atomic_int threads_done;
+ threads_done.store(0);
+ threads->emplace_back(ConcurrentFillBenchmarkThread(
+ table_, key_gen_, bytes_written, bytes_read, sequence_,
+ num_write_ops_per_thread_, read_hits, &threads_done));
+ for (int i = 1; i < FLAGS_num_threads; ++i) {
+ threads->emplace_back(
+ ReadThreadType(table_, key_gen_, bytes_written, bytes_read, sequence_,
+ num_read_ops_per_thread_, read_hits, &threads_done));
+ }
+ for (auto& thread : *threads) {
+ thread.join();
+ }
+ }
+};
+
+} // namespace rocksdb
+
+void PrintWarnings() {
+#if defined(__GNUC__) && !defined(__OPTIMIZE__)
+ fprintf(stdout,
+ "WARNING: Optimization is disabled: benchmarks unnecessarily slow\n");
+#endif
+#ifndef NDEBUG
+ fprintf(stdout,
+ "WARNING: Assertions are enabled; benchmarks unnecessarily slow\n");
+#endif
+}
+
+int main(int argc, char** argv) {
+ rocksdb::port::InstallStackTraceHandler();
+ SetUsageMessage(std::string("\nUSAGE:\n") + std::string(argv[0]) +
+ " [OPTIONS]...");
+ ParseCommandLineFlags(&argc, &argv, true);
+
+ PrintWarnings();
+
+ rocksdb::Options options;
+
+ std::unique_ptr<rocksdb::MemTableRepFactory> factory;
+ if (FLAGS_memtablerep == "skiplist") {
+ factory.reset(new rocksdb::SkipListFactory);
+#ifndef ROCKSDB_LITE
+ } else if (FLAGS_memtablerep == "vector") {
+ factory.reset(new rocksdb::VectorRepFactory);
+ } else if (FLAGS_memtablerep == "hashskiplist") {
+ factory.reset(rocksdb::NewHashSkipListRepFactory(
+ FLAGS_bucket_count, FLAGS_hashskiplist_height,
+ FLAGS_hashskiplist_branching_factor));
+ options.prefix_extractor.reset(
+ rocksdb::NewFixedPrefixTransform(FLAGS_prefix_length));
+ } else if (FLAGS_memtablerep == "hashlinklist") {
+ factory.reset(rocksdb::NewHashLinkListRepFactory(
+ FLAGS_bucket_count, FLAGS_huge_page_tlb_size,
+ FLAGS_bucket_entries_logging_threshold,
+ FLAGS_if_log_bucket_dist_when_flash, FLAGS_threshold_use_skiplist));
+ options.prefix_extractor.reset(
+ rocksdb::NewFixedPrefixTransform(FLAGS_prefix_length));
+ } else if (FLAGS_memtablerep == "cuckoo") {
+ factory.reset(rocksdb::NewHashCuckooRepFactory(
+ FLAGS_write_buffer_size, FLAGS_average_data_size,
+ static_cast<uint32_t>(FLAGS_hash_function_count)));
+ options.prefix_extractor.reset(
+ rocksdb::NewFixedPrefixTransform(FLAGS_prefix_length));
+#endif // ROCKSDB_LITE
+ } else {
+ fprintf(stdout, "Unknown memtablerep: %s\n", FLAGS_memtablerep.c_str());
+ exit(1);
+ }
+
+ rocksdb::InternalKeyComparator internal_key_comp(
+ rocksdb::BytewiseComparator());
+ rocksdb::MemTable::KeyComparator key_comp(internal_key_comp);
+ rocksdb::Arena arena;
+ rocksdb::WriteBufferManager wb(FLAGS_write_buffer_size);
+ rocksdb::MemTableAllocator memtable_allocator(&arena, &wb);
+ uint64_t sequence;
+ auto createMemtableRep = [&] {
+ sequence = 0;
+ return factory->CreateMemTableRep(key_comp, &memtable_allocator,
+ options.prefix_extractor.get(),
+ options.info_log.get());
+ };
+ std::unique_ptr<rocksdb::MemTableRep> memtablerep;
+ rocksdb::Random64 rng(FLAGS_seed);
+ const char* benchmarks = FLAGS_benchmarks.c_str();
+ while (benchmarks != nullptr) {
+ std::unique_ptr<rocksdb::KeyGenerator> key_gen;
+ const char* sep = strchr(benchmarks, ',');
+ rocksdb::Slice name;
+ if (sep == nullptr) {
+ name = benchmarks;
+ benchmarks = nullptr;
+ } else {
+ name = rocksdb::Slice(benchmarks, sep - benchmarks);
+ benchmarks = sep + 1;
+ }
+ std::unique_ptr<rocksdb::Benchmark> benchmark;
+ if (name == rocksdb::Slice("fillseq")) {
+ memtablerep.reset(createMemtableRep());
+ key_gen.reset(new rocksdb::KeyGenerator(&rng, rocksdb::SEQUENTIAL,
+ FLAGS_num_operations));
+ benchmark.reset(new rocksdb::FillBenchmark(memtablerep.get(),
+ key_gen.get(), &sequence));
+ } else if (name == rocksdb::Slice("fillrandom")) {
+ memtablerep.reset(createMemtableRep());
+ key_gen.reset(new rocksdb::KeyGenerator(&rng, rocksdb::UNIQUE_RANDOM,
+ FLAGS_num_operations));
+ benchmark.reset(new rocksdb::FillBenchmark(memtablerep.get(),
+ key_gen.get(), &sequence));
+ } else if (name == rocksdb::Slice("readrandom")) {
+ key_gen.reset(new rocksdb::KeyGenerator(&rng, rocksdb::RANDOM,
+ FLAGS_num_operations));
+ benchmark.reset(new rocksdb::ReadBenchmark(memtablerep.get(),
+ key_gen.get(), &sequence));
+ } else if (name == rocksdb::Slice("readseq")) {
+ key_gen.reset(new rocksdb::KeyGenerator(&rng, rocksdb::SEQUENTIAL,
+ FLAGS_num_operations));
+ benchmark.reset(
+ new rocksdb::SeqReadBenchmark(memtablerep.get(), &sequence));
+ } else if (name == rocksdb::Slice("readwrite")) {
+ memtablerep.reset(createMemtableRep());
+ key_gen.reset(new rocksdb::KeyGenerator(&rng, rocksdb::RANDOM,
+ FLAGS_num_operations));
+ benchmark.reset(new rocksdb::ReadWriteBenchmark<
+ rocksdb::ConcurrentReadBenchmarkThread>(memtablerep.get(),
+ key_gen.get(), &sequence));
+ } else if (name == rocksdb::Slice("seqreadwrite")) {
+ memtablerep.reset(createMemtableRep());
+ key_gen.reset(new rocksdb::KeyGenerator(&rng, rocksdb::RANDOM,
+ FLAGS_num_operations));
+ benchmark.reset(new rocksdb::ReadWriteBenchmark<
+ rocksdb::SeqConcurrentReadBenchmarkThread>(memtablerep.get(),
+ key_gen.get(), &sequence));
+ } else {
+ std::cout << "WARNING: skipping unknown benchmark '" << name.ToString()
+ << std::endl;
+ continue;
+ }
+ std::cout << "Running " << name.ToString() << std::endl;
+ benchmark->Run();
+ }
+
+ return 0;
+}
+
+#endif // GFLAGS
diff --git a/memtable/skiplist.h b/memtable/skiplist.h
new file mode 100644
index 000000000..fe69094fb
--- /dev/null
+++ b/memtable/skiplist.h
@@ -0,0 +1,495 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree. An additional grant
+// of patent rights can be found in the PATENTS file in the same directory.
+//
+// Copyright (c) 2011 The LevelDB Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file. See the AUTHORS file for names of contributors.
+//
+// Thread safety
+// -------------
+//
+// Writes require external synchronization, most likely a mutex.
+// Reads require a guarantee that the SkipList will not be destroyed
+// while the read is in progress. Apart from that, reads progress
+// without any internal locking or synchronization.
+//
+// Invariants:
+//
+// (1) Allocated nodes are never deleted until the SkipList is
+// destroyed. This is trivially guaranteed by the code since we
+// never delete any skip list nodes.
+//
+// (2) The contents of a Node except for the next/prev pointers are
+// immutable after the Node has been linked into the SkipList.
+// Only Insert() modifies the list, and it is careful to initialize
+// a node and use release-stores to publish the nodes in one or
+// more lists.
+//
+// ... prev vs. next pointer ordering ...
+//
+
+#pragma once
+#include <assert.h>
+#include <atomic>
+#include <stdlib.h>
+#include "port/port.h"
+#include "util/allocator.h"
+#include "util/random.h"
+
+namespace rocksdb {
+
+template<typename Key, class Comparator>
+class SkipList {
+ private:
+ struct Node;
+
+ public:
+ // Create a new SkipList object that will use "cmp" for comparing keys,
+ // and will allocate memory using "*allocator". Objects allocated in the
+ // allocator must remain allocated for the lifetime of the skiplist object.
+ explicit SkipList(Comparator cmp, Allocator* allocator,
+ int32_t max_height = 12, int32_t branching_factor = 4);
+
+ // Insert key into the list.
+ // REQUIRES: nothing that compares equal to key is currently in the list.
+ void Insert(const Key& key);
+
+ // Returns true iff an entry that compares equal to key is in the list.
+ bool Contains(const Key& key) const;
+
+ // Return estimated number of entries smaller than `key`.
+ uint64_t EstimateCount(const Key& key) const;
+
+ // Iteration over the contents of a skip list
+ class Iterator {
+ public:
+ // Initialize an iterator over the specified list.
+ // The returned iterator is not valid.
+ explicit Iterator(const SkipList* list);
+
+ // Change the underlying skiplist used for this iterator
+ // This enables us not changing the iterator without deallocating
+ // an old one and then allocating a new one
+ void SetList(const SkipList* list);
+
+ // Returns true iff the iterator is positioned at a valid node.
+ bool Valid() const;
+
+ // Returns the key at the current position.
+ // REQUIRES: Valid()
+ const Key& key() const;
+
+ // Advances to the next position.
+ // REQUIRES: Valid()
+ void Next();
+
+ // Advances to the previous position.
+ // REQUIRES: Valid()
+ void Prev();
+
+ // Advance to the first entry with a key >= target
+ void Seek(const Key& target);
+
+ // Retreat to the last entry with a key <= target
+ void SeekForPrev(const Key& target);
+
+ // Position at the first entry in list.
+ // Final state of iterator is Valid() iff list is not empty.
+ void SeekToFirst();
+
+ // Position at the last entry in list.
+ // Final state of iterator is Valid() iff list is not empty.
+ void SeekToLast();
+
+ private:
+ const SkipList* list_;
+ Node* node_;
+ // Intentionally copyable
+ };
+
+ private:
+ const uint16_t kMaxHeight_;
+ const uint16_t kBranching_;
+ const uint32_t kScaledInverseBranching_;
+
+ // Immutable after construction
+ Comparator const compare_;
+ Allocator* const allocator_; // Allocator used for allocations of nodes
+
+ Node* const head_;
+
+ // Modified only by Insert(). Read racily by readers, but stale
+ // values are ok.
+ std::atomic<int> max_height_; // Height of the entire list
+
+ // Used for optimizing sequential insert patterns. Tricky. prev_[i] for
+ // i up to max_height_ is the predecessor of prev_[0] and prev_height_
+ // is the height of prev_[0]. prev_[0] can only be equal to head before
+ // insertion, in which case max_height_ and prev_height_ are 1.
+ Node** prev_;
+ int32_t prev_height_;
+
+ inline int GetMaxHeight() const {
+ return max_height_.load(std::memory_order_relaxed);
+ }
+
+ Node* NewNode(const Key& key, int height);
+ int RandomHeight();
+ bool Equal(const Key& a, const Key& b) const { return (compare_(a, b) == 0); }
+ bool LessThan(const Key& a, const Key& b) const {
+ return (compare_(a, b) < 0);
+ }
+
+ // Return true if key is greater than the data stored in "n"
+ bool KeyIsAfterNode(const Key& key, Node* n) const;
+
+ // Returns the earliest node with a key >= key.
+ // Return nullptr if there is no such node.
+ Node* FindGreaterOrEqual(const Key& key) const;
+
+ // Return the latest node with a key < key.
+ // Return head_ if there is no such node.
+ // Fills prev[level] with pointer to previous node at "level" for every
+ // level in [0..max_height_-1], if prev is non-null.
+ Node* FindLessThan(const Key& key, Node** prev = nullptr) const;
+
+ // Return the last node in the list.
+ // Return head_ if list is empty.
+ Node* FindLast() const;
+
+ // No copying allowed
+ SkipList(const SkipList&);
+ void operator=(const SkipList&);
+};
+
+// Implementation details follow
+template<typename Key, class Comparator>
+struct SkipList<Key, Comparator>::Node {
+ explicit Node(const Key& k) : key(k) { }
+
+ Key const key;
+
+ // Accessors/mutators for links. Wrapped in methods so we can
+ // add the appropriate barriers as necessary.
+ Node* Next(int n) {
+ assert(n >= 0);
+ // Use an 'acquire load' so that we observe a fully initialized
+ // version of the returned Node.
+ return (next_[n].load(std::memory_order_acquire));
+ }
+ void SetNext(int n, Node* x) {
+ assert(n >= 0);
+ // Use a 'release store' so that anybody who reads through this
+ // pointer observes a fully initialized version of the inserted node.
+ next_[n].store(x, std::memory_order_release);
+ }
+
+ // No-barrier variants that can be safely used in a few locations.
+ Node* NoBarrier_Next(int n) {
+ assert(n >= 0);
+ return next_[n].load(std::memory_order_relaxed);
+ }
+ void NoBarrier_SetNext(int n, Node* x) {
+ assert(n >= 0);
+ next_[n].store(x, std::memory_order_relaxed);
+ }
+
+ private:
+ // Array of length equal to the node height. next_[0] is lowest level link.
+ std::atomic<Node*> next_[1];
+};
+
+template<typename Key, class Comparator>
+typename SkipList<Key, Comparator>::Node*
+SkipList<Key, Comparator>::NewNode(const Key& key, int height) {
+ char* mem = allocator_->AllocateAligned(
+ sizeof(Node) + sizeof(std::atomic<Node*>) * (height - 1));
+ return new (mem) Node(key);
+}
+
+template<typename Key, class Comparator>
+inline SkipList<Key, Comparator>::Iterator::Iterator(const SkipList* list) {
+ SetList(list);
+}
+
+template<typename Key, class Comparator>
+inline void SkipList<Key, Comparator>::Iterator::SetList(const SkipList* list) {
+ list_ = list;
+ node_ = nullptr;
+}
+
+template<typename Key, class Comparator>
+inline bool SkipList<Key, Comparator>::Iterator::Valid() const {
+ return node_ != nullptr;
+}
+
+template<typename Key, class Comparator>
+inline const Key& SkipList<Key, Comparator>::Iterator::key() const {
+ assert(Valid());
+ return node_->key;
+}
+
+template<typename Key, class Comparator>
+inline void SkipList<Key, Comparator>::Iterator::Next() {
+ assert(Valid());
+ node_ = node_->Next(0);
+}
+
+template<typename Key, class Comparator>
+inline void SkipList<Key, Comparator>::Iterator::Prev() {
+ // Instead of using explicit "prev" links, we just search for the
+ // last node that falls before key.
+ assert(Valid());
+ node_ = list_->FindLessThan(node_->key);
+ if (node_ == list_->head_) {
+ node_ = nullptr;
+ }
+}
+
+template<typename Key, class Comparator>
+inline void SkipList<Key, Comparator>::Iterator::Seek(const Key& target) {
+ node_ = list_->FindGreaterOrEqual(target);
+}
+
+template <typename Key, class Comparator>
+inline void SkipList<Key, Comparator>::Iterator::SeekForPrev(
+ const Key& target) {
+ Seek(target);
+ if (!Valid()) {
+ SeekToLast();
+ }
+ while (Valid() && list_->LessThan(target, key())) {
+ Prev();
+ }
+}
+
+template <typename Key, class Comparator>
+inline void SkipList<Key, Comparator>::Iterator::SeekToFirst() {
+ node_ = list_->head_->Next(0);
+}
+
+template<typename Key, class Comparator>
+inline void SkipList<Key, Comparator>::Iterator::SeekToLast() {
+ node_ = list_->FindLast();
+ if (node_ == list_->head_) {
+ node_ = nullptr;
+ }
+}
+
+template<typename Key, class Comparator>
+int SkipList<Key, Comparator>::RandomHeight() {
+ auto rnd = Random::GetTLSInstance();
+
+ // Increase height with probability 1 in kBranching
+ int height = 1;
+ while (height < kMaxHeight_ && rnd->Next() < kScaledInverseBranching_) {
+ height++;
+ }
+ assert(height > 0);
+ assert(height <= kMaxHeight_);
+ return height;
+}
+
+template<typename Key, class Comparator>
+bool SkipList<Key, Comparator>::KeyIsAfterNode(const Key& key, Node* n) const {
+ // nullptr n is considered infinite
+ return (n != nullptr) && (compare_(n->key, key) < 0);
+}
+
+template<typename Key, class Comparator>
+typename SkipList<Key, Comparator>::Node* SkipList<Key, Comparator>::
+ FindGreaterOrEqual(const Key& key) const {
+ // Note: It looks like we could reduce duplication by implementing
+ // this function as FindLessThan(key)->Next(0), but we wouldn't be able
+ // to exit early on equality and the result wouldn't even be correct.
+ // A concurrent insert might occur after FindLessThan(key) but before
+ // we get a chance to call Next(0).
+ Node* x = head_;
+ int level = GetMaxHeight() - 1;
+ Node* last_bigger = nullptr;
+ while (true) {
+ Node* next = x->Next(level);
+ // Make sure the lists are sorted
+ assert(x == head_ || next == nullptr || KeyIsAfterNode(next->key, x));
+ // Make sure we haven't overshot during our search
+ assert(x == head_ || KeyIsAfterNode(key, x));
+ int cmp = (next == nullptr || next == last_bigger)
+ ? 1 : compare_(next->key, key);
+ if (cmp == 0 || (cmp > 0 && level == 0)) {
+ return next;
+ } else if (cmp < 0) {
+ // Keep searching in this list
+ x = next;
+ } else {
+ // Switch to next list, reuse compare_() result
+ last_bigger = next;
+ level--;
+ }
+ }
+}
+
+template<typename Key, class Comparator>
+typename SkipList<Key, Comparator>::Node*
+SkipList<Key, Comparator>::FindLessThan(const Key& key, Node** prev) const {
+ Node* x = head_;
+ int level = GetMaxHeight() - 1;
+ // KeyIsAfter(key, last_not_after) is definitely false
+ Node* last_not_after = nullptr;
+ while (true) {
+ Node* next = x->Next(level);
+ assert(x == head_ || next == nullptr || KeyIsAfterNode(next->key, x));
+ assert(x == head_ || KeyIsAfterNode(key, x));
+ if (next != last_not_after && KeyIsAfterNode(key, next)) {
+ // Keep searching in this list
+ x = next;
+ } else {
+ if (prev != nullptr) {
+ prev[level] = x;
+ }
+ if (level == 0) {
+ return x;
+ } else {
+ // Switch to next list, reuse KeyIUsAfterNode() result
+ last_not_after = next;
+ level--;
+ }
+ }
+ }
+}
+
+template<typename Key, class Comparator>
+typename SkipList<Key, Comparator>::Node* SkipList<Key, Comparator>::FindLast()
+ const {
+ Node* x = head_;
+ int level = GetMaxHeight() - 1;
+ while (true) {
+ Node* next = x->Next(level);
+ if (next == nullptr) {
+ if (level == 0) {
+ return x;
+ } else {
+ // Switch to next list
+ level--;
+ }
+ } else {
+ x = next;
+ }
+ }
+}
+
+template <typename Key, class Comparator>
+uint64_t SkipList<Key, Comparator>::EstimateCount(const Key& key) const {
+ uint64_t count = 0;
+
+ Node* x = head_;
+ int level = GetMaxHeight() - 1;
+ while (true) {
+ assert(x == head_ || compare_(x->key, key) < 0);
+ Node* next = x->Next(level);
+ if (next == nullptr || compare_(next->key, key) >= 0) {
+ if (level == 0) {
+ return count;
+ } else {
+ // Switch to next list
+ count *= kBranching_;
+ level--;
+ }
+ } else {
+ x = next;
+ count++;
+ }
+ }
+}
+
+template <typename Key, class Comparator>
+SkipList<Key, Comparator>::SkipList(const Comparator cmp, Allocator* allocator,
+ int32_t max_height,
+ int32_t branching_factor)
+ : kMaxHeight_(max_height),
+ kBranching_(branching_factor),
+ kScaledInverseBranching_((Random::kMaxNext + 1) / kBranching_),
+ compare_(cmp),
+ allocator_(allocator),
+ head_(NewNode(0 /* any key will do */, max_height)),
+ max_height_(1),
+ prev_height_(1) {
+ assert(max_height > 0 && kMaxHeight_ == static_cast<uint32_t>(max_height));
+ assert(branching_factor > 0 &&
+ kBranching_ == static_cast<uint32_t>(branching_factor));
+ assert(kScaledInverseBranching_ > 0);
+ // Allocate the prev_ Node* array, directly from the passed-in allocator.
+ // prev_ does not need to be freed, as its life cycle is tied up with
+ // the allocator as a whole.
+ prev_ = reinterpret_cast<Node**>(
+ allocator_->AllocateAligned(sizeof(Node*) * kMaxHeight_));
+ for (int i = 0; i < kMaxHeight_; i++) {
+ head_->SetNext(i, nullptr);
+ prev_[i] = head_;
+ }
+}
+
+template<typename Key, class Comparator>
+void SkipList<Key, Comparator>::Insert(const Key& key) {
+ // fast path for sequential insertion
+ if (!KeyIsAfterNode(key, prev_[0]->NoBarrier_Next(0)) &&
+ (prev_[0] == head_ || KeyIsAfterNode(key, prev_[0]))) {
+ assert(prev_[0] != head_ || (prev_height_ == 1 && GetMaxHeight() == 1));
+
+ // Outside of this method prev_[1..max_height_] is the predecessor
+ // of prev_[0], and prev_height_ refers to prev_[0]. Inside Insert
+ // prev_[0..max_height - 1] is the predecessor of key. Switch from
+ // the external state to the internal
+ for (int i = 1; i < prev_height_; i++) {
+ prev_[i] = prev_[0];
+ }
+ } else {
+ // TODO(opt): we could use a NoBarrier predecessor search as an
+ // optimization for architectures where memory_order_acquire needs
+ // a synchronization instruction. Doesn't matter on x86
+ FindLessThan(key, prev_);
+ }
+
+ // Our data structure does not allow duplicate insertion
+ assert(prev_[0]->Next(0) == nullptr || !Equal(key, prev_[0]->Next(0)->key));
+
+ int height = RandomHeight();
+ if (height > GetMaxHeight()) {
+ for (int i = GetMaxHeight(); i < height; i++) {
+ prev_[i] = head_;
+ }
+ //fprintf(stderr, "Change height from %d to %d\n", max_height_, height);
+
+ // It is ok to mutate max_height_ without any synchronization
+ // with concurrent readers. A concurrent reader that observes
+ // the new value of max_height_ will see either the old value of
+ // new level pointers from head_ (nullptr), or a new value set in
+ // the loop below. In the former case the reader will
+ // immediately drop to the next level since nullptr sorts after all
+ // keys. In the latter case the reader will use the new node.
+ max_height_.store(height, std::memory_order_relaxed);
+ }
+
+ Node* x = NewNode(key, height);
+ for (int i = 0; i < height; i++) {
+ // NoBarrier_SetNext() suffices since we will add a barrier when
+ // we publish a pointer to "x" in prev[i].
+ x->NoBarrier_SetNext(i, prev_[i]->NoBarrier_Next(i));
+ prev_[i]->SetNext(i, x);
+ }
+ prev_[0] = x;
+ prev_height_ = height;
+}
+
+template<typename Key, class Comparator>
+bool SkipList<Key, Comparator>::Contains(const Key& key) const {
+ Node* x = FindGreaterOrEqual(key);
+ if (x != nullptr && Equal(key, x->key)) {
+ return true;
+ } else {
+ return false;
+ }
+}
+
+} // namespace rocksdb
diff --git a/memtable/skiplist_test.cc b/memtable/skiplist_test.cc
new file mode 100644
index 000000000..cd0e94287
--- /dev/null
+++ b/memtable/skiplist_test.cc
@@ -0,0 +1,387 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree. An additional grant
+// of patent rights can be found in the PATENTS file in the same directory.
+//
+// Copyright (c) 2011 The LevelDB Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file. See the AUTHORS file for names of contributors.
+
+#include "memtable/skiplist.h"
+#include <set>
+#include "rocksdb/env.h"
+#include "util/arena.h"
+#include "util/hash.h"
+#include "util/random.h"
+#include "util/testharness.h"
+
+namespace rocksdb {
+
+typedef uint64_t Key;
+
+struct TestComparator {
+ int operator()(const Key& a, const Key& b) const {
+ if (a < b) {
+ return -1;
+ } else if (a > b) {
+ return +1;
+ } else {
+ return 0;
+ }
+ }
+};
+
+class SkipTest : public testing::Test {};
+
+TEST_F(SkipTest, Empty) {
+ Arena arena;
+ TestComparator cmp;
+ SkipList<Key, TestComparator> list(cmp, &arena);
+ ASSERT_TRUE(!list.Contains(10));
+
+ SkipList<Key, TestComparator>::Iterator iter(&list);
+ ASSERT_TRUE(!iter.Valid());
+ iter.SeekToFirst();
+ ASSERT_TRUE(!iter.Valid());
+ iter.Seek(100);
+ ASSERT_TRUE(!iter.Valid());
+ iter.SeekForPrev(100);
+ ASSERT_TRUE(!iter.Valid());
+ iter.SeekToLast();
+ ASSERT_TRUE(!iter.Valid());
+}
+
+TEST_F(SkipTest, InsertAndLookup) {
+ const int N = 2000;
+ const int R = 5000;
+ Random rnd(1000);
+ std::set<Key> keys;
+ Arena arena;
+ TestComparator cmp;
+ SkipList<Key, TestComparator> list(cmp, &arena);
+ for (int i = 0; i < N; i++) {
+ Key key = rnd.Next() % R;
+ if (keys.insert(key).second) {
+ list.Insert(key);
+ }
+ }
+
+ for (int i = 0; i < R; i++) {
+ if (list.Contains(i)) {
+ ASSERT_EQ(keys.count(i), 1U);
+ } else {
+ ASSERT_EQ(keys.count(i), 0U);
+ }
+ }
+
+ // Simple iterator tests
+ {
+ SkipList<Key, TestComparator>::Iterator iter(&list);
+ ASSERT_TRUE(!iter.Valid());
+
+ iter.Seek(0);
+ ASSERT_TRUE(iter.Valid());
+ ASSERT_EQ(*(keys.begin()), iter.key());
+
+ iter.SeekForPrev(R - 1);
+ ASSERT_TRUE(iter.Valid());
+ ASSERT_EQ(*(keys.rbegin()), iter.key());
+
+ iter.SeekToFirst();
+ ASSERT_TRUE(iter.Valid());
+ ASSERT_EQ(*(keys.begin()), iter.key());
+
+ iter.SeekToLast();
+ ASSERT_TRUE(iter.Valid());
+ ASSERT_EQ(*(keys.rbegin()), iter.key());
+ }
+
+ // Forward iteration test
+ for (int i = 0; i < R; i++) {
+ SkipList<Key, TestComparator>::Iterator iter(&list);
+ iter.Seek(i);
+
+ // Compare against model iterator
+ std::set<Key>::iterator model_iter = keys.lower_bound(i);
+ for (int j = 0; j < 3; j++) {
+ if (model_iter == keys.end()) {
+ ASSERT_TRUE(!iter.Valid());
+ break;
+ } else {
+ ASSERT_TRUE(iter.Valid());
+ ASSERT_EQ(*model_iter, iter.key());
+ ++model_iter;
+ iter.Next();
+ }
+ }
+ }
+
+ // Backward iteration test
+ for (int i = 0; i < R; i++) {
+ SkipList<Key, TestComparator>::Iterator iter(&list);
+ iter.SeekForPrev(i);
+
+ // Compare against model iterator
+ std::set<Key>::iterator model_iter = keys.upper_bound(i);
+ for (int j = 0; j < 3; j++) {
+ if (model_iter == keys.begin()) {
+ ASSERT_TRUE(!iter.Valid());
+ break;
+ } else {
+ ASSERT_TRUE(iter.Valid());
+ ASSERT_EQ(*--model_iter, iter.key());
+ iter.Prev();
+ }
+ }
+ }
+}
+
+// We want to make sure that with a single writer and multiple
+// concurrent readers (with no synchronization other than when a
+// reader's iterator is created), the reader always observes all the
+// data that was present in the skip list when the iterator was
+// constructor. Because insertions are happening concurrently, we may
+// also observe new values that were inserted since the iterator was
+// constructed, but we should never miss any values that were present
+// at iterator construction time.
+//
+// We generate multi-part keys:
+// <key,gen,hash>
+// where:
+// key is in range [0..K-1]
+// gen is a generation number for key
+// hash is hash(key,gen)
+//
+// The insertion code picks a random key, sets gen to be 1 + the last
+// generation number inserted for that key, and sets hash to Hash(key,gen).
+//
+// At the beginning of a read, we snapshot the last inserted
+// generation number for each key. We then iterate, including random
+// calls to Next() and Seek(). For every key we encounter, we
+// check that it is either expected given the initial snapshot or has
+// been concurrently added since the iterator started.
+class ConcurrentTest {
+ private:
+ static const uint32_t K = 4;
+
+ static uint64_t key(Key key) { return (key >> 40); }
+ static uint64_t gen(Key key) { return (key >> 8) & 0xffffffffu; }
+ static uint64_t hash(Key key) { return key & 0xff; }
+
+ static uint64_t HashNumbers(uint64_t k, uint64_t g) {
+ uint64_t data[2] = { k, g };
+ return Hash(reinterpret_cast<char*>(data), sizeof(data), 0);
+ }
+
+ static Key MakeKey(uint64_t k, uint64_t g) {
+ assert(sizeof(Key) == sizeof(uint64_t));
+ assert(k <= K); // We sometimes pass K to seek to the end of the skiplist
+ assert(g <= 0xffffffffu);
+ return ((k << 40) | (g << 8) | (HashNumbers(k, g) & 0xff));
+ }
+
+ static bool IsValidKey(Key k) {
+ return hash(k) == (HashNumbers(key(k), gen(k)) & 0xff);
+ }
+
+ static Key RandomTarget(Random* rnd) {
+ switch (rnd->Next() % 10) {
+ case 0:
+ // Seek to beginning
+ return MakeKey(0, 0);
+ case 1:
+ // Seek to end
+ return MakeKey(K, 0);
+ default:
+ // Seek to middle
+ return MakeKey(rnd->Next() % K, 0);
+ }
+ }
+
+ // Per-key generation
+ struct State {
+ std::atomic<int> generation[K];
+ void Set(int k, int v) {
+ generation[k].store(v, std::memory_order_release);
+ }
+ int Get(int k) { return generation[k].load(std::memory_order_acquire); }
+
+ State() {
+ for (unsigned int k = 0; k < K; k++) {
+ Set(k, 0);
+ }
+ }
+ };
+
+ // Current state of the test
+ State current_;
+
+ Arena arena_;
+
+ // SkipList is not protected by mu_. We just use a single writer
+ // thread to modify it.
+ SkipList<Key, TestComparator> list_;
+
+ public:
+ ConcurrentTest() : list_(TestComparator(), &arena_) {}
+
+ // REQUIRES: External synchronization
+ void WriteStep(Random* rnd) {
+ const uint32_t k = rnd->Next() % K;
+ const int g = current_.Get(k) + 1;
+ const Key new_key = MakeKey(k, g);
+ list_.Insert(new_key);
+ current_.Set(k, g);
+ }
+
+ void ReadStep(Random* rnd) {
+ // Remember the initial committed state of the skiplist.
+ State initial_state;
+ for (unsigned int k = 0; k < K; k++) {
+ initial_state.Set(k, current_.Get(k));
+ }
+
+ Key pos = RandomTarget(rnd);
+ SkipList<Key, TestComparator>::Iterator iter(&list_);
+ iter.Seek(pos);
+ while (true) {
+ Key current;
+ if (!iter.Valid()) {
+ current = MakeKey(K, 0);
+ } else {
+ current = iter.key();
+ ASSERT_TRUE(IsValidKey(current)) << current;
+ }
+ ASSERT_LE(pos, current) << "should not go backwards";
+
+ // Verify that everything in [pos,current) was not present in
+ // initial_state.
+ while (pos < current) {
+ ASSERT_LT(key(pos), K) << pos;
+
+ // Note that generation 0 is never inserted, so it is ok if
+ // <*,0,*> is missing.
+ ASSERT_TRUE((gen(pos) == 0U) ||
+ (gen(pos) > static_cast<uint64_t>(initial_state.Get(
+ static_cast<int>(key(pos))))))
+ << "key: " << key(pos) << "; gen: " << gen(pos)
+ << "; initgen: " << initial_state.Get(static_cast<int>(key(pos)));
+
+ // Advance to next key in the valid key space
+ if (key(pos) < key(current)) {
+ pos = MakeKey(key(pos) + 1, 0);
+ } else {
+ pos = MakeKey(key(pos), gen(pos) + 1);
+ }
+ }
+
+ if (!iter.Valid()) {
+ break;
+ }
+
+ if (rnd->Next() % 2) {
+ iter.Next();
+ pos = MakeKey(key(pos), gen(pos) + 1);
+ } else {
+ Key new_target = RandomTarget(rnd);
+ if (new_target > pos) {
+ pos = new_target;
+ iter.Seek(new_target);
+ }
+ }
+ }
+ }
+};
+const uint32_t ConcurrentTest::K;
+
+// Simple test that does single-threaded testing of the ConcurrentTest
+// scaffolding.
+TEST_F(SkipTest, ConcurrentWithoutThreads) {
+ ConcurrentTest test;
+ Random rnd(test::RandomSeed());
+ for (int i = 0; i < 10000; i++) {
+ test.ReadStep(&rnd);
+ test.WriteStep(&rnd);
+ }
+}
+
+class TestState {
+ public:
+ ConcurrentTest t_;
+ int seed_;
+ std::atomic<bool> quit_flag_;
+
+ enum ReaderState {
+ STARTING,
+ RUNNING,
+ DONE
+ };
+
+ explicit TestState(int s)
+ : seed_(s), quit_flag_(false), state_(STARTING), state_cv_(&mu_) {}
+
+ void Wait(ReaderState s) {
+ mu_.Lock();
+ while (state_ != s) {
+ state_cv_.Wait();
+ }
+ mu_.Unlock();
+ }
+
+ void Change(ReaderState s) {
+ mu_.Lock();
+ state_ = s;
+ state_cv_.Signal();
+ mu_.Unlock();
+ }
+
+ private:
+ port::Mutex mu_;
+ ReaderState state_;
+ port::CondVar state_cv_;
+};
+
+static void ConcurrentReader(void* arg) {
+ TestState* state = reinterpret_cast<TestState*>(arg);
+ Random rnd(state->seed_);
+ int64_t reads = 0;
+ state->Change(TestState::RUNNING);
+ while (!state->quit_flag_.load(std::memory_order_acquire)) {
+ state->t_.ReadStep(&rnd);
+ ++reads;
+ }
+ state->Change(TestState::DONE);
+}
+
+static void RunConcurrent(int run) {
+ const int seed = test::RandomSeed() + (run * 100);
+ Random rnd(seed);
+ const int N = 1000;
+ const int kSize = 1000;
+ for (int i = 0; i < N; i++) {
+ if ((i % 100) == 0) {
+ fprintf(stderr, "Run %d of %d\n", i, N);
+ }
+ TestState state(seed + 1);
+ Env::Default()->Schedule(ConcurrentReader, &state);
+ state.Wait(TestState::RUNNING);
+ for (int k = 0; k < kSize; k++) {
+ state.t_.WriteStep(&rnd);
+ }
+ state.quit_flag_.store(true, std::memory_order_release);
+ state.Wait(TestState::DONE);
+ }
+}
+
+TEST_F(SkipTest, Concurrent1) { RunConcurrent(1); }
+TEST_F(SkipTest, Concurrent2) { RunConcurrent(2); }
+TEST_F(SkipTest, Concurrent3) { RunConcurrent(3); }
+TEST_F(SkipTest, Concurrent4) { RunConcurrent(4); }
+TEST_F(SkipTest, Concurrent5) { RunConcurrent(5); }
+
+} // namespace rocksdb
+
+int main(int argc, char** argv) {
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/memtable/skiplistrep.cc b/memtable/skiplistrep.cc
index d07fa1218..363ebe286 100644
--- a/memtable/skiplistrep.cc
+++ b/memtable/skiplistrep.cc
@@ -3,7 +3,7 @@
// LICENSE file in the root directory of this source tree. An additional grant
// of patent rights can be found in the PATENTS file in the same directory.
//
-#include "db/inlineskiplist.h"
+#include "memtable/inlineskiplist.h"
#include "db/memtable.h"
#include "rocksdb/memtablerep.h"
#include "util/arena.h"