1
0
Fork 0
mirror of synced 2025-09-24 04:40:05 +00:00

add flat resolving

This commit is contained in:
schaeff 2019-07-01 19:17:54 +02:00
parent 2c8370138e
commit 9e14f28ff4
3 changed files with 271 additions and 7 deletions

View file

@ -0,0 +1,191 @@
// Turn nested imports which end by a flat symbol into non-nested imports of that flat symbol
// Before
// // main.code
// import "myflatfun.code" as foo
//
// // myflatfun.code
// def myflatfun() -> ()
// // hidden (flat code)
// After
// // main.code
// def foo() -> ():
// // hidden (flat code)
//
// // myflatfun.code
// def myflatfun() -> ()
// // hidden (flat code)
use crate::typed_absy::folder::*;
use crate::typed_absy::Folder;
use crate::typed_absy::*;
use absy::ModuleId;
use flat_absy::FlatFunction;
use types::FunctionKey;
use zokrates_field::field::Field;
pub struct FlatResolver<'ast, T: Field> {
modules: TypedModules<'ast, T>,
}
impl<'ast, T: Field> FlatResolver<'ast, T> {
fn with_modules(m: TypedModules<'ast, T>) -> Self {
FlatResolver { modules: m }
}
pub fn resolve(p: TypedProgram<T>) -> TypedProgram<T> {
FlatResolver::with_modules(p.modules.clone()).fold_program(p)
}
fn resolve_external_symbol(
&self,
key: &FunctionKey,
module_id: &ModuleId,
) -> Option<FlatFunction<T>> {
match self
.modules
.get(module_id)
.unwrap()
.functions
.get(&key)
.unwrap()
{
TypedFunctionSymbol::There(key, module_id) => {
self.resolve_external_symbol(key, module_id)
}
TypedFunctionSymbol::Flat(f) => Some(f.clone()),
_ => None,
}
}
}
impl<'ast, T: Field> Folder<'ast, T> for FlatResolver<'ast, T> {
fn fold_function_symbol(
&mut self,
s: TypedFunctionSymbol<'ast, T>,
) -> TypedFunctionSymbol<'ast, T> {
match s {
TypedFunctionSymbol::There(key, module_id) => {
match self.resolve_external_symbol(&key, &module_id) {
Some(f) => TypedFunctionSymbol::Flat(f),
None => TypedFunctionSymbol::There(key, module_id),
}
}
s => fold_function_symbol(self, s),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use flat_absy::*;
use types::{FunctionKey, Signature, Type};
use zokrates_field::field::FieldPrime;
#[test]
fn remove_linked_flat_import() {
let flat_fun = FlatFunction {
arguments: vec![],
signature: Signature::new().outputs(vec![Type::FieldElement]),
statements: vec![FlatStatement::Return(FlatExpressionList {
expressions: vec![FlatExpression::Number(FieldPrime::from(42))],
})],
};
let main = TypedModule {
functions: vec![
(
FunctionKey::with_id("foo")
.signature(Signature::new().outputs(vec![Type::FieldElement])),
TypedFunctionSymbol::There(
FunctionKey::with_id("myflatfun")
.signature(Signature::new().outputs(vec![Type::FieldElement])),
String::from("other"),
),
),
(
FunctionKey::with_id("main")
.signature(Signature::new().outputs(vec![Type::FieldElement])),
TypedFunctionSymbol::Here(TypedFunction {
signature: Signature::new().outputs(vec![Type::FieldElement]),
arguments: vec![],
statements: vec![TypedStatement::Return(vec![
FieldElementExpression::FunctionCall(
FunctionKey::with_id("foo")
.signature(Signature::new().outputs(vec![Type::FieldElement])),
vec![],
)
.into(),
])],
}),
),
]
.into_iter()
.collect(),
imports: vec![],
};
let other = TypedModule {
functions: vec![(
FunctionKey::with_id("myflatfun")
.signature(Signature::new().outputs(vec![Type::FieldElement])),
TypedFunctionSymbol::Flat(flat_fun.clone()),
)]
.into_iter()
.collect(),
imports: vec![],
};
let prog = TypedProgram {
main: String::from("main"),
modules: vec![
(String::from("main"), main),
(String::from("other"), other.clone()),
]
.into_iter()
.collect(),
};
let new_main = TypedModule {
functions: vec![
(
FunctionKey::with_id("foo")
.signature(Signature::new().outputs(vec![Type::FieldElement])),
TypedFunctionSymbol::Flat(flat_fun),
),
(
FunctionKey::with_id("main")
.signature(Signature::new().outputs(vec![Type::FieldElement])),
TypedFunctionSymbol::Here(TypedFunction {
signature: Signature::new().outputs(vec![Type::FieldElement]),
arguments: vec![],
statements: vec![TypedStatement::Return(vec![
FieldElementExpression::FunctionCall(
FunctionKey::with_id("foo")
.signature(Signature::new().outputs(vec![Type::FieldElement])),
vec![],
)
.into(),
])],
}),
),
]
.into_iter()
.collect(),
imports: vec![],
};
let expected_prog = TypedProgram {
main: String::from("main"),
modules: vec![
(String::from("main"), new_main),
(String::from("other"), other),
]
.into_iter()
.collect(),
};
assert_eq!(FlatResolver::resolve(prog), expected_prog);
}
}

