1 /** Interval Tree backed by augmented Splay Tree
2 
3 This is not threadsafe! Every query modifies the tree.
4 
5 Enable instrumentation:
6     version(instrument)
7         This will calculate statistics related to traversal depth
8 
9 Author: James S. Blachly, MD <james.blachly@gmail.com>
10 Copyright: Copyright (c) 2019 James Blachly
11 License: MIT
12 */
13 module intervaltree.splaytree;
14 
15 import intervaltree : BasicInterval, overlaps;
16 
17 import std.experimental.allocator;
18 import std.experimental.allocator.building_blocks.region;
19 import std.experimental.allocator.building_blocks.allocator_list : AllocatorList;
20 import std.experimental.allocator.building_blocks.null_allocator : NullAllocator;
21 import std.experimental.allocator.mallocator : Mallocator;
22 
23 import mir.random;
24 
25 version(instrument) __gshared int[] _splaytree_visited;
26 
27 /// Probably should not be used directly by consumer
28 struct IntervalTreeNode(IntervalType)
29 if (__traits(hasMember, IntervalType, "start") &&
30     __traits(hasMember, IntervalType, "end"))
31 {
32     //alias key = interval.start;   // no longer works with the embedded
33                                     // structs and chain of alias this
34     /// sort key
35     pragma(inline,true)
36     @property @safe @nogc nothrow const
37     auto key() { return this.interval.start; }
38 
39     IntervalType interval;  /// must at a minimum include members start, end
40     typeof(IntervalType.end) max;    /// maximum in this $(I subtree)
41 
42     IntervalTreeNode *parent;   /// parent node
43     IntervalTreeNode *left;     /// left child
44     IntervalTreeNode *right;    /// right child
45 
46     /// Does the interval in this node overlap the interval in the other node?
47     pragma(inline, true) @nogc nothrow bool overlaps(const ref IntervalTreeNode other)
48         { return this.interval.overlaps(other.interval); }
49 
50     /// non-default ctor: construct Node from interval, update max
51     /// side note: D is beautiful in that Node(i) will work just fine
52     /// without this constructor since its first member is IntervalType interval,
53     /// but we need the constructor to update max.
54     @nogc nothrow
55     this(IntervalType i) 
56     {
57         this.interval = i;  // blit
58         this.max = i.end;
59     }
60 
61     /// Returns true if this node is the left child of its' parent
62     @nogc nothrow
63     bool isLeftChild()
64     {
65         if (this.parent !is null)   // may be null if is root
66         {
67             return (this.parent.left is &this);
68         }
69         else return false;
70     }
71 
72     invariant
73     {
74         // the Interval type itself should include checks, but in case it does not:
75         assert(this.interval.start <= this.interval.end, "Interval start must <= end");
76 
77         assert(this.max >= this.interval.end, "max must be at least as high as our own end");
78 
79         // Make sure this is a child of its parent
80         if (this.parent !is null)
81         {
82             assert(this.parent.left is &this || this.parent.right is &this,
83                 "Broken parent/child relationship");
84         }
85 
86         // Ensure children are distinct
87         if (this.left !is null && this.right !is null)
88         {
89             assert(this.left != this.right, "Left and righ child appear identical");
90         }
91 
92     }
93 }
94 
95 /// Common API across Interval AVL Trees and Interval Splay Trees
96 alias IntervalTree = IntervalSplayTree;
97 
98 ///
99 struct IntervalSplayTree(IntervalType)
100 {
101     alias Node = IntervalTreeNode!IntervalType;
102 
103     Node *root;    /// tree root
104     Node *cur;      /// current or cursor for iteration
105 
106     private AllocatorList!((n) => Region!Mallocator(IntervalType.sizeof * 65_536), NullAllocator) mempool;
107 
108     // NB if change to class, add 'final'
109     /** zig a child of the root node */
110     pragma(inline, true)
111     @safe @nogc nothrow
112     private void zig(Node *n) 
113     in
114     {
115         // zig should not be called on empty tree
116         assert(n !is null);
117         // zig should not be called on root node
118         assert( n.parent !is null );
119         // zig only to be called on child of root node -- i.e. no grandparent node
120         assert( n.parent.parent is null );
121     }
122     do
123     {
124         Node *p = n.parent;
125 
126         if (p.left == n)    // node is left child of parent
127         {
128             //Node *A = n.left;   // left child
129             Node *B = n.right;  // right child
130             //Node *C = p.right;  // sister node
131 
132             n.parent = null;    // rotate to top (splay() fn handles tree.root reassignment)
133 
134             // place parent (former root) as right child
135             n.right  = p;
136             p.parent = n;
137 
138             // assign former right child to (former) parent's left (our prev pos)
139             p.left = B;
140             if (B !is null) B.parent = p;
141         }
142         else                // node is right child of parent
143         {
144             // safety check during development
145             assert(p.right == n);
146 
147             //Node *A = p.left;   // sister node
148             Node *B = n.left;   // left child
149             //Node *C = n.right;  // right child
150 
151             n.parent = null;    // rotate to top (splay() fn handles tree.root reassignment)
152 
153             // place parent (former root) as left child
154             n.left = p;
155             p.parent = n;
156 
157             // assign former left child to (former) parent's right (our prev pos)
158             p.right = B;
159             if (B !is null) B.parent = p;
160         }
161 
162         // update max
163         // lemmas (with respect to their positions prior to rotation):
164         // 1. scenarios when both root/"parent" and Node n need to be updated may exist
165         // 2. A, B, C, D subtrees never need to be updated
166         // 3. other subtree of root/"parent" never needs to be updated
167         // conclusion: n takes p's max; update p which is now child of n
168         n.max = p.max;  // n now at root, can take p (prior root)'s max
169         updateMax(p); 
170     }
171 
172     // NB if change to class, add 'final'
173     /** zig-zig  */
174     //pragma(inline, true)
175     @safe @nogc nothrow
176     private void zigZig(Node *n) 
177     in
178     {
179         // zig-zig should not be called on empty tree
180         assert(n !is null);
181         // zig-zig should not be called on the root node
182         assert(n.parent !is null);
183         // zig-zig requires a grandparent node
184         assert(n.parent.parent !is null);
185         // relationships must be identical:
186         if(n == n.parent.left) assert(n.parent == n.parent.parent.left);
187         else if(n == n.parent.right) assert(n.parent == n.parent.parent.right);
188         else assert(0);
189     }
190     do
191     {
192         // DMD cannot inline this
193         version(LDC) pragma(inline, true);
194         version(GNU) pragma(inline, true);
195 
196         Node *p = n.parent;
197         Node *g = p.parent;
198 
199         if (p.left == n)
200         {
201 /*
202         /g\
203        /   \
204      /p\   /D\
205     /   \
206   /n\   /C\
207  /   \
208 /A\  /B\
209 */
210             //Node *A = n.left;
211             Node *B = n.right;
212             Node *C = p.right;
213             //Node *D = g.right;
214 
215             n.parent = g.parent;
216             if (n.parent !is null)
217             {
218                 assert( n.parent.left == g || n.parent.right == g);
219                 if (n.parent.left == g) n.parent.left = n;
220                 else n.parent.right = n;
221             }
222 
223             n.right = p;
224             p.parent = n;
225 
226             p.left = B;
227             if (B !is null) B.parent = p;
228             p.right = g;
229             g.parent = p;
230 
231             g.left = C;
232             if (C !is null) C.parent = g;
233 
234         }
235         else    // node is right child of parent
236         {
237 /*
238         /g\
239        /   \
240      /A\   /p\
241           /   \
242         /B\   /n\
243              /   \
244             /C\  /D\
245 */
246             // safety check during development
247             assert(p.right == n);
248 
249             //Node *A = g.left;
250             Node *B = p.left;
251             Node *C = n.left;
252             //Node *D = n.right;
253 
254             n.parent = g.parent;
255             if (n.parent !is null)
256             {
257                 assert( n.parent.left == g || n.parent.right == g);
258                 if (n.parent.left == g) n.parent.left = n;
259                 else n.parent.right = n;
260             }
261 
262             n.left = p;
263             p.parent = n;
264 
265             p.left = g;
266             g.parent = p;
267             p.right = C;
268             if (C !is null) C.parent = p;
269 
270             g.right = B;
271             if (B !is null) B.parent = g;
272 
273         }
274 
275         // update max
276         // lemmas:
277         // 1. A, B, C, D had only a parent changed => nver need max updated
278         // 2. g, p, or n may need to be changed
279         // 3. g -> p -> n after both left zigzig and right zigzig
280         // conclusion: can update on g and percolate upward
281         // update: never need to update n (prev: g)'s parent or higher
282         n.max = g.max;  // take max of prior subtree root (g)
283         updateMax(g);
284         updateMax(p);
285     }
286 
287     // NB if change to class, add 'final'
288     /** zig-zag */
289     //pragma(inline, true)
290     @safe @nogc nothrow
291     private void zigZag(Node *n) 
292     in
293     {
294         // zig-zag should not be called on empty tree
295         assert(n !is null);
296         // zig-zag should not be called on the root node
297         assert(n.parent !is null);
298         // zig-zag requires a grandparent node
299         assert(n.parent.parent !is null);
300         // relationships must be opposite:
301         if(n == n.parent.left) assert(n.parent == n.parent.parent.right);
302         else if(n == n.parent.right) assert(n.parent == n.parent.parent.left);
303         else assert(0);
304     }
305     do
306     {
307         // DMD cannot inline this
308         version(LDC) pragma(inline, true);
309         version(GNU) pragma(inline, true);
310 
311         Node *p = n.parent;
312         Node *g = p.parent;
313 
314         if (p.right == n)
315         {
316             assert(p.right == n && g.left == p);
317 /*  node is right child of parent; parent is left child of grandparent
318               /g\             /n\
319              /   \           /   \
320            /p\   /D\   ->  /p\   /g\
321           /   \           /   \ /   \
322         /A\   /n\        A    B C   D
323              /   \
324             /B\  /C\
325 */
326             //Node *A = p.left;
327             Node *B = n.left;
328             Node *C = n.right;
329             //Node *D = g.right;
330 
331             n.parent = g.parent;
332             n.left = p;
333             n.right = g;
334             if (n.parent !is null)
335             {
336                 assert( n.parent.left == g || n.parent.right == g);
337                 if (n.parent.left == g) n.parent.left = n;
338                 else n.parent.right = n;
339             }
340 
341             p.parent = n;
342             p.right = B;
343             if (B !is null) B.parent = p;
344 
345             g.parent = n;
346             g.left = C;
347             if (C !is null) C.parent = g;
348         }
349         else
350         {
351             assert(p.left == n && g.right == p);
352 /*  node is left child of parent; parent is right child of grandparent
353          /g\             /n\
354         /   \           /   \
355       /A\  /p\    ->   /g\   /p\
356           /   \       /   \ /   \
357         /n\   /D\    A    B C   D
358        /   \
359       /B\  /C\
360 */
361             //Node *A = g.left;
362             Node *B = n.left;
363             Node *C = n.right;
364             //Node *D = p.right;
365 
366             n.parent = g.parent;
367             n.left = g;
368             n.right = p;
369             if (n.parent !is null)
370             {
371                 assert( n.parent.left == g || n.parent.right == g);
372                 if (n.parent.left == g) n.parent.left = n;
373                 else n.parent.right = n;
374             }
375 
376             p.parent = n;
377             p.left = C;
378             if (C !is null) C.parent = p;
379 
380             g.parent = n;
381             g.right = B;
382             if (B !is null) B.parent = g;
383         }
384 
385         // update max
386         // lemmas:
387         // 1. A, B, C, D had only a parent changed => nver need max updated
388         // 2. g, p, or n may need to be changed
389         // 3. p and g are children of n after left zig-zag or right zig-zag
390         // conclusion: updating and percolating upward on both p and g would be wasteful
391         n.max = g.max;  // take max of prior subtree root (g)
392         updateMax(p);
393         updateMax(g);
394     }
395 
396     // NB if change to class, add 'final'
397     /** Bring Node N to top of tree */
398     @safe @nogc nothrow
399     private void splay(Node *n) 
400     {
401         // probablistically splay:
402         // Albers and Karpinski, Randomized splay trees: theoretical and experimental results
403         // Information Processing Letters, Volume 81, Issue 4, 28 February 2002
404         // http://www14.in.tum.de/personen/albers/papers/ipl02.pdf
405         if (rand!ubyte & 0b11110000) return;
406         
407         while (n.parent !is null)
408         {
409             const Node *p = n.parent;
410             const Node *g = p.parent;
411             if (g is null) zig(n);
412             else if (g.left == p && p.left == n) zigZig(n);
413             else if (g.right== p && p.right== n) zigZig(n);
414             else zigZag(n);
415         }
416         this.root = n;
417     }
418 
419     // TBD: state of default ctor inited struct
420     // TODO: @disable postblit?
421 
422 /+
423     /// Find interval(s) overlapping query interval qi
424     Node*[] intervalsOverlappingWith(IntervalType qi)
425     {
426         Node*[] ret;    // stack
427 
428         Node *cur = root;
429         
430         if (qi.overlaps(cur)) ret ~= cur;
431 
432         // If left subtree's maximum is larger than current root's start,
433         // there may be an overlap
434         if (cur.left !is null &&
435             cur.left.max > cur.key)           /// TODO: check whether should be >=
436                 break;
437     }
438 +/
439 
440     /// find interval
441     /// TODO: use augmented tree's 'max' to efficiently bail out early
442     @nogc nothrow
443     Node *find(IntervalType interval)
444     {
445         Node *ret;
446         Node *current = this.root;
447         Node *previous;
448 
449         while (current !is null)
450         {
451             previous = current;
452             if (interval < current.interval) current = current.left;
453             else if (interval > current.interval) current = current.right;
454             else if (interval == current.interval)
455             {
456                 ret = current;
457                 break;
458             }
459             else assert(0, "An unexpected inequality occurred");
460         }
461 
462         if (ret !is null) splay(ret);        // splay to the found node
463         // TODO: Benchmark with/without below condition
464         //else if (prev !is null) splay(prev); // splay the last node searched before no result was found
465 
466         return ret;
467     }
468 
469     /** find interval(s) overlapping given interval
470         
471         unlike find interval by key, matching elements could be in left /and/ right subtree
472 
473         We use template type "T" here instead of the enclosing struct's IntervalType
474         so that we can from externally query with any type of interval object
475 
476         (note: outdated, see below)
477         Timing data:
478             UnrolledList < D array < SList (emsi)
479 
480         Notes:
481             Node*[] performed more poorly than UnrolledList on my personal Mac laptop,
482             However, dlang GC-backed Node*[] performed BETTER than UnrolledList (<5%, but consistent)
483             on linux, perhaps due to no memory pressure and GC not needing to free.
484             As this is a bioinformatics tool likely to be run on decent linux machines,
485             we will leave as dyanmic array.
486         TODO: benchmark return Node[]
487         TODO: benchmark return Node** and out count vs non-GC container
488     */
489     nothrow
490     Node*[] findOverlapsWith(T)(T qinterval)
491     if (__traits(hasMember, T, "start") &&
492         __traits(hasMember, T, "end"))
493     {
494         Node*[128] stack = void;
495         int s;
496         debug int maxs;
497         version(instrument) int visited;
498 
499         Node*[] ret;
500 
501         Node* current;
502 
503         stack[s++] = this.root;
504 
505         while(s >= 1)
506         {
507             current = stack[--s];
508             version(instrument) visited+=1;
509 
510             // if query interval lies to the right of current tree, skip  
511             if (qinterval.start >= current.max) continue;
512 
513             // if query interval end is left of the current node's start,
514             // look in the left subtree
515             if (qinterval.end <= current.interval.start)
516             {
517                 if (current.left) stack[s++] = current.left;
518                 continue;
519             }
520 
521             // if current node overlaps query interval, save it and search its children
522             if (current.interval.overlaps(qinterval)) ret ~= current;
523             if (current.left) stack[s++] = current.left;
524             if (current.right) stack[s++] = current.right;
525 
526             debug(intervaltree_debug)
527             {
528                 // guard was (s > 64) but there are three increments above, so decrease for safety
529                 if (s > 60) {
530                     import core.stdc.stdio : stderr, fprintf;
531                     fprintf(stderr, "FAIL maxs: %d", maxs);
532                     assert(0, "stack overflow :-( Please post an issue at https://github.com/blachlylab/intervaltree");
533                 }
534                 if (s > maxs) maxs = s;
535             }
536             
537         }
538 
539         debug(intervaltree_debug)
540         {
541             // Observations:
542             // Max depth observed, in real world bedcov application is ~13
543             // Max depth observed, in real world liftover application is 18 (outlier), ~12-14 max, mode 2!
544             // Max depth observed in sorted BED tree/BAM query cov is 28, max ~8, mode 2
545             //  - pathological; this is when one interval overlapped thousands to tens of thousands of intervals in tree
546             import core.stdc.stdio : stderr, fprintf;
547             fprintf(stderr, "maxs: %d\n", maxs);
548         }
549 
550         version(instrument) _splaytree_visited ~= visited;
551         // PERF: splay(current) without the branch led to marked performance degradation, > 10% worse runtime
552         if (ret.length > 0)
553             splay(ret[0]);
554         else
555             splay(current); // 3-5% runtime improvement
556         return ret;
557     }
558 
559     /// find interval by exact key -- NOT overlap
560     Node *findxxx(IntervalType interval)
561     {
562         Node*[] stack;
563 
564         if (this.root !is null)
565             stack ~= this.root; // push
566 
567         while (stack.length > 0)
568         {
569             // pop
570             Node *current = stack[$];
571             stack = stack[0 .. $-1];
572 
573             // Check if the interval is to the right of our largest value;
574             // if yes, bail out
575             if (interval.start >= current.max)  // TODO: check inequality; is >= correct for half-open coords?
576                 continue;
577             
578             // If the interval starts less than curent interval,
579             // search left subtree
580             if (interval < current.interval) {
581                 if (current.left !is null)
582                     stack ~= current.left;
583                 continue;
584             }
585 
586             // if the current node is a match, return result; then check left and right subtrees
587             if (interval == current.interval) 
588             {
589                 splay(current);
590                 return current;
591             }
592 
593             // Check left and right subtrees
594             if (current.left !is null) stack ~= current.left;
595             if (current.right!is null) stack ~= current.right;
596             /*
597             // If the current node's interval overlaps, include it in results; then check left and right subtrees
598             if (interval.start >= current.start && interval.start <= current.end)   // TODO: check inequality for half-open coords
599                 results ~= current;
600             
601             if (current.left !is null) stack ~= current.left;
602             if (current.right!is null) stack ~= current.right;
603             */
604         }
605 
606         // no match was found
607         return null;        
608     }
609 
610     /// find minimum valued Node (interval)
611     @safe @nogc nothrow
612     Node *findMin() 
613     {
614         return findSubtreeMin(this.root);
615     }
616     /// ditto
617     @safe @nogc nothrow
618     private static Node* findSubtreeMin(Node *n) 
619     {
620         Node *current = n;
621         if (current is null) return current;
622         while (current.left !is null)
623             current = current.left;         // descend leftward
624         return current;
625     }
626 
627     /** update Node n's max from subtrees
628     
629     Params:
630         n = node to update
631     */
632     pragma(inline, true)
633     @safe @nogc nothrow
634     private
635     void updateMax(Node *n) 
636     {
637         import std.algorithm.comparison : max;
638 
639         if (n !is null)
640         {
641             int localmax = n.interval.end;
642             if (n.left)
643                 localmax = max(n.left.max, localmax);
644             if (n.right)
645                 localmax = max(n.right.max, localmax);
646             n.max = localmax;
647         }
648     }
649 
650     /// insert interval, updating "max" on the way down
651     // TODO: unit test degenerate start intervals (i.e. [10, 11), [10, 13) )
652     @trusted @nogc nothrow Node* insert(IntervalType i)
653     {
654         // if empty tree, assign a new root and return
655         if (this.root is null)
656         {
657             //this.root = new Node(i);   // heap alloc
658             this.root = this.mempool.make!Node(i);
659             return this.root;
660         }
661 
662         Node *current = this.root;
663 
664         // TODO: can maybe speed this up by pulling the "add here and return" code out 
665         while (current !is null)
666         {
667             // conditionally update max irrespective of whether we add new node, or descend
668             if (i.end > current.max) current.max = i.end;
669 
670             if (i < current.interval)           // Look at left subtree
671             {
672                 if (current.left is null)       // add here and return
673                 {
674                     //Node *newNode = new Node(i);   // heap alloc
675                     Node *newNode = this.mempool.make!Node(i);
676                     current.left = newNode;
677                     newNode.parent = current;
678 
679                     splay(newNode);
680                     return newNode;
681                 }
682                 else current = current.left;    // descend leftward
683             }
684             else if (i > current.interval)      // Look at right subtree
685             {
686                 if (current.right is null)      // add here and return
687                 {
688                     //Node *newNode = new Node(i);    // heap alloc
689                     Node *newNode = this.mempool.make!Node(i);
690                     current.right = newNode;
691                     newNode.parent = current;
692 
693                     splay(newNode);
694                     return newNode;
695                 }
696                 else current = current.right;   // descend rightward
697             }
698             else                                // Aleady exists
699             {
700                 assert(i == current.interval);
701                 splay(current);
702                 return current;
703             }
704         }
705 
706         assert(0, "Unexpectedly, current is null");
707     }
708 
709     /** remove interval
710 
711         Returns:
712             * True if interval i removed
713             * False if interval not found
714 
715         TODO: check that the this.cur is not being removed, if so, also advance it to next
716     */
717     bool remove(IntervalType i);
718 
719     /// iterator functions: reset
720     @safe @nogc nothrow
721     void iteratorReset()
722     {
723         this.cur = null;
724     }
725     /// iterator functions: next
726     @safe @nogc nothrow
727     Node *iteratorNext()
728     {
729         if (this.cur is null)   // initial condition
730         {
731             this.cur = findMin();
732             return this.cur;
733         }
734         else                    // anytime after start
735         {
736             if (this.cur.right is null)
737             {
738                 while (!this.cur.isLeftChild() && this.cur.parent)   // if we are a right child (really, "if not the left child" -- root node returns false), (and not the root, or an orphan)
739                     this.cur = this.cur.parent; // ascend one level
740                 
741                 if (this.cur.parent && this.cur == this.root)
742                 {
743                     this.cur = null;
744                     return null;
745                 }
746 
747                 // now that we are a left child, ascend and return
748                 this.cur = this.cur.parent;
749                 return this.cur;
750             }
751             else    // there is a right subtree
752             {
753                 // descend right, then find the minimum
754                 this.cur = findSubtreeMin(this.cur.right);
755                 return this.cur;
756             }
757         }
758     }
759 }
760 unittest
761 {
762     import std.stdio: writeln, writefln;
763 
764     IntervalSplayTree!BasicInterval t;
765 
766     writefln("Inserted node: %s", *t.insert(BasicInterval(0, 100)));
767     while(t.iteratorNext() !is null)
768         writefln("Value in order: %s", *t.cur);
769 
770     writefln("Inserted node: %s", *t.insert(BasicInterval(100, 200)));
771     while(t.iteratorNext() !is null)
772         writefln("Value in order: %s", *t.cur);
773 
774     writefln("Inserted node: %s", *t.insert(BasicInterval(200, 300)));
775     while(t.iteratorNext() !is null)
776         writefln("Value in order: %s", *t.cur);
777 
778     writefln("Inserted node: %s", *t.insert(BasicInterval(300, 400)));
779     while(t.iteratorNext() !is null)
780         writefln("Value in order: %s", *t.cur);
781 
782     writefln("Inserted node: %s", *t.insert(BasicInterval(400, 500)));
783     while(t.iteratorNext() !is null)
784         writefln("Value in order: %s", *t.cur);
785     
786     const auto n0 = t.find(BasicInterval(200, 250));
787     assert(n0 is null);
788 
789     const auto n1 = t.find(BasicInterval(200, 300));
790     assert(n1.interval == BasicInterval(200, 300));
791 
792     writeln("\n---\n");
793 
794     while(t.iteratorNext() !is null)
795         writefln("Value in order: %s", *t.cur);
796     
797     writefln("\nOne more shows it's been reset: %s", *t.iteratorNext());
798 
799     writeln("---\nCheck overlaps:");
800     //auto x = t.findOverlapsWithXXX(BasicInterval(0, 100));
801 
802     auto o1 = t.findOverlapsWith(BasicInterval(150, 250));
803     auto o2 = t.findOverlapsWith(BasicInterval(150, 350));
804     auto o3 = t.findOverlapsWith(BasicInterval(300, 400));
805     writefln("o1: %s", o1);
806     writefln("o2: %s", o2);
807     writefln("o3: %s", o3);
808 
809 }