Merge pull request #907 from Zokrates/cache-branch-conditions
Avoid creating empty if-else statements, cache conditions when flattening
This commit is contained in:
commit
8e88733bd6
5 changed files with 42 additions and 11 deletions
1
changelogs/unreleased/907-schaeff
Normal file
1
changelogs/unreleased/907-schaeff
Normal file
|
@ -0,0 +1 @@
|
|||
Reduce the cost of conditionals
|
|
@ -35,6 +35,8 @@ pub struct Flattener<'ast, T: Field> {
|
|||
layout: HashMap<Identifier<'ast>, FlatVariable>,
|
||||
/// Cached bit decompositions to avoid re-generating them
|
||||
bits_cache: HashMap<FlatExpression<T>, Vec<FlatExpression<T>>>,
|
||||
/// Cached flattened conditions for branches
|
||||
condition_cache: HashMap<BooleanExpression<'ast, T>, FlatVariable>,
|
||||
}
|
||||
|
||||
trait FlattenOutput<T: Field>: Sized {
|
||||
|
@ -159,6 +161,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
next_var_idx: 0,
|
||||
layout: HashMap::new(),
|
||||
bits_cache: HashMap::new(),
|
||||
condition_cache: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -413,6 +416,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
name_x_or_y.into(),
|
||||
T::one().into(),
|
||||
));
|
||||
|
||||
output
|
||||
}
|
||||
s => vec![s],
|
||||
|
@ -437,10 +441,13 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
consequence: U,
|
||||
alternative: U,
|
||||
) -> FlatUExpression<T> {
|
||||
let condition = self.flatten_boolean_expression(statements_flattened, condition);
|
||||
let condition_flat =
|
||||
self.flatten_boolean_expression(statements_flattened, condition.clone());
|
||||
|
||||
let condition_id = self.use_sym();
|
||||
statements_flattened.push(FlatStatement::Definition(condition_id, condition));
|
||||
statements_flattened.push(FlatStatement::Definition(condition_id, condition_flat));
|
||||
|
||||
self.condition_cache.insert(condition, condition_id);
|
||||
|
||||
let (consequence, alternative) = if self.config.isolate_branches {
|
||||
let mut consequence_statements = vec![];
|
||||
|
@ -636,7 +643,11 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
statements_flattened: &mut FlatStatements<T>,
|
||||
expression: BooleanExpression<'ast, T>,
|
||||
) -> FlatExpression<T> {
|
||||
// those will be booleans in the future
|
||||
// check the cache
|
||||
if let Some(c) = self.condition_cache.get(&expression) {
|
||||
return (*c).into();
|
||||
}
|
||||
|
||||
match expression {
|
||||
BooleanExpression::Identifier(x) => {
|
||||
FlatExpression::Identifier(*self.layout.get(&x).unwrap())
|
||||
|
@ -2184,7 +2195,13 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
}));
|
||||
}
|
||||
ZirStatement::IfElse(condition, consequence, alternative) => {
|
||||
let condition = self.flatten_boolean_expression(statements_flattened, condition);
|
||||
let condition_flat =
|
||||
self.flatten_boolean_expression(statements_flattened, condition.clone());
|
||||
|
||||
let condition_id = self.use_sym();
|
||||
statements_flattened.push(FlatStatement::Definition(condition_id, condition_flat));
|
||||
|
||||
self.condition_cache.insert(condition, condition_id);
|
||||
|
||||
if self.config.isolate_branches {
|
||||
let mut consequence_statements = vec![];
|
||||
|
@ -2198,10 +2215,13 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
.for_each(|s| self.flatten_statement(&mut alternative_statements, s));
|
||||
|
||||
let consequence_statements =
|
||||
self.make_conditional(consequence_statements, condition.clone());
|
||||
self.make_conditional(consequence_statements, condition_id.into());
|
||||
let alternative_statements = self.make_conditional(
|
||||
alternative_statements,
|
||||
FlatExpression::Sub(box FlatExpression::Number(T::one()), box condition),
|
||||
FlatExpression::Sub(
|
||||
box FlatExpression::Number(T::one()),
|
||||
box condition_id.into(),
|
||||
),
|
||||
);
|
||||
|
||||
statements_flattened.extend(consequence_statements);
|
||||
|
|
|
@ -601,11 +601,13 @@ fn fold_if_else_expression<'ast, T: Field, E: Flatten<'ast, T>>(
|
|||
|
||||
assert_eq!(consequence.len(), alternative.len());
|
||||
|
||||
if !consequence_statements.is_empty() || !alternative_statements.is_empty() {
|
||||
statements_buffer.push(zir::ZirStatement::IfElse(
|
||||
condition.clone(),
|
||||
consequence_statements,
|
||||
alternative_statements,
|
||||
));
|
||||
}
|
||||
|
||||
use crate::zir::IfElse;
|
||||
|
||||
|
|
4
zokrates_core_test/tests/tests/cached_condition.json
Normal file
4
zokrates_core_test/tests/tests/cached_condition.json
Normal file
|
@ -0,0 +1,4 @@
|
|||
{
|
||||
"entry_point": "./tests/tests/cached_condition.zok",
|
||||
"max_constraint_count": 2015
|
||||
}
|
4
zokrates_core_test/tests/tests/cached_condition.zok
Normal file
4
zokrates_core_test/tests/tests/cached_condition.zok
Normal file
|
@ -0,0 +1,4 @@
|
|||
// `a < b` should be flattened a single time, even if 1000 elements are assigned conditionally
|
||||
|
||||
def main(field a, field b) -> field[1000]:
|
||||
return if a < b then [0f; 1000] else [1; 1000] fi
|
Loading…
Reference in a new issue