Library Puf

This file implements the standard union-find algorithm using adjustable references.

Require Import MyInt31 Classical List.
Require Import Vbase Veq Varith Wlog.
Require Recdef.
Require Import Aref.
Require Parray.

Set Implicit Arguments.
Unset Strict Implicit.
Local Open Scope int31.

Part 1. Axiomatization of a union-find data structure.

Module Type UNION_FIND.

  Parameter t : Type.
  Parameter create : int31 -> t.
  Parameter size : t -> int31.
  Parameter find : t -> int31 -> int31.
  Parameter union : t -> int31 -> int31 -> t.

A ranking function that returns the representative chosen by the union of two elements.
  Parameter rank_le : t -> int31 -> int31 -> bool.

The ranking function is a total order.
  Axiom rank_leC :
    forall uf a b,
      rank_le uf (find uf a) (find uf b) = false ->
      rank_le uf (find uf b) (find uf a).

Properties of size
  Axiom size_pos : forall uf, 0 <? size uf.
  Axiom size_create : forall n, size (create n) = if 0 <? n then n else 1.
  Axiom size_union : forall uf a b, size (union uf a b) = size uf.
  Axiom ltu_find_size : forall uf a, find uf a <? size uf.

Properties of find
  Axiom find_create :
    forall n x (LTU: x <? n), find (create n) x = x.
  Axiom find_find :
    forall uf x, find uf (find uf x) = find uf x.
  Axiom find_union : forall uf a b x,
  if rank_le uf (find uf a) (find uf b) then
    find (union uf a b) x = find uf x /\ find uf x <> find uf a \/
    find (union uf a b) x = find uf b /\ find uf x = find uf a
  else
    find (union uf a b) x = find uf x /\ find uf x <> find uf b \/
    find (union uf a b) x = find uf a /\ find uf x = find uf b.

union is almost commutative
  Axiom union_comm :
    forall uf a b, rank_le uf (find uf a) (find uf b) = false ->
    union uf a b = union uf b a.

Optimized special case of union
  Parameter fast_union : forall uf a (Pa: find uf a = a) b (Pb: find uf b = b), t.
  Arguments fast_union : clear implicits.

  Axiom fast_unionE : forall t a Pa b Pb,
    find t a <> find t b ->
    rank_le t (find t a) (find t b) ->
    fast_union t a Pa b Pb = union t a b.

Return the list of representatives. This is meant for specification purposes only, e.g. to assert that the set of representatives is finite.
  Parameter reps : t -> list int31.
  Axiom reps_uniq : forall t, NoDup (reps t).
  Axiom repsOK : forall t r, In r (reps t) <-> find t r = r.

  Parameter copy : t -> t.
  Axiom copyE : forall uf, copy uf = uf.

End UNION_FIND.

Part 2. Implementation of the union-find data structure.

Module Puf : UNION_FIND.

Definition get_rel (a : Parray.t int31) (x y : int31) : Prop :=
  x <> y /\ Parray.get a y = x.

Unset Implicit Arguments.

Program Fixpoint get_aux (a : Parray.t int31) (x : int31) (WF : Acc (get_rel a) x) :=
  let y := Parray.get a x in
  match y == x with
  | true => x
  | false => get_aux a y (match WF with Acc_intro WF => WF _ _ end)
  end.
Next Obligation. by destruct (eqP (Parray.get a x) x). Qed.

Program Fixpoint find_aux (a : Parray.t int31) (x : int31) (WF : Acc (get_rel a) x) :=
  let y := Parray.get a x in
  match y == x with
  | true => (a, x)
  | false =>
    let '(f, r) := @find_aux a y (match WF with Acc_intro WF => WF _ _ end) in
    (Parray.set f x r, r)
  end.
Next Obligation. by destruct (eqP (Parray.get a x) x). Qed.

Set Implicit Arguments.

Lemma get_get_aux a x (WF: Acc (get_rel a) x):
  Parray.get a (get_aux a x WF) = get_aux a x WF.
Proof.
  ins; generalize WF, WF at 2; induction WF; ins.
  by destruct WF; simpl; des_eqrefl;
     destruct WF0; simpl; des_eqrefl; try congruence;
     [|apply H0]; revert EQ; case eqP.
Qed.

Lemma find_aux1 a y WF b z :
  find_aux a y WF = (b, z) ->
  get_aux a y WF = z.
Proof.
  generalize WF; induction[b] WF; ins.
  destruct WF; simpls; do 2 des_eqrefl; desf; try congruence.
  generalize EQ; ins; revert EQ.
  case eqP; desf; intros.
   eapply H0; try done.
  etransitivity; try eapply Heq; instantiate; repeat f_equal; intros; apply proof_irrelevance.
Qed.

Hint Resolve Parray.length_pos.

Lemma length_find_aux a y WF b z :
  find_aux a y WF = (b, z) ->
  Parray.length b = Parray.length a.
Proof.
  generalize WF; induction[b z] WF; ins.
  destruct WF; ins; des_eqrefl; desf.
  generalize EQ; case eqP; ins.
  by rewrite Parray.length_set; eapply H0; eauto.
