1
0
Fork 0
mirror of synced 2025-09-24 04:40:05 +00:00

Merge pull request #1025 from Zokrates/cheaper-dynamic-comparison

Reduce the cost of dynamic LT checks
This commit is contained in:
Thibaut Schaeffer 2021-11-11 10:21:55 +01:00 committed by GitHub
commit d2e8b905c1
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 184 additions and 1803 deletions

View file

@ -0,0 +1 @@
Reduce cost of dynamic comparison

View file

@ -2,26 +2,26 @@
The following table lists the precedence and associativity of all operators. Operators are listed top to bottom, in ascending precedence. Operators in the same cell have the same precedence. Operators are binary, unless the syntax is provided.
| Operator | Description | `field` | `u8/u16` `u32/u64` | `bool` | Associativity | Remarks |
|----------------------------|------------------------------------------------------------|------------------------------|-------------------------------|-----------------------------|---------------|---------|
| `**`<br> | Power | &check; | &nbsp; | &nbsp; | Left | [^1] |
| `+x`<br>`-x`<br>`!x`<br> | Positive<br>Negative<br>Negation<br> | &check;<br>&check;<br>&nbsp; | &check;<br>&check;<br>&nbsp; | &nbsp;<br>&nbsp;<br>&check; | Right | |
| `*`<br>`/`<br>`%`<br> | Multiplication<br> Division<br> Remainder<br> | &check;<br>&check;<br>&nbsp; | &check;<br>&check;<br>&check; | &nbsp;<br>&nbsp;<br>&nbsp; | Left | |
| `+`<br>`-`<br> | Addition<br> Subtraction<br> | &check; | &check; | &nbsp; | Left | |
| `<<`<br>`>>`<br> | Left shift<br> Right shift<br> | &nbsp; | &check; | &nbsp; | Left | [^2] |
| `&` | Bitwise AND | &nbsp; | &check; | &nbsp; | Left | |
| <code>&#124;</code> | Bitwise OR | &nbsp; | &check; | &nbsp; | Left | |
| `^` | Bitwise XOR | &nbsp; | &check; | &nbsp; | Left | |
| `>=`<br>`>`<br>`<=`<br>`<` | Greater or equal<br>Greater<br>Lower or equal<br>Lower<br> | &check; | &check; | &nbsp; | Left | [^3] |
| `!=`<br>`==`<br> | Not Equal<br>Equal<br> | &check; | &check; | &check; | Left | |
| `&&` | Boolean AND | &nbsp; | &nbsp; | &check; | Left | |
| <code>&#124;&#124;</code> | Boolean OR | &nbsp; | &nbsp; | &check; | Left | |
| `if c then x else y fi` | Conditional expression | &check; | &check; | &check; | Right | [^4] |
| Operator | Description | `field` | `u8/u16` `u32/u64` | `bool` | Associativity |
|----------------------------|------------------------------------------------------------|------------------------------|-------------------------------|-----------------------------|---------------|
| `**`<br> | Power | &check;[^1] | &nbsp; | &nbsp; | Left |
| `+x`<br>`-x`<br>`!x`<br> | Positive<br>Negative<br>Negation<br> | &check;<br>&check;<br>&nbsp; | &check;<br>&check;<br>&check; | &nbsp;<br>&nbsp;<br>&check; | Right |
| `*`<br>`/`<br>`%`<br> | Multiplication<br> Division<br> Remainder<br> | &check;<br>&check;<br>&nbsp; | &check;<br>&check;<br>&check; | &nbsp;<br>&nbsp;<br>&nbsp; | Left |
| `+`<br>`-`<br> | Addition<br> Subtraction<br> | &check; | &check; | &nbsp; | Left |
| `<<`<br>`>>`<br> | Left shift<br> Right shift<br> | &nbsp; | &check;[^2] | &nbsp; | Left |
| `&` | Bitwise AND | &nbsp; | &check; | &nbsp; | Left |
| <code>&#124;</code> | Bitwise OR | &nbsp; | &check; | &nbsp; | Left |
| `^` | Bitwise XOR | &nbsp; | &check; | &nbsp; | Left |
| `>=`<br>`>`<br>`<=`<br>`<` | Greater or equal<br>Greater<br>Lower or equal<br>Lower<br> | &check;[^3] | &check; | &nbsp; | Left |
| `!=`<br>`==`<br> | Not Equal<br>Equal<br> | &check; | &check; | &check; | Left |
| `&&` | Boolean AND | &nbsp; | &nbsp; | &check; | Left |
| <code>&#124;&#124;</code> | Boolean OR | &nbsp; | &nbsp; | &check; | Left |
| `c ? x : y`<br><br>`if c then x else y fi` | Conditional expression | &check; | &check; | &check; | Right | |
[^1]: The exponent must be a compile-time constant of type `u32`
[^2]: The right operand must be a compile time constant of type `u32`
[^3]: Both operands are asserted to be strictly lower than the biggest power of 2 lower than `p/2`, unless one of them can be determined to be a compile-time constant
[^4]: Conditional expression can also be written using a ternary operator: `c ? x : y`
[^3]: If neither of the operands can be determined to be a compile-time constant, then we have a restriction: for the check `a < b`, if the field prime `p` is represented on `N` bits, `|a - b|` must fit in `N - 2` bits.
Failing to respect this condition will lead to a runtime error.

