Merge pull request #761 from Zokrates/constant-range-check
Constant range check
This commit is contained in:
commit
4f17446e3d
15 changed files with 731 additions and 243 deletions
1
changelogs/unreleased/761-schaeff
Normal file
1
changelogs/unreleased/761-schaeff
Normal file
|
@ -0,0 +1 @@
|
|||
Introduce constant range checks for checks of the form `x < c` where `p` is a compile-time constant, also for other comparison operators. This works for any `x` and `p`, unlike dynamic `x < y` comparison
|
|
@ -22,4 +22,4 @@ The following table lists the precedence and associativity of all operators. Ope
|
|||
|
||||
[^2]: The right operand must be a compile time constant of type `u32`
|
||||
|
||||
[^3]: Both operands are asserted to be strictly lower than the biggest power of 2 lower than `p/2`
|
||||
[^3]: Both operands are asserted to be strictly lower than the biggest power of 2 lower than `p/2`, unless one of them can be determined to be a compile-time constant
|
|
@ -9,4 +9,4 @@ def main(field a) -> bool:
|
|||
// maxvalue = 2**252 - 1
|
||||
field maxvalue = a + 7237005577332262213973186563042994240829374041602535252466099000494570602496 - 1
|
||||
// we added a = 0 to prevent the condition to be evaluated at compile time
|
||||
return 0 < (maxvalue + 1)
|
||||
return a < (maxvalue + 1)
|
|
@ -4,4 +4,4 @@
|
|||
def main(field a) -> bool:
|
||||
field p = 21888242871839275222246405745257275088548364400416034343698204186575808495616 + a
|
||||
// we added a = 0 to prevent the condition to be evaluated at compile time
|
||||
return 0 < p
|
||||
return a < p
|
|
@ -208,6 +208,12 @@ pub enum FlatExpression<T> {
|
|||
Mult(Box<FlatExpression<T>>, Box<FlatExpression<T>>),
|
||||
}
|
||||
|
||||
impl<T> From<T> for FlatExpression<T> {
|
||||
fn from(other: T) -> Self {
|
||||
Self::Number(other)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Field> FlatExpression<T> {
|
||||
pub fn apply_substitution(
|
||||
self,
|
||||
|
|
|
@ -178,34 +178,46 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
}
|
||||
}
|
||||
|
||||
// Let's assume b = [1, 1, 1, 0]
|
||||
//
|
||||
// 1. Init `sizeUnknown = true`
|
||||
// As long as `sizeUnknown` is `true` we don't yet know if a is <= than b.
|
||||
// 2. Loop over `b`:
|
||||
// * b[0] = 1
|
||||
// when `b` is 1 we check wether `a` is 0 in that particular run and update
|
||||
// `sizeUnknown` accordingly:
|
||||
// `sizeUnknown = sizeUnknown && a[0]`
|
||||
// * b[1] = 1
|
||||
// `sizrUnknown = sizeUnknown && a[1]`
|
||||
// * b[2] = 1
|
||||
// `sizeUnknown = sizeUnknown && a[2]`
|
||||
// * b[3] = 0
|
||||
// we need to enforce that `a` is 0 in case `sizeUnknown`is still `true`,
|
||||
// otherwise `a` can be {0,1}:
|
||||
// `true == (!sizeUnknown || !a[3])`
|
||||
// ```
|
||||
// **true => a -> 0
|
||||
// sizeUnkown *
|
||||
// **false => a -> {0,1}
|
||||
// ```
|
||||
fn strict_le_check(
|
||||
/// Compute a range check between the bid endian decomposition of an expression and the
|
||||
/// big endian decomposition of a constant
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `a` - the big-endian bit decomposition of the expression to check against the range
|
||||
/// * `b` - the big-endian bit decomposition of the upper bound we're checking against
|
||||
///
|
||||
/// # Returns
|
||||
/// * a vector of FlatExpression which all evaluate to `1` if `a <= b` and `0` otherwise
|
||||
///
|
||||
/// # Notes
|
||||
///
|
||||
/// Algorithm from [the sapling spec](https://github.com/zcash/zips/blob/master/protocol/sapling.pdf) A.3.2.2
|
||||
///
|
||||
/// Let's assume b = [1, 1, 1, 0]
|
||||
///
|
||||
/// 1. Init `sizeUnknown = true`
|
||||
/// As long as `sizeUnknown` is `true` we don't yet know if a is <= than b.
|
||||
/// 2. Loop over `b`:
|
||||
/// * b[0] = 1
|
||||
/// when `b` is 1 we check wether `a` is 0 in that particular run and update
|
||||
/// `sizeUnknown` accordingly:
|
||||
/// `sizeUnknown = sizeUnknown && a[0]`
|
||||
/// * b[1] = 1
|
||||
/// `sizeUnknown = sizeUnknown && a[1]`
|
||||
/// * b[2] = 1
|
||||
/// `sizeUnknown = sizeUnknown && a[2]`
|
||||
/// * b[3] = 0
|
||||
/// we need to enforce that `a` is 0 in case `sizeUnknown`is still `true`,
|
||||
/// otherwise `a` can be {0,1}:
|
||||
/// `true == (!sizeUnknown || !a[3])`
|
||||
/// **true => a -> 0
|
||||
/// sizeUnkown *
|
||||
/// **false => a -> {0,1}
|
||||
fn constant_le_check(
|
||||
&mut self,
|
||||
statements_flattened: &mut FlatStatements<T>,
|
||||
a: &[FlatVariable],
|
||||
b: &[bool],
|
||||
a: Vec<FlatVariable>,
|
||||
) {
|
||||
) -> Vec<FlatExpression<T>> {
|
||||
let len = b.len();
|
||||
assert_eq!(a.len(), T::get_required_bits());
|
||||
assert_eq!(a.len(), b.len());
|
||||
|
@ -224,6 +236,8 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
FlatExpression::Number(T::from(1)),
|
||||
));
|
||||
|
||||
let mut res = vec![];
|
||||
|
||||
for (i, b) in b.iter().enumerate() {
|
||||
if *b {
|
||||
statements_flattened.push(FlatStatement::Definition(
|
||||
|
@ -267,12 +281,93 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
box and_name.into(),
|
||||
);
|
||||
|
||||
statements_flattened.push(FlatStatement::Condition(
|
||||
FlatExpression::Number(T::from(1)),
|
||||
or,
|
||||
));
|
||||
res.push(or);
|
||||
}
|
||||
}
|
||||
|
||||
res
|
||||
}
|
||||
|
||||
/// Compute an equality check between two expressions
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `statements_flattened` - Vector where new flattened statements can be added.
|
||||
/// * `left - the first `FlatExpression`
|
||||
/// * `right` - the second `FlatExpression`
|
||||
///
|
||||
/// # Returns
|
||||
/// * A FlatExpression which evaluates to `1` if `left == right`, `0` otherwise
|
||||
fn eq_check(
|
||||
&mut self,
|
||||
statements_flattened: &mut Vec<FlatStatement<T>>,
|
||||
left: FlatExpression<T>,
|
||||
right: FlatExpression<T>,
|
||||
) -> FlatExpression<T> {
|
||||
let left = self.define(left, statements_flattened);
|
||||
let right = self.define(right, statements_flattened);
|
||||
|
||||
// Wanted: (Y = (X != 0) ? 1 : 0)
|
||||
// X = a - b
|
||||
// # Y = if X == 0 then 0 else 1 fi
|
||||
// # M = if X == 0 then 1 else 1/X fi
|
||||
// Y == X * M
|
||||
// 0 == (1-Y) * X
|
||||
|
||||
let x = FlatExpression::Sub(box left.into(), box right.into());
|
||||
|
||||
let name_y = self.use_sym();
|
||||
let name_m = self.use_sym();
|
||||
|
||||
statements_flattened.push(FlatStatement::Directive(FlatDirective::new(
|
||||
vec![name_y, name_m],
|
||||
Solver::ConditionEq,
|
||||
vec![x.clone()],
|
||||
)));
|
||||
statements_flattened.push(FlatStatement::Condition(
|
||||
FlatExpression::Identifier(name_y),
|
||||
FlatExpression::Mult(box x.clone(), box FlatExpression::Identifier(name_m)),
|
||||
));
|
||||
|
||||
let res = FlatExpression::Sub(
|
||||
box FlatExpression::Number(T::one()),
|
||||
box FlatExpression::Identifier(name_y),
|
||||
);
|
||||
|
||||
statements_flattened.push(FlatStatement::Condition(
|
||||
FlatExpression::Number(T::zero()),
|
||||
FlatExpression::Mult(box res.clone(), box x),
|
||||
));
|
||||
|
||||
res
|
||||
}
|
||||
|
||||
/// Enforce a range check against a constant: the range check isn't verified iff a constraint will fail
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `statements_flattened` - Vector where new flattened statements can be added.
|
||||
/// * `a` - the big-endian bit decomposition of the expression we enforce to be in range
|
||||
/// * `b` - the big-endian bit decomposition of the upper bound of the range
|
||||
fn enforce_constant_le_check(
|
||||
&mut self,
|
||||
statements_flattened: &mut FlatStatements<T>,
|
||||
a: &[FlatVariable],
|
||||
b: &[bool],
|
||||
) {
|
||||
let conditions = self.constant_le_check(statements_flattened, a, b);
|
||||
|
||||
let conditions_count = conditions.len();
|
||||
|
||||
let conditions_sum = conditions
|
||||
.into_iter()
|
||||
.fold(FlatExpression::from(T::zero()), |acc, e| {
|
||||
FlatExpression::Add(box acc, box e)
|
||||
});
|
||||
statements_flattened.push(FlatStatement::Condition(
|
||||
FlatExpression::Number(T::from(0)),
|
||||
FlatExpression::Sub(box conditions_sum, box T::from(conditions_count).into()),
|
||||
));
|
||||
}
|
||||
|
||||
/// Flatten an if/else expression
|
||||
|
@ -346,6 +441,109 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
}
|
||||
}
|
||||
|
||||
/// Compute a strict check against a constant
|
||||
/// # Arguments
|
||||
/// * `statements_flattened` - Vector where new flattened statements can be added.
|
||||
/// * `e` - the `FlatExpression` that's being checked against the range.
|
||||
/// * `c` - the constant strict upper bound of the range
|
||||
///
|
||||
/// # Returns
|
||||
/// * a `FlatExpression` which evaluates to `1` if `0 <= e < c`, and to `0` otherwise
|
||||
fn constant_lt_check(
|
||||
&mut self,
|
||||
statements_flattened: &mut FlatStatements<T>,
|
||||
e: FlatExpression<T>,
|
||||
c: T,
|
||||
) -> FlatExpression<T> {
|
||||
if c == T::zero() {
|
||||
// this is the case c == 0, we return 0, aka false
|
||||
return T::zero().into();
|
||||
}
|
||||
|
||||
self.constant_field_le_check(statements_flattened, e, c - T::one())
|
||||
}
|
||||
|
||||
/// Compute a range check against a constant
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `statements_flattened` - Vector where new flattened statements can be added.
|
||||
/// * `e` - the `FlatExpression` that's being checked against the range.
|
||||
/// * `c` - the constant upper bound of the range
|
||||
///
|
||||
/// # Returns
|
||||
/// * a `FlatExpression` which evaluates to `1` if `0 <= e <= c`, and to `0` otherwise
|
||||
fn constant_field_le_check(
|
||||
&mut self,
|
||||
statements_flattened: &mut FlatStatements<T>,
|
||||
e: FlatExpression<T>,
|
||||
c: T,
|
||||
) -> FlatExpression<T> {
|
||||
let bit_width = T::get_required_bits();
|
||||
// decompose e to bits
|
||||
let e_id = self.define(e, statements_flattened);
|
||||
|
||||
// define variables for the bits
|
||||
let e_bits_be: Vec<FlatVariable> = (0..bit_width).map(|_| self.use_sym()).collect();
|
||||
|
||||
// add a directive to get the bits
|
||||
statements_flattened.push(FlatStatement::Directive(FlatDirective::new(
|
||||
e_bits_be.clone(),
|
||||
Solver::bits(bit_width),
|
||||
vec![e_id],
|
||||
)));
|
||||
|
||||
// bitness checks
|
||||
for bit in e_bits_be.iter().take(bit_width) {
|
||||
statements_flattened.push(FlatStatement::Condition(
|
||||
FlatExpression::Identifier(*bit),
|
||||
FlatExpression::Mult(
|
||||
box FlatExpression::Identifier(*bit),
|
||||
box FlatExpression::Identifier(*bit),
|
||||
),
|
||||
));
|
||||
}
|
||||
|
||||
// bit decomposition check
|
||||
let mut e_sum = FlatExpression::Number(T::from(0));
|
||||
|
||||
for (i, bit) in e_bits_be.iter().take(bit_width).enumerate() {
|
||||
e_sum = FlatExpression::Add(
|
||||
box e_sum,
|
||||
box FlatExpression::Mult(
|
||||
box FlatExpression::Identifier(*bit),
|
||||
box FlatExpression::Number(T::from(2).pow(bit_width - i - 1)),
|
||||
),
|
||||
);
|
||||
}
|
||||
|
||||
statements_flattened.push(FlatStatement::Condition(
|
||||
FlatExpression::Identifier(e_id),
|
||||
e_sum,
|
||||
));
|
||||
|
||||
// check that this decomposition does not overflow the field
|
||||
self.enforce_constant_le_check(
|
||||
statements_flattened,
|
||||
&e_bits_be,
|
||||
&T::max_value().bit_vector_be(),
|
||||
);
|
||||
|
||||
let conditions =
|
||||
self.constant_le_check(statements_flattened, &e_bits_be, &c.bit_vector_be());
|
||||
|
||||
// return `len(conditions) == sum(conditions)`
|
||||
self.eq_check(
|
||||
statements_flattened,
|
||||
T::from(conditions.len()).into(),
|
||||
conditions
|
||||
.into_iter()
|
||||
.fold(FlatExpression::Number(T::zero()), |acc, e| {
|
||||
FlatExpression::Add(box acc, box e)
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
/// Flattens a boolean expression
|
||||
///
|
||||
/// # Arguments
|
||||
|
@ -370,7 +568,6 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
BooleanExpression::FieldLt(box lhs, box rhs) => {
|
||||
// Get the bit width to know the size of the binary decompositions for this Field
|
||||
let bit_width = T::get_required_bits();
|
||||
let safe_width = bit_width - 2; // making sure we don't overflow, assert here?
|
||||
|
||||
// We know from semantic checking that lhs and rhs have the same type
|
||||
// What the expression will flatten to depends on that type
|
||||
|
@ -378,157 +575,181 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
let lhs_flattened = self.flatten_field_expression(statements_flattened, lhs);
|
||||
let rhs_flattened = self.flatten_field_expression(statements_flattened, rhs);
|
||||
|
||||
// lhs
|
||||
let lhs_id = self.define(lhs_flattened, statements_flattened);
|
||||
|
||||
// check that lhs and rhs are within the right range, i.e., their higher two bits are zero. We use big-endian so they are at positions 0 and 1
|
||||
|
||||
// lhs
|
||||
{
|
||||
// define variables for the bits
|
||||
let lhs_bits_be: Vec<FlatVariable> =
|
||||
(0..safe_width).map(|_| self.use_sym()).collect();
|
||||
|
||||
// add a directive to get the bits
|
||||
statements_flattened.push(FlatStatement::Directive(FlatDirective::new(
|
||||
lhs_bits_be.clone(),
|
||||
Solver::bits(safe_width),
|
||||
vec![lhs_id],
|
||||
)));
|
||||
|
||||
// bitness checks
|
||||
for bit in lhs_bits_be.iter().take(safe_width) {
|
||||
statements_flattened.push(FlatStatement::Condition(
|
||||
FlatExpression::Identifier(*bit),
|
||||
FlatExpression::Mult(
|
||||
box FlatExpression::Identifier(*bit),
|
||||
box FlatExpression::Identifier(*bit),
|
||||
),
|
||||
));
|
||||
match (lhs_flattened, rhs_flattened) {
|
||||
(x, FlatExpression::Number(constant)) => {
|
||||
self.constant_lt_check(statements_flattened, x, constant)
|
||||
}
|
||||
// (c < x <= p - 1) <=> (0 <= p - 1 - x < p - 1 - c)
|
||||
(FlatExpression::Number(constant), x) => self.constant_lt_check(
|
||||
statements_flattened,
|
||||
FlatExpression::Sub(box T::max_value().into(), box x),
|
||||
T::max_value() - constant,
|
||||
),
|
||||
(lhs_flattened, rhs_flattened) => {
|
||||
let safe_width = bit_width - 2; // dynamic comparison is not complete, it only applies to words of width `bit_width - 2`
|
||||
|
||||
// bit decomposition check
|
||||
let mut lhs_sum = FlatExpression::Number(T::from(0));
|
||||
// lhs
|
||||
let lhs_id = self.define(lhs_flattened, statements_flattened);
|
||||
|
||||
for (i, bit) in lhs_bits_be.iter().enumerate().take(safe_width) {
|
||||
lhs_sum = FlatExpression::Add(
|
||||
box lhs_sum,
|
||||
// check that lhs and rhs are within the right range, i.e., their higher two bits are zero. We use big-endian so they are at positions 0 and 1
|
||||
|
||||
// lhs
|
||||
{
|
||||
// define variables for the bits
|
||||
let lhs_bits_be: Vec<FlatVariable> =
|
||||
(0..safe_width).map(|_| self.use_sym()).collect();
|
||||
|
||||
// add a directive to get the bits
|
||||
statements_flattened.push(FlatStatement::Directive(
|
||||
FlatDirective::new(
|
||||
lhs_bits_be.clone(),
|
||||
Solver::bits(safe_width),
|
||||
vec![lhs_id],
|
||||
),
|
||||
));
|
||||
|
||||
// bitness checks
|
||||
for bit in lhs_bits_be.iter().take(safe_width) {
|
||||
statements_flattened.push(FlatStatement::Condition(
|
||||
FlatExpression::Identifier(*bit),
|
||||
FlatExpression::Mult(
|
||||
box FlatExpression::Identifier(*bit),
|
||||
box FlatExpression::Identifier(*bit),
|
||||
),
|
||||
));
|
||||
}
|
||||
|
||||
// bit decomposition check
|
||||
let mut lhs_sum = FlatExpression::Number(T::from(0));
|
||||
|
||||
for (i, bit) in lhs_bits_be.iter().take(safe_width).enumerate() {
|
||||
lhs_sum = FlatExpression::Add(
|
||||
box lhs_sum,
|
||||
box FlatExpression::Mult(
|
||||
box FlatExpression::Identifier(*bit),
|
||||
box FlatExpression::Number(
|
||||
T::from(2).pow(safe_width - i - 1),
|
||||
),
|
||||
),
|
||||
);
|
||||
}
|
||||
|
||||
statements_flattened.push(FlatStatement::Condition(
|
||||
FlatExpression::Identifier(lhs_id),
|
||||
lhs_sum,
|
||||
));
|
||||
}
|
||||
|
||||
// rhs
|
||||
let rhs_id = self.define(rhs_flattened, statements_flattened);
|
||||
|
||||
// rhs
|
||||
{
|
||||
// define variables for the bits
|
||||
let rhs_bits_be: Vec<FlatVariable> =
|
||||
(0..safe_width).map(|_| self.use_sym()).collect();
|
||||
|
||||
// add a directive to get the bits
|
||||
statements_flattened.push(FlatStatement::Directive(
|
||||
FlatDirective::new(
|
||||
rhs_bits_be.clone(),
|
||||
Solver::bits(safe_width),
|
||||
vec![rhs_id],
|
||||
),
|
||||
));
|
||||
|
||||
// bitness checks
|
||||
for bit in rhs_bits_be.iter().take(safe_width) {
|
||||
statements_flattened.push(FlatStatement::Condition(
|
||||
FlatExpression::Identifier(*bit),
|
||||
FlatExpression::Mult(
|
||||
box FlatExpression::Identifier(*bit),
|
||||
box FlatExpression::Identifier(*bit),
|
||||
),
|
||||
));
|
||||
}
|
||||
|
||||
// bit decomposition check
|
||||
let mut rhs_sum = FlatExpression::Number(T::from(0));
|
||||
|
||||
for (i, bit) in rhs_bits_be.iter().take(safe_width).enumerate() {
|
||||
rhs_sum = FlatExpression::Add(
|
||||
box rhs_sum,
|
||||
box FlatExpression::Mult(
|
||||
box FlatExpression::Identifier(*bit),
|
||||
box FlatExpression::Number(
|
||||
T::from(2).pow(safe_width - i - 1),
|
||||
),
|
||||
),
|
||||
);
|
||||
}
|
||||
|
||||
statements_flattened.push(FlatStatement::Condition(
|
||||
FlatExpression::Identifier(rhs_id),
|
||||
rhs_sum,
|
||||
));
|
||||
}
|
||||
|
||||
// sym := (lhs * 2) - (rhs * 2)
|
||||
let subtraction_result = FlatExpression::Sub(
|
||||
box FlatExpression::Mult(
|
||||
box FlatExpression::Identifier(*bit),
|
||||
box FlatExpression::Number(T::from(2).pow(safe_width - i - 1)),
|
||||
box FlatExpression::Number(T::from(2)),
|
||||
box FlatExpression::Identifier(lhs_id),
|
||||
),
|
||||
box FlatExpression::Mult(
|
||||
box FlatExpression::Number(T::from(2)),
|
||||
box FlatExpression::Identifier(rhs_id),
|
||||
),
|
||||
);
|
||||
}
|
||||
|
||||
statements_flattened.push(FlatStatement::Condition(
|
||||
FlatExpression::Identifier(lhs_id),
|
||||
lhs_sum,
|
||||
));
|
||||
}
|
||||
// define variables for the bits
|
||||
let sub_bits_be: Vec<FlatVariable> =
|
||||
(0..bit_width).map(|_| self.use_sym()).collect();
|
||||
|
||||
// rhs
|
||||
let rhs_id = self.define(rhs_flattened, statements_flattened);
|
||||
// add a directive to get the bits
|
||||
statements_flattened.push(FlatStatement::Directive(FlatDirective::new(
|
||||
sub_bits_be.clone(),
|
||||
Solver::bits(bit_width),
|
||||
vec![subtraction_result.clone()],
|
||||
)));
|
||||
|
||||
// rhs
|
||||
{
|
||||
// define variables for the bits
|
||||
let rhs_bits_be: Vec<FlatVariable> =
|
||||
(0..safe_width).map(|_| self.use_sym()).collect();
|
||||
// bitness checks
|
||||
for bit in sub_bits_be.iter().take(bit_width) {
|
||||
statements_flattened.push(FlatStatement::Condition(
|
||||
FlatExpression::Identifier(*bit),
|
||||
FlatExpression::Mult(
|
||||
box FlatExpression::Identifier(*bit),
|
||||
box FlatExpression::Identifier(*bit),
|
||||
),
|
||||
));
|
||||
}
|
||||
|
||||
// add a directive to get the bits
|
||||
statements_flattened.push(FlatStatement::Directive(FlatDirective::new(
|
||||
rhs_bits_be.clone(),
|
||||
Solver::bits(safe_width),
|
||||
vec![rhs_id],
|
||||
)));
|
||||
|
||||
// bitness checks
|
||||
for bit in rhs_bits_be.iter().take(safe_width) {
|
||||
statements_flattened.push(FlatStatement::Condition(
|
||||
FlatExpression::Identifier(*bit),
|
||||
FlatExpression::Mult(
|
||||
box FlatExpression::Identifier(*bit),
|
||||
box FlatExpression::Identifier(*bit),
|
||||
),
|
||||
));
|
||||
}
|
||||
|
||||
// bit decomposition check
|
||||
let mut rhs_sum = FlatExpression::Number(T::from(0));
|
||||
|
||||
for (i, bit) in rhs_bits_be.iter().enumerate().take(safe_width) {
|
||||
rhs_sum = FlatExpression::Add(
|
||||
box rhs_sum,
|
||||
box FlatExpression::Mult(
|
||||
box FlatExpression::Identifier(*bit),
|
||||
box FlatExpression::Number(T::from(2).pow(safe_width - i - 1)),
|
||||
),
|
||||
// check that the decomposition is in the field with a strict `< p` checks
|
||||
self.enforce_constant_le_check(
|
||||
statements_flattened,
|
||||
&sub_bits_be,
|
||||
&T::max_value().bit_vector_be(),
|
||||
);
|
||||
|
||||
// sum(sym_b{i} * 2**i)
|
||||
let mut expr = FlatExpression::Number(T::from(0));
|
||||
|
||||
for (i, bit) in sub_bits_be.iter().take(bit_width).enumerate() {
|
||||
expr = FlatExpression::Add(
|
||||
box expr,
|
||||
box FlatExpression::Mult(
|
||||
box FlatExpression::Identifier(*bit),
|
||||
box FlatExpression::Number(T::from(2).pow(bit_width - i - 1)),
|
||||
),
|
||||
);
|
||||
}
|
||||
|
||||
statements_flattened
|
||||
.push(FlatStatement::Condition(subtraction_result, expr));
|
||||
|
||||
FlatExpression::Identifier(sub_bits_be[bit_width - 1])
|
||||
}
|
||||
|
||||
statements_flattened.push(FlatStatement::Condition(
|
||||
FlatExpression::Identifier(rhs_id),
|
||||
rhs_sum,
|
||||
));
|
||||
}
|
||||
|
||||
// sym := (lhs * 2) - (rhs * 2)
|
||||
let subtraction_result = FlatExpression::Sub(
|
||||
box FlatExpression::Mult(
|
||||
box FlatExpression::Number(T::from(2)),
|
||||
box FlatExpression::Identifier(lhs_id),
|
||||
),
|
||||
box FlatExpression::Mult(
|
||||
box FlatExpression::Number(T::from(2)),
|
||||
box FlatExpression::Identifier(rhs_id),
|
||||
),
|
||||
);
|
||||
|
||||
// define variables for the bits
|
||||
let sub_bits_be: Vec<FlatVariable> =
|
||||
(0..bit_width).map(|_| self.use_sym()).collect();
|
||||
|
||||
// add a directive to get the bits
|
||||
statements_flattened.push(FlatStatement::Directive(FlatDirective::new(
|
||||
sub_bits_be.clone(),
|
||||
Solver::bits(bit_width),
|
||||
vec![subtraction_result.clone()],
|
||||
)));
|
||||
|
||||
// bitness checks
|
||||
for bit in sub_bits_be.iter().take(bit_width) {
|
||||
statements_flattened.push(FlatStatement::Condition(
|
||||
FlatExpression::Identifier(*bit),
|
||||
FlatExpression::Mult(
|
||||
box FlatExpression::Identifier(*bit),
|
||||
box FlatExpression::Identifier(*bit),
|
||||
),
|
||||
));
|
||||
}
|
||||
|
||||
// check that the decomposition is in the field with a strict `< p` checks
|
||||
self.strict_le_check(
|
||||
statements_flattened,
|
||||
&T::max_value_bit_vector_be(),
|
||||
sub_bits_be.clone(),
|
||||
);
|
||||
|
||||
// sum(sym_b{i} * 2**i)
|
||||
let mut expr = FlatExpression::Number(T::from(0));
|
||||
|
||||
for (i, bit) in sub_bits_be.iter().enumerate().take(bit_width) {
|
||||
expr = FlatExpression::Add(
|
||||
box expr,
|
||||
box FlatExpression::Mult(
|
||||
box FlatExpression::Identifier(*bit),
|
||||
box FlatExpression::Number(T::from(2).pow(bit_width - i - 1)),
|
||||
),
|
||||
);
|
||||
}
|
||||
|
||||
statements_flattened.push(FlatStatement::Condition(subtraction_result, expr));
|
||||
|
||||
FlatExpression::Identifier(sub_bits_be[bit_width - 1])
|
||||
}
|
||||
BooleanExpression::BoolEq(box lhs, box rhs) => {
|
||||
// lhs and rhs are booleans, they flatten to 0 or 1
|
||||
|
@ -563,42 +784,11 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
)
|
||||
}
|
||||
BooleanExpression::FieldEq(box lhs, box rhs) => {
|
||||
// Wanted: (Y = (X != 0) ? 1 : 0)
|
||||
// X = a - b
|
||||
// # Y = if X == 0 then 0 else 1 fi
|
||||
// # M = if X == 0 then 1 else 1/X fi
|
||||
// Y == X * M
|
||||
// 0 == (1-Y) * X
|
||||
let lhs = self.flatten_field_expression(statements_flattened, lhs);
|
||||
|
||||
let name_y = self.use_sym();
|
||||
let name_m = self.use_sym();
|
||||
let rhs = self.flatten_field_expression(statements_flattened, rhs);
|
||||
|
||||
let x = self.flatten_field_expression(
|
||||
statements_flattened,
|
||||
FieldElementExpression::Sub(box lhs, box rhs),
|
||||
);
|
||||
|
||||
statements_flattened.push(FlatStatement::Directive(FlatDirective::new(
|
||||
vec![name_y, name_m],
|
||||
Solver::ConditionEq,
|
||||
vec![x.clone()],
|
||||
)));
|
||||
statements_flattened.push(FlatStatement::Condition(
|
||||
FlatExpression::Identifier(name_y),
|
||||
FlatExpression::Mult(box x.clone(), box FlatExpression::Identifier(name_m)),
|
||||
));
|
||||
|
||||
let res = FlatExpression::Sub(
|
||||
box FlatExpression::Number(T::one()),
|
||||
box FlatExpression::Identifier(name_y),
|
||||
);
|
||||
|
||||
statements_flattened.push(FlatStatement::Condition(
|
||||
FlatExpression::Number(T::zero()),
|
||||
FlatExpression::Mult(box res.clone(), box x),
|
||||
));
|
||||
|
||||
res
|
||||
self.eq_check(statements_flattened, lhs, rhs)
|
||||
}
|
||||
BooleanExpression::UintEq(box lhs, box rhs) => {
|
||||
// We reduce each side into range and apply the same approach as for field elements
|
||||
|
@ -610,9 +800,6 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
// Y == X * M
|
||||
// 0 == (1-Y) * X
|
||||
|
||||
let name_y = self.use_sym();
|
||||
let name_m = self.use_sym();
|
||||
|
||||
assert!(lhs.metadata.clone().unwrap().should_reduce.to_bool());
|
||||
assert!(rhs.metadata.clone().unwrap().should_reduce.to_bool());
|
||||
|
||||
|
@ -623,29 +810,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
.flatten_uint_expression(statements_flattened, rhs)
|
||||
.get_field_unchecked();
|
||||
|
||||
let x = FlatExpression::Sub(box lhs, box rhs);
|
||||
|
||||
statements_flattened.push(FlatStatement::Directive(FlatDirective::new(
|
||||
vec![name_y, name_m],
|
||||
Solver::ConditionEq,
|
||||
vec![x.clone()],
|
||||
)));
|
||||
statements_flattened.push(FlatStatement::Condition(
|
||||
FlatExpression::Identifier(name_y),
|
||||
FlatExpression::Mult(box x.clone(), box FlatExpression::Identifier(name_m)),
|
||||
));
|
||||
|
||||
let res = FlatExpression::Sub(
|
||||
box FlatExpression::Number(T::one()),
|
||||
box FlatExpression::Identifier(name_y),
|
||||
);
|
||||
|
||||
statements_flattened.push(FlatStatement::Condition(
|
||||
FlatExpression::Number(T::zero()),
|
||||
FlatExpression::Mult(box res.clone(), box x),
|
||||
));
|
||||
|
||||
res
|
||||
self.eq_check(statements_flattened, lhs, rhs)
|
||||
}
|
||||
BooleanExpression::FieldLe(box lhs, box rhs) => {
|
||||
let lt = self.flatten_boolean_expression(
|
||||
|
@ -713,10 +878,10 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
}
|
||||
|
||||
// check that the decomposition is in the field with a strict `< p` checks
|
||||
self.strict_le_check(
|
||||
self.constant_le_check(
|
||||
statements_flattened,
|
||||
&T::max_value_bit_vector_be(),
|
||||
sub_bits_be.clone(),
|
||||
&sub_bits_be,
|
||||
&T::max_value().bit_vector_be(),
|
||||
);
|
||||
|
||||
// sum(sym_b{i} * 2**i)
|
||||
|
@ -2343,7 +2508,7 @@ mod tests {
|
|||
ZirStatement::Assertion(BooleanExpression::FieldEq(
|
||||
box FieldElementExpression::Add(
|
||||
box FieldElementExpression::Identifier("x".into()),
|
||||
box FieldElementExpression::Number(Bn128Field::from(1)).into(),
|
||||
box FieldElementExpression::Number(Bn128Field::from(1)),
|
||||
),
|
||||
box FieldElementExpression::Identifier("y".into()),
|
||||
)),
|
||||
|
|
|
@ -14,8 +14,8 @@ use zokrates_field::Bn128Field;
|
|||
#[test]
|
||||
fn out_of_range() {
|
||||
let source = r#"
|
||||
def main(private field a) -> field:
|
||||
field x = if a < 5555 then 3333 else 4444 fi
|
||||
def main(private field a, private field b) -> field:
|
||||
field x = if a < b then 3333 else 4444 fi
|
||||
assert(x == 3333)
|
||||
return 1
|
||||
"#
|
||||
|
@ -36,6 +36,9 @@ fn out_of_range() {
|
|||
let interpreter = Interpreter::try_out_of_range();
|
||||
|
||||
assert!(interpreter
|
||||
.execute(&res.prog(), &[Bn128Field::from(10000)])
|
||||
.execute(
|
||||
&res.prog(),
|
||||
&[Bn128Field::from(10000), Bn128Field::from(5555)]
|
||||
)
|
||||
.is_err());
|
||||
}
|
||||
|
|
86
zokrates_core_test/tests/tests/le.json
Normal file
86
zokrates_core_test/tests/tests/le.json
Normal file
|
@ -0,0 +1,86 @@
|
|||
{
|
||||
"entry_point": "./tests/tests/le.zok",
|
||||
"curves": ["Bn128"],
|
||||
"tests": [
|
||||
{
|
||||
"input": {
|
||||
"values": ["0"]
|
||||
},
|
||||
"output": {
|
||||
"Ok": {
|
||||
"values": ["1"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"input": {
|
||||
"values": ["1"]
|
||||
},
|
||||
"output": {
|
||||
"Ok": {
|
||||
"values": ["1"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"input": {
|
||||
"values": ["2"]
|
||||
},
|
||||
"output": {
|
||||
"Ok": {
|
||||
"values": ["1"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"input": {
|
||||
"values": ["41"]
|
||||
},
|
||||
"output": {
|
||||
"Ok": {
|
||||
"values": ["1"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"input": {
|
||||
"values": ["42"]
|
||||
},
|
||||
"output": {
|
||||
"Ok": {
|
||||
"values": ["1"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"input": {
|
||||
"values": ["43"]
|
||||
},
|
||||
"output": {
|
||||
"Ok": {
|
||||
"values": ["0"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"input": {
|
||||
"values": ["44"]
|
||||
},
|
||||
"output": {
|
||||
"Ok": {
|
||||
"values": ["0"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"input": {
|
||||
"values": ["100"]
|
||||
},
|
||||
"output": {
|
||||
"Ok": {
|
||||
"values": ["0"]
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
2
zokrates_core_test/tests/tests/le.zok
Normal file
2
zokrates_core_test/tests/tests/le.zok
Normal file
|
@ -0,0 +1,2 @@
|
|||
def main(field e) -> bool:
|
||||
return e <= 42
|
86
zokrates_core_test/tests/tests/native_le.json
Normal file
86
zokrates_core_test/tests/tests/native_le.json
Normal file
|
@ -0,0 +1,86 @@
|
|||
{
|
||||
"entry_point": "./tests/tests/native_le.zok",
|
||||
"curves": ["Bn128"],
|
||||
"tests": [
|
||||
{
|
||||
"input": {
|
||||
"values": ["0", "0"]
|
||||
},
|
||||
"output": {
|
||||
"Ok": {
|
||||
"values": ["1", "1"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"input": {
|
||||
"values": ["1", "1"]
|
||||
},
|
||||
"output": {
|
||||
"Ok": {
|
||||
"values": ["1", "1"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"input": {
|
||||
"values": ["2", "2"]
|
||||
},
|
||||
"output": {
|
||||
"Ok": {
|
||||
"values": ["1", "1"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"input": {
|
||||
"values": ["41", "41"]
|
||||
},
|
||||
"output": {
|
||||
"Ok": {
|
||||
"values": ["1", "1"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"input": {
|
||||
"values": ["42", "42"]
|
||||
},
|
||||
"output": {
|
||||
"Ok": {
|
||||
"values": ["1", "1"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"input": {
|
||||
"values": ["43", "43"]
|
||||
},
|
||||
"output": {
|
||||
"Ok": {
|
||||
"values": ["0", "0"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"input": {
|
||||
"values": ["44", "44"]
|
||||
},
|
||||
"output": {
|
||||
"Ok": {
|
||||
"values": ["0", "0"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"input": {
|
||||
"values": ["100", "100"]
|
||||
},
|
||||
"output": {
|
||||
"Ok": {
|
||||
"values": ["0", "0"]
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
47
zokrates_core_test/tests/tests/native_le.zok
Normal file
47
zokrates_core_test/tests/tests/native_le.zok
Normal file
|
@ -0,0 +1,47 @@
|
|||
from "utils/pack/bool/unpack.zok" import main as unpack
|
||||
from "utils/casts/u32_to_bits" import main as u32_to_bits
|
||||
|
||||
// this comparison works for any N smaller than the field size, which is the case in practice
|
||||
def le<N>(bool[N] a_bits, bool[N] c_bits) -> bool:
|
||||
|
||||
bool size_unknown = false
|
||||
|
||||
u32 verified_conditions = 0 // `and(conditions) == (sum(conditions) == len(conditions))`, here we initialize `sum(conditions)`
|
||||
|
||||
size_unknown = true
|
||||
|
||||
for u32 i in 0..N do
|
||||
verified_conditions = verified_conditions + if c_bits[i] || (!size_unknown || !a_bits[i]) then 1 else 0 fi
|
||||
size_unknown = if c_bits[i] then size_unknown && a_bits[i] else size_unknown fi // this is actually not required in the last round
|
||||
endfor
|
||||
|
||||
return verified_conditions == N // this checks that all conditions were verified
|
||||
|
||||
// this instanciates comparison starting from field elements
|
||||
def le<N>(field a, field c) -> bool:
|
||||
|
||||
field MAX = 21888242871839275222246405745257275088548364400416034343698204186575808495616
|
||||
bool[N] MAX_BITS = unpack::<N>(MAX)
|
||||
|
||||
bool[N] a_bits = unpack(a)
|
||||
assert(le(a_bits, MAX_BITS))
|
||||
bool[N] c_bits = unpack(c)
|
||||
assert(le(c_bits, MAX_BITS))
|
||||
|
||||
return le(a_bits, c_bits)
|
||||
|
||||
// this instanciates comparison starting from u32
|
||||
def le(u32 a, u32 c) -> bool:
|
||||
bool[32] a_bits = u32_to_bits(a)
|
||||
bool[32] c_bits = u32_to_bits(c)
|
||||
|
||||
return le(a_bits, c_bits)
|
||||
|
||||
def main(field a, u32 b) -> (bool, bool):
|
||||
|
||||
u32 N = 254
|
||||
field c = 42
|
||||
|
||||
u32 d = 42
|
||||
|
||||
return le::<N>(a, c), le(b, d)
|
76
zokrates_core_test/tests/tests/range.json
Normal file
76
zokrates_core_test/tests/tests/range.json
Normal file
|
@ -0,0 +1,76 @@
|
|||
{
|
||||
"entry_point": "./tests/tests/range.zok",
|
||||
"curves": ["Bn128"],
|
||||
"tests": [
|
||||
{
|
||||
"input": {
|
||||
"values": ["0"]
|
||||
},
|
||||
"output": {
|
||||
"Ok": {
|
||||
"values": ["0", "1", "1", "1", "0"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"input": {
|
||||
"values": ["1"]
|
||||
},
|
||||
"output": {
|
||||
"Ok": {
|
||||
"values": ["0", "0", "1", "1", "1"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"input": {
|
||||
"values": ["2"]
|
||||
},
|
||||
"output": {
|
||||
"Ok": {
|
||||
"values": ["0", "0", "1", "1", "0"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"input": {
|
||||
"values": ["254"]
|
||||
},
|
||||
"output": {
|
||||
"Ok": {
|
||||
"values": ["0", "0", "1", "1", "0"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"input": {
|
||||
"values": ["255"]
|
||||
},
|
||||
"output": {
|
||||
"Ok": {
|
||||
"values": ["0", "0", "0", "1", "0"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"input": {
|
||||
"values": ["21888242871839275222246405745257275088548364400416034343698204186575808495615"]
|
||||
},
|
||||
"output": {
|
||||
"Ok": {
|
||||
"values": ["0", "0", "0", "1", "0"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"input": {
|
||||
"values": ["21888242871839275222246405745257275088548364400416034343698204186575808495616"]
|
||||
},
|
||||
"output": {
|
||||
"Ok": {
|
||||
"values": ["0", "0", "0", "0", "0"]
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
2
zokrates_core_test/tests/tests/range.zok
Normal file
2
zokrates_core_test/tests/tests/range.zok
Normal file
|
@ -0,0 +1,2 @@
|
|||
def main(field x) -> bool[5]:
|
||||
return [x < 0, x < 1, x < 255, x < 0 - 1, x - 2 > 0 - 2]
|
|
@ -30,7 +30,7 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn max_value_bits() {
|
||||
let bits = FieldPrime::max_value_bit_vector_be();
|
||||
let bits = FieldPrime::max_value().bit_vector_be();
|
||||
assert_eq!(
|
||||
bits[0..10].to_vec(),
|
||||
vec![true, true, false, false, false, false, false, true, true, false]
|
||||
|
|
|
@ -116,7 +116,8 @@ pub trait Field:
|
|||
/// Gets the number of bits
|
||||
fn bits(&self) -> u32;
|
||||
/// Returns this `Field`'s largest value as a big-endian bit vector
|
||||
fn max_value_bit_vector_be() -> Vec<bool> {
|
||||
/// Always returns `Self::get_required_bits()` elements
|
||||
fn bit_vector_be(&self) -> Vec<bool> {
|
||||
fn bytes_to_bits(bytes: &[u8]) -> Vec<bool> {
|
||||
bytes
|
||||
.iter()
|
||||
|
@ -124,13 +125,26 @@ pub trait Field:
|
|||
.collect()
|
||||
}
|
||||
|
||||
let field_bytes_le = Self::to_byte_vector(&Self::max_value());
|
||||
let field_bytes_le = self.to_byte_vector();
|
||||
|
||||
// reverse for big-endianess
|
||||
let field_bytes_be = field_bytes_le.into_iter().rev().collect::<Vec<u8>>();
|
||||
let field_bits_be = bytes_to_bits(&field_bytes_be);
|
||||
|
||||
let field_bits_be = &field_bits_be[field_bits_be.len() - Self::get_required_bits()..];
|
||||
field_bits_be.to_vec()
|
||||
let field_bits_be: Vec<_> = (0..Self::get_required_bits()
|
||||
.saturating_sub(field_bits_be.len()))
|
||||
.map(|_| &false)
|
||||
.chain(
|
||||
&field_bits_be[field_bits_be
|
||||
.len()
|
||||
.saturating_sub(Self::get_required_bits())..],
|
||||
)
|
||||
.cloned()
|
||||
.collect();
|
||||
|
||||
assert_eq!(field_bits_be.len(), Self::get_required_bits());
|
||||
|
||||
field_bits_be
|
||||
}
|
||||
/// Returns the value as a BigUint
|
||||
fn to_biguint(&self) -> BigUint;
|
||||
|
|
Loading…
Reference in a new issue