1
0
Fork 0
mirror of synced 2025-09-24 04:40:05 +00:00

execute directives at compile time if the inputs are constant

This commit is contained in:
schaeff 2020-06-17 13:56:58 +02:00
parent 911ea3171b
commit 1b5ff6aa98
3 changed files with 97 additions and 15 deletions

View file

@ -79,6 +79,27 @@ impl<T: Field> Eq for LinComb<T> {}
#[derive(PartialEq, PartialOrd, Clone, Eq, Ord, Hash, Debug, Serialize, Deserialize)]
pub struct CanonicalLinComb<T>(pub BTreeMap<FlatVariable, T>);
#[derive(PartialEq, PartialOrd, Clone, Eq, Ord, Hash, Debug, Serialize, Deserialize)]
pub struct CanonicalQuadComb<T> {
left: CanonicalLinComb<T>,
right: CanonicalLinComb<T>,
}
impl<T> From<CanonicalQuadComb<T>> for QuadComb<T> {
fn from(q: CanonicalQuadComb<T>) -> Self {
QuadComb {
left: q.left.into(),
right: q.right.into(),
}
}
}
impl<T> From<CanonicalLinComb<T>> for LinComb<T> {
fn from(l: CanonicalLinComb<T>) -> Self {
LinComb(l.0.into_iter().collect())
}
}
impl<T> LinComb<T> {
pub fn summand<U: Into<T>>(mult: U, var: FlatVariable) -> LinComb<T> {
let res = vec![(var, mult.into())];
@ -160,6 +181,15 @@ impl<T: Field> LinComb<T> {
}
}
impl<T: Field> QuadComb<T> {
pub fn as_canonical(&self) -> CanonicalQuadComb<T> {
CanonicalQuadComb {
left: self.left.as_canonical(),
right: self.right.as_canonical(),
}
}
}
impl<T: Field> fmt::Display for LinComb<T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self.is_zero() {

View file

@ -133,7 +133,7 @@ impl Interpreter {
}
}
fn execute_solver<T: Field>(&self, s: &Solver, inputs: &Vec<T>) -> Result<Vec<T>, String> {
pub fn execute_solver<T: Field>(&self, s: &Solver, inputs: &Vec<T>) -> Result<Vec<T>, String> {
use solvers::Signed;
let (expected_input_count, expected_output_count) = s.get_signature();
assert!(inputs.len() == expected_input_count);

View file

@ -23,7 +23,7 @@
// ## Optimization rules
// We maintain `s`, a set of substitutions as a mapping of `(variable => linear_combination)`. It starts empty.
// We also maintaint `i`, a set of variables that should be ignored when trying to substitute. It starts empty.
// We also maintain `i`, a set of variables that should be ignored when trying to substitute. It starts empty.
// - input variables are inserted into `i`
// - the `~one` variable is inserted into `i`
@ -114,25 +114,77 @@ impl<T: Field> Folder<T> for RedefinitionOptimizer<T> {
}
Statement::Directive(d) => {
let d = self.fold_directive(d);
// to prevent the optimiser from replacing variables introduced by directives, add them to the ignored set
for o in d.outputs.iter().cloned() {
self.ignore.insert(o);
let inputs = d
.inputs
.iter()
.map(|i| LinComb::from(i.as_canonical()))
.map(|l| match l.0.len() {
0 => Ok(T::from(0)),
_ => l
.try_summand()
.map(|(variable, coefficient)| match variable {
v if v == FlatVariable::one() => Ok(coefficient),
_ => Err(LinComb::summand(coefficient, variable).into()),
})
.unwrap_or(Err(l.into())),
})
.collect::<Vec<Result<T, LinComb<T>>>>();
match inputs.iter().all(|r| r.is_ok()) {
true => {
let inputs = inputs.into_iter().map(|i| i.unwrap()).collect();
let outputs = Interpreter::default()
.execute_solver(&d.solver, &inputs)
.unwrap();
assert_eq!(outputs.len(), d.outputs.len());
for (output, value) in d.outputs.into_iter().zip(outputs.into_iter()) {
self.substitution.insert(output, value.into());
}
vec![]
}
false => {
let inputs = inputs
.into_iter()
.map(|i| {
i.map(|v| LinComb::summand(v, FlatVariable::one()))
.unwrap_or_else(|q| q)
})
.collect();
// to prevent the optimiser from replacing variables introduced by directives, add them to the ignored set
for o in d.outputs.iter().cloned() {
self.ignore.insert(o);
}
vec![Statement::Directive(Directive { inputs, ..d })]
}
}
vec![Statement::Directive(d)]
}
}
}
fn fold_linear_combination(&mut self, lc: LinComb<T>) -> LinComb<T> {
// for each summand, check if it is equal to a linear term in our substitution, otherwise keep it as is
lc.0.into_iter()
.map(|(variable, coefficient)| {
self.substitution
.get(&variable)
.map(|l| l.clone() * &coefficient)
.unwrap_or(LinComb::summand(coefficient, variable))
})
.fold(LinComb::zero(), |acc, x| acc + x)
match lc
.0
.iter()
.find(|(variable, _)| self.substitution.get(&variable).is_some())
.is_some()
{
true =>
// for each summand, check if it is equal to a linear term in our substitution, otherwise keep it as is
{
lc.0.into_iter()
.map(|(variable, coefficient)| {
self.substitution
.get(&variable)
.map(|l| l.clone() * &coefficient)
.unwrap_or(LinComb::summand(coefficient, variable))
})
.fold(LinComb::zero(), |acc, x| acc + x)
}
false => lc,
}
}
fn fold_argument(&mut self, a: FlatVariable) -> FlatVariable {