execute directives at compile time if the inputs are constant
This commit is contained in:
parent
911ea3171b
commit
1b5ff6aa98
3 changed files with 97 additions and 15 deletions
|
@ -79,6 +79,27 @@ impl<T: Field> Eq for LinComb<T> {}
|
|||
#[derive(PartialEq, PartialOrd, Clone, Eq, Ord, Hash, Debug, Serialize, Deserialize)]
|
||||
pub struct CanonicalLinComb<T>(pub BTreeMap<FlatVariable, T>);
|
||||
|
||||
#[derive(PartialEq, PartialOrd, Clone, Eq, Ord, Hash, Debug, Serialize, Deserialize)]
|
||||
pub struct CanonicalQuadComb<T> {
|
||||
left: CanonicalLinComb<T>,
|
||||
right: CanonicalLinComb<T>,
|
||||
}
|
||||
|
||||
impl<T> From<CanonicalQuadComb<T>> for QuadComb<T> {
|
||||
fn from(q: CanonicalQuadComb<T>) -> Self {
|
||||
QuadComb {
|
||||
left: q.left.into(),
|
||||
right: q.right.into(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> From<CanonicalLinComb<T>> for LinComb<T> {
|
||||
fn from(l: CanonicalLinComb<T>) -> Self {
|
||||
LinComb(l.0.into_iter().collect())
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> LinComb<T> {
|
||||
pub fn summand<U: Into<T>>(mult: U, var: FlatVariable) -> LinComb<T> {
|
||||
let res = vec![(var, mult.into())];
|
||||
|
@ -160,6 +181,15 @@ impl<T: Field> LinComb<T> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<T: Field> QuadComb<T> {
|
||||
pub fn as_canonical(&self) -> CanonicalQuadComb<T> {
|
||||
CanonicalQuadComb {
|
||||
left: self.left.as_canonical(),
|
||||
right: self.right.as_canonical(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Field> fmt::Display for LinComb<T> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
match self.is_zero() {
|
||||
|
|
|
@ -133,7 +133,7 @@ impl Interpreter {
|
|||
}
|
||||
}
|
||||
|
||||
fn execute_solver<T: Field>(&self, s: &Solver, inputs: &Vec<T>) -> Result<Vec<T>, String> {
|
||||
pub fn execute_solver<T: Field>(&self, s: &Solver, inputs: &Vec<T>) -> Result<Vec<T>, String> {
|
||||
use solvers::Signed;
|
||||
let (expected_input_count, expected_output_count) = s.get_signature();
|
||||
assert!(inputs.len() == expected_input_count);
|
||||
|
|
|
@ -23,7 +23,7 @@
|
|||
// ## Optimization rules
|
||||
|
||||
// We maintain `s`, a set of substitutions as a mapping of `(variable => linear_combination)`. It starts empty.
|
||||
// We also maintaint `i`, a set of variables that should be ignored when trying to substitute. It starts empty.
|
||||
// We also maintain `i`, a set of variables that should be ignored when trying to substitute. It starts empty.
|
||||
|
||||
// - input variables are inserted into `i`
|
||||
// - the `~one` variable is inserted into `i`
|
||||
|
@ -114,25 +114,77 @@ impl<T: Field> Folder<T> for RedefinitionOptimizer<T> {
|
|||
}
|
||||
Statement::Directive(d) => {
|
||||
let d = self.fold_directive(d);
|
||||
// to prevent the optimiser from replacing variables introduced by directives, add them to the ignored set
|
||||
for o in d.outputs.iter().cloned() {
|
||||
self.ignore.insert(o);
|
||||
|
||||
let inputs = d
|
||||
.inputs
|
||||
.iter()
|
||||
.map(|i| LinComb::from(i.as_canonical()))
|
||||
.map(|l| match l.0.len() {
|
||||
0 => Ok(T::from(0)),
|
||||
_ => l
|
||||
.try_summand()
|
||||
.map(|(variable, coefficient)| match variable {
|
||||
v if v == FlatVariable::one() => Ok(coefficient),
|
||||
_ => Err(LinComb::summand(coefficient, variable).into()),
|
||||
})
|
||||
.unwrap_or(Err(l.into())),
|
||||
})
|
||||
.collect::<Vec<Result<T, LinComb<T>>>>();
|
||||
|
||||
match inputs.iter().all(|r| r.is_ok()) {
|
||||
true => {
|
||||
let inputs = inputs.into_iter().map(|i| i.unwrap()).collect();
|
||||
let outputs = Interpreter::default()
|
||||
.execute_solver(&d.solver, &inputs)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(outputs.len(), d.outputs.len());
|
||||
|
||||
for (output, value) in d.outputs.into_iter().zip(outputs.into_iter()) {
|
||||
self.substitution.insert(output, value.into());
|
||||
}
|
||||
vec![]
|
||||
}
|
||||
false => {
|
||||
let inputs = inputs
|
||||
.into_iter()
|
||||
.map(|i| {
|
||||
i.map(|v| LinComb::summand(v, FlatVariable::one()))
|
||||
.unwrap_or_else(|q| q)
|
||||
})
|
||||
.collect();
|
||||
// to prevent the optimiser from replacing variables introduced by directives, add them to the ignored set
|
||||
for o in d.outputs.iter().cloned() {
|
||||
self.ignore.insert(o);
|
||||
}
|
||||
vec![Statement::Directive(Directive { inputs, ..d })]
|
||||
}
|
||||
}
|
||||
vec![Statement::Directive(d)]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn fold_linear_combination(&mut self, lc: LinComb<T>) -> LinComb<T> {
|
||||
// for each summand, check if it is equal to a linear term in our substitution, otherwise keep it as is
|
||||
lc.0.into_iter()
|
||||
.map(|(variable, coefficient)| {
|
||||
self.substitution
|
||||
.get(&variable)
|
||||
.map(|l| l.clone() * &coefficient)
|
||||
.unwrap_or(LinComb::summand(coefficient, variable))
|
||||
})
|
||||
.fold(LinComb::zero(), |acc, x| acc + x)
|
||||
match lc
|
||||
.0
|
||||
.iter()
|
||||
.find(|(variable, _)| self.substitution.get(&variable).is_some())
|
||||
.is_some()
|
||||
{
|
||||
true =>
|
||||
// for each summand, check if it is equal to a linear term in our substitution, otherwise keep it as is
|
||||
{
|
||||
lc.0.into_iter()
|
||||
.map(|(variable, coefficient)| {
|
||||
self.substitution
|
||||
.get(&variable)
|
||||
.map(|l| l.clone() * &coefficient)
|
||||
.unwrap_or(LinComb::summand(coefficient, variable))
|
||||
})
|
||||
.fold(LinComb::zero(), |acc, x| acc + x)
|
||||
}
|
||||
false => lc,
|
||||
}
|
||||
}
|
||||
|
||||
fn fold_argument(&mut self, a: FlatVariable) -> FlatVariable {
|
||||
|
|
Loading…
Reference in a new issue