implement function call memoization
This commit is contained in:
parent
6da9911493
commit
fcdbed507a
7 changed files with 105 additions and 10 deletions
|
@ -26,22 +26,30 @@ use zokrates_field::field::Field;
|
||||||
pub struct Inliner<'ast, T: Field> {
|
pub struct Inliner<'ast, T: Field> {
|
||||||
modules: TypedModules<'ast, T>, // the modules in which to look for functions when inlining
|
modules: TypedModules<'ast, T>, // the modules in which to look for functions when inlining
|
||||||
module_id: TypedModuleId, // the current module we're visiting
|
module_id: TypedModuleId, // the current module we're visiting
|
||||||
|
function_key: FunctionKey<'ast>, // the current function we're visiting
|
||||||
statement_buffer: Vec<TypedStatement<'ast, T>>, // a buffer of statements to be added to the inlined statements
|
statement_buffer: Vec<TypedStatement<'ast, T>>, // a buffer of statements to be added to the inlined statements
|
||||||
stack: Vec<(TypedModuleId, FunctionKey<'ast>, usize)>, // the current call stack
|
stack: Vec<(TypedModuleId, FunctionKey<'ast>, usize)>, // the current call stack
|
||||||
call_count: HashMap<(TypedModuleId, FunctionKey<'ast>), usize>, // the call count for each function
|
call_count: HashMap<(TypedModuleId, FunctionKey<'ast>), usize>, // the call count for each function
|
||||||
|
call_cache: HashMap<
|
||||||
|
(TypedModuleId, FunctionKey<'ast>),
|
||||||
|
HashMap<(FunctionKey<'ast>, Vec<TypedExpression<'ast, T>>), Vec<TypedExpression<'ast, T>>>,
|
||||||
|
>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'ast, T: Field> Inliner<'ast, T> {
|
impl<'ast, T: Field> Inliner<'ast, T> {
|
||||||
fn with_modules_and_module_id<S: Into<TypedModuleId>>(
|
fn with_modules_and_module_id<S: Into<TypedModuleId>>(
|
||||||
modules: TypedModules<'ast, T>,
|
modules: TypedModules<'ast, T>,
|
||||||
module_id: S,
|
module_id: S,
|
||||||
|
key: FunctionKey<'ast>,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
Inliner {
|
Inliner {
|
||||||
modules,
|
modules,
|
||||||
module_id: module_id.into(),
|
module_id: module_id.into(),
|
||||||
|
function_key: key,
|
||||||
statement_buffer: vec![],
|
statement_buffer: vec![],
|
||||||
stack: vec![],
|
stack: vec![],
|
||||||
call_count: HashMap::new(),
|
call_count: HashMap::new(),
|
||||||
|
call_cache: HashMap::new(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -59,7 +67,8 @@ impl<'ast, T: Field> Inliner<'ast, T> {
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
// initialize an inliner over all modules, starting from the main module
|
// initialize an inliner over all modules, starting from the main module
|
||||||
let mut inliner = Inliner::with_modules_and_module_id(p.modules, main_module_id);
|
let mut inliner =
|
||||||
|
Inliner::with_modules_and_module_id(p.modules, main_module_id, main_key.clone());
|
||||||
|
|
||||||
// inline all calls in the main function, recursively
|
// inline all calls in the main function, recursively
|
||||||
let main = inliner.fold_function_symbol(main);
|
let main = inliner.fold_function_symbol(main);
|
||||||
|
@ -101,10 +110,18 @@ impl<'ast, T: Field> Inliner<'ast, T> {
|
||||||
expressions: Vec<TypedExpression<'ast, T>>,
|
expressions: Vec<TypedExpression<'ast, T>>,
|
||||||
) -> Result<Vec<TypedExpression<'ast, T>>, (FunctionKey<'ast>, Vec<TypedExpression<'ast, T>>)>
|
) -> Result<Vec<TypedExpression<'ast, T>>, (FunctionKey<'ast>, Vec<TypedExpression<'ast, T>>)>
|
||||||
{
|
{
|
||||||
|
match self.call_cache().get(&(key.clone(), expressions.clone())) {
|
||||||
|
Some(exprs) => return Ok(exprs.clone()),
|
||||||
|
None => {}
|
||||||
|
};
|
||||||
|
|
||||||
// here we clone a function symbol, which is cheap except when it contains the function body, in which case we'd clone anyways
|
// here we clone a function symbol, which is cheap except when it contains the function body, in which case we'd clone anyways
|
||||||
match self.module().functions.get(&key).unwrap().clone() {
|
let res = match self.module().functions.get(&key).unwrap().clone() {
|
||||||
// if the function called is in the same module, we can go ahead and inline in this module
|
// if the function called is in the same module, we can go ahead and inline in this module
|
||||||
TypedFunctionSymbol::Here(function) => {
|
TypedFunctionSymbol::Here(function) => {
|
||||||
|
let (current_module, current_key) =
|
||||||
|
self.change_context(self.module_id.clone(), key.clone());
|
||||||
|
|
||||||
// increase the number of calls for this function by one
|
// increase the number of calls for this function by one
|
||||||
let count = self
|
let count = self
|
||||||
.call_count
|
.call_count
|
||||||
|
@ -118,7 +135,7 @@ impl<'ast, T: Field> Inliner<'ast, T> {
|
||||||
let inputs_bindings: Vec<_> = function
|
let inputs_bindings: Vec<_> = function
|
||||||
.arguments
|
.arguments
|
||||||
.iter()
|
.iter()
|
||||||
.zip(expressions)
|
.zip(expressions.clone())
|
||||||
.map(|(a, e)| {
|
.map(|(a, e)| {
|
||||||
TypedStatement::Definition(
|
TypedStatement::Definition(
|
||||||
self.fold_assignee(TypedAssignee::Identifier(a.id.clone())),
|
self.fold_assignee(TypedAssignee::Identifier(a.id.clone())),
|
||||||
|
@ -145,6 +162,8 @@ impl<'ast, T: Field> Inliner<'ast, T> {
|
||||||
// pop this call from the stack
|
// pop this call from the stack
|
||||||
self.stack.pop();
|
self.stack.pop();
|
||||||
|
|
||||||
|
self.change_context(current_module, current_key);
|
||||||
|
|
||||||
match ret.pop().unwrap() {
|
match ret.pop().unwrap() {
|
||||||
TypedStatement::Return(exprs) => Ok(exprs),
|
TypedStatement::Return(exprs) => Ok(exprs),
|
||||||
_ => unreachable!(""),
|
_ => unreachable!(""),
|
||||||
|
@ -153,26 +172,59 @@ impl<'ast, T: Field> Inliner<'ast, T> {
|
||||||
// if the function called is in some other module, we switch focus to that module and call the function locally there
|
// if the function called is in some other module, we switch focus to that module and call the function locally there
|
||||||
TypedFunctionSymbol::There(function_key, module_id) => {
|
TypedFunctionSymbol::There(function_key, module_id) => {
|
||||||
// switch focus to `module_id`
|
// switch focus to `module_id`
|
||||||
let current_module = self.change_module(module_id);
|
let (current_module, current_key) =
|
||||||
|
self.change_context(module_id, function_key.clone());
|
||||||
// inline the call there
|
// inline the call there
|
||||||
let res = self.try_inline_call(&function_key, expressions)?;
|
let res = self.try_inline_call(&function_key, expressions.clone())?;
|
||||||
// switch back focus
|
// switch back focus
|
||||||
self.change_module(current_module);
|
self.change_context(current_module, current_key);
|
||||||
Ok(res)
|
Ok(res)
|
||||||
}
|
}
|
||||||
// if the function is a flat symbol, replace the call with a call to the local function we provide so it can be inlined in flattening
|
// if the function is a flat symbol, replace the call with a call to the local function we provide so it can be inlined in flattening
|
||||||
TypedFunctionSymbol::Flat(embed) => Err((embed.key::<T>(), expressions)),
|
TypedFunctionSymbol::Flat(embed) => Err((embed.key::<T>(), expressions.clone())),
|
||||||
}
|
};
|
||||||
|
|
||||||
|
res.map(|exprs| {
|
||||||
|
self.call_cache_mut()
|
||||||
|
.insert((key.clone(), expressions), exprs.clone());
|
||||||
|
exprs
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// Focus the inliner on another module with id `module_id` and return the current `module_id`
|
// Focus the inliner on another module with id `module_id` and return the current `module_id`
|
||||||
fn change_module(&mut self, module_id: TypedModuleId) -> TypedModuleId {
|
fn change_context(
|
||||||
std::mem::replace(&mut self.module_id, module_id)
|
&mut self,
|
||||||
|
module_id: TypedModuleId,
|
||||||
|
function_key: FunctionKey<'ast>,
|
||||||
|
) -> (TypedModuleId, FunctionKey<'ast>) {
|
||||||
|
let current_module = std::mem::replace(&mut self.module_id, module_id);
|
||||||
|
let current_key = std::mem::replace(&mut self.function_key, function_key);
|
||||||
|
(current_module, current_key)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn module(&self) -> &TypedModule<'ast, T> {
|
fn module(&self) -> &TypedModule<'ast, T> {
|
||||||
self.modules.get(&self.module_id).unwrap()
|
self.modules.get(&self.module_id).unwrap()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn call_cache(
|
||||||
|
&mut self,
|
||||||
|
) -> &HashMap<(FunctionKey<'ast>, Vec<TypedExpression<'ast, T>>), Vec<TypedExpression<'ast, T>>>
|
||||||
|
{
|
||||||
|
self.call_cache
|
||||||
|
.entry((self.module_id.clone().clone(), self.function_key.clone()))
|
||||||
|
.or_insert_with(|| HashMap::new())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn call_cache_mut(
|
||||||
|
&mut self,
|
||||||
|
) -> &mut HashMap<
|
||||||
|
(FunctionKey<'ast>, Vec<TypedExpression<'ast, T>>),
|
||||||
|
Vec<TypedExpression<'ast, T>>,
|
||||||
|
> {
|
||||||
|
self.call_cache
|
||||||
|
.get_mut(&(self.module_id.clone().clone(), self.function_key.clone()))
|
||||||
|
.unwrap()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'ast, T: Field> Folder<'ast, T> for Inliner<'ast, T> {
|
impl<'ast, T: Field> Folder<'ast, T> for Inliner<'ast, T> {
|
||||||
|
|
7
zokrates_core/src/static_analysis/memoize.rs
Normal file
7
zokrates_core/src/static_analysis/memoize.rs
Normal file
|
@ -0,0 +1,7 @@
|
||||||
|
use zokrates_field::field::Field;
|
||||||
|
use std::collections::HashSet;
|
||||||
|
use typed_absy::identifier::CallIdentifier;
|
||||||
|
|
||||||
|
pub struct Memoizer<'ast, T: Field> {
|
||||||
|
identifiers: HashSet<CallIdentifier<'ast, T>>
|
||||||
|
}
|
6
zokrates_core_test/tests/tests/arrays/fun_spread.json
Normal file
6
zokrates_core_test/tests/tests/arrays/fun_spread.json
Normal file
|
@ -0,0 +1,6 @@
|
||||||
|
{
|
||||||
|
"entry_point": "./tests/tests/arrays/fun_spread.zok",
|
||||||
|
"max_constraint_count": 1050,
|
||||||
|
"tests": [
|
||||||
|
]
|
||||||
|
}
|
7
zokrates_core_test/tests/tests/arrays/fun_spread.zok
Normal file
7
zokrates_core_test/tests/tests/arrays/fun_spread.zok
Normal file
|
@ -0,0 +1,7 @@
|
||||||
|
import "utils/pack/nonStrictUnpack256.zok" as unpack256
|
||||||
|
|
||||||
|
def main(field[2] inputs) -> (field[512]):
|
||||||
|
|
||||||
|
field[512] preimage512 = [...unpack256(inputs[0]), ...unpack256(inputs[1])]
|
||||||
|
|
||||||
|
return preimage512
|
3
zokrates_core_test/tests/tests/memoize/dep.zok
Normal file
3
zokrates_core_test/tests/tests/memoize/dep.zok
Normal file
|
@ -0,0 +1,3 @@
|
||||||
|
def dep(field a) -> (field): // this costs 2 constraits per call
|
||||||
|
field res = a ** 4
|
||||||
|
return res
|
6
zokrates_core_test/tests/tests/memoize/memoize.json
Normal file
6
zokrates_core_test/tests/tests/memoize/memoize.json
Normal file
|
@ -0,0 +1,6 @@
|
||||||
|
{
|
||||||
|
"entry_point": "./tests/tests/memoize/memoize.zok",
|
||||||
|
"max_constraint_count": 14,
|
||||||
|
"tests": [
|
||||||
|
]
|
||||||
|
}
|
14
zokrates_core_test/tests/tests/memoize/memoize.zok
Normal file
14
zokrates_core_test/tests/tests/memoize/memoize.zok
Normal file
|
@ -0,0 +1,14 @@
|
||||||
|
from "./dep.zok" import dep as dep
|
||||||
|
|
||||||
|
def local(field a) -> (field): // this costs 3 constraints per call
|
||||||
|
field res = a ** 8
|
||||||
|
return res // currently expressions in the return statement don't get memoized
|
||||||
|
|
||||||
|
def main(field a) -> ():
|
||||||
|
// calling a local function many times with the same arg should cost only once
|
||||||
|
local(a) + local(a) + local(a) + local(a) + local(a) == 5 * (a ** 8)
|
||||||
|
|
||||||
|
// calling an imported function many times with the same arg should cost only once
|
||||||
|
dep(a) + dep(a) + dep(a) + dep(a) + dep(a) == 4 * (a ** 4)
|
||||||
|
|
||||||
|
return
|
Loading…
Reference in a new issue