Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 15 additions & 17 deletions Algorithm/Data/MutableQuotient.lean
Original file line number Diff line number Diff line change
Expand Up @@ -34,45 +34,43 @@ namespace MutableQuotient
def mk (m : α → M) (a : α) : MutableQuotient α m := Mutable.mk ⟦a⟧

@[inline]
def get (x : MutableQuotient α m) (f : α → β) (hf : ∀ a₁ a₂, m a₁ = m a₂ → f a₁ = f a₂) : β :=
def lift (x : MutableQuotient α m) (f : α → β) (hf : ∀ a₁ a₂, m a₁ = m a₂ → f a₁ = f a₂) : β :=
(Mutable.get x).lift f fun _ _ h ↦ hf _ _ (by exact h)

@[simp]
lemma mk_get (m : α → M) (a : α) (f : α → β) (hf : ∀ a₁ a₂, m a₁ = m a₂ → f a₁ = f a₂) :
(mk m a).get f hf = f a :=
rfl
lemma mk_lift (m : α → M) (a : α) (f : α → β) (hf : ∀ a₁ a₂, m a₁ = m a₂ → f a₁ = f a₂) :
(mk m a).lift f hf = f a := by
-- rfl
simp [mk, lift]

@[elab_as_elim, induction_eliminator]
lemma ind {motive : MutableQuotient α m → Prop} (h : ∀ (a : α), motive (mk m a)) (x) : motive x :=
Quotient.ind (motive := fun x ↦ motive (Mutable.mk x)) h (Mutable.get x)
-- Quotient.ind (motive := fun x ↦ motive (Mutable.mk x)) h (Mutable.get x)
Mutable.rec (Quotient.ind h) x

@[inline]
def getMkEq (x : MutableQuotient α m) (f : ∀ a : α, m a = x.get m (fun _ _ ↦ id) → β)
def liftOnMkEq (x : MutableQuotient α m) (f : ∀ a : α, m a = x.lift m (fun _ _ ↦ id) → β)
(hf : ∀ a₁ ha₁ a₂ ha₂, f a₁ ha₁ = f a₂ ha₂) : β :=
(Mutable.get x).liftOnMkEq (fun a ha ↦ f a (congr_arg (Quotient.lift m _) ha))
(fun _ _ _ _ ↦ hf _ _ _ _)

@[inline]
def map (x : MutableQuotient α m) (f : α → α) (hf : ∀ a₁ a₂, m a₁ = m a₂ → m (f a₁) = m (f a₂)) :
MutableQuotient α m :=
get x (fun a ↦ mk m (f a)) (Mutable.ext <| Quotient.sound <| hf · · ·)

