Library mathcomp_eulerian.foata


From mathcomp Require Import all_ssreflect fingroup perm.
From mathcomp_eulerian Require Import descent inversions perm_seq_bridge.

Set Implicit Arguments.
Unset Strict Implicit.
Unset Printing Implicit Defensive.


Definition cyc_last_to_front (s : seq nat) : seq nat :=
  if s is _ :: _ then last 0 s :: belast (head 0 s) (behead s)
  else [::].

Fixpoint split_blocks_aux (P : nat bool) (cur : seq nat) (s : seq nat) :
  seq (seq nat) :=
  match s with
  | [::]if cur is _ :: _ then [:: cur] else [::]
  | x :: rest
      if P x then (rcons cur x) :: split_blocks_aux P [::] rest
      else split_blocks_aux P (rcons cur x) rest
  end.

Definition split_blocks (P : nat bool) (s : seq nat) : seq (seq nat) :=
  split_blocks_aux P [::] s.

Definition foata_step (a : nat) (u : seq nat) : seq nat :=
  match u with
  | [::][:: a]
  | _ :: _
      let x := last 0 u in
      let P := if x < a then (fun y : naty < a) else (fun ya < y) in
      flatten (map cyc_last_to_front (split_blocks P u)) ++ [:: a]
  end.

Definition foata (w : seq nat) : seq nat :=
  foldl (fun u afoata_step a u) [::] w.


Definition is_desc_seq (w : seq nat) (k : nat) : bool :=
  nth 0 w k > nth 0 w k.+1.

Definition maj_seq (w : seq nat) : nat :=
  \sum_(k <- iota 0 (size w).-1 | is_desc_seq w k) k.+1.

Definition inv_seq (w : seq nat) : nat :=
  \sum_(j <- iota 0 (size w))
    \sum_(i <- iota 0 j | nth 0 w i > nth 0 w j) 1.

Definition count_gt (a : nat) (w : seq nat) : nat :=
  count (fun ya < y) w.



Lemma sanity_inv_eq_maj :
  inv_seq (foata [:: 3; 1; 4; 5; 9; 2; 6])
  = maj_seq [:: 3; 1; 4; 5; 9; 2; 6].
Proof.
have Hf : foata [:: 3; 1; 4; 5; 9; 2; 6] = [:: 3; 4; 1; 5; 2; 9; 6] by [].
rewrite Hf /inv_seq /maj_seq /=.
by rewrite !big_cons !big_nil /is_desc_seq /=.
Qed.

Lemma sanity_inv_eq_maj2 :
  inv_seq (foata [:: 2; 3; 1]) = maj_seq [:: 2; 3; 1].
Proof.
have → : foata [:: 2; 3; 1] = [:: 2; 3; 1] by [].
by rewrite /inv_seq /maj_seq /= !big_cons !big_nil /is_desc_seq /=.
Qed.

Lemma sanity_inv_eq_maj3 :
  inv_seq (foata [:: 3; 1; 2]) = maj_seq [:: 3; 1; 2].
Proof.
have → : foata [:: 3; 1; 2] = [:: 1; 3; 2] by [].
by rewrite /inv_seq /maj_seq /= !big_cons !big_nil /is_desc_seq /=.
Qed.


Lemma cyc_last_to_front_perm_eq s :
  perm_eq (cyc_last_to_front s) s.
Proof.
case: s ⇒ [|x s] //=.
have → : x :: s = belast x s ++ [:: last x s] by rewrite lastI cats1.
by rewrite -cat1s perm_catC.
Qed.

Lemma cyc_last_to_front_size s :
  size (cyc_last_to_front s) = size s.
Proof. by rewrite (perm_size (cyc_last_to_front_perm_eq _)). Qed.

Lemma cyc_last_to_front_uniq s :
  uniq (cyc_last_to_front s) = uniq s.
Proof. exact: perm_uniq (cyc_last_to_front_perm_eq _). Qed.

Lemma split_blocks_aux_flatten P cur s :
  flatten (split_blocks_aux P cur s) = cur ++ s.
Proof.
elim: s cur ⇒ [|x rest IH] cur /=.
- by case: cur ⇒ //=; rewrite cats0.
- case Hp : (P x).
  + by rewrite /= IH cat_rcons.
  + by rewrite IH cat_rcons.
Qed.

Lemma split_blocks_flatten P s :
  flatten (split_blocks P s) = s.
Proof. by rewrite /split_blocks split_blocks_aux_flatten. Qed.

Lemma perm_eq_flatten_map_cyc (bs : seq (seq nat)) :
  perm_eq (flatten (map cyc_last_to_front bs)) (flatten bs).
Proof.
elim: bs ⇒ [|b bs IH] //=.
exact: perm_cat (cyc_last_to_front_perm_eq _) IH.
Qed.

Lemma foata_step_perm_eq a u :
  perm_eq (foata_step a u) (rcons u a).
Proof.
rewrite /foata_step.
case: u ⇒ [|x u] //=.
rewrite -cats1.
have → : x :: u ++ [:: a] = (x :: u) ++ [:: a] by [].
apply: perm_cat; last exact: perm_refl.
rewrite -[in X in perm_eq _ X](split_blocks_flatten
  (if last x u < a then (fun yy < a) else (fun ya < y)) (x :: u)).
exact: perm_eq_flatten_map_cyc.
Qed.

Lemma foata_step_size a u :
  size (foata_step a u) = (size u).+1.
Proof.
have Hp := foata_step_perm_eq a u.
by rewrite (perm_size Hp) size_rcons.
Qed.

Lemma foata_step_uniq a u :
  a \notin u uniq u uniq (foata_step a u).
Proof.
moveHa Hu.
rewrite (perm_uniq (foata_step_perm_eq _ _)).
by rewrite rcons_uniq Ha Hu.
Qed.


Lemma foata_perm_eq w : perm_eq (foata w) w.
Proof.
rewrite /foata.
suff: acc, perm_eq (foldl (fun u afoata_step a u) acc w) (acc ++ w).
  by move/(_ [::]); rewrite /=.
elim: w ⇒ [|a w IH] acc /=.
  by rewrite cats0; exact: perm_refl.
apply: perm_trans (IH _) _.
rewrite -cat1s catA.
apply: perm_cat; last exact: perm_refl.
have Hp := foata_step_perm_eq a acc.
by rewrite -cats1 in Hp.
Qed.

Lemma foata_size w : size (foata w) = size w.
Proof. by rewrite (perm_size (foata_perm_eq _)). Qed.

Lemma foata_uniq w : uniq w uniq (foata w).
Proof.
moveHu.
by rewrite (perm_uniq (foata_perm_eq _)).
Qed.

Lemma foata_all_lt w n :
  all (fun xx < n) w all (fun xx < n) (foata w).
Proof.
moveHw.
apply/allPx; rewrite (perm_mem (foata_perm_eq _)).
exact: (allP Hw).
Qed.


Lemma inv_seq_rcons w a :
  inv_seq (rcons w a) = inv_seq w + count_gt a w.
Proof.
rewrite /inv_seq /count_gt size_rcons.
rewrite -addn1 iotaD /= add0n big_cat /=.
rewrite big_cons big_nil addn0.
congr (_ + _).
- apply: eq_big_seqj Hj.
  rewrite mem_iota /= add0n in Hj.
  rewrite [LHS]big_seq_cond [RHS]big_seq_cond.
  apply: eq_bigli.
  rewrite mem_iota /= add0n.
  case Hir: (i < j) ⇒ //=.
  have His : i < size w by apply: ltn_trans Hj.
  have Hjs : j < size w by [].
  by rewrite !nth_rcons Hjs His.
- rewrite nth_rcons ltnn eq_refl.
  rewrite sum1_count.
  have → : count (fun ja < nth 0 (rcons w a) j) (iota 0 (size w))
          = count (fun ja < nth 0 w j) (iota 0 (size w)).
    apply: eq_in_countj.
    rewrite mem_iota /= add0nHj.
    by rewrite nth_rcons Hj.
  rewrite /count_gt.
  rewrite -[w in RHS](take_size).
  rewrite -(map_nth_iota0 0 (leqnn (size w))).
  by rewrite count_map.
