From f8dc4d7649650ce55be366b6315d0a986a59db3b Mon Sep 17 00:00:00 2001 From: dark64 Date: Tue, 30 Mar 2021 17:15:48 +0200 Subject: [PATCH 01/12] introduce constant definitions --- .../examples/functions/no_args_multiple.zok | 14 +- zokrates_core/src/absy/from_ast.rs | 52 ++- zokrates_core/src/absy/mod.rs | 57 +++- zokrates_core/src/absy/node.rs | 1 + zokrates_core/src/semantics.rs | 314 +++++++++++------- zokrates_core/src/static_analysis/inline.rs | 52 +++ .../src/static_analysis/propagate_unroll.rs | 2 + zokrates_core/src/typed_absy/abi.rs | 8 +- zokrates_core/src/typed_absy/mod.rs | 28 +- zokrates_parser/src/zokrates.pest | 5 +- zokrates_pest_ast/src/lib.rs | 29 +- 11 files changed, 408 insertions(+), 154 deletions(-) diff --git a/zokrates_cli/examples/functions/no_args_multiple.zok b/zokrates_cli/examples/functions/no_args_multiple.zok index cef72d30..c70ad021 100644 --- a/zokrates_cli/examples/functions/no_args_multiple.zok +++ b/zokrates_cli/examples/functions/no_args_multiple.zok @@ -1,10 +1,10 @@ -def const() -> field: +def constant() -> field: return 123123 -def add(field a,field b) -> field: - a=const() - return a+b +def add(field a, field b) -> field: + a = constant() + return a + b -def main(field a,field b) -> field: - field c = add(a, b+const()) - return const() +def main(field a, field b) -> field: + field c = add(a, b + constant()) + return constant() diff --git a/zokrates_core/src/absy/from_ast.rs b/zokrates_core/src/absy/from_ast.rs index 5fbed664..d1aeffd5 100644 --- a/zokrates_core/src/absy/from_ast.rs +++ b/zokrates_core/src/absy/from_ast.rs @@ -1,6 +1,7 @@ use crate::absy; use crate::imports; +use crate::absy::SymbolDefinition; use num::ToPrimitive; use num_bigint::BigUint; use zokrates_pest_ast as pest; @@ -11,6 +12,11 @@ impl<'ast> From> for absy::Module<'ast> { prog.structs .into_iter() .map(|t| absy::SymbolDeclarationNode::from(t)) + .chain( + prog.constants + .into_iter() + .map(|f| absy::SymbolDeclarationNode::from(f)), + ) .chain( prog.functions .into_iter() @@ -65,7 +71,7 @@ impl<'ast> From> for absy::SymbolDeclarationNode<'a absy::SymbolDeclaration { id, - symbol: absy::Symbol::HereType(ty), + symbol: absy::Symbol::Here(SymbolDefinition::Struct(ty)), } .span(span) } @@ -85,6 +91,28 @@ impl<'ast> From> for absy::StructDefinitionFieldNode<'as } } +impl<'ast> From> for absy::SymbolDeclarationNode<'ast> { + fn from(definition: pest::ConstantDefinition<'ast>) -> absy::SymbolDeclarationNode<'ast> { + use crate::absy::NodeValue; + + let span = definition.span; + let id = definition.id.span.as_str(); + + let ty = absy::ConstantDefinition { + id, + ty: definition.ty.into(), + expression: definition.expression.into(), + } + .span(span.clone()); + + absy::SymbolDeclaration { + id, + symbol: absy::Symbol::Here(SymbolDefinition::Constant(ty)), + } + .span(span) + } +} + impl<'ast> From> for absy::SymbolDeclarationNode<'ast> { fn from(function: pest::Function<'ast>) -> absy::SymbolDeclarationNode<'ast> { use crate::absy::NodeValue; @@ -128,7 +156,7 @@ impl<'ast> From> for absy::SymbolDeclarationNode<'ast> { absy::SymbolDeclaration { id, - symbol: absy::Symbol::HereFunction(function), + symbol: absy::Symbol::Here(SymbolDefinition::Function(function)), } .span(span) } @@ -683,7 +711,7 @@ mod tests { let expected: absy::Module = absy::Module { symbols: vec![absy::SymbolDeclaration { id: &source[4..8], - symbol: absy::Symbol::HereFunction( + symbol: absy::Symbol::Here(SymbolDefinition::Function( absy::Function { arguments: vec![], statements: vec![absy::Statement::Return( @@ -701,7 +729,7 @@ mod tests { .outputs(vec![UnresolvedType::FieldElement.mock()]), } .into(), - ), + )), } .into()], imports: vec![], @@ -716,7 +744,7 @@ mod tests { let expected: absy::Module = absy::Module { symbols: vec![absy::SymbolDeclaration { id: &source[4..8], - symbol: absy::Symbol::HereFunction( + symbol: absy::Symbol::Here(SymbolDefinition::Function( absy::Function { arguments: vec![], statements: vec![absy::Statement::Return( @@ -731,7 +759,7 @@ mod tests { .outputs(vec![UnresolvedType::Boolean.mock()]), } .into(), - ), + )), } .into()], imports: vec![], @@ -747,7 +775,7 @@ mod tests { let expected: absy::Module = absy::Module { symbols: vec![absy::SymbolDeclaration { id: &source[4..8], - symbol: absy::Symbol::HereFunction( + symbol: absy::Symbol::Here(SymbolDefinition::Function( absy::Function { arguments: vec![ absy::Parameter::private( @@ -785,7 +813,7 @@ mod tests { .outputs(vec![UnresolvedType::FieldElement.mock()]), } .into(), - ), + )), } .into()], imports: vec![], @@ -802,7 +830,7 @@ mod tests { absy::Module { symbols: vec![absy::SymbolDeclaration { id: "main", - symbol: absy::Symbol::HereFunction( + symbol: absy::Symbol::Here(SymbolDefinition::Function( absy::Function { arguments: vec![absy::Parameter::private( absy::Variable::new("a", ty.clone().mock()).into(), @@ -818,7 +846,7 @@ mod tests { signature: UnresolvedSignature::new().inputs(vec![ty.mock()]), } .into(), - ), + )), } .into()], imports: vec![], @@ -866,7 +894,7 @@ mod tests { absy::Module { symbols: vec![absy::SymbolDeclaration { id: "main", - symbol: absy::Symbol::HereFunction( + symbol: absy::Symbol::Here(SymbolDefinition::Function( absy::Function { arguments: vec![], statements: vec![absy::Statement::Return( @@ -879,7 +907,7 @@ mod tests { signature: UnresolvedSignature::new(), } .into(), - ), + )), } .into()], imports: vec![], diff --git a/zokrates_core/src/absy/mod.rs b/zokrates_core/src/absy/mod.rs index a497c374..7be899c5 100644 --- a/zokrates_core/src/absy/mod.rs +++ b/zokrates_core/src/absy/mod.rs @@ -49,10 +49,26 @@ pub struct SymbolDeclaration<'ast> { pub symbol: Symbol<'ast>, } +#[derive(PartialEq, Clone)] +pub enum SymbolDefinition<'ast> { + Struct(StructDefinitionNode<'ast>), + Constant(ConstantDefinitionNode<'ast>), + Function(FunctionNode<'ast>), +} + +impl<'ast> fmt::Debug for SymbolDefinition<'ast> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + SymbolDefinition::Struct(s) => write!(f, "Struct({:?})", s), + SymbolDefinition::Constant(c) => write!(f, "Constant({:?})", c), + SymbolDefinition::Function(func) => write!(f, "Function({:?})", func), + } + } +} + #[derive(PartialEq, Clone)] pub enum Symbol<'ast> { - HereType(StructDefinitionNode<'ast>), - HereFunction(FunctionNode<'ast>), + Here(SymbolDefinition<'ast>), There(SymbolImportNode<'ast>), Flat(FlatEmbed), } @@ -60,9 +76,8 @@ pub enum Symbol<'ast> { impl<'ast> fmt::Debug for Symbol<'ast> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { - Symbol::HereType(t) => write!(f, "HereType({:?})", t), - Symbol::HereFunction(fun) => write!(f, "HereFunction({:?})", fun), - Symbol::There(t) => write!(f, "There({:?})", t), + Symbol::Here(k) => write!(f, "Here({:?})", k), + Symbol::There(i) => write!(f, "There({:?})", i), Symbol::Flat(flat) => write!(f, "Flat({:?})", flat), } } @@ -71,8 +86,11 @@ impl<'ast> fmt::Debug for Symbol<'ast> { impl<'ast> fmt::Display for SymbolDeclaration<'ast> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self.symbol { - Symbol::HereType(ref t) => write!(f, "struct {} {}", self.id, t), - Symbol::HereFunction(ref fun) => write!(f, "def {}{}", self.id, fun), + Symbol::Here(ref kind) => match kind { + SymbolDefinition::Struct(t) => write!(f, "struct {} {}", self.id, t), + SymbolDefinition::Constant(c) => write!(f, "{}", c), + SymbolDefinition::Function(func) => write!(f, "def {}{}", self.id, func), + }, Symbol::There(ref import) => write!(f, "import {} as {}", import, self.id), Symbol::Flat(ref flat_fun) => { write!(f, "def {}{}:\n\t// hidden", self.id, flat_fun.signature()) @@ -216,6 +234,31 @@ impl<'ast> fmt::Debug for Module<'ast> { } } +#[derive(Clone, PartialEq)] +pub struct ConstantDefinition<'ast> { + pub id: Identifier<'ast>, + pub ty: UnresolvedTypeNode, + pub expression: ExpressionNode<'ast>, +} + +pub type ConstantDefinitionNode<'ast> = Node>; + +impl<'ast> fmt::Display for ConstantDefinition<'ast> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "const {} {} = {}", self.ty, self.id, self.expression) + } +} + +impl<'ast> fmt::Debug for ConstantDefinition<'ast> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!( + f, + "ConstantDefinition({:?}, {:?}, {:?})", + self.ty, self.id, self.expression + ) + } +} + /// A function defined locally #[derive(Clone, PartialEq)] pub struct Function<'ast> { diff --git a/zokrates_core/src/absy/node.rs b/zokrates_core/src/absy/node.rs index cf8d3fb4..72607db1 100644 --- a/zokrates_core/src/absy/node.rs +++ b/zokrates_core/src/absy/node.rs @@ -84,6 +84,7 @@ impl<'ast> NodeValue for SymbolDeclaration<'ast> {} impl NodeValue for UnresolvedType {} impl<'ast> NodeValue for StructDefinition<'ast> {} impl<'ast> NodeValue for StructDefinitionField<'ast> {} +impl<'ast> NodeValue for ConstantDefinition<'ast> {} impl<'ast> NodeValue for Function<'ast> {} impl<'ast> NodeValue for Module<'ast> {} impl<'ast> NodeValue for SymbolImport<'ast> {} diff --git a/zokrates_core/src/semantics.rs b/zokrates_core/src/semantics.rs index 4e074645..71bdeb93 100644 --- a/zokrates_core/src/semantics.rs +++ b/zokrates_core/src/semantics.rs @@ -59,6 +59,7 @@ struct State<'ast, T: Field> { #[derive(PartialEq, Hash, Eq, Debug)] enum SymbolType { Type, + Constant, Functions(BTreeSet), } @@ -69,14 +70,14 @@ struct SymbolUnifier { } impl SymbolUnifier { - fn insert_type>(&mut self, id: S) -> bool { + fn insert_symbol>(&mut self, id: S, ty: SymbolType) -> bool { let s_type = self.symbols.entry(id.into()); match s_type { // if anything is already called `id`, we cannot introduce this type Entry::Occupied(..) => false, // otherwise, we can! Entry::Vacant(v) => { - v.insert(SymbolType::Type); + v.insert(ty); true } } @@ -88,8 +89,8 @@ impl SymbolUnifier { // if anything is already called `id`, it depends what it is Entry::Occupied(mut o) => { match o.get_mut() { - // if it's a Type, then we can't introduce a function - SymbolType::Type => false, + // if it's a Type or a Constant, then we can't introduce a function + SymbolType::Type | SymbolType::Constant => false, // if it's a Function, we can introduce a new function only if it has a different signature SymbolType::Functions(signatures) => signatures.insert(signature), } @@ -205,6 +206,7 @@ impl<'ast> FunctionQuery<'ast> { pub struct ScopedVariable<'ast> { id: Variable<'ast>, level: usize, + constant: bool, } /// Identifiers of different `ScopedVariable`s should not conflict, so we define them as equivalent @@ -282,6 +284,35 @@ impl<'ast> Checker<'ast> { }) } + fn check_constant_definition( + &mut self, + c: ConstantDefinitionNode<'ast>, + module_id: &ModuleId, + types: &TypeMap, + ) -> Result, ErrorInner> { + let pos = c.pos(); + let ty = self.check_type(c.value.ty, module_id, &types)?; + let expression = self.check_expression(c.value.expression, module_id, types)?; + + match ty == expression.get_type() { + true => Ok(TypedConstant { + id: crate::typed_absy::Identifier::from(c.value.id), + ty, + expression, + }), + false => Err(ErrorInner { + pos: Some(pos), + message: format!( + "Expression `{}` of type `{}` cannot be assigned to constant `{}` of type `{}`", + expression, + expression.get_type(), + c.value.id, + ty + ), + }), + } + } + fn check_struct_type_declaration( &mut self, id: String, @@ -335,6 +366,7 @@ impl<'ast> Checker<'ast> { module_id: &ModuleId, state: &mut State<'ast, T>, functions: &mut HashMap, TypedFunctionSymbol<'ast, T>>, + constants: &mut HashMap, TypedConstant<'ast, T>>, symbol_unifier: &mut SymbolUnifier, ) -> Result<(), Vec> { let mut errors: Vec = vec![]; @@ -343,67 +375,99 @@ impl<'ast> Checker<'ast> { let declaration = declaration.value; match declaration.symbol.clone() { - Symbol::HereType(t) => { - match self.check_struct_type_declaration( - declaration.id.to_string(), - t.clone(), - module_id, - &state.types, - ) { - Ok(ty) => { - match symbol_unifier.insert_type(declaration.id) { - false => errors.push( - ErrorInner { - pos: Some(pos), - message: format!( - "{} conflicts with another symbol", - declaration.id, - ), - } - .in_file(module_id), - ), - true => {} - }; - state - .types - .entry(module_id.clone()) - .or_default() - .insert(declaration.id.to_string(), ty); - } - Err(e) => errors.extend(e.into_iter().map(|inner| Error { - inner, - module_id: module_id.clone(), - })), - } - } - Symbol::HereFunction(f) => match self.check_function(f, module_id, &state.types) { - Ok(funct) => { - match symbol_unifier.insert_function(declaration.id, funct.signature.clone()) { - false => errors.push( - ErrorInner { - pos: Some(pos), - message: format!( - "{} conflicts with another symbol", - declaration.id, + Symbol::Here(kind) => match kind { + SymbolDefinition::Struct(t) => { + match self.check_struct_type_declaration( + declaration.id.to_string(), + t.clone(), + module_id, + &state.types, + ) { + Ok(ty) => { + match symbol_unifier.insert_symbol(declaration.id, SymbolType::Type) { + false => errors.push( + ErrorInner { + pos: Some(pos), + message: format!( + "{} conflicts with another symbol", + declaration.id, + ), + } + .in_file(module_id), ), - } - .in_file(module_id), - ), - true => {} - }; - - self.functions.insert( - FunctionKey::with_id(declaration.id.clone()) - .signature(funct.signature.clone()), - ); - functions.insert( - FunctionKey::with_id(declaration.id.clone()) - .signature(funct.signature.clone()), - TypedFunctionSymbol::Here(funct), - ); + true => {} + }; + state + .types + .entry(module_id.clone()) + .or_default() + .insert(declaration.id.to_string(), ty); + } + Err(e) => errors.extend(e.into_iter().map(|inner| Error { + inner, + module_id: module_id.clone(), + })), + } } - Err(e) => { - errors.extend(e.into_iter().map(|inner| inner.in_file(module_id))); + SymbolDefinition::Constant(c) => { + match self.check_constant_definition(c, module_id, &state.types) { + Ok(c) => { + match symbol_unifier.insert_symbol(declaration.id, SymbolType::Constant) + { + false => errors.push( + ErrorInner { + pos: Some(pos), + message: format!( + "{} conflicts with another symbol", + declaration.id, + ), + } + .in_file(module_id), + ), + true => {} + }; + constants + .insert(identifier::Identifier::from(declaration.id), c.clone()); + self.insert_into_scope(Variable::with_id_and_type(c.id, c.ty), true); + } + Err(e) => { + errors.push(e.in_file(module_id)); + } + } + } + SymbolDefinition::Function(f) => { + match self.check_function(f, module_id, &state.types) { + Ok(funct) => { + match symbol_unifier + .insert_function(declaration.id, funct.signature.clone()) + { + false => errors.push( + ErrorInner { + pos: Some(pos), + message: format!( + "{} conflicts with another symbol", + declaration.id, + ), + } + .in_file(module_id), + ), + true => {} + }; + + self.functions.insert( + FunctionKey::with_id(declaration.id.clone()) + .signature(funct.signature.clone()), + ); + functions.insert( + FunctionKey::with_id(declaration.id.clone()) + .signature(funct.signature.clone()), + TypedFunctionSymbol::Here(funct), + ); + } + Err(e) => { + errors.extend(e.into_iter().map(|inner| inner.in_file(module_id))); + } + } } }, Symbol::There(import) => { @@ -450,7 +514,7 @@ impl<'ast> Checker<'ast> { }; // we imported a type, so the symbol it gets bound to should not already exist - match symbol_unifier.insert_type(declaration.id) { + match symbol_unifier.insert_symbol(declaration.id, SymbolType::Type) { false => { errors.push(Error { module_id: module_id.clone(), @@ -557,6 +621,7 @@ impl<'ast> Checker<'ast> { ) -> Result<(), Vec> { let mut errors = vec![]; let mut checked_functions = HashMap::new(); + let mut checked_constants = HashMap::new(); // check if the module was already removed from the untyped ones let to_insert = match state.modules.remove(module_id) { @@ -569,7 +634,7 @@ impl<'ast> Checker<'ast> { // we need to create an entry in the types map to store types for this module state.types.entry(module_id.clone()).or_default(); - // we keep track of the introduced symbols to avoid colisions between types and functions + // we keep track of the introduced symbols to avoid collisions between types and functions let mut symbol_unifier = SymbolUnifier::default(); // we go through symbol declarations and check them @@ -579,6 +644,7 @@ impl<'ast> Checker<'ast> { module_id, state, &mut checked_functions, + &mut checked_constants, &mut symbol_unifier, ) { Ok(()) => {} @@ -590,6 +656,7 @@ impl<'ast> Checker<'ast> { Some(TypedModule { functions: checked_functions, + constants: checked_constants, }) } }; @@ -663,7 +730,7 @@ impl<'ast> Checker<'ast> { for arg in funct.arguments { match self.check_parameter(arg, module_id, types) { Ok(a) => { - self.insert_into_scope(a.id.clone()); + self.insert_into_scope(a.id.clone(), false); arguments_checked.push(a); } Err(e) => errors.extend(e), @@ -873,7 +940,7 @@ impl<'ast> Checker<'ast> { } Statement::Declaration(var) => { let var = self.check_variable(var, module_id, types)?; - match self.insert_into_scope(var.clone()) { + match self.insert_into_scope(var.clone(), false) { true => Ok(TypedStatement::Declaration(var)), false => Err(ErrorInner { pos: Some(pos), @@ -909,7 +976,7 @@ impl<'ast> Checker<'ast> { false => Err(ErrorInner { pos: Some(pos), message: format!( - "Expression {} of type {} cannot be assigned to {} of type {}", + "Expression `{}` of type `{}` cannot be assigned to `{}` of type `{}`", checked_expr, expression_type, var, var_type ), }), @@ -972,7 +1039,7 @@ impl<'ast> Checker<'ast> { } .map_err(|e| vec![e])?; - self.insert_into_scope(var.clone()); + self.insert_into_scope(var.clone(), false); let mut checked_statements = vec![]; @@ -1045,10 +1112,16 @@ impl<'ast> Checker<'ast> { // check that the assignee is declared match assignee.value { Assignee::Identifier(variable_name) => match self.get_scope(&variable_name) { - Some(var) => Ok(TypedAssignee::Identifier(Variable::with_id_and_type( - variable_name, - var.id._type.clone(), - ))), + Some(var) => match var.constant { + false => Ok(TypedAssignee::Identifier(Variable::with_id_and_type( + variable_name, + var.id._type.clone(), + ))), + true => Err(ErrorInner { + pos: Some(assignee.pos()), + message: format!("Assignment to constant variable `{}`", variable_name), + }), + }, None => Err(ErrorInner { pos: Some(assignee.pos()), message: format!("Variable `{}` is undeclared", variable_name), @@ -2359,13 +2432,15 @@ impl<'ast> Checker<'ast> { Type::FieldElement, ), level: 0, + constant: false, }) } - fn insert_into_scope(&mut self, v: Variable<'ast>) -> bool { + fn insert_into_scope(&mut self, v: Variable<'ast>, constant: bool) -> bool { self.scope.insert(ScopedVariable { id: v, level: self.level, + constant, }) } @@ -2555,15 +2630,15 @@ mod tests { let mut unifier = SymbolUnifier::default(); - assert!(unifier.insert_type("foo")); - assert!(!unifier.insert_type("foo")); + assert!(unifier.insert_symbol("foo", SymbolType::Type)); + assert!(!unifier.insert_symbol("foo", SymbolType::Type)); assert!(!unifier.insert_function("foo", Signature::new())); assert!(unifier.insert_function("bar", Signature::new())); assert!(!unifier.insert_function("bar", Signature::new())); assert!( unifier.insert_function("bar", Signature::new().inputs(vec![Type::FieldElement])) ); - assert!(!unifier.insert_type("bar")); + assert!(!unifier.insert_symbol("bar", SymbolType::Type)); } #[test] @@ -2580,7 +2655,7 @@ mod tests { let foo: Module = Module { symbols: vec![SymbolDeclaration { id: "main", - symbol: Symbol::HereFunction(function0()), + symbol: Symbol::Here(SymbolDefinition::Function(function0())), } .mock()], imports: vec![], @@ -2616,6 +2691,7 @@ mod tests { )] .into_iter() .collect(), + constants: Default::default() }) ); } @@ -2633,12 +2709,12 @@ mod tests { symbols: vec![ SymbolDeclaration { id: "foo", - symbol: Symbol::HereFunction(function0()), + symbol: Symbol::Here(SymbolDefinition::Function(function0())), } .mock(), SymbolDeclaration { id: "foo", - symbol: Symbol::HereFunction(function0()), + symbol: Symbol::Here(SymbolDefinition::Function(function0())), } .mock(), ], @@ -2675,12 +2751,12 @@ mod tests { symbols: vec![ SymbolDeclaration { id: "foo", - symbol: Symbol::HereFunction(function0()), + symbol: Symbol::Here(SymbolDefinition::Function(function0())), } .mock(), SymbolDeclaration { id: "foo", - symbol: Symbol::HereFunction(function1()), + symbol: Symbol::Here(SymbolDefinition::Function(function1())), } .mock(), ], @@ -2726,12 +2802,12 @@ mod tests { symbols: vec![ SymbolDeclaration { id: "foo", - symbol: Symbol::HereType(struct0()), + symbol: Symbol::Here(SymbolDefinition::Struct(struct0())), } .mock(), SymbolDeclaration { id: "foo", - symbol: Symbol::HereType(struct1()), + symbol: Symbol::Here(SymbolDefinition::Struct(struct1())), } .mock(), ], @@ -2764,12 +2840,14 @@ mod tests { symbols: vec![ SymbolDeclaration { id: "foo", - symbol: Symbol::HereFunction(function0()), + symbol: Symbol::Here(SymbolDefinition::Function(function0())), } .mock(), SymbolDeclaration { id: "foo", - symbol: Symbol::HereType(StructDefinition { fields: vec![] }.mock()), + symbol: Symbol::Here(SymbolDefinition::Struct( + StructDefinition { fields: vec![] }.mock(), + )), } .mock(), ], @@ -2805,7 +2883,7 @@ mod tests { let bar = Module::with_symbols(vec![SymbolDeclaration { id: "main", - symbol: Symbol::HereFunction(function0()), + symbol: Symbol::Here(SymbolDefinition::Function(function0())), } .mock()]); @@ -2820,7 +2898,7 @@ mod tests { .mock(), SymbolDeclaration { id: "foo", - symbol: Symbol::HereType(struct0()), + symbol: Symbol::Here(SymbolDefinition::Struct(struct0())), } .mock(), ], @@ -2856,7 +2934,7 @@ mod tests { let bar = Module::with_symbols(vec![SymbolDeclaration { id: "main", - symbol: Symbol::HereFunction(function0()), + symbol: Symbol::Here(SymbolDefinition::Function(function0())), } .mock()]); @@ -2864,7 +2942,7 @@ mod tests { symbols: vec![ SymbolDeclaration { id: "foo", - symbol: Symbol::HereType(struct0()), + symbol: Symbol::Here(SymbolDefinition::Struct(struct0())), } .mock(), SymbolDeclaration { @@ -2948,10 +3026,12 @@ mod tests { scope.insert(ScopedVariable { id: Variable::field_element("a"), level: 0, + constant: false, }); scope.insert(ScopedVariable { id: Variable::field_element("b"), level: 0, + constant: false, }); let mut checker = new_with_args(scope, 1, HashSet::new()); assert_eq!( @@ -3019,12 +3099,12 @@ mod tests { let symbols = vec![ SymbolDeclaration { id: "foo", - symbol: Symbol::HereFunction(foo), + symbol: Symbol::Here(SymbolDefinition::Function(foo)), } .mock(), SymbolDeclaration { id: "bar", - symbol: Symbol::HereFunction(bar), + symbol: Symbol::Here(SymbolDefinition::Function(bar)), } .mock(), ]; @@ -3135,17 +3215,17 @@ mod tests { let symbols = vec![ SymbolDeclaration { id: "foo", - symbol: Symbol::HereFunction(foo), + symbol: Symbol::Here(SymbolDefinition::Function(foo)), } .mock(), SymbolDeclaration { id: "bar", - symbol: Symbol::HereFunction(bar), + symbol: Symbol::Here(SymbolDefinition::Function(bar)), } .mock(), SymbolDeclaration { id: "main", - symbol: Symbol::HereFunction(main), + symbol: Symbol::Here(SymbolDefinition::Function(main)), } .mock(), ]; @@ -3529,12 +3609,12 @@ mod tests { symbols: vec![ SymbolDeclaration { id: "foo", - symbol: Symbol::HereFunction(foo), + symbol: Symbol::Here(SymbolDefinition::Function(foo)), } .mock(), SymbolDeclaration { id: "main", - symbol: Symbol::HereFunction(main), + symbol: Symbol::Here(SymbolDefinition::Function(main)), } .mock(), ], @@ -3622,12 +3702,12 @@ mod tests { symbols: vec![ SymbolDeclaration { id: "foo", - symbol: Symbol::HereFunction(foo), + symbol: Symbol::Here(SymbolDefinition::Function(foo)), } .mock(), SymbolDeclaration { id: "main", - symbol: Symbol::HereFunction(main), + symbol: Symbol::Here(SymbolDefinition::Function(main)), } .mock(), ], @@ -3738,12 +3818,12 @@ mod tests { symbols: vec![ SymbolDeclaration { id: "foo", - symbol: Symbol::HereFunction(foo), + symbol: Symbol::Here(SymbolDefinition::Function(foo)), } .mock(), SymbolDeclaration { id: "main", - symbol: Symbol::HereFunction(main), + symbol: Symbol::Here(SymbolDefinition::Function(main)), } .mock(), ], @@ -4002,12 +4082,12 @@ mod tests { let symbols = vec![ SymbolDeclaration { id: "main", - symbol: Symbol::HereFunction(main1), + symbol: Symbol::Here(SymbolDefinition::Function(main1)), } .mock(), SymbolDeclaration { id: "main", - symbol: Symbol::HereFunction(main2), + symbol: Symbol::Here(SymbolDefinition::Function(main2)), } .mock(), ]; @@ -4121,7 +4201,7 @@ mod tests { imports: vec![], symbols: vec![SymbolDeclaration { id: "Foo", - symbol: Symbol::HereType(s.mock()), + symbol: Symbol::Here(SymbolDefinition::Struct(s.mock())), } .mock()], }; @@ -4305,7 +4385,7 @@ mod tests { symbols: vec![ SymbolDeclaration { id: "Foo", - symbol: Symbol::HereType( + symbol: Symbol::Here(SymbolDefinition::Struct( StructDefinition { fields: vec![StructDefinitionField { id: "foo", @@ -4314,12 +4394,12 @@ mod tests { .mock()], } .mock(), - ), + )), } .mock(), SymbolDeclaration { id: "Bar", - symbol: Symbol::HereType( + symbol: Symbol::Here(SymbolDefinition::Struct( StructDefinition { fields: vec![StructDefinitionField { id: "foo", @@ -4328,7 +4408,7 @@ mod tests { .mock()], } .mock(), - ), + )), } .mock(), ], @@ -4373,7 +4453,7 @@ mod tests { imports: vec![], symbols: vec![SymbolDeclaration { id: "Bar", - symbol: Symbol::HereType( + symbol: Symbol::Here(SymbolDefinition::Struct( StructDefinition { fields: vec![StructDefinitionField { id: "foo", @@ -4382,7 +4462,7 @@ mod tests { .mock()], } .mock(), - ), + )), } .mock()], }; @@ -4406,7 +4486,7 @@ mod tests { imports: vec![], symbols: vec![SymbolDeclaration { id: "Foo", - symbol: Symbol::HereType( + symbol: Symbol::Here(SymbolDefinition::Struct( StructDefinition { fields: vec![StructDefinitionField { id: "foo", @@ -4415,7 +4495,7 @@ mod tests { .mock()], } .mock(), - ), + )), } .mock()], }; @@ -4441,7 +4521,7 @@ mod tests { symbols: vec![ SymbolDeclaration { id: "Foo", - symbol: Symbol::HereType( + symbol: Symbol::Here(SymbolDefinition::Struct( StructDefinition { fields: vec![StructDefinitionField { id: "bar", @@ -4450,12 +4530,12 @@ mod tests { .mock()], } .mock(), - ), + )), } .mock(), SymbolDeclaration { id: "Bar", - symbol: Symbol::HereType( + symbol: Symbol::Here(SymbolDefinition::Struct( StructDefinition { fields: vec![StructDefinitionField { id: "foo", @@ -4464,7 +4544,7 @@ mod tests { .mock()], } .mock(), - ), + )), } .mock(), ], @@ -4607,7 +4687,7 @@ mod tests { ) .mock(), &PathBuf::from(MODULE_ID).into(), - &state.types, + &state.types ), Ok(TypedStatement::Declaration(Variable::with_id_and_type( "a", diff --git a/zokrates_core/src/static_analysis/inline.rs b/zokrates_core/src/static_analysis/inline.rs index 4fad20a6..6733ea3b 100644 --- a/zokrates_core/src/static_analysis/inline.rs +++ b/zokrates_core/src/static_analysis/inline.rs @@ -19,6 +19,7 @@ use crate::typed_absy::types::{FunctionKey, FunctionKeyHash, Type, UBitwidth}; use crate::typed_absy::{folder::*, *}; use std::collections::HashMap; +use std::convert::TryInto; use zokrates_field::Field; #[derive(Debug, PartialEq, Eq, Hash, Clone)] @@ -145,6 +146,7 @@ impl<'ast, T: Field> Inliner<'ast, T> { ] .into_iter() .collect(), + constants: Default::default(), }, )] .into_iter() @@ -298,6 +300,12 @@ impl<'ast, T: Field> Folder<'ast, T> for Inliner<'ast, T> { e: FieldElementExpression<'ast, T>, ) -> FieldElementExpression<'ast, T> { match e { + FieldElementExpression::Identifier(ref id) => { + match self.module().constants.get(id).cloned() { + Some(c) => fold_field_expression(self, c.expression.try_into().unwrap()), + None => fold_field_expression(self, e), + } + } FieldElementExpression::FunctionCall(key, exps) => { let exps: Vec<_> = exps.into_iter().map(|e| self.fold_expression(e)).collect(); @@ -344,6 +352,12 @@ impl<'ast, T: Field> Folder<'ast, T> for Inliner<'ast, T> { e: BooleanExpression<'ast, T>, ) -> BooleanExpression<'ast, T> { match e { + BooleanExpression::Identifier(ref id) => { + match self.module().constants.get(id).cloned() { + Some(c) => fold_boolean_expression(self, c.expression.try_into().unwrap()), + None => fold_boolean_expression(self, e), + } + } BooleanExpression::FunctionCall(key, exps) => { let exps: Vec<_> = exps.into_iter().map(|e| self.fold_expression(e)).collect(); @@ -392,6 +406,15 @@ impl<'ast, T: Field> Folder<'ast, T> for Inliner<'ast, T> { e: ArrayExpressionInner<'ast, T>, ) -> ArrayExpressionInner<'ast, T> { match e { + ArrayExpressionInner::Identifier(ref id) => { + match self.module().constants.get(id).cloned() { + Some(c) => { + let expr: ArrayExpression<'ast, T> = c.expression.try_into().unwrap(); + fold_array_expression(self, expr).into_inner() + } + None => fold_array_expression_inner(self, ty, size, e), + } + } ArrayExpressionInner::FunctionCall(key, exps) => { let exps: Vec<_> = exps.into_iter().map(|e| self.fold_expression(e)).collect(); @@ -439,6 +462,15 @@ impl<'ast, T: Field> Folder<'ast, T> for Inliner<'ast, T> { e: StructExpressionInner<'ast, T>, ) -> StructExpressionInner<'ast, T> { match e { + StructExpressionInner::Identifier(ref id) => { + match self.module().constants.get(id).cloned() { + Some(c) => { + let expr: StructExpression<'ast, T> = c.expression.try_into().unwrap(); + fold_struct_expression(self, expr).into_inner() + } + None => fold_struct_expression_inner(self, ty, e), + } + } StructExpressionInner::FunctionCall(key, exps) => { let exps: Vec<_> = exps.into_iter().map(|e| self.fold_expression(e)).collect(); @@ -486,6 +518,15 @@ impl<'ast, T: Field> Folder<'ast, T> for Inliner<'ast, T> { e: UExpressionInner<'ast, T>, ) -> UExpressionInner<'ast, T> { match e { + UExpressionInner::Identifier(ref id) => { + match self.module().constants.get(id).cloned() { + Some(c) => { + let expr: UExpression<'ast, T> = c.expression.try_into().unwrap(); + fold_uint_expression(self, expr).into_inner() + } + None => fold_uint_expression_inner(self, size, e), + } + } UExpressionInner::FunctionCall(key, exps) => { let exps: Vec<_> = exps.into_iter().map(|e| self.fold_expression(e)).collect(); @@ -581,6 +622,7 @@ mod tests { ] .into_iter() .collect(), + constants: Default::default(), }; let foo = TypedModule { @@ -597,6 +639,7 @@ mod tests { )] .into_iter() .collect(), + constants: Default::default(), }; let modules: HashMap<_, _> = vec![("main".into(), main), ("foo".into(), foo)] @@ -696,6 +739,7 @@ mod tests { ] .into_iter() .collect(), + constants: Default::default(), }; let foo = TypedModule { @@ -719,6 +763,7 @@ mod tests { )] .into_iter() .collect(), + constants: Default::default(), }; let modules: HashMap<_, _> = vec![("main".into(), main), ("foo".into(), foo)] @@ -875,6 +920,7 @@ mod tests { ] .into_iter() .collect(), + constants: Default::default(), }; let foo: TypedModule = TypedModule { @@ -893,6 +939,7 @@ mod tests { )] .into_iter() .collect(), + constants: Default::default(), }; let modules: HashMap<_, _> = vec![("main".into(), main), ("foo".into(), foo)] @@ -1065,6 +1112,7 @@ mod tests { ] .into_iter() .collect(), + constants: Default::default(), }; let foo = TypedModule { @@ -1081,6 +1129,7 @@ mod tests { )] .into_iter() .collect(), + constants: Default::default(), }; let modules: HashMap<_, _> = vec![("main".into(), main), ("foo".into(), foo)] @@ -1176,6 +1225,7 @@ mod tests { ] .into_iter() .collect(), + constants: Default::default(), }; let modules: HashMap<_, _> = vec![("main".into(), main)].into_iter().collect(); @@ -1283,6 +1333,7 @@ mod tests { ] .into_iter() .collect(), + constants: Default::default(), }; let id = TypedModule { @@ -1304,6 +1355,7 @@ mod tests { )] .into_iter() .collect(), + constants: Default::default(), }; let modules = vec![("main".into(), main), ("id".into(), id)] diff --git a/zokrates_core/src/static_analysis/propagate_unroll.rs b/zokrates_core/src/static_analysis/propagate_unroll.rs index 67354b53..5a461206 100644 --- a/zokrates_core/src/static_analysis/propagate_unroll.rs +++ b/zokrates_core/src/static_analysis/propagate_unroll.rs @@ -90,6 +90,7 @@ mod tests { )] .into_iter() .collect(), + constants: Default::default(), }, )] .into_iter() @@ -215,6 +216,7 @@ mod tests { )] .into_iter() .collect(), + constants: Default::default(), }, )] .into_iter() diff --git a/zokrates_core/src/typed_absy/abi.rs b/zokrates_core/src/typed_absy/abi.rs index f8ff7a87..e9c25fbb 100644 --- a/zokrates_core/src/typed_absy/abi.rs +++ b/zokrates_core/src/typed_absy/abi.rs @@ -61,7 +61,13 @@ mod tests { ); let mut modules = HashMap::new(); - modules.insert("main".into(), TypedModule { functions }); + modules.insert( + "main".into(), + TypedModule { + functions, + constants: Default::default(), + }, + ); let typed_ast: TypedProgram = TypedProgram { main: "main".into(), diff --git a/zokrates_core/src/typed_absy/mod.rs b/zokrates_core/src/typed_absy/mod.rs index 0bb7f7c9..6ba33a9e 100644 --- a/zokrates_core/src/typed_absy/mod.rs +++ b/zokrates_core/src/typed_absy/mod.rs @@ -45,6 +45,9 @@ pub type TypedModules<'ast, T> = HashMap>; /// in a given `TypedModule`, hence the use of a HashMap pub type TypedFunctionSymbols<'ast, T> = HashMap, TypedFunctionSymbol<'ast, T>>; +/// A collection of `TypedConstant`s +pub type TypedConstants<'ast, T> = HashMap, TypedConstant<'ast, T>>; + /// A typed program as a collection of modules, one of them being the main #[derive(PartialEq, Debug, Clone)] pub struct TypedProgram<'ast, T> { @@ -102,11 +105,13 @@ impl<'ast, T: fmt::Display> fmt::Display for TypedProgram<'ast, T> { } } -/// A typed program as a collection of functions. Types have been resolved during semantic checking. +/// A typed module as a collection of functions. Types have been resolved during semantic checking. #[derive(PartialEq, Clone)] pub struct TypedModule<'ast, T> { - /// Functions of the program + /// Functions of the module pub functions: TypedFunctionSymbols<'ast, T>, + /// Constants defined in module + pub constants: TypedConstants<'ast, T>, } #[derive(Clone, PartialEq)] @@ -248,6 +253,25 @@ impl<'ast, T: fmt::Debug> fmt::Debug for TypedFunction<'ast, T> { } } +#[derive(Clone, PartialEq)] +pub struct TypedConstant<'ast, T> { + pub id: Identifier<'ast>, + pub ty: Type, + pub expression: TypedExpression<'ast, T>, +} + +impl<'ast, T: fmt::Debug> fmt::Debug for TypedConstant<'ast, T> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "TypedConstant({:?}, {:?}, ...)", self.id, self.ty) + } +} + +impl<'ast, T: fmt::Display> fmt::Display for TypedConstant<'ast, T> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "const {} {} = {}", self.ty, self.id, self.expression) + } +} + /// Something we can assign to. #[derive(Clone, PartialEq, Hash, Eq)] pub enum TypedAssignee<'ast, T> { diff --git a/zokrates_parser/src/zokrates.pest b/zokrates_parser/src/zokrates.pest index b6afdd04..e896bed5 100644 --- a/zokrates_parser/src/zokrates.pest +++ b/zokrates_parser/src/zokrates.pest @@ -1,5 +1,5 @@ -file = { SOI ~ NEWLINE* ~ pragma? ~ NEWLINE* ~ import_directive* ~ NEWLINE* ~ ty_struct_definition* ~ NEWLINE* ~ function_definition* ~ EOI } +file = { SOI ~ NEWLINE* ~ pragma? ~ NEWLINE* ~ import_directive* ~ NEWLINE* ~ ty_struct_definition* ~ NEWLINE* ~ const_definition* ~ NEWLINE* ~ function_definition* ~ EOI } pragma = { "#pragma" ~ "curve" ~ curve } curve = @{ (ASCII_ALPHANUMERIC | "_") * } @@ -9,6 +9,7 @@ from_import_directive = { "from" ~ "\"" ~ import_source ~ "\"" ~ "import" ~ iden main_import_directive = {"import" ~ "\"" ~ import_source ~ "\"" ~ ("as" ~ identifier)? ~ NEWLINE+} import_source = @{(!"\"" ~ ANY)*} function_definition = {"def" ~ identifier ~ "(" ~ parameter_list ~ ")" ~ return_types ~ ":" ~ NEWLINE* ~ statement* } +const_definition = {"const" ~ ty ~ identifier ~ "=" ~ expression ~ NEWLINE*} return_types = _{ ( "->" ~ ( "(" ~ type_list ~ ")" | ty ))? } parameter_list = _{(parameter ~ ("," ~ parameter)*)?} @@ -128,6 +129,6 @@ COMMENT = _{ ("/*" ~ (!"*/" ~ ANY)* ~ "*/") | ("//" ~ (!NEWLINE ~ ANY)*) } // the ordering of reserved keywords matters: if "as" is before "assert", then "assert" gets parsed as (as)(sert) and incorrectly // accepted -keyword = @{"assert"|"as"|"bool"|"byte"|"def"|"do"|"else"|"endfor"|"export"|"false"|"field"|"for"|"if"|"then"|"fi"|"import"|"from"| +keyword = @{"assert"|"as"|"bool"|"byte"|"const"|"def"|"do"|"else"|"endfor"|"export"|"false"|"field"|"for"|"if"|"then"|"fi"|"import"|"from"| "in"|"private"|"public"|"return"|"struct"|"true"|"u8"|"u16"|"u32" } diff --git a/zokrates_pest_ast/src/lib.rs b/zokrates_pest_ast/src/lib.rs index 4ba8e299..66303b46 100644 --- a/zokrates_pest_ast/src/lib.rs +++ b/zokrates_pest_ast/src/lib.rs @@ -10,12 +10,13 @@ extern crate lazy_static; pub use ast::{ Access, ArrayAccess, ArrayInitializerExpression, ArrayType, AssertionStatement, Assignee, AssigneeAccess, BasicOrStructType, BasicType, BinaryExpression, BinaryOperator, CallAccess, - ConstantExpression, DecimalNumberExpression, DefinitionStatement, Expression, FieldType, File, - FromExpression, Function, IdentifierExpression, ImportDirective, ImportSource, - InlineArrayExpression, InlineStructExpression, InlineStructMember, IterationStatement, - OptionallyTypedAssignee, Parameter, PostfixExpression, Range, RangeOrExpression, - ReturnStatement, Span, Spread, SpreadOrExpression, Statement, StructDefinition, StructField, - TernaryExpression, ToExpression, Type, UnaryExpression, UnaryOperator, Visibility, + ConstantDefinition, ConstantExpression, DecimalNumberExpression, DefinitionStatement, + Expression, FieldType, File, FromExpression, Function, IdentifierExpression, ImportDirective, + ImportSource, InlineArrayExpression, InlineStructExpression, InlineStructMember, + IterationStatement, OptionallyTypedAssignee, Parameter, PostfixExpression, Range, + RangeOrExpression, ReturnStatement, Span, Spread, SpreadOrExpression, Statement, + StructDefinition, StructField, TernaryExpression, ToExpression, Type, UnaryExpression, + UnaryOperator, Visibility, }; mod ast { @@ -173,6 +174,7 @@ mod ast { pub pragma: Option>, pub imports: Vec>, pub structs: Vec>, + pub constants: Vec>, pub functions: Vec>, pub eoi: EOI, #[pest_ast(outer())] @@ -225,6 +227,16 @@ mod ast { pub span: Span<'ast>, } + #[derive(Debug, FromPest, PartialEq, Clone)] + #[pest_ast(rule(Rule::const_definition))] + pub struct ConstantDefinition<'ast> { + pub ty: Type<'ast>, + pub id: IdentifierExpression<'ast>, + pub expression: Expression<'ast>, + #[pest_ast(outer())] + pub span: Span<'ast>, + } + #[derive(Debug, FromPest, PartialEq, Clone)] #[pest_ast(rule(Rule::import_directive))] pub enum ImportDirective<'ast> { @@ -866,6 +878,7 @@ mod tests { Ok(File { pragma: None, structs: vec![], + constants: vec![], functions: vec![Function { id: IdentifierExpression { value: String::from("main"), @@ -919,6 +932,7 @@ mod tests { Ok(File { pragma: None, structs: vec![], + constants: vec![], functions: vec![Function { id: IdentifierExpression { value: String::from("main"), @@ -990,6 +1004,7 @@ mod tests { Ok(File { pragma: None, structs: vec![], + constants: vec![], functions: vec![Function { id: IdentifierExpression { value: String::from("main"), @@ -1048,6 +1063,7 @@ mod tests { Ok(File { pragma: None, structs: vec![], + constants: vec![], functions: vec![Function { id: IdentifierExpression { value: String::from("main"), @@ -1084,6 +1100,7 @@ mod tests { Ok(File { pragma: None, structs: vec![], + constants: vec![], functions: vec![Function { id: IdentifierExpression { value: String::from("main"), From 1ca985809b6d2405dc20eca2e447d452a1facbac Mon Sep 17 00:00:00 2001 From: dark64 Date: Fri, 2 Apr 2021 17:55:58 +0200 Subject: [PATCH 02/12] fix clippy warning --- zokrates_core/src/absy/from_ast.rs | 2 +- zokrates_core/src/absy/mod.rs | 2 +- zokrates_core/src/semantics.rs | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/zokrates_core/src/absy/from_ast.rs b/zokrates_core/src/absy/from_ast.rs index 64cddf2a..0db771b3 100644 --- a/zokrates_core/src/absy/from_ast.rs +++ b/zokrates_core/src/absy/from_ast.rs @@ -105,7 +105,7 @@ impl<'ast> From> for absy::SymbolDeclarationNode< absy::SymbolDeclaration { id, - symbol: absy::Symbol::Here(SymbolDefinition::Constant(ty)), + symbol: absy::Symbol::Here(SymbolDefinition::Constant(box ty)), } .span(span) } diff --git a/zokrates_core/src/absy/mod.rs b/zokrates_core/src/absy/mod.rs index 4df76858..a1bca482 100644 --- a/zokrates_core/src/absy/mod.rs +++ b/zokrates_core/src/absy/mod.rs @@ -53,7 +53,7 @@ pub struct SymbolDeclaration<'ast> { #[derive(PartialEq, Clone)] pub enum SymbolDefinition<'ast> { Struct(StructDefinitionNode<'ast>), - Constant(ConstantDefinitionNode<'ast>), + Constant(Box>), Function(FunctionNode<'ast>), } diff --git a/zokrates_core/src/semantics.rs b/zokrates_core/src/semantics.rs index 3901e3d2..1e3f19d3 100644 --- a/zokrates_core/src/semantics.rs +++ b/zokrates_core/src/semantics.rs @@ -453,7 +453,7 @@ impl<'ast, T: Field> Checker<'ast, T> { })), } } - SymbolDefinition::Constant(c) => { + SymbolDefinition::Constant(box c) => { match self.check_constant_definition(c, module_id, &state.types) { Ok(c) => { match symbol_unifier.insert_symbol(declaration.id, SymbolType::Constant) From 8f4ee002762e9e4895f2784630699d374d71b612 Mon Sep 17 00:00:00 2001 From: dark64 Date: Mon, 5 Apr 2021 15:53:09 +0200 Subject: [PATCH 03/12] constant inliner tests --- changelogs/unreleased/792-dark64 | 1 + zokrates_core/src/semantics.rs | 4 +- .../src/static_analysis/constant_inliner.rs | 525 +++++++++++++++++- .../src/static_analysis/reducer/mod.rs | 20 +- zokrates_core/src/typed_absy/abi.rs | 2 +- zokrates_core/src/typed_absy/mod.rs | 2 +- 6 files changed, 538 insertions(+), 16 deletions(-) create mode 100644 changelogs/unreleased/792-dark64 diff --git a/changelogs/unreleased/792-dark64 b/changelogs/unreleased/792-dark64 new file mode 100644 index 00000000..fbc1a5dc --- /dev/null +++ b/changelogs/unreleased/792-dark64 @@ -0,0 +1 @@ +Introduce constant definitions to the language (`const` keyword) \ No newline at end of file diff --git a/zokrates_core/src/semantics.rs b/zokrates_core/src/semantics.rs index 1e3f19d3..d7abd022 100644 --- a/zokrates_core/src/semantics.rs +++ b/zokrates_core/src/semantics.rs @@ -701,7 +701,7 @@ impl<'ast, T: Field> Checker<'ast, T> { Some(TypedModule { functions: checked_functions, - constants: checked_constants, + constants: Some(checked_constants).filter(|m| !m.is_empty()), }) } }; @@ -3152,7 +3152,7 @@ mod tests { )] .into_iter() .collect(), - constants: Default::default() + constants: None }) ); } diff --git a/zokrates_core/src/static_analysis/constant_inliner.rs b/zokrates_core/src/static_analysis/constant_inliner.rs index ad643ec6..645dbc5e 100644 --- a/zokrates_core/src/static_analysis/constant_inliner.rs +++ b/zokrates_core/src/static_analysis/constant_inliner.rs @@ -27,8 +27,11 @@ impl<'ast, T: Field> ConstantInliner<'ast, T> { impl<'ast, T: Field> Folder<'ast, T> for ConstantInliner<'ast, T> { fn fold_module(&mut self, p: TypedModule<'ast, T>) -> TypedModule<'ast, T> { - self.constants = p.constants.clone(); - fold_module(self, p) + self.constants = p.constants.clone().unwrap_or_default(); + TypedModule { + functions: fold_module(self, p).functions, + constants: None, + } } fn fold_field_expression( @@ -111,3 +114,521 @@ impl<'ast, T: Field> Folder<'ast, T> for ConstantInliner<'ast, T> { } } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::typed_absy::types::DeclarationSignature; + use crate::typed_absy::{ + DeclarationFunctionKey, DeclarationType, FieldElementExpression, GType, Identifier, + TypedConstant, TypedExpression, TypedFunction, TypedFunctionSymbol, TypedStatement, + }; + use zokrates_field::Bn128Field; + + #[test] + fn inline_const_field() { + // const field a = 1 + // + // def main() -> field: + // return a + + let const_id = Identifier::from("a"); + let main: TypedFunction = TypedFunction { + arguments: vec![], + statements: vec![TypedStatement::Return(vec![ + FieldElementExpression::Identifier(const_id.clone()).into(), + ])], + signature: DeclarationSignature::new() + .inputs(vec![]) + .outputs(vec![DeclarationType::FieldElement]), + }; + + let mut constants = TypedConstants::::new(); + constants.insert( + const_id.clone(), + TypedConstant { + id: const_id.clone(), + ty: GType::FieldElement, + expression: (TypedExpression::FieldElement(FieldElementExpression::Number( + Bn128Field::from(1), + ))), + }, + ); + + let program = TypedProgram { + main: "main".into(), + modules: vec![( + "main".into(), + TypedModule { + functions: vec![( + DeclarationFunctionKey::with_location("main", "main").signature( + DeclarationSignature::new() + .inputs(vec![]) + .outputs(vec![DeclarationType::FieldElement]), + ), + TypedFunctionSymbol::Here(main), + )] + .into_iter() + .collect(), + constants: Some(constants), + }, + )] + .into_iter() + .collect(), + }; + + let program = ConstantInliner::inline(program); + + let expected_main = TypedFunction { + arguments: vec![], + statements: vec![TypedStatement::Return(vec![ + FieldElementExpression::Number(Bn128Field::from(1)).into(), + ])], + signature: DeclarationSignature::new() + .inputs(vec![]) + .outputs(vec![DeclarationType::FieldElement]), + }; + + let expected_program: TypedProgram = TypedProgram { + main: "main".into(), + modules: vec![( + "main".into(), + TypedModule { + functions: vec![( + DeclarationFunctionKey::with_location("main", "main").signature( + DeclarationSignature::new() + .inputs(vec![]) + .outputs(vec![DeclarationType::FieldElement]), + ), + TypedFunctionSymbol::Here(expected_main), + )] + .into_iter() + .collect(), + constants: None, + }, + )] + .into_iter() + .collect(), + }; + + assert_eq!(program, expected_program) + } + + #[test] + fn inline_const_boolean() { + // const bool a = true + // + // def main() -> bool: + // return a + + let const_id = Identifier::from("a"); + let main: TypedFunction = TypedFunction { + arguments: vec![], + statements: vec![TypedStatement::Return(vec![BooleanExpression::Identifier( + const_id.clone(), + ) + .into()])], + signature: DeclarationSignature::new() + .inputs(vec![]) + .outputs(vec![DeclarationType::Boolean]), + }; + + let mut constants = TypedConstants::::new(); + constants.insert( + const_id.clone(), + TypedConstant { + id: const_id.clone(), + ty: GType::Boolean, + expression: (TypedExpression::Boolean(BooleanExpression::Value(true))), + }, + ); + + let program = TypedProgram { + main: "main".into(), + modules: vec![( + "main".into(), + TypedModule { + functions: vec![( + DeclarationFunctionKey::with_location("main", "main").signature( + DeclarationSignature::new() + .inputs(vec![]) + .outputs(vec![DeclarationType::Boolean]), + ), + TypedFunctionSymbol::Here(main), + )] + .into_iter() + .collect(), + constants: Some(constants), + }, + )] + .into_iter() + .collect(), + }; + + let program = ConstantInliner::inline(program); + + let expected_main = TypedFunction { + arguments: vec![], + statements: vec![TypedStatement::Return(vec![ + BooleanExpression::Value(true).into() + ])], + signature: DeclarationSignature::new() + .inputs(vec![]) + .outputs(vec![DeclarationType::Boolean]), + }; + + let expected_program: TypedProgram = TypedProgram { + main: "main".into(), + modules: vec![( + "main".into(), + TypedModule { + functions: vec![( + DeclarationFunctionKey::with_location("main", "main").signature( + DeclarationSignature::new() + .inputs(vec![]) + .outputs(vec![DeclarationType::Boolean]), + ), + TypedFunctionSymbol::Here(expected_main), + )] + .into_iter() + .collect(), + constants: None, + }, + )] + .into_iter() + .collect(), + }; + + assert_eq!(program, expected_program) + } + + #[test] + fn inline_const_uint() { + // const u32 a = 0x00000001 + // + // def main() -> u32: + // return a + + let const_id = Identifier::from("a"); + let main: TypedFunction = TypedFunction { + arguments: vec![], + statements: vec![TypedStatement::Return(vec![UExpressionInner::Identifier( + const_id.clone(), + ) + .annotate(UBitwidth::B32) + .into()])], + signature: DeclarationSignature::new() + .inputs(vec![]) + .outputs(vec![DeclarationType::Uint(UBitwidth::B32)]), + }; + + let mut constants = TypedConstants::::new(); + constants.insert( + const_id.clone(), + TypedConstant { + id: const_id.clone(), + ty: GType::Uint(UBitwidth::B32), + expression: (UExpressionInner::Value(1u128) + .annotate(UBitwidth::B32) + .into()), + }, + ); + + let program = TypedProgram { + main: "main".into(), + modules: vec![( + "main".into(), + TypedModule { + functions: vec![( + DeclarationFunctionKey::with_location("main", "main").signature( + DeclarationSignature::new() + .inputs(vec![]) + .outputs(vec![DeclarationType::Uint(UBitwidth::B32)]), + ), + TypedFunctionSymbol::Here(main), + )] + .into_iter() + .collect(), + constants: Some(constants), + }, + )] + .into_iter() + .collect(), + }; + + let program = ConstantInliner::inline(program); + + let expected_main = TypedFunction { + arguments: vec![], + statements: vec![TypedStatement::Return(vec![UExpressionInner::Value(1u128) + .annotate(UBitwidth::B32) + .into()])], + signature: DeclarationSignature::new() + .inputs(vec![]) + .outputs(vec![DeclarationType::Uint(UBitwidth::B32)]), + }; + + let expected_program: TypedProgram = TypedProgram { + main: "main".into(), + modules: vec![( + "main".into(), + TypedModule { + functions: vec![( + DeclarationFunctionKey::with_location("main", "main").signature( + DeclarationSignature::new() + .inputs(vec![]) + .outputs(vec![DeclarationType::Uint(UBitwidth::B32)]), + ), + TypedFunctionSymbol::Here(expected_main), + )] + .into_iter() + .collect(), + constants: None, + }, + )] + .into_iter() + .collect(), + }; + + assert_eq!(program, expected_program) + } + + #[test] + fn inline_const_field_array() { + // const field[2] a = [2, 2] + // + // def main() -> field: + // return a[0] + a[1] + + let const_id = Identifier::from("a"); + let main: TypedFunction = TypedFunction { + arguments: vec![], + statements: vec![TypedStatement::Return(vec![FieldElementExpression::Add( + FieldElementExpression::Select( + box ArrayExpressionInner::Identifier(const_id.clone()) + .annotate(GType::FieldElement, 2usize), + box UExpressionInner::Value(0u128).annotate(UBitwidth::B32), + ) + .into(), + FieldElementExpression::Select( + box ArrayExpressionInner::Identifier(const_id.clone()) + .annotate(GType::FieldElement, 2usize), + box UExpressionInner::Value(1u128).annotate(UBitwidth::B32), + ) + .into(), + ) + .into()])], + signature: DeclarationSignature::new() + .inputs(vec![]) + .outputs(vec![DeclarationType::FieldElement]), + }; + + let mut constants = TypedConstants::::new(); + constants.insert( + const_id.clone(), + TypedConstant { + id: const_id.clone(), + ty: GType::FieldElement, + expression: TypedExpression::Array( + ArrayExpressionInner::Value( + vec![ + FieldElementExpression::Number(Bn128Field::from(2)).into(), + FieldElementExpression::Number(Bn128Field::from(2)).into(), + ] + .into(), + ) + .annotate(GType::FieldElement, 2usize), + ), + }, + ); + + let program = TypedProgram { + main: "main".into(), + modules: vec![( + "main".into(), + TypedModule { + functions: vec![( + DeclarationFunctionKey::with_location("main", "main").signature( + DeclarationSignature::new() + .inputs(vec![]) + .outputs(vec![DeclarationType::FieldElement]), + ), + TypedFunctionSymbol::Here(main), + )] + .into_iter() + .collect(), + constants: Some(constants), + }, + )] + .into_iter() + .collect(), + }; + + let program = ConstantInliner::inline(program); + + let expected_main = TypedFunction { + arguments: vec![], + statements: vec![TypedStatement::Return(vec![FieldElementExpression::Add( + FieldElementExpression::Select( + box ArrayExpressionInner::Value( + vec![ + FieldElementExpression::Number(Bn128Field::from(2)).into(), + FieldElementExpression::Number(Bn128Field::from(2)).into(), + ] + .into(), + ) + .annotate(GType::FieldElement, 2usize), + box UExpressionInner::Value(0u128).annotate(UBitwidth::B32), + ) + .into(), + FieldElementExpression::Select( + box ArrayExpressionInner::Value( + vec![ + FieldElementExpression::Number(Bn128Field::from(2)).into(), + FieldElementExpression::Number(Bn128Field::from(2)).into(), + ] + .into(), + ) + .annotate(GType::FieldElement, 2usize), + box UExpressionInner::Value(1u128).annotate(UBitwidth::B32), + ) + .into(), + ) + .into()])], + signature: DeclarationSignature::new() + .inputs(vec![]) + .outputs(vec![DeclarationType::FieldElement]), + }; + + let expected_program: TypedProgram = TypedProgram { + main: "main".into(), + modules: vec![( + "main".into(), + TypedModule { + functions: vec![( + DeclarationFunctionKey::with_location("main", "main").signature( + DeclarationSignature::new() + .inputs(vec![]) + .outputs(vec![DeclarationType::FieldElement]), + ), + TypedFunctionSymbol::Here(expected_main), + )] + .into_iter() + .collect(), + constants: None, + }, + )] + .into_iter() + .collect(), + }; + + assert_eq!(program, expected_program) + } + + #[test] + fn inline_nested_const_field() { + // const field a = 1 + // const field b = a + 1 + // + // def main() -> field: + // return b + + let const_a_id = Identifier::from("a"); + let const_b_id = Identifier::from("b"); + + let main: TypedFunction = TypedFunction { + arguments: vec![], + statements: vec![TypedStatement::Return(vec![ + FieldElementExpression::Identifier(const_b_id.clone()).into(), + ])], + signature: DeclarationSignature::new() + .inputs(vec![]) + .outputs(vec![DeclarationType::FieldElement]), + }; + + let mut constants = TypedConstants::::new(); + constants.extend(vec![ + ( + const_a_id.clone(), + TypedConstant { + id: const_a_id.clone(), + ty: GType::FieldElement, + expression: (TypedExpression::FieldElement(FieldElementExpression::Number( + Bn128Field::from(1), + ))), + }, + ), + ( + const_b_id.clone(), + TypedConstant { + id: const_b_id.clone(), + ty: GType::FieldElement, + expression: (TypedExpression::FieldElement(FieldElementExpression::Add( + box FieldElementExpression::Identifier(const_a_id.clone()), + box FieldElementExpression::Number(Bn128Field::from(1)), + ))), + }, + ), + ]); + + let program = TypedProgram { + main: "main".into(), + modules: vec![( + "main".into(), + TypedModule { + functions: vec![( + DeclarationFunctionKey::with_location("main", "main").signature( + DeclarationSignature::new() + .inputs(vec![]) + .outputs(vec![DeclarationType::FieldElement]), + ), + TypedFunctionSymbol::Here(main), + )] + .into_iter() + .collect(), + constants: Some(constants), + }, + )] + .into_iter() + .collect(), + }; + + let program = ConstantInliner::inline(program); + + let expected_main = TypedFunction { + arguments: vec![], + statements: vec![TypedStatement::Return(vec![FieldElementExpression::Add( + box FieldElementExpression::Number(Bn128Field::from(1)), + box FieldElementExpression::Number(Bn128Field::from(1)), + ) + .into()])], + signature: DeclarationSignature::new() + .inputs(vec![]) + .outputs(vec![DeclarationType::FieldElement]), + }; + + let expected_program: TypedProgram = TypedProgram { + main: "main".into(), + modules: vec![( + "main".into(), + TypedModule { + functions: vec![( + DeclarationFunctionKey::with_location("main", "main").signature( + DeclarationSignature::new() + .inputs(vec![]) + .outputs(vec![DeclarationType::FieldElement]), + ), + TypedFunctionSymbol::Here(expected_main), + )] + .into_iter() + .collect(), + constants: None, + }, + )] + .into_iter() + .collect(), + }; + + assert_eq!(program, expected_program) + } +} diff --git a/zokrates_core/src/static_analysis/reducer/mod.rs b/zokrates_core/src/static_analysis/reducer/mod.rs index 57e17db2..05191af1 100644 --- a/zokrates_core/src/static_analysis/reducer/mod.rs +++ b/zokrates_core/src/static_analysis/reducer/mod.rs @@ -547,7 +547,7 @@ pub fn reduce_program(p: TypedProgram) -> Result, E )] .into_iter() .collect(), - constants: Default::default(), + constants: None, }, )] .into_iter() @@ -769,7 +769,7 @@ mod tests { ] .into_iter() .collect(), - constants: Default::default(), + constants: None, }, )] .into_iter() @@ -835,7 +835,7 @@ mod tests { )] .into_iter() .collect(), - constants: Default::default(), + constants: None, }, )] .into_iter() @@ -964,7 +964,7 @@ mod tests { ] .into_iter() .collect(), - constants: Default::default(), + constants: None, }, )] .into_iter() @@ -1045,7 +1045,7 @@ mod tests { )] .into_iter() .collect(), - constants: Default::default(), + constants: None, }, )] .into_iter() @@ -1183,7 +1183,7 @@ mod tests { ] .into_iter() .collect(), - constants: Default::default(), + constants: None, }, )] .into_iter() @@ -1264,7 +1264,7 @@ mod tests { )] .into_iter() .collect(), - constants: Default::default(), + constants: None, }, )] .into_iter() @@ -1441,7 +1441,7 @@ mod tests { ] .into_iter() .collect(), - constants: Default::default(), + constants: None, }, )] .into_iter() @@ -1545,7 +1545,7 @@ mod tests { )] .into_iter() .collect(), - constants: Default::default(), + constants: None, }, )] .into_iter() @@ -1629,7 +1629,7 @@ mod tests { ] .into_iter() .collect(), - constants: Default::default(), + constants: None, }, )] .into_iter() diff --git a/zokrates_core/src/typed_absy/abi.rs b/zokrates_core/src/typed_absy/abi.rs index f3adbdea..7b189e97 100644 --- a/zokrates_core/src/typed_absy/abi.rs +++ b/zokrates_core/src/typed_absy/abi.rs @@ -69,7 +69,7 @@ mod tests { "main".into(), TypedModule { functions, - constants: Default::default(), + constants: None, }, ); diff --git a/zokrates_core/src/typed_absy/mod.rs b/zokrates_core/src/typed_absy/mod.rs index def99d03..759421e6 100644 --- a/zokrates_core/src/typed_absy/mod.rs +++ b/zokrates_core/src/typed_absy/mod.rs @@ -144,7 +144,7 @@ pub struct TypedModule<'ast, T> { /// Functions of the module pub functions: TypedFunctionSymbols<'ast, T>, /// Constants defined in module - pub constants: TypedConstants<'ast, T>, + pub constants: Option>, } #[derive(Clone, PartialEq)] From dafef03b1f5e7455e62923ffab2def6334441c63 Mon Sep 17 00:00:00 2001 From: dark64 Date: Thu, 8 Apr 2021 11:29:21 +0200 Subject: [PATCH 04/12] fix imports, more tests --- .../examples/book/constant_definition.zok | 5 + .../compile_errors/constant_assignment.zok | 5 + zokrates_cli/examples/imports/bar.zok | 8 +- zokrates_cli/examples/imports/baz.zok | 7 +- zokrates_cli/examples/imports/foo.zok | 11 +- zokrates_cli/examples/imports/import.zok | 11 - .../examples/imports/import_constants.zok | 6 + .../examples/imports/import_functions.zok | 6 + .../examples/imports/import_structs.zok | 8 + .../examples/imports/import_with_alias.zok | 8 +- zokrates_core/src/semantics.rs | 45 +- .../src/static_analysis/constant_inliner.rs | 440 +++++++++++++----- zokrates_core/src/typed_absy/folder.rs | 26 +- zokrates_core/src/typed_absy/mod.rs | 21 +- .../tests/tests/constants/array.json | 16 + .../tests/tests/constants/array.zok | 4 + .../tests/tests/constants/bool.json | 16 + .../tests/tests/constants/bool.zok | 4 + .../tests/tests/constants/field.json | 16 + .../tests/tests/constants/field.zok | 4 + .../tests/tests/constants/nested.json | 16 + .../tests/tests/constants/nested.zok | 6 + .../tests/tests/constants/struct.json | 16 + .../tests/tests/constants/struct.zok | 9 + .../tests/tests/constants/uint.json | 16 + .../tests/tests/constants/uint.zok | 4 + 26 files changed, 589 insertions(+), 145 deletions(-) create mode 100644 zokrates_cli/examples/book/constant_definition.zok create mode 100644 zokrates_cli/examples/compile_errors/constant_assignment.zok delete mode 100644 zokrates_cli/examples/imports/import.zok create mode 100644 zokrates_cli/examples/imports/import_constants.zok create mode 100644 zokrates_cli/examples/imports/import_functions.zok create mode 100644 zokrates_cli/examples/imports/import_structs.zok create mode 100644 zokrates_core_test/tests/tests/constants/array.json create mode 100644 zokrates_core_test/tests/tests/constants/array.zok create mode 100644 zokrates_core_test/tests/tests/constants/bool.json create mode 100644 zokrates_core_test/tests/tests/constants/bool.zok create mode 100644 zokrates_core_test/tests/tests/constants/field.json create mode 100644 zokrates_core_test/tests/tests/constants/field.zok create mode 100644 zokrates_core_test/tests/tests/constants/nested.json create mode 100644 zokrates_core_test/tests/tests/constants/nested.zok create mode 100644 zokrates_core_test/tests/tests/constants/struct.json create mode 100644 zokrates_core_test/tests/tests/constants/struct.zok create mode 100644 zokrates_core_test/tests/tests/constants/uint.json create mode 100644 zokrates_core_test/tests/tests/constants/uint.zok diff --git a/zokrates_cli/examples/book/constant_definition.zok b/zokrates_cli/examples/book/constant_definition.zok new file mode 100644 index 00000000..b6850f89 --- /dev/null +++ b/zokrates_cli/examples/book/constant_definition.zok @@ -0,0 +1,5 @@ +const field ONE = 1 +const field TWO = ONE + ONE + +def main() -> field: + return TWO \ No newline at end of file diff --git a/zokrates_cli/examples/compile_errors/constant_assignment.zok b/zokrates_cli/examples/compile_errors/constant_assignment.zok new file mode 100644 index 00000000..e04da9f6 --- /dev/null +++ b/zokrates_cli/examples/compile_errors/constant_assignment.zok @@ -0,0 +1,5 @@ +const field a = 1 + +def main() -> field: + a = 2 // not allowed + return a \ No newline at end of file diff --git a/zokrates_cli/examples/imports/bar.zok b/zokrates_cli/examples/imports/bar.zok index c7d3af12..34eb1c5a 100644 --- a/zokrates_cli/examples/imports/bar.zok +++ b/zokrates_cli/examples/imports/bar.zok @@ -1,5 +1,7 @@ -struct Bar { -} +struct Bar {} + +const field ONE = 1 +const field BAR = 21 * ONE def main() -> field: - return 21 \ No newline at end of file + return BAR \ No newline at end of file diff --git a/zokrates_cli/examples/imports/baz.zok b/zokrates_cli/examples/imports/baz.zok index 9fd704a3..84cc641d 100644 --- a/zokrates_cli/examples/imports/baz.zok +++ b/zokrates_cli/examples/imports/baz.zok @@ -1,5 +1,6 @@ -struct Baz { -} +struct Baz {} + +const field BAZ = 123 def main() -> field: - return 123 \ No newline at end of file + return BAZ \ No newline at end of file diff --git a/zokrates_cli/examples/imports/foo.zok b/zokrates_cli/examples/imports/foo.zok index 43018b20..8dddfd8b 100644 --- a/zokrates_cli/examples/imports/foo.zok +++ b/zokrates_cli/examples/imports/foo.zok @@ -1,9 +1,10 @@ from "./baz" import Baz - -import "./baz" from "./baz" import main as my_function +import "./baz" + +const field FOO = 144 def main() -> field: - field a = my_function() - Baz b = Baz {} - return baz() \ No newline at end of file + Baz b = Baz {} + assert(baz() == my_function()) + return FOO \ No newline at end of file diff --git a/zokrates_cli/examples/imports/import.zok b/zokrates_cli/examples/imports/import.zok deleted file mode 100644 index bc4e7669..00000000 --- a/zokrates_cli/examples/imports/import.zok +++ /dev/null @@ -1,11 +0,0 @@ -from "./bar" import Bar as MyBar -from "./bar" import Bar - -import "./foo" -import "./bar" - -def main() -> field: - MyBar my_bar = MyBar {} - Bar bar = Bar {} - assert(my_bar == bar) - return foo() + bar() \ No newline at end of file diff --git a/zokrates_cli/examples/imports/import_constants.zok b/zokrates_cli/examples/imports/import_constants.zok new file mode 100644 index 00000000..2abaaea5 --- /dev/null +++ b/zokrates_cli/examples/imports/import_constants.zok @@ -0,0 +1,6 @@ +from "./foo" import FOO +from "./bar" import BAR +from "./baz" import BAZ + +def main() -> bool: + return FOO == BAR + BAZ \ No newline at end of file diff --git a/zokrates_cli/examples/imports/import_functions.zok b/zokrates_cli/examples/imports/import_functions.zok new file mode 100644 index 00000000..32628eb9 --- /dev/null +++ b/zokrates_cli/examples/imports/import_functions.zok @@ -0,0 +1,6 @@ +import "./foo" +import "./bar" +import "./baz" + +def main() -> bool: + return foo() == bar() + baz() \ No newline at end of file diff --git a/zokrates_cli/examples/imports/import_structs.zok b/zokrates_cli/examples/imports/import_structs.zok new file mode 100644 index 00000000..61c94b02 --- /dev/null +++ b/zokrates_cli/examples/imports/import_structs.zok @@ -0,0 +1,8 @@ +from "./bar" import Bar as MyBar +from "./bar" import Bar + +def main(): + MyBar my_bar = MyBar {} + Bar bar = Bar {} + assert(my_bar == bar) + return \ No newline at end of file diff --git a/zokrates_cli/examples/imports/import_with_alias.zok b/zokrates_cli/examples/imports/import_with_alias.zok index e9bb2194..77cdd263 100644 --- a/zokrates_cli/examples/imports/import_with_alias.zok +++ b/zokrates_cli/examples/imports/import_with_alias.zok @@ -1,4 +1,8 @@ -import "./foo" as d +from "./bar" import main as bar +from "./baz" import main as baz +import "./foo" as f def main() -> field: - return d() \ No newline at end of file + field foo = f() + assert(foo == bar() + baz()) + return foo \ No newline at end of file diff --git a/zokrates_core/src/semantics.rs b/zokrates_core/src/semantics.rs index d7abd022..c1ec8228 100644 --- a/zokrates_core/src/semantics.rs +++ b/zokrates_core/src/semantics.rs @@ -407,7 +407,7 @@ impl<'ast, T: Field> Checker<'ast, T> { module_id: &ModuleId, state: &mut State<'ast, T>, functions: &mut HashMap, TypedFunctionSymbol<'ast, T>>, - constants: &mut HashMap, TypedConstant<'ast, T>>, + constants: &mut HashMap, TypedConstantSymbol<'ast, T>>, symbol_unifier: &mut SymbolUnifier<'ast>, ) -> Result<(), Vec> { let mut errors: Vec = vec![]; @@ -470,8 +470,7 @@ impl<'ast, T: Field> Checker<'ast, T> { ), true => {} }; - constants - .insert(identifier::Identifier::from(declaration.id), c.clone()); + constants.insert(declaration.id, TypedConstantSymbol::Here(c.clone())); self.insert_into_scope(Variable::with_id_and_type(c.id, c.ty), true); } Err(e) => { @@ -549,8 +548,21 @@ impl<'ast, T: Field> Checker<'ast, T> { .get(import.symbol_id) .cloned(); - match (function_candidates.len(), type_candidate) { - (0, Some(t)) => { + // find constant definition candidate + let const_candidate = state + .typed_modules + .get(&import.module_id) + .unwrap() + .constants + .as_ref() + .and_then(|tc| tc.get(import.symbol_id)) + .and_then(|sym| match sym { + TypedConstantSymbol::Here(tc) => Some(tc), + _ => None, + }); + + match (function_candidates.len(), type_candidate, const_candidate) { + (0, Some(t), None) => { // rename the type to the declared symbol let t = match t { @@ -585,7 +597,26 @@ impl<'ast, T: Field> Checker<'ast, T> { .or_default() .insert(declaration.id.to_string(), t); } - (0, None) => { + (0, None, Some(c)) => { + match symbol_unifier.insert_symbol(declaration.id, SymbolType::Constant) { + false => { + errors.push(Error { + module_id: module_id.to_path_buf(), + inner: ErrorInner { + pos: Some(pos), + message: format!( + "{} conflicts with another symbol", + declaration.id, + ), + }}); + } + true => { + constants.insert(declaration.id, TypedConstantSymbol::There(import.module_id, declaration.id)); + self.insert_into_scope(Variable::with_id_and_type(c.id.clone(), c.ty.clone()), true); + } + }; + } + (0, None, None) => { errors.push(ErrorInner { pos: Some(pos), message: format!( @@ -594,7 +625,7 @@ impl<'ast, T: Field> Checker<'ast, T> { ), }.in_file(module_id)); } - (_, Some(_)) => unreachable!("collision in module we're importing from should have been caught when checking it"), + (_, Some(_), Some(_)) => unreachable!("collision in module we're importing from should have been caught when checking it"), _ => { for candidate in function_candidates { diff --git a/zokrates_core/src/static_analysis/constant_inliner.rs b/zokrates_core/src/static_analysis/constant_inliner.rs index 645dbc5e..9c04dcaa 100644 --- a/zokrates_core/src/static_analysis/constant_inliner.rs +++ b/zokrates_core/src/static_analysis/constant_inliner.rs @@ -1,36 +1,92 @@ -use crate::typed_absy::folder::{ - fold_array_expression, fold_array_expression_inner, fold_boolean_expression, - fold_field_expression, fold_module, fold_struct_expression, fold_struct_expression_inner, - fold_uint_expression, fold_uint_expression_inner, Folder, -}; -use crate::typed_absy::{ - ArrayExpression, ArrayExpressionInner, ArrayType, BooleanExpression, FieldElementExpression, - StructExpression, StructExpressionInner, StructType, TypedConstants, TypedModule, TypedProgram, - UBitwidth, UExpression, UExpressionInner, -}; -use std::collections::HashMap; +use crate::typed_absy::folder::*; +use crate::typed_absy::*; use std::convert::TryInto; use zokrates_field::Field; pub struct ConstantInliner<'ast, T: Field> { - constants: TypedConstants<'ast, T>, + modules: TypedModules<'ast, T>, + location: OwnedTypedModuleId, } impl<'ast, T: Field> ConstantInliner<'ast, T> { + fn with_modules_and_location( + modules: TypedModules<'ast, T>, + location: OwnedTypedModuleId, + ) -> Self { + ConstantInliner { modules, location } + } + pub fn inline(p: TypedProgram<'ast, T>) -> TypedProgram<'ast, T> { - let mut inliner = ConstantInliner { - constants: HashMap::new(), - }; + // initialize an inliner over all modules, starting from the main module + let mut inliner = + ConstantInliner::with_modules_and_location(p.modules.clone(), p.main.clone()); + inliner.fold_program(p) } + + pub fn module(&self) -> &TypedModule<'ast, T> { + self.modules.get(&self.location).unwrap() + } + + pub fn change_location(&mut self, location: OwnedTypedModuleId) -> OwnedTypedModuleId { + let prev = self.location.clone(); + self.location = location; + prev + } + + pub fn get_constant(&mut self, id: &Identifier) -> Option> { + self.modules + .get(&self.location) + .unwrap() + .constants + .as_ref() + .and_then(|c| c.get(id.clone().try_into().unwrap())) + .cloned() + .and_then(|tc| { + let symbol = self.fold_constant_symbol(tc); + match symbol { + TypedConstantSymbol::Here(tc) => Some(tc), + _ => None, + } + }) + } } impl<'ast, T: Field> Folder<'ast, T> for ConstantInliner<'ast, T> { - fn fold_module(&mut self, p: TypedModule<'ast, T>) -> TypedModule<'ast, T> { - self.constants = p.constants.clone().unwrap_or_default(); - TypedModule { - functions: fold_module(self, p).functions, - constants: None, + fn fold_program(&mut self, p: TypedProgram<'ast, T>) -> TypedProgram<'ast, T> { + TypedProgram { + modules: p + .modules + .into_iter() + .map(|(module_id, module)| { + self.change_location(module_id.clone()); + (module_id, self.fold_module(module)) + }) + .collect(), + main: p.main, + } + } + + fn fold_constant_symbol( + &mut self, + p: TypedConstantSymbol<'ast, T>, + ) -> TypedConstantSymbol<'ast, T> { + match p { + TypedConstantSymbol::There(module_id, id) => { + let location = self.change_location(module_id); + let symbol = self + .module() + .constants + .as_ref() + .and_then(|c| c.get(id)) + .unwrap() + .to_owned(); + + let symbol = self.fold_constant_symbol(symbol); + let _ = self.change_location(location); + symbol + } + _ => fold_constant_symbol(self, p), } } @@ -39,7 +95,7 @@ impl<'ast, T: Field> Folder<'ast, T> for ConstantInliner<'ast, T> { e: FieldElementExpression<'ast, T>, ) -> FieldElementExpression<'ast, T> { match e { - FieldElementExpression::Identifier(ref id) => match self.constants.get(id).cloned() { + FieldElementExpression::Identifier(ref id) => match self.get_constant(id) { Some(c) => fold_field_expression(self, c.expression.try_into().unwrap()), None => fold_field_expression(self, e), }, @@ -52,7 +108,7 @@ impl<'ast, T: Field> Folder<'ast, T> for ConstantInliner<'ast, T> { e: BooleanExpression<'ast, T>, ) -> BooleanExpression<'ast, T> { match e { - BooleanExpression::Identifier(ref id) => match self.constants.get(id).cloned() { + BooleanExpression::Identifier(ref id) => match self.get_constant(id) { Some(c) => fold_boolean_expression(self, c.expression.try_into().unwrap()), None => fold_boolean_expression(self, e), }, @@ -66,14 +122,12 @@ impl<'ast, T: Field> Folder<'ast, T> for ConstantInliner<'ast, T> { e: UExpressionInner<'ast, T>, ) -> UExpressionInner<'ast, T> { match e { - UExpressionInner::Identifier(ref id) => match self.constants.get(id).cloned() { + UExpressionInner::Identifier(ref id) => match self.get_constant(id) { Some(c) => { - let expr: UExpression<'ast, T> = c.expression.try_into().unwrap(); - fold_uint_expression(self, expr).into_inner() + fold_uint_expression(self, c.expression.try_into().unwrap()).into_inner() } None => fold_uint_expression_inner(self, size, e), }, - // default e => fold_uint_expression_inner(self, size, e), } } @@ -84,14 +138,12 @@ impl<'ast, T: Field> Folder<'ast, T> for ConstantInliner<'ast, T> { e: ArrayExpressionInner<'ast, T>, ) -> ArrayExpressionInner<'ast, T> { match e { - ArrayExpressionInner::Identifier(ref id) => match self.constants.get(id).cloned() { + ArrayExpressionInner::Identifier(ref id) => match self.get_constant(id) { Some(c) => { - let expr: ArrayExpression<'ast, T> = c.expression.try_into().unwrap(); - fold_array_expression(self, expr).into_inner() + fold_array_expression(self, c.expression.try_into().unwrap()).into_inner() } None => fold_array_expression_inner(self, ty, e), }, - // default e => fold_array_expression_inner(self, ty, e), } } @@ -102,14 +154,12 @@ impl<'ast, T: Field> Folder<'ast, T> for ConstantInliner<'ast, T> { e: StructExpressionInner<'ast, T>, ) -> StructExpressionInner<'ast, T> { match e { - StructExpressionInner::Identifier(ref id) => match self.constants.get(id).cloned() { + StructExpressionInner::Identifier(ref id) => match self.get_constant(id) { Some(c) => { - let expr: StructExpression<'ast, T> = c.expression.try_into().unwrap(); - fold_struct_expression(self, expr).into_inner() + fold_struct_expression(self, c.expression.try_into().unwrap()).into_inner() } None => fold_struct_expression_inner(self, ty, e), }, - // default e => fold_struct_expression_inner(self, ty, e), } } @@ -132,28 +182,29 @@ mod tests { // def main() -> field: // return a - let const_id = Identifier::from("a"); + let const_id = "a"; let main: TypedFunction = TypedFunction { arguments: vec![], statements: vec![TypedStatement::Return(vec![ - FieldElementExpression::Identifier(const_id.clone()).into(), + FieldElementExpression::Identifier(Identifier::from(const_id)).into(), ])], signature: DeclarationSignature::new() .inputs(vec![]) .outputs(vec![DeclarationType::FieldElement]), }; - let mut constants = TypedConstants::::new(); - constants.insert( - const_id.clone(), - TypedConstant { - id: const_id.clone(), + let constants: TypedConstantSymbols<_> = vec![( + const_id, + TypedConstantSymbol::Here(TypedConstant { + id: Identifier::from(const_id), ty: GType::FieldElement, expression: (TypedExpression::FieldElement(FieldElementExpression::Number( Bn128Field::from(1), ))), - }, - ); + }), + )] + .into_iter() + .collect(); let program = TypedProgram { main: "main".into(), @@ -170,7 +221,7 @@ mod tests { )] .into_iter() .collect(), - constants: Some(constants), + constants: Some(constants.clone()), }, )] .into_iter() @@ -204,7 +255,7 @@ mod tests { )] .into_iter() .collect(), - constants: None, + constants: Some(constants), }, )] .into_iter() @@ -221,11 +272,11 @@ mod tests { // def main() -> bool: // return a - let const_id = Identifier::from("a"); + let const_id = "a"; let main: TypedFunction = TypedFunction { arguments: vec![], statements: vec![TypedStatement::Return(vec![BooleanExpression::Identifier( - const_id.clone(), + Identifier::from(const_id), ) .into()])], signature: DeclarationSignature::new() @@ -233,15 +284,16 @@ mod tests { .outputs(vec![DeclarationType::Boolean]), }; - let mut constants = TypedConstants::::new(); - constants.insert( - const_id.clone(), - TypedConstant { - id: const_id.clone(), + let constants: TypedConstantSymbols<_> = vec![( + const_id, + TypedConstantSymbol::Here(TypedConstant { + id: Identifier::from(const_id), ty: GType::Boolean, expression: (TypedExpression::Boolean(BooleanExpression::Value(true))), - }, - ); + }), + )] + .into_iter() + .collect(); let program = TypedProgram { main: "main".into(), @@ -258,7 +310,7 @@ mod tests { )] .into_iter() .collect(), - constants: Some(constants), + constants: Some(constants.clone()), }, )] .into_iter() @@ -292,7 +344,7 @@ mod tests { )] .into_iter() .collect(), - constants: None, + constants: Some(constants), }, )] .into_iter() @@ -309,11 +361,11 @@ mod tests { // def main() -> u32: // return a - let const_id = Identifier::from("a"); + let const_id = "a"; let main: TypedFunction = TypedFunction { arguments: vec![], statements: vec![TypedStatement::Return(vec![UExpressionInner::Identifier( - const_id.clone(), + Identifier::from(const_id), ) .annotate(UBitwidth::B32) .into()])], @@ -322,17 +374,18 @@ mod tests { .outputs(vec![DeclarationType::Uint(UBitwidth::B32)]), }; - let mut constants = TypedConstants::::new(); - constants.insert( - const_id.clone(), - TypedConstant { - id: const_id.clone(), + let constants: TypedConstantSymbols<_> = vec![( + const_id, + TypedConstantSymbol::Here(TypedConstant { + id: Identifier::from(const_id), ty: GType::Uint(UBitwidth::B32), expression: (UExpressionInner::Value(1u128) .annotate(UBitwidth::B32) .into()), - }, - ); + }), + )] + .into_iter() + .collect(); let program = TypedProgram { main: "main".into(), @@ -349,7 +402,7 @@ mod tests { )] .into_iter() .collect(), - constants: Some(constants), + constants: Some(constants.clone()), }, )] .into_iter() @@ -383,7 +436,7 @@ mod tests { )] .into_iter() .collect(), - constants: None, + constants: Some(constants), }, )] .into_iter() @@ -400,18 +453,18 @@ mod tests { // def main() -> field: // return a[0] + a[1] - let const_id = Identifier::from("a"); + let const_id = "a"; let main: TypedFunction = TypedFunction { arguments: vec![], statements: vec![TypedStatement::Return(vec![FieldElementExpression::Add( FieldElementExpression::Select( - box ArrayExpressionInner::Identifier(const_id.clone()) + box ArrayExpressionInner::Identifier(Identifier::from(const_id)) .annotate(GType::FieldElement, 2usize), box UExpressionInner::Value(0u128).annotate(UBitwidth::B32), ) .into(), FieldElementExpression::Select( - box ArrayExpressionInner::Identifier(const_id.clone()) + box ArrayExpressionInner::Identifier(Identifier::from(const_id)) .annotate(GType::FieldElement, 2usize), box UExpressionInner::Value(1u128).annotate(UBitwidth::B32), ) @@ -423,11 +476,10 @@ mod tests { .outputs(vec![DeclarationType::FieldElement]), }; - let mut constants = TypedConstants::::new(); - constants.insert( - const_id.clone(), - TypedConstant { - id: const_id.clone(), + let constants: TypedConstantSymbols<_> = vec![( + const_id, + TypedConstantSymbol::Here(TypedConstant { + id: Identifier::from(const_id), ty: GType::FieldElement, expression: TypedExpression::Array( ArrayExpressionInner::Value( @@ -439,8 +491,10 @@ mod tests { ) .annotate(GType::FieldElement, 2usize), ), - }, - ); + }), + )] + .into_iter() + .collect(); let program = TypedProgram { main: "main".into(), @@ -457,7 +511,7 @@ mod tests { )] .into_iter() .collect(), - constants: Some(constants), + constants: Some(constants.clone()), }, )] .into_iter() @@ -515,7 +569,7 @@ mod tests { )] .into_iter() .collect(), - constants: None, + constants: Some(constants), }, )] .into_iter() @@ -533,44 +587,19 @@ mod tests { // def main() -> field: // return b - let const_a_id = Identifier::from("a"); - let const_b_id = Identifier::from("b"); + let const_a_id = "a"; + let const_b_id = "b"; let main: TypedFunction = TypedFunction { arguments: vec![], statements: vec![TypedStatement::Return(vec![ - FieldElementExpression::Identifier(const_b_id.clone()).into(), + FieldElementExpression::Identifier(Identifier::from(const_b_id)).into(), ])], signature: DeclarationSignature::new() .inputs(vec![]) .outputs(vec![DeclarationType::FieldElement]), }; - let mut constants = TypedConstants::::new(); - constants.extend(vec![ - ( - const_a_id.clone(), - TypedConstant { - id: const_a_id.clone(), - ty: GType::FieldElement, - expression: (TypedExpression::FieldElement(FieldElementExpression::Number( - Bn128Field::from(1), - ))), - }, - ), - ( - const_b_id.clone(), - TypedConstant { - id: const_b_id.clone(), - ty: GType::FieldElement, - expression: (TypedExpression::FieldElement(FieldElementExpression::Add( - box FieldElementExpression::Identifier(const_a_id.clone()), - box FieldElementExpression::Number(Bn128Field::from(1)), - ))), - }, - ), - ]); - let program = TypedProgram { main: "main".into(), modules: vec![( @@ -586,7 +615,37 @@ mod tests { )] .into_iter() .collect(), - constants: Some(constants), + constants: Some( + vec![ + ( + const_a_id, + TypedConstantSymbol::Here(TypedConstant { + id: Identifier::from(const_a_id), + ty: GType::FieldElement, + expression: (TypedExpression::FieldElement( + FieldElementExpression::Number(Bn128Field::from(1)), + )), + }), + ), + ( + const_b_id, + TypedConstantSymbol::Here(TypedConstant { + id: Identifier::from(const_b_id), + ty: GType::FieldElement, + expression: (TypedExpression::FieldElement( + FieldElementExpression::Add( + box FieldElementExpression::Identifier( + Identifier::from(const_a_id), + ), + box FieldElementExpression::Number(Bn128Field::from(1)), + ), + )), + }), + ), + ] + .into_iter() + .collect(), + ), }, )] .into_iter() @@ -622,7 +681,35 @@ mod tests { )] .into_iter() .collect(), - constants: None, + constants: Some( + vec![ + ( + const_a_id, + TypedConstantSymbol::Here(TypedConstant { + id: Identifier::from(const_a_id), + ty: GType::FieldElement, + expression: (TypedExpression::FieldElement( + FieldElementExpression::Number(Bn128Field::from(1)), + )), + }), + ), + ( + const_b_id, + TypedConstantSymbol::Here(TypedConstant { + id: Identifier::from(const_b_id), + ty: GType::FieldElement, + expression: (TypedExpression::FieldElement( + FieldElementExpression::Add( + box FieldElementExpression::Number(Bn128Field::from(1)), + box FieldElementExpression::Number(Bn128Field::from(1)), + ), + )), + }), + ), + ] + .into_iter() + .collect(), + ), }, )] .into_iter() @@ -631,4 +718,139 @@ mod tests { assert_eq!(program, expected_program) } + + #[test] + fn inline_imported_constant() { + // --------------------- + // module `foo` + // -------------------- + // const field FOO = 42 + // + // def main(): + // return + // + // --------------------- + // module `main` + // --------------------- + // from "foo" import FOO + // + // def main() -> field: + // return FOO + + let foo_const_id = "FOO"; + let foo_module = TypedModule { + functions: vec![( + DeclarationFunctionKey::with_location("main", "main") + .signature(DeclarationSignature::new().inputs(vec![]).outputs(vec![])), + TypedFunctionSymbol::Here(TypedFunction { + arguments: vec![], + statements: vec![], + signature: DeclarationSignature::new().inputs(vec![]).outputs(vec![]), + }), + )] + .into_iter() + .collect(), + constants: Some( + vec![( + foo_const_id, + TypedConstantSymbol::Here(TypedConstant { + id: Identifier::from(foo_const_id), + ty: GType::FieldElement, + expression: (TypedExpression::FieldElement( + FieldElementExpression::Number(Bn128Field::from(42)), + )), + }), + )] + .into_iter() + .collect(), + ), + }; + + let main_module = TypedModule { + functions: vec![( + DeclarationFunctionKey::with_location("main", "main").signature( + DeclarationSignature::new() + .inputs(vec![]) + .outputs(vec![DeclarationType::FieldElement]), + ), + TypedFunctionSymbol::Here(TypedFunction { + arguments: vec![], + statements: vec![TypedStatement::Return(vec![ + FieldElementExpression::Identifier(Identifier::from(foo_const_id)).into(), + ])], + signature: DeclarationSignature::new() + .inputs(vec![]) + .outputs(vec![DeclarationType::FieldElement]), + }), + )] + .into_iter() + .collect(), + constants: Some( + vec![( + foo_const_id, + TypedConstantSymbol::There(OwnedTypedModuleId::from("foo"), foo_const_id), + )] + .into_iter() + .collect(), + ), + }; + + let program = TypedProgram { + main: "main".into(), + modules: vec![ + ("main".into(), main_module), + ("foo".into(), foo_module.clone()), + ] + .into_iter() + .collect(), + }; + + let program = ConstantInliner::inline(program); + let expected_main_module = TypedModule { + functions: vec![( + DeclarationFunctionKey::with_location("main", "main").signature( + DeclarationSignature::new() + .inputs(vec![]) + .outputs(vec![DeclarationType::FieldElement]), + ), + TypedFunctionSymbol::Here(TypedFunction { + arguments: vec![], + statements: vec![TypedStatement::Return(vec![ + FieldElementExpression::Number(Bn128Field::from(42)).into(), + ])], + signature: DeclarationSignature::new() + .inputs(vec![]) + .outputs(vec![DeclarationType::FieldElement]), + }), + )] + .into_iter() + .collect(), + constants: Some( + vec![( + foo_const_id, + TypedConstantSymbol::Here(TypedConstant { + id: Identifier::from(foo_const_id), + ty: GType::FieldElement, + expression: (TypedExpression::FieldElement( + FieldElementExpression::Number(Bn128Field::from(42)), + )), + }), + )] + .into_iter() + .collect(), + ), + }; + + let expected_program: TypedProgram = TypedProgram { + main: "main".into(), + modules: vec![ + ("main".into(), expected_main_module), + ("foo".into(), foo_module), + ] + .into_iter() + .collect(), + }; + + assert_eq!(program, expected_program) + } } diff --git a/zokrates_core/src/typed_absy/folder.rs b/zokrates_core/src/typed_absy/folder.rs index 348450d7..f72bbf3b 100644 --- a/zokrates_core/src/typed_absy/folder.rs +++ b/zokrates_core/src/typed_absy/folder.rs @@ -13,6 +13,13 @@ pub trait Folder<'ast, T: Field>: Sized { fold_module(self, p) } + fn fold_constant_symbol( + &mut self, + p: TypedConstantSymbol<'ast, T>, + ) -> TypedConstantSymbol<'ast, T> { + fold_constant_symbol(self, p) + } + fn fold_function_symbol( &mut self, s: TypedFunctionSymbol<'ast, T>, @@ -193,12 +200,16 @@ pub fn fold_module<'ast, T: Field, F: Folder<'ast, T>>( p: TypedModule<'ast, T>, ) -> TypedModule<'ast, T> { TypedModule { + constants: p.constants.map(|tc| { + tc.into_iter() + .map(|(key, tc)| (key, f.fold_constant_symbol(tc))) + .collect() + }), functions: p .functions .into_iter() .map(|(key, fun)| (key, f.fold_function_symbol(fun))) .collect(), - constants: p.constants, } } @@ -692,6 +703,19 @@ pub fn fold_struct_expression<'ast, T: Field, F: Folder<'ast, T>>( } } +pub fn fold_constant_symbol<'ast, T: Field, F: Folder<'ast, T>>( + f: &mut F, + p: TypedConstantSymbol<'ast, T>, +) -> TypedConstantSymbol<'ast, T> { + match p { + TypedConstantSymbol::Here(tc) => TypedConstantSymbol::Here(TypedConstant { + expression: f.fold_expression(tc.expression), + ..tc + }), + there => there, + } +} + pub fn fold_function_symbol<'ast, T: Field, F: Folder<'ast, T>>( f: &mut F, s: TypedFunctionSymbol<'ast, T>, diff --git a/zokrates_core/src/typed_absy/mod.rs b/zokrates_core/src/typed_absy/mod.rs index 759421e6..373a2a1d 100644 --- a/zokrates_core/src/typed_absy/mod.rs +++ b/zokrates_core/src/typed_absy/mod.rs @@ -61,8 +61,17 @@ pub type TypedModules<'ast, T> = HashMap = HashMap, TypedFunctionSymbol<'ast, T>>; -/// A collection of `TypedConstant`s -pub type TypedConstants<'ast, T> = HashMap, TypedConstant<'ast, T>>; +pub type ConstantIdentifier<'ast> = &'ast str; + +#[derive(Clone, Debug, PartialEq)] +pub enum TypedConstantSymbol<'ast, T> { + Here(TypedConstant<'ast, T>), + There(OwnedTypedModuleId, ConstantIdentifier<'ast>), +} + +/// A collection of `TypedConstantSymbol`s +pub type TypedConstantSymbols<'ast, T> = + HashMap, TypedConstantSymbol<'ast, T>>; /// A typed program as a collection of modules, one of them being the main #[derive(PartialEq, Debug, Clone)] @@ -144,7 +153,7 @@ pub struct TypedModule<'ast, T> { /// Functions of the module pub functions: TypedFunctionSymbols<'ast, T>, /// Constants defined in module - pub constants: Option>, + pub constants: Option>, } #[derive(Clone, PartialEq)] @@ -320,7 +329,11 @@ pub struct TypedConstant<'ast, T> { impl<'ast, T: fmt::Debug> fmt::Debug for TypedConstant<'ast, T> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "TypedConstant({:?}, {:?}, ...)", self.id, self.ty) + write!( + f, + "TypedConstant({:?}, {:?}, {:?})", + self.id, self.ty, self.expression + ) } } diff --git a/zokrates_core_test/tests/tests/constants/array.json b/zokrates_core_test/tests/tests/constants/array.json new file mode 100644 index 00000000..7cbbb2ba --- /dev/null +++ b/zokrates_core_test/tests/tests/constants/array.json @@ -0,0 +1,16 @@ +{ + "entry_point": "./tests/tests/constants/array.zok", + "max_constraint_count": 2, + "tests": [ + { + "input": { + "values": [] + }, + "output": { + "Ok": { + "values": ["1", "2"] + } + } + } + ] +} diff --git a/zokrates_core_test/tests/tests/constants/array.zok b/zokrates_core_test/tests/tests/constants/array.zok new file mode 100644 index 00000000..cce74dc7 --- /dev/null +++ b/zokrates_core_test/tests/tests/constants/array.zok @@ -0,0 +1,4 @@ +const field[2] ARRAY = [1, 2] + +def main() -> field[2]: + return ARRAY \ No newline at end of file diff --git a/zokrates_core_test/tests/tests/constants/bool.json b/zokrates_core_test/tests/tests/constants/bool.json new file mode 100644 index 00000000..f11aea9a --- /dev/null +++ b/zokrates_core_test/tests/tests/constants/bool.json @@ -0,0 +1,16 @@ +{ + "entry_point": "./tests/tests/constants/bool.zok", + "max_constraint_count": 1, + "tests": [ + { + "input": { + "values": [] + }, + "output": { + "Ok": { + "values": ["1"] + } + } + } + ] +} diff --git a/zokrates_core_test/tests/tests/constants/bool.zok b/zokrates_core_test/tests/tests/constants/bool.zok new file mode 100644 index 00000000..00e0ae11 --- /dev/null +++ b/zokrates_core_test/tests/tests/constants/bool.zok @@ -0,0 +1,4 @@ +const bool BOOLEAN = true + +def main() -> bool: + return BOOLEAN \ No newline at end of file diff --git a/zokrates_core_test/tests/tests/constants/field.json b/zokrates_core_test/tests/tests/constants/field.json new file mode 100644 index 00000000..2ec19a1f --- /dev/null +++ b/zokrates_core_test/tests/tests/constants/field.json @@ -0,0 +1,16 @@ +{ + "entry_point": "./tests/tests/constants/field.zok", + "max_constraint_count": 1, + "tests": [ + { + "input": { + "values": [] + }, + "output": { + "Ok": { + "values": ["1"] + } + } + } + ] +} diff --git a/zokrates_core_test/tests/tests/constants/field.zok b/zokrates_core_test/tests/tests/constants/field.zok new file mode 100644 index 00000000..4408b12e --- /dev/null +++ b/zokrates_core_test/tests/tests/constants/field.zok @@ -0,0 +1,4 @@ +const field ONE = 1 + +def main() -> field: + return ONE \ No newline at end of file diff --git a/zokrates_core_test/tests/tests/constants/nested.json b/zokrates_core_test/tests/tests/constants/nested.json new file mode 100644 index 00000000..61cfd309 --- /dev/null +++ b/zokrates_core_test/tests/tests/constants/nested.json @@ -0,0 +1,16 @@ +{ + "entry_point": "./tests/tests/constants/nested.zok", + "max_constraint_count": 1, + "tests": [ + { + "input": { + "values": [] + }, + "output": { + "Ok": { + "values": ["8"] + } + } + } + ] +} diff --git a/zokrates_core_test/tests/tests/constants/nested.zok b/zokrates_core_test/tests/tests/constants/nested.zok new file mode 100644 index 00000000..a7861aeb --- /dev/null +++ b/zokrates_core_test/tests/tests/constants/nested.zok @@ -0,0 +1,6 @@ +const field A = 2 +const field B = 2 +const field[2] ARRAY = [A * 2, B * 2] + +def main() -> field: + return ARRAY[0] + ARRAY[1] \ No newline at end of file diff --git a/zokrates_core_test/tests/tests/constants/struct.json b/zokrates_core_test/tests/tests/constants/struct.json new file mode 100644 index 00000000..4d77d484 --- /dev/null +++ b/zokrates_core_test/tests/tests/constants/struct.json @@ -0,0 +1,16 @@ +{ + "entry_point": "./tests/tests/constants/struct.zok", + "max_constraint_count": 1, + "tests": [ + { + "input": { + "values": [] + }, + "output": { + "Ok": { + "values": ["4"] + } + } + } + ] +} diff --git a/zokrates_core_test/tests/tests/constants/struct.zok b/zokrates_core_test/tests/tests/constants/struct.zok new file mode 100644 index 00000000..92e705ca --- /dev/null +++ b/zokrates_core_test/tests/tests/constants/struct.zok @@ -0,0 +1,9 @@ +struct Foo { + field a + field b +} + +const Foo FOO = Foo { a: 2, b: 2 } + +def main() -> field: + return FOO.a + FOO.b \ No newline at end of file diff --git a/zokrates_core_test/tests/tests/constants/uint.json b/zokrates_core_test/tests/tests/constants/uint.json new file mode 100644 index 00000000..a2fccab8 --- /dev/null +++ b/zokrates_core_test/tests/tests/constants/uint.json @@ -0,0 +1,16 @@ +{ + "entry_point": "./tests/tests/constants/uint.zok", + "max_constraint_count": 1, + "tests": [ + { + "input": { + "values": [] + }, + "output": { + "Ok": { + "values": ["1"] + } + } + } + ] +} diff --git a/zokrates_core_test/tests/tests/constants/uint.zok b/zokrates_core_test/tests/tests/constants/uint.zok new file mode 100644 index 00000000..914308ec --- /dev/null +++ b/zokrates_core_test/tests/tests/constants/uint.zok @@ -0,0 +1,4 @@ +const u32 ONE = 0x00000001 + +def main() -> u32: + return ONE \ No newline at end of file From 42900ed3bdae0734510dfcb7471ab8e7d26c9186 Mon Sep 17 00:00:00 2001 From: dark64 Date: Mon, 19 Apr 2021 14:38:28 +0200 Subject: [PATCH 05/12] update book --- zokrates_book/src/SUMMARY.md | 5 +++-- zokrates_book/src/language/constants.md | 17 +++++++++++++++++ zokrates_book/src/language/imports.md | 5 ++++- .../examples/book/constant_definition.zok | 5 ++--- .../examples/book/constant_reference.zok | 5 +++++ 5 files changed, 31 insertions(+), 6 deletions(-) create mode 100644 zokrates_book/src/language/constants.md create mode 100644 zokrates_cli/examples/book/constant_reference.zok diff --git a/zokrates_book/src/SUMMARY.md b/zokrates_book/src/SUMMARY.md index 6ccefb31..7755bbf7 100644 --- a/zokrates_book/src/SUMMARY.md +++ b/zokrates_book/src/SUMMARY.md @@ -8,11 +8,12 @@ - [Variables](language/variables.md) - [Types](language/types.md) - [Operators](language/operators.md) - - [Functions](language/functions.md) - [Control flow](language/control_flow.md) + - [Constants](language/constants.md) + - [Functions](language/functions.md) + - [Generics](language/generics.md) - [Imports](language/imports.md) - [Comments](language/comments.md) - - [Generics](language/generics.md) - [Macros](language/macros.md) - [Toolbox](toolbox/index.md) diff --git a/zokrates_book/src/language/constants.md b/zokrates_book/src/language/constants.md new file mode 100644 index 00000000..8bad0721 --- /dev/null +++ b/zokrates_book/src/language/constants.md @@ -0,0 +1,17 @@ +## Constants + +Constants must be globally defined outside all other scopes by using a `const` keyword. Constants can be set only to a constant expression. + +```zokrates +{{#include ../../../zokrates_cli/examples/book/constant_definition.zok}} +``` + +The value of a constant can't be changed through reassignment, and it can't be redeclared. Constants are essentially inlined wherever they are used, meaning that they are copied directly into the relevant context when used. + +Constants must be explicitly typed. One can reference other constants inside the expression, as long as the referenced constant is defined before the constant. + +```zokrates +{{#include ../../../zokrates_cli/examples/book/constant_reference.zok}} +``` + +The naming convention for constants are similar to that of variables. All characters in a constant name are usually in uppercase. \ No newline at end of file diff --git a/zokrates_book/src/language/imports.md b/zokrates_book/src/language/imports.md index fef1083d..8d4d17de 100644 --- a/zokrates_book/src/language/imports.md +++ b/zokrates_book/src/language/imports.md @@ -44,7 +44,7 @@ from "./path/to/my/module" import main as module Note that this legacy method is likely to become deprecated, so it is recommended to use the preferred way instead. ### Symbols -Two types of symbols can be imported +Three types of symbols can be imported #### Functions Functions are imported by name. If many functions have the same name but different signatures, all of them get imported, and which one to use in a particular call is inferred. @@ -52,6 +52,9 @@ Functions are imported by name. If many functions have the same name but differe #### User-defined types User-defined types declared with the `struct` keyword are imported by name. +#### Constants +Constants declared with the `const` keyword are imported by name. + ### Relative Imports You can import a resource in the same folder directly, like this: diff --git a/zokrates_cli/examples/book/constant_definition.zok b/zokrates_cli/examples/book/constant_definition.zok index b6850f89..016f231c 100644 --- a/zokrates_cli/examples/book/constant_definition.zok +++ b/zokrates_cli/examples/book/constant_definition.zok @@ -1,5 +1,4 @@ -const field ONE = 1 -const field TWO = ONE + ONE +const field BN128_GROUP_MODULUS = 21888242871839275222246405745257275088548364400416034343698204186575808495617 def main() -> field: - return TWO \ No newline at end of file + return BN128_GROUP_MODULUS \ No newline at end of file diff --git a/zokrates_cli/examples/book/constant_reference.zok b/zokrates_cli/examples/book/constant_reference.zok new file mode 100644 index 00000000..b6850f89 --- /dev/null +++ b/zokrates_cli/examples/book/constant_reference.zok @@ -0,0 +1,5 @@ +const field ONE = 1 +const field TWO = ONE + ONE + +def main() -> field: + return TWO \ No newline at end of file From b94b72080f91e47a4a085ced8ef76fdee5d50c84 Mon Sep 17 00:00:00 2001 From: dark64 Date: Mon, 19 Apr 2021 14:57:53 +0200 Subject: [PATCH 06/12] update example --- zokrates_cli/examples/book/constant_definition.zok | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/zokrates_cli/examples/book/constant_definition.zok b/zokrates_cli/examples/book/constant_definition.zok index 016f231c..10b31ec7 100644 --- a/zokrates_cli/examples/book/constant_definition.zok +++ b/zokrates_cli/examples/book/constant_definition.zok @@ -1,4 +1,4 @@ -const field BN128_GROUP_MODULUS = 21888242871839275222246405745257275088548364400416034343698204186575808495617 +const field PRIME = 31 def main() -> field: - return BN128_GROUP_MODULUS \ No newline at end of file + return PRIME \ No newline at end of file From 3b5ad3d13e28c14d479c717ddaa321d009888ca6 Mon Sep 17 00:00:00 2001 From: dark64 Date: Mon, 19 Apr 2021 17:38:46 +0200 Subject: [PATCH 07/12] update textmate, add yaml version --- zokrates_core/src/semantics.rs | 2 - zokrates_parser/src/ace_mode/index.js | 2 +- zokrates_parser/src/textmate/package.json | 59 +- .../syntaxes/zokrates.tmLanguage.json | 1233 +++++++++-------- .../src/textmate/zokrates.tmLanguage.yaml | 349 +++++ 5 files changed, 1019 insertions(+), 626 deletions(-) create mode 100644 zokrates_parser/src/textmate/zokrates.tmLanguage.yaml diff --git a/zokrates_core/src/semantics.rs b/zokrates_core/src/semantics.rs index 50602985..b76f6f1d 100644 --- a/zokrates_core/src/semantics.rs +++ b/zokrates_core/src/semantics.rs @@ -784,7 +784,6 @@ impl<'ast, T: Field> Checker<'ast, T> { module_id: &ModuleId, types: &TypeMap<'ast>, ) -> Result, Vec> { - // assert!(self.scope.is_empty()); assert!(self.return_types.is_none()); self.enter_scope(); @@ -911,7 +910,6 @@ impl<'ast, T: Field> Checker<'ast, T> { } self.return_types = None; - // assert!(self.scope.is_empty()); Ok(TypedFunction { arguments: arguments_checked, diff --git a/zokrates_parser/src/ace_mode/index.js b/zokrates_parser/src/ace_mode/index.js index ec77a4fb..537e0f76 100644 --- a/zokrates_parser/src/ace_mode/index.js +++ b/zokrates_parser/src/ace_mode/index.js @@ -37,7 +37,7 @@ ace.define("ace/mode/zokrates_highlight_rules",["require","exports","module","ac var ZoKratesHighlightRules = function () { var keywords = ( - "assert|as|bool|byte|def|do|else|endfor|export|false|field|for|if|then|fi|import|from|in|private|public|return|struct|true|u8|u16|u32|u64" + "assert|as|bool|byte|const|def|do|else|endfor|export|false|field|for|if|then|fi|import|from|in|private|public|return|struct|true|u8|u16|u32|u64" ); var keywordMapper = this.createKeywordMapper({ diff --git a/zokrates_parser/src/textmate/package.json b/zokrates_parser/src/textmate/package.json index ebacd5de..a72f04a0 100644 --- a/zokrates_parser/src/textmate/package.json +++ b/zokrates_parser/src/textmate/package.json @@ -1,27 +1,36 @@ { - "name": "zokrates", - "displayName": "zokrates", - "description": "Syntax highlighting for the ZoKrates language", - "publisher": "zokrates", - "repository": "https://github.com/ZoKrates/ZoKrates", - "version": "0.0.1", - "engines": { - "vscode": "^1.53.0" - }, - "categories": [ - "Programming Languages" + "name": "zokrates", + "displayName": "zokrates", + "description": "Syntax highlighting for the ZoKrates language", + "publisher": "zokrates", + "repository": "https://github.com/ZoKrates/ZoKrates", + "version": "0.0.1", + "engines": { + "vscode": "^1.53.0" + }, + "categories": [ + "Programming Languages" + ], + "contributes": { + "languages": [ + { + "id": "zokrates", + "aliases": [ + "ZoKrates", + "zokrates" + ], + "extensions": [ + ".zok" + ], + "configuration": "./language-configuration.json" + } ], - "contributes": { - "languages": [{ - "id": "zokrates", - "aliases": ["ZoKrates", "zokrates"], - "extensions": [".zok"], - "configuration": "./language-configuration.json" - }], - "grammars": [{ - "language": "zokrates", - "scopeName": "source.zok", - "path": "./syntaxes/zokrates.tmLanguage.json" - }] - } -} \ No newline at end of file + "grammars": [ + { + "language": "zokrates", + "scopeName": "source.zok", + "path": "./syntaxes/zokrates.tmLanguage.json" + } + ] + } +} diff --git a/zokrates_parser/src/textmate/syntaxes/zokrates.tmLanguage.json b/zokrates_parser/src/textmate/syntaxes/zokrates.tmLanguage.json index 3730240c..c55cadb3 100644 --- a/zokrates_parser/src/textmate/syntaxes/zokrates.tmLanguage.json +++ b/zokrates_parser/src/textmate/syntaxes/zokrates.tmLanguage.json @@ -1,600 +1,637 @@ { - "$schema": "https://raw.githubusercontent.com/martinring/tmlanguage/master/tmlanguage.json", - "name": "ZoKrates", - "fileTypes": [ - "zok" - ], - "scopeName": "source.zok", - "patterns": [ - { - "comment": "attributes", - "name": "meta.attribute.zokrates", - "begin": "(#)(\\!?)(\\[)", - "beginCaptures": { - "1": { - "name": "punctuation.definition.attribute.zokrates" - }, - "2": { - "name": "keyword.operator.attribute.inner.zokrates" - }, - "3": { - "name": "punctuation.brackets.attribute.zokrates" - } - }, - "end": "\\]", - "endCaptures": { - "0": { - "name": "punctuation.brackets.attribute.zokrates" - } - }, - "patterns": [ - { - "include": "#block-comments" - }, - { - "include": "#comments" - }, - { - "include": "#keywords" - }, - { - "include": "#lifetimes" - }, - { - "include": "#punctuation" - }, - { - "include": "#strings" - }, - { - "include": "#gtypes" - }, - { - "include": "#types" - } - ] - }, - { - "include": "#block-comments" - }, - { - "include": "#comments" - }, - { - "include": "#constants" - }, - { - "include": "#functions" - }, - { - "include": "#types" - }, - { - "include": "#keywords" - }, - { - "include": "#punctuation" - }, - { - "include": "#strings" - }, - { - "include": "#variables" - } - ], - "repository": { - "comments": { - "patterns": [ - { - "comment": "line comments", - "name": "comment.line.double-slash.zokrates", - "match": "\\s*//.*" - } - ] - }, - "block-comments": { - "patterns": [ - { - "comment": "empty block comments", - "name": "comment.block.zokrates", - "match": "/\\*\\*/" - }, - { - "comment": "block comments", - "name": "comment.block.zokrates", - "begin": "/\\*(?!\\*)", - "end": "\\*/", - "patterns": [ - { - "include": "#block-comments" - } - ] - } - ] - }, - "constants": { - "patterns": [ - { - "comment": "ALL CAPS constants", - "name": "constant.other.caps.zokrates", - "match": "\\b[A-Z]{2}[A-Z0-9_]*\\b" - }, - { - "comment": "decimal integers and floats", - "name": "constant.numeric.decimal.zokrates", - "match": "\\b\\d[\\d_]*(u128|u16|u32|u64|u8|f)?\\b", - "captures": { - "5": { - "name": "entity.name.type.numeric.zokrates" - } - } - }, - { - "comment": "hexadecimal integers", - "name": "constant.numeric.hex.zokrates", - "match": "\\b0x[\\da-fA-F_]+\\b" - }, - { - "comment": "booleans", - "name": "constant.language.bool.zokrates", - "match": "\\b(true|false)\\b" - } - ] - }, - "imports": { - "patterns": [ - { - "comment": "explicit import statement", - "name": "meta.import.explicit.zokrates", - "match": "\\b(from)\\s+(\\\".*\\\")(import)\\s+([A-Za-z0-9_]+)\\s+((as)\\s+[A-Za-z0-9_]+)?\\b", - "patterns": [ - { - "include": "#block-comments" - }, - { - "include": "#comments" - }, - { - "include": "#keywords" - }, - { - "include": "#punctuation" - }, - { - "include": "#types" - }, - { - "include": "#strings" - } - ] - }, - { - "comment": "main import statement", - "name": "meta.import.explicit.zokrates", - "match": "\\b(import)\\s+(\\\".*\\\")\\s+((as)\\s+[A-Za-z0-9_]+)?\\b", - "patterns": [ - { - "include": "#block-comments" - }, - { - "include": "#comments" - }, - { - "include": "#keywords" - }, - { - "include": "#punctuation" - }, - { - "include": "#types" - }, - { - "include": "#strings" - } - ] - } - ] - }, - "functions": { - "patterns": [ - { - "comment": "function definition", - "name": "meta.function.definition.zokrates", - "begin": "\\b(def)\\s+([A-Za-z0-9_]+)((\\()|(<))", - "beginCaptures": { - "1": { - "name": "keyword.other.def.zokrates" - }, - "2": { - "name": "entity.name.function.zokrates" - }, - "4": { - "name": "punctuation.brackets.round.zokrates" - }, - "5": { - "name": "punctuation.brackets.angle.zokrates" - } - }, - "end": "\\:|;", - "endCaptures": { - "0": { - "name": "keyword.punctuation.colon.zokrates" - } - }, - "patterns": [ - { - "include": "#block-comments" - }, - { - "include": "#comments" - }, - { - "include": "#keywords" - }, - { - "include": "#constants" - }, - { - "include": "#functions" - }, - { - "include": "#punctuation" - }, - { - "include": "#strings" - }, - { - "include": "#types" - }, - { - "include": "#variables" - } - ] - }, - { - "comment": "function/method calls, chaining", - "name": "meta.function.call.zokrates", - "begin": "([A-Za-z0-9_]+)(\\()", - "beginCaptures": { - "1": { - "name": "entity.name.function.zokrates" - }, - "2": { - "name": "punctuation.brackets.round.zokrates" - } - }, - "end": "\\)", - "endCaptures": { - "0": { - "name": "punctuation.brackets.round.zokrates" - } - }, - "patterns": [ - { - "include": "#block-comments" - }, - { - "include": "#comments" - }, - { - "include": "#keywords" - }, - { - "include": "#constants" - }, - { - "include": "#functions" - }, - { - "include": "#punctuation" - }, - { - "include": "#strings" - }, - { - "include": "#types" - }, - { - "include": "#variables" - } - ] - }, - { - "comment": "function/method calls with turbofish", - "name": "meta.function.call.zokrates", - "begin": "([A-Za-z0-9_]+)(?=::<.*>\\()", - "beginCaptures": { - "1": { - "name": "entity.name.function.zokrates" - } - }, - "end": "\\)", - "endCaptures": { - "0": { - "name": "punctuation.brackets.round.zokrates" - } - }, - "patterns": [ - { - "include": "#block-comments" - }, - { - "include": "#comments" - }, - { - "include": "#keywords" - }, - { - "include": "#constants" - }, - { - "include": "#functions" - }, - { - "include": "#punctuation" - }, - { - "include": "#strings" - }, - { - "include": "#types" - }, - { - "include": "#variables" - } - ] - } - ] - }, - "keywords": { - "patterns": [ - { - "comment": "argument visibility", - "name": "keyword.visibility.zokrates", - "match": "\\b(public|private)\\b" - }, - { - "comment": "control flow keywords", - "name": "keyword.control.zokrates", - "match": "\\b(do|else|for|do|endfor|if|then|fi|return|assert)\\b" - }, - { - "comment": "storage keywords", - "name": "storage.type.zokrates", - "match": "\\b(struct)\\b" - }, - { - "comment": "def", - "name": "keyword.other.def.zokrates", - "match": "\\bdef\\b" - }, - { - "comment": "import keywords", - "name": "keyword.other.import.zokrates", - "match": "\\b(import|from|as)\\b" - }, - { - "comment": "logical operators", - "name": "keyword.operator.logical.zokrates", - "match": "(\\^|\\||\\|\\||&|&&|<<|>>|!)(?!=)" - }, - { - "comment": "single equal", - "name": "keyword.operator.assignment.equal.zokrates", - "match": "(?])=(?!=|>)" - }, - { - "comment": "comparison operators", - "name": "keyword.operator.comparison.zokrates", - "match": "(=(=)?(?!>)|!=|<=|(?=)" - }, - { - "comment": "math operators", - "name": "keyword.operator.math.zokrates", - "match": "(([+%]|(\\*(?!\\w)))(?!=))|(-(?!>))|(/(?!/))" - }, - { - "comment": "less than, greater than (special case)", - "match": "(?:\\b|(?:(\\))|(\\])|(\\})))[ \\t]+([<>])[ \\t]+(?:\\b|(?:(\\()|(\\[)|(\\{)))", - "captures": { - "1": { - "name": "punctuation.brackets.round.zokrates" - }, - "2": { - "name": "punctuation.brackets.square.zokrates" - }, - "3": { - "name": "punctuation.brackets.curly.zokrates" - }, - "4": { - "name": "keyword.operator.comparison.zokrates" - }, - "5": { - "name": "punctuation.brackets.round.zokrates" - }, - "6": { - "name": "punctuation.brackets.square.zokrates" - }, - "7": { - "name": "punctuation.brackets.curly.zokrates" - } - } - }, - { - "comment": "dot access", - "name": "keyword.operator.access.dot.zokrates", - "match": "\\.(?!\\.)" - }, - { - "comment": "ranges, range patterns", - "name": "keyword.operator.range.zokrates", - "match": "\\.{2}(=|\\.)?" - }, - { - "comment": "colon", - "name": "keyword.operator.colon.zokrates", - "match": ":(?!:)" - }, - { - "comment": "dashrocket, skinny arrow", - "name": "keyword.operator.arrow.skinny.zokrates", - "match": "->" - } - ] - }, - "types": { - "patterns": [ - { - "comment": "numeric types", - "match": "(?", - "endCaptures": { - "0": { - "name": "punctuation.brackets.angle.zokrates" - } - }, - "patterns": [ - { - "include": "#block-comments" - }, - { - "include": "#comments" - }, - { - "include": "#keywords" - }, - { - "include": "#punctuation" - }, - { - "include": "#types" - }, - { - "include": "#variables" - } - ] - }, - { - "comment": "primitive types", - "name": "entity.name.type.primitive.zokrates", - "match": "\\b(bool)\\b" - }, - { - "comment": "struct declarations", - "match": "\\b(struct)\\s+([A-Z][A-Za-z0-9]*)\\b", - "captures": { - "1": { - "name": "storage.type.zokrates" - }, - "2": { - "name": "entity.name.type.struct.zokrates" - } - } - }, - { - "comment": "types", - "name": "entity.name.type.zokrates", - "match": "\\b[A-Z][A-Za-z0-9]*\\b(?!!)" - } - ] - }, - "punctuation": { - "patterns": [ - { - "comment": "comma", - "name": "punctuation.comma.zokrates", - "match": "," - }, - { - "comment": "parentheses, round brackets", - "name": "punctuation.brackets.round.zokrates", - "match": "[()]" - }, - { - "comment": "square brackets", - "name": "punctuation.brackets.square.zokrates", - "match": "[\\[\\]]" - }, - { - "comment": "angle brackets", - "name": "punctuation.brackets.angle.zokrates", - "match": "(?]" - } - ] - }, - "strings": { - "patterns": [ - { - "comment": "double-quoted strings and byte strings", - "name": "string.quoted.double.zokrates", - "begin": "(b?)(\")", - "beginCaptures": { - "1": { - "name": "string.quoted.byte.raw.zokrates" - }, - "2": { - "name": "punctuation.definition.string.zokrates" - } - }, - "end": "\"", - "endCaptures": { - "0": { - "name": "punctuation.definition.string.zokrates" - } - } - }, - { - "comment": "double-quoted raw strings and raw byte strings", - "name": "string.quoted.double.zokrates", - "begin": "(b?r)(#*)(\")", - "beginCaptures": { - "1": { - "name": "string.quoted.byte.raw.zokrates" - }, - "2": { - "name": "punctuation.definition.string.raw.zokrates" - }, - "3": { - "name": "punctuation.definition.string.zokrates" - } - }, - "end": "(\")(\\2)", - "endCaptures": { - "1": { - "name": "punctuation.definition.string.zokrates" - }, - "2": { - "name": "punctuation.definition.string.raw.zokrates" - } - } - } - ] - }, - "variables": { - "patterns": [ - { - "comment": "variables", - "name": "variable.other.zokrates", - "match": "\\b(?\\()", + "beginCaptures": { + "1": { + "name": "entity.name.function.zokrates" + } + }, + "end": "\\)", + "endCaptures": { + "0": { + "name": "punctuation.brackets.round.zokrates" + } + }, + "patterns": [ + { + "include": "#block-comments" + }, + { + "include": "#comments" + }, + { + "include": "#keywords" + }, + { + "include": "#constants" + }, + { + "include": "#functions" + }, + { + "include": "#punctuation" + }, + { + "include": "#strings" + }, + { + "include": "#types" + }, + { + "include": "#variables" + } + ] + } + ] + }, + "keywords": { + "patterns": [ + { + "comment": "argument visibility", + "name": "keyword.visibility.zokrates", + "match": "\\b(public|private)\\b" + }, + { + "comment": "control flow keywords", + "name": "keyword.control.zokrates", + "match": "\\b(do|else|for|do|endfor|if|then|fi|return|assert)\\b" + }, + { + "comment": "storage keywords", + "name": "storage.type.zokrates", + "match": "\\b(struct)\\b" + }, + { + "comment": "const", + "name": "keyword.other.const.zokrates", + "match": "\\bconst\\b" + }, + { + "comment": "def", + "name": "keyword.other.def.zokrates", + "match": "\\bdef\\b" + }, + { + "comment": "import keywords", + "name": "keyword.other.import.zokrates", + "match": "\\b(import|from|as)\\b" + }, + { + "comment": "logical operators", + "name": "keyword.operator.logical.zokrates", + "match": "(\\^|\\||\\|\\||&|&&|<<|>>|!)(?!=)" + }, + { + "comment": "single equal", + "name": "keyword.operator.assignment.equal.zokrates", + "match": "(?])=(?!=|>)" + }, + { + "comment": "comparison operators", + "name": "keyword.operator.comparison.zokrates", + "match": "(=(=)?(?!>)|!=|<=|(?=)" + }, + { + "comment": "math operators", + "name": "keyword.operator.math.zokrates", + "match": "(([+%]|(\\*(?!\\w)))(?!=))|(-(?!>))|(/(?!/))" + }, + { + "comment": "less than, greater than (special case)", + "match": "(?:\\b|(?:(\\))|(\\])|(\\})))[ \\t]+([<>])[ \\t]+(?:\\b|(?:(\\()|(\\[)|(\\{)))", + "captures": { + "1": { + "name": "punctuation.brackets.round.zokrates" + }, + "2": { + "name": "punctuation.brackets.square.zokrates" + }, + "3": { + "name": "punctuation.brackets.curly.zokrates" + }, + "4": { + "name": "keyword.operator.comparison.zokrates" + }, + "5": { + "name": "punctuation.brackets.round.zokrates" + }, + "6": { + "name": "punctuation.brackets.square.zokrates" + }, + "7": { + "name": "punctuation.brackets.curly.zokrates" + } + } + }, + { + "comment": "dot access", + "name": "keyword.operator.access.dot.zokrates", + "match": "\\.(?!\\.)" + }, + { + "comment": "ranges, range patterns", + "name": "keyword.operator.range.zokrates", + "match": "\\.{2}(=|\\.)?" + }, + { + "comment": "colon", + "name": "keyword.operator.colon.zokrates", + "match": ":(?!:)" + }, + { + "comment": "dashrocket, skinny arrow", + "name": "keyword.operator.arrow.skinny.zokrates", + "match": "->" + } + ] + }, + "types": { + "patterns": [ + { + "comment": "numeric types", + "match": "(?", + "endCaptures": { + "0": { + "name": "punctuation.brackets.angle.zokrates" + } + }, + "patterns": [ + { + "include": "#block-comments" + }, + { + "include": "#comments" + }, + { + "include": "#keywords" + }, + { + "include": "#punctuation" + }, + { + "include": "#types" + }, + { + "include": "#variables" + } + ] + }, + { + "comment": "primitive types", + "name": "entity.name.type.primitive.zokrates", + "match": "\\b(bool)\\b" + }, + { + "comment": "struct declarations", + "match": "\\b(struct)\\s+([A-Z][A-Za-z0-9]*)\\b", + "captures": { + "1": { + "name": "storage.type.zokrates" + }, + "2": { + "name": "entity.name.type.struct.zokrates" + } + } + }, + { + "comment": "types", + "name": "entity.name.type.zokrates", + "match": "\\b[A-Z][A-Za-z0-9]*\\b(?!!)" + } + ] + }, + "punctuation": { + "patterns": [ + { + "comment": "comma", + "name": "punctuation.comma.zokrates", + "match": "," + }, + { + "comment": "parentheses, round brackets", + "name": "punctuation.brackets.round.zokrates", + "match": "[()]" + }, + { + "comment": "square brackets", + "name": "punctuation.brackets.square.zokrates", + "match": "[\\[\\]]" + }, + { + "comment": "angle brackets", + "name": "punctuation.brackets.angle.zokrates", + "match": "(?]" + } + ] + }, + "strings": { + "patterns": [ + { + "comment": "double-quoted strings and byte strings", + "name": "string.quoted.double.zokrates", + "begin": "(b?)(\")", + "beginCaptures": { + "1": { + "name": "string.quoted.byte.raw.zokrates" + }, + "2": { + "name": "punctuation.definition.string.zokrates" + } + }, + "end": "\"", + "endCaptures": { + "0": { + "name": "punctuation.definition.string.zokrates" + } + } + }, + { + "comment": "double-quoted raw strings and raw byte strings", + "name": "string.quoted.double.zokrates", + "begin": "(b?r)(#*)(\")", + "beginCaptures": { + "1": { + "name": "string.quoted.byte.raw.zokrates" + }, + "2": { + "name": "punctuation.definition.string.raw.zokrates" + }, + "3": { + "name": "punctuation.definition.string.zokrates" + } + }, + "end": "(\")(\\2)", + "endCaptures": { + "1": { + "name": "punctuation.definition.string.zokrates" + }, + "2": { + "name": "punctuation.definition.string.raw.zokrates" + } + } + } + ] + }, + "variables": { + "patterns": [ + { + "comment": "variables", + "name": "variable.other.zokrates", + "match": "\\b(?\()' + beginCaptures: + '1': {name: entity.name.function.zokrates} + end: \) + endCaptures: + '0': {name: punctuation.brackets.round.zokrates} + patterns: + - {include: '#block-comments'} + - {include: '#comments'} + - {include: '#keywords'} + - {include: '#constants'} + - {include: '#functions'} + - {include: '#punctuation'} + - {include: '#strings'} + - {include: '#types'} + - {include: '#variables'} + keywords: + patterns: + - + comment: 'argument visibility' + name: keyword.visibility.zokrates + match: \b(public|private)\b + - + comment: 'control flow keywords' + name: keyword.control.zokrates + match: \b(do|else|for|do|endfor|if|then|fi|return|assert)\b + - + comment: 'storage keywords' + name: storage.type.zokrates + match: \b(struct)\b + - + comment: const + name: keyword.other.const.zokrates + match: \bconst\b + - + comment: def + name: keyword.other.def.zokrates + match: \bdef\b + - + comment: 'import keywords' + name: keyword.other.import.zokrates + match: \b(import|from|as)\b + - + comment: 'logical operators' + name: keyword.operator.logical.zokrates + match: '(\^|\||\|\||&|&&|<<|>>|!)(?!=)' + - + comment: 'single equal' + name: keyword.operator.assignment.equal.zokrates + match: '(?])=(?!=|>)' + - + comment: 'comparison operators' + name: keyword.operator.comparison.zokrates + match: '(=(=)?(?!>)|!=|<=|(?=)' + - + comment: 'math operators' + name: keyword.operator.math.zokrates + match: '(([+%]|(\*(?!\w)))(?!=))|(-(?!>))|(/(?!/))' + - + comment: 'less than, greater than (special case)' + match: '(?:\b|(?:(\))|(\])|(\})))[ \t]+([<>])[ \t]+(?:\b|(?:(\()|(\[)|(\{)))' + captures: + '1': {name: punctuation.brackets.round.zokrates} + '2': {name: punctuation.brackets.square.zokrates} + '3': {name: punctuation.brackets.curly.zokrates} + '4': {name: keyword.operator.comparison.zokrates} + '5': {name: punctuation.brackets.round.zokrates} + '6': {name: punctuation.brackets.square.zokrates} + '7': {name: punctuation.brackets.curly.zokrates} + - + comment: 'dot access' + name: keyword.operator.access.dot.zokrates + match: '\.(?!\.)' + - + comment: 'ranges, range patterns' + name: keyword.operator.range.zokrates + match: '\.{2}(=|\.)?' + - + comment: colon + name: keyword.operator.colon.zokrates + match: ':(?!:)' + - + comment: 'dashrocket, skinny arrow' + name: keyword.operator.arrow.skinny.zokrates + match: '->' + types: + patterns: + - + comment: 'numeric types' + match: '(?' + endCaptures: + '0': {name: punctuation.brackets.angle.zokrates} + patterns: + - {include: '#block-comments'} + - {include: '#comments'} + - {include: '#keywords'} + - {include: '#punctuation'} + - {include: '#types'} + - {include: '#variables'} + - + comment: 'primitive types' + name: entity.name.type.primitive.zokrates + match: \b(bool)\b + - + comment: 'struct declarations' + match: '\b(struct)\s+([A-Z][A-Za-z0-9]*)\b' + captures: + '1': {name: storage.type.zokrates} + '2': {name: entity.name.type.struct.zokrates} + - + comment: types + name: entity.name.type.zokrates + match: '\b[A-Z][A-Za-z0-9]*\b(?!!)' + punctuation: + patterns: + - + comment: comma + name: punctuation.comma.zokrates + match: ',' + - + comment: 'parentheses, round brackets' + name: punctuation.brackets.round.zokrates + match: '[()]' + - + comment: 'square brackets' + name: punctuation.brackets.square.zokrates + match: '[\[\]]' + - + comment: 'angle brackets' + name: punctuation.brackets.angle.zokrates + match: '(?]' + strings: + patterns: + - + comment: 'double-quoted strings and byte strings' + name: string.quoted.double.zokrates + begin: '(b?)(")' + beginCaptures: + '1': {name: string.quoted.byte.raw.zokrates} + '2': {name: punctuation.definition.string.zokrates} + end: '"' + endCaptures: + '0': {name: punctuation.definition.string.zokrates} + - + comment: 'double-quoted raw strings and raw byte strings' + name: string.quoted.double.zokrates + begin: '(b?r)(#*)(")' + beginCaptures: + '1': {name: string.quoted.byte.raw.zokrates} + '2': {name: punctuation.definition.string.raw.zokrates} + '3': {name: punctuation.definition.string.zokrates} + end: '(")(\2)' + endCaptures: + '1': {name: punctuation.definition.string.zokrates} + '2': {name: punctuation.definition.string.raw.zokrates} + variables: + patterns: + - + comment: variables + name: variable.other.zokrates + match: '\b(? Date: Wed, 21 Apr 2021 18:15:12 +0200 Subject: [PATCH 08/12] refactoring --- zokrates_book/src/language/constants.md | 4 +- zokrates_core/src/absy/from_ast.rs | 1 - zokrates_core/src/absy/mod.rs | 13 +- zokrates_core/src/semantics.rs | 152 +++++++----- .../src/static_analysis/constant_inliner.rs | 224 +++++++++--------- .../src/static_analysis/propagation.rs | 2 +- .../src/static_analysis/reducer/mod.rs | 20 +- zokrates_core/src/typed_absy/abi.rs | 2 +- zokrates_core/src/typed_absy/folder.rs | 6 +- zokrates_core/src/typed_absy/mod.rs | 11 +- zokrates_core/src/typed_absy/result_folder.rs | 22 +- 11 files changed, 251 insertions(+), 206 deletions(-) diff --git a/zokrates_book/src/language/constants.md b/zokrates_book/src/language/constants.md index 8bad0721..a6c16db7 100644 --- a/zokrates_book/src/language/constants.md +++ b/zokrates_book/src/language/constants.md @@ -6,9 +6,9 @@ Constants must be globally defined outside all other scopes by using a `const` k {{#include ../../../zokrates_cli/examples/book/constant_definition.zok}} ``` -The value of a constant can't be changed through reassignment, and it can't be redeclared. Constants are essentially inlined wherever they are used, meaning that they are copied directly into the relevant context when used. +The value of a constant can't be changed through reassignment, and it can't be redeclared. -Constants must be explicitly typed. One can reference other constants inside the expression, as long as the referenced constant is defined before the constant. +Constants must be explicitly typed. One can reference other constants inside the expression, as long as the referenced constant is already defined. ```zokrates {{#include ../../../zokrates_cli/examples/book/constant_reference.zok}} diff --git a/zokrates_core/src/absy/from_ast.rs b/zokrates_core/src/absy/from_ast.rs index 3f6edb31..10436d86 100644 --- a/zokrates_core/src/absy/from_ast.rs +++ b/zokrates_core/src/absy/from_ast.rs @@ -112,7 +112,6 @@ impl<'ast> From> for absy::SymbolDeclarationNode< let id = definition.id.span.as_str(); let ty = absy::ConstantDefinition { - id, ty: definition.ty.into(), expression: definition.expression.into(), } diff --git a/zokrates_core/src/absy/mod.rs b/zokrates_core/src/absy/mod.rs index 028212a5..b9fccd77 100644 --- a/zokrates_core/src/absy/mod.rs +++ b/zokrates_core/src/absy/mod.rs @@ -90,7 +90,11 @@ impl<'ast> fmt::Display for SymbolDeclaration<'ast> { match self.symbol { Symbol::Here(ref kind) => match kind { SymbolDefinition::Struct(t) => write!(f, "struct {} {}", self.id, t), - SymbolDefinition::Constant(c) => write!(f, "{}", c), + SymbolDefinition::Constant(c) => write!( + f, + "const {} {} = {}", + c.value.ty, self.id, c.value.expression + ), SymbolDefinition::Function(func) => write!(f, "def {}{}", self.id, func), }, Symbol::There(ref import) => write!(f, "import {} as {}", import, self.id), @@ -166,7 +170,6 @@ type StructDefinitionFieldNode<'ast> = Node>; #[derive(Clone, PartialEq)] pub struct ConstantDefinition<'ast> { - pub id: Identifier<'ast>, pub ty: UnresolvedTypeNode<'ast>, pub expression: ExpressionNode<'ast>, } @@ -175,7 +178,7 @@ pub type ConstantDefinitionNode<'ast> = Node>; impl<'ast> fmt::Display for ConstantDefinition<'ast> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "const {} {} = {}", self.ty, self.id, self.expression) + write!(f, "const {}({})", self.ty, self.expression) } } @@ -183,8 +186,8 @@ impl<'ast> fmt::Debug for ConstantDefinition<'ast> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!( f, - "ConstantDefinition({:?}, {:?}, {:?})", - self.ty, self.id, self.expression + "ConstantDefinition({:?}, {:?})", + self.ty, self.expression ) } } diff --git a/zokrates_core/src/semantics.rs b/zokrates_core/src/semantics.rs index b76f6f1d..6debe421 100644 --- a/zokrates_core/src/semantics.rs +++ b/zokrates_core/src/semantics.rs @@ -46,6 +46,7 @@ impl ErrorInner { } type TypeMap<'ast> = HashMap>>; +type ConstantMap<'ast, T> = HashMap>; /// The global state of the program during semantic checks #[derive(Debug)] @@ -56,6 +57,8 @@ struct State<'ast, T> { typed_modules: TypedModules<'ast, T>, /// The user-defined types, which we keep track at this phase only. In later phases, we rely only on basic types and combinations thereof types: TypeMap<'ast>, + // The user-defined constants + constants: ConstantMap<'ast, T>, } /// A symbol for a given name: either a type or a group of functions. Not both! @@ -73,14 +76,27 @@ struct SymbolUnifier<'ast> { } impl<'ast> SymbolUnifier<'ast> { - fn insert_symbol>(&mut self, id: S, ty: SymbolType<'ast>) -> bool { - let s_type = self.symbols.entry(id.into()); - match s_type { - // if anything is already called `id`, we cannot introduce this type + fn insert_type>(&mut self, id: S) -> bool { + let e = self.symbols.entry(id.into()); + match e { + // if anything is already called `id`, we cannot introduce the symbol Entry::Occupied(..) => false, // otherwise, we can! Entry::Vacant(v) => { - v.insert(ty); + v.insert(SymbolType::Type); + true + } + } + } + + fn insert_constant>(&mut self, id: S) -> bool { + let e = self.symbols.entry(id.into()); + match e { + // if anything is already called `id`, we cannot introduce this constant + Entry::Occupied(..) => false, + // otherwise, we can! + Entry::Vacant(v) => { + v.insert(SymbolType::Constant); true } } @@ -117,6 +133,7 @@ impl<'ast, T: Field> State<'ast, T> { modules, typed_modules: HashMap::new(), types: HashMap::new(), + constants: HashMap::new(), } } } @@ -230,7 +247,12 @@ impl<'ast, T: Field> FunctionQuery<'ast, T> { pub struct ScopedVariable<'ast, T> { id: Variable<'ast, T>, level: usize, - constant: bool, +} + +impl<'ast, T> ScopedVariable<'ast, T> { + fn is_constant(&self) -> bool { + self.level == 0 + } } /// Identifiers of different `ScopedVariable`s should not conflict, so we define them as equivalent @@ -312,6 +334,7 @@ impl<'ast, T: Field> Checker<'ast, T> { fn check_constant_definition( &mut self, + id: &'ast str, c: ConstantDefinitionNode<'ast>, module_id: &ModuleId, types: &TypeMap<'ast>, @@ -320,7 +343,7 @@ impl<'ast, T: Field> Checker<'ast, T> { let ty = self.check_type(c.value.ty.clone(), module_id, &types)?; let checked_expr = self.check_expression(c.value.expression.clone(), module_id, types)?; - match ty.clone() { + match ty { Type::FieldElement => { FieldElementExpression::try_from_typed(checked_expr).map(TypedExpression::from) } @@ -330,10 +353,13 @@ impl<'ast, T: Field> Checker<'ast, T> { Type::Uint(bitwidth) => { UExpression::try_from_typed(checked_expr, bitwidth).map(TypedExpression::from) } - Type::Array(array_ty) => ArrayExpression::try_from_typed(checked_expr, *array_ty.ty) - .map(TypedExpression::from), - Type::Struct(struct_ty) => { - StructExpression::try_from_typed(checked_expr, struct_ty).map(TypedExpression::from) + Type::Array(ref array_ty) => { + ArrayExpression::try_from_typed(checked_expr, *array_ty.ty.clone()) + .map(TypedExpression::from) + } + Type::Struct(ref struct_ty) => { + StructExpression::try_from_typed(checked_expr, struct_ty.clone()) + .map(TypedExpression::from) } Type::Int => Err(checked_expr), // Integers cannot be assigned } @@ -343,15 +369,11 @@ impl<'ast, T: Field> Checker<'ast, T> { "Expression `{}` of type `{}` cannot be assigned to constant `{}` of type `{}`", e, e.get_type(), - c.value.id, + id, ty ), }) - .map(|e| TypedConstant { - id: identifier::Identifier::from(c.value.id), - ty, - expression: e, - }) + .map(|e| TypedConstant { ty, expression: e }) } fn check_struct_type_declaration( @@ -425,7 +447,7 @@ impl<'ast, T: Field> Checker<'ast, T> { &state.types, ) { Ok(ty) => { - match symbol_unifier.insert_symbol(declaration.id, SymbolType::Type) { + match symbol_unifier.insert_type(declaration.id) { false => errors.push( ErrorInner { pos: Some(pos), @@ -454,10 +476,10 @@ impl<'ast, T: Field> Checker<'ast, T> { } } SymbolDefinition::Constant(box c) => { - match self.check_constant_definition(c, module_id, &state.types) { + match self.check_constant_definition(declaration.id, c, module_id, &state.types) + { Ok(c) => { - match symbol_unifier.insert_symbol(declaration.id, SymbolType::Constant) - { + match symbol_unifier.insert_constant(declaration.id) { false => errors.push( ErrorInner { pos: Some(pos), @@ -468,10 +490,23 @@ impl<'ast, T: Field> Checker<'ast, T> { } .in_file(module_id), ), - true => {} + true => { + constants.insert( + declaration.id, + TypedConstantSymbol::Here(c.clone()), + ); + self.insert_into_scope(Variable::with_id_and_type( + declaration.id, + c.ty.clone(), + )); + assert!(state + .constants + .entry(module_id.to_path_buf()) + .or_default() + .insert(declaration.id, TypedConstantSymbol::Here(c)) + .is_none()); + } }; - constants.insert(declaration.id, TypedConstantSymbol::Here(c.clone())); - self.insert_into_scope(Variable::with_id_and_type(c.id, c.ty), true); } Err(e) => { errors.push(e.in_file(module_id)); @@ -550,16 +585,15 @@ impl<'ast, T: Field> Checker<'ast, T> { // find constant definition candidate let const_candidate = state - .typed_modules - .get(&import.module_id) - .unwrap() .constants - .as_ref() - .and_then(|tc| tc.get(import.symbol_id)) + .entry(import.module_id.to_path_buf()) + .or_default() + .get(import.symbol_id) .and_then(|sym| match sym { TypedConstantSymbol::Here(tc) => Some(tc), _ => None, - }); + }) + .cloned(); match (function_candidates.len(), type_candidate, const_candidate) { (0, Some(t), None) => { @@ -577,7 +611,7 @@ impl<'ast, T: Field> Checker<'ast, T> { }; // we imported a type, so the symbol it gets bound to should not already exist - match symbol_unifier.insert_symbol(declaration.id, SymbolType::Type) { + match symbol_unifier.insert_type(declaration.id) { false => { errors.push(Error { module_id: module_id.to_path_buf(), @@ -598,7 +632,7 @@ impl<'ast, T: Field> Checker<'ast, T> { .insert(declaration.id.to_string(), t); } (0, None, Some(c)) => { - match symbol_unifier.insert_symbol(declaration.id, SymbolType::Constant) { + match symbol_unifier.insert_constant(declaration.id) { false => { errors.push(Error { module_id: module_id.to_path_buf(), @@ -612,7 +646,13 @@ impl<'ast, T: Field> Checker<'ast, T> { } true => { constants.insert(declaration.id, TypedConstantSymbol::There(import.module_id, declaration.id)); - self.insert_into_scope(Variable::with_id_and_type(c.id.clone(), c.ty.clone()), true); + self.insert_into_scope(Variable::with_id_and_type(declaration.id, c.ty.clone())); + + state + .constants + .entry(module_id.to_path_buf()) + .or_default() + .insert(declaration.id, TypedConstantSymbol::Here(c)); // we insert as `Here` to avoid later recursive search } }; } @@ -732,7 +772,7 @@ impl<'ast, T: Field> Checker<'ast, T> { Some(TypedModule { functions: checked_functions, - constants: Some(checked_constants).filter(|m| !m.is_empty()), + constants: checked_constants, }) } }; @@ -814,7 +854,7 @@ impl<'ast, T: Field> Checker<'ast, T> { Type::Uint(UBitwidth::B32), ); // we don't have to check for conflicts here, because this was done when checking the signature - self.insert_into_scope(v.clone(), false); + self.insert_into_scope(v.clone()); } for (arg, decl_ty) in funct.arguments.into_iter().zip(s.inputs.iter()) { @@ -825,7 +865,7 @@ impl<'ast, T: Field> Checker<'ast, T> { let decl_v = DeclarationVariable::with_id_and_type(arg.id.value.id, decl_ty.clone()); - match self.insert_into_scope(decl_v.clone(), false) { + match self.insert_into_scope(decl_v.clone()) { true => {} false => { errors.push(ErrorInner { @@ -1216,7 +1256,7 @@ impl<'ast, T: Field> Checker<'ast, T> { } .map_err(|e| vec![e])?; - self.insert_into_scope(var.clone(), false); + self.insert_into_scope(var.clone()); let mut checked_statements = vec![]; @@ -1315,7 +1355,7 @@ impl<'ast, T: Field> Checker<'ast, T> { } Statement::Declaration(var) => { let var = self.check_variable(var, module_id, types)?; - match self.insert_into_scope(var.clone(), false) { + match self.insert_into_scope(var.clone()) { true => Ok(TypedStatement::Declaration(var)), false => Err(ErrorInner { pos: Some(pos), @@ -1498,15 +1538,15 @@ impl<'ast, T: Field> Checker<'ast, T> { // check that the assignee is declared match assignee.value { Assignee::Identifier(variable_name) => match self.get_scope(&variable_name) { - Some(var) => match var.constant { - false => Ok(TypedAssignee::Identifier(Variable::with_id_and_type( - variable_name, - var.id._type.clone(), - ))), + Some(var) => match var.is_constant() { true => Err(ErrorInner { pos: Some(assignee.pos()), message: format!("Assignment to constant variable `{}`", variable_name), }), + false => Ok(TypedAssignee::Identifier(Variable::with_id_and_type( + variable_name, + var.id._type.clone(), + ))), }, None => Err(ErrorInner { pos: Some(assignee.pos()), @@ -2919,15 +2959,13 @@ impl<'ast, T: Field> Checker<'ast, T> { Type::FieldElement, ), level: 0, - constant: false, }) } - fn insert_into_scope>>(&mut self, v: U, constant: bool) -> bool { + fn insert_into_scope>>(&mut self, v: U) -> bool { self.scope.insert(ScopedVariable { id: v.into(), level: self.level, - constant, }) } @@ -3123,9 +3161,9 @@ mod tests { let mut unifier = SymbolUnifier::default(); // the `foo` type - assert!(unifier.insert_symbol("foo", SymbolType::Type)); + assert!(unifier.insert_type("foo")); // the `foo` type annot be declared a second time - assert!(!unifier.insert_symbol("foo", SymbolType::Type)); + assert!(!unifier.insert_type("foo")); // the `foo` function cannot be declared as the name is already taken by a type assert!(!unifier.insert_function("foo", DeclarationSignature::new())); // the `bar` type @@ -3163,7 +3201,7 @@ mod tests { ))]) )); // a `bar` type isn't allowed as the name is already taken by at least one function - assert!(!unifier.insert_symbol("bar", SymbolType::Type)); + assert!(!unifier.insert_type("bar")); } #[test] @@ -3220,7 +3258,7 @@ mod tests { )] .into_iter() .collect(), - constants: None + constants: TypedConstantSymbols::default() }) ); } @@ -3755,6 +3793,8 @@ mod tests { let types = HashMap::new(); let mut checker: Checker = Checker::new(); + checker.enter_scope(); + assert_eq!( checker.check_statement(statement, &*MODULE_ID, &types), Err(vec![ErrorInner { @@ -3779,14 +3819,13 @@ mod tests { let mut scope = HashSet::new(); scope.insert(ScopedVariable { id: Variable::field_element("a"), - level: 0, - constant: false, + level: 1, }); scope.insert(ScopedVariable { id: Variable::field_element("b"), - level: 0, - constant: false, + level: 1, }); + let mut checker: Checker = new_with_args(scope, 1, HashSet::new()); assert_eq!( checker.check_statement(statement, &*MODULE_ID, &types), @@ -5770,6 +5809,7 @@ mod tests { let types = HashMap::new(); let mut checker: Checker = Checker::new(); + checker.enter_scope(); checker .check_statement( @@ -5803,6 +5843,8 @@ mod tests { let types = HashMap::new(); let mut checker: Checker = Checker::new(); + checker.enter_scope(); + checker .check_statement( Statement::Declaration( @@ -5853,6 +5895,8 @@ mod tests { let types = HashMap::new(); let mut checker: Checker = Checker::new(); + checker.enter_scope(); + checker .check_statement( Statement::Declaration( diff --git a/zokrates_core/src/static_analysis/constant_inliner.rs b/zokrates_core/src/static_analysis/constant_inliner.rs index 9c04dcaa..199a5a8b 100644 --- a/zokrates_core/src/static_analysis/constant_inliner.rs +++ b/zokrates_core/src/static_analysis/constant_inliner.rs @@ -39,14 +39,13 @@ impl<'ast, T: Field> ConstantInliner<'ast, T> { .get(&self.location) .unwrap() .constants - .as_ref() - .and_then(|c| c.get(id.clone().try_into().unwrap())) + .get(id.clone().try_into().unwrap()) .cloned() .and_then(|tc| { let symbol = self.fold_constant_symbol(tc); match symbol { TypedConstantSymbol::Here(tc) => Some(tc), - _ => None, + _ => unreachable!(), } }) } @@ -67,6 +66,21 @@ impl<'ast, T: Field> Folder<'ast, T> for ConstantInliner<'ast, T> { } } + fn fold_module(&mut self, p: TypedModule<'ast, T>) -> TypedModule<'ast, T> { + TypedModule { + constants: p + .constants + .into_iter() + .map(|(key, tc)| (key, self.fold_constant_symbol(tc))) + .collect(), + functions: p + .functions + .into_iter() + .map(|(key, fun)| (key, self.fold_function_symbol(fun))) + .collect(), + } + } + fn fold_constant_symbol( &mut self, p: TypedConstantSymbol<'ast, T>, @@ -74,13 +88,7 @@ impl<'ast, T: Field> Folder<'ast, T> for ConstantInliner<'ast, T> { match p { TypedConstantSymbol::There(module_id, id) => { let location = self.change_location(module_id); - let symbol = self - .module() - .constants - .as_ref() - .and_then(|c| c.get(id)) - .unwrap() - .to_owned(); + let symbol = self.module().constants.get(id).cloned().unwrap(); let symbol = self.fold_constant_symbol(symbol); let _ = self.change_location(location); @@ -196,7 +204,6 @@ mod tests { let constants: TypedConstantSymbols<_> = vec![( const_id, TypedConstantSymbol::Here(TypedConstant { - id: Identifier::from(const_id), ty: GType::FieldElement, expression: (TypedExpression::FieldElement(FieldElementExpression::Number( Bn128Field::from(1), @@ -221,7 +228,7 @@ mod tests { )] .into_iter() .collect(), - constants: Some(constants.clone()), + constants: constants.clone(), }, )] .into_iter() @@ -255,7 +262,7 @@ mod tests { )] .into_iter() .collect(), - constants: Some(constants), + constants, }, )] .into_iter() @@ -287,7 +294,6 @@ mod tests { let constants: TypedConstantSymbols<_> = vec![( const_id, TypedConstantSymbol::Here(TypedConstant { - id: Identifier::from(const_id), ty: GType::Boolean, expression: (TypedExpression::Boolean(BooleanExpression::Value(true))), }), @@ -310,7 +316,7 @@ mod tests { )] .into_iter() .collect(), - constants: Some(constants.clone()), + constants: constants.clone(), }, )] .into_iter() @@ -344,7 +350,7 @@ mod tests { )] .into_iter() .collect(), - constants: Some(constants), + constants, }, )] .into_iter() @@ -377,7 +383,6 @@ mod tests { let constants: TypedConstantSymbols<_> = vec![( const_id, TypedConstantSymbol::Here(TypedConstant { - id: Identifier::from(const_id), ty: GType::Uint(UBitwidth::B32), expression: (UExpressionInner::Value(1u128) .annotate(UBitwidth::B32) @@ -402,7 +407,7 @@ mod tests { )] .into_iter() .collect(), - constants: Some(constants.clone()), + constants: constants.clone(), }, )] .into_iter() @@ -436,7 +441,7 @@ mod tests { )] .into_iter() .collect(), - constants: Some(constants), + constants, }, )] .into_iter() @@ -479,7 +484,6 @@ mod tests { let constants: TypedConstantSymbols<_> = vec![( const_id, TypedConstantSymbol::Here(TypedConstant { - id: Identifier::from(const_id), ty: GType::FieldElement, expression: TypedExpression::Array( ArrayExpressionInner::Value( @@ -511,7 +515,7 @@ mod tests { )] .into_iter() .collect(), - constants: Some(constants.clone()), + constants: constants.clone(), }, )] .into_iter() @@ -569,7 +573,7 @@ mod tests { )] .into_iter() .collect(), - constants: Some(constants), + constants, }, )] .into_iter() @@ -615,37 +619,33 @@ mod tests { )] .into_iter() .collect(), - constants: Some( - vec![ - ( - const_a_id, - TypedConstantSymbol::Here(TypedConstant { - id: Identifier::from(const_a_id), - ty: GType::FieldElement, - expression: (TypedExpression::FieldElement( - FieldElementExpression::Number(Bn128Field::from(1)), - )), - }), - ), - ( - const_b_id, - TypedConstantSymbol::Here(TypedConstant { - id: Identifier::from(const_b_id), - ty: GType::FieldElement, - expression: (TypedExpression::FieldElement( - FieldElementExpression::Add( - box FieldElementExpression::Identifier( - Identifier::from(const_a_id), - ), - box FieldElementExpression::Number(Bn128Field::from(1)), - ), - )), - }), - ), - ] - .into_iter() - .collect(), - ), + constants: vec![ + ( + const_a_id, + TypedConstantSymbol::Here(TypedConstant { + ty: GType::FieldElement, + expression: (TypedExpression::FieldElement( + FieldElementExpression::Number(Bn128Field::from(1)), + )), + }), + ), + ( + const_b_id, + TypedConstantSymbol::Here(TypedConstant { + ty: GType::FieldElement, + expression: (TypedExpression::FieldElement( + FieldElementExpression::Add( + box FieldElementExpression::Identifier(Identifier::from( + const_a_id, + )), + box FieldElementExpression::Number(Bn128Field::from(1)), + ), + )), + }), + ), + ] + .into_iter() + .collect(), }, )] .into_iter() @@ -681,35 +681,31 @@ mod tests { )] .into_iter() .collect(), - constants: Some( - vec![ - ( - const_a_id, - TypedConstantSymbol::Here(TypedConstant { - id: Identifier::from(const_a_id), - ty: GType::FieldElement, - expression: (TypedExpression::FieldElement( - FieldElementExpression::Number(Bn128Field::from(1)), - )), - }), - ), - ( - const_b_id, - TypedConstantSymbol::Here(TypedConstant { - id: Identifier::from(const_b_id), - ty: GType::FieldElement, - expression: (TypedExpression::FieldElement( - FieldElementExpression::Add( - box FieldElementExpression::Number(Bn128Field::from(1)), - box FieldElementExpression::Number(Bn128Field::from(1)), - ), - )), - }), - ), - ] - .into_iter() - .collect(), - ), + constants: vec![ + ( + const_a_id, + TypedConstantSymbol::Here(TypedConstant { + ty: GType::FieldElement, + expression: (TypedExpression::FieldElement( + FieldElementExpression::Number(Bn128Field::from(1)), + )), + }), + ), + ( + const_b_id, + TypedConstantSymbol::Here(TypedConstant { + ty: GType::FieldElement, + expression: (TypedExpression::FieldElement( + FieldElementExpression::Add( + box FieldElementExpression::Number(Bn128Field::from(1)), + box FieldElementExpression::Number(Bn128Field::from(1)), + ), + )), + }), + ), + ] + .into_iter() + .collect(), }, )] .into_iter() @@ -750,20 +746,17 @@ mod tests { )] .into_iter() .collect(), - constants: Some( - vec![( - foo_const_id, - TypedConstantSymbol::Here(TypedConstant { - id: Identifier::from(foo_const_id), - ty: GType::FieldElement, - expression: (TypedExpression::FieldElement( - FieldElementExpression::Number(Bn128Field::from(42)), - )), - }), - )] - .into_iter() - .collect(), - ), + constants: vec![( + foo_const_id, + TypedConstantSymbol::Here(TypedConstant { + ty: GType::FieldElement, + expression: (TypedExpression::FieldElement(FieldElementExpression::Number( + Bn128Field::from(42), + ))), + }), + )] + .into_iter() + .collect(), }; let main_module = TypedModule { @@ -785,14 +778,12 @@ mod tests { )] .into_iter() .collect(), - constants: Some( - vec![( - foo_const_id, - TypedConstantSymbol::There(OwnedTypedModuleId::from("foo"), foo_const_id), - )] - .into_iter() - .collect(), - ), + constants: vec![( + foo_const_id, + TypedConstantSymbol::There(OwnedTypedModuleId::from("foo"), foo_const_id), + )] + .into_iter() + .collect(), }; let program = TypedProgram { @@ -825,20 +816,17 @@ mod tests { )] .into_iter() .collect(), - constants: Some( - vec![( - foo_const_id, - TypedConstantSymbol::Here(TypedConstant { - id: Identifier::from(foo_const_id), - ty: GType::FieldElement, - expression: (TypedExpression::FieldElement( - FieldElementExpression::Number(Bn128Field::from(42)), - )), - }), - )] - .into_iter() - .collect(), - ), + constants: vec![( + foo_const_id, + TypedConstantSymbol::Here(TypedConstant { + ty: GType::FieldElement, + expression: (TypedExpression::FieldElement(FieldElementExpression::Number( + Bn128Field::from(42), + ))), + }), + )] + .into_iter() + .collect(), }; let expected_program: TypedProgram = TypedProgram { diff --git a/zokrates_core/src/static_analysis/propagation.rs b/zokrates_core/src/static_analysis/propagation.rs index 45fc1779..63c7cd5b 100644 --- a/zokrates_core/src/static_analysis/propagation.rs +++ b/zokrates_core/src/static_analysis/propagation.rs @@ -250,7 +250,7 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Propagator<'ast, 'a, T> { } }) .collect::>()?, - constants: m.constants, + ..m }) } diff --git a/zokrates_core/src/static_analysis/reducer/mod.rs b/zokrates_core/src/static_analysis/reducer/mod.rs index 05191af1..57e17db2 100644 --- a/zokrates_core/src/static_analysis/reducer/mod.rs +++ b/zokrates_core/src/static_analysis/reducer/mod.rs @@ -547,7 +547,7 @@ pub fn reduce_program(p: TypedProgram) -> Result, E )] .into_iter() .collect(), - constants: None, + constants: Default::default(), }, )] .into_iter() @@ -769,7 +769,7 @@ mod tests { ] .into_iter() .collect(), - constants: None, + constants: Default::default(), }, )] .into_iter() @@ -835,7 +835,7 @@ mod tests { )] .into_iter() .collect(), - constants: None, + constants: Default::default(), }, )] .into_iter() @@ -964,7 +964,7 @@ mod tests { ] .into_iter() .collect(), - constants: None, + constants: Default::default(), }, )] .into_iter() @@ -1045,7 +1045,7 @@ mod tests { )] .into_iter() .collect(), - constants: None, + constants: Default::default(), }, )] .into_iter() @@ -1183,7 +1183,7 @@ mod tests { ] .into_iter() .collect(), - constants: None, + constants: Default::default(), }, )] .into_iter() @@ -1264,7 +1264,7 @@ mod tests { )] .into_iter() .collect(), - constants: None, + constants: Default::default(), }, )] .into_iter() @@ -1441,7 +1441,7 @@ mod tests { ] .into_iter() .collect(), - constants: None, + constants: Default::default(), }, )] .into_iter() @@ -1545,7 +1545,7 @@ mod tests { )] .into_iter() .collect(), - constants: None, + constants: Default::default(), }, )] .into_iter() @@ -1629,7 +1629,7 @@ mod tests { ] .into_iter() .collect(), - constants: None, + constants: Default::default(), }, )] .into_iter() diff --git a/zokrates_core/src/typed_absy/abi.rs b/zokrates_core/src/typed_absy/abi.rs index 7b189e97..f3adbdea 100644 --- a/zokrates_core/src/typed_absy/abi.rs +++ b/zokrates_core/src/typed_absy/abi.rs @@ -69,7 +69,7 @@ mod tests { "main".into(), TypedModule { functions, - constants: None, + constants: Default::default(), }, ); diff --git a/zokrates_core/src/typed_absy/folder.rs b/zokrates_core/src/typed_absy/folder.rs index 2393b84b..30385b1b 100644 --- a/zokrates_core/src/typed_absy/folder.rs +++ b/zokrates_core/src/typed_absy/folder.rs @@ -200,16 +200,12 @@ pub fn fold_module<'ast, T: Field, F: Folder<'ast, T>>( p: TypedModule<'ast, T>, ) -> TypedModule<'ast, T> { TypedModule { - constants: p.constants.map(|tc| { - tc.into_iter() - .map(|(key, tc)| (key, f.fold_constant_symbol(tc))) - .collect() - }), functions: p .functions .into_iter() .map(|(key, fun)| (key, f.fold_function_symbol(fun))) .collect(), + ..p } } diff --git a/zokrates_core/src/typed_absy/mod.rs b/zokrates_core/src/typed_absy/mod.rs index 2789f0aa..101e1395 100644 --- a/zokrates_core/src/typed_absy/mod.rs +++ b/zokrates_core/src/typed_absy/mod.rs @@ -153,7 +153,7 @@ pub struct TypedModule<'ast, T> { /// Functions of the module pub functions: TypedFunctionSymbols<'ast, T>, /// Constants defined in module - pub constants: Option>, + pub constants: TypedConstantSymbols<'ast, T>, } #[derive(Clone, PartialEq)] @@ -322,24 +322,19 @@ impl<'ast, T: fmt::Debug> fmt::Debug for TypedFunction<'ast, T> { #[derive(Clone, PartialEq)] pub struct TypedConstant<'ast, T> { - pub id: Identifier<'ast>, pub ty: Type<'ast, T>, pub expression: TypedExpression<'ast, T>, } impl<'ast, T: fmt::Debug> fmt::Debug for TypedConstant<'ast, T> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!( - f, - "TypedConstant({:?}, {:?}, {:?})", - self.id, self.ty, self.expression - ) + write!(f, "TypedConstant({:?}, {:?})", self.ty, self.expression) } } impl<'ast, T: fmt::Display> fmt::Display for TypedConstant<'ast, T> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "const {} {} = {}", self.ty, self.id, self.expression) + write!(f, "const {}({})", self.ty, self.expression) } } diff --git a/zokrates_core/src/typed_absy/result_folder.rs b/zokrates_core/src/typed_absy/result_folder.rs index 977e04f2..82ae443a 100644 --- a/zokrates_core/src/typed_absy/result_folder.rs +++ b/zokrates_core/src/typed_absy/result_folder.rs @@ -21,6 +21,13 @@ pub trait ResultFolder<'ast, T: Field>: Sized { fold_module(self, p) } + fn fold_constant_symbol( + &mut self, + s: TypedConstantSymbol<'ast, T>, + ) -> Result, Self::Error> { + fold_constant_symbol(self, s) + } + fn fold_function_symbol( &mut self, s: TypedFunctionSymbol<'ast, T>, @@ -793,6 +800,19 @@ pub fn fold_struct_expression<'ast, T: Field, F: ResultFolder<'ast, T>>( }) } +pub fn fold_constant_symbol<'ast, T: Field, F: ResultFolder<'ast, T>>( + f: &mut F, + s: TypedConstantSymbol<'ast, T>, +) -> Result, F::Error> { + match s { + TypedConstantSymbol::Here(tc) => Ok(TypedConstantSymbol::Here(TypedConstant { + expression: f.fold_expression(tc.expression)?, + ..tc + })), + there => Ok(there), + } +} + pub fn fold_function_symbol<'ast, T: Field, F: ResultFolder<'ast, T>>( f: &mut F, s: TypedFunctionSymbol<'ast, T>, @@ -813,7 +833,7 @@ pub fn fold_module<'ast, T: Field, F: ResultFolder<'ast, T>>( .into_iter() .map(|(key, fun)| f.fold_function_symbol(fun).map(|f| (key, f))) .collect::>()?, - constants: p.constants, + ..p }) } From cd276760b1b7c86e9da827c521391b82d39fc6e6 Mon Sep 17 00:00:00 2001 From: dark64 Date: Thu, 22 Apr 2021 13:38:54 +0200 Subject: [PATCH 09/12] allow clippy rule, remove unreachable arms --- zokrates_core/src/absy/from_ast.rs | 2 +- zokrates_core/src/absy/mod.rs | 3 +- zokrates_core/src/semantics.rs | 17 +++---- .../src/static_analysis/constant_inliner.rs | 48 ++++++++++--------- 4 files changed, 36 insertions(+), 34 deletions(-) diff --git a/zokrates_core/src/absy/from_ast.rs b/zokrates_core/src/absy/from_ast.rs index 10436d86..652932c6 100644 --- a/zokrates_core/src/absy/from_ast.rs +++ b/zokrates_core/src/absy/from_ast.rs @@ -119,7 +119,7 @@ impl<'ast> From> for absy::SymbolDeclarationNode< absy::SymbolDeclaration { id, - symbol: absy::Symbol::Here(SymbolDefinition::Constant(box ty)), + symbol: absy::Symbol::Here(SymbolDefinition::Constant(ty)), } .span(span) } diff --git a/zokrates_core/src/absy/mod.rs b/zokrates_core/src/absy/mod.rs index b9fccd77..d070860d 100644 --- a/zokrates_core/src/absy/mod.rs +++ b/zokrates_core/src/absy/mod.rs @@ -51,10 +51,11 @@ pub struct SymbolDeclaration<'ast> { pub symbol: Symbol<'ast>, } +#[allow(clippy::large_enum_variant)] #[derive(PartialEq, Clone)] pub enum SymbolDefinition<'ast> { Struct(StructDefinitionNode<'ast>), - Constant(Box>), + Constant(ConstantDefinitionNode<'ast>), Function(FunctionNode<'ast>), } diff --git a/zokrates_core/src/semantics.rs b/zokrates_core/src/semantics.rs index 6debe421..443adaa8 100644 --- a/zokrates_core/src/semantics.rs +++ b/zokrates_core/src/semantics.rs @@ -46,7 +46,8 @@ impl ErrorInner { } type TypeMap<'ast> = HashMap>>; -type ConstantMap<'ast, T> = HashMap>; +type ConstantMap<'ast, T> = + HashMap, Type<'ast, T>>>; /// The global state of the program during semantic checks #[derive(Debug)] @@ -475,7 +476,7 @@ impl<'ast, T: Field> Checker<'ast, T> { })), } } - SymbolDefinition::Constant(box c) => { + SymbolDefinition::Constant(c) => { match self.check_constant_definition(declaration.id, c, module_id, &state.types) { Ok(c) => { @@ -503,7 +504,7 @@ impl<'ast, T: Field> Checker<'ast, T> { .constants .entry(module_id.to_path_buf()) .or_default() - .insert(declaration.id, TypedConstantSymbol::Here(c)) + .insert(declaration.id, c.ty) .is_none()); } }; @@ -589,10 +590,6 @@ impl<'ast, T: Field> Checker<'ast, T> { .entry(import.module_id.to_path_buf()) .or_default() .get(import.symbol_id) - .and_then(|sym| match sym { - TypedConstantSymbol::Here(tc) => Some(tc), - _ => None, - }) .cloned(); match (function_candidates.len(), type_candidate, const_candidate) { @@ -631,7 +628,7 @@ impl<'ast, T: Field> Checker<'ast, T> { .or_default() .insert(declaration.id.to_string(), t); } - (0, None, Some(c)) => { + (0, None, Some(ty)) => { match symbol_unifier.insert_constant(declaration.id) { false => { errors.push(Error { @@ -646,13 +643,13 @@ impl<'ast, T: Field> Checker<'ast, T> { } true => { constants.insert(declaration.id, TypedConstantSymbol::There(import.module_id, declaration.id)); - self.insert_into_scope(Variable::with_id_and_type(declaration.id, c.ty.clone())); + self.insert_into_scope(Variable::with_id_and_type(declaration.id, ty.clone())); state .constants .entry(module_id.to_path_buf()) .or_default() - .insert(declaration.id, TypedConstantSymbol::Here(c)); // we insert as `Here` to avoid later recursive search + .insert(declaration.id, ty); } }; } diff --git a/zokrates_core/src/static_analysis/constant_inliner.rs b/zokrates_core/src/static_analysis/constant_inliner.rs index 199a5a8b..fff3ca64 100644 --- a/zokrates_core/src/static_analysis/constant_inliner.rs +++ b/zokrates_core/src/static_analysis/constant_inliner.rs @@ -17,37 +17,50 @@ impl<'ast, T: Field> ConstantInliner<'ast, T> { } pub fn inline(p: TypedProgram<'ast, T>) -> TypedProgram<'ast, T> { - // initialize an inliner over all modules, starting from the main module let mut inliner = ConstantInliner::with_modules_and_location(p.modules.clone(), p.main.clone()); inliner.fold_program(p) } - pub fn module(&self) -> &TypedModule<'ast, T> { + fn module(&self) -> &TypedModule<'ast, T> { self.modules.get(&self.location).unwrap() } - pub fn change_location(&mut self, location: OwnedTypedModuleId) -> OwnedTypedModuleId { + fn change_location(&mut self, location: OwnedTypedModuleId) -> OwnedTypedModuleId { let prev = self.location.clone(); self.location = location; prev } - pub fn get_constant(&mut self, id: &Identifier) -> Option> { + fn get_constant(&mut self, id: &Identifier) -> Option> { self.modules .get(&self.location) .unwrap() .constants .get(id.clone().try_into().unwrap()) .cloned() - .and_then(|tc| { - let symbol = self.fold_constant_symbol(tc); - match symbol { - TypedConstantSymbol::Here(tc) => Some(tc), - _ => unreachable!(), - } - }) + .map(|symbol| self.get_canonical_constant(symbol)) + } + + fn get_canonical_constant( + &mut self, + symbol: TypedConstantSymbol<'ast, T>, + ) -> TypedConstant<'ast, T> { + match symbol { + TypedConstantSymbol::There(module_id, id) => { + let location = self.change_location(module_id); + let symbol = self.module().constants.get(id).cloned().unwrap(); + + let symbol = self.get_canonical_constant(symbol); + let _ = self.change_location(location); + symbol + } + TypedConstantSymbol::Here(tc) => TypedConstant { + expression: self.fold_expression(tc.expression), + ..tc + }, + } } } @@ -85,17 +98,8 @@ impl<'ast, T: Field> Folder<'ast, T> for ConstantInliner<'ast, T> { &mut self, p: TypedConstantSymbol<'ast, T>, ) -> TypedConstantSymbol<'ast, T> { - match p { - TypedConstantSymbol::There(module_id, id) => { - let location = self.change_location(module_id); - let symbol = self.module().constants.get(id).cloned().unwrap(); - - let symbol = self.fold_constant_symbol(symbol); - let _ = self.change_location(location); - symbol - } - _ => fold_constant_symbol(self, p), - } + let tc = self.get_canonical_constant(p); + TypedConstantSymbol::Here(tc) } fn fold_field_expression( From ae7e095cacdab3367b35c25fe32b813704b43b75 Mon Sep 17 00:00:00 2001 From: dark64 Date: Mon, 26 Apr 2021 13:16:25 +0200 Subject: [PATCH 10/12] refactoring, fix constant aliasing --- .../examples/imports/import_with_alias.zok | 4 +- zokrates_core/src/semantics.rs | 8 +- .../src/static_analysis/constant_inliner.rs | 145 +++++++++--------- zokrates_core/src/typed_absy/folder.rs | 27 +++- zokrates_core/src/typed_absy/mod.rs | 90 ++++++++++- zokrates_core/src/typed_absy/result_folder.rs | 22 ++- 6 files changed, 196 insertions(+), 100 deletions(-) diff --git a/zokrates_cli/examples/imports/import_with_alias.zok b/zokrates_cli/examples/imports/import_with_alias.zok index dd8994bf..5e013691 100644 --- a/zokrates_cli/examples/imports/import_with_alias.zok +++ b/zokrates_cli/examples/imports/import_with_alias.zok @@ -1,8 +1,8 @@ from "./bar" import main as bar -from "./baz" import main as baz +from "./baz" import BAZ as baz import "./foo" as f def main() -> field: field foo = f() - assert(foo == bar() + baz()) + assert(foo == bar() + baz) return foo \ No newline at end of file diff --git a/zokrates_core/src/semantics.rs b/zokrates_core/src/semantics.rs index 443adaa8..4c9836c6 100644 --- a/zokrates_core/src/semantics.rs +++ b/zokrates_core/src/semantics.rs @@ -374,7 +374,7 @@ impl<'ast, T: Field> Checker<'ast, T> { ty ), }) - .map(|e| TypedConstant { ty, expression: e }) + .map(|e| TypedConstant::new(ty, e)) } fn check_struct_type_declaration( @@ -498,13 +498,13 @@ impl<'ast, T: Field> Checker<'ast, T> { ); self.insert_into_scope(Variable::with_id_and_type( declaration.id, - c.ty.clone(), + c.get_type(), )); assert!(state .constants .entry(module_id.to_path_buf()) .or_default() - .insert(declaration.id, c.ty) + .insert(declaration.id, c.get_type()) .is_none()); } }; @@ -642,7 +642,7 @@ impl<'ast, T: Field> Checker<'ast, T> { }}); } true => { - constants.insert(declaration.id, TypedConstantSymbol::There(import.module_id, declaration.id)); + constants.insert(declaration.id, TypedConstantSymbol::There(import.module_id, import.symbol_id)); self.insert_into_scope(Variable::with_id_and_type(declaration.id, ty.clone())); state diff --git a/zokrates_core/src/static_analysis/constant_inliner.rs b/zokrates_core/src/static_analysis/constant_inliner.rs index fff3ca64..3e082fbf 100644 --- a/zokrates_core/src/static_analysis/constant_inliner.rs +++ b/zokrates_core/src/static_analysis/constant_inliner.rs @@ -9,17 +9,12 @@ pub struct ConstantInliner<'ast, T: Field> { } impl<'ast, T: Field> ConstantInliner<'ast, T> { - fn with_modules_and_location( - modules: TypedModules<'ast, T>, - location: OwnedTypedModuleId, - ) -> Self { + pub fn new(modules: TypedModules<'ast, T>, location: OwnedTypedModuleId) -> Self { ConstantInliner { modules, location } } pub fn inline(p: TypedProgram<'ast, T>) -> TypedProgram<'ast, T> { - let mut inliner = - ConstantInliner::with_modules_and_location(p.modules.clone(), p.main.clone()); - + let mut inliner = ConstantInliner::new(p.modules.clone(), p.main.clone()); inliner.fold_program(p) } @@ -56,10 +51,7 @@ impl<'ast, T: Field> ConstantInliner<'ast, T> { let _ = self.change_location(location); symbol } - TypedConstantSymbol::Here(tc) => TypedConstant { - expression: self.fold_expression(tc.expression), - ..tc - }, + TypedConstantSymbol::Here(tc) => self.fold_constant(tc), } } } @@ -96,9 +88,9 @@ impl<'ast, T: Field> Folder<'ast, T> for ConstantInliner<'ast, T> { fn fold_constant_symbol( &mut self, - p: TypedConstantSymbol<'ast, T>, + s: TypedConstantSymbol<'ast, T>, ) -> TypedConstantSymbol<'ast, T> { - let tc = self.get_canonical_constant(p); + let tc = self.get_canonical_constant(s); TypedConstantSymbol::Here(tc) } @@ -108,7 +100,10 @@ impl<'ast, T: Field> Folder<'ast, T> for ConstantInliner<'ast, T> { ) -> FieldElementExpression<'ast, T> { match e { FieldElementExpression::Identifier(ref id) => match self.get_constant(id) { - Some(c) => fold_field_expression(self, c.expression.try_into().unwrap()), + Some(c) => { + let c = self.fold_constant(c); + fold_field_expression(self, c.try_into().unwrap()) + } None => fold_field_expression(self, e), }, e => fold_field_expression(self, e), @@ -121,7 +116,10 @@ impl<'ast, T: Field> Folder<'ast, T> for ConstantInliner<'ast, T> { ) -> BooleanExpression<'ast, T> { match e { BooleanExpression::Identifier(ref id) => match self.get_constant(id) { - Some(c) => fold_boolean_expression(self, c.expression.try_into().unwrap()), + Some(c) => { + let c = self.fold_constant(c); + fold_boolean_expression(self, c.try_into().unwrap()) + } None => fold_boolean_expression(self, e), }, e => fold_boolean_expression(self, e), @@ -136,7 +134,8 @@ impl<'ast, T: Field> Folder<'ast, T> for ConstantInliner<'ast, T> { match e { UExpressionInner::Identifier(ref id) => match self.get_constant(id) { Some(c) => { - fold_uint_expression(self, c.expression.try_into().unwrap()).into_inner() + let c = self.fold_constant(c); + fold_uint_expression(self, c.try_into().unwrap()).into_inner() } None => fold_uint_expression_inner(self, size, e), }, @@ -152,7 +151,8 @@ impl<'ast, T: Field> Folder<'ast, T> for ConstantInliner<'ast, T> { match e { ArrayExpressionInner::Identifier(ref id) => match self.get_constant(id) { Some(c) => { - fold_array_expression(self, c.expression.try_into().unwrap()).into_inner() + let c = self.fold_constant(c); + fold_array_expression(self, c.try_into().unwrap()).into_inner() } None => fold_array_expression_inner(self, ty, e), }, @@ -168,7 +168,8 @@ impl<'ast, T: Field> Folder<'ast, T> for ConstantInliner<'ast, T> { match e { StructExpressionInner::Identifier(ref id) => match self.get_constant(id) { Some(c) => { - fold_struct_expression(self, c.expression.try_into().unwrap()).into_inner() + let c = self.fold_constant(c); + fold_struct_expression(self, c.try_into().unwrap()).into_inner() } None => fold_struct_expression_inner(self, ty, e), }, @@ -207,12 +208,10 @@ mod tests { let constants: TypedConstantSymbols<_> = vec![( const_id, - TypedConstantSymbol::Here(TypedConstant { - ty: GType::FieldElement, - expression: (TypedExpression::FieldElement(FieldElementExpression::Number( - Bn128Field::from(1), - ))), - }), + TypedConstantSymbol::Here(TypedConstant::new( + GType::FieldElement, + TypedExpression::FieldElement(FieldElementExpression::Number(Bn128Field::from(1))), + )), )] .into_iter() .collect(); @@ -297,10 +296,10 @@ mod tests { let constants: TypedConstantSymbols<_> = vec![( const_id, - TypedConstantSymbol::Here(TypedConstant { - ty: GType::Boolean, - expression: (TypedExpression::Boolean(BooleanExpression::Value(true))), - }), + TypedConstantSymbol::Here(TypedConstant::new( + GType::Boolean, + TypedExpression::Boolean(BooleanExpression::Value(true)), + )), )] .into_iter() .collect(); @@ -386,12 +385,12 @@ mod tests { let constants: TypedConstantSymbols<_> = vec![( const_id, - TypedConstantSymbol::Here(TypedConstant { - ty: GType::Uint(UBitwidth::B32), - expression: (UExpressionInner::Value(1u128) + TypedConstantSymbol::Here(TypedConstant::new( + GType::Uint(UBitwidth::B32), + UExpressionInner::Value(1u128) .annotate(UBitwidth::B32) - .into()), - }), + .into(), + )), )] .into_iter() .collect(); @@ -487,9 +486,9 @@ mod tests { let constants: TypedConstantSymbols<_> = vec![( const_id, - TypedConstantSymbol::Here(TypedConstant { - ty: GType::FieldElement, - expression: TypedExpression::Array( + TypedConstantSymbol::Here(TypedConstant::new( + GType::FieldElement, + TypedExpression::Array( ArrayExpressionInner::Value( vec![ FieldElementExpression::Number(Bn128Field::from(2)).into(), @@ -499,7 +498,7 @@ mod tests { ) .annotate(GType::FieldElement, 2usize), ), - }), + )), )] .into_iter() .collect(); @@ -626,26 +625,24 @@ mod tests { constants: vec![ ( const_a_id, - TypedConstantSymbol::Here(TypedConstant { - ty: GType::FieldElement, - expression: (TypedExpression::FieldElement( - FieldElementExpression::Number(Bn128Field::from(1)), + TypedConstantSymbol::Here(TypedConstant::new( + GType::FieldElement, + TypedExpression::FieldElement(FieldElementExpression::Number( + Bn128Field::from(1), )), - }), + )), ), ( const_b_id, - TypedConstantSymbol::Here(TypedConstant { - ty: GType::FieldElement, - expression: (TypedExpression::FieldElement( - FieldElementExpression::Add( - box FieldElementExpression::Identifier(Identifier::from( - const_a_id, - )), - box FieldElementExpression::Number(Bn128Field::from(1)), - ), + TypedConstantSymbol::Here(TypedConstant::new( + GType::FieldElement, + TypedExpression::FieldElement(FieldElementExpression::Add( + box FieldElementExpression::Identifier(Identifier::from( + const_a_id, + )), + box FieldElementExpression::Number(Bn128Field::from(1)), )), - }), + )), ), ] .into_iter() @@ -688,24 +685,22 @@ mod tests { constants: vec![ ( const_a_id, - TypedConstantSymbol::Here(TypedConstant { - ty: GType::FieldElement, - expression: (TypedExpression::FieldElement( - FieldElementExpression::Number(Bn128Field::from(1)), + TypedConstantSymbol::Here(TypedConstant::new( + GType::FieldElement, + TypedExpression::FieldElement(FieldElementExpression::Number( + Bn128Field::from(1), )), - }), + )), ), ( const_b_id, - TypedConstantSymbol::Here(TypedConstant { - ty: GType::FieldElement, - expression: (TypedExpression::FieldElement( - FieldElementExpression::Add( - box FieldElementExpression::Number(Bn128Field::from(1)), - box FieldElementExpression::Number(Bn128Field::from(1)), - ), + TypedConstantSymbol::Here(TypedConstant::new( + GType::FieldElement, + TypedExpression::FieldElement(FieldElementExpression::Add( + box FieldElementExpression::Number(Bn128Field::from(1)), + box FieldElementExpression::Number(Bn128Field::from(1)), )), - }), + )), ), ] .into_iter() @@ -752,12 +747,12 @@ mod tests { .collect(), constants: vec![( foo_const_id, - TypedConstantSymbol::Here(TypedConstant { - ty: GType::FieldElement, - expression: (TypedExpression::FieldElement(FieldElementExpression::Number( + TypedConstantSymbol::Here(TypedConstant::new( + GType::FieldElement, + TypedExpression::FieldElement(FieldElementExpression::Number( Bn128Field::from(42), - ))), - }), + )), + )), )] .into_iter() .collect(), @@ -822,12 +817,12 @@ mod tests { .collect(), constants: vec![( foo_const_id, - TypedConstantSymbol::Here(TypedConstant { - ty: GType::FieldElement, - expression: (TypedExpression::FieldElement(FieldElementExpression::Number( + TypedConstantSymbol::Here(TypedConstant::new( + GType::FieldElement, + TypedExpression::FieldElement(FieldElementExpression::Number( Bn128Field::from(42), - ))), - }), + )), + )), )] .into_iter() .collect(), diff --git a/zokrates_core/src/typed_absy/folder.rs b/zokrates_core/src/typed_absy/folder.rs index 30385b1b..610aa710 100644 --- a/zokrates_core/src/typed_absy/folder.rs +++ b/zokrates_core/src/typed_absy/folder.rs @@ -13,11 +13,15 @@ pub trait Folder<'ast, T: Field>: Sized { fold_module(self, p) } + fn fold_constant(&mut self, c: TypedConstant<'ast, T>) -> TypedConstant<'ast, T> { + fold_constant(self, c) + } + fn fold_constant_symbol( &mut self, - p: TypedConstantSymbol<'ast, T>, + s: TypedConstantSymbol<'ast, T>, ) -> TypedConstantSymbol<'ast, T> { - fold_constant_symbol(self, p) + fold_constant_symbol(self, s) } fn fold_function_symbol( @@ -719,15 +723,22 @@ pub fn fold_struct_expression<'ast, T: Field, F: Folder<'ast, T>>( } } +pub fn fold_constant<'ast, T: Field, F: Folder<'ast, T>>( + f: &mut F, + c: TypedConstant<'ast, T>, +) -> TypedConstant<'ast, T> { + TypedConstant { + expression: f.fold_expression(c.expression), + ..c + } +} + pub fn fold_constant_symbol<'ast, T: Field, F: Folder<'ast, T>>( f: &mut F, - p: TypedConstantSymbol<'ast, T>, + s: TypedConstantSymbol<'ast, T>, ) -> TypedConstantSymbol<'ast, T> { - match p { - TypedConstantSymbol::Here(tc) => TypedConstantSymbol::Here(TypedConstant { - expression: f.fold_expression(tc.expression), - ..tc - }), + match s { + TypedConstantSymbol::Here(tc) => TypedConstantSymbol::Here(f.fold_constant(tc)), there => there, } } diff --git a/zokrates_core/src/typed_absy/mod.rs b/zokrates_core/src/typed_absy/mod.rs index 101e1395..408920d2 100644 --- a/zokrates_core/src/typed_absy/mod.rs +++ b/zokrates_core/src/typed_absy/mod.rs @@ -196,22 +196,31 @@ impl<'ast, T: Field> TypedFunctionSymbol<'ast, T> { impl<'ast, T: fmt::Display> fmt::Display for TypedModule<'ast, T> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { let res = self - .functions + .constants .iter() .map(|(key, symbol)| match symbol { + TypedConstantSymbol::Here(tc) => { + format!("const {} {} = {}", tc.ty, key, tc.expression) + } + TypedConstantSymbol::There(module_id, id) => { + format!("from \"{}\" import {} as {}", module_id.display(), id, key) + } + }) + .chain(self.functions.iter().map(|(key, symbol)| match symbol { TypedFunctionSymbol::Here(ref function) => format!("def {}{}", key.id, function), TypedFunctionSymbol::There(ref fun_key) => format!( - "import {} from \"{}\" as {} // with signature {}", - fun_key.id, + "from \"{}\" import {} as {} // with signature {}", fun_key.module.display(), + fun_key.id, key.id, key.signature ), TypedFunctionSymbol::Flat(ref flat_fun) => { format!("def {}{}:\n\t// hidden", key.id, flat_fun.signature()) } - }) + })) .collect::>(); + write!(f, "{}", res.join("\n")) } } @@ -220,8 +229,13 @@ impl<'ast, T: fmt::Debug> fmt::Debug for TypedModule<'ast, T> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!( f, - "module(\n\tfunctions:\n\t\t{:?}\n)", + "TypedModule(\n\tFunctions:\n\t\t{:?}\n\tConstants:\n\t\t{:?}\n)", self.functions + .iter() + .map(|x| format!("{:?}", x)) + .collect::>() + .join("\n\t\t"), + self.constants .iter() .map(|x| format!("{:?}", x)) .collect::>() @@ -322,8 +336,14 @@ impl<'ast, T: fmt::Debug> fmt::Debug for TypedFunction<'ast, T> { #[derive(Clone, PartialEq)] pub struct TypedConstant<'ast, T> { - pub ty: Type<'ast, T>, - pub expression: TypedExpression<'ast, T>, + ty: Type<'ast, T>, + expression: TypedExpression<'ast, T>, +} + +impl<'ast, T> TypedConstant<'ast, T> { + pub fn new(ty: Type<'ast, T>, expression: TypedExpression<'ast, T>) -> Self { + TypedConstant { ty, expression } + } } impl<'ast, T: fmt::Debug> fmt::Debug for TypedConstant<'ast, T> { @@ -338,6 +358,12 @@ impl<'ast, T: fmt::Display> fmt::Display for TypedConstant<'ast, T> { } } +impl<'ast, T: Clone> Typed<'ast, T> for TypedConstant<'ast, T> { + fn get_type(&self) -> Type<'ast, T> { + self.ty.clone() + } +} + /// Something we can assign to. #[derive(Clone, PartialEq, Hash, Eq)] pub enum TypedAssignee<'ast, T> { @@ -1256,6 +1282,56 @@ impl<'ast, T> TryFrom> for StructExpression<'ast, T> { } } +impl<'ast, T> TryFrom> for FieldElementExpression<'ast, T> { + type Error = (); + + fn try_from( + tc: TypedConstant<'ast, T>, + ) -> Result, Self::Error> { + tc.expression.try_into() + } +} + +impl<'ast, T> TryFrom> for BooleanExpression<'ast, T> { + type Error = (); + + fn try_from(tc: TypedConstant<'ast, T>) -> Result, Self::Error> { + tc.expression.try_into() + } +} + +impl<'ast, T> TryFrom> for UExpression<'ast, T> { + type Error = (); + + fn try_from(tc: TypedConstant<'ast, T>) -> Result, Self::Error> { + tc.expression.try_into() + } +} + +impl<'ast, T> TryFrom> for ArrayExpression<'ast, T> { + type Error = (); + + fn try_from(tc: TypedConstant<'ast, T>) -> Result, Self::Error> { + tc.expression.try_into() + } +} + +impl<'ast, T> TryFrom> for StructExpression<'ast, T> { + type Error = (); + + fn try_from(tc: TypedConstant<'ast, T>) -> Result, Self::Error> { + tc.expression.try_into() + } +} + +impl<'ast, T> TryFrom> for IntExpression<'ast, T> { + type Error = (); + + fn try_from(tc: TypedConstant<'ast, T>) -> Result, Self::Error> { + tc.expression.try_into() + } +} + impl<'ast, T: fmt::Display> fmt::Display for FieldElementExpression<'ast, T> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match *self { diff --git a/zokrates_core/src/typed_absy/result_folder.rs b/zokrates_core/src/typed_absy/result_folder.rs index 82ae443a..a939b296 100644 --- a/zokrates_core/src/typed_absy/result_folder.rs +++ b/zokrates_core/src/typed_absy/result_folder.rs @@ -21,6 +21,13 @@ pub trait ResultFolder<'ast, T: Field>: Sized { fold_module(self, p) } + fn fold_constant( + &mut self, + s: TypedConstant<'ast, T>, + ) -> Result, Self::Error> { + fold_constant(self, s) + } + fn fold_constant_symbol( &mut self, s: TypedConstantSymbol<'ast, T>, @@ -800,15 +807,22 @@ pub fn fold_struct_expression<'ast, T: Field, F: ResultFolder<'ast, T>>( }) } +pub fn fold_constant<'ast, T: Field, F: ResultFolder<'ast, T>>( + f: &mut F, + c: TypedConstant<'ast, T>, +) -> Result, F::Error> { + Ok(TypedConstant { + expression: f.fold_expression(c.expression)?, + ..c + }) +} + pub fn fold_constant_symbol<'ast, T: Field, F: ResultFolder<'ast, T>>( f: &mut F, s: TypedConstantSymbol<'ast, T>, ) -> Result, F::Error> { match s { - TypedConstantSymbol::Here(tc) => Ok(TypedConstantSymbol::Here(TypedConstant { - expression: f.fold_expression(tc.expression)?, - ..tc - })), + TypedConstantSymbol::Here(tc) => Ok(TypedConstantSymbol::Here(f.fold_constant(tc)?)), there => Ok(there), } } From 231204d7a2288ecd4caca00d73bf90368d25170c Mon Sep 17 00:00:00 2001 From: dark64 Date: Mon, 26 Apr 2021 19:53:09 +0200 Subject: [PATCH 11/12] fold type --- .../src/static_analysis/constant_inliner.rs | 28 ++++++++----------- zokrates_core/src/typed_absy/folder.rs | 12 ++++---- zokrates_core/src/typed_absy/mod.rs | 4 +-- zokrates_core/src/typed_absy/result_folder.rs | 16 +++++------ 4 files changed, 27 insertions(+), 33 deletions(-) diff --git a/zokrates_core/src/static_analysis/constant_inliner.rs b/zokrates_core/src/static_analysis/constant_inliner.rs index 3e082fbf..67693282 100644 --- a/zokrates_core/src/static_analysis/constant_inliner.rs +++ b/zokrates_core/src/static_analysis/constant_inliner.rs @@ -71,14 +71,14 @@ impl<'ast, T: Field> Folder<'ast, T> for ConstantInliner<'ast, T> { } } - fn fold_module(&mut self, p: TypedModule<'ast, T>) -> TypedModule<'ast, T> { + fn fold_module(&mut self, m: TypedModule<'ast, T>) -> TypedModule<'ast, T> { TypedModule { - constants: p + constants: m .constants .into_iter() .map(|(key, tc)| (key, self.fold_constant_symbol(tc))) .collect(), - functions: p + functions: m .functions .into_iter() .map(|(key, fun)| (key, self.fold_function_symbol(fun))) @@ -100,10 +100,7 @@ impl<'ast, T: Field> Folder<'ast, T> for ConstantInliner<'ast, T> { ) -> FieldElementExpression<'ast, T> { match e { FieldElementExpression::Identifier(ref id) => match self.get_constant(id) { - Some(c) => { - let c = self.fold_constant(c); - fold_field_expression(self, c.try_into().unwrap()) - } + Some(c) => self.fold_constant(c).try_into().unwrap(), None => fold_field_expression(self, e), }, e => fold_field_expression(self, e), @@ -116,10 +113,7 @@ impl<'ast, T: Field> Folder<'ast, T> for ConstantInliner<'ast, T> { ) -> BooleanExpression<'ast, T> { match e { BooleanExpression::Identifier(ref id) => match self.get_constant(id) { - Some(c) => { - let c = self.fold_constant(c); - fold_boolean_expression(self, c.try_into().unwrap()) - } + Some(c) => self.fold_constant(c).try_into().unwrap(), None => fold_boolean_expression(self, e), }, e => fold_boolean_expression(self, e), @@ -134,8 +128,8 @@ impl<'ast, T: Field> Folder<'ast, T> for ConstantInliner<'ast, T> { match e { UExpressionInner::Identifier(ref id) => match self.get_constant(id) { Some(c) => { - let c = self.fold_constant(c); - fold_uint_expression(self, c.try_into().unwrap()).into_inner() + let e: UExpression<'ast, T> = self.fold_constant(c).try_into().unwrap(); + e.into_inner() } None => fold_uint_expression_inner(self, size, e), }, @@ -151,8 +145,8 @@ impl<'ast, T: Field> Folder<'ast, T> for ConstantInliner<'ast, T> { match e { ArrayExpressionInner::Identifier(ref id) => match self.get_constant(id) { Some(c) => { - let c = self.fold_constant(c); - fold_array_expression(self, c.try_into().unwrap()).into_inner() + let e: ArrayExpression<'ast, T> = self.fold_constant(c).try_into().unwrap(); + e.into_inner() } None => fold_array_expression_inner(self, ty, e), }, @@ -168,8 +162,8 @@ impl<'ast, T: Field> Folder<'ast, T> for ConstantInliner<'ast, T> { match e { StructExpressionInner::Identifier(ref id) => match self.get_constant(id) { Some(c) => { - let c = self.fold_constant(c); - fold_struct_expression(self, c.try_into().unwrap()).into_inner() + let e: StructExpression<'ast, T> = self.fold_constant(c).try_into().unwrap(); + e.into_inner() } None => fold_struct_expression_inner(self, ty, e), }, diff --git a/zokrates_core/src/typed_absy/folder.rs b/zokrates_core/src/typed_absy/folder.rs index 610aa710..db19977b 100644 --- a/zokrates_core/src/typed_absy/folder.rs +++ b/zokrates_core/src/typed_absy/folder.rs @@ -9,8 +9,8 @@ pub trait Folder<'ast, T: Field>: Sized { fold_program(self, p) } - fn fold_module(&mut self, p: TypedModule<'ast, T>) -> TypedModule<'ast, T> { - fold_module(self, p) + fn fold_module(&mut self, m: TypedModule<'ast, T>) -> TypedModule<'ast, T> { + fold_module(self, m) } fn fold_constant(&mut self, c: TypedConstant<'ast, T>) -> TypedConstant<'ast, T> { @@ -201,15 +201,15 @@ pub trait Folder<'ast, T: Field>: Sized { pub fn fold_module<'ast, T: Field, F: Folder<'ast, T>>( f: &mut F, - p: TypedModule<'ast, T>, + m: TypedModule<'ast, T>, ) -> TypedModule<'ast, T> { TypedModule { - functions: p + functions: m .functions .into_iter() .map(|(key, fun)| (key, f.fold_function_symbol(fun))) .collect(), - ..p + ..m } } @@ -728,8 +728,8 @@ pub fn fold_constant<'ast, T: Field, F: Folder<'ast, T>>( c: TypedConstant<'ast, T>, ) -> TypedConstant<'ast, T> { TypedConstant { + ty: f.fold_type(c.ty), expression: f.fold_expression(c.expression), - ..c } } diff --git a/zokrates_core/src/typed_absy/mod.rs b/zokrates_core/src/typed_absy/mod.rs index 408920d2..07b1391c 100644 --- a/zokrates_core/src/typed_absy/mod.rs +++ b/zokrates_core/src/typed_absy/mod.rs @@ -199,10 +199,10 @@ impl<'ast, T: fmt::Display> fmt::Display for TypedModule<'ast, T> { .constants .iter() .map(|(key, symbol)| match symbol { - TypedConstantSymbol::Here(tc) => { + TypedConstantSymbol::Here(ref tc) => { format!("const {} {} = {}", tc.ty, key, tc.expression) } - TypedConstantSymbol::There(module_id, id) => { + TypedConstantSymbol::There(ref module_id, ref id) => { format!("from \"{}\" import {} as {}", module_id.display(), id, key) } }) diff --git a/zokrates_core/src/typed_absy/result_folder.rs b/zokrates_core/src/typed_absy/result_folder.rs index a939b296..c42d8df3 100644 --- a/zokrates_core/src/typed_absy/result_folder.rs +++ b/zokrates_core/src/typed_absy/result_folder.rs @@ -16,16 +16,16 @@ pub trait ResultFolder<'ast, T: Field>: Sized { fn fold_module( &mut self, - p: TypedModule<'ast, T>, + m: TypedModule<'ast, T>, ) -> Result, Self::Error> { - fold_module(self, p) + fold_module(self, m) } fn fold_constant( &mut self, - s: TypedConstant<'ast, T>, + c: TypedConstant<'ast, T>, ) -> Result, Self::Error> { - fold_constant(self, s) + fold_constant(self, c) } fn fold_constant_symbol( @@ -812,8 +812,8 @@ pub fn fold_constant<'ast, T: Field, F: ResultFolder<'ast, T>>( c: TypedConstant<'ast, T>, ) -> Result, F::Error> { Ok(TypedConstant { + ty: f.fold_type(c.ty)?, expression: f.fold_expression(c.expression)?, - ..c }) } @@ -839,15 +839,15 @@ pub fn fold_function_symbol<'ast, T: Field, F: ResultFolder<'ast, T>>( pub fn fold_module<'ast, T: Field, F: ResultFolder<'ast, T>>( f: &mut F, - p: TypedModule<'ast, T>, + m: TypedModule<'ast, T>, ) -> Result, F::Error> { Ok(TypedModule { - functions: p + functions: m .functions .into_iter() .map(|(key, fun)| f.fold_function_symbol(fun).map(|f| (key, f))) .collect::>()?, - ..p + ..m }) } From 9ec3ac3cf779a9fc78a292693fe9428e30a1ea8d Mon Sep 17 00:00:00 2001 From: dark64 Date: Tue, 27 Apr 2021 13:34:16 +0200 Subject: [PATCH 12/12] visit constants in default folder --- .../src/static_analysis/constant_inliner.rs | 15 --------------- zokrates_core/src/typed_absy/folder.rs | 6 +++++- zokrates_core/src/typed_absy/result_folder.rs | 6 +++++- 3 files changed, 10 insertions(+), 17 deletions(-) diff --git a/zokrates_core/src/static_analysis/constant_inliner.rs b/zokrates_core/src/static_analysis/constant_inliner.rs index 67693282..360927d7 100644 --- a/zokrates_core/src/static_analysis/constant_inliner.rs +++ b/zokrates_core/src/static_analysis/constant_inliner.rs @@ -71,21 +71,6 @@ impl<'ast, T: Field> Folder<'ast, T> for ConstantInliner<'ast, T> { } } - fn fold_module(&mut self, m: TypedModule<'ast, T>) -> TypedModule<'ast, T> { - TypedModule { - constants: m - .constants - .into_iter() - .map(|(key, tc)| (key, self.fold_constant_symbol(tc))) - .collect(), - functions: m - .functions - .into_iter() - .map(|(key, fun)| (key, self.fold_function_symbol(fun))) - .collect(), - } - } - fn fold_constant_symbol( &mut self, s: TypedConstantSymbol<'ast, T>, diff --git a/zokrates_core/src/typed_absy/folder.rs b/zokrates_core/src/typed_absy/folder.rs index db19977b..36c1453f 100644 --- a/zokrates_core/src/typed_absy/folder.rs +++ b/zokrates_core/src/typed_absy/folder.rs @@ -204,12 +204,16 @@ pub fn fold_module<'ast, T: Field, F: Folder<'ast, T>>( m: TypedModule<'ast, T>, ) -> TypedModule<'ast, T> { TypedModule { + constants: m + .constants + .into_iter() + .map(|(key, tc)| (key, f.fold_constant_symbol(tc))) + .collect(), functions: m .functions .into_iter() .map(|(key, fun)| (key, f.fold_function_symbol(fun))) .collect(), - ..m } } diff --git a/zokrates_core/src/typed_absy/result_folder.rs b/zokrates_core/src/typed_absy/result_folder.rs index c42d8df3..245ab2a5 100644 --- a/zokrates_core/src/typed_absy/result_folder.rs +++ b/zokrates_core/src/typed_absy/result_folder.rs @@ -842,12 +842,16 @@ pub fn fold_module<'ast, T: Field, F: ResultFolder<'ast, T>>( m: TypedModule<'ast, T>, ) -> Result, F::Error> { Ok(TypedModule { + constants: m + .constants + .into_iter() + .map(|(key, tc)| f.fold_constant_symbol(tc).map(|tc| (key, tc))) + .collect::>()?, functions: m .functions .into_iter() .map(|(key, fun)| f.fold_function_symbol(fun).map(|f| (key, f))) .collect::>()?, - ..m }) }