View file

@ -1,6 +1,6 @@
from "field" import FIELD_SIZE_IN_BITS
// we can compare numbers up to 2^(pbits - 2) - 1, ie any number which fits in (pbits - 2) bits
// we can compare numbers whose difference fits in (pbits - 2) bits
// It should not work for the maxvalue = 2^(pbits - 2) - 1 augmented by one
// /!\ should be called with a = 0
@ -8,4 +8,4 @@ def main(field a) -> bool:
u32 pbits = FIELD_SIZE_IN_BITS
// we added a = 0 to prevent the condition to be evaluated at compile time
field maxvalue = a + (2**(pbits - 2) - 1)
return a < (maxvalue + 1)
return a < maxvalue + 1

View file

@ -0,0 +1,11 @@
from "field" import FIELD_SIZE_IN_BITS
// we can compare numbers whose difference fits in (pbits - 2) bits
// It should not work for the maxvalue = 2^(pbits - 2) - 1 augmented by one
// /!\ should be called with a = 0
def main(field a) -> bool:
u32 pbits = FIELD_SIZE_IN_BITS
// we added a = 0 to prevent the condition to be evaluated at compile time
field maxvalue = a + (2**(pbits - 2) - 1)
return maxvalue + 1 < a

File diff suppressed because one or more lines are too long

View file

