diff --git a/zokrates_core/src/flatten/mod.rs b/zokrates_core/src/flatten/mod.rs index a8b2c68c..bc72757b 100644 --- a/zokrates_core/src/flatten/mod.rs +++ b/zokrates_core/src/flatten/mod.rs @@ -33,8 +33,6 @@ pub struct Flattener<'ast, T: Field> { layout: HashMap, FlatVariable>, /// Cached bit decompositions to avoid re-generating them bits_cache: HashMap, Vec>>, - /// Cached flattened conditions for branches - condition_cache: HashMap, FlatVariable>, } trait FlattenOutput: Sized { @@ -168,7 +166,6 @@ impl<'ast, T: Field> Flattener<'ast, T> { next_var_idx: 0, layout: HashMap::new(), bits_cache: HashMap::new(), - condition_cache: HashMap::new(), } } @@ -459,8 +456,6 @@ impl<'ast, T: Field> Flattener<'ast, T> { let condition_id = self.use_sym(); statements_flattened.push(FlatStatement::Definition(condition_id, condition_flat)); - self.condition_cache.insert(condition, condition_id); - let (consequence, alternative) = if self.config.isolate_branches { let mut consequence_statements = vec![]; @@ -754,11 +749,6 @@ impl<'ast, T: Field> Flattener<'ast, T> { statements_flattened: &mut FlatStatements, expression: BooleanExpression<'ast, T>, ) -> FlatExpression { - // check the cache - if let Some(c) = self.condition_cache.get(&expression) { - return (*c).into(); - } - match expression { BooleanExpression::Identifier(x) => { FlatExpression::Identifier(*self.layout.get(&x).unwrap()) @@ -2179,8 +2169,6 @@ impl<'ast, T: Field> Flattener<'ast, T> { let condition_id = self.use_sym(); statements_flattened.push(FlatStatement::Definition(condition_id, condition_flat)); - self.condition_cache.insert(condition, condition_id); - if self.config.isolate_branches { let mut consequence_statements = vec![]; let mut alternative_statements = vec![]; diff --git a/zokrates_core/src/static_analysis/condition_redefiner.rs b/zokrates_core/src/static_analysis/condition_redefiner.rs new file mode 100644 index 00000000..600a3a3d --- /dev/null +++ b/zokrates_core/src/static_analysis/condition_redefiner.rs @@ -0,0 +1,57 @@ +use crate::typed_absy::{ + folder::*, BooleanExpression, Conditional, ConditionalExpression, ConditionalOrExpression, + CoreIdentifier, Expr, Identifier, TypedProgram, TypedStatement, Variable, +}; +use zokrates_field::Field; + +#[derive(Default)] +pub struct ConditionRedefiner<'ast, T> { + index: usize, + buffer: Vec>, +} + +impl<'ast, T: Field> ConditionRedefiner<'ast, T> { + pub fn redefine(p: TypedProgram<'ast, T>) -> TypedProgram<'ast, T> { + Self::default().fold_program(p) + } +} + +impl<'ast, T: Field> Folder<'ast, T> for ConditionRedefiner<'ast, T> { + fn fold_statement(&mut self, s: TypedStatement<'ast, T>) -> Vec> { + assert!(self.buffer.is_empty()); + let s = fold_statement(self, s); + let buffer = std::mem::take(&mut self.buffer); + buffer.into_iter().chain(s).collect() + } + + fn fold_conditional_expression + Conditional<'ast, T> + Fold<'ast, T>>( + &mut self, + _: &E::Ty, + e: ConditionalExpression<'ast, T, E>, + ) -> ConditionalOrExpression<'ast, T, E> { + let condition = self.fold_boolean_expression(*e.condition); + let condition = match condition { + condition @ BooleanExpression::Value(_) + | condition @ BooleanExpression::Identifier(_) => condition, + condition => { + let condition_id = Identifier::from(CoreIdentifier::Condition(self.index)); + self.buffer.push(TypedStatement::Definition( + Variable::boolean(condition_id.clone()).into(), + condition.into(), + )); + self.index += 1; + BooleanExpression::Identifier(condition_id) + } + }; + + let consequence = e.consequence.fold(self); + let alternative = e.alternative.fold(self); + + ConditionalOrExpression::Conditional(ConditionalExpression::new( + condition, + consequence, + alternative, + e.kind, + )) + } +} diff --git a/zokrates_core/src/static_analysis/mod.rs b/zokrates_core/src/static_analysis/mod.rs index b46582c6..241d6983 100644 --- a/zokrates_core/src/static_analysis/mod.rs +++ b/zokrates_core/src/static_analysis/mod.rs @@ -5,6 +5,7 @@ //! @date 2018 mod branch_isolator; +mod condition_redefiner; mod constant_argument_checker; mod constant_resolver; mod flat_propagation; @@ -19,6 +20,7 @@ mod variable_write_remover; mod zir_propagation; use self::branch_isolator::Isolator; +use self::condition_redefiner::ConditionRedefiner; use self::constant_argument_checker::ConstantArgumentChecker; use self::flatten_complex_types::Flattener; use self::out_of_bounds::OutOfBoundsChecker; @@ -158,6 +160,11 @@ impl<'ast, T: Field> TypedProgram<'ast, T> { let r = OutOfBoundsChecker::check(r).map_err(Error::from)?; log::trace!("\n{}", r); + // redefine conditions + log::debug!("Static analyser: Redefine conditions"); + let r = ConditionRedefiner::redefine(r); + log::trace!("\n{}", r); + // convert to zir, removing complex types log::debug!("Static analyser: Convert to zir"); let zir = Flattener::flatten(r); diff --git a/zokrates_core/src/typed_absy/identifier.rs b/zokrates_core/src/typed_absy/identifier.rs index 972a69e6..56246eb5 100644 --- a/zokrates_core/src/typed_absy/identifier.rs +++ b/zokrates_core/src/typed_absy/identifier.rs @@ -7,6 +7,7 @@ pub enum CoreIdentifier<'ast> { Source(&'ast str), Call(usize), Constant(CanonicalConstantIdentifier<'ast>), + Condition(usize), } impl<'ast> fmt::Display for CoreIdentifier<'ast> { @@ -15,6 +16,7 @@ impl<'ast> fmt::Display for CoreIdentifier<'ast> { CoreIdentifier::Source(s) => write!(f, "{}", s), CoreIdentifier::Call(i) => write!(f, "#CALL_RETURN_AT_INDEX_{}", i), CoreIdentifier::Constant(c) => write!(f, "{}/{}", c.module.display(), c.id), + CoreIdentifier::Condition(i) => write!(f, "#CONDITION_{}", i), } } }