1 /** Interval Tree backed by augmented AVL tree
2 
3 The AVL implementation is derived from attractivechaos'
4 klib (kavl.h) and the derived code remains MIT licensed.
5 
6 Enable instrumentation:
7     version(instrument)
8         This will calculate statistics related to traversal depth
9 
10 Author: James S. Blachly, MD <james.blachly@gmail.com>
11 Copyright: Copyright (c) 2019 James Blachly
12 License: MIT
13 */
14 module intervaltree.avltree;
15 
16 import intervaltree : BasicInterval, overlaps;
17 
18 import std.traits : isPointer, PointerTarget, Unqual;
19 
20 import std.experimental.allocator;
21 import std.experimental.allocator.building_blocks.region;
22 import std.experimental.allocator.building_blocks.allocator_list : AllocatorList;
23 import std.experimental.allocator.building_blocks.null_allocator : NullAllocator;
24 import std.experimental.allocator.mallocator : Mallocator;
25 
26 version(instrument) __gshared int[] _avltree_visited;
27 
28 // LOL, this compares pointer addresses
29 //alias cmpfn = (x,y) => ((y < x) - (x < y));
30 //@safe @nogc nothrow alias cmpfn = (x, y) => ((y.interval < x.interval) - (x.interval < y.interval));
31 
32 /// child node direction
33 private enum DIR : int
34 {
35     LEFT = 0,
36     RIGHT = 1
37 }
38 
39 ///
40 private enum KAVL_MAX_DEPTH = 64;
41 
42 ///
43 pragma(inline, true)
44 @safe @nogc nothrow
45 auto kavl_size(T)(T* p) { return (p ? p.size : 0); }
46 
47 ///
48 pragma(inline, true)
49 @safe @nogc nothrow
50 auto kavl_size_child(T)(T* q, int i) { return (q.p[i] ? q.p[i].size : 0); }
51 
52 /// Consumer needs to use this with insert functions (unlike splaytree fns, which take interval directly)
53 struct IntervalTreeNode(IntervalType)
54 if (__traits(hasMember, IntervalType, "start") &&
55     __traits(hasMember, IntervalType, "end"))
56 {
57     /// sort key
58     pragma(inline,true)
59     @property @safe @nogc nothrow const
60     auto key() { return this.interval.start; }
61 
62     IntervalType interval;  /// must at a minimum include members start, end
63     
64     IntervalTreeNode*[2] p;     /// 0:left, 1:right
65     // no parent pointer in KAVL implementation
66     byte balance;   /// balance factor (signed, 8-bit)
67     uint size;      /// #elements in subtree
68     typeof(IntervalType.end) max;   /// maximum in this $(I subtree)
69 
70     /// non-default ctor: construct Node from interval, update max
71     /// side note: D is beautiful in that Node(i) will work just fine
72     /// without this constructor since its first member is IntervalType interval,
73     /// but we need the constructor to update max.
74     @safe @nogc nothrow
75     this(IntervalType i) 
76     {
77         this.interval = i;  // blit
78         this.max = i.end;
79     }
80 
81     invariant
82     {
83         // the Interval type itself should include checks, but in case it does not:
84         assert(this.interval.start <= this.interval.end, "Interval start must <= end");
85 
86         assert(this.max >= this.interval.end, "max must be at least as high as our own end");
87 
88         // Ensure children are distinct
89         if (this.p[DIR.LEFT] !is null && this.p[DIR.RIGHT] !is null)
90         {
91             assert(this.p[DIR.LEFT] != this.p[DIR.RIGHT], "Left and righ child appear identical");
92         }
93     }
94 }
95 
96 /// Common API across Interval AVL Trees and Interval Splay Trees
97 alias IntervalTree = IntervalAVLTree;
98 
99 ///
100 struct IntervalAVLTree(IntervalType)
101 {
102     alias Node = IntervalTreeNode!IntervalType;
103 
104     Node *root;    /// tree root
105 
106     private AllocatorList!((n) => Region!Mallocator(IntervalType.sizeof * 65_536), NullAllocator) mempool;
107 
108     /+
109     /// needed for iterator / range
110     const(Node)*[KAVL_MAX_DEPTH] itrstack; /// ?
111     const(Node)** top;     /// _right_ points to the right child of *top
112     const(Node)*  right;   /// _right_ points to the right child of *top
113     +/
114 
115     ///@safe @nogc nothrow alias cmpfn = (x, y) => ((y.interval < x.interval) - (x.interval < y.interval));
116     /*pragma(inline, true) @safe @nogc nothrow int cmpfn(Tx)(Tx x, const(Node)* y)
117     {
118         static if (isPointer!Tx && is(PointerTarget!(Unqual!Tx) == Node))
119             return ((y.interval < x.interval) - (x.interval < y.interval));
120         else static if (is(Tx == IntervalType))
121             return ((y.interval < x) - (x < y.interval));
122         else
123             assert(0);
124     }*/
125 
126     pragma(inline, true)
127     {
128         @safe @nogc nothrow int cmpfn(IntervalType x, inout(Node)* y)
129         {
130             return ((y.interval < x) - (x < y.interval));
131         }
132 
133         @safe @nogc nothrow int cmpfn(inout(Node)* x, inout(Node)* y)
134         {
135             return ((y.interval < x.interval) - (x.interval < y.interval));
136         }
137     }
138 
139     /**
140     * Find a node in the tree
141     *
142     * @param x       node value to find (in)
143     * @param cnt     number of nodes smaller than or equal to _x_; can be NULL (out)
144     *
145     * @return node equal to _x_ if present, or NULL if absent
146     */
147     @trusted    // cannot be @safe: casts away const
148     @nogc nothrow
149     Node *find(const(Node) *x, out uint cnt) {
150 
151         const(Node)* p = this.root;
152 
153         while (p !is null) {
154             const int cmp = cmpfn(x, p);
155             if (cmp >= 0) cnt += kavl_size_child(p, DIR.LEFT) + 1; // left tree plus self
156 
157             if (cmp < 0) p = p.p[DIR.LEFT];         // descend leftward
158             else if (cmp > 0) p = p.p[DIR.RIGHT];   // descend rightward
159             else break;
160         }
161 
162         return cast(Node*)p;    // not allowed in @safe, but is const only within this fn
163     }
164 
165     /** find interval(s) overlapping given interval
166         
167         unlike find interval by key, matching elements could be in left /and/ right subtree
168 
169         We use template type "T" here instead of the enclosing struct's IntervalType
170         so that we can from externally query with any type of interval object
171 
172         TODO: benchmark return Node[]
173     */
174     nothrow 
175     // cannot be safe due to emsi container UnrolledList
176     // cannot be @nogc due to return dynamic array
177     Node*[] findOverlapsWith(T)(T qinterval)
178     if (__traits(hasMember, T, "start") &&
179         __traits(hasMember, T, "end"))
180     {
181         Node*[KAVL_MAX_DEPTH] stack = void;
182         int s;
183         version(instrument) int visited;
184 
185         Node*[] ret;
186 
187         Node* current;
188 
189         stack[s++] = this.root;
190 
191         while(s >= 1)
192         {
193             current = stack[--s];
194             version(instrument) visited += 1;
195 
196             // if query interval lies to the right of current tree, skip  
197             if (qinterval.start >= current.max) continue;
198 
199             // if query interval end is left of the current node's start,
200             // look in the left subtree
201             if (qinterval.end <= current.interval.start)
202             {
203                 if (current.p[DIR.LEFT]) stack[s++] = current.p[DIR.LEFT];
204                 continue;
205             }
206 
207             // if current node overlaps query interval, save it and search its children
208             if (current.interval.overlaps(qinterval)) ret ~= current;
209             if (current.p[DIR.LEFT]) stack[s++] = current.p[DIR.LEFT];
210             if (current.p[DIR.RIGHT]) stack[s++] = current.p[DIR.RIGHT];
211         }
212 
213         version(instrument) _avltree_visited ~= visited;
214         return ret;
215     }
216 
217     /// /* one rotation: (a,(b,c)q)p => ((a,b)p,c)q */
218     pragma(inline, true)
219     @safe @nogc nothrow
220     private
221     Node *rotate1(Node *p, int dir) { /* dir=0 to left; dir=1 to right */
222         const int opp = 1 - dir; /* opposite direction */
223         Node *q = p.p[opp];
224         const uint size_p = p.size;
225         p.size -= q.size - kavl_size_child(q, dir);
226         q.size = size_p;
227         p.p[opp] = q.p[dir];
228         q.p[dir] = p;
229 
230         //JSB: update max
231         q.max = p.max;          // q came to top, can take p (prvious top)'s
232         updateMax(p);
233 
234         return q;
235     }
236 
237     /** two consecutive rotations: (a,((b,c)r,d)q)p => ((a,b)p,(c,d)q)r */
238     pragma(inline, true)
239     @safe @nogc nothrow
240     private
241     Node *rotate2(Node *p, int dir) {
242         int b1;
243         const int opp = 1 - dir;
244         Node* q = p.p[opp];
245         Node* r = q.p[dir];
246         const uint size_x_dir = kavl_size_child(r, dir);
247         r.size = p.size;
248         p.size -= q.size - size_x_dir;
249         q.size -= size_x_dir + 1;
250         p.p[opp] = r.p[dir];
251         r.p[dir] = p;
252         q.p[dir] = r.p[opp];
253         r.p[opp] = q;
254         b1 = dir == 0 ? +1 : -1;
255         if (r.balance == b1) q.balance = 0, p.balance = cast(byte)-b1;
256         else if (r.balance == 0) q.balance = p.balance = 0;
257         else q.balance = cast(byte)b1, p.balance = 0;
258         r.balance = 0;
259 
260         //JSB: update max
261         r.max = p.max;          // r came to top, can take p (prvious top)'s
262         updateMax(p);
263         updateMax(q);
264 
265         return r;
266     }
267 
268     /**
269     * Insert a node to the tree
270     *
271     *   Will update Node .max values on the way down
272     *
273     * @param interval Interval: IntervalType to insert (in)
274     * @param cnt     number of nodes smaller than or equal to _x_; can be NULL (out)
275     *
276     * @return _node*_ if not present in the tree, or the node* equal to interval I.
277     *
278     * Cannot be @safe due to call to @system std.experimental.allocator.make
279     * Cannot use @trusted escape for make as delegates are gc-allocating(?)
280     */
281     @trusted @nogc nothrow Node* insert(IntervalType interval, out uint cnt)
282     {
283         
284         ubyte[KAVL_MAX_DEPTH] stack;
285         Node*[KAVL_MAX_DEPTH] path;
286 
287         Node* bp;
288         Node* bq;
289         Node* p;    // current node in iteration
290         Node* q;    // parent of p
291         Node* r = null; /* _r_ is potentially the new root */
292 
293         int i, which = 0, top, b1, path_len;
294 
295         bp = this.root, bq = null;
296         /* find the insertion location */
297         for (p = bp, q = bq, top = path_len = 0; p; q = p, p = p.p[which]) {
298             const int cmp = cmpfn(interval, p);
299             if (cmp >= 0) cnt += kavl_size_child(p, DIR.LEFT) + 1; // left tree plus self
300             if (cmp == 0) {
301                 // an identical Node is already present here
302                 return p;
303             }
304             if (p.balance != 0)
305                 bq = q, bp = p, top = 0;
306             stack[top++] = which = (cmp > 0);
307             path[path_len++] = p;
308 
309             // JSB: conditionally update max irrespective of whether we add new node, or descend
310             if (interval.end > p.max) p.max = interval.end;
311         }
312 
313         // JSB: an interval will be inserted, create node x (previously x was parameter Node*)
314         Node* x = this.mempool.make!Node(interval);
315 
316         x.balance = 0, x.size = 1, x.p[DIR.LEFT] = x.p[DIR.RIGHT] = null;
317         if (q is null) this.root = x;
318         else q.p[which] = x;
319         if (bp is null) return x;
320         for (i = 0; i < path_len; ++i) ++path[i].size;
321         for (p = bp, top = 0; p != x; p = p.p[stack[top]], ++top) /* update balance factors */
322             if (stack[top] == 0) --p.balance;
323             else ++p.balance;
324         if (bp.balance > -2 && bp.balance < 2) return x; /* balance in [-1, 1] : no re-balance needed */
325         /* re-balance */
326         which = (bp.balance < 0);
327         b1 = which == 0 ? +1 : -1;
328         q = bp.p[1 - which];
329         if (q.balance == b1) {
330             r = rotate1(bp, which);
331             q.balance = bp.balance = 0;
332         } else r = rotate2(bp, which);
333         if (bq is null) this.root = r;
334         else bq.p[bp != bq.p[0]] = r;   // wow
335         return x;
336     }
337 
338     /**
339     * Delete a node from the tree
340     *
341     * @param x       node value to delete; if NULL, delete the first (NB: NOT ROOT!) node (in)
342     *
343     * @return node removed from the tree if present, or NULL if absent
344     */
345     /+
346     #define kavl_erase(suf, proot, x, cnt) kavl_erase_##suf(proot, x, cnt)
347     #define kavl_erase_first(suf, proot) kavl_erase_##suf(proot, 0, 0)
348     +/
349     @trusted    // cannot be @safe: takes &fake
350     @nogc nothrow
351     Node *kavl_erase(const(Node) *x, out uint cnt) {
352         Node* p;
353         Node*[KAVL_MAX_DEPTH] path;
354         Node fake;
355         ubyte[KAVL_MAX_DEPTH] dir;
356         int i, d = 0, cmp;
357         fake.p[DIR.LEFT] = this.root, fake.p[DIR.RIGHT] = null;
358 
359         if (x !is null) {
360             for (cmp = -1, p = &fake; cmp; cmp = cmpfn(x, p)) {
361                 const int which = (cmp > 0);
362                 if (cmp > 0) cnt += kavl_size_child(p, DIR.LEFT) + 1; // left tree plus self
363                 dir[d] = which;
364                 path[d++] = p;
365                 p = p.p[which];
366                 if (p is null) {
367                     // node not found
368                     return null;
369                 }
370             }
371             cnt += kavl_size_child(p, DIR.LEFT) + 1; /* because p==x is not counted */
372         } else {    // NULL, delete the first node
373             assert(x is null);
374             // Descend leftward as far as possible, set p to this node
375             for (p = &fake, cnt = 1; p; p = p.p[DIR.LEFT])
376                 dir[d] = 0, path[d++] = p;
377             p = path[--d];
378         }
379 
380         for (i = 1; i < d; ++i) --path[i].size;
381 
382         if (p.p[DIR.RIGHT] is null) { /* ((1,.)2,3)4 => (1,3)4; p=2 */
383             path[d-1].p[dir[d-1]] = p.p[DIR.LEFT];
384         } else {
385             Node *q = p.p[DIR.RIGHT];
386             if (q.p[0] is null) { /* ((1,2)3,4)5 => ((1)2,4)5; p=3 */
387                 q.p[0] = p.p[0];
388                 q.balance = p.balance;
389                 path[d-1].p[dir[d-1]] = q;
390                 path[d] = q, dir[d++] = 1;
391                 q.size = p.size - 1;
392             } else { /* ((1,((.,2)3,4)5)6,7)8 => ((1,(2,4)5)3,7)8; p=6 */
393                 Node *r;
394                 int e = d++; /* backup _d_ */
395                 for (;;) {
396                     dir[d] = 0;
397                     path[d++] = q;
398                     r = q.p[0];
399                     if (r.p[0] is null) break;
400                     q = r;
401                 }
402                 r.p[0] = p.p[0];
403                 q.p[0] = r.p[1];
404                 r.p[1] = p.p[1];
405                 r.balance = p.balance;
406                 path[e-1].p[dir[e-1]] = r;
407                 path[e] = r, dir[e] = 1;
408                 for (i = e + 1; i < d; ++i) --path[i].size;
409                 r.size = p.size - 1;
410             }
411         }
412 
413         // Rebalance on the way up
414         while (--d > 0) {
415             Node *q = path[d];
416             int which, other, b1 = 1, b2 = 2;
417             which = dir[d], other = 1 - which;
418             if (which) b1 = -b1, b2 = -b2;
419             q.balance += b1;
420             if (q.balance == b1) break;
421             else if (q.balance == b2) {
422                 Node *r = q.p[other];
423                 if (r.balance == -b1) {
424                     path[d-1].p[dir[d-1]] = rotate2(q, which);
425                 } else {
426                     path[d-1].p[dir[d-1]] = rotate1(q, which);
427                     if (r.balance == 0) {
428                         r.balance = cast(byte) -b1;
429                         q.balance = cast(byte) b1;
430                         break;
431                     } else r.balance = q.balance = 0;
432                 }
433             }
434         }
435         this.root = fake.p[0];
436         return p;
437     }
438 
439     /** update Node n's max from subtrees
440     
441     Params:
442         n = node to update
443     */
444     pragma(inline, true)
445     @safe @nogc nothrow
446     private
447     void updateMax(Node *n) 
448     {
449         import std.algorithm.comparison : max;
450 
451         if (n !is null)
452         {
453             int localmax = n.interval.end;
454             if (n.p[DIR.LEFT])
455                 localmax = max(n.p[DIR.LEFT].max, localmax);
456             if (n.p[DIR.RIGHT])
457                 localmax = max(n.p[DIR.RIGHT].max, localmax);
458             n.max = localmax;
459 
460         }
461     }
462 
463     // TODO: iterator as InputRange
464 }
465 unittest
466 {
467     // module-level unit test
468     import std.stdio : write, writeln;
469     write(__MODULE__ ~ " unittest ...");
470 
471     auto tree = new IntervalAVLTree!BasicInterval;
472 
473     auto a = BasicInterval(0, 10);
474     auto b = BasicInterval(10, 20);
475     auto c = BasicInterval(25, 35);
476 
477     auto anode = new IntervalTreeNode!(BasicInterval)(a);
478     auto bnode = new IntervalTreeNode!(BasicInterval)(b);
479     auto cnode = new IntervalTreeNode!(BasicInterval)(c);
480 
481     uint cnt;
482     tree.insert(a, cnt);
483     tree.insert(b, cnt);
484     tree.insert(c, cnt);
485 
486     const auto found = tree.find(bnode, cnt);
487     assert(found.interval == b);
488 
489     // TODO, actually not sure that these are returned strictly ordered if there are many
490     auto o = tree.findOverlapsWith(BasicInterval(15, 30));
491     assert(o.length == 2);
492     assert(o[0].interval == bnode.interval);
493     assert(o[1].interval == cnode.interval);
494 
495     writeln("passed");
496 }