extract eq check code, implement c < x checks, update tests to use dynamic comparison
This commit is contained in:
parent
0ffcb31392
commit
0a963dd152
8 changed files with 232 additions and 165 deletions
3
test.zok
3
test.zok
|
@ -1,3 +0,0 @@
|
|||
def main(field x):
|
||||
assert(x < 3)
|
||||
return
|
|
@ -9,4 +9,4 @@ def main(field a) -> bool:
|
|||
// maxvalue = 2**252 - 1
|
||||
field maxvalue = a + 7237005577332262213973186563042994240829374041602535252466099000494570602496 - 1
|
||||
// we added a = 0 to prevent the condition to be evaluated at compile time
|
||||
return 0 < (maxvalue + 1)
|
||||
return a < (maxvalue + 1)
|
|
@ -4,4 +4,4 @@
|
|||
def main(field a) -> bool:
|
||||
field p = 21888242871839275222246405745257275088548364400416034343698204186575808495616 + a
|
||||
// we added a = 0 to prevent the condition to be evaluated at compile time
|
||||
return 0 < p
|
||||
return a < p
|
|
@ -218,6 +218,12 @@ pub enum FlatExpression<T> {
|
|||
Mult(Box<FlatExpression<T>>, Box<FlatExpression<T>>),
|
||||
}
|
||||
|
||||
impl<T> From<T> for FlatExpression<T> {
|
||||
fn from(other: T) -> Self {
|
||||
Self::Number(other)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Field> FlatExpression<T> {
|
||||
pub fn apply_substitution(
|
||||
self,
|
||||
|
|
|
@ -288,6 +288,47 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
res
|
||||
}
|
||||
|
||||
fn eq_check(
|
||||
&mut self,
|
||||
statements_flattened: &mut Vec<FlatStatement<T>>,
|
||||
left: FlatExpression<T>,
|
||||
right: FlatExpression<T>,
|
||||
) -> FlatExpression<T> {
|
||||
// Wanted: (Y = (X != 0) ? 1 : 0)
|
||||
// X = a - b
|
||||
// # Y = if X == 0 then 0 else 1 fi
|
||||
// # M = if X == 0 then 1 else 1/X fi
|
||||
// Y == X * M
|
||||
// 0 == (1-Y) * X
|
||||
|
||||
let x = FlatExpression::Sub(box left, box right);
|
||||
|
||||
let name_y = self.use_sym();
|
||||
let name_m = self.use_sym();
|
||||
|
||||
statements_flattened.push(FlatStatement::Directive(FlatDirective::new(
|
||||
vec![name_y, name_m],
|
||||
Solver::ConditionEq,
|
||||
vec![x.clone()],
|
||||
)));
|
||||
statements_flattened.push(FlatStatement::Condition(
|
||||
FlatExpression::Identifier(name_y),
|
||||
FlatExpression::Mult(box x.clone(), box FlatExpression::Identifier(name_m)),
|
||||
));
|
||||
|
||||
let res = FlatExpression::Sub(
|
||||
box FlatExpression::Number(T::one()),
|
||||
box FlatExpression::Identifier(name_y),
|
||||
);
|
||||
|
||||
statements_flattened.push(FlatStatement::Condition(
|
||||
FlatExpression::Number(T::zero()),
|
||||
FlatExpression::Mult(box res.clone(), box x),
|
||||
));
|
||||
|
||||
res
|
||||
}
|
||||
|
||||
fn enforce_strict_le_check(
|
||||
&mut self,
|
||||
statements_flattened: &mut FlatStatements<T>,
|
||||
|
@ -375,6 +416,96 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
}
|
||||
}
|
||||
|
||||
/// Compute a range check against a constant
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `statements_flattened` - Vector where new flattened statements can be added.
|
||||
/// * `e` - the `FlatExpression` that's being checked against the range.
|
||||
/// * `c` - the constant upper bound of the range
|
||||
///
|
||||
/// # Returns
|
||||
/// * a `FlatExpression` which evaluates to `1` if `0 <= e < c`, and to `0` otherwise
|
||||
fn constant_range_check(
|
||||
&mut self,
|
||||
statements_flattened: &mut FlatStatements<T>,
|
||||
e: FlatExpression<T>,
|
||||
c: T,
|
||||
) -> FlatExpression<T> {
|
||||
// we make use of constant `<=` checks in this function, therefore we rely on the fact that:
|
||||
// `a < c <=> (a <= c - 1 if c !=0, false if c == 0)`
|
||||
|
||||
if c == T::zero() {
|
||||
// this is the case c == 0, we return 0, aka false
|
||||
return T::zero().into();
|
||||
}
|
||||
|
||||
let c = c - T::one();
|
||||
|
||||
let bit_width = T::get_required_bits();
|
||||
// decompose e to bits
|
||||
let e_id = self.define(e, statements_flattened);
|
||||
|
||||
// define variables for the bits
|
||||
let e_bits_be: Vec<FlatVariable> = (0..bit_width).map(|_| self.use_sym()).collect();
|
||||
|
||||
// add a directive to get the bits
|
||||
statements_flattened.push(FlatStatement::Directive(FlatDirective::new(
|
||||
e_bits_be.clone(),
|
||||
Solver::bits(bit_width),
|
||||
vec![e_id],
|
||||
)));
|
||||
|
||||
// bitness checks
|
||||
for i in 0..bit_width {
|
||||
statements_flattened.push(FlatStatement::Condition(
|
||||
FlatExpression::Identifier(e_bits_be[i]),
|
||||
FlatExpression::Mult(
|
||||
box FlatExpression::Identifier(e_bits_be[i]),
|
||||
box FlatExpression::Identifier(e_bits_be[i]),
|
||||
),
|
||||
));
|
||||
}
|
||||
|
||||
// bit decomposition check
|
||||
let mut e_sum = FlatExpression::Number(T::from(0));
|
||||
|
||||
for i in 0..bit_width {
|
||||
e_sum = FlatExpression::Add(
|
||||
box e_sum,
|
||||
box FlatExpression::Mult(
|
||||
box FlatExpression::Identifier(e_bits_be[i]),
|
||||
box FlatExpression::Number(T::from(2).pow(bit_width - i - 1)),
|
||||
),
|
||||
);
|
||||
}
|
||||
|
||||
statements_flattened.push(FlatStatement::Condition(
|
||||
FlatExpression::Identifier(e_id),
|
||||
e_sum,
|
||||
));
|
||||
|
||||
// check that this decomposition does not overflow the field
|
||||
self.enforce_strict_le_check(
|
||||
statements_flattened,
|
||||
&e_bits_be,
|
||||
&T::max_value().bit_vector_be(),
|
||||
);
|
||||
|
||||
let conditions = self.strict_le_check(statements_flattened, &e_bits_be, &c.bit_vector_be());
|
||||
|
||||
// return `len(conditions) == sum(conditions)`
|
||||
self.eq_check(
|
||||
statements_flattened,
|
||||
T::from(conditions.len()).into(),
|
||||
conditions
|
||||
.into_iter()
|
||||
.fold(FlatExpression::Number(T::zero()), |acc, e| {
|
||||
FlatExpression::Add(box acc, box e)
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
/// Flattens a boolean expression
|
||||
///
|
||||
/// # Arguments
|
||||
|
@ -401,7 +532,6 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
BooleanExpression::Lt(box lhs, box rhs) => {
|
||||
// Get the bit width to know the size of the binary decompositions for this Field
|
||||
let bit_width = T::get_required_bits();
|
||||
let safe_width = bit_width - 2; // making sure we don't overflow, assert here?
|
||||
|
||||
// We know from semantic checking that lhs and rhs have the same type
|
||||
// What the expression will flatten to depends on that type
|
||||
|
@ -412,105 +542,18 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
self.flatten_field_expression(symbols, statements_flattened, rhs);
|
||||
|
||||
match (lhs_flattened, rhs_flattened) {
|
||||
(lhs_flattened, FlatExpression::Number(constant)) => {
|
||||
// decompose lhs to bits
|
||||
let lhs_id = self.define(lhs_flattened, statements_flattened);
|
||||
|
||||
// define variables for the bits
|
||||
let lhs_bits_be: Vec<FlatVariable> =
|
||||
(0..bit_width).map(|_| self.use_sym()).collect();
|
||||
|
||||
// add a directive to get the bits
|
||||
statements_flattened.push(FlatStatement::Directive(FlatDirective::new(
|
||||
lhs_bits_be.clone(),
|
||||
Solver::bits(bit_width),
|
||||
vec![lhs_id],
|
||||
)));
|
||||
|
||||
// bitness checks
|
||||
for i in 0..bit_width {
|
||||
statements_flattened.push(FlatStatement::Condition(
|
||||
FlatExpression::Identifier(lhs_bits_be[i]),
|
||||
FlatExpression::Mult(
|
||||
box FlatExpression::Identifier(lhs_bits_be[i]),
|
||||
box FlatExpression::Identifier(lhs_bits_be[i]),
|
||||
),
|
||||
));
|
||||
}
|
||||
|
||||
// bit decomposition check
|
||||
let mut lhs_sum = FlatExpression::Number(T::from(0));
|
||||
|
||||
for i in 0..bit_width {
|
||||
lhs_sum = FlatExpression::Add(
|
||||
box lhs_sum,
|
||||
box FlatExpression::Mult(
|
||||
box FlatExpression::Identifier(lhs_bits_be[i]),
|
||||
box FlatExpression::Number(T::from(2).pow(bit_width - i - 1)),
|
||||
),
|
||||
);
|
||||
}
|
||||
|
||||
statements_flattened.push(FlatStatement::Condition(
|
||||
FlatExpression::Identifier(lhs_id),
|
||||
lhs_sum,
|
||||
));
|
||||
|
||||
// check that this decomposition does not overflow the field
|
||||
self.enforce_strict_le_check(
|
||||
statements_flattened,
|
||||
&lhs_bits_be,
|
||||
&T::max_value().bit_vector_be(),
|
||||
);
|
||||
|
||||
let conditions = self.strict_le_check(
|
||||
statements_flattened,
|
||||
&lhs_bits_be,
|
||||
&constant.bit_vector_be(),
|
||||
);
|
||||
|
||||
// define the condition as `conditions.len() - sum(conditions)`
|
||||
let condition = FlatExpression::Sub(
|
||||
box FlatExpression::Number(T::from(conditions.len())),
|
||||
box conditions
|
||||
.into_iter()
|
||||
.fold(FlatExpression::Number(T::zero()), |acc, e| {
|
||||
FlatExpression::Add(box acc, box e)
|
||||
}),
|
||||
);
|
||||
|
||||
// return `condition == 0`
|
||||
|
||||
// copy pasted the fieldeq flattening code
|
||||
let name_y = self.use_sym();
|
||||
let name_m = self.use_sym();
|
||||
|
||||
statements_flattened.push(FlatStatement::Directive(FlatDirective::new(
|
||||
vec![name_y, name_m],
|
||||
Solver::ConditionEq,
|
||||
vec![condition.clone()],
|
||||
)));
|
||||
statements_flattened.push(FlatStatement::Condition(
|
||||
FlatExpression::Identifier(name_y),
|
||||
FlatExpression::Mult(
|
||||
box condition.clone(),
|
||||
box FlatExpression::Identifier(name_m),
|
||||
),
|
||||
));
|
||||
|
||||
let res = FlatExpression::Sub(
|
||||
box FlatExpression::Number(T::one()),
|
||||
box FlatExpression::Identifier(name_y),
|
||||
);
|
||||
|
||||
statements_flattened.push(FlatStatement::Condition(
|
||||
FlatExpression::Number(T::zero()),
|
||||
FlatExpression::Mult(box res.clone(), box condition),
|
||||
));
|
||||
|
||||
res
|
||||
(x, FlatExpression::Number(constant)) => {
|
||||
self.constant_range_check(statements_flattened, x, constant)
|
||||
}
|
||||
// (c < x <= p - 1) <=> (0 <= p - 1 - x < p - 1 - c)
|
||||
(FlatExpression::Number(constant), x) => self.constant_range_check(
|
||||
statements_flattened,
|
||||
FlatExpression::Sub(box T::max_value().into(), box x),
|
||||
T::max_value() - constant,
|
||||
),
|
||||
(lhs_flattened, rhs_flattened) => {
|
||||
let safe_width = bit_width - 2; // dynamic comparison is not complete, it only applies to words of width `bit_width - 2`
|
||||
|
||||
// lhs
|
||||
let lhs_id = self.define(lhs_flattened, statements_flattened);
|
||||
|
||||
|
@ -707,43 +750,11 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
)
|
||||
}
|
||||
BooleanExpression::FieldEq(box lhs, box rhs) => {
|
||||
// Wanted: (Y = (X != 0) ? 1 : 0)
|
||||
// X = a - b
|
||||
// # Y = if X == 0 then 0 else 1 fi
|
||||
// # M = if X == 0 then 1 else 1/X fi
|
||||
// Y == X * M
|
||||
// 0 == (1-Y) * X
|
||||
let lhs = self.flatten_field_expression(symbols, statements_flattened, lhs);
|
||||
|
||||
let name_y = self.use_sym();
|
||||
let name_m = self.use_sym();
|
||||
let rhs = self.flatten_field_expression(symbols, statements_flattened, rhs);
|
||||
|
||||
let x = self.flatten_field_expression(
|
||||
symbols,
|
||||
statements_flattened,
|
||||
FieldElementExpression::Sub(box lhs, box rhs),
|
||||
);
|
||||
|
||||
statements_flattened.push(FlatStatement::Directive(FlatDirective::new(
|
||||
vec![name_y, name_m],
|
||||
Solver::ConditionEq,
|
||||
vec![x.clone()],
|
||||
)));
|
||||
statements_flattened.push(FlatStatement::Condition(
|
||||
FlatExpression::Identifier(name_y),
|
||||
FlatExpression::Mult(box x.clone(), box FlatExpression::Identifier(name_m)),
|
||||
));
|
||||
|
||||
let res = FlatExpression::Sub(
|
||||
box FlatExpression::Number(T::one()),
|
||||
box FlatExpression::Identifier(name_y),
|
||||
);
|
||||
|
||||
statements_flattened.push(FlatStatement::Condition(
|
||||
FlatExpression::Number(T::zero()),
|
||||
FlatExpression::Mult(box res.clone(), box x),
|
||||
));
|
||||
|
||||
res
|
||||
self.eq_check(statements_flattened, lhs, rhs)
|
||||
}
|
||||
BooleanExpression::UintEq(box lhs, box rhs) => {
|
||||
// We reduce each side into range and apply the same approach as for field elements
|
||||
|
@ -755,9 +766,6 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
// Y == X * M
|
||||
// 0 == (1-Y) * X
|
||||
|
||||
let name_y = self.use_sym();
|
||||
let name_m = self.use_sym();
|
||||
|
||||
assert!(lhs.metadata.clone().unwrap().should_reduce.to_bool());
|
||||
assert!(rhs.metadata.clone().unwrap().should_reduce.to_bool());
|
||||
|
||||
|
@ -768,29 +776,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
.flatten_uint_expression(symbols, statements_flattened, rhs)
|
||||
.get_field_unchecked();
|
||||
|
||||
let x = FlatExpression::Sub(box lhs, box rhs);
|
||||
|
||||
statements_flattened.push(FlatStatement::Directive(FlatDirective::new(
|
||||
vec![name_y, name_m],
|
||||
Solver::ConditionEq,
|
||||
vec![x.clone()],
|
||||
)));
|
||||
statements_flattened.push(FlatStatement::Condition(
|
||||
FlatExpression::Identifier(name_y),
|
||||
FlatExpression::Mult(box x.clone(), box FlatExpression::Identifier(name_m)),
|
||||
));
|
||||
|
||||
let res = FlatExpression::Sub(
|
||||
box FlatExpression::Number(T::one()),
|
||||
box FlatExpression::Identifier(name_y),
|
||||
);
|
||||
|
||||
statements_flattened.push(FlatStatement::Condition(
|
||||
FlatExpression::Number(T::zero()),
|
||||
FlatExpression::Mult(box res.clone(), box x),
|
||||
));
|
||||
|
||||
res
|
||||
self.eq_check(statements_flattened, lhs, rhs)
|
||||
}
|
||||
BooleanExpression::Le(box lhs, box rhs) => {
|
||||
let lt = self.flatten_boolean_expression(
|
||||
|
|
76
zokrates_core_test/tests/tests/range.json
Normal file
76
zokrates_core_test/tests/tests/range.json
Normal file
|
@ -0,0 +1,76 @@
|
|||
{
|
||||
"entry_point": "./tests/tests/range.zok",
|
||||
"curves": ["Bn128"],
|
||||
"tests": [
|
||||
{
|
||||
"input": {
|
||||
"values": ["0"]
|
||||
},
|
||||
"output": {
|
||||
"Ok": {
|
||||
"values": ["0", "1", "1", "1", "0"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"input": {
|
||||
"values": ["1"]
|
||||
},
|
||||
"output": {
|
||||
"Ok": {
|
||||
"values": ["0", "0", "1", "1", "1"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"input": {
|
||||
"values": ["2"]
|
||||
},
|
||||
"output": {
|
||||
"Ok": {
|
||||
"values": ["0", "0", "1", "1", "0"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"input": {
|
||||
"values": ["254"]
|
||||
},
|
||||
"output": {
|
||||
"Ok": {
|
||||
"values": ["0", "0", "1", "1", "0"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"input": {
|
||||
"values": ["255"]
|
||||
},
|
||||
"output": {
|
||||
"Ok": {
|
||||
"values": ["0", "0", "0", "1", "0"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"input": {
|
||||
"values": ["21888242871839275222246405745257275088548364400416034343698204186575808495615"]
|
||||
},
|
||||
"output": {
|
||||
"Ok": {
|
||||
"values": ["0", "0", "0", "1", "0"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"input": {
|
||||
"values": ["21888242871839275222246405745257275088548364400416034343698204186575808495616"]
|
||||
},
|
||||
"output": {
|
||||
"Ok": {
|
||||
"values": ["0", "0", "0", "0", "0"]
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
2
zokrates_core_test/tests/tests/range.zok
Normal file
2
zokrates_core_test/tests/tests/range.zok
Normal file
|
@ -0,0 +1,2 @@
|
|||
def main(field x) -> bool[5]:
|
||||
return [x < 0, x < 1, x < 255, x < 0 - 1, x - 2 > 0 - 2]
|
|
@ -30,7 +30,7 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn max_value_bits() {
|
||||
let bits = FieldPrime::max_value_bit_vector_be();
|
||||
let bits = FieldPrime::max_value().bit_vector_be();
|
||||
assert_eq!(
|
||||
bits[0..10].to_vec(),
|
||||
vec![true, true, false, false, false, false, false, true, true, false]
|
||||
|
|
Loading…
Reference in a new issue