From 83fa45b2991710dd784de4220d37308d5fb6dd18 Mon Sep 17 00:00:00 2001 From: Justin Frank Date: Mon, 30 Sep 2024 00:36:37 -0400 Subject: [PATCH 1/8] refactor stateT and writerT monads to use the definitions in extlib. --- theories/Basics/Basics.v | 62 ++-------- theories/Basics/MonadProp.v | 1 - theories/Basics/MonadState.v | 70 +++++------ theories/Events/Dependent.v | 9 +- theories/Events/FailFacts.v | 1 - theories/Events/Map.v | 16 ++- theories/Events/MapDefault.v | 15 +-- theories/Events/MapDefaultFacts.v | 20 ++-- theories/Events/State.v | 25 ++-- theories/Events/StateFacts.v | 185 ++++++++++++++++-------------- theories/Events/Writer.v | 42 ++++--- theories/Interp/Handler.v | 2 - theories/Interp/HandlerFacts.v | 1 - theories/Props/Leaf.v | 32 +++--- 14 files changed, 243 insertions(+), 238 deletions(-) diff --git a/theories/Basics/Basics.v b/theories/Basics/Basics.v index c25cbca0..f34b84ed 100644 --- a/theories/Basics/Basics.v +++ b/theories/Basics/Basics.v @@ -12,8 +12,10 @@ From Coq Require Import From ExtLib Require Import Structures.Functor Structures.Monad + Structures.Monoid Data.Monads.StateMonad Data.Monads.ReaderMonad + Data.Monads.WriterMonad Data.Monads.OptionMonad Data.Monads.EitherMonad. @@ -45,44 +47,6 @@ Definition idM {E : Type -> Type} : E ~> E := fun _ e => e. (** [void] is a shorthand for [Empty_set]. *) Notation void := Empty_set. -(** ** Common monads and transformers. *) - -Module Monads. - -Definition identity (a : Type) : Type := a. - -Definition stateT (s : Type) (m : Type -> Type) (a : Type) : Type := - s -> m (prod s a). -Definition state (s a : Type) := s -> prod s a. - -Definition run_stateT {s m a} (x : stateT s m a) : s -> m (s * a)%type := x. - -Definition liftState {s a f} `{Functor f} (fa : f a) : Monads.stateT s f a := - fun s => pair s <$> fa. - -Definition readerT (r : Type) (m : Type -> Type) (a : Type) : Type := - r -> m a. -Definition reader (r a : Type) := r -> a. - -Definition writerT (w : Type) (m : Type -> Type) (a : Type) : Type := - m (prod w a). -Definition writer := prod. - -#[global] Instance Functor_stateT {m s} {Fm : Functor m} : Functor (stateT s m) - := {| - fmap _ _ f := fun run s => fmap (fun sa => (fst sa, f (snd sa))) (run s) - |}. - -#[global] Instance Monad_stateT {m s} {Fm : Monad m} : Monad (stateT s m) - := {| - ret _ a := fun s => ret (s, a) - ; bind _ _ t k := fun s => - sa <- t s ;; - k (snd sa) (fst sa) - |}. - -End Monads. - (** ** Loop operator *) (** [iter]: A primitive for general recursion. @@ -113,22 +77,20 @@ Polymorphic Class MonadIter (M : Type -> Type) : Type := | inr r => inr (r, snd is') end) (i, s)). -#[global] Polymorphic Instance MonadIter_stateT0 {M S} {MM : Monad M} {AM : MonadIter M} - : MonadIter (Monads.stateT S M) := - fun _ _ step i s => - iter (fun si => - let s := fst si in - let i := snd si in - si' <- step i s;; - ret match snd si' with - | inl i' => inl (fst si', i') - | inr r => inr (fst si', r) - end) (s, i). - #[global] Instance MonadIter_readerT {M S} {AM : MonadIter M} : MonadIter (readerT S M) := fun _ _ step i => mkReaderT (fun s => iter (fun i => runReaderT (step i) s) i). +#[global] Instance MonadIter_writerT {M W} {MM : Monad M} {AM : MonadIter M} (Monoid_W : Monoid W) + : MonadIter (writerT Monoid_W M) := + fun _ _ step i => mkWriterT _ ( + iter (fun iw => + iw' <- runWriterT (step (PPair.pfst iw)) ;; + ret match PPair.pfst iw' with + | inl i' => inl (PPair.ppair i' (monoid_plus Monoid_W (PPair.psnd iw) (PPair.psnd iw'))) + | inr r => inr (PPair.ppair r (monoid_plus Monoid_W (PPair.psnd iw) (PPair.psnd iw'))) + end) (PPair.ppair i (monoid_unit Monoid_W))). + #[global] Instance MonadIter_optionT {M} {MM : Monad M} {AM : MonadIter M} : MonadIter (optionT M) := fun _ _ step i => mkOptionT ( diff --git a/theories/Basics/MonadProp.v b/theories/Basics/MonadProp.v index 29de11c7..bdc3b712 100644 --- a/theories/Basics/MonadProp.v +++ b/theories/Basics/MonadProp.v @@ -13,7 +13,6 @@ From ITree Require Import Basics.CategoryKleisli Basics.Monad. -Import ITree.Basics.Basics.Monads. Import CatNotations. Local Open Scope cat_scope. Local Open Scope cat. diff --git a/theories/Basics/MonadState.v b/theories/Basics/MonadState.v index 40bee1c8..c927a10e 100644 --- a/theories/Basics/MonadState.v +++ b/theories/Basics/MonadState.v @@ -4,7 +4,8 @@ From Coq Require Import Morphisms. From ExtLib Require Import - Structures.Monad. + Structures.Monad + Data.Monads.StateMonad. From ITree Require Import Basics.Basics @@ -13,7 +14,8 @@ From ITree Require Import Basics.CategoryKleisliFacts Basics.Monad. -Import ITree.Basics.Basics.Monads. +Existing Instance Monad_stateT. + Import CatNotations. Local Open Scope cat_scope. Local Open Scope cat. @@ -28,7 +30,7 @@ Section State. Global Instance Eq1_stateTM : Eq1 (stateT S M). Proof. - exact (fun a => pointwise_relation _ eq1). + exact (fun a s1 s2 => pointwise_relation _ eq1 (runStateT s1) (runStateT s2)). Defined. Global Instance Eq1Equivalence_stateTM : @Eq1Equivalence (stateT S M) _ Eq1_stateTM. @@ -44,20 +46,21 @@ Section State. Proof. constructor. - cbn. intros a b f x. - repeat red. intros s. + repeat red. intros s. cbn. rewrite bind_ret_l. reflexivity. - cbn. intros a x. - repeat red. intros s. - assert (EQM _ (bind (x s) (fun sa : S * a => ret (fst sa, snd sa))) (bind (x s) (fun sa => ret sa))). + repeat red. intros s. cbn. + assert (EQM _ (bind (runStateT x s) (fun '(a, s1) => ret (a, s1))) + (bind (runStateT x s) (fun axs => ret axs))). { apply Proper_bind. reflexivity. intros. repeat red. destruct a0; reflexivity. } rewrite H. rewrite bind_ret_r. reflexivity. - cbn. intros a b c x f g. - repeat red. intros s. + repeat red. intros s. cbn. rewrite bind_bind. apply Proper_bind. + reflexivity. - + reflexivity. + + intros []. reflexivity. - repeat red. intros a b x y H x0 y0 H0 s. apply Proper_bind. + apply H. @@ -69,20 +72,20 @@ Section State. Context {IM: Iter (Kleisli M) sum}. Context {CM: Iterative (Kleisli M) sum}. - Definition iso {a b:Type} (sab : S * (a + b)) : (S * a) + (S * b) := - match sab with - | (s, inl x) => inl (s, x) - | (s, inr y) => inr (s, y) + Definition iso {a b:Type} (abs : (a + b) * S) : (a * S) + (b * S) := + match abs with + | (inl x, s) => inl (x, s) + | (inr y, s) => inr (y, s) end. - Definition iso_inv {a b:Type} (sab : (S * a) + (S * b)) : S * (a + b) := - match sab with - | inl (s, a) => (s, inl a) - | inr (s, b) => (s, inr b) + Definition iso_inv {a b:Type} (abs : (a * S) + (b * S)) : (a + b) * S := + match abs with + | inl (a, s) => (inl a, s) + | inr (b, s) => (inr b, s) end. - Definition internalize {a b:Type} (f : Kleisli (stateT S M) a b) : Kleisli M (S * a) (S * b) := - fun (sa : S * a) => f (snd sa) (fst sa). + Definition internalize {a b:Type} (f : Kleisli (stateT S M) a b) : Kleisli M (a * S) (b * S) := + fun '(a, s) => runStateT (f a) s. Lemma internalize_eq {a b:Type} (f g : Kleisli (stateT S M) a b) : eq2 f g <-> eq2 (internalize f) (internalize g). @@ -94,7 +97,7 @@ Section State. - intros. repeat red. intros. unfold internalize in H. - specialize (H (a1, a0)). + specialize (H (a0, a1)). apply H. Qed. @@ -112,8 +115,8 @@ Section State. Qed. - Lemma internalize_pure {a b c} (f : Kleisli (stateT S M) a b) (g : S * b -> S * c) : - internalize f >>> pure g ⩯ (internalize (f >>> (fun b s => ret (g (s,b))))). + Lemma internalize_pure {a b c} (f : Kleisli (stateT S M) a b) (g : b * S -> c * S) : + internalize f >>> pure g ⩯ (internalize (f >>> (fun b => mkStateT (fun s => ret (g (b,s)))))). Proof. repeat red. destruct a0. @@ -129,8 +132,8 @@ Section State. Global Instance Iter_stateTM : Iter (Kleisli (stateT S M)) sum. Proof. exact - (fun (a b : Type) (f : a -> S -> M (S * (a + b))) (x:a) (s:S) => - iter ((internalize f) >>> (pure iso)) (s, x)). + (fun (a b : Type) (f : a -> stateT S M (a + b)) (x:a) => + mkStateT (fun (s:S) => iter ((internalize f) >>> (pure iso)) (x, s))). Defined. Global Instance Proper_Iter_stateTM : forall a b, @Proper (Kleisli (stateT S M) a (a + b) -> (Kleisli (stateT S M) a b)) (eq2 ==> eq2) iter. @@ -144,7 +147,7 @@ Section State. cbn. apply Proper_bind. - apply H. - - repeat red. destruct a2 as [s' [x1|y1]]; reflexivity. + - repeat red. destruct a2 as [[x1|y1] s']; reflexivity. Qed. Global Instance IterUnfold_stateTM : IterUnfold (Kleisli (stateT S M)) sum. @@ -156,13 +159,14 @@ Section State. intros a0 s. unfold cat, Cat_Kleisli. unfold iter, Iter_stateTM. + cbn. rewrite iterative_unfold. (* SAZ: why isn't iter_unfold exposed here? *) unfold cat, Cat_Kleisli. simpl. rewrite bind_bind. apply Proper_bind. + reflexivity. - + repeat red. destruct a1 as [s' [x | y]]; simpl. + + repeat red. destruct a1 as [[x | y] s']; simpl. * unfold pure. rewrite bind_ret_l. reflexivity. * unfold pure. rewrite bind_ret_l. @@ -185,7 +189,7 @@ Section State. rewrite! bind_bind. apply Proper_bind. - reflexivity. - - repeat red. destruct a2 as [s' [x | y]]; simpl. + - repeat red. destruct a2 as [[x | y] s']; simpl. + unfold pure. rewrite bind_ret_l. cbn. unfold cat, Cat_Kleisli. cbn. rewrite bind_bind. @@ -204,7 +208,7 @@ Section State. Qed. Lemma internalize_pure_iso {a b c} (f : Kleisli (stateT S M) a (b + c)) : - ((internalize f) >>> pure iso) ⩯ (fun sa => (bind (f (snd sa) (fst sa))) (fun sbc => ret (iso sbc))). + ((internalize f) >>> pure iso) ⩯ (fun axs => (bind (let '(a, s) := axs in runStateT (f a) s)) (fun bcs => ret (iso bcs))). Proof. reflexivity. Qed. @@ -212,7 +216,7 @@ Section State. Lemma eq2_to_eq1 : forall a b (f g : Kleisli (stateT S M) a b) (x:a) (s:S), eq2 f g -> - eq1 (f x s) (g x s). + eq1 (runStateT (f x) s) (runStateT (g x) s). Proof. intros a b f g x s H. apply H. @@ -233,7 +237,7 @@ Section State. apply Proper_bind. - reflexivity. - repeat red. - destruct a1 as [s' [x | y]]. + destruct a1 as [[x | y] s']. + unfold pure. rewrite bind_ret_l. unfold case_, Case_Kleisli, Function.case_sum. @@ -263,7 +267,7 @@ Section State. apply Proper_bind. - reflexivity. - repeat red. - destruct a2 as [s [x | y]]. + destruct a2 as [[x | y] s]. + unfold pure. rewrite bind_ret_l. cbn. @@ -274,7 +278,7 @@ Section State. apply Proper_bind. * reflexivity. * repeat red. - destruct a2 as [s' [x' | y]]. + destruct a2 as [[x' | y] s']. ** cbn. rewrite bind_ret_l. unfold case_, Case_Kleisli, Function.case_sum. reflexivity. ** cbn. rewrite bind_ret_l. unfold case_, Case_Kleisli, Function.case_sum. @@ -298,7 +302,7 @@ Section State. eapply iterative_proper_iter. eapply Proper_cat_Kleisli. - assert (internalize (fun (x:a) (s:S) => iter (internalize f >>> pure iso) (s, x)) + assert (internalize (fun (x:a) => mkStateT (fun (s:S) => iter (internalize f >>> pure iso) (x, s))) ⩯ iter (internalize f >>> pure iso)). { repeat red. @@ -325,7 +329,7 @@ Section State. apply Proper_bind. - reflexivity. - repeat red. - destruct a3 as [s' [x | [y | z]]]. + destruct a3 as [[x | [y | z]] s']. + rewrite bind_ret_l. cbn. unfold id_, Id_Kleisli, pure. rewrite bind_ret_l. diff --git a/theories/Events/Dependent.v b/theories/Events/Dependent.v index 78610d36..ee8b581e 100644 --- a/theories/Events/Dependent.v +++ b/theories/Events/Dependent.v @@ -14,13 +14,16 @@ *) (* begin hide *) + +From ExtLib Require Import + Data.Monads.IdentityMonad. + From ITree Require Import Basics.Basics Core.ITreeDefinition Indexed.Sum Core.Subevent. -Import Basics.Basics.Monads. (* end hide *) Variant depE {I : Type} (F : I -> Type) : Type -> Type := @@ -32,8 +35,8 @@ Definition dep {I F E} `{depE F -< E} (i : I) : itree E (F i) := trigger (Dep (F := F) i). Definition undep {I F} (f : forall i : I, F i) : - depE F ~> identity := + depE F ~> ident := fun _ d => match d with - | Dep i => f i + | Dep i => mkIdent (f i) end. diff --git a/theories/Events/FailFacts.v b/theories/Events/FailFacts.v index 113b01db..23859e73 100644 --- a/theories/Events/FailFacts.v +++ b/theories/Events/FailFacts.v @@ -25,7 +25,6 @@ From ITree Require Import Interp.RecursionFacts. Import ITreeNotations. -Import ITree.Basics.Basics.Monads. Local Open Scope itree_scope. Import Monads. diff --git a/theories/Events/Map.v b/theories/Events/Map.v index 64f6b904..0ddc4368 100644 --- a/theories/Events/Map.v +++ b/theories/Events/Map.v @@ -9,6 +9,9 @@ Import ListNotations. From ExtLib.Structures Require Maps. +From ExtLib Require Import + Data.Monads.StateMonad. + From ITree Require Import Basics.Basics Basics.CategoryOps @@ -19,7 +22,6 @@ From ITree Require Import Interp.Interp Events.State. -Import ITree.Basics.Basics.Monads. (* end hide *) Section Map. @@ -50,15 +52,17 @@ Section Map. Context {M : Map K V map}. Definition handle_map {E} : mapE ~> stateT map (itree E) := - fun _ e env => + fun _ e => match e with - | Insert k v => Ret (Maps.add k v env, tt) - | Lookup k => Ret (env, Maps.lookup k env) - | Remove k => Ret (Maps.remove k env, tt) + | Insert k v => mkStateT (fun env => Ret (tt, Maps.add k v env)) + | Lookup k => mkStateT (fun env => Ret (Maps.lookup k env, env)) + | Remove k => mkStateT (fun env => Ret (tt, Maps.remove k env)) end. + (* not sure why case_ requires manual parameters *) Definition run_map {E} : itree (mapE +' E) ~> stateT map (itree E) := - interp_state (case_ handle_map pure_state). + interp_state (case_ (bif := sum1) (c := stateT map (itree E)) + handle_map pure_state). End Map. diff --git a/theories/Events/MapDefault.v b/theories/Events/MapDefault.v index c2ef6c64..d8115e0d 100644 --- a/theories/Events/MapDefault.v +++ b/theories/Events/MapDefault.v @@ -5,7 +5,8 @@ Set Implicit Arguments. Set Contextual Implicit. From ExtLib Require Import - Core.RelDec. + Core.RelDec + Data.Monads.StateMonad. From ExtLib.Structures Require Functor Monoid Maps. @@ -23,7 +24,6 @@ From ITree Require Import Interp.Handler Events.State. -Import ITree.Basics.Basics.Monads. (* end hide *) Section Map. @@ -56,11 +56,11 @@ Section Map. end. Definition handle_map {E d} : mapE d ~> stateT map (itree E) := - fun _ e env => + fun _ e => match e with - | Insert k v => Ret (Maps.add k v env, tt) - | LookupDef k => Ret (env, lookup_default k d env) - | Remove k => Ret (Maps.remove k env, tt) + | Insert k v => mkStateT (fun env => Ret (tt, Maps.add k v env)) + | LookupDef k => mkStateT (fun env => Ret (lookup_default k d env, env)) + | Remove k => mkStateT (fun env => Ret (tt, Maps.remove k env)) end. (* SAZ: I think that all of these [run_foo] functions should be renamed @@ -69,7 +69,8 @@ Section Map. strange to define [interp_map] in terms of [interp_state]. *) Definition interp_map {E d} : itree (mapE d +' E) ~> stateT map (itree E) := - interp_state (case_ (C := IFun) handle_map pure_state). + interp_state (case_ (bif := sum1) (C := IFun) (c := stateT map (itree E)) + handle_map pure_state). (* The appropriate notation of the equivalence on the state associated with diff --git a/theories/Events/MapDefaultFacts.v b/theories/Events/MapDefaultFacts.v index 7829dc30..b9696fdd 100644 --- a/theories/Events/MapDefaultFacts.v +++ b/theories/Events/MapDefaultFacts.v @@ -7,7 +7,8 @@ Set Contextual Implicit. From Coq Require Import Morphisms. From ExtLib Require Import - Core.RelDec. + Core.RelDec + Data.Monads.StateMonad. From ExtLib.Structures Require Maps. @@ -23,7 +24,6 @@ From ITree Require Import Events.StateFacts Events.MapDefault. -Import ITree.Basics.Basics.Monads. Import Structures.Maps. (* end hide *) @@ -120,7 +120,7 @@ Section MapFacts. Definition map_default_eq d {E} : (stateT map (itree E) R1) -> (stateT map (itree E) R2) -> Prop := - fun t1 t2 => forall s1 s2, (@eq_map _ _ _ _ d) s1 s2 -> eutt (prod_rel (@eq_map _ _ _ _ d) RR) (t1 s1) (t2 s2). + fun t1 t2 => forall s1 s2, (@eq_map _ _ _ _ d) s1 s2 -> eutt (prod_rel RR (@eq_map _ _ _ _ d)) (runStateT t1 s1) (runStateT t2 s2). End Relations. @@ -153,7 +153,7 @@ Section MapFacts. Lemma handle_map_eq : forall d E X (s1 s2 : map) (m : mapE K d X), (@eq_map _ _ _ _ d) s1 s2 -> - eutt (prod_rel (@eq_map _ _ _ _ d) eq) (handle_map m s1) ((handle_map m s2) : itree E (map * X)). + eutt (prod_rel eq (@eq_map _ _ _ _ d)) (runStateT (handle_map m) s1) ((runStateT (handle_map m) s2) : itree E (X * map)). Proof. intros. destruct m; cbn; red; apply eqit_Ret; constructor; cbn; auto. @@ -210,22 +210,24 @@ Section MapFacts. rewrite! unfold_interp_state. punfold H0. red in H0. revert s1 s2 H1. - induction H0; intros; subst; simpl; pclearbot. - - eret. + induction H0; intros; subst; cbn; pclearbot. + - eret. - etau. - ebind. - apply pbc_intro_h with (RU := prod_rel (@eq_map _ _ _ _ d) eq). + apply pbc_intro_h with (RU := prod_rel eq (@eq_map _ _ _ _ d)). { (* SAZ: I must be missing some lemma that should solve this case *) unfold case_. unfold Case_sum1, case_sum1. destruct e. apply handle_map_eq. assumption. unfold pure_state. pstep. econstructor. intros. constructor. pfold. econstructor. constructor; auto. - } - intros. destruct H as [HH1 ->]. + } + intros. destruct H as [-> HH1]. estep; constructor. ebase. - rewrite tau_euttge, unfold_interp_state. + apply IHeqitF. eauto. - rewrite tau_euttge, unfold_interp_state. + apply IHeqitF. eauto. Qed. diff --git a/theories/Events/State.v b/theories/Events/State.v index 4d264400..7ac60e68 100644 --- a/theories/Events/State.v +++ b/theories/Events/State.v @@ -5,7 +5,8 @@ (* begin hide *) From ExtLib Require Import Structures.Functor - Structures.Monad. + Structures.Monad + Data.Monads.StateMonad. From ITree Require Import Basics.Basics @@ -17,7 +18,7 @@ From ITree Require Import Core.Subevent Interp.Interp. -Import ITree.Basics.Basics.Monads. +Existing Instance Monad_stateT. Local Open Scope itree_scope. (* end hide *) @@ -32,6 +33,7 @@ Definition interp_state {E M S} itree E ~> stateT S M := interp h. Arguments interp_state {E M S FM MM IM} h [T]. +Arguments interp_state : simpl never. Section State. @@ -45,27 +47,30 @@ Section State. Definition put {E} `{stateE -< E} : S -> itree E unit := embed Put. Definition h_state {E} : stateE ~> stateT S (itree E) := - fun _ e s => - match e with - | Get => Ret (s, s) - | Put s' => Ret (s', tt) - end. + fun _ e => + mkStateT (fun s => + match e with + | Get => Ret (s, s) + | Put s' => Ret (tt, s') + end). (* SAZ: this is the instance for the hypothetical "Trigger E M" typeclass. Class Trigger E M := trigger : E ~> M *) Definition pure_state {S E} : E ~> stateT S (itree E) - := fun _ e s => Vis e (fun x => Ret (s, x)). + := fun _ e => mkStateT (fun s => Vis e (fun x => Ret (x, s))). + (* not sure why case_ requires the manual parameters *) Definition run_state {E} : itree (stateE +' E) ~> stateT S (itree E) - := interp_state (case_ h_state pure_state). + := interp_state (case_ (bif := sum1) (c := stateT S (itree E)) + h_state pure_state). End State. Arguments get {S E _}. Arguments put {S E _}. -Arguments run_state {S E} [_] _ _. +Arguments run_state {S E} [_] _. (** An extensional stateful handler *) diff --git a/theories/Events/StateFacts.v b/theories/Events/StateFacts.v index ca0405f2..37b8c49d 100644 --- a/theories/Events/StateFacts.v +++ b/theories/Events/StateFacts.v @@ -5,6 +5,8 @@ From Coq Require Import Program.Tactics Morphisms. From Paco Require Import paco. +From ExtLib Require Import Data.Monads.StateMonad. + From ITree Require Import Basics.Basics Basics.Category @@ -27,41 +29,61 @@ Import ITreeNotations. Local Open Scope itree_scope. -Import Monads. (* end hide *) Definition _interp_state {E F S R} (f : E ~> stateT S (itree F)) (ot : itreeF E R _) - : stateT S (itree F) R := fun s => + : stateT S (itree F) R := mkStateT (fun s => match ot with - | RetF r => Ret (s, r) - | TauF t => Tau (interp_state f t s) - | VisF e k => f _ e s >>= (fun sx => Tau (interp_state f (k (snd sx)) (fst sx))) - end. + | RetF r => Ret (r, s) + | TauF t => Tau (runStateT (interp_state f t) s) + | VisF e k => runStateT (f _ e) s >>= (fun xs => Tau (runStateT (interp_state f (k (fst xs))) (snd xs))) + end). -Lemma unfold_interp_state {E F S R} (h : E ~> Monads.stateT S (itree F)) +Lemma unfold_interp_state {E F S R} (h : E ~> stateT S (itree F)) (t : itree E R) s : eq_itree eq - (interp_state h t s) - (_interp_state h (observe t) s). + (runStateT (interp_state h t) s) + (runStateT (_interp_state h (observe t)) s). Proof. - unfold interp_state, interp, Basics.iter, MonadIter_stateT0, Basics.iter, MonadIter_itree; cbn. + unfold interp_state, interp, Basics.iter, MonadIter_stateT, Basics.iter, MonadIter_itree; cbn. rewrite unfold_iter; cbn. destruct observe; cbn. - rewrite 2 bind_ret_l. reflexivity. - rewrite 2 bind_ret_l. reflexivity. - - rewrite bind_map, bind_bind; cbn. setoid_rewrite bind_ret_l. - apply eqit_bind; reflexivity. + - rewrite 2 bind_bind; cbn. + apply eqit_bind; [reflexivity|]. + intros []. + rewrite 2 bind_ret_l. + reflexivity. Qed. +Definition eq_stateT {S M X Y} + (RM : forall R1 R2, (R1 -> R2 -> Prop) -> M R1 -> M R2 -> Prop) + (RXY : X * S -> Y * S -> Prop) + : (stateT S M X) -> (stateT S M Y) -> Prop := + fun t1 t2 => forall s, RM _ _ RXY (runStateT t1 s) (runStateT t2 s). + #[global] -Instance eq_itree_interp_state {E F S R} (h : E ~> Monads.stateT S (itree F)) : - Proper (eq_itree eq ==> eq ==> eq_itree eq) +Instance eq_stateT_runStateT {S M R} + (RM : forall R1 R2, (R1 -> R2 -> Prop) -> M R1 -> M R2 -> Prop) : + Proper (eq_stateT RM eq ==> eq ==> (RM _ _ eq)) (@runStateT S _ R). +Proof. + unfold Proper, respectful. + intros x y Heq_stateT sx sy Heq. + specialize (Heq_stateT sy). + rewrite Heq. + apply Heq_stateT. +Qed. + +#[global] +Instance eq_itree_interp_state {E F S R} (h : E ~> stateT S (itree F)) : + Proper (eq_itree eq ==> (eq_stateT (@eq_itree _) eq)) (@interp_state _ _ _ _ _ _ h R). Proof. revert_until R. - ginit. pcofix CIH. intros h x y H0 x2 _ []. + ginit. pcofix CIH. intros h x y H0 s. rewrite !unfold_interp_state. punfold H0; repeat red in H0. destruct H0; subst; pclearbot; try discriminate; cbn. @@ -73,31 +95,31 @@ Proof. Qed. Lemma interp_state_ret {E F : Type -> Type} {R S : Type} - (f : forall T, E T -> S -> itree F (S * T)%type) + (f : forall T, E T -> stateT S (itree F) T) (s : S) (r : R) : - (interp_state f (Ret r) s) ≅ (Ret (s, r)). + (runStateT (interp_state f (Ret r)) s) ≅ (Ret (r, s)). Proof. rewrite itree_eta. reflexivity. Qed. Lemma interp_state_vis {E F : Type -> Type} {S T U : Type} - (e : E T) (k : T -> itree E U) (h : E ~> Monads.stateT S (itree F)) (s : S) - : interp_state h (Vis e k) s - ≅ h T e s >>= fun sx => Tau (interp_state h (k (snd sx)) (fst sx)). + (e : E T) (k : T -> itree E U) (h : E ~> stateT S (itree F)) (s : S) + : runStateT (interp_state h (Vis e k)) s + ≅ runStateT (h T e) s >>= fun xs => Tau (runStateT (interp_state h (k (fst xs))) (snd xs)). Proof. rewrite unfold_interp_state; reflexivity. Qed. Lemma interp_state_tau {E F : Type -> Type} S {T : Type} - (t : itree E T) (h : E ~> Monads.stateT S (itree F)) (s : S) - : interp_state h (Tau t) s ≅ Tau (interp_state h t s). + (t : itree E T) (h : E ~> stateT S (itree F)) (s : S) + : runStateT (interp_state h (Tau t)) s ≅ Tau (runStateT (interp_state h t) s). Proof. rewrite unfold_interp_state; reflexivity. Qed. Lemma interp_state_trigger_eqit {E F : Type -> Type} {R S : Type} - (e : E R) (f : E ~> Monads.stateT S (itree F)) (s : S) - : (interp_state f (ITree.trigger e) s) ≅ (f _ e s >>= fun x => Tau (Ret x)). + (e : E R) (f : E ~> stateT S (itree F)) (s : S) + : (runStateT (interp_state f (ITree.trigger e)) s) ≅ (runStateT (f _ e) s >>= fun x => Tau (Ret x)). Proof. unfold ITree.trigger. rewrite interp_state_vis. eapply eqit_bind; try reflexivity. @@ -105,8 +127,8 @@ Proof. Qed. Lemma interp_state_trigger {E F : Type -> Type} {R S : Type} - (e : E R) (f : E ~> Monads.stateT S (itree F)) (s : S) - : interp_state f (ITree.trigger e) s ≈ f _ e s. + (e : E R) (f : E ~> stateT S (itree F)) (s : S) + : runStateT (interp_state f (ITree.trigger e)) s ≈ runStateT (f _ e) s. Proof. unfold ITree.trigger. rewrite interp_state_vis. match goal with @@ -118,12 +140,12 @@ Proof. Qed. Lemma interp_state_bind {E F : Type -> Type} {A B S : Type} - (f : forall T, E T -> S -> itree F (S * T)%type) + (f : forall T, E T -> stateT S (itree F) T) (t : itree E A) (k : A -> itree E B) (s : S) : - (interp_state f (t >>= k) s) + (runStateT (interp_state f (t >>= k)) s) ≅ - (interp_state f t s >>= fun st => interp_state f (k (snd st)) (fst st)). + (runStateT (interp_state f t) s >>= fun ts => runStateT (interp_state f (k (fst ts))) (snd ts)). Proof. revert t k s. ginit. pcofix CIH. @@ -131,11 +153,11 @@ Proof. rewrite unfold_bind. rewrite (unfold_interp_state f t). destruct (observe t). - - cbn. rewrite !bind_ret_l. cbn. + - cbn. rewrite bind_ret_l. cbn. apply reflexivity. - - cbn. rewrite !bind_tau, interp_state_tau. + - rewrite interp_state_tau. cbn. rewrite bind_tau. gstep. econstructor. gbase. apply CIH. - - cbn. rewrite interp_state_vis, bind_bind. + - rewrite interp_state_vis. cbn. rewrite bind_bind. guclo eqit_clo_bind. econstructor. + reflexivity. + intros u2 ? []. @@ -147,14 +169,14 @@ Qed. #[global] Instance eutt_interp_state {E F: Type -> Type} {S : Type} - (h : E ~> Monads.stateT S (itree F)) R RR : - Proper (eutt RR ==> eq ==> eutt (prod_rel eq RR)) (@interp_state E (itree F) S _ _ _ h R). + (h : E ~> stateT S (itree F)) R RR : + Proper (eutt RR ==> eq_stateT (@eutt _) (prod_rel RR eq)) (@interp_state E (itree F) S _ _ _ h R). Proof. repeat intro. subst. revert_until RR. einit. ecofix CIH. intros. rewrite !unfold_interp_state. punfold H0. red in H0. - induction H0; intros; subst; simpl; pclearbot. + induction H0; intros; subst; pclearbot; cbn. - eret. - etau. - ebind. econstructor; [reflexivity|]. @@ -166,14 +188,14 @@ Qed. #[global] Instance eutt_interp_state_eq {E F: Type -> Type} {S : Type} - (h : E ~> Monads.stateT S (itree F)) R : - Proper (eutt eq ==> eq ==> eutt eq) (@interp_state E (itree F) S _ _ _ h R). + (h : E ~> stateT S (itree F)) R : + Proper (eutt eq ==> eq_stateT (@eutt _) eq) (@interp_state E (itree F) S _ _ _ h R). Proof. repeat intro. subst. revert_until R. einit. ecofix CIH. intros. rewrite !unfold_interp_state. punfold H0. red in H0. - induction H0; intros; subst; simpl; pclearbot. + induction H0; intros; subst; cbn; pclearbot. - eret. - etau. - ebind. econstructor; [reflexivity|]. @@ -187,17 +209,17 @@ Qed. Lemma eutt_interp_state_aloop {E F S I I' A A'} (RA : A -> A' -> Prop) (RI : I -> I' -> Prop) (RS : S -> S -> Prop) - (h : E ~> Monads.stateT S (itree F)) + (h : E ~> stateT S (itree F)) (t1 : I -> itree E (I + A)) (t2 : I' -> itree E (I' + A')): (forall i i' s1 s2, RS s1 s2 -> RI i i' -> - eutt (prod_rel RS (sum_rel RI RA)) - (interp_state h (t1 i) s1) - (interp_state h (t2 i') s2)) -> + eutt (prod_rel (sum_rel RI RA) RS) + (runStateT (interp_state h (t1 i)) s1) + (runStateT (interp_state h (t2 i')) s2)) -> (forall i i' s1 s2, RS s1 s2 -> RI i i' -> - eutt (fun a b => RS (fst a) (fst b) /\ RA (snd a) (snd b)) - (interp_state h (ITree.iter t1 i) s1) - (interp_state h (ITree.iter t2 i') s2)). + eutt (fun a b => RA (fst a) (fst b) /\ RS (snd a) (snd b)) + (runStateT (interp_state h (ITree.iter t1 i)) s1) + (runStateT (interp_state h (ITree.iter t2 i')) s2)). Proof. intro Ht. einit. ecofix CIH. intros. @@ -205,43 +227,43 @@ Proof. rewrite 2 interp_state_bind. ebind; econstructor. - eapply Ht; auto. - - intros [s1' i1'] [s2' i2'] [? []]; cbn. - + rewrite 2 interp_state_tau. auto with paco. - + rewrite 2 interp_state_ret. auto with paco. + - intros [i1' s1'] [i2' s2'] [[] ?]. + + rewrite 2 interp_state_tau. cbn. auto with paco. + + rewrite 2 interp_state_ret. cbn. auto with paco. Qed. Lemma eutt_interp_state_iter {E F S A A' B B'} (RA : A -> A' -> Prop) (RB : B -> B' -> Prop) (RS : S -> S -> Prop) - (h : E ~> Monads.stateT S (itree F)) + (h : E ~> stateT S (itree F)) (t1 : A -> itree E (A + B)) (t2 : A' -> itree E (A' + B')) : (forall ca ca' s1 s2, RS s1 s2 -> RA ca ca' -> - eutt (prod_rel RS (sum_rel RA RB)) - (interp_state h (t1 ca) s1) - (interp_state h (t2 ca') s2)) -> + eutt (prod_rel (sum_rel RA RB) RS) + (runStateT (interp_state h (t1 ca)) s1) + (runStateT (interp_state h (t2 ca')) s2)) -> (forall a a' s1 s2, RS s1 s2 -> RA a a' -> - eutt (fun a b => RS (fst a) (fst b) /\ RB (snd a) (snd b)) - (interp_state h (iter (C := ktree _) t1 a) s1) - (interp_state h (iter (C := ktree _) t2 a') s2)). + eutt (fun a b => RB (fst a) (fst b) /\ RS (snd a) (snd b)) + (runStateT (interp_state h (iter (C := ktree _) t1 a)) s1) + (runStateT (interp_state h (iter (C := ktree _) t2 a')) s2)). Proof. apply eutt_interp_state_aloop. Qed. Lemma eutt_eq_interp_state_iter {E F S} (f: E ~> stateT S (itree F)) {I A} (t : I -> itree E (I + A)): - forall i s, interp_state f (ITree.iter t i) s ≈ - Basics.iter (fun i => interp_state f (t i)) i s. + forall i s, runStateT (interp_state f (ITree.iter t i)) s ≈ + runStateT (Basics.iter (fun i => interp_state f (t i)) i) s. Proof. - unfold Basics.iter, MonadIter_stateT0, Basics.iter, MonadIter_itree in *; cbn. + unfold Basics.iter, MonadIter_stateT, Basics.iter, MonadIter_itree in *. cbn. ginit. gcofix CIH; intros i s. rewrite 2 unfold_iter; cbn. rewrite !bind_bind. setoid_rewrite bind_ret_l. rewrite interp_state_bind. guclo eqit_clo_bind; econstructor; eauto. reflexivity. - intros [s' []] _ []; cbn. + intros [[] s'] _ []; cbn. - rewrite interp_state_tau. gstep; constructor. auto with paco. @@ -249,16 +271,16 @@ Proof. Qed. Lemma eutt_interp_state_loop {E F S A B C} (RS : S -> S -> Prop) - (h : E ~> Monads.stateT S (itree F)) + (h : E ~> stateT S (itree F)) (t1 t2 : C + A -> itree E (C + B)) : (forall ca s1 s2, RS s1 s2 -> - eutt (fun a b => RS (fst a) (fst b) /\ snd a = snd b) - (interp_state h (t1 ca) s1) - (interp_state h (t2 ca) s2)) -> + eutt (fun a b => fst a = fst b /\ RS (snd a) (snd b)) + (runStateT (interp_state h (t1 ca)) s1) + (runStateT (interp_state h (t2 ca)) s2)) -> (forall a s1 s2, RS s1 s2 -> - eutt (fun a b => RS (fst a) (fst b) /\ snd a = snd b) - (interp_state h (loop (C := ktree E) t1 a) s1) - (interp_state h (loop (C := ktree E) t2 a) s2)). + eutt (fun a b => fst a = fst b /\ RS (snd a) (snd b)) + (runStateT (interp_state h (loop (C := ktree E) t1 a)) s1) + (runStateT (interp_state h (loop (C := ktree E) t2 a)) s2)). Proof. intros. unfold loop, bimap, Bimap_Coproduct, case_, Case_Kleisli, Function.case_sum, id_, Id_Kleisli, cat, Cat_Kleisli, inr_, Inr_Kleisli, inl_, Inl_Kleisli, lift_ktree_; cbn. @@ -268,36 +290,31 @@ Proof. subst. eapply eutt_clo_bind; eauto. intros. - cbn in H2; destruct (snd u1); rewrite <- (proj2 H2). + cbn in H2; destruct (fst u1); rewrite <- (proj1 H2). - rewrite bind_ret_l, 2 interp_state_ret. pstep. constructor. - split; cbn; auto using (proj1 H2). + split; cbn; auto using (proj2 H2). - rewrite bind_ret_l, 2 interp_state_ret. pstep. constructor. - split; cbn; auto using (proj1 H2). + split; cbn; auto using (proj2 H2). Qed. -(* SAZ: These are probably too specialized. *) -Definition state_eq {E S X} - : (stateT S (itree E) X) -> (stateT S (itree E) X) -> Prop := - fun t1 t2 => forall s, eq_itree eq (t1 s) (t2 s). - Lemma interp_state_iter {E F } S (f : E ~> stateT S (itree F)) {I A} (t : I -> itree E (I + A)) (t' : I -> stateT S (itree F) (I + A)) - (EQ_t : forall i, state_eq (State.interp_state f (t i)) (t' i)) - : forall i, state_eq (State.interp_state f (ITree.iter t i)) + (EQ_t : forall i, eq_stateT (@eq_itree _) eq (State.interp_state f (t i)) (t' i)) + : forall i, eq_stateT (@eq_itree _) eq (State.interp_state f (ITree.iter t i)) (Basics.iter t' i). Proof. - unfold Basics.iter, MonadIter_stateT0, Basics.iter, MonadIter_itree in *; cbn. - ginit. pcofix CIH; intros i s. + unfold Basics.iter, MonadIter_stateT, Basics.iter, MonadIter_itree in *. + ginit. pcofix CIH; intros i s. cbn. rewrite 2 unfold_iter; cbn. rewrite !bind_bind. setoid_rewrite bind_ret_l. rewrite interp_state_bind. guclo eqit_clo_bind; econstructor; eauto. - apply EQ_t. - - intros [s' []] _ []; cbn. + - intros [[] s'] _ []; cbn. + rewrite interp_state_tau. gstep; constructor. auto with paco. @@ -306,7 +323,7 @@ Qed. Lemma interp_state_iter' {E F } S (f : E ~> stateT S (itree F)) {I A} (t : I -> itree E (I + A)) - : forall i, state_eq (State.interp_state f (ITree.iter t i)) + : forall i, eq_stateT (@eq_itree _) eq (State.interp_state f (ITree.iter t i)) (Basics.iter (fun i => State.interp_state f (t i)) i). Proof. eapply interp_state_iter. @@ -317,10 +334,10 @@ Qed. Lemma interp_state_iter'_eutt {E F S} (f: E ~> stateT S (itree F)) {I A} (t : I -> itree E (I + A)) (t': I -> stateT S (itree F) (I + A)) - (Heq: forall i s, interp_state f (t i) s ≈ (t' i) s): - forall i s, interp_state f (ITree.iter t i) s ≈ Basics.iter t' i s. + (Heq: forall i s, runStateT (interp_state f (t i)) s ≈ runStateT (t' i) s): + forall i s, runStateT (interp_state f (ITree.iter t i)) s ≈ runStateT (Basics.iter t' i) s. Proof. - unfold Basics.iter, MonadIter_stateT0, Basics.iter, MonadIter_itree in *; cbn. + unfold Basics.iter, MonadIter_stateT, Basics.iter, MonadIter_itree in *; cbn. ginit. gcofix CIH; intros i s. rewrite 2 unfold_iter; cbn. rewrite !bind_bind. @@ -328,7 +345,7 @@ Proof. rewrite interp_state_bind. guclo eqit_clo_bind; econstructor; eauto. - apply Heq. - - intros [s' []] _ []; cbn. + - intros [[] s'] _ []; cbn. + rewrite interp_state_tau. gstep; constructor. auto with paco. diff --git a/theories/Events/Writer.v b/theories/Events/Writer.v index 9cbff741..b106c3c2 100644 --- a/theories/Events/Writer.v +++ b/theories/Events/Writer.v @@ -11,6 +11,9 @@ From Coq Require Import Import ListNotations. From ExtLib Require Import + Data.List + Data.Monads.StateMonad + Data.Monads.WriterMonad Structures.Functor Structures.Monad Structures.Monoid. @@ -27,7 +30,6 @@ From ITree Require Import Interp.Handler Events.State. -Import Basics.Basics.Monads. (* end hide *) (** Event to output values of type [W]. *) @@ -43,41 +45,49 @@ Definition tell {W E} `{writerE W -< E} : W -> itree E unit := (** Note that this handler appends new outputs to the front of the list. *) Definition handle_writer_list {W E} : writerE W ~> stateT (list W) (itree E) - := fun _ e s => + := fun _ e => match e with - | Tell w => Ret (w :: s, tt) + | Tell w => mkStateT (fun s => Ret (tt, w :: s)) end. +(* not sure why case_ requires the manual parameters *) Definition run_writer_list_state {W E} : itree (writerE W +' E) ~> stateT (list W) (itree E) - := interp_state (case_ handle_writer_list pure_state). + := interp_state (case_ (bif := sum1) (c := stateT (list W) (itree E)) + handle_writer_list pure_state). Arguments run_writer_list_state {W E} [T]. (** Returns the outputs in order: the first output at the head, the last output and the end of the list. *) Definition run_writer_list {W E} - : itree (writerE W +' E) ~> writerT (list W) (itree E) + : itree (writerE W +' E) ~> stateT (list W) (itree E) := fun _ t => - ITree.map (fun wsx => (rev' (fst wsx), snd wsx)) - (run_writer_list_state t []). + mkStateT (fun _ => ITree.map (fun wsx => (fst wsx, rev' (snd wsx))) + (runStateT (run_writer_list_state t) [])). Arguments run_writer_list {W E} [T]. (** When [W] is a monoid, we can also use that to append the outputs together. *) -Definition handle_writer {W E} (Monoid_W : Monoid W) - : writerE W ~> stateT W (itree E) - := fun _ e s => +Definition handle_writer {W E} `{Monoid_W : Monoid W} + : writerE W ~> writerT Monoid_W (itree E) + := fun _ e => match e with - | Tell w => Ret (monoid_plus Monoid_W s w, tt) + | Tell w => MonadWriter.tell w end. +Definition pure_writer {W E} `{Monoid_W : Monoid W} + : E ~> writerT Monoid_W (itree E) + := fun _ e => + mkWriterT Monoid_W + (Vis e (fun x => Ret (PPair.ppair x (monoid_unit Monoid_W)))). + +(* not sure why case_ requires the manual parameters *) Definition run_writer {W E} (Monoid_W : Monoid W) - : itree (writerE W +' E) ~> writerT W (itree E) - := fun _ t => - interp_state (M := itree E) - (case_ (handle_writer Monoid_W) pure_state) t - (monoid_unit Monoid_W). + : itree (writerE W +' E) ~> writerT Monoid_W (itree E) + := interp (M := writerT Monoid_W (itree E)) + (case_ (bif := sum1) (c := writerT Monoid_W (itree E)) + handle_writer pure_writer). Arguments run_writer {W E} Monoid_W [T]. diff --git a/theories/Interp/Handler.v b/theories/Interp/Handler.v index 2518dc8c..ed43d55d 100644 --- a/theories/Interp/Handler.v +++ b/theories/Interp/Handler.v @@ -18,8 +18,6 @@ From ITree Require Import Interp.Interp Interp.Recursion. -Import ITree.Basics.Basics.Monads. - Local Open Scope itree_scope. (* end hide *) diff --git a/theories/Interp/HandlerFacts.v b/theories/Interp/HandlerFacts.v index 435a7647..0218ac3a 100644 --- a/theories/Interp/HandlerFacts.v +++ b/theories/Interp/HandlerFacts.v @@ -22,7 +22,6 @@ From ITree Require Import Interp.RecursionFacts. Import ITreeNotations. -Import ITree.Basics.Basics.Monads. Local Open Scope itree_scope. diff --git a/theories/Props/Leaf.v b/theories/Props/Leaf.v index aebbb2e3..4f76ec9d 100644 --- a/theories/Props/Leaf.v +++ b/theories/Props/Leaf.v @@ -1,6 +1,8 @@ (** * Leaves of an Interaction Tree *) (* begin hide *) +From ExtLib Require Import Data.Monads.StateMonad. + From ITree Require Import Basics.Utils Basics.HeterogeneousRelations @@ -375,12 +377,12 @@ Qed. (* Inverts [sr' ∈ interp_state h (ITree.iter body i)] into a post-condition on both retun value and state, like Leaf_iter_inv. *) Lemma Leaf_interp_state_iter_inv {E F S R I}: - forall (h: E ~> Monads.stateT S (itree F)) (body: I -> itree E (I + R)) + forall (h: E ~> stateT S (itree F)) (body: I -> itree E (I + R)) (RS: S -> Prop) (RI: I -> Prop) (RR: R -> Prop) (s: S) (i: I), - (forall s i, RS s -> RI i -> (forall sx', sx' ∈ interp_state h (body i) s -> - prod_pred RS (sum_pred RI RR) sx')) -> + (forall s i, RS s -> RI i -> (forall sx', sx' ∈ runStateT (interp_state h (body i)) s -> + prod_pred (sum_pred RI RR) RS sx')) -> RS s -> RI i -> - forall sr', sr' ∈ interp_state h (ITree.iter body i) s -> prod_pred RS RR sr'. + forall sr', sr' ∈ runStateT (interp_state h (ITree.iter body i)) s -> prod_pred RR RS sr'. Proof. setoid_rewrite <- has_post_Leaf_equiv. setoid_rewrite has_post_post_strong. @@ -389,8 +391,8 @@ Proof. set (eRR := fun (r1 r2: R) => r1 = r2 /\ RR r1). set (eRS := fun (s1 s2: S) => s1 = s2 /\ RS s1). - set (R1 := (fun x y : S * R => x = y /\ prod_pred RS RR x)). - set (R2 := (fun a b : S * R => eRS (fst a) (fst b) /\ eRR (snd a) (snd b))). + set (R1 := (fun x y : R * S => x = y /\ prod_pred RR RS x)). + set (R2 := (fun a b : R * S => eRR (fst a) (fst b) /\ eRS (snd a) (snd b))). assert (HR1R2: eq_rel R1 R2) by (compute; intuition; subst; now try destruct y). unfold has_post_strong; fold R1; rewrite (eutt_equiv _ _ HR1R2). @@ -398,12 +400,12 @@ Proof. [| subst eRS; intuition | subst eRI; intuition]. intros i1 ? s1 ? [<- Hs1] [<- Hi1]. - set (R3 := (fun x y : S * (I + R) => x = y /\ prod_pred RS (sum_pred RI RR) x)). - set (R4 := (prod_rel eRS (sum_rel eRI eRR))). + set (R3 := (fun x y : (I + R) * S => x = y /\ prod_pred (sum_pred RI RR) RS x)). + set (R4 := (prod_rel (sum_rel eRI eRR) eRS)). assert (HR3R4: eq_rel R3 R4). - { split; intros [? [|]] [? [|]]; compute. + { split; intros [[|] ?] [[|] ?]; compute. 1-4: intros [[]]; dintuition; cbn; intuition. - all: intros [[[=->] ?] HZ]; inversion HZ; intuition now subst. } + all: intros [HZ [[=->] ?]]; inversion HZ; intuition now subst. } rewrite <- (eutt_equiv _ _ HR3R4). now apply Hinv. @@ -456,9 +458,9 @@ Proof. reflexivity. Qed. -Lemma Leaf_interp_state_subtree_inv {E F S R} (h: E ~> Monads.stateT S (itree F)) +Lemma Leaf_interp_state_subtree_inv {E F S R} (h: E ~> stateT S (itree F)) (t u: itree E R) (s: S): - subtree u t -> has_post (interp_state h u s) (fun x => snd x ∈ t). + subtree u t -> has_post (runStateT (interp_state h u) s) (fun x => fst x ∈ t). Proof. revert t u s. ginit. gcofix CIH; intros * Hsub. rewrite (itree_eta u) in Hsub. @@ -481,12 +483,12 @@ Proof. apply Leaf_interp_subtree_inv. apply SubtreeRefl; reflexivity. Qed. -Lemma Leaf_interp_state_inv {E F S R} (h: E ~> Monads.stateT S (itree F)) +Lemma Leaf_interp_state_inv {E F S R} (h: E ~> stateT S (itree F)) (t: itree E R) s x: - x ∈ interp_state h t s -> snd x ∈ t. + x ∈ runStateT (interp_state h t) s -> fst x ∈ t. Proof. intros Hleaf. - apply (has_post_Leaf (interp_state h t s) (fun x => snd x ∈ t)); auto. + apply (has_post_Leaf (runStateT (interp_state h t) s) (fun x => fst x ∈ t)); auto. apply Leaf_interp_state_subtree_inv. apply SubtreeRefl; reflexivity. Qed. From f0a6606c6aa45cdb155685e57b88004facd4f173 Mon Sep 17 00:00:00 2001 From: Justin Frank Date: Wed, 2 Oct 2024 00:16:18 -0400 Subject: [PATCH 2/8] refactored usages of stateT in extra/ --- extra/Dijkstra/StateDelaySpec.v | 59 ++++++++-------- extra/Dijkstra/StateIOTrace.v | 38 ++++++---- extra/Dijkstra/StateSpecT.v | 106 +++++++++++++--------------- extra/IForest.v | 2 - extra/Secure/SecureStateHandler.v | 34 ++++----- extra/Secure/SecureStateHandlerPi.v | 23 +++--- 6 files changed, 134 insertions(+), 128 deletions(-) diff --git a/extra/Dijkstra/StateDelaySpec.v b/extra/Dijkstra/StateDelaySpec.v index c95d184c..cc66cb34 100644 --- a/extra/Dijkstra/StateDelaySpec.v +++ b/extra/Dijkstra/StateDelaySpec.v @@ -1,5 +1,6 @@ From ExtLib Require Import Data.List + Data.Monads.StateMonad Structures.Monad. From Paco Require Import paco. @@ -42,21 +43,21 @@ Section StateDelaySpec. Definition StateDelayObs := EffectObsStateT St DelaySpec Delay. - Definition StateDelayMonadMorph := MonadMorphimStateT St DelaySpec Delay. + Definition StateDelayMonadMorph := MonadMorphismStateT St DelaySpec Delay. - Definition PrePost A : Type := (Delay (St * A) -> Prop ) * (St -> Prop). + Definition PrePost A : Type := (Delay (A * St) -> Prop ) * (St -> Prop). Definition PrePostRef {A : Type} (m : StateDelay A) (pp : PrePost A) : Prop := let '(post,pre) := pp in - forall s, pre s -> post (m s). + forall s, pre s -> post (runStateT m s). Program Definition encode {A : Type} (pp : PrePost A) : StateDelaySpec A := let '(post,pre) := pp in - fun s p => pre s /\ (forall r, post r -> p r). + mkStateT (fun s p => pre s /\ (forall r, post r -> p r)). Definition verify_cond {A : Type} := DijkstraProp StateDelay StateDelaySpec StateDelayObs A. - Lemma encode_correct : forall (A : Type) (pre : St -> Prop) (post : Delay (St * A) -> Prop) + Lemma encode_correct : forall (A : Type) (pre : St -> Prop) (post : Delay (A * St) -> Prop) (m : StateDelay A), resp_eutt post -> (PrePostRef m (post,pre) <-> verify_cond (encode (post,pre)) m). Proof. @@ -65,7 +66,7 @@ Section StateDelaySpec. - repeat red. simpl. intros. destruct p as [p Hp]. simpl in H1. destruct H1 as [Hpre Himp]. auto. - repeat red in H0. simpl in H0. - set (exist _ post H) as p. enough ((m s) ∈ p); auto. + set (exist _ post H) as p. enough ((runStateT m s) ∈ p); auto. apply H0. auto. Qed. @@ -73,12 +74,12 @@ Section StateDelaySpec. Definition PrePostPairRef {A : Type} (pppp : PrePostPair A) (m : StateDelay A) := let '((post0, pre0), (post1, pre1)) := pppp in - forall s, (pre0 s -> post0 (m s)) /\ (pre1 s -> post1 (m s)) . + forall s, (pre0 s -> post0 (runStateT m s)) /\ (pre1 s -> post1 (runStateT m s)) . Program Definition encode_pair {A : Type} (pppp : PrePostPair A) : StateDelaySpec A:= let '((post0, pre0), (post1, pre1)) := pppp in - fun s (p : DelaySpecInput (St * A)) => - (pre0 s /\ (forall r, post0 r -> p r)) \/ (pre1 s /\ forall r, post1 r -> p r). + mkStateT (fun s (p : DelaySpecInput (A * St)) => + (pre0 s /\ (forall r, post0 r -> p r)) \/ (pre1 s /\ forall r, post1 r -> p r)). Next Obligation. destruct H0 as [H0 | H1]. - destruct H0 as [Hp Hr]. left. auto. @@ -86,7 +87,7 @@ Section StateDelaySpec. Qed. Lemma encode_pair_correct : forall (A : Type) (pre0 pre1 : St -> Prop) - (post0 post1 : Delay (St * A) -> Prop ) (m : StateDelay A), + (post0 post1 : Delay (A * St) -> Prop ) (m : StateDelay A), let pp : PrePostPair A := ((post0,pre0),(post1,pre1)) in resp_eutt post0 -> resp_eutt post1 -> (PrePostPairRef pp m <-> verify_cond (encode_pair pp) m). @@ -97,20 +98,20 @@ Section StateDelaySpec. destruct H2 as [ [Hs Hp] | [Hs Hp] ]; simpl in *; auto. - repeat red in H1. simpl in *. split; intros. - + set (exist _ post0 H) as p. enough ((m s) ∈ p ); auto. + + set (exist _ post0 H) as p. enough ((runStateT m s) ∈ p ); auto. apply H1. left. split; auto. - + set (exist _ post1 H0) as p. enough ((m s) ∈ p ); auto. + + set (exist _ post1 H0) as p. enough ((runStateT m s) ∈ p ); auto. apply H1. right. split; auto. Qed. Definition PrePostList A : Type := list (PrePost A). Definition PrePostListRef {A : Type} (ppl : PrePostList A) (m : StateDelay A) := - forall s, List.Forall (fun pp : PrePost A=> let (post,pre) := pp in pre s -> post (m s) ) ppl. + forall s, List.Forall (fun pp : PrePost A=> let (post,pre) := pp in pre s -> post (runStateT m s) ) ppl. Program Definition encode_list {A : Type} (ppl : PrePostList A) : StateDelaySpec A := - fun s (p : DelaySpecInput (St * A) ) => - List.Exists (fun pp : PrePost A => let (post,pre) := pp in pre s /\ forall r, post r -> p r) ppl. + mkStateT (fun s (p : DelaySpecInput (A * St) ) => + List.Exists (fun pp : PrePost A => let (post,pre) := pp in pre s /\ forall r, post r -> p r) ppl). Next Obligation. induction H0; eauto. destruct x as [post pre]. destruct H0 as [Hs Hr]. left. auto. @@ -129,7 +130,7 @@ Section StateDelaySpec. + destruct a as [post pre]. inversion H1; subst. * destruct H3. auto. - assert ((pre s -> post (m s)) ); auto. + assert ((pre s -> post (runStateT m s)) ); auto. intros. inversion Hrefine; subst; auto. * apply IHppl; auto. -- inversion H; auto. @@ -142,24 +143,24 @@ Section StateDelaySpec. { inversion H. auto. } set (exist _ post Heutt) as p. specialize (Henc p) as Hencp. constructor; intros. - + enough ((m s) ∈ p ); auto. apply Hencp. + + enough ((runStateT m s) ∈ p ); auto. apply Hencp. left. split; auto. + apply IHppl; auto. * inversion H. auto. * clear IHppl. intros. apply H0. eauto. Qed. - Definition DynPrePost A : Type := (St -> Prop) * (St -> Delay (St * A) -> Prop). + Definition DynPrePost A : Type := (St -> Prop) * (St -> Delay (A * St) -> Prop). Definition DynPrePostRef {A : Type} (pp : DynPrePost A) (m : StateDelay A) := let (pre,post) := pp in - forall s, pre s -> post s (m s). + forall s, pre s -> post s (runStateT m s). Program Definition encode_dyn {A : Type} (pp : DynPrePost A) : StateDelaySpec A := let (pre,post) := pp in - fun s p => pre s /\ forall r, post s r -> p r. + mkStateT (fun s p => pre s /\ forall r, post s r -> p r). - Lemma encode_dyn_correct : forall (A : Type) (pre : St -> Prop) (post : St -> Delay (St * A) -> Prop ) (m : StateDelay A), + Lemma encode_dyn_correct : forall (A : Type) (pre : St -> Prop) (post : St -> Delay (A * St) -> Prop ) (m : StateDelay A), (forall s, resp_eutt (post s)) -> (DynPrePostRef (pre,post) m <-> verify_cond (encode_dyn (pre,post) ) m). Proof. intros. unfold verify_cond, DijkstraProp. split; intros. @@ -174,7 +175,7 @@ Section StateDelaySpec. Forall (fun pp => DynPrePostRef pp m) ppl. Program Definition encode_list_dyn {A : Type} (ppl : list (DynPrePost A)) : StateDelaySpec A := - fun s p => List.Exists (fun pp : DynPrePost A => let (pre,post) := pp in pre s /\ forall r, post s r -> p r ) ppl. + mkStateT (fun s p => List.Exists (fun pp : DynPrePost A => let (pre,post) := pp in pre s /\ forall r, post s r -> p r ) ppl). Next Obligation. induction H0; eauto. left. destruct x as [pre post]. destruct H0 as [Hs Hr]. split; auto. @@ -192,7 +193,7 @@ Section StateDelaySpec. + destruct a as [pre post]. inversion H1; subst. * destruct H2. - assert ((pre s -> post s (m s)) ); auto. + assert ((pre s -> post s (runStateT m s)) ); auto. intros. inversion Hrefine; subst; auto. * apply IHppl; auto. -- inversion H; auto. @@ -204,7 +205,7 @@ Section StateDelaySpec. { inversion H. auto. } constructor; intros. + red. intros. set (exist _ (post s) (Heutt s)) as p. - specialize (H0 s p). enough ((m s) ∈ p); auto. apply H0. + specialize (H0 s p). enough ((runStateT m s) ∈ p); auto. apply H0. left. split; auto. + apply IHppl; auto. * inversion H. auto. @@ -213,11 +214,11 @@ Section StateDelaySpec. Qed. Lemma combine_prepost_aux : forall (A B : Type) (pre1 pre2 : St -> Prop) - (post1 : Delay (St * A) -> Prop ) (post2 : Delay (St * B) -> Prop) + (post1 : Delay (A * St) -> Prop ) (post2 : Delay (B * St) -> Prop) (m : StateDelay A) (f : A -> StateDelay B), verify_cond (encode (post1,pre1) ) m -> (forall (a : A) (s : St), (* this condition is not exactly what i want*) - post1 (Ret (s,a) ) -> post2 (f a s) ) -> + post1 (Ret (a,s) ) -> post2 (runStateT (f a) s) ) -> (post1 ITree.spin -> post2 ITree.spin) -> resp_eutt post1 -> verify_cond (encode (post2, pre1) ) (bind m f). @@ -225,7 +226,7 @@ Section StateDelaySpec. intros. repeat red in H. repeat red. intros. destruct p as [p Hp]. simpl in *. destruct H3. - destruct (eutt_reta_or_div (m s)); basic_solve. + destruct (eutt_reta_or_div (runStateT m s)); basic_solve. - destruct a as [s' a]. cbn in H5. rewrite <- H5, bind_ret_l; cbn. apply H4, H0. rewrite H5. apply (H s (exist _ post1 H2)); auto. @@ -234,10 +235,10 @@ Section StateDelaySpec. Qed. Lemma combine_prepost : forall (A B : Type) (pre1 pre2 : St -> Prop) - (post1 : Delay (St * A) -> Prop ) (post2 : Delay (St * B) -> Prop) + (post1 : Delay (A * St) -> Prop ) (post2 : Delay (B * St) -> Prop) (m : StateDelay A) (f : A -> StateDelay B), verify_cond (encode (post1,pre1) ) m -> - (forall a s, post1 (Ret (s,a)) -> pre2 s) -> + (forall a s, post1 (Ret (a,s)) -> pre2 s) -> (forall a, verify_cond (encode (post2,pre2) ) (f a) ) -> (post1 ITree.spin -> post2 ITree.spin) -> resp_eutt post1 -> diff --git a/extra/Dijkstra/StateIOTrace.v b/extra/Dijkstra/StateIOTrace.v index 5d399fbe..0c21ebd6 100644 --- a/extra/Dijkstra/StateIOTrace.v +++ b/extra/Dijkstra/StateIOTrace.v @@ -6,7 +6,8 @@ From ExtLib Require Import Data.String Structures.Monad Core.RelDec - Data.Map.FMapAList. + Data.Map.FMapAList + Data.Monads.StateMonad. From Paco Require Import paco. @@ -55,17 +56,17 @@ Definition SIOSpecEq := StateSpecTEq env (TraceSpec IO). Definition SIOObs := EffectObsStateT env (TraceSpec IO) (itree IO). -Definition SIOMorph :=MonadMorphimStateT env (TraceSpec IO) (itree IO). +Definition SIOMorph :=MonadMorphismStateT env (TraceSpec IO) (itree IO). Definition verify_cond {A : Type} := DijkstraProp (stateT env (itree IO)) StateIOSpec SIOObs A. (*Predicate on initial state and initial log*) Definition StateIOSpecPre : Type := env -> ev_list IO -> Prop. (*Predicate on final log and possible return value*) -Definition StateIOSpecPost (A : Type) : Type := itrace IO (env * A) -> Prop. +Definition StateIOSpecPost (A : Type) : Type := itrace IO (A * env) -> Prop. Program Definition encode {A} (pre : StateIOSpecPre) (post : StateIOSpecPost A) : StateIOSpec A := - fun s log p => pre s log /\ (forall tr, post tr -> p tr). + mkStateT (fun s log p => pre s log /\ (forall tr, post tr -> p tr)). Section PrintMults. @@ -114,13 +115,12 @@ Section PrintMults. alist_add _ V v s. Definition handleIOStateE (A : Type) (ev : (StateE +' IO) A) : stateT env (itree IO) A := - fun s => match ev with | inl1 ev' => match ev' with - | GetE V => Ret (s, lookup_default V 0 s) - | PutE V v => Ret (Maps.add V v s, tt) end - | inr1 ev' => Vis ev' (fun x => Ret (s,x) ) + | GetE V => mkStateT (fun s => Ret (lookup_default V 0 s, s)) + | PutE V v => mkStateT (fun s => Ret (tt, Maps.add V v s)) end + | inr1 ev' => mkStateT (fun s => Vis ev' (fun x => Ret (x,s))) end. Ltac unf_res := unfold resum, ReSum_id, id_, Id_IFun in *. @@ -174,6 +174,8 @@ Section PrintMults. let H' := fresh H in match type of H with ?P -> _ => assert (H' : P); try (specialize (H H'); clear H') end. + Arguments interp_state : simpl never. + Lemma print_mults_sats_spec : verify_cond (encode print_mults_pre print_mults_post) (interp_state handleIOStateE print_mults). Proof. @@ -208,11 +210,11 @@ Section PrintMults. assert (RAnsRef IO unit nat (evans nat Read n) tt Read n); auto with itree. apply H6 in H. pclearbot. auto. } - clear Href ev. subst. rewrite bind_ret_l in H. simpl in *. rewrite interp_state_bind in H. - rewrite interp_state_trigger in H. simpl in *. rewrite bind_ret_l in H. - simpl in *. + clear Href ev. subst. rewrite bind_ret_l in H. cbn in *. rewrite interp_state_bind in H. + rewrite interp_state_trigger in H. cbn in *. rewrite bind_ret_l in H. + cbn in *. specialize (@interp_state_iter' (StateE +' IO) ) as Hiter. - unfold state_eq in Hiter. rewrite Hiter in H. clear Hiter. + unfold eq_stateT in Hiter. rewrite Hiter in H. clear Hiter. remember (Maps.add X n s) as si. assert (si = alist_add RelDec_string X n s); try (subst; auto; fail). @@ -240,7 +242,8 @@ Section PrintMults. (*This block shows how to proceed through the loop body*) rename H0 into H. - unfold Basics.iter, MonadIter_stateT0, Basics.iter, MonadIter_itree in H. + unfold Basics.iter, MonadIter_stateT, Basics.iter, MonadIter_itree in H. + cbn in H. rewrite unfold_iter in H. match type of H with _ ⊑ ITree.bind _ ?k0 => remember k0 as k end. @@ -270,7 +273,10 @@ Section PrintMults. remember (lookup_default Y 0 si) as m. eapply CIH with (Maps.add Y (n + m) si); try apply lookup_eq. 2: { rewrite lookup_neq; subst; auto. } - rewrite tau_eutt in Hk1. setoid_rewrite bind_trigger in Hk1. + rewrite tau_eutt in Hk1. + (* TODO: not sure why this is failing *) + (* + setoid_rewrite bind_trigger in Hk1. setoid_rewrite interp_state_vis in Hk1. cbn in *. rewrite bind_ret_l in Hk1. rewrite tau_eutt in Hk1. setoid_rewrite bind_vis in Hk1. @@ -285,6 +291,8 @@ Section PrintMults. H : _ ⊑ ITree.iter _ (?s1, _) |- _ ⊑ ITree.iter _ (?s2, _) => enough (Hseq : s2 = s1) end; try rewrite Hseq; auto. subst. rewrite Nat.add_comm. auto. - Qed. + *) + admit. + Admitted. End PrintMults. diff --git a/extra/Dijkstra/StateSpecT.v b/extra/Dijkstra/StateSpecT.v index 1777cf7e..fd9de49c 100644 --- a/extra/Dijkstra/StateSpecT.v +++ b/extra/Dijkstra/StateSpecT.v @@ -2,6 +2,7 @@ From Coq Require Import Morphisms. From ExtLib Require Import + Data.Monads.StateMonad Structures.Monad. From Paco Require Import paco. @@ -9,6 +10,7 @@ From Paco Require Import paco. From ITree Require Import ITree ITreeFacts + Basics.MonadState Props.Infinite. From ITree.Extra Require Import @@ -36,8 +38,10 @@ Section StateSpecT. Definition StateSpecT (A : Type) := stateT S W A. + #[global] Instance Monad_StateSpecT : Monad StateSpecT := Monad_stateT S MonadW. + #[global] Instance StateSpecTOrder : OrderM StateSpecT := - fun A (w1 w2 : stateT S W A) => forall (s : S), w1 s <≈ w2 s. + fun A (w1 w2 : stateT S W A) => forall (s : S), runStateT w1 s <≈ runStateT w2 s. #[global] Instance StateSpecTOrderedLaws : OrderedMonad StateSpecT. Proof. @@ -51,25 +55,10 @@ Section StateSpecT. repeat red in Hlf. apply Hlf. Qed. - #[global] Instance StateSpecTEq : Eq1 StateSpecT := - fun _ w1 w2 => forall s, (w1 s ≈ w2 s)%monad. + #[global] Instance StateSpecTEq : Eq1 StateSpecT := Eq1_stateTM W S. - #[global] Instance StateSpecTMonadLaws : MonadLawsE StateSpecT. - Proof. - destruct MonadLawsW. - constructor. - - intros A B f a. intro. repeat red. - cbn. - rewrite bind_ret_l. simpl. reflexivity. - - intros A w. intro. cbn. - etransitivity; [ | apply bind_ret_r ]. - eapply Proper_bind; [ reflexivity | ]. - intros []; reflexivity. - - intros A B C w f g. intro. cbn. rewrite bind_bind. reflexivity. - - intros A B w1 w2 Hw k1 k2 Hk. - cbn. do 2 red. intros. do 2 red in Hw. rewrite Hw. do 3 red in Hk. - setoid_rewrite Hk. reflexivity. - Qed. + #[global] Instance StateSpecTMonadLaws : @MonadLawsE StateSpecT (Eq1_stateTM W S) _ := + @MonadLawsE_stateTM W S _ _ EquivRel MonadLawsW. Section Observation. Context (M : Type -> Type). @@ -77,17 +66,20 @@ Section StateSpecT. Context {EffectObsMW : EffectObs M W}. Context {MonadMorphismMW : MonadMorphism M W EffectObsMW}. - #[global] Instance EffectObsStateT : EffectObs (stateT S M) (stateT S W) := fun _ m s => obs _ (m s). + #[global] Instance EffectObsStateT : EffectObs (stateT S M) (stateT S W) := + fun T (m : stateT S M T) => mkStateT (fun s => obs _ (runStateT m s)). - #[global] Instance MonadMorphimStateT : MonadMorphism (stateT S M) (stateT S W) EffectObsStateT. + #[global] Instance MonadMorphismStateT : MonadMorphism (stateT S M) (stateT S W) EffectObsStateT. Proof. destruct MonadMorphismMW. constructor. - - intros. repeat red. intros. specialize (ret_pres (S * A)%type (s,a) ). + - intros. repeat red. intros. specialize (ret_pres (A * S)%type (a,a0)). cbn. rewrite <- ret_pres. reflexivity. - - intros. repeat red. intros. cbn. specialize (bind_pres (S * A)%type (S * B)%type ). - unfold obs, EffectObsStateT. - rewrite bind_pres. reflexivity. + - intros. repeat red. intros. cbn. specialize (bind_pres (A * S)%type (B * S)%type). + rewrite bind_pres. + apply Proper_bind. + + reflexivity. + + intros []. reflexivity. Qed. End Observation. @@ -111,45 +103,45 @@ Section LoopInvarSpecific. Definition State (A : Type) := stateT S Delay A. - Instance StateIter : MonadIter State := MonadIter_stateT0. + Instance StateIter : MonadIter State := MonadIter_stateT. - Definition reassoc {A B : Type} (t : Delay (S * (A + B) ) ) : Delay ((S * A) + (S * B) ) := - t >>= (fun '(s,ab) => + Definition reassoc {A B : Type} (t : Delay ((A + B) * S)) : Delay ((A * S) + (B * S)) := + t >>= (fun '(ab,s) => match ab with - | inl a => ret (inl (s , a)) - | inr b => ret (inr (s , b)) + | inl a => ret (inl (a,s)) + | inr b => ret (inr (b,s)) end ). - Definition iso_arrow {A B : Type} (f : A -> State B) (g : (S * A) -> Delay (S * B) ) := - forall (a : A) (s : S), f a s ≈ g (s,a). + Definition iso_arrow {A B : Type} (f : A -> State B) (g : (A * S) -> Delay (B * S)) := + forall (a : A) (s : S), runStateT (f a) s ≈ g (a,s). Definition decurry_flip {A B C : Type} (f : A -> B -> C) : B * A -> C := fun '(b,a) => f a b. (*this is just decurry*) - Definition iso_destatify_arrow {A B : Type} (f : A -> State (A + B) ) : (S * A) -> Delay ((S * A) + (S * B)) := - fun '(s,a) => reassoc (f a s). + Definition iso_destatify_arrow {A B : Type} (f : A -> State (A + B) ) : (A * S) -> Delay ((A * S) + (B * S)) := + fun '(a,s) => reassoc (runStateT (f a) s). (*should be able to use original*) Lemma loop_invar_state: forall (A B : Type) (g : A -> State (A + B)) (a : A) (s : S) - (p : Delay ( S * B) -> Prop) (q : Delay ((S * A) + (S * B)) -> Prop ) + (p : Delay (B * S) -> Prop) (q : Delay ((A * S) + (B * S)) -> Prop ) (Hp : resp_eutt p) (Hq : resp_eutt q) , - (q (reassoc (g a s) )) -> + (q (reassoc (runStateT (g a) s) )) -> (q -+> p) -> (forall t, q t -> q (t >>= (iter_lift ( iso_destatify_arrow g) ))) -> - (p \1/ any_infinite) (MonadIter_stateT0 _ _ g a s) . + (p \1/ any_infinite) (runStateT (MonadIter_stateT _ _ g a) s) . Proof. intros. set (iso_destatify_arrow g) as g'. - enough ((p \1/ any_infinite) (ITree.iter g' (s,a) )). - - assert (ITree.iter g' (s,a) ≈ iter g a s). + enough ((p \1/ any_infinite) (ITree.iter g' (a,s) )). + - assert (ITree.iter g' (a,s) ≈ runStateT (iter g a) s). + unfold g', iso_destatify_arrow. unfold iter, Iter_Kleisli, Basics.iter, MonadIterDelay, StateIter, - MonadIter_stateT0, reassoc. unfold Basics.iter. + MonadIter_stateT, reassoc. unfold Basics.iter. unfold MonadIterDelay. eapply eutt_iter. intro. destruct a0 as [a' s']. simpl. eapply eutt_clo_bind; try reflexivity. intros. - subst. destruct u2. simpl. destruct s1; reflexivity. + subst. destruct u2. simpl. destruct s0; reflexivity. + assert (Hpdiv : resp_eutt (p \1/ any_infinite)). { intros t1 t2 Heutt. split; intros; basic_solve. - left. eapply Hp; eauto. symmetry. auto. @@ -161,13 +153,15 @@ Section LoopInvarSpecific. - eapply loop_invar; eauto. Qed. - Definition state_iter_arrow_rel {A B S : Type} (g : A -> stateT S Delay (A + B) ) '(s0,a0) '(s1,a1) := - g a0 s0 ≈ Ret (s1, inl a1). + Definition state_iter_arrow_rel {A B S : Type} (g : A -> stateT S Delay (A + B) ) '(a0,s0) '(a1,s1) := + runStateT (g a0) s0 ≈ Ret (inl a1, s1). + + Locate not_wf. Lemma iter_inl_spin_state : forall (A B S : Type) (g : A -> stateT S Delay (A + B) ) (a : A) (s : S), - not_wf_from (state_iter_arrow_rel g ) (s,a) -> MonadIter_stateT0 _ _ g a s ≈ ITree.spin. + not_wf_from (state_iter_arrow_rel g ) (a,s) -> runStateT (MonadIter_stateT _ _ g a) s ≈ ITree.spin. Proof. - intros. unfold MonadIter_stateT0. + intros. unfold MonadIter_stateT. apply iter_inl_spin. (*seems to require some coinduciton*) generalize dependent a. generalize dependent s. pcofix CIH. intros. pinversion H0; try apply not_wf_F_mono'. pfold. @@ -178,27 +172,27 @@ Section LoopInvarSpecific. Qed. Lemma iter_wf_converge_state : forall (A B S : Type) (g : A -> stateT S Delay (A + B) ) (a : A) (s : S), - (forall a s, exists ab, g a s ≈ Ret ab) -> - wf_from (state_iter_arrow_rel g) (s,a) -> - exists (p : S * B), MonadIter_stateT0 _ _ g a s ≈ Ret p. + (forall a s, exists ab, runStateT (g a) s ≈ Ret ab) -> + wf_from (state_iter_arrow_rel g) (a,s) -> + exists (p : B * S), runStateT (MonadIter_stateT _ _ g a) s ≈ Ret p. Proof. - intros. unfold MonadIter_stateT0, Basics.iter, MonadIterDelay. + intros. unfold MonadIter_stateT, Basics.iter, MonadIterDelay. apply iter_wf_converge. - eapply wf_from_sub_rel; try apply H0. repeat intro. unfold iter_arrow_rel in *. unfold state_iter_arrow_rel. clear H0 a s. - destruct x as [s a]. simpl in *. destruct y as [s' a']. - destruct (eutt_reta_or_div (g a s)); basic_solve. + destruct x as [a s]. simpl in *. destruct y as [a' s']. + destruct (eutt_reta_or_div (runStateT (g a) s)); basic_solve. + rewrite <- H0. rewrite <- H0 in H1. simpl in H1. rewrite bind_ret_l in H1. - simpl in *. destruct a0. simpl in *. destruct s1; basic_solve. + simpl in *. destruct a0. simpl in *. destruct s0; basic_solve. reflexivity. + apply div_spin_eutt in H0. rewrite H0 in H1. rewrite <- spin_bind in H1. symmetry in H1. exfalso. eapply not_ret_eutt_spin; eauto. - - clear H0 a s. intros [s a]. specialize (H a s). basic_solve. - destruct ab as [s' [a' | b] ]. - + exists (inl (s',a') ). simpl. rewrite H. rewrite bind_ret_l. simpl. + - clear H0 a s. intros [a s]. specialize (H a s). basic_solve. + destruct ab as [[a' | b] s']. + + exists (inl (a',s')). simpl. rewrite H. rewrite bind_ret_l. simpl. reflexivity. - + exists (inr (s',b)). simpl. rewrite H. rewrite bind_ret_l. simpl. + + exists (inr (b,s')). simpl. rewrite H. rewrite bind_ret_l. simpl. reflexivity. Qed. diff --git a/extra/IForest.v b/extra/IForest.v index db7fa83f..5dd0c07f 100644 --- a/extra/IForest.v +++ b/extra/IForest.v @@ -28,8 +28,6 @@ From Coq Require Import Relations Morphisms. -Import ITree.Basics.Basics.Monads. - Import MonadNotation. Import CatNotations. Local Open Scope monad_scope. diff --git a/extra/Secure/SecureStateHandler.v b/extra/Secure/SecureStateHandler.v index 31a65198..df15e538 100644 --- a/extra/Secure/SecureStateHandler.v +++ b/extra/Secure/SecureStateHandler.v @@ -1,5 +1,7 @@ From Coq Require Import Morphisms. +From ExtLib Require Import Data.Monads.StateMonad. + From ITree Require Import Basics.HeterogeneousRelations Axioms @@ -46,7 +48,7 @@ Context (l : L). Definition state_eqit_secure {R1 R2 : Type} (b1 b2 : bool) (RR : R1 -> R2 -> Prop) (m1 : stateT S (itree E2) R1) (m2 : stateT S (itree E2) R2) := - forall s1 s2, RS s1 s2 -> eqit_secure Label priv2 (prod_rel RS RR) b1 b2 l (m1 s1) (m2 s2). + forall s1 s2, RS s1 s2 -> eqit_secure Label priv2 (prod_rel RR RS) b1 b2 l (runStateT m1 s1) (runStateT m2 s2). Definition top2 {R1 R2} (r1 : R1) (r2 : R2) : Prop := True. @@ -55,12 +57,12 @@ Definition secure_in_nonempty_context {R} (m : stateT S (itree E2) R) := forall r' : R, state_eqit_secure true true top2 m (ret r'). Definition secure_in_empty_context {R} (m : stateT S (itree E2) R) := - state_eqit_secure true true (@top2 R R) m (fun s => ITree.spin). + state_eqit_secure true true (@top2 R R) m (mkStateT (fun s => ITree.spin)). -Inductive terminates (s1 : S) (P : forall A, E2 A -> Prop) : forall {A : Type}, itree E2 (S * A) -> Prop := -| terminates_ret {R : Type} : forall (r : R) (s2 : S), RS s1 s2 -> terminates s1 P (Ret (s2, r)) -| terminates_tau : forall A (t : itree E2 (S * A)) , terminates s1 P t -> terminates s1 P (Tau t) -| terminates_vis {A R : Type} : forall (e : E2 A) (k : A -> itree E2 (S * R)) , (forall v, terminates s1 P (k v)) -> P A e -> terminates s1 P (Vis e k) +Inductive terminates (s1 : S) (P : forall A, E2 A -> Prop) : forall {A : Type}, itree E2 (A * S) -> Prop := +| terminates_ret {R : Type} : forall (r : R) (s2 : S), RS s1 s2 -> terminates s1 P (Ret (r, s2)) +| terminates_tau : forall A (t : itree E2 (A * S)) , terminates s1 P t -> terminates s1 P (Tau t) +| terminates_vis {A R : Type} : forall (e : E2 A) (k : A -> itree E2 (R * S)) , (forall v, terminates s1 P (k v)) -> P A e -> terminates s1 P (Vis e k) . Variant diverges_with' {E : Type -> Type} (P : forall A, E A -> Prop) (A : Type) (F : itree E A -> Prop) : itree' E A -> Prop := @@ -92,7 +94,7 @@ Proof. do 2 red. intros t1 t2 Heq. apply EqAxiom.bisimulation_is_eq in Heq. subst; tauto. Qed. -#[global] Instance proper_terminate {R s} {P : forall A, E2 A -> Prop} : Proper (eq_itree (@eq (S *R )) ==> iff) (terminates s P). +#[global] Instance proper_terminate {R s} {P : forall A, E2 A -> Prop} : Proper (eq_itree (@eq (R * S)) ==> iff) (terminates s P). Proof. red. intros t1 t2 Heq. apply EqAxiom.bisimulation_is_eq in Heq. subst; tauto. Qed. @@ -224,7 +226,7 @@ Qed. Lemma silent_terminates_eqit_secure_ret : forall R (m : stateT S (itree E2) R), nonempty R -> - (forall s, terminates s (fun B e => ~ leq (priv2 _ e) l /\ nonempty B) (m s) ) <-> forall r' : R, state_eqit_secure true true top2 m (ret r'). + (forall s, terminates s (fun B e => ~ leq (priv2 _ e) l /\ nonempty B) (runStateT m s) ) <-> forall r' : R, state_eqit_secure true true top2 m (ret r'). Proof. split; intros. - red. intros. specialize (H0 s1). @@ -235,8 +237,8 @@ Proof. pstep_reverse. eapply H2; eauto. - cbn in *. red in H0. assert (RS s s). reflexivity. inv H. - specialize (H0 a s s H1). remember (m s) as t. clear Heqt. - punfold H0. red in H0. cbn in H0. remember (RetF (s,a) ) as oret. remember (observe t) as ot. + specialize (H0 a s s H1). remember (runStateT m s) as t. clear Heqt. + punfold H0. red in H0. cbn in H0. remember (RetF (a,s) ) as oret. remember (observe t) as ot. hinduction H0 before E1; intros; try discriminate; use_simpobs. + rewrite Heqot. injection Heqoret; intros; subst. destruct r1, H. cbn in *. constructor. symmetry. auto. @@ -246,16 +248,16 @@ Qed. Variant handler_respects_priv (A : Type) (e : E1 A) : Prop := | respect_private (SECCHECK : ~ leq (priv1 _ e) l) - (FINCHECK : forall s, terminates s (fun _ e' => ~ leq (priv2 _ e') l) (handler A e s)) + (FINCHECK : forall s, terminates s (fun _ e' => ~ leq (priv2 _ e') l) (runStateT (handler A e) s)) | respect_public (SECCHECK : leq (priv1 _ e) l) (RESCHECK : state_eqit_secure true true eq (handler A e) (handler A e)) . Variant handler_respects_priv' (A : Type) (e : E1 A) : Prop := | respect_private_ne (SECCHECK : ~ leq (priv1 _ e) l) (SIZECHECK : nonempty A) - (FINCHECK : forall s, terminates s (fun B e' => ~ leq (priv2 _ e') l /\ nonempty B ) (handler A e s) ) + (FINCHECK : forall s, terminates s (fun B e' => ~ leq (priv2 _ e') l /\ nonempty B ) (runStateT (handler A e) s) ) | respect_private_e (SECCHECK : ~ leq (priv1 _ e) l) (SIZECHECK : empty A) - (DIVCHECK : forall s, diverges_with (fun _ e' => ~ leq (priv2 _ e') l ) (handler A e s) ) + (DIVCHECK : forall s, diverges_with (fun _ e' => ~ leq (priv2 _ e') l ) (runStateT (handler A e) s) ) | respect_public' (SECCHECK : leq (priv1 _ e) l) (RESCHECK : state_eqit_secure true true eq (handler A e) (handler A e)) . @@ -264,7 +266,7 @@ Context (Hhandler : forall A (e : E1 A), handler_respects_priv' A e). Lemma diverge_with_respectful_handler : forall (R : Type) (t : itree E1 R), diverges_with (fun _ e => ~ leq (priv1 _ e) l ) t -> - forall s, diverges_with (fun _ e => ~ leq (priv2 _ e) l) (interp_state handler t s). + forall s, diverges_with (fun _ e => ~ leq (priv2 _ e) l) (runStateT (interp_state handler t) s). Proof. intro R. pcofix CIH. intros t Hdiv s. pinversion Hdiv; use_simpobs. - rewrite H. rewrite interp_state_tau. pfold. constructor. right. eapply CIH; eauto. @@ -299,7 +301,7 @@ Proof. (* could use the bind closure here, but maybe we can do manually for now*) repeat setoid_rewrite <- interp_state_tau. inv Hhandler; try contradiction. specialize (RESCHECK s1 s2 Hs). - eapply secure_eqit_bind'; eauto. intros [] [] []. simpl in *. subst. + eapply secure_eqit_bind'; eauto. intros [] [] []. cbn in *. subst. repeat rewrite interp_state_tau. pfold. constructor. right. eapply CIH; eauto. apply H. - pclearbot. rewrite Heqot1. rewrite Heqot2. @@ -321,7 +323,7 @@ Proof. - pclearbot. rewrite Heqot1. rewrite Heqot2. repeat rewrite interp_state_vis. specialize (Hhandler _ e1) as He1. specialize (Hhandler _ e2) as He2. inv He1; inv He2; try contradiction; try contra_size. - eapply secure_eqit_bind' with (RR := prod_rel RS (fun _ _ => True)). + eapply secure_eqit_bind' with (RR := prod_rel (fun _ _ => True) RS). + intros [] [] ?. pstep. constructor. right. apply CIH. apply H. simpl. apply H0. + specialize (FINCHECK s1). specialize (FINCHECK0 s2). diff --git a/extra/Secure/SecureStateHandlerPi.v b/extra/Secure/SecureStateHandlerPi.v index 5a87a3c5..e8e979a5 100644 --- a/extra/Secure/SecureStateHandlerPi.v +++ b/extra/Secure/SecureStateHandlerPi.v @@ -1,4 +1,7 @@ From Coq Require Import Morphisms. + +From ExtLib Require Import Data.Monads.StateMonad. + From ITree Require Import Axioms ITree @@ -41,7 +44,7 @@ Context (l : L). Definition state_pi_eqit_secure {R1 R2 : Type} (b1 b2 : bool) (RR : R1 -> R2 -> Prop) (m1 : stateT S (itree E2) R1) (m2 : stateT S (itree E2) R2) := - forall s1 s2, RS s1 s2 -> pi_eqit_secure Label priv2 (prod_rel RS RR) b1 b2 l (m1 s1) (m2 s2). + forall s1 s2, RS s1 s2 -> pi_eqit_secure Label priv2 (prod_rel RR RS) b1 b2 l (runStateT m1 s1) (runStateT m2 s2). Definition top2 {R1 R2} (r1 : R1) (r2 : R2) : Prop := True. @@ -49,7 +52,7 @@ Definition secure_in_nonempty_context {R} (m : stateT S (itree E2) R) := forall r' : R, state_pi_eqit_secure true true top2 m (ret r'). Definition secure_in_empty_context {R} (m : stateT S (itree E2) R) := - state_pi_eqit_secure true true (@top2 R R) m (fun s => ITree.spin). + state_pi_eqit_secure true true (@top2 R R) m (mkStateT (fun s => ITree.spin)). Lemma diverges_with_spin : forall E A P, diverges_with P (@ITree.spin E A). @@ -83,7 +86,7 @@ Qed. Lemma silent_terminates_eqit_secure_ret : forall R (m : stateT S (itree E2) R), nonempty R -> - (forall s, terminates S RS E2 s (fun B e => ~ leq (priv2 _ e) l /\ nonempty B) (m s) ) -> forall r' : R, state_pi_eqit_secure true true top2 m (ret r'). + (forall s, terminates S RS E2 s (fun B e => ~ leq (priv2 _ e) l /\ nonempty B) (runStateT m s) ) -> forall r' : R, state_pi_eqit_secure true true top2 m (ret r'). Proof. red. intros. specialize (H0 s1). cbn. induction H0. @@ -160,7 +163,7 @@ Proof. gstep. constructor; auto. rewrite interp_state_vis. specialize (Hhandler A e). inv Hhandler; try contradiction. red in RESCHECK. apply RESCHECK in Hs as He. - remember (handler A e s1) as t3. clear Heqt3. + remember (runStateT (handler A e) s1) as t3. clear Heqt3. cbn in He. generalize dependent t3. gcofix CIH'. intros t3 Ht3. pinversion Ht3; use_simpobs; subst. + destruct H4. cbn in *. destruct r1. cbn in *. @@ -176,7 +179,7 @@ Proof. gstep. constructor; auto. rewrite interp_state_vis. specialize (Hhandler A e). inv Hhandler; try contradiction. red in RESCHECK. symmetry in Hs. apply RESCHECK in Hs as He. - remember (handler A e s2) as t3. clear Heqt3. + remember (runStateT (handler A e) s2) as t3. clear Heqt3. cbn in He. generalize dependent t3. gcofix CIH'. intros t3 Ht3. pinversion Ht3; use_simpobs; subst. + destruct H4. cbn in *. destruct r1. cbn in *. @@ -192,23 +195,23 @@ Proof. - pclearbot. rewrite H, H0. repeat rewrite interp_state_vis. specialize (Hhandler A e1) as He1. specialize (Hhandler B e2) as He2. inv He1; inv He2; try contradiction. - eapply pi_secure_eqit_bind' with (RR := prod_rel RS top2); eauto. + eapply pi_secure_eqit_bind' with (RR := prod_rel top2 RS); eauto. + intros [? ?] [? ?] [? ?]. cbn in *. pstep. constructor. right. eapply CIH; eauto. apply H1. + cbn in *. apply pi_eqit_secure_RR_imp with - (RR1 := rcompose (prod_rel RS (@top2 A unit)) (prod_rel RS top2) ). + (RR1 := rcompose (prod_rel (@top2 A unit) RS) (prod_rel top2 RS) ). { intros. inv H2. destruct REL1. destruct REL2. split; auto. etransitivity; eauto. } eapply pi_eqit_secure_trans_ret; eauto. apply pi_eqit_secure_sym. apply pi_eqit_secure_RR_imp with - (RR1 := prod_rel RS top2). + (RR1 := prod_rel top2 RS). { intros. inv H2. split; auto. symmetry. auto. } eapply RESCHECK0. reflexivity. - apply simpobs in H0. rewrite <- itree_eta in H0. rewrite H0. rewrite H. rewrite interp_state_vis. specialize (Hhandler A e). inv Hhandler; try contradiction. red in RESCHECK. apply RESCHECK in Hs as He. - remember (handler A e s1) as t3. clear Heqt3. + remember (runStateT (handler A e) s1) as t3. clear Heqt3. cbn in He. generalize dependent t3. gcofix CIH'. intros t3 Ht3. pinversion Ht3; use_simpobs; subst. + destruct H4. cbn in *. destruct r1. cbn in *. @@ -224,7 +227,7 @@ Proof. pclearbot. rewrite H0. rewrite interp_state_vis. specialize (Hhandler A e). inv Hhandler; try contradiction. red in RESCHECK. symmetry in Hs. apply RESCHECK in Hs as He. - remember (handler A e s2) as t3. clear Heqt3. + remember (runStateT (handler A e) s2) as t3. clear Heqt3. cbn in He. generalize dependent t3. gcofix CIH'. intros t3 Ht3. pinversion Ht3; use_simpobs; subst. + destruct H4. cbn in *. destruct r1. cbn in *. From a2efc92e052002537b788215d936ec436715f7bb Mon Sep 17 00:00:00 2001 From: Li-yao Xia Date: Thu, 3 Oct 2024 00:53:03 +0200 Subject: [PATCH 3/8] Fix Dijkstra.StateIOTrace with updated stateT --- extra/Dijkstra/StateIOTrace.v | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/extra/Dijkstra/StateIOTrace.v b/extra/Dijkstra/StateIOTrace.v index 0c21ebd6..84cdcfe9 100644 --- a/extra/Dijkstra/StateIOTrace.v +++ b/extra/Dijkstra/StateIOTrace.v @@ -258,7 +258,9 @@ Section PrintMults. setoid_rewrite bind_ret_l in H. unf_res. punfold H. red in H. cbn in *. - dependent induction H. + remember (interp_state handleIOStateE) as h eqn:Eh. + revert Eh. + dependent induction H; intros Eh. 2:{ rewrite <- x. constructor; auto. eapply IHruttF; eauto; reflexivity. } inversion H; ddestruction; subst; ddestruction; try contradiction. subst. specialize (H0 tt tt). @@ -274,8 +276,6 @@ Section PrintMults. eapply CIH with (Maps.add Y (n + m) si); try apply lookup_eq. 2: { rewrite lookup_neq; subst; auto. } rewrite tau_eutt in Hk1. - (* TODO: not sure why this is failing *) - (* setoid_rewrite bind_trigger in Hk1. setoid_rewrite interp_state_vis in Hk1. cbn in *. rewrite bind_ret_l in Hk1. rewrite tau_eutt in Hk1. @@ -286,13 +286,11 @@ Section PrintMults. rewrite interp_state_ret in Hk1. rewrite bind_ret_l in Hk1. cbn in *. rewrite tau_eutt in Hk1. - unfold Basics.iter, MonadIter_stateT0, Basics.iter, MonadIter_itree. + unfold Basics.iter, MonadIter_itree. match goal with H : _ ⊑ ITree.iter _ (?s1, _) |- _ ⊑ ITree.iter _ (?s2, _) => enough (Hseq : s2 = s1) end; try rewrite Hseq; auto. subst. rewrite Nat.add_comm. auto. - *) - admit. - Admitted. +Qed. End PrintMults. From 35a05c3c4b144f45bc0e7b570296a9ba6c466d95 Mon Sep 17 00:00:00 2001 From: Justin Frank Date: Thu, 3 Oct 2024 16:50:00 -0400 Subject: [PATCH 4/8] Generalize eq_stateT to a fully heterogenous stateT_rel --- extra/Dijkstra/StateIOTrace.v | 3 ++- theories/Events/StateFacts.v | 42 +++++++++++++++++------------------ 2 files changed, 23 insertions(+), 22 deletions(-) diff --git a/extra/Dijkstra/StateIOTrace.v b/extra/Dijkstra/StateIOTrace.v index 84cdcfe9..1de2c2c9 100644 --- a/extra/Dijkstra/StateIOTrace.v +++ b/extra/Dijkstra/StateIOTrace.v @@ -214,7 +214,8 @@ Section PrintMults. rewrite interp_state_trigger in H. cbn in *. rewrite bind_ret_l in H. cbn in *. specialize (@interp_state_iter' (StateE +' IO) ) as Hiter. - unfold eq_stateT in Hiter. rewrite Hiter in H. clear Hiter. + unfold stateT_rel, HeterogeneousRelations.subrelationH in Hiter. + rewrite Hiter in H; eauto. clear Hiter. remember (Maps.add X n s) as si. assert (si = alist_add RelDec_string X n s); try (subst; auto; fail). diff --git a/theories/Events/StateFacts.v b/theories/Events/StateFacts.v index 37b8c49d..9a4872b2 100644 --- a/theories/Events/StateFacts.v +++ b/theories/Events/StateFacts.v @@ -59,31 +59,30 @@ Proof. reflexivity. Qed. -Definition eq_stateT {S M X Y} - (RM : forall R1 R2, (R1 -> R2 -> Prop) -> M R1 -> M R2 -> Prop) - (RXY : X * S -> Y * S -> Prop) - : (stateT S M X) -> (stateT S M Y) -> Prop := - fun t1 t2 => forall s, RM _ _ RXY (runStateT t1 s) (runStateT t2 s). +Definition stateT_rel {S1 S2 M X Y} + (RM : (X * S1 -> Y * S2 -> Prop) -> M (X * S1)%type -> M (Y * S2)%type -> Prop) + (RS : S1 -> S2 -> Prop) + (RXY : X * S1 -> Y * S2 -> Prop) + : (stateT S1 M X) -> (stateT S2 M Y) -> Prop := + fun t1 t2 => subrelationH RS (fun s1 s2 => RM RXY (runStateT t1 s1) (runStateT t2 s2)). #[global] Instance eq_stateT_runStateT {S M R} - (RM : forall R1 R2, (R1 -> R2 -> Prop) -> M R1 -> M R2 -> Prop) : - Proper (eq_stateT RM eq ==> eq ==> (RM _ _ eq)) (@runStateT S _ R). + (RM : (R * S -> R * S -> Prop) -> M (R * S)%type -> M (R * S)%type -> Prop) : + Proper (stateT_rel RM eq eq ==> eq ==> (RM eq)) (@runStateT S _ R). Proof. unfold Proper, respectful. intros x y Heq_stateT sx sy Heq. - specialize (Heq_stateT sy). - rewrite Heq. - apply Heq_stateT. + apply Heq_stateT, Heq. Qed. #[global] Instance eq_itree_interp_state {E F S R} (h : E ~> stateT S (itree F)) : - Proper (eq_itree eq ==> (eq_stateT (@eq_itree _) eq)) + Proper (eq_itree eq ==> (stateT_rel eq_itree eq eq)) (@interp_state _ _ _ _ _ _ h R). Proof. revert_until R. - ginit. pcofix CIH. intros h x y H0 s. + ginit. pcofix CIH. intros h x y H0 s1 s2 Heq. rewrite !unfold_interp_state. punfold H0; repeat red in H0. destruct H0; subst; pclearbot; try discriminate; cbn. @@ -170,7 +169,7 @@ Qed. #[global] Instance eutt_interp_state {E F: Type -> Type} {S : Type} (h : E ~> stateT S (itree F)) R RR : - Proper (eutt RR ==> eq_stateT (@eutt _) (prod_rel RR eq)) (@interp_state E (itree F) S _ _ _ h R). + Proper (eutt RR ==> stateT_rel eutt eq (prod_rel RR eq)) (@interp_state E (itree F) S _ _ _ h R). Proof. repeat intro. subst. revert_until RR. einit. ecofix CIH. intros. @@ -189,7 +188,7 @@ Qed. #[global] Instance eutt_interp_state_eq {E F: Type -> Type} {S : Type} (h : E ~> stateT S (itree F)) R : - Proper (eutt eq ==> eq_stateT (@eutt _) eq) (@interp_state E (itree F) S _ _ _ h R). + Proper (eutt eq ==> stateT_rel eutt eq eq) (@interp_state E (itree F) S _ _ _ h R). Proof. repeat intro. subst. revert_until R. einit. ecofix CIH. intros. @@ -302,18 +301,18 @@ Qed. Lemma interp_state_iter {E F } S (f : E ~> stateT S (itree F)) {I A} (t : I -> itree E (I + A)) (t' : I -> stateT S (itree F) (I + A)) - (EQ_t : forall i, eq_stateT (@eq_itree _) eq (State.interp_state f (t i)) (t' i)) - : forall i, eq_stateT (@eq_itree _) eq (State.interp_state f (ITree.iter t i)) + (EQ_t : forall i, stateT_rel eq_itree eq eq (State.interp_state f (t i)) (t' i)) + : forall i, stateT_rel eq_itree eq eq (State.interp_state f (ITree.iter t i)) (Basics.iter t' i). Proof. unfold Basics.iter, MonadIter_stateT, Basics.iter, MonadIter_itree in *. - ginit. pcofix CIH; intros i s. cbn. + ginit. pcofix CIH; intros i s1 s2. cbn. rewrite 2 unfold_iter; cbn. rewrite !bind_bind. setoid_rewrite bind_ret_l. rewrite interp_state_bind. guclo eqit_clo_bind; econstructor; eauto. - - apply EQ_t. + - apply EQ_t. auto. - intros [[] s'] _ []; cbn. + rewrite interp_state_tau. gstep; constructor. @@ -323,12 +322,13 @@ Qed. Lemma interp_state_iter' {E F } S (f : E ~> stateT S (itree F)) {I A} (t : I -> itree E (I + A)) - : forall i, eq_stateT (@eq_itree _) eq (State.interp_state f (ITree.iter t i)) + : forall i, stateT_rel eq_itree eq eq (State.interp_state f (ITree.iter t i)) (Basics.iter (fun i => State.interp_state f (t i)) i). Proof. eapply interp_state_iter. - intros i. - red. reflexivity. + intros i s1 s2 Heq. + rewrite Heq. + reflexivity. Qed. Lemma interp_state_iter'_eutt {E F S} (f: E ~> stateT S (itree F)) {I A} From b34b8bf4e94a0d4b51573781df197db048e5309a Mon Sep 17 00:00:00 2001 From: Justin Frank Date: Mon, 7 Oct 2024 13:32:35 -0400 Subject: [PATCH 5/8] change state relation to be heterogenous --- extra/Dijkstra/StateIOTrace.v | 1 - theories/Events/Map.v | 12 ++++++------ theories/Events/MapDefault.v | 12 ++++++------ theories/Events/StateFacts.v | 29 ++++++++++++++--------------- 4 files changed, 26 insertions(+), 28 deletions(-) diff --git a/extra/Dijkstra/StateIOTrace.v b/extra/Dijkstra/StateIOTrace.v index 1de2c2c9..4252d10a 100644 --- a/extra/Dijkstra/StateIOTrace.v +++ b/extra/Dijkstra/StateIOTrace.v @@ -214,7 +214,6 @@ Section PrintMults. rewrite interp_state_trigger in H. cbn in *. rewrite bind_ret_l in H. cbn in *. specialize (@interp_state_iter' (StateE +' IO) ) as Hiter. - unfold stateT_rel, HeterogeneousRelations.subrelationH in Hiter. rewrite Hiter in H; eauto. clear Hiter. remember (Maps.add X n s) as si. diff --git a/theories/Events/Map.v b/theories/Events/Map.v index 0ddc4368..4f99723a 100644 --- a/theories/Events/Map.v +++ b/theories/Events/Map.v @@ -52,12 +52,12 @@ Section Map. Context {M : Map K V map}. Definition handle_map {E} : mapE ~> stateT map (itree E) := - fun _ e => - match e with - | Insert k v => mkStateT (fun env => Ret (tt, Maps.add k v env)) - | Lookup k => mkStateT (fun env => Ret (Maps.lookup k env, env)) - | Remove k => mkStateT (fun env => Ret (tt, Maps.remove k env)) - end. + fun _ e => mkStateT (fun env => + match e in mapE R return itree E (R * _) with + | Insert k v => Ret (tt, Maps.add k v env) + | Lookup k => Ret (Maps.lookup k env, env) + | Remove k => Ret (tt, Maps.remove k env) + end). (* not sure why case_ requires manual parameters *) Definition run_map {E} : itree (mapE +' E) ~> stateT map (itree E) := diff --git a/theories/Events/MapDefault.v b/theories/Events/MapDefault.v index d8115e0d..47384204 100644 --- a/theories/Events/MapDefault.v +++ b/theories/Events/MapDefault.v @@ -56,12 +56,12 @@ Section Map. end. Definition handle_map {E d} : mapE d ~> stateT map (itree E) := - fun _ e => - match e with - | Insert k v => mkStateT (fun env => Ret (tt, Maps.add k v env)) - | LookupDef k => mkStateT (fun env => Ret (lookup_default k d env, env)) - | Remove k => mkStateT (fun env => Ret (tt, Maps.remove k env)) - end. + fun _ e => mkStateT (fun env => + match e in mapE _ R return itree E (R * _) with + | Insert k v => Ret (tt, Maps.add k v env) + | LookupDef k => Ret (lookup_default k d env, env) + | Remove k => Ret (tt, Maps.remove k env) + end). (* SAZ: I think that all of these [run_foo] functions should be renamed [interp_foo]. That would be more consistent with the idea that diff --git a/theories/Events/StateFacts.v b/theories/Events/StateFacts.v index 9a4872b2..ead06022 100644 --- a/theories/Events/StateFacts.v +++ b/theories/Events/StateFacts.v @@ -59,17 +59,16 @@ Proof. reflexivity. Qed. -Definition stateT_rel {S1 S2 M X Y} - (RM : (X * S1 -> Y * S2 -> Prop) -> M (X * S1)%type -> M (Y * S2)%type -> Prop) +Definition stateT_rel {S1 S2 M1 M2 R1 R2} (RS : S1 -> S2 -> Prop) - (RXY : X * S1 -> Y * S2 -> Prop) - : (stateT S1 M X) -> (stateT S2 M Y) -> Prop := - fun t1 t2 => subrelationH RS (fun s1 s2 => RM RXY (runStateT t1 s1) (runStateT t2 s2)). + (RM : M1 (R1 * S1)%type -> M2 (R2 * S2)%type -> Prop) + : (stateT S1 M1 R1) -> (stateT S2 M2 R2) -> Prop := + fun t1 t2 => forall s1 s2, RS s1 s2 -> RM (runStateT t1 s1) (runStateT t2 s2). #[global] Instance eq_stateT_runStateT {S M R} - (RM : (R * S -> R * S -> Prop) -> M (R * S)%type -> M (R * S)%type -> Prop) : - Proper (stateT_rel RM eq eq ==> eq ==> (RM eq)) (@runStateT S _ R). + (RM : M (R * S)%type -> M (R * S)%type -> Prop) : + Proper (stateT_rel eq RM ==> eq ==> RM) (@runStateT S _ R). Proof. unfold Proper, respectful. intros x y Heq_stateT sx sy Heq. @@ -78,7 +77,7 @@ Qed. #[global] Instance eq_itree_interp_state {E F S R} (h : E ~> stateT S (itree F)) : - Proper (eq_itree eq ==> (stateT_rel eq_itree eq eq)) + Proper (eq_itree eq ==> (stateT_rel eq (eq_itree eq))) (@interp_state _ _ _ _ _ _ h R). Proof. revert_until R. @@ -169,7 +168,7 @@ Qed. #[global] Instance eutt_interp_state {E F: Type -> Type} {S : Type} (h : E ~> stateT S (itree F)) R RR : - Proper (eutt RR ==> stateT_rel eutt eq (prod_rel RR eq)) (@interp_state E (itree F) S _ _ _ h R). + Proper (eutt RR ==> stateT_rel eq (eutt (prod_rel RR eq))) (@interp_state E (itree F) S _ _ _ h R). Proof. repeat intro. subst. revert_until RR. einit. ecofix CIH. intros. @@ -188,7 +187,7 @@ Qed. #[global] Instance eutt_interp_state_eq {E F: Type -> Type} {S : Type} (h : E ~> stateT S (itree F)) R : - Proper (eutt eq ==> stateT_rel eutt eq eq) (@interp_state E (itree F) S _ _ _ h R). + Proper (eutt eq ==> stateT_rel eq (eutt eq)) (@interp_state E (itree F) S _ _ _ h R). Proof. repeat intro. subst. revert_until R. einit. ecofix CIH. intros. @@ -301,18 +300,18 @@ Qed. Lemma interp_state_iter {E F } S (f : E ~> stateT S (itree F)) {I A} (t : I -> itree E (I + A)) (t' : I -> stateT S (itree F) (I + A)) - (EQ_t : forall i, stateT_rel eq_itree eq eq (State.interp_state f (t i)) (t' i)) - : forall i, stateT_rel eq_itree eq eq (State.interp_state f (ITree.iter t i)) + (EQ_t : forall i, stateT_rel eq (eq_itree eq) (State.interp_state f (t i)) (t' i)) + : forall i, stateT_rel eq (eq_itree eq) (State.interp_state f (ITree.iter t i)) (Basics.iter t' i). Proof. unfold Basics.iter, MonadIter_stateT, Basics.iter, MonadIter_itree in *. - ginit. pcofix CIH; intros i s1 s2. cbn. + ginit. pcofix CIH; intros i s1 s2 Heq. cbn. rewrite 2 unfold_iter; cbn. rewrite !bind_bind. setoid_rewrite bind_ret_l. rewrite interp_state_bind. guclo eqit_clo_bind; econstructor; eauto. - - apply EQ_t. auto. + - apply EQ_t, Heq. - intros [[] s'] _ []; cbn. + rewrite interp_state_tau. gstep; constructor. @@ -322,7 +321,7 @@ Qed. Lemma interp_state_iter' {E F } S (f : E ~> stateT S (itree F)) {I A} (t : I -> itree E (I + A)) - : forall i, stateT_rel eq_itree eq eq (State.interp_state f (ITree.iter t i)) + : forall i, stateT_rel eq (eq_itree eq) (State.interp_state f (ITree.iter t i)) (Basics.iter (fun i => State.interp_state f (t i)) i). Proof. eapply interp_state_iter. From f976fa8e59d87f33a59efbcdc1f952c6ee1a6196 Mon Sep 17 00:00:00 2001 From: Justin Frank Date: Mon, 7 Oct 2024 21:30:43 -0400 Subject: [PATCH 6/8] Update tutorial proofs --- theories/Events/StateFacts.v | 36 +++++++----- tutorial/Asm.v | 36 ++++++------ tutorial/AsmOptimization.v | 56 ++++++++++--------- tutorial/Imp.v | 16 +++--- tutorial/Imp2AsmCorrectness.v | 90 +++++++++++++++++------------- tutorial/Introduction.v | 22 +++++--- tutorial/extract-imptest/ImpTest.v | 2 +- 7 files changed, 145 insertions(+), 113 deletions(-) diff --git a/theories/Events/StateFacts.v b/theories/Events/StateFacts.v index ead06022..ee120b50 100644 --- a/theories/Events/StateFacts.v +++ b/theories/Events/StateFacts.v @@ -59,29 +59,36 @@ Proof. reflexivity. Qed. +(* Definition stateT_rel {S1 S2 M1 M2 R1 R2} (RS : S1 -> S2 -> Prop) (RM : M1 (R1 * S1)%type -> M2 (R2 * S2)%type -> Prop) : (stateT S1 M1 R1) -> (stateT S2 M2 R2) -> Prop := fun t1 t2 => forall s1 s2, RS s1 s2 -> RM (runStateT t1 s1) (runStateT t2 s2). +*) +Definition eq_stateT {S M R} + (RM : M (R * S)%type -> M (R * S)%type -> Prop) + : stateT S M R -> stateT S M R -> Prop := + fun t1 t2 => forall s, RM (runStateT t1 s) (runStateT t2 s). #[global] Instance eq_stateT_runStateT {S M R} (RM : M (R * S)%type -> M (R * S)%type -> Prop) : - Proper (stateT_rel eq RM ==> eq ==> RM) (@runStateT S _ R). + Proper (eq_stateT RM ==> eq ==> RM) (@runStateT S _ R). Proof. unfold Proper, respectful. intros x y Heq_stateT sx sy Heq. - apply Heq_stateT, Heq. + rewrite Heq. + apply Heq_stateT. Qed. #[global] Instance eq_itree_interp_state {E F S R} (h : E ~> stateT S (itree F)) : - Proper (eq_itree eq ==> (stateT_rel eq (eq_itree eq))) + Proper (eq_itree eq ==> (eq_stateT (eq_itree eq))) (@interp_state _ _ _ _ _ _ h R). Proof. revert_until R. - ginit. pcofix CIH. intros h x y H0 s1 s2 Heq. + ginit. pcofix CIH. intros h x y H0 s. rewrite !unfold_interp_state. punfold H0; repeat red in H0. destruct H0; subst; pclearbot; try discriminate; cbn. @@ -168,7 +175,7 @@ Qed. #[global] Instance eutt_interp_state {E F: Type -> Type} {S : Type} (h : E ~> stateT S (itree F)) R RR : - Proper (eutt RR ==> stateT_rel eq (eutt (prod_rel RR eq))) (@interp_state E (itree F) S _ _ _ h R). + Proper (eutt RR ==> eq_stateT (eutt (prod_rel RR eq))) (@interp_state E (itree F) S _ _ _ h R). Proof. repeat intro. subst. revert_until RR. einit. ecofix CIH. intros. @@ -187,7 +194,7 @@ Qed. #[global] Instance eutt_interp_state_eq {E F: Type -> Type} {S : Type} (h : E ~> stateT S (itree F)) R : - Proper (eutt eq ==> stateT_rel eq (eutt eq)) (@interp_state E (itree F) S _ _ _ h R). + Proper (eutt eq ==> eq_stateT (eutt eq)) (@interp_state E (itree F) S _ _ _ h R). Proof. repeat intro. subst. revert_until R. einit. ecofix CIH. intros. @@ -300,18 +307,18 @@ Qed. Lemma interp_state_iter {E F } S (f : E ~> stateT S (itree F)) {I A} (t : I -> itree E (I + A)) (t' : I -> stateT S (itree F) (I + A)) - (EQ_t : forall i, stateT_rel eq (eq_itree eq) (State.interp_state f (t i)) (t' i)) - : forall i, stateT_rel eq (eq_itree eq) (State.interp_state f (ITree.iter t i)) - (Basics.iter t' i). + (EQ_t : forall i, eq_stateT (eq_itree eq) (State.interp_state f (t i)) (t' i)) + : forall i, eq_stateT (eq_itree eq) (State.interp_state f (ITree.iter t i)) + (Basics.iter t' i). Proof. unfold Basics.iter, MonadIter_stateT, Basics.iter, MonadIter_itree in *. - ginit. pcofix CIH; intros i s1 s2 Heq. cbn. + ginit. pcofix CIH; intros i s. cbn. rewrite 2 unfold_iter; cbn. rewrite !bind_bind. setoid_rewrite bind_ret_l. rewrite interp_state_bind. guclo eqit_clo_bind; econstructor; eauto. - - apply EQ_t, Heq. + - apply EQ_t. - intros [[] s'] _ []; cbn. + rewrite interp_state_tau. gstep; constructor. @@ -321,12 +328,11 @@ Qed. Lemma interp_state_iter' {E F } S (f : E ~> stateT S (itree F)) {I A} (t : I -> itree E (I + A)) - : forall i, stateT_rel eq (eq_itree eq) (State.interp_state f (ITree.iter t i)) - (Basics.iter (fun i => State.interp_state f (t i)) i). + : forall i, eq_stateT (eq_itree eq) (State.interp_state f (ITree.iter t i)) + (Basics.iter (fun i => State.interp_state f (t i)) i). Proof. eapply interp_state_iter. - intros i s1 s2 Heq. - rewrite Heq. + intros i s. reflexivity. Qed. diff --git a/tutorial/Asm.v b/tutorial/Asm.v index 410fe4fd..6b94462a 100644 --- a/tutorial/Asm.v +++ b/tutorial/Asm.v @@ -13,7 +13,9 @@ From Coq Require Import Setoid RelationClasses. -Require Import ExtLib.Structures.Monad. +From ExtLib Require Import + Structures.Monad + Data.Monads.StateMonad. (* SAZ: Should we add ITreeMonad to ITree? *) From ITree Require Import @@ -287,16 +289,16 @@ Instance RelDec_reg : RelDec (@eq reg) := RelDec_from_dec eq Nat.eq_dec. (** Both environments and memory events can be interpreted as "map" events, exactly as we did for _Imp_. *) -Definition h_reg {E: Type -> Type} `{mapE reg 0 -< E} - : Reg ~> itree E := +Definition h_reg + : Reg ~> itree (mapE reg 0) := fun _ e => match e with | GetReg x => lookup_def x | SetReg x v => insert x v end. -Definition h_memory {E : Type -> Type} `{mapE addr 0 -< E} : - Memory ~> itree E := +Definition h_memory : + Memory ~> itree (mapE addr 0) := fun _ e => match e with | Load x => lookup_def x @@ -314,16 +316,17 @@ Definition memory := alist addr value. do a bit of post-processing to swap the order of the "state components" introduced by the interpretation. *) +Check interp_map. Definition interp_asm {E A} (t : itree (Reg +' Memory +' E) A) : - memory -> registers -> itree E (memory * (registers * A)) := - let h := bimap h_reg (bimap h_memory (id_ _)) in + memory -> registers -> itree E (A * registers * memory) := + let h := bimap h_reg (bimap h_memory (id_ E)) in let t' := interp h t in - fun mem regs => interp_map (interp_map t' regs) mem. + fun mem regs => runStateT (interp_map (runStateT (interp_map t') regs)) mem. (** We can then define an evaluator for closed assembly programs by interpreting both store and heap events into two instances of [mapE], and running them both in the empty initial environments. *) -Definition run_asm (p : asm 1 0) : itree Exit (memory * (registers * fin 0)) := +Definition run_asm (p : asm 1 0) : itree Exit (fin 0 * registers * memory) := interp_asm (denote_asm p Fin.f0) empty empty. (* SAZ: Should some of thes notions of equivalence be put into the library? @@ -345,25 +348,25 @@ Section InterpAsmProperties. (** This interpreter is compatible with the equivalence-up-to-tau. *) #[global] Instance eutt_interp_asm {R}: - Proper (@eutt E R R eq ==> eq ==> eq ==> @eutt E' (prod memory (prod registers R)) (prod _ (prod _ R)) eq) interp_asm. + Proper (@eutt E R R eq ==> eq ==> eq ==> eutt eq) interp_asm. Proof. repeat intro. unfold interp_asm. - unfold interp_map. - rewrite H0. + cbn. + subst. rewrite H. - rewrite H1. reflexivity. Qed. (** [interp_asm] commutes with [Ret]. *) Lemma interp_asm_ret: forall {R} (r: R) (regs : registers) (mem: memory), @eutt E' _ _ eq (interp_asm (ret r) mem regs) - (ret (mem, (regs, r))). + (ret (r, regs, mem)). Proof. unfold interp_asm, interp_map. intros. unfold ret at 1, Monad_itree. + cbn. rewrite interp_ret, 2 interp_state_ret. reflexivity. Qed. @@ -371,7 +374,8 @@ Section InterpAsmProperties. (** [interp_asm] commutes with [bind]. *) Lemma interp_asm_bind: forall {R S} (t: itree E R) (k: R -> itree _ S) (regs : registers) (mem: memory), @eutt E' _ _ eq (interp_asm (ITree.bind t k) mem regs) - (ITree.bind (interp_asm t mem regs) (fun '(mem', (regs', x)) => interp_asm (k x) mem' regs')). + (ITree.bind (interp_asm t mem regs) + (fun '(x, regs', mem') => interp_asm (k x) mem' regs')). Proof. intros. @@ -384,7 +388,7 @@ Section InterpAsmProperties. { reflexivity. } intros. rewrite H. - destruct u2 as [g' [l' x]]. + destruct u2 as [[x l'] g']. reflexivity. Qed. diff --git a/tutorial/AsmOptimization.v b/tutorial/AsmOptimization.v index 712f3251..d0e40b2f 100644 --- a/tutorial/AsmOptimization.v +++ b/tutorial/AsmOptimization.v @@ -28,7 +28,8 @@ From ExtLib Require Import Structures.Monad Structures.Maps Programming.Show - Data.Map.FMapAList. + Data.Map.FMapAList + Data.Monads.StateMonad. Import ListNotations. Open Scope string_scope. @@ -99,8 +100,8 @@ Global Instance EQ_memory_eqv : Equivalence (EQ_memory). constructor; typeclasses eauto. Qed. -Definition rel_asm {B} : memory * (registers * B) -> memory * (registers * B) -> Prop := - prod_rel EQ_memory (prod_rel (EQ_registers 0) eq). +Definition rel_asm {B} : (B * registers * memory) -> (B * registers * memory) -> Prop := + prod_rel (prod_rel eq (EQ_registers 0)) EQ_memory. Global Hint Unfold rel_asm: core. @@ -121,7 +122,7 @@ Definition optimization_correct {E} `{Exit -< E} A B (opt:optimization) : Prop : forall (p : asm A B), eq_asm_EQ (E := E) p (opt p). -Definition EQ_asm {E A} (f g : memory -> registers -> itree E (memory * (registers * A))) : Prop := +Definition EQ_asm {E A} (f g : memory -> registers -> itree E (A * registers * memory)) : Prop := forall mem1 mem2 regs1 regs2, EQ_memory mem1 mem2 -> EQ_registers 0 regs1 regs2 -> @@ -143,17 +144,17 @@ Proof. reflexivity. } intros. - destruct H as [J1 [J2 J3]]; subst. + destruct H as [[J1 J2] J3]; subst. unfold interp_asm. unfold interp_map. - destruct u2 as [? [? []]]. + destruct u2 as [[[] ?] ?]. rewrite interp_ret. do 2 rewrite interp_state_ret. apply eqit_Ret. auto. Qed. Lemma interp_asm_ret {E A} (x:A) mem reg : - interp_asm (Ret x) mem reg ≈ (Ret (mem, (reg, x)) : itree E _). + interp_asm (Ret x) mem reg ≈ (Ret (x, reg, mem) : itree E _). Proof. unfold interp_asm, interp_map. rewrite interp_ret. @@ -274,14 +275,15 @@ Proof. reflexivity. Qed. -Lemma interp_state_iter' {E F } S (f : E ~> Monads.stateT S (itree F)) {I A} +Lemma interp_state_iter' {E F } S (f : E ~> stateT S (itree F)) {I A} (t : I -> itree E (I + A)) - : forall i, state_eq (State.interp_state f (ITree.iter t i)) - (Basics.iter (fun i => State.interp_state f (t i)) i). + : forall i, eq_stateT (eq_itree eq) + (State.interp_state f (ITree.iter t i)) + (Basics.iter (fun i => State.interp_state f (t i)) i). Proof. eapply interp_state_iter. - intros i. - red. reflexivity. + intros i s. + reflexivity. Qed. (* peephole optimizations --------------------------------------------------- *) @@ -334,8 +336,8 @@ Lemma ph_blk_append_correct {E} {HasExit : Exit -< E} : eapply eutt_clo_bind. apply H2; auto. intros. - destruct H0 as [J1 [J2 J3]]. - destruct u1 as [? [? []]], u2 as [? [? []]]. cbn in *. + destruct H0 as [[J1 J2] J3]. + destruct u1 as [[[] ?] ?], u2 as [[[] ?] ?]. cbn in *. apply HP; auto. Qed. @@ -414,18 +416,18 @@ Proof. repeat rewrite interp_ret. repeat rewrite interp_state_ret. apply eqit_Ret. constructor; auto. - - intros. destruct H0 as [J1 [J2 J3]]. + - intros. destruct H0 as [[J1 J2] J3]. subst. cbn. unfold CategorySub.from_bif, FromBifunctor_ktree_fin. repeat rewrite interp_ret. repeat rewrite interp_state_ret. apply eqit_Ret. constructor; cbn; auto. constructor; cbn; auto. - rewrite J3. reflexivity. } + rewrite J1. reflexivity. } intros. - destruct H0 as [J1 [J2 J3]]; subst. - simpl in *. + destruct H0 as [[J1 J2] J3]; subst. + cbn in *. unfold denote_bks. unfold iter, CategorySub.Iter_sub. repeat rewrite interp_iter. @@ -433,21 +435,21 @@ Proof. cbn. assert (JJ := @interp_state_iter'). red in JJ. - unfold Basics.iter, MonadIter_stateT0, Basics.iter, MonadIter_itree in *. + unfold Basics.iter, MonadIter_stateT, Basics.iter, MonadIter_itree in *. cbn in *. repeat rewrite JJ. eapply eutt_iter' with (RI := rel_asm); cbn; auto. - intros j1 j2 [K1 [K2 ->]]; cbn. + intros j1 j2 [[-> K1] K2]; cbn. rewrite !interp_bind, !interp_state_bind, !bind_bind. (* Slow! *) apply (@eutt_clo_bind _ _ _ _ _ _ rel_asm); - [|intros ? ? [? [? ->]]]; cbn. + [|intros ? ? [[-> ?] ?]]; cbn. { refine (peephole_block_correct _ _ _ _ _ _ _ _ _ _ _ _); eauto. } unfold CategorySub.to_bif, ToBifunctor_ktree_fin. apply (@eutt_clo_bind _ _ _ _ _ _ rel_asm); - [|intros ? ? [? [? ->]]]; cbn. + [|intros ? ? [[-> ?] ?]]; cbn. { rewrite bind_ret_l. unfold case_, Case_sum1, Case_Kleisli, case_sum. @@ -508,17 +510,17 @@ Proof. intros E i. unfold eq_asm_denotations_EQ. intros. - destruct i; simpl; try apply interp_asm_ret_tt; auto; try reflexivity. + destruct i; cbn; try apply interp_asm_ret_tt; auto; try reflexivity. destruct src. - + simpl. rewrite !bind_ret_l. + + cbn. rewrite !bind_ret_l. apply interp_asm_ret_tt; auto; try reflexivity. - + simpl. + + cbn. destruct (Nat.eq_dec dest r). * subst. rewrite Nat.eqb_refl. - simpl. + cbn. rewrite interp_asm_ret. rewrite interp_asm_GetReg. @@ -527,7 +529,7 @@ Proof. rewrite interp_vis. cbn. repeat rewrite interp_state_bind. - unfold CategoryOps.cat, Cat_Handler, Handler.cat. simpl. + unfold CategoryOps.cat, Cat_Handler, Handler.cat. cbn. unfold inl_, Inl_sum1_Handler, Handler.inl_, Handler.htrigger. unfold insert. unfold embed, Embeddable_itree, Embeddable_forall, inl_, embed. diff --git a/tutorial/Imp.v b/tutorial/Imp.v index ef832efc..2344ec8b 100644 --- a/tutorial/Imp.v +++ b/tutorial/Imp.v @@ -51,7 +51,8 @@ From ExtLib Require Import Data.String Structures.Monad Structures.Traversable - Data.List. + Data.List + Data.Monads.StateMonad. From ITree Require Import ITree @@ -324,12 +325,12 @@ forall eff, {pf:E -< eff == F[E]} (t : itree eff A) interp pf h h' t : M A *) -Definition interp_imp {E A} (t : itree (ImpState +' E) A) : stateT env (itree E) A := +Definition interp_imp {E A} (t : itree (ImpState +' E) A) : env -> (itree E) (A * env) := let t' := interp (bimap handle_ImpState (id_ E)) t in - interp_map t'. + runStateT (interp_map t'). -Definition eval_imp (s: stmt) : itree void1 (env * unit) := +Definition eval_imp (s: stmt) : itree void1 (unit * env) := interp_imp (denote_imp s) empty. (** Equipped with this evaluator, we can now compute. @@ -358,20 +359,21 @@ Section InterpImpProperties. (** This interpreter is compatible with the equivalence-up-to-tau. *) Global Instance eutt_interp_imp {R}: - Proper (@eutt E R R eq ==> eq ==> @eutt E' (prod (env) R) (prod _ R) eq) + Proper (@eutt E R R eq ==> eq ==> (eutt eq)) interp_imp. Proof. repeat intro. unfold interp_imp. unfold interp_map. - rewrite H0. eapply eutt_interp_state_eq; auto. + rewrite H0. + eapply eutt_interp_state_eq; auto. rewrite H. reflexivity. Qed. (** [interp_imp] commutes with [bind]. *) Lemma interp_imp_bind: forall {R S} (t: itree E R) (k: R -> itree E S) (g : env), (interp_imp (ITree.bind t k) g) - ≅ (ITree.bind (interp_imp t g) (fun '(g', x) => interp_imp (k x) g')). + ≅ (ITree.bind (interp_imp t g) (fun '(x, g') => interp_imp (k x) g')). Proof. intros. unfold interp_imp. diff --git a/tutorial/Imp2AsmCorrectness.v b/tutorial/Imp2AsmCorrectness.v index 88a97b61..92215872 100644 --- a/tutorial/Imp2AsmCorrectness.v +++ b/tutorial/Imp2AsmCorrectness.v @@ -79,7 +79,8 @@ From ExtLib Require Import Core.RelDec Structures.Monad Structures.Maps - Data.Map.FMapAList. + Data.Map.FMapAList + Data.Monads.StateMonad. Import ListNotations. Open Scope string_scope. @@ -156,15 +157,15 @@ Section Simulation_Relation. - The "stack" of temporaries used to compute intermediate results is left untouched. *) - Definition sim_rel l_asm n: (env * value) -> (memory * (registers * unit)) -> Prop := - fun '(g_imp', v) '(g_asm', (l_asm', _)) => + Definition sim_rel l_asm n: (value * env) -> (unit * registers * memory) -> Prop := + fun '(v, g_imp') '(_, l_asm', g_asm') => Renv g_imp' g_asm' /\ (* we don't corrupt any of the imp variables *) alist_In n l_asm' v /\ (* we get the right value *) (forall m, m < n -> forall v, (* we don't mess with anything on the "stack" *) alist_In m l_asm v <-> alist_In m l_asm' v). Lemma sim_rel_find : forall g_asm g_imp l_asm l_asm' n v, - sim_rel l_asm n (g_imp, v) (g_asm, (l_asm', tt)) -> + sim_rel l_asm n (v, g_imp) (tt, l_asm', g_asm) -> alist_find n l_asm' = Some v. Proof. intros. @@ -193,7 +194,7 @@ Section Simulation_Relation. (** [sim_rel] can be initialized from [Renv]. *) Lemma sim_rel_add: forall g_asm l_asm g_imp n v, Renv g_imp g_asm -> - sim_rel l_asm n (g_imp, v) (g_asm, (alist_add n v l_asm, tt)). + sim_rel l_asm n (v, g_imp) (tt, alist_add n v l_asm, g_asm). Proof. intros. split; [| split]. @@ -205,14 +206,14 @@ Section Simulation_Relation. (** [Renv] can be recovered from [sim_rel]. *) Lemma sim_rel_Renv: forall l_asm n s1 l v1 s2 v2, - sim_rel l_asm n (s2,v2) (s1,(l,v1)) -> Renv s2 s1 . + sim_rel l_asm n (v2,s2) (v1,l,s1) -> Renv s2 s1 . Proof. intros ? ? ? ? ? ? ? H; apply H. Qed. Lemma sim_rel_find_tmp_n: forall l_asm g_asm' n l_asm' g_imp' v, - sim_rel l_asm n (g_imp',v) (g_asm', (l_asm', tt)) -> + sim_rel l_asm n (v,g_imp') (tt, l_asm', g_asm') -> alist_In n l_asm' v. Proof. intros ? ? ? ? ? ? [_ [H _]]; exact H. @@ -223,7 +224,7 @@ Section Simulation_Relation. Lemma sim_rel_find_tmp_lt_n: forall l_asm g_asm' n m l_asm' g_imp' v, m < n -> - sim_rel l_asm n (g_imp',v) (g_asm', (l_asm', tt)) -> + sim_rel l_asm n (v, g_imp') (tt, l_asm', g_asm') -> alist_find m l_asm = alist_find m l_asm'. Proof. intros ? ? ? ? ? ? ? ineq [_ [_ H]]. @@ -240,8 +241,8 @@ Section Simulation_Relation. Lemma sim_rel_find_tmp_n_trans: forall l_asm n l_asm' l_asm'' g_asm' g_asm'' g_imp' g_imp'' v v', - sim_rel l_asm n (g_imp',v) (g_asm', (l_asm', tt)) -> - sim_rel l_asm' (S n) (g_imp'',v') (g_asm'', (l_asm'', tt)) -> + sim_rel l_asm n (v,g_imp') (tt, l_asm', g_asm') -> + sim_rel l_asm' (S n) (v',g_imp'') (tt, l_asm'', g_asm'') -> alist_In n l_asm'' v. Proof. intros. @@ -271,10 +272,10 @@ Section Simulation_Relation. Lemma sim_rel_binary_op: forall (l_asm l_asm' l_asm'' : registers) (g_asm' g_asm'' : memory) (g_imp' g_imp'' : env) (n v v' : nat) - (Hsim : sim_rel l_asm n (g_imp', v) (g_asm', (l_asm', tt))) - (Hsim': sim_rel l_asm' (S n) (g_imp'', v') (g_asm'', (l_asm'', tt))) + (Hsim : sim_rel l_asm n (v, g_imp') (tt, l_asm', g_asm')) + (Hsim': sim_rel l_asm' (S n) (v', g_imp'') (tt, l_asm'', g_asm'')) (op: nat -> nat -> nat), - sim_rel l_asm n (g_imp'', op v v') (g_asm'', (alist_add n (op v v') l_asm'', tt)). + sim_rel l_asm n (op v v', g_imp'') (tt, alist_add n (op v v') l_asm'', g_asm''). Proof. intros. split; [| split]. @@ -325,8 +326,8 @@ Section Bisimulation. Context {A B : Type}. Context (RAB : A -> B -> Prop). (* relation on Imp / Asm values *) - Definition state_invariant (a : Imp.env * A) (b : Asm.memory * (Asm.registers * B)) := - Renv (fst a) (fst b) /\ (RAB (snd a) (snd (snd b))). + Definition state_invariant (a : A * Imp.env) (b : B * Asm.registers * Asm.memory) := + Renv (snd a) (snd b) /\ (RAB (fst a) (fst (fst b))). Definition bisimilar {E} (t1 : itree (ImpState +' E) A) (t2 : itree (Reg +' Memory +' E) B) := forall g_asm g_imp l, @@ -364,7 +365,7 @@ Section Bisimulation. { eapply H; auto. } intros. destruct u1 as [? ?]. - destruct u2 as [? [? ?]]. + destruct u2 as [[? ?] ?]. unfold state_invariant in H2. simpl in H2. destruct H2. subst. eapply H0; eauto. @@ -380,14 +381,14 @@ Section Bisimulation. bisimilar S (iter (C := ktree _) t1 x) (iter (C := ktree _) t2 x'). Proof. - unfold bisimilar, interp_asm, interp_imp, interp_map. + unfold bisimilar, interp_asm, interp_imp, interp_map. cbn. intros. rewrite 2 interp_iter. unfold iter, Iter_Kleisli. pose proof @interp_state_iter'. red in H2. unfold Basics.iter, MonadIter_itree. rewrite 2 H2. - unfold Basics.iter, MonadIter_stateT0, Basics.iter, MonadIter_itree; cbn. + unfold Basics.iter, MonadIter_stateT, Basics.iter, MonadIter_itree; cbn. rewrite H2. apply (eutt_iter' (state_invariant R)). intros. @@ -406,13 +407,13 @@ Section Bisimulation. *) Lemma sim_rel_get_tmp0: forall {E} n l l' g_asm g_imp v, - sim_rel l' n (g_imp,v) (g_asm, (l,tt)) -> + sim_rel l' n (v,g_imp) (tt,l,g_asm) -> (interp_asm ((trigger (GetReg n)) : itree (Reg +' Memory +' E) value) - g_asm l) - ≈ (Ret (g_asm, (l, v))). + g_asm l) + ≈ (Ret (v, l, g_asm)). Proof. intros. - unfold interp_asm. + unfold interp_asm. cbn. rewrite interp_trigger. cbn. unfold interp_map. @@ -585,6 +586,9 @@ End Linking. Section Correctness. + Arguments interp_imp : simpl never. + Arguments interp_asm : simpl never. + (** Correctness of expressions. We strengthen [bisimilar]: initial environments are still related by [Renv], @@ -600,11 +604,10 @@ Section Correctness. (interp_imp (denote_expr e) g_imp) (interp_asm (denote_list (compile_expr n e)) g_asm l). Proof. - induction e; simpl; intros. + induction e; cbn; intros. - (* Var case *) (* We first compute and eliminate taus on both sides. *) - force_left. - rewrite tau_eutt. + unfold interp_imp, interp_asm, interp_map, State.interp_state. tau_steps. @@ -619,6 +622,7 @@ Section Correctness. - (* Literal case *) (* We reduce both sides to Ret constructs *) + unfold interp_imp, interp_asm, interp_map, State.interp_state. tau_steps. red; rewrite <-eqit_Ret. @@ -637,7 +641,7 @@ Section Correctness. eapply eutt_clo_bind. { eapply IHe1; assumption. } (* We obtain new related environments *) - intros [g_imp' v] [g_asm' [l' []]] HSIM. + intros [g_imp' v] [[[] l'] g_asm'] HSIM. (* The Induction hypothesis on [e2] relates the second itrees *) rewrite interp_asm_bind. rewrite interp_imp_bind. @@ -645,8 +649,9 @@ Section Correctness. { eapply IHe2. eapply sim_rel_Renv; eassumption. } (* And we once again get new related environments *) - intros [g_imp'' v'] [g_asm'' [l'' []]] HSIM'. + intros [g_imp'' v'] [[[] l''] g_asm''] HSIM'. (* We can now reduce down to Ret constructs that remains to be related *) + unfold interp_imp, interp_asm, interp_map, State.interp_state. tau_steps. red. rewrite <- eqit_Ret. @@ -665,7 +670,7 @@ Section Correctness. eapply eutt_clo_bind. { eapply IHe1; assumption. } (* We obtain new related environments *) - intros [g_imp' v] [g_asm' [l' []]] HSIM. + intros [v g_imp'] [[[] l'] g_asm'] HSIM. (* The Induction hypothesis on [e2] relates the second itrees *) rewrite interp_asm_bind. rewrite interp_imp_bind. @@ -673,8 +678,9 @@ Section Correctness. { eapply IHe2. eapply sim_rel_Renv; eassumption. } (* And we once again get new related environments *) - intros [g_imp'' v'] [g_asm'' [l'' []]] HSIM'. + intros [g_imp'' v'] [[[] l''] g_asm''] HSIM'. (* We can now reduce down to Ret constructs that remains to be related *) + unfold interp_imp, interp_asm, interp_map, State.interp_state. tau_steps. red. rewrite <- eqit_Ret. @@ -693,7 +699,7 @@ Section Correctness. eapply eutt_clo_bind. { eapply IHe1; assumption. } (* We obtain new related environments *) - intros [g_imp' v] [g_asm' [l' []]] HSIM. + intros [v g_imp'] [[[] l'] g_asm'] HSIM. (* The Induction hypothesis on [e2] relates the second itrees *) rewrite interp_asm_bind. rewrite interp_imp_bind. @@ -701,8 +707,9 @@ Section Correctness. { eapply IHe2. eapply sim_rel_Renv; eassumption. } (* And we once again get new related environments *) - intros [g_imp'' v'] [g_asm'' [l'' []]] HSIM'. + intros [g_imp'' v'] [[[] l''] g_asm''] HSIM'. (* We can now reduce down to Ret constructs that remain to be related *) + unfold interp_imp, interp_asm, interp_map, State.interp_state. tau_steps. red. rewrite <- eqit_Ret. @@ -735,10 +742,11 @@ Section Correctness. { eapply compile_expr_correct; eauto. } (* Once again, we get related environments *) - intros [g_imp' v] [g_asm' [l' y]] HSIM. + intros [g_imp' v] [[y l'] g_asm'] HSIM. simpl in HSIM. (* We can now reduce to Ret constructs *) + unfold interp_imp, interp_asm, interp_map, State.interp_state. tau_steps. red. rewrite <- eqit_Ret. @@ -809,6 +817,8 @@ Section Correctness. Notation Inr_Kleisli := Inr_Kleisli. + Arguments State.interp_state : simpl nomatch. + (** Correctness of the compiler. After interpretation of the [Locals], the source _Imp_ statement denoted as an [itree] and the compiled _Asm_ program denoted @@ -840,6 +850,7 @@ Section Correctness. intros []; simpl. repeat intro. + unfold interp_imp, interp_asm, interp_map, State.interp_state. force_left; force_right. Transparent eutt. red. rewrite <- eqit_Ret; auto. @@ -871,19 +882,19 @@ Section Correctness. { apply compile_expr_correct; auto. } (* We get in return [sim_rel] related environments *) - intros [g_imp' v] [g_asm' [l' x]] HSIM. + intros [v g_imp'] [[x l'] g_asm'] HSIM. (* We know that interpreting [GetVar tmp_if] is eutt to [Ret (g_asm,v)] *) generalize HSIM; intros EQ. eapply sim_rel_get_tmp0 in EQ. unfold tmp_if. rewrite interp_asm_bind. rewrite EQ; clear EQ. - rewrite bind_ret_; simpl. + rewrite bind_ret_; cbn. (* We can weaken [sim_rel] down to [Renv] *) apply sim_rel_Renv in HSIM. (* And finally conclude in both cases *) - destruct v; simpl; auto. + destruct v; cbn; auto. - (* While *) (* We commute [denote_asm] with [while_asm], and restructure the @@ -909,6 +920,7 @@ Section Correctness. 2:{ repeat intro. unfold to_bif, ToBifunctor_ktree_fin. rewrite !bind_ret_l. cbn. force_left. force_right. + unfold interp_imp, interp_asm. cbn. red; rewrite <- eqit_Ret; auto. unfold state_invariant. simpl. split; auto. @@ -926,7 +938,7 @@ Section Correctness. eapply eutt_clo_bind. { apply compile_expr_correct; auto. } - intros [g_imp' v] [g_asm' [l' x]] HSIM. + intros [v g_imp'] [[x l'] g_asm'] HSIM. rewrite !interp_asm_bind. rewrite !bind_bind. @@ -940,8 +952,9 @@ Section Correctness. (* We can weaken [sim_rel] down to [Renv] *) apply sim_rel_Renv in HSIM. (* And now consider both cases *) - destruct v; simpl; auto. + destruct v; cbn; auto. + (* The false case is trivial *) + unfold interp_imp, interp_asm. force_left; force_right. red. rewrite <- eqit_Ret. @@ -954,7 +967,8 @@ Section Correctness. rewrite !bind_bind. eapply eutt_clo_bind. { eapply IHs; auto. } - intros [g_imp'' v''] [g_asm'' [l'' x']] [HSIM' ?]. + intros [v'' g_imp''] [[x' l''] g_asm''] [HSIM' ?]. + unfold interp_imp, interp_asm. force_right; force_left. apply eqit_Ret. setoid_rewrite split_fin_sum_L_L_f1. diff --git a/tutorial/Introduction.v b/tutorial/Introduction.v index f76ab318..e051994c 100644 --- a/tutorial/Introduction.v +++ b/tutorial/Introduction.v @@ -16,7 +16,9 @@ From Coq Require Import From ExtLib Require Import Monad Traversable - Data.List. + Data.List + Structures.MonadState + Data.Monads.StateMonad. From ITree Require Import Simple. @@ -25,6 +27,8 @@ Import ListNotations. Import ITreeNotations. Import MonadNotation. Open Scope monad_scope. + +Existing Instance Monad_stateT. (* end hide *) (** * Events *) @@ -70,25 +74,25 @@ Definition write_one : itree ioE unit := - [void1] is the empty event (so the resulting ITree can trigger no event). *) -Compute Monads.stateT (list nat) (itree void1) unit. +Compute stateT (list nat) (itree void1) unit. Print void1. Definition handle_io - : forall R, ioE R -> Monads.stateT (list nat) (itree void1) R - := fun R e log => + : forall R, ioE R -> stateT (list nat) (itree void1) R + := fun R e => match e with - | Input => ret (log, [0]) - | Output o => ret (log ++ o, tt) + | Input => log <- get ;; ret [0] + | Output o => log <- get ;; put (log ++ o) ;; ret tt end. (** [interp] lifts any handler into an _interpreter_, of type [forall R, itree ioE R -> M R]. *) Definition interp_io - : forall R, itree ioE R -> itree void1 (list nat * R) - := fun R t => Monads.run_stateT (interp handle_io t) []. + : forall R, itree ioE R -> itree void1 (R * list nat) + := fun R t => runStateT (interp handle_io t) []. (** We can now interpret [write_one]. *) -Definition interpreted_write_one : itree void1 (list nat * unit) +Definition interpreted_write_one : itree void1 (unit * list nat) := interp_io _ write_one. (** Intuitively, [interp_io] replaces every [ITree.trigger] in the diff --git a/tutorial/extract-imptest/ImpTest.v b/tutorial/extract-imptest/ImpTest.v index 468876eb..37cf8ba2 100644 --- a/tutorial/extract-imptest/ImpTest.v +++ b/tutorial/extract-imptest/ImpTest.v @@ -18,7 +18,7 @@ Fixpoint run {A} (n : nat) (t : itree void1 A) : option A := end. Definition run_ (n : N) (s : stmt) : option env := - option_map fst (run (N.to_nat n) (eval_imp s)). + option_map snd (run (N.to_nat n) (eval_imp s)). Require Extraction. Require ExtrOcamlBasic. From 2c6f53d22e76209b6df52baccec0b3d6f3342216 Mon Sep 17 00:00:00 2001 From: Justin Frank Date: Mon, 7 Oct 2024 21:57:49 -0400 Subject: [PATCH 7/8] Update introduction --- examples/IntroductionSolutions.v | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/examples/IntroductionSolutions.v b/examples/IntroductionSolutions.v index 3a6f9d89..2b9e64bb 100644 --- a/examples/IntroductionSolutions.v +++ b/examples/IntroductionSolutions.v @@ -21,7 +21,8 @@ From Coq Require Import From ExtLib Require Import Monad Traversable - Data.List. + Data.List + Data.Monads.StateMonad. From ITree Require Import Simple. @@ -30,6 +31,8 @@ Import ListNotations. Import ITreeNotations. Import MonadNotation. Open Scope monad_scope. + +Existing Instance Monad_stateT. (* end hide *) (** * Events *) @@ -75,25 +78,25 @@ Definition write_one : itree ioE unit := - [void1] is the empty event (so the resulting ITree can trigger no event). *) -Compute Monads.stateT (list nat) (itree void1) unit. +Compute stateT (list nat) (itree void1) unit. Print void1. Definition handle_io - : forall R, ioE R -> Monads.stateT (list nat) (itree void1) R - := fun R e log => - match e with - | Input => ret (log, [0]) - | Output o => ret (log ++ o, tt) - end. + : forall R, ioE R -> stateT (list nat) (itree void1) R + := fun R e => mkStateT (fun log => + match e in ioE R return itree void1 (R * list nat) with + | Input => ret ([0], log) + | Output o => ret (tt, log ++ o) + end). (** [interp] lifts any handler into an _interpreter_, of type [forall R, itree ioE R -> M R]. *) Definition interp_io - : forall R, itree ioE R -> itree void1 (list nat * R) - := fun R t => Monads.run_stateT (interp handle_io t) []. + : forall R, itree ioE R -> itree void1 (R * list nat) + := fun R t => runStateT (interp handle_io t) []. (** We can now interpret [write_one]. *) -Definition interpreted_write_one : itree void1 (list nat * unit) +Definition interpreted_write_one : itree void1 (unit * list nat) := interp_io _ write_one. (** Intuitively, [interp_io] replaces every [ITree.trigger] in the From 5c3ee42a79e891582a3421f602dc9837bdad4bcd Mon Sep 17 00:00:00 2001 From: Justin Frank Date: Mon, 9 Dec 2024 16:05:34 -0500 Subject: [PATCH 8/8] IN PROGRESS --- extra/Dijkstra/DelaySpecMonad.v | 2 +- hoare_example/Imp.v | 19 ++++--- hoare_example/ImpHoare.v | 98 ++++++++++++++++----------------- theories/Events/StateFacts.v | 17 ++++++ 4 files changed, 77 insertions(+), 59 deletions(-) diff --git a/extra/Dijkstra/DelaySpecMonad.v b/extra/Dijkstra/DelaySpecMonad.v index 46ee81b5..2d6e690d 100644 --- a/extra/Dijkstra/DelaySpecMonad.v +++ b/extra/Dijkstra/DelaySpecMonad.v @@ -58,7 +58,7 @@ Delimit Scope delayspec_scope with delayspec. Notation "a ∈ b" := (proj1_sig (A := _ -> Prop) b a) (at level 70) : delayspec_scope. Notation "a ∋ b" := (proj1_sig (A := _ -> Prop) a b) (at level 70, only parsing) : delayspec_scope. -Definition Delay (A : Type) := itree void1 A. +Definition Delay := itree void1. #[global] Instance EqMDelay : Eq1 Delay := @ITreeMonad.Eq1_ITree void1. #[global] Instance MonadDelay : Monad Delay := @Monad_itree void1. diff --git a/hoare_example/Imp.v b/hoare_example/Imp.v index de66c30e..3cb83799 100644 --- a/hoare_example/Imp.v +++ b/hoare_example/Imp.v @@ -107,6 +107,7 @@ Module ImpNotations. if b then BTrue else BFalse. Coercion bool_to_bexp : bool >-> bexp. + Declare Scope imp_scope. Bind Scope imp_scope with aexp. Bind Scope imp_scope with bexp. Delimit Scope imp_scope with imp. @@ -361,7 +362,8 @@ From ITree Require Import From ExtLib Require Import Core.RelDec Structures.Maps - Data.Map.FMapAList. + Data.Map.FMapAList + Data.Monads.StateMonad. (* end hide *) (** We provide an _ITree event handler_ to interpret away [ImpState] events. We @@ -402,8 +404,8 @@ Definition interp_imp {E A} (t : itree (ImpState +' E) A) : stateT env (itree E interp_map t'. -Definition eval_imp (s: com) : itree void1 (env * unit) := - interp_imp (denote_com s) empty. +Definition eval_imp (s: com) : itree void1 (unit * env) := + runStateT (interp_imp (denote_com s)) empty. (** Equipped with this evaluator, we can now compute. Naturally since Coq is total, we cannot do it directly inside of it. @@ -433,20 +435,19 @@ Section InterpImpProperties. (** This interpreter is compatible with the equivalence-up-to-tau. *) Global Instance eutt_interp_imp {R}: - Proper (@eutt E R R eq ==> eq ==> @eutt E' (prod (env) R) (prod _ R) eq) + Proper (@eutt E R R eq ==> eq_stateT (@eutt E' (prod R (env)) (prod R _) eq)) interp_imp. Proof. repeat intro. unfold interp_imp. - unfold interp_map. - rewrite H0. eapply eutt_interp_state_eq; auto. - rewrite H. reflexivity. + rewrite H. + reflexivity. Qed. (** [interp_imp] commutes with [bind]. *) Lemma interp_imp_bind: forall {R S} (t: itree E R) (k: R -> itree E S) (g : env), - (interp_imp (ITree.bind t k) g) - ≅ (ITree.bind (interp_imp t g) (fun '(g', x) => interp_imp (k x) g')). + runStateT (interp_imp (ITree.bind t k)) g + ≅ ITree.bind (runStateT (interp_imp t) g) (fun '(x, g') => runStateT (interp_imp (k x)) g'). Proof. intros. unfold interp_imp. diff --git a/hoare_example/ImpHoare.v b/hoare_example/ImpHoare.v index b689f31b..35810656 100644 --- a/hoare_example/ImpHoare.v +++ b/hoare_example/ImpHoare.v @@ -1,10 +1,12 @@ From Coq Require Import Arith Lia (* nia *) Morphisms + Classes.RelationClasses . From ExtLib Require Import Data.String + Data.Monads.StateMonad . From ITree Require Import @@ -40,15 +42,15 @@ Definition denote_imp (c : com) : stateT env Delay unit := interp_imp (denote_com c). Definition hoare_triple (P Q : env -> Prop) (c : com) : Prop := - forall (s s' :env), P s -> (denote_imp c s ≈ ret (s',tt)) -> Q s'. + forall (s s' :env), P s -> (runStateT (denote_imp c) s ≈ runStateT (ret tt) s') -> Q s'. -Definition lift_imp_post (P : env -> Prop) : Delay (env * unit) -> Prop := - fun (t : Delay (env * unit) ) => (exists (s : env), ret (s, tt) ≈ t /\ P s). +Definition lift_imp_post (P : env -> Prop) : Delay (unit * env) -> Prop := + fun (t : Delay (unit * env) ) => (exists (s : env), ret (tt,s) ≈ t /\ P s). Notation "{{ P }} c {{ Q }}" := (hoare_triple P Q c) (at level 70). Definition is_bool (E : Type -> Type) (bc : bool) (be : bexp) (s : env) : Prop := - @interp_imp E bool (denote_bexp be) s ≈ ret (s, bc). + runStateT (@interp_imp E bool (denote_bexp be)) s ≈ ret (bc, s). Definition is_true (b : bexp) (s : env) : Prop := is_bool void1 true b s. @@ -61,12 +63,13 @@ Ltac unf_intep := unfold interp_imp, interp_map, interp_state, interp, Basics.it *) Lemma aexp_term : forall (E : Type -> Type) (ae : aexp) (s : env), - exists (n : nat), @interp_imp void1 _ (denote_aexp ae) s ≈ Ret (s,n). + exists (n : nat), runStateT (@interp_imp void1 _ (denote_aexp ae)) s ≈ Ret (n,s). Proof. intros. induction ae. - exists n. cbn. tau_steps. reflexivity. (*getvar case, extract to a lemma*) - - cbn. exists (lookup_default x 0 s). + - unfold interp_imp, interp_map, interp_state. + cbn. exists (lookup_default x 0 s). tau_steps. reflexivity. - basic_solve. exists (n0 + n)%nat. cbn. setoid_rewrite interp_imp_bind. rewrite IHae1. @@ -83,7 +86,7 @@ Proof. Qed. Lemma bools_term : forall (be : bexp) (s : env), - exists (bc : bool), @interp_imp void1 _ (denote_bexp be) s ≈ Ret (s,bc). + exists (bc : bool), runStateT (@interp_imp void1 _ (denote_bexp be)) s ≈ Ret (bc,s). Proof. intros. induction be. - exists true. cbn. unfold interp_imp, interp_map, interp_state. repeat rewrite interp_ret. @@ -120,11 +123,12 @@ Lemma hoare_seq : forall (c1 c2 : com) (P Q R : env -> Prop), {{P}} c1 {{Q}} -> {{P}} c1 ;;; c2 {{R}}. Proof. unfold hoare_triple. intros c1 c2 P Q R Hc1 Hc2 s s' Hs Hs'. - unfold denote_imp in Hs'. cbn in Hs'. rewrite interp_imp_bind in Hs'. + unfold denote_imp in Hs'. cbn in Hs'. unfold Delay in Hs'. rewrite interp_imp_bind in Hs'. fold (denote_imp c1) in Hs'. fold (denote_imp c2) in Hs'. - destruct (eutt_reta_or_div (denote_imp c1 s) ); basic_solve. - - destruct a as [s'' [] ]. rewrite <- H in Hs'. setoid_rewrite bind_ret_l in Hs'. symmetry in H. + destruct (eutt_reta_or_div (runStateT (denote_imp c1) s) ); basic_solve. + - destruct a as [[] s'']. rewrite <- H in Hs'. setoid_rewrite bind_ret_l in Hs'. symmetry in H. eapply Hc2; eauto. + eapply Hc1; eauto. - apply div_spin_eutt in H. rewrite H in Hs'. rewrite <- spin_bind in Hs'. symmetry in Hs'. exfalso. eapply not_ret_eutt_spin. eauto. Qed. @@ -137,27 +141,28 @@ Proof. unfold hoare_triple. intros c1 c2 b P Q Hc1 Hc2 s s' Hs. unfold denote_imp. cbn. destruct (classic_bool b s). - - unfold is_true, is_bool in H. rewrite interp_imp_bind. + - unfold is_true, is_bool in H. unfold Delay. rewrite interp_imp_bind. rewrite H. setoid_rewrite bind_ret_l. apply Hc1. auto. - - unfold is_false, is_bool in H. rewrite interp_imp_bind. + - unfold is_false, is_bool in H. unfold Delay. rewrite interp_imp_bind. rewrite H. setoid_rewrite bind_ret_l. apply Hc2. auto. Qed. Definition app {A B : Type} (f : A -> B) (a : A) := f a. -Definition run_state_itree {A S : Type} {E : Type -> Type} (s : S) (m : stateT S (itree E) A ) : itree E (S * A) := - m s. +Definition run_state_itree {A S : Type} {E : Type -> Type} (s : S) (m : stateT S (itree E) A ) : itree E (A * S) := + runStateT m s. -Global Instance EqStateEq {S R: Type} {E : Type -> Type} : Equivalence (@state_eq E R S). +Global Instance EqStateEq {S R: Type} {E : Type -> Type} : Equivalence (@eq_stateT S (itree E) R (eq_itree eq)). Proof. constructor; repeat intro. - reflexivity. - - unfold state_eq in H. symmetry. auto. - - unfold state_eq in *. rewrite H. auto. + - unfold eq_stateT in H. symmetry. auto. + - unfold eq_stateT in *. rewrite H. auto. Qed. +(* Global Instance run_state_proper_eq_itree {E : Type -> Type} {S R : Type} {s : S} : - Proper (@state_eq E S R ==> eq_itree eq) (@run_state_itree R S E s). + Proper (@eq_stateT S (itree E) R (eq_itree eq)) (@run_state_itree R S E s). Proof. repeat intro. unfold run_state_itree. unfold state_eq in H. rewrite H. reflexivity. Qed. @@ -174,7 +179,7 @@ Global Instance eutt_proper_under_interp_state Proof. repeat intro. unfold interp_state. rewrite H. reflexivity. Qed. - +*) (* Check (case_ (handle_map (V := value) pure_state ) ). @@ -191,8 +196,8 @@ Section interp_state_eq_iter. Context (a : A). - Lemma interp_state_eq_iter : state_eq (interp_state f (ITree.iter g a) ) - (MonadIter_stateT0 _ _ (fun a0 => interp_state f (g a0)) a). + Lemma interp_state_eq_iter : eq_stateT (eq_itree eq) (interp_state f (ITree.iter g a)) + (MonadIter_stateT _ _ (fun a0 => interp_state f (g a0)) a). Proof. unfold ITree.iter, Iter_Kleisli, Basics.iter, MonadIter_itree. eapply interp_state_iter; reflexivity. @@ -202,54 +207,49 @@ End interp_state_eq_iter. Set Default Timeout 15. Global Instance proper_state_eq_iter {S: Type} : - Proper (@state_eq void1 S (unit + unit) ==> @state_eq void1 S (unit) ) (fun body => @MonadIter_stateT0 Delay S _ _ unit unit (fun _ : unit => body) tt ). + Proper (@eq_stateT S (itree void1) (unit + unit) (eq_itree eq) ==> @eq_stateT S (itree void1) unit (eq_itree eq)) + (fun body => @MonadIter_stateT Delay S _ _ unit unit (fun _ : unit => body) tt ). Proof. repeat intro. - unfold MonadIter_stateT0, Basics.iter, MonadIterDelay. eapply eq_itree_iter. - repeat intro. subst. destruct y0 as [s' [] ]. + unfold MonadIter_stateT, Basics.iter, MonadIterDelay. eapply eq_itree_iter. + repeat intro. subst. destruct y0 as [[] s']. simpl. specialize (H s'). rewrite H. reflexivity. Qed. Lemma interp_state_bind_state : forall (E F : Type -> Type) (A B S : Type) - (h : forall T : Type, E T -> S -> itree F (S * T) ) (t : itree E A) + (h : forall T : Type, E T -> stateT S (itree F) T) (t : itree E A) (k : A -> itree E B), - state_eq (interp_state h (ITree.bind t k)) - (bind (interp_state h t) (fun a => interp_state h (k a) ) ). + eq_stateT (eq_itree eq) + (interp_state h (ITree.bind t k)) + (bind (interp_state h t) (fun a => interp_state h (k a))). Proof. - unfold state_eq. intros. eapply interp_state_bind. + Locate interp_state_bind'. + unfold eq_stateT. intros. eapply interp_state_bind'. Qed. Definition state_eq2 {E : Type -> Type} {A B S : Type} (k1 k2 : A -> stateT S (itree E) B ) : Prop := - forall a, state_eq (k1 a) (k2 a). - -Lemma eq_itree_clo_bind {E : Type -> Type} {R1 R2 : Type} : - forall (RR : R1 -> R2 -> Prop) (U1 U2 : Type) (UU : U1 -> U2 -> Prop) - (t1 : itree E U1) (t2 : itree E U2) - (k1 : U1 -> itree E R1) (k2 : U2 -> itree E R2), - eq_itree UU t1 t2 -> - (forall (u1 : U1) (u2 : U2), UU u1 u2 -> eq_itree RR (k1 u1) (k2 u2) ) -> - eq_itree RR (ITree.bind t1 k1) (ITree.bind t2 k2). -Proof. - intros. unfold eq_itree in *. eapply eqit_bind'; eauto. -Qed. - + forall a, eq_stateT (eq_itree eq) (k1 a) (k2 a). Global Instance bind_state_eq2 {E : Type -> Type} {A B S : Type} {m : stateT S (itree E) A} : - Proper (@state_eq2 E A B S ==> @state_eq E S B) (bind m). + Proper (@state_eq2 E A B S ==> @eq_stateT S (itree E) B (eq_itree eq)) (bind m). Proof. - repeat intro. unfold state_eq2, state_eq in H. cbn. + repeat intro. unfold state_eq2, eq_stateT in H. cbn. eapply eq_itree_clo_bind; try reflexivity. intros. subst. destruct u2 as [s' a]. simpl. rewrite H. reflexivity. Qed. (*can actually make this nicer*) -Lemma compile_while : forall (b : bexp) (c : com), - ((denote_imp ( WHILE b DO c END )) ≈ MonadIter_stateT0 unit unit - (fun _ : unit => bind (interp_imp (denote_bexp b)) - (fun b : bool => if b - then bind (denote_imp c) (fun _ : unit => interp_imp (Ret (inl tt)) ) - else interp_imp (Ret (inr tt))) ) tt)%monad. +Lemma compile_while : + forall (b : bexp) (c : com), + (denote_imp ( WHILE b DO c END ) ≈ + MonadIter_stateT unit unit + (fun _ : unit => + bind (interp_imp (denote_bexp b)) + (fun b : bool => + if b + then denote_imp c ;; interp_imp (Ret (inl tt)) + else interp_imp (Ret (inr tt)))) tt)%monad. Proof. intros. simpl. unfold denote_imp. simpl. unfold while. unfold interp_imp at 1, interp_map at 1. cbn. red. red. intros. symmetry. diff --git a/theories/Events/StateFacts.v b/theories/Events/StateFacts.v index ee120b50..e5f393ff 100644 --- a/theories/Events/StateFacts.v +++ b/theories/Events/StateFacts.v @@ -172,6 +172,23 @@ Proof. auto with paco. Qed. +Lemma interp_state_bind' {E F : Type -> Type} {A B S : Type} + (f : forall T, E T -> stateT S (itree F) T) + (t : itree E A) (k : A -> itree E B) + (s : S) : + (runStateT (interp_state f (t >>= k)) s) + ≅ + (runStateT (interp_state f t) s >>= fun '(t, s) => runStateT (interp_state f (k t)) s). +Proof. + rewrite interp_state_bind. + apply eq_itree_clo_bind with (UU := eq). + - reflexivity. + - intros. + subst. + destruct u2. + reflexivity. +Qed. + #[global] Instance eutt_interp_state {E F: Type -> Type} {S : Type} (h : E ~> stateT S (itree F)) R RR :