From 898b3eb7e851c79cc58dfaf76217a74c97a82b05 Mon Sep 17 00:00:00 2001 From: schaeff Date: Wed, 4 Nov 2020 21:50:44 +0000 Subject: [PATCH] wip --- test.zok | 18 +- zokrates_core/src/embed.rs | 2 +- zokrates_core/src/flatten/mod.rs | 2 +- zokrates_core/src/semantics.rs | 107 ++- .../static_analysis/flatten_complex_types.rs | 5 +- zokrates_core/src/static_analysis/mod.rs | 8 +- .../src/static_analysis/propagation.rs | 2 +- .../src/static_analysis/reducer/inline.rs | 100 ++- .../src/static_analysis/reducer/mod.rs | 698 ++++++++++++------ .../static_analysis/reducer/shallow_ssa.rs | 23 +- zokrates_core/src/typed_absy/abi.rs | 2 +- zokrates_core/src/typed_absy/integer.rs | 126 ++-- zokrates_core/src/typed_absy/mod.rs | 161 ++-- zokrates_core/src/typed_absy/result_folder.rs | 667 +++++++++++++++++ zokrates_core/src/typed_absy/types.rs | 34 +- zokrates_core/src/typed_absy/uint.rs | 2 +- zokrates_core/src/zir/from_typed.rs | 1 + zokrates_core/src/zir/mod.rs | 10 +- zokrates_core/src/zir/types.rs | 13 +- .../tests/tests/fact_up_to_4.zok | 6 +- .../tests/tests/generics/call.json | 4 + .../tests/tests/generics/call.zok | 8 + .../tests/tests/generics/multidef.json | 23 + .../tests/tests/generics/multidef.zok | 6 + .../stdlib/utils/casts/u32_to_field.zok | 10 + zokrates_test/src/lib.rs | 20 +- 26 files changed, 1602 insertions(+), 456 deletions(-) create mode 100644 zokrates_core/src/typed_absy/result_folder.rs create mode 100644 zokrates_core_test/tests/tests/generics/call.json create mode 100644 zokrates_core_test/tests/tests/generics/call.zok create mode 100644 zokrates_core_test/tests/tests/generics/multidef.json create mode 100644 zokrates_core_test/tests/tests/generics/multidef.zok create mode 100644 zokrates_stdlib/stdlib/utils/casts/u32_to_field.zok diff --git a/test.zok b/test.zok index 019c6cce..4990a469 100644 --- a/test.zok +++ b/test.zok @@ -1,10 +1,10 @@ -def sum_array