View file

@ -132,7 +132,12 @@ impl<'ast, T: Field> Inliner<'ast, T> {
// if the function called is in some other module, we switch context to that module and call the function locally there
TypedFunctionSymbol::There(function_key, module_id) => {
let current_module = self.change_module(module_id);
let res = self.try_inline_call(&function_key, expressions)?;
let res = self
.try_inline_call(&function_key, expressions)
.expect(&format!(
"inlining external symbols should always succeed, failed for {:?}",
function_key
));
self.change_module(current_module);
Ok(res)
}
@ -366,12 +371,76 @@ impl<'ast, T: Field> Folder<'ast, T> for Inliner<'ast, T> {
}
}
// #[cfg(test)]
// mod tests {
// use super::*;
// use std::collections::HashMap;
// use types::{FunctionKey, Signature, Type};
// use zokrates_field::field::FieldPrime;
#[cfg(test)]
mod tests {
use super::*;
use flat_absy::*;
use std::collections::HashMap;
use types::{FunctionKey, Signature, Type};
use zokrates_field::field::FieldPrime;
#[test]
#[should_panic]
fn non_resolved_flat_call() {
let main = TypedModule {
functions: vec![
(
FunctionKey::with_id("foo")
.signature(Signature::new().outputs(vec![Type::FieldElement])),
TypedFunctionSymbol::There(
FunctionKey::with_id("myflatfun")
.signature(Signature::new().outputs(vec![Type::FieldElement])),
String::from("other"),
),
),
(
FunctionKey::with_id("main")
.signature(Signature::new().outputs(vec![Type::FieldElement])),
TypedFunctionSymbol::Here(TypedFunction {
signature: Signature::new().outputs(vec![Type::FieldElement]),
arguments: vec![],
statements: vec![TypedStatement::Return(vec![
FieldElementExpression::FunctionCall(
FunctionKey::with_id("foo")
.signature(Signature::new().outputs(vec![Type::FieldElement])),
vec![],
)
.into(),
])],
}),
),
]
.into_iter()
.collect(),
imports: vec![],
};
let other = TypedModule {
functions: vec![(
FunctionKey::with_id("myflatfun")
.signature(Signature::new().outputs(vec![Type::FieldElement])),
TypedFunctionSymbol::Flat(FlatFunction {
arguments: vec![],
signature: Signature::new().outputs(vec![Type::FieldElement]),
statements: vec![FlatStatement::Return(FlatExpressionList {
expressions: vec![FlatExpression::Number(FieldPrime::from(42))],
})],
}),
)]
.into_iter()
.collect(),
imports: vec![],
};
let prog = TypedProgram {
main: String::from("main"),
modules: vec![(String::from("main"), main), (String::from("other"), other)]
.into_iter()
.collect(),
};
let _ = Inliner::inline(prog);
}
}
// <<<<<<< HEAD
// #[test]

View file

@ -6,12 +6,14 @@
mod core_lib_injector;
mod flat_propagation;
mod flat_resolver;
mod inline;
mod power_check;
mod propagation;
mod unroll;
pub use self::core_lib_injector::CoreLibInjector;
use self::flat_resolver::FlatResolver;
use self::inline::Inliner;
use self::power_check::PowerChecker;
use self::propagation::Propagator;
@ -27,6 +29,8 @@ pub trait Analyse {
impl<'ast, T: Field> Analyse for TypedProgram<'ast, T> {
fn analyse(self) -> Self {
let r = PowerChecker::check(self);
// remove chains of imports ending by a flat function
let r = FlatResolver::resolve(r);
// unroll
let r = Unroller::unroll(r);
// inline