1 /** Interval Tree backed by augmented AVL tree
2 
3 The AVL implementation is derived from attractivechaos'
4 klib (kavl.h) and the derived code remains MIT licensed.
5 
6 Enable instrumentation:
7     version(instrument)
8         This will calculate statistics related to traversal depth
9 
10 Author: James S. Blachly, MD <james.blachly@gmail.com>
11 Copyright: Copyright (c) 2019 James Blachly
12 License: MIT
13 */
14 module intervaltree.avltree;
15 
16 import intervaltree : BasicInterval, overlaps;
17 
18 import std.traits : isPointer, PointerTarget, Unqual;
19 
20 import std.experimental.allocator;
21 import std.experimental.allocator.building_blocks.region;
22 import std.experimental.allocator.building_blocks.allocator_list : AllocatorList;
23 import std.experimental.allocator.building_blocks.null_allocator : NullAllocator;
24 import std.experimental.allocator.mallocator : Mallocator;
25 
26 version(instrument) __gshared int[] _avltree_visited;
27 
28 // LOL, this compares pointer addresses
29 //alias cmpfn = (x,y) => ((y < x) - (x < y));
30 //@safe @nogc nothrow alias cmpfn = (x, y) => ((y.interval < x.interval) - (x.interval < y.interval));
31 
32 /// child node direction
33 private enum DIR : int
34 {
35     LEFT = 0,
36     RIGHT = 1
37 }
38 
39 ///
40 private enum KAVL_MAX_DEPTH = 64;
41 
42 ///
43 pragma(inline, true)
44 @safe @nogc nothrow
45 auto kavl_size(T)(T* p) { return (p ? p.size : 0); }
46 
47 ///
48 pragma(inline, true)
49 @safe @nogc nothrow
50 auto kavl_size_child(T)(T* q, int i) { return (q.p[i] ? q.p[i].size : 0); }
51 
52 /// Consumer needs to use this with insert functions (unlike splaytree fns, which take interval directly)
53 struct IntervalTreeNode(IntervalType)
54 if (__traits(hasMember, IntervalType, "start") &&
55     __traits(hasMember, IntervalType, "end"))
56 {
57     /// sort key
58     pragma(inline,true)
59     @property @safe @nogc nothrow const
60     auto key() { return this.interval.start; }
61 
62     IntervalType interval;  /// must at a minimum include members start, end
63     
64     IntervalTreeNode*[2] p;     /// 0:left, 1:right
65     // no parent pointer in KAVL implementation
66     byte balance;   /// balance factor (signed, 8-bit)
67     uint size;      /// #elements in subtree
68     typeof(IntervalType.end) max;   /// maximum in this $(I subtree)
69 
70     /// non-default ctor: construct Node from interval, update max
71     /// side note: D is beautiful in that Node(i) will work just fine
72     /// without this constructor since its first member is IntervalType interval,
73     /// but we need the constructor to update max.
74     @safe @nogc nothrow
75     this(IntervalType i) 
76     {
77         this.interval = i;  // blit
78         this.max = i.end;
79     }
80 
81     invariant
82     {
83         // the Interval type itself should include checks, but in case it does not:
84         assert(this.interval.start <= this.interval.end, "Interval start must <= end");
85 
86         assert(this.max >= this.interval.end, "max must be at least as high as our own end");
87 
88         // Ensure children are distinct
89         if (this.p[DIR.LEFT] !is null && this.p[DIR.RIGHT] !is null)
90         {
91             assert(this.p[DIR.LEFT] != this.p[DIR.RIGHT], "Left and righ child appear identical");
92         }
93     }
94 }
95 
96 /// Common API across Interval AVL Trees and Interval Splay Trees
97 alias IntervalTree = IntervalAVLTree;
98 
99 ///
100 struct IntervalAVLTree(IntervalType)
101 {
102     alias Node = IntervalTreeNode!IntervalType;
103 
104     Node *root;    /// tree root
105 
106     private AllocatorList!((n) => Region!Mallocator(IntervalType.sizeof * 65_536), NullAllocator) mempool;
107 
108     /+
109     /// needed for iterator / range
110     const(Node)*[KAVL_MAX_DEPTH] itrstack; /// ?
111     const(Node)** top;     /// _right_ points to the right child of *top
112     const(Node)*  right;   /// _right_ points to the right child of *top
113     +/
114 
115     ///@safe @nogc nothrow alias cmpfn = (x, y) => ((y.interval < x.interval) - (x.interval < y.interval));
116     /*pragma(inline, true) @safe @nogc nothrow int cmpfn(Tx)(Tx x, const(Node)* y)
117     {
118         static if (isPointer!Tx && is(PointerTarget!(Unqual!Tx) == Node))
119             return ((y.interval < x.interval) - (x.interval < y.interval));
120         else static if (is(Tx == IntervalType))
121             return ((y.interval < x) - (x < y.interval));
122         else
123             assert(0);
124     }*/
125 
126     pragma(inline, true)
127     {
128         @safe @nogc nothrow int cmpfn(IntervalType x, inout(Node)* y)
129         {
130             return ((y.interval < x) - (x < y.interval));
131         }
132 
133         @safe @nogc nothrow int cmpfn(inout(Node)* x, inout(Node)* y)
134         {
135             return ((y.interval < x.interval) - (x.interval < y.interval));
136         }
137     }
138 
139     /**
140     * Find a node in the tree
141     *
142     * @param x       node value to find (in)
143     * @param cnt     number of nodes smaller than or equal to _x_; can be NULL (out)
144     *
145     * @return node equal to _x_ if present, or NULL if absent
146     */
147     @trusted    // cannot be @safe: casts away const
148     @nogc nothrow
149     Node *find(const(Node) *x, out uint cnt) {
150 
151         const(Node)* p = this.root;
152 
153         while (p !is null) {
154             const int cmp = cmpfn(x, p);
155             if (cmp >= 0) cnt += kavl_size_child(p, DIR.LEFT) + 1; // left tree plus self
156 
157             if (cmp < 0) p = p.p[DIR.LEFT];         // descend leftward
158             else if (cmp > 0) p = p.p[DIR.RIGHT];   // descend rightward
159             else break;
160         }
161 
162         return cast(Node*)p;    // not allowed in @safe, but is const only within this fn
163     }
164 
165     /** find interval(s) overlapping given interval
166         
167         unlike find interval by key, matching elements could be in left /and/ right subtree
168 
169         We use template type "T" here instead of the enclosing struct's IntervalType
170         so that we can from externally query with any type of interval object
171 
172         TODO: benchmark return Node[]
173     */
174     nothrow 
175     // cannot be safe due to emsi container UnrolledList
176     // cannot be @nogc due to return dynamic array
177     Node*[] findOverlapsWith(T)(T qinterval)
178     if (__traits(hasMember, T, "start") &&
179         __traits(hasMember, T, "end"))
180     {
181         // If the calling library does something stupid like, say, call this method
182         // on a null-pointer let's try to prevent a segfault.
183         // MAINTAINER: there is an identical code block in avltree.d/splaytree.d. Update both.
184         if (&this is null) {
185             debug(intervaltree_debug) {
186                 import core.stdc.stdio : stderr, fprintf;
187                 // The below error is perhaps over specific. In the case of swiftover, we use a hash table
188                 // to map contig->interval tree *, keyed on contig. If DNE it happily returns a null pointer *eyeroll*
189                 fprintf(stderr, "Null context in findOverlapsWith. Your contig probably does not exist.\n");
190             }
191             return [];
192         }
193 
194         Node*[KAVL_MAX_DEPTH] stack = void;
195         int s;
196         version(instrument) int visited;
197 
198         Node*[] ret;
199 
200         Node* current;
201 
202         stack[s++] = this.root;
203 
204         while(s >= 1)
205         {
206             current = stack[--s];
207             version(instrument) visited += 1;
208 
209             // if query interval lies to the right of current tree, skip  
210             if (qinterval.start >= current.max) continue;
211 
212             // if query interval end is left of the current node's start,
213             // look in the left subtree
214             if (qinterval.end <= current.interval.start)
215             {
216                 if (current.p[DIR.LEFT]) stack[s++] = current.p[DIR.LEFT];
217                 continue;
218             }
219 
220             // if current node overlaps query interval, save it and search its children
221             if (current.interval.overlaps(qinterval)) ret ~= current;
222             if (current.p[DIR.LEFT]) stack[s++] = current.p[DIR.LEFT];
223             if (current.p[DIR.RIGHT]) stack[s++] = current.p[DIR.RIGHT];
224         }
225 
226         version(instrument) _avltree_visited ~= visited;
227         return ret;
228     }
229 
230     /// /* one rotation: (a,(b,c)q)p => ((a,b)p,c)q */
231     pragma(inline, true)
232     @safe @nogc nothrow
233     private
234     Node *rotate1(Node *p, int dir) { /* dir=0 to left; dir=1 to right */
235         const int opp = 1 - dir; /* opposite direction */
236         Node *q = p.p[opp];
237         const uint size_p = p.size;
238         p.size -= q.size - kavl_size_child(q, dir);
239         q.size = size_p;
240         p.p[opp] = q.p[dir];
241         q.p[dir] = p;
242 
243         //JSB: update max
244         q.max = p.max;          // q came to top, can take p (prvious top)'s
245         updateMax(p);
246 
247         return q;
248     }
249 
250     /** two consecutive rotations: (a,((b,c)r,d)q)p => ((a,b)p,(c,d)q)r */
251     pragma(inline, true)
252     @safe @nogc nothrow
253     private
254     Node *rotate2(Node *p, int dir) {
255         int b1;
256         const int opp = 1 - dir;
257         Node* q = p.p[opp];
258         Node* r = q.p[dir];
259         const uint size_x_dir = kavl_size_child(r, dir);
260         r.size = p.size;
261         p.size -= q.size - size_x_dir;
262         q.size -= size_x_dir + 1;
263         p.p[opp] = r.p[dir];
264         r.p[dir] = p;
265         q.p[dir] = r.p[opp];
266         r.p[opp] = q;
267         b1 = dir == 0 ? +1 : -1;
268         if (r.balance == b1) q.balance = 0, p.balance = cast(byte)-b1;
269         else if (r.balance == 0) q.balance = p.balance = 0;
270         else q.balance = cast(byte)b1, p.balance = 0;
271         r.balance = 0;
272 
273         //JSB: update max
274         r.max = p.max;          // r came to top, can take p (prvious top)'s
275         updateMax(p);
276         updateMax(q);
277 
278         return r;
279     }
280 
281     /**
282     * Insert a node to the tree
283     *
284     *   Will update Node .max values on the way down
285     *
286     * @param interval Interval: IntervalType to insert (in)
287     * @param cnt     number of nodes smaller than or equal to _x_; can be NULL (out)
288     *
289     * @return _node*_ if not present in the tree, or the node* equal to interval I.
290     *
291     * Cannot be @safe due to call to @system std.experimental.allocator.make
292     * Cannot use @trusted escape for make as delegates are gc-allocating(?)
293     */
294     @trusted @nogc nothrow Node* insert(IntervalType interval, out uint cnt)
295     {
296         
297         ubyte[KAVL_MAX_DEPTH] stack;
298         Node*[KAVL_MAX_DEPTH] path;
299 
300         Node* bp;
301         Node* bq;
302         Node* p;    // current node in iteration
303         Node* q;    // parent of p
304         Node* r = null; /* _r_ is potentially the new root */
305 
306         int i, which = 0, top, b1, path_len;
307 
308         bp = this.root, bq = null;
309         /* find the insertion location */
310         for (p = bp, q = bq, top = path_len = 0; p; q = p, p = p.p[which]) {
311             const int cmp = cmpfn(interval, p);
312             if (cmp >= 0) cnt += kavl_size_child(p, DIR.LEFT) + 1; // left tree plus self
313             if (cmp == 0) {
314                 // an identical Node is already present here
315                 return p;
316             }
317             if (p.balance != 0)
318                 bq = q, bp = p, top = 0;
319             stack[top++] = which = (cmp > 0);
320             path[path_len++] = p;
321 
322             // JSB: conditionally update max irrespective of whether we add new node, or descend
323             if (interval.end > p.max) p.max = interval.end;
324         }
325 
326         // JSB: an interval will be inserted, create node x (previously x was parameter Node*)
327         Node* x = this.mempool.make!Node(interval);
328 
329         x.balance = 0, x.size = 1, x.p[DIR.LEFT] = x.p[DIR.RIGHT] = null;
330         if (q is null) this.root = x;
331         else q.p[which] = x;
332         if (bp is null) return x;
333         for (i = 0; i < path_len; ++i) ++path[i].size;
334         for (p = bp, top = 0; p != x; p = p.p[stack[top]], ++top) /* update balance factors */
335             if (stack[top] == 0) --p.balance;
336             else ++p.balance;
337         if (bp.balance > -2 && bp.balance < 2) return x; /* balance in [-1, 1] : no re-balance needed */
338         /* re-balance */
339         which = (bp.balance < 0);
340         b1 = which == 0 ? +1 : -1;
341         q = bp.p[1 - which];
342         if (q.balance == b1) {
343             r = rotate1(bp, which);
344             q.balance = bp.balance = 0;
345         } else r = rotate2(bp, which);
346         if (bq is null) this.root = r;
347         else bq.p[bp != bq.p[0]] = r;   // wow
348         return x;
349     }
350 
351     /**
352     * Delete a node from the tree
353     *
354     * @param x       node value to delete; if NULL, delete the first (NB: NOT ROOT!) node (in)
355     *
356     * @return node removed from the tree if present, or NULL if absent
357     */
358     /+
359     #define kavl_erase(suf, proot, x, cnt) kavl_erase_##suf(proot, x, cnt)
360     #define kavl_erase_first(suf, proot) kavl_erase_##suf(proot, 0, 0)
361     +/
362     @trusted    // cannot be @safe: takes &fake
363     @nogc nothrow
364     Node *kavl_erase(const(Node) *x, out uint cnt) {
365         Node* p;
366         Node*[KAVL_MAX_DEPTH] path;
367         Node fake;
368         ubyte[KAVL_MAX_DEPTH] dir;
369         int i, d = 0, cmp;
370         fake.p[DIR.LEFT] = this.root, fake.p[DIR.RIGHT] = null;
371 
372         if (x !is null) {
373             for (cmp = -1, p = &fake; cmp; cmp = cmpfn(x, p)) {
374                 const int which = (cmp > 0);
375                 if (cmp > 0) cnt += kavl_size_child(p, DIR.LEFT) + 1; // left tree plus self
376                 dir[d] = which;
377                 path[d++] = p;
378                 p = p.p[which];
379                 if (p is null) {
380                     // node not found
381                     return null;
382                 }
383             }
384             cnt += kavl_size_child(p, DIR.LEFT) + 1; /* because p==x is not counted */
385         } else {    // NULL, delete the first node
386             assert(x is null);
387             // Descend leftward as far as possible, set p to this node
388             for (p = &fake, cnt = 1; p; p = p.p[DIR.LEFT])
389                 dir[d] = 0, path[d++] = p;
390             p = path[--d];
391         }
392 
393         for (i = 1; i < d; ++i) --path[i].size;
394 
395         if (p.p[DIR.RIGHT] is null) { /* ((1,.)2,3)4 => (1,3)4; p=2 */
396             path[d-1].p[dir[d-1]] = p.p[DIR.LEFT];
397         } else {
398             Node *q = p.p[DIR.RIGHT];
399             if (q.p[0] is null) { /* ((1,2)3,4)5 => ((1)2,4)5; p=3 */
400                 q.p[0] = p.p[0];
401                 q.balance = p.balance;
402                 path[d-1].p[dir[d-1]] = q;
403                 path[d] = q, dir[d++] = 1;
404                 q.size = p.size - 1;
405             } else { /* ((1,((.,2)3,4)5)6,7)8 => ((1,(2,4)5)3,7)8; p=6 */
406                 Node *r;
407                 int e = d++; /* backup _d_ */
408                 for (;;) {
409                     dir[d] = 0;
410                     path[d++] = q;
411                     r = q.p[0];
412                     if (r.p[0] is null) break;
413                     q = r;
414                 }
415                 r.p[0] = p.p[0];
416                 q.p[0] = r.p[1];
417                 r.p[1] = p.p[1];
418                 r.balance = p.balance;
419                 path[e-1].p[dir[e-1]] = r;
420                 path[e] = r, dir[e] = 1;
421                 for (i = e + 1; i < d; ++i) --path[i].size;
422                 r.size = p.size - 1;
423             }
424         }
425 
426         // Rebalance on the way up
427         while (--d > 0) {
428             Node *q = path[d];
429             int which, other, b1 = 1, b2 = 2;
430             which = dir[d], other = 1 - which;
431             if (which) b1 = -b1, b2 = -b2;
432             q.balance += b1;
433             if (q.balance == b1) break;
434             else if (q.balance == b2) {
435                 Node *r = q.p[other];
436                 if (r.balance == -b1) {
437                     path[d-1].p[dir[d-1]] = rotate2(q, which);
438                 } else {
439                     path[d-1].p[dir[d-1]] = rotate1(q, which);
440                     if (r.balance == 0) {
441                         r.balance = cast(byte) -b1;
442                         q.balance = cast(byte) b1;
443                         break;
444                     } else r.balance = q.balance = 0;
445                 }
446             }
447         }
448         this.root = fake.p[0];
449         return p;
450     }
451 
452     /** update Node n's max from subtrees
453     
454     Params:
455         n = node to update
456     */
457     pragma(inline, true)
458     @safe @nogc nothrow
459     private
460     void updateMax(Node *n) 
461     {
462         import std.algorithm.comparison : max;
463 
464         if (n !is null)
465         {
466             auto localmax = n.interval.end;
467             if (n.p[DIR.LEFT])
468                 localmax = max(n.p[DIR.LEFT].max, localmax);
469             if (n.p[DIR.RIGHT])
470                 localmax = max(n.p[DIR.RIGHT].max, localmax);
471             n.max = localmax;
472 
473         }
474     }
475 
476     // TODO: iterator as InputRange
477 }
478 unittest
479 {
480     // module-level unit test
481     import std.stdio : write, writeln;
482     write(__MODULE__ ~ " unittest ...");
483 
484     auto tree = new IntervalAVLTree!BasicInterval;
485 
486     auto a = BasicInterval(0, 10);
487     auto b = BasicInterval(10, 20);
488     auto c = BasicInterval(25, 35);
489 
490     auto anode = new IntervalTreeNode!(BasicInterval)(a);
491     auto bnode = new IntervalTreeNode!(BasicInterval)(b);
492     auto cnode = new IntervalTreeNode!(BasicInterval)(c);
493 
494     uint cnt;
495     tree.insert(a, cnt);
496     tree.insert(b, cnt);
497     tree.insert(c, cnt);
498 
499     const auto found = tree.find(bnode, cnt);
500     assert(found.interval == b);
501 
502     // TODO, actually not sure that these are returned strictly ordered if there are many
503     auto o = tree.findOverlapsWith(BasicInterval(15, 30));
504     assert(o.length == 2);
505     assert(o[0].interval == bnode.interval);
506     assert(o[1].interval == cnode.interval);
507 
508     writeln("passed");
509 }