diff --git a/zokrates_core/src/static_analysis/inline.rs b/zokrates_core/src/static_analysis/inline.rs index 4acda308..31c1494d 100644 --- a/zokrates_core/src/static_analysis/inline.rs +++ b/zokrates_core/src/static_analysis/inline.rs @@ -26,22 +26,30 @@ use zokrates_field::field::Field; pub struct Inliner<'ast, T: Field> { modules: TypedModules<'ast, T>, // the modules in which to look for functions when inlining module_id: TypedModuleId, // the current module we're visiting + function_key: FunctionKey<'ast>, // the current function we're visiting statement_buffer: Vec>, // a buffer of statements to be added to the inlined statements stack: Vec<(TypedModuleId, FunctionKey<'ast>, usize)>, // the current call stack call_count: HashMap<(TypedModuleId, FunctionKey<'ast>), usize>, // the call count for each function + call_cache: HashMap< + (TypedModuleId, FunctionKey<'ast>), + HashMap<(FunctionKey<'ast>, Vec>), Vec>>, + >, } impl<'ast, T: Field> Inliner<'ast, T> { fn with_modules_and_module_id>( modules: TypedModules<'ast, T>, module_id: S, + key: FunctionKey<'ast>, ) -> Self { Inliner { modules, module_id: module_id.into(), + function_key: key, statement_buffer: vec![], stack: vec![], call_count: HashMap::new(), + call_cache: HashMap::new(), } } @@ -59,7 +67,8 @@ impl<'ast, T: Field> Inliner<'ast, T> { .unwrap(); // initialize an inliner over all modules, starting from the main module - let mut inliner = Inliner::with_modules_and_module_id(p.modules, main_module_id); + let mut inliner = + Inliner::with_modules_and_module_id(p.modules, main_module_id, main_key.clone()); // inline all calls in the main function, recursively let main = inliner.fold_function_symbol(main); @@ -101,10 +110,18 @@ impl<'ast, T: Field> Inliner<'ast, T> { expressions: Vec>, ) -> Result>, (FunctionKey<'ast>, Vec>)> { + match self.call_cache().get(&(key.clone(), expressions.clone())) { + Some(exprs) => return Ok(exprs.clone()), + None => {} + }; + // here we clone a function symbol, which is cheap except when it contains the function body, in which case we'd clone anyways - match self.module().functions.get(&key).unwrap().clone() { + let res = match self.module().functions.get(&key).unwrap().clone() { // if the function called is in the same module, we can go ahead and inline in this module TypedFunctionSymbol::Here(function) => { + let (current_module, current_key) = + self.change_context(self.module_id.clone(), key.clone()); + // increase the number of calls for this function by one let count = self .call_count @@ -118,7 +135,7 @@ impl<'ast, T: Field> Inliner<'ast, T> { let inputs_bindings: Vec<_> = function .arguments .iter() - .zip(expressions) + .zip(expressions.clone()) .map(|(a, e)| { TypedStatement::Definition( self.fold_assignee(TypedAssignee::Identifier(a.id.clone())), @@ -145,6 +162,8 @@ impl<'ast, T: Field> Inliner<'ast, T> { // pop this call from the stack self.stack.pop(); + self.change_context(current_module, current_key); + match ret.pop().unwrap() { TypedStatement::Return(exprs) => Ok(exprs), _ => unreachable!(""), @@ -153,26 +172,59 @@ impl<'ast, T: Field> Inliner<'ast, T> { // if the function called is in some other module, we switch focus to that module and call the function locally there TypedFunctionSymbol::There(function_key, module_id) => { // switch focus to `module_id` - let current_module = self.change_module(module_id); + let (current_module, current_key) = + self.change_context(module_id, function_key.clone()); // inline the call there - let res = self.try_inline_call(&function_key, expressions)?; + let res = self.try_inline_call(&function_key, expressions.clone())?; // switch back focus - self.change_module(current_module); + self.change_context(current_module, current_key); Ok(res) } // if the function is a flat symbol, replace the call with a call to the local function we provide so it can be inlined in flattening - TypedFunctionSymbol::Flat(embed) => Err((embed.key::(), expressions)), - } + TypedFunctionSymbol::Flat(embed) => Err((embed.key::(), expressions.clone())), + }; + + res.map(|exprs| { + self.call_cache_mut() + .insert((key.clone(), expressions), exprs.clone()); + exprs + }) } // Focus the inliner on another module with id `module_id` and return the current `module_id` - fn change_module(&mut self, module_id: TypedModuleId) -> TypedModuleId { - std::mem::replace(&mut self.module_id, module_id) + fn change_context( + &mut self, + module_id: TypedModuleId, + function_key: FunctionKey<'ast>, + ) -> (TypedModuleId, FunctionKey<'ast>) { + let current_module = std::mem::replace(&mut self.module_id, module_id); + let current_key = std::mem::replace(&mut self.function_key, function_key); + (current_module, current_key) } fn module(&self) -> &TypedModule<'ast, T> { self.modules.get(&self.module_id).unwrap() } + + fn call_cache( + &mut self, + ) -> &HashMap<(FunctionKey<'ast>, Vec>), Vec>> + { + self.call_cache + .entry((self.module_id.clone().clone(), self.function_key.clone())) + .or_insert_with(|| HashMap::new()) + } + + fn call_cache_mut( + &mut self, + ) -> &mut HashMap< + (FunctionKey<'ast>, Vec>), + Vec>, + > { + self.call_cache + .get_mut(&(self.module_id.clone().clone(), self.function_key.clone())) + .unwrap() + } } impl<'ast, T: Field> Folder<'ast, T> for Inliner<'ast, T> { diff --git a/zokrates_core/src/static_analysis/memoize.rs b/zokrates_core/src/static_analysis/memoize.rs new file mode 100644 index 00000000..a27d0734 --- /dev/null +++ b/zokrates_core/src/static_analysis/memoize.rs @@ -0,0 +1,7 @@ +use zokrates_field::field::Field; +use std::collections::HashSet; +use typed_absy::identifier::CallIdentifier; + +pub struct Memoizer<'ast, T: Field> { + identifiers: HashSet> +} \ No newline at end of file diff --git a/zokrates_core_test/tests/tests/arrays/fun_spread.json b/zokrates_core_test/tests/tests/arrays/fun_spread.json new file mode 100644 index 00000000..5583d306 --- /dev/null +++ b/zokrates_core_test/tests/tests/arrays/fun_spread.json @@ -0,0 +1,6 @@ +{ + "entry_point": "./tests/tests/arrays/fun_spread.zok", + "max_constraint_count": 1050, + "tests": [ + ] +} \ No newline at end of file diff --git a/zokrates_core_test/tests/tests/arrays/fun_spread.zok b/zokrates_core_test/tests/tests/arrays/fun_spread.zok new file mode 100644 index 00000000..a62a95e0 --- /dev/null +++ b/zokrates_core_test/tests/tests/arrays/fun_spread.zok @@ -0,0 +1,7 @@ +import "utils/pack/nonStrictUnpack256.zok" as unpack256 + +def main(field[2] inputs) -> (field[512]): + + field[512] preimage512 = [...unpack256(inputs[0]), ...unpack256(inputs[1])] + + return preimage512 \ No newline at end of file diff --git a/zokrates_core_test/tests/tests/memoize/dep.zok b/zokrates_core_test/tests/tests/memoize/dep.zok new file mode 100644 index 00000000..df49dc48 --- /dev/null +++ b/zokrates_core_test/tests/tests/memoize/dep.zok @@ -0,0 +1,3 @@ +def dep(field a) -> (field): // this costs 2 constraits per call + field res = a ** 4 + return res \ No newline at end of file diff --git a/zokrates_core_test/tests/tests/memoize/memoize.json b/zokrates_core_test/tests/tests/memoize/memoize.json new file mode 100644 index 00000000..af46d1d1 --- /dev/null +++ b/zokrates_core_test/tests/tests/memoize/memoize.json @@ -0,0 +1,6 @@ +{ + "entry_point": "./tests/tests/memoize/memoize.zok", + "max_constraint_count": 14, + "tests": [ + ] +} \ No newline at end of file diff --git a/zokrates_core_test/tests/tests/memoize/memoize.zok b/zokrates_core_test/tests/tests/memoize/memoize.zok new file mode 100644 index 00000000..7d7ce93c --- /dev/null +++ b/zokrates_core_test/tests/tests/memoize/memoize.zok @@ -0,0 +1,14 @@ +from "./dep.zok" import dep as dep + +def local(field a) -> (field): // this costs 3 constraints per call + field res = a ** 8 + return res // currently expressions in the return statement don't get memoized + +def main(field a) -> (): + // calling a local function many times with the same arg should cost only once + local(a) + local(a) + local(a) + local(a) + local(a) == 5 * (a ** 8) + + // calling an imported function many times with the same arg should cost only once + dep(a) + dep(a) + dep(a) + dep(a) + dep(a) == 4 * (a ** 4) + + return