From 759f4dd5534b779c0d73172a94a91ce63ccaac48 Mon Sep 17 00:00:00 2001 From: schaeff Date: Tue, 15 Jun 2021 16:41:38 +0200 Subject: [PATCH] fold module id in folders --- access.zok | 2 + .../examples/imports/import_with_alias.zok | 4 +- .../src/static_analysis/constant_inliner.rs | 56 +++++++------------ zokrates_core/src/typed_absy/folder.rs | 29 ++++++++-- zokrates_core/src/typed_absy/result_folder.rs | 26 ++++++++- 5 files changed, 73 insertions(+), 44 deletions(-) create mode 100644 access.zok diff --git a/access.zok b/access.zok new file mode 100644 index 00000000..7d01f505 --- /dev/null +++ b/access.zok @@ -0,0 +1,2 @@ +def main(field[3][2] a, u32 index) -> field[2]: + return a[index] \ No newline at end of file diff --git a/zokrates_cli/examples/imports/import_with_alias.zok b/zokrates_cli/examples/imports/import_with_alias.zok index 5e013691..7302c4e2 100644 --- a/zokrates_cli/examples/imports/import_with_alias.zok +++ b/zokrates_cli/examples/imports/import_with_alias.zok @@ -1,8 +1,8 @@ -from "./bar" import main as bar +//from "./bar" import main as bar from "./baz" import BAZ as baz import "./foo" as f def main() -> field: field foo = f() - assert(foo == bar() + baz) + assert(foo == 21 + baz) return foo \ No newline at end of file diff --git a/zokrates_core/src/static_analysis/constant_inliner.rs b/zokrates_core/src/static_analysis/constant_inliner.rs index 634bcef5..e92bd259 100644 --- a/zokrates_core/src/static_analysis/constant_inliner.rs +++ b/zokrates_core/src/static_analysis/constant_inliner.rs @@ -48,7 +48,7 @@ impl<'ast, 'a, T: Field> ConstantInliner<'ast, T> { assert_eq!(id.version, 0); match id.id { CoreIdentifier::Call(..) => { - unreachable!("calls indentifiers are only available after call inlining") + unreachable!("calls identifiers are only available after call inlining") } CoreIdentifier::Source(id) => self .constants @@ -60,46 +60,34 @@ impl<'ast, 'a, T: Field> ConstantInliner<'ast, T> { } impl<'ast, T: Field> Folder<'ast, T> for ConstantInliner<'ast, T> { - fn fold_program(&mut self, p: TypedProgram<'ast, T>) -> TypedProgram<'ast, T> { - TypedProgram { - modules: p - .modules - .into_iter() - .map(|(m_id, m)| { - if !self.treated(&m_id) { - self.change_location(m_id.clone()); - (m_id, self.fold_module(m)) - } else { - (m_id, m) - } - }) - .collect(), - ..p + fn fold_module_id(&mut self, id: OwnedTypedModuleId) -> OwnedTypedModuleId { + // anytime we encounter a module id, visit the corresponding module if it hasn't been done yet + if !self.treated(&id) { + let current_m_id = self.change_location(id.clone()); + self.constants.entry(self.location.clone()).or_default(); + let m = self.fold_module(self.modules.get(&id).unwrap().clone()); + + self.modules.insert(id.clone(), m); + + self.change_location(current_m_id); } + id } fn fold_module(&mut self, m: TypedModule<'ast, T>) -> TypedModule<'ast, T> { - assert!(self - .constants - .insert(self.location.clone(), Default::default()) - .is_none()); + // initialise a constant map for this module + self.constants.entry(self.location.clone()).or_default(); + TypedModule { constants: m .constants .into_iter() .map(|(id, tc)| { + let id = self.fold_canonical_constant_identifier(id); + let constant = match tc { TypedConstantSymbol::There(imported_id) => { - if !self.treated(&imported_id.module) { - let current_m_id = self.change_location(imported_id.module.clone()); - let m = self.fold_module( - self.modules.get(&imported_id.module).unwrap().clone(), - ); - - self.modules.insert(imported_id.module.clone(), m); - - self.change_location(current_m_id); - } + let imported_id = self.fold_canonical_constant_identifier(imported_id); self.constants .get(&imported_id.module) .unwrap() @@ -115,12 +103,10 @@ impl<'ast, T: Field> Folder<'ast, T> for ConstantInliner<'ast, T> { .fold_expression(constant) .unwrap(); - assert!(self - .constants + self.constants .entry(self.location.clone()) .or_default() - .insert(id.id.into(), constant.clone()) - .is_none()); + .insert(id.id.into(), constant.clone()); ( id, @@ -176,7 +162,7 @@ impl<'ast, T: Field> Folder<'ast, T> for ConstantInliner<'ast, T> { ) -> FieldElementExpression<'ast, T> { match e { FieldElementExpression::Identifier(ref id) => match self.get_constant(id) { - Some(c) => self.fold_expression(c).try_into().unwrap(), + Some(c) => c.try_into().unwrap(), None => fold_field_expression(self, e), }, e => fold_field_expression(self, e), diff --git a/zokrates_core/src/typed_absy/folder.rs b/zokrates_core/src/typed_absy/folder.rs index 22f3e229..41a39706 100644 --- a/zokrates_core/src/typed_absy/folder.rs +++ b/zokrates_core/src/typed_absy/folder.rs @@ -215,6 +215,20 @@ pub trait Folder<'ast, T: Field>: Sized { } } + fn fold_canonical_constant_identifier( + &mut self, + i: CanonicalConstantIdentifier<'ast>, + ) -> CanonicalConstantIdentifier<'ast> { + CanonicalConstantIdentifier { + module: self.fold_module_id(i.module), + id: i.id, + } + } + + fn fold_module_id(&mut self, i: OwnedTypedModuleId) -> OwnedTypedModuleId { + i + } + fn fold_expression(&mut self, e: TypedExpression<'ast, T>) -> TypedExpression<'ast, T> { match e { TypedExpression::FieldElement(e) => self.fold_field_expression(e).into(), @@ -342,7 +356,12 @@ pub fn fold_module<'ast, T: Field, F: Folder<'ast, T>>( constants: m .constants .into_iter() - .map(|(key, tc)| (key, f.fold_constant_symbol(tc))) + .map(|(id, tc)| { + ( + f.fold_canonical_constant_identifier(id), + f.fold_constant_symbol(tc), + ) + }) .collect(), functions: m .functions @@ -1014,7 +1033,9 @@ pub fn fold_constant_symbol<'ast, T: Field, F: Folder<'ast, T>>( ) -> TypedConstantSymbol<'ast, T> { match s { TypedConstantSymbol::Here(tc) => TypedConstantSymbol::Here(f.fold_constant(tc)), - there => there, + TypedConstantSymbol::There(id) => { + TypedConstantSymbol::There(f.fold_canonical_constant_identifier(id)) + } } } @@ -1052,8 +1073,8 @@ pub fn fold_program<'ast, T: Field, F: Folder<'ast, T>>( modules: p .modules .into_iter() - .map(|(module_id, module)| (module_id, f.fold_module(module))) + .map(|(module_id, module)| (f.fold_module_id(module_id), f.fold_module(module))) .collect(), - main: p.main, + main: f.fold_module_id(p.main), } } diff --git a/zokrates_core/src/typed_absy/result_folder.rs b/zokrates_core/src/typed_absy/result_folder.rs index cbca2f7e..b78e81a2 100644 --- a/zokrates_core/src/typed_absy/result_folder.rs +++ b/zokrates_core/src/typed_absy/result_folder.rs @@ -114,6 +114,20 @@ pub trait ResultFolder<'ast, T: Field>: Sized { }) } + fn fold_canonical_constant_identifier( + &mut self, + i: CanonicalConstantIdentifier<'ast>, + ) -> Result, Self::Error> { + Ok(CanonicalConstantIdentifier { + module: self.fold_module_id(i.module)?, + id: i.id, + }) + } + + fn fold_module_id(&mut self, i: OwnedTypedModuleId) -> Result { + Ok(i) + } + fn fold_name(&mut self, n: Identifier<'ast>) -> Result, Self::Error> { Ok(n) } @@ -923,6 +937,7 @@ pub fn fold_declaration_function_key<'ast, T: Field, F: ResultFolder<'ast, T>>( key: DeclarationFunctionKey<'ast>, ) -> Result, F::Error> { Ok(DeclarationFunctionKey { + module: f.fold_module_id(key.module)?, signature: f.fold_signature(key.signature)?, ..key }) @@ -1067,7 +1082,9 @@ pub fn fold_constant_symbol<'ast, T: Field, F: ResultFolder<'ast, T>>( ) -> Result, F::Error> { match s { TypedConstantSymbol::Here(tc) => Ok(TypedConstantSymbol::Here(f.fold_constant(tc)?)), - there => Ok(there), + TypedConstantSymbol::There(id) => Ok(TypedConstantSymbol::There( + f.fold_canonical_constant_identifier(id)?, + )), } } @@ -1077,7 +1094,10 @@ pub fn fold_function_symbol<'ast, T: Field, F: ResultFolder<'ast, T>>( ) -> Result, F::Error> { match s { TypedFunctionSymbol::Here(fun) => Ok(TypedFunctionSymbol::Here(f.fold_function(fun)?)), - there => Ok(there), // by default, do not fold modules recursively + TypedFunctionSymbol::There(key) => Ok(TypedFunctionSymbol::There( + f.fold_declaration_function_key(key)?, + )), + s => Ok(s), } } @@ -1109,6 +1129,6 @@ pub fn fold_program<'ast, T: Field, F: ResultFolder<'ast, T>>( .into_iter() .map(|(module_id, module)| f.fold_module(module).map(|m| (module_id, m))) .collect::>()?, - main: p.main, + main: f.fold_module_id(p.main)?, }) }