1
0
Fork 0
mirror of synced 2025-09-24 04:40:05 +00:00

refactor out of range interpreter, accept any size of output

This commit is contained in:
schaeff 2021-08-05 16:03:57 +02:00
parent 63be983d74
commit 7f76505a86
4 changed files with 40 additions and 55 deletions

View file

@ -1982,8 +1982,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
// constants do not require directives
if let Some(FlatExpression::Number(ref x)) = e.field {
let bits: Vec<_> = Interpreter::default()
.execute_solver(&Solver::bits(to), &[x.clone()])
let bits: Vec<_> = Interpreter::execute_solver(&Solver::bits(to), &[x.clone()])
.unwrap()
.into_iter()
.map(FlatExpression::Number)

View file

@ -1,5 +1,4 @@
use crate::flat_absy::flat_variable::FlatVariable;
use crate::ir::Directive;
use crate::ir::{LinComb, Prog, QuadComb, Statement, Witness};
use crate::solvers::Solver;
use serde::{Deserialize, Serialize};
@ -65,66 +64,59 @@ impl Interpreter {
}
}
},
Statement::Directive(ref d) => match (&d.solver, self.should_try_out_of_range) {
(Solver::Bits(bitwidth), true) if *bitwidth >= T::get_required_bits() => {
Self::try_solve_out_of_range(&d, &mut witness)
Statement::Directive(ref d) => {
let mut inputs: Vec<_> = d
.inputs
.iter()
.map(|i| i.evaluate(&witness).unwrap())
.collect();
let res = match (&d.solver, self.should_try_out_of_range) {
(Solver::Bits(bitwidth), true) if *bitwidth >= T::get_required_bits() => {
Ok(Self::try_solve_with_out_of_range_bits(
*bitwidth,
inputs.pop().unwrap(),
))
}
_ => Self::execute_solver(&d.solver, &inputs),
}
_ => {
let inputs: Vec<_> = d
.inputs
.iter()
.map(|i| i.evaluate(&witness).unwrap())
.collect();
match self.execute_solver(&d.solver, &inputs) {
Ok(res) => {
for (i, o) in d.outputs.iter().enumerate() {
witness.insert(*o, res[i].clone());
}
continue;
}
Err(_) => return Err(Error::Solver),
};
.map_err(|_| Error::Solver)?;
for (i, o) in d.outputs.iter().enumerate() {
witness.insert(*o, res[i].clone());
}
},
}
}
}
Ok(Witness(witness))
}
fn try_solve_out_of_range<T: Field>(d: &Directive<T>, witness: &mut BTreeMap<FlatVariable, T>) {
fn try_solve_with_out_of_range_bits<T: Field>(bit_width: usize, input: T) -> Vec<T> {
use num::traits::Pow;
use num_bigint::BigUint;
// we target the `2a - 2b` part of the `<` check by only returning out-of-range results
// when the input is not a single summand
let value = d.inputs[0].evaluate(&witness).unwrap();
let candidate = value.to_biguint() + T::max_value().to_biguint() + T::from(1).to_biguint();
let candidate = input.to_biguint() + T::max_value().to_biguint() + T::from(1).to_biguint();
let input = if candidate < T::from(2).to_biguint().pow(T::get_required_bits()) {
candidate
} else {
value.to_biguint()
input.to_biguint()
};
let mut num = input;
let mut res = vec![];
let bits = T::get_required_bits();
for i in (0..bits).rev() {
if T::from(2).to_biguint().pow(i as usize) <= num {
num -= T::from(2).to_biguint().pow(i as usize);
res.push(T::one());
} else {
res.push(T::zero());
}
}
assert_eq!(num, T::zero().to_biguint());
let padding = bit_width - T::get_required_bits();
println!("RES {:?}", res);
for (i, o) in d.outputs.iter().enumerate() {
witness.insert(*o, res[i].clone());
}
(0..padding)
.map(|_| T::zero())
.chain((0..T::get_required_bits()).rev().scan(input, |state, i| {
if BigUint::from(2usize).pow(i) <= *state {
*state = (*state).clone() - BigUint::from(2usize).pow(i);
Some(T::one())
} else {
Some(T::zero())
}
}))
.collect()
}
fn check_inputs<T: Field, U>(&self, program: &Prog<T>, inputs: &[U]) -> Result<(), Error> {
@ -138,11 +130,7 @@ impl Interpreter {
}
}
pub fn execute_solver<T: Field>(
&self,
solver: &Solver,
inputs: &[T],
) -> Result<Vec<T>, String> {
pub fn execute_solver<T: Field>(solver: &Solver, inputs: &[T]) -> Result<Vec<T>, String> {
let (expected_input_count, expected_output_count) = solver.get_signature();
assert_eq!(inputs.len(), expected_input_count);

View file

@ -140,9 +140,7 @@ impl<T: Field> Folder<T> for RedefinitionOptimizer<T> {
// unwrap inputs to their constant value
let inputs: Vec<_> = inputs.into_iter().map(|i| i.unwrap()).collect();
// run the solver
let outputs = Interpreter::default()
.execute_solver(&d.solver, &inputs)
.unwrap();
let outputs = Interpreter::execute_solver(&d.solver, &inputs).unwrap();
assert_eq!(outputs.len(), d.outputs.len());
// insert the results in the substitution

View file

@ -129,7 +129,7 @@ fn unpack256_unchecked() {
.to_string();
// let's try to prove that the least significant bit of 0 is 1
// we exploit the fact that the bits of 0 are the bits of p, and p is even
// we exploit the fact that the bits of 0 are the bits of p, and p is odd
// we want this to succeed as the non strict version does not enforce the bits to be in range
let stdlib_path = std::fs::canonicalize(