A splay tree is a self-balancing binary search tree. It performs operations such as insertion, look-up and removal in O(log n) amortized time. The splay tree moves recently accessed elements to the root by splaying, so if the element is queried again it can be reached in O(1) time.
A fully-fledged splay tree involves such operations:
- Splay: moving the recently accessed nodes to the root through a series of rotations below based on different situation
- Zig
- Zig-zig and Zag-zag
- Zig-zag and Zag-zig
- Join: merging two subtrees
- Split: splitting one tree into two subtrees
- Insert: add an element
- Delete: remove an element
- Find: look up an element
A good visualization can be found here. More details about splay trees and above operations can be found on Wikipedia. Here I use an example called Set with range sums to show the application of splay trees.
Set with Range Sums
Description
Implement a data structure that stores a set 𝑆 of integers with the following allowed operations:
- add(𝑖) — add integer 𝑖 into the set S (if it was there already, the set doesn’t change).
- del(𝑖) — remove integer 𝑖 from the set 𝑆 (if there was no such element, nothing happens).
- find(𝑖) — check whether 𝑖 is in the set 𝑆 or not.
- sum(l, r) — output the sum of all elements 𝑣 in 𝑆 such that 𝑙 ≤𝑣 ≤𝑟.
Initially the set 𝑆 is empty. The first line contains 𝑛 — the number of operations. The next 𝑛 lines contain operations. Each operation is one of the following:
- “+ i” — which means add some integer (not 𝑖, see below) to 𝑆,
- “- i” — which means del some integer (not 𝑖, see below) from 𝑆,
- “? i” — which means find some integer (not 𝑖, see below)in 𝑆,
- “s l r” — which means compute the sum of all elements of 𝑆 within some range of values (not from 𝑙 to 𝑟, see below).
To avoid integer overflow, we denote M = 1 000 000 001 and let x be the result of the last sum operation, or just 0 if there were no sum operations before.
- “+ i” means add((𝑖 + 𝑥) mod 𝑀),
- “- i” means del((𝑖 + 𝑥) mod 𝑀),
- “? i” means find((𝑖 + 𝑥) mod 𝑀),
- “s l r” means sum((𝑙+𝑥) mod 𝑀, (𝑟 + 𝑥) mod 𝑀).
Constraints. 1 ≤ n ≤ 100 000; 0 ≤ i ≤ 10^9.
This is typical problem that needs a splay tree if you care about the efficiency, as we can use the Split to get the range sum and Join two subtrees back to restore the structure.
Thoughts
The idea is each Vertex (TreeNode) has a key representing an integer, a sum recording the sum of all the keys in the subtree, a left pointing to its left child vertex, a right pointing to its right child vertex, and a parent pointing to its parent vertex. Note that the sum is not static, but will constantly change as the splay tree changes.
With this design, add would just be Insert a vertex; del is the same as Delete; find equals Find. How about sum(l, r)? Well, if we split the splay tree into three subtrees: I (nodes having keys smaller than l), J (keys are greater than l but smaller than r), and K (keys are greater than r). Obviously, the sum we want is the sum stored in the root of the middle subtree J. Once we get the sum we need, we just need to merge three parts back. All operations of the splay tree run in $O(log n)$ time on average, so this is a very effective way to solve the problem.
Let’s Code!
Skeleton
First, let me present the skeleton of the java class that handles i/o.
import java.io.*;
import java.util.*;
public class SetRangeSum {
BufferedReader br;
PrintWriter out;
StringTokenizer st;
boolean eof;
public static final int MODULO = 1000000001;
void solve() throws IOException {
int n = nextInt();
int last_sum_result = 0;
for (int i = 0; i < n; i++) {
char type = nextChar();
switch (type) {
case '+':
{
int x = nextInt();
insert((x + last_sum_result) % MODULO);
}
break;
case '-':
{
int x = nextInt();
erase((x + last_sum_result) % MODULO);
}
break;
case '?':
{
int x = nextInt();
out.println(find((x + last_sum_result) % MODULO) ? "Found" : "Not found");
}
break;
case 's':
{
int l = nextInt();
int r = nextInt();
long res = sum((l + last_sum_result) % MODULO, (r + last_sum_result) % MODULO);
out.println(res);
last_sum_result = (int)(res % MODULO);
}
}
}
}
SetRangeSum() throws IOException {
br = new BufferedReader(new InputStreamReader(System.in));
out = new PrintWriter(System.out);
solve();
out.close();
}
public static void main(String[] args) throws IOException {
new SetRangeSum();
}
String nextToken() {
while (st == null || !st.hasMoreTokens()) {
try {
st = new StringTokenizer(br.readLine());
} catch (Exception e) {
eof = true;
return null;
}
}
return st.nextToken();
}
int nextInt() throws IOException {
return Integer.parseInt(nextToken());
}
char nextChar() throws IOException {
return nextToken().charAt(0);
}
}
Clearly, we need to define the Vertex as the TreeNode in the splay tree. A vertex is defined to be able to hold a value key
, the sum below the present node sum
, left
/ right
child node, and its parent
node.
// Splay tree implementation
// Vertex of a splay tree
class Vertex {
int key;
// Sum of all the keys in the subtree - remember to update
// it after each operation that changes the tree.
long sum;
Vertex left;
Vertex right;
Vertex parent;
Vertex(int key, long sum, Vertex left, Vertex right, Vertex parent) {
this.key = key;
this.sum = sum;
this.left = left;
this.right = right;
this.parent = parent;
}
}
Vertex root = null;
Basic operations
Then as described above, we need basic operations like: Splay, Join, Split, Insert, Delete, and Find. Let’s have the Splay as splay
like so. Again, the splay
moves the node through a series of rotations (smallRotation / bigRotation
) that change the structure of the tree, and thus, the vertex.sum
changes accordingly via calling update
.
// Makes splay of the given vertex and returns the new root.
Vertex splay(Vertex v) {
if (v == null) return null;
while (v.parent != null) {
if (v.parent.parent == null) {
smallRotation(v);
break;
}
bigRotation(v);
}
return v;
}
void smallRotation(Vertex v) {
Vertex parent = v.parent;
if (parent == null) {
return;
}
Vertex grandparent = v.parent.parent;
if (parent.left == v) {
Vertex m = v.right;
v.right = parent;
parent.left = m;
} else {
Vertex m = v.left;
v.left = parent;
parent.right = m;
}
update(parent);
update(v);
v.parent = grandparent;
if (grandparent != null) {
if (grandparent.left == parent) {
grandparent.left = v;
} else {
grandparent.right = v;
}
}
}
void bigRotation(Vertex v) {
if (v.parent.left == v && v.parent.parent.left == v.parent) {
// Zig-zig
smallRotation(v.parent);
smallRotation(v);
} else if (v.parent.right == v && v.parent.parent.right == v.parent) {
// Zig-zig
smallRotation(v.parent);
smallRotation(v);
} else {
// Zig-zag
smallRotation(v);
smallRotation(v);
}
}
void update(Vertex v) {
if (v == null) return;
v.sum = v.key + (v.left != null ? v.left.sum : 0) + (v.right != null ? v.right.sum : 0);
if (v.left != null) {
v.left.parent = v;
}
if (v.right != null) {
v.right.parent = v;
}
}
We could also have Join
as merge
define below. Here we always choose the right root as the new root of the merged tree.
Vertex merge(Vertex left, Vertex right) {
if (left == null) return right;
if (right == null) return left;
while (right.left != null) {
right = right.left;
}
right = splay(right);
right.left = left;
update(right);
return right;
}
Similarly, we can have Split
as split
. The find
will be given later. At this stage we only need to know the find
will give us two vertices given the root
of the tree and the queried key
: the result node and the new root (remember we always move the recently accessed node to the root). The split
here would divide the tree into two trees and return two vertices representing two new roots of two divided trees.
VertexPair split(Vertex root, int key) {
VertexPair result = new VertexPair();
VertexPair findAndRoot = find(root, key);
root = findAndRoot.right;
result.right = findAndRoot.left;
if (result.right == null) {
result.left = root;
return result;
}
result.right = splay(result.right);
result.left = result.right.left;
result.right.left = null;
if (result.left != null) {
result.left.parent = null;
}
update(result.left);
update(result.right);
return result;
}
class VertexPair {
Vertex left;
Vertex right;
VertexPair() {}
VertexPair(Vertex left, Vertex right) {
this.left = left;
this.right = right;
}
}
We then have Insert
implemented as insert
. We utilize split
to disconnect the tree at the edge where the node x
should be added, and merge
the x
with left subtree and right subtree.
void insert(int x) {
// System.out.println("Inserting "+x);
Vertex left = null;
Vertex right = null;
Vertex new_vertex = null;
VertexPair leftRight = split(root, x);
left = leftRight.left;
right = leftRight.right;
if (right == null || right.key != x) {
new_vertex = new Vertex(x, x, null, null, null);
}
root = merge(merge(left, new_vertex), right);
}
We can also have Delete
as erase
: we use split
to break the connection(s) between the targeting node x
and the splay tree. Then we merge the rest back together.
void erase(int x) {
// use split and merge
VertexPair leftMiddle = split(root, x);
Vertex left = leftMiddle.left;
Vertex middle = leftMiddle.right;
VertexPair middleRight = split(middle, x + 1);
middle = middleRight.left;
Vertex right = middleRight.right;
if (middle == null || middle.key != x) {
root = merge(merge(left, middle), right);
} else {
middle = null;
root = merge(left, right);
}
}
Lastly we have out find
described previously.
// Searches for the given key in the tree with the given root
// and calls splay for the deepest visited node after that.
// Returns pair of the result and the new root.
// If found, result is a pointer to the node with the given key.
// Otherwise, result is a pointer to the node with the smallest
// bigger key (next value in the order).
// If the key is bigger than all keys in the tree,
// then result is null.
VertexPair find(Vertex root, int key) {
Vertex v = root;
Vertex last = root;
Vertex next = null;
while (v != null) {
if (v.key >= key && (next == null || v.key < next.key)) {
next = v;
}
last = v;
if (v.key == key) {
break;
}
if (v.key < key) {
v = v.right;
} else {
v = v.left;
}
}
root = splay(last);
return new VertexPair(next, root);
}
boolean find(int x) {
// By looking into find(Vertex, int) if found,
// the right of the returned VertexPair is the node
// the left is the smallest bigger node
/* Does this work? */
// VertexPair leftRight = find(root, x);
// if(leftRight.right != null && leftRight.right.key == x) {
// return true;
// }
// return false;
Vertex left = null;
Vertex right = null;
VertexPair leftRight = split(root, x);
left = leftRight.left;
right = leftRight.right;
if (right == null || right.key != x) {
root = merge(left, right);
return false;
} else {
root = merge(left, right);
return true;
}
}
Finally, we can use the splay tree and calculate the range sum via sum
: As described above, we split the tree into three parts among which the middle part ranges from from
to to
, then we can get the range sum easily. After the query, we just need to merge those three parts back to restore the splay tree.
Range sum
long sum(int from, int to) {
// System.out.println("From "+from + " to " + to);
VertexPair leftMiddle = split(root, from);
Vertex left = leftMiddle.left;
Vertex middle = leftMiddle.right;
VertexPair middleRight = split(middle, to + 1);
middle = middleRight.left;
Vertex right = middleRight.right;
long ans = 0;
// Complete the implementation of sum
if (middle != null)
ans = middle.sum;
middle = merge(middle, right);
root = merge(left, middle);
return ans;
}
Examples
The full implementation can be found here. Feel free to test it with some examples below.
Sample 1
Input:
15
? 1
+ 1
? 1
+ 2
s 1 2
+ 1000000000
? 1000000000
- 1000000000
? 1000000000
s 999999999 1000000000 -2
? 2
- 0
+ 9
s 0 9
Output:
Not found
Found
3
Found
Not found
1
Not found
10
Explanation:
For the first 5 queries, 𝑥 = 0. For the next 5 queries, 𝑥 = 3. For the next 5 queries, 𝑥 = 1. The actual list of operations is:
find(1)
add(1)
find(1)
add(2)
sum(1, 2) → 3
add(2)
find(2) → Found
del(2)
find(2) → Not found
sum(1, 2) → 1
del(3)
find(3) → Not found
del(1)
add(10)
sum(1, 10) → 10
Sample 2
Input:
5
? 0
+ 0
? 0
- 0
? 0
Output:
Not found
Found
Not found
Sample 3
Input:
5
+ 491572259
? 491572259
? 899375874
s 310971296 877523306
+ 352411209
Output:
Found
Not found
491572259