diff --git a/changelogs/unreleased/792-dark64 b/changelogs/unreleased/792-dark64 new file mode 100644 index 00000000..fbc1a5dc --- /dev/null +++ b/changelogs/unreleased/792-dark64 @@ -0,0 +1 @@ +Introduce constant definitions to the language (`const` keyword) \ No newline at end of file diff --git a/zokrates_book/src/SUMMARY.md b/zokrates_book/src/SUMMARY.md index 6ccefb31..7755bbf7 100644 --- a/zokrates_book/src/SUMMARY.md +++ b/zokrates_book/src/SUMMARY.md @@ -8,11 +8,12 @@ - [Variables](language/variables.md) - [Types](language/types.md) - [Operators](language/operators.md) - - [Functions](language/functions.md) - [Control flow](language/control_flow.md) + - [Constants](language/constants.md) + - [Functions](language/functions.md) + - [Generics](language/generics.md) - [Imports](language/imports.md) - [Comments](language/comments.md) - - [Generics](language/generics.md) - [Macros](language/macros.md) - [Toolbox](toolbox/index.md) diff --git a/zokrates_book/src/language/constants.md b/zokrates_book/src/language/constants.md new file mode 100644 index 00000000..a6c16db7 --- /dev/null +++ b/zokrates_book/src/language/constants.md @@ -0,0 +1,17 @@ +## Constants + +Constants must be globally defined outside all other scopes by using a `const` keyword. Constants can be set only to a constant expression. + +```zokrates +{{#include ../../../zokrates_cli/examples/book/constant_definition.zok}} +``` + +The value of a constant can't be changed through reassignment, and it can't be redeclared. + +Constants must be explicitly typed. One can reference other constants inside the expression, as long as the referenced constant is already defined. + +```zokrates +{{#include ../../../zokrates_cli/examples/book/constant_reference.zok}} +``` + +The naming convention for constants are similar to that of variables. All characters in a constant name are usually in uppercase. \ No newline at end of file diff --git a/zokrates_book/src/language/imports.md b/zokrates_book/src/language/imports.md index fef1083d..8d4d17de 100644 --- a/zokrates_book/src/language/imports.md +++ b/zokrates_book/src/language/imports.md @@ -44,7 +44,7 @@ from "./path/to/my/module" import main as module Note that this legacy method is likely to become deprecated, so it is recommended to use the preferred way instead. ### Symbols -Two types of symbols can be imported +Three types of symbols can be imported #### Functions Functions are imported by name. If many functions have the same name but different signatures, all of them get imported, and which one to use in a particular call is inferred. @@ -52,6 +52,9 @@ Functions are imported by name. If many functions have the same name but differe #### User-defined types User-defined types declared with the `struct` keyword are imported by name. +#### Constants +Constants declared with the `const` keyword are imported by name. + ### Relative Imports You can import a resource in the same folder directly, like this: diff --git a/zokrates_cli/examples/book/constant_definition.zok b/zokrates_cli/examples/book/constant_definition.zok new file mode 100644 index 00000000..10b31ec7 --- /dev/null +++ b/zokrates_cli/examples/book/constant_definition.zok @@ -0,0 +1,4 @@ +const field PRIME = 31 + +def main() -> field: + return PRIME \ No newline at end of file diff --git a/zokrates_cli/examples/book/constant_reference.zok b/zokrates_cli/examples/book/constant_reference.zok new file mode 100644 index 00000000..b6850f89 --- /dev/null +++ b/zokrates_cli/examples/book/constant_reference.zok @@ -0,0 +1,5 @@ +const field ONE = 1 +const field TWO = ONE + ONE + +def main() -> field: + return TWO \ No newline at end of file diff --git a/zokrates_cli/examples/compile_errors/constant_assignment.zok b/zokrates_cli/examples/compile_errors/constant_assignment.zok new file mode 100644 index 00000000..e04da9f6 --- /dev/null +++ b/zokrates_cli/examples/compile_errors/constant_assignment.zok @@ -0,0 +1,5 @@ +const field a = 1 + +def main() -> field: + a = 2 // not allowed + return a \ No newline at end of file diff --git a/zokrates_cli/examples/functions/no_args_multiple.zok b/zokrates_cli/examples/functions/no_args_multiple.zok index cef72d30..c70ad021 100644 --- a/zokrates_cli/examples/functions/no_args_multiple.zok +++ b/zokrates_cli/examples/functions/no_args_multiple.zok @@ -1,10 +1,10 @@ -def const() -> field: +def constant() -> field: return 123123 -def add(field a,field b) -> field: - a=const() - return a+b +def add(field a, field b) -> field: + a = constant() + return a + b -def main(field a,field b) -> field: - field c = add(a, b+const()) - return const() +def main(field a, field b) -> field: + field c = add(a, b + constant()) + return constant() diff --git a/zokrates_cli/examples/imports/bar.zok b/zokrates_cli/examples/imports/bar.zok index c7d3af12..34eb1c5a 100644 --- a/zokrates_cli/examples/imports/bar.zok +++ b/zokrates_cli/examples/imports/bar.zok @@ -1,5 +1,7 @@ -struct Bar { -} +struct Bar {} + +const field ONE = 1 +const field BAR = 21 * ONE def main() -> field: - return 21 \ No newline at end of file + return BAR \ No newline at end of file diff --git a/zokrates_cli/examples/imports/baz.zok b/zokrates_cli/examples/imports/baz.zok index 9fd704a3..84cc641d 100644 --- a/zokrates_cli/examples/imports/baz.zok +++ b/zokrates_cli/examples/imports/baz.zok @@ -1,5 +1,6 @@ -struct Baz { -} +struct Baz {} + +const field BAZ = 123 def main() -> field: - return 123 \ No newline at end of file + return BAZ \ No newline at end of file diff --git a/zokrates_cli/examples/imports/foo.zok b/zokrates_cli/examples/imports/foo.zok index 43018b20..8dddfd8b 100644 --- a/zokrates_cli/examples/imports/foo.zok +++ b/zokrates_cli/examples/imports/foo.zok @@ -1,9 +1,10 @@ from "./baz" import Baz - -import "./baz" from "./baz" import main as my_function +import "./baz" + +const field FOO = 144 def main() -> field: - field a = my_function() - Baz b = Baz {} - return baz() \ No newline at end of file + Baz b = Baz {} + assert(baz() == my_function()) + return FOO \ No newline at end of file diff --git a/zokrates_cli/examples/imports/import.zok b/zokrates_cli/examples/imports/import.zok deleted file mode 100644 index bc4e7669..00000000 --- a/zokrates_cli/examples/imports/import.zok +++ /dev/null @@ -1,11 +0,0 @@ -from "./bar" import Bar as MyBar -from "./bar" import Bar - -import "./foo" -import "./bar" - -def main() -> field: - MyBar my_bar = MyBar {} - Bar bar = Bar {} - assert(my_bar == bar) - return foo() + bar() \ No newline at end of file diff --git a/zokrates_cli/examples/imports/import_constants.zok b/zokrates_cli/examples/imports/import_constants.zok new file mode 100644 index 00000000..2abaaea5 --- /dev/null +++ b/zokrates_cli/examples/imports/import_constants.zok @@ -0,0 +1,6 @@ +from "./foo" import FOO +from "./bar" import BAR +from "./baz" import BAZ + +def main() -> bool: + return FOO == BAR + BAZ \ No newline at end of file diff --git a/zokrates_cli/examples/imports/import_functions.zok b/zokrates_cli/examples/imports/import_functions.zok new file mode 100644 index 00000000..32628eb9 --- /dev/null +++ b/zokrates_cli/examples/imports/import_functions.zok @@ -0,0 +1,6 @@ +import "./foo" +import "./bar" +import "./baz" + +def main() -> bool: + return foo() == bar() + baz() \ No newline at end of file diff --git a/zokrates_cli/examples/imports/import_structs.zok b/zokrates_cli/examples/imports/import_structs.zok new file mode 100644 index 00000000..61c94b02 --- /dev/null +++ b/zokrates_cli/examples/imports/import_structs.zok @@ -0,0 +1,8 @@ +from "./bar" import Bar as MyBar +from "./bar" import Bar + +def main(): + MyBar my_bar = MyBar {} + Bar bar = Bar {} + assert(my_bar == bar) + return \ No newline at end of file diff --git a/zokrates_cli/examples/imports/import_with_alias.zok b/zokrates_cli/examples/imports/import_with_alias.zok index e9bb2194..5e013691 100644 --- a/zokrates_cli/examples/imports/import_with_alias.zok +++ b/zokrates_cli/examples/imports/import_with_alias.zok @@ -1,4 +1,8 @@ -import "./foo" as d +from "./bar" import main as bar +from "./baz" import BAZ as baz +import "./foo" as f def main() -> field: - return d() \ No newline at end of file + field foo = f() + assert(foo == bar() + baz) + return foo \ No newline at end of file diff --git a/zokrates_core/src/absy/from_ast.rs b/zokrates_core/src/absy/from_ast.rs index 1ba9b958..652932c6 100644 --- a/zokrates_core/src/absy/from_ast.rs +++ b/zokrates_core/src/absy/from_ast.rs @@ -1,6 +1,7 @@ use crate::absy; use crate::imports; +use crate::absy::SymbolDefinition; use num_bigint::BigUint; use zokrates_pest_ast as pest; @@ -10,6 +11,11 @@ impl<'ast> From> for absy::Module<'ast> { prog.structs .into_iter() .map(absy::SymbolDeclarationNode::from) + .chain( + prog.constants + .into_iter() + .map(absy::SymbolDeclarationNode::from), + ) .chain( prog.functions .into_iter() @@ -78,7 +84,7 @@ impl<'ast> From> for absy::SymbolDeclarationNode<'a absy::SymbolDeclaration { id, - symbol: absy::Symbol::HereType(ty), + symbol: absy::Symbol::Here(SymbolDefinition::Struct(ty)), } .span(span) } @@ -98,6 +104,27 @@ impl<'ast> From> for absy::StructDefinitionFieldNode<'as } } +impl<'ast> From> for absy::SymbolDeclarationNode<'ast> { + fn from(definition: pest::ConstantDefinition<'ast>) -> absy::SymbolDeclarationNode<'ast> { + use crate::absy::NodeValue; + + let span = definition.span; + let id = definition.id.span.as_str(); + + let ty = absy::ConstantDefinition { + ty: definition.ty.into(), + expression: definition.expression.into(), + } + .span(span.clone()); + + absy::SymbolDeclaration { + id, + symbol: absy::Symbol::Here(SymbolDefinition::Constant(ty)), + } + .span(span) + } +} + impl<'ast> From> for absy::SymbolDeclarationNode<'ast> { fn from(function: pest::Function<'ast>) -> absy::SymbolDeclarationNode<'ast> { use crate::absy::NodeValue; @@ -148,7 +175,7 @@ impl<'ast> From> for absy::SymbolDeclarationNode<'ast> { absy::SymbolDeclaration { id, - symbol: absy::Symbol::HereFunction(function), + symbol: absy::Symbol::Here(SymbolDefinition::Function(function)), } .span(span) } @@ -754,7 +781,7 @@ mod tests { let expected: absy::Module = absy::Module { symbols: vec![absy::SymbolDeclaration { id: &source[4..8], - symbol: absy::Symbol::HereFunction( + symbol: absy::Symbol::Here(SymbolDefinition::Function( absy::Function { arguments: vec![], statements: vec![absy::Statement::Return( @@ -771,7 +798,7 @@ mod tests { .outputs(vec![UnresolvedType::FieldElement.mock()]), } .into(), - ), + )), } .into()], imports: vec![], @@ -786,7 +813,7 @@ mod tests { let expected: absy::Module = absy::Module { symbols: vec![absy::SymbolDeclaration { id: &source[4..8], - symbol: absy::Symbol::HereFunction( + symbol: absy::Symbol::Here(SymbolDefinition::Function( absy::Function { arguments: vec![], statements: vec![absy::Statement::Return( @@ -801,7 +828,7 @@ mod tests { .outputs(vec![UnresolvedType::Boolean.mock()]), } .into(), - ), + )), } .into()], imports: vec![], @@ -817,7 +844,7 @@ mod tests { let expected: absy::Module = absy::Module { symbols: vec![absy::SymbolDeclaration { id: &source[4..8], - symbol: absy::Symbol::HereFunction( + symbol: absy::Symbol::Here(SymbolDefinition::Function( absy::Function { arguments: vec![ absy::Parameter::private( @@ -854,7 +881,7 @@ mod tests { .outputs(vec![UnresolvedType::FieldElement.mock()]), } .into(), - ), + )), } .into()], imports: vec![], @@ -871,7 +898,7 @@ mod tests { absy::Module { symbols: vec![absy::SymbolDeclaration { id: "main", - symbol: absy::Symbol::HereFunction( + symbol: absy::Symbol::Here(SymbolDefinition::Function( absy::Function { arguments: vec![absy::Parameter::private( absy::Variable::new("a", ty.clone().mock()).into(), @@ -887,7 +914,7 @@ mod tests { signature: UnresolvedSignature::new().inputs(vec![ty.mock()]), } .into(), - ), + )), } .into()], imports: vec![], @@ -945,7 +972,7 @@ mod tests { absy::Module { symbols: vec![absy::SymbolDeclaration { id: "main", - symbol: absy::Symbol::HereFunction( + symbol: absy::Symbol::Here(SymbolDefinition::Function( absy::Function { arguments: vec![], statements: vec![absy::Statement::Return( @@ -958,7 +985,7 @@ mod tests { signature: UnresolvedSignature::new(), } .into(), - ), + )), } .into()], imports: vec![], diff --git a/zokrates_core/src/absy/mod.rs b/zokrates_core/src/absy/mod.rs index 798bb93b..d070860d 100644 --- a/zokrates_core/src/absy/mod.rs +++ b/zokrates_core/src/absy/mod.rs @@ -51,10 +51,27 @@ pub struct SymbolDeclaration<'ast> { pub symbol: Symbol<'ast>, } +#[allow(clippy::large_enum_variant)] +#[derive(PartialEq, Clone)] +pub enum SymbolDefinition<'ast> { + Struct(StructDefinitionNode<'ast>), + Constant(ConstantDefinitionNode<'ast>), + Function(FunctionNode<'ast>), +} + +impl<'ast> fmt::Debug for SymbolDefinition<'ast> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + SymbolDefinition::Struct(s) => write!(f, "Struct({:?})", s), + SymbolDefinition::Constant(c) => write!(f, "Constant({:?})", c), + SymbolDefinition::Function(func) => write!(f, "Function({:?})", func), + } + } +} + #[derive(PartialEq, Clone)] pub enum Symbol<'ast> { - HereType(StructDefinitionNode<'ast>), - HereFunction(FunctionNode<'ast>), + Here(SymbolDefinition<'ast>), There(SymbolImportNode<'ast>), Flat(FlatEmbed), } @@ -62,9 +79,8 @@ pub enum Symbol<'ast> { impl<'ast> fmt::Debug for Symbol<'ast> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { - Symbol::HereType(t) => write!(f, "HereType({:?})", t), - Symbol::HereFunction(fun) => write!(f, "HereFunction({:?})", fun), - Symbol::There(t) => write!(f, "There({:?})", t), + Symbol::Here(k) => write!(f, "Here({:?})", k), + Symbol::There(i) => write!(f, "There({:?})", i), Symbol::Flat(flat) => write!(f, "Flat({:?})", flat), } } @@ -73,8 +89,15 @@ impl<'ast> fmt::Debug for Symbol<'ast> { impl<'ast> fmt::Display for SymbolDeclaration<'ast> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self.symbol { - Symbol::HereType(ref t) => write!(f, "struct {} {}", self.id, t), - Symbol::HereFunction(ref fun) => write!(f, "def {}{}", self.id, fun), + Symbol::Here(ref kind) => match kind { + SymbolDefinition::Struct(t) => write!(f, "struct {} {}", self.id, t), + SymbolDefinition::Constant(c) => write!( + f, + "const {} {} = {}", + c.value.ty, self.id, c.value.expression + ), + SymbolDefinition::Function(func) => write!(f, "def {}{}", self.id, func), + }, Symbol::There(ref import) => write!(f, "import {} as {}", import, self.id), Symbol::Flat(ref flat_fun) => { write!(f, "def {}{}:\n\t// hidden", self.id, flat_fun.signature()) @@ -146,6 +169,30 @@ impl<'ast> fmt::Display for StructDefinitionField<'ast> { type StructDefinitionFieldNode<'ast> = Node>; +#[derive(Clone, PartialEq)] +pub struct ConstantDefinition<'ast> { + pub ty: UnresolvedTypeNode<'ast>, + pub expression: ExpressionNode<'ast>, +} + +pub type ConstantDefinitionNode<'ast> = Node>; + +impl<'ast> fmt::Display for ConstantDefinition<'ast> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "const {}({})", self.ty, self.expression) + } +} + +impl<'ast> fmt::Debug for ConstantDefinition<'ast> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!( + f, + "ConstantDefinition({:?}, {:?})", + self.ty, self.expression + ) + } +} + /// An import #[derive(Debug, Clone, PartialEq)] pub struct SymbolImport<'ast> { diff --git a/zokrates_core/src/absy/node.rs b/zokrates_core/src/absy/node.rs index 745054b4..5b949d49 100644 --- a/zokrates_core/src/absy/node.rs +++ b/zokrates_core/src/absy/node.rs @@ -84,6 +84,7 @@ impl<'ast> NodeValue for SymbolDeclaration<'ast> {} impl<'ast> NodeValue for UnresolvedType<'ast> {} impl<'ast> NodeValue for StructDefinition<'ast> {} impl<'ast> NodeValue for StructDefinitionField<'ast> {} +impl<'ast> NodeValue for ConstantDefinition<'ast> {} impl<'ast> NodeValue for Function<'ast> {} impl<'ast> NodeValue for Module<'ast> {} impl<'ast> NodeValue for SymbolImport<'ast> {} diff --git a/zokrates_core/src/semantics.rs b/zokrates_core/src/semantics.rs index 1f6a3727..5e434b3c 100644 --- a/zokrates_core/src/semantics.rs +++ b/zokrates_core/src/semantics.rs @@ -47,6 +47,8 @@ impl ErrorInner { } type TypeMap<'ast> = HashMap>>; +type ConstantMap<'ast, T> = + HashMap, Type<'ast, T>>>; /// The global state of the program during semantic checks #[derive(Debug)] @@ -57,12 +59,15 @@ struct State<'ast, T> { typed_modules: TypedModules<'ast, T>, /// The user-defined types, which we keep track at this phase only. In later phases, we rely only on basic types and combinations thereof types: TypeMap<'ast>, + // The user-defined constants + constants: ConstantMap<'ast, T>, } /// A symbol for a given name: either a type or a group of functions. Not both! #[derive(PartialEq, Hash, Eq, Debug)] enum SymbolType<'ast> { Type, + Constant, Functions(BTreeSet>), } @@ -74,9 +79,9 @@ struct SymbolUnifier<'ast> { impl<'ast> SymbolUnifier<'ast> { fn insert_type>(&mut self, id: S) -> bool { - let s_type = self.symbols.entry(id.into()); - match s_type { - // if anything is already called `id`, we cannot introduce this type + let e = self.symbols.entry(id.into()); + match e { + // if anything is already called `id`, we cannot introduce the symbol Entry::Occupied(..) => false, // otherwise, we can! Entry::Vacant(v) => { @@ -86,6 +91,19 @@ impl<'ast> SymbolUnifier<'ast> { } } + fn insert_constant>(&mut self, id: S) -> bool { + let e = self.symbols.entry(id.into()); + match e { + // if anything is already called `id`, we cannot introduce this constant + Entry::Occupied(..) => false, + // otherwise, we can! + Entry::Vacant(v) => { + v.insert(SymbolType::Constant); + true + } + } + } + fn insert_function>( &mut self, id: S, @@ -96,8 +114,8 @@ impl<'ast> SymbolUnifier<'ast> { // if anything is already called `id`, it depends what it is Entry::Occupied(mut o) => { match o.get_mut() { - // if it's a Type, then we can't introduce a function - SymbolType::Type => false, + // if it's a Type or a Constant, then we can't introduce a function + SymbolType::Type | SymbolType::Constant => false, // if it's a Function, we can introduce it only if it has a different signature SymbolType::Functions(signatures) => signatures.insert(signature), } @@ -117,6 +135,7 @@ impl<'ast, T: Field> State<'ast, T> { modules, typed_modules: HashMap::new(), types: HashMap::new(), + constants: HashMap::new(), } } } @@ -248,6 +267,12 @@ pub struct ScopedVariable<'ast, T> { level: usize, } +impl<'ast, T> ScopedVariable<'ast, T> { + fn is_constant(&self) -> bool { + self.level == 0 + } +} + /// Identifiers of different `ScopedVariable`s should not conflict, so we define them as equivalent impl<'ast, T> PartialEq for ScopedVariable<'ast, T> { fn eq(&self, other: &Self) -> bool { @@ -325,6 +350,50 @@ impl<'ast, T: Field> Checker<'ast, T> { }) } + fn check_constant_definition( + &mut self, + id: &'ast str, + c: ConstantDefinitionNode<'ast>, + module_id: &ModuleId, + types: &TypeMap<'ast>, + ) -> Result, ErrorInner> { + let pos = c.pos(); + let ty = self.check_type(c.value.ty.clone(), module_id, &types)?; + let checked_expr = self.check_expression(c.value.expression.clone(), module_id, types)?; + + match ty { + Type::FieldElement => { + FieldElementExpression::try_from_typed(checked_expr).map(TypedExpression::from) + } + Type::Boolean => { + BooleanExpression::try_from_typed(checked_expr).map(TypedExpression::from) + } + Type::Uint(bitwidth) => { + UExpression::try_from_typed(checked_expr, bitwidth).map(TypedExpression::from) + } + Type::Array(ref array_ty) => { + ArrayExpression::try_from_typed(checked_expr, *array_ty.ty.clone()) + .map(TypedExpression::from) + } + Type::Struct(ref struct_ty) => { + StructExpression::try_from_typed(checked_expr, struct_ty.clone()) + .map(TypedExpression::from) + } + Type::Int => Err(checked_expr), // Integers cannot be assigned + } + .map_err(|e| ErrorInner { + pos: Some(pos), + message: format!( + "Expression `{}` of type `{}` cannot be assigned to constant `{}` of type `{}`", + e, + e.get_type(), + id, + ty + ), + }) + .map(|e| TypedConstant::new(ty, e)) + } + fn check_struct_type_declaration( &mut self, id: String, @@ -378,6 +447,7 @@ impl<'ast, T: Field> Checker<'ast, T> { module_id: &ModuleId, state: &mut State<'ast, T>, functions: &mut HashMap, TypedFunctionSymbol<'ast, T>>, + constants: &mut HashMap, TypedConstantSymbol<'ast, T>>, symbol_unifier: &mut SymbolUnifier<'ast>, ) -> Result<(), Vec> { let mut errors: Vec = vec![]; @@ -386,76 +456,120 @@ impl<'ast, T: Field> Checker<'ast, T> { let declaration = declaration.value; match declaration.symbol.clone() { - Symbol::HereType(t) => { - match self.check_struct_type_declaration( - declaration.id.to_string(), - t.clone(), - module_id, - &state.types, - ) { - Ok(ty) => { - match symbol_unifier.insert_type(declaration.id) { - false => errors.push( - ErrorInner { - pos: Some(pos), - message: format!( - "{} conflicts with another symbol", - declaration.id, - ), - } - .in_file(module_id), - ), - true => { - // there should be no entry in the map for this type yet - assert!(state - .types - .entry(module_id.to_path_buf()) - .or_default() - .insert(declaration.id.to_string(), ty) - .is_none()); - } - }; - } - Err(e) => errors.extend(e.into_iter().map(|inner| Error { - inner, - module_id: module_id.to_path_buf(), - })), - } - } - Symbol::HereFunction(f) => match self.check_function(f, module_id, &state.types) { - Ok(funct) => { - match symbol_unifier.insert_function(declaration.id, funct.signature.clone()) { - false => errors.push( - ErrorInner { - pos: Some(pos), - message: format!( - "{} conflicts with another symbol", - declaration.id, + Symbol::Here(kind) => match kind { + SymbolDefinition::Struct(t) => { + match self.check_struct_type_declaration( + declaration.id.to_string(), + t.clone(), + module_id, + &state.types, + ) { + Ok(ty) => { + match symbol_unifier.insert_type(declaration.id) { + false => errors.push( + ErrorInner { + pos: Some(pos), + message: format!( + "{} conflicts with another symbol", + declaration.id, + ), + } + .in_file(module_id), ), - } - .in_file(module_id), - ), - true => {} - }; - - self.functions.insert( - DeclarationFunctionKey::with_location( - module_id.to_path_buf(), - declaration.id, - ) - .signature(funct.signature.clone()), - ); - functions.insert( - DeclarationFunctionKey::with_location( - module_id.to_path_buf(), - declaration.id, - ) - .signature(funct.signature.clone()), - TypedFunctionSymbol::Here(funct), - ); + true => { + // there should be no entry in the map for this type yet + assert!(state + .types + .entry(module_id.to_path_buf()) + .or_default() + .insert(declaration.id.to_string(), ty) + .is_none()); + } + }; + } + Err(e) => errors.extend(e.into_iter().map(|inner| Error { + inner, + module_id: module_id.to_path_buf(), + })), + } } - Err(e) => { - errors.extend(e.into_iter().map(|inner| inner.in_file(module_id))); + SymbolDefinition::Constant(c) => { + match self.check_constant_definition(declaration.id, c, module_id, &state.types) + { + Ok(c) => { + match symbol_unifier.insert_constant(declaration.id) { + false => errors.push( + ErrorInner { + pos: Some(pos), + message: format!( + "{} conflicts with another symbol", + declaration.id, + ), + } + .in_file(module_id), + ), + true => { + constants.insert( + declaration.id, + TypedConstantSymbol::Here(c.clone()), + ); + self.insert_into_scope(Variable::with_id_and_type( + declaration.id, + c.get_type(), + )); + assert!(state + .constants + .entry(module_id.to_path_buf()) + .or_default() + .insert(declaration.id, c.get_type()) + .is_none()); + } + }; + } + Err(e) => { + errors.push(e.in_file(module_id)); + } + } + } + SymbolDefinition::Function(f) => { + match self.check_function(f, module_id, &state.types) { + Ok(funct) => { + match symbol_unifier + .insert_function(declaration.id, funct.signature.clone()) + { + false => errors.push( + ErrorInner { + pos: Some(pos), + message: format!( + "{} conflicts with another symbol", + declaration.id, + ), + } + .in_file(module_id), + ), + true => {} + }; + + self.functions.insert( + DeclarationFunctionKey::with_location( + module_id.to_path_buf(), + declaration.id, + ) + .signature(funct.signature.clone()), + ); + functions.insert( + DeclarationFunctionKey::with_location( + module_id.to_path_buf(), + declaration.id, + ) + .signature(funct.signature.clone()), + TypedFunctionSymbol::Here(funct), + ); + } + Err(e) => { + errors.extend(e.into_iter().map(|inner| inner.in_file(module_id))); + } + } } }, Symbol::There(import) => { @@ -487,8 +601,16 @@ impl<'ast, T: Field> Checker<'ast, T> { .get(import.symbol_id) .cloned(); - match (function_candidates.len(), type_candidate) { - (0, Some(t)) => { + // find constant definition candidate + let const_candidate = state + .constants + .entry(import.module_id.to_path_buf()) + .or_default() + .get(import.symbol_id) + .cloned(); + + match (function_candidates.len(), type_candidate, const_candidate) { + (0, Some(t), None) => { // rename the type to the declared symbol let t = match t { @@ -523,7 +645,32 @@ impl<'ast, T: Field> Checker<'ast, T> { .or_default() .insert(declaration.id.to_string(), t); } - (0, None) => { + (0, None, Some(ty)) => { + match symbol_unifier.insert_constant(declaration.id) { + false => { + errors.push(Error { + module_id: module_id.to_path_buf(), + inner: ErrorInner { + pos: Some(pos), + message: format!( + "{} conflicts with another symbol", + declaration.id, + ), + }}); + } + true => { + constants.insert(declaration.id, TypedConstantSymbol::There(import.module_id, import.symbol_id)); + self.insert_into_scope(Variable::with_id_and_type(declaration.id, ty.clone())); + + state + .constants + .entry(module_id.to_path_buf()) + .or_default() + .insert(declaration.id, ty); + } + }; + } + (0, None, None) => { errors.push(ErrorInner { pos: Some(pos), message: format!( @@ -532,7 +679,7 @@ impl<'ast, T: Field> Checker<'ast, T> { ), }.in_file(module_id)); } - (_, Some(_)) => unreachable!("collision in module we're importing from should have been caught when checking it"), + (_, Some(_), Some(_)) => unreachable!("collision in module we're importing from should have been caught when checking it"), _ => { for candidate in function_candidates { @@ -609,6 +756,7 @@ impl<'ast, T: Field> Checker<'ast, T> { state: &mut State<'ast, T>, ) -> Result<(), Vec> { let mut checked_functions = HashMap::new(); + let mut checked_constants = HashMap::new(); // check if the module was already removed from the untyped ones let to_insert = match state.modules.remove(module_id) { @@ -621,7 +769,7 @@ impl<'ast, T: Field> Checker<'ast, T> { // we need to create an entry in the types map to store types for this module state.types.entry(module_id.to_path_buf()).or_default(); - // we keep track of the introduced symbols to avoid colisions between types and functions + // we keep track of the introduced symbols to avoid collisions between types and functions let mut symbol_unifier = SymbolUnifier::default(); // we go through symbol declarations and check them @@ -631,12 +779,14 @@ impl<'ast, T: Field> Checker<'ast, T> { module_id, state, &mut checked_functions, + &mut checked_constants, &mut symbol_unifier, )? } Some(TypedModule { functions: checked_functions, + constants: checked_constants, }) } }; @@ -688,7 +838,6 @@ impl<'ast, T: Field> Checker<'ast, T> { module_id: &ModuleId, types: &TypeMap<'ast>, ) -> Result, Vec> { - assert!(self.scope.is_empty()); assert!(self.return_types.is_none()); self.enter_scope(); @@ -815,7 +964,6 @@ impl<'ast, T: Field> Checker<'ast, T> { } self.return_types = None; - assert!(self.scope.is_empty()); Ok(TypedFunction { arguments: arguments_checked, @@ -1275,7 +1423,7 @@ impl<'ast, T: Field> Checker<'ast, T> { .map_err(|e| ErrorInner { pos: Some(pos), message: format!( - "Expression {} of type {} cannot be assigned to {} of type {}", + "Expression `{}` of type `{}` cannot be assigned to `{}` of type `{}`", e, e.get_type(), var.clone(), @@ -1408,10 +1556,16 @@ impl<'ast, T: Field> Checker<'ast, T> { // check that the assignee is declared match assignee.value { Assignee::Identifier(variable_name) => match self.get_scope(&variable_name) { - Some(var) => Ok(TypedAssignee::Identifier(Variable::with_id_and_type( - variable_name, - var.id._type.clone(), - ))), + Some(var) => match var.is_constant() { + true => Err(ErrorInner { + pos: Some(assignee.pos()), + message: format!("Assignment to constant variable `{}`", variable_name), + }), + false => Ok(TypedAssignee::Identifier(Variable::with_id_and_type( + variable_name, + var.id._type.clone(), + ))), + }, None => Err(ErrorInner { pos: Some(assignee.pos()), message: format!("Variable `{}` is undeclared", variable_name), @@ -3094,7 +3248,7 @@ mod tests { let foo: Module = Module { symbols: vec![SymbolDeclaration { id: "main", - symbol: Symbol::HereFunction(function0()), + symbol: Symbol::Here(SymbolDefinition::Function(function0())), } .mock()], imports: vec![], @@ -3134,6 +3288,7 @@ mod tests { )] .into_iter() .collect(), + constants: TypedConstantSymbols::default() }) ); } @@ -3151,12 +3306,12 @@ mod tests { symbols: vec![ SymbolDeclaration { id: "foo", - symbol: Symbol::HereFunction(function0()), + symbol: Symbol::Here(SymbolDefinition::Function(function0())), } .mock(), SymbolDeclaration { id: "foo", - symbol: Symbol::HereFunction(function0()), + symbol: Symbol::Here(SymbolDefinition::Function(function0())), } .mock(), ], @@ -3230,12 +3385,12 @@ mod tests { symbols: vec![ SymbolDeclaration { id: "foo", - symbol: Symbol::HereFunction(f0), + symbol: Symbol::Here(SymbolDefinition::Function(f0)), } .mock(), SymbolDeclaration { id: "foo", - symbol: Symbol::HereFunction(f1), + symbol: Symbol::Here(SymbolDefinition::Function(f1)), } .mock(), ], @@ -3268,12 +3423,12 @@ mod tests { symbols: vec![ SymbolDeclaration { id: "foo", - symbol: Symbol::HereFunction(foo), + symbol: Symbol::Here(SymbolDefinition::Function(foo)), } .mock(), SymbolDeclaration { id: "main", - symbol: Symbol::HereFunction(function0()), + symbol: Symbol::Here(SymbolDefinition::Function(function0())), } .mock(), ], @@ -3321,12 +3476,12 @@ mod tests { symbols: vec![ SymbolDeclaration { id: "foo", - symbol: Symbol::HereFunction(foo), + symbol: Symbol::Here(SymbolDefinition::Function(foo)), } .mock(), SymbolDeclaration { id: "main", - symbol: Symbol::HereFunction(function0()), + symbol: Symbol::Here(SymbolDefinition::Function(function0())), } .mock(), ], @@ -3361,12 +3516,12 @@ mod tests { symbols: vec![ SymbolDeclaration { id: "foo", - symbol: Symbol::HereFunction(function0()), + symbol: Symbol::Here(SymbolDefinition::Function(function0())), } .mock(), SymbolDeclaration { id: "foo", - symbol: Symbol::HereFunction(function1()), + symbol: Symbol::Here(SymbolDefinition::Function(function1())), } .mock(), ], @@ -3411,12 +3566,12 @@ mod tests { symbols: vec![ SymbolDeclaration { id: "foo", - symbol: Symbol::HereType(struct0()), + symbol: Symbol::Here(SymbolDefinition::Struct(struct0())), } .mock(), SymbolDeclaration { id: "foo", - symbol: Symbol::HereType(struct1()), + symbol: Symbol::Here(SymbolDefinition::Struct(struct1())), } .mock(), ], @@ -3448,12 +3603,14 @@ mod tests { symbols: vec![ SymbolDeclaration { id: "foo", - symbol: Symbol::HereFunction(function0()), + symbol: Symbol::Here(SymbolDefinition::Function(function0())), } .mock(), SymbolDeclaration { id: "foo", - symbol: Symbol::HereType(StructDefinition { fields: vec![] }.mock()), + symbol: Symbol::Here(SymbolDefinition::Struct( + StructDefinition { fields: vec![] }.mock(), + )), } .mock(), ], @@ -3488,7 +3645,7 @@ mod tests { let bar = Module::with_symbols(vec![SymbolDeclaration { id: "main", - symbol: Symbol::HereFunction(function0()), + symbol: Symbol::Here(SymbolDefinition::Function(function0())), } .mock()]); @@ -3503,7 +3660,7 @@ mod tests { .mock(), SymbolDeclaration { id: "foo", - symbol: Symbol::HereType(struct0()), + symbol: Symbol::Here(SymbolDefinition::Struct(struct0())), } .mock(), ], @@ -3537,7 +3694,7 @@ mod tests { let bar = Module::with_symbols(vec![SymbolDeclaration { id: "main", - symbol: Symbol::HereFunction(function0()), + symbol: Symbol::Here(SymbolDefinition::Function(function0())), } .mock()]); @@ -3545,7 +3702,7 @@ mod tests { symbols: vec![ SymbolDeclaration { id: "foo", - symbol: Symbol::HereType(struct0()), + symbol: Symbol::Here(SymbolDefinition::Struct(struct0())), } .mock(), SymbolDeclaration { @@ -3667,6 +3824,8 @@ mod tests { let types = HashMap::new(); let mut checker: Checker = Checker::new(); + checker.enter_scope(); + assert_eq!( checker.check_statement(statement, &*MODULE_ID, &types), Err(vec![ErrorInner { @@ -3691,12 +3850,13 @@ mod tests { let mut scope = HashSet::new(); scope.insert(ScopedVariable { id: Variable::field_element("a"), - level: 0, + level: 1, }); scope.insert(ScopedVariable { id: Variable::field_element("b"), - level: 0, + level: 1, }); + let mut checker: Checker = new_with_args(scope, 1, HashSet::new()); assert_eq!( checker.check_statement(statement, &*MODULE_ID, &types), @@ -3762,12 +3922,12 @@ mod tests { let symbols = vec![ SymbolDeclaration { id: "foo", - symbol: Symbol::HereFunction(foo), + symbol: Symbol::Here(SymbolDefinition::Function(foo)), } .mock(), SymbolDeclaration { id: "bar", - symbol: Symbol::HereFunction(bar), + symbol: Symbol::Here(SymbolDefinition::Function(bar)), } .mock(), ]; @@ -3877,17 +4037,17 @@ mod tests { let symbols = vec![ SymbolDeclaration { id: "foo", - symbol: Symbol::HereFunction(foo), + symbol: Symbol::Here(SymbolDefinition::Function(foo)), } .mock(), SymbolDeclaration { id: "bar", - symbol: Symbol::HereFunction(bar), + symbol: Symbol::Here(SymbolDefinition::Function(bar)), } .mock(), SymbolDeclaration { id: "main", - symbol: Symbol::HereFunction(main), + symbol: Symbol::Here(SymbolDefinition::Function(main)), } .mock(), ]; @@ -4265,12 +4425,12 @@ mod tests { symbols: vec![ SymbolDeclaration { id: "foo", - symbol: Symbol::HereFunction(foo), + symbol: Symbol::Here(SymbolDefinition::Function(foo)), } .mock(), SymbolDeclaration { id: "main", - symbol: Symbol::HereFunction(main), + symbol: Symbol::Here(SymbolDefinition::Function(main)), } .mock(), ], @@ -4352,12 +4512,12 @@ mod tests { symbols: vec![ SymbolDeclaration { id: "foo", - symbol: Symbol::HereFunction(foo), + symbol: Symbol::Here(SymbolDefinition::Function(foo)), } .mock(), SymbolDeclaration { id: "main", - symbol: Symbol::HereFunction(main), + symbol: Symbol::Here(SymbolDefinition::Function(main)), } .mock(), ], @@ -4468,12 +4628,12 @@ mod tests { symbols: vec![ SymbolDeclaration { id: "foo", - symbol: Symbol::HereFunction(foo), + symbol: Symbol::Here(SymbolDefinition::Function(foo)), } .mock(), SymbolDeclaration { id: "main", - symbol: Symbol::HereFunction(main), + symbol: Symbol::Here(SymbolDefinition::Function(main)), } .mock(), ], @@ -4761,12 +4921,12 @@ mod tests { let symbols = vec![ SymbolDeclaration { id: "main", - symbol: Symbol::HereFunction(main1), + symbol: Symbol::Here(SymbolDefinition::Function(main1)), } .mock(), SymbolDeclaration { id: "main", - symbol: Symbol::HereFunction(main2), + symbol: Symbol::Here(SymbolDefinition::Function(main2)), } .mock(), ]; @@ -4879,7 +5039,7 @@ mod tests { imports: vec![], symbols: vec![SymbolDeclaration { id: "Foo", - symbol: Symbol::HereType(s.mock()), + symbol: Symbol::Here(SymbolDefinition::Struct(s.mock())), } .mock()], }; @@ -5009,7 +5169,7 @@ mod tests { symbols: vec![ SymbolDeclaration { id: "Foo", - symbol: Symbol::HereType( + symbol: Symbol::Here(SymbolDefinition::Struct( StructDefinition { fields: vec![StructDefinitionField { id: "foo", @@ -5018,12 +5178,12 @@ mod tests { .mock()], } .mock(), - ), + )), } .mock(), SymbolDeclaration { id: "Bar", - symbol: Symbol::HereType( + symbol: Symbol::Here(SymbolDefinition::Struct( StructDefinition { fields: vec![StructDefinitionField { id: "foo", @@ -5032,7 +5192,7 @@ mod tests { .mock()], } .mock(), - ), + )), } .mock(), ], @@ -5078,7 +5238,7 @@ mod tests { imports: vec![], symbols: vec![SymbolDeclaration { id: "Bar", - symbol: Symbol::HereType( + symbol: Symbol::Here(SymbolDefinition::Struct( StructDefinition { fields: vec![StructDefinitionField { id: "foo", @@ -5087,7 +5247,7 @@ mod tests { .mock()], } .mock(), - ), + )), } .mock()], }; @@ -5111,7 +5271,7 @@ mod tests { imports: vec![], symbols: vec![SymbolDeclaration { id: "Foo", - symbol: Symbol::HereType( + symbol: Symbol::Here(SymbolDefinition::Struct( StructDefinition { fields: vec![StructDefinitionField { id: "foo", @@ -5120,7 +5280,7 @@ mod tests { .mock()], } .mock(), - ), + )), } .mock()], }; @@ -5146,7 +5306,7 @@ mod tests { symbols: vec![ SymbolDeclaration { id: "Foo", - symbol: Symbol::HereType( + symbol: Symbol::Here(SymbolDefinition::Struct( StructDefinition { fields: vec![StructDefinitionField { id: "bar", @@ -5155,12 +5315,12 @@ mod tests { .mock()], } .mock(), - ), + )), } .mock(), SymbolDeclaration { id: "Bar", - symbol: Symbol::HereType( + symbol: Symbol::Here(SymbolDefinition::Struct( StructDefinition { fields: vec![StructDefinitionField { id: "foo", @@ -5169,7 +5329,7 @@ mod tests { .mock()], } .mock(), - ), + )), } .mock(), ], @@ -5638,17 +5798,17 @@ mod tests { let m = Module::with_symbols(vec![ absy::SymbolDeclaration { id: "foo", - symbol: Symbol::HereFunction(foo_field), + symbol: Symbol::Here(SymbolDefinition::Function(foo_field)), } .mock(), absy::SymbolDeclaration { id: "foo", - symbol: Symbol::HereFunction(foo_u32), + symbol: Symbol::Here(SymbolDefinition::Function(foo_u32)), } .mock(), absy::SymbolDeclaration { id: "main", - symbol: Symbol::HereFunction(main), + symbol: Symbol::Here(SymbolDefinition::Function(main)), } .mock(), ]); @@ -5680,6 +5840,7 @@ mod tests { let types = HashMap::new(); let mut checker: Checker = Checker::new(); + checker.enter_scope(); checker .check_statement( @@ -5713,6 +5874,8 @@ mod tests { let types = HashMap::new(); let mut checker: Checker = Checker::new(); + checker.enter_scope(); + checker .check_statement( Statement::Declaration( @@ -5763,6 +5926,8 @@ mod tests { let types = HashMap::new(); let mut checker: Checker = Checker::new(); + checker.enter_scope(); + checker .check_statement( Statement::Declaration( diff --git a/zokrates_core/src/static_analysis/constant_inliner.rs b/zokrates_core/src/static_analysis/constant_inliner.rs new file mode 100644 index 00000000..360927d7 --- /dev/null +++ b/zokrates_core/src/static_analysis/constant_inliner.rs @@ -0,0 +1,822 @@ +use crate::typed_absy::folder::*; +use crate::typed_absy::*; +use std::convert::TryInto; +use zokrates_field::Field; + +pub struct ConstantInliner<'ast, T: Field> { + modules: TypedModules<'ast, T>, + location: OwnedTypedModuleId, +} + +impl<'ast, T: Field> ConstantInliner<'ast, T> { + pub fn new(modules: TypedModules<'ast, T>, location: OwnedTypedModuleId) -> Self { + ConstantInliner { modules, location } + } + + pub fn inline(p: TypedProgram<'ast, T>) -> TypedProgram<'ast, T> { + let mut inliner = ConstantInliner::new(p.modules.clone(), p.main.clone()); + inliner.fold_program(p) + } + + fn module(&self) -> &TypedModule<'ast, T> { + self.modules.get(&self.location).unwrap() + } + + fn change_location(&mut self, location: OwnedTypedModuleId) -> OwnedTypedModuleId { + let prev = self.location.clone(); + self.location = location; + prev + } + + fn get_constant(&mut self, id: &Identifier) -> Option> { + self.modules + .get(&self.location) + .unwrap() + .constants + .get(id.clone().try_into().unwrap()) + .cloned() + .map(|symbol| self.get_canonical_constant(symbol)) + } + + fn get_canonical_constant( + &mut self, + symbol: TypedConstantSymbol<'ast, T>, + ) -> TypedConstant<'ast, T> { + match symbol { + TypedConstantSymbol::There(module_id, id) => { + let location = self.change_location(module_id); + let symbol = self.module().constants.get(id).cloned().unwrap(); + + let symbol = self.get_canonical_constant(symbol); + let _ = self.change_location(location); + symbol + } + TypedConstantSymbol::Here(tc) => self.fold_constant(tc), + } + } +} + +impl<'ast, T: Field> Folder<'ast, T> for ConstantInliner<'ast, T> { + fn fold_program(&mut self, p: TypedProgram<'ast, T>) -> TypedProgram<'ast, T> { + TypedProgram { + modules: p + .modules + .into_iter() + .map(|(module_id, module)| { + self.change_location(module_id.clone()); + (module_id, self.fold_module(module)) + }) + .collect(), + main: p.main, + } + } + + fn fold_constant_symbol( + &mut self, + s: TypedConstantSymbol<'ast, T>, + ) -> TypedConstantSymbol<'ast, T> { + let tc = self.get_canonical_constant(s); + TypedConstantSymbol::Here(tc) + } + + fn fold_field_expression( + &mut self, + e: FieldElementExpression<'ast, T>, + ) -> FieldElementExpression<'ast, T> { + match e { + FieldElementExpression::Identifier(ref id) => match self.get_constant(id) { + Some(c) => self.fold_constant(c).try_into().unwrap(), + None => fold_field_expression(self, e), + }, + e => fold_field_expression(self, e), + } + } + + fn fold_boolean_expression( + &mut self, + e: BooleanExpression<'ast, T>, + ) -> BooleanExpression<'ast, T> { + match e { + BooleanExpression::Identifier(ref id) => match self.get_constant(id) { + Some(c) => self.fold_constant(c).try_into().unwrap(), + None => fold_boolean_expression(self, e), + }, + e => fold_boolean_expression(self, e), + } + } + + fn fold_uint_expression_inner( + &mut self, + size: UBitwidth, + e: UExpressionInner<'ast, T>, + ) -> UExpressionInner<'ast, T> { + match e { + UExpressionInner::Identifier(ref id) => match self.get_constant(id) { + Some(c) => { + let e: UExpression<'ast, T> = self.fold_constant(c).try_into().unwrap(); + e.into_inner() + } + None => fold_uint_expression_inner(self, size, e), + }, + e => fold_uint_expression_inner(self, size, e), + } + } + + fn fold_array_expression_inner( + &mut self, + ty: &ArrayType<'ast, T>, + e: ArrayExpressionInner<'ast, T>, + ) -> ArrayExpressionInner<'ast, T> { + match e { + ArrayExpressionInner::Identifier(ref id) => match self.get_constant(id) { + Some(c) => { + let e: ArrayExpression<'ast, T> = self.fold_constant(c).try_into().unwrap(); + e.into_inner() + } + None => fold_array_expression_inner(self, ty, e), + }, + e => fold_array_expression_inner(self, ty, e), + } + } + + fn fold_struct_expression_inner( + &mut self, + ty: &StructType<'ast, T>, + e: StructExpressionInner<'ast, T>, + ) -> StructExpressionInner<'ast, T> { + match e { + StructExpressionInner::Identifier(ref id) => match self.get_constant(id) { + Some(c) => { + let e: StructExpression<'ast, T> = self.fold_constant(c).try_into().unwrap(); + e.into_inner() + } + None => fold_struct_expression_inner(self, ty, e), + }, + e => fold_struct_expression_inner(self, ty, e), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::typed_absy::types::DeclarationSignature; + use crate::typed_absy::{ + DeclarationFunctionKey, DeclarationType, FieldElementExpression, GType, Identifier, + TypedConstant, TypedExpression, TypedFunction, TypedFunctionSymbol, TypedStatement, + }; + use zokrates_field::Bn128Field; + + #[test] + fn inline_const_field() { + // const field a = 1 + // + // def main() -> field: + // return a + + let const_id = "a"; + let main: TypedFunction = TypedFunction { + arguments: vec![], + statements: vec![TypedStatement::Return(vec![ + FieldElementExpression::Identifier(Identifier::from(const_id)).into(), + ])], + signature: DeclarationSignature::new() + .inputs(vec![]) + .outputs(vec![DeclarationType::FieldElement]), + }; + + let constants: TypedConstantSymbols<_> = vec![( + const_id, + TypedConstantSymbol::Here(TypedConstant::new( + GType::FieldElement, + TypedExpression::FieldElement(FieldElementExpression::Number(Bn128Field::from(1))), + )), + )] + .into_iter() + .collect(); + + let program = TypedProgram { + main: "main".into(), + modules: vec![( + "main".into(), + TypedModule { + functions: vec![( + DeclarationFunctionKey::with_location("main", "main").signature( + DeclarationSignature::new() + .inputs(vec![]) + .outputs(vec![DeclarationType::FieldElement]), + ), + TypedFunctionSymbol::Here(main), + )] + .into_iter() + .collect(), + constants: constants.clone(), + }, + )] + .into_iter() + .collect(), + }; + + let program = ConstantInliner::inline(program); + + let expected_main = TypedFunction { + arguments: vec![], + statements: vec![TypedStatement::Return(vec![ + FieldElementExpression::Number(Bn128Field::from(1)).into(), + ])], + signature: DeclarationSignature::new() + .inputs(vec![]) + .outputs(vec![DeclarationType::FieldElement]), + }; + + let expected_program: TypedProgram = TypedProgram { + main: "main".into(), + modules: vec![( + "main".into(), + TypedModule { + functions: vec![( + DeclarationFunctionKey::with_location("main", "main").signature( + DeclarationSignature::new() + .inputs(vec![]) + .outputs(vec![DeclarationType::FieldElement]), + ), + TypedFunctionSymbol::Here(expected_main), + )] + .into_iter() + .collect(), + constants, + }, + )] + .into_iter() + .collect(), + }; + + assert_eq!(program, expected_program) + } + + #[test] + fn inline_const_boolean() { + // const bool a = true + // + // def main() -> bool: + // return a + + let const_id = "a"; + let main: TypedFunction = TypedFunction { + arguments: vec![], + statements: vec![TypedStatement::Return(vec![BooleanExpression::Identifier( + Identifier::from(const_id), + ) + .into()])], + signature: DeclarationSignature::new() + .inputs(vec![]) + .outputs(vec![DeclarationType::Boolean]), + }; + + let constants: TypedConstantSymbols<_> = vec![( + const_id, + TypedConstantSymbol::Here(TypedConstant::new( + GType::Boolean, + TypedExpression::Boolean(BooleanExpression::Value(true)), + )), + )] + .into_iter() + .collect(); + + let program = TypedProgram { + main: "main".into(), + modules: vec![( + "main".into(), + TypedModule { + functions: vec![( + DeclarationFunctionKey::with_location("main", "main").signature( + DeclarationSignature::new() + .inputs(vec![]) + .outputs(vec![DeclarationType::Boolean]), + ), + TypedFunctionSymbol::Here(main), + )] + .into_iter() + .collect(), + constants: constants.clone(), + }, + )] + .into_iter() + .collect(), + }; + + let program = ConstantInliner::inline(program); + + let expected_main = TypedFunction { + arguments: vec![], + statements: vec![TypedStatement::Return(vec![ + BooleanExpression::Value(true).into() + ])], + signature: DeclarationSignature::new() + .inputs(vec![]) + .outputs(vec![DeclarationType::Boolean]), + }; + + let expected_program: TypedProgram = TypedProgram { + main: "main".into(), + modules: vec![( + "main".into(), + TypedModule { + functions: vec![( + DeclarationFunctionKey::with_location("main", "main").signature( + DeclarationSignature::new() + .inputs(vec![]) + .outputs(vec![DeclarationType::Boolean]), + ), + TypedFunctionSymbol::Here(expected_main), + )] + .into_iter() + .collect(), + constants, + }, + )] + .into_iter() + .collect(), + }; + + assert_eq!(program, expected_program) + } + + #[test] + fn inline_const_uint() { + // const u32 a = 0x00000001 + // + // def main() -> u32: + // return a + + let const_id = "a"; + let main: TypedFunction = TypedFunction { + arguments: vec![], + statements: vec![TypedStatement::Return(vec![UExpressionInner::Identifier( + Identifier::from(const_id), + ) + .annotate(UBitwidth::B32) + .into()])], + signature: DeclarationSignature::new() + .inputs(vec![]) + .outputs(vec![DeclarationType::Uint(UBitwidth::B32)]), + }; + + let constants: TypedConstantSymbols<_> = vec![( + const_id, + TypedConstantSymbol::Here(TypedConstant::new( + GType::Uint(UBitwidth::B32), + UExpressionInner::Value(1u128) + .annotate(UBitwidth::B32) + .into(), + )), + )] + .into_iter() + .collect(); + + let program = TypedProgram { + main: "main".into(), + modules: vec![( + "main".into(), + TypedModule { + functions: vec![( + DeclarationFunctionKey::with_location("main", "main").signature( + DeclarationSignature::new() + .inputs(vec![]) + .outputs(vec![DeclarationType::Uint(UBitwidth::B32)]), + ), + TypedFunctionSymbol::Here(main), + )] + .into_iter() + .collect(), + constants: constants.clone(), + }, + )] + .into_iter() + .collect(), + }; + + let program = ConstantInliner::inline(program); + + let expected_main = TypedFunction { + arguments: vec![], + statements: vec![TypedStatement::Return(vec![UExpressionInner::Value(1u128) + .annotate(UBitwidth::B32) + .into()])], + signature: DeclarationSignature::new() + .inputs(vec![]) + .outputs(vec![DeclarationType::Uint(UBitwidth::B32)]), + }; + + let expected_program: TypedProgram = TypedProgram { + main: "main".into(), + modules: vec![( + "main".into(), + TypedModule { + functions: vec![( + DeclarationFunctionKey::with_location("main", "main").signature( + DeclarationSignature::new() + .inputs(vec![]) + .outputs(vec![DeclarationType::Uint(UBitwidth::B32)]), + ), + TypedFunctionSymbol::Here(expected_main), + )] + .into_iter() + .collect(), + constants, + }, + )] + .into_iter() + .collect(), + }; + + assert_eq!(program, expected_program) + } + + #[test] + fn inline_const_field_array() { + // const field[2] a = [2, 2] + // + // def main() -> field: + // return a[0] + a[1] + + let const_id = "a"; + let main: TypedFunction = TypedFunction { + arguments: vec![], + statements: vec![TypedStatement::Return(vec![FieldElementExpression::Add( + FieldElementExpression::Select( + box ArrayExpressionInner::Identifier(Identifier::from(const_id)) + .annotate(GType::FieldElement, 2usize), + box UExpressionInner::Value(0u128).annotate(UBitwidth::B32), + ) + .into(), + FieldElementExpression::Select( + box ArrayExpressionInner::Identifier(Identifier::from(const_id)) + .annotate(GType::FieldElement, 2usize), + box UExpressionInner::Value(1u128).annotate(UBitwidth::B32), + ) + .into(), + ) + .into()])], + signature: DeclarationSignature::new() + .inputs(vec![]) + .outputs(vec![DeclarationType::FieldElement]), + }; + + let constants: TypedConstantSymbols<_> = vec![( + const_id, + TypedConstantSymbol::Here(TypedConstant::new( + GType::FieldElement, + TypedExpression::Array( + ArrayExpressionInner::Value( + vec![ + FieldElementExpression::Number(Bn128Field::from(2)).into(), + FieldElementExpression::Number(Bn128Field::from(2)).into(), + ] + .into(), + ) + .annotate(GType::FieldElement, 2usize), + ), + )), + )] + .into_iter() + .collect(); + + let program = TypedProgram { + main: "main".into(), + modules: vec![( + "main".into(), + TypedModule { + functions: vec![( + DeclarationFunctionKey::with_location("main", "main").signature( + DeclarationSignature::new() + .inputs(vec![]) + .outputs(vec![DeclarationType::FieldElement]), + ), + TypedFunctionSymbol::Here(main), + )] + .into_iter() + .collect(), + constants: constants.clone(), + }, + )] + .into_iter() + .collect(), + }; + + let program = ConstantInliner::inline(program); + + let expected_main = TypedFunction { + arguments: vec![], + statements: vec![TypedStatement::Return(vec![FieldElementExpression::Add( + FieldElementExpression::Select( + box ArrayExpressionInner::Value( + vec![ + FieldElementExpression::Number(Bn128Field::from(2)).into(), + FieldElementExpression::Number(Bn128Field::from(2)).into(), + ] + .into(), + ) + .annotate(GType::FieldElement, 2usize), + box UExpressionInner::Value(0u128).annotate(UBitwidth::B32), + ) + .into(), + FieldElementExpression::Select( + box ArrayExpressionInner::Value( + vec![ + FieldElementExpression::Number(Bn128Field::from(2)).into(), + FieldElementExpression::Number(Bn128Field::from(2)).into(), + ] + .into(), + ) + .annotate(GType::FieldElement, 2usize), + box UExpressionInner::Value(1u128).annotate(UBitwidth::B32), + ) + .into(), + ) + .into()])], + signature: DeclarationSignature::new() + .inputs(vec![]) + .outputs(vec![DeclarationType::FieldElement]), + }; + + let expected_program: TypedProgram = TypedProgram { + main: "main".into(), + modules: vec![( + "main".into(), + TypedModule { + functions: vec![( + DeclarationFunctionKey::with_location("main", "main").signature( + DeclarationSignature::new() + .inputs(vec![]) + .outputs(vec![DeclarationType::FieldElement]), + ), + TypedFunctionSymbol::Here(expected_main), + )] + .into_iter() + .collect(), + constants, + }, + )] + .into_iter() + .collect(), + }; + + assert_eq!(program, expected_program) + } + + #[test] + fn inline_nested_const_field() { + // const field a = 1 + // const field b = a + 1 + // + // def main() -> field: + // return b + + let const_a_id = "a"; + let const_b_id = "b"; + + let main: TypedFunction = TypedFunction { + arguments: vec![], + statements: vec![TypedStatement::Return(vec![ + FieldElementExpression::Identifier(Identifier::from(const_b_id)).into(), + ])], + signature: DeclarationSignature::new() + .inputs(vec![]) + .outputs(vec![DeclarationType::FieldElement]), + }; + + let program = TypedProgram { + main: "main".into(), + modules: vec![( + "main".into(), + TypedModule { + functions: vec![( + DeclarationFunctionKey::with_location("main", "main").signature( + DeclarationSignature::new() + .inputs(vec![]) + .outputs(vec![DeclarationType::FieldElement]), + ), + TypedFunctionSymbol::Here(main), + )] + .into_iter() + .collect(), + constants: vec![ + ( + const_a_id, + TypedConstantSymbol::Here(TypedConstant::new( + GType::FieldElement, + TypedExpression::FieldElement(FieldElementExpression::Number( + Bn128Field::from(1), + )), + )), + ), + ( + const_b_id, + TypedConstantSymbol::Here(TypedConstant::new( + GType::FieldElement, + TypedExpression::FieldElement(FieldElementExpression::Add( + box FieldElementExpression::Identifier(Identifier::from( + const_a_id, + )), + box FieldElementExpression::Number(Bn128Field::from(1)), + )), + )), + ), + ] + .into_iter() + .collect(), + }, + )] + .into_iter() + .collect(), + }; + + let program = ConstantInliner::inline(program); + + let expected_main = TypedFunction { + arguments: vec![], + statements: vec![TypedStatement::Return(vec![FieldElementExpression::Add( + box FieldElementExpression::Number(Bn128Field::from(1)), + box FieldElementExpression::Number(Bn128Field::from(1)), + ) + .into()])], + signature: DeclarationSignature::new() + .inputs(vec![]) + .outputs(vec![DeclarationType::FieldElement]), + }; + + let expected_program: TypedProgram = TypedProgram { + main: "main".into(), + modules: vec![( + "main".into(), + TypedModule { + functions: vec![( + DeclarationFunctionKey::with_location("main", "main").signature( + DeclarationSignature::new() + .inputs(vec![]) + .outputs(vec![DeclarationType::FieldElement]), + ), + TypedFunctionSymbol::Here(expected_main), + )] + .into_iter() + .collect(), + constants: vec![ + ( + const_a_id, + TypedConstantSymbol::Here(TypedConstant::new( + GType::FieldElement, + TypedExpression::FieldElement(FieldElementExpression::Number( + Bn128Field::from(1), + )), + )), + ), + ( + const_b_id, + TypedConstantSymbol::Here(TypedConstant::new( + GType::FieldElement, + TypedExpression::FieldElement(FieldElementExpression::Add( + box FieldElementExpression::Number(Bn128Field::from(1)), + box FieldElementExpression::Number(Bn128Field::from(1)), + )), + )), + ), + ] + .into_iter() + .collect(), + }, + )] + .into_iter() + .collect(), + }; + + assert_eq!(program, expected_program) + } + + #[test] + fn inline_imported_constant() { + // --------------------- + // module `foo` + // -------------------- + // const field FOO = 42 + // + // def main(): + // return + // + // --------------------- + // module `main` + // --------------------- + // from "foo" import FOO + // + // def main() -> field: + // return FOO + + let foo_const_id = "FOO"; + let foo_module = TypedModule { + functions: vec![( + DeclarationFunctionKey::with_location("main", "main") + .signature(DeclarationSignature::new().inputs(vec![]).outputs(vec![])), + TypedFunctionSymbol::Here(TypedFunction { + arguments: vec![], + statements: vec![], + signature: DeclarationSignature::new().inputs(vec![]).outputs(vec![]), + }), + )] + .into_iter() + .collect(), + constants: vec![( + foo_const_id, + TypedConstantSymbol::Here(TypedConstant::new( + GType::FieldElement, + TypedExpression::FieldElement(FieldElementExpression::Number( + Bn128Field::from(42), + )), + )), + )] + .into_iter() + .collect(), + }; + + let main_module = TypedModule { + functions: vec![( + DeclarationFunctionKey::with_location("main", "main").signature( + DeclarationSignature::new() + .inputs(vec![]) + .outputs(vec![DeclarationType::FieldElement]), + ), + TypedFunctionSymbol::Here(TypedFunction { + arguments: vec![], + statements: vec![TypedStatement::Return(vec![ + FieldElementExpression::Identifier(Identifier::from(foo_const_id)).into(), + ])], + signature: DeclarationSignature::new() + .inputs(vec![]) + .outputs(vec![DeclarationType::FieldElement]), + }), + )] + .into_iter() + .collect(), + constants: vec![( + foo_const_id, + TypedConstantSymbol::There(OwnedTypedModuleId::from("foo"), foo_const_id), + )] + .into_iter() + .collect(), + }; + + let program = TypedProgram { + main: "main".into(), + modules: vec![ + ("main".into(), main_module), + ("foo".into(), foo_module.clone()), + ] + .into_iter() + .collect(), + }; + + let program = ConstantInliner::inline(program); + let expected_main_module = TypedModule { + functions: vec![( + DeclarationFunctionKey::with_location("main", "main").signature( + DeclarationSignature::new() + .inputs(vec![]) + .outputs(vec![DeclarationType::FieldElement]), + ), + TypedFunctionSymbol::Here(TypedFunction { + arguments: vec![], + statements: vec![TypedStatement::Return(vec![ + FieldElementExpression::Number(Bn128Field::from(42)).into(), + ])], + signature: DeclarationSignature::new() + .inputs(vec![]) + .outputs(vec![DeclarationType::FieldElement]), + }), + )] + .into_iter() + .collect(), + constants: vec![( + foo_const_id, + TypedConstantSymbol::Here(TypedConstant::new( + GType::FieldElement, + TypedExpression::FieldElement(FieldElementExpression::Number( + Bn128Field::from(42), + )), + )), + )] + .into_iter() + .collect(), + }; + + let expected_program: TypedProgram = TypedProgram { + main: "main".into(), + modules: vec![ + ("main".into(), expected_main_module), + ("foo".into(), foo_module), + ] + .into_iter() + .collect(), + }; + + assert_eq!(program, expected_program) + } +} diff --git a/zokrates_core/src/static_analysis/mod.rs b/zokrates_core/src/static_analysis/mod.rs index e30c5e54..4b91e49e 100644 --- a/zokrates_core/src/static_analysis/mod.rs +++ b/zokrates_core/src/static_analysis/mod.rs @@ -5,6 +5,7 @@ //! @date 2018 mod bounds_checker; +mod constant_inliner; mod flat_propagation; mod flatten_complex_types; mod propagation; @@ -28,6 +29,7 @@ use self::variable_read_remover::VariableReadRemover; use self::variable_write_remover::VariableWriteRemover; use crate::flat_absy::FlatProg; use crate::ir::Prog; +use crate::static_analysis::constant_inliner::ConstantInliner; use crate::typed_absy::{abi::Abi, TypedProgram}; use crate::zir::ZirProgram; use std::fmt; @@ -73,8 +75,11 @@ impl fmt::Display for Error { impl<'ast, T: Field> TypedProgram<'ast, T> { pub fn analyse(self) -> Result<(ZirProgram<'ast, T>, Abi), Error> { - let r = reduce_program(self).map_err(Error::from)?; - + // inline user-defined constants + let r = ConstantInliner::inline(self); + // reduce the program to a single function + let r = reduce_program(r).map_err(Error::from)?; + // generate abi let abi = r.abi(); // propagate diff --git a/zokrates_core/src/static_analysis/propagation.rs b/zokrates_core/src/static_analysis/propagation.rs index 39f6409f..0d72e6c5 100644 --- a/zokrates_core/src/static_analysis/propagation.rs +++ b/zokrates_core/src/static_analysis/propagation.rs @@ -264,6 +264,7 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Propagator<'ast, 'a, T> { } }) .collect::>()?, + ..m }) } diff --git a/zokrates_core/src/static_analysis/reducer/mod.rs b/zokrates_core/src/static_analysis/reducer/mod.rs index 24fd36ed..7d393773 100644 --- a/zokrates_core/src/static_analysis/reducer/mod.rs +++ b/zokrates_core/src/static_analysis/reducer/mod.rs @@ -547,6 +547,7 @@ pub fn reduce_program(p: TypedProgram) -> Result, E )] .into_iter() .collect(), + constants: Default::default(), }, )] .into_iter() @@ -761,6 +762,7 @@ mod tests { ] .into_iter() .collect(), + constants: Default::default(), }, )] .into_iter() @@ -826,6 +828,7 @@ mod tests { )] .into_iter() .collect(), + constants: Default::default(), }, )] .into_iter() @@ -959,6 +962,7 @@ mod tests { ] .into_iter() .collect(), + constants: Default::default(), }, )] .into_iter() @@ -1043,6 +1047,7 @@ mod tests { )] .into_iter() .collect(), + constants: Default::default(), }, )] .into_iter() @@ -1185,6 +1190,7 @@ mod tests { ] .into_iter() .collect(), + constants: Default::default(), }, )] .into_iter() @@ -1269,6 +1275,7 @@ mod tests { )] .into_iter() .collect(), + constants: Default::default(), }, )] .into_iter() @@ -1447,6 +1454,7 @@ mod tests { ] .into_iter() .collect(), + constants: Default::default(), }, )] .into_iter() @@ -1558,6 +1566,7 @@ mod tests { )] .into_iter() .collect(), + constants: Default::default(), }, )] .into_iter() @@ -1646,6 +1655,7 @@ mod tests { ] .into_iter() .collect(), + constants: Default::default(), }, )] .into_iter() diff --git a/zokrates_core/src/typed_absy/abi.rs b/zokrates_core/src/typed_absy/abi.rs index 3cd1f425..f3adbdea 100644 --- a/zokrates_core/src/typed_absy/abi.rs +++ b/zokrates_core/src/typed_absy/abi.rs @@ -65,7 +65,13 @@ mod tests { ); let mut modules = HashMap::new(); - modules.insert("main".into(), TypedModule { functions }); + modules.insert( + "main".into(), + TypedModule { + functions, + constants: Default::default(), + }, + ); let typed_ast: TypedProgram = TypedProgram { main: "main".into(), diff --git a/zokrates_core/src/typed_absy/folder.rs b/zokrates_core/src/typed_absy/folder.rs index 5aaacc15..36c1453f 100644 --- a/zokrates_core/src/typed_absy/folder.rs +++ b/zokrates_core/src/typed_absy/folder.rs @@ -9,8 +9,19 @@ pub trait Folder<'ast, T: Field>: Sized { fold_program(self, p) } - fn fold_module(&mut self, p: TypedModule<'ast, T>) -> TypedModule<'ast, T> { - fold_module(self, p) + fn fold_module(&mut self, m: TypedModule<'ast, T>) -> TypedModule<'ast, T> { + fold_module(self, m) + } + + fn fold_constant(&mut self, c: TypedConstant<'ast, T>) -> TypedConstant<'ast, T> { + fold_constant(self, c) + } + + fn fold_constant_symbol( + &mut self, + s: TypedConstantSymbol<'ast, T>, + ) -> TypedConstantSymbol<'ast, T> { + fold_constant_symbol(self, s) } fn fold_function_symbol( @@ -190,10 +201,15 @@ pub trait Folder<'ast, T: Field>: Sized { pub fn fold_module<'ast, T: Field, F: Folder<'ast, T>>( f: &mut F, - p: TypedModule<'ast, T>, + m: TypedModule<'ast, T>, ) -> TypedModule<'ast, T> { TypedModule { - functions: p + constants: m + .constants + .into_iter() + .map(|(key, tc)| (key, f.fold_constant_symbol(tc))) + .collect(), + functions: m .functions .into_iter() .map(|(key, fun)| (key, f.fold_function_symbol(fun))) @@ -711,6 +727,26 @@ pub fn fold_struct_expression<'ast, T: Field, F: Folder<'ast, T>>( } } +pub fn fold_constant<'ast, T: Field, F: Folder<'ast, T>>( + f: &mut F, + c: TypedConstant<'ast, T>, +) -> TypedConstant<'ast, T> { + TypedConstant { + ty: f.fold_type(c.ty), + expression: f.fold_expression(c.expression), + } +} + +pub fn fold_constant_symbol<'ast, T: Field, F: Folder<'ast, T>>( + f: &mut F, + s: TypedConstantSymbol<'ast, T>, +) -> TypedConstantSymbol<'ast, T> { + match s { + TypedConstantSymbol::Here(tc) => TypedConstantSymbol::Here(f.fold_constant(tc)), + there => there, + } +} + pub fn fold_function_symbol<'ast, T: Field, F: Folder<'ast, T>>( f: &mut F, s: TypedFunctionSymbol<'ast, T>, diff --git a/zokrates_core/src/typed_absy/mod.rs b/zokrates_core/src/typed_absy/mod.rs index 7b5935da..07b1391c 100644 --- a/zokrates_core/src/typed_absy/mod.rs +++ b/zokrates_core/src/typed_absy/mod.rs @@ -61,6 +61,18 @@ pub type TypedModules<'ast, T> = HashMap = HashMap, TypedFunctionSymbol<'ast, T>>; +pub type ConstantIdentifier<'ast> = &'ast str; + +#[derive(Clone, Debug, PartialEq)] +pub enum TypedConstantSymbol<'ast, T> { + Here(TypedConstant<'ast, T>), + There(OwnedTypedModuleId, ConstantIdentifier<'ast>), +} + +/// A collection of `TypedConstantSymbol`s +pub type TypedConstantSymbols<'ast, T> = + HashMap, TypedConstantSymbol<'ast, T>>; + /// A typed program as a collection of modules, one of them being the main #[derive(PartialEq, Debug, Clone)] pub struct TypedProgram<'ast, T> { @@ -135,11 +147,13 @@ impl<'ast, T: fmt::Display> fmt::Display for TypedProgram<'ast, T> { } } -/// A typed program as a collection of functions. Types have been resolved during semantic checking. +/// A typed module as a collection of functions. Types have been resolved during semantic checking. #[derive(PartialEq, Clone)] pub struct TypedModule<'ast, T> { - /// Functions of the program + /// Functions of the module pub functions: TypedFunctionSymbols<'ast, T>, + /// Constants defined in module + pub constants: TypedConstantSymbols<'ast, T>, } #[derive(Clone, PartialEq)] @@ -182,22 +196,31 @@ impl<'ast, T: Field> TypedFunctionSymbol<'ast, T> { impl<'ast, T: fmt::Display> fmt::Display for TypedModule<'ast, T> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { let res = self - .functions + .constants .iter() .map(|(key, symbol)| match symbol { + TypedConstantSymbol::Here(ref tc) => { + format!("const {} {} = {}", tc.ty, key, tc.expression) + } + TypedConstantSymbol::There(ref module_id, ref id) => { + format!("from \"{}\" import {} as {}", module_id.display(), id, key) + } + }) + .chain(self.functions.iter().map(|(key, symbol)| match symbol { TypedFunctionSymbol::Here(ref function) => format!("def {}{}", key.id, function), TypedFunctionSymbol::There(ref fun_key) => format!( - "import {} from \"{}\" as {} // with signature {}", - fun_key.id, + "from \"{}\" import {} as {} // with signature {}", fun_key.module.display(), + fun_key.id, key.id, key.signature ), TypedFunctionSymbol::Flat(ref flat_fun) => { format!("def {}{}:\n\t// hidden", key.id, flat_fun.signature()) } - }) + })) .collect::>(); + write!(f, "{}", res.join("\n")) } } @@ -206,8 +229,13 @@ impl<'ast, T: fmt::Debug> fmt::Debug for TypedModule<'ast, T> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!( f, - "module(\n\tfunctions:\n\t\t{:?}\n)", + "TypedModule(\n\tFunctions:\n\t\t{:?}\n\tConstants:\n\t\t{:?}\n)", self.functions + .iter() + .map(|x| format!("{:?}", x)) + .collect::>() + .join("\n\t\t"), + self.constants .iter() .map(|x| format!("{:?}", x)) .collect::>() @@ -306,6 +334,36 @@ impl<'ast, T: fmt::Debug> fmt::Debug for TypedFunction<'ast, T> { } } +#[derive(Clone, PartialEq)] +pub struct TypedConstant<'ast, T> { + ty: Type<'ast, T>, + expression: TypedExpression<'ast, T>, +} + +impl<'ast, T> TypedConstant<'ast, T> { + pub fn new(ty: Type<'ast, T>, expression: TypedExpression<'ast, T>) -> Self { + TypedConstant { ty, expression } + } +} + +impl<'ast, T: fmt::Debug> fmt::Debug for TypedConstant<'ast, T> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "TypedConstant({:?}, {:?})", self.ty, self.expression) + } +} + +impl<'ast, T: fmt::Display> fmt::Display for TypedConstant<'ast, T> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "const {}({})", self.ty, self.expression) + } +} + +impl<'ast, T: Clone> Typed<'ast, T> for TypedConstant<'ast, T> { + fn get_type(&self) -> Type<'ast, T> { + self.ty.clone() + } +} + /// Something we can assign to. #[derive(Clone, PartialEq, Hash, Eq)] pub enum TypedAssignee<'ast, T> { @@ -1224,6 +1282,56 @@ impl<'ast, T> TryFrom> for StructExpression<'ast, T> { } } +impl<'ast, T> TryFrom> for FieldElementExpression<'ast, T> { + type Error = (); + + fn try_from( + tc: TypedConstant<'ast, T>, + ) -> Result, Self::Error> { + tc.expression.try_into() + } +} + +impl<'ast, T> TryFrom> for BooleanExpression<'ast, T> { + type Error = (); + + fn try_from(tc: TypedConstant<'ast, T>) -> Result, Self::Error> { + tc.expression.try_into() + } +} + +impl<'ast, T> TryFrom> for UExpression<'ast, T> { + type Error = (); + + fn try_from(tc: TypedConstant<'ast, T>) -> Result, Self::Error> { + tc.expression.try_into() + } +} + +impl<'ast, T> TryFrom> for ArrayExpression<'ast, T> { + type Error = (); + + fn try_from(tc: TypedConstant<'ast, T>) -> Result, Self::Error> { + tc.expression.try_into() + } +} + +impl<'ast, T> TryFrom> for StructExpression<'ast, T> { + type Error = (); + + fn try_from(tc: TypedConstant<'ast, T>) -> Result, Self::Error> { + tc.expression.try_into() + } +} + +impl<'ast, T> TryFrom> for IntExpression<'ast, T> { + type Error = (); + + fn try_from(tc: TypedConstant<'ast, T>) -> Result, Self::Error> { + tc.expression.try_into() + } +} + impl<'ast, T: fmt::Display> fmt::Display for FieldElementExpression<'ast, T> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match *self { diff --git a/zokrates_core/src/typed_absy/result_folder.rs b/zokrates_core/src/typed_absy/result_folder.rs index 16aa40db..245ab2a5 100644 --- a/zokrates_core/src/typed_absy/result_folder.rs +++ b/zokrates_core/src/typed_absy/result_folder.rs @@ -16,9 +16,23 @@ pub trait ResultFolder<'ast, T: Field>: Sized { fn fold_module( &mut self, - p: TypedModule<'ast, T>, + m: TypedModule<'ast, T>, ) -> Result, Self::Error> { - fold_module(self, p) + fold_module(self, m) + } + + fn fold_constant( + &mut self, + c: TypedConstant<'ast, T>, + ) -> Result, Self::Error> { + fold_constant(self, c) + } + + fn fold_constant_symbol( + &mut self, + s: TypedConstantSymbol<'ast, T>, + ) -> Result, Self::Error> { + fold_constant_symbol(self, s) } fn fold_function_symbol( @@ -793,6 +807,26 @@ pub fn fold_struct_expression<'ast, T: Field, F: ResultFolder<'ast, T>>( }) } +pub fn fold_constant<'ast, T: Field, F: ResultFolder<'ast, T>>( + f: &mut F, + c: TypedConstant<'ast, T>, +) -> Result, F::Error> { + Ok(TypedConstant { + ty: f.fold_type(c.ty)?, + expression: f.fold_expression(c.expression)?, + }) +} + +pub fn fold_constant_symbol<'ast, T: Field, F: ResultFolder<'ast, T>>( + f: &mut F, + s: TypedConstantSymbol<'ast, T>, +) -> Result, F::Error> { + match s { + TypedConstantSymbol::Here(tc) => Ok(TypedConstantSymbol::Here(f.fold_constant(tc)?)), + there => Ok(there), + } +} + pub fn fold_function_symbol<'ast, T: Field, F: ResultFolder<'ast, T>>( f: &mut F, s: TypedFunctionSymbol<'ast, T>, @@ -805,10 +839,15 @@ pub fn fold_function_symbol<'ast, T: Field, F: ResultFolder<'ast, T>>( pub fn fold_module<'ast, T: Field, F: ResultFolder<'ast, T>>( f: &mut F, - p: TypedModule<'ast, T>, + m: TypedModule<'ast, T>, ) -> Result, F::Error> { Ok(TypedModule { - functions: p + constants: m + .constants + .into_iter() + .map(|(key, tc)| f.fold_constant_symbol(tc).map(|tc| (key, tc))) + .collect::>()?, + functions: m .functions .into_iter() .map(|(key, fun)| f.fold_function_symbol(fun).map(|f| (key, f))) diff --git a/zokrates_core_test/tests/tests/constants/array.json b/zokrates_core_test/tests/tests/constants/array.json new file mode 100644 index 00000000..7cbbb2ba --- /dev/null +++ b/zokrates_core_test/tests/tests/constants/array.json @@ -0,0 +1,16 @@ +{ + "entry_point": "./tests/tests/constants/array.zok", + "max_constraint_count": 2, + "tests": [ + { + "input": { + "values": [] + }, + "output": { + "Ok": { + "values": ["1", "2"] + } + } + } + ] +} diff --git a/zokrates_core_test/tests/tests/constants/array.zok b/zokrates_core_test/tests/tests/constants/array.zok new file mode 100644 index 00000000..cce74dc7 --- /dev/null +++ b/zokrates_core_test/tests/tests/constants/array.zok @@ -0,0 +1,4 @@ +const field[2] ARRAY = [1, 2] + +def main() -> field[2]: + return ARRAY \ No newline at end of file diff --git a/zokrates_core_test/tests/tests/constants/bool.json b/zokrates_core_test/tests/tests/constants/bool.json new file mode 100644 index 00000000..f11aea9a --- /dev/null +++ b/zokrates_core_test/tests/tests/constants/bool.json @@ -0,0 +1,16 @@ +{ + "entry_point": "./tests/tests/constants/bool.zok", + "max_constraint_count": 1, + "tests": [ + { + "input": { + "values": [] + }, + "output": { + "Ok": { + "values": ["1"] + } + } + } + ] +} diff --git a/zokrates_core_test/tests/tests/constants/bool.zok b/zokrates_core_test/tests/tests/constants/bool.zok new file mode 100644 index 00000000..00e0ae11 --- /dev/null +++ b/zokrates_core_test/tests/tests/constants/bool.zok @@ -0,0 +1,4 @@ +const bool BOOLEAN = true + +def main() -> bool: + return BOOLEAN \ No newline at end of file diff --git a/zokrates_core_test/tests/tests/constants/field.json b/zokrates_core_test/tests/tests/constants/field.json new file mode 100644 index 00000000..2ec19a1f --- /dev/null +++ b/zokrates_core_test/tests/tests/constants/field.json @@ -0,0 +1,16 @@ +{ + "entry_point": "./tests/tests/constants/field.zok", + "max_constraint_count": 1, + "tests": [ + { + "input": { + "values": [] + }, + "output": { + "Ok": { + "values": ["1"] + } + } + } + ] +} diff --git a/zokrates_core_test/tests/tests/constants/field.zok b/zokrates_core_test/tests/tests/constants/field.zok new file mode 100644 index 00000000..4408b12e --- /dev/null +++ b/zokrates_core_test/tests/tests/constants/field.zok @@ -0,0 +1,4 @@ +const field ONE = 1 + +def main() -> field: + return ONE \ No newline at end of file diff --git a/zokrates_core_test/tests/tests/constants/nested.json b/zokrates_core_test/tests/tests/constants/nested.json new file mode 100644 index 00000000..61cfd309 --- /dev/null +++ b/zokrates_core_test/tests/tests/constants/nested.json @@ -0,0 +1,16 @@ +{ + "entry_point": "./tests/tests/constants/nested.zok", + "max_constraint_count": 1, + "tests": [ + { + "input": { + "values": [] + }, + "output": { + "Ok": { + "values": ["8"] + } + } + } + ] +} diff --git a/zokrates_core_test/tests/tests/constants/nested.zok b/zokrates_core_test/tests/tests/constants/nested.zok new file mode 100644 index 00000000..a7861aeb --- /dev/null +++ b/zokrates_core_test/tests/tests/constants/nested.zok @@ -0,0 +1,6 @@ +const field A = 2 +const field B = 2 +const field[2] ARRAY = [A * 2, B * 2] + +def main() -> field: + return ARRAY[0] + ARRAY[1] \ No newline at end of file diff --git a/zokrates_core_test/tests/tests/constants/struct.json b/zokrates_core_test/tests/tests/constants/struct.json new file mode 100644 index 00000000..4d77d484 --- /dev/null +++ b/zokrates_core_test/tests/tests/constants/struct.json @@ -0,0 +1,16 @@ +{ + "entry_point": "./tests/tests/constants/struct.zok", + "max_constraint_count": 1, + "tests": [ + { + "input": { + "values": [] + }, + "output": { + "Ok": { + "values": ["4"] + } + } + } + ] +} diff --git a/zokrates_core_test/tests/tests/constants/struct.zok b/zokrates_core_test/tests/tests/constants/struct.zok new file mode 100644 index 00000000..92e705ca --- /dev/null +++ b/zokrates_core_test/tests/tests/constants/struct.zok @@ -0,0 +1,9 @@ +struct Foo { + field a + field b +} + +const Foo FOO = Foo { a: 2, b: 2 } + +def main() -> field: + return FOO.a + FOO.b \ No newline at end of file diff --git a/zokrates_core_test/tests/tests/constants/uint.json b/zokrates_core_test/tests/tests/constants/uint.json new file mode 100644 index 00000000..a2fccab8 --- /dev/null +++ b/zokrates_core_test/tests/tests/constants/uint.json @@ -0,0 +1,16 @@ +{ + "entry_point": "./tests/tests/constants/uint.zok", + "max_constraint_count": 1, + "tests": [ + { + "input": { + "values": [] + }, + "output": { + "Ok": { + "values": ["1"] + } + } + } + ] +} diff --git a/zokrates_core_test/tests/tests/constants/uint.zok b/zokrates_core_test/tests/tests/constants/uint.zok new file mode 100644 index 00000000..914308ec --- /dev/null +++ b/zokrates_core_test/tests/tests/constants/uint.zok @@ -0,0 +1,4 @@ +const u32 ONE = 0x00000001 + +def main() -> u32: + return ONE \ No newline at end of file diff --git a/zokrates_parser/src/ace_mode/index.js b/zokrates_parser/src/ace_mode/index.js index ec77a4fb..537e0f76 100644 --- a/zokrates_parser/src/ace_mode/index.js +++ b/zokrates_parser/src/ace_mode/index.js @@ -37,7 +37,7 @@ ace.define("ace/mode/zokrates_highlight_rules",["require","exports","module","ac var ZoKratesHighlightRules = function () { var keywords = ( - "assert|as|bool|byte|def|do|else|endfor|export|false|field|for|if|then|fi|import|from|in|private|public|return|struct|true|u8|u16|u32|u64" + "assert|as|bool|byte|const|def|do|else|endfor|export|false|field|for|if|then|fi|import|from|in|private|public|return|struct|true|u8|u16|u32|u64" ); var keywordMapper = this.createKeywordMapper({ diff --git a/zokrates_parser/src/textmate/package.json b/zokrates_parser/src/textmate/package.json index ebacd5de..a72f04a0 100644 --- a/zokrates_parser/src/textmate/package.json +++ b/zokrates_parser/src/textmate/package.json @@ -1,27 +1,36 @@ { - "name": "zokrates", - "displayName": "zokrates", - "description": "Syntax highlighting for the ZoKrates language", - "publisher": "zokrates", - "repository": "https://github.com/ZoKrates/ZoKrates", - "version": "0.0.1", - "engines": { - "vscode": "^1.53.0" - }, - "categories": [ - "Programming Languages" + "name": "zokrates", + "displayName": "zokrates", + "description": "Syntax highlighting for the ZoKrates language", + "publisher": "zokrates", + "repository": "https://github.com/ZoKrates/ZoKrates", + "version": "0.0.1", + "engines": { + "vscode": "^1.53.0" + }, + "categories": [ + "Programming Languages" + ], + "contributes": { + "languages": [ + { + "id": "zokrates", + "aliases": [ + "ZoKrates", + "zokrates" + ], + "extensions": [ + ".zok" + ], + "configuration": "./language-configuration.json" + } ], - "contributes": { - "languages": [{ - "id": "zokrates", - "aliases": ["ZoKrates", "zokrates"], - "extensions": [".zok"], - "configuration": "./language-configuration.json" - }], - "grammars": [{ - "language": "zokrates", - "scopeName": "source.zok", - "path": "./syntaxes/zokrates.tmLanguage.json" - }] - } -} \ No newline at end of file + "grammars": [ + { + "language": "zokrates", + "scopeName": "source.zok", + "path": "./syntaxes/zokrates.tmLanguage.json" + } + ] + } +} diff --git a/zokrates_parser/src/textmate/syntaxes/zokrates.tmLanguage.json b/zokrates_parser/src/textmate/syntaxes/zokrates.tmLanguage.json index 3730240c..c55cadb3 100644 --- a/zokrates_parser/src/textmate/syntaxes/zokrates.tmLanguage.json +++ b/zokrates_parser/src/textmate/syntaxes/zokrates.tmLanguage.json @@ -1,600 +1,637 @@ { - "$schema": "https://raw.githubusercontent.com/martinring/tmlanguage/master/tmlanguage.json", - "name": "ZoKrates", - "fileTypes": [ - "zok" - ], - "scopeName": "source.zok", - "patterns": [ - { - "comment": "attributes", - "name": "meta.attribute.zokrates", - "begin": "(#)(\\!?)(\\[)", - "beginCaptures": { - "1": { - "name": "punctuation.definition.attribute.zokrates" - }, - "2": { - "name": "keyword.operator.attribute.inner.zokrates" - }, - "3": { - "name": "punctuation.brackets.attribute.zokrates" - } - }, - "end": "\\]", - "endCaptures": { - "0": { - "name": "punctuation.brackets.attribute.zokrates" - } - }, - "patterns": [ - { - "include": "#block-comments" - }, - { - "include": "#comments" - }, - { - "include": "#keywords" - }, - { - "include": "#lifetimes" - }, - { - "include": "#punctuation" - }, - { - "include": "#strings" - }, - { - "include": "#gtypes" - }, - { - "include": "#types" - } - ] - }, - { - "include": "#block-comments" - }, - { - "include": "#comments" - }, - { - "include": "#constants" - }, - { - "include": "#functions" - }, - { - "include": "#types" - }, - { - "include": "#keywords" - }, - { - "include": "#punctuation" - }, - { - "include": "#strings" - }, - { - "include": "#variables" - } - ], - "repository": { - "comments": { - "patterns": [ - { - "comment": "line comments", - "name": "comment.line.double-slash.zokrates", - "match": "\\s*//.*" - } - ] - }, - "block-comments": { - "patterns": [ - { - "comment": "empty block comments", - "name": "comment.block.zokrates", - "match": "/\\*\\*/" - }, - { - "comment": "block comments", - "name": "comment.block.zokrates", - "begin": "/\\*(?!\\*)", - "end": "\\*/", - "patterns": [ - { - "include": "#block-comments" - } - ] - } - ] - }, - "constants": { - "patterns": [ - { - "comment": "ALL CAPS constants", - "name": "constant.other.caps.zokrates", - "match": "\\b[A-Z]{2}[A-Z0-9_]*\\b" - }, - { - "comment": "decimal integers and floats", - "name": "constant.numeric.decimal.zokrates", - "match": "\\b\\d[\\d_]*(u128|u16|u32|u64|u8|f)?\\b", - "captures": { - "5": { - "name": "entity.name.type.numeric.zokrates" - } - } - }, - { - "comment": "hexadecimal integers", - "name": "constant.numeric.hex.zokrates", - "match": "\\b0x[\\da-fA-F_]+\\b" - }, - { - "comment": "booleans", - "name": "constant.language.bool.zokrates", - "match": "\\b(true|false)\\b" - } - ] - }, - "imports": { - "patterns": [ - { - "comment": "explicit import statement", - "name": "meta.import.explicit.zokrates", - "match": "\\b(from)\\s+(\\\".*\\\")(import)\\s+([A-Za-z0-9_]+)\\s+((as)\\s+[A-Za-z0-9_]+)?\\b", - "patterns": [ - { - "include": "#block-comments" - }, - { - "include": "#comments" - }, - { - "include": "#keywords" - }, - { - "include": "#punctuation" - }, - { - "include": "#types" - }, - { - "include": "#strings" - } - ] - }, - { - "comment": "main import statement", - "name": "meta.import.explicit.zokrates", - "match": "\\b(import)\\s+(\\\".*\\\")\\s+((as)\\s+[A-Za-z0-9_]+)?\\b", - "patterns": [ - { - "include": "#block-comments" - }, - { - "include": "#comments" - }, - { - "include": "#keywords" - }, - { - "include": "#punctuation" - }, - { - "include": "#types" - }, - { - "include": "#strings" - } - ] - } - ] - }, - "functions": { - "patterns": [ - { - "comment": "function definition", - "name": "meta.function.definition.zokrates", - "begin": "\\b(def)\\s+([A-Za-z0-9_]+)((\\()|(<))", - "beginCaptures": { - "1": { - "name": "keyword.other.def.zokrates" - }, - "2": { - "name": "entity.name.function.zokrates" - }, - "4": { - "name": "punctuation.brackets.round.zokrates" - }, - "5": { - "name": "punctuation.brackets.angle.zokrates" - } - }, - "end": "\\:|;", - "endCaptures": { - "0": { - "name": "keyword.punctuation.colon.zokrates" - } - }, - "patterns": [ - { - "include": "#block-comments" - }, - { - "include": "#comments" - }, - { - "include": "#keywords" - }, - { - "include": "#constants" - }, - { - "include": "#functions" - }, - { - "include": "#punctuation" - }, - { - "include": "#strings" - }, - { - "include": "#types" - }, - { - "include": "#variables" - } - ] - }, - { - "comment": "function/method calls, chaining", - "name": "meta.function.call.zokrates", - "begin": "([A-Za-z0-9_]+)(\\()", - "beginCaptures": { - "1": { - "name": "entity.name.function.zokrates" - }, - "2": { - "name": "punctuation.brackets.round.zokrates" - } - }, - "end": "\\)", - "endCaptures": { - "0": { - "name": "punctuation.brackets.round.zokrates" - } - }, - "patterns": [ - { - "include": "#block-comments" - }, - { - "include": "#comments" - }, - { - "include": "#keywords" - }, - { - "include": "#constants" - }, - { - "include": "#functions" - }, - { - "include": "#punctuation" - }, - { - "include": "#strings" - }, - { - "include": "#types" - }, - { - "include": "#variables" - } - ] - }, - { - "comment": "function/method calls with turbofish", - "name": "meta.function.call.zokrates", - "begin": "([A-Za-z0-9_]+)(?=::<.*>\\()", - "beginCaptures": { - "1": { - "name": "entity.name.function.zokrates" - } - }, - "end": "\\)", - "endCaptures": { - "0": { - "name": "punctuation.brackets.round.zokrates" - } - }, - "patterns": [ - { - "include": "#block-comments" - }, - { - "include": "#comments" - }, - { - "include": "#keywords" - }, - { - "include": "#constants" - }, - { - "include": "#functions" - }, - { - "include": "#punctuation" - }, - { - "include": "#strings" - }, - { - "include": "#types" - }, - { - "include": "#variables" - } - ] - } - ] - }, - "keywords": { - "patterns": [ - { - "comment": "argument visibility", - "name": "keyword.visibility.zokrates", - "match": "\\b(public|private)\\b" - }, - { - "comment": "control flow keywords", - "name": "keyword.control.zokrates", - "match": "\\b(do|else|for|do|endfor|if|then|fi|return|assert)\\b" - }, - { - "comment": "storage keywords", - "name": "storage.type.zokrates", - "match": "\\b(struct)\\b" - }, - { - "comment": "def", - "name": "keyword.other.def.zokrates", - "match": "\\bdef\\b" - }, - { - "comment": "import keywords", - "name": "keyword.other.import.zokrates", - "match": "\\b(import|from|as)\\b" - }, - { - "comment": "logical operators", - "name": "keyword.operator.logical.zokrates", - "match": "(\\^|\\||\\|\\||&|&&|<<|>>|!)(?!=)" - }, - { - "comment": "single equal", - "name": "keyword.operator.assignment.equal.zokrates", - "match": "(?])=(?!=|>)" - }, - { - "comment": "comparison operators", - "name": "keyword.operator.comparison.zokrates", - "match": "(=(=)?(?!>)|!=|<=|(?=)" - }, - { - "comment": "math operators", - "name": "keyword.operator.math.zokrates", - "match": "(([+%]|(\\*(?!\\w)))(?!=))|(-(?!>))|(/(?!/))" - }, - { - "comment": "less than, greater than (special case)", - "match": "(?:\\b|(?:(\\))|(\\])|(\\})))[ \\t]+([<>])[ \\t]+(?:\\b|(?:(\\()|(\\[)|(\\{)))", - "captures": { - "1": { - "name": "punctuation.brackets.round.zokrates" - }, - "2": { - "name": "punctuation.brackets.square.zokrates" - }, - "3": { - "name": "punctuation.brackets.curly.zokrates" - }, - "4": { - "name": "keyword.operator.comparison.zokrates" - }, - "5": { - "name": "punctuation.brackets.round.zokrates" - }, - "6": { - "name": "punctuation.brackets.square.zokrates" - }, - "7": { - "name": "punctuation.brackets.curly.zokrates" - } - } - }, - { - "comment": "dot access", - "name": "keyword.operator.access.dot.zokrates", - "match": "\\.(?!\\.)" - }, - { - "comment": "ranges, range patterns", - "name": "keyword.operator.range.zokrates", - "match": "\\.{2}(=|\\.)?" - }, - { - "comment": "colon", - "name": "keyword.operator.colon.zokrates", - "match": ":(?!:)" - }, - { - "comment": "dashrocket, skinny arrow", - "name": "keyword.operator.arrow.skinny.zokrates", - "match": "->" - } - ] - }, - "types": { - "patterns": [ - { - "comment": "numeric types", - "match": "(?", - "endCaptures": { - "0": { - "name": "punctuation.brackets.angle.zokrates" - } - }, - "patterns": [ - { - "include": "#block-comments" - }, - { - "include": "#comments" - }, - { - "include": "#keywords" - }, - { - "include": "#punctuation" - }, - { - "include": "#types" - }, - { - "include": "#variables" - } - ] - }, - { - "comment": "primitive types", - "name": "entity.name.type.primitive.zokrates", - "match": "\\b(bool)\\b" - }, - { - "comment": "struct declarations", - "match": "\\b(struct)\\s+([A-Z][A-Za-z0-9]*)\\b", - "captures": { - "1": { - "name": "storage.type.zokrates" - }, - "2": { - "name": "entity.name.type.struct.zokrates" - } - } - }, - { - "comment": "types", - "name": "entity.name.type.zokrates", - "match": "\\b[A-Z][A-Za-z0-9]*\\b(?!!)" - } - ] - }, - "punctuation": { - "patterns": [ - { - "comment": "comma", - "name": "punctuation.comma.zokrates", - "match": "," - }, - { - "comment": "parentheses, round brackets", - "name": "punctuation.brackets.round.zokrates", - "match": "[()]" - }, - { - "comment": "square brackets", - "name": "punctuation.brackets.square.zokrates", - "match": "[\\[\\]]" - }, - { - "comment": "angle brackets", - "name": "punctuation.brackets.angle.zokrates", - "match": "(?]" - } - ] - }, - "strings": { - "patterns": [ - { - "comment": "double-quoted strings and byte strings", - "name": "string.quoted.double.zokrates", - "begin": "(b?)(\")", - "beginCaptures": { - "1": { - "name": "string.quoted.byte.raw.zokrates" - }, - "2": { - "name": "punctuation.definition.string.zokrates" - } - }, - "end": "\"", - "endCaptures": { - "0": { - "name": "punctuation.definition.string.zokrates" - } - } - }, - { - "comment": "double-quoted raw strings and raw byte strings", - "name": "string.quoted.double.zokrates", - "begin": "(b?r)(#*)(\")", - "beginCaptures": { - "1": { - "name": "string.quoted.byte.raw.zokrates" - }, - "2": { - "name": "punctuation.definition.string.raw.zokrates" - }, - "3": { - "name": "punctuation.definition.string.zokrates" - } - }, - "end": "(\")(\\2)", - "endCaptures": { - "1": { - "name": "punctuation.definition.string.zokrates" - }, - "2": { - "name": "punctuation.definition.string.raw.zokrates" - } - } - } - ] - }, - "variables": { - "patterns": [ - { - "comment": "variables", - "name": "variable.other.zokrates", - "match": "\\b(?\\()", + "beginCaptures": { + "1": { + "name": "entity.name.function.zokrates" + } + }, + "end": "\\)", + "endCaptures": { + "0": { + "name": "punctuation.brackets.round.zokrates" + } + }, + "patterns": [ + { + "include": "#block-comments" + }, + { + "include": "#comments" + }, + { + "include": "#keywords" + }, + { + "include": "#constants" + }, + { + "include": "#functions" + }, + { + "include": "#punctuation" + }, + { + "include": "#strings" + }, + { + "include": "#types" + }, + { + "include": "#variables" + } + ] + } + ] + }, + "keywords": { + "patterns": [ + { + "comment": "argument visibility", + "name": "keyword.visibility.zokrates", + "match": "\\b(public|private)\\b" + }, + { + "comment": "control flow keywords", + "name": "keyword.control.zokrates", + "match": "\\b(do|else|for|do|endfor|if|then|fi|return|assert)\\b" + }, + { + "comment": "storage keywords", + "name": "storage.type.zokrates", + "match": "\\b(struct)\\b" + }, + { + "comment": "const", + "name": "keyword.other.const.zokrates", + "match": "\\bconst\\b" + }, + { + "comment": "def", + "name": "keyword.other.def.zokrates", + "match": "\\bdef\\b" + }, + { + "comment": "import keywords", + "name": "keyword.other.import.zokrates", + "match": "\\b(import|from|as)\\b" + }, + { + "comment": "logical operators", + "name": "keyword.operator.logical.zokrates", + "match": "(\\^|\\||\\|\\||&|&&|<<|>>|!)(?!=)" + }, + { + "comment": "single equal", + "name": "keyword.operator.assignment.equal.zokrates", + "match": "(?])=(?!=|>)" + }, + { + "comment": "comparison operators", + "name": "keyword.operator.comparison.zokrates", + "match": "(=(=)?(?!>)|!=|<=|(?=)" + }, + { + "comment": "math operators", + "name": "keyword.operator.math.zokrates", + "match": "(([+%]|(\\*(?!\\w)))(?!=))|(-(?!>))|(/(?!/))" + }, + { + "comment": "less than, greater than (special case)", + "match": "(?:\\b|(?:(\\))|(\\])|(\\})))[ \\t]+([<>])[ \\t]+(?:\\b|(?:(\\()|(\\[)|(\\{)))", + "captures": { + "1": { + "name": "punctuation.brackets.round.zokrates" + }, + "2": { + "name": "punctuation.brackets.square.zokrates" + }, + "3": { + "name": "punctuation.brackets.curly.zokrates" + }, + "4": { + "name": "keyword.operator.comparison.zokrates" + }, + "5": { + "name": "punctuation.brackets.round.zokrates" + }, + "6": { + "name": "punctuation.brackets.square.zokrates" + }, + "7": { + "name": "punctuation.brackets.curly.zokrates" + } + } + }, + { + "comment": "dot access", + "name": "keyword.operator.access.dot.zokrates", + "match": "\\.(?!\\.)" + }, + { + "comment": "ranges, range patterns", + "name": "keyword.operator.range.zokrates", + "match": "\\.{2}(=|\\.)?" + }, + { + "comment": "colon", + "name": "keyword.operator.colon.zokrates", + "match": ":(?!:)" + }, + { + "comment": "dashrocket, skinny arrow", + "name": "keyword.operator.arrow.skinny.zokrates", + "match": "->" + } + ] + }, + "types": { + "patterns": [ + { + "comment": "numeric types", + "match": "(?", + "endCaptures": { + "0": { + "name": "punctuation.brackets.angle.zokrates" + } + }, + "patterns": [ + { + "include": "#block-comments" + }, + { + "include": "#comments" + }, + { + "include": "#keywords" + }, + { + "include": "#punctuation" + }, + { + "include": "#types" + }, + { + "include": "#variables" + } + ] + }, + { + "comment": "primitive types", + "name": "entity.name.type.primitive.zokrates", + "match": "\\b(bool)\\b" + }, + { + "comment": "struct declarations", + "match": "\\b(struct)\\s+([A-Z][A-Za-z0-9]*)\\b", + "captures": { + "1": { + "name": "storage.type.zokrates" + }, + "2": { + "name": "entity.name.type.struct.zokrates" + } + } + }, + { + "comment": "types", + "name": "entity.name.type.zokrates", + "match": "\\b[A-Z][A-Za-z0-9]*\\b(?!!)" + } + ] + }, + "punctuation": { + "patterns": [ + { + "comment": "comma", + "name": "punctuation.comma.zokrates", + "match": "," + }, + { + "comment": "parentheses, round brackets", + "name": "punctuation.brackets.round.zokrates", + "match": "[()]" + }, + { + "comment": "square brackets", + "name": "punctuation.brackets.square.zokrates", + "match": "[\\[\\]]" + }, + { + "comment": "angle brackets", + "name": "punctuation.brackets.angle.zokrates", + "match": "(?]" + } + ] + }, + "strings": { + "patterns": [ + { + "comment": "double-quoted strings and byte strings", + "name": "string.quoted.double.zokrates", + "begin": "(b?)(\")", + "beginCaptures": { + "1": { + "name": "string.quoted.byte.raw.zokrates" + }, + "2": { + "name": "punctuation.definition.string.zokrates" + } + }, + "end": "\"", + "endCaptures": { + "0": { + "name": "punctuation.definition.string.zokrates" + } + } + }, + { + "comment": "double-quoted raw strings and raw byte strings", + "name": "string.quoted.double.zokrates", + "begin": "(b?r)(#*)(\")", + "beginCaptures": { + "1": { + "name": "string.quoted.byte.raw.zokrates" + }, + "2": { + "name": "punctuation.definition.string.raw.zokrates" + }, + "3": { + "name": "punctuation.definition.string.zokrates" + } + }, + "end": "(\")(\\2)", + "endCaptures": { + "1": { + "name": "punctuation.definition.string.zokrates" + }, + "2": { + "name": "punctuation.definition.string.raw.zokrates" + } + } + } + ] + }, + "variables": { + "patterns": [ + { + "comment": "variables", + "name": "variable.other.zokrates", + "match": "\\b(?\()' + beginCaptures: + '1': {name: entity.name.function.zokrates} + end: \) + endCaptures: + '0': {name: punctuation.brackets.round.zokrates} + patterns: + - {include: '#block-comments'} + - {include: '#comments'} + - {include: '#keywords'} + - {include: '#constants'} + - {include: '#functions'} + - {include: '#punctuation'} + - {include: '#strings'} + - {include: '#types'} + - {include: '#variables'} + keywords: + patterns: + - + comment: 'argument visibility' + name: keyword.visibility.zokrates + match: \b(public|private)\b + - + comment: 'control flow keywords' + name: keyword.control.zokrates + match: \b(do|else|for|do|endfor|if|then|fi|return|assert)\b + - + comment: 'storage keywords' + name: storage.type.zokrates + match: \b(struct)\b + - + comment: const + name: keyword.other.const.zokrates + match: \bconst\b + - + comment: def + name: keyword.other.def.zokrates + match: \bdef\b + - + comment: 'import keywords' + name: keyword.other.import.zokrates + match: \b(import|from|as)\b + - + comment: 'logical operators' + name: keyword.operator.logical.zokrates + match: '(\^|\||\|\||&|&&|<<|>>|!)(?!=)' + - + comment: 'single equal' + name: keyword.operator.assignment.equal.zokrates + match: '(?])=(?!=|>)' + - + comment: 'comparison operators' + name: keyword.operator.comparison.zokrates + match: '(=(=)?(?!>)|!=|<=|(?=)' + - + comment: 'math operators' + name: keyword.operator.math.zokrates + match: '(([+%]|(\*(?!\w)))(?!=))|(-(?!>))|(/(?!/))' + - + comment: 'less than, greater than (special case)' + match: '(?:\b|(?:(\))|(\])|(\})))[ \t]+([<>])[ \t]+(?:\b|(?:(\()|(\[)|(\{)))' + captures: + '1': {name: punctuation.brackets.round.zokrates} + '2': {name: punctuation.brackets.square.zokrates} + '3': {name: punctuation.brackets.curly.zokrates} + '4': {name: keyword.operator.comparison.zokrates} + '5': {name: punctuation.brackets.round.zokrates} + '6': {name: punctuation.brackets.square.zokrates} + '7': {name: punctuation.brackets.curly.zokrates} + - + comment: 'dot access' + name: keyword.operator.access.dot.zokrates + match: '\.(?!\.)' + - + comment: 'ranges, range patterns' + name: keyword.operator.range.zokrates + match: '\.{2}(=|\.)?' + - + comment: colon + name: keyword.operator.colon.zokrates + match: ':(?!:)' + - + comment: 'dashrocket, skinny arrow' + name: keyword.operator.arrow.skinny.zokrates + match: '->' + types: + patterns: + - + comment: 'numeric types' + match: '(?' + endCaptures: + '0': {name: punctuation.brackets.angle.zokrates} + patterns: + - {include: '#block-comments'} + - {include: '#comments'} + - {include: '#keywords'} + - {include: '#punctuation'} + - {include: '#types'} + - {include: '#variables'} + - + comment: 'primitive types' + name: entity.name.type.primitive.zokrates + match: \b(bool)\b + - + comment: 'struct declarations' + match: '\b(struct)\s+([A-Z][A-Za-z0-9]*)\b' + captures: + '1': {name: storage.type.zokrates} + '2': {name: entity.name.type.struct.zokrates} + - + comment: types + name: entity.name.type.zokrates + match: '\b[A-Z][A-Za-z0-9]*\b(?!!)' + punctuation: + patterns: + - + comment: comma + name: punctuation.comma.zokrates + match: ',' + - + comment: 'parentheses, round brackets' + name: punctuation.brackets.round.zokrates + match: '[()]' + - + comment: 'square brackets' + name: punctuation.brackets.square.zokrates + match: '[\[\]]' + - + comment: 'angle brackets' + name: punctuation.brackets.angle.zokrates + match: '(?]' + strings: + patterns: + - + comment: 'double-quoted strings and byte strings' + name: string.quoted.double.zokrates + begin: '(b?)(")' + beginCaptures: + '1': {name: string.quoted.byte.raw.zokrates} + '2': {name: punctuation.definition.string.zokrates} + end: '"' + endCaptures: + '0': {name: punctuation.definition.string.zokrates} + - + comment: 'double-quoted raw strings and raw byte strings' + name: string.quoted.double.zokrates + begin: '(b?r)(#*)(")' + beginCaptures: + '1': {name: string.quoted.byte.raw.zokrates} + '2': {name: punctuation.definition.string.raw.zokrates} + '3': {name: punctuation.definition.string.zokrates} + end: '(")(\2)' + endCaptures: + '1': {name: punctuation.definition.string.zokrates} + '2': {name: punctuation.definition.string.raw.zokrates} + variables: + patterns: + - + comment: variables + name: variable.other.zokrates + match: '\b(?" ~ ( "(" ~ type_list ~ ")" | ty ))? } constant_generics_declaration = _{ "<" ~ constant_generics_list ~ ">" } constant_generics_list = _{ identifier ~ ("," ~ identifier)* } @@ -144,7 +145,7 @@ op_sub = {"-"} op_mul = {"*"} op_div = {"/"} op_rem = {"%"} -op_pow = @{"**"} +op_pow = @{"**"} op_not = {"!"} op_neg = {"-"} op_pos = {"+"} @@ -159,6 +160,6 @@ COMMENT = _{ ("/*" ~ (!"*/" ~ ANY)* ~ "*/") | ("//" ~ (!NEWLINE ~ ANY)*) } // the ordering of reserved keywords matters: if "as" is before "assert", then "assert" gets parsed as (as)(sert) and incorrectly // accepted -keyword = @{"assert"|"as"|"bool"|"byte"|"def"|"do"|"else"|"endfor"|"export"|"false"|"field"|"for"|"if"|"then"|"fi"|"import"|"from"| +keyword = @{"assert"|"as"|"bool"|"byte"|"const"|"def"|"do"|"else"|"endfor"|"export"|"false"|"field"|"for"|"if"|"then"|"fi"|"import"|"from"| "in"|"private"|"public"|"return"|"struct"|"true"|"u8"|"u16"|"u32"|"u64" } diff --git a/zokrates_pest_ast/src/lib.rs b/zokrates_pest_ast/src/lib.rs index 31fae3b4..be50f71d 100644 --- a/zokrates_pest_ast/src/lib.rs +++ b/zokrates_pest_ast/src/lib.rs @@ -10,14 +10,14 @@ extern crate lazy_static; pub use ast::{ Access, Arguments, ArrayAccess, ArrayInitializerExpression, ArrayType, AssertionStatement, Assignee, AssigneeAccess, BasicOrStructType, BasicType, BinaryExpression, BinaryOperator, - CallAccess, ConstantGenericValue, DecimalLiteralExpression, DecimalNumber, DecimalSuffix, - DefinitionStatement, ExplicitGenerics, Expression, FieldType, File, FromExpression, Function, - HexLiteralExpression, HexNumberExpression, IdentifierExpression, ImportDirective, ImportSource, - ImportSymbol, InlineArrayExpression, InlineStructExpression, InlineStructMember, - IterationStatement, LiteralExpression, OptionallyTypedAssignee, Parameter, PostfixExpression, - Range, RangeOrExpression, ReturnStatement, Span, Spread, SpreadOrExpression, Statement, - StructDefinition, StructField, TernaryExpression, ToExpression, Type, UnaryExpression, - UnaryOperator, Underscore, Visibility, + CallAccess, ConstantDefinition, ConstantGenericValue, DecimalLiteralExpression, DecimalNumber, + DecimalSuffix, DefinitionStatement, ExplicitGenerics, Expression, FieldType, File, + FromExpression, Function, HexLiteralExpression, HexNumberExpression, IdentifierExpression, + ImportDirective, ImportSource, ImportSymbol, InlineArrayExpression, InlineStructExpression, + InlineStructMember, IterationStatement, LiteralExpression, OptionallyTypedAssignee, Parameter, + PostfixExpression, Range, RangeOrExpression, ReturnStatement, Span, Spread, SpreadOrExpression, + Statement, StructDefinition, StructField, TernaryExpression, ToExpression, Type, + UnaryExpression, UnaryOperator, Underscore, Visibility, }; mod ast { @@ -111,6 +111,7 @@ mod ast { pub pragma: Option>, pub imports: Vec>, pub structs: Vec>, + pub constants: Vec>, pub functions: Vec>, pub eoi: EOI, #[pest_ast(outer())] @@ -164,6 +165,16 @@ mod ast { pub span: Span<'ast>, } + #[derive(Debug, FromPest, PartialEq, Clone)] + #[pest_ast(rule(Rule::const_definition))] + pub struct ConstantDefinition<'ast> { + pub ty: Type<'ast>, + pub id: IdentifierExpression<'ast>, + pub expression: Expression<'ast>, + #[pest_ast(outer())] + pub span: Span<'ast>, + } + #[derive(Debug, FromPest, PartialEq, Clone)] #[pest_ast(rule(Rule::import_directive))] pub enum ImportDirective<'ast> { @@ -1047,6 +1058,7 @@ mod tests { Ok(File { pragma: None, structs: vec![], + constants: vec![], functions: vec![Function { generics: vec![], id: IdentifierExpression { @@ -1107,6 +1119,7 @@ mod tests { Ok(File { pragma: None, structs: vec![], + constants: vec![], functions: vec![Function { generics: vec![], id: IdentifierExpression { @@ -1191,6 +1204,7 @@ mod tests { Ok(File { pragma: None, structs: vec![], + constants: vec![], functions: vec![Function { generics: vec![], id: IdentifierExpression { @@ -1259,6 +1273,7 @@ mod tests { Ok(File { pragma: None, structs: vec![], + constants: vec![], functions: vec![Function { generics: vec![], id: IdentifierExpression { @@ -1299,6 +1314,7 @@ mod tests { Ok(File { pragma: None, structs: vec![], + constants: vec![], functions: vec![Function { generics: vec![], id: IdentifierExpression {