From 8553d9d745513a7a0470667ee2f052c710435e8c Mon Sep 17 00:00:00 2001 From: schaeff Date: Fri, 4 Jun 2021 13:02:17 +0200 Subject: [PATCH 1/4] avoid creating empty if-else statements, cache conditions when flattening --- zokrates_core/src/flatten/mod.rs | 16 ++++- .../static_analysis/flatten_complex_types.rs | 60 +++++++++++-------- .../tests/tests/cached_condition.json | 4 ++ .../tests/tests/cached_condition.zok | 4 ++ 4 files changed, 56 insertions(+), 28 deletions(-) create mode 100644 zokrates_core_test/tests/tests/cached_condition.json create mode 100644 zokrates_core_test/tests/tests/cached_condition.zok diff --git a/zokrates_core/src/flatten/mod.rs b/zokrates_core/src/flatten/mod.rs index 249ebe7c..efaa0163 100644 --- a/zokrates_core/src/flatten/mod.rs +++ b/zokrates_core/src/flatten/mod.rs @@ -35,6 +35,8 @@ pub struct Flattener<'ast, T: Field> { layout: HashMap, FlatVariable>, /// Cached bit decompositions to avoid re-generating them bits_cache: HashMap, Vec>>, + /// Cached flattened conditions for branches + condition_cache: HashMap, FlatVariable>, } trait FlattenOutput: Sized { @@ -159,6 +161,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { next_var_idx: 0, layout: HashMap::new(), bits_cache: HashMap::new(), + condition_cache: HashMap::new(), } } @@ -437,10 +440,13 @@ impl<'ast, T: Field> Flattener<'ast, T> { consequence: U, alternative: U, ) -> FlatUExpression { - let condition = self.flatten_boolean_expression(statements_flattened, condition); + let condition_flat = + self.flatten_boolean_expression(statements_flattened, condition.clone()); let condition_id = self.use_sym(); - statements_flattened.push(FlatStatement::Definition(condition_id, condition)); + statements_flattened.push(FlatStatement::Definition(condition_id, condition_flat)); + + self.condition_cache.insert(condition, condition_id); let (consequence, alternative) = if self.config.isolate_branches { let mut consequence_statements = vec![]; @@ -636,7 +642,11 @@ impl<'ast, T: Field> Flattener<'ast, T> { statements_flattened: &mut FlatStatements, expression: BooleanExpression<'ast, T>, ) -> FlatExpression { - // those will be booleans in the future + // check the cache + if let Some(c) = self.condition_cache.get(&expression) { + return c.clone().into(); + } + match expression { BooleanExpression::Identifier(x) => { FlatExpression::Identifier(*self.layout.get(&x).unwrap()) diff --git a/zokrates_core/src/static_analysis/flatten_complex_types.rs b/zokrates_core/src/static_analysis/flatten_complex_types.rs index 6ccd7600..6dd9abf8 100644 --- a/zokrates_core/src/static_analysis/flatten_complex_types.rs +++ b/zokrates_core/src/static_analysis/flatten_complex_types.rs @@ -374,11 +374,13 @@ pub fn fold_array_expression_inner<'ast, T: Field>( assert_eq!(consequence.len(), alternative.len()); - statements_buffer.push(zir::ZirStatement::IfElse( - condition.clone(), - consequence_statements, - alternative_statements, - )); + if !consequence_statements.is_empty() || !alternative_statements.is_empty() { + statements_buffer.push(zir::ZirStatement::IfElse( + condition.clone(), + consequence_statements, + alternative_statements, + )); + } use crate::zir::IfElse; @@ -513,11 +515,13 @@ pub fn fold_struct_expression_inner<'ast, T: Field>( assert_eq!(consequence.len(), alternative.len()); - statements_buffer.push(zir::ZirStatement::IfElse( - condition.clone(), - consequence_statements, - alternative_statements, - )); + if !consequence_statements.is_empty() || !alternative_statements.is_empty() { + statements_buffer.push(zir::ZirStatement::IfElse( + condition.clone(), + consequence_statements, + alternative_statements, + )); + } use zir::IfElse; @@ -647,11 +651,13 @@ pub fn fold_field_expression<'ast, T: Field>( let consequence = f.fold_field_expression(&mut consequence_statements, consequence); let alternative = f.fold_field_expression(&mut alternative_statements, alternative); - statements_buffer.push(zir::ZirStatement::IfElse( - condition.clone(), - consequence_statements, - alternative_statements, - )); + if !consequence_statements.is_empty() || !alternative_statements.is_empty() { + statements_buffer.push(zir::ZirStatement::IfElse( + condition.clone(), + consequence_statements, + alternative_statements, + )); + } zir::FieldElementExpression::IfElse(box condition, box consequence, box alternative) } @@ -846,11 +852,13 @@ pub fn fold_boolean_expression<'ast, T: Field>( let consequence = f.fold_boolean_expression(&mut consequence_statements, consequence); let alternative = f.fold_boolean_expression(&mut alternative_statements, alternative); - statements_buffer.push(zir::ZirStatement::IfElse( - condition.clone(), - consequence_statements, - alternative_statements, - )); + if !consequence_statements.is_empty() || !alternative_statements.is_empty() { + statements_buffer.push(zir::ZirStatement::IfElse( + condition.clone(), + consequence_statements, + alternative_statements, + )); + } zir::BooleanExpression::IfElse(box condition, box consequence, box alternative) } @@ -1048,11 +1056,13 @@ pub fn fold_uint_expression_inner<'ast, T: Field>( let consequence = f.fold_uint_expression(&mut consequence_statements, consequence); let alternative = f.fold_uint_expression(&mut alternative_statements, alternative); - statements_buffer.push(zir::ZirStatement::IfElse( - condition.clone(), - consequence_statements, - alternative_statements, - )); + if !consequence_statements.is_empty() || !alternative_statements.is_empty() { + statements_buffer.push(zir::ZirStatement::IfElse( + condition.clone(), + consequence_statements, + alternative_statements, + )); + } zir::UExpressionInner::IfElse(box condition, box consequence, box alternative) } diff --git a/zokrates_core_test/tests/tests/cached_condition.json b/zokrates_core_test/tests/tests/cached_condition.json new file mode 100644 index 00000000..5601f26f --- /dev/null +++ b/zokrates_core_test/tests/tests/cached_condition.json @@ -0,0 +1,4 @@ +{ + "entry_point": "./tests/tests/cached_condition.zok", + "max_constraint_count": 2015 +} \ No newline at end of file diff --git a/zokrates_core_test/tests/tests/cached_condition.zok b/zokrates_core_test/tests/tests/cached_condition.zok new file mode 100644 index 00000000..b58544c1 --- /dev/null +++ b/zokrates_core_test/tests/tests/cached_condition.zok @@ -0,0 +1,4 @@ +// `a < b` should be flattened a single time, even if 1000 elements are assigned conditionally + +def main(field a, field b) -> field[1000]: + return if a < b then [0f; 1000] else [1; 1000] fi \ No newline at end of file From e13741ba2bde76f38c7274c0ccf8ae43ed305ee8 Mon Sep 17 00:00:00 2001 From: schaeff Date: Fri, 4 Jun 2021 13:23:21 +0200 Subject: [PATCH 2/4] clippy, changelog --- changelogs/unreleased/907-schaeff | 1 + zokrates_core/src/flatten/mod.rs | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) create mode 100644 changelogs/unreleased/907-schaeff diff --git a/changelogs/unreleased/907-schaeff b/changelogs/unreleased/907-schaeff new file mode 100644 index 00000000..f948a79a --- /dev/null +++ b/changelogs/unreleased/907-schaeff @@ -0,0 +1 @@ +Reduce the cost of conditionals \ No newline at end of file diff --git a/zokrates_core/src/flatten/mod.rs b/zokrates_core/src/flatten/mod.rs index efaa0163..34f5f4c9 100644 --- a/zokrates_core/src/flatten/mod.rs +++ b/zokrates_core/src/flatten/mod.rs @@ -644,7 +644,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { ) -> FlatExpression { // check the cache if let Some(c) = self.condition_cache.get(&expression) { - return c.clone().into(); + return (*c).into(); } match expression { From 52673ab70ee311bac2bebfc51ac1be40d8401dfa Mon Sep 17 00:00:00 2001 From: schaeff Date: Fri, 4 Jun 2021 17:54:22 +0200 Subject: [PATCH 3/4] add to cache on statement branching --- zokrates_core/src/flatten/mod.rs | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/zokrates_core/src/flatten/mod.rs b/zokrates_core/src/flatten/mod.rs index 34f5f4c9..7dbdfc2f 100644 --- a/zokrates_core/src/flatten/mod.rs +++ b/zokrates_core/src/flatten/mod.rs @@ -416,6 +416,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { name_x_or_y.into(), T::one().into(), )); + output } s => vec![s], @@ -2194,7 +2195,13 @@ impl<'ast, T: Field> Flattener<'ast, T> { })); } ZirStatement::IfElse(condition, consequence, alternative) => { - let condition = self.flatten_boolean_expression(statements_flattened, condition); + let condition_flat = + self.flatten_boolean_expression(statements_flattened, condition.clone()); + + let condition_id = self.use_sym(); + statements_flattened.push(FlatStatement::Definition(condition_id, condition_flat)); + + self.condition_cache.insert(condition, condition_id); if self.config.isolate_branches { let mut consequence_statements = vec![]; @@ -2208,10 +2215,13 @@ impl<'ast, T: Field> Flattener<'ast, T> { .for_each(|s| self.flatten_statement(&mut alternative_statements, s)); let consequence_statements = - self.make_conditional(consequence_statements, condition.clone()); + self.make_conditional(consequence_statements, condition_id.clone().into()); let alternative_statements = self.make_conditional( alternative_statements, - FlatExpression::Sub(box FlatExpression::Number(T::one()), box condition), + FlatExpression::Sub( + box FlatExpression::Number(T::one()), + box condition_id.into(), + ), ); statements_flattened.extend(consequence_statements); From 46db99dfbe0d27d0941839c9234c6d013f5d0bba Mon Sep 17 00:00:00 2001 From: schaeff Date: Tue, 8 Jun 2021 10:28:44 +0200 Subject: [PATCH 4/4] clippy --- zokrates_core/src/flatten/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/zokrates_core/src/flatten/mod.rs b/zokrates_core/src/flatten/mod.rs index 7dbdfc2f..712fa848 100644 --- a/zokrates_core/src/flatten/mod.rs +++ b/zokrates_core/src/flatten/mod.rs @@ -2215,7 +2215,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { .for_each(|s| self.flatten_statement(&mut alternative_statements, s)); let consequence_statements = - self.make_conditional(consequence_statements, condition_id.clone().into()); + self.make_conditional(consequence_statements, condition_id.into()); let alternative_statements = self.make_conditional( alternative_statements, FlatExpression::Sub(