1
0
Fork 0
mirror of synced 2025-09-24 04:40:05 +00:00

revert to original impl, add symetric check, add logs

This commit is contained in:
schaeff 2021-10-15 13:00:48 +03:00
parent 1e04d56a7b
commit f98585b784
8 changed files with 54 additions and 7 deletions

View file

@ -1,6 +1,6 @@
from "field" import FIELD_SIZE_IN_BITS
// we can compare numbers up to 2^(pbits - 2) - 1, ie any number which fits in (pbits - 2) bits
// we can compare numbers whose difference fits in (pbits - 2) bits
// It should not work for the maxvalue = 2^(pbits - 2) - 1 augmented by one
// /!\ should be called with a = 0
@ -8,4 +8,4 @@ def main(field a) -> bool:
u32 pbits = FIELD_SIZE_IN_BITS
// we added a = 0 to prevent the condition to be evaluated at compile time
field maxvalue = a + (2**(pbits - 2) - 1)
return a < (maxvalue + 1)
return a < maxvalue + 1

View file

@ -0,0 +1,11 @@
from "field" import FIELD_SIZE_IN_BITS
// we can compare numbers whose difference fits in (pbits - 2) bits
// It should not work for the maxvalue = 2^(pbits - 2) - 1 augmented by one
// /!\ should be called with a = 0
def main(field a) -> bool:
u32 pbits = FIELD_SIZE_IN_BITS
// we added a = 0 to prevent the condition to be evaluated at compile time
field maxvalue = a + (2**(pbits - 2) - 1)
return maxvalue + 1 < a

View file

@ -196,14 +196,17 @@ pub fn compile<T: Field, E: Into<imports::Error>>(
// flatten input program
log::debug!("Flatten");
let program_flattened = Flattener::flatten(typed_ast, config);
log::trace!("\n{}", program_flattened);
// constant propagation after call resolution
log::debug!("Propagate flat program");
let program_flattened = program_flattened.propagate();
log::trace!("\n{}", program_flattened);
// convert to ir
log::debug!("Convert to IR");
let ir_prog = ir::Prog::from(program_flattened);
log::trace!("\n{}", ir_prog);
// optimize
log::debug!("Optimise IR");

View file

@ -37,6 +37,7 @@ pub enum RuntimeError {
LtSum,
LtFinalBitness,
LtFinalSum,
LtSymetric,
Or,
Xor,
Inverse,
@ -81,6 +82,7 @@ impl fmt::Display for RuntimeError {
LtSum => "Sum check failed in Lt check",
LtFinalBitness => "Bitness check failed in final Lt check",
LtFinalSum => "Sum check failed in final Lt check",
LtSymetric => "Symetrical check failed in Lt check",
Or => "Or check failed",
Xor => "Xor check failed",
Inverse => "Division by zero",

View file

@ -706,6 +706,21 @@ impl<'ast, T: Field> Flattener<'ast, T> {
RuntimeError::LtFinalSum,
));
// to make this check symetric, we ban the value `a - b == -2**N`, as the value `a - b == 2**N` is already banned
let fail = self.eq_check(
statements_flattened,
FlatExpression::Sub(
box FlatExpression::Identifier(rhs_id),
box FlatExpression::Identifier(lhs_id),
),
FlatExpression::Number(T::from(2).pow(bit_width)),
);
statements_flattened.push(FlatStatement::Condition(
fail,
FlatExpression::Number(T::from(0)),
RuntimeError::LtSymetric,
));
FlatExpression::Sub(
box FlatExpression::Number(T::one()),
box FlatExpression::Identifier(shifted_sub_bits_be[0]),

View file

@ -24,25 +24,25 @@ impl<T: Field> Prog<T> {
log::debug!("Constraints: {}", self.constraint_count());
log::debug!("Optimizer: Remove redefinitions");
let r = RedefinitionOptimizer::optimize(self);
log::debug!("Done");
log::trace!("\n{}\n", r);
// remove constraints that are always satisfied
log::debug!("Constraints: {}", r.constraint_count());
log::debug!("Optimizer: Remove tautologies");
let r = TautologyOptimizer::optimize(r);
log::debug!("Done");
log::trace!("\n{}\n", r);
// deduplicate directives which take the same input
log::debug!("Constraints: {}", r.constraint_count());
log::debug!("Optimizer: Remove duplicate directive");
let r = DirectiveOptimizer::optimize(r);
log::debug!("Done");
log::trace!("\n{}\n", r);
// remove duplicate constraints
log::debug!("Constraints: {}", r.constraint_count());
log::debug!("Optimizer: Remove duplicate constraints");
let r = DuplicateOptimizer::optimize(r);
log::debug!("Done");
log::trace!("\n{}\n", r);
log::debug!("Constraints: {}", r.constraint_count());
r

View file

@ -0,0 +1,16 @@
{
"entry_point": "./tests/tests/compare_min_to_max.zok",
"curves": ["Bn128", "Bls12_381", "Bls12_377", "Bw6_761"],
"tests": [
{
"input": {
"values": ["0"]
},
"output": {
"Ok": {
"values": ["0"]
}
}
}
]
}

View file

@ -1,7 +1,7 @@
from "field" import FIELD_MAX
// as p - 1 is greater than p/2, comparing to it should fail
// /!\ should be called with a = 0
// as `|a - FIELD_MAX| < 2**(N-2)` the comparison should succeed
def main(field a) -> bool:
field p = FIELD_MAX + a