-- def getModify (x : MutableQuotient α m) (f : α → β) (hf : ∀ a₁ a₂, m a₁ = m a₂ → f a₁ = f a₂)
-- (r : α → α) (hr : ∀ a, m (r a) = m a) : β :=
-- (Mutable.getModify x (Quotient.map' r (fun a₁ a₂ h ↦ (hr a₁).trans h |>.trans (hr a₂).symm))
-- (Quotient.ind fun a ↦ Quotient.sound (hr a))).lift f fun _ _ h ↦ hf _ _ (by exact h)
lift x (fun a ↦ mk m (f a)) (Mutable.mk_inj.mpr <| Quotient.sound <| hf · · ·)

@[inline]
def getModify (x : MutableQuotient α m) (fr : α → β × α)
def liftModify (x : MutableQuotient α m) (fr : α → β × α)
(hf : ∀ a₁ a₂, m a₁ = m a₂ → (fr a₁).fst = (fr a₂).fst) (hr : ∀ a, m (fr a).snd = m a) : β :=
Mutable.getModify x (Quotient.lift (let ⟨x, y⟩ := fr ·; (x, ⟦y⟧)) (fun a₁ a₂ h ↦
Mutable.getModify x (Quotient.lift (let ⟨x, y⟩ := fr ·; (x, ⟦y⟧)) (fun a₁ a₂ h ↦
Prod.ext (hf a₁ a₂ h) <| Quotient.sound <| (hr a₁).trans <| h.trans (hr a₂).symm))
(Quotient.ind fun a ↦ Quotient.sound (hr a))

@[simp]
lemma mk_getModify (m : α → M) (a : α) (fr : α → β × α)
lemma mk_liftModify (m : α → M) (a : α) (fr : α → β × α)
(hf : ∀ a₁ a₂, m a₁ = m a₂ → (fr a₁).fst = (fr a₂).fst) (hr : ∀ a, m (fr a).snd = m a) :
(mk m a).getModify fr hf hr = (fr a).fst :=
rfl
(mk m a).liftModify fr hf hr = (fr a).fst := by
-- rfl
simp [mk, liftModify]

end MutableQuotient
4 changes: 2 additions & 2 deletions Algorithm/Data/UnionFind.lean
Original file line number Diff line number Diff line change
Expand Up @@ -457,7 +457,7 @@ instance [OfFn P ι ι id] [OfFn S ι ℕ (fun _ ↦ 1)] : Inhabited (UnionFind

@[inline]
def find (self : UnionFind ι P S) (i : ι) : ι :=
MutableQuotient.getModify self (fun x ↦ x.find i) (by simp (config := { contextual := true }))
MutableQuotient.liftModify self (fun x ↦ x.find i) (by simp (config := { contextual := true }))
(by simp)

def IsRoot (self : UnionFind ι P S) (i : ι) : Prop := self.find i = i
Expand All @@ -477,7 +477,7 @@ def union (self : UnionFind ι P S) (i j : ι) : UnionFind ι P S :=

@[inline]
def size (self : UnionFind ι P S) (i : ι) (hi : self.IsRoot i) : ℕ :=
MutableQuotient.getMkEq self
MutableQuotient.liftOnMkEq self
(fun x hx ↦ x.size i (by
induction self using MutableQuotient.ind
rw [UnionFindImpl.UnionFindWF.isRoot_iff_root, hx, ← hi]
Expand Down
76 changes: 40 additions & 36 deletions Mutable/Mutable.lean
Original file line number Diff line number Diff line change
Expand Up @@ -17,67 +17,71 @@ structure Mutable (α : Type u) : Type u where
mk ::
get : α

attribute [extern "lean_mk_Mutable"] Mutable.mk
attribute [extern "lean_Mutable_get"] Mutable.get

variable {α : Type u}

@[extern "lean_Mutable_get_with"]
protected def Mutable.getWith (x : @& Mutable α) (f : α → α) : α :=
x.get
attribute [extern "lean_st_mk_ref"] Mutable.mk
attribute [extern "lean_st_ref_get"] Mutable.get
-/

structure Mutable (α : Type u) : Type u where
private __mk__ ::
private __get__ : α
opaque MutableAux (α : Type u) : Subtype (· = α) := ⟨α, rfl⟩

def Mutable (α : Type u) : Type u := (MutableAux α).val

namespace Mutable
variable {α : Type u} {β : Type v}

@[extern "lean_mk_Mutable"]
def mk (a : α) : Mutable α := __mk__ a
@[extern "lean_st_mk_ref"]
def mk (a : α) : Mutable α := (MutableAux α).2.mpr a

@[extern "lean_st_ref_get", never_extract]
def get (x : @& Mutable α) : α := (MutableAux α).2.mp x

@[extern "lean_Mutable_get"]
def get (a : @& Mutable α) : α := __get__ a
set_option linter.unusedVariables false in
@[extern "lean_Mutable_set", never_extract]
unsafe def set (x : @& Mutable α) (a : α) (b : @& β) : β := b

@[simp] theorem mk_get (x : Mutable α) : mk x.get = x := by simp [mk, Mutable.get]
@[simp] theorem get_mk (x : α) : (mk x).get = x := by simp [mk, Mutable.get]

def rec {motive : Mutable α → Sort _} (h : ∀ a, motive (mk a)) (x : Mutable α) : motive x :=
mk_get x ▸ h _

theorem ext {x y : Mutable α} (get : x.get = y.get) : x = y :=
match x, y, get with | ⟨_⟩, ⟨_⟩, rfl => rfl
theorem ext {x y : Mutable α} (get : x.get = y.get) : x = y := by
simpa [Mutable.get] using congrArg (MutableAux α).2.mpr get

theorem ext_iff {x y : Mutable α} : x = y ↔ x.get = y.get :=
⟨congrArg get, ext⟩

@[simp]
theorem mk_eq_mk {x y : α} : mk x = mk y ↔ x = y :=
ext_iff
theorem mk_inj {x y : α} : mk x = mk y ↔ x = y :=
ext_iff.trans (by simp)

set_option linter.unusedVariables false in
@[extern "lean_Mutable_modify"]
unsafe def getModifyUnsafe (x : @& Mutable α) (f : α → α) : α :=
x.get
unsafe def modifyUnsafe (x : Mutable α) (f : α → α) : α :=
let a := f x.get; x.set a a

set_option linter.unusedVariables false in
unsafe abbrev getModifyImpl (x : Mutable α)
unsafe abbrev modifyImpl (x : Mutable α)
(f : α → α) (hf : ∀ a, f a = a) : α :=
Mutable.getModifyUnsafe x f
Mutable.modifyUnsafe x f

@[implemented_by Mutable.getModifyImpl]
def getModify (x : Mutable α)
@[implemented_by Mutable.modifyImpl]
def modify (x : Mutable α)
(f : α → α) (hf : ∀ a, f a = a) : α :=
f x.get

set_option linter.unusedVariables false in
@[extern "lean_Mutable_modify2"]
unsafe def getModify₂Unsafe (x : @& Mutable α) (f : α → β × α) : β :=
(f x.get).fst
unsafe def getModifyUnsafe (x : Mutable α) (f : α → β × α) : β :=
let (b, a) := f x.get; x.set a b

set_option linter.unusedVariables false in
unsafe abbrev getModify₂Impl (x : Mutable α)
unsafe abbrev getModifyImpl (x : Mutable α)
(f : α → β × α) (hgf : ∀ a, (f a).snd = a) : β :=
Mutable.getModify₂Unsafe x f
Mutable.getModifyUnsafe x f

@[implemented_by Mutable.getModify₂Impl]
def getModify₂ (x : Mutable α)
(f : α → β × α) (hgf : ∀ a, (f a).snd = a) : β :=
@[implemented_by Mutable.getModifyImpl]
def getModify (x : Mutable α) (f : α → β × α) (hgf : ∀ a, (f a).snd = a) : β :=
(f x.get).fst

@[simp]
theorem getModify_mk {a : α} {f : α → β × α} {hgf : ∀ a, (f a).snd = a} :
(mk a).getModify f hgf = (f a).fst := by
simp [getModify]

end Mutable
68 changes: 3 additions & 65 deletions cpp/ffi.cpp
Original file line number Diff line number Diff line change
@@ -1,68 +1,6 @@
#include <lean/lean.h>

struct Mutable {
std::atomic<lean_object *> m_value;
Mutable(lean_object * v) : m_value(v) { }
};

static void Mutable_finalize(void * o) {
lean_dec(static_cast<Mutable *>(o)->m_value);
delete static_cast<Mutable *>(o);
}

static void Mutable_foreach(void * o, b_lean_obj_arg f) {
lean_inc(f);
lean_apply_1(f, static_cast<Mutable *>(o)->m_value);
}

static lean_external_class * g_Mutable_class = nullptr;

static inline lean_object * Mutable_to_lean(Mutable * x) {
if (g_Mutable_class == nullptr) {
g_Mutable_class = lean_register_external_class(Mutable_finalize, Mutable_foreach);
}
return lean_alloc_external(g_Mutable_class, x);
}

static inline Mutable * lean_to_Mutable(b_lean_obj_arg o) {
return static_cast<Mutable *>(lean_get_external_data(o));
}

extern "C" LEAN_EXPORT lean_obj_res lean_mk_Mutable(b_lean_obj_arg o) {
return Mutable_to_lean(new Mutable(o));
}

extern "C" LEAN_EXPORT b_lean_obj_res lean_Mutable_get(b_lean_obj_arg x) {
lean_inc(lean_to_Mutable(x)->m_value);
return lean_to_Mutable(x)->m_value;
}

extern "C" LEAN_EXPORT b_lean_obj_res lean_Mutable_modify(b_lean_obj_arg x, lean_obj_arg f) {
lean_object * c = lean_to_Mutable(x)->m_value.exchange(nullptr);
while (c == nullptr) {
// std::this_thread::yield();
c = lean_to_Mutable(x)->m_value.exchange(nullptr);
}
lean_object * r = lean_apply_1(f, c);
// lean_assert(r != nullptr); /* Closure must return a valid lean object */
// lean_assert(lean_to_Mutable(x)->m_value == nullptr);
lean_inc(r);
lean_to_Mutable(x)->m_value = r;
return r;
}

extern "C" LEAN_EXPORT b_lean_obj_res lean_Mutable_modify2(b_lean_obj_arg x, lean_obj_arg f) {
lean_object * c = lean_to_Mutable(x)->m_value.exchange(nullptr);
while (c == nullptr) {
// std::this_thread::yield();
c = lean_to_Mutable(x)->m_value.exchange(nullptr);
}
lean_object * p = lean_apply_1(f, c);
// lean_assert(p != nullptr); /* Closure must return a valid lean object */
// lean_assert(lean_to_Mutable(x)->m_value == nullptr);
lean_object * r = lean_ctor_get(p, 0); lean_inc(r);
lean_object * n = lean_ctor_get(p, 1); lean_inc(n);
lean_dec_ref(p);
lean_to_Mutable(x)->m_value = n;
return r;
extern "C" LEAN_EXPORT b_lean_obj_res lean_Mutable_set(b_lean_obj_arg ref, lean_obj_arg a, b_lean_obj_arg b) {
lean_st_ref_set(ref, a);
return b;
}
6 changes: 3 additions & 3 deletions scripts/nolints.json
Original file line number Diff line number Diff line change
Expand Up @@ -113,9 +113,9 @@
["docBlame", "MinHeap.headD"],
["docBlame", "MinHeap.tail"],
["docBlame", "MultiBag.ReadOnly"],
["docBlame", "MutableQuotient.get"],
["docBlame", "MutableQuotient.getMkEq"],
["docBlame", "MutableQuotient.getModify"],
["docBlame", "MutableQuotient.lift"],
["docBlame", "MutableQuotient.liftModify"],
["docBlame", "MutableQuotient.liftOnMkEq"],
["docBlame", "MutableQuotient.map"],
["docBlame", "MutableQuotient.mk"],
["docBlame", "OfFn.ofFn"],
Expand Down
33 changes: 15 additions & 18 deletions test/Mutable/thunk1.lean
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ protected def Thunk'.pure (a : α) : Thunk' α :=
.mk fun _ ↦ a

protected def Thunk'.get (x : Thunk' α) : α :=
Mutable.getModify x (fun f ↦ let a := f (); ⟨a, fun _ ↦ a⟩) (fun _ ↦ rfl)
Mutable.getModify x (fun f ↦ let a := f (); ⟨a, fun _ ↦ a⟩) (fun _ ↦ rfl)

@[inline] protected def Thunk'.map (f : α → β) (x : Thunk' α) : Thunk' β :=
.mk fun _ => f x.get
Expand All @@ -23,46 +23,43 @@ protected def Thunk'.get (x : Thunk' α) : α :=
/-! lean4/tests/lean/thunk.lean -/

/-- info: 1 -/
#guard_msgs in #eval
(Thunk'.pure 1).get
#guard_msgs in #eval (Thunk'.pure 1).get
/-- info: 2 -/
#guard_msgs in #eval (Thunk'.mk fun _ => 2).get
/--
info: 3
4
5
---
info: 5
-/
#guard_msgs in #eval
let t1 := Thunk'.mk fun _ => dbg_trace 4; 5
-- let t2 := Thunk'.mk fun _ => dbg_trace 3; 0
-- let v2 := t2.get
let v2 := dbg_trace 3; 0

let t2 := Thunk'.mk fun _ => dbg_trace 3; 0
let v2 := t2.get
let v1 := t1.get
v1 + v2
/--
info: 6
7
8
---
info: 8
-/
#guard_msgs in #eval
let t1 := Thunk'.pure 8 |>.map fun n => dbg_trace 7; n
-- let t2 := Thunk'.mk fun _ => dbg_trace 6; 0
-- let v2 := t2.get
let v2 := dbg_trace 6; 0

let t2 := Thunk'.mk fun _ => dbg_trace 6; 0
let v2 := t2.get
let v1 := t1.get
v1 + v2

/--
info: 9
10
11
---
info: 11
-/
#guard_msgs in #eval
let t1 := Thunk'.pure 11 |>.bind fun n => dbg_trace 10; Thunk'.pure n
-- let t2 := Thunk'.mk fun _ => dbg_trace 9; 0
-- let v2 := t2.get
let v2 := dbg_trace 9; 0

let t2 := Thunk'.mk fun _ => dbg_trace 9; 0
let v2 := t2.get
let v1 := t1.get
v1 + v2
4 changes: 2 additions & 2 deletions test/Mutable/thunk2.lean
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ protected def Thunk'.pure (a : α) : Thunk' α :=
.mk fun _ ↦ a

protected def Thunk'.get (x : Thunk' α) : α :=
Mutable.getModify x (fun f ↦ let a := f (); ⟨a, fun _ ↦ a⟩) (fun _ ↦ rfl)
Mutable.getModify x (fun f ↦ let a := f (); ⟨a, fun _ ↦ a⟩) (fun _ ↦ rfl)

/-! lean4/tests/compiler/thunk.lean -/

Expand All @@ -27,4 +27,4 @@ def main : IO Unit :=
IO.println (toString (test (compute 1) 100000))

/-- info: 10000000000 -/
#guard_msgs in #eval main -- TODO: 超时中断
#guard_msgs in #eval main
5 changes: 1 addition & 4 deletions test/Mutable/thunk3.lean
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,14 @@ protected def Thunk'.pure (a : α) : Thunk' α :=
.mk fun _ ↦ a

protected def Thunk'.get (x : Thunk' α) : α :=
Mutable.getModify x (fun f ↦ let a := f (); ⟨a, fun _ ↦ a⟩) (fun _ ↦ rfl)
Mutable.getModify x (fun f ↦ let a := f (); ⟨a, fun _ ↦ a⟩) (fun _ ↦ rfl)

@[inline] protected def Thunk'.map (f : α → β) (x : Thunk' α) : Thunk' β :=
.mk fun _ => f x.get

@[inline] protected def Thunk'.bind (x : Thunk' α) (f : α → Thunk' β) : Thunk' β :=
.mk fun _ => (f x.get).get

def List.sum [OfNat α 0] [Add α] : List α → α :=
foldl (· + ·) 0

#eval show IO Unit from do
let _ : OfNat (Thunk' Nat) 0 := ⟨.pure 0⟩
let _ : Inhabited (Thunk' Nat) := ⟨.pure 0⟩
Expand Down
Loading