Skip to content

Commit 2e0c3de

Browse files
nomeataclaude
andcommitted
refactor: path compression for DiscrTree
This PR adds path compression to the DiscrTree trie data structure. Instead of a single `.node` constructor, the `Trie` type now has four constructors: `.empty`, `.values`, `.path` (for compressed sequences of keys with no branching), and `.branch`. This reduces the size of large discrimination trees (e.g. mathlib's library_search DiscrTree from 217 MB to 41 MB) and the maximum depth (from 2596 to 27). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 333ab1c commit 2e0c3de

7 files changed

Lines changed: 277 additions & 100 deletions

File tree

src/Lean/Meta/DiscrTree/Basic.lean

Lines changed: 87 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ def hasNoindexAnnotation (e : Expr) : Bool :=
1919
annotation? `noindex e |>.isSome
2020

2121
instance : Inhabited (Trie α) where
22-
default := .node #[] #[]
22+
default := .empty
2323

2424
instance : Inhabited (DiscrTree α) where
2525
default := {}
@@ -60,9 +60,17 @@ def Key.format : Key → Format
6060
instance : ToFormat Key := ⟨Key.format⟩
6161

6262
partial def Trie.format [ToFormat α] : Trie α → Format
63-
| .node vs cs => Format.group $ Format.paren $
64-
"node" ++ (if vs.isEmpty then Format.nil else " " ++ Std.format vs)
65-
++ Format.join (cs.toList.map fun ⟨k, c⟩ => Format.line ++ Format.paren (Std.format k ++ " => " ++ format c))
63+
| .empty => "empty"
64+
| .values vs t => Format.group $ Format.paren $
65+
"values" ++ (if vs.isEmpty then Format.nil else " " ++ Std.format vs)
66+
++ Format.line ++ format t
67+
| .path ks t => Format.group $ Format.paren $
68+
"path" ++ (if ks.isEmpty then Format.nil else " " ++ Std.format ks)
69+
++ Format.line ++ format t
70+
| .branch cs => Format.group $ Format.paren $
71+
"node"
72+
++ Format.join (cs.toList.map fun ⟨k, c⟩ =>
73+
Format.line ++ Format.paren (Std.format k ++ " => " ++ format c))
6674

6775
instance [ToFormat α] : ToFormat (Trie α) := ⟨Trie.format⟩
6876

@@ -122,13 +130,13 @@ where
122130
r := r.push (← go)
123131
return r
124132

133+
/-- Creates a trie with the keys `keys` starting at `i`, and the value `v` as the leaf -/
125134
private partial def createNodes (keys : Array Key) (v : α) (i : Nat) : Trie α :=
126-
if h : i < keys.size then
127-
let k := keys[i]
128-
let c := createNodes keys v (i+1)
129-
.node #[] #[(k, c)]
135+
let t := .values #[v] .empty
136+
if i < keys.size then
137+
.path (keys.extract i keys.size) t
130138
else
131-
.node #[v] #[]
139+
t
132140

133141
/--
134142
If `vs` contains an element `v'` such that `v == v'`, then replace `v'` with `v`.
@@ -149,18 +157,78 @@ where
149157
vs.push v
150158
termination_by vs.size - i
151159

152-
private partial def insertAux [BEq α] (keys : Array Key) (v : α) : Nat → Trie α → Trie α
153-
| i, .node vs cs =>
154-
if h : i < keys.size then
160+
/--
161+
Calculate the length of the common prefix of two arrays of keys.
162+
The parameter `i` marks the starting position in the second array.
163+
-/
164+
def commonPrefix (ks1 : Array Key) (ks2 : Array Key) (i : Nat) : Nat :=
165+
go 0
166+
where
167+
go (j : Nat) : Nat :=
168+
if h1 : j < ks1.size then
169+
if h2 : j + i < ks2.size then
170+
if ks1[j] == ks2[j + i] then
171+
go (j + 1)
172+
else
173+
j
174+
else
175+
j
176+
else
177+
j
178+
termination_by ks1.size - j
179+
180+
/-- Smart constructor around branch that ensures the ordering -/
181+
private def branch2 (k1 : Key) (t1 : Trie α) (k2 : Key) (t2 : Trie α) : Trie α :=
182+
if k1 < k2 then
183+
.branch #[(k1, t1), (k2, t2)]
184+
else
185+
.branch #[(k2, t2), (k1, t1)]
186+
187+
/-- Smart constructor ensuring that `.values` constructors are not nested -/
188+
private partial def insertHere [BEq α] (v : α) : Trie α → Trie α
189+
| .values vs t => .values (insertVal vs v) t
190+
| t => .values #[v] t
191+
192+
private partial def insertAux [BEq α] (keys : Array Key) (v : α) (i : Nat) (t : Trie α) :
193+
Trie α :=
194+
if h : i < keys.size then
195+
-- we have to walk down the tree some more
196+
match t with
197+
| .empty => createNodes keys v i
198+
| .values _ t => insertAux keys v i t
199+
| .path ks t =>
200+
let j := commonPrefix ks keys i
201+
let t' := -- the new tree after the common prefix
202+
if h1 : j < ks.size then
203+
if h2 : i + j < keys.size then
204+
-- we must branch at offset j
205+
let k1 := ks[j]
206+
let t1 := if j + 1 < ks.size then .path (ks.extract (j + 1) ks.size) t else t
207+
let k2 := keys[i + j]
208+
let t2 := createNodes keys v (i + j + 1)
209+
branch2 k1 t1 k2 t2
210+
else
211+
-- the entry keys are a prefix of the path in the node: split the path, insert the value
212+
.values #[v] (.path (ks.extract j ks.size) t)
213+
else
214+
-- the node path is a prefix of the new entry
215+
insertAux keys v (i + j) t
216+
if 0 < j then
217+
-- add a .path for the common prefix, if present
218+
.path (ks.extract 0 j) t'
219+
else
220+
t'
221+
| .branch cs =>
155222
let k := keys[i]
156223
let c := Id.run $ cs.binInsertM
157224
(fun a b => a.1 < b.1)
158225
(fun ⟨_, s⟩ => let c := insertAux keys v (i+1) s; (k, c)) -- merge with existing
159226
(fun _ => let c := createNodes keys v (i+1); (k, c))
160227
(k, default)
161-
.node vs c
162-
else
163-
.node (insertVal vs v) cs
228+
.branch c
229+
else
230+
-- this is where we need to insert the value
231+
insertHere v t
164232

165233
def insertKeyValue [BEq α] (d : DiscrTree α) (keys : Array Key) (v : α) : DiscrTree α :=
166234
if keys.isEmpty then panic! "invalid key sequence"
@@ -174,6 +242,10 @@ def insertKeyValue [BEq α] (d : DiscrTree α) (keys : Array Key) (v : α) : Dis
174242
let c := insertAux keys v 1 c
175243
{ root := d.root.insert k c }
176244

245+
def getValues : Trie α → Array α
246+
| .values vs _ => vs
247+
| _ => #[]
248+
177249
@[deprecated insertKeyValue (since := "2026-01-02")]
178250
def insertCore [BEq α] (d : DiscrTree α) (keys : Array Key) (v : α) : DiscrTree α :=
179251
insertKeyValue d keys v

src/Lean/Meta/DiscrTree/Main.lean

Lines changed: 105 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -429,27 +429,27 @@ private abbrev getUnifyKeyArgs (e : Expr) (root : Bool) : MetaM (Key × Array Ex
429429
private def getStarResult (d : DiscrTree α) : Array α :=
430430
let result : Array α := .mkEmpty initCapacity
431431
match d.root.find? .star with
432-
| none => result
433-
| some (.node vs _) => result ++ vs
432+
| none => result
433+
| some t => result ++ getValues t
434434

435435
private abbrev findKey (cs : Array (Key × Trie α)) (k : Key) : Option (Key × Trie α) :=
436436
cs.binSearch (k, default) (fun a b => a.1 < b.1)
437437

438438
private partial def getMatchLoop (todo : Array Expr) (c : Trie α) (result : Array α) : MetaM (Array α) := do
439-
match c with
440-
| .node vs cs =>
441-
if todo.isEmpty then
442-
return result ++ vs
443-
else if cs.isEmpty then
444-
return result
445-
else
439+
if todo.isEmpty then
440+
return result ++ getValues c
441+
else
442+
match c with
443+
| .empty => return result
444+
| .values _ t => getMatchLoop todo t result
445+
| .branch cs =>
446+
if cs.isEmpty then return result else -- should not happen
446447
let e := todo.back!
447448
let todo := todo.pop
448449
let first := cs[0]! /- Recall that `Key.star` is the minimal key -/
449-
let (k, args) ← getMatchKeyArgs e (root := false)
450450
/- We must always visit `Key.star` edges since they are wildcards.
451-
Thus, `todo` is not used linearly when there is `Key.star` edge
452-
and there is an edge for `k` and `k != Key.star`. -/
451+
Thus, `todo` is not used linearly when there is `Key.star` edge
452+
and there is an edge for `k` and `k != Key.star`. -/
453453
let visitStar (result : Array α) : MetaM (Array α) :=
454454
if first.1 == .star then
455455
getMatchLoop todo first.2 result
@@ -459,10 +459,38 @@ private partial def getMatchLoop (todo : Array Expr) (c : Trie α) (result : Arr
459459
match findKey cs k with
460460
| none => return result
461461
| some c => getMatchLoop (todo ++ args) c.2 result
462+
let (k, args) ← getMatchKeyArgs e (root := false)
462463
let result ← visitStar result
463464
match k with
464465
| .star => return result
465466
| _ => visitNonStar k args result
467+
| .path ks t =>
468+
let rec loop (todo : Array Expr) (result : Array α) (i : Nat) : MetaM (Array α) := do
469+
-- the following logic is a copy of the .branch case, as if `cs` is a singleton
470+
if h : i < ks.size then
471+
if todo.isEmpty then
472+
return result
473+
let e := todo.back!
474+
let todo := todo.pop
475+
let k' := ks[i]
476+
let visitStar (result : Array α) : MetaM (Array α) :=
477+
if k' == .star then
478+
loop todo result (i + 1)
479+
else
480+
return result
481+
let visitNonStar (k : Key) (args : Array Expr) (result : Array α) : MetaM (Array α) :=
482+
if k == k' then
483+
loop (todo ++ args) result (i + 1)
484+
else
485+
return result
486+
let (k, args) ← getMatchKeyArgs e (root := false)
487+
let result ← visitStar result
488+
match k with
489+
| .star => return result
490+
| _ => visitNonStar k args result
491+
else
492+
getMatchLoop todo t result
493+
loop todo result 0
466494

467495
private def getMatchRoot (d : DiscrTree α) (k : Key) (args : Array Expr) (result : Array α) : MetaM (Array α) :=
468496
match d.root.find? k with
@@ -544,8 +572,11 @@ private partial def getAllValuesForKey (d : DiscrTree α) (k : Key) (result : Ar
544572
where
545573
go (trie : Trie α) (result : Array α) : Array α := Id.run do
546574
match trie with
547-
| .node vs cs =>
548-
let mut result := result ++ vs
575+
| .empty => return result
576+
| .values vs t => go t (result ++ vs)
577+
| .path _ t => go t result
578+
| .branch cs =>
579+
let mut result := result
549580
for (_, trie) in cs do
550581
result := go trie result
551582
return result
@@ -576,33 +607,66 @@ partial def getUnify (d : DiscrTree α) (e : Expr) : MetaM (Array α) :=
576607
| some c => process 0 args c result
577608
where
578609
process (skip : Nat) (todo : Array Expr) (c : Trie α) (result : Array α) : MetaM (Array α) := do
579-
match skip, c with
580-
| skip+1, .node _ cs =>
581-
if cs.isEmpty then
582-
return result
583-
else
584-
cs.foldlM (init := result) fun result ⟨k, c⟩ => process (skip + k.arity) todo c result
585-
| 0, .node vs cs => do
586-
if todo.isEmpty then
587-
return result ++ vs
588-
else if cs.isEmpty then
589-
return result
590-
else
591-
let e := todo.back!
592-
let todo := todo.pop
593-
let (k, args) ← getUnifyKeyArgs e (root := false)
594-
let visitStar (result : Array α) : MetaM (Array α) :=
595-
let first := cs[0]!
596-
if first.1 == .star then
597-
process 0 todo first.2 result
598-
else
610+
if skip == 0 && todo.isEmpty then
611+
return result ++ getValues c
612+
else match c with
613+
| .empty => return result
614+
| .values _ t => process skip todo t result
615+
| .branch cs =>
616+
match skip with
617+
| skip+1 =>
618+
if cs.isEmpty then
599619
return result
600-
let visitNonStar (k : Key) (args : Array Expr) (result : Array α) : MetaM (Array α) :=
601-
match findKey cs k with
602-
| none => return result
603-
| some c => process 0 (todo ++ args) c.2 result
604-
match k with
605-
| .star => cs.foldlM (init := result) fun result ⟨k, c⟩ => process k.arity todo c result
606-
| _ => visitNonStar k args (← visitStar result)
620+
else
621+
cs.foldlM (init := result) fun result ⟨k, c⟩ => process (skip + k.arity) todo c result
622+
| 0 => do
623+
if cs.isEmpty then return result else -- should not happen
624+
let e := todo.back!
625+
let todo := todo.pop
626+
let (k, args) ← getUnifyKeyArgs e (root := false)
627+
let visitStar (result : Array α) : MetaM (Array α) :=
628+
let first := cs[0]!
629+
if first.1 == .star then
630+
process 0 todo first.2 result
631+
else
632+
return result
633+
let visitNonStar (k : Key) (args : Array Expr) (result : Array α) : MetaM (Array α) :=
634+
match findKey cs k with
635+
| none => return result
636+
| some c => process 0 (todo ++ args) c.2 result
637+
match k with
638+
| .star => cs.foldlM (init := result) fun result ⟨k, c⟩ => process k.arity todo c result
639+
| _ => visitNonStar k args (← visitStar result)
640+
| .path ks t =>
641+
let rec loop (skip : Nat) (todo : Array Expr) (result : Array α) (i : Nat) : MetaM (Array α) :=
642+
if h : i < ks.size then
643+
match skip with
644+
| skip+1 =>
645+
let k' := ks[i]
646+
loop (skip + k'.arity) todo result (i + 1)
647+
| 0 => do
648+
if todo.isEmpty then
649+
return result
650+
else
651+
let e := todo.back!
652+
let todo := todo.pop
653+
let k' := ks[i]
654+
let visitStar (result : Array α) : MetaM (Array α) :=
655+
if k' == .star then
656+
loop 0 todo result (i + 1)
657+
else
658+
return result
659+
let visitNonStar (k : Key) (args : Array Expr) (result : Array α) : MetaM (Array α) :=
660+
if k' == k then
661+
loop 0 (todo ++ args) result (i + 1)
662+
else
663+
return result
664+
let (k, args) ← getUnifyKeyArgs e (root := false)
665+
match k with
666+
| .star => loop k'.arity todo result (i + 1)
667+
| _ => visitNonStar k args (← visitStar result)
668+
else
669+
process skip todo t result
670+
loop skip todo result 0
607671

608672
end Lean.Meta.DiscrTree

src/Lean/Meta/DiscrTree/Types.lean

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,10 @@ instance : Hashable Key := ⟨Key.hash⟩
3838
Discrimination tree trie. See `DiscrTree`.
3939
-/
4040
inductive Trie (α : Type) where
41-
| node (vs : Array α) (children : Array (Key × Trie α)) : Trie α
41+
| empty
42+
| values (vs : Array α) (t : Trie α)
43+
| path (ks : Array Key) (t : Trie α)
44+
| branch (children : Array (Key × Trie α))
4245

4346
end DiscrTree
4447

0 commit comments

Comments
 (0)