1
0
Fork 0
mirror of synced 2025-09-23 12:18:44 +00:00

Merge pull request #907 from Zokrates/cache-branch-conditions

Avoid creating empty if-else statements, cache conditions when flattening
This commit is contained in:
Thibaut Schaeffer 2021-06-08 15:05:07 +02:00 committed by GitHub
commit 8e88733bd6
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 42 additions and 11 deletions

View file

@ -0,0 +1 @@
Reduce the cost of conditionals

View file

@ -35,6 +35,8 @@ pub struct Flattener<'ast, T: Field> {
layout: HashMap<Identifier<'ast>, FlatVariable>,
/// Cached bit decompositions to avoid re-generating them
bits_cache: HashMap<FlatExpression<T>, Vec<FlatExpression<T>>>,
/// Cached flattened conditions for branches
condition_cache: HashMap<BooleanExpression<'ast, T>, FlatVariable>,
}
trait FlattenOutput<T: Field>: Sized {
@ -159,6 +161,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
next_var_idx: 0,
layout: HashMap::new(),
bits_cache: HashMap::new(),
condition_cache: HashMap::new(),
}
}
@ -413,6 +416,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
name_x_or_y.into(),
T::one().into(),
));
output
}
s => vec![s],
@ -437,10 +441,13 @@ impl<'ast, T: Field> Flattener<'ast, T> {
consequence: U,
alternative: U,
) -> FlatUExpression<T> {
let condition = self.flatten_boolean_expression(statements_flattened, condition);
let condition_flat =
self.flatten_boolean_expression(statements_flattened, condition.clone());
let condition_id = self.use_sym();
statements_flattened.push(FlatStatement::Definition(condition_id, condition));
statements_flattened.push(FlatStatement::Definition(condition_id, condition_flat));
self.condition_cache.insert(condition, condition_id);
let (consequence, alternative) = if self.config.isolate_branches {
let mut consequence_statements = vec![];
@ -636,7 +643,11 @@ impl<'ast, T: Field> Flattener<'ast, T> {
statements_flattened: &mut FlatStatements<T>,
expression: BooleanExpression<'ast, T>,
) -> FlatExpression<T> {
// those will be booleans in the future
// check the cache
if let Some(c) = self.condition_cache.get(&expression) {
return (*c).into();
}
match expression {
BooleanExpression::Identifier(x) => {
FlatExpression::Identifier(*self.layout.get(&x).unwrap())
@ -2184,7 +2195,13 @@ impl<'ast, T: Field> Flattener<'ast, T> {
}));
}
ZirStatement::IfElse(condition, consequence, alternative) => {
let condition = self.flatten_boolean_expression(statements_flattened, condition);
let condition_flat =
self.flatten_boolean_expression(statements_flattened, condition.clone());
let condition_id = self.use_sym();
statements_flattened.push(FlatStatement::Definition(condition_id, condition_flat));
self.condition_cache.insert(condition, condition_id);
if self.config.isolate_branches {
let mut consequence_statements = vec![];
@ -2198,10 +2215,13 @@ impl<'ast, T: Field> Flattener<'ast, T> {
.for_each(|s| self.flatten_statement(&mut alternative_statements, s));
let consequence_statements =
self.make_conditional(consequence_statements, condition.clone());
self.make_conditional(consequence_statements, condition_id.into());
let alternative_statements = self.make_conditional(
alternative_statements,
FlatExpression::Sub(box FlatExpression::Number(T::one()), box condition),
FlatExpression::Sub(
box FlatExpression::Number(T::one()),
box condition_id.into(),
),
);
statements_flattened.extend(consequence_statements);

View file

@ -601,11 +601,13 @@ fn fold_if_else_expression<'ast, T: Field, E: Flatten<'ast, T>>(
assert_eq!(consequence.len(), alternative.len());
statements_buffer.push(zir::ZirStatement::IfElse(
condition.clone(),
consequence_statements,
alternative_statements,
));
if !consequence_statements.is_empty() || !alternative_statements.is_empty() {
statements_buffer.push(zir::ZirStatement::IfElse(
condition.clone(),
consequence_statements,
alternative_statements,
));
}
use crate::zir::IfElse;

View file

@ -0,0 +1,4 @@
{
"entry_point": "./tests/tests/cached_condition.zok",
"max_constraint_count": 2015
}

View file

@ -0,0 +1,4 @@
// `a < b` should be flattened a single time, even if 1000 elements are assigned conditionally
def main(field a, field b) -> field[1000]:
return if a < b then [0f; 1000] else [1; 1000] fi