1
0
Fork 0
mirror of synced 2025-09-23 20:28:36 +00:00

fix to_bits, improve naming

This commit is contained in:
schaeff 2021-04-20 18:35:03 +02:00
parent c8322bf0db
commit c105e801e5
4 changed files with 19 additions and 60 deletions

47
lt.zok
View file

@ -1,47 +0,0 @@
from "utils/pack/bool/unpack.zok" import main as unpack
from "utils/casts/u32_to_bits" import main as u32_to_bits
// this comparison works for any N smaller than the field size, which is the case in practice
def lt<N>(bool[N] a_bits, bool[N] c_bits) -> bool:
bool[N] is_not_smaller_run = [false; N]
bool[N] size_unknown = [false; N]
u32 verified_conditions = 0 // `and(conditions) == (sum(conditions) == len(conditions))`, here we initialize `sum(conditions)`
size_unknown[0] = true
for u32 i in 0..N - 1 do
is_not_smaller_run[i] = if c_bits[i] then a_bits[i] else is_not_smaller_run[i] fi
size_unknown[i + 1] = if c_bits[i] then size_unknown[i] && is_not_smaller_run[i] else size_unknown[i] fi
verified_conditions = verified_conditions + if c_bits[i] then 1 else if (!size_unknown[i] || !a_bits[i]) then 1 else 0 fi fi
endfor
u32 i = N - 1
is_not_smaller_run[i] = if c_bits[i] then a_bits[i] else is_not_smaller_run[i] fi
verified_conditions = verified_conditions + if c_bits[i] then 1 else if (!size_unknown[i] || !a_bits[i]) then 1 else 0 fi fi
return verified_conditions == N // this checks that all conditions were verified
// this instanciates comparison starting from field elements
def lt<N>(field a, field c) -> bool:
bool[N] a_bits = unpack(a)
bool[N] c_bits = unpack(c)
return lt(a_bits, c_bits)
// this instanciates comparison starting from u32
def lt(u32 a, u32 c) -> bool:
bool[32] a_bits = u32_to_bits(a)
bool[32] c_bits = u32_to_bits(c)
return lt(a_bits, c_bits)
def main(field a) -> bool:
u32 N = 254
field c = 42
//u32 d = 42
return lt::<N>(a, c)

View file

@ -204,7 +204,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
// # Returns
//
// * a vector of FlatExpression which all evaluate to 1 if a <= b and 0 otherwise
fn strict_le_check(
fn constant_le_check(
&mut self,
statements_flattened: &mut FlatStatements<T>,
a: &[FlatVariable],
@ -217,6 +217,8 @@ impl<'ast, T: Field> Flattener<'ast, T> {
let mut is_not_smaller_run = vec![];
let mut size_unknown = vec![];
println!("B {:?}", b);
for _ in 0..len {
is_not_smaller_run.push(self.use_sym());
size_unknown.push(self.use_sym());
@ -326,14 +328,14 @@ impl<'ast, T: Field> Flattener<'ast, T> {
res
}
fn enforce_strict_le_check(
fn enforce_constant_le_check(
&mut self,
statements_flattened: &mut FlatStatements<T>,
a: &[FlatVariable],
b: &[bool],
) {
let statements: Vec<_> = self
.strict_le_check(statements_flattened, a, b)
.constant_le_check(statements_flattened, a, b)
.into_iter()
.map(|c| FlatStatement::Condition(FlatExpression::Number(T::from(1)), c))
.collect();
@ -422,7 +424,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
return T::zero().into();
}
self.constant_le_check(statements_flattened, e, c - T::one())
self.constant_field_le_check(statements_flattened, e, c - T::one())
}
/// Compute a range check against a constant
@ -435,7 +437,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
///
/// # Returns
/// * a `FlatExpression` which evaluates to `1` if `0 <= e <= c`, and to `0` otherwise
fn constant_le_check(
fn constant_field_le_check(
&mut self,
statements_flattened: &mut FlatStatements<T>,
e: FlatExpression<T>,
@ -485,13 +487,15 @@ impl<'ast, T: Field> Flattener<'ast, T> {
));
// check that this decomposition does not overflow the field
self.enforce_strict_le_check(
self.enforce_constant_le_check(
statements_flattened,
&e_bits_be,
&T::max_value().bit_vector_be(),
);
let conditions = self.strict_le_check(statements_flattened, &e_bits_be, &c.bit_vector_be());
println!("YOOOOOO");
let conditions =
self.constant_le_check(statements_flattened, &e_bits_be, &c.bit_vector_be());
// return `len(conditions) == sum(conditions)`
self.eq_check(
@ -521,6 +525,8 @@ impl<'ast, T: Field> Flattener<'ast, T> {
statements_flattened: &mut FlatStatements<T>,
expression: BooleanExpression<'ast, T>,
) -> FlatExpression<T> {
println!("{}", expression);
// those will be booleans in the future
match expression {
BooleanExpression::Identifier(x) => {
@ -538,7 +544,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
match (lhs_flattened, rhs_flattened) {
(x, FlatExpression::Number(constant)) => {
println!("yo");
println!("yo {} < {}", x, constant);
self.constant_lt_check(statements_flattened, x, constant)
}
// (c < x <= p - 1) <=> (0 <= p - 1 - x < p - 1 - c)
@ -687,7 +693,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
}
// check that the decomposition is in the field with a strict `< p` checks
self.enforce_strict_le_check(
self.enforce_constant_le_check(
statements_flattened,
&sub_bits_be,
&T::max_value().bit_vector_be(),
@ -840,7 +846,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
}
// check that the decomposition is in the field with a strict `< p` checks
self.strict_le_check(
self.constant_le_check(
statements_flattened,
&sub_bits_be,
&T::max_value().bit_vector_be(),

View file

@ -31,10 +31,10 @@ def le(u32 a, u32 c) -> bool:
return le(a_bits, c_bits)
def main(field a, field c) -> bool:
def main(field a) -> bool:
u32 N = 254
//field c = 42
field c = 42
//u32 d = 42

View file

@ -125,7 +125,7 @@ pub trait Field:
.collect()
}
let field_bytes_le = Self::to_byte_vector(&Self::max_value());
let field_bytes_le = self.to_byte_vector();
// reverse for big-endianess
let field_bytes_be = field_bytes_le.into_iter().rev().collect::<Vec<u8>>();