1
0
Fork 0
mirror of synced 2025-09-23 12:18:44 +00:00

implement function call memoization

This commit is contained in:
schaeff 2020-03-02 18:14:05 +01:00
parent 6da9911493
commit fcdbed507a
7 changed files with 105 additions and 10 deletions

View file

@ -26,22 +26,30 @@ use zokrates_field::field::Field;
pub struct Inliner<'ast, T: Field> { pub struct Inliner<'ast, T: Field> {
modules: TypedModules<'ast, T>, // the modules in which to look for functions when inlining modules: TypedModules<'ast, T>, // the modules in which to look for functions when inlining
module_id: TypedModuleId, // the current module we're visiting module_id: TypedModuleId, // the current module we're visiting
function_key: FunctionKey<'ast>, // the current function we're visiting
statement_buffer: Vec<TypedStatement<'ast, T>>, // a buffer of statements to be added to the inlined statements statement_buffer: Vec<TypedStatement<'ast, T>>, // a buffer of statements to be added to the inlined statements
stack: Vec<(TypedModuleId, FunctionKey<'ast>, usize)>, // the current call stack stack: Vec<(TypedModuleId, FunctionKey<'ast>, usize)>, // the current call stack
call_count: HashMap<(TypedModuleId, FunctionKey<'ast>), usize>, // the call count for each function call_count: HashMap<(TypedModuleId, FunctionKey<'ast>), usize>, // the call count for each function
call_cache: HashMap<
(TypedModuleId, FunctionKey<'ast>),
HashMap<(FunctionKey<'ast>, Vec<TypedExpression<'ast, T>>), Vec<TypedExpression<'ast, T>>>,
>,
} }
impl<'ast, T: Field> Inliner<'ast, T> { impl<'ast, T: Field> Inliner<'ast, T> {
fn with_modules_and_module_id<S: Into<TypedModuleId>>( fn with_modules_and_module_id<S: Into<TypedModuleId>>(
modules: TypedModules<'ast, T>, modules: TypedModules<'ast, T>,
module_id: S, module_id: S,
key: FunctionKey<'ast>,
) -> Self { ) -> Self {
Inliner { Inliner {
modules, modules,
module_id: module_id.into(), module_id: module_id.into(),
function_key: key,
statement_buffer: vec![], statement_buffer: vec![],
stack: vec![], stack: vec![],
call_count: HashMap::new(), call_count: HashMap::new(),
call_cache: HashMap::new(),
} }
} }
@ -59,7 +67,8 @@ impl<'ast, T: Field> Inliner<'ast, T> {
.unwrap(); .unwrap();
// initialize an inliner over all modules, starting from the main module // initialize an inliner over all modules, starting from the main module
let mut inliner = Inliner::with_modules_and_module_id(p.modules, main_module_id); let mut inliner =
Inliner::with_modules_and_module_id(p.modules, main_module_id, main_key.clone());
// inline all calls in the main function, recursively // inline all calls in the main function, recursively
let main = inliner.fold_function_symbol(main); let main = inliner.fold_function_symbol(main);
@ -101,10 +110,18 @@ impl<'ast, T: Field> Inliner<'ast, T> {
expressions: Vec<TypedExpression<'ast, T>>, expressions: Vec<TypedExpression<'ast, T>>,
) -> Result<Vec<TypedExpression<'ast, T>>, (FunctionKey<'ast>, Vec<TypedExpression<'ast, T>>)> ) -> Result<Vec<TypedExpression<'ast, T>>, (FunctionKey<'ast>, Vec<TypedExpression<'ast, T>>)>
{ {
match self.call_cache().get(&(key.clone(), expressions.clone())) {
Some(exprs) => return Ok(exprs.clone()),
None => {}
};
// here we clone a function symbol, which is cheap except when it contains the function body, in which case we'd clone anyways // here we clone a function symbol, which is cheap except when it contains the function body, in which case we'd clone anyways
match self.module().functions.get(&key).unwrap().clone() { let res = match self.module().functions.get(&key).unwrap().clone() {
// if the function called is in the same module, we can go ahead and inline in this module // if the function called is in the same module, we can go ahead and inline in this module
TypedFunctionSymbol::Here(function) => { TypedFunctionSymbol::Here(function) => {
let (current_module, current_key) =
self.change_context(self.module_id.clone(), key.clone());
// increase the number of calls for this function by one // increase the number of calls for this function by one
let count = self let count = self
.call_count .call_count
@ -118,7 +135,7 @@ impl<'ast, T: Field> Inliner<'ast, T> {
let inputs_bindings: Vec<_> = function let inputs_bindings: Vec<_> = function
.arguments .arguments
.iter() .iter()
.zip(expressions) .zip(expressions.clone())
.map(|(a, e)| { .map(|(a, e)| {
TypedStatement::Definition( TypedStatement::Definition(
self.fold_assignee(TypedAssignee::Identifier(a.id.clone())), self.fold_assignee(TypedAssignee::Identifier(a.id.clone())),
@ -145,6 +162,8 @@ impl<'ast, T: Field> Inliner<'ast, T> {
// pop this call from the stack // pop this call from the stack
self.stack.pop(); self.stack.pop();
self.change_context(current_module, current_key);
match ret.pop().unwrap() { match ret.pop().unwrap() {
TypedStatement::Return(exprs) => Ok(exprs), TypedStatement::Return(exprs) => Ok(exprs),
_ => unreachable!(""), _ => unreachable!(""),
@ -153,26 +172,59 @@ impl<'ast, T: Field> Inliner<'ast, T> {
// if the function called is in some other module, we switch focus to that module and call the function locally there // if the function called is in some other module, we switch focus to that module and call the function locally there
TypedFunctionSymbol::There(function_key, module_id) => { TypedFunctionSymbol::There(function_key, module_id) => {
// switch focus to `module_id` // switch focus to `module_id`
let current_module = self.change_module(module_id); let (current_module, current_key) =
self.change_context(module_id, function_key.clone());
// inline the call there // inline the call there
let res = self.try_inline_call(&function_key, expressions)?; let res = self.try_inline_call(&function_key, expressions.clone())?;
// switch back focus // switch back focus
self.change_module(current_module); self.change_context(current_module, current_key);
Ok(res) Ok(res)
} }
// if the function is a flat symbol, replace the call with a call to the local function we provide so it can be inlined in flattening // if the function is a flat symbol, replace the call with a call to the local function we provide so it can be inlined in flattening
TypedFunctionSymbol::Flat(embed) => Err((embed.key::<T>(), expressions)), TypedFunctionSymbol::Flat(embed) => Err((embed.key::<T>(), expressions.clone())),
} };
res.map(|exprs| {
self.call_cache_mut()
.insert((key.clone(), expressions), exprs.clone());
exprs
})
} }
// Focus the inliner on another module with id `module_id` and return the current `module_id` // Focus the inliner on another module with id `module_id` and return the current `module_id`
fn change_module(&mut self, module_id: TypedModuleId) -> TypedModuleId { fn change_context(
std::mem::replace(&mut self.module_id, module_id) &mut self,
module_id: TypedModuleId,
function_key: FunctionKey<'ast>,
) -> (TypedModuleId, FunctionKey<'ast>) {
let current_module = std::mem::replace(&mut self.module_id, module_id);
let current_key = std::mem::replace(&mut self.function_key, function_key);
(current_module, current_key)
} }
fn module(&self) -> &TypedModule<'ast, T> { fn module(&self) -> &TypedModule<'ast, T> {
self.modules.get(&self.module_id).unwrap() self.modules.get(&self.module_id).unwrap()
} }
fn call_cache(
&mut self,
) -> &HashMap<(FunctionKey<'ast>, Vec<TypedExpression<'ast, T>>), Vec<TypedExpression<'ast, T>>>
{
self.call_cache
.entry((self.module_id.clone().clone(), self.function_key.clone()))
.or_insert_with(|| HashMap::new())
}
fn call_cache_mut(
&mut self,
) -> &mut HashMap<
(FunctionKey<'ast>, Vec<TypedExpression<'ast, T>>),
Vec<TypedExpression<'ast, T>>,
> {
self.call_cache
.get_mut(&(self.module_id.clone().clone(), self.function_key.clone()))
.unwrap()
}
} }
impl<'ast, T: Field> Folder<'ast, T> for Inliner<'ast, T> { impl<'ast, T: Field> Folder<'ast, T> for Inliner<'ast, T> {

View file

@ -0,0 +1,7 @@
use zokrates_field::field::Field;
use std::collections::HashSet;
use typed_absy::identifier::CallIdentifier;
pub struct Memoizer<'ast, T: Field> {
identifiers: HashSet<CallIdentifier<'ast, T>>
}

View file

@ -0,0 +1,6 @@
{
"entry_point": "./tests/tests/arrays/fun_spread.zok",
"max_constraint_count": 1050,
"tests": [
]
}

View file

@ -0,0 +1,7 @@
import "utils/pack/nonStrictUnpack256.zok" as unpack256
def main(field[2] inputs) -> (field[512]):
field[512] preimage512 = [...unpack256(inputs[0]), ...unpack256(inputs[1])]
return preimage512

View file

@ -0,0 +1,3 @@
def dep(field a) -> (field): // this costs 2 constraits per call
field res = a ** 4
return res

View file

@ -0,0 +1,6 @@
{
"entry_point": "./tests/tests/memoize/memoize.zok",
"max_constraint_count": 14,
"tests": [
]
}

View file

@ -0,0 +1,14 @@
from "./dep.zok" import dep as dep
def local(field a) -> (field): // this costs 3 constraints per call
field res = a ** 8
return res // currently expressions in the return statement don't get memoized
def main(field a) -> ():
// calling a local function many times with the same arg should cost only once
local(a) + local(a) + local(a) + local(a) + local(a) == 5 * (a ** 8)
// calling an imported function many times with the same arg should cost only once
dep(a) + dep(a) + dep(a) + dep(a) + dep(a) == 4 * (a ** 4)
return