1
0
Fork 0
mirror of synced 2025-09-23 12:18:44 +00:00

Merge pull request #761 from Zokrates/constant-range-check

Constant range check
This commit is contained in:
Thibaut Schaeffer 2021-04-29 20:14:24 +02:00 committed by GitHub
commit 4f17446e3d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
15 changed files with 731 additions and 243 deletions

View file

@ -0,0 +1 @@
Introduce constant range checks for checks of the form `x < c` where `p` is a compile-time constant, also for other comparison operators. This works for any `x` and `p`, unlike dynamic `x < y` comparison

View file

@ -22,4 +22,4 @@ The following table lists the precedence and associativity of all operators. Ope
[^2]: The right operand must be a compile time constant of type `u32`
[^3]: Both operands are asserted to be strictly lower than the biggest power of 2 lower than `p/2`
[^3]: Both operands are asserted to be strictly lower than the biggest power of 2 lower than `p/2`, unless one of them can be determined to be a compile-time constant

View file

@ -9,4 +9,4 @@ def main(field a) -> bool:
// maxvalue = 2**252 - 1
field maxvalue = a + 7237005577332262213973186563042994240829374041602535252466099000494570602496 - 1
// we added a = 0 to prevent the condition to be evaluated at compile time
return 0 < (maxvalue + 1)
return a < (maxvalue + 1)

View file

@ -4,4 +4,4 @@
def main(field a) -> bool:
field p = 21888242871839275222246405745257275088548364400416034343698204186575808495616 + a
// we added a = 0 to prevent the condition to be evaluated at compile time
return 0 < p
return a < p

View file

