1
0
Fork 0
mirror of synced 2025-09-23 12:18:44 +00:00

extract eq check code, implement c < x checks, update tests to use dynamic comparison

This commit is contained in:
schaeff 2021-03-16 12:24:32 +01:00
parent 0ffcb31392
commit 0a963dd152
8 changed files with 232 additions and 165 deletions

View file

@ -1,3 +0,0 @@
def main(field x):
assert(x < 3)
return

View file

@ -9,4 +9,4 @@ def main(field a) -> bool:
// maxvalue = 2**252 - 1
field maxvalue = a + 7237005577332262213973186563042994240829374041602535252466099000494570602496 - 1
// we added a = 0 to prevent the condition to be evaluated at compile time
return 0 < (maxvalue + 1)
return a < (maxvalue + 1)

View file

@ -4,4 +4,4 @@
def main(field a) -> bool:
field p = 21888242871839275222246405745257275088548364400416034343698204186575808495616 + a
// we added a = 0 to prevent the condition to be evaluated at compile time
return 0 < p
return a < p

View file

@ -218,6 +218,12 @@ pub enum FlatExpression<T> {
Mult(Box<FlatExpression<T>>, Box<FlatExpression<T>>),
}
impl<T> From<T> for FlatExpression<T> {
fn from(other: T) -> Self {
Self::Number(other)
}
}
impl<T: Field> FlatExpression<T> {
pub fn apply_substitution(
self,

View file

@ -288,6 +288,47 @@ impl<'ast, T: Field> Flattener<'ast, T> {
res
}
fn eq_check(
&mut self,
statements_flattened: &mut Vec<FlatStatement<T>>,
left: FlatExpression<T>,
right: FlatExpression<T>,
) -> FlatExpression<T> {
// Wanted: (Y = (X != 0) ? 1 : 0)
// X = a - b
// # Y = if X == 0 then 0 else 1 fi
// # M = if X == 0 then 1 else 1/X fi
// Y == X * M
// 0 == (1-Y) * X
let x = FlatExpression::Sub(box left, box right);
let name_y = self.use_sym();
let name_m = self.use_sym();
statements_flattened.push(FlatStatement::Directive(FlatDirective::new(
vec![name_y, name_m],
Solver::ConditionEq,
vec![x.clone()],
)));
statements_flattened.push(FlatStatement::Condition(
FlatExpression::Identifier(name_y),
FlatExpression::Mult(box x.clone(), box FlatExpression::Identifier(name_m)),
));
let res = FlatExpression::Sub(
box FlatExpression::Number(T::one()),
box FlatExpression::Identifier(name_y),
);
statements_flattened.push(FlatStatement::Condition(
FlatExpression::Number(T::zero()),
FlatExpression::Mult(box res.clone(), box x),
));
res
}
fn enforce_strict_le_check(
&mut self,
statements_flattened: &mut FlatStatements<T>,
@ -375,6 +416,96 @@ impl<'ast, T: Field> Flattener<'ast, T> {
}
}
/// Compute a range check against a constant
///
/// # Arguments
///
/// * `statements_flattened` - Vector where new flattened statements can be added.
/// * `e` - the `FlatExpression` that's being checked against the range.
/// * `c` - the constant upper bound of the range
///
/// # Returns
/// * a `FlatExpression` which evaluates to `1` if `0 <= e < c`, and to `0` otherwise
fn constant_range_check(
&mut self,
statements_flattened: &mut FlatStatements<T>,
e: FlatExpression<T>,
c: T,
) -> FlatExpression<T> {
// we make use of constant `<=` checks in this function, therefore we rely on the fact that:
// `a < c <=> (a <= c - 1 if c !=0, false if c == 0)`
if c == T::zero() {
// this is the case c == 0, we return 0, aka false
return T::zero().into();
}
let c = c - T::one();
let bit_width = T::get_required_bits();
// decompose e to bits
let e_id = self.define(e, statements_flattened);
// define variables for the bits
let e_bits_be: Vec<FlatVariable> = (0..bit_width).map(|_| self.use_sym()).collect();
// add a directive to get the bits
statements_flattened.push(FlatStatement::Directive(FlatDirective::new(
e_bits_be.clone(),
Solver::bits(bit_width),
vec![e_id],
)));
// bitness checks
for i in 0..bit_width {
statements_flattened.push(FlatStatement::Condition(
FlatExpression::Identifier(e_bits_be[i]),
FlatExpression::Mult(
box FlatExpression::Identifier(e_bits_be[i]),
box FlatExpression::Identifier(e_bits_be[i]),
),
));
}
// bit decomposition check
let mut e_sum = FlatExpression::Number(T::from(0));
for i in 0..bit_width {
e_sum = FlatExpression::Add(
box e_sum,
box FlatExpression::Mult(
box FlatExpression::Identifier(e_bits_be[i]),
box FlatExpression::Number(T::from(2).pow(bit_width - i - 1)),
),
);
}
statements_flattened.push(FlatStatement::Condition(
FlatExpression::Identifier(e_id),
e_sum,
));
// check that this decomposition does not overflow the field
self.enforce_strict_le_check(
statements_flattened,
&e_bits_be,
&T::max_value().bit_vector_be(),
);
let conditions = self.strict_le_check(statements_flattened, &e_bits_be, &c.bit_vector_be());
// return `len(conditions) == sum(conditions)`
self.eq_check(
statements_flattened,
T::from(conditions.len()).into(),
conditions
.into_iter()
.fold(FlatExpression::Number(T::zero()), |acc, e| {
FlatExpression::Add(box acc, box e)
}),
)
}
/// Flattens a boolean expression
///
/// # Arguments
@ -401,7 +532,6 @@ impl<'ast, T: Field> Flattener<'ast, T> {
BooleanExpression::Lt(box lhs, box rhs) => {
// Get the bit width to know the size of the binary decompositions for this Field
let bit_width = T::get_required_bits();
let safe_width = bit_width - 2; // making sure we don't overflow, assert here?
// We know from semantic checking that lhs and rhs have the same type
// What the expression will flatten to depends on that type
@ -412,105 +542,18 @@ impl<'ast, T: Field> Flattener<'ast, T> {
self.flatten_field_expression(symbols, statements_flattened, rhs);
match (lhs_flattened, rhs_flattened) {
(lhs_flattened, FlatExpression::Number(constant)) => {
// decompose lhs to bits
let lhs_id = self.define(lhs_flattened, statements_flattened);
// define variables for the bits
let lhs_bits_be: Vec<FlatVariable> =
(0..bit_width).map(|_| self.use_sym()).collect();
// add a directive to get the bits
statements_flattened.push(FlatStatement::Directive(FlatDirective::new(
lhs_bits_be.clone(),
Solver::bits(bit_width),
vec![lhs_id],
)));
// bitness checks
for i in 0..bit_width {
statements_flattened.push(FlatStatement::Condition(
FlatExpression::Identifier(lhs_bits_be[i]),
FlatExpression::Mult(
box FlatExpression::Identifier(lhs_bits_be[i]),
box FlatExpression::Identifier(lhs_bits_be[i]),
),
));
}
// bit decomposition check
let mut lhs_sum = FlatExpression::Number(T::from(0));
for i in 0..bit_width {
lhs_sum = FlatExpression::Add(
box lhs_sum,
box FlatExpression::Mult(
box FlatExpression::Identifier(lhs_bits_be[i]),
box FlatExpression::Number(T::from(2).pow(bit_width - i - 1)),
),
);
}
statements_flattened.push(FlatStatement::Condition(
FlatExpression::Identifier(lhs_id),
lhs_sum,
));
// check that this decomposition does not overflow the field
self.enforce_strict_le_check(
statements_flattened,
&lhs_bits_be,
&T::max_value().bit_vector_be(),
);
let conditions = self.strict_le_check(
statements_flattened,
&lhs_bits_be,
&constant.bit_vector_be(),
);
// define the condition as `conditions.len() - sum(conditions)`
let condition = FlatExpression::Sub(
box FlatExpression::Number(T::from(conditions.len())),
box conditions
.into_iter()
.fold(FlatExpression::Number(T::zero()), |acc, e| {
FlatExpression::Add(box acc, box e)
}),
);
// return `condition == 0`
// copy pasted the fieldeq flattening code
let name_y = self.use_sym();
let name_m = self.use_sym();
statements_flattened.push(FlatStatement::Directive(FlatDirective::new(
vec![name_y, name_m],
Solver::ConditionEq,
vec![condition.clone()],
)));
statements_flattened.push(FlatStatement::Condition(
FlatExpression::Identifier(name_y),
FlatExpression::Mult(
box condition.clone(),
box FlatExpression::Identifier(name_m),
),
));
let res = FlatExpression::Sub(
box FlatExpression::Number(T::one()),
box FlatExpression::Identifier(name_y),
);
statements_flattened.push(FlatStatement::Condition(
FlatExpression::Number(T::zero()),
FlatExpression::Mult(box res.clone(), box condition),
));
res
(x, FlatExpression::Number(constant)) => {
self.constant_range_check(statements_flattened, x, constant)
}
// (c < x <= p - 1) <=> (0 <= p - 1 - x < p - 1 - c)
(FlatExpression::Number(constant), x) => self.constant_range_check(
statements_flattened,
FlatExpression::Sub(box T::max_value().into(), box x),
T::max_value() - constant,
),
(lhs_flattened, rhs_flattened) => {
let safe_width = bit_width - 2; // dynamic comparison is not complete, it only applies to words of width `bit_width - 2`
// lhs
let lhs_id = self.define(lhs_flattened, statements_flattened);
@ -707,43 +750,11 @@ impl<'ast, T: Field> Flattener<'ast, T> {
)
}
BooleanExpression::FieldEq(box lhs, box rhs) => {
// Wanted: (Y = (X != 0) ? 1 : 0)
// X = a - b
// # Y = if X == 0 then 0 else 1 fi
// # M = if X == 0 then 1 else 1/X fi
// Y == X * M
// 0 == (1-Y) * X
let lhs = self.flatten_field_expression(symbols, statements_flattened, lhs);
let name_y = self.use_sym();
let name_m = self.use_sym();
let rhs = self.flatten_field_expression(symbols, statements_flattened, rhs);
let x = self.flatten_field_expression(
symbols,
statements_flattened,
FieldElementExpression::Sub(box lhs, box rhs),
);
statements_flattened.push(FlatStatement::Directive(FlatDirective::new(
vec![name_y, name_m],
Solver::ConditionEq,
vec![x.clone()],
)));
statements_flattened.push(FlatStatement::Condition(
FlatExpression::Identifier(name_y),
FlatExpression::Mult(box x.clone(), box FlatExpression::Identifier(name_m)),
));
let res = FlatExpression::Sub(
box FlatExpression::Number(T::one()),
box FlatExpression::Identifier(name_y),
);
statements_flattened.push(FlatStatement::Condition(
FlatExpression::Number(T::zero()),
FlatExpression::Mult(box res.clone(), box x),
));
res
self.eq_check(statements_flattened, lhs, rhs)
}
BooleanExpression::UintEq(box lhs, box rhs) => {
// We reduce each side into range and apply the same approach as for field elements
@ -755,9 +766,6 @@ impl<'ast, T: Field> Flattener<'ast, T> {
// Y == X * M
// 0 == (1-Y) * X
let name_y = self.use_sym();
let name_m = self.use_sym();
assert!(lhs.metadata.clone().unwrap().should_reduce.to_bool());
assert!(rhs.metadata.clone().unwrap().should_reduce.to_bool());
@ -768,29 +776,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
.flatten_uint_expression(symbols, statements_flattened, rhs)
.get_field_unchecked();
let x = FlatExpression::Sub(box lhs, box rhs);
statements_flattened.push(FlatStatement::Directive(FlatDirective::new(
vec![name_y, name_m],
Solver::ConditionEq,
vec![x.clone()],
)));
statements_flattened.push(FlatStatement::Condition(
FlatExpression::Identifier(name_y),
FlatExpression::Mult(box x.clone(), box FlatExpression::Identifier(name_m)),
));
let res = FlatExpression::Sub(
box FlatExpression::Number(T::one()),
box FlatExpression::Identifier(name_y),
);
statements_flattened.push(FlatStatement::Condition(
FlatExpression::Number(T::zero()),
FlatExpression::Mult(box res.clone(), box x),
));
res
self.eq_check(statements_flattened, lhs, rhs)
}
BooleanExpression::Le(box lhs, box rhs) => {
let lt = self.flatten_boolean_expression(

View file

@ -0,0 +1,76 @@
{
"entry_point": "./tests/tests/range.zok",
"curves": ["Bn128"],
"tests": [
{
"input": {
"values": ["0"]
},
"output": {
"Ok": {
"values": ["0", "1", "1", "1", "0"]
}
}
},
{
"input": {
"values": ["1"]
},
"output": {
"Ok": {
"values": ["0", "0", "1", "1", "1"]
}
}
},
{
"input": {
"values": ["2"]
},
"output": {
"Ok": {
"values": ["0", "0", "1", "1", "0"]
}
}
},
{
"input": {
"values": ["254"]
},
"output": {
"Ok": {
"values": ["0", "0", "1", "1", "0"]
}
}
},
{
"input": {
"values": ["255"]
},
"output": {
"Ok": {
"values": ["0", "0", "0", "1", "0"]
}
}
},
{
"input": {
"values": ["21888242871839275222246405745257275088548364400416034343698204186575808495615"]
},
"output": {
"Ok": {
"values": ["0", "0", "0", "1", "0"]
}
}
},
{
"input": {
"values": ["21888242871839275222246405745257275088548364400416034343698204186575808495616"]
},
"output": {
"Ok": {
"values": ["0", "0", "0", "0", "0"]
}
}
}
]
}

View file

@ -0,0 +1,2 @@
def main(field x) -> bool[5]:
return [x < 0, x < 1, x < 255, x < 0 - 1, x - 2 > 0 - 2]

View file

@ -30,7 +30,7 @@ mod tests {
#[test]
fn max_value_bits() {
let bits = FieldPrime::max_value_bit_vector_be();
let bits = FieldPrime::max_value().bit_vector_be();
assert_eq!(
bits[0..10].to_vec(),
vec![true, true, false, false, false, false, false, true, true, false]