From 7f76505a86771ad3f131c7651d11ab8c205376a4 Mon Sep 17 00:00:00 2001 From: schaeff Date: Thu, 5 Aug 2021 16:03:57 +0200 Subject: [PATCH] refactor out of range interpreter, accept any size of output --- zokrates_core/src/flatten/mod.rs | 3 +- zokrates_core/src/ir/interpreter.rs | 86 +++++++++------------ zokrates_core/src/optimizer/redefinition.rs | 4 +- zokrates_core/tests/out_of_range.rs | 2 +- 4 files changed, 40 insertions(+), 55 deletions(-) diff --git a/zokrates_core/src/flatten/mod.rs b/zokrates_core/src/flatten/mod.rs index e096f9b2..f44c5364 100644 --- a/zokrates_core/src/flatten/mod.rs +++ b/zokrates_core/src/flatten/mod.rs @@ -1982,8 +1982,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { // constants do not require directives if let Some(FlatExpression::Number(ref x)) = e.field { - let bits: Vec<_> = Interpreter::default() - .execute_solver(&Solver::bits(to), &[x.clone()]) + let bits: Vec<_> = Interpreter::execute_solver(&Solver::bits(to), &[x.clone()]) .unwrap() .into_iter() .map(FlatExpression::Number) diff --git a/zokrates_core/src/ir/interpreter.rs b/zokrates_core/src/ir/interpreter.rs index ef344e91..510c89b2 100644 --- a/zokrates_core/src/ir/interpreter.rs +++ b/zokrates_core/src/ir/interpreter.rs @@ -1,5 +1,4 @@ use crate::flat_absy::flat_variable::FlatVariable; -use crate::ir::Directive; use crate::ir::{LinComb, Prog, QuadComb, Statement, Witness}; use crate::solvers::Solver; use serde::{Deserialize, Serialize}; @@ -65,66 +64,59 @@ impl Interpreter { } } }, - Statement::Directive(ref d) => match (&d.solver, self.should_try_out_of_range) { - (Solver::Bits(bitwidth), true) if *bitwidth >= T::get_required_bits() => { - Self::try_solve_out_of_range(&d, &mut witness) + Statement::Directive(ref d) => { + let mut inputs: Vec<_> = d + .inputs + .iter() + .map(|i| i.evaluate(&witness).unwrap()) + .collect(); + + let res = match (&d.solver, self.should_try_out_of_range) { + (Solver::Bits(bitwidth), true) if *bitwidth >= T::get_required_bits() => { + Ok(Self::try_solve_with_out_of_range_bits( + *bitwidth, + inputs.pop().unwrap(), + )) + } + _ => Self::execute_solver(&d.solver, &inputs), } - _ => { - let inputs: Vec<_> = d - .inputs - .iter() - .map(|i| i.evaluate(&witness).unwrap()) - .collect(); - match self.execute_solver(&d.solver, &inputs) { - Ok(res) => { - for (i, o) in d.outputs.iter().enumerate() { - witness.insert(*o, res[i].clone()); - } - continue; - } - Err(_) => return Err(Error::Solver), - }; + .map_err(|_| Error::Solver)?; + + for (i, o) in d.outputs.iter().enumerate() { + witness.insert(*o, res[i].clone()); } - }, + } } } Ok(Witness(witness)) } - fn try_solve_out_of_range(d: &Directive, witness: &mut BTreeMap) { + fn try_solve_with_out_of_range_bits(bit_width: usize, input: T) -> Vec { use num::traits::Pow; + use num_bigint::BigUint; - // we target the `2a - 2b` part of the `<` check by only returning out-of-range results - // when the input is not a single summand - let value = d.inputs[0].evaluate(&witness).unwrap(); - - let candidate = value.to_biguint() + T::max_value().to_biguint() + T::from(1).to_biguint(); + let candidate = input.to_biguint() + T::max_value().to_biguint() + T::from(1).to_biguint(); let input = if candidate < T::from(2).to_biguint().pow(T::get_required_bits()) { candidate } else { - value.to_biguint() + input.to_biguint() }; - let mut num = input; - let mut res = vec![]; - let bits = T::get_required_bits(); - for i in (0..bits).rev() { - if T::from(2).to_biguint().pow(i as usize) <= num { - num -= T::from(2).to_biguint().pow(i as usize); - res.push(T::one()); - } else { - res.push(T::zero()); - } - } - assert_eq!(num, T::zero().to_biguint()); + let padding = bit_width - T::get_required_bits(); - println!("RES {:?}", res); - - for (i, o) in d.outputs.iter().enumerate() { - witness.insert(*o, res[i].clone()); - } + (0..padding) + .map(|_| T::zero()) + .chain((0..T::get_required_bits()).rev().scan(input, |state, i| { + if BigUint::from(2usize).pow(i) <= *state { + *state = (*state).clone() - BigUint::from(2usize).pow(i); + Some(T::one()) + } else { + Some(T::zero()) + } + })) + .collect() } fn check_inputs(&self, program: &Prog, inputs: &[U]) -> Result<(), Error> { @@ -138,11 +130,7 @@ impl Interpreter { } } - pub fn execute_solver( - &self, - solver: &Solver, - inputs: &[T], - ) -> Result, String> { + pub fn execute_solver(solver: &Solver, inputs: &[T]) -> Result, String> { let (expected_input_count, expected_output_count) = solver.get_signature(); assert_eq!(inputs.len(), expected_input_count); diff --git a/zokrates_core/src/optimizer/redefinition.rs b/zokrates_core/src/optimizer/redefinition.rs index 6b197b1e..4222281c 100644 --- a/zokrates_core/src/optimizer/redefinition.rs +++ b/zokrates_core/src/optimizer/redefinition.rs @@ -140,9 +140,7 @@ impl Folder for RedefinitionOptimizer { // unwrap inputs to their constant value let inputs: Vec<_> = inputs.into_iter().map(|i| i.unwrap()).collect(); // run the solver - let outputs = Interpreter::default() - .execute_solver(&d.solver, &inputs) - .unwrap(); + let outputs = Interpreter::execute_solver(&d.solver, &inputs).unwrap(); assert_eq!(outputs.len(), d.outputs.len()); // insert the results in the substitution diff --git a/zokrates_core/tests/out_of_range.rs b/zokrates_core/tests/out_of_range.rs index 95fd5497..c27e2361 100644 --- a/zokrates_core/tests/out_of_range.rs +++ b/zokrates_core/tests/out_of_range.rs @@ -129,7 +129,7 @@ fn unpack256_unchecked() { .to_string(); // let's try to prove that the least significant bit of 0 is 1 - // we exploit the fact that the bits of 0 are the bits of p, and p is even + // we exploit the fact that the bits of 0 are the bits of p, and p is odd // we want this to succeed as the non strict version does not enforce the bits to be in range let stdlib_path = std::fs::canonicalize(