1 /** Interval Tree backed by augmented AVL tree
3 The AVL implementation is derived from attractivechaos'
4 klib (kavl.h) and the derived code remains MIT licensed.
6 Enable instrumentation:
7     version(instrument)
8         This will calculate statistics related to traversal depth
10 Author: James S. Blachly, MD <james.blachly@gmail.com>
11 Copyright: Copyright (c) 2019 James Blachly
12 License: MIT
13 */
14 module intervaltree.avltree;
16 import intervaltree : BasicInterval, overlaps;
18 import std.traits : isPointer, PointerTarget, Unqual;
20 import std.experimental.allocator;
21 import std.experimental.allocator.building_blocks.region;
22 import std.experimental.allocator.building_blocks.allocator_list : AllocatorList;
23 import std.experimental.allocator.building_blocks.null_allocator : NullAllocator;
24 import std.experimental.allocator.mallocator : Mallocator;
26 version(instrument) __gshared int[] _avltree_visited;
28 // LOL, this compares pointer addresses
29 //alias cmpfn = (x,y) => ((y < x) - (x < y));
30 //@safe @nogc nothrow alias cmpfn = (x, y) => ((y.interval < x.interval) - (x.interval < y.interval));
32 /// child node direction
33 private enum DIR : int
34 {
35     LEFT = 0,
36     RIGHT = 1
37 }
39 ///
40 private enum KAVL_MAX_DEPTH = 64;
42 ///
43 pragma(inline, true)
44 @safe @nogc nothrow
45 auto kavl_size(T)(T* p) { return (p ? p.size : 0); }
47 ///
48 pragma(inline, true)
49 @safe @nogc nothrow
50 auto kavl_size_child(T)(T* q, int i) { return (q.p[i] ? q.p[i].size : 0); }
52 /// Consumer needs to use this with insert functions (unlike splaytree fns, which take interval directly)
53 struct IntervalTreeNode(IntervalType)
54 if (__traits(hasMember, IntervalType, "start") &&
55     __traits(hasMember, IntervalType, "end"))
56 {
57     /// sort key
58     pragma(inline,true)
59     @property @safe @nogc nothrow const
60     auto key() { return this.interval.start; }
62     IntervalType interval;  /// must at a minimum include members start, end
64     IntervalTreeNode*[2] p;     /// 0:left, 1:right
65     // no parent pointer in KAVL implementation
66     byte balance;   /// balance factor (signed, 8-bit)
67     uint size;      /// #elements in subtree
68     typeof(IntervalType.end) max;   /// maximum in this $(I subtree)
70     /// non-default ctor: construct Node from interval, update max
71     /// side note: D is beautiful in that Node(i) will work just fine
72     /// without this constructor since its first member is IntervalType interval,
73     /// but we need the constructor to update max.
74     @safe @nogc nothrow
75     this(IntervalType i) 
76     {
77         this.interval = i;  // blit
78         this.max = i.end;
79     }
81     invariant
82     {
83         // the Interval type itself should include checks, but in case it does not:
84         assert(this.interval.start <= this.interval.end, "Interval start must <= end");
86         assert(this.max >= this.interval.end, "max must be at least as high as our own end");
88         // Ensure children are distinct
89         if (this.p[DIR.LEFT] !is null && this.p[DIR.RIGHT] !is null)
90         {
91             assert(this.p[DIR.LEFT] != this.p[DIR.RIGHT], "Left and righ child appear identical");
92         }
93     }
94 }
96 /// Common API across Interval AVL Trees and Interval Splay Trees
97 alias IntervalTree = IntervalAVLTree;
99 ///
100 struct IntervalAVLTree(IntervalType)
101 {
102     alias Node = IntervalTreeNode!IntervalType;
104     Node *root;    /// tree root
106     private AllocatorList!((n) => Region!Mallocator(IntervalType.sizeof * 65_536), NullAllocator) mempool;
108     /+
109     /// needed for iterator / range
110     const(Node)*[KAVL_MAX_DEPTH] itrstack; /// ?
111     const(Node)** top;     /// _right_ points to the right child of *top
112     const(Node)*  right;   /// _right_ points to the right child of *top
113     +/
115     ///@safe @nogc nothrow alias cmpfn = (x, y) => ((y.interval < x.interval) - (x.interval < y.interval));
116     /*pragma(inline, true) @safe @nogc nothrow int cmpfn(Tx)(Tx x, const(Node)* y)
117     {
118         static if (isPointer!Tx && is(PointerTarget!(Unqual!Tx) == Node))
119             return ((y.interval < x.interval) - (x.interval < y.interval));
120         else static if (is(Tx == IntervalType))
121             return ((y.interval < x) - (x < y.interval));
122         else
123             assert(0);
124     }*/
126     pragma(inline, true)
127     {
128         @safe @nogc nothrow int cmpfn(IntervalType x, inout(Node)* y)
129         {
130             return ((y.interval < x) - (x < y.interval));
131         }
133         @safe @nogc nothrow int cmpfn(inout(Node)* x, inout(Node)* y)
134         {
135             return ((y.interval < x.interval) - (x.interval < y.interval));
136         }
137     }
139     /**
140     * Find a node in the tree
141     *
142     * @param x       node value to find (in)
143     * @param cnt     number of nodes smaller than or equal to _x_; can be NULL (out)
144     *
145     * @return node equal to _x_ if present, or NULL if absent
146     */
147     @trusted    // cannot be @safe: casts away const
148     @nogc nothrow
149     Node *find(const(Node) *x, out uint cnt) {
151         const(Node)* p = this.root;
153         while (p !is null) {
154             const int cmp = cmpfn(x, p);
155             if (cmp >= 0) cnt += kavl_size_child(p, DIR.LEFT) + 1; // left tree plus self
157             if (cmp < 0) p = p.p[DIR.LEFT];         // descend leftward
158             else if (cmp > 0) p = p.p[DIR.RIGHT];   // descend rightward
159             else break;
160         }
162         return cast(Node*)p;    // not allowed in @safe, but is const only within this fn
163     }
165     /** find interval(s) overlapping given interval
167         unlike find interval by key, matching elements could be in left /and/ right subtree
169         We use template type "T" here instead of the enclosing struct's IntervalType
170         so that we can from externally query with any type of interval object
172         TODO: benchmark return Node[]
173     */
174     nothrow 
175     // cannot be safe due to emsi container UnrolledList
176     // cannot be @nogc due to return dynamic array
177     Node*[] findOverlapsWith(T)(T qinterval)
178     if (__traits(hasMember, T, "start") &&
179         __traits(hasMember, T, "end"))
180     {
181         // If the calling library does something stupid like, say, call this method
182         // on a null-pointer let's try to prevent a segfault.
183         // MAINTAINER: there is an identical code block in avltree.d/splaytree.d. Update both.
184         if (&this is null) {
185             debug(intervaltree_debug) {
186                 import core.stdc.stdio : stderr, fprintf;
187                 // The below error is perhaps over specific. In the case of swiftover, we use a hash table
188                 // to map contig->interval tree *, keyed on contig. If DNE it happily returns a null pointer *eyeroll*
189                 fprintf(stderr, "Null context in findOverlapsWith. Your contig probably does not exist.\n");
190             }
191             return [];
192         }
194         Node*[KAVL_MAX_DEPTH] stack = void;
195         int s;
196         version(instrument) int visited;
198         Node*[] ret;
200         Node* current;
202         stack[s++] = this.root;
204         while(s >= 1)
205         {
206             current = stack[--s];
207             version(instrument) visited += 1;
209             // if query interval lies to the right of current tree, skip  
210             if (qinterval.start >= current.max) continue;
212             // if query interval end is left of the current node's start,
213             // look in the left subtree
214             if (qinterval.end <= current.interval.start)
215             {
216                 if (current.p[DIR.LEFT]) stack[s++] = current.p[DIR.LEFT];
217                 continue;
218             }
220             // if current node overlaps query interval, save it and search its children
221             if (current.interval.overlaps(qinterval)) ret ~= current;
222             if (current.p[DIR.LEFT]) stack[s++] = current.p[DIR.LEFT];
223             if (current.p[DIR.RIGHT]) stack[s++] = current.p[DIR.RIGHT];
224         }
226         version(instrument) _avltree_visited ~= visited;
227         return ret;
228     }
230     /// /* one rotation: (a,(b,c)q)p => ((a,b)p,c)q */
231     pragma(inline, true)
232     @safe @nogc nothrow
233     private
234     Node *rotate1(Node *p, int dir) { /* dir=0 to left; dir=1 to right */
235         const int opp = 1 - dir; /* opposite direction */
236         Node *q = p.p[opp];
237         const uint size_p = p.size;
238         p.size -= q.size - kavl_size_child(q, dir);
239         q.size = size_p;
240         p.p[opp] = q.p[dir];
241         q.p[dir] = p;
243         //JSB: update max
244         q.max = p.max;          // q came to top, can take p (prvious top)'s
245         updateMax(p);
247         return q;
248     }
250     /** two consecutive rotations: (a,((b,c)r,d)q)p => ((a,b)p,(c,d)q)r */
251     pragma(inline, true)
252     @safe @nogc nothrow
253     private
254     Node *rotate2(Node *p, int dir) {
255         int b1;
256         const int opp = 1 - dir;
257         Node* q = p.p[opp];
258         Node* r = q.p[dir];
259         const uint size_x_dir = kavl_size_child(r, dir);
260         r.size = p.size;
261         p.size -= q.size - size_x_dir;
262         q.size -= size_x_dir + 1;
263         p.p[opp] = r.p[dir];
264         r.p[dir] = p;
265         q.p[dir] = r.p[opp];
266         r.p[opp] = q;
267         b1 = dir == 0 ? +1 : -1;
268         if (r.balance == b1) q.balance = 0, p.balance = cast(byte)-b1;
269         else if (r.balance == 0) q.balance = p.balance = 0;
270         else q.balance = cast(byte)b1, p.balance = 0;
271         r.balance = 0;
273         //JSB: update max
274         r.max = p.max;          // r came to top, can take p (prvious top)'s
275         updateMax(p);
276         updateMax(q);
278         return r;
279     }
281     /**
282     * Insert a node to the tree
283     *
284     *   Will update Node .max values on the way down
285     *
286     * @param interval Interval: IntervalType to insert (in)
287     * @param cnt     number of nodes smaller than or equal to _x_; can be NULL (out)
288     *
289     * @return _node*_ if not present in the tree, or the node* equal to interval I.
290     *
291     * Cannot be @safe due to call to @system std.experimental.allocator.make
292     * Cannot use @trusted escape for make as delegates are gc-allocating(?)
293     */
294     @trusted @nogc nothrow Node* insert(IntervalType interval, out uint cnt)
295     {
297         ubyte[KAVL_MAX_DEPTH] stack;
298         Node*[KAVL_MAX_DEPTH] path;
300         Node* bp;
301         Node* bq;
302         Node* p;    // current node in iteration
303         Node* q;    // parent of p
304         Node* r = null; /* _r_ is potentially the new root */
306         int i, which = 0, top, b1, path_len;
308         bp = this.root, bq = null;
309         /* find the insertion location */
310         for (p = bp, q = bq, top = path_len = 0; p; q = p, p = p.p[which]) {
311             const int cmp = cmpfn(interval, p);
312             if (cmp >= 0) cnt += kavl_size_child(p, DIR.LEFT) + 1; // left tree plus self
313             if (cmp == 0) {
314                 // an identical Node is already present here
315                 return p;
316             }
317             if (p.balance != 0)
318                 bq = q, bp = p, top = 0;
319             stack[top++] = which = (cmp > 0);
320             path[path_len++] = p;
322             // JSB: conditionally update max irrespective of whether we add new node, or descend
323             if (interval.end > p.max) p.max = interval.end;
324         }
326         // JSB: an interval will be inserted, create node x (previously x was parameter Node*)
327         Node* x = this.mempool.make!Node(interval);
329         x.balance = 0, x.size = 1, x.p[DIR.LEFT] = x.p[DIR.RIGHT] = null;
330         if (q is null) this.root = x;
331         else q.p[which] = x;
332         if (bp is null) return x;
333         for (i = 0; i < path_len; ++i) ++path[i].size;
334         for (p = bp, top = 0; p != x; p = p.p[stack[top]], ++top) /* update balance factors */
335             if (stack[top] == 0) --p.balance;
336             else ++p.balance;
337         if (bp.balance > -2 && bp.balance < 2) return x; /* balance in [-1, 1] : no re-balance needed */
338         /* re-balance */
339         which = (bp.balance < 0);
340         b1 = which == 0 ? +1 : -1;
341         q = bp.p[1 - which];
342         if (q.balance == b1) {
343             r = rotate1(bp, which);
344             q.balance = bp.balance = 0;
345         } else r = rotate2(bp, which);
346         if (bq is null) this.root = r;
347         else bq.p[bp != bq.p[0]] = r;   // wow
348         return x;
349     }
351     /**
352     * Delete a node from the tree
353     *
354     * @param x       node value to delete; if NULL, delete the first (NB: NOT ROOT!) node (in)
355     *
356     * @return node removed from the tree if present, or NULL if absent
357     */
358     /+
359     #define kavl_erase(suf, proot, x, cnt) kavl_erase_##suf(proot, x, cnt)
360     #define kavl_erase_first(suf, proot) kavl_erase_##suf(proot, 0, 0)
361     +/
362     @trusted    // cannot be @safe: takes &fake
363     @nogc nothrow
364     Node *kavl_erase(const(Node) *x, out uint cnt) {
365         Node* p;
366         Node*[KAVL_MAX_DEPTH] path;
367         Node fake;
368         ubyte[KAVL_MAX_DEPTH] dir;
369         int i, d = 0, cmp;
370         fake.p[DIR.LEFT] = this.root, fake.p[DIR.RIGHT] = null;
372         if (x !is null) {
373             for (cmp = -1, p = &fake; cmp; cmp = cmpfn(x, p)) {
374                 const int which = (cmp > 0);
375                 if (cmp > 0) cnt += kavl_size_child(p, DIR.LEFT) + 1; // left tree plus self
376                 dir[d] = which;
377                 path[d++] = p;
378                 p = p.p[which];
379                 if (p is null) {
380                     // node not found
381                     return null;
382                 }
383             }
384             cnt += kavl_size_child(p, DIR.LEFT) + 1; /* because p==x is not counted */
385         } else {    // NULL, delete the first node
386             assert(x is null);
387             // Descend leftward as far as possible, set p to this node
388             for (p = &fake, cnt = 1; p; p = p.p[DIR.LEFT])
389                 dir[d] = 0, path[d++] = p;
390             p = path[--d];
391         }
393         for (i = 1; i < d; ++i) --path[i].size;
395         if (p.p[DIR.RIGHT] is null) { /* ((1,.)2,3)4 => (1,3)4; p=2 */
396             path[d-1].p[dir[d-1]] = p.p[DIR.LEFT];
397         } else {
398             Node *q = p.p[DIR.RIGHT];
399             if (q.p[0] is null) { /* ((1,2)3,4)5 => ((1)2,4)5; p=3 */
400                 q.p[0] = p.p[0];
401                 q.balance = p.balance;
402                 path[d-1].p[dir[d-1]] = q;
403                 path[d] = q, dir[d++] = 1;
404                 q.size = p.size - 1;
405             } else { /* ((1,((.,2)3,4)5)6,7)8 => ((1,(2,4)5)3,7)8; p=6 */
406                 Node *r;
407                 int e = d++; /* backup _d_ */
408                 for (;;) {
409                     dir[d] = 0;
410                     path[d++] = q;
411                     r = q.p[0];
412                     if (r.p[0] is null) break;
413                     q = r;
414                 }
415                 r.p[0] = p.p[0];
416                 q.p[0] = r.p[1];
417                 r.p[1] = p.p[1];
418                 r.balance = p.balance;
419                 path[e-1].p[dir[e-1]] = r;
420                 path[e] = r, dir[e] = 1;
421                 for (i = e + 1; i < d; ++i) --path[i].size;
422                 r.size = p.size - 1;
423             }
424         }
426         // Rebalance on the way up
427         while (--d > 0) {
428             Node *q = path[d];
429             int which, other, b1 = 1, b2 = 2;
430             which = dir[d], other = 1 - which;
431             if (which) b1 = -b1, b2 = -b2;
432             q.balance += b1;
433             if (q.balance == b1) break;
434             else if (q.balance == b2) {
435                 Node *r = q.p[other];
436                 if (r.balance == -b1) {
437                     path[d-1].p[dir[d-1]] = rotate2(q, which);
438                 } else {
439                     path[d-1].p[dir[d-1]] = rotate1(q, which);
440                     if (r.balance == 0) {
441                         r.balance = cast(byte) -b1;
442                         q.balance = cast(byte) b1;
443                         break;
444                     } else r.balance = q.balance = 0;
445                 }
446             }
447         }
448         this.root = fake.p[0];
449         return p;
450     }
452     /** update Node n's max from subtrees
454     Params:
455         n = node to update
456     */
457     pragma(inline, true)
458     @safe @nogc nothrow
459     private
460     void updateMax(Node *n) 
461     {
462         import std.algorithm.comparison : max;
464         if (n !is null)
465         {
466             auto localmax = n.interval.end;
467             if (n.p[DIR.LEFT])
468                 localmax = max(n.p[DIR.LEFT].max, localmax);
469             if (n.p[DIR.RIGHT])
470                 localmax = max(n.p[DIR.RIGHT].max, localmax);
471             n.max = localmax;
473         }
474     }
476     // TODO: iterator as InputRange
477 }
478 unittest
479 {
480     // module-level unit test
481     import std.stdio : write, writeln;
482     write(__MODULE__ ~ " unittest ...");
484     auto tree = new IntervalAVLTree!BasicInterval;
486     auto a = BasicInterval(0, 10);
487     auto b = BasicInterval(10, 20);
488     auto c = BasicInterval(25, 35);
490     auto anode = new IntervalTreeNode!(BasicInterval)(a);
491     auto bnode = new IntervalTreeNode!(BasicInterval)(b);
492     auto cnode = new IntervalTreeNode!(BasicInterval)(c);
494     uint cnt;
495     tree.insert(a, cnt);
496     tree.insert(b, cnt);
497     tree.insert(c, cnt);
499     const auto found = tree.find(bnode, cnt);
500     assert(found.interval == b);
502     // TODO, actually not sure that these are returned strictly ordered if there are many
503     auto o = tree.findOverlapsWith(BasicInterval(15, 30));
504     assert(o.length == 2);
505     assert(o[0].interval == bnode.interval);
506     assert(o[1].interval == cnode.interval);
508     writeln("passed");
509 }