@ -196,14 +196,17 @@ pub fn compile<T: Field, E: Into<imports::Error>>(
// flatten input program
log::debug!("Flatten");
let program_flattened = Flattener::flatten(typed_ast, config);
log::trace!("\n{}", program_flattened);
// constant propagation after call resolution
log::debug!("Propagate flat program");
let program_flattened = program_flattened.propagate();
log::trace!("\n{}", program_flattened);
// convert to ir
log::debug!("Convert to IR");
let ir_prog = ir::Prog::from(program_flattened);
log::trace!("\n{}", ir_prog);
// optimize
log::debug!("Optimise IR");

View file

@ -37,6 +37,7 @@ pub enum RuntimeError {
LtSum,
LtFinalBitness,
LtFinalSum,
LtSymetric,
Or,
Xor,
Inverse,
@ -81,6 +82,7 @@ impl fmt::Display for RuntimeError {
LtSum => "Sum check failed in Lt check",
LtFinalBitness => "Bitness check failed in final Lt check",
LtFinalSum => "Sum check failed in final Lt check",
LtSymetric => "Symetrical check failed in Lt check",
Or => "Or check failed",
Xor => "Xor check failed",
Inverse => "Division by zero",

View file

@ -641,6 +641,103 @@ impl<'ast, T: Field> Flattener<'ast, T> {
)
}
fn lt_check(
&mut self,
statements_flattened: &mut FlatStatements<T>,
lhs_flattened: FlatExpression<T>,
rhs_flattened: FlatExpression<T>,
bit_width: usize,
) -> FlatExpression<T> {
match (lhs_flattened, rhs_flattened) {
(x, FlatExpression::Number(constant)) => {
self.constant_lt_check(statements_flattened, x, constant)
}
// (c < x <= p - 1) <=> (0 <= p - 1 - x < p - 1 - c)
(FlatExpression::Number(constant), x) => self.constant_lt_check(
statements_flattened,
FlatExpression::Sub(box T::max_value().into(), box x),
T::max_value() - constant,
),
(lhs_flattened, rhs_flattened) => {
let lhs_id = self.define(lhs_flattened, statements_flattened);
let rhs_id = self.define(rhs_flattened, statements_flattened);
// shifted_sub := 2**safe_width + lhs - rhs
let shifted_sub = FlatExpression::Add(
box FlatExpression::Number(T::from(2).pow(bit_width)),
box FlatExpression::Sub(
box FlatExpression::Identifier(lhs_id),
box FlatExpression::Identifier(rhs_id),
),
);
let sub_width = bit_width + 1;
// define variables for the bits
let shifted_sub_bits_be: Vec<FlatVariable> =
(0..sub_width).map(|_| self.use_sym()).collect();
// add a directive to get the bits
statements_flattened.push(FlatStatement::Directive(FlatDirective::new(
shifted_sub_bits_be.clone(),
Solver::bits(sub_width),
vec![shifted_sub.clone()],
)));
// bitness checks
for bit in shifted_sub_bits_be.iter() {
statements_flattened.push(FlatStatement::Condition(
FlatExpression::Identifier(*bit),
FlatExpression::Mult(
box FlatExpression::Identifier(*bit),
box FlatExpression::Identifier(*bit),
),
RuntimeError::LtFinalBitness,
));
}
// sum(sym_b{i} * 2**i)
let mut expr = FlatExpression::Number(T::from(0));
for (i, bit) in shifted_sub_bits_be.iter().take(sub_width).enumerate() {
expr = FlatExpression::Add(
box expr,
box FlatExpression::Mult(
box FlatExpression::Identifier(*bit),
box FlatExpression::Number(T::from(2).pow(sub_width - i - 1)),
),
);
}
statements_flattened.push(FlatStatement::Condition(
shifted_sub,
expr,
RuntimeError::LtFinalSum,
));
// to make this check symetric, we ban the value `a - b == -2**N`, as the value `a - b == 2**N` is already banned
let fail = self.eq_check(
statements_flattened,
FlatExpression::Sub(
box FlatExpression::Identifier(rhs_id),
box FlatExpression::Identifier(lhs_id),
),
FlatExpression::Number(T::from(2).pow(bit_width)),
);
statements_flattened.push(FlatStatement::Condition(
fail,
FlatExpression::Number(T::from(0)),
RuntimeError::LtSymetric,
));
FlatExpression::Sub(
box FlatExpression::Number(T::one()),
box FlatExpression::Identifier(shifted_sub_bits_be[0]),
)
}
}
}
/// Flattens a boolean expression
///
/// # Arguments
@ -673,195 +770,17 @@ impl<'ast, T: Field> Flattener<'ast, T> {
// Get the bit width to know the size of the binary decompositions for this Field
let bit_width = T::get_required_bits();
// We know from semantic checking that lhs and rhs have the same type
// What the expression will flatten to depends on that type
let safe_width = bit_width - 2; // dynamic comparison is not complete, it only applies to field elements whose difference is strictly smaller than 2**(bitwidth - 2)
let lhs_flattened = self.flatten_field_expression(statements_flattened, lhs);
let rhs_flattened = self.flatten_field_expression(statements_flattened, rhs);
match (lhs_flattened, rhs_flattened) {
(x, FlatExpression::Number(constant)) => {
self.constant_lt_check(statements_flattened, x, constant)
}
// (c < x <= p - 1) <=> (0 <= p - 1 - x < p - 1 - c)
(FlatExpression::Number(constant), x) => self.constant_lt_check(
statements_flattened,
FlatExpression::Sub(box T::max_value().into(), box x),
T::max_value() - constant,
),
(lhs_flattened, rhs_flattened) => {
let safe_width = bit_width - 2; // dynamic comparison is not complete, it only applies to words of width `bit_width - 2`
// lhs
let lhs_id = self.define(lhs_flattened, statements_flattened);
// check that lhs and rhs are within the right range, i.e., their higher two bits are zero. We use big-endian so they are at positions 0 and 1
// lhs
{
// define variables for the bits
let lhs_bits_be: Vec<FlatVariable> =
(0..safe_width).map(|_| self.use_sym()).collect();
// add a directive to get the bits
statements_flattened.push(FlatStatement::Directive(
FlatDirective::new(
lhs_bits_be.clone(),
Solver::bits(safe_width),
vec![lhs_id],
),
));
// bitness checks
for bit in lhs_bits_be.iter().take(safe_width) {
statements_flattened.push(FlatStatement::Condition(
FlatExpression::Identifier(*bit),
FlatExpression::Mult(
box FlatExpression::Identifier(*bit),
box FlatExpression::Identifier(*bit),
),
RuntimeError::LtBitness,
));
}
// bit decomposition check
let mut lhs_sum = FlatExpression::Number(T::from(0));
for (i, bit) in lhs_bits_be.iter().take(safe_width).enumerate() {
lhs_sum = FlatExpression::Add(
box lhs_sum,
box FlatExpression::Mult(
box FlatExpression::Identifier(*bit),
box FlatExpression::Number(
T::from(2).pow(safe_width - i - 1),
),
),
);
}
statements_flattened.push(FlatStatement::Condition(
FlatExpression::Identifier(lhs_id),
lhs_sum,
RuntimeError::LtSum,
));
}
// rhs
let rhs_id = self.define(rhs_flattened, statements_flattened);
// rhs
{
// define variables for the bits
let rhs_bits_be: Vec<FlatVariable> =
(0..safe_width).map(|_| self.use_sym()).collect();
// add a directive to get the bits
statements_flattened.push(FlatStatement::Directive(
FlatDirective::new(
rhs_bits_be.clone(),
Solver::bits(safe_width),
vec![rhs_id],
),
));
// bitness checks
for bit in rhs_bits_be.iter().take(safe_width) {
statements_flattened.push(FlatStatement::Condition(
FlatExpression::Identifier(*bit),
FlatExpression::Mult(
box FlatExpression::Identifier(*bit),
box FlatExpression::Identifier(*bit),
),
RuntimeError::LtBitness,
));
}
// bit decomposition check
let mut rhs_sum = FlatExpression::Number(T::from(0));
for (i, bit) in rhs_bits_be.iter().take(safe_width).enumerate() {
rhs_sum = FlatExpression::Add(
box rhs_sum,
box FlatExpression::Mult(
box FlatExpression::Identifier(*bit),
box FlatExpression::Number(
T::from(2).pow(safe_width - i - 1),
),
),
);
}
statements_flattened.push(FlatStatement::Condition(
FlatExpression::Identifier(rhs_id),
rhs_sum,
RuntimeError::LtSum,
));
}
// sym := (lhs * 2) - (rhs * 2)
let subtraction_result = FlatExpression::Sub(
box FlatExpression::Mult(
box FlatExpression::Number(T::from(2)),
box FlatExpression::Identifier(lhs_id),
),
box FlatExpression::Mult(
box FlatExpression::Number(T::from(2)),
box FlatExpression::Identifier(rhs_id),
),
);
// define variables for the bits
let sub_bits_be: Vec<FlatVariable> =
(0..bit_width).map(|_| self.use_sym()).collect();
// add a directive to get the bits
statements_flattened.push(FlatStatement::Directive(FlatDirective::new(
sub_bits_be.clone(),
Solver::bits(bit_width),
vec![subtraction_result.clone()],
)));
// bitness checks
for bit in sub_bits_be.iter().take(bit_width) {
statements_flattened.push(FlatStatement::Condition(
FlatExpression::Identifier(*bit),
FlatExpression::Mult(
box FlatExpression::Identifier(*bit),
box FlatExpression::Identifier(*bit),
),
RuntimeError::LtFinalBitness,
));
}
// check that the decomposition is in the field with a strict `< p` checks
self.enforce_constant_le_check(
statements_flattened,
&sub_bits_be,
&T::max_value().bit_vector_be(),
);
// sum(sym_b{i} * 2**i)
let mut expr = FlatExpression::Number(T::from(0));
for (i, bit) in sub_bits_be.iter().take(bit_width).enumerate() {
expr = FlatExpression::Add(
box expr,
box FlatExpression::Mult(
box FlatExpression::Identifier(*bit),
box FlatExpression::Number(T::from(2).pow(bit_width - i - 1)),
),
);
}
statements_flattened.push(FlatStatement::Condition(
subtraction_result,
expr,
RuntimeError::LtFinalSum,
));
FlatExpression::Identifier(sub_bits_be[bit_width - 1])
}
}
self.lt_check(
statements_flattened,
lhs_flattened,
rhs_flattened,
safe_width,
)
}
BooleanExpression::BoolEq(box lhs, box rhs) => {
// lhs and rhs are booleans, they flatten to 0 or 1
@ -944,79 +863,23 @@ impl<'ast, T: Field> Flattener<'ast, T> {
BooleanExpression::FieldLe(rhs, lhs),
),
BooleanExpression::UintLt(box lhs, box rhs) => {
let lhs_flattened = self.flatten_uint_expression(statements_flattened, lhs);
let rhs_flattened = self.flatten_uint_expression(statements_flattened, rhs);
let bit_width = lhs.bitwidth.to_usize();
assert!(lhs.metadata.as_ref().unwrap().should_reduce.to_bool());
assert!(rhs.metadata.as_ref().unwrap().should_reduce.to_bool());
// Get the bit width to know the size of the binary decompositions for this Field
// This is not this uint bitwidth
let bit_width = T::get_required_bits();
let lhs_flattened = self
.flatten_uint_expression(statements_flattened, lhs)
.get_field_unchecked();
let rhs_flattened = self
.flatten_uint_expression(statements_flattened, rhs)
.get_field_unchecked();
// lhs
let lhs_id = self.define(lhs_flattened.get_field_unchecked(), statements_flattened);
let rhs_id = self.define(rhs_flattened.get_field_unchecked(), statements_flattened);
// sym := (lhs * 2) - (rhs * 2)
let subtraction_result = FlatExpression::Sub(
box FlatExpression::Mult(
box FlatExpression::Number(T::from(2)),
box FlatExpression::Identifier(lhs_id),
),
box FlatExpression::Mult(
box FlatExpression::Number(T::from(2)),
box FlatExpression::Identifier(rhs_id),
),
);
// define variables for the bits
let sub_bits_be: Vec<FlatVariable> =
(0..bit_width).map(|_| self.use_sym()).collect();
// add a directive to get the bits
statements_flattened.push(FlatStatement::Directive(FlatDirective::new(
sub_bits_be.clone(),
Solver::bits(bit_width),
vec![subtraction_result.clone()],
)));
// bitness checks
for bit in sub_bits_be.iter().take(bit_width) {
statements_flattened.push(FlatStatement::Condition(
FlatExpression::Identifier(*bit),
FlatExpression::Mult(
box FlatExpression::Identifier(*bit),
box FlatExpression::Identifier(*bit),
),
RuntimeError::LtFinalBitness,
));
}
// check that the decomposition is in the field with a strict `< p` checks
self.enforce_constant_le_check(
self.lt_check(
statements_flattened,
&sub_bits_be,
&T::max_value().bit_vector_be(),
);
// sum(sym_b{i} * 2**i)
let mut expr = FlatExpression::Number(T::from(0));
for (i, bit) in sub_bits_be.iter().enumerate().take(bit_width) {
expr = FlatExpression::Add(
box expr,
box FlatExpression::Mult(
box FlatExpression::Identifier(*bit),
box FlatExpression::Number(T::from(2).pow(bit_width - i - 1)),
),
);
}
statements_flattened.push(FlatStatement::Condition(
subtraction_result,
expr,
RuntimeError::LtFinalSum,
));
FlatExpression::Identifier(sub_bits_be[bit_width - 1])
lhs_flattened,
rhs_flattened,
bit_width,
)
}
BooleanExpression::UintLe(box lhs, box rhs) => {
let lt = self.flatten_boolean_expression(

View file

@ -24,25 +24,25 @@ impl<T: Field> Prog<T> {
log::debug!("Constraints: {}", self.constraint_count());
log::debug!("Optimizer: Remove redefinitions");
let r = RedefinitionOptimizer::optimize(self);
log::debug!("Done");
log::trace!("\n{}\n", r);
// remove constraints that are always satisfied
log::debug!("Constraints: {}", r.constraint_count());
log::debug!("Optimizer: Remove tautologies");
let r = TautologyOptimizer::optimize(r);
log::debug!("Done");
log::trace!("\n{}\n", r);
// deduplicate directives which take the same input
log::debug!("Constraints: {}", r.constraint_count());
log::debug!("Optimizer: Remove duplicate directive");
let r = DirectiveOptimizer::optimize(r);
log::debug!("Done");
log::trace!("\n{}\n", r);
// remove duplicate constraints
log::debug!("Constraints: {}", r.constraint_count());
log::debug!("Optimizer: Remove duplicate constraints");
let r = DuplicateOptimizer::optimize(r);
log::debug!("Done");
log::trace!("\n{}\n", r);
log::debug!("Constraints: {}", r.constraint_count());
r

View file

@ -0,0 +1,16 @@
{
"entry_point": "./tests/tests/compare_min_to_max.zok",
"curves": ["Bn128", "Bls12_381", "Bls12_377", "Bw6_761"],
"tests": [
{
"input": {
"values": ["0"]
},
"output": {
"Ok": {
"values": ["0"]
}
}
}
]
}

View file

@ -1,7 +1,7 @@
from "field" import FIELD_MAX
// as p - 1 is greater than p/2, comparing to it should fail
// /!\ should be called with a = 0
// as `|a - FIELD_MAX| < 2**(N-2)` the comparison should succeed
def main(field a) -> bool:
field p = FIELD_MAX + a