refactor out of range interpreter, accept any size of output
This commit is contained in:
parent
63be983d74
commit
7f76505a86
4 changed files with 40 additions and 55 deletions
|
@ -1982,8 +1982,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
|
||||
// constants do not require directives
|
||||
if let Some(FlatExpression::Number(ref x)) = e.field {
|
||||
let bits: Vec<_> = Interpreter::default()
|
||||
.execute_solver(&Solver::bits(to), &[x.clone()])
|
||||
let bits: Vec<_> = Interpreter::execute_solver(&Solver::bits(to), &[x.clone()])
|
||||
.unwrap()
|
||||
.into_iter()
|
||||
.map(FlatExpression::Number)
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
use crate::flat_absy::flat_variable::FlatVariable;
|
||||
use crate::ir::Directive;
|
||||
use crate::ir::{LinComb, Prog, QuadComb, Statement, Witness};
|
||||
use crate::solvers::Solver;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
@ -65,66 +64,59 @@ impl Interpreter {
|
|||
}
|
||||
}
|
||||
},
|
||||
Statement::Directive(ref d) => match (&d.solver, self.should_try_out_of_range) {
|
||||
(Solver::Bits(bitwidth), true) if *bitwidth >= T::get_required_bits() => {
|
||||
Self::try_solve_out_of_range(&d, &mut witness)
|
||||
Statement::Directive(ref d) => {
|
||||
let mut inputs: Vec<_> = d
|
||||
.inputs
|
||||
.iter()
|
||||
.map(|i| i.evaluate(&witness).unwrap())
|
||||
.collect();
|
||||
|
||||
let res = match (&d.solver, self.should_try_out_of_range) {
|
||||
(Solver::Bits(bitwidth), true) if *bitwidth >= T::get_required_bits() => {
|
||||
Ok(Self::try_solve_with_out_of_range_bits(
|
||||
*bitwidth,
|
||||
inputs.pop().unwrap(),
|
||||
))
|
||||
}
|
||||
_ => Self::execute_solver(&d.solver, &inputs),
|
||||
}
|
||||
_ => {
|
||||
let inputs: Vec<_> = d
|
||||
.inputs
|
||||
.iter()
|
||||
.map(|i| i.evaluate(&witness).unwrap())
|
||||
.collect();
|
||||
match self.execute_solver(&d.solver, &inputs) {
|
||||
Ok(res) => {
|
||||
for (i, o) in d.outputs.iter().enumerate() {
|
||||
witness.insert(*o, res[i].clone());
|
||||
}
|
||||
continue;
|
||||
}
|
||||
Err(_) => return Err(Error::Solver),
|
||||
};
|
||||
.map_err(|_| Error::Solver)?;
|
||||
|
||||
for (i, o) in d.outputs.iter().enumerate() {
|
||||
witness.insert(*o, res[i].clone());
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(Witness(witness))
|
||||
}
|
||||
|
||||
fn try_solve_out_of_range<T: Field>(d: &Directive<T>, witness: &mut BTreeMap<FlatVariable, T>) {
|
||||
fn try_solve_with_out_of_range_bits<T: Field>(bit_width: usize, input: T) -> Vec<T> {
|
||||
use num::traits::Pow;
|
||||
use num_bigint::BigUint;
|
||||
|
||||
// we target the `2a - 2b` part of the `<` check by only returning out-of-range results
|
||||
// when the input is not a single summand
|
||||
let value = d.inputs[0].evaluate(&witness).unwrap();
|
||||
|
||||
let candidate = value.to_biguint() + T::max_value().to_biguint() + T::from(1).to_biguint();
|
||||
let candidate = input.to_biguint() + T::max_value().to_biguint() + T::from(1).to_biguint();
|
||||
|
||||
let input = if candidate < T::from(2).to_biguint().pow(T::get_required_bits()) {
|
||||
candidate
|
||||
} else {
|
||||
value.to_biguint()
|
||||
input.to_biguint()
|
||||
};
|
||||
|
||||
let mut num = input;
|
||||
let mut res = vec![];
|
||||
let bits = T::get_required_bits();
|
||||
for i in (0..bits).rev() {
|
||||
if T::from(2).to_biguint().pow(i as usize) <= num {
|
||||
num -= T::from(2).to_biguint().pow(i as usize);
|
||||
res.push(T::one());
|
||||
} else {
|
||||
res.push(T::zero());
|
||||
}
|
||||
}
|
||||
assert_eq!(num, T::zero().to_biguint());
|
||||
let padding = bit_width - T::get_required_bits();
|
||||
|
||||
println!("RES {:?}", res);
|
||||
|
||||
for (i, o) in d.outputs.iter().enumerate() {
|
||||
witness.insert(*o, res[i].clone());
|
||||
}
|
||||
(0..padding)
|
||||
.map(|_| T::zero())
|
||||
.chain((0..T::get_required_bits()).rev().scan(input, |state, i| {
|
||||
if BigUint::from(2usize).pow(i) <= *state {
|
||||
*state = (*state).clone() - BigUint::from(2usize).pow(i);
|
||||
Some(T::one())
|
||||
} else {
|
||||
Some(T::zero())
|
||||
}
|
||||
}))
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn check_inputs<T: Field, U>(&self, program: &Prog<T>, inputs: &[U]) -> Result<(), Error> {
|
||||
|
@ -138,11 +130,7 @@ impl Interpreter {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn execute_solver<T: Field>(
|
||||
&self,
|
||||
solver: &Solver,
|
||||
inputs: &[T],
|
||||
) -> Result<Vec<T>, String> {
|
||||
pub fn execute_solver<T: Field>(solver: &Solver, inputs: &[T]) -> Result<Vec<T>, String> {
|
||||
let (expected_input_count, expected_output_count) = solver.get_signature();
|
||||
assert_eq!(inputs.len(), expected_input_count);
|
||||
|
||||
|
|
|
@ -140,9 +140,7 @@ impl<T: Field> Folder<T> for RedefinitionOptimizer<T> {
|
|||
// unwrap inputs to their constant value
|
||||
let inputs: Vec<_> = inputs.into_iter().map(|i| i.unwrap()).collect();
|
||||
// run the solver
|
||||
let outputs = Interpreter::default()
|
||||
.execute_solver(&d.solver, &inputs)
|
||||
.unwrap();
|
||||
let outputs = Interpreter::execute_solver(&d.solver, &inputs).unwrap();
|
||||
assert_eq!(outputs.len(), d.outputs.len());
|
||||
|
||||
// insert the results in the substitution
|
||||
|
|
|
@ -129,7 +129,7 @@ fn unpack256_unchecked() {
|
|||
.to_string();
|
||||
|
||||
// let's try to prove that the least significant bit of 0 is 1
|
||||
// we exploit the fact that the bits of 0 are the bits of p, and p is even
|
||||
// we exploit the fact that the bits of 0 are the bits of p, and p is odd
|
||||
// we want this to succeed as the non strict version does not enforce the bits to be in range
|
||||
|
||||
let stdlib_path = std::fs::canonicalize(
|
||||
|
|
Loading…
Reference in a new issue