@@ -52,6 +52,18 @@ theorem pointwiseSup_alt {l : Type v} [CompleteBooleanAlgebra l] {α : Type u} (
5252 apply iSup_congr ; intro a
5353 by_cases h : a ∈ lis <;> simp [h]
5454
55+ @[inline] def List.flatMapTRNoSpecialize (f : α → List β) (as : List α) : List β := go as #[] where
56+ go : List α → Array β → List β
57+ | [], acc => acc.toList
58+ | x::xs, acc => go xs (acc ++ f x)
59+
60+ theorem List.flatMapTRNoSpecialize_eq_flatMap : @List.flatMapTRNoSpecialize = @List.flatMap := by
61+ funext α β f as
62+ let rec go : ∀ as acc, flatMapTRNoSpecialize.go f as acc = acc.toList ++ as.flatMap f
63+ | [], acc => by simp [flatMapTRNoSpecialize.go, flatMap]
64+ | x::xs, acc => by simp [flatMapTRNoSpecialize.go, flatMap, go xs]
65+ exact (go as #[])
66+
5567end BasicStuff
5668
5769/-!
@@ -599,13 +611,22 @@ instance [Functor m] : Functor (TsilT m) where
599611
600612@[always_inline]
601613instance [TsilTCore m] : Bind (TsilT m) where
602- bind := fun xs f => xs.flatMap fun mx => TsilTCore.op mx f
614+ bind := fun xs f =>
615+ match xs with
616+ | [] => []
617+ | [x] => TsilTCore.op x f
618+ | _ => xs.flatMapTRNoSpecialize fun mx => TsilTCore.op mx f
619+
620+ theorem TsilTCore.bind_eq_flatMap [TsilTCore m] (xs : TsilT m α) (f : α → TsilT m β) :
621+ bind xs f = xs.flatMap fun mx => TsilTCore.op mx f := by
622+ rcases xs with _ | ⟨x, _ | ⟨y, xs⟩⟩ <;> (try solve | rfl | simp [bind])
623+ simp [bind, List.flatMapTRNoSpecialize_eq_flatMap]
603624
604625@[always_inline]
605626instance [Monad m] [TsilTCore m] : Monad (TsilT m) where
606627
607628instance [Monad m] [TsilTCore m] : MonadFlatMap'BindDistributive (TsilT m) where
608- bind_distrib := by introv ; simp [MonadFlatMap'.op, Bind.bind ] ; induction l <;> grind
629+ bind_distrib := by introv ; simp [MonadFlatMap'.op, TsilTCore.bind_eq_flatMap ] ; induction l <;> grind
609630
610631section Lawfulness
611632
@@ -625,11 +646,11 @@ class LawfulTsilTCore (m : Type u → Type v) [Monad m] [TsilTCore m] where
625646theorem TsilTCore.bind_cons {α β : Type u} [Monad m] [TsilTCore m]
626647 (mx : m α) (mxs : TsilT m α) (f : α → TsilT m β) :
627648 letI tmp : TsilT m α := (mx :: mxs)
628- (tmp >>= f) = (TsilTCore.op mx f) ++ (mxs >>= f) := by simp [bind ]
649+ (tmp >>= f) = (TsilTCore.op mx f) ++ (mxs >>= f) := by simp [TsilTCore.bind_eq_flatMap ]
629650
630651theorem TsilTCore.bind_append {α β : Type u} [Monad m] [TsilTCore m]
631652 (mx1 mx2 : TsilT m α) (f : α → TsilT m β) :
632- ((mx1 ++ mx2) >>= f) = (mx1 >>= f) ++ (mx2 >>= f) := by simp [bind ]
653+ ((mx1 ++ mx2) >>= f) = (mx1 >>= f) ++ (mx2 >>= f) := by simp [TsilTCore.bind_eq_flatMap ]
633654
634655-- this is required in general
635656instance [Monad m] [LawfulMonad m] [TsilTCore m] [LawfulTsilTCore m] : LawfulMonad (TsilT m) :=
@@ -640,12 +661,12 @@ instance [Monad m] [LawfulMonad m] [TsilTCore m] [LawfulTsilTCore m] : LawfulMon
640661 induction x with
641662 | nil => simp
642663 | cons y xs ih => simp [ih])
643- (pure_bind := by introv ; simp [bind, pure ] ; apply LawfulTsilTCore.pure_op)
664+ (pure_bind := by introv ; simp [bind] ; apply LawfulTsilTCore.pure_op)
644665 (bind_assoc := by
645- introv ; simp [bind ] ; rw [List.flatMap_assoc]
666+ introv ; simp [TsilTCore.bind_eq_flatMap ] ; rw [List.flatMap_assoc]
646667 apply List.flatMap_congr ; intro x _ ; apply LawfulTsilTCore.op_assoc)
647668 (bind_pure_comp := by
648- introv ; simp [bind , pure, Functor.map]
669+ introv ; simp [TsilTCore.bind_eq_flatMap , pure, Functor.map]
649670 induction x with
650671 | nil => simp
651672 | cons y xs ih => simp [ih] ; rw [LawfulTsilTCore.op_single] ; simp)
0 commit comments