Qed.

Lemma find_aux_same_reps a y WF b z x :
  find_aux a y WF = (b, z) ->
  (forall x, Parray.get a x <? Parray.length a) ->
  Parray.get a x = x ->
  Parray.get b x = x.
Proof.
  generalize WF; induction[b z] WF; ins.
  destruct WF; ins; des_eqrefl; desf; subst b z; try done.
  rewrite Parray.gs2, (length_find_aux Heq), <- H3, H2, H3.
  desf; [by generalize EQ; case eqP; ins; subst x0|].
  eby generalize EQ; case eqP; ins; eapply H0.
Qed.

Lemma find_aux_helper1 a y WF b z :
  find_aux a y WF = (b, z) ->
  y <? Parray.length a ->
  Parray.get b y = z.
Proof.
  generalize WF; induction[b z] WF; ins.
  destruct WF; ins; des_eqrefl; desf.
    by revert EQ; case eqb31P.
  rewrite Parray.gs; desf; clarify.
  apply length_find_aux in Heq; congruence.
Qed.

Lemma find_aux_helper2 a y WF b z :
  find_aux a y WF = (b, z) ->
  (forall x, Parray.get a x <? Parray.length a) ->
  Parray.get b z = z.
Proof.
  ins; assert (X := find_aux1 H).
    eapply find_aux_same_reps; eauto.
  subst; apply get_get_aux.
Qed.

Lemma su A (a : Parray.t A) x v :
  Parray.set a x v
  = (if x <? Parray.length a then Parray.set a x v else a).
Proof.
  desf; unfold Parray.set; des_eqrefl; congruence.
Qed.

Fixpoint pow A (f : A -> A) n x :=
  match n with
    | O => x
    | S n => pow f n (f x)
  end.

Lemma pow_S A (f: A -> A) n x :
  pow f (S n) x = f (pow f n x).
Proof.
  induct[x] n.
Qed.

Lemma pow_plus A (f: A -> A) n1 n2 x :
  pow f (n1 + n2) x = pow f n1 (pow f n2 x).
Proof.
  by induction[x] n1; ins; rewrite IHn1, <- pow_S.
Qed.

Definition getaux2 a x n :=
  Parray.get a (pow (Parray.get a) n x) = pow (Parray.get a) n x /\
  (forall i, i < n -> Parray.get a (pow (Parray.get a) i x) <> pow (Parray.get a) i x).

Lemma getaux2a a x n i :
  getaux2 a x n ->
  pow (Parray.get a) i x = x ->
  (i = 0 \/ n < i)%nat.
Proof.
  destruct 1; ins.
  destruct (lenP i 0); [by destruct i; vauto|].
  destruct (lenP i n); vauto.
  assert (EQ: (n = (n - i) + i)%nat) by (rewrite subnK; auto).
  rewrite EQ, pow_plus, H1 in H.
  edestruct H0; try eapply H; eauto.
  by rewrite EQ at 2; rewrite <- (addn0 (n - i)) at 1; rewrite ltn_add2l.
Qed.

Lemma get_auxD a x WFx :
  exists n, getaux2 a x n /\ get_aux a x WFx = pow (Parray.get a) n x.
Proof.
  generalize WFx; induction WFx; destruct WFx; ins; des_eqrefl;
  generalize EQ; symmetry in EQ; [apply/eqP in EQ | apply/neqP in EQ]; ins.
   by exists O; split; vauto.
  exploit H0; vauto; ins; unfold getaux2 in *; desf.
  eexists (S _); repeat split; ins; eauto.
  destruct i; eauto.
Qed.

Lemma get_auxI a x n :
  Parray.get a (pow (Parray.get a) n x) = pow (Parray.get a) n x ->
  exists WFx, get_aux a x WFx = pow (Parray.get a) n x.
Proof.
  induction[x] n; ins.
    eexists (Acc_intro _ _); ins; des_eqrefl; try done.
    by exfalso; revert EQ; case eqP.
  edestruct IHn as [WFx M]; eauto.
  simpl; subst; desf; eexists (Acc_intro _ _); ins; des_eqrefl.
    revert EQ; case eqP; try congruence.
    by clear; induction n; ins; try rewrite e; auto.
  by rewrite <- M; f_equal; apply proof_irrelevance.
Grab Existential Variables.
  destruct 1; congruence.
  destruct 1; congruence.
Qed.

Lemma helper :
  forall a z x n y,
    Parray.get a z = z ->
    pow (Parray.get a) n y = z ->
    pow (Parray.get (Parray.set a x z)) n y = z.
Proof.
  induction n; ins; eapply IHn; eauto.
  rewrite Parray.gu, Parray.length_set; desf; rewrite Parray.gs; desf; subst; eauto.
  - revert H; generalize (pow (Parray.get a) n (Parray.get a y)) as z.
    by clear; induction n; ins; rewrite H, IHn.
  - revert H; generalize (pow (Parray.get a) n (Parray.get a y)) as z.
    by clear; induction n; ins; rewrite H, IHn.
  - by rewrite (Parray.gu a y), Heq.