@ -208,6 +208,12 @@ pub enum FlatExpression<T> {
Mult(Box<FlatExpression<T>>, Box<FlatExpression<T>>),
}
impl<T> From<T> for FlatExpression<T> {
fn from(other: T) -> Self {
Self::Number(other)
}
}
impl<T: Field> FlatExpression<T> {
pub fn apply_substitution(
self,

View file

@ -178,34 +178,46 @@ impl<'ast, T: Field> Flattener<'ast, T> {
}
}
// Let's assume b = [1, 1, 1, 0]
//
// 1. Init `sizeUnknown = true`
// As long as `sizeUnknown` is `true` we don't yet know if a is <= than b.
// 2. Loop over `b`:
// * b[0] = 1
// when `b` is 1 we check wether `a` is 0 in that particular run and update
// `sizeUnknown` accordingly:
// `sizeUnknown = sizeUnknown && a[0]`
// * b[1] = 1
// `sizrUnknown = sizeUnknown && a[1]`
// * b[2] = 1
// `sizeUnknown = sizeUnknown && a[2]`
// * b[3] = 0
// we need to enforce that `a` is 0 in case `sizeUnknown`is still `true`,
// otherwise `a` can be {0,1}:
// `true == (!sizeUnknown || !a[3])`
// ```
// **true => a -> 0
// sizeUnkown *
// **false => a -> {0,1}
// ```
fn strict_le_check(
/// Compute a range check between the bid endian decomposition of an expression and the
/// big endian decomposition of a constant
///
/// # Arguments
/// * `a` - the big-endian bit decomposition of the expression to check against the range
/// * `b` - the big-endian bit decomposition of the upper bound we're checking against
///
/// # Returns
/// * a vector of FlatExpression which all evaluate to `1` if `a <= b` and `0` otherwise
///
/// # Notes
///
/// Algorithm from [the sapling spec](https://github.com/zcash/zips/blob/master/protocol/sapling.pdf) A.3.2.2
///
/// Let's assume b = [1, 1, 1, 0]
///
/// 1. Init `sizeUnknown = true`
/// As long as `sizeUnknown` is `true` we don't yet know if a is <= than b.
/// 2. Loop over `b`:
/// * b[0] = 1
/// when `b` is 1 we check wether `a` is 0 in that particular run and update
/// `sizeUnknown` accordingly:
/// `sizeUnknown = sizeUnknown && a[0]`
/// * b[1] = 1
/// `sizeUnknown = sizeUnknown && a[1]`
/// * b[2] = 1
/// `sizeUnknown = sizeUnknown && a[2]`
/// * b[3] = 0
/// we need to enforce that `a` is 0 in case `sizeUnknown`is still `true`,
/// otherwise `a` can be {0,1}:
/// `true == (!sizeUnknown || !a[3])`
/// **true => a -> 0
/// sizeUnkown *
/// **false => a -> {0,1}
fn constant_le_check(
&mut self,
statements_flattened: &mut FlatStatements<T>,
a: &[FlatVariable],
b: &[bool],
a: Vec<FlatVariable>,
) {
) -> Vec<FlatExpression<T>> {
let len = b.len();
assert_eq!(a.len(), T::get_required_bits());
assert_eq!(a.len(), b.len());
@ -224,6 +236,8 @@ impl<'ast, T: Field> Flattener<'ast, T> {
FlatExpression::Number(T::from(1)),
));
let mut res = vec![];
for (i, b) in b.iter().enumerate() {
if *b {
statements_flattened.push(FlatStatement::Definition(
@ -267,12 +281,93 @@ impl<'ast, T: Field> Flattener<'ast, T> {
box and_name.into(),
);
statements_flattened.push(FlatStatement::Condition(
FlatExpression::Number(T::from(1)),
or,
));
res.push(or);
}
}
res
}
/// Compute an equality check between two expressions
///
/// # Arguments
///
/// * `statements_flattened` - Vector where new flattened statements can be added.
/// * `left - the first `FlatExpression`
/// * `right` - the second `FlatExpression`
///
/// # Returns
/// * A FlatExpression which evaluates to `1` if `left == right`, `0` otherwise
fn eq_check(
&mut self,
statements_flattened: &mut Vec<FlatStatement<T>>,
left: FlatExpression<T>,
right: FlatExpression<T>,
) -> FlatExpression<T> {
let left = self.define(left, statements_flattened);
let right = self.define(right, statements_flattened);
// Wanted: (Y = (X != 0) ? 1 : 0)
// X = a - b
// # Y = if X == 0 then 0 else 1 fi
// # M = if X == 0 then 1 else 1/X fi
// Y == X * M
// 0 == (1-Y) * X
let x = FlatExpression::Sub(box left.into(), box right.into());
let name_y = self.use_sym();
let name_m = self.use_sym();
statements_flattened.push(FlatStatement::Directive(FlatDirective::new(
vec![name_y, name_m],
Solver::ConditionEq,
vec![x.clone()],
)));
statements_flattened.push(FlatStatement::Condition(
FlatExpression::Identifier(name_y),
FlatExpression::Mult(box x.clone(), box FlatExpression::Identifier(name_m)),
));
let res = FlatExpression::Sub(
box FlatExpression::Number(T::one()),
box FlatExpression::Identifier(name_y),
);
statements_flattened.push(FlatStatement::Condition(
FlatExpression::Number(T::zero()),
FlatExpression::Mult(box res.clone(), box x),
));
res
}
/// Enforce a range check against a constant: the range check isn't verified iff a constraint will fail
///
/// # Arguments
///
/// * `statements_flattened` - Vector where new flattened statements can be added.
/// * `a` - the big-endian bit decomposition of the expression we enforce to be in range
/// * `b` - the big-endian bit decomposition of the upper bound of the range
fn enforce_constant_le_check(
&mut self,
statements_flattened: &mut FlatStatements<T>,
a: &[FlatVariable],
b: &[bool],
) {
let conditions = self.constant_le_check(statements_flattened, a, b);
let conditions_count = conditions.len();
let conditions_sum = conditions
.into_iter()
.fold(FlatExpression::from(T::zero()), |acc, e| {
FlatExpression::Add(box acc, box e)
});
statements_flattened.push(FlatStatement::Condition(
FlatExpression::Number(T::from(0)),
FlatExpression::Sub(box conditions_sum, box T::from(conditions_count).into()),
));
}
/// Flatten an if/else expression
@ -346,6 +441,109 @@ impl<'ast, T: Field> Flattener<'ast, T> {
}
}
/// Compute a strict check against a constant
/// # Arguments
/// * `statements_flattened` - Vector where new flattened statements can be added.
/// * `e` - the `FlatExpression` that's being checked against the range.
/// * `c` - the constant strict upper bound of the range
///
/// # Returns
/// * a `FlatExpression` which evaluates to `1` if `0 <= e < c`, and to `0` otherwise
fn constant_lt_check(
&mut self,
statements_flattened: &mut FlatStatements<T>,
e: FlatExpression<T>,
c: T,
) -> FlatExpression<T> {
if c == T::zero() {
// this is the case c == 0, we return 0, aka false
return T::zero().into();
}
self.constant_field_le_check(statements_flattened, e, c - T::one())
}
/// Compute a range check against a constant
///
/// # Arguments
///
/// * `statements_flattened` - Vector where new flattened statements can be added.
/// * `e` - the `FlatExpression` that's being checked against the range.
/// * `c` - the constant upper bound of the range
///
/// # Returns
/// * a `FlatExpression` which evaluates to `1` if `0 <= e <= c`, and to `0` otherwise
fn constant_field_le_check(
&mut self,
statements_flattened: &mut FlatStatements<T>,
e: FlatExpression<T>,
c: T,
) -> FlatExpression<T> {
let bit_width = T::get_required_bits();
// decompose e to bits
let e_id = self.define(e, statements_flattened);
// define variables for the bits
let e_bits_be: Vec<FlatVariable> = (0..bit_width).map(|_| self.use_sym()).collect();
// add a directive to get the bits
statements_flattened.push(FlatStatement::Directive(FlatDirective::new(
e_bits_be.clone(),
Solver::bits(bit_width),
vec![e_id],
)));
// bitness checks
for bit in e_bits_be.iter().take(bit_width) {
statements_flattened.push(FlatStatement::Condition(
FlatExpression::Identifier(*bit),
FlatExpression::Mult(
box FlatExpression::Identifier(*bit),
box FlatExpression::Identifier(*bit),
),
));
}
// bit decomposition check
let mut e_sum = FlatExpression::Number(T::from(0));
for (i, bit) in e_bits_be.iter().take(bit_width).enumerate() {
e_sum = FlatExpression::Add(
box e_sum,
box FlatExpression::Mult(
box FlatExpression::Identifier(*bit),
box FlatExpression::Number(T::from(2).pow(bit_width - i - 1)),
),
);
}
statements_flattened.push(FlatStatement::Condition(
FlatExpression::Identifier(e_id),
e_sum,
));
// check that this decomposition does not overflow the field
self.enforce_constant_le_check(
statements_flattened,
&e_bits_be,
&T::max_value().bit_vector_be(),
);
let conditions =
self.constant_le_check(statements_flattened, &e_bits_be, &c.bit_vector_be());
// return `len(conditions) == sum(conditions)`
self.eq_check(
statements_flattened,
T::from(conditions.len()).into(),
conditions
.into_iter()
.fold(FlatExpression::Number(T::zero()), |acc, e| {
FlatExpression::Add(box acc, box e)
}),
)
}
/// Flattens a boolean expression
///
/// # Arguments
@ -370,7 +568,6 @@ impl<'ast, T: Field> Flattener<'ast, T> {
BooleanExpression::FieldLt(box lhs, box rhs) => {
// Get the bit width to know the size of the binary decompositions for this Field
let bit_width = T::get_required_bits();
let safe_width = bit_width - 2; // making sure we don't overflow, assert here?
// We know from semantic checking that lhs and rhs have the same type
// What the expression will flatten to depends on that type
@ -378,157 +575,181 @@ impl<'ast, T: Field> Flattener<'ast, T> {
let lhs_flattened = self.flatten_field_expression(statements_flattened, lhs);
let rhs_flattened = self.flatten_field_expression(statements_flattened, rhs);
// lhs
let lhs_id = self.define(lhs_flattened, statements_flattened);
// check that lhs and rhs are within the right range, i.e., their higher two bits are zero. We use big-endian so they are at positions 0 and 1
// lhs
{
// define variables for the bits
let lhs_bits_be: Vec<FlatVariable> =
(0..safe_width).map(|_| self.use_sym()).collect();
// add a directive to get the bits
statements_flattened.push(FlatStatement::Directive(FlatDirective::new(
lhs_bits_be.clone(),
Solver::bits(safe_width),
vec![lhs_id],
)));
// bitness checks
for bit in lhs_bits_be.iter().take(safe_width) {
statements_flattened.push(FlatStatement::Condition(
FlatExpression::Identifier(*bit),
FlatExpression::Mult(
box FlatExpression::Identifier(*bit),
box FlatExpression::Identifier(*bit),
),
));
match (lhs_flattened, rhs_flattened) {
(x, FlatExpression::Number(constant)) => {
self.constant_lt_check(statements_flattened, x, constant)
}
// (c < x <= p - 1) <=> (0 <= p - 1 - x < p - 1 - c)
(FlatExpression::Number(constant), x) => self.constant_lt_check(
statements_flattened,
FlatExpression::Sub(box T::max_value().into(), box x),
T::max_value() - constant,
),
(lhs_flattened, rhs_flattened) => {
let safe_width = bit_width - 2; // dynamic comparison is not complete, it only applies to words of width `bit_width - 2`
// bit decomposition check
let mut lhs_sum = FlatExpression::Number(T::from(0));
// lhs
let lhs_id = self.define(lhs_flattened, statements_flattened);
for (i, bit) in lhs_bits_be.iter().enumerate().take(safe_width) {
lhs_sum = FlatExpression::Add(
box lhs_sum,
// check that lhs and rhs are within the right range, i.e., their higher two bits are zero. We use big-endian so they are at positions 0 and 1
// lhs
{
// define variables for the bits
let lhs_bits_be: Vec<FlatVariable> =
(0..safe_width).map(|_| self.use_sym()).collect();
// add a directive to get the bits
statements_flattened.push(FlatStatement::Directive(
FlatDirective::new(
lhs_bits_be.clone(),
Solver::bits(safe_width),
vec![lhs_id],
),
));
// bitness checks
for bit in lhs_bits_be.iter().take(safe_width) {
statements_flattened.push(FlatStatement::Condition(
FlatExpression::Identifier(*bit),
FlatExpression::Mult(
box FlatExpression::Identifier(*bit),
box FlatExpression::Identifier(*bit),
),
));
}
// bit decomposition check
let mut lhs_sum = FlatExpression::Number(T::from(0));
for (i, bit) in lhs_bits_be.iter().take(safe_width).enumerate() {
lhs_sum = FlatExpression::Add(
box lhs_sum,
box FlatExpression::Mult(
box FlatExpression::Identifier(*bit),
box FlatExpression::Number(
T::from(2).pow(safe_width - i - 1),
),
),
);
}
statements_flattened.push(FlatStatement::Condition(
FlatExpression::Identifier(lhs_id),
lhs_sum,
));
}
// rhs
let rhs_id = self.define(rhs_flattened, statements_flattened);
// rhs
{
// define variables for the bits
let rhs_bits_be: Vec<FlatVariable> =
(0..safe_width).map(|_| self.use_sym()).collect();
// add a directive to get the bits
statements_flattened.push(FlatStatement::Directive(
FlatDirective::new(
rhs_bits_be.clone(),
Solver::bits(safe_width),
vec![rhs_id],
),
));
// bitness checks
for bit in rhs_bits_be.iter().take(safe_width) {
statements_flattened.push(FlatStatement::Condition(
FlatExpression::Identifier(*bit),
FlatExpression::Mult(
box FlatExpression::Identifier(*bit),
box FlatExpression::Identifier(*bit),
),
));
}
// bit decomposition check
let mut rhs_sum = FlatExpression::Number(T::from(0));
for (i, bit) in rhs_bits_be.iter().take(safe_width).enumerate() {
rhs_sum = FlatExpression::Add(
box rhs_sum,
box FlatExpression::Mult(
box FlatExpression::Identifier(*bit),
box FlatExpression::Number(
T::from(2).pow(safe_width - i - 1),
),
),
);
}
statements_flattened.push(FlatStatement::Condition(
FlatExpression::Identifier(rhs_id),
rhs_sum,
));
}
// sym := (lhs * 2) - (rhs * 2)
let subtraction_result = FlatExpression::Sub(
box FlatExpression::Mult(
box FlatExpression::Identifier(*bit),
box FlatExpression::Number(T::from(2).pow(safe_width - i - 1)),
box FlatExpression::Number(T::from(2)),
box FlatExpression::Identifier(lhs_id),
),
box FlatExpression::Mult(
box FlatExpression::Number(T::from(2)),
box FlatExpression::Identifier(rhs_id),
),
);
}
statements_flattened.push(FlatStatement::Condition(
FlatExpression::Identifier(lhs_id),
lhs_sum,
));
}
// define variables for the bits
let sub_bits_be: Vec<FlatVariable> =
(0..bit_width).map(|_| self.use_sym()).collect();
// rhs
let rhs_id = self.define(rhs_flattened, statements_flattened);
// add a directive to get the bits
statements_flattened.push(FlatStatement::Directive(FlatDirective::new(
sub_bits_be.clone(),
Solver::bits(bit_width),
vec![subtraction_result.clone()],
)));
// rhs
{
// define variables for the bits
let rhs_bits_be: Vec<FlatVariable> =
(0..safe_width).map(|_| self.use_sym()).collect();
// bitness checks
for bit in sub_bits_be.iter().take(bit_width) {
statements_flattened.push(FlatStatement::Condition(
FlatExpression::Identifier(*bit),
FlatExpression::Mult(
box FlatExpression::Identifier(*bit),
box FlatExpression::Identifier(*bit),
),
));
}
// add a directive to get the bits
statements_flattened.push(FlatStatement::Directive(FlatDirective::new(
rhs_bits_be.clone(),
Solver::bits(safe_width),
vec![rhs_id],
)));
// bitness checks
for bit in rhs_bits_be.iter().take(safe_width) {
statements_flattened.push(FlatStatement::Condition(
FlatExpression::Identifier(*bit),
FlatExpression::Mult(
box FlatExpression::Identifier(*bit),
box FlatExpression::Identifier(*bit),
),
));
}
// bit decomposition check
let mut rhs_sum = FlatExpression::Number(T::from(0));
for (i, bit) in rhs_bits_be.iter().enumerate().take(safe_width) {
rhs_sum = FlatExpression::Add(
box rhs_sum,
box FlatExpression::Mult(
box FlatExpression::Identifier(*bit),
box FlatExpression::Number(T::from(2).pow(safe_width - i - 1)),
),
// check that the decomposition is in the field with a strict `< p` checks
self.enforce_constant_le_check(
statements_flattened,
&sub_bits_be,
&T::max_value().bit_vector_be(),
);
// sum(sym_b{i} * 2**i)
let mut expr = FlatExpression::Number(T::from(0));
for (i, bit) in sub_bits_be.iter().take(bit_width).enumerate() {
expr = FlatExpression::Add(
box expr,
box FlatExpression::Mult(
box FlatExpression::Identifier(*bit),
box FlatExpression::Number(T::from(2).pow(bit_width - i - 1)),
),
);
}
statements_flattened
.push(FlatStatement::Condition(subtraction_result, expr));
FlatExpression::Identifier(sub_bits_be[bit_width - 1])
}
statements_flattened.push(FlatStatement::Condition(
FlatExpression::Identifier(rhs_id),
rhs_sum,
));
}
// sym := (lhs * 2) - (rhs * 2)
let subtraction_result = FlatExpression::Sub(
box FlatExpression::Mult(
box FlatExpression::Number(T::from(2)),
box FlatExpression::Identifier(lhs_id),
),
box FlatExpression::Mult(
box FlatExpression::Number(T::from(2)),
box FlatExpression::Identifier(rhs_id),
),
);
// define variables for the bits
let sub_bits_be: Vec<FlatVariable> =
(0..bit_width).map(|_| self.use_sym()).collect();
// add a directive to get the bits
statements_flattened.push(FlatStatement::Directive(FlatDirective::new(
sub_bits_be.clone(),
Solver::bits(bit_width),
vec![subtraction_result.clone()],
)));
// bitness checks
for bit in sub_bits_be.iter().take(bit_width) {
statements_flattened.push(FlatStatement::Condition(
FlatExpression::Identifier(*bit),
FlatExpression::Mult(
box FlatExpression::Identifier(*bit),
box FlatExpression::Identifier(*bit),
),
));
}
// check that the decomposition is in the field with a strict `< p` checks
self.strict_le_check(
statements_flattened,
&T::max_value_bit_vector_be(),
sub_bits_be.clone(),
);
// sum(sym_b{i} * 2**i)
let mut expr = FlatExpression::Number(T::from(0));
for (i, bit) in sub_bits_be.iter().enumerate().take(bit_width) {
expr = FlatExpression::Add(
box expr,
box FlatExpression::Mult(
box FlatExpression::Identifier(*bit),
box FlatExpression::Number(T::from(2).pow(bit_width - i - 1)),
),
);
}
statements_flattened.push(FlatStatement::Condition(subtraction_result, expr));
FlatExpression::Identifier(sub_bits_be[bit_width - 1])
}
BooleanExpression::BoolEq(box lhs, box rhs) => {
// lhs and rhs are booleans, they flatten to 0 or 1
@ -563,42 +784,11 @@ impl<'ast, T: Field> Flattener<'ast, T> {
)
}
BooleanExpression::FieldEq(box lhs, box rhs) => {
// Wanted: (Y = (X != 0) ? 1 : 0)
// X = a - b
// # Y = if X == 0 then 0 else 1 fi
// # M = if X == 0 then 1 else 1/X fi
// Y == X * M
// 0 == (1-Y) * X
let lhs = self.flatten_field_expression(statements_flattened, lhs);
let name_y = self.use_sym();
let name_m = self.use_sym();
let rhs = self.flatten_field_expression(statements_flattened, rhs);
let x = self.flatten_field_expression(
statements_flattened,
FieldElementExpression::Sub(box lhs, box rhs),
);
statements_flattened.push(FlatStatement::Directive(FlatDirective::new(
vec![name_y, name_m],
Solver::ConditionEq,
vec![x.clone()],
)));
statements_flattened.push(FlatStatement::Condition(
FlatExpression::Identifier(name_y),
FlatExpression::Mult(box x.clone(), box FlatExpression::Identifier(name_m)),
));
let res = FlatExpression::Sub(
box FlatExpression::Number(T::one()),
box FlatExpression::Identifier(name_y),
);
statements_flattened.push(FlatStatement::Condition(
FlatExpression::Number(T::zero()),
FlatExpression::Mult(box res.clone(), box x),
));
res
self.eq_check(statements_flattened, lhs, rhs)
}
BooleanExpression::UintEq(box lhs, box rhs) => {
// We reduce each side into range and apply the same approach as for field elements
@ -610,9 +800,6 @@ impl<'ast, T: Field> Flattener<'ast, T> {
// Y == X * M
// 0 == (1-Y) * X
let name_y = self.use_sym();
let name_m = self.use_sym();
assert!(lhs.metadata.clone().unwrap().should_reduce.to_bool());
assert!(rhs.metadata.clone().unwrap().should_reduce.to_bool());
@ -623,29 +810,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
.flatten_uint_expression(statements_flattened, rhs)
.get_field_unchecked();
let x = FlatExpression::Sub(box lhs, box rhs);
statements_flattened.push(FlatStatement::Directive(FlatDirective::new(
vec![name_y, name_m],
Solver::ConditionEq,
vec![x.clone()],
)));
statements_flattened.push(FlatStatement::Condition(
FlatExpression::Identifier(name_y),
FlatExpression::Mult(box x.clone(), box FlatExpression::Identifier(name_m)),
));
let res = FlatExpression::Sub(
box FlatExpression::Number(T::one()),
box FlatExpression::Identifier(name_y),
);
statements_flattened.push(FlatStatement::Condition(
FlatExpression::Number(T::zero()),
FlatExpression::Mult(box res.clone(), box x),
));
res
self.eq_check(statements_flattened, lhs, rhs)
}
BooleanExpression::FieldLe(box lhs, box rhs) => {
let lt = self.flatten_boolean_expression(
@ -713,10 +878,10 @@ impl<'ast, T: Field> Flattener<'ast, T> {
}
// check that the decomposition is in the field with a strict `< p` checks
self.strict_le_check(
self.constant_le_check(
statements_flattened,
&T::max_value_bit_vector_be(),
sub_bits_be.clone(),
&sub_bits_be,
&T::max_value().bit_vector_be(),
);
// sum(sym_b{i} * 2**i)
@ -2343,7 +2508,7 @@ mod tests {
ZirStatement::Assertion(BooleanExpression::FieldEq(
box FieldElementExpression::Add(
box FieldElementExpression::Identifier("x".into()),
box FieldElementExpression::Number(Bn128Field::from(1)).into(),
box FieldElementExpression::Number(Bn128Field::from(1)),
),
box FieldElementExpression::Identifier("y".into()),
)),

View file

@ -14,8 +14,8 @@ use zokrates_field::Bn128Field;
#[test]
fn out_of_range() {
let source = r#"
def main(private field a) -> field:
field x = if a < 5555 then 3333 else 4444 fi
def main(private field a, private field b) -> field:
field x = if a < b then 3333 else 4444 fi
assert(x == 3333)
return 1
"#
@ -36,6 +36,9 @@ fn out_of_range() {
let interpreter = Interpreter::try_out_of_range();
assert!(interpreter
.execute(&res.prog(), &[Bn128Field::from(10000)])
.execute(
&res.prog(),
&[Bn128Field::from(10000), Bn128Field::from(5555)]
)
.is_err());
}

View file

@ -0,0 +1,86 @@
{
"entry_point": "./tests/tests/le.zok",
"curves": ["Bn128"],
"tests": [
{
"input": {
"values": ["0"]
},
"output": {
"Ok": {
"values": ["1"]
}
}
},
{
"input": {
"values": ["1"]
},
"output": {
"Ok": {
"values": ["1"]
}
}
},
{
"input": {
"values": ["2"]
},
"output": {
"Ok": {
"values": ["1"]
}
}
},
{
"input": {
"values": ["41"]
},
"output": {
"Ok": {
"values": ["1"]
}
}
},
{
"input": {
"values": ["42"]
},
"output": {
"Ok": {
"values": ["1"]
}
}
},
{
"input": {
"values": ["43"]
},
"output": {
"Ok": {
"values": ["0"]
}
}
},
{
"input": {
"values": ["44"]
},
"output": {
"Ok": {
"values": ["0"]
}
}
},
{
"input": {
"values": ["100"]
},
"output": {
"Ok": {
"values": ["0"]
}
}
}
]
}

View file

@ -0,0 +1,2 @@
def main(field e) -> bool:
return e <= 42

View file

@ -0,0 +1,86 @@
{
"entry_point": "./tests/tests/native_le.zok",
"curves": ["Bn128"],
"tests": [
{
"input": {
"values": ["0", "0"]
},
"output": {
"Ok": {
"values": ["1", "1"]
}
}
},
{
"input": {
"values": ["1", "1"]
},
"output": {
"Ok": {
"values": ["1", "1"]
}
}
},
{
"input": {
"values": ["2", "2"]
},
"output": {
"Ok": {
"values": ["1", "1"]
}
}
},
{
"input": {
"values": ["41", "41"]
},
"output": {
"Ok": {
"values": ["1", "1"]
}
}
},
{
"input": {
"values": ["42", "42"]
},
"output": {
"Ok": {
"values": ["1", "1"]
}
}
},
{
"input": {
"values": ["43", "43"]
},
"output": {
"Ok": {
"values": ["0", "0"]
}
}
},
{
"input": {
"values": ["44", "44"]
},
"output": {
"Ok": {
"values": ["0", "0"]
}
}
},
{
"input": {
"values": ["100", "100"]
},
"output": {
"Ok": {
"values": ["0", "0"]
}
}
}
]
}

View file

@ -0,0 +1,47 @@
from "utils/pack/bool/unpack.zok" import main as unpack
from "utils/casts/u32_to_bits" import main as u32_to_bits
// this comparison works for any N smaller than the field size, which is the case in practice
def le<N>(bool[N] a_bits, bool[N] c_bits) -> bool:
bool size_unknown = false
u32 verified_conditions = 0 // `and(conditions) == (sum(conditions) == len(conditions))`, here we initialize `sum(conditions)`
size_unknown = true
for u32 i in 0..N do
verified_conditions = verified_conditions + if c_bits[i] || (!size_unknown || !a_bits[i]) then 1 else 0 fi
size_unknown = if c_bits[i] then size_unknown && a_bits[i] else size_unknown fi // this is actually not required in the last round
endfor
return verified_conditions == N // this checks that all conditions were verified
// this instanciates comparison starting from field elements
def le<N>(field a, field c) -> bool:
field MAX = 21888242871839275222246405745257275088548364400416034343698204186575808495616
bool[N] MAX_BITS = unpack::<N>(MAX)
bool[N] a_bits = unpack(a)
assert(le(a_bits, MAX_BITS))
bool[N] c_bits = unpack(c)
assert(le(c_bits, MAX_BITS))
return le(a_bits, c_bits)
// this instanciates comparison starting from u32
def le(u32 a, u32 c) -> bool:
bool[32] a_bits = u32_to_bits(a)
bool[32] c_bits = u32_to_bits(c)
return le(a_bits, c_bits)
def main(field a, u32 b) -> (bool, bool):
u32 N = 254
field c = 42
u32 d = 42
return le::<N>(a, c), le(b, d)

View file

@ -0,0 +1,76 @@
{
"entry_point": "./tests/tests/range.zok",
"curves": ["Bn128"],
"tests": [
{
"input": {
"values": ["0"]
},
"output": {
"Ok": {
"values": ["0", "1", "1", "1", "0"]
}
}
},
{
"input": {
"values": ["1"]
},
"output": {
"Ok": {
"values": ["0", "0", "1", "1", "1"]
}
}
},
{
"input": {
"values": ["2"]
},
"output": {
"Ok": {
"values": ["0", "0", "1", "1", "0"]
}
}
},
{
"input": {
"values": ["254"]
},
"output": {
"Ok": {
"values": ["0", "0", "1", "1", "0"]
}
}
},
{
"input": {
"values": ["255"]
},
"output": {
"Ok": {
"values": ["0", "0", "0", "1", "0"]
}
}
},
{
"input": {
"values": ["21888242871839275222246405745257275088548364400416034343698204186575808495615"]
},
"output": {
"Ok": {
"values": ["0", "0", "0", "1", "0"]
}
}
},
{
"input": {
"values": ["21888242871839275222246405745257275088548364400416034343698204186575808495616"]
},
"output": {
"Ok": {
"values": ["0", "0", "0", "0", "0"]
}
}
}
]
}

View file

@ -0,0 +1,2 @@
def main(field x) -> bool[5]:
return [x < 0, x < 1, x < 255, x < 0 - 1, x - 2 > 0 - 2]

View file

@ -30,7 +30,7 @@ mod tests {
#[test]
fn max_value_bits() {
let bits = FieldPrime::max_value_bit_vector_be();
let bits = FieldPrime::max_value().bit_vector_be();
assert_eq!(
bits[0..10].to_vec(),
vec![true, true, false, false, false, false, false, true, true, false]

View file

@ -116,7 +116,8 @@ pub trait Field:
/// Gets the number of bits
fn bits(&self) -> u32;
/// Returns this `Field`'s largest value as a big-endian bit vector
fn max_value_bit_vector_be() -> Vec<bool> {
/// Always returns `Self::get_required_bits()` elements
fn bit_vector_be(&self) -> Vec<bool> {
fn bytes_to_bits(bytes: &[u8]) -> Vec<bool> {
bytes
.iter()
@ -124,13 +125,26 @@ pub trait Field:
.collect()
}
let field_bytes_le = Self::to_byte_vector(&Self::max_value());
let field_bytes_le = self.to_byte_vector();
// reverse for big-endianess
let field_bytes_be = field_bytes_le.into_iter().rev().collect::<Vec<u8>>();
let field_bits_be = bytes_to_bits(&field_bytes_be);
let field_bits_be = &field_bits_be[field_bits_be.len() - Self::get_required_bits()..];
field_bits_be.to_vec()
let field_bits_be: Vec<_> = (0..Self::get_required_bits()
.saturating_sub(field_bits_be.len()))
.map(|_| &false)
.chain(
&field_bits_be[field_bits_be
.len()
.saturating_sub(Self::get_required_bits())..],
)
.cloned()
.collect();
assert_eq!(field_bits_be.len(), Self::get_required_bits());
field_bits_be
}
/// Returns the value as a BigUint
fn to_biguint(&self) -> BigUint;