diff --git a/changelogs/unreleased/761-schaeff b/changelogs/unreleased/761-schaeff new file mode 100644 index 00000000..f51dfda3 --- /dev/null +++ b/changelogs/unreleased/761-schaeff @@ -0,0 +1 @@ +Introduce constant range checks for checks of the form `x < c` where `p` is a compile-time constant, also for other comparison operators. This works for any `x` and `p`, unlike dynamic `x < y` comparison \ No newline at end of file diff --git a/zokrates_book/src/language/operators.md b/zokrates_book/src/language/operators.md index afaffa38..b330ea79 100644 --- a/zokrates_book/src/language/operators.md +++ b/zokrates_book/src/language/operators.md @@ -22,4 +22,4 @@ The following table lists the precedence and associativity of all operators. Ope [^2]: The right operand must be a compile time constant of type `u32` -[^3]: Both operands are asserted to be strictly lower than the biggest power of 2 lower than `p/2` \ No newline at end of file +[^3]: Both operands are asserted to be strictly lower than the biggest power of 2 lower than `p/2`, unless one of them can be determined to be a compile-time constant \ No newline at end of file diff --git a/zokrates_cli/examples/runtime_errors/lt_overflow_max_plus_1.zok b/zokrates_cli/examples/runtime_errors/lt_overflow_max_plus_1.zok index 83b8f337..d4affb1c 100644 --- a/zokrates_cli/examples/runtime_errors/lt_overflow_max_plus_1.zok +++ b/zokrates_cli/examples/runtime_errors/lt_overflow_max_plus_1.zok @@ -9,4 +9,4 @@ def main(field a) -> bool: // maxvalue = 2**252 - 1 field maxvalue = a + 7237005577332262213973186563042994240829374041602535252466099000494570602496 - 1 // we added a = 0 to prevent the condition to be evaluated at compile time - return 0 < (maxvalue + 1) \ No newline at end of file + return a < (maxvalue + 1) \ No newline at end of file diff --git a/zokrates_cli/examples/runtime_errors/lt_overflow_p_minus_one.zok b/zokrates_cli/examples/runtime_errors/lt_overflow_p_minus_one.zok index fd7aca40..e0b5d775 100644 --- a/zokrates_cli/examples/runtime_errors/lt_overflow_p_minus_one.zok +++ b/zokrates_cli/examples/runtime_errors/lt_overflow_p_minus_one.zok @@ -4,4 +4,4 @@ def main(field a) -> bool: field p = 21888242871839275222246405745257275088548364400416034343698204186575808495616 + a // we added a = 0 to prevent the condition to be evaluated at compile time - return 0 < p \ No newline at end of file + return a < p \ No newline at end of file diff --git a/zokrates_core/src/flat_absy/mod.rs b/zokrates_core/src/flat_absy/mod.rs index 08232fbe..c832578a 100644 --- a/zokrates_core/src/flat_absy/mod.rs +++ b/zokrates_core/src/flat_absy/mod.rs @@ -208,6 +208,12 @@ pub enum FlatExpression { Mult(Box>, Box>), } +impl From for FlatExpression { + fn from(other: T) -> Self { + Self::Number(other) + } +} + impl FlatExpression { pub fn apply_substitution( self, diff --git a/zokrates_core/src/flatten/mod.rs b/zokrates_core/src/flatten/mod.rs index b5a355cd..2ff28827 100644 --- a/zokrates_core/src/flatten/mod.rs +++ b/zokrates_core/src/flatten/mod.rs @@ -178,34 +178,46 @@ impl<'ast, T: Field> Flattener<'ast, T> { } } - // Let's assume b = [1, 1, 1, 0] - // - // 1. Init `sizeUnknown = true` - // As long as `sizeUnknown` is `true` we don't yet know if a is <= than b. - // 2. Loop over `b`: - // * b[0] = 1 - // when `b` is 1 we check wether `a` is 0 in that particular run and update - // `sizeUnknown` accordingly: - // `sizeUnknown = sizeUnknown && a[0]` - // * b[1] = 1 - // `sizrUnknown = sizeUnknown && a[1]` - // * b[2] = 1 - // `sizeUnknown = sizeUnknown && a[2]` - // * b[3] = 0 - // we need to enforce that `a` is 0 in case `sizeUnknown`is still `true`, - // otherwise `a` can be {0,1}: - // `true == (!sizeUnknown || !a[3])` - // ``` - // **true => a -> 0 - // sizeUnkown * - // **false => a -> {0,1} - // ``` - fn strict_le_check( + /// Compute a range check between the bid endian decomposition of an expression and the + /// big endian decomposition of a constant + /// + /// # Arguments + /// * `a` - the big-endian bit decomposition of the expression to check against the range + /// * `b` - the big-endian bit decomposition of the upper bound we're checking against + /// + /// # Returns + /// * a vector of FlatExpression which all evaluate to `1` if `a <= b` and `0` otherwise + /// + /// # Notes + /// + /// Algorithm from [the sapling spec](https://github.com/zcash/zips/blob/master/protocol/sapling.pdf) A.3.2.2 + /// + /// Let's assume b = [1, 1, 1, 0] + /// + /// 1. Init `sizeUnknown = true` + /// As long as `sizeUnknown` is `true` we don't yet know if a is <= than b. + /// 2. Loop over `b`: + /// * b[0] = 1 + /// when `b` is 1 we check wether `a` is 0 in that particular run and update + /// `sizeUnknown` accordingly: + /// `sizeUnknown = sizeUnknown && a[0]` + /// * b[1] = 1 + /// `sizeUnknown = sizeUnknown && a[1]` + /// * b[2] = 1 + /// `sizeUnknown = sizeUnknown && a[2]` + /// * b[3] = 0 + /// we need to enforce that `a` is 0 in case `sizeUnknown`is still `true`, + /// otherwise `a` can be {0,1}: + /// `true == (!sizeUnknown || !a[3])` + /// **true => a -> 0 + /// sizeUnkown * + /// **false => a -> {0,1} + fn constant_le_check( &mut self, statements_flattened: &mut FlatStatements, + a: &[FlatVariable], b: &[bool], - a: Vec, - ) { + ) -> Vec> { let len = b.len(); assert_eq!(a.len(), T::get_required_bits()); assert_eq!(a.len(), b.len()); @@ -224,6 +236,8 @@ impl<'ast, T: Field> Flattener<'ast, T> { FlatExpression::Number(T::from(1)), )); + let mut res = vec![]; + for (i, b) in b.iter().enumerate() { if *b { statements_flattened.push(FlatStatement::Definition( @@ -267,12 +281,93 @@ impl<'ast, T: Field> Flattener<'ast, T> { box and_name.into(), ); - statements_flattened.push(FlatStatement::Condition( - FlatExpression::Number(T::from(1)), - or, - )); + res.push(or); } } + + res + } + + /// Compute an equality check between two expressions + /// + /// # Arguments + /// + /// * `statements_flattened` - Vector where new flattened statements can be added. + /// * `left - the first `FlatExpression` + /// * `right` - the second `FlatExpression` + /// + /// # Returns + /// * A FlatExpression which evaluates to `1` if `left == right`, `0` otherwise + fn eq_check( + &mut self, + statements_flattened: &mut Vec>, + left: FlatExpression, + right: FlatExpression, + ) -> FlatExpression { + let left = self.define(left, statements_flattened); + let right = self.define(right, statements_flattened); + + // Wanted: (Y = (X != 0) ? 1 : 0) + // X = a - b + // # Y = if X == 0 then 0 else 1 fi + // # M = if X == 0 then 1 else 1/X fi + // Y == X * M + // 0 == (1-Y) * X + + let x = FlatExpression::Sub(box left.into(), box right.into()); + + let name_y = self.use_sym(); + let name_m = self.use_sym(); + + statements_flattened.push(FlatStatement::Directive(FlatDirective::new( + vec![name_y, name_m], + Solver::ConditionEq, + vec![x.clone()], + ))); + statements_flattened.push(FlatStatement::Condition( + FlatExpression::Identifier(name_y), + FlatExpression::Mult(box x.clone(), box FlatExpression::Identifier(name_m)), + )); + + let res = FlatExpression::Sub( + box FlatExpression::Number(T::one()), + box FlatExpression::Identifier(name_y), + ); + + statements_flattened.push(FlatStatement::Condition( + FlatExpression::Number(T::zero()), + FlatExpression::Mult(box res.clone(), box x), + )); + + res + } + + /// Enforce a range check against a constant: the range check isn't verified iff a constraint will fail + /// + /// # Arguments + /// + /// * `statements_flattened` - Vector where new flattened statements can be added. + /// * `a` - the big-endian bit decomposition of the expression we enforce to be in range + /// * `b` - the big-endian bit decomposition of the upper bound of the range + fn enforce_constant_le_check( + &mut self, + statements_flattened: &mut FlatStatements, + a: &[FlatVariable], + b: &[bool], + ) { + let conditions = self.constant_le_check(statements_flattened, a, b); + + let conditions_count = conditions.len(); + + let conditions_sum = conditions + .into_iter() + .fold(FlatExpression::from(T::zero()), |acc, e| { + FlatExpression::Add(box acc, box e) + }); + statements_flattened.push(FlatStatement::Condition( + FlatExpression::Number(T::from(0)), + FlatExpression::Sub(box conditions_sum, box T::from(conditions_count).into()), + )); } /// Flatten an if/else expression @@ -346,6 +441,109 @@ impl<'ast, T: Field> Flattener<'ast, T> { } } + /// Compute a strict check against a constant + /// # Arguments + /// * `statements_flattened` - Vector where new flattened statements can be added. + /// * `e` - the `FlatExpression` that's being checked against the range. + /// * `c` - the constant strict upper bound of the range + /// + /// # Returns + /// * a `FlatExpression` which evaluates to `1` if `0 <= e < c`, and to `0` otherwise + fn constant_lt_check( + &mut self, + statements_flattened: &mut FlatStatements, + e: FlatExpression, + c: T, + ) -> FlatExpression { + if c == T::zero() { + // this is the case c == 0, we return 0, aka false + return T::zero().into(); + } + + self.constant_field_le_check(statements_flattened, e, c - T::one()) + } + + /// Compute a range check against a constant + /// + /// # Arguments + /// + /// * `statements_flattened` - Vector where new flattened statements can be added. + /// * `e` - the `FlatExpression` that's being checked against the range. + /// * `c` - the constant upper bound of the range + /// + /// # Returns + /// * a `FlatExpression` which evaluates to `1` if `0 <= e <= c`, and to `0` otherwise + fn constant_field_le_check( + &mut self, + statements_flattened: &mut FlatStatements, + e: FlatExpression, + c: T, + ) -> FlatExpression { + let bit_width = T::get_required_bits(); + // decompose e to bits + let e_id = self.define(e, statements_flattened); + + // define variables for the bits + let e_bits_be: Vec = (0..bit_width).map(|_| self.use_sym()).collect(); + + // add a directive to get the bits + statements_flattened.push(FlatStatement::Directive(FlatDirective::new( + e_bits_be.clone(), + Solver::bits(bit_width), + vec![e_id], + ))); + + // bitness checks + for bit in e_bits_be.iter().take(bit_width) { + statements_flattened.push(FlatStatement::Condition( + FlatExpression::Identifier(*bit), + FlatExpression::Mult( + box FlatExpression::Identifier(*bit), + box FlatExpression::Identifier(*bit), + ), + )); + } + + // bit decomposition check + let mut e_sum = FlatExpression::Number(T::from(0)); + + for (i, bit) in e_bits_be.iter().take(bit_width).enumerate() { + e_sum = FlatExpression::Add( + box e_sum, + box FlatExpression::Mult( + box FlatExpression::Identifier(*bit), + box FlatExpression::Number(T::from(2).pow(bit_width - i - 1)), + ), + ); + } + + statements_flattened.push(FlatStatement::Condition( + FlatExpression::Identifier(e_id), + e_sum, + )); + + // check that this decomposition does not overflow the field + self.enforce_constant_le_check( + statements_flattened, + &e_bits_be, + &T::max_value().bit_vector_be(), + ); + + let conditions = + self.constant_le_check(statements_flattened, &e_bits_be, &c.bit_vector_be()); + + // return `len(conditions) == sum(conditions)` + self.eq_check( + statements_flattened, + T::from(conditions.len()).into(), + conditions + .into_iter() + .fold(FlatExpression::Number(T::zero()), |acc, e| { + FlatExpression::Add(box acc, box e) + }), + ) + } + /// Flattens a boolean expression /// /// # Arguments @@ -370,7 +568,6 @@ impl<'ast, T: Field> Flattener<'ast, T> { BooleanExpression::FieldLt(box lhs, box rhs) => { // Get the bit width to know the size of the binary decompositions for this Field let bit_width = T::get_required_bits(); - let safe_width = bit_width - 2; // making sure we don't overflow, assert here? // We know from semantic checking that lhs and rhs have the same type // What the expression will flatten to depends on that type @@ -378,157 +575,181 @@ impl<'ast, T: Field> Flattener<'ast, T> { let lhs_flattened = self.flatten_field_expression(statements_flattened, lhs); let rhs_flattened = self.flatten_field_expression(statements_flattened, rhs); - // lhs - let lhs_id = self.define(lhs_flattened, statements_flattened); - - // check that lhs and rhs are within the right range, i.e., their higher two bits are zero. We use big-endian so they are at positions 0 and 1 - - // lhs - { - // define variables for the bits - let lhs_bits_be: Vec = - (0..safe_width).map(|_| self.use_sym()).collect(); - - // add a directive to get the bits - statements_flattened.push(FlatStatement::Directive(FlatDirective::new( - lhs_bits_be.clone(), - Solver::bits(safe_width), - vec![lhs_id], - ))); - - // bitness checks - for bit in lhs_bits_be.iter().take(safe_width) { - statements_flattened.push(FlatStatement::Condition( - FlatExpression::Identifier(*bit), - FlatExpression::Mult( - box FlatExpression::Identifier(*bit), - box FlatExpression::Identifier(*bit), - ), - )); + match (lhs_flattened, rhs_flattened) { + (x, FlatExpression::Number(constant)) => { + self.constant_lt_check(statements_flattened, x, constant) } + // (c < x <= p - 1) <=> (0 <= p - 1 - x < p - 1 - c) + (FlatExpression::Number(constant), x) => self.constant_lt_check( + statements_flattened, + FlatExpression::Sub(box T::max_value().into(), box x), + T::max_value() - constant, + ), + (lhs_flattened, rhs_flattened) => { + let safe_width = bit_width - 2; // dynamic comparison is not complete, it only applies to words of width `bit_width - 2` - // bit decomposition check - let mut lhs_sum = FlatExpression::Number(T::from(0)); + // lhs + let lhs_id = self.define(lhs_flattened, statements_flattened); - for (i, bit) in lhs_bits_be.iter().enumerate().take(safe_width) { - lhs_sum = FlatExpression::Add( - box lhs_sum, + // check that lhs and rhs are within the right range, i.e., their higher two bits are zero. We use big-endian so they are at positions 0 and 1 + + // lhs + { + // define variables for the bits + let lhs_bits_be: Vec = + (0..safe_width).map(|_| self.use_sym()).collect(); + + // add a directive to get the bits + statements_flattened.push(FlatStatement::Directive( + FlatDirective::new( + lhs_bits_be.clone(), + Solver::bits(safe_width), + vec![lhs_id], + ), + )); + + // bitness checks + for bit in lhs_bits_be.iter().take(safe_width) { + statements_flattened.push(FlatStatement::Condition( + FlatExpression::Identifier(*bit), + FlatExpression::Mult( + box FlatExpression::Identifier(*bit), + box FlatExpression::Identifier(*bit), + ), + )); + } + + // bit decomposition check + let mut lhs_sum = FlatExpression::Number(T::from(0)); + + for (i, bit) in lhs_bits_be.iter().take(safe_width).enumerate() { + lhs_sum = FlatExpression::Add( + box lhs_sum, + box FlatExpression::Mult( + box FlatExpression::Identifier(*bit), + box FlatExpression::Number( + T::from(2).pow(safe_width - i - 1), + ), + ), + ); + } + + statements_flattened.push(FlatStatement::Condition( + FlatExpression::Identifier(lhs_id), + lhs_sum, + )); + } + + // rhs + let rhs_id = self.define(rhs_flattened, statements_flattened); + + // rhs + { + // define variables for the bits + let rhs_bits_be: Vec = + (0..safe_width).map(|_| self.use_sym()).collect(); + + // add a directive to get the bits + statements_flattened.push(FlatStatement::Directive( + FlatDirective::new( + rhs_bits_be.clone(), + Solver::bits(safe_width), + vec![rhs_id], + ), + )); + + // bitness checks + for bit in rhs_bits_be.iter().take(safe_width) { + statements_flattened.push(FlatStatement::Condition( + FlatExpression::Identifier(*bit), + FlatExpression::Mult( + box FlatExpression::Identifier(*bit), + box FlatExpression::Identifier(*bit), + ), + )); + } + + // bit decomposition check + let mut rhs_sum = FlatExpression::Number(T::from(0)); + + for (i, bit) in rhs_bits_be.iter().take(safe_width).enumerate() { + rhs_sum = FlatExpression::Add( + box rhs_sum, + box FlatExpression::Mult( + box FlatExpression::Identifier(*bit), + box FlatExpression::Number( + T::from(2).pow(safe_width - i - 1), + ), + ), + ); + } + + statements_flattened.push(FlatStatement::Condition( + FlatExpression::Identifier(rhs_id), + rhs_sum, + )); + } + + // sym := (lhs * 2) - (rhs * 2) + let subtraction_result = FlatExpression::Sub( box FlatExpression::Mult( - box FlatExpression::Identifier(*bit), - box FlatExpression::Number(T::from(2).pow(safe_width - i - 1)), + box FlatExpression::Number(T::from(2)), + box FlatExpression::Identifier(lhs_id), + ), + box FlatExpression::Mult( + box FlatExpression::Number(T::from(2)), + box FlatExpression::Identifier(rhs_id), ), ); - } - statements_flattened.push(FlatStatement::Condition( - FlatExpression::Identifier(lhs_id), - lhs_sum, - )); - } + // define variables for the bits + let sub_bits_be: Vec = + (0..bit_width).map(|_| self.use_sym()).collect(); - // rhs - let rhs_id = self.define(rhs_flattened, statements_flattened); + // add a directive to get the bits + statements_flattened.push(FlatStatement::Directive(FlatDirective::new( + sub_bits_be.clone(), + Solver::bits(bit_width), + vec![subtraction_result.clone()], + ))); - // rhs - { - // define variables for the bits - let rhs_bits_be: Vec = - (0..safe_width).map(|_| self.use_sym()).collect(); + // bitness checks + for bit in sub_bits_be.iter().take(bit_width) { + statements_flattened.push(FlatStatement::Condition( + FlatExpression::Identifier(*bit), + FlatExpression::Mult( + box FlatExpression::Identifier(*bit), + box FlatExpression::Identifier(*bit), + ), + )); + } - // add a directive to get the bits - statements_flattened.push(FlatStatement::Directive(FlatDirective::new( - rhs_bits_be.clone(), - Solver::bits(safe_width), - vec![rhs_id], - ))); - - // bitness checks - for bit in rhs_bits_be.iter().take(safe_width) { - statements_flattened.push(FlatStatement::Condition( - FlatExpression::Identifier(*bit), - FlatExpression::Mult( - box FlatExpression::Identifier(*bit), - box FlatExpression::Identifier(*bit), - ), - )); - } - - // bit decomposition check - let mut rhs_sum = FlatExpression::Number(T::from(0)); - - for (i, bit) in rhs_bits_be.iter().enumerate().take(safe_width) { - rhs_sum = FlatExpression::Add( - box rhs_sum, - box FlatExpression::Mult( - box FlatExpression::Identifier(*bit), - box FlatExpression::Number(T::from(2).pow(safe_width - i - 1)), - ), + // check that the decomposition is in the field with a strict `< p` checks + self.enforce_constant_le_check( + statements_flattened, + &sub_bits_be, + &T::max_value().bit_vector_be(), ); + + // sum(sym_b{i} * 2**i) + let mut expr = FlatExpression::Number(T::from(0)); + + for (i, bit) in sub_bits_be.iter().take(bit_width).enumerate() { + expr = FlatExpression::Add( + box expr, + box FlatExpression::Mult( + box FlatExpression::Identifier(*bit), + box FlatExpression::Number(T::from(2).pow(bit_width - i - 1)), + ), + ); + } + + statements_flattened + .push(FlatStatement::Condition(subtraction_result, expr)); + + FlatExpression::Identifier(sub_bits_be[bit_width - 1]) } - - statements_flattened.push(FlatStatement::Condition( - FlatExpression::Identifier(rhs_id), - rhs_sum, - )); } - - // sym := (lhs * 2) - (rhs * 2) - let subtraction_result = FlatExpression::Sub( - box FlatExpression::Mult( - box FlatExpression::Number(T::from(2)), - box FlatExpression::Identifier(lhs_id), - ), - box FlatExpression::Mult( - box FlatExpression::Number(T::from(2)), - box FlatExpression::Identifier(rhs_id), - ), - ); - - // define variables for the bits - let sub_bits_be: Vec = - (0..bit_width).map(|_| self.use_sym()).collect(); - - // add a directive to get the bits - statements_flattened.push(FlatStatement::Directive(FlatDirective::new( - sub_bits_be.clone(), - Solver::bits(bit_width), - vec![subtraction_result.clone()], - ))); - - // bitness checks - for bit in sub_bits_be.iter().take(bit_width) { - statements_flattened.push(FlatStatement::Condition( - FlatExpression::Identifier(*bit), - FlatExpression::Mult( - box FlatExpression::Identifier(*bit), - box FlatExpression::Identifier(*bit), - ), - )); - } - - // check that the decomposition is in the field with a strict `< p` checks - self.strict_le_check( - statements_flattened, - &T::max_value_bit_vector_be(), - sub_bits_be.clone(), - ); - - // sum(sym_b{i} * 2**i) - let mut expr = FlatExpression::Number(T::from(0)); - - for (i, bit) in sub_bits_be.iter().enumerate().take(bit_width) { - expr = FlatExpression::Add( - box expr, - box FlatExpression::Mult( - box FlatExpression::Identifier(*bit), - box FlatExpression::Number(T::from(2).pow(bit_width - i - 1)), - ), - ); - } - - statements_flattened.push(FlatStatement::Condition(subtraction_result, expr)); - - FlatExpression::Identifier(sub_bits_be[bit_width - 1]) } BooleanExpression::BoolEq(box lhs, box rhs) => { // lhs and rhs are booleans, they flatten to 0 or 1 @@ -563,42 +784,11 @@ impl<'ast, T: Field> Flattener<'ast, T> { ) } BooleanExpression::FieldEq(box lhs, box rhs) => { - // Wanted: (Y = (X != 0) ? 1 : 0) - // X = a - b - // # Y = if X == 0 then 0 else 1 fi - // # M = if X == 0 then 1 else 1/X fi - // Y == X * M - // 0 == (1-Y) * X + let lhs = self.flatten_field_expression(statements_flattened, lhs); - let name_y = self.use_sym(); - let name_m = self.use_sym(); + let rhs = self.flatten_field_expression(statements_flattened, rhs); - let x = self.flatten_field_expression( - statements_flattened, - FieldElementExpression::Sub(box lhs, box rhs), - ); - - statements_flattened.push(FlatStatement::Directive(FlatDirective::new( - vec![name_y, name_m], - Solver::ConditionEq, - vec![x.clone()], - ))); - statements_flattened.push(FlatStatement::Condition( - FlatExpression::Identifier(name_y), - FlatExpression::Mult(box x.clone(), box FlatExpression::Identifier(name_m)), - )); - - let res = FlatExpression::Sub( - box FlatExpression::Number(T::one()), - box FlatExpression::Identifier(name_y), - ); - - statements_flattened.push(FlatStatement::Condition( - FlatExpression::Number(T::zero()), - FlatExpression::Mult(box res.clone(), box x), - )); - - res + self.eq_check(statements_flattened, lhs, rhs) } BooleanExpression::UintEq(box lhs, box rhs) => { // We reduce each side into range and apply the same approach as for field elements @@ -610,9 +800,6 @@ impl<'ast, T: Field> Flattener<'ast, T> { // Y == X * M // 0 == (1-Y) * X - let name_y = self.use_sym(); - let name_m = self.use_sym(); - assert!(lhs.metadata.clone().unwrap().should_reduce.to_bool()); assert!(rhs.metadata.clone().unwrap().should_reduce.to_bool()); @@ -623,29 +810,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { .flatten_uint_expression(statements_flattened, rhs) .get_field_unchecked(); - let x = FlatExpression::Sub(box lhs, box rhs); - - statements_flattened.push(FlatStatement::Directive(FlatDirective::new( - vec![name_y, name_m], - Solver::ConditionEq, - vec![x.clone()], - ))); - statements_flattened.push(FlatStatement::Condition( - FlatExpression::Identifier(name_y), - FlatExpression::Mult(box x.clone(), box FlatExpression::Identifier(name_m)), - )); - - let res = FlatExpression::Sub( - box FlatExpression::Number(T::one()), - box FlatExpression::Identifier(name_y), - ); - - statements_flattened.push(FlatStatement::Condition( - FlatExpression::Number(T::zero()), - FlatExpression::Mult(box res.clone(), box x), - )); - - res + self.eq_check(statements_flattened, lhs, rhs) } BooleanExpression::FieldLe(box lhs, box rhs) => { let lt = self.flatten_boolean_expression( @@ -713,10 +878,10 @@ impl<'ast, T: Field> Flattener<'ast, T> { } // check that the decomposition is in the field with a strict `< p` checks - self.strict_le_check( + self.constant_le_check( statements_flattened, - &T::max_value_bit_vector_be(), - sub_bits_be.clone(), + &sub_bits_be, + &T::max_value().bit_vector_be(), ); // sum(sym_b{i} * 2**i) @@ -2343,7 +2508,7 @@ mod tests { ZirStatement::Assertion(BooleanExpression::FieldEq( box FieldElementExpression::Add( box FieldElementExpression::Identifier("x".into()), - box FieldElementExpression::Number(Bn128Field::from(1)).into(), + box FieldElementExpression::Number(Bn128Field::from(1)), ), box FieldElementExpression::Identifier("y".into()), )), diff --git a/zokrates_core/tests/out_of_range.rs b/zokrates_core/tests/out_of_range.rs index ca7f43ab..6d2f19ab 100644 --- a/zokrates_core/tests/out_of_range.rs +++ b/zokrates_core/tests/out_of_range.rs @@ -14,8 +14,8 @@ use zokrates_field::Bn128Field; #[test] fn out_of_range() { let source = r#" - def main(private field a) -> field: - field x = if a < 5555 then 3333 else 4444 fi + def main(private field a, private field b) -> field: + field x = if a < b then 3333 else 4444 fi assert(x == 3333) return 1 "# @@ -36,6 +36,9 @@ fn out_of_range() { let interpreter = Interpreter::try_out_of_range(); assert!(interpreter - .execute(&res.prog(), &[Bn128Field::from(10000)]) + .execute( + &res.prog(), + &[Bn128Field::from(10000), Bn128Field::from(5555)] + ) .is_err()); } diff --git a/zokrates_core_test/tests/tests/le.json b/zokrates_core_test/tests/tests/le.json new file mode 100644 index 00000000..8a5092a9 --- /dev/null +++ b/zokrates_core_test/tests/tests/le.json @@ -0,0 +1,86 @@ +{ + "entry_point": "./tests/tests/le.zok", + "curves": ["Bn128"], + "tests": [ + { + "input": { + "values": ["0"] + }, + "output": { + "Ok": { + "values": ["1"] + } + } + }, + { + "input": { + "values": ["1"] + }, + "output": { + "Ok": { + "values": ["1"] + } + } + }, + { + "input": { + "values": ["2"] + }, + "output": { + "Ok": { + "values": ["1"] + } + } + }, + { + "input": { + "values": ["41"] + }, + "output": { + "Ok": { + "values": ["1"] + } + } + }, + { + "input": { + "values": ["42"] + }, + "output": { + "Ok": { + "values": ["1"] + } + } + }, + { + "input": { + "values": ["43"] + }, + "output": { + "Ok": { + "values": ["0"] + } + } + }, + { + "input": { + "values": ["44"] + }, + "output": { + "Ok": { + "values": ["0"] + } + } + }, + { + "input": { + "values": ["100"] + }, + "output": { + "Ok": { + "values": ["0"] + } + } + } + ] +} \ No newline at end of file diff --git a/zokrates_core_test/tests/tests/le.zok b/zokrates_core_test/tests/tests/le.zok new file mode 100644 index 00000000..fad82ef8 --- /dev/null +++ b/zokrates_core_test/tests/tests/le.zok @@ -0,0 +1,2 @@ +def main(field e) -> bool: + return e <= 42 \ No newline at end of file diff --git a/zokrates_core_test/tests/tests/native_le.json b/zokrates_core_test/tests/tests/native_le.json new file mode 100644 index 00000000..65d0cc62 --- /dev/null +++ b/zokrates_core_test/tests/tests/native_le.json @@ -0,0 +1,86 @@ +{ + "entry_point": "./tests/tests/native_le.zok", + "curves": ["Bn128"], + "tests": [ + { + "input": { + "values": ["0", "0"] + }, + "output": { + "Ok": { + "values": ["1", "1"] + } + } + }, + { + "input": { + "values": ["1", "1"] + }, + "output": { + "Ok": { + "values": ["1", "1"] + } + } + }, + { + "input": { + "values": ["2", "2"] + }, + "output": { + "Ok": { + "values": ["1", "1"] + } + } + }, + { + "input": { + "values": ["41", "41"] + }, + "output": { + "Ok": { + "values": ["1", "1"] + } + } + }, + { + "input": { + "values": ["42", "42"] + }, + "output": { + "Ok": { + "values": ["1", "1"] + } + } + }, + { + "input": { + "values": ["43", "43"] + }, + "output": { + "Ok": { + "values": ["0", "0"] + } + } + }, + { + "input": { + "values": ["44", "44"] + }, + "output": { + "Ok": { + "values": ["0", "0"] + } + } + }, + { + "input": { + "values": ["100", "100"] + }, + "output": { + "Ok": { + "values": ["0", "0"] + } + } + } + ] +} diff --git a/zokrates_core_test/tests/tests/native_le.zok b/zokrates_core_test/tests/tests/native_le.zok new file mode 100644 index 00000000..d1f50437 --- /dev/null +++ b/zokrates_core_test/tests/tests/native_le.zok @@ -0,0 +1,47 @@ +from "utils/pack/bool/unpack.zok" import main as unpack +from "utils/casts/u32_to_bits" import main as u32_to_bits + +// this comparison works for any N smaller than the field size, which is the case in practice +def le(bool[N] a_bits, bool[N] c_bits) -> bool: + + bool size_unknown = false + + u32 verified_conditions = 0 // `and(conditions) == (sum(conditions) == len(conditions))`, here we initialize `sum(conditions)` + + size_unknown = true + + for u32 i in 0..N do + verified_conditions = verified_conditions + if c_bits[i] || (!size_unknown || !a_bits[i]) then 1 else 0 fi + size_unknown = if c_bits[i] then size_unknown && a_bits[i] else size_unknown fi // this is actually not required in the last round + endfor + + return verified_conditions == N // this checks that all conditions were verified + +// this instanciates comparison starting from field elements +def le(field a, field c) -> bool: + + field MAX = 21888242871839275222246405745257275088548364400416034343698204186575808495616 + bool[N] MAX_BITS = unpack::(MAX) + + bool[N] a_bits = unpack(a) + assert(le(a_bits, MAX_BITS)) + bool[N] c_bits = unpack(c) + assert(le(c_bits, MAX_BITS)) + + return le(a_bits, c_bits) + +// this instanciates comparison starting from u32 +def le(u32 a, u32 c) -> bool: + bool[32] a_bits = u32_to_bits(a) + bool[32] c_bits = u32_to_bits(c) + + return le(a_bits, c_bits) + +def main(field a, u32 b) -> (bool, bool): + + u32 N = 254 + field c = 42 + + u32 d = 42 + + return le::(a, c), le(b, d) \ No newline at end of file diff --git a/zokrates_core_test/tests/tests/range.json b/zokrates_core_test/tests/tests/range.json new file mode 100644 index 00000000..85a71730 --- /dev/null +++ b/zokrates_core_test/tests/tests/range.json @@ -0,0 +1,76 @@ +{ + "entry_point": "./tests/tests/range.zok", + "curves": ["Bn128"], + "tests": [ + { + "input": { + "values": ["0"] + }, + "output": { + "Ok": { + "values": ["0", "1", "1", "1", "0"] + } + } + }, + { + "input": { + "values": ["1"] + }, + "output": { + "Ok": { + "values": ["0", "0", "1", "1", "1"] + } + } + }, + { + "input": { + "values": ["2"] + }, + "output": { + "Ok": { + "values": ["0", "0", "1", "1", "0"] + } + } + }, + { + "input": { + "values": ["254"] + }, + "output": { + "Ok": { + "values": ["0", "0", "1", "1", "0"] + } + } + }, + { + "input": { + "values": ["255"] + }, + "output": { + "Ok": { + "values": ["0", "0", "0", "1", "0"] + } + } + }, + { + "input": { + "values": ["21888242871839275222246405745257275088548364400416034343698204186575808495615"] + }, + "output": { + "Ok": { + "values": ["0", "0", "0", "1", "0"] + } + } + }, + { + "input": { + "values": ["21888242871839275222246405745257275088548364400416034343698204186575808495616"] + }, + "output": { + "Ok": { + "values": ["0", "0", "0", "0", "0"] + } + } + } + ] +} diff --git a/zokrates_core_test/tests/tests/range.zok b/zokrates_core_test/tests/tests/range.zok new file mode 100644 index 00000000..667dc5c4 --- /dev/null +++ b/zokrates_core_test/tests/tests/range.zok @@ -0,0 +1,2 @@ +def main(field x) -> bool[5]: + return [x < 0, x < 1, x < 255, x < 0 - 1, x - 2 > 0 - 2] \ No newline at end of file diff --git a/zokrates_field/src/bn128.rs b/zokrates_field/src/bn128.rs index 014a222c..b7bf2ea2 100644 --- a/zokrates_field/src/bn128.rs +++ b/zokrates_field/src/bn128.rs @@ -30,7 +30,7 @@ mod tests { #[test] fn max_value_bits() { - let bits = FieldPrime::max_value_bit_vector_be(); + let bits = FieldPrime::max_value().bit_vector_be(); assert_eq!( bits[0..10].to_vec(), vec![true, true, false, false, false, false, false, true, true, false] diff --git a/zokrates_field/src/lib.rs b/zokrates_field/src/lib.rs index d4537e64..030841a9 100644 --- a/zokrates_field/src/lib.rs +++ b/zokrates_field/src/lib.rs @@ -116,7 +116,8 @@ pub trait Field: /// Gets the number of bits fn bits(&self) -> u32; /// Returns this `Field`'s largest value as a big-endian bit vector - fn max_value_bit_vector_be() -> Vec { + /// Always returns `Self::get_required_bits()` elements + fn bit_vector_be(&self) -> Vec { fn bytes_to_bits(bytes: &[u8]) -> Vec { bytes .iter() @@ -124,13 +125,26 @@ pub trait Field: .collect() } - let field_bytes_le = Self::to_byte_vector(&Self::max_value()); + let field_bytes_le = self.to_byte_vector(); + // reverse for big-endianess let field_bytes_be = field_bytes_le.into_iter().rev().collect::>(); let field_bits_be = bytes_to_bits(&field_bytes_be); - let field_bits_be = &field_bits_be[field_bits_be.len() - Self::get_required_bits()..]; - field_bits_be.to_vec() + let field_bits_be: Vec<_> = (0..Self::get_required_bits() + .saturating_sub(field_bits_be.len())) + .map(|_| &false) + .chain( + &field_bits_be[field_bits_be + .len() + .saturating_sub(Self::get_required_bits())..], + ) + .cloned() + .collect(); + + assert_eq!(field_bits_be.len(), Self::get_required_bits()); + + field_bits_be } /// Returns the value as a BigUint fn to_biguint(&self) -> BigUint;