Qed.

Lemma find_aux2 :
  forall a y WF b z
    (FIND: find_aux a y WF = (b, z))
    (LTU: forall x, Parray.get a x <? Parray.length a)
    (LTUy: y <? Parray.length a),
  exists n, getaux2 a y n /\ pow (Parray.get a) n y = z
    /\ Parray.get b z = z
    /\ (forall m,
          pow (Parray.get a) (n - m) (pow (Parray.get a) m y) = z /\
          pow (Parray.get b) (n - m) (pow (Parray.get a) m y) = z)
    /\ (forall x (LT: x <? Parray.length a)
                 (NIN: forall m, x <> pow (Parray.get a) m y),
        Parray.get b x = Parray.get a x).
Proof.
  intros until WF; generalize WF; induction WF; destruct WF; ins; des_eqrefl;
  revert FIND; generalize EQ; symmetry in EQ; [apply/eqP in EQ | apply/neqP in EQ]; ins; desf.
    exists O; repeat split; try rewrite sub0n; vauto; simpl;
    ins; clear - EQ;
    by induction m; ins; desf; repeat split; ins; congruence.

  exploit H0; vauto; clear H0; intro M; desf.

  assert (AUX: getaux2 a x (S n)).
    by split; [|destruct i]; try apply M.

  exists (S n); repeat (split; try done).

    rewrite Parray.gs; rewrite ?(length_find_aux Heq); desf; try done.
    by change (pow _ _ _) with (pow (Parray.get a) (S n) x); rewrite pow_S.

    by destruct m; ins; eapply M2.

    destruct m; ins.
      rewrite Parray.gss; rewrite ?(length_find_aux Heq); try done.
      eapply helper; try done.
      revert M1; generalize (pow (Parray.get a) n (Parray.get a x)); clear.
      by induction n; ins; rewrite IHn; congruence.
    by eapply helper; try eapply M2.

  ins; rewrite Parray.gso; try eby erewrite length_find_aux.
    by apply M3; try done; apply (fun m => NIN (S m)).
    by intros ->; specialize (NIN O).
Qed.

