Skip to content

Commit 0d3b54a

Browse files
committed
fix: use List.flatMap without @[specialize] to reduce code bloating
1 parent 48d2f74 commit 0d3b54a

File tree

1 file changed

+28
-7
lines changed

1 file changed

+28
-7
lines changed

Loom/MonadAlgebras/NonDetT'/ExtractListCore.lean

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,18 @@ theorem pointwiseSup_alt {l : Type v} [CompleteBooleanAlgebra l] {α : Type u} (
5252
apply iSup_congr ; intro a
5353
by_cases h : a ∈ lis <;> simp [h]
5454

55+
@[inline] def List.flatMapTRNoSpecialize (f : α → List β) (as : List α) : List β := go as #[] where
56+
go : List α → Array β → List β
57+
| [], acc => acc.toList
58+
| x::xs, acc => go xs (acc ++ f x)
59+
60+
theorem List.flatMapTRNoSpecialize_eq_flatMap : @List.flatMapTRNoSpecialize = @List.flatMap := by
61+
funext α β f as
62+
let rec go : ∀ as acc, flatMapTRNoSpecialize.go f as acc = acc.toList ++ as.flatMap f
63+
| [], acc => by simp [flatMapTRNoSpecialize.go, flatMap]
64+
| x::xs, acc => by simp [flatMapTRNoSpecialize.go, flatMap, go xs]
65+
exact (go as #[])
66+
5567
end BasicStuff
5668

5769
/-!
@@ -599,13 +611,22 @@ instance [Functor m] : Functor (TsilT m) where
599611

600612
@[always_inline]
601613
instance [TsilTCore m] : Bind (TsilT m) where
602-
bind := fun xs f => xs.flatMap fun mx => TsilTCore.op mx f
614+
bind := fun xs f =>
615+
match xs with
616+
| [] => []
617+
| [x] => TsilTCore.op x f
618+
| _ => xs.flatMapTRNoSpecialize fun mx => TsilTCore.op mx f
619+
620+
theorem TsilTCore.bind_eq_flatMap [TsilTCore m] (xs : TsilT m α) (f : α → TsilT m β) :
621+
bind xs f = xs.flatMap fun mx => TsilTCore.op mx f := by
622+
rcases xs with _ | ⟨x, _ | ⟨y, xs⟩⟩ <;> (try solve | rfl | simp [bind])
623+
simp [bind, List.flatMapTRNoSpecialize_eq_flatMap]
603624

604625
@[always_inline]
605626
instance [Monad m] [TsilTCore m] : Monad (TsilT m) where
606627

607628
instance [Monad m] [TsilTCore m] : MonadFlatMap'BindDistributive (TsilT m) where
608-
bind_distrib := by introv ; simp [MonadFlatMap'.op, Bind.bind] ; induction l <;> grind
629+
bind_distrib := by introv ; simp [MonadFlatMap'.op, TsilTCore.bind_eq_flatMap] ; induction l <;> grind
609630

610631
section Lawfulness
611632

@@ -625,11 +646,11 @@ class LawfulTsilTCore (m : Type u → Type v) [Monad m] [TsilTCore m] where
625646
theorem TsilTCore.bind_cons {α β : Type u} [Monad m] [TsilTCore m]
626647
(mx : m α) (mxs : TsilT m α) (f : α → TsilT m β) :
627648
letI tmp : TsilT m α := (mx :: mxs)
628-
(tmp >>= f) = (TsilTCore.op mx f) ++ (mxs >>= f) := by simp [bind]
649+
(tmp >>= f) = (TsilTCore.op mx f) ++ (mxs >>= f) := by simp [TsilTCore.bind_eq_flatMap]
629650

630651
theorem TsilTCore.bind_append {α β : Type u} [Monad m] [TsilTCore m]
631652
(mx1 mx2 : TsilT m α) (f : α → TsilT m β) :
632-
((mx1 ++ mx2) >>= f) = (mx1 >>= f) ++ (mx2 >>= f) := by simp [bind]
653+
((mx1 ++ mx2) >>= f) = (mx1 >>= f) ++ (mx2 >>= f) := by simp [TsilTCore.bind_eq_flatMap]
633654

634655
-- this is required in general
635656
instance [Monad m] [LawfulMonad m] [TsilTCore m] [LawfulTsilTCore m] : LawfulMonad (TsilT m) :=
@@ -640,12 +661,12 @@ instance [Monad m] [LawfulMonad m] [TsilTCore m] [LawfulTsilTCore m] : LawfulMon
640661
induction x with
641662
| nil => simp
642663
| cons y xs ih => simp [ih])
643-
(pure_bind := by introv ; simp [bind, pure] ; apply LawfulTsilTCore.pure_op)
664+
(pure_bind := by introv ; simp [bind] ; apply LawfulTsilTCore.pure_op)
644665
(bind_assoc := by
645-
introv ; simp [bind] ; rw [List.flatMap_assoc]
666+
introv ; simp [TsilTCore.bind_eq_flatMap] ; rw [List.flatMap_assoc]
646667
apply List.flatMap_congr ; intro x _ ; apply LawfulTsilTCore.op_assoc)
647668
(bind_pure_comp := by
648-
introv ; simp [bind, pure, Functor.map]
669+
introv ; simp [TsilTCore.bind_eq_flatMap, pure, Functor.map]
649670
induction x with
650671
| nil => simp
651672
| cons y xs ih => simp [ih] ; rw [LawfulTsilTCore.op_single] ; simp)

0 commit comments

Comments
 (0)