Merge pull request #907 from Zokrates/cache-branch-conditions
Avoid creating empty if-else statements, cache conditions when flattening
This commit is contained in:
commit
8e88733bd6
5 changed files with 42 additions and 11 deletions
1
changelogs/unreleased/907-schaeff
Normal file
1
changelogs/unreleased/907-schaeff
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Reduce the cost of conditionals
|
|
@ -35,6 +35,8 @@ pub struct Flattener<'ast, T: Field> {
|
||||||
layout: HashMap<Identifier<'ast>, FlatVariable>,
|
layout: HashMap<Identifier<'ast>, FlatVariable>,
|
||||||
/// Cached bit decompositions to avoid re-generating them
|
/// Cached bit decompositions to avoid re-generating them
|
||||||
bits_cache: HashMap<FlatExpression<T>, Vec<FlatExpression<T>>>,
|
bits_cache: HashMap<FlatExpression<T>, Vec<FlatExpression<T>>>,
|
||||||
|
/// Cached flattened conditions for branches
|
||||||
|
condition_cache: HashMap<BooleanExpression<'ast, T>, FlatVariable>,
|
||||||
}
|
}
|
||||||
|
|
||||||
trait FlattenOutput<T: Field>: Sized {
|
trait FlattenOutput<T: Field>: Sized {
|
||||||
|
@ -159,6 +161,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
||||||
next_var_idx: 0,
|
next_var_idx: 0,
|
||||||
layout: HashMap::new(),
|
layout: HashMap::new(),
|
||||||
bits_cache: HashMap::new(),
|
bits_cache: HashMap::new(),
|
||||||
|
condition_cache: HashMap::new(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -413,6 +416,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
||||||
name_x_or_y.into(),
|
name_x_or_y.into(),
|
||||||
T::one().into(),
|
T::one().into(),
|
||||||
));
|
));
|
||||||
|
|
||||||
output
|
output
|
||||||
}
|
}
|
||||||
s => vec![s],
|
s => vec![s],
|
||||||
|
@ -437,10 +441,13 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
||||||
consequence: U,
|
consequence: U,
|
||||||
alternative: U,
|
alternative: U,
|
||||||
) -> FlatUExpression<T> {
|
) -> FlatUExpression<T> {
|
||||||
let condition = self.flatten_boolean_expression(statements_flattened, condition);
|
let condition_flat =
|
||||||
|
self.flatten_boolean_expression(statements_flattened, condition.clone());
|
||||||
|
|
||||||
let condition_id = self.use_sym();
|
let condition_id = self.use_sym();
|
||||||
statements_flattened.push(FlatStatement::Definition(condition_id, condition));
|
statements_flattened.push(FlatStatement::Definition(condition_id, condition_flat));
|
||||||
|
|
||||||
|
self.condition_cache.insert(condition, condition_id);
|
||||||
|
|
||||||
let (consequence, alternative) = if self.config.isolate_branches {
|
let (consequence, alternative) = if self.config.isolate_branches {
|
||||||
let mut consequence_statements = vec![];
|
let mut consequence_statements = vec![];
|
||||||
|
@ -636,7 +643,11 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
||||||
statements_flattened: &mut FlatStatements<T>,
|
statements_flattened: &mut FlatStatements<T>,
|
||||||
expression: BooleanExpression<'ast, T>,
|
expression: BooleanExpression<'ast, T>,
|
||||||
) -> FlatExpression<T> {
|
) -> FlatExpression<T> {
|
||||||
// those will be booleans in the future
|
// check the cache
|
||||||
|
if let Some(c) = self.condition_cache.get(&expression) {
|
||||||
|
return (*c).into();
|
||||||
|
}
|
||||||
|
|
||||||
match expression {
|
match expression {
|
||||||
BooleanExpression::Identifier(x) => {
|
BooleanExpression::Identifier(x) => {
|
||||||
FlatExpression::Identifier(*self.layout.get(&x).unwrap())
|
FlatExpression::Identifier(*self.layout.get(&x).unwrap())
|
||||||
|
@ -2184,7 +2195,13 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
||||||
}));
|
}));
|
||||||
}
|
}
|
||||||
ZirStatement::IfElse(condition, consequence, alternative) => {
|
ZirStatement::IfElse(condition, consequence, alternative) => {
|
||||||
let condition = self.flatten_boolean_expression(statements_flattened, condition);
|
let condition_flat =
|
||||||
|
self.flatten_boolean_expression(statements_flattened, condition.clone());
|
||||||
|
|
||||||
|
let condition_id = self.use_sym();
|
||||||
|
statements_flattened.push(FlatStatement::Definition(condition_id, condition_flat));
|
||||||
|
|
||||||
|
self.condition_cache.insert(condition, condition_id);
|
||||||
|
|
||||||
if self.config.isolate_branches {
|
if self.config.isolate_branches {
|
||||||
let mut consequence_statements = vec![];
|
let mut consequence_statements = vec![];
|
||||||
|
@ -2198,10 +2215,13 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
||||||
.for_each(|s| self.flatten_statement(&mut alternative_statements, s));
|
.for_each(|s| self.flatten_statement(&mut alternative_statements, s));
|
||||||
|
|
||||||
let consequence_statements =
|
let consequence_statements =
|
||||||
self.make_conditional(consequence_statements, condition.clone());
|
self.make_conditional(consequence_statements, condition_id.into());
|
||||||
let alternative_statements = self.make_conditional(
|
let alternative_statements = self.make_conditional(
|
||||||
alternative_statements,
|
alternative_statements,
|
||||||
FlatExpression::Sub(box FlatExpression::Number(T::one()), box condition),
|
FlatExpression::Sub(
|
||||||
|
box FlatExpression::Number(T::one()),
|
||||||
|
box condition_id.into(),
|
||||||
|
),
|
||||||
);
|
);
|
||||||
|
|
||||||
statements_flattened.extend(consequence_statements);
|
statements_flattened.extend(consequence_statements);
|
||||||
|
|
|
@ -601,11 +601,13 @@ fn fold_if_else_expression<'ast, T: Field, E: Flatten<'ast, T>>(
|
||||||
|
|
||||||
assert_eq!(consequence.len(), alternative.len());
|
assert_eq!(consequence.len(), alternative.len());
|
||||||
|
|
||||||
statements_buffer.push(zir::ZirStatement::IfElse(
|
if !consequence_statements.is_empty() || !alternative_statements.is_empty() {
|
||||||
condition.clone(),
|
statements_buffer.push(zir::ZirStatement::IfElse(
|
||||||
consequence_statements,
|
condition.clone(),
|
||||||
alternative_statements,
|
consequence_statements,
|
||||||
));
|
alternative_statements,
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
use crate::zir::IfElse;
|
use crate::zir::IfElse;
|
||||||
|
|
||||||
|
|
4
zokrates_core_test/tests/tests/cached_condition.json
Normal file
4
zokrates_core_test/tests/tests/cached_condition.json
Normal file
|
@ -0,0 +1,4 @@
|
||||||
|
{
|
||||||
|
"entry_point": "./tests/tests/cached_condition.zok",
|
||||||
|
"max_constraint_count": 2015
|
||||||
|
}
|
4
zokrates_core_test/tests/tests/cached_condition.zok
Normal file
4
zokrates_core_test/tests/tests/cached_condition.zok
Normal file
|
@ -0,0 +1,4 @@
|
||||||
|
// `a < b` should be flattened a single time, even if 1000 elements are assigned conditionally
|
||||||
|
|
||||||
|
def main(field a, field b) -> field[1000]:
|
||||||
|
return if a < b then [0f; 1000] else [1; 1000] fi
|
Loading…
Reference in a new issue