Lemma helper2:
  forall A (f: A -> A) n n' x,
    f (pow f n x) = pow f n x ->
    f (pow f n' x) = pow f n' x ->
    pow f n x = pow f n' x.
Proof.
  induction n; ins.
    by induction n'; ins; rewrite H in *; eauto.
  destruct n'; ins; eauto.
  by rewrite H0 in *; specialize (IHn O); eauto.
Qed.

Lemma find_aux_get_weak :
  forall a y WF b z (FIND: find_aux a y WF = (b, z)) x,
  exists x', Parray.get b x = Parray.get a x'.
Proof.
  ins; cut (forall x, x <? Parray.length b ->
            exists x', Parray.get b x = Parray.get a x').
    by intro H; rewrite Parray.gu; desf; eapply H.
  clear x; revert b z FIND; generalize WF; induction WF; destruct WF.
  ins; des_eqrefl; desf; vauto.
  rewrite Parray.length_set in *.
  rewrite Parray.gs; desf; subst; vauto; generalize EQ; case eqP; ins;
    try eapply H0; vauto.
  eapply find_aux1 in Heq; subst.
  edestruct (get_auxD) as (nn & _ & ->).
  destruct nn; try rewrite pow_S; vauto.
Qed.

Lemma find_aux2b :
  forall a y WF b z
    (FIND: find_aux a y WF = (b, z))
    (LTU: forall x, Parray.get a x <? Parray.length a)
    (LTUy: y <? Parray.length a) x n
    (AUX : Parray.get a (pow (Parray.get a) n x) = pow (Parray.get a) n x)
    (LTUx: x <? Parray.length a),
  exists m',
    Parray.get b (pow (Parray.get a) n x) = (pow (Parray.get a) n x) /\
    pow (Parray.get b) m' x = pow (Parray.get a) n x.
Proof.
  induction [x] n; ins.
    by exists O; ins; split; try done; eapply find_aux_same_reps; try edone; try eapply AUX.
  exploit find_aux2; eauto; intro K; desf.

  destruct (classic (exists n', x = pow (Parray.get a) n' y)) as [[n' ->]|C].

  destruct (K2 n') as (M1 & M2); try split; ins.
    exists ((n0 - n')%nat); split; ins; eauto using find_aux_same_reps.
rewrite M2.
rewrite <- pow_S, <- !pow_plus.
by apply helper2; [apply K | rewrite pow_plus, pow_S ].

  exploit IHn; try split; try edone.
  intro N; desf.
  exists (S m'); split; ins.
  rewrite K3; eauto.
Qed.

Lemma get_aux_outside_range a x
  (LTUx: (x <? Parray.length a) = false)
  (WFz: Acc (get_rel a) 0)
  (LTU: forall x, Parray.get a x <? Parray.length a) :
  exists WFx', get_aux a x WFx' = get_aux a 0 WFz.
Proof.
  destruct WFz; simpl; des_eqrefl.

  Case 1.
  eexists (Acc_intro _ (fun x y => Acc_intro _ _)); simpl; des_eqrefl.
    by revert EQ EQ0; do 2 case eqP; try done; rewrite (Parray.gu a x), LTUx; intros ->.
  assert (X := sym_eq EQ); apply/eqP in X.
  des_eqrefl; [|exfalso]; revert EQ1; case eqP; try done;
  rewrite (Parray.gu a x), LTUx, X; congruence.

  Case 2.
  assert (N: forall k1, get_rel a k1 x -> get_rel a k1 0).
    revert EQ; case eqP; try done.
    by destruct 3; subst; split; rewrite (Parray.gu a x), LTUx in *.

  eexists (Acc_intro _ (fun k1 k2 => a0 _ (N k1 k2))); simpl; des_eqrefl.
    by revert EQ0; case eqP; rewrite Parray.gu, LTUx; ins; subst; rewrite LTU in LTUx.

  match goal with |- context [get_aux _ _ ?x ] => generalize x end.
  rewrite (Parray.gu a x), LTUx; ins; f_equal; apply proof_irrelevance.
Grab Existential Variables.
  ins; eapply a0.
  destruct y, H; subst; rewrite (Parray.gu a x), LTUx in *.
  by revert EQ; case eqP; ins; rewrite e in *.
Qed.

Lemma find_aux3 a y WF b z x
  (FIND: find_aux a y WF = (b, z))
  (WFx: Acc (get_rel a) x)
  (LTU: forall x, Parray.get a x <? Parray.length a) :
  exists WFx', get_aux b x WFx' = get_aux a x WFx.
Proof.
  revert WF FIND.
  wlogC y / (y <? Parray.length a).
    destruct WF; simpl; des_eqrefl; ins; desf; vauto.
    by rewrite su, (length_find_aux Heq), H in *; eauto.
  intro LTUy; intros.
  assert (M := get_auxD WFx); desf; rewrite M0; clear M0; destruct M as [M M'].
  case_eq (x <? Parray.length a) as LTUx.
    exploit find_aux2b; eauto.
    by intros (? & ? & X); rewrite <- X in *; apply get_auxI.
  destruct n; simpls.
    by rewrite Parray.gu, LTUx in M; subst; rewrite LTU in LTUx.
  rewrite (Parray.gu a x), LTUx in *.
  destruct (find_aux2b FIND LTU LTUy (x:=0) (n:=S n)) as (? & ? & X); auto.
  simpls; rewrite <- X in *; eapply get_auxI in H; desf.
  edestruct (get_aux_outside_range (a:=b)) as [? Y];
    try rewrite (length_find_aux FIND); try edone.
    by intro w; destruct (find_aux_get_weak FIND w) as [? ->].
  rewrite <- Y in H; eauto.
Qed.


Lemma find_aux_wf:
  forall a (WF: well_founded (get_rel a)) x b y
    (FIND : find_aux a x (WF x) = (b, y))
    (LTU: forall x, Parray.get a x <? Parray.length a),
   well_founded (get_rel b).
Proof.
  intros; intro z; edestruct (find_aux3 FIND (WF z)); eauto.
Qed.

Definition closed_arr_cond length (a: Parray.t int31) :=
  Parray.length a = length
  /\ (forall x, Parray.get a x <? Parray.length a)
  /\ well_founded (get_rel a).

Definition closed_arr length := { a | closed_arr_cond length a }.

Definition get_closed n (a:closed_arr n) x :=
  let (arr, CLOS) := a in
  @get_aux arr x (proj2 (proj2 CLOS) x).

Record t_internal := mkT
  { ranks : Parray.t int31 ;
    parr : aref (@get_closed (Parray.length ranks)) }.
Definition t := t_internal.

Lemma ltu1_implies_eq0 x : x <? 1 -> x = 0.
Proof.
  intro H; apply Z.ltb_lt in H.
  change (phi 1) with 1%Z in *.
  pose proof (phi_bounded x).
  assert (X : phi x = 0%Z) by omega.
  by apply (f_equal phi_inv) in X; rewrite phi_inv_phi in *.
Qed.

Lemma ginit0 A n f :
  Parray.get (Parray.init (A:=A) n f) 0 = f 0.
Proof.
  by unfold Parray.init; des_eqrefl;
     unfold Parray.get; des_eqrefl; rewrite Parray.uginit.
Qed.

Program Definition create (i: int31) : t :=
  {| ranks := Parray.create i 0;
     parr := aref_val _ (Parray.init i id) |}.
Next Obligation.
  split; [by rewrite Parray.length_init, Parray.length_create|].
  split; intro x; unfold get_rel; [|constructor; intro y];
  rewrite Parray.gu; ins; desf; rewrite ?ginit0, Parray.length_init in *; desf.
    by rewrite Parray.ginit.
    by apply ltu1_implies_eq0 in Heq; subst; rewrite ginit0.

    by rewrite Parray.ginit in *; vauto.
    by apply ltu1_implies_eq0 in Heq; subst; rewrite ginit0 in *; vauto.
    by constructor; intros z []; rewrite ginit0; ins; desf.
    by constructor; intros z []; rewrite ginit0; ins; desf.
Qed.

Program Definition find (uf: t) : int31 -> int31 :=
  aref_getu (f:= @get_closed _)
    (fun a x => @find_aux (proj1_sig a) x (proj2 (proj2 (proj2_sig a)) x)) _ _ (parr uf).
Next Obligation.
  destruct a as (a & M & M' & M''); simpl.
  remember (find_aux a x ((proj2 (proj2 (conj M (conj M' M'')))) x)) as t; destruct t;
    symmetry in Heqt; ins.
  red; rewrite (length_find_aux Heqt).
  split; [done|]; split; [|eapply find_aux_wf; eauto].
    by intro ww; destruct (find_aux_get_weak Heqt ww) as [? ->].
Qed.
Next Obligation.
  extensionality y; destruct x as [x M]; simpl.
  match goal with |- context [get_aux _ _ ?x ] => generalize x end.
  remember (find_aux x a (proj2 (proj2 M) a)) as t; destruct t; ins.
  symmetry in Heqt.
  edestruct find_aux3 as [? H]; [eapply Heqt | apply M |].
  rewrite <- H; f_equal; apply proof_irrelevance.
Qed.
Next Obligation.
  destruct x as [x M]; simpl.
  remember (find_aux x a (proj2 (proj2 M) a)) as z; symmetry in Heqz; destruct z.
  eby erewrite find_aux1.
Qed.


Lemma findE (uf : t) v (EQ : parr uf = aref_val _ v) x :
  find uf x = get_closed v x.
Proof.
  by unfold find; rewrite aref_getuE, EQ, aref_get_val.
Qed.

Lemma findE2 (uf : t) v (EQ : parr uf = aref_val _ v) x :
  find uf x = get_aux (sval v) x (proj2 (proj2 (proj2_sig v)) x).
Proof.
  by unfold get_closed, find in *; destruct uf, v; ins; subst; rewrite aref_getuE, aref_get_val.
Qed.

Lemma get_aux_ltu :
  forall v (P: forall x, Parray.get v x <? Parray.length v) a WF,
    get_aux v a WF <? Parray.length v.
Proof.
  ins; destruct (get_auxD WF) as ([] & ? & ->); [|rewrite pow_S]; auto.
  destruct H; simpls; rewrite <- H; auto.
Qed.

Definition rank_le (uf : t) (a b : int31) : bool :=
  negb (Parray.get (ranks uf) a <? Parray.get (ranks uf) b).

Lemma get_aux_set_get_aux :
  forall v a b WF WF',
    (forall x, Parray.get v x <? Parray.length v) ->
    get_aux (Parray.set v a (get_aux v b WF)) (get_aux v b WF) WF' =
    get_aux v b WF.
Proof.
  ins; destruct WF'; simpl; des_eqrefl; try done.
  match goal with |- get_aux _ _ ?x = _ => generalize x end.
  symmetry in EQ; apply/neqP in EQ;
    instantiate; rewrite Parray.gs, get_get_aux in EQ; desf; try congruence;
    auto using get_aux_ltu.
Qed.

Lemma get_aux_outside_range2 :
  forall (v: Parray.t int31) (LTU : forall x, Parray.get v x <? Parray.length v)
    x (OUT: (x <? Parray.length v) = false) WFx,
  exists WF, get_aux v x WFx = get_aux v 0 WF.
Proof.
  ins; destruct WFx; ins; des_eqrefl.
    by symmetry in EQ; apply/eqP in EQ; rewrite <- EQ, LTU in OUT.
  match goal with |- context [get_aux _ _ ?x = _] => generalize x end.
  symmetry in EQ; apply/neqP in EQ.
  rewrite Parray.gu, OUT; intro WF'.
  eexists (Acc_intro _ _); simpl; des_eqrefl.
  destruct WF'; simpl; des_eqrefl; [|exfalso]; assert (EQ2 := EQ0);
    symmetry in EQ0; apply/eqP in EQ0; rewrite EQ0 in EQ1; congruence.
  f_equal; apply proof_irrelevance.
Grab Existential Variables.
  by destruct 1; subst.
Qed.

Lemma get_aux_upd :
  forall a b (v: Parray.t int31)
    (LTU : forall x, Parray.get v x <? Parray.length v)
    WFa WFb x WFx WF',
   get_aux (Parray.set v (get_aux v a WFa) (get_aux v b WFb)) x WF'
   = if get_aux v x WFx == get_aux v a WFa
     then get_aux v b WFb
     else get_aux v x WFx.
Proof.
  intros until x.
  wlogC x / (x <? Parray.length v).
    ins; edestruct get_aux_outside_range2 with (WFx := WFx) as [? ->]; try done.
    edestruct get_aux_outside_range2 with (WFx := WF') as [? ->]; auto;
      rewrite Parray.length_set; ins.
    by rewrite Parray.gu, Parray.length_set, Parray.gs; desf; auto using get_aux_ltu.
  intro LTUx; ins; destruct (get_auxD WFx) as [n [[A _] ->]].
  revert x LTUx A WFx WF'.
  induction n; ins; destruct WF'; simpl; des_eqrefl.

+ symmetry in EQ; apply/eqP in EQ; instantiate; rewrite Parray.gs in EQ; desf; try congruence.

+ match goal with |- get_aux _ _ ?x = _ => generalize x end.
  rewrite Parray.gs; try done;
  symmetry in EQ; apply/neqP in EQ; instantiate; rewrite Parray.gs in EQ; desf;
    try congruence; try done.
  by intros; apply get_aux_set_get_aux.

+ symmetry in EQ; apply/eqP in EQ; instantiate; rewrite Parray.gs in EQ; desf; try congruence.
    by rewrite EQ, <- Heq in *; destruct Heq0; clear - EQ;
       induction[x EQ] n; ins; rewrite EQ; eauto.
    by subst; clear; induction n; simpls; rewrite get_get_aux.
  by rewrite EQ; clear - EQ; induction[x EQ] n; ins; rewrite EQ; eauto.

+ match goal with |- get_aux _ _ ?x = _ => generalize x end.
  rewrite Parray.gs; [case eqP|done].
    intros <-; rewrite get_get_aux in *; desf.
      by intros; apply get_aux_set_get_aux.
    by destruct Heq; clear; induction n; ins; rewrite get_get_aux.
  intros; apply IHn; auto.
  case (eqP (Parray.get v x) x); ins; try congruence.
  by destruct WFx as [WFx]; apply WFx.
Qed.

Lemma union_helper1 :
 forall (uf : t) a b (r : closed_arr (Parray.length (ranks uf))),
   aref_get (parr uf) = get_closed r ->
 closed_arr_cond
   (Parray.length (Parray.set (ranks uf) (find uf b)
                     (Parray.get (ranks uf) (find uf a) + 1)))
   (Parray.set (sval r) (find uf b) (find uf a)).
Proof.
  red; ins; rewrite !Parray.length_set.
  destruct (aref_inh (parr uf)) as [v EQ].
  revert H; rewrite !(findE EQ), EQ, aref_get_val.
  destruct r as (r & M); split; [destruct M; done|]; split; ins.
    by destruct M; rewrite Parray.gs2; desf; simpls;
       rewrite (f_equal (fun f => f a) H); apply get_aux_ltu.
  rewrite (f_equal (fun f => f a) H), (f_equal (fun f => f b) H).
  simpl; generalize (proj2 (proj2 M)), (proj1 (proj2 M)); clear.
  intros WF LTU x.
  induction (WF x); constructor; destruct 1; subst.
  revert H1; rewrite Parray.gu, Parray.gs in *; try rewrite Parray.length_set;
    desf; ins; subst; try done; try (by apply H0).

  constructor; destruct 1; simpls.
  by rewrite Parray.gs, get_get_aux in *; desf; auto using get_aux_ltu.

  constructor; destruct 1; simpls.
  by rewrite Parray.gs, get_get_aux in *; desf; auto using get_aux_ltu.

  by apply H0; split; try done; rewrite Parray.gu; clarify.
Qed.

Lemma union_helper2 :
  forall (uf : t) a b n (x y : closed_arr n) z
    (X: get_closed x = aref_get (parr uf))
    (Y: get_closed y = aref_get (parr uf))
    (WFy : Acc (get_rel (Parray.set (sval y) (find uf a) (find uf b))) z)
    (WFx : Acc (get_rel (Parray.set (sval x) (find uf a) (find uf b))) z),
   get_aux (Parray.set (sval x) (find uf a) (find uf b)) z WFx =
   get_aux (Parray.set (sval y) (find uf a) (find uf b)) z WFy.
Proof.
  intros until 2; destruct (aref_inh (parr uf)) as [v EQ].
  rewrite !(findE EQ).
  rewrite EQ, aref_get_val in *.
  assert (Xa := f_equal (fun f => f a) X).
  assert (Xb := f_equal (fun f => f b) X).
  assert (Xz := f_equal (fun f => f z) X).
  assert (Ya := f_equal (fun f => f a) Y).
  assert (Yb := f_equal (fun f => f b) Y).
  assert (Yz := f_equal (fun f => f z) Y).
  clear X Y; simpls.
  rewrite <- Xb, <- Xa; intros.
  destruct x as [x Cx], y as [y Cy]; simpls.
  rewrite get_aux_upd with (WFx := proj2 (proj2 Cx) z);
    try congruence; [|by clear - Cx; destruct Cx as [? []]].
  revert WFy; rewrite Xa, Xb, Xz, <- Ya, <- Yb; intros.
  rewrite get_aux_upd with (WFx := proj2 (proj2 Cy) z);
    try congruence; [|by clear - Cy; destruct Cy as [? []]].
  by rewrite Ya, Yb, Yz.
Qed.

Program Definition union (uf : t) (a b : int31) : t :=
  let a' := find uf a in
  let b' := find uf b in
  if a' == b' then uf
  else
    let ra := Parray.get (ranks uf) a' in
    let rb := Parray.get (ranks uf) b' in
    if ra <? rb then
      {| ranks := Parray.set (ranks uf) b' (ra + 1) ;
         parr := aref_new (parr uf) (fun r PF => existT _ (Parray.set (proj1_sig r) b' a')
                                             (union_helper1 a b PF)) _ |}
    else
      {| ranks := Parray.set (ranks uf) a' (rb + 1) ;
         parr := aref_new (parr uf) (fun r PF => existT _ (Parray.set (proj1_sig r) a' b')
                                             (union_helper1 b a PF)) _ |}.
Next Obligation. by extensionality z; apply union_helper2. Qed.
Next Obligation. by extensionality z; apply union_helper2. Qed.

Definition size (uf: t) := Parray.length (ranks uf).

Theorem find_create :
  forall n x (LTU: x <? n), find (create n) x = x.
Proof.
  intros; unfold create, find; simpl.
  rewrite aref_getuE, aref_get_val; simpl.
  destruct (proj2 (proj2 (create_obligation_1 n)) x).
  simpl; des_eqrefl; try done.
  exfalso; revert EQ; rewrite Parray.ginit; try case eqP; try done.
Qed.

Theorem find_find :
  forall uf x, find uf (find uf x) = find uf x.
Proof.
  intros; destruct (aref_inh (parr uf)) as [v EQ].
  rewrite !(findE EQ); destruct uf; clear.
  simpls; destruct v as [v Cv]; simpl.
  destruct (proj2 (proj2 Cv) (get_aux v x (proj2 (proj2 Cv) x))); simpl.
  des_eqrefl; try done.
  by exfalso; revert EQ; rewrite get_get_aux; case eqP.
Qed.

Theorem rank_leC :
  forall uf a b,
    rank_le uf (find uf a) (find uf b) = false ->
    rank_le uf (find uf b) (find uf a).
Proof.
  intros until 0; generalize (find uf a), (find uf b); unfold rank_le, int31_ltu; ins.
  rewrite Z.ltb_antisym, negb_neg; apply Z.leb_le.
  apply negbFE, Z.ltb_lt in H; omega.
Qed.

Theorem union_comm :
  forall uf a b,
    rank_le uf (find uf a) (find uf b) = false ->
  union uf a b = union uf b a.
Proof.
  unfold union, rank_le; intros.
  rewrite (negbFE H), (negbTE (rank_leC H)).
  do 2 case eqP; try congruence; intros; clear.
  f_equal; f_equal; apply proof_irrelevance.
Qed.

Theorem find_union : forall uf a b x,
  if rank_le uf (find uf a) (find uf b) then
    find (union uf a b) x = find uf x /\ find uf x <> find uf a \/
    find (union uf a b) x = find uf b /\ find uf x = find uf a
  else
    find (union uf a b) x = find uf x /\ find uf x <> find uf b \/
    find (union uf a b) x = find uf a /\ find uf x = find uf b.
Proof.
  intros; case (eqP (find uf a) (find uf b)); intro X.
    by rewrite X; unfold union; case eqP; try done;
       desf; vauto; case (eqP (find uf x) (find uf b)); vauto.

  revert X; wlogC a b / (rank_le uf (find uf a) (find uf b)).
    by rewrite (rank_leC EQ); rewrite (union_comm EQ); auto.
  intros; clarify.

  destruct (aref_inh (parr uf)) as [v EQ].
  destruct (aref_inh (parr (union uf a b))) as [w EQ'].
  rewrite !(findE EQ), !(findE EQ'); simpl.

  revert EQ'; unfold union, rank_le in *. revert w; case eqP; try done; intros _.
  apply negbTE in H; rewrite H; clear H; simpls; subst.
  intros; apply (f_equal (@aref_get _ _ _)) in EQ'.
  rewrite aref_get_val in EQ'; revert EQ'.
  edestruct (aref_get_new EQ) as [? ->]; simpl.
  match goal with |- context[get_closed (exist _ _ ?y)] => generalize y end.
  revert w X; rewrite Parray.length_set, !(findE EQ); ins; rewrite <- EQ'; clear - X.
  destruct v as [v Cv]; simpl.
  rewrite get_aux_upd with (WFx := proj2 (proj2 Cv) x); desf; vauto.
  by destruct Cv as [? []].
Qed.

Program Definition fast_union (uf : t) a (Pa: find uf a = a) b (Pb: find uf b = b) : t :=
  let rb := Parray.get (ranks uf) b in
  @mkT
    (Parray.set (ranks uf) a (rb + 1))
    (aref_new (parr uf) (fun r PF => Parray.set (proj1_sig r) a b) _).
Next Obligation. by rewrite <- Pa, <- Pb; apply union_helper1. Qed.
Next Obligation.
  extensionality z; unfold get_closed.
  by do 2 match goal with |- context [get_aux _ _ ?x] => generalize x end;
     rewrite <- Pa, <- Pb; apply union_helper2.
Qed.
Arguments fast_union: clear implicits.

Theorem fast_unionE : forall uf a
  (Pa : find uf a = a) b
  (Pb : find uf b = b),
  find uf a <> find uf b ->
  rank_le uf (find uf a) (find uf b) ->
  fast_union uf a Pa b Pb = union uf a b.
Proof.
  unfold rank_le, fast_union, union; intros; case eqP; [done|intros _].
  rewrite (negbTE H0); f_equal; try congruence.
  do 2 match goal with |- context [aref_new _ _ ?x] => generalize x end.
  simpl; generalize (fast_union_obligation_1 (uf:=uf) Pa Pb),
                    (union_helper1 (uf:=uf) b a).
  destruct (aref_inh (parr uf)) as [? ->]; ins; rewrite !aref_new_val; f_equal.
  do 2 match goal with |- context [exist _ _ ?x] => generalize x end.
  rewrite Pa, Pb; ins; do 3 f_equal; apply proof_irrelevance.
Qed.

Lemma phi_m1 x : x <> 0 -> (0 <= phi (x - 1) < phi x)%Z.
Proof.
  intros; assert (phi x <> 0)%Z.
    by intro M; apply (f_equal phi_inv) in M; rewrite phi_inv_phi in *.
  pose proof (phi_bounded x).
  rewrite spec_sub, Zmod_small; change (phi 1) with 1%Z; omega.
Qed.

Function int31_iota x { measure (fun x => Zabs_nat (phi x)) x } :=
  if x == 0 then nil else x - 1 :: int31_iota (x - 1).
Proof.
  by intro x; case eqP; ins; apply Zabs_nat_lt, phi_m1.
Qed.

Definition reps (uf : t) : list int31 :=
  filter (fun x => find uf x == x) (int31_iota (size uf)).


Lemma filter_uniq :
  forall (A: eqType) (f : A -> bool) l, NoDup l -> NoDup (filter f l).
Proof.
  induction l; ins; desf; inv H; eauto.
  econstructor; eauto; rewrite filter_In; tauto.
Qed.

Lemma In_int31_iota x y :
  In x (int31_iota y) <-> x <? y.
Proof.
  apply int31_iota_ind; ins; [apply/eqP in e|apply/neqP in e]; instantiate; subst.
    split; ins; des; try apply/int31_ltuP in H; vauto.
    pose proof (phi_bounded x); change (phi 0) with 0%Z in *; omega.
  rewrite H; clear H; split; ins; desf; eauto using eqxx;
  try apply/int31_ltuP in H.
    by apply/int31_ltuP; apply phi_m1.
    by apply phi_m1 in e; apply/int31_ltuP; instantiate; omega.
  case (eqP (x0 - 1) x); vauto; right.
  apply/int31_ltuP.
  pose proof (phi_bounded x0).
  pose proof (phi_bounded x).
  assert (phi x0 - 1 <> phi x)%Z.
    by intro M; apply (f_equal phi_inv) in M; rewrite phi_inv_phi in *.
  rewrite spec_sub; change (phi 1) with 1%Z; rewrite Zmod_small; omega.
Qed.

Theorem reps_uniq : forall uf, NoDup (reps uf).
Proof.
  ins; apply filter_uniq.
  eapply int31_iota_ind; ins; eauto using NoDup.
  apply/neqP in e; econstructor; eauto.
  red; generalize (proj2 (phi_m1 e)).
  rewrite In_int31_iota; ins.
  apply/int31_ltuP in H1; instantiate; omega.
Qed.

Theorem size_pos : forall t, 0 <? size t.
Proof. ins; apply Parray.length_pos. Qed.

Theorem size_create :
  forall n, size (create n) = if 0 <? n then n else 1.
Proof. intros; apply Parray.length_create. Qed.

Theorem size_union : forall uf a b, size (union uf a b) = size uf.
Proof.
  by unfold size, union; ins; desf; ins; apply Parray.length_set.
Qed.

Theorem ltu_find_size : forall uf a, find uf a <? size uf.
Proof.
  unfold size; ins; destruct (aref_inh (parr uf)) as [v EQ].
  rewrite <- (proj1 (proj2_sig v)), (findE EQ).
  by destruct v as [v [? []]]; apply get_aux_ltu.
Qed.

Theorem repsOK : forall uf r, In r (reps uf) <-> find uf r = r.
Proof.
  split; ins; unfold reps in *; rewrite filter_In in *; desf.
    by apply/eqP.
  case eqP; try done; intros _; split; try done.
  rewrite In_int31_iota, <- H; apply ltu_find_size.
Qed.

Definition copy (uf : t) : t :=
  let (x, m) := uf in
  @mkT
    (Parray.copy x)
    (eq_rect_r (fun i => aref (get_closed (n:=i)))
      (aref_copy m) (Parray.length_copy x)).

Lemma array_copy_eq : forall A (a: Parray.t A), Parray.copy a = a.
Proof.
  intros; apply Parray.eqI; intros; rewrite ?Parray.unsafe_getE;
    auto using Parray.length_copy, Parray.gcopy.
Qed.

Theorem copyE : forall uf, copy uf = uf.
Proof.
  destruct uf; simpl.
  generalize (Parray.length_copy ranks0).
  rewrite array_copy_eq; intros; f_equal.
  by rewrite aref_copyE, (UIP_refl _ _ e).
Qed.

End Puf.