(field[P] a) -> field: - field res = 0 - for u32 i in 0..P do - res = res + a[i] +def main(): + field a = 3 + for u32 i in 0..2 do + for u32 j in 0..2 do + a = a + 1 + endfor + a = a + 1 endfor - return res - -def main(u32 i) -> (field): - u32[2][4] a = [[1, 2, 3, 4], [1, 2, 3, 4]] - return 1 + field b = a + return \ No newline at end of file diff --git a/zokrates_core/src/embed.rs b/zokrates_core/src/embed.rs index 278f9439..b6e555da 100644 --- a/zokrates_core/src/embed.rs +++ b/zokrates_core/src/embed.rs @@ -52,7 +52,7 @@ impl FlatEmbed { } pub fn key(&self) -> ConcreteFunctionKey<'static> { - ConcreteFunctionKey::with_id(self.id()).signature(self.signature()) + ConcreteFunctionKey::with_location("#EMBED#", self.id()).signature(self.signature()) } pub fn id(&self) -> &'static str { diff --git a/zokrates_core/src/flatten/mod.rs b/zokrates_core/src/flatten/mod.rs index 89eaf039..801a93ba 100644 --- a/zokrates_core/src/flatten/mod.rs +++ b/zokrates_core/src/flatten/mod.rs @@ -889,7 +889,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { .inputs(param_expressions.iter().map(|e| e.get_type()).collect()) .outputs(return_types); - let key = FunctionKey::with_id(id).signature(passed_signature); + let key = FunctionKey::with_location("#EMBED#", id).signature(passed_signature); let funct = self.get_embed(&key, &symbols); diff --git a/zokrates_core/src/semantics.rs b/zokrates_core/src/semantics.rs index 6a2df64c..7dc8c198 100644 --- a/zokrates_core/src/semantics.rs +++ b/zokrates_core/src/semantics.rs @@ -422,12 +422,18 @@ impl<'ast, T: Field> Checker<'ast, T> { }; self.functions.insert( - DeclarationFunctionKey::with_id(declaration.id.clone()) - .signature(funct.signature.clone()), + DeclarationFunctionKey::with_location( + module_id.clone(), + declaration.id.clone(), + ) + .signature(funct.signature.clone()), ); functions.insert( - DeclarationFunctionKey::with_id(declaration.id.clone()) - .signature(funct.signature.clone()), + DeclarationFunctionKey::with_location( + module_id.clone(), + declaration.id.clone(), + ) + .signature(funct.signature.clone()), TypedFunctionSymbol::Here(funct), ); } @@ -450,6 +456,7 @@ impl<'ast, T: Field> Checker<'ast, T> { .iter() .filter(|(k, _)| k.id == import.symbol_id) .map(|(_, v)| DeclarationFunctionKey { + module: import.module_id.clone(), id: import.symbol_id.clone(), signature: v.signature(&state.typed_modules).clone(), }) @@ -525,12 +532,12 @@ impl<'ast, T: Field> Checker<'ast, T> { true => {} }; - self.functions.insert(candidate.clone().id(declaration.id)); + let local_key = candidate.clone().id(declaration.id).module(module_id.clone()); + + self.functions.insert(local_key.clone()); functions.insert( - candidate.clone().id(declaration.id), - TypedFunctionSymbol::There( - candidate, - import.module_id.clone(), + local_key, + TypedFunctionSymbol::There(candidate, ), ); } @@ -562,12 +569,18 @@ impl<'ast, T: Field> Checker<'ast, T> { }; self.functions.insert( - DeclarationFunctionKey::with_id(declaration.id.clone()) - .signature(funct.signature().clone().try_into().unwrap()), + DeclarationFunctionKey::with_location( + module_id.clone(), + declaration.id.clone(), + ) + .signature(funct.signature().clone().try_into().unwrap()), ); functions.insert( - DeclarationFunctionKey::with_id(declaration.id.clone()) - .signature(funct.signature().clone().try_into().unwrap()), + DeclarationFunctionKey::with_location( + module_id.clone(), + declaration.id.clone(), + ) + .signature(funct.signature().clone().try_into().unwrap()), TypedFunctionSymbol::Flat(funct), ); } @@ -1363,11 +1376,11 @@ impl<'ast, T: Field> Checker<'ast, T> { let f = functions.pop().unwrap(); let arguments_checked = arguments_checked.into_iter().zip(f.signature.inputs.clone()).map(|(a, t)| TypedExpression::align_to_type(a, t.into())).collect::, _>>().map_err(|e| vec![ErrorInner { - pos: Some(pos), - message: format!("Expected function call argument to be of type {}, found {}", e.1, e.0) - }])?; + pos: Some(pos), + message: format!("Expected function call argument to be of type {}, found {} of type {}", e.1, e.0, e.0.get_type()) + }])?; - let call = TypedExpressionList::FunctionCall(f.clone().into(), arguments_checked, Signature::from(f.signature.clone()).outputs); + let call = TypedExpressionList::FunctionCall(f.clone().into(), arguments_checked, variables.iter().map(|v| v.get_type()).collect()); Ok(TypedStatement::MultipleDefinition(variables, call)) }, @@ -1914,9 +1927,9 @@ impl<'ast, T: Field> Checker<'ast, T> { let f = functions.pop().unwrap(); - let signature: Signature = f.signature.into(); + let signature = f.signature; - let arguments_checked = arguments_checked.into_iter().zip(signature.inputs.clone()).map(|(a, t)| TypedExpression::align_to_type(a, t)).collect::, _>>().map_err(|e| ErrorInner { + let arguments_checked = arguments_checked.into_iter().zip(signature.inputs.clone()).map(|(a, t)| TypedExpression::align_to_type(a, t.into())).collect::, _>>().map_err(|e| ErrorInner { pos: Some(pos), message: format!("Expected function call argument to be of type {}, found {}", e.1, e.0) })?; @@ -1924,25 +1937,28 @@ impl<'ast, T: Field> Checker<'ast, T> { // the return count has to be 1 match signature.outputs.len() { 1 => match &signature.outputs[0] { - Type::Int => unreachable!(), - Type::FieldElement => Ok(FieldElementExpression::FunctionCall( - FunctionKey { + DeclarationType::Int => unreachable!(), + DeclarationType::FieldElement => Ok(FieldElementExpression::FunctionCall( + DeclarationFunctionKey { + module: module_id.clone(), id: f.id.clone(), signature: signature.clone(), }, arguments_checked, ) .into()), - Type::Boolean => Ok(BooleanExpression::FunctionCall( - FunctionKey { + DeclarationType::Boolean => Ok(BooleanExpression::FunctionCall( + DeclarationFunctionKey { + module: module_id.clone(), id: f.id.clone(), signature: signature.clone(), }, arguments_checked, ) .into()), - Type::Uint(bitwidth) => Ok(UExpressionInner::FunctionCall( - FunctionKey { + DeclarationType::Uint(bitwidth) => Ok(UExpressionInner::FunctionCall( + DeclarationFunctionKey { + module: module_id.clone(), id: f.id.clone(), signature: signature.clone(), }, @@ -1950,23 +1966,25 @@ impl<'ast, T: Field> Checker<'ast, T> { ) .annotate(*bitwidth) .into()), - Type::Struct(members) => Ok(StructExpressionInner::FunctionCall( - FunctionKey { + DeclarationType::Struct(members) => Ok(StructExpressionInner::FunctionCall( + DeclarationFunctionKey { + module: module_id.clone(), id: f.id.clone(), signature: signature.clone(), }, arguments_checked, ) - .annotate(members.clone()) + .annotate(members.clone().into()) .into()), - Type::Array(array_type) => Ok(ArrayExpressionInner::FunctionCall( - FunctionKey { + DeclarationType::Array(array_type) => Ok(ArrayExpressionInner::FunctionCall( + DeclarationFunctionKey { + module: module_id.clone(), id: f.id.clone(), signature: signature.clone(), }, arguments_checked, ) - .annotate(*array_type.ty.clone(), array_type.size.clone()) + .annotate(Type::from(*array_type.ty.clone()), array_type.size.clone()) .into()), }, n => Err(ErrorInner { @@ -2592,6 +2610,8 @@ impl<'ast, T: Field> Checker<'ast, T> { .next() .unwrap_or(Type::Int); + println!("INFERRED TYPE {:?}", inferred_type); + match inferred_type { Type::Int => { // no need to check the expressions have the same type, this is guaranteed above @@ -2643,6 +2663,8 @@ impl<'ast, T: Field> Checker<'ast, T> { let size = unwrapped_expressions.len() as u32; + println!("hey"); + Ok(ArrayExpressionInner::Value(unwrapped_expressions) .annotate(Type::Boolean, size as usize) .into()) @@ -3379,12 +3401,11 @@ mod tests { state.typed_modules.get(&PathBuf::from("bar")), Some(&TypedModule { functions: vec![( - DeclarationFunctionKey::with_id("main") + DeclarationFunctionKey::with_location("bar", "main") .signature(DeclarationSignature::new()), TypedFunctionSymbol::There( - DeclarationFunctionKey::with_id("main") + DeclarationFunctionKey::with_location("foo", "main") .signature(DeclarationSignature::new()), - "foo".into() ) )] .into_iter() @@ -3672,16 +3693,19 @@ mod tests { .unwrap() .functions .contains_key( - &DeclarationFunctionKey::with_id("foo").signature(DeclarationSignature::new()) + &DeclarationFunctionKey::with_location(MODULE_ID, "foo") + .signature(DeclarationSignature::new()) )); assert!(state .typed_modules .get(&PathBuf::from(MODULE_ID)) .unwrap() .functions - .contains_key(&DeclarationFunctionKey::with_id("foo").signature( - DeclarationSignature::new().inputs(vec![DeclarationType::FieldElement]) - ))) + .contains_key( + &DeclarationFunctionKey::with_location(MODULE_ID, "foo").signature( + DeclarationSignature::new().inputs(vec![DeclarationType::FieldElement]) + ) + )) } #[test] @@ -4271,6 +4295,7 @@ mod tests { ]; let foo = DeclarationFunctionKey { + module: "main".into(), id: "foo", signature: DeclarationSignature { inputs: vec![], @@ -4323,6 +4348,7 @@ mod tests { .mock()]; let foo = DeclarationFunctionKey { + module: "main".into(), id: "foo", signature: DeclarationSignature { inputs: vec![], @@ -4856,7 +4882,7 @@ mod tests { typed_absy::Variable::field_element("b"), ], TypedExpressionList::FunctionCall( - DeclarationFunctionKey::with_id("foo").signature( + DeclarationFunctionKey::with_location(MODULE_ID, "foo").signature( DeclarationSignature::new().outputs(vec![ DeclarationType::FieldElement, DeclarationType::FieldElement, @@ -4874,6 +4900,7 @@ mod tests { ]; let foo = DeclarationFunctionKey { + module: "main".into(), id: "foo", signature: DeclarationSignature { inputs: vec![], diff --git a/zokrates_core/src/static_analysis/flatten_complex_types.rs b/zokrates_core/src/static_analysis/flatten_complex_types.rs index e40708a5..00d6c3eb 100644 --- a/zokrates_core/src/static_analysis/flatten_complex_types.rs +++ b/zokrates_core/src/static_analysis/flatten_complex_types.rs @@ -279,7 +279,7 @@ pub fn fold_statement<'ast, T: Field>( )] } typed_absy::TypedStatement::PushCallLog(..) => vec![], - typed_absy::TypedStatement::PopCallLog(..) => vec![], + typed_absy::TypedStatement::PopCallLog => vec![], } } @@ -906,9 +906,8 @@ pub fn fold_function_symbol<'ast, T: Field>( typed_absy::TypedFunctionSymbol::Here(fun) => { zir::ZirFunctionSymbol::Here(f.fold_function(fun)) } - typed_absy::TypedFunctionSymbol::There(key, module) => zir::ZirFunctionSymbol::There( + typed_absy::TypedFunctionSymbol::There(key) => zir::ZirFunctionSymbol::There( f.fold_function_key(typed_absy::types::ConcreteFunctionKey::try_from(key).unwrap()), - module, ), // by default, do not fold modules recursively typed_absy::TypedFunctionSymbol::Flat(flat) => zir::ZirFunctionSymbol::Flat(flat), } diff --git a/zokrates_core/src/static_analysis/mod.rs b/zokrates_core/src/static_analysis/mod.rs index 7c5c6ee8..32880229 100644 --- a/zokrates_core/src/static_analysis/mod.rs +++ b/zokrates_core/src/static_analysis/mod.rs @@ -41,12 +41,16 @@ pub trait Analyse { impl<'ast, T: Field> TypedProgram<'ast, T> { pub fn analyse(self) -> (ZirProgram<'ast, T>, Abi) { // return binding - let r = ReturnBinder::bind(self); + //let r = ReturnBinder::bind(self); // propagated unrolling //let r = PropagatedUnroller::unroll(r).unwrap_or_else(|e| panic!(e)); - let r = reduce_program(r).unwrap(); + println!("{:#?}", self); + + let r = reduce_program(self).unwrap(); + + println!("reduced"); let r = Trimmer::trim(r); diff --git a/zokrates_core/src/static_analysis/propagation.rs b/zokrates_core/src/static_analysis/propagation.rs index 2266b126..f0597814 100644 --- a/zokrates_core/src/static_analysis/propagation.rs +++ b/zokrates_core/src/static_analysis/propagation.rs @@ -331,7 +331,7 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> { } } s @ TypedStatement::PushCallLog(..) => Some(s), - s @ TypedStatement::PopCallLog(..) => Some(s), + s @ TypedStatement::PopCallLog => Some(s), }; // In verbose mode, we always return a statement diff --git a/zokrates_core/src/static_analysis/reducer/inline.rs b/zokrates_core/src/static_analysis/reducer/inline.rs index 6fd7c705..1c9099ab 100644 --- a/zokrates_core/src/static_analysis/reducer/inline.rs +++ b/zokrates_core/src/static_analysis/reducer/inline.rs @@ -25,23 +25,25 @@ // - The body of the function is in SSA form // - The return value(s) are assigned to internal variables +use embed::FlatEmbed; +use static_analysis::reducer::CallCache; use static_analysis::reducer::Output; use static_analysis::reducer::ShallowTransformer; use static_analysis::reducer::Versions; use typed_absy::CoreIdentifier; use typed_absy::Identifier; +use typed_absy::TypedAssignee; use typed_absy::{ - ConcreteSignature, ConcreteVariable, DeclarationFunctionKey, DeclarationSignature, Signature, - Type, TypedExpression, TypedFunction, TypedFunctionSymbol, TypedModuleId, TypedModules, - TypedStatement, Variable, + ConcreteFunctionKey, ConcreteSignature, ConcreteVariable, DeclarationFunctionKey, + DeclarationSignature, Signature, Type, TypedExpression, TypedFunction, TypedFunctionSymbol, + TypedModules, TypedStatement, Variable, }; use zokrates_field::Field; pub enum InlineError<'ast, T> { Generic(DeclarationSignature<'ast>, ConcreteSignature), - Flat, + Flat(FlatEmbed, Vec>, Vec>), NonConstant( - TypedModuleId, DeclarationFunctionKey<'ast>, Vec>, Vec>, @@ -49,33 +51,29 @@ pub enum InlineError<'ast, T> { } fn get_canonical_function<'ast, T: Field>( - module_id: TypedModuleId, function_key: DeclarationFunctionKey<'ast>, modules: &TypedModules<'ast, T>, -) -> ( - TypedModuleId, - DeclarationFunctionKey<'ast>, - TypedFunction<'ast, T>, -) { +) -> Result<(DeclarationFunctionKey<'ast>, TypedFunction<'ast, T>), FlatEmbed> { match modules - .get(&module_id.clone()) + .get(&function_key.module) .unwrap() .functions .iter() .find(|(key, _)| function_key == **key) .unwrap() { - (key, TypedFunctionSymbol::Here(f)) => (module_id, key.clone(), f.clone()), - _ => unimplemented!(), + (key, TypedFunctionSymbol::Here(f)) => Ok((key.clone(), f.clone())), + (_, TypedFunctionSymbol::There(key)) => get_canonical_function(key.clone(), &modules), + (_, TypedFunctionSymbol::Flat(f)) => Err(f.clone()), } } pub fn inline_call<'a, 'ast, T: Field>( - module_id: TypedModuleId, k: DeclarationFunctionKey<'ast>, arguments: Vec>, output_types: Vec>, modules: &TypedModules<'ast, T>, + cache: &mut CallCache<'ast, T>, versions: &'a mut Versions<'ast>, ) -> Result< Output<(Vec>, Vec>), Vec>>, @@ -95,16 +93,12 @@ pub fn inline_call<'a, 'ast, T: Field>( let inferred_signature = match ConcreteSignature::try_from(inferred_signature) { Ok(s) => s, Err(()) => { - return Err(InlineError::NonConstant( - module_id, - k, - arguments, - output_types, - )); + return Err(InlineError::NonConstant(k, arguments, output_types)); } }; - let (module_id, decl_key, f) = get_canonical_function(module_id, k.clone(), modules); + let (decl_key, f) = get_canonical_function(k.clone(), modules) + .map_err(|e| InlineError::Flat(e, arguments.clone(), output_types))?; assert_eq!(f.arguments.len(), arguments.len()); // get an assignment of generics for this call site @@ -115,23 +109,34 @@ pub fn inline_call<'a, 'ast, T: Field>( InlineError::Generic(decl_key.signature.clone(), inferred_signature.clone()) })?; + let concrete_key = ConcreteFunctionKey { + module: decl_key.module.clone(), + id: decl_key.id.clone(), + signature: inferred_signature.clone(), + }; + + match cache.get(&(concrete_key.clone(), arguments.clone())) { + Some(v) => { + return Ok(Output::Complete((vec![], v.clone()))); + } + None => {} + }; + let (ssa_f, incomplete_data) = match ShallowTransformer::transform(f, &assignment, versions) { Output::Complete(v) => (v, None), Output::Incomplete(statements, for_loop_versions) => (statements, Some(for_loop_versions)), }; - let call_log = TypedStatement::PushCallLog( - module_id, - decl_key.clone(), - assignment, - ssa_f - .arguments - .into_iter() - .zip(inferred_signature.inputs.clone()) - .map(|(p, t)| ConcreteVariable::with_id_and_type(p.id.id, t)) - .zip(arguments) - .collect(), - ); + let call_log = TypedStatement::PushCallLog(decl_key.clone(), assignment); + + let input_bindings: Vec> = ssa_f + .arguments + .into_iter() + .zip(inferred_signature.inputs.clone()) + .map(|(p, t)| ConcreteVariable::with_id_and_type(p.id.id, t)) + .zip(arguments.clone()) + .map(|(v, a)| TypedStatement::Definition(TypedAssignee::Identifier(v.into()), a)) + .collect(); let (statements, returns): (Vec<_>, Vec<_>) = ssa_f.statements.into_iter().partition(|s| match s { @@ -151,7 +156,15 @@ pub fn inline_call<'a, 'ast, T: Field>( .iter() .enumerate() .map(|(i, t)| { - ConcreteVariable::with_id_and_type(Identifier::from(CoreIdentifier::Call(i)), t.clone()) + ConcreteVariable::with_id_and_type( + Identifier::from(CoreIdentifier::Call(i)).version( + *versions + .entry(CoreIdentifier::Call(i).clone()) + .and_modify(|e| *e += 1) // if it was already declared, we increment + .or_insert(0), + ), + t.clone(), + ) }) .collect(); @@ -162,13 +175,26 @@ pub fn inline_call<'a, 'ast, T: Field>( assert_eq!(res.len(), returns.len()); - let call_pop = TypedStatement::PopCallLog(res.into_iter().zip(returns).collect()); + let output_bindings: Vec> = res + .into_iter() + .zip(returns) + .map(|(v, a)| TypedStatement::Definition(TypedAssignee::Identifier(v.into()), a)) + .collect(); + + let pop_log = TypedStatement::PopCallLog; let statements: Vec<_> = std::iter::once(call_log) + .chain(input_bindings) .chain(statements) - .chain(std::iter::once(call_pop)) + .chain(output_bindings) + .chain(std::iter::once(pop_log)) .collect(); + cache.insert( + (concrete_key.clone(), arguments.clone()), + expressions.clone(), + ); + Ok(incomplete_data .map(|d| Output::Incomplete((statements.clone(), expressions.clone()), d)) .unwrap_or_else(|| Output::Complete((statements, expressions)))) diff --git a/zokrates_core/src/static_analysis/reducer/mod.rs b/zokrates_core/src/static_analysis/reducer/mod.rs index f0cda4d0..0af8c9d0 100644 --- a/zokrates_core/src/static_analysis/reducer/mod.rs +++ b/zokrates_core/src/static_analysis/reducer/mod.rs @@ -17,12 +17,20 @@ mod unroll; use self::inline::{inline_call, InlineError}; use std::collections::HashMap; +use typed_absy::result_folder::*; use typed_absy::types::GenericsAssignment; +use typed_absy::Folder; use typed_absy::{ - CoreIdentifier, TypedExpressionList, TypedFunction, TypedFunctionSymbol, TypedModule, - TypedModules, TypedProgram, TypedStatement, + ArrayExpression, ArrayExpressionInner, BooleanExpression, ConcreteFunctionKey, CoreIdentifier, + DeclarationFunctionKey, FieldElementExpression, FunctionCall, Identifier, StructExpression, + StructExpressionInner, Type, Typed, TypedExpression, TypedExpressionList, TypedFunction, + TypedFunctionSymbol, TypedModule, TypedModules, TypedProgram, TypedStatement, UExpression, + UExpressionInner, }; + +use std::convert::{TryFrom, TryInto}; + use zokrates_field::Field; use self::shallow_ssa::ShallowTransformer; @@ -44,6 +52,326 @@ pub enum Error { Incompatible, } +type CallCache<'ast, T> = HashMap< + (ConcreteFunctionKey<'ast>, Vec>), + Vec>, +>; + +type Substitutions<'ast> = HashMap, HashMap>; + +struct Sub<'a, 'ast> { + substitutions: &'a mut Substitutions<'ast>, +} + +impl<'a, 'ast> Sub<'a, 'ast> { + fn new(substitutions: &'a mut Substitutions<'ast>) -> Self { + Self { substitutions } + } +} + +impl<'a, 'ast, T: Field> Folder<'ast, T> for Sub<'a, 'ast> { + fn fold_name(&mut self, id: Identifier<'ast>) -> Identifier<'ast> { + let sub = self.substitutions.entry(id.id.clone()).or_default(); + let version = sub.get(&id.version).cloned().unwrap_or(id.version); + id.version(version) + } +} + +fn register<'ast>( + substitutions: &mut Substitutions<'ast>, + substitute: &Versions<'ast>, + with: &Versions<'ast>, +) { + //assert!(substitute.len() <= with.len()); + for (id, key, value) in substitute + .iter() + .filter_map(|(id, version)| with.get(&id).clone().map(|to| (id, version, to))) + .filter(|(_, key, value)| key != value) + { + let sub = substitutions.entry(id.clone()).or_default(); + sub.insert(*key, *value); + } +} + +struct Reducer<'ast, 'a, T> { + statement_buffer: Vec>, + for_loop_versions: Vec>, + for_loop_index: usize, + modules: &'a TypedModules<'ast, T>, + versions: &'a mut Versions<'ast>, + substitutions: Substitutions<'ast>, + cache: CallCache<'ast, T>, + complete: bool, +} + +impl<'ast, 'a, T: Field> Reducer<'ast, 'a, T> { + fn new( + modules: &'a TypedModules<'ast, T>, + versions: &'a mut Versions<'ast>, + for_loop_versions: Vec>, + ) -> Self { + Reducer { + statement_buffer: vec![], + for_loop_index: 0, + for_loop_versions, + cache: CallCache::default(), + substitutions: Substitutions::default(), + modules, + versions, + complete: true, + } + } + + fn fold_function_call( + &mut self, + key: DeclarationFunctionKey<'ast>, + arguments: Vec>, + output_types: Vec>, + ) -> Result>::Error> + where + E: FunctionCall<'ast, T> + TryFrom, Error = ()> + std::fmt::Debug, + { + let arguments = arguments + .into_iter() + .map(|e| self.fold_expression(e)) + .collect::>()?; + let res = inline_call( + key.clone(), + arguments, + output_types, + &self.modules, + &mut self.cache, + &mut self.versions, + ); + + match res { + Ok(Output::Complete((statements, expressions))) => { + self.complete &= true; + self.statement_buffer.extend(statements); + Ok(expressions[0].clone().try_into().unwrap()) + } + Ok(Output::Incomplete((statements, expressions), _delta_for_loop_versions)) => { + self.complete = false; + self.statement_buffer.extend(statements); + Ok(expressions[0].clone().try_into().unwrap()) + } + Err(InlineError::Generic(a, b)) => Err(Error::Incompatible), + Err(InlineError::NonConstant(key, arguments, _)) => { + self.complete = false; + + Ok(E::function_call(key, arguments)) + } + Err(InlineError::Flat(embed, arguments, output_types)) => { + Ok(E::function_call(embed.key::().into(), arguments)) + } + } + } +} + +impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> { + type Error = Error; + + fn fold_statement( + &mut self, + s: TypedStatement<'ast, T>, + ) -> Result>, Self::Error> { + let res = match s { + TypedStatement::MultipleDefinition( + v, + TypedExpressionList::FunctionCall(key, arguments, output_types), + ) => { + let arguments = arguments + .into_iter() + .map(|a| self.fold_expression(a)) + .collect::>()?; + + match inline_call( + key, + arguments, + output_types, + &self.modules, + &mut self.cache, + &mut self.versions, + ) { + Ok(Output::Complete((statements, expressions))) => { + assert_eq!(v.len(), expressions.len()); + + self.complete &= true; + + Ok(statements + .into_iter() + .chain( + v.into_iter() + .zip(expressions) + .map(|(v, e)| TypedStatement::Definition(v.into(), e)), + ) + .collect()) + } + Ok(Output::Incomplete((statements, expressions), delta_for_loop_versions)) => { + assert_eq!(v.len(), expressions.len()); + + self.complete = false; + + Ok(statements + .into_iter() + .chain( + v.into_iter() + .zip(expressions) + .map(|(v, e)| TypedStatement::Definition(v.into(), e)), + ) + .collect()) + } + Err(InlineError::Generic(..)) => Err(Error::Incompatible), + Err(InlineError::NonConstant(key, arguments, output_types)) => { + self.complete = false; + + Ok(vec![TypedStatement::MultipleDefinition( + v, + TypedExpressionList::FunctionCall(key, arguments, output_types), + )]) + } + Err(InlineError::Flat(embed, arguments, output_types)) => { + Ok(vec![TypedStatement::MultipleDefinition( + v, + TypedExpressionList::FunctionCall( + embed.key::().into(), + arguments, + output_types, + ), + )]) + } + } + } + TypedStatement::For(v, from, to, statements) => { + match (from.as_inner(), to.as_inner()) { + (UExpressionInner::Value(from), UExpressionInner::Value(to)) => { + let mut out_statements = vec![]; + + let versions_before = self.for_loop_versions[self.for_loop_index].clone(); + + self.versions.values_mut().for_each(|v| *v = *v + 1); + + register(&mut self.substitutions, &self.versions, &versions_before); + + let versions_after = versions_before + .clone() + .into_iter() + .map(|(k, v)| (k, v + 2)) + .collect(); + + let mut transformer = ShallowTransformer::with_versions(&mut self.versions); + + let old_index = self.for_loop_index; + + for index in *from..*to { + let statements: Vec> = + std::iter::once(TypedStatement::Definition( + v.clone().into(), + UExpression::from(index as u32).into(), + )) + .chain(statements.clone().into_iter()) + .map(|s| transformer.fold_statement(s)) + .flatten() + .collect(); + + let new_versions_count = transformer.for_loop_backups.len(); + + self.for_loop_index += new_versions_count; + + out_statements.extend(statements); + } + + let backups = transformer.for_loop_backups; + let blocked = transformer.blocked; + + register(&mut self.substitutions, &versions_after, &self.versions); + + self.for_loop_versions.splice(old_index..old_index, backups); + + self.complete &= !blocked; + + let out_statements = out_statements + .into_iter() + .map(|s| Sub::new(&mut self.substitutions).fold_statement(s)) + .flatten() + .collect(); + + Ok(out_statements) + } + _ => { + self.complete = false; + self.for_loop_index += 1; + Ok(vec![TypedStatement::For(v, from, to, statements)]) + } + } + } + s => fold_statement(self, s), + }; + + res.map(|res| self.statement_buffer.drain(..).chain(res).collect()) + } + + fn fold_boolean_expression( + &mut self, + e: BooleanExpression<'ast, T>, + ) -> Result, Self::Error> { + match e { + BooleanExpression::FunctionCall(key, arguments) => { + self.fold_function_call(key, arguments, vec![Type::Boolean]) + } + e => fold_boolean_expression(self, e), + } + } + + fn fold_uint_expression( + &mut self, + e: UExpression<'ast, T>, + ) -> Result, Self::Error> { + match e.as_inner() { + UExpressionInner::FunctionCall(key, arguments) => { + self.fold_function_call(key.clone(), arguments.clone(), vec![e.get_type()]) + } + _ => fold_uint_expression(self, e), + } + } + + fn fold_field_expression( + &mut self, + e: FieldElementExpression<'ast, T>, + ) -> Result, Self::Error> { + match e { + FieldElementExpression::FunctionCall(key, arguments) => { + self.fold_function_call(key, arguments, vec![Type::FieldElement]) + } + e => fold_field_expression(self, e), + } + } + + fn fold_array_expression( + &mut self, + e: ArrayExpression<'ast, T>, + ) -> Result, Self::Error> { + match e.as_inner() { + ArrayExpressionInner::FunctionCall(key, arguments) => { + self.fold_function_call(key.clone(), arguments.clone(), vec![e.get_type()]) + } + _ => fold_array_expression(self, e), + } + } + + fn fold_struct_expression( + &mut self, + e: StructExpression<'ast, T>, + ) -> Result, Self::Error> { + match e.as_inner() { + StructExpressionInner::FunctionCall(key, arguments) => { + self.fold_function_call(key.clone(), arguments.clone(), vec![e.get_type()]) + } + _ => fold_struct_expression(self, e), + } + } +} + pub fn reduce_program(p: TypedProgram) -> Result, Error> { let main_module = p.modules.get(&p.main).unwrap().clone(); @@ -93,27 +421,27 @@ fn reduce_function<'ast, T: Field>( let mut f = new_f; let statements = loop { - match reduce_statements( - f.statements, - &mut for_loop_versions, - &mut versions, - modules, - ) { - Ok(Output::Complete(statements)) => { - break statements; - } - Ok(Output::Incomplete(new_statements, new_for_loop_versions)) => { - let new_f = TypedFunction { - statements: new_statements, - ..f - }; + println!("{}", f); - use typed_absy::folder::Folder; + let mut reducer = Reducer::new(&modules, &mut versions, for_loop_versions); + + let statements = f + .statements + .into_iter() + .map(|s| reducer.fold_statement(s)) + .collect::, _>>()? + .into_iter() + .flatten() + .collect(); + + match reducer.complete { + true => break statements, + false => { + let new_f = TypedFunction { statements, ..f }; f = Propagator::verbose().fold_function(new_f); - for_loop_versions = new_for_loop_versions; + for_loop_versions = reducer.for_loop_versions; } - Err(e) => return Err(e), } }; @@ -122,111 +450,47 @@ fn reduce_function<'ast, T: Field>( } } -fn reduce_statements<'ast, T: Field>( - statements: Vec>, - for_loop_versions: &mut Vec>, - versions: &mut Versions<'ast>, - modules: &TypedModules<'ast, T>, -) -> Result>, Vec>>, Error> { - let mut versions = versions; +// fn reduce_statements<'ast, T: Field>( +// statements: Vec>, +// for_loop_versions: &mut Vec>, +// versions: &mut Versions<'ast>, +// modules: &TypedModules<'ast, T>, +// ) -> Result>, Vec>>, Error> { +// let mut versions = versions; - let statements = statements - .into_iter() - .map(|s| reduce_statement(s, for_loop_versions, &mut versions, modules)); +// let statements = statements +// .into_iter() +// .map(|s| reduce_statement(s, for_loop_versions, &mut versions, modules)); - statements - .into_iter() - .fold(Ok(Output::Complete(vec![])), |state, e| match (state, e) { - (Ok(state), Ok(e)) => { - use self::Output::*; - match (state, e) { - (Complete(mut s), Complete(c)) => { - s.extend(c); - Ok(Complete(s)) - } - (Complete(mut s), Incomplete(stats, for_loop_versions)) => { - s.extend(stats); - Ok(Incomplete(s, for_loop_versions)) - } - (Incomplete(mut stats, new_for_loop_versions), Complete(c)) => { - stats.extend(c); - Ok(Incomplete(stats, new_for_loop_versions)) - } - (Incomplete(mut stats, mut versions), Incomplete(new_stats, new_versions)) => { - stats.extend(new_stats); - versions.extend(new_versions); - Ok(Incomplete(stats, versions)) - } - } - } - (Err(state), _) => Err(state), - (Ok(_), Err(e)) => Err(e), - }) -} - -fn reduce_statement<'ast, T: Field>( - statement: TypedStatement<'ast, T>, - _for_loop_versions: &mut Vec>, - versions: &mut Versions<'ast>, - modules: &TypedModules<'ast, T>, -) -> Result>, Vec>>, Error> { - use self::Output::*; - - match statement { - TypedStatement::MultipleDefinition( - v, - TypedExpressionList::FunctionCall(key, arguments, output_types), - ) => match inline_call( - "main".into(), - key, - arguments, - output_types, - modules, - versions, - ) { - Ok(Output::Complete((statements, expressions))) => { - assert_eq!(v.len(), expressions.len()); - Ok(Output::Complete( - statements - .into_iter() - .chain( - v.into_iter() - .zip(expressions) - .map(|(v, e)| TypedStatement::Definition(v.into(), e)), - ) - .collect(), - )) - } - Ok(Output::Incomplete((statements, expressions), delta_for_loop_versions)) => { - assert_eq!(v.len(), expressions.len()); - Ok(Output::Incomplete( - statements - .into_iter() - .chain( - v.into_iter() - .zip(expressions) - .map(|(v, e)| TypedStatement::Definition(v.into(), e)), - ) - .collect(), - delta_for_loop_versions, - )) - } - Err(InlineError::Generic(..)) => Err(Error::Incompatible), - Err(InlineError::NonConstant(_module, key, arguments, output_types)) => { - Ok(Output::Incomplete( - vec![TypedStatement::MultipleDefinition( - v, - TypedExpressionList::FunctionCall(key, arguments, output_types), - )], - vec![], - )) - } - Err(InlineError::Flat) => unimplemented!(), - }, - TypedStatement::For(..) => unimplemented!(), - s => Ok(Complete(vec![s])), - } -} +// statements +// .into_iter() +// .fold(Ok(Output::Complete(vec![])), |state, e| match (state, e) { +// (Ok(state), Ok(e)) => { +// use self::Output::*; +// match (state, e) { +// (Complete(mut s), Complete(c)) => { +// s.extend(c); +// Ok(Complete(s)) +// } +// (Complete(mut s), Incomplete(stats, for_loop_versions)) => { +// s.extend(stats); +// Ok(Incomplete(s, for_loop_versions)) +// } +// (Incomplete(mut stats, new_for_loop_versions), Complete(c)) => { +// stats.extend(c); +// Ok(Incomplete(stats, new_for_loop_versions)) +// } +// (Incomplete(mut stats, mut versions), Incomplete(new_stats, new_versions)) => { +// stats.extend(new_stats); +// versions.extend(new_versions); +// Ok(Incomplete(stats, versions)) +// } +// } +// } +// (Err(state), _) => Err(state), +// (Ok(_), Err(e)) => Err(e), +// }) +// } #[cfg(test)] mod tests { @@ -257,9 +521,11 @@ mod tests { // u32 n_0 = 42 // n_1 = n_0 // a_1 = a_0 - // # PUSH CALL to foo with a_3 := a_1 - // # POP CALL with foo_ifof_42 := a_3 - // a_2 = foo_ifof_42 + // # PUSH CALL to foo + // a_3 := a_1 // input binding + // #RETURN_AT_INDEX_0_0 := a_3 + // # POP CALL + // a_2 = #RETURN_AT_INDEX_0_0 // n_2 = n_1 // return a_2 @@ -295,7 +561,7 @@ mod tests { TypedStatement::MultipleDefinition( vec![Variable::field_element("a").into()], TypedExpressionList::FunctionCall( - DeclarationFunctionKey::with_id("foo").signature( + DeclarationFunctionKey::with_location("main", "foo").signature( DeclarationSignature::new() .inputs(vec![DeclarationType::FieldElement]) .outputs(vec![DeclarationType::FieldElement]), @@ -324,7 +590,7 @@ mod tests { TypedModule { functions: vec![ ( - DeclarationFunctionKey::with_id("foo").signature( + DeclarationFunctionKey::with_location("main", "foo").signature( DeclarationSignature::new() .inputs(vec![DeclarationType::FieldElement]) .outputs(vec![DeclarationType::FieldElement]), @@ -332,7 +598,7 @@ mod tests { TypedFunctionSymbol::Here(foo), ), ( - DeclarationFunctionKey::with_id("main").signature( + DeclarationFunctionKey::with_location("main", "main").signature( DeclarationSignature::new() .inputs(vec![DeclarationType::FieldElement]) .outputs(vec![DeclarationType::FieldElement]), @@ -369,28 +635,23 @@ mod tests { FieldElementExpression::Identifier("a".into()).into(), ), TypedStatement::PushCallLog( - "main".into(), - DeclarationFunctionKey::with_id("foo").signature( + DeclarationFunctionKey::with_location("main", "foo").signature( DeclarationSignature::new() .inputs(vec![DeclarationType::FieldElement]) .outputs(vec![DeclarationType::FieldElement]), ), GenericsAssignment::default(), - vec![( - ConcreteVariable::with_id_and_type( - Identifier::from("a").version(3), - ConcreteType::FieldElement, - ), - FieldElementExpression::Identifier(Identifier::from("a").version(1)).into(), - )], ), - TypedStatement::PopCallLog(vec![( - ConcreteVariable::with_id_and_type( - Identifier::from(CoreIdentifier::Call(0)).version(0), - ConcreteType::FieldElement, - ), + TypedStatement::Definition( + Variable::field_element(Identifier::from("a").version(3)).into(), + FieldElementExpression::Identifier(Identifier::from("a").version(1)).into(), + ), + TypedStatement::Definition( + Variable::field_element(Identifier::from(CoreIdentifier::Call(0)).version(0)) + .into(), FieldElementExpression::Identifier(Identifier::from("a").version(3)).into(), - )]), + ), + TypedStatement::PopCallLog, TypedStatement::Definition( Variable::field_element(Identifier::from("a").version(2)).into(), FieldElementExpression::Identifier( @@ -420,7 +681,7 @@ mod tests { "main".into(), TypedModule { functions: vec![( - DeclarationFunctionKey::with_id("main").signature( + DeclarationFunctionKey::with_location("main", "main").signature( DeclarationSignature::new() .inputs(vec![DeclarationType::FieldElement]) .outputs(vec![DeclarationType::FieldElement]), @@ -435,6 +696,9 @@ mod tests { .collect(), }; + println!("{}", reduced.clone().unwrap()); + println!("{}", expected); + assert_eq!(reduced.unwrap(), expected); } @@ -455,9 +719,12 @@ mod tests { // u32 n_0 = 42 // n_1 = n_0 // field[1] b_0 = [42] - // # PUSH CALL to foo::<1> with a_0 := b_0 - // # POP CALL with foo_if[1]_of[1]_42 := a_0 - // b_1 = foo_if[1]_of[1]_42 + // # PUSH CALL to foo::<1> + // a_0 = b_0 + // K = 1 + // #RETURN_AT_INDEX_0_0 := a_0 + // # POP CALL + // b_1 = #RETURN_AT_INDEX_0_0 // n_2 = n_1 // return a_2 @@ -512,7 +779,8 @@ mod tests { TypedStatement::MultipleDefinition( vec![Variable::array("b", Type::FieldElement, 1u32.into()).into()], TypedExpressionList::FunctionCall( - DeclarationFunctionKey::with_id("foo").signature(foo_signature.clone()), + DeclarationFunctionKey::with_location("main", "foo") + .signature(foo_signature.clone()), vec![ArrayExpressionInner::Identifier("b".into()) .annotate(Type::FieldElement, 1u32) .into()], @@ -539,11 +807,12 @@ mod tests { TypedModule { functions: vec![ ( - DeclarationFunctionKey::with_id("foo").signature(foo_signature.clone()), + DeclarationFunctionKey::with_location("main", "foo") + .signature(foo_signature.clone()), TypedFunctionSymbol::Here(foo), ), ( - DeclarationFunctionKey::with_id("main").signature( + DeclarationFunctionKey::with_location("main", "main").signature( DeclarationSignature::new() .inputs(vec![DeclarationType::FieldElement]) .outputs(vec![DeclarationType::FieldElement]), @@ -584,35 +853,37 @@ mod tests { .into(), ), TypedStatement::PushCallLog( - "main".into(), - DeclarationFunctionKey::with_id("foo").signature(foo_signature.clone()), + DeclarationFunctionKey::with_location("main", "foo") + .signature(foo_signature.clone()), GenericsAssignment(vec![("K", 1)].into_iter().collect()), - vec![( - ConcreteVariable::array( - Identifier::from("a").version(1), - ConcreteType::FieldElement, - 1, - ) + ), + TypedStatement::Definition( + Variable::array( + Identifier::from("a").version(1), + Type::FieldElement, + 1u32.into(), + ) + .into(), + ArrayExpressionInner::Identifier("b".into()) + .annotate(Type::FieldElement, 1u32) .into(), - ArrayExpressionInner::Identifier("b".into()) - .annotate(Type::FieldElement, 1u32) - .into(), - )], ), TypedStatement::Definition( Variable::uint("K", UBitwidth::B32).into(), - UExpression::from(1u32).into(), + TypedExpression::Uint(1u32.into()), ), - TypedStatement::PopCallLog(vec![( - ConcreteVariable::array( + TypedStatement::Definition( + Variable::array( Identifier::from(CoreIdentifier::Call(0)).version(0), - ConcreteType::FieldElement, - 1, - ), + Type::FieldElement, + 1u32.into(), + ) + .into(), ArrayExpressionInner::Identifier(Identifier::from("a").version(1)) .annotate(Type::FieldElement, 1u32) .into(), - )]), + ), + TypedStatement::PopCallLog, TypedStatement::Definition( Variable::array( Identifier::from("b").version(1), @@ -645,7 +916,7 @@ mod tests { "main".into(), TypedModule { functions: vec![( - DeclarationFunctionKey::with_id("main").signature( + DeclarationFunctionKey::with_location("main", "main").signature( DeclarationSignature::new() .inputs(vec![DeclarationType::FieldElement]) .outputs(vec![DeclarationType::FieldElement]), @@ -660,6 +931,9 @@ mod tests { .collect(), }; + println!("{}", reduced.clone().unwrap()); + println!("{}", expected); + assert_eq!(reduced.unwrap(), expected); } @@ -680,9 +954,12 @@ mod tests { // u32 n_0 = 2 // n_1 = 2 // field[1] b_0 = [42] - // # PUSH CALL to foo::<1> with a_3 := b_0 - // # POP CALL with foo_if[1]of[1]_42 := a_3 - // b_1 = foo_if[1]of[1]_42 + // # PUSH CALL to foo::<1> + // a_3 = b_0 + // K = 1 + // #RETURN_AT_INDEX_0_0 = a_3 + // # POP CALL + // b_1 = #RETURN_AT_INDEX_0_0 // n_2 = 2 // return a_2 @@ -755,7 +1032,8 @@ mod tests { ) .into()], TypedExpressionList::FunctionCall( - DeclarationFunctionKey::with_id("foo").signature(foo_signature.clone()), + DeclarationFunctionKey::with_location("main", "foo") + .signature(foo_signature.clone()), vec![ArrayExpressionInner::Identifier("b".into()) .annotate( Type::FieldElement, @@ -798,11 +1076,12 @@ mod tests { TypedModule { functions: vec![ ( - DeclarationFunctionKey::with_id("foo").signature(foo_signature.clone()), + DeclarationFunctionKey::with_location("main", "foo") + .signature(foo_signature.clone()), TypedFunctionSymbol::Here(foo), ), ( - DeclarationFunctionKey::with_id("main").signature( + DeclarationFunctionKey::with_location("main", "main").signature( DeclarationSignature::new() .inputs(vec![DeclarationType::FieldElement]) .outputs(vec![DeclarationType::FieldElement]), @@ -841,37 +1120,39 @@ mod tests { .into(), ), TypedStatement::PushCallLog( - "main".into(), - DeclarationFunctionKey::with_id("foo").signature(foo_signature.clone()), + DeclarationFunctionKey::with_location("main", "foo") + .signature(foo_signature.clone()), GenericsAssignment(vec![("K", 1)].into_iter().collect()), - vec![( - ConcreteVariable::array( - Identifier::from("a").version(1), - ConcreteType::FieldElement, - 1, - ) - .into(), - ArrayExpressionInner::Value(vec![ - FieldElementExpression::Number(1.into()).into() - ]) - .annotate(Type::FieldElement, 1u32) - .into(), - )], + ), + TypedStatement::Definition( + Variable::array( + Identifier::from("a").version(1), + Type::FieldElement, + 1u32.into(), + ) + .into(), + ArrayExpressionInner::Value(vec![ + FieldElementExpression::Number(1.into()).into() + ]) + .annotate(Type::FieldElement, 1u32) + .into(), ), TypedStatement::Definition( Variable::uint("K", UBitwidth::B32).into(), UExpression::from(1u32).into(), ), - TypedStatement::PopCallLog(vec![( - ConcreteVariable::array( + TypedStatement::Definition( + Variable::array( Identifier::from(CoreIdentifier::Call(0)).version(0), - ConcreteType::FieldElement, - 1, - ), + Type::FieldElement, + 1u32.into(), + ) + .into(), ArrayExpressionInner::Identifier(Identifier::from("a").version(1)) .annotate(Type::FieldElement, 1u32) .into(), - )]), + ), + TypedStatement::PopCallLog, TypedStatement::Definition( Variable::array( Identifier::from("b").version(1), @@ -902,7 +1183,7 @@ mod tests { "main".into(), TypedModule { functions: vec![( - DeclarationFunctionKey::with_id("main").signature( + DeclarationFunctionKey::with_location("main", "main").signature( DeclarationSignature::new() .inputs(vec![DeclarationType::FieldElement]) .outputs(vec![DeclarationType::FieldElement]), @@ -917,6 +1198,9 @@ mod tests { .collect(), }; + println!("{}", reduced.clone().unwrap()); + println!("{}", expected); + assert_eq!(reduced.unwrap(), expected); } @@ -1010,7 +1294,7 @@ mod tests { // TypedStatement::MultipleDefinition( // vec![Variable::array("b", Type::FieldElement, 1u32.into()).into()], // TypedExpressionList::FunctionCall( - // DeclarationFunctionKey::with_id("foo").signature(foo_signature.clone()), + // DeclarationFunctionKey::with_location(module_id.clone(), "foo").signature(foo_signature.clone()), // vec![ArrayExpressionInner::Identifier("b".into()) // .annotate(Type::FieldElement, 1u32) // .into()], @@ -1029,11 +1313,11 @@ mod tests { // TypedModule { // functions: vec![ // ( - // DeclarationFunctionKey::with_id("foo").signature(foo_signature.clone()), + // DeclarationFunctionKey::with_location(module_id.clone(), "foo").signature(foo_signature.clone()), // TypedFunctionSymbol::Here(foo), // ), // ( - // DeclarationFunctionKey::with_id("main"), + // DeclarationFunctionKey::with_location(module_id.clone(), "main"), // TypedFunctionSymbol::Here(main), // ), // ] @@ -1071,7 +1355,7 @@ mod tests { // ), // TypedStatement::PushCallLog( // "main".into(), - // DeclarationFunctionKey::with_id("foo").signature(foo_signature.clone()), + // DeclarationFunctionKey::with_location(module_id.clone(), "foo").signature(foo_signature.clone()), // GenericsAssignment(vec![("K", 1)].into_iter().collect()), // vec![( // ConcreteVariable::array("a", ConcreteType::FieldElement, 1).into(), @@ -1113,7 +1397,7 @@ mod tests { // "main".into(), // TypedModule { // functions: vec![( - // DeclarationFunctionKey::with_id("main").signature( + // DeclarationFunctionKey::with_location(module_id.clone(), "main").signature( // DeclarationSignature::new() // .inputs(vec![DeclarationType::FieldElement]) // .outputs(vec![DeclarationType::FieldElement]), @@ -1178,7 +1462,8 @@ mod tests { TypedStatement::MultipleDefinition( vec![Variable::array("b", Type::FieldElement, 1u32.into()).into()], TypedExpressionList::FunctionCall( - DeclarationFunctionKey::with_id("foo").signature(foo_signature.clone()), + DeclarationFunctionKey::with_location("main", "foo") + .signature(foo_signature.clone()), vec![ArrayExpressionInner::Value(vec![]) .annotate(Type::FieldElement, 0u32) .into()], @@ -1197,11 +1482,12 @@ mod tests { TypedModule { functions: vec![ ( - DeclarationFunctionKey::with_id("foo").signature(foo_signature.clone()), + DeclarationFunctionKey::with_location("main", "foo") + .signature(foo_signature.clone()), TypedFunctionSymbol::Here(foo), ), ( - DeclarationFunctionKey::with_id("main").signature( + DeclarationFunctionKey::with_location("main", "main").signature( DeclarationSignature::new().inputs(vec![]).outputs(vec![]), ), TypedFunctionSymbol::Here(main), diff --git a/zokrates_core/src/static_analysis/reducer/shallow_ssa.rs b/zokrates_core/src/static_analysis/reducer/shallow_ssa.rs index 9ad35bec..fc923770 100644 --- a/zokrates_core/src/static_analysis/reducer/shallow_ssa.rs +++ b/zokrates_core/src/static_analysis/reducer/shallow_ssa.rs @@ -37,15 +37,15 @@ use super::{Output, Versions}; pub struct ShallowTransformer<'ast, 'a> { // version index for any variable name - versions: &'a mut Versions<'ast>, + pub versions: &'a mut Versions<'ast>, // A backup of the versions before each for-loop - for_loop_backups: Vec>, + pub for_loop_backups: Vec>, // whether all statements could be unrolled so far. Loops with variable bounds cannot. - blocked: bool, + pub blocked: bool, } impl<'ast, 'a> ShallowTransformer<'ast, 'a> { - fn with_versions(versions: &'a mut Versions<'ast>) -> Self { + pub fn with_versions(versions: &'a mut Versions<'ast>) -> Self { ShallowTransformer { versions, for_loop_backups: Vec::default(), @@ -568,7 +568,8 @@ impl<'ast, 'a, T: Field> Folder<'ast, T> for ShallowTransformer<'ast, 'a> { e: TypedExpressionList<'ast, T>, ) -> TypedExpressionList<'ast, T> { match e { - TypedExpressionList::FunctionCall(ref k, _, _) => { + TypedExpressionList::FunctionCall(ref k, ref a, _) => { + println!("{:#?}", a.iter().map(|a| a.get_type()).collect::>()); if !k.id.starts_with("_") { self.blocked = true; } @@ -929,7 +930,7 @@ mod tests { let s: TypedStatement = TypedStatement::MultipleDefinition( vec![Variable::field_element("a")], TypedExpressionList::FunctionCall( - DeclarationFunctionKey::with_id("foo").signature( + DeclarationFunctionKey::with_location("main", "foo").signature( DeclarationSignature::new() .inputs(vec![DeclarationType::FieldElement]) .outputs(vec![DeclarationType::FieldElement]), @@ -943,7 +944,7 @@ mod tests { vec![TypedStatement::MultipleDefinition( vec![Variable::field_element(Identifier::from("a").version(1))], TypedExpressionList::FunctionCall( - DeclarationFunctionKey::with_id("foo").signature( + DeclarationFunctionKey::with_location("main", "foo").signature( DeclarationSignature::new() .inputs(vec![DeclarationType::FieldElement]) .outputs(vec![DeclarationType::FieldElement]) @@ -1447,7 +1448,7 @@ mod tests { TypedStatement::MultipleDefinition( vec![Variable::field_element("a").into()], TypedExpressionList::FunctionCall( - DeclarationFunctionKey::with_id("foo"), + DeclarationFunctionKey::with_location("main", "foo"), vec![FieldElementExpression::Identifier("a".into()).into()], vec![Type::FieldElement], ), @@ -1463,7 +1464,7 @@ mod tests { FieldElementExpression::mult( FieldElementExpression::Identifier("a".into()), FieldElementExpression::FunctionCall( - FunctionKey::with_id("foo"), + DeclarationFunctionKey::with_location("main", "foo"), vec![FieldElementExpression::Identifier("a".into()).into()], ), ) @@ -1511,7 +1512,7 @@ mod tests { TypedStatement::MultipleDefinition( vec![Variable::field_element(Identifier::from("a").version(2)).into()], TypedExpressionList::FunctionCall( - DeclarationFunctionKey::with_id("foo"), + DeclarationFunctionKey::with_location("main", "foo"), vec![FieldElementExpression::Identifier( Identifier::from("a").version(1), ) @@ -1530,7 +1531,7 @@ mod tests { FieldElementExpression::mult( FieldElementExpression::Identifier(Identifier::from("a").version(2)), FieldElementExpression::FunctionCall( - FunctionKey::with_id("foo"), + DeclarationFunctionKey::with_location("main", "foo"), vec![FieldElementExpression::Identifier( Identifier::from("a").version(2), ) diff --git a/zokrates_core/src/typed_absy/abi.rs b/zokrates_core/src/typed_absy/abi.rs index caeb31fd..a3369825 100644 --- a/zokrates_core/src/typed_absy/abi.rs +++ b/zokrates_core/src/typed_absy/abi.rs @@ -43,7 +43,7 @@ mod tests { fn generate_abi_from_typed_ast() { let mut functions = HashMap::new(); functions.insert( - ConcreteFunctionKey::with_id("main").into(), + ConcreteFunctionKey::with_location("main", "main").into(), TypedFunctionSymbol::Here(TypedFunction { generics: vec![], arguments: vec![ diff --git a/zokrates_core/src/typed_absy/integer.rs b/zokrates_core/src/typed_absy/integer.rs index 30b89164..6a4b7dcf 100644 --- a/zokrates_core/src/typed_absy/integer.rs +++ b/zokrates_core/src/typed_absy/integer.rs @@ -381,80 +381,70 @@ impl<'ast, T: Field> ArrayExpression<'ast, T> { ) -> Result> { let array_ty = array.get_array_type(); - if array_ty == target_array_ty { + if array_ty.weak_eq(&target_array_ty) { return Ok(array); } - // sizes must be equal - let converted = match target_array_ty.size == array_ty.size { - true => - // elements must fit in the target type - { - match array.into_inner() { - ArrayExpressionInner::Value(inline_array) => { - match *target_array_ty.ty { - Type::Int => Ok(inline_array), - Type::FieldElement => { - // try to convert all elements to field - inline_array - .into_iter() - .map(|e| { - FieldElementExpression::try_from_typed(e) - .map(TypedExpression::from) - }) - .collect::>, _>>() - .map_err(TypedExpression::from) - } - Type::Uint(bitwidth) => { - // try to convert all elements to uint - inline_array - .into_iter() - .map(|e| { - UExpression::try_from_typed(e, bitwidth) - .map(TypedExpression::from) - }) - .collect::>, _>>() - .map_err(TypedExpression::from) - } - Type::Array(ref inner_array_ty) => { - // try to convert all elements to array - inline_array - .into_iter() - .map(|e| { - ArrayExpression::try_from_typed(e, inner_array_ty.clone()) - .map(TypedExpression::from) - }) - .collect::>, _>>() - .map_err(TypedExpression::from) - } - Type::Struct(ref struct_ty) => { - // try to convert all elements to struct - inline_array - .into_iter() - .map(|e| { - StructExpression::try_from_typed(e, struct_ty.clone()) - .map(TypedExpression::from) - }) - .collect::>, _>>() - .map_err(TypedExpression::from) - } - Type::Boolean => { - // try to convert all elements to boolean - inline_array - .into_iter() - .map(|e| { - BooleanExpression::try_from_typed(e) - .map(TypedExpression::from) - }) - .collect::>, _>>() - .map_err(TypedExpression::from) - } - } + // elements must fit in the target type + let converted = match array.into_inner() { + ArrayExpressionInner::Value(inline_array) => { + match *target_array_ty.ty { + Type::Int => Ok(inline_array), + Type::FieldElement => { + // try to convert all elements to field + inline_array + .into_iter() + .map(|e| { + FieldElementExpression::try_from_typed(e).map(TypedExpression::from) + }) + .collect::>, _>>() + .map_err(TypedExpression::from) + } + Type::Uint(bitwidth) => { + // try to convert all elements to uint + inline_array + .into_iter() + .map(|e| { + UExpression::try_from_typed(e, bitwidth).map(TypedExpression::from) + }) + .collect::>, _>>() + .map_err(TypedExpression::from) + } + Type::Array(ref inner_array_ty) => { + // try to convert all elements to array + inline_array + .into_iter() + .map(|e| { + ArrayExpression::try_from_typed(e, inner_array_ty.clone()) + .map(TypedExpression::from) + }) + .collect::>, _>>() + .map_err(TypedExpression::from) + } + Type::Struct(ref struct_ty) => { + // try to convert all elements to struct + inline_array + .into_iter() + .map(|e| { + StructExpression::try_from_typed(e, struct_ty.clone()) + .map(TypedExpression::from) + }) + .collect::>, _>>() + .map_err(TypedExpression::from) + } + Type::Boolean => { + // try to convert all elements to boolean + inline_array + .into_iter() + .map(|e| { + BooleanExpression::try_from_typed(e).map(TypedExpression::from) + }) + .collect::>, _>>() + .map_err(TypedExpression::from) } - _ => unreachable!(""), } } - false => Err(array.into()), + _ => unreachable!(), }?; Ok(ArrayExpressionInner::Value(converted).annotate(*target_array_ty.ty, array_ty.size)) diff --git a/zokrates_core/src/typed_absy/mod.rs b/zokrates_core/src/typed_absy/mod.rs index 103e2f5c..ef8807f3 100644 --- a/zokrates_core/src/typed_absy/mod.rs +++ b/zokrates_core/src/typed_absy/mod.rs @@ -8,6 +8,7 @@ pub mod abi; pub mod folder; pub mod identifier; +pub mod result_folder; mod integer; mod parameter; @@ -19,8 +20,8 @@ pub use self::identifier::CoreIdentifier; pub use self::parameter::{DeclarationParameter, GParameter}; pub use self::types::{ ConcreteFunctionKey, ConcreteSignature, ConcreteType, DeclarationFunctionKey, - DeclarationSignature, DeclarationType, GType, GenericIdentifier, Signature, StructType, Type, - UBitwidth, + DeclarationSignature, DeclarationType, GArrayType, GStructType, GType, GenericIdentifier, + Signature, StructType, Type, UBitwidth, }; use typed_absy::types::GenericsAssignment; @@ -141,7 +142,7 @@ pub struct TypedModule<'ast, T> { #[derive(Clone, PartialEq)] pub enum TypedFunctionSymbol<'ast, T> { Here(TypedFunction<'ast, T>), - There(DeclarationFunctionKey<'ast>, TypedModuleId), + There(DeclarationFunctionKey<'ast>), Flat(FlatEmbed), } @@ -150,7 +151,7 @@ impl<'ast, T: fmt::Debug> fmt::Debug for TypedFunctionSymbol<'ast, T> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { TypedFunctionSymbol::Here(s) => write!(f, "Here({:?})", s), - TypedFunctionSymbol::There(key, module) => write!(f, "There({:?}, {:?})", key, module), + TypedFunctionSymbol::There(key) => write!(f, "There({:?})", key), TypedFunctionSymbol::Flat(s) => write!(f, "Flat({:?})", s), } } @@ -163,8 +164,8 @@ impl<'ast, T: Field> TypedFunctionSymbol<'ast, T> { ) -> DeclarationSignature<'ast> { match self { TypedFunctionSymbol::Here(f) => f.signature.clone(), - TypedFunctionSymbol::There(key, module_id) => modules - .get(module_id) + TypedFunctionSymbol::There(key) => modules + .get(&key.module) .unwrap() .functions .get(key) @@ -183,10 +184,10 @@ impl<'ast, T: fmt::Display> fmt::Display for TypedModule<'ast, T> { .iter() .map(|(key, symbol)| match symbol { TypedFunctionSymbol::Here(ref function) => format!("def {}{}", key.id, function), - TypedFunctionSymbol::There(ref fun_key, ref module_id) => format!( + TypedFunctionSymbol::There(ref fun_key) => format!( "import {} from \"{}\" as {} // with signature {}", fun_key.id, - module_id.display(), + fun_key.module.display(), key.id, key.signature ), @@ -230,12 +231,16 @@ impl<'ast, T: fmt::Display> fmt::Display for TypedFunction<'ast, T> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!( f, - "<{}>", - self.generics - .iter() - .map(|g| g.to_string()) - .collect::>() - .join(", ") + "{}", + if self.generics.len() > 0 { + self.generics + .iter() + .map(|g| g.to_string()) + .collect::>() + .join(", ") + } else { + "".to_string() + } )?; write!( f, @@ -271,7 +276,7 @@ impl<'ast, T: fmt::Display> fmt::Display for TypedFunction<'ast, T> { for s in &self.statements { match s { - TypedStatement::PopCallLog(..) => tab -= 1, + TypedStatement::PopCallLog => tab -= 1, _ => {} }; @@ -380,13 +385,8 @@ pub enum TypedStatement<'ast, T> { ), MultipleDefinition(Vec>, TypedExpressionList<'ast, T>), // Aux - PushCallLog( - TypedModuleId, - DeclarationFunctionKey<'ast>, - GenericsAssignment<'ast>, - Vec<(ConcreteVariable<'ast>, TypedExpression<'ast, T>)>, - ), - PopCallLog(Vec<(ConcreteVariable<'ast>, TypedExpression<'ast, T>)>), + PushCallLog(DeclarationFunctionKey<'ast>, GenericsAssignment<'ast>), + PopCallLog, } impl<'ast, T: fmt::Debug> fmt::Debug for TypedStatement<'ast, T> { @@ -417,16 +417,10 @@ impl<'ast, T: fmt::Debug> fmt::Debug for TypedStatement<'ast, T> { TypedStatement::MultipleDefinition(ref lhs, ref rhs) => { write!(f, "MultipleDefinition({:?}, {:?})", lhs, rhs) } - TypedStatement::PushCallLog(ref module_id, ref key, ref generics, ref assignments) => { - write!( - f, - "PushCallLog({:?}, {:?}, {:?}, {:?})", - module_id, key, generics, assignments - ) - } - TypedStatement::PopCallLog(ref assignments) => { - write!(f, "PopCallLog({:?})", assignments) + TypedStatement::PushCallLog(ref key, ref generics) => { + write!(f, "PushCallLog({:?}, {:?})", key, generics) } + TypedStatement::PopCallLog => write!(f, "PopCallLog"), } } } @@ -480,29 +474,14 @@ impl<'ast, T: fmt::Display> fmt::Display for TypedStatement<'ast, T> { } write!(f, " = {}", rhs) } - TypedStatement::PushCallLog(ref module_id, ref key, ref generics, ref assignments) => { - write!( - f, - "// PUSH CALL TO {}_{}::<{}> with {}", - module_id.display(), - key.id, - generics, - assignments - .iter() - .map(|(v, e)| format!("{} := {}", v, e)) - .collect::>() - .join(", ") - ) - } - TypedStatement::PopCallLog(ref assignments) => write!( + TypedStatement::PushCallLog(ref key, ref generics) => write!( f, - "// POP CALL with {}", - assignments - .iter() - .map(|(v, e)| format!("{} := {}", v, e)) - .collect::>() - .join(", ") + "// PUSH CALL TO {}/{}::<{}>", + key.module.display(), + key.id, + generics, ), + TypedStatement::PopCallLog => write!(f, "// POP CALL",), } } } @@ -735,7 +714,7 @@ pub enum FieldElementExpression<'ast, T> { Box>, Box>, ), - FunctionCall(FunctionKey<'ast, T>, Vec>), + FunctionCall(DeclarationFunctionKey<'ast>, Vec>), Member(Box>, MemberId), Select(Box>, Box>), } @@ -822,7 +801,7 @@ pub enum BooleanExpression<'ast, T> { Box>, ), Member(Box>, MemberId), - FunctionCall(FunctionKey<'ast, T>, Vec>), + FunctionCall(DeclarationFunctionKey<'ast>, Vec>), Select(Box>, Box>), } @@ -848,7 +827,7 @@ pub struct ArrayExpression<'ast, T> { pub enum ArrayExpressionInner<'ast, T> { Identifier(Identifier<'ast>), Value(Vec>), - FunctionCall(FunctionKey<'ast, T>, Vec>), + FunctionCall(DeclarationFunctionKey<'ast>, Vec>), IfElse( Box>, Box>, @@ -939,7 +918,7 @@ impl<'ast, T> StructExpression<'ast, T> { pub enum StructExpressionInner<'ast, T> { Identifier(Identifier<'ast>), Value(Vec>), - FunctionCall(FunctionKey<'ast, T>, Vec>), + FunctionCall(DeclarationFunctionKey<'ast>, Vec>), IfElse( Box>, Box>, @@ -1328,7 +1307,9 @@ impl<'ast, T: Field> From> for TypedExpression<'ast, T> { Type::Array(ty) => ArrayExpressionInner::Identifier(v.id) .annotate(*ty.ty, ty.size) .into(), - _ => unimplemented!(), + Type::Struct(ty) => StructExpressionInner::Identifier(v.id).annotate(ty).into(), + Type::Uint(w) => UExpressionInner::Identifier(v.id).annotate(w).into(), + Type::Int => unreachable!(), } } } @@ -1536,3 +1517,69 @@ impl<'ast, T: Clone> Member<'ast, T> for StructExpression<'ast, T> { StructExpressionInner::Member(box s, member_id).annotate(members) } } + +pub trait FunctionCall<'ast, T> { + fn function_call( + key: DeclarationFunctionKey<'ast>, + arguments: Vec>, + ) -> Self; +} + +impl<'ast, T> FunctionCall<'ast, T> for FieldElementExpression<'ast, T> { + fn function_call( + key: DeclarationFunctionKey<'ast>, + arguments: Vec>, + ) -> Self { + FieldElementExpression::FunctionCall(key, arguments) + } +} + +impl<'ast, T> FunctionCall<'ast, T> for BooleanExpression<'ast, T> { + fn function_call( + key: DeclarationFunctionKey<'ast>, + arguments: Vec>, + ) -> Self { + BooleanExpression::FunctionCall(key, arguments) + } +} + +impl<'ast, T: Clone> FunctionCall<'ast, T> for UExpression<'ast, T> { + fn function_call( + key: DeclarationFunctionKey<'ast>, + arguments: Vec>, + ) -> Self { + let bitwidth = match &key.signature.outputs[0] { + DeclarationType::Uint(bitwidth) => bitwidth.clone(), + _ => unreachable!(), + }; + UExpressionInner::FunctionCall(key, arguments).annotate(bitwidth) + } +} + +impl<'ast, T: Clone> FunctionCall<'ast, T> for ArrayExpression<'ast, T> { + fn function_call( + key: DeclarationFunctionKey<'ast>, + arguments: Vec>, + ) -> Self { + let array_ty = match &key.signature.outputs[0] { + DeclarationType::Array(array_ty) => array_ty.clone(), + _ => unreachable!(), + }; + ArrayExpressionInner::FunctionCall(key, arguments) + .annotate(Type::::from(*array_ty.ty), array_ty.size) + } +} + +impl<'ast, T: Clone> FunctionCall<'ast, T> for StructExpression<'ast, T> { + fn function_call( + key: DeclarationFunctionKey<'ast>, + arguments: Vec>, + ) -> Self { + let struct_ty = match &key.signature.outputs[0] { + DeclarationType::Struct(struct_ty) => struct_ty.clone(), + _ => unreachable!(), + }; + + StructExpressionInner::FunctionCall(key, arguments).annotate(StructType::from(struct_ty)) + } +} diff --git a/zokrates_core/src/typed_absy/result_folder.rs b/zokrates_core/src/typed_absy/result_folder.rs new file mode 100644 index 00000000..7f6dfaf0 --- /dev/null +++ b/zokrates_core/src/typed_absy/result_folder.rs @@ -0,0 +1,667 @@ +// Generic walk through a typed AST. Not mutating in place + +use crate::typed_absy::*; +use typed_absy::types::{ArrayType, StructMember, StructType}; +use zokrates_field::Field; + +pub trait ResultFolder<'ast, T: Field>: Sized { + type Error; + + fn fold_function_symbol( + &mut self, + s: TypedFunctionSymbol<'ast, T>, + ) -> Result, Self::Error> { + fold_function_symbol(self, s) + } + + fn fold_function( + &mut self, + f: TypedFunction<'ast, T>, + ) -> Result, Self::Error> { + fold_function(self, f) + } + + fn fold_parameter( + &mut self, + p: DeclarationParameter<'ast>, + ) -> Result, Self::Error> { + Ok(DeclarationParameter { + id: self.fold_declaration_variable(p.id)?, + ..p + }) + } + + fn fold_name(&mut self, n: Identifier<'ast>) -> Result, Self::Error> { + Ok(n) + } + + fn fold_variable(&mut self, v: Variable<'ast, T>) -> Result, Self::Error> { + Ok(Variable { + id: self.fold_name(v.id)?, + _type: self.fold_type(v._type)?, + }) + } + + fn fold_declaration_variable( + &mut self, + v: DeclarationVariable<'ast>, + ) -> Result, Self::Error> { + Ok(DeclarationVariable { + id: self.fold_name(v.id)?, + _type: self.fold_declaration_type(v._type)?, + }) + } + + fn fold_type(&mut self, t: Type<'ast, T>) -> Result, Self::Error> { + use self::GType::*; + + match t { + Array(array_type) => Ok(Array(ArrayType { + ty: box self.fold_type(*array_type.ty)?, + size: self.fold_uint_expression(array_type.size)?, + })), + Struct(struct_type) => Ok(Struct(StructType { + members: struct_type + .members + .into_iter() + .map(|m| { + self.fold_type(*m.ty.clone()) + .map(|ty| StructMember { ty: box ty, ..m }) + }) + .collect::>()?, + ..struct_type + })), + t => Ok(t), + } + } + + fn fold_declaration_type( + &mut self, + t: DeclarationType<'ast>, + ) -> Result, Self::Error> { + Ok(t) + } + + fn fold_assignee( + &mut self, + a: TypedAssignee<'ast, T>, + ) -> Result, Self::Error> { + match a { + TypedAssignee::Identifier(v) => Ok(TypedAssignee::Identifier(self.fold_variable(v)?)), + TypedAssignee::Select(box a, box index) => Ok(TypedAssignee::Select( + box self.fold_assignee(a)?, + box self.fold_uint_expression(index)?, + )), + TypedAssignee::Member(box s, m) => { + Ok(TypedAssignee::Member(box self.fold_assignee(s)?, m)) + } + } + } + + fn fold_statement( + &mut self, + s: TypedStatement<'ast, T>, + ) -> Result>, Self::Error> { + fold_statement(self, s) + } + + fn fold_expression( + &mut self, + e: TypedExpression<'ast, T>, + ) -> Result, Self::Error> { + match e { + TypedExpression::FieldElement(e) => Ok(self.fold_field_expression(e)?.into()), + TypedExpression::Boolean(e) => Ok(self.fold_boolean_expression(e)?.into()), + TypedExpression::Uint(e) => Ok(self.fold_uint_expression(e)?.into()), + TypedExpression::Array(e) => Ok(self.fold_array_expression(e)?.into()), + TypedExpression::Struct(e) => Ok(self.fold_struct_expression(e)?.into()), + TypedExpression::Int(e) => Ok(self.fold_int_expression(e)?.into()), + } + } + + fn fold_array_expression( + &mut self, + e: ArrayExpression<'ast, T>, + ) -> Result, Self::Error> { + fold_array_expression(self, e) + } + + fn fold_struct_expression( + &mut self, + e: StructExpression<'ast, T>, + ) -> Result, Self::Error> { + fold_struct_expression(self, e) + } + + fn fold_expression_list( + &mut self, + es: TypedExpressionList<'ast, T>, + ) -> Result, Self::Error> { + fold_expression_list(self, es) + } + + fn fold_int_expression( + &mut self, + e: IntExpression<'ast, T>, + ) -> Result, Self::Error> { + fold_int_expression(self, e) + } + + fn fold_field_expression( + &mut self, + e: FieldElementExpression<'ast, T>, + ) -> Result, Self::Error> { + fold_field_expression(self, e) + } + fn fold_boolean_expression( + &mut self, + e: BooleanExpression<'ast, T>, + ) -> Result, Self::Error> { + fold_boolean_expression(self, e) + } + fn fold_uint_expression( + &mut self, + e: UExpression<'ast, T>, + ) -> Result, Self::Error> { + fold_uint_expression(self, e) + } + + fn fold_uint_expression_inner( + &mut self, + bitwidth: UBitwidth, + e: UExpressionInner<'ast, T>, + ) -> Result, Self::Error> { + fold_uint_expression_inner(self, bitwidth, e) + } + + fn fold_array_expression_inner( + &mut self, + ty: &Type<'ast, T>, + size: UExpression<'ast, T>, + e: ArrayExpressionInner<'ast, T>, + ) -> Result, Self::Error> { + fold_array_expression_inner(self, ty, size, e) + } + fn fold_struct_expression_inner( + &mut self, + ty: &StructType<'ast, T>, + e: StructExpressionInner<'ast, T>, + ) -> Result, Self::Error> { + fold_struct_expression_inner(self, ty, e) + } +} + +pub fn fold_statement<'ast, T: Field, F: ResultFolder<'ast, T>>( + f: &mut F, + s: TypedStatement<'ast, T>, +) -> Result>, F::Error> { + let res = match s { + TypedStatement::Return(expressions) => TypedStatement::Return( + expressions + .into_iter() + .map(|e| f.fold_expression(e)) + .collect::>()?, + ), + TypedStatement::Definition(a, e) => { + TypedStatement::Definition(f.fold_assignee(a)?, f.fold_expression(e)?) + } + TypedStatement::Declaration(v) => TypedStatement::Declaration(f.fold_variable(v)?), + TypedStatement::Assertion(e) => TypedStatement::Assertion(f.fold_boolean_expression(e)?), + TypedStatement::For(v, from, to, statements) => TypedStatement::For( + f.fold_variable(v)?, + f.fold_uint_expression(from)?, + f.fold_uint_expression(to)?, + statements + .into_iter() + .map(|s| f.fold_statement(s)) + .collect::, _>>()? + .into_iter() + .flatten() + .collect(), + ), + TypedStatement::MultipleDefinition(variables, elist) => TypedStatement::MultipleDefinition( + variables + .into_iter() + .map(|v| f.fold_variable(v)) + .collect::>()?, + f.fold_expression_list(elist)?, + ), + s => s, + }; + Ok(vec![res]) +} + +pub fn fold_array_expression_inner<'ast, T: Field, F: ResultFolder<'ast, T>>( + f: &mut F, + _: &Type<'ast, T>, + _: UExpression<'ast, T>, + e: ArrayExpressionInner<'ast, T>, +) -> Result, F::Error> { + let e = match e { + ArrayExpressionInner::Identifier(id) => ArrayExpressionInner::Identifier(f.fold_name(id)?), + ArrayExpressionInner::Value(exprs) => ArrayExpressionInner::Value( + exprs + .into_iter() + .map(|e| f.fold_expression(e)) + .collect::>()?, + ), + ArrayExpressionInner::FunctionCall(id, exps) => { + let exps = exps + .into_iter() + .map(|e| f.fold_expression(e)) + .collect::>()?; + ArrayExpressionInner::FunctionCall(id, exps) + } + ArrayExpressionInner::IfElse(box condition, box consequence, box alternative) => { + ArrayExpressionInner::IfElse( + box f.fold_boolean_expression(condition)?, + box f.fold_array_expression(consequence)?, + box f.fold_array_expression(alternative)?, + ) + } + ArrayExpressionInner::Member(box s, id) => { + let s = f.fold_struct_expression(s)?; + ArrayExpressionInner::Member(box s, id) + } + ArrayExpressionInner::Select(box array, box index) => { + let array = f.fold_array_expression(array)?; + let index = f.fold_uint_expression(index)?; + ArrayExpressionInner::Select(box array, box index) + } + }; + Ok(e) +} + +pub fn fold_struct_expression_inner<'ast, T: Field, F: ResultFolder<'ast, T>>( + f: &mut F, + _: &StructType<'ast, T>, + e: StructExpressionInner<'ast, T>, +) -> Result, F::Error> { + let e = match e { + StructExpressionInner::Identifier(id) => { + StructExpressionInner::Identifier(f.fold_name(id)?) + } + StructExpressionInner::Value(exprs) => StructExpressionInner::Value( + exprs + .into_iter() + .map(|e| f.fold_expression(e)) + .collect::>()?, + ), + StructExpressionInner::FunctionCall(id, exps) => { + let exps = exps + .into_iter() + .map(|e| f.fold_expression(e)) + .collect::>()?; + StructExpressionInner::FunctionCall(id, exps) + } + StructExpressionInner::IfElse(box condition, box consequence, box alternative) => { + StructExpressionInner::IfElse( + box f.fold_boolean_expression(condition)?, + box f.fold_struct_expression(consequence)?, + box f.fold_struct_expression(alternative)?, + ) + } + StructExpressionInner::Member(box s, id) => { + let s = f.fold_struct_expression(s)?; + StructExpressionInner::Member(box s, id) + } + StructExpressionInner::Select(box array, box index) => { + let array = f.fold_array_expression(array)?; + let index = f.fold_uint_expression(index)?; + StructExpressionInner::Select(box array, box index) + } + }; + Ok(e) +} + +pub fn fold_field_expression<'ast, T: Field, F: ResultFolder<'ast, T>>( + f: &mut F, + e: FieldElementExpression<'ast, T>, +) -> Result, F::Error> { + let e = match e { + FieldElementExpression::Number(n) => FieldElementExpression::Number(n), + FieldElementExpression::Identifier(id) => { + FieldElementExpression::Identifier(f.fold_name(id)?) + } + FieldElementExpression::Add(box e1, box e2) => { + let e1 = f.fold_field_expression(e1)?; + let e2 = f.fold_field_expression(e2)?; + FieldElementExpression::Add(box e1, box e2) + } + FieldElementExpression::Sub(box e1, box e2) => { + let e1 = f.fold_field_expression(e1)?; + let e2 = f.fold_field_expression(e2)?; + FieldElementExpression::Sub(box e1, box e2) + } + FieldElementExpression::Mult(box e1, box e2) => { + let e1 = f.fold_field_expression(e1)?; + let e2 = f.fold_field_expression(e2)?; + FieldElementExpression::Mult(box e1, box e2) + } + FieldElementExpression::Div(box e1, box e2) => { + let e1 = f.fold_field_expression(e1)?; + let e2 = f.fold_field_expression(e2)?; + FieldElementExpression::Div(box e1, box e2) + } + FieldElementExpression::Pow(box e1, box e2) => { + let e1 = f.fold_field_expression(e1)?; + let e2 = f.fold_uint_expression(e2)?; + FieldElementExpression::Pow(box e1, box e2) + } + FieldElementExpression::IfElse(box cond, box cons, box alt) => { + let cond = f.fold_boolean_expression(cond)?; + let cons = f.fold_field_expression(cons)?; + let alt = f.fold_field_expression(alt)?; + FieldElementExpression::IfElse(box cond, box cons, box alt) + } + FieldElementExpression::FunctionCall(key, exps) => { + let exps = exps + .into_iter() + .map(|e| f.fold_expression(e)) + .collect::>()?; + FieldElementExpression::FunctionCall(key, exps) + } + FieldElementExpression::Member(box s, id) => { + let s = f.fold_struct_expression(s)?; + FieldElementExpression::Member(box s, id) + } + FieldElementExpression::Select(box array, box index) => { + let array = f.fold_array_expression(array)?; + let index = f.fold_uint_expression(index)?; + FieldElementExpression::Select(box array, box index) + } + }; + Ok(e) +} + +pub fn fold_int_expression<'ast, T: Field, F: ResultFolder<'ast, T>>( + _: &mut F, + _: IntExpression<'ast, T>, +) -> Result, F::Error> { + unreachable!() +} + +pub fn fold_boolean_expression<'ast, T: Field, F: ResultFolder<'ast, T>>( + f: &mut F, + e: BooleanExpression<'ast, T>, +) -> Result, F::Error> { + let e = match e { + BooleanExpression::Value(v) => BooleanExpression::Value(v), + BooleanExpression::Identifier(id) => BooleanExpression::Identifier(f.fold_name(id)?), + BooleanExpression::FieldEq(box e1, box e2) => { + let e1 = f.fold_field_expression(e1)?; + let e2 = f.fold_field_expression(e2)?; + BooleanExpression::FieldEq(box e1, box e2) + } + BooleanExpression::BoolEq(box e1, box e2) => { + let e1 = f.fold_boolean_expression(e1)?; + let e2 = f.fold_boolean_expression(e2)?; + BooleanExpression::BoolEq(box e1, box e2) + } + BooleanExpression::ArrayEq(box e1, box e2) => { + let e1 = f.fold_array_expression(e1)?; + let e2 = f.fold_array_expression(e2)?; + BooleanExpression::ArrayEq(box e1, box e2) + } + BooleanExpression::StructEq(box e1, box e2) => { + let e1 = f.fold_struct_expression(e1)?; + let e2 = f.fold_struct_expression(e2)?; + BooleanExpression::StructEq(box e1, box e2) + } + BooleanExpression::UintEq(box e1, box e2) => { + let e1 = f.fold_uint_expression(e1)?; + let e2 = f.fold_uint_expression(e2)?; + BooleanExpression::UintEq(box e1, box e2) + } + BooleanExpression::FieldLt(box e1, box e2) => { + let e1 = f.fold_field_expression(e1)?; + let e2 = f.fold_field_expression(e2)?; + BooleanExpression::FieldLt(box e1, box e2) + } + BooleanExpression::FieldLe(box e1, box e2) => { + let e1 = f.fold_field_expression(e1)?; + let e2 = f.fold_field_expression(e2)?; + BooleanExpression::FieldLe(box e1, box e2) + } + BooleanExpression::FieldGt(box e1, box e2) => { + let e1 = f.fold_field_expression(e1)?; + let e2 = f.fold_field_expression(e2)?; + BooleanExpression::FieldGt(box e1, box e2) + } + BooleanExpression::FieldGe(box e1, box e2) => { + let e1 = f.fold_field_expression(e1)?; + let e2 = f.fold_field_expression(e2)?; + BooleanExpression::FieldGe(box e1, box e2) + } + BooleanExpression::UintLt(box e1, box e2) => { + let e1 = f.fold_uint_expression(e1)?; + let e2 = f.fold_uint_expression(e2)?; + BooleanExpression::UintLt(box e1, box e2) + } + BooleanExpression::UintLe(box e1, box e2) => { + let e1 = f.fold_uint_expression(e1)?; + let e2 = f.fold_uint_expression(e2)?; + BooleanExpression::UintLe(box e1, box e2) + } + BooleanExpression::UintGt(box e1, box e2) => { + let e1 = f.fold_uint_expression(e1)?; + let e2 = f.fold_uint_expression(e2)?; + BooleanExpression::UintGt(box e1, box e2) + } + BooleanExpression::UintGe(box e1, box e2) => { + let e1 = f.fold_uint_expression(e1)?; + let e2 = f.fold_uint_expression(e2)?; + BooleanExpression::UintGe(box e1, box e2) + } + BooleanExpression::Or(box e1, box e2) => { + let e1 = f.fold_boolean_expression(e1)?; + let e2 = f.fold_boolean_expression(e2)?; + BooleanExpression::Or(box e1, box e2) + } + BooleanExpression::And(box e1, box e2) => { + let e1 = f.fold_boolean_expression(e1)?; + let e2 = f.fold_boolean_expression(e2)?; + BooleanExpression::And(box e1, box e2) + } + BooleanExpression::Not(box e) => { + let e = f.fold_boolean_expression(e)?; + BooleanExpression::Not(box e) + } + BooleanExpression::FunctionCall(key, exps) => { + let exps = exps + .into_iter() + .map(|e| f.fold_expression(e)) + .collect::>()?; + BooleanExpression::FunctionCall(key, exps) + } + BooleanExpression::IfElse(box cond, box cons, box alt) => { + let cond = f.fold_boolean_expression(cond)?; + let cons = f.fold_boolean_expression(cons)?; + let alt = f.fold_boolean_expression(alt)?; + BooleanExpression::IfElse(box cond, box cons, box alt) + } + BooleanExpression::Member(box s, id) => { + let s = f.fold_struct_expression(s)?; + BooleanExpression::Member(box s, id) + } + BooleanExpression::Select(box array, box index) => { + let array = f.fold_array_expression(array)?; + let index = f.fold_uint_expression(index)?; + BooleanExpression::Select(box array, box index) + } + }; + Ok(e) +} + +pub fn fold_uint_expression<'ast, T: Field, F: ResultFolder<'ast, T>>( + f: &mut F, + e: UExpression<'ast, T>, +) -> Result, F::Error> { + Ok(UExpression { + inner: f.fold_uint_expression_inner(e.bitwidth, e.inner)?, + ..e + }) +} + +pub fn fold_uint_expression_inner<'ast, T: Field, F: ResultFolder<'ast, T>>( + f: &mut F, + _: UBitwidth, + e: UExpressionInner<'ast, T>, +) -> Result, F::Error> { + let e = match e { + UExpressionInner::Value(v) => UExpressionInner::Value(v), + UExpressionInner::Identifier(id) => UExpressionInner::Identifier(f.fold_name(id)?), + UExpressionInner::Add(box left, box right) => { + let left = f.fold_uint_expression(left)?; + let right = f.fold_uint_expression(right)?; + + UExpressionInner::Add(box left, box right) + } + UExpressionInner::Sub(box left, box right) => { + let left = f.fold_uint_expression(left)?; + let right = f.fold_uint_expression(right)?; + + UExpressionInner::Sub(box left, box right) + } + UExpressionInner::Mult(box left, box right) => { + let left = f.fold_uint_expression(left)?; + let right = f.fold_uint_expression(right)?; + + UExpressionInner::Mult(box left, box right) + } + UExpressionInner::Xor(box left, box right) => { + let left = f.fold_uint_expression(left)?; + let right = f.fold_uint_expression(right)?; + + UExpressionInner::Xor(box left, box right) + } + UExpressionInner::And(box left, box right) => { + let left = f.fold_uint_expression(left)?; + let right = f.fold_uint_expression(right)?; + + UExpressionInner::And(box left, box right) + } + UExpressionInner::Or(box left, box right) => { + let left = f.fold_uint_expression(left)?; + let right = f.fold_uint_expression(right)?; + + UExpressionInner::Or(box left, box right) + } + UExpressionInner::LeftShift(box e, box by) => { + let e = f.fold_uint_expression(e)?; + let by = f.fold_field_expression(by)?; + + UExpressionInner::LeftShift(box e, box by) + } + UExpressionInner::RightShift(box e, box by) => { + let e = f.fold_uint_expression(e)?; + let by = f.fold_field_expression(by)?; + + UExpressionInner::RightShift(box e, box by) + } + UExpressionInner::Not(box e) => { + let e = f.fold_uint_expression(e)?; + + UExpressionInner::Not(box e) + } + UExpressionInner::FunctionCall(key, exps) => { + let exps = exps + .into_iter() + .map(|e| f.fold_expression(e)) + .collect::>()?; + UExpressionInner::FunctionCall(key, exps) + } + UExpressionInner::Select(box array, box index) => { + let array = f.fold_array_expression(array)?; + let index = f.fold_uint_expression(index)?; + UExpressionInner::Select(box array, box index) + } + UExpressionInner::IfElse(box cond, box cons, box alt) => { + let cond = f.fold_boolean_expression(cond)?; + let cons = f.fold_uint_expression(cons)?; + let alt = f.fold_uint_expression(alt)?; + UExpressionInner::IfElse(box cond, box cons, box alt) + } + UExpressionInner::Member(box s, id) => { + let s = f.fold_struct_expression(s)?; + UExpressionInner::Member(box s, id) + } + }; + Ok(e) +} + +pub fn fold_function<'ast, T: Field, F: ResultFolder<'ast, T>>( + f: &mut F, + fun: TypedFunction<'ast, T>, +) -> Result, F::Error> { + Ok(TypedFunction { + arguments: fun + .arguments + .into_iter() + .map(|a| f.fold_parameter(a)) + .collect::>()?, + statements: fun + .statements + .into_iter() + .map(|s| f.fold_statement(s)) + .collect::, _>>()? + .into_iter() + .flatten() + .collect(), + ..fun + }) +} + +pub fn fold_array_expression<'ast, T: Field, F: ResultFolder<'ast, T>>( + f: &mut F, + e: ArrayExpression<'ast, T>, +) -> Result, F::Error> { + let size = f.fold_uint_expression(e.size)?; + + Ok(ArrayExpression { + inner: f.fold_array_expression_inner(&e.ty, size.clone(), e.inner)?, + size, + ..e + }) +} + +pub fn fold_expression_list<'ast, T: Field, F: ResultFolder<'ast, T>>( + f: &mut F, + es: TypedExpressionList<'ast, T>, +) -> Result, F::Error> { + match es { + TypedExpressionList::FunctionCall(id, arguments, types) => { + Ok(TypedExpressionList::FunctionCall( + id, + arguments + .into_iter() + .map(|a| f.fold_expression(a)) + .collect::>()?, + types + .into_iter() + .map(|t| f.fold_type(t)) + .collect::>()?, + )) + } + } +} + +pub fn fold_struct_expression<'ast, T: Field, F: ResultFolder<'ast, T>>( + f: &mut F, + e: StructExpression<'ast, T>, +) -> Result, F::Error> { + Ok(StructExpression { + inner: f.fold_struct_expression_inner(&e.ty, e.inner)?, + ..e + }) +} + +pub fn fold_function_symbol<'ast, T: Field, F: ResultFolder<'ast, T>>( + f: &mut F, + s: TypedFunctionSymbol<'ast, T>, +) -> Result, F::Error> { + match s { + TypedFunctionSymbol::Here(fun) => Ok(TypedFunctionSymbol::Here(f.fold_function(fun)?)), + there => Ok(there), // by default, do not fold modules recursively + } +} diff --git a/zokrates_core/src/typed_absy/types.rs b/zokrates_core/src/typed_absy/types.rs index 4d2f012a..e8769b97 100644 --- a/zokrates_core/src/typed_absy/types.rs +++ b/zokrates_core/src/typed_absy/types.rs @@ -3,7 +3,7 @@ use std::fmt; use std::hash::{Hash, Hasher}; use std::path::{Path, PathBuf}; use typed_absy::{TryFrom, TryInto}; -use typed_absy::{UExpression, UExpressionInner}; +use typed_absy::{TypedModuleId, UExpression, UExpressionInner}; pub type GenericIdentifier<'ast> = &'ast str; @@ -15,7 +15,7 @@ pub enum Constant<'ast> { // At this stage we want all constants to be equal impl<'ast> PartialEq for Constant<'ast> { - fn eq(&self, _: &Self) -> bool { + fn eq(&self, other: &Self) -> bool { true } } @@ -193,6 +193,14 @@ impl<'ast, T: PartialEq> PartialEq> for ArrayType<'as } } +impl<'ast, T: PartialEq> ArrayType<'ast, T> { + // array type equality with non-strict size checks + // sizes always match unless they are different constants + pub fn weak_eq(&self, other: &Self) -> bool { + self.ty == other.ty + } +} + fn try_from_g_array_type, U>(t: GArrayType) -> Result, ()> { Ok(GArrayType { size: t.size.try_into().map_err(|_| ())?, @@ -676,6 +684,7 @@ pub type FunctionIdentifier<'ast> = &'ast str; #[derive(PartialEq, Eq, Hash, Debug, Clone)] pub struct GFunctionKey<'ast, S> { + pub module: TypedModuleId, pub id: FunctionIdentifier<'ast>, pub signature: GSignature, } @@ -703,12 +712,13 @@ impl<'ast> fmt::Display for GenericsAssignment<'ast> { impl<'ast> PartialEq> for ConcreteFunctionKey<'ast> { fn eq(&self, other: &DeclarationFunctionKey<'ast>) -> bool { - self.id == other.id && self.signature == other.signature + self.module == other.module && self.id == other.id && self.signature == other.signature } } fn try_from_g_function_key, U>(k: GFunctionKey) -> Result, ()> { Ok(GFunctionKey { + module: k.module, signature: signature::try_from_g_signature(k.signature)?, id: k.id, }) @@ -749,8 +759,12 @@ impl<'ast, T> From> for FunctionKey<'ast, T> { } impl<'ast, S> GFunctionKey<'ast, S> { - pub fn with_id>>(id: U) -> Self { + pub fn with_location, U: Into>>( + module: T, + id: U, + ) -> Self { GFunctionKey { + module: module.into(), id: id.into(), signature: GSignature::new(), } @@ -765,11 +779,21 @@ impl<'ast, S> GFunctionKey<'ast, S> { self.id = id.into(); self } + + pub fn module>(mut self, module: T) -> Self { + self.module = module.into(); + self + } } impl<'ast> ConcreteFunctionKey<'ast> { pub fn to_slug(&self) -> String { - format!("{}_{}", self.id, self.signature.to_slug()) + format!( + "{}/{}_{}", + self.module.display(), + self.id, + self.signature.to_slug() + ) } } diff --git a/zokrates_core/src/typed_absy/uint.rs b/zokrates_core/src/typed_absy/uint.rs index 98e4cc60..d93409e7 100644 --- a/zokrates_core/src/typed_absy/uint.rs +++ b/zokrates_core/src/typed_absy/uint.rs @@ -119,7 +119,7 @@ pub enum UExpressionInner<'ast, T> { Box>, Box>, ), - FunctionCall(FunctionKey<'ast, T>, Vec>), + FunctionCall(DeclarationFunctionKey<'ast>, Vec>), IfElse( Box>, Box>, diff --git a/zokrates_core/src/zir/from_typed.rs b/zokrates_core/src/zir/from_typed.rs index 1491d39c..a290c7e8 100644 --- a/zokrates_core/src/zir/from_typed.rs +++ b/zokrates_core/src/zir/from_typed.rs @@ -4,6 +4,7 @@ use zir; impl<'ast> From> for zir::types::FunctionKey<'ast> { fn from(k: typed_absy::types::ConcreteFunctionKey<'ast>) -> zir::types::FunctionKey<'ast> { zir::types::FunctionKey { + module: k.module, id: k.id, signature: k.signature.into(), } diff --git a/zokrates_core/src/zir/mod.rs b/zokrates_core/src/zir/mod.rs index 5553c3c0..640fe3cd 100644 --- a/zokrates_core/src/zir/mod.rs +++ b/zokrates_core/src/zir/mod.rs @@ -74,7 +74,7 @@ pub struct ZirModule<'ast, T> { #[derive(Debug, Clone, PartialEq)] pub enum ZirFunctionSymbol<'ast, T> { Here(ZirFunction<'ast, T>), - There(FunctionKey<'ast>, ZirModuleId), + There(FunctionKey<'ast>), Flat(FlatEmbed), } @@ -82,8 +82,8 @@ impl<'ast, T> ZirFunctionSymbol<'ast, T> { pub fn signature<'a>(&'a self, modules: &'a ZirModules) -> Signature { match self { ZirFunctionSymbol::Here(f) => f.signature.clone(), - ZirFunctionSymbol::There(key, module_id) => modules - .get(module_id) + ZirFunctionSymbol::There(key) => modules + .get(&key.module) .unwrap() .functions .get(key) @@ -102,10 +102,10 @@ impl<'ast, T: fmt::Display> fmt::Display for ZirModule<'ast, T> { .iter() .map(|(key, symbol)| match symbol { ZirFunctionSymbol::Here(ref function) => format!("def {}{}", key.id, function), - ZirFunctionSymbol::There(ref fun_key, ref module_id) => format!( + ZirFunctionSymbol::There(ref fun_key) => format!( "import {} from \"{}\" as {} // with signature {}", fun_key.id, - module_id.display(), + fun_key.module.display(), key.id, key.signature ), diff --git a/zokrates_core/src/zir/types.rs b/zokrates_core/src/zir/types.rs index 3b45481e..5a6ee3c1 100644 --- a/zokrates_core/src/zir/types.rs +++ b/zokrates_core/src/zir/types.rs @@ -1,4 +1,5 @@ use std::fmt; +use zir::ZirModuleId; pub type Identifier<'ast> = &'ast str; @@ -95,13 +96,18 @@ pub type FunctionIdentifier<'ast> = &'ast str; #[derive(PartialEq, Eq, Hash, Debug, Clone)] pub struct FunctionKey<'ast> { + pub module: ZirModuleId, pub id: FunctionIdentifier<'ast>, pub signature: Signature, } impl<'ast> FunctionKey<'ast> { - pub fn with_id>>(id: S) -> Self { + pub fn with_location, S: Into>>( + module: T, + id: S, + ) -> Self { FunctionKey { + module: module.into(), id: id.into(), signature: Signature::new(), } @@ -117,6 +123,11 @@ impl<'ast> FunctionKey<'ast> { self } + pub fn module>(mut self, module: T) -> Self { + self.module = module.into(); + self + } + pub fn to_slug(&self) -> String { format!("{}_{}", self.id, self.signature.to_slug()) } diff --git a/zokrates_core_test/tests/tests/fact_up_to_4.zok b/zokrates_core_test/tests/tests/fact_up_to_4.zok index d33b19c0..299cd995 100644 --- a/zokrates_core_test/tests/tests/fact_up_to_4.zok +++ b/zokrates_core_test/tests/tests/fact_up_to_4.zok @@ -1,8 +1,10 @@ +import "utils/casts/u32_to_field" as to_field + def main(field x) -> field: field f = 1 field counter = 0 - for field i in 1..5 do - f = if counter == x then f else f * i fi + for u32 i in 1..5 do + f = if counter == x then f else f * to_field(i) fi counter = if counter == x then counter else counter + 1 fi endfor return f \ No newline at end of file diff --git a/zokrates_core_test/tests/tests/generics/call.json b/zokrates_core_test/tests/tests/generics/call.json new file mode 100644 index 00000000..57ed1aee --- /dev/null +++ b/zokrates_core_test/tests/tests/generics/call.json @@ -0,0 +1,4 @@ +{ + "curves": ["Bn128", "Bls12"], + "tests": [] +} \ No newline at end of file diff --git a/zokrates_core_test/tests/tests/generics/call.zok b/zokrates_core_test/tests/tests/generics/call.zok new file mode 100644 index 00000000..a0e6aa05 --- /dev/null +++ b/zokrates_core_test/tests/tests/generics/call.zok @@ -0,0 +1,8 @@ +def foo(field[T] b) -> field: + return 1 + +def bar(field[T] b) -> field: + return foo(b) + +def main(field[3] a) -> field: + return foo(a) + bar(a) \ No newline at end of file diff --git a/zokrates_core_test/tests/tests/generics/multidef.json b/zokrates_core_test/tests/tests/generics/multidef.json new file mode 100644 index 00000000..d2dabe11 --- /dev/null +++ b/zokrates_core_test/tests/tests/generics/multidef.json @@ -0,0 +1,23 @@ +{ + "curves": ["Bn128", "Bls12"], + "tests": [ + { + "input": { + "values": [ + "1", + "2", + "3" + ] + }, + "output": { + "Ok": { + "values": [ + "1", + "2", + "3" + ] + } + } + } + ] +} \ No newline at end of file diff --git a/zokrates_core_test/tests/tests/generics/multidef.zok b/zokrates_core_test/tests/tests/generics/multidef.zok new file mode 100644 index 00000000..dbf3bd3f --- /dev/null +++ b/zokrates_core_test/tests/tests/generics/multidef.zok @@ -0,0 +1,6 @@ +def foo(field[T] b) -> field[T]: + return b + +def main(field[3] a) -> field[3]: + field[3] res = foo(a) + return res \ No newline at end of file diff --git a/zokrates_stdlib/stdlib/utils/casts/u32_to_field.zok b/zokrates_stdlib/stdlib/utils/casts/u32_to_field.zok new file mode 100644 index 00000000..c7e1903f --- /dev/null +++ b/zokrates_stdlib/stdlib/utils/casts/u32_to_field.zok @@ -0,0 +1,10 @@ +import "EMBED/u32_to_bits" as to_bits + +def main(u32 i) -> field: + bool[32] bits = to_bits(i) + field res = 0 + for u32 j in 0..32 do + u32 exponent = 32 - j - 1 + res = res + 2 ** exponent + endfor + return res \ No newline at end of file diff --git a/zokrates_test/src/lib.rs b/zokrates_test/src/lib.rs index 63fb8f08..990035b7 100644 --- a/zokrates_test/src/lib.rs +++ b/zokrates_test/src/lib.rs @@ -13,7 +13,7 @@ enum Curve { #[derive(Serialize, Deserialize, Clone)] struct Tests { - pub entry_point: PathBuf, + pub entry_point: Option, pub curves: Option>, pub max_constraint_count: Option, pub tests: Vec, @@ -92,6 +92,14 @@ pub fn test_inner(test_path: &str) { let curves = t.curves.clone().unwrap_or(vec![Curve::Bn128]); + let t = Tests { + entry_point: Some( + t.entry_point + .unwrap_or(PathBuf::from(String::from(test_path)).with_extension("zok")), + ), + ..t + }; + for c in &curves { match c { Curve::Bn128 => compile_and_run::(t.clone()), @@ -101,11 +109,13 @@ pub fn test_inner(test_path: &str) { } fn compile_and_run(t: Tests) { - let code = std::fs::read_to_string(&t.entry_point).unwrap(); + let entry_point = t.entry_point.unwrap(); + + let code = std::fs::read_to_string(&entry_point).unwrap(); let stdlib = std::fs::canonicalize("../zokrates_stdlib/stdlib").unwrap(); let resolver = FileSystemResolver::with_stdlib_root(stdlib.to_str().unwrap()); - let artifacts = compile::(code, t.entry_point.clone(), Some(&resolver)).unwrap(); + let artifacts = compile::(code, entry_point.clone(), Some(&resolver)).unwrap(); let bin = artifacts.prog(); @@ -115,7 +125,7 @@ fn compile_and_run(t: Tests) { println!( "{} at {}% of max", - t.entry_point.display(), + entry_point.display(), (count as f32) / (target_count as f32) * 100_f32 ); } @@ -131,7 +141,7 @@ fn compile_and_run(t: Tests) { match compare(output, test.output) { Err(e) => { - let mut code = File::open(&t.entry_point).unwrap(); + let mut code = File::open(&entry_point).unwrap(); let mut s = String::new(); code.read_to_string(&mut s).unwrap(); let context = format!(