Qed.

Lemma count_gt_perm_eq a w1 w2 :
  perm_eq w1 w2 count_gt a w1 = count_gt a w2.
Proof.
moveHp.
rewrite /count_gt.
have := perm_filter (fun ya < y) Hp.
move/perm_size.
by rewrite !size_filter.
Qed.


Lemma maj_seq_rcons w a :
  w != [::]
  maj_seq (rcons w a)
  = maj_seq w + (size w) × (nat_of_bool (last 0 w > a)).
Proof.
moveHw.
rewrite /maj_seq size_rcons /=.
have Hszw : (size w).-1.+1 = size w.
  by case: w Hw ⇒ //=.
rewrite -[in LHS]Hszw.
rewrite -addn1 iotaD /= add0n big_cat /=.
rewrite big_cons big_nil addn0.
congr (_ + _).
- rewrite [LHS]big_seq_cond [RHS]big_seq_cond.
  apply: eq_biglk.
  rewrite mem_iota /= add0n.
  case Hk: (k < (size w).-1) ⇒ //=.
  rewrite /is_desc_seq !nth_rcons.
  have Hk1 : k.+1 < size w by rewrite -Hszw ltnS.
  have Hk2 : k < size w by apply: ltnW.
  by rewrite Hk1 Hk2.
- rewrite Hszw.
  rewrite /is_desc_seq.
  rewrite !nth_rcons.
  have Hsz1 : (size w).-1 < size w by rewrite -Hszw.
  rewrite Hsz1.
  have Hsz2 : (size w).-1.+1 < size w = false by rewrite Hszw ltnn.
  rewrite Hsz2 Hszw eq_refl.
  rewrite -nth_last.
  by case: (_ < _) ⇒ //=; rewrite ?muln1 ?muln0.
Qed.

Lemma foata_step_last d a u :
  last d (foata_step a u) = a.
Proof.
rewrite /foata_step.
case: u ⇒ [|x u] //=.
by rewrite last_cat /=.
Qed.

Lemma foata_rcons w a :
  foata (rcons w a) = foata_step a (foata w).
Proof.
rewrite /foata.
elim/last_ind: w ⇒ [|w b IH] //=.
by rewrite -!cats1 -catA /= !foldl_cat /=.
Qed.

Lemma foata_last d a w :
  last d (foata (a :: w)) = last a w.
Proof.
rewrite /foata /=.
have Hgen : u acc,
  u != [::] last d acc = last d u
  last d (foldl (fun v bfoata_step b v) acc w) = last d (u ++ w).
  elim: w ⇒ [|b w IH] u acc Hu Hacc /=.
    by rewrite cats0 -Hacc.
  have → : u ++ b :: w = rcons u b ++ w by rewrite cat_rcons.
  apply: IH; first by rewrite -cats1; case: (u) ⇒ //=.
  by rewrite (foata_step_last d) last_rcons.
have := Hgen [:: a] [:: a] erefl erefl.
move⇒ →.
by [].
Qed.

Lemma foata_last_eq d w :
  w != [::] last d (foata w) = last d w.
Proof.
case: w ⇒ [|a w] //= _.
exact: foata_last.
Qed.


Definition cross_inv (s1 s2 : seq nat) : nat :=
  \sum_(b <- s2) count_gt b s1.

Lemma cross_inv_nil_r s1 : cross_inv s1 [::] = 0.
Proof. by rewrite /cross_inv big_nil. Qed.

Lemma cross_inv_nil_l s2 : cross_inv [::] s2 = 0.
Proof.
rewrite /cross_inv /count_gt /=.
by rewrite big1 // ⇒ b _.
Qed.

Lemma cross_inv_rcons s1 s2 a :
  cross_inv s1 (rcons s2 a) = cross_inv s1 s2 + count_gt a s1.
Proof. by rewrite /cross_inv -cats1 big_cat /= big_cons big_nil addn0. Qed.

Lemma cross_inv_cons s1 b s2 :
  cross_inv s1 (b :: s2) = count_gt b s1 + cross_inv s1 s2.
Proof. by rewrite /cross_inv big_cons. Qed.

Lemma inv_seq_cat s1 s2 :
  inv_seq (s1 ++ s2) = inv_seq s1 + inv_seq s2 + cross_inv s1 s2.
Proof.
elim/last_ind: s2 ⇒ [|s2 a IH] /=.
  by rewrite cats0 cross_inv_nil_r /inv_seq /= big_nil !addn0.
rewrite -rcons_cat !inv_seq_rcons.
rewrite /count_gt count_cat -/(count_gt a s1) -/(count_gt a s2).
rewrite cross_inv_rcons IH.
set X := count_gt a s1; set Y := count_gt a s2.
set A := inv_seq s1; set B := inv_seq s2; set C := cross_inv s1 s2.
rewrite -[A + B + C]addnA -[A + (B + Y) + _]addnA -addnA.
congr (_ + _).
rewrite -!addnA; congr (_ + _).
by rewrite addnA addnC addnA addnC.
Qed.

Lemma count_gt_cons a b s :
  count_gt a (b :: s) = (nat_of_bool (a < b)) + count_gt a s.
Proof. by rewrite /count_gt /=. Qed.

Lemma count_gt_cat a s1 s2 :
  count_gt a (s1 ++ s2) = count_gt a s1 + count_gt a s2.
Proof. by rewrite /count_gt count_cat. Qed.

Lemma cross_inv_cat_l s1 s1' s2 :
  cross_inv (s1 ++ s1') s2 = cross_inv s1 s2 + cross_inv s1' s2.
Proof.
rewrite /cross_inv -big_split /=.
by apply: eq_bigrb _; rewrite count_gt_cat.
Qed.

Lemma cross_inv_cat_r s1 s2 s2' :
  cross_inv s1 (s2 ++ s2') = cross_inv s1 s2 + cross_inv s1 s2'.
Proof. by rewrite /cross_inv big_cat. Qed.

Lemma cross_inv_perm_eq_r s1 s2 s2' :
  perm_eq s2 s2' cross_inv s1 s2 = cross_inv s1 s2'.
Proof.
moveHp.
rewrite /cross_inv.
by rewrite (perm_big _ Hp).
Qed.

Lemma cross_inv_perm_eq_l s1 s1' s2 :
  perm_eq s1 s1' cross_inv s1 s2 = cross_inv s1' s2.
Proof.
moveHp.
rewrite /cross_inv.
by apply: eq_bigrb _; apply: count_gt_perm_eq.
Qed.

Definition count_lt (a : nat) (w : seq nat) : nat :=
  count (fun yy < a) w.

Lemma count_lt_perm_eq a w1 w2 :
  perm_eq w1 w2 count_lt a w1 = count_lt a w2.
Proof.
moveHp.
rewrite /count_lt.
have := perm_filter (fun yy < a) Hp.
move/perm_size.
by rewrite !size_filter.
Qed.

Lemma count_lt_cat a s1 s2 :
  count_lt a (s1 ++ s2) = count_lt a s1 + count_lt a s2.
Proof. by rewrite /count_lt count_cat. Qed.

Lemma perm_eq_flatten_map_pred (f : seq nat seq nat) bs :
  ( b, b \in bs perm_eq (f b) b)
  perm_eq (flatten (map f bs)) (flatten bs).
Proof.
elim: bs ⇒ [|b bs IH] Hf //=.
apply: perm_cat.
  by apply: Hf; rewrite mem_head.
by apply: IHc Hc; apply: Hf; rewrite in_cons Hc orbT.
Qed.

Lemma inv_seq_flatten_swap_eq (f : seq nat seq nat) bs :
  ( b, b \in bs perm_eq (f b) b)
  inv_seq (flatten (map f bs)) + \sum_(b <- bs) inv_seq b
  = inv_seq (flatten bs) + \sum_(b <- bs) inv_seq (f b).
Proof.
elim: bs ⇒ [|b bs IH] Hf.
  by rewrite /= /inv_seq /= big_nil !big_nil !addn0.
have Hb : perm_eq (f b) b by apply: Hf; rewrite mem_head.
have Hbs : c, c \in bs perm_eq (f c) c.
  by movec Hc; apply: Hf; rewrite in_cons Hc orbT.
have IHs := IH Hbs.
rewrite /= !inv_seq_cat.
have HpermFlat := perm_eq_flatten_map_pred Hbs.
have Hcross : cross_inv (f b) (flatten (map f bs)) = cross_inv b (flatten bs).
  rewrite (cross_inv_perm_eq_l _ Hb).
  by rewrite (cross_inv_perm_eq_r _ HpermFlat).
rewrite Hcross !big_cons.
set IFb := inv_seq (f b); set IB := inv_seq b.
set IFbs := inv_seq (flatten (map f bs)); set IBs := inv_seq (flatten bs).
set CR := cross_inv b (flatten bs).
set Sb := \sum_(c <- bs) inv_seq c.
set SFb := \sum_(c <- bs) inv_seq (f c).
have HSrec : IFbs + Sb = IBs + SFb by [].
apply/eqP.
rewrite -!addnA.
rewrite [IFbs + (CR + _)]addnCA [IBs + (CR + _)]addnCA.
rewrite [IFb + _]addnCA [IB + (CR + _)]addnCA.
rewrite eqn_add2l; apply/eqP.
rewrite [IFbs + (IB + Sb)]addnCA [IBs + (IFb + SFb)]addnCA.
by rewrite HSrec addnCA.
Qed.

Lemma cyc_last_to_front_rcons b' l :
  cyc_last_to_front (rcons b' l) = l :: b'.
Proof.
case: b' ⇒ [|x b'] /=.
  by [].
by rewrite last_rcons belast_rcons.
Qed.

Lemma inv_seq_cons_eq_rcons_shift l b' :
  inv_seq (l :: b') + count_gt l b' = inv_seq (rcons b' l) + count_lt l b'.
Proof.
rewrite -cat1s -cats1 !inv_seq_cat.
have HCG : cross_inv b' [:: l] = count_gt l b'.
  by rewrite /cross_inv big_seq1.
have HCL : cross_inv [:: l] b' = count_lt l b'.
  rewrite /cross_inv /count_lt.
  have ->: \sum_(b <- b') count_gt b [:: l] = \sum_(b <- b') (nat_of_bool (b < l)).
    by apply: eq_bigrx _; rewrite /count_gt /= addn0.
  rewrite -sum1_count.
  by rewrite (eq_bigr (fun bif b < l then 1 else 0));
     [rewrite -big_mkcond | movei _; case: (i < l)].
rewrite HCG HCL.
set X := inv_seq b'; set CG := count_gt l b'; set CL := count_lt l b'.
set IL := inv_seq [:: l].
by rewrite -addnA -[in RHS]addnA [IL + X]addnC; congr (_ + _); rewrite addnC.
Qed.

Lemma cyc_diff_block_lt a b' l :
  l < a all (fun xa < x) b'
  inv_seq (cyc_last_to_front (rcons b' l)) + size b' = inv_seq (rcons b' l).
Proof.
moveHla Hb'.
have HCG : count_gt l b' = size b'.
  rewrite /count_gt -size_filter.
  have ->: [seq y <- b' | l < y] = b'.
    apply/all_filterP/allPx Hx.
    by apply: ltn_trans Hla _; exact: (allP Hb').
  by [].
have HCL : count_lt l b' = 0.
  rewrite /count_lt; apply/eqP; rewrite -leqn0 leqNgt -has_count.
  apply/hasPnx Hx /=.
  rewrite -leqNgt.
  by have Hxa := allP Hb' _ Hx; apply: ltnW; apply: ltn_trans Hla _.
rewrite cyc_last_to_front_rcons.
have := inv_seq_cons_eq_rcons_shift l b'.
by rewrite HCG HCL addn0 ⇒ →.
Qed.

Lemma cyc_diff_block_gt a b' l :
  a < l all (fun xx < a) b'
  inv_seq (cyc_last_to_front (rcons b' l)) = inv_seq (rcons b' l) + size b'.
Proof.
moveHal Hb'.
have HCL : count_lt l b' = size b'.
  rewrite /count_lt -size_filter.
  have ->: [seq y <- b' | y < l] = b'.
    apply/all_filterP/allPx Hx.
    by apply: ltn_trans (allP Hb' _ Hx) Hal.
  by [].
have HCG : count_gt l b' = 0.
  rewrite /count_gt; apply/eqP; rewrite -leqn0 leqNgt -has_count.
  apply/hasPnx Hx /=.
  rewrite -leqNgt.
  have Hxa := allP Hb' _ Hx.
  by apply: ltnW; apply: ltn_trans Hxa _.
rewrite cyc_last_to_front_rcons.
have := inv_seq_cons_eq_rcons_shift l b'.
by rewrite HCG HCL addn0 ⇒ <-.
Qed.

Lemma split_blocks_aux_all_nonempty P cur s :
  all (fun bb != [::]) (split_blocks_aux P cur s).
Proof.
elim: s cur ⇒ [|x s IH] cur /=.
  by case: cur.
case Hp : (P x) ⇒ /=.
  by rewrite IH /= -size_eq0 size_rcons.
exact: IH.
Qed.

Lemma split_blocks_all_nonempty P s :
  all (fun bb != [::]) (split_blocks P s).
Proof. exact: split_blocks_aux_all_nonempty. Qed.



Definition wf_block (P : pred nat) (b : seq nat) : bool :=
  match b with
  | [::]false
  | x :: rest
      P (last x rest) && all (fun y~~ P y) (belast x rest)
  end.

Lemma wf_block_rcons P b' l :
  wf_block P (rcons b' l) = P l && all (fun y~~ P y) b'.
Proof.
case: b' ⇒ [|x b'] /=.
  by rewrite andbT.
by rewrite last_rcons belast_rcons.
Qed.

Lemma split_blocks_aux_wf (P : pred nat) s :
  s != [::] P (last 0 s)
   cur,
  all (fun y~~ P y) cur
  all (wf_block P) (split_blocks_aux P cur s).
Proof.
elim: s ⇒ [|x s IH] Hs HPlast cur Hcur //=.
move: Hs_.
case Hp: (P x).
- apply/andP; split.
    by rewrite wf_block_rcons Hp Hcur.
  case Hs': (s == [::]).
    by move/eqP: Hs' ⇒ →.
  apply: IH ⇒ //.
  + by rewrite Hs'.
  + move: HPlast ⇒ /=.
    move: Hs'; case: s ⇒ [|? ?] //=.
- case Hs': (s == [::]).
    move/eqP: Hs'Hsnil; rewrite Hsnil /= in HPlast.
    by rewrite Hp in HPlast.
  apply: IH.
  + by rewrite Hs'.
  + move: HPlast; rewrite /=.
    move: Hs'; case: s ⇒ [|? ?] //=.
  + by rewrite all_rcons Hp.
Qed.

Lemma split_blocks_wf (P : pred nat) s :
  s != [::] P (last 0 s)
  all (wf_block P) (split_blocks P s).
Proof. by move⇒ ? ?; apply: split_blocks_aux_wf. Qed.

Lemma split_blocks_aux_size_when_last_P (P : pred nat) s :
  s != [::] P (last 0 s)
   cur, size (split_blocks_aux P cur s) = count P s.
Proof.
elim: s ⇒ [|x s IH] Hs HPlast cur //=.
move: Hs_.
case Hp: (P x) ⇒ /=.
- case Hs': (s == [::]).
    by move/eqP: Hs'H; rewrite H /=.
  rewrite IH //.
    by rewrite Hs'.
  have HsP : P (last 0 s).
    move: HPlast ⇒ /= H.
    by case: (s) Hs' H ⇒ [|? ?] //.
  exact: HsP.
- case Hs': (s == [::]).
    move/eqP: Hs'Hsnil; rewrite Hsnil /= in HPlast.
    by rewrite Hp in HPlast.
  rewrite IH //.
    by rewrite Hs'.
  have HsP : P (last 0 s).
    move: HPlast ⇒ /= H.
    by case: (s) Hs' H ⇒ [|? ?] //.
  exact: HsP.
Qed.

Lemma split_blocks_size_eq (P : pred nat) s :
  s != [::] P (last 0 s)
  size (split_blocks P s) = count P s.
Proof. by move⇒ ? ?; apply: split_blocks_aux_size_when_last_P. Qed.

Lemma wf_block_decomp P b :
  wf_block P b
   b' l, [/\ b = rcons b' l, P l, all (fun y~~ P y) b'
              & size b = (size b').+1].
Proof.
case: b ⇒ [|x rest] //= /andP[HPlast Hbelast].
(belast x rest), (last x rest); split.
- by rewrite -lastI.
- by [].
- by [].
- by rewrite size_belast.
Qed.

Lemma size_count_lt_gt a u :
  uniq u a \notin u size u = count_lt a u + count_gt a u.
Proof.
moveHu Hau.
rewrite /count_lt /count_gt -count_predUI.
have ->: count (predI (fun yy < a) (fun ya < y)) u = 0.
  apply/eqP; rewrite -leqn0 leqNgt -has_count.
  apply/hasPnx _ /=.
  by case: (ltngtP x a) ⇒ //.
rewrite addn0 -[LHS](count_predT).
apply: eq_in_countx Hx /=.
case: (ltngtP x a) ⇒ //= Heq.
by exfalso; move: Hau; rewrite -Heq Hx.
Qed.

Lemma sum_size_belast_wf P bs :
  all (wf_block P) bs
  \sum_(b <- bs) (size b).-1 = size (flatten bs) - size bs.
Proof.
elim: bs ⇒ [|b bs IH] /= Hbs.
  by rewrite big_nil.
move: Hbs ⇒ /andP[Hb Hbs].
rewrite big_cons IH // size_cat.
case: (wf_block_decomp Hb) ⇒ b' [l] [-> _ _ Hsz].
rewrite size_rcons /=.
have Hflat : size bs size (flatten bs).
  elim: bs Hbs {IH} ⇒ [|c cs IHcs] //=.
  move⇒ /andP[Hc Hcs].
  rewrite size_cat /=.
  have Hsc : 0 < size c.
    by case: (c) Hc ⇒ [|? ?] //=.
  have IHcss := IHcs Hcs.
  by apply: leq_ltn_trans IHcss _; rewrite -[size _]add0n ltn_add2r.
by rewrite addSn subSS addnBA.
Qed.


Lemma sum_inv_cyc_lt_blocks a bs :
  all (wf_block (fun yy < a)) bs
  
  all (fun bmatch b with
                | [::]false
                | x :: restall (fun ya < y) (belast x rest)
                end) bs
  \sum_(b <- bs) inv_seq (cyc_last_to_front b) + \sum_(b <- bs) (size b).-1
  = \sum_(b <- bs) inv_seq b.
Proof.
elim: bs ⇒ [|b bs IH] /= Hbs Hstr.
  by rewrite !big_nil.
move: Hbs ⇒ /andP[Hb Hbs].
move: Hstr ⇒ /andP[Hstrb Hstrbs].
rewrite !big_cons.
have IHs := IH Hbs Hstrbs.
case: (wf_block_decomp Hb) ⇒ b' [l] [Hbeq HPl _ Hsz].
have Hstrb' : all (fun ya < y) b'.
  move: Hstrb; rewrite Hbeq.
  case: b' Hbeq Hsz ⇒ [|y b' /=] _ Hsz //=.
  by rewrite belast_rcons.
have := cyc_diff_block_lt HPl Hstrb'.
rewrite -HbeqHcyc.
rewrite Hsz /=.
set IB := inv_seq b; set ICB := inv_seq (cyc_last_to_front b).
set SI := \sum_(b0 <- bs) inv_seq (cyc_last_to_front b0).
set SS := \sum_(b0 <- bs) (size b0).-1.
set SB := \sum_(b0 <- bs) inv_seq b0.
have HE : ICB + size b' = IB by rewrite /ICB /IB; exact: Hcyc.
have IHs' : SI + SS = SB by rewrite -/SI -/SS -/SB in IHs.
rewrite -addnA [SI + (size b' + SS)]addnCA.
by rewrite addnA HE IHs'.
Qed.

Lemma sum_inv_cyc_gt_blocks a bs :
  all (wf_block (fun ya < y)) bs
  all (fun bmatch b with
                | [::]false
                | x :: restall (fun yy < a) (belast x rest)
                end) bs
  \sum_(b <- bs) inv_seq (cyc_last_to_front b)
  = \sum_(b <- bs) inv_seq b + \sum_(b <- bs) (size b).-1.
Proof.
elim: bs ⇒ [|b bs IH] /= Hbs Hstr.
  by rewrite !big_nil.
move: Hbs ⇒ /andP[Hb Hbs].
move: Hstr ⇒ /andP[Hstrb Hstrbs].
rewrite !big_cons.
have IHs := IH Hbs Hstrbs.
case: (wf_block_decomp Hb) ⇒ b' [l] [Hbeq HPl _ Hsz].
have Hstrb' : all (fun yy < a) b'.
  move: Hstrb; rewrite Hbeq.
  case: b' Hbeq Hsz ⇒ [|y b' /=] _ Hsz //=.
  by rewrite belast_rcons.
have := cyc_diff_block_gt HPl Hstrb'.
rewrite -HbeqHcyc.
rewrite Hsz /=.
set IB := inv_seq b; set ICB := inv_seq (cyc_last_to_front b).
set SI := \sum_(b0 <- bs) inv_seq (cyc_last_to_front b0).
set SS := \sum_(b0 <- bs) (size b0).-1.
set SB := \sum_(b0 <- bs) inv_seq b0.
have HE : ICB = IB + size b' by rewrite /ICB /IB; exact: Hcyc.
have IHs' : SI = SB + SS by rewrite -/SI -/SB -/SS in IHs.
rewrite HE IHs'.
by rewrite -!addnA; congr (IB + _); rewrite addnCA.
Qed.

Lemma split_blocks_lt_strict a u :
  uniq u a \notin u
  u != [::] last 0 u < a
  all (fun bmatch b with
                | [::]false
                | x :: restall (fun ya < y) (belast x rest)
                end) (split_blocks (fun yy < a) u).
Proof.
moveHu Hau Hu' Hla.
have HPlast : (fun y : naty < a) (last 0 u) by [].
have Hwf := @split_blocks_wf (fun y : naty < a) u Hu' HPlast.
apply/allPb Hb.
have Hwfb := allP Hwf _ Hb.
case: (wf_block_decomp Hwfb) ⇒ b' [l] [Hbeq _ Hb' Hsz].
case: b Hb Hbeq Hsz Hwfb ⇒ [|x rest] //= Hb Hbeq Hsz Hwfb.
have Hbelast_eq : belast x rest = b'.
  by have := f_equal (belast 0) Hbeq; rewrite belast_rcons /= ⇒ -[->].
rewrite Hbelast_eq.
apply/allPy Hy.
have HyU : y \in u.
  rewrite -(split_blocks_flatten (fun yy < a) u).
  apply/flattenP. (rcons b' l).
    by rewrite -Hbeq; exact: Hb.
  by rewrite mem_rcons in_cons Hy orbT.
have Hya_neq : y != a.
  by apply: contra Hau ⇒ /eqP <-.
have Hyge : ~~ (y < a).
  by have := allP Hb' _ Hy.
rewrite ltn_neqAle.
apply/andP; split.
- by rewrite eq_sym.
- by rewrite leqNgt.
Qed.

Lemma split_blocks_gt_strict a u :
  uniq u a \notin u
  u != [::] a < last 0 u
  all (fun bmatch b with
                | [::]false
                | x :: restall (fun yy < a) (belast x rest)
                end) (split_blocks (fun ya < y) u).
Proof.
moveHu Hau Hu' Hla.
have HPlast : (fun ya < y) (last 0 u) by [].
have Hwf := split_blocks_wf Hu' HPlast.
apply/allPb Hb.
have Hwfb := allP Hwf _ Hb.
case: (wf_block_decomp Hwfb) ⇒ b' [l] [Hbeq _ Hb' Hsz].
case: b Hb Hbeq Hsz Hwfb ⇒ [|x rest] //= Hb Hbeq Hsz Hwfb.
have Hbelast_eq : belast x rest = b'.
  by have := f_equal (belast 0) Hbeq; rewrite belast_rcons /= ⇒ -[->].
rewrite Hbelast_eq.
apply/allPy Hy.
have HyU : y \in u.
  rewrite -(split_blocks_flatten (fun ya < y) u).
  apply/flattenP. (rcons b' l).
    by rewrite -Hbeq; exact: Hb.
  by rewrite mem_rcons in_cons Hy orbT.
have Hya_neq : y != a.
  by apply: contra Hau ⇒ /eqP <-.
have Hyle : ~~ (a < y).
  by have := allP Hb' _ Hy.
rewrite ltn_neqAle.
apply/andP; split.
- by apply: contra Hau ⇒ /eqP <-.
- by rewrite leqNgt.
Qed.

Lemma foata_step_inv_lt a u :
  u != [::] uniq u a \notin u last 0 u < a
  inv_seq (foata_step a u) = inv_seq u.
Proof.
case: u ⇒ [|x u] //= _ Hu' Hau Hla.
have HPlast : (fun y : naty < a) (last 0 (x :: u)) by [].
have Huniq : uniq (x :: u) by [].
have Hu_ne : (x :: u) != [::] by [].
have Hwf : all (wf_block (fun y : naty < a)) (split_blocks (fun y : naty < a) (x :: u)).
  exact: split_blocks_wf.
have Hstr := split_blocks_lt_strict Huniq Hau Hu_ne Hla.
have Hflat_bs : flatten (split_blocks (fun y : naty < a) (x :: u)) = x :: u by exact: split_blocks_flatten.
rewrite /foata_step /=.
have Hxulast : last x u < a by [].
rewrite Hxulast.
set bs := split_blocks (fun y : naty < a) (x :: u) in Hwf Hstr Hflat_bs ×.
rewrite inv_seq_cat /cross_inv big_seq1.
have ->: inv_seq [:: a] = 0 by rewrite /inv_seq /= !big_cons !big_nil.
rewrite addn0.
set ru := flatten (map cyc_last_to_front bs).
have Hperm_ru : perm_eq ru (x :: u) by rewrite /ru -[in X in perm_eq _ X]Hflat_bs;
  exact: perm_eq_flatten_map_cyc.
have Hcg : count_gt a ru = count_gt a (x :: u) by exact: count_gt_perm_eq.
rewrite Hcg.
have Hpermall : b, b \in bs perm_eq (cyc_last_to_front b) b.
  by move⇒ ? _; exact: cyc_last_to_front_perm_eq.
have Hswap := inv_seq_flatten_swap_eq Hpermall.
have Hsum_cyc := sum_inv_cyc_lt_blocks Hwf Hstr.
rewrite Hflat_bs in Hswap.
set IRu := inv_seq ru. set IU := inv_seq (x :: u).
set SB := \sum_(b <- bs) inv_seq b.
set SCB := \sum_(b <- bs) inv_seq (cyc_last_to_front b).
set SS := \sum_(b <- bs) (size b).-1.
have Hswap' : IRu + SB = IU + SCB by rewrite -/IRu -/SB -/IU -/SCB in Hswap.
have Hsum' : SCB + SS = SB by rewrite -/SCB -/SB -/SS in Hsum_cyc.
have Hkey : IRu + SS = IU.
  apply/eqP. rewrite -(eqn_add2r SCB).
  by rewrite -addnA [SS + SCB]addnC Hsum' Hswap'.
have Hsize_u : size (x :: u) = count_lt a (x :: u) + count_gt a (x :: u)
  by exact: size_count_lt_gt.
have Hnumb : size bs = count_lt a (x :: u).
  rewrite /bs split_blocks_size_eq //= /count_lt.
have Hsumss : SS = size (x :: u) - size bs.
  rewrite /SS (sum_size_belast_wf Hwf).
  by rewrite Hflat_bs.
rewrite Hnumb Hsize_u in Hsumss.
rewrite addKn in Hsumss.
rewrite Hsumss in Hkey.
exact: Hkey.
Qed.

Lemma foata_step_inv_gt a u :
  u != [::] uniq u a \notin u a < last 0 u
  inv_seq (foata_step a u) = inv_seq u + size u.
Proof.
case: u ⇒ [|x u] //= _ Hu' Hau Hla.
have HPlast : (fun y : nata < y) (last 0 (x :: u)) by [].
have Huniq : uniq (x :: u) by [].
have Hu_ne : (x :: u) != [::] by [].
have Hwf : all (wf_block (fun y : nata < y)) (split_blocks (fun y : nata < y) (x :: u)).
  exact: split_blocks_wf.
have Hstr := split_blocks_gt_strict Huniq Hau Hu_ne Hla.
have Hflat_bs : flatten (split_blocks (fun y : nata < y) (x :: u)) = x :: u
  by exact: split_blocks_flatten.
rewrite /foata_step /=.
have Hxulast : ~~ (last x u < a) by rewrite -leqNgt; apply: ltnW.
move: Hxulast ⇒ /negbTE →.
set bs := split_blocks (fun y : nata < y) (x :: u) in Hwf Hstr Hflat_bs ×.
rewrite inv_seq_cat /cross_inv big_seq1.
have ->: inv_seq [:: a] = 0 by rewrite /inv_seq /= !big_cons !big_nil.
rewrite addn0.
set ru := flatten (map cyc_last_to_front bs).
have Hperm_ru : perm_eq ru (x :: u) by rewrite /ru -[in X in perm_eq _ X]Hflat_bs;
  exact: perm_eq_flatten_map_cyc.
have Hcg : count_gt a ru = count_gt a (x :: u) by exact: count_gt_perm_eq.
rewrite Hcg.
have Hpermall : b, b \in bs perm_eq (cyc_last_to_front b) b.
  by move⇒ ? _; exact: cyc_last_to_front_perm_eq.
have Hswap := inv_seq_flatten_swap_eq Hpermall.
have Hsum_cyc := sum_inv_cyc_gt_blocks Hwf Hstr.
rewrite Hflat_bs in Hswap.
set IRu := inv_seq ru. set IU := inv_seq (x :: u).
set SB := \sum_(b <- bs) inv_seq b.
set SCB := \sum_(b <- bs) inv_seq (cyc_last_to_front b).
set SS := \sum_(b <- bs) (size b).-1.
have Hswap' : IRu + SB = IU + SCB by rewrite -/IRu -/SB -/IU -/SCB in Hswap.
have Hsum' : SCB = SB + SS by rewrite -/SCB -/SB -/SS in Hsum_cyc.
have Hkey : IRu = IU + SS.
  apply/eqP. rewrite -(eqn_add2r SB).
  by rewrite Hswap' Hsum' [SB + SS]addnC addnA.
have Hsize_u : size (x :: u) = count_lt a (x :: u) + count_gt a (x :: u)
  by exact: size_count_lt_gt.
have Hnumb : size bs = count_gt a (x :: u).
  rewrite /bs split_blocks_size_eq //= /count_gt.
have Hsumss : SS = size (x :: u) - size bs.
  rewrite /SS (sum_size_belast_wf Hwf).
  by rewrite Hflat_bs.
rewrite Hnumb Hsize_u in Hsumss.
rewrite -addnBA // subnn addn0 in Hsumss.
rewrite Hsumss in Hkey.
rewrite Hkey.
rewrite -addnA -[count_lt a (x :: u) + _]Hsize_u.
by [].
Qed.

Lemma foata_step_inv a u :
  u != [::] uniq u a \notin u
  inv_seq (foata_step a u)
    = inv_seq u + (size u) × (nat_of_bool (a < last 0 u)).
Proof.
moveHu Hu' Hau.
case Hla: (a < last 0 u).
- by rewrite muln1; exact: foata_step_inv_gt.
- rewrite muln0 addn0.
  apply: foata_step_inv_lt ⇒ //.
  rewrite ltn_neqAle leqNgt Hla andbT.
  apply/eqPHlast.
  move/negP: Hau; apply.
  rewrite -Hlast.
  by case: (u) Hu ⇒ [//=|? ?] _; exact: mem_last.
Qed.


Theorem foata_inv_eq_maj w :
  uniq w inv_seq (foata w) = maj_seq w.
Proof.
elim/last_ind: w ⇒ [|w a IH] Hu.
  by rewrite /foata /= /inv_seq /maj_seq /= !big_nil.
have Hw : uniq w by move: Hu; rewrite rcons_uniq; case/andP.
have Ha : a \notin w by move: Hu; rewrite rcons_uniq; case/andP.
have IHw := IH Hw.
rewrite foata_rcons.
case Hwnil: (w == [::]).
  move/eqP: Hwnil ⇒ →.
  by rewrite /= /inv_seq /maj_seq /= !big_cons !big_nil.
have Hw_ne : w != [::] by rewrite Hwnil.
have Hfw_ne : foata w != [::].
  rewrite -size_eq0 foata_size.
  by case: (w) Hw_ne ⇒ [|? ?].
have Hfw_uniq : uniq (foata w) by exact: foata_uniq.
have Hfw_a : a \notin foata w.
  by rewrite (perm_mem (foata_perm_eq _)).
rewrite foata_step_inv //.
rewrite (foata_last_eq 0 Hw_ne).
rewrite foata_size.
rewrite (maj_seq_rcons _ Hw_ne).
by rewrite IHw.
Qed.


Definition cyc_first_to_back (s : seq nat) : seq nat :=
  match s with
  | [::][::]
  | x :: restrcons rest x
  end.

Lemma cyc_first_to_back_size s :
  size (cyc_first_to_back s) = size s.
Proof. by case: s ⇒ //= ? ?; rewrite size_rcons. Qed.

Lemma cyc_first_to_back_perm_eq s :
  perm_eq (cyc_first_to_back s) s.
Proof.
case: s ⇒ [|x s] //=.
by rewrite -cats1 -cat1s perm_catC.
Qed.

Lemma cyc_first_to_back_uniq s :
  uniq (cyc_first_to_back s) = uniq s.
Proof. exact: perm_uniq (cyc_first_to_back_perm_eq _). Qed.

Lemma cyc_first_to_backK s :
  cyc_first_to_back (cyc_last_to_front s) = s.
Proof.
case: s ⇒ [|x s] //=.
by rewrite -lastI.
Qed.

Lemma cyc_last_to_frontK s :
  cyc_last_to_front (cyc_first_to_back s) = s.
Proof.
case: s ⇒ [|x s] //=.
case: s ⇒ [|y s] /=.
  by [].
rewrite last_rcons.
rewrite belast_rcons /=.
by [].
Qed.

Fixpoint split_blocks_inv_aux (P : nat bool) (cur : seq nat) (s : seq nat) :
  seq (seq nat) :=
  match s with
  | [::]if cur is _ :: _ then [:: cur] else [::]
  | x :: rest
      if P x then
        
        if cur is _ :: _ then cur :: split_blocks_inv_aux P [:: x] rest
        else split_blocks_inv_aux P [:: x] rest
      else split_blocks_inv_aux P (rcons cur x) rest
  end.

Definition split_blocks_inv (P : nat bool) (s : seq nat) : seq (seq nat) :=
  split_blocks_inv_aux P [::] s.

Lemma split_blocks_inv_aux_flatten P cur s :
  flatten (split_blocks_inv_aux P cur s) = cur ++ s.
Proof.
elim: s cur ⇒ [|x rest IH] cur /=.
- by case: cur ⇒ //=; rewrite cats0.
- case Hp: (P x).
  + case: cur ⇒ [|y cur] /=.
    × by rewrite IH /=.
    × by rewrite IH.
  + by rewrite IH cat_rcons.
Qed.

Lemma split_blocks_inv_flatten P s :
  flatten (split_blocks_inv P s) = s.
Proof. by rewrite /split_blocks_inv split_blocks_inv_aux_flatten. Qed.


Lemma cyc_last_to_front_wf (P : pred nat) b :
  wf_block P b
  if cyc_last_to_front b is x :: rest then
    P x && all (fun y~~ P y) rest
  else false.
Proof.
case/wf_block_decompb' [l] [-> Hl Hb' _].
rewrite cyc_last_to_front_rcons /=.
by rewrite Hl Hb'.
Qed.



Lemma split_blocks_inv_aux_cons_P (P : pred nat) cur x rest :
  P x
  split_blocks_inv_aux P cur (x :: rest)
  = (if cur is _ :: _ then [:: cur] else [::])
    ++ split_blocks_inv_aux P [:: x] rest.
Proof. by moveHx /=; rewrite Hx; case: cur. Qed.

Lemma split_blocks_inv_aux_cons_notP (P : pred nat) cur x rest :
  ~~ P x
  split_blocks_inv_aux P cur (x :: rest)
  = split_blocks_inv_aux P (rcons cur x) rest.
Proof. by move⇒ /negbTE Hx /=; rewrite Hx. Qed.





Lemma split_blocks_inv_aux_app_notP (P : pred nat) cur nots rest :
  all (fun y~~ P y) nots
  split_blocks_inv_aux P cur (nots ++ rest)
  = split_blocks_inv_aux P (cur ++ nots) rest.
Proof.
elim: nots cur ⇒ [|y nots IH] cur /= Hnots.
  by rewrite cats0.
move: Hnots ⇒ /andP[Hy Hnots].
rewrite (negbTE Hy) IH // -cats1 -catA /=.
by [].
Qed.


Lemma split_blocks_inv_aux_one_block (P : pred nat) l nots tail :
  P l all (fun y~~ P y) nots
  split_blocks_inv_aux P [::] ((l :: nots) ++ tail)
  = split_blocks_inv_aux P [:: l] (nots ++ tail).
Proof. by moveHl _ /=; rewrite Hl. Qed.


Lemma split_blocks_inv_aux_block_then_P (P : pred nat) l nots rest :
  P l all (fun y~~ P y) nots
  ( x, x \in rest True)
  match rest with
  | [::]split_blocks_inv_aux P [:: l] nots
            = [:: l :: nots]
  | y :: _P y
              split_blocks_inv_aux P [:: l] (nots ++ rest)
              = (l :: nots) :: split_blocks_inv_aux P [::] rest
  end.
Proof.
moveHl Hnots _.
case: rest ⇒ [|y rest'].
- rewrite -[in LHS](cats0 nots).
  rewrite (split_blocks_inv_aux_app_notP _ _ Hnots) /=.
  by case: nots Hnots ⇒ [|? ?] //=; rewrite cats0.
- moveHy.
  rewrite (split_blocks_inv_aux_app_notP _ _ Hnots) /=.
  rewrite Hy.
  by case: nots Hnots ⇒ [|? ?] //=; rewrite cats0.
Qed.

Lemma split_blocks_inv_cyc_wf (P : pred nat) bs :
  bs != [::]
  all (wf_block P) bs
  split_blocks_inv P (flatten (map cyc_last_to_front bs))
  = map cyc_last_to_front bs.
Proof.
rewrite /split_blocks_inv.
elim: bs ⇒ [|b bs IH] // _ /andP[Hb Hbs].
case/wf_block_decomp: (Hb) ⇒ b' [l] [Hbeq Hl Hb' _].
rewrite Hbeq /= cyc_last_to_front_rcons /=.
case Hbsnil: (bs == [::]).
- move/eqP: Hbsnil ⇒ → /=.
  rewrite cats0.
  rewrite Hl.
  rewrite -[in LHS](cats0 b') (split_blocks_inv_aux_app_notP _ _ Hb') /=.
  by case: b' Hb' Hbeq ⇒ [|? ?] /=; rewrite ?cats0.
- have Hbs_ne : bs != [::] by rewrite Hbsnil.
  have IHs := IH Hbs_ne Hbs.
  rewrite Hl.
  rewrite (split_blocks_inv_aux_app_notP _ _ Hb').
  case: bs Hbs Hbs_ne IHs Hbsnil IH ⇒ [|c bs0] //= /andP[Hc Hbs0] _ IHs _ _.
  case/wf_block_decomp: (Hc) ⇒ c' [m] [Hceq Hm Hc' _].
  rewrite Hceq /= cyc_last_to_front_rcons /=.
  rewrite Hceq /= cyc_last_to_front_rcons /= in IHs.
  rewrite Hm.
  move: IHs.
  rewrite Hm.
  rewrite (split_blocks_inv_aux_app_notP _ _ Hc') ⇒ IHs.
  by rewrite IHs.
Qed.

Definition foata_step_undo (s : seq nat) : (nat × seq nat) :=
  let a := last 0 s in
  let r := belast (head 0 s) (behead s) in
  match r with
  | [::](a, [::])
  | h :: _
      let P := if h < a then (fun y : naty < a) else (fun ya < y) in
      (a, flatten (map cyc_first_to_back (split_blocks_inv P r)))
  end.

Lemma split_blocks_inv_aux_eq (P Q : pred nat) cur s :
  P =1 Q split_blocks_inv_aux P cur s = split_blocks_inv_aux Q cur s.
Proof.
moveHPQ.
elim: s cur ⇒ [|x rest IH] cur //=.
rewrite (HPQ x).
by case: (Q x); case: cur; rewrite ?IH.
Qed.

Lemma split_blocks_inv_eq (P Q : pred nat) s :
  P =1 Q split_blocks_inv P s = split_blocks_inv Q s.
Proof.
moveHPQ.
by rewrite /split_blocks_inv (split_blocks_inv_aux_eq _ _ HPQ).
Qed.

Lemma foata_step_undoK a u :
  u != [::] uniq u a \notin u
  foata_step_undo (foata_step a u) = (a, u).
Proof.
moveHne Hu Hau.
have Hxu_last : last 0 u != a.
  apply/eqPHeq.
  move/negP: Hau; apply.
  rewrite -Heq.
  case: (u) Hne ⇒ [|x0 u0] //= _.
  exact: mem_last.
rewrite /foata_step.
case: u Hne Hu Hau Hxu_last ⇒ [|x u] //= _ Hu Hau Hxu_last.
set P := if last x u < a then (fun y : naty < a) else (fun ya < y).
set bs := split_blocks P (x :: u).
have HPlast : P (last 0 (x :: u)).
  rewrite /P /=.
  case Hcase: (last x u < a); first by rewrite Hcase.
  rewrite ltn_neqAle leqNgt Hcase andbT.
  by rewrite eq_sym Hxu_last.
have Hwf : all (wf_block P) bs by exact: split_blocks_wf.
have Hbs_ne : bs != [::].
  rewrite -size_eq0 split_blocks_size_eq //.
  rewrite -lt0n -has_count; apply/hasP; (last 0 (x :: u)) ⇒ //.
  exact: mem_last.
set ru := flatten (map cyc_last_to_front bs).
have Hperm_ru : perm_eq ru (x :: u).
  rewrite /ru -[X in perm_eq _ X](split_blocks_flatten P (x :: u)).
  exact: perm_eq_flatten_map_cyc.
have Hru_size : size ru = size (x :: u).
  by rewrite (perm_size Hperm_ru).
have Hru_ne : ru != [::].
  by rewrite -size_eq0 Hru_size /=.
rewrite /foata_step_undo.
have Hlast : last 0 (ru ++ [:: a]) = a by rewrite last_cat /=.
have Hbelast : belast (head 0 (ru ++ [:: a])) (behead (ru ++ [:: a])) = ru.
  move: Hru_ne Hperm_ru Hru_size Hlast.
  case: ru ⇒ [|y ru'] //= _ Hperm Hsize Hlast.
  by rewrite cats1 belast_rcons.
rewrite Hlast Hbelast.
case Hbs_eq: bs Hbs_ne Hwf Hperm_ru Hru_size Hru_ne ⇒ [|b1 brest] // _.
move⇒ /andP[Hb1 Hbrest] Hperm_ru Hru_size Hru_ne.
case/wf_block_decomp: Hb1b1' [l1] [Hb1eq Hl1 Hb1' _].
have Hru_decomp : ru = (l1 :: b1') ++ flatten (map cyc_last_to_front brest).
  by rewrite /ru Hbs_eq /= Hb1eq cyc_last_to_front_rcons.
have Hru_head : head 0 ru = l1 by rewrite Hru_decomp.
case Hru: ru Hru_ne Hbelast Hru_decomp Hru_head ⇒ [|y ru'] //.
move_ Hbelast Hru_decomp Hru_head.
have Hyl : y = l1 by [].
move: Hl1; rewrite -HylHyl1.
have Hy_in : y \in (x :: u).
  rewrite -(perm_mem Hperm_ru) Hru.
  by rewrite mem_head.
have Hyne_a : y != a.
  apply/eqPHeq.
  by move: Hau; rewrite -Heq Hy_in.
have HPunfold : z, P z = (if last x u < a then z < a else a < z).
  by movez; rewrite /P; case: ifP.
have Hy_la : (y < a) = (last x u < a).
  case Hla: (last x u < a).
  -
    by move: Hyl1; rewrite (HPunfold y) Hla.
  -
    have Hay : a < y by move: Hyl1; rewrite (HPunfold y) Hla.
    apply/negbTE; rewrite -leqNgt.
    exact: ltnW.
rewrite Hy_la.
have Hbs_decomp : bs = b1 :: brest by rewrite Hbs_eq.
have Hwf_full : all (wf_block P) bs.
  exact: split_blocks_wf.
have Hsbi : split_blocks_inv P ru = map cyc_last_to_front bs.
  have Hru_eq : ru = flatten (map cyc_last_to_front bs) by [].
  rewrite Hru_eq.
  apply: split_blocks_inv_cyc_wf ⇒ //.
  by rewrite Hbs_decomp.
congr (_, _).
have HPiff : (Q : pred nat),
  Q =1 (if last x u < a then (fun z : natz < a) else (fun za < z))
  split_blocks_inv Q (y :: ru') = map cyc_last_to_front bs.
  moveQ HQ.
  have HQP : Q =1 P by movez; rewrite HQ /P.
  by rewrite (split_blocks_inv_eq _ HQP) -Hru Hsbi.
rewrite HPiff; last by movez //=.
rewrite -map_comp.
have Hcomp : b, b \in bs
  (cyc_first_to_back \o cyc_last_to_front) b = b.
  by moveb _ /=; exact: cyc_first_to_backK.
rewrite (_ : [seq (cyc_first_to_back \o cyc_last_to_front) i | i <- bs] = bs);
  first by exact: split_blocks_flatten.
have Heq : [seq (cyc_first_to_back \o cyc_last_to_front) i | i <- bs]
         = [seq id i | i <- bs] by apply/eq_in_map; apply: Hcomp.
by rewrite Heq map_id.
Qed.

Fixpoint foata_inv_aux (n : nat) (s : seq nat) : seq nat :=
  match n with
  | 0 ⇒ [::]
  | n.+1
      let: (a, r) := foata_step_undo s in
      rcons (foata_inv_aux n r) a
  end.

Definition foata_inv (s : seq nat) : seq nat :=
  foata_inv_aux (size s) s.

Lemma foata_invK_aux n w :
  size w = n uniq w
  foata_inv_aux n (foata w) = w.
Proof.
elim: n w ⇒ [|n IH] w Hsz Hu /=.
  by move: Hu Hsz; case: w ⇒ //=.
case/lastP: w Hsz Hu ⇒ [|w' a] // Hsz Hu.
rewrite size_rcons in Hsz; case: HszHsz.
have Hw' : uniq w' by move: Hu; rewrite rcons_uniq ⇒ /andP[].
have Hau : a \notin w' by move: Hu; rewrite rcons_uniq ⇒ /andP[].
have Hfw' := foata_perm_eq w'.
have Hfw_uniq : uniq (foata w') by exact: foata_uniq.
have Hfw_au : a \notin foata w'.
  by rewrite (perm_mem Hfw').
rewrite foata_rcons.
case Hw'_nil: (w' == [::]).
- move/eqP: Hw'_nil Hsz ⇒ → Hsz0.
  have Hsz_n : n = 0 by case: n IH Hsz0.
  by rewrite Hsz_n /= /foata_step /foata_step_undo /=.
- have Hw'_ne : w' != [::] by rewrite Hw'_nil.
  have Hfw'_ne : foata w' != [::].
    by rewrite -size_eq0 foata_size; case: (w') Hw'_ne.
  have Hcanc := foata_step_undoK Hfw'_ne Hfw_uniq Hfw_au.
  rewrite Hcanc.
  have IHw := IH w' Hsz Hw'.
  by rewrite IHw.
Qed.

Lemma foata_invK w :
  uniq w foata_inv (foata w) = w.
Proof.
moveHu.
rewrite /foata_inv foata_size.
exact: foata_invK_aux.
Qed.

Lemma foata_inj_uniq w1 w2 :
  uniq w1 uniq w2 foata w1 = foata w2 w1 = w2.
Proof.
moveHu1 Hu2 Heq.
have H1 := foata_invK Hu1.
have H2 := foata_invK Hu2.
by rewrite -H1 -H2 Heq.
Qed.


Section Equidistribution.

Lemma maj_eq_maj_seq n (s : {perm 'I_n.+1}) :
  maj s = maj_seq (perm_to_seq s).
Proof.
rewrite /maj /maj_seq perm_to_seq_size /=.
rewrite [LHS](_ : _ = \sum_(i : 'I_n | is_descent s i) (val i).+1); last first.
  by apply: eq_bigli; rewrite mem_descent_set.
rewrite -[iota 0 n]val_enum_ord big_map.
rewrite [LHS]big_mkcond -big_enum.
rewrite [RHS](_ : _ = \sum_(i <- enum 'I_n) (if is_descent s i then (val i).+1 else 0)).
  by [].
rewrite -big_mkcond.
by apply: eq_bigk //=; rewrite is_descent_perm_seq.
Qed.

Section FoataPerm.

Variable n : nat.

Lemma foata_perm_to_seq_size (s : {perm 'I_n.+1}) :
  size (foata (perm_to_seq s)) = n.+1.
Proof. by rewrite foata_size perm_to_seq_size. Qed.

Lemma foata_perm_to_seq_uniq (s : {perm 'I_n.+1}) :
  uniq (foata (perm_to_seq s)).
Proof. by apply: foata_uniq; exact: perm_to_seq_uniq. Qed.

Lemma foata_perm_to_seq_bnd (s : {perm 'I_n.+1}) :
  all (fun xx < n.+1) (foata (perm_to_seq s)).
Proof.
apply/allPx.
rewrite (perm_mem (foata_perm_eq _)) ⇒ Hx.
exact: (allP (perm_to_seq_bnd s)).
Qed.

Definition foata_perm (s : {perm 'I_n.+1}) : {perm 'I_n.+1} :=
  seq_to_perm (foata_perm_to_seq_size s) (foata_perm_to_seq_uniq s)
              (foata_perm_to_seq_bnd s).

Lemma perm_to_seq_foata_perm (s : {perm 'I_n.+1}) :
  perm_to_seq (foata_perm s) = foata (perm_to_seq s).
Proof. by rewrite /foata_perm perm_to_seq_seq_to_perm. Qed.

End FoataPerm.

Lemma inv_eq_inv_seq n (s : {perm 'I_n.+1}) :
  inv s = inv_seq (perm_to_seq s).
Proof.
rewrite (inv_double_sum s).
rewrite /inv_seq perm_to_seq_size.
rewrite -[iota 0 n.+1]val_enum_ord big_map.
rewrite -big_enum.
apply: eq_bigrj _.
have Hjn : (j : nat) n.+1 by apply: ltnW; exact: ltn_ord.
rewrite -(subn0 (val j)) -[iota 0 _]/(index_iota _ _).
rewrite (big_nat_widen 0 (val j) n.+1 _ _ Hjn).
rewrite -[index_iota 0 n.+1]/(iota 0 (n.+1 - 0)) subn0.
rewrite -[iota 0 n.+1]val_enum_ord big_map.
rewrite subn0.
rewrite [in LHS]big_mkcond /=.
rewrite [in RHS]big_seq_cond.
under [in RHS]eq_bigli do rewrite mem_enum andTb.
rewrite [in RHS]big_mkcond.
have ->: enum 'I_n.+1 = index_enum (Finite.clone _ 'I_n.+1) by rewrite enumT.
apply: eq_bigri _.
have Hi : (i : nat) < n.+1 := ltn_ord i.
have Hj' : (j : nat) < n.+1 := ltn_ord j.
rewrite (nth_perm_to_seq s Hi) (nth_perm_to_seq s Hj').
have ->: Ordinal Hi = i by apply: val_inj.
have ->: Ordinal Hj' = j by apply: val_inj.
by case Hij: (i < j); rewrite andbC //=.
Qed.

Lemma foata_perm_inv_maj n (s : {perm 'I_n.+1}) :
  inv (foata_perm s) = maj s.
Proof.
rewrite inv_eq_inv_seq perm_to_seq_foata_perm.
rewrite foata_inv_eq_maj; last exact: perm_to_seq_uniq.
by rewrite -maj_eq_maj_seq.
Qed.

Lemma foata_perm_inj n : injective (@foata_perm n).
Proof.
moves1 s2 Heq.
apply: (@perm_to_seq_inj n.+1).
have H1 := perm_to_seq_foata_perm s1.
have H2 := perm_to_seq_foata_perm s2.
have HF : foata (perm_to_seq s1) = foata (perm_to_seq s2).
  by rewrite -H1 -H2 Heq.
apply: foata_inj_uniq HF; exact: perm_to_seq_uniq.
Qed.

Theorem inv_maj_equidistr n k :
  #|[set s : {perm 'I_n.+1} | inv s == k]|
  = #|[set s : {perm 'I_n.+1} | maj s == k]|.
Proof.
have Hinj := @foata_perm_inj n.
have Hbij : [set s : {perm 'I_n.+1} | inv s == k]
          = @foata_perm n @: [set s : {perm 'I_n.+1} | maj s == k].
  apply/setPt; rewrite !inE.
  apply/idP/imsetP.
  - moveHinv.
     (finv (@foata_perm n) t).
    + rewrite inE.
      have Hcanc := f_finv Hinj t.
      rewrite -[X in maj _ == X](_ : inv t = k); last by apply/eqP.
      by rewrite -{2}Hcanc foata_perm_inv_maj.
    + by rewrite (f_finv Hinj t).
  - cases Hs →.
    rewrite inE in Hs.
    by rewrite foata_perm_inv_maj.
rewrite Hbij.
by rewrite (card_imset _ Hinj).
Qed.

End Equidistribution.