1
0
Fork 0
mirror of synced 2025-09-23 12:18:44 +00:00
ZoKrates/zokrates_core/src/optimizer/redefinition.rs
2022-05-11 10:00:11 +02:00

480 lines
16 KiB
Rust

//! Module containing the `RedefinitionOptimizer` to remove code of the form
// ```
// b := a
// c := b
// ```
// and replace by
// ```
// c := a
// ```
// # Redefinition rules
// ## Defined variables
// We say that a variable `v` is defined in constraint `c_n` or an ir program if there exists a constraint `c_i` with `i < n` of the form:
// ```
// q == k * v
// where:
// - q is quadratic and does not contain `v`
// - k is a scalar
// ```
// ## Optimization rules
// We maintain `s`, a set of substitutions as a mapping of `(variable => linear_combination)`. It starts empty.
// We also maintain `i`, a set of variables that should be ignored when trying to substitute. It starts empty.
// - input variables are inserted into `i`
// - the `~one` variable is inserted into `i`
// - For each directive
// - if all directive inputs are of the form as `coeff * ~one`, execute the solver with all `coeff` as inputs, introducing `res`
// as the output, and insert `(v, o)` into `s` for `(v, o)` in `(d.outputs, res)`
// - else, for each variable `v` introduced, insert `v` into `i`
// - For each constraint `c`, we replace all variables by their value in `s` if any, otherwise leave them unchanged. Let's call `c_0` the resulting constraint. We either return `c_0` or nothing based on the form of `c_0`:
// - `~one * lin == k * v if v isn't in i`: insert `(v, lin / k)` into `s` and return nothing
// - `q == k * v if v isn't in i`: insert `v` into `i` and return `c_0`
// - otherwise return `c_0`
use crate::flat_absy::Variable;
use std::collections::{HashMap, HashSet};
use zokrates_ast::ir::folder::Folder;
use zokrates_ast::ir::LinComb;
use zokrates_ast::ir::*;
use zokrates_field::Field;
use zokrates_interpreter::Interpreter;
#[derive(Debug)]
pub struct RedefinitionOptimizer<T> {
/// Map of renamings for reassigned variables while processing the program.
substitution: HashMap<Variable, CanonicalLinComb<T>>,
/// Set of variables that should not be substituted
pub ignore: HashSet<Variable>,
}
impl<T> RedefinitionOptimizer<T> {
pub fn init<I: IntoIterator<Item = Statement<T>>>(p: &ProgIterator<T, I>) -> Self {
RedefinitionOptimizer {
substitution: HashMap::new(),
ignore: vec![Variable::one()]
.into_iter()
.chain(p.arguments.iter().map(|p| p.id))
.chain(p.returns())
.into_iter()
.collect(),
}
}
}
impl<T: Field> Folder<T> for RedefinitionOptimizer<T> {
fn fold_statement(&mut self, s: Statement<T>) -> Vec<Statement<T>> {
match s {
Statement::Constraint(quad, lin, message) => {
let quad = self.fold_quadratic_combination(quad);
let lin = self.fold_linear_combination(lin);
if lin.is_zero() {
return vec![Statement::Constraint(quad, lin, message)];
}
let (constraint, to_insert, to_ignore) = match self.ignore.contains(&lin.0[0].0)
|| self.substitution.contains_key(&lin.0[0].0)
{
true => (Some(Statement::Constraint(quad, lin, message)), None, None),
false => match lin.try_summand() {
// if the right side is a single variable
Ok((variable, coefficient)) => match quad.try_linear() {
// if the left side is linear
Ok(l) => (None, Some((variable, l / &coefficient)), None),
// if the left side isn't linear
Err(quad) => (
Some(Statement::Constraint(
quad,
LinComb::summand(coefficient, variable),
message,
)),
None,
Some(variable),
),
},
Err(l) => (Some(Statement::Constraint(quad, l, message)), None, None),
},
};
// insert into the ignored set
if let Some(v) = to_ignore {
self.ignore.insert(v);
}
// insert into the substitution map
if let Some((k, v)) = to_insert {
self.substitution.insert(k, v.into_canonical());
};
// decide whether the constraint should be kept
match constraint {
Some(c) => vec![c],
_ => vec![],
}
}
Statement::Directive(d) => {
let d = self.fold_directive(d);
// check if the inputs are constants, ie reduce to the form `coeff * ~one`
let inputs: Vec<_> = d
.inputs
.into_iter()
// we need to reduce to the canonical form to interpret `a + 1 - a` as `1`
.map(|i| i.reduce())
.map(|q| {
match q
.try_linear()
.map(|l| l.try_constant().map_err(|l| l.into()))
{
Ok(r) => r,
Err(e) => Err(e),
}
})
.collect::<Vec<Result<T, QuadComb<T>>>>();
match inputs.iter().all(|i| i.is_ok()) {
true => {
// unwrap inputs to their constant value
let inputs: Vec<_> = inputs.into_iter().map(|i| i.unwrap()).collect();
// run the solver
let outputs = Interpreter::execute_solver(&d.solver, &inputs).unwrap();
assert_eq!(outputs.len(), d.outputs.len());
// insert the results in the substitution
for (output, value) in d.outputs.into_iter().zip(outputs.into_iter()) {
self.substitution
.insert(output, LinComb::from(value).into_canonical());
}
vec![]
}
false => {
//reconstruct the input expressions
let inputs: Vec<_> = inputs
.into_iter()
.map(|i| {
i.map(|v| LinComb::summand(v, Variable::one()).into())
.unwrap_or_else(|q| q)
})
.collect();
// to prevent the optimiser from replacing variables introduced by directives, add them to the ignored set
for o in d.outputs.iter().cloned() {
self.ignore.insert(o);
}
vec![Statement::Directive(Directive { inputs, ..d })]
}
}
}
}
}
fn fold_linear_combination(&mut self, lc: LinComb<T>) -> LinComb<T> {
match lc
.0
.iter()
.any(|(variable, _)| self.substitution.get(variable).is_some())
{
true =>
// for each summand, check if it is equal to a linear term in our substitution, otherwise keep it as is
{
lc.0.into_iter()
.map(|(variable, coefficient)| {
self.substitution
.get(&variable)
.map(|l| LinComb::from(l.clone()) * &coefficient)
.unwrap_or_else(|| LinComb::summand(coefficient, variable))
})
.fold(LinComb::zero(), |acc, x| acc + x)
}
false => lc,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::flat_absy::Parameter;
use zokrates_field::Bn128Field;
#[test]
fn remove_synonyms() {
// def main(x):
// y = x
// z = y
// return z
let x = Parameter::public(Variable::new(0));
let y = Variable::new(1);
let out = Variable::public(0);
let p: Prog<Bn128Field> = Prog {
arguments: vec![x],
statements: vec![
Statement::definition(y, x.id),
Statement::definition(out, y),
],
return_count: 1,
};
let optimized: Prog<Bn128Field> = Prog {
arguments: vec![x],
statements: vec![Statement::definition(out, x.id)],
return_count: 1,
};
let mut optimizer = RedefinitionOptimizer::init(&p);
assert_eq!(optimizer.fold_program(p), optimized);
}
#[test]
fn keep_one() {
// def main(x):
// one = x
let one = Variable::one();
let x = Parameter::public(Variable::new(0));
let p: Prog<Bn128Field> = Prog {
arguments: vec![x],
statements: vec![Statement::definition(one, x.id)],
return_count: 1,
};
let optimized = p.clone();
let mut optimizer = RedefinitionOptimizer::init(&p);
assert_eq!(optimizer.fold_program(p), optimized);
}
#[test]
fn remove_synonyms_in_condition() {
// def main(x) -> (1):
// y = x
// z = y
// z == y
// ~out_0 = z
// ->
// def main(x) -> (1)
// x == x // will be eliminated as a tautology
// return x
let x = Parameter::public(Variable::new(0));
let y = Variable::new(1);
let z = Variable::new(2);
let out = Variable::public(0);
let p: Prog<Bn128Field> = Prog {
arguments: vec![x],
statements: vec![
Statement::definition(y, x.id),
Statement::definition(z, y),
Statement::constraint(z, y),
Statement::definition(out, z),
],
return_count: 1,
};
let optimized: Prog<Bn128Field> = Prog {
arguments: vec![x],
statements: vec![
Statement::constraint(x.id, x.id),
Statement::definition(out, x.id),
],
return_count: 1,
};
let mut optimizer = RedefinitionOptimizer::init(&p);
assert_eq!(optimizer.fold_program(p), optimized);
}
#[test]
fn remove_multiple_synonyms() {
// def main(x) -> (2):
// y = x
// t = 1
// z = y
// w = t
// ~out_0 = z
// ~out_1 = w
// ->
// def main(x):
// return x, 1
let x = Parameter::public(Variable::new(0));
let y = Variable::new(1);
let z = Variable::new(2);
let t = Variable::new(3);
let w = Variable::new(4);
let out_1 = Variable::public(0);
let out_0 = Variable::public(1);
let p: Prog<Bn128Field> = Prog {
arguments: vec![x],
statements: vec![
Statement::definition(y, x.id),
Statement::definition(t, Bn128Field::from(1)),
Statement::definition(z, y),
Statement::definition(w, t),
Statement::definition(out_0, z),
Statement::definition(out_1, w),
],
return_count: 2,
};
let optimized: Prog<Bn128Field> = Prog {
arguments: vec![x],
statements: vec![
Statement::definition(out_0, x.id),
Statement::definition(out_1, Bn128Field::from(1)),
],
return_count: 2,
};
let mut optimizer = RedefinitionOptimizer::init(&p);
assert_eq!(optimizer.fold_program(p), optimized);
}
#[test]
fn substitute_lincomb() {
// def main(x, y) -> (1):
// a = x + y
// b = a + x + y
// c = b + x + y
// 2*c == 6*x + 6*y
// ~out_0 = a + b + c
// ->
// def main(x, y) -> (1):
// 1*x + 1*y + 2*x + 2*y + 3*x + 3*y == 6*x + 6*y // will be eliminated as a tautology
// ~out_0 = 6*x + 6*y
let x = Parameter::public(Variable::new(0));
let y = Parameter::public(Variable::new(1));
let a = Variable::new(2);
let b = Variable::new(3);
let c = Variable::new(4);
let r = Variable::public(0);
let p: Prog<Bn128Field> = Prog {
arguments: vec![x, y],
statements: vec![
Statement::definition(a, LinComb::from(x.id) + LinComb::from(y.id)),
Statement::definition(
b,
LinComb::from(a) + LinComb::from(x.id) + LinComb::from(y.id),
),
Statement::definition(
c,
LinComb::from(b) + LinComb::from(x.id) + LinComb::from(y.id),
),
Statement::constraint(
LinComb::summand(2, c),
LinComb::summand(6, x.id) + LinComb::summand(6, y.id),
),
Statement::definition(r, LinComb::from(a) + LinComb::from(b) + LinComb::from(c)),
],
return_count: 1,
};
let expected: Prog<Bn128Field> = Prog {
arguments: vec![x, y],
statements: vec![
Statement::constraint(
LinComb::summand(6, x.id) + LinComb::summand(6, y.id),
LinComb::summand(6, x.id) + LinComb::summand(6, y.id),
),
Statement::definition(
r,
LinComb::summand(1, x.id)
+ LinComb::summand(1, y.id)
+ LinComb::summand(2, x.id)
+ LinComb::summand(2, y.id)
+ LinComb::summand(3, x.id)
+ LinComb::summand(3, y.id),
),
],
return_count: 1,
};
let mut optimizer = RedefinitionOptimizer::init(&p);
let optimized = optimizer.fold_program(p);
assert_eq!(optimized, expected);
}
#[test]
fn keep_existing_quadratic_variable() {
// def main(x, y):
// z = x * y
// z = x
// return
// ->
// def main(x, y):
// z = x * y
// z = x
// return
let x = Parameter::public(Variable::new(0));
let y = Parameter::public(Variable::new(1));
let z = Variable::new(2);
let p: Prog<Bn128Field> = Prog {
arguments: vec![x, y],
statements: vec![
Statement::definition(
z,
QuadComb::from_linear_combinations(LinComb::from(x.id), LinComb::from(y.id)),
),
Statement::definition(z, LinComb::from(x.id)),
],
return_count: 0,
};
let optimized = p.clone();
let mut optimizer = RedefinitionOptimizer::init(&p);
assert_eq!(optimizer.fold_program(p), optimized);
}
#[test]
fn keep_existing_variable() {
// def main(x) -> (1):
// x == 1
// x == 2
// return x
// ->
// unchanged
let x = Parameter::public(Variable::new(0));
let p: Prog<Bn128Field> = Prog {
arguments: vec![x],
statements: vec![
Statement::constraint(x.id, Bn128Field::from(1)),
Statement::constraint(x.id, Bn128Field::from(2)),
],
return_count: 1,
};
let optimized = p.clone();
let mut optimizer = RedefinitionOptimizer::init(&p);
assert_eq!(optimizer.fold_program(p), optimized);
}
}