section \<open>Some common bitwise operations on uint8 and uint64\<close>

theory Bit_Operations
  imports Main "HOL-Word.Word_Bitwise" 
    "Native_Word.Uint8" "Native_Word.Uint64" "Native_Word.Code_Target_Bits_Int"
    "More_List" "HOL-Library.List_Lexorder"
begin

subsection \<open>Definitions\<close>

definition bit_test :: "nat \<Rightarrow> ('a::{bit_operations, zero, one}) \<Rightarrow> bool" where
  "bit_test k x \<longleftrightarrow> x AND (1 << k) \<noteq> 0"

definition bit_get :: "nat \<Rightarrow> 'a :: {bit_operations, one} \<Rightarrow> 'a" where
  "bit_get k x = (x >> k) AND 1"

definition bit_put :: "nat \<Rightarrow> 'a :: {bit_operations} \<Rightarrow> 'a \<Rightarrow> 'a" where
  "bit_put k x b = x OR (b << k)"

definition bit_set :: "nat \<Rightarrow> 'a :: {bit_operations, one} \<Rightarrow> 'a" where
  "bit_set k x = x OR (1 << k)"

definition bit_clear :: "nat \<Rightarrow> 'a :: {bit_operations, one} \<Rightarrow> 'a" where
  "bit_clear k x = x AND NOT (1 << k)"

definition bit_pop_count :: "uint8 \<Rightarrow> nat" where
  "bit_pop_count s0 = length (filter (\<lambda> k. bit_test k s0) [0, 1, 2, 3, 4, 5, 6, 7])"

definition bit_pop_count_opt :: "uint8 \<Rightarrow> nat" where
  "bit_pop_count_opt s0 =
      (let s1 = (s0 AND 0x55) + ((s0 >> 1) AND 0x55);
           s2 = (s1 AND 0x33) + ((s1 >> 2) AND 0x33);
           s3 = (s2 AND 0x0F) + ((s2 >> 4) AND 0x0F)
        in  nat_of_uint8 s3)"

definition bit_is_subset :: "'a :: bit_operations \<Rightarrow> 'a \<Rightarrow> bool" where
  "bit_is_subset x_sub x \<longleftrightarrow> x_sub AND x = x_sub"

definition bit_inter :: "'a :: bit_operations \<Rightarrow> 'a \<Rightarrow> 'a" where
  "bit_inter x1 x2 = x1 AND x2"

definition bit_union :: "'a :: bit_operations \<Rightarrow> 'a \<Rightarrow> 'a" where
  "bit_union x1 x2 = x1 OR x2"

definition bit_last_one_pos :: "uint64 \<Rightarrow> nat" where
  "bit_last_one_pos x = 
     (if x = 0 then 64
      else 
        let mks = [(0xFFFFFFFF, 32), (0xFFFF, 16), (0xFF, 8), (0xF, 4), (0x3, 2), (0x1, 1)];
            upd = (\<lambda> (x, n) (m, k). if x AND m = 0 then (x >> k, n + k) else (x, n))
         in snd (foldl upd (x, 0) mks))"

definition bit_clear_last_one :: "'a :: {bit_operations, uminus} \<Rightarrow> 'a" where
  "bit_clear_last_one x = x XOR (x AND (-x))"

definition bit_shiftL :: "nat \<Rightarrow> 'a :: bit_operations \<Rightarrow> 'a" where
  "bit_shiftL k x = x << k"

subsection \<open>Auxiliary lemmas\<close>

lemma nat7_cases:
  fixes k::nat
  shows "k \<le> 6 \<Longrightarrow> k = 0 \<or> k = 1 \<or> k = 2 \<or> k = 3 \<or> k = 4 \<or> k = 5 \<or> k = 6"
  by presburger

lemma nat8_cases:
  fixes k::nat
  shows "k < 8 \<Longrightarrow> k = 0 \<or> k = 1 \<or> k = 2 \<or> k = 3 \<or> k = 4 \<or> k = 5 \<or> k = 6 \<or> k = 7"
  by presburger

lemma nat64_cases:
  fixes k::nat
  shows "k < 64 \<Longrightarrow> k = 0 \<or> k = 1 \<or> k = 2 \<or> k = 3 \<or> k = 4 \<or> k = 5 \<or> k = 6 \<or> k = 7 \<or> k = 8 \<or> k = 9 \<or> k = 10 \<or> k = 11 \<or>
          k = 12 \<or> k = 13 \<or> k = 14 \<or> k = 15 \<or> k = 16 \<or> k = 17 \<or> k = 18 \<or> k = 19 \<or> k = 20 \<or> k = 21 \<or> k = 22 \<or>
          k = 23 \<or> k = 24 \<or> k = 25 \<or> k = 26 \<or> k = 27 \<or> k = 28 \<or> k = 29 \<or> k = 30 \<or> k = 31 \<or> k = 32 \<or> k = 33 \<or>
          k = 34 \<or> k = 35 \<or> k = 36 \<or> k = 37 \<or> k = 38 \<or> k = 39 \<or> k = 40 \<or> k = 41 \<or> k = 42 \<or> k = 43 \<or> k = 44 \<or>
          k = 45 \<or> k = 46 \<or> k = 47 \<or> k = 48 \<or> k = 49 \<or> k = 50 \<or> k = 51 \<or> k = 52 \<or> k = 53 \<or> k = 54 \<or> k = 55 \<or>
          k = 56 \<or> k = 57 \<or> k = 58 \<or> k = 59 \<or> k = 60 \<or> k = 61 \<or> k = 62 \<or> k = 63" (is "k < 64 \<Longrightarrow> ?P k")
  by presburger  

lemma takefill_append:
  shows "takefill d k (xs @ ys) = (if k \<le> length xs then takefill d k xs else xs @ takefill d (k - length xs) ys)"
  apply (induction xs)
   apply (auto split: if_split_asm)
    apply (metis le_add_same_cancel2 length_takefill not_le plus_1_eq_Suc take_takefill takefill.simps(1) takefill_minus takefill_minus_simps(2) zero_le)
   apply (metis add.left_neutral append_Nil2 le_SucE takefill.Z takefill_Suc_Cons takefill_append)
  apply (smt Suc_diff_le add_Suc_right diff_Suc_Suc length_append length_takefill nat_le_linear takefill_Suc_Cons takefill_append)
  done

subsection \<open>Additions to the Word library\<close>

lemma sum_2_pow_ub: 
  shows "(\<Sum>k\<leftarrow>[0..<m]. if P k then 2 ^ k else 0) < (2::int) ^ m"
proof (induction m)
  case 0
  then show ?case 
    by simp
next
  case (Suc m)
  then show ?case
    by auto (smt zle2p)
qed

lemma bin_to_bl_trunc:
  assumes "n \<ge> 2 ^ m" "n < 2 ^ (Suc m)"
  shows "bin_to_bl m n = bin_to_bl m (n - 2 ^ m)"
proof-
  have "bintrunc m n = n - 2 ^ m"
    using assms
    unfolding bintrunc_mod2p
    by (smt Bit_def int_mod_eq' minus_mod_self2 power_BIT)
  thus ?thesis
    by (metis bin_bl_bin bl_bin_bl size_bin_to_bl)
qed  

lemma bin_nth_trunc:
  assumes "n \<ge> 2 ^ m" "n < 2 ^ (Suc m)" "k < m"
  shows "bin_nth n k = bin_nth (n - 2 ^ m) k"
proof-
  have "bintrunc m n = n - 2 ^ m"
    using assms
    unfolding bintrunc_mod2p
    by (smt Bit_def int_mod_eq' minus_mod_self2 power_BIT)
  thus ?thesis
    by (metis assms(3) nth_bintr)
qed

lemma bin_nth_False:
  assumes "n \<ge> 0" "n < 2 ^ m"
  shows "bin_nth n m = False"
  using assms
proof (induction m arbitrary: n)
  case 0
  hence "n = 0"
    by simp
  then show ?case
    by simp
next
  case (Suc m)
  show ?case
    using Suc(1)[of "bin_rest n"] Suc.prems
    using bin_rest_def by auto
qed
  
lemma bin_nth_True:
  assumes "n \<ge> 2^m" "n < 2^(Suc m)"
  shows "bin_nth n m"
  using assms
proof (induction m arbitrary: n)
  case 0
  hence "n = 1"
    by simp
  then show ?case
    by simp
next
  case (Suc m)
  show ?case
    using Suc(1)[of "bin_rest n"] Suc.prems
    by (simp add: bin_rest_def)
qed

lemma sum_2_pow: 
  assumes "0 \<le> n" "n < 2^m"
  shows "sum_list (map (\<lambda> k. if bin_nth n k then 2 ^ k else 0) [0..<m]) = n"
  using assms
proof (induction m arbitrary: n)
  case 0
  then show ?case
    by simp
next
  case (Suc m)
  show ?case
  proof (cases "n < 2 ^ m")
    case True
    then show ?thesis
      using Suc(1)[of n] Suc.prems bin_nth_False
      by simp
  next
    case False

    have *: "map (\<lambda> k. if bin_nth (n - 2 ^ m) k then 2 ^ k else 0) [0..<m] = 
          map (\<lambda> k. if bin_nth n k then 2 ^ k else 0) [0..<m]"
      using bin_nth_trunc[of m n] Suc.prems False
      by auto
    show ?thesis
      using Suc(1)[of "n - 2^m"] Suc.prems False bin_nth_True[of m n]
      by (auto simp add: *)
  qed
qed

lemma bin_to_bl_Suc:
  shows "bin_to_bl (Suc m) n = (bin_nth n m) # bin_to_bl m n" (is "?lhs = ?rhs")
proof (rule nth_equalityI)
  show "length ?lhs = length ?rhs"
    by (metis length_Cons size_bin_to_bl)
next
  fix i
  assume "i < length ?lhs"
  thus "bin_to_bl (Suc m) n ! i = (bin_nth n m # bin_to_bl m n) ! i"
    using nth_bin_to_bl[of i "Suc m" n]
    using nth_bin_to_bl[of i "m" n]
    by (metis Suc_pred add_diff_cancel_left add_diff_cancel_left' less_Suc_eq_0_disj nth_Cons' nth_bin_to_bl plus_1_eq_Suc size_bin_to_bl)
qed

lemma bin_to_bl_lex:
  assumes "n1 \<ge> 0" "n1 < 2^m" "n2 \<ge> 0" "n2 < 2^m"
  shows "bin_to_bl m n1 < bin_to_bl m n2 \<longleftrightarrow> n1 < n2"
proof
  assume "bin_to_bl m n1 < bin_to_bl m n2"
  moreover
  {
    assume "\<exists>a v. bin_to_bl m n2 = bin_to_bl m n1 @ a # v"
    then obtain a v where "bin_to_bl m n2 = bin_to_bl m n1 @ a # v"
      by auto
    hence "length (bin_to_bl m n2) = length (bin_to_bl m n1) + length v + 1"
      by auto
    hence False
      using size_bin_to_bl
      by auto
    hence "n1 < n2"
      by auto
  }
  moreover
  {
    assume "\<exists>u a b v w. a < b \<and> bin_to_bl m n1 = u @ a # v \<and> bin_to_bl m n2 = u @ b # w"
    then obtain u v w where
      *: "bin_to_bl m n1 = u @ False # v" "bin_to_bl m n2 = u @ True # w"
      by auto
    have "length v = length w"
      using * size_bin_to_bl
      by (metis add.right_neutral add_Suc_right add_diff_cancel_left' length_append list.size(4) nat.simps(1))

    have "length v < m"
      using * size_bin_to_bl
      by (metis One_nat_def add.commute dual_order.strict_trans2 length_append less_add_one list.size(4) not_add_less1 not_le)

    have "n1 = (\<Sum>k\<leftarrow>[0..<m]. if bin_nth n1 k then 2 ^ k else 0)"
         "n2 = (\<Sum>k\<leftarrow>[0..<m]. if bin_nth n2 k then 2 ^ k else 0)"
      using sum_2_pow[of n1 m] assms(1-2) *(1)
      using sum_2_pow[of n2 m] assms(3-4) *(2)
      by auto
    moreover
    have "map (\<lambda> k. if rev (bin_to_bl m n1) ! k then 2 ^ k else 0) [0..<m] = 
          map (\<lambda> k. if bin_nth n1 k then 2 ^ k else 0) [0..<m]"
          "map (\<lambda> k. if rev (bin_to_bl m n2) ! k then 2 ^ k else 0) [0..<m] = 
          map (\<lambda> k. if bin_nth n2 k then 2 ^ k else 0) [0..<m]"
      using bin_nth_bl[of _ m n1] bin_nth_bl[of _ m n2]
      by auto
    ultimately
    have **: "n1 = (\<Sum>k\<leftarrow>[0..<m]. if rev (bin_to_bl m n1) ! k then 2 ^ k else 0)"
         "n2 = (\<Sum>k\<leftarrow>[0..<m]. if rev (bin_to_bl m n2) ! k then 2 ^ k else 0)"
      by smt+

    have ***: 
      "n1 = (\<Sum>k\<leftarrow>[0..<m]. if (rev v @ False # rev u) ! k then 2 ^ k else 0)"
      "n2 = (\<Sum>k\<leftarrow>[0..<m]. if (rev w @ True # rev u) ! k then 2 ^ k else 0)"
      using * **
      by auto

    have "[0..<m] = [0..<length v] @ [length v] @ [length v + 1 ..< m]"
      using `length v < m`
      by (metis One_nat_def Suc_leI add_gr_0 append_Cons append_Nil le_add_same_cancel2 less_imp_add_positive less_numeral_extra(1) upt_add_eq_append upt_eq_list_intros(2))
    hence "n1 = (\<Sum>k\<leftarrow>[0..<length v]. if (rev v @ False # rev u) ! k then 2 ^ k else 0) + 
                (\<Sum>k\<leftarrow>[length v]. if (rev v @ False # rev u) ! k then 2 ^ k else 0) +
                (\<Sum>k\<leftarrow>[length v+1..<m]. if (rev v @ False # rev u) ! k then 2 ^ k else 0)"
      using ***
      by auto
    moreover
    have "(\<Sum>k\<leftarrow>[0..<length v]. if (rev v @ False # rev u) ! k then 2 ^ k else 0) = 
          (\<Sum>k\<leftarrow>[0..<length v]. if rev v ! k then 2 ^ k else 0)"
    proof-
      have "map (\<lambda> k. if (rev v @ False # rev u) ! k then 2 ^ k else 0) [0..<length v] = 
            map (\<lambda> k. if rev v ! k then 2 ^ k else 0) [0..<length v]"
        by (auto simp add: nth_append)
      thus ?thesis
        by smt
    qed
    moreover
    have "(\<Sum>k\<leftarrow>[length v]. if (rev v @ False # rev u) ! k then 2 ^ k else 0) = 0"
      by (auto simp add: nth_append)
    moreover
    have "(\<Sum>k\<leftarrow>[length v+1..<m]. if (rev v @ False # rev u) ! k then 2 ^ k else 0) = 
          (\<Sum>k\<leftarrow>[length v+1..<m]. if rev u ! (k - length v - 1) then 2 ^ k else 0)"
    proof-
      have "map (\<lambda> k. if (rev v @ False # rev u) ! k then 2 ^ k else 0) [length v+1..<m] = 
            map (\<lambda> k. if rev u ! (k - length v - 1) then 2 ^ k else 0) [length v+1 ..<m]"
        by (auto simp add: nth_append)
      thus ?thesis
        by smt
    qed
    ultimately
    have n1: "n1 = (\<Sum>k\<leftarrow>[0..<length v]. if rev v ! k then 2 ^ k else 0) +
                   (\<Sum>k\<leftarrow>[length v+1..<m]. if rev u ! (k - length v - 1) then 2 ^ k else 0)"
      by smt

    have "[0..<m] = [0..<length w] @ [length w] @ [length w + 1 ..< m]"
      using `length v < m` `length v = length w`
      by (metis One_nat_def Suc_leI add_gr_0 append_Cons append_Nil le_add_same_cancel2 less_imp_add_positive less_numeral_extra(1) upt_add_eq_append upt_eq_list_intros(2))
    hence "n2 = (\<Sum>k\<leftarrow>[0..<length w]. if (rev w @ True # rev u) ! k then 2 ^ k else 0) + 
                (\<Sum>k\<leftarrow>[length w]. if (rev w @ True # rev u) ! k then 2 ^ k else 0) +
                (\<Sum>k\<leftarrow>[length w+1..<m]. if (rev w @ True # rev u) ! k then 2 ^ k else 0)"
      using ***
      by auto
    moreover
    have "(\<Sum>k\<leftarrow>[0..<length w]. if (rev w @ True # rev u) ! k then 2 ^ k else 0) = 
          (\<Sum>k\<leftarrow>[0..<length v]. if rev w ! k then 2 ^ k else 0)"
    proof-
      have "map (\<lambda> k. if (rev w @ True # rev u) ! k then 2 ^ k else 0) [0..<length w] = 
            map (\<lambda> k. if rev w ! k then 2 ^ k else 0) [0..<length v]"
        using `length v = length w`
        by (auto simp add: nth_append)
      thus ?thesis
        by smt
    qed
    moreover
    have "(\<Sum>k\<leftarrow>[length w]. if (rev w @ True # rev u) ! k then 2 ^ k else 0) = 2 ^ length v"
      using `length v = length w`
      by (auto simp add: nth_append)
    moreover
    have "(\<Sum>k\<leftarrow>[length w+1..<m]. if (rev w @ True # rev u) ! k then 2 ^ k else 0) = 
          (\<Sum>k\<leftarrow>[length v+1..<m]. if rev u ! (k - length v - 1) then 2 ^ k else 0)"
    proof-
      have "map (\<lambda> k. if (rev w @ True # rev u) ! k then 2 ^ k else 0) [length w+1..<m] = 
            map (\<lambda> k. if rev u ! (k - length v - 1) then 2 ^ k else 0) [length v+1 ..<m]"
        using `length v = length w`
        by (auto simp add: nth_append)
      thus ?thesis
        by smt
    qed
    ultimately
    have n2: "n2 = (\<Sum>k\<leftarrow>[0..<length v]. if rev w ! k then 2 ^ k else 0) + 2 ^ length v + 
               (\<Sum>k\<leftarrow>[length v+1..<m]. if rev u ! (k - length v - 1) then 2 ^ k else 0)"
      by smt

    have "(\<Sum>k\<leftarrow>[0..<length v]. if rev v ! k then 2 ^ k else 0) < (2::int) ^ length v"
      using sum_2_pow_ub
      by simp
    moreover
    have "(\<Sum>k\<leftarrow>[0..<length v]. if rev w ! k then 2 ^ k else 0) \<ge> (0::int)"
      by (rule sum_list_nonneg) auto
    ultimately
    have "n1 < n2"
      using n1 n2
      by smt
  }
  ultimately
  show "n1 < n2"
    using assms
    unfolding list_less_def lexord_def
    by auto
next
  assume "n1 < n2"
  then show "bin_to_bl m n1 < bin_to_bl m n2"
    using assms
  proof (induction m arbitrary: n1 n2)
    case 0
    then show ?case
      by simp
  next
    case (Suc m)
    show ?case
    proof (cases "n2 < 2 ^ m")
      case True
      hence "bin_to_bl m n1 < bin_to_bl m n2"
        using Suc(1)[of n1 n2]
        using Suc.prems
        by simp
      moreover
      have "bin_nth n1 m = False" "bin_nth n2 m = False"
        using True Suc.prems
        using bin_nth_False
        by simp_all
      ultimately
      show ?thesis
        using bin_to_bl_Suc[of m n1] bin_to_bl_Suc[of m n2]
        by simp
    next
      case False

      have "bin_nth n2 m = True"
        using bin_nth_True[of m n2] False Suc.prems
        by simp

      show ?thesis
      proof (cases "n1 < 2 ^ m")
        case True
        hence "bin_nth n1 m = False"
          using Suc.prems
          using bin_nth_False
          by simp
        then show ?thesis
          using `bin_nth n2 m = True`
          using bin_to_bl_Suc[of m n1] bin_to_bl_Suc[of m n2]
          by simp
      next
        case False
        have "bin_nth n1 m = True"
          using bin_nth_True[of m n1] Suc.prems False
          by simp

        have "bin_to_bl m (n1 - 2 ^ m) < bin_to_bl m (n2 - 2 ^ m)"
             "bin_to_bl (Suc m) n1 = True # bin_to_bl m (n1 - 2 ^ m)"
             "bin_to_bl (Suc m) n2 = True # bin_to_bl m (n2 - 2 ^ m)"
          using `bin_nth n1 m = True` `bin_nth n2 m = True`
          using bin_to_bl_Suc[of m n1] bin_to_bl_Suc[of m n2]
          using Suc(1)[of "n1 - 2^m" "n2 - 2^m"] Suc.prems
          using `\<not> n1 < 2 ^ m` `\<not> n2 < 2 ^ m`
          using bin_to_bl_trunc[of m n1] bin_to_bl_trunc[of m n2]
          by auto

        thus ?thesis
          by simp
      qed
    qed
  qed
qed

lemma trailing_zeros_ub':
  fixes n :: nat
  assumes "x \<ge> 0" "\<forall> k \<ge> n. \<not> bin_nth x k"
  shows "x < 2 ^ n"
proof-
  have "\<forall> x::nat. \<exists> m. 2 ^ m > x"
  proof
    fix x :: nat
    show "\<exists> m. x < 2 ^ m"
    proof (induction x)
      case 0
      then show ?case
        by simp
    next
      case (Suc x)
      then obtain m where "x < 2 ^ m"
        by auto
      hence "Suc x < 2 ^ Suc m"
        by auto
      then show ?case
        by blast
    qed
  qed
  then obtain m where "x < 2 ^ m"
    by (metis nat_int numeral_power_eq_of_nat_cancel_iff zless_nat_conj)
  show ?thesis
  proof (cases "n \<le> m")
    case True
    have "x = (\<Sum>k\<leftarrow>[0..<m]. if bin_nth x k then 2 ^ k else 0)"
      using sum_2_pow[of x m] `x < 2 ^ m` assms
      by simp
    moreover
    have "[0..<m] = [0..<n] @ [n..<m]"
      using `n \<le> m`
      by (metis le_iff_add upt_add_eq_append zero_le)
    ultimately
    have "x = (\<Sum>k\<leftarrow>[0..<n]. if bin_nth x k then 2 ^ k else 0) + 
              (\<Sum>k\<leftarrow>[n..<m]. if bin_nth x k then 2 ^ k else 0)"
      by simp
    moreover
    have "(\<Sum>k\<leftarrow>[n..<m]. if bin_nth x k then 2 ^ k else 0) = (\<Sum>k\<leftarrow>[n..<m]. (0::int))"
      using assms
      by (smt atLeastLessThan_iff map_eq_conv set_upt)
    ultimately
    have "x = (\<Sum>k\<leftarrow>[0..<n]. if bin_nth x k then 2 ^ k else 0)"
      by simp
    thus ?thesis
      using sum_2_pow_ub[of "\<lambda> k. bin_nth x k" n]
      by simp
  next
    case False
    then show ?thesis 
      using `x < 2 ^ m`
      using less_le_trans by fastforce
  qed
qed

subsection \<open>Additions to the @{text uint8} and @{text uint64} types\<close>

lemma size_uint8 [simp]: 
  shows "size (s::uint8) = 8"
  by transfer (simp add: word_size)

lemma size_uint64 [simp]: 
  shows "size (s::uint64) = 64"
  by transfer (simp add: word_size)

lemma zero_neq_one_uint8 [simp]:
  "(0::uint8) \<noteq> (1::uint8)"
  by transfer simp

lemma [simp]:
  "uint8_of_nat 64 = 64"
  "nat_of_uint8 64 = 64"
  "uint8_of_nat 8 = 8"
  "nat_of_uint8 8 = 8"
  "uint8_of_nat 1 = 1"
  "nat_of_uint8 1 = 1"
  "uint8_of_nat 0 = 0"
  "nat_of_uint8 0 = 0"
  by eval+

lemma pow2_uint8 [simp]:
  "(2::uint8)^6 = 64"
  "(2::uint8)^7 = 128"
  by eval+

lemma lt_64_128 [simp]: 
  "(64::uint8) < 128"
  by eval

lemma zero_no_bit_8 [simp]:
  shows "\<not> (0::uint8) !! s"
  by transfer auto

lemma zero_no_bit_64 [simp]:
  shows "\<not> (0::uint64) !! s"
  by transfer auto

lemma uint8_eq_I:
  fixes x y :: uint8
  shows "x = y \<longleftrightarrow> (\<forall> i < 8. x !! i = y !! i)"
  by transfer (metis (full_types) test_bit_uint8.abs_eq test_bit_uint8_code word_eqI)

lemma uint64_eq_I:
  fixes x y :: uint64
  shows "x = y \<longleftrightarrow> (\<forall> i < 64. x !! i = y !! i)"
  by transfer (metis (full_types) test_bit_uint64.abs_eq test_bit_uint64_code word_eqI)

lemma uint8_of_nat_nonneg [simp]:
  shows "uint8_of_nat x \<ge> 0"
  unfolding uint8_of_nat_def
  by transfer auto

lemma uint8_of_nat_nat_of_uint8 [simp]: 
  "uint8_of_nat (nat_of_uint8 x) = x"
  unfolding uint8_of_nat_def
  by simp (transfer, word_bitwise, metis word_of_nat word_unat.Rep_inverse)

lemma nat_of_uint8_uint8_of_nat [simp]:
  assumes "x < 2 ^ 8"
  shows "nat_of_uint8 (uint8_of_nat x) = x"
  using assms
  unfolding uint8_of_nat_def
  unfolding comp_def
  by (transfer, simp add: unat_def uint_word_of_int)

lemma nat_of_uint8_mono:
  fixes a b :: uint8
  shows "a < b \<longleftrightarrow> nat_of_uint8 a < nat_of_uint8 b"
  by transfer (metis word_less_nat_alt)

lemma nat_of_uint8_mono_leq:
  fixes a b :: uint8
  shows "a \<le> b \<longleftrightarrow> nat_of_uint8 a \<le> nat_of_uint8 b"
  by transfer (simp add: word_le_nat_alt)

lemma uint8_of_nat_mono:
  fixes a b :: nat
  assumes "a < 2 ^ 8" "b < 2 ^ 8"
  shows "a < b \<longleftrightarrow> uint8_of_nat a < uint8_of_nat b"
  using assms
  unfolding uint8_of_nat_def
  by transfer (auto simp add: wi_less)

lemma uint8_of_nat_mono_leq:
  fixes a b :: nat
  assumes "a < 2 ^ 8" "b < 2 ^ 8"
  shows "a \<le> b \<longleftrightarrow> uint8_of_nat a \<le> uint8_of_nat b"
  using assms
  unfolding uint8_of_nat_def
  by transfer (auto simp add: wi_le)

lemma uint8_of_nat_inj [simp]: 
  assumes "x < 2 ^ 8" "y < 2 ^ 8"
  assumes "uint8_of_nat x = uint8_of_nat y"
  shows "x = y"
proof-
  have "x < 256" "y < 256"
    using assms
    by simp_all
  have "nat_of_uint8 (uint8_of_nat x) = nat_of_uint8 (uint8_of_nat y)"
    using assms
    by simp
  thus ?thesis
    using `x < 256` `y < 256` nat_of_uint8_uint8_of_nat[of x] nat_of_uint8_uint8_of_nat[of y]
    by simp
qed

lemma shiftL_zero_id [simp]:
  fixes x :: uint8
  shows "x << 0 = x"
  by transfer auto

lemma shiftL_Suc [simp]:
  assumes "k < 7" 
  shows "(1::uint8) << (Suc k) = 2 * (1 << k)"
  using assms
  by transfer simp

lemma pow2_eq_shiftL_uint8:
  assumes "k < 8"
  shows "(2::uint8) ^ k = (1::uint8) << k"
  using assms
 by (induction k) auto

lemma pow2_leq_64_uint8:
  assumes "k \<le> 6"
  shows "(2::uint8)^k \<le> 64"
  using assms
proof (subst pow2_eq_shiftL_uint8)
  show "k < 8"
    using assms
    by simp
next
  show "(1::uint8) << k \<le> 64"
    using assms
  proof transfer
    fix k :: nat
    assume "k \<le> 6"
    thus "(1::8 word) << k \<le> 64"
      using nat7_cases[of k]
      by auto
  qed
qed

lemma nat_of_uint8_mult2 [simp]:
  assumes "x < 128"
  shows "nat_of_uint8 (2 * x) = 2 * nat_of_uint8 x"
  using assms
proof transfer
  fix x :: "8 word"
  assume "x < 128"
  thus "unat (2 * x) = 2 * unat x"
    using unat_mult_lem[of 2 x]
    using unat_mono[of x 128]
    by simp
qed

lemma nat_of_uint8_2pow [simp]:
  assumes "k < 8"
  shows "nat_of_uint8 (2 ^ k) = 2 ^ k"
  using assms
proof (induction k)
  case 0
  thus ?case
    by simp
next
  case (Suc k)
  hence "k \<le> 6"
    by simp
  hence "(2::uint8)^k < 128"
    using pow2_leq_64_uint8[of k] lt_64_128
    using dual_order.strict_trans2 by blast
  thus ?case
    using Suc
    by simp
qed
     
lemma uint8_of_nat_2pow:
  assumes "k < 8"
  shows "2^k = uint8_of_nat (2^k)"
  using nat_of_uint8_2pow[of k] assms
  by (metis uint8_of_nat_nat_of_uint8)

lemma shiftL_mono_uint8 [simp]:
  assumes "n < 7"
  shows "(1::uint8) << n < (1::uint8) << (Suc n)"
  using assms
  by (metis Suc_mono eval_nat_numeral(2) less_SucI nat_of_uint8_2pow nat_of_uint8_mono one_less_numeral_iff pow2_eq_shiftL_uint8 power_less_power_Suc semiring_norm(26) semiring_norm(27) semiring_norm(76) semiring_normalization_rules(27))

lemma bang_is_le:
  fixes x :: uint8 and k :: nat
  assumes "k < 8" "x !! k"
  shows "x \<ge> 2^k"
proof-
  have "x \<ge> 1 << k"
    using assms
    by transfer (simp add: bang_is_le)
  moreover 
  have "(1::uint8) << k = 2 ^ k"
    using `k < 8`
    by (simp add: pow2_eq_shiftL_uint8)
  ultimately
  show ?thesis
    by simp
qed

lemma bang_is_le_nat:
  fixes x :: nat and k ::nat
  assumes "x < 2 ^ 8" "k < 8" "uint8_of_nat x !! k"
  shows "x \<ge> 2 ^ k"
proof-
  have "2 ^ k \<le> uint8_of_nat x"
    using bang_is_le[OF assms(2-3)]
    by simp
  hence "nat_of_uint8 (2 ^ k) \<le> nat_of_uint8 (uint8_of_nat x)"
    using nat_of_uint8_mono[of "uint8_of_nat x" "2 ^ k"]
    by auto
  thus ?thesis
    using `k < 8` `x < 2 ^ 8`
    by simp
qed
  
lemma trailing_zeros_ub:
  fixes x :: uint8
  assumes "n < 8" "\<forall> k \<ge> n. \<not> x !! k"
  shows "x < 1 << n"
  using assms
proof transfer
  fix n :: nat and x :: "8 word"
  assume *: "n < 8" "\<forall>k\<ge>n. \<not> x !! k"

  have "\<forall>k\<ge>n. \<not> bin_nth (uint x) k"
    using *
    by (simp add: word_test_bit_def)
  hence "uint x < 2 ^ n"
    using trailing_zeros_ub'[of "uint x" n]
    by simp
  hence "uint x < 1 << n"
    by (simp add: shiftl_int_def)
  moreover
  have "size (1::8 word) = 8"
    unfolding word_size
    by simp
  hence "uint ((1::8 word) << n) = 1 << n"
    using `n < 8`
    unfolding uint_shiftl
    using bintrunc_shiftl[of 8 1 n]
    using bintrunc_minus_simps(4)[of "8 - n"]
    by simp
  ultimately
  show "x < 1 << n"
    using word_less_alt[of x "1 << n"]
    by simp
qed

subsection \<open>Properties of bitwise operations\<close>

text \<open>@{text to_bl} - conversion to bool lists\<close>

lift_definition to_bl_uint8 :: "uint8 \<Rightarrow> bool list" is to_bl
  done

lift_definition to_bl_uint64 :: "uint64 \<Rightarrow> bool list" is to_bl
  done

lemma length_to_bl_uint8 [simp]:
  shows "length (to_bl_uint8 x) = 8"
  by transfer simp

lemma length_to_bl_uint64 [simp]:
  shows "length (to_bl_uint64 x) = 64"
  by transfer simp

lemma test_bit_bl_uint8:
  fixes x :: uint8
  shows "test_bit x k \<longleftrightarrow> k < size x \<and> rev (to_bl_uint8 x) ! k"
  by transfer (auto simp add: test_bit_bl)

lemma test_bit_bl_uint64:
  fixes x :: uint64
  shows "test_bit x k \<longleftrightarrow> k < size x \<and> rev (to_bl_uint64 x) ! k"
  by transfer (auto simp add: test_bit_bl)

text \<open>@{text of_bl} - construction from bool list\<close>

lift_definition of_bl_uint8 :: "bool list \<Rightarrow> uint8" is of_bl
  done

lift_definition of_bl_uint64 :: "bool list \<Rightarrow> uint64" is of_bl
  done

lemma of_bl_to_bl_uint8[simp]: 
  shows "of_bl_uint8 (to_bl_uint8 x) = x"
  by transfer simp

lemma of_bl_to_bl_uint64[simp]: 
  shows "of_bl_uint64 (to_bl_uint64 x) = x"
  by transfer simp

lemma test_bit_of_bl_uint8:
  assumes "i < 8" "length xs = 8"
  shows "of_bl_uint8 xs !! i \<longleftrightarrow> rev xs ! i"
  using assms
proof transfer
  fix i :: nat and xs :: "bool list"
  assume "i < 8" "length xs = 8"
  thus "(of_bl :: bool list \<Rightarrow> 8 word) xs !! i = rev xs ! i"
    using test_bit_of_bl[of xs i]
    by auto
qed

lemma test_bit_of_bl_uint64:
  assumes "i < 64" "length xs = 64"
  shows "of_bl_uint64 xs !! i \<longleftrightarrow> rev xs ! i"
  using assms
proof transfer
  fix i :: nat and xs :: "bool list"
  assume "i < 64" "length xs = 64"
  thus "(of_bl :: bool list \<Rightarrow> 64 word) xs !! i = rev xs ! i"
    using test_bit_of_bl[of xs i]
    by auto
qed

lemma to_bl_of_bl_uint8:
  assumes "length x = 8"
  shows "to_bl_uint8 (of_bl_uint8 x) = x" (is "?lhs = x")
proof (rule nth_equalityI)
  show "length ?lhs = length x"
    using assms
    by simp
next
  fix i
  assume "i < length ?lhs"
  hence "i < 8"
    using assms
    by simp
  show "?lhs ! i = x ! i"
    using `i < 8` `length x = 8`
    using test_bit_bl_uint8[of "of_bl_uint8 x" "8 - i - 1"]
    using test_bit_of_bl_uint8[of "8 - i - 1" x]
    using rev_nth[of "8 - i - 1" ?lhs]
    using rev_nth[of "8 - i - 1" x]
    by auto
qed


text \<open>@{text bit_test}\<close>

lemma bit_test_bl_8word:
  fixes x :: "8 word"
  shows "(x AND (1 << k) \<noteq> 0) = (k < size x \<and> rev (to_bl x) ! k)"
  using nat8_cases[of k]
  apply word_bitwise
  apply (case_tac "k < 8")
    apply (simp_all add: takefill_append)
  apply ((erule disjE, simp)+, simp)
  done

lemma bit_test_bl_64word:
  fixes x :: "64 word"
  shows "(x AND (1 << k) \<noteq> 0) = (k < size x \<and> rev (to_bl x) ! k)"
  using nat64_cases[of k]
  apply word_bitwise
  apply (case_tac "k < 64")
    apply (simp_all add: takefill_append)
  apply ((erule disjE, simp)+, simp)
  done

lemma bit_test_bl_uint8:
  fixes x :: uint8
  shows "bit_test k x \<longleftrightarrow> k < size x \<and> rev (to_bl_uint8 x) ! k"
  unfolding bit_test_def
  by transfer (metis bit_test_bl_8word)

lemma bit_test_bl_uint64:
  fixes x :: uint64
  shows "bit_test k x \<longleftrightarrow> k < size x \<and> rev (to_bl_uint64 x) ! k"
  unfolding bit_test_def
  by transfer (metis bit_test_bl_64word)

lemma bit_test_test_bit_uint8:
  fixes x :: uint8
  shows "bit_test k x \<longleftrightarrow> test_bit x k"
  by (simp add: bit_test_bl_uint8 test_bit_bl_uint8)

lemma bit_test_test_bit_uint64:
  fixes x :: uint64
  shows "bit_test k x \<longleftrightarrow> test_bit x k"
  by (simp add: bit_test_bl_uint64 test_bit_bl_uint64)

text \<open>@{text to_bl'} - bool list via @{text bit_test}\<close>

definition to_bl' :: "'a :: {bit_operations, zero, one, size} \<Rightarrow> bool list" where
  "to_bl' s = map (\<lambda> k. bit_test k s) [0..<size s]"

lemma to_bl'_uint8:
  fixes x :: uint8
  shows "to_bl' x = rev (to_bl_uint8 x)"
  unfolding to_bl'_def bit_test_def
proof transfer
  fix x :: "8 word"
  show "map (\<lambda>k. x AND (1 << k) \<noteq> 0) [0..<size x] = rev (to_bl x)" (is "?lhs = ?rhs")
  proof (rule nth_equalityI)
    show "length ?lhs = length ?rhs"
      by (simp add: word_size)
  next
    fix i
    assume "i < length ?lhs"
    hence "i < 8"
      by (simp add: word_size)
    thus "?lhs ! i = ?rhs ! i"
      using bit_test_bl_8word
      using Word.test_bit_bl[symmetric, of x i]
      using size_uint8.abs_eq
      by simp
  qed
qed

lemma to_bl'_uint64:
  fixes x :: uint64
  shows "to_bl' x = rev (to_bl_uint64 x)"
  unfolding to_bl'_def bit_test_def
proof transfer
  fix x :: "64 word"
  show "map (\<lambda>k. x AND (1 << k) \<noteq> 0) [0..<size x] = rev (to_bl x)" (is "?lhs = ?rhs")
  proof (rule nth_equalityI)
    show "length ?lhs = length ?rhs"
      by (simp add: word_size)
  next
    fix i
    assume "i < length ?lhs"
    hence "i < 64"
      by (simp add: word_size)
    thus "?lhs ! i = ?rhs ! i"
      using bit_test_bl_64word
      using Word.test_bit_bl[symmetric, of x i]
      using size_uint64.abs_eq
      by simp
  qed
qed

text \<open>@{text one_bits_pos} - positions of 1 bits\<close>

definition one_bits_pos :: "'a :: {bit_operations, zero, one, size} \<Rightarrow> nat list" where 
  "one_bits_pos x = positions (to_bl' x)"

lemma one_bits_uint8':
  fixes x :: uint8
  shows "one_bits_pos x = filter (\<lambda> k. bit_test k x) [0..<8]" (is "?lhs = ?rhs")
proof-
  let ?A = "{p. p < length (to_bl' x) \<and> to_bl' x ! p}"
  have "sorted ?lhs" "distinct ?lhs" "set ?lhs = ?A"
    unfolding one_bits_pos_def
    by (auto simp add: sorted_positions distinct_positions set_positions)
  moreover
  have "sorted ?rhs" "distinct ?rhs"
    by auto
  moreover
  have "set ?rhs = ?A"
    using bit_test_bl_uint8[of _ x]
    using to_bl'_uint8[of x]
    by simp
  ultimately
  show ?thesis
    using sorted_distinct_set_unique
    by metis
qed

lemma one_bits_uint64':
  fixes x :: uint64
  shows "one_bits_pos x = filter (\<lambda> k. bit_test k x) [0..<64]" (is "?lhs = ?rhs")
proof-
  let ?A = "{p. p < length (to_bl' x) \<and> to_bl' x ! p}"
  have "sorted ?lhs" "distinct ?lhs" "set ?lhs = ?A"
    unfolding one_bits_pos_def
    by (auto simp add: sorted_positions distinct_positions set_positions)
  moreover
  have "sorted ?rhs" "distinct ?rhs"
    by auto
  moreover
  have "set ?rhs = ?A"
    using bit_test_bl_uint64[of _ x]
    using to_bl'_uint64[of x]
    by simp
  ultimately
  show ?thesis
    using sorted_distinct_set_unique
    by metis
qed

lemma set_one_bits_pos_uint8:
  fixes x :: uint8
  shows "k \<in> set (one_bits_pos x) \<longleftrightarrow> k < 8 \<and> test_bit x k"
  unfolding one_bits_pos_def
  by (simp add: set_positions test_bit_bl_uint8 to_bl'_uint8)

lemma set_one_bits_pos_uint64:
  fixes x :: uint64
  shows "k \<in> set (one_bits_pos x) \<longleftrightarrow> k < 64 \<and> test_bit x k"
  unfolding one_bits_pos_def
  by (simp add: set_positions test_bit_bl_uint64 to_bl'_uint64)

text \<open>@{text bit_set}\<close>

lemma set_bit_test_gen_uint8:
  fixes x :: uint8
  shows "test_bit (set_bit x k b) i = (if i = k then k < size x \<and> b else test_bit x i)" 
  by transfer (metis test_bit_set_gen)

lemma bit_set_test_gen_word8:
  fixes x :: "8 word"
  shows "(x OR (1 << k)) !! i = (if i = k then k < size x else x !! i)"
  by (metis (full_types) nth_w2p shiftl_1 test_bit_size word_ao_nth wsst_TYs(3))

lemma bit_set_test_gen_uint8:
  fixes x :: uint8
  shows "test_bit (bit_set k x) i = (if i = k then k < size x else test_bit x i)"
  unfolding bit_set_def
  by transfer (metis bit_set_test_gen_word8)

lemma set_bit_bit_set_uint8:
  fixes x :: uint8
  shows "bit_set k x = set_bit x k True"
  unfolding bit_set_def
  by (transfer, rule word_eqI) (metis bit_set_test_gen_word8 Word.test_bit_set_gen)

lemma set_bit_test_gen_uint64:
  fixes x :: uint64
  shows "test_bit (set_bit x k b) i = (if i = k then k < size x \<and> b else test_bit x i)" 
  by transfer (metis test_bit_set_gen)

lemma bit_set_test_gen_word64:
  fixes x :: "64 word"
  shows "(x OR (1 << k)) !! i = (if i = k then k < size x else x !! i)"
  by (metis (full_types) nth_w2p shiftl_1 test_bit_size word_ao_nth wsst_TYs(3))

lemma bit_set_test_gen_uint64:
  fixes x :: uint64
  shows "test_bit (bit_set k x) i = (if i = k then k < size x else test_bit x i)"
  unfolding bit_set_def
  by transfer (metis bit_set_test_gen_word64)

lemma set_bit_bit_set_uint64:
  fixes x :: uint64
  shows "bit_set k x = set_bit x k True"
  unfolding bit_set_def
  by (transfer, rule word_eqI) (metis bit_set_test_gen_word64 Word.test_bit_set_gen)


lemma set_one_bit_pos_bit_set_uint8:
  fixes x :: uint8
  assumes "k < 8"
  shows "set (one_bits_pos (bit_set k x)) = insert k (set (one_bits_pos x))"
  using assms
  by (auto simp add: set_one_bits_pos_uint8 bit_set_test_gen_uint8 split: if_split_asm)

lemma set_one_bit_pos_bit_set_uint64:
  fixes x :: uint64
  assumes "k < 64"
  shows "set (one_bits_pos (bit_set k x)) = insert k (set (one_bits_pos x))"
  using assms
  by (auto simp add: set_one_bits_pos_uint64 bit_set_test_gen_uint64 split: if_split_asm)

text \<open>@{text bit_put}\<close>

lemma bit_put_1:
  "bit_put k x 1 = bit_set k x"
  unfolding bit_set_def bit_put_def
  by simp

lemma bit_put_0_uint8:
  fixes x :: uint8
  shows "bit_put k x 0 = x"
  unfolding bit_put_def
  by transfer simp

text \<open>@{text bit_get}\<close>

lemma bit_get_word8:
  fixes x :: "8 word"
  shows "(x >> k) AND 1 = (if k < size x \<and> x !! k then 1 else 0)"
  using nat8_cases[of k]
  apply word_bitwise
  apply (case_tac "k < 8")
   apply (auto simp add: takefill_append)
  done

lemma bit_get_uint8:
  fixes x :: uint8
  shows "bit_get k x = (if k < size x \<and> test_bit x k then 1 else 0)"
  unfolding bit_get_def
  by transfer (metis bit_get_word8)


text \<open>@{text bit_clear}\<close>

lemma bit_clear_set_bit_word8:
  fixes x :: "8 word"
  shows "x AND NOT (1 << k) = set_bit x k False"
  by (rule word_eqI, auto simp add: word_size word_ao_nth test_bit_set_gen nth_w2p word_ops_nth_size)

lemma bit_clear_set_bit_word64:
  fixes x :: "64 word"
  shows "x AND NOT (1 << k) = set_bit x k False"
  by (rule word_eqI, auto simp add: word_size word_ao_nth test_bit_set_gen nth_w2p word_ops_nth_size)

lemma bit_clear_set_bit_uint8:
  fixes x :: uint8
  shows "bit_clear k x = set_bit x k False"
  unfolding bit_clear_def
  by transfer (metis bit_clear_set_bit_word8)

lemma bit_clear_set_bit_uint64:
  fixes x :: uint64
  shows "bit_clear k x = set_bit x k False"
  unfolding bit_clear_def
  by transfer (metis bit_clear_set_bit_word64)

text \<open>@{text bit_pop_count}\<close>

lemma bit_pop_count_uint8:
  fixes x :: uint8
  shows "bit_pop_count x = length (one_bits_pos x)"
proof-
  have "[0..<8] = [0, 1, 2, 3, 4, 5, 6, 7]"
    by (rule sorted_distinct_set_unique) auto
  thus ?thesis
    unfolding bit_pop_count_def one_bits_uint8'
    by simp
qed

text \<open>@{text bit_inter}\<close>

lemma bit_int_uint64:
  fixes x :: uint64
  shows "test_bit (bit_inter x y) k \<longleftrightarrow> test_bit x k \<and> test_bit y k"
  unfolding bit_inter_def
  by transfer (simp add: word_ao_nth)

text \<open>@{text bit_union}\<close>

lemma bit_union_uint8:
  fixes x :: uint8
  shows "test_bit (bit_union x y) k \<longleftrightarrow> test_bit x k \<or> test_bit y k"
  unfolding bit_union_def
  by transfer (simp add: word_ao_nth)

text \<open>@{text bit_is_subset}\<close>

lemma bit_is_subset_uint8:
  fixes x y :: uint8
  shows "bit_is_subset x y \<longleftrightarrow> (\<forall> k < size x. test_bit x k \<longrightarrow> test_bit y k)"
  unfolding bit_is_subset_def
  by transfer (metis (no_types, hide_lams) word_ao_nth word_eqI wsst_TYs(3))

text \<open>@{text of_pos}\<close>

definition of_pos :: "nat list \<Rightarrow> 'a :: {bit_operations, zero, one}" where
   "of_pos xs = foldr bit_set xs 0"

lemma of_pos_Nil [simp]:
  shows "of_pos [] = 0"
  by (simp add: of_pos_def)

lemma of_pos_Cons [simp]:
  shows "of_pos (x # xs) = bit_set x (of_pos xs)"
  by (simp add: of_pos_def)

lemma bit_test_of_pos_uint8:
  assumes "k < 8"
  shows "bit_test k ((of_pos::nat list \<Rightarrow> uint8) xs) \<longleftrightarrow> k \<in> set xs"
  using assms
proof (induction xs)
  case Nil
  thus ?case
    by (simp add: bit_test_test_bit_uint8)
next
  case (Cons x xs)
  then show ?case 
    by (simp add: bit_test_test_bit_uint8 bit_set_test_gen_uint8)
qed

lemma bit_test_of_pos_uint64:
  assumes "k < 64"
  shows "bit_test k ((of_pos::nat list \<Rightarrow> uint64) xs) \<longleftrightarrow> k \<in> set xs"
  using assms
proof (induction xs)
  case Nil
  thus ?case
    by (simp add: bit_test_test_bit_uint64)
next
  case (Cons x xs)
  then show ?case 
    by (simp add: bit_test_test_bit_uint64 bit_set_test_gen_uint64)
qed

lemma of_pos_positions_of_bl_uint8:
  assumes "length xs = 8"
  shows "of_pos (positions xs) = of_bl_uint8 (rev xs)"
proof (subst uint8_eq_I, rule allI, rule impI)
  fix i :: nat
  assume "i < 8"
  thus "(of_pos :: nat list \<Rightarrow> uint8) (positions xs) !! i = of_bl_uint8 (rev xs) !! i"
    using bit_test_of_pos_uint8[of i "positions xs"] `length xs = 8`
    unfolding bit_test_test_bit_uint8
    by (simp add: set_positions test_bit_of_bl_uint8)
qed

lemma of_pos_positions_of_bl_uint64:
  assumes "length xs = 64"
  shows "of_pos (positions xs) = of_bl_uint64 (rev xs)"
proof (subst uint64_eq_I, rule allI, rule impI)
  fix i :: nat
  assume "i < 64"
  thus "(of_pos :: nat list \<Rightarrow> uint64) (positions xs) !! i = of_bl_uint64 (rev xs) !! i"
    using bit_test_of_pos_uint64[of i "positions xs"] `length xs = 64`
    unfolding bit_test_test_bit_uint64
    by (simp add: set_positions test_bit_of_bl_uint64)
qed

lemma one_bits_pos_of_pos_uint8:
  assumes "set xs \<subseteq> {0..<8}" "sorted xs" "distinct xs"
  shows "one_bits_pos ((of_pos :: nat list \<Rightarrow> uint8) xs) = xs"
proof-
  let ?A = "{p. p < length (to_bl' ((of_pos :: nat list \<Rightarrow> uint8) xs)) \<and> to_bl' ((of_pos :: nat list \<Rightarrow> uint8) xs) ! p}"
  have *: "set xs = ?A"
    using `set xs \<subseteq> {0..<8}`
    unfolding to_bl'_def
    using bit_test_of_pos_uint8 
    by auto
  show ?thesis
    using assms
    using sorted_distinct_set_unique[of xs "sorted_list_of_set ?A"]
    unfolding one_bits_pos_def positions_sorted_list_of_set
    by (simp add: *)
qed

lemma of_pos_one_bits_pos_uint8 [simp]:
  fixes x :: uint8
  shows "of_pos (one_bits_pos x) = x"
  unfolding one_bits_pos_def
  apply (subst of_pos_positions_of_bl_uint8[of "to_bl' x"])
   apply (simp add: to_bl'_def)
  apply (subst to_bl'_uint8)
  apply simp
  done

lemma of_pos_one_bits_pos_uint64 [simp]:
  fixes x :: uint64
  shows "of_pos (one_bits_pos x) = x"
  unfolding one_bits_pos_def
  apply (subst of_pos_positions_of_bl_uint64[of "to_bl' x"])
   apply (simp add: to_bl'_def)
  apply (subst to_bl'_uint64)
  apply simp
  done

text \<open>Upper bounds\<close>

lemma bit_set_ub:
  fixes x :: uint8
  assumes "x < 2 ^ n" "a < n" "n < 8"
  shows "bit_set a x < 2 ^ n"
proof-
  have "\<forall> k. k \<ge> n \<longrightarrow> \<not> x !! k"
  proof (rule ccontr)
    assume "\<not> ?thesis"
    then obtain k where "k \<ge> n" "x !! k"
      by auto
    have "k < 8"
      using `x !! k`
      using test_bit_uint8_code by auto
    hence "x \<ge> 2^k"
      using bang_is_le[of k x] `x !! k`
      by auto
    have "x \<ge> 2^n"
    proof-
      have "(2::nat) ^ k < 256" "(2::nat) ^ n < 256"
        using `k \<ge> n` `k < 8`
        using power_strict_increasing_iff[of "2::nat" k 8]
        using power_strict_increasing_iff[of "2::nat" n 8]
        by auto
      thus ?thesis
        using `k \<ge> n` `k < 8` `x \<ge> 2 ^ k`
        using uint8_of_nat_2pow[of k]
        using uint8_of_nat_2pow[of n]
        using uint8_of_nat_mono_leq[of "2 ^ n" "2 ^ k"]
        by auto
    qed
    thus False
      using `x < 2 ^ n`
      by simp
  qed
  hence "\<forall> k. k \<ge> n \<longrightarrow> \<not> bit_set a x !! k"
    using assms
    using bit_set_test_gen_uint8 
    by auto
  thus ?thesis
    using trailing_zeros_ub[of n "bit_set a x"] `n < 8`
    by (simp add: pow2_eq_shiftL_uint8)
qed

lemma of_pos_ub:
  fixes xs :: "nat list"
  assumes "set xs \<subseteq> {0..<n}" "n < 8"
  shows "(of_pos :: nat list \<Rightarrow> uint8) xs < 2 ^ n"
  using assms
  unfolding of_pos_def
proof (induction xs)
  case Nil
  then show ?case 
    by (simp add: nat_of_uint8_mono)
next
  case (Cons a xs)
  then show ?case
    using bit_set_ub
    by auto
qed

lemma set_one_bits_pos_ub:
  assumes "n < 8"
  shows "s < (2::uint8) ^ n \<longleftrightarrow> set (one_bits_pos s) \<subseteq> {0..<n}"
proof safe
  fix s::uint8
  assume "set (one_bits_pos s) \<subseteq> {0..<n}"
  hence "of_pos (one_bits_pos s) < (2::uint8) ^ n"
    using of_pos_ub[of "one_bits_pos s" n] `n < 8`
    by simp
  thus "s < 2 ^ n"
    by simp
next
  fix s::uint8 and x
  assume "s < 2 ^ n" "x \<in> set (one_bits_pos s)"
  thus "x \<in> {0..<n}"
    using `n < 8`
    by (smt bang_is_le atLeastLessThan_iff dual_order.strict_trans2 nat_of_uint8_2pow nat_of_uint8_mono nat_power_less_imp_less set_one_bits_pos_uint8 zero_le zero_less_numeral)
qed

end
