1 /** Interval Tree backed by augmented AVL tree 2 3 The AVL implementation is derived from attractivechaos' 4 klib (kavl.h) and the derived code remains MIT licensed. 5 6 Enable instrumentation: 7 version(instrument) 8 This will calculate statistics related to traversal depth 9 10 Author: James S. Blachly, MD <james.blachly@gmail.com> 11 Copyright: Copyright (c) 2019 James Blachly 12 License: MIT 13 */ 14 module intervaltree.avltree; 15 16 import intervaltree : BasicInterval, overlaps; 17 18 import std.traits : isPointer, PointerTarget, Unqual; 19 20 import std.experimental.allocator; 21 import std.experimental.allocator.building_blocks.region; 22 import std.experimental.allocator.building_blocks.allocator_list : AllocatorList; 23 import std.experimental.allocator.building_blocks.null_allocator : NullAllocator; 24 import std.experimental.allocator.mallocator : Mallocator; 25 26 version(instrument) __gshared int[] _avltree_visited; 27 28 // LOL, this compares pointer addresses 29 //alias cmpfn = (x,y) => ((y < x) - (x < y)); 30 //@safe @nogc nothrow alias cmpfn = (x, y) => ((y.interval < x.interval) - (x.interval < y.interval)); 31 32 /// child node direction 33 private enum DIR : int 34 { 35 LEFT = 0, 36 RIGHT = 1 37 } 38 39 /// 40 private enum KAVL_MAX_DEPTH = 64; 41 42 /// 43 pragma(inline, true) 44 @safe @nogc nothrow 45 auto kavl_size(T)(T* p) { return (p ? p.size : 0); } 46 47 /// 48 pragma(inline, true) 49 @safe @nogc nothrow 50 auto kavl_size_child(T)(T* q, int i) { return (q.p[i] ? q.p[i].size : 0); } 51 52 /// Consumer needs to use this with insert functions (unlike splaytree fns, which take interval directly) 53 struct IntervalTreeNode(IntervalType) 54 if (__traits(hasMember, IntervalType, "start") && 55 __traits(hasMember, IntervalType, "end")) 56 { 57 /// sort key 58 pragma(inline,true) 59 @property @safe @nogc nothrow const 60 auto key() { return this.interval.start; } 61 62 IntervalType interval; /// must at a minimum include members start, end 63 64 IntervalTreeNode*[2] p; /// 0:left, 1:right 65 // no parent pointer in KAVL implementation 66 byte balance; /// balance factor (signed, 8-bit) 67 uint size; /// #elements in subtree 68 typeof(IntervalType.end) max; /// maximum in this $(I subtree) 69 70 /// non-default ctor: construct Node from interval, update max 71 /// side note: D is beautiful in that Node(i) will work just fine 72 /// without this constructor since its first member is IntervalType interval, 73 /// but we need the constructor to update max. 74 @safe @nogc nothrow 75 this(IntervalType i) 76 { 77 this.interval = i; // blit 78 this.max = i.end; 79 } 80 81 invariant 82 { 83 // the Interval type itself should include checks, but in case it does not: 84 assert(this.interval.start <= this.interval.end, "Interval start must <= end"); 85 86 assert(this.max >= this.interval.end, "max must be at least as high as our own end"); 87 88 // Ensure children are distinct 89 if (this.p[DIR.LEFT] !is null && this.p[DIR.RIGHT] !is null) 90 { 91 assert(this.p[DIR.LEFT] != this.p[DIR.RIGHT], "Left and righ child appear identical"); 92 } 93 } 94 } 95 96 /// Common API across Interval AVL Trees and Interval Splay Trees 97 alias IntervalTree = IntervalAVLTree; 98 99 /// 100 struct IntervalAVLTree(IntervalType) 101 { 102 alias Node = IntervalTreeNode!IntervalType; 103 104 Node *root; /// tree root 105 106 private AllocatorList!((n) => Region!Mallocator(IntervalType.sizeof * 65_536), NullAllocator) mempool; 107 108 /+ 109 /// needed for iterator / range 110 const(Node)*[KAVL_MAX_DEPTH] itrstack; /// ? 111 const(Node)** top; /// _right_ points to the right child of *top 112 const(Node)* right; /// _right_ points to the right child of *top 113 +/ 114 115 ///@safe @nogc nothrow alias cmpfn = (x, y) => ((y.interval < x.interval) - (x.interval < y.interval)); 116 /*pragma(inline, true) @safe @nogc nothrow int cmpfn(Tx)(Tx x, const(Node)* y) 117 { 118 static if (isPointer!Tx && is(PointerTarget!(Unqual!Tx) == Node)) 119 return ((y.interval < x.interval) - (x.interval < y.interval)); 120 else static if (is(Tx == IntervalType)) 121 return ((y.interval < x) - (x < y.interval)); 122 else 123 assert(0); 124 }*/ 125 126 pragma(inline, true) 127 { 128 @safe @nogc nothrow int cmpfn(IntervalType x, inout(Node)* y) 129 { 130 return ((y.interval < x) - (x < y.interval)); 131 } 132 133 @safe @nogc nothrow int cmpfn(inout(Node)* x, inout(Node)* y) 134 { 135 return ((y.interval < x.interval) - (x.interval < y.interval)); 136 } 137 } 138 139 /** 140 * Find a node in the tree 141 * 142 * @param x node value to find (in) 143 * @param cnt number of nodes smaller than or equal to _x_; can be NULL (out) 144 * 145 * @return node equal to _x_ if present, or NULL if absent 146 */ 147 @trusted // cannot be @safe: casts away const 148 @nogc nothrow 149 Node *find(const(Node) *x, out uint cnt) { 150 151 const(Node)* p = this.root; 152 153 while (p !is null) { 154 const int cmp = cmpfn(x, p); 155 if (cmp >= 0) cnt += kavl_size_child(p, DIR.LEFT) + 1; // left tree plus self 156 157 if (cmp < 0) p = p.p[DIR.LEFT]; // descend leftward 158 else if (cmp > 0) p = p.p[DIR.RIGHT]; // descend rightward 159 else break; 160 } 161 162 return cast(Node*)p; // not allowed in @safe, but is const only within this fn 163 } 164 165 /** find interval(s) overlapping given interval 166 167 unlike find interval by key, matching elements could be in left /and/ right subtree 168 169 We use template type "T" here instead of the enclosing struct's IntervalType 170 so that we can from externally query with any type of interval object 171 172 TODO: benchmark return Node[] 173 */ 174 nothrow 175 // cannot be safe due to emsi container UnrolledList 176 // cannot be @nogc due to return dynamic array 177 Node*[] findOverlapsWith(T)(T qinterval) 178 if (__traits(hasMember, T, "start") && 179 __traits(hasMember, T, "end")) 180 { 181 Node*[KAVL_MAX_DEPTH] stack = void; 182 int s; 183 version(instrument) int visited; 184 185 Node*[] ret; 186 187 Node* current; 188 189 stack[s++] = this.root; 190 191 while(s >= 1) 192 { 193 current = stack[--s]; 194 version(instrument) visited += 1; 195 196 // if query interval lies to the right of current tree, skip 197 if (qinterval.start >= current.max) continue; 198 199 // if query interval end is left of the current node's start, 200 // look in the left subtree 201 if (qinterval.end <= current.interval.start) 202 { 203 if (current.p[DIR.LEFT]) stack[s++] = current.p[DIR.LEFT]; 204 continue; 205 } 206 207 // if current node overlaps query interval, save it and search its children 208 if (current.interval.overlaps(qinterval)) ret ~= current; 209 if (current.p[DIR.LEFT]) stack[s++] = current.p[DIR.LEFT]; 210 if (current.p[DIR.RIGHT]) stack[s++] = current.p[DIR.RIGHT]; 211 } 212 213 version(instrument) _avltree_visited ~= visited; 214 return ret; 215 } 216 217 /// /* one rotation: (a,(b,c)q)p => ((a,b)p,c)q */ 218 pragma(inline, true) 219 @safe @nogc nothrow 220 private 221 Node *rotate1(Node *p, int dir) { /* dir=0 to left; dir=1 to right */ 222 const int opp = 1 - dir; /* opposite direction */ 223 Node *q = p.p[opp]; 224 const uint size_p = p.size; 225 p.size -= q.size - kavl_size_child(q, dir); 226 q.size = size_p; 227 p.p[opp] = q.p[dir]; 228 q.p[dir] = p; 229 230 //JSB: update max 231 q.max = p.max; // q came to top, can take p (prvious top)'s 232 updateMax(p); 233 234 return q; 235 } 236 237 /** two consecutive rotations: (a,((b,c)r,d)q)p => ((a,b)p,(c,d)q)r */ 238 pragma(inline, true) 239 @safe @nogc nothrow 240 private 241 Node *rotate2(Node *p, int dir) { 242 int b1; 243 const int opp = 1 - dir; 244 Node* q = p.p[opp]; 245 Node* r = q.p[dir]; 246 const uint size_x_dir = kavl_size_child(r, dir); 247 r.size = p.size; 248 p.size -= q.size - size_x_dir; 249 q.size -= size_x_dir + 1; 250 p.p[opp] = r.p[dir]; 251 r.p[dir] = p; 252 q.p[dir] = r.p[opp]; 253 r.p[opp] = q; 254 b1 = dir == 0 ? +1 : -1; 255 if (r.balance == b1) q.balance = 0, p.balance = cast(byte)-b1; 256 else if (r.balance == 0) q.balance = p.balance = 0; 257 else q.balance = cast(byte)b1, p.balance = 0; 258 r.balance = 0; 259 260 //JSB: update max 261 r.max = p.max; // r came to top, can take p (prvious top)'s 262 updateMax(p); 263 updateMax(q); 264 265 return r; 266 } 267 268 /** 269 * Insert a node to the tree 270 * 271 * Will update Node .max values on the way down 272 * 273 * @param interval Interval: IntervalType to insert (in) 274 * @param cnt number of nodes smaller than or equal to _x_; can be NULL (out) 275 * 276 * @return _node*_ if not present in the tree, or the node* equal to interval I. 277 * 278 * Cannot be @safe due to call to @system std.experimental.allocator.make 279 * Cannot use @trusted escape for make as delegates are gc-allocating(?) 280 */ 281 @trusted @nogc nothrow Node* insert(IntervalType interval, out uint cnt) 282 { 283 284 ubyte[KAVL_MAX_DEPTH] stack; 285 Node*[KAVL_MAX_DEPTH] path; 286 287 Node* bp; 288 Node* bq; 289 Node* p; // current node in iteration 290 Node* q; // parent of p 291 Node* r = null; /* _r_ is potentially the new root */ 292 293 int i, which = 0, top, b1, path_len; 294 295 bp = this.root, bq = null; 296 /* find the insertion location */ 297 for (p = bp, q = bq, top = path_len = 0; p; q = p, p = p.p[which]) { 298 const int cmp = cmpfn(interval, p); 299 if (cmp >= 0) cnt += kavl_size_child(p, DIR.LEFT) + 1; // left tree plus self 300 if (cmp == 0) { 301 // an identical Node is already present here 302 return p; 303 } 304 if (p.balance != 0) 305 bq = q, bp = p, top = 0; 306 stack[top++] = which = (cmp > 0); 307 path[path_len++] = p; 308 309 // JSB: conditionally update max irrespective of whether we add new node, or descend 310 if (interval.end > p.max) p.max = interval.end; 311 } 312 313 // JSB: an interval will be inserted, create node x (previously x was parameter Node*) 314 Node* x = this.mempool.make!Node(interval); 315 316 x.balance = 0, x.size = 1, x.p[DIR.LEFT] = x.p[DIR.RIGHT] = null; 317 if (q is null) this.root = x; 318 else q.p[which] = x; 319 if (bp is null) return x; 320 for (i = 0; i < path_len; ++i) ++path[i].size; 321 for (p = bp, top = 0; p != x; p = p.p[stack[top]], ++top) /* update balance factors */ 322 if (stack[top] == 0) --p.balance; 323 else ++p.balance; 324 if (bp.balance > -2 && bp.balance < 2) return x; /* balance in [-1, 1] : no re-balance needed */ 325 /* re-balance */ 326 which = (bp.balance < 0); 327 b1 = which == 0 ? +1 : -1; 328 q = bp.p[1 - which]; 329 if (q.balance == b1) { 330 r = rotate1(bp, which); 331 q.balance = bp.balance = 0; 332 } else r = rotate2(bp, which); 333 if (bq is null) this.root = r; 334 else bq.p[bp != bq.p[0]] = r; // wow 335 return x; 336 } 337 338 /** 339 * Delete a node from the tree 340 * 341 * @param x node value to delete; if NULL, delete the first (NB: NOT ROOT!) node (in) 342 * 343 * @return node removed from the tree if present, or NULL if absent 344 */ 345 /+ 346 #define kavl_erase(suf, proot, x, cnt) kavl_erase_##suf(proot, x, cnt) 347 #define kavl_erase_first(suf, proot) kavl_erase_##suf(proot, 0, 0) 348 +/ 349 @trusted // cannot be @safe: takes &fake 350 @nogc nothrow 351 Node *kavl_erase(const(Node) *x, out uint cnt) { 352 Node* p; 353 Node*[KAVL_MAX_DEPTH] path; 354 Node fake; 355 ubyte[KAVL_MAX_DEPTH] dir; 356 int i, d = 0, cmp; 357 fake.p[DIR.LEFT] = this.root, fake.p[DIR.RIGHT] = null; 358 359 if (x !is null) { 360 for (cmp = -1, p = &fake; cmp; cmp = cmpfn(x, p)) { 361 const int which = (cmp > 0); 362 if (cmp > 0) cnt += kavl_size_child(p, DIR.LEFT) + 1; // left tree plus self 363 dir[d] = which; 364 path[d++] = p; 365 p = p.p[which]; 366 if (p is null) { 367 // node not found 368 return null; 369 } 370 } 371 cnt += kavl_size_child(p, DIR.LEFT) + 1; /* because p==x is not counted */ 372 } else { // NULL, delete the first node 373 assert(x is null); 374 // Descend leftward as far as possible, set p to this node 375 for (p = &fake, cnt = 1; p; p = p.p[DIR.LEFT]) 376 dir[d] = 0, path[d++] = p; 377 p = path[--d]; 378 } 379 380 for (i = 1; i < d; ++i) --path[i].size; 381 382 if (p.p[DIR.RIGHT] is null) { /* ((1,.)2,3)4 => (1,3)4; p=2 */ 383 path[d-1].p[dir[d-1]] = p.p[DIR.LEFT]; 384 } else { 385 Node *q = p.p[DIR.RIGHT]; 386 if (q.p[0] is null) { /* ((1,2)3,4)5 => ((1)2,4)5; p=3 */ 387 q.p[0] = p.p[0]; 388 q.balance = p.balance; 389 path[d-1].p[dir[d-1]] = q; 390 path[d] = q, dir[d++] = 1; 391 q.size = p.size - 1; 392 } else { /* ((1,((.,2)3,4)5)6,7)8 => ((1,(2,4)5)3,7)8; p=6 */ 393 Node *r; 394 int e = d++; /* backup _d_ */ 395 for (;;) { 396 dir[d] = 0; 397 path[d++] = q; 398 r = q.p[0]; 399 if (r.p[0] is null) break; 400 q = r; 401 } 402 r.p[0] = p.p[0]; 403 q.p[0] = r.p[1]; 404 r.p[1] = p.p[1]; 405 r.balance = p.balance; 406 path[e-1].p[dir[e-1]] = r; 407 path[e] = r, dir[e] = 1; 408 for (i = e + 1; i < d; ++i) --path[i].size; 409 r.size = p.size - 1; 410 } 411 } 412 413 // Rebalance on the way up 414 while (--d > 0) { 415 Node *q = path[d]; 416 int which, other, b1 = 1, b2 = 2; 417 which = dir[d], other = 1 - which; 418 if (which) b1 = -b1, b2 = -b2; 419 q.balance += b1; 420 if (q.balance == b1) break; 421 else if (q.balance == b2) { 422 Node *r = q.p[other]; 423 if (r.balance == -b1) { 424 path[d-1].p[dir[d-1]] = rotate2(q, which); 425 } else { 426 path[d-1].p[dir[d-1]] = rotate1(q, which); 427 if (r.balance == 0) { 428 r.balance = cast(byte) -b1; 429 q.balance = cast(byte) b1; 430 break; 431 } else r.balance = q.balance = 0; 432 } 433 } 434 } 435 this.root = fake.p[0]; 436 return p; 437 } 438 439 /** update Node n's max from subtrees 440 441 Params: 442 n = node to update 443 */ 444 pragma(inline, true) 445 @safe @nogc nothrow 446 private 447 void updateMax(Node *n) 448 { 449 import std.algorithm.comparison : max; 450 451 if (n !is null) 452 { 453 int localmax = n.interval.end; 454 if (n.p[DIR.LEFT]) 455 localmax = max(n.p[DIR.LEFT].max, localmax); 456 if (n.p[DIR.RIGHT]) 457 localmax = max(n.p[DIR.RIGHT].max, localmax); 458 n.max = localmax; 459 460 } 461 } 462 463 // TODO: iterator as InputRange 464 } 465 unittest 466 { 467 // module-level unit test 468 import std.stdio : write, writeln; 469 write(__MODULE__ ~ " unittest ..."); 470 471 auto tree = new IntervalAVLTree!BasicInterval; 472 473 auto a = BasicInterval(0, 10); 474 auto b = BasicInterval(10, 20); 475 auto c = BasicInterval(25, 35); 476 477 auto anode = new IntervalTreeNode!(BasicInterval)(a); 478 auto bnode = new IntervalTreeNode!(BasicInterval)(b); 479 auto cnode = new IntervalTreeNode!(BasicInterval)(c); 480 481 uint cnt; 482 tree.insert(a, cnt); 483 tree.insert(b, cnt); 484 tree.insert(c, cnt); 485 486 const auto found = tree.find(bnode, cnt); 487 assert(found.interval == b); 488 489 // TODO, actually not sure that these are returned strictly ordered if there are many 490 auto o = tree.findOverlapsWith(BasicInterval(15, 30)); 491 assert(o.length == 2); 492 assert(o[0].interval == bnode.interval); 493 assert(o[1].interval == cnode.interval); 494 495 writeln("passed"); 496 }