From abc5d23b1f48e4c5379f10a5add834c44c5fe165 Mon Sep 17 00:00:00 2001 From: schaeff Date: Mon, 22 Jul 2019 14:59:15 +0200 Subject: [PATCH 01/35] add struct to parser, ambiguity in multidef apparently --- zokrates_parser/src/lib.rs | 42 +++++++++++++++++++++++++++++++ zokrates_parser/src/zokrates.pest | 13 ++++++++-- 2 files changed, 53 insertions(+), 2 deletions(-) diff --git a/zokrates_parser/src/lib.rs b/zokrates_parser/src/lib.rs index 6e92d8b1..6f4accf1 100644 --- a/zokrates_parser/src/lib.rs +++ b/zokrates_parser/src/lib.rs @@ -125,6 +125,48 @@ mod tests { }; } + #[test] + fn parse_struct_def() { + parses_to! { + parser: ZoKratesParser, + input: "struct Foo { foo: field, bar: field[2] } + ", + rule: Rule::ty_struct_definition, + tokens: [ + ty_struct_definition(0, 40, [ + identifier(7, 10), + struct_field(13, 23, [ + identifier(13, 16), + ty(18, 23, [ + ty_basic(18, 23, [ + ty_field(18, 23) + ]) + ]) + ]), + struct_field(25, 39, [ + identifier(25, 28), + ty(30, 39, [ + ty_array(30, 39, [ + ty_basic(30, 35, [ + ty_field(30, 35) + ]), + expression(36, 37, [ + term(36, 37, [ + primary_expression(36, 37, [ + constant(36, 37, [ + decimal_number(36, 37) + ]) + ]) + ]) + ]) + ]) + ]) + ]) + ]) + ] + }; + } + #[test] fn parse_invalid_identifier_because_keyword() { fails_with! { diff --git a/zokrates_parser/src/zokrates.pest b/zokrates_parser/src/zokrates.pest index efa9411d..8cb67aa6 100644 --- a/zokrates_parser/src/zokrates.pest +++ b/zokrates_parser/src/zokrates.pest @@ -3,7 +3,10 @@ * Author: Jacob Eberhardt, Thibaut Schaeffer */ -file = { SOI ~ NEWLINE* ~ import_directive* ~ NEWLINE* ~ function_definition* ~ EOI } +// intuition for current issue: since type annotations are optional for multidef and struct types have arbitrary names, there is an ambiguity +// when parsing `a, b = foo()` as `a` could be the type or the identifier + +file = { SOI ~ NEWLINE* ~ import_directive* ~ NEWLINE* ~ ty_struct_definition* ~ NEWLINE* ~ function_definition* ~ EOI } import_directive = {"import" ~ "\"" ~ import_source ~ "\"" ~ ("as" ~ identifier)? ~ NEWLINE+} import_source = @{(!"\"" ~ ANY)*} function_definition = {"def" ~ identifier ~ "(" ~ parameter_list ~ ")" ~ "->" ~ "(" ~ type_list ~ ")" ~ ":" ~ NEWLINE* ~ statement* } @@ -17,8 +20,14 @@ ty_bool = {"bool"} ty_basic = { ty_field | ty_bool } // (unidimensional for now) arrays of (basic for now) types ty_array = { ty_basic ~ ("[" ~ expression ~ "]")+ } -ty = { ty_array | ty_basic } +ty = { ty_array | ty_basic | ty_struct } type_list = _{(ty ~ ("," ~ ty)*)?} +// structs +ty_struct = { identifier } +// type definitions +ty_struct_definition = { "struct" ~ identifier ~ "{" ~ NEWLINE* ~ struct_field_list ~ NEWLINE* ~ "}" } +struct_field_list = _{(struct_field ~ ("," ~ NEWLINE* ~ struct_field)*)? ~ ","? } +struct_field = { identifier ~ ":" ~ ty } vis_private = {"private"} vis_public = {"public"} From 2f7b26034a2a6c96e2f36fd9b96834b1b97dc6cd Mon Sep 17 00:00:00 2001 From: schaeff Date: Fri, 26 Jul 2019 17:35:31 +0200 Subject: [PATCH 02/35] fix parsing of multidef --- zokrates_parser/src/lib.rs | 20 ++++++++++++++++++++ zokrates_parser/src/zokrates.pest | 5 +---- 2 files changed, 21 insertions(+), 4 deletions(-) diff --git a/zokrates_parser/src/lib.rs b/zokrates_parser/src/lib.rs index 6f4accf1..f89da45f 100644 --- a/zokrates_parser/src/lib.rs +++ b/zokrates_parser/src/lib.rs @@ -113,6 +113,26 @@ mod tests { }; } + #[test] + fn parse_single_def_to_multi() { + parses_to! { + parser: ZoKratesParser, + input: "a = foo() + ", + rule: Rule::statement, + tokens: [ + statement(0, 10, [ + multi_assignment_statement(0, 9, [ + optionally_typed_identifier(0, 1, [ + identifier(0, 1) + ]), + identifier(4, 7), + ]) + ]) + ] + }; + } + #[test] fn parse_invalid_identifier() { fails_with! { diff --git a/zokrates_parser/src/zokrates.pest b/zokrates_parser/src/zokrates.pest index 8cb67aa6..6ff7e248 100644 --- a/zokrates_parser/src/zokrates.pest +++ b/zokrates_parser/src/zokrates.pest @@ -3,9 +3,6 @@ * Author: Jacob Eberhardt, Thibaut Schaeffer */ -// intuition for current issue: since type annotations are optional for multidef and struct types have arbitrary names, there is an ambiguity -// when parsing `a, b = foo()` as `a` could be the type or the identifier - file = { SOI ~ NEWLINE* ~ import_directive* ~ NEWLINE* ~ ty_struct_definition* ~ NEWLINE* ~ function_definition* ~ EOI } import_directive = {"import" ~ "\"" ~ import_source ~ "\"" ~ ("as" ~ identifier)? ~ NEWLINE+} import_source = @{(!"\"" ~ ANY)*} @@ -51,7 +48,7 @@ assignment_statement = {assignee ~ "=" ~ expression } // TODO: Is this optimal? expression_statement = {expression} optionally_typed_identifier_list = _{ optionally_typed_identifier ~ ("," ~ optionally_typed_identifier)* } -optionally_typed_identifier = { ty? ~ identifier } +optionally_typed_identifier = { (identifier) | (ty ~ identifier) } // we don't use { ty? ~ identifier } as with a single token, it gets parsed as `ty` but we want `identifier` // Expressions expression_list = _{(expression ~ ("," ~ expression)*)?} From 192f854b0270e1495ec6097dfa5d82455fdb01c0 Mon Sep 17 00:00:00 2001 From: schaeff Date: Fri, 26 Jul 2019 17:54:20 +0200 Subject: [PATCH 03/35] add struct to ast --- zokrates_pest_ast/src/lib.rs | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/zokrates_pest_ast/src/lib.rs b/zokrates_pest_ast/src/lib.rs index 68a7467e..70b1dca3 100644 --- a/zokrates_pest_ast/src/lib.rs +++ b/zokrates_pest_ast/src/lib.rs @@ -155,12 +155,31 @@ mod ast { #[pest_ast(rule(Rule::file))] pub struct File<'ast> { pub imports: Vec>, + pub structs: Vec>, pub functions: Vec>, pub eoi: EOI, #[pest_ast(outer())] pub span: Span<'ast>, } + #[derive(Debug, FromPest, PartialEq, Clone)] + #[pest_ast(rule(Rule::ty_struct_definition))] + pub struct StructDefinition<'ast> { + pub id: IdentifierExpression<'ast>, + pub fields: Vec>, + #[pest_ast(outer())] + pub span: Span<'ast>, + } + + #[derive(Debug, FromPest, PartialEq, Clone)] + #[pest_ast(rule(Rule::struct_field))] + pub struct StructField<'ast> { + pub id: IdentifierExpression<'ast>, + pub ty: Type<'ast>, + #[pest_ast(outer())] + pub span: Span<'ast>, + } + #[derive(Debug, FromPest, PartialEq, Clone)] #[pest_ast(rule(Rule::function_definition))] pub struct Function<'ast> { @@ -195,6 +214,7 @@ mod ast { pub enum Type<'ast> { Basic(BasicType<'ast>), Array(ArrayType<'ast>), + Struct(StructType<'ast>), } #[derive(Debug, FromPest, PartialEq, Clone)] @@ -219,7 +239,15 @@ mod ast { #[derive(Debug, FromPest, PartialEq, Clone)] #[pest_ast(rule(Rule::ty_bool))] + pub struct StructType<'ast> { + #[pest_ast(outer())] + pub span: Span<'ast>, + } + + #[derive(Debug, FromPest, PartialEq, Clone)] + #[pest_ast(rule(Rule::ty_struct))] pub struct BooleanType<'ast> { + id: IdentifierExpression<'ast>, #[pest_ast(outer())] pub span: Span<'ast>, } @@ -700,6 +728,7 @@ mod tests { assert_eq!( generate_ast(&source), Ok(File { + structs: vec![], functions: vec![Function { id: IdentifierExpression { value: String::from("main"), @@ -749,6 +778,7 @@ mod tests { assert_eq!( generate_ast(&source), Ok(File { + structs: vec![], functions: vec![Function { id: IdentifierExpression { value: String::from("main"), @@ -816,6 +846,7 @@ mod tests { assert_eq!( generate_ast(&source), Ok(File { + structs: vec![], functions: vec![Function { id: IdentifierExpression { value: String::from("main"), @@ -870,6 +901,7 @@ mod tests { assert_eq!( generate_ast(&source), Ok(File { + structs: vec![], functions: vec![Function { id: IdentifierExpression { value: String::from("main"), @@ -902,6 +934,7 @@ mod tests { assert_eq!( generate_ast(&source), Ok(File { + structs: vec![], functions: vec![Function { id: IdentifierExpression { value: String::from("main"), From f35abd559c7b8d6556a7ddcc81b33a0928423736 Mon Sep 17 00:00:00 2001 From: schaeff Date: Fri, 26 Jul 2019 18:30:01 +0200 Subject: [PATCH 04/35] add member access to parser --- zokrates_core/src/absy/from_ast.rs | 2 ++ zokrates_parser/src/lib.rs | 4 ++-- zokrates_parser/src/zokrates.pest | 4 ++-- zokrates_pest_ast/src/lib.rs | 32 ++++++++++++++++++++---------- 4 files changed, 28 insertions(+), 14 deletions(-) diff --git a/zokrates_core/src/absy/from_ast.rs b/zokrates_core/src/absy/from_ast.rs index bef7b4d4..c67fcc8c 100644 --- a/zokrates_core/src/absy/from_ast.rs +++ b/zokrates_core/src/absy/from_ast.rs @@ -467,6 +467,7 @@ impl<'ast, T: Field> From> for absy::ExpressionNod absy::Expression::Select(box acc, box absy::RangeOrExpression::from(a.expression)) .span(a.span) } + pest::Access::Member(_) => unimplemented!("member access is not implemented yet"), }) } } @@ -553,6 +554,7 @@ impl<'ast> From> for Type { }) .unwrap() } + pest::Type::Struct(s) => unimplemented!("struct declarations not supported yet"), } } } diff --git a/zokrates_parser/src/lib.rs b/zokrates_parser/src/lib.rs index f89da45f..b104cd9b 100644 --- a/zokrates_parser/src/lib.rs +++ b/zokrates_parser/src/lib.rs @@ -117,8 +117,8 @@ mod tests { fn parse_single_def_to_multi() { parses_to! { parser: ZoKratesParser, - input: "a = foo() - ", + input: r#"a = foo() + "#, rule: Rule::statement, tokens: [ statement(0, 10, [ diff --git a/zokrates_parser/src/zokrates.pest b/zokrates_parser/src/zokrates.pest index 6ff7e248..2abd642c 100644 --- a/zokrates_parser/src/zokrates.pest +++ b/zokrates_parser/src/zokrates.pest @@ -15,7 +15,6 @@ parameter = {vis? ~ ty ~ identifier} ty_field = {"field"} ty_bool = {"bool"} ty_basic = { ty_field | ty_bool } -// (unidimensional for now) arrays of (basic for now) types ty_array = { ty_basic ~ ("[" ~ expression ~ "]")+ } ty = { ty_array | ty_basic | ty_struct } type_list = _{(ty ~ ("," ~ ty)*)?} @@ -63,9 +62,10 @@ to_expression = { expression } conditional_expression = { "if" ~ expression ~ "then" ~ expression ~ "else" ~ expression ~ "fi"} postfix_expression = { identifier ~ access+ } // we force there to be at least one access, otherwise this matches single identifiers. Not sure that's what we want. -access = { array_access | call_access } +access = { array_access | call_access | member_access } array_access = { "[" ~ range_or_expression ~ "]" } call_access = { "(" ~ expression_list ~ ")" } +member_access = { "." ~ identifier } primary_expression = { identifier | constant diff --git a/zokrates_pest_ast/src/lib.rs b/zokrates_pest_ast/src/lib.rs index 70b1dca3..48b451da 100644 --- a/zokrates_pest_ast/src/lib.rs +++ b/zokrates_pest_ast/src/lib.rs @@ -212,16 +212,16 @@ mod ast { #[derive(Debug, FromPest, PartialEq, Clone)] #[pest_ast(rule(Rule::ty))] pub enum Type<'ast> { - Basic(BasicType<'ast>), + Basic(BasicType), Array(ArrayType<'ast>), Struct(StructType<'ast>), } #[derive(Debug, FromPest, PartialEq, Clone)] #[pest_ast(rule(Rule::ty_basic))] - pub enum BasicType<'ast> { + pub enum BasicType { Field(FieldType), - Boolean(BooleanType<'ast>), + Boolean(BooleanType), } #[derive(Debug, FromPest, PartialEq, Clone)] @@ -231,7 +231,7 @@ mod ast { #[derive(Debug, FromPest, PartialEq, Clone)] #[pest_ast(rule(Rule::ty_array))] pub struct ArrayType<'ast> { - pub ty: BasicType<'ast>, + pub ty: BasicType, pub size: Vec>, #[pest_ast(outer())] pub span: Span<'ast>, @@ -239,14 +239,11 @@ mod ast { #[derive(Debug, FromPest, PartialEq, Clone)] #[pest_ast(rule(Rule::ty_bool))] - pub struct StructType<'ast> { - #[pest_ast(outer())] - pub span: Span<'ast>, - } + pub struct BooleanType {} #[derive(Debug, FromPest, PartialEq, Clone)] #[pest_ast(rule(Rule::ty_struct))] - pub struct BooleanType<'ast> { + pub struct StructType<'ast> { id: IdentifierExpression<'ast>, #[pest_ast(outer())] pub span: Span<'ast>, @@ -476,6 +473,7 @@ mod ast { pub enum Access<'ast> { Call(CallAccess<'ast>), Select(ArrayAccess<'ast>), + Member(MemberAccess<'ast>), } #[derive(Debug, FromPest, PartialEq, Clone)] @@ -494,6 +492,14 @@ mod ast { pub span: Span<'ast>, } + #[derive(Debug, FromPest, PartialEq, Clone)] + #[pest_ast(rule(Rule::member_access))] + pub struct MemberAccess<'ast> { + pub id: IdentifierExpression<'ast>, + #[pest_ast(outer())] + pub span: Span<'ast>, + } + #[derive(Debug, PartialEq, Clone)] pub struct BinaryExpression<'ast> { pub op: BinaryOperator, @@ -1002,13 +1008,19 @@ mod tests { #[test] fn playground() { let source = r#"import "heyman" as yo + + struct Foo { + foo: field[2], + bar: Bar + } + def main(private field[23] a) -> (bool[234 + 6]): field a = 1 a[32 + x][55] = y for field i in 0..3 do a == 1 + 2 + 3+ 4+ 5+ 6+ 6+ 7+ 8 + 4+ 5+ 3+ 4+ 2+ 3 endfor - a == 1 + a.member == 1 return a "#; let res = generate_ast(&source); From 0bd41abf2a34eb9644819f9a8dd109028141d805 Mon Sep 17 00:00:00 2001 From: schaeff Date: Fri, 26 Jul 2019 19:09:54 +0200 Subject: [PATCH 05/35] add struct definition to untyped ast --- zokrates_core/src/absy/from_ast.rs | 11 +++++ zokrates_core/src/absy/mod.rs | 67 +++++++++++++++++++++++++++++- zokrates_core/src/absy/node.rs | 3 ++ zokrates_core/src/imports.rs | 1 + zokrates_pest_ast/src/lib.rs | 4 +- 5 files changed, 83 insertions(+), 3 deletions(-) diff --git a/zokrates_core/src/absy/from_ast.rs b/zokrates_core/src/absy/from_ast.rs index c67fcc8c..8a508f81 100644 --- a/zokrates_core/src/absy/from_ast.rs +++ b/zokrates_core/src/absy/from_ast.rs @@ -7,6 +7,11 @@ use zokrates_pest_ast as pest; impl<'ast, T: Field> From> for absy::Module<'ast, T> { fn from(prog: pest::File<'ast>) -> absy::Module { absy::Module { + types: prog + .structs + .into_iter() + .map(|t| absy::TypeDeclarationNode::from(t)) + .collect(), functions: prog .functions .into_iter() @@ -30,6 +35,12 @@ impl<'ast> From> for absy::ImportNode<'ast> { } } +impl<'ast> From> for absy::TypeDeclarationNode<'ast> { + fn from(definition: pest::StructDefinition<'ast>) -> absy::TypeDeclarationNode { + unimplemented!() + } +} + impl<'ast, T: Field> From> for absy::FunctionDeclarationNode<'ast, T> { fn from(function: pest::Function<'ast>) -> absy::FunctionDeclarationNode { use absy::NodeValue; diff --git a/zokrates_core/src/absy/mod.rs b/zokrates_core/src/absy/mod.rs index fa0c4a59..ec6e7ae5 100644 --- a/zokrates_core/src/absy/mod.rs +++ b/zokrates_core/src/absy/mod.rs @@ -13,7 +13,7 @@ pub mod variable; pub use crate::absy::node::{Node, NodeValue}; pub use crate::absy::parameter::{Parameter, ParameterNode}; pub use crate::absy::variable::{Variable, VariableNode}; -use crate::types::{FunctionIdentifier, Signature}; +use crate::types::{FunctionIdentifier, Signature, Type}; use embed::FlatEmbed; use crate::imports::ImportNode; @@ -34,6 +34,9 @@ pub type Modules<'ast, T> = HashMap>; /// A collection of `FunctionDeclaration`. Duplicates are allowed here as they are fine syntatically. pub type FunctionDeclarations<'ast, T> = Vec>; +/// A collection of `StructDeclaration`. Duplicates are allowed here as they are fine syntatically. +pub type TypeDeclarations<'ast> = Vec>; + /// A `Program` is a collection of `Module`s and an id of the main `Module` pub struct Program<'ast, T: Field> { pub modules: HashMap>, @@ -47,6 +50,23 @@ pub struct FunctionDeclaration<'ast, T: Field> { pub symbol: FunctionSymbol<'ast, T>, } +/// A declaration of a `FunctionSymbol`, be it from an import or a function definition +#[derive(PartialEq, Debug, Clone)] +pub struct TypeDeclaration<'ast> { + pub id: Identifier<'ast>, + pub symbol: TypeSymbol<'ast>, +} + +impl<'ast> fmt::Display for TypeDeclaration<'ast> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self.symbol { + TypeSymbol::Here(ref s) => write!(f, "struct {} {}", self.id, s), + } + } +} + +type TypeDeclarationNode<'ast> = Node>; + impl<'ast, T: Field> fmt::Display for FunctionDeclaration<'ast, T> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self.symbol { @@ -67,6 +87,8 @@ type FunctionDeclarationNode<'ast, T> = Node>; /// A module as a collection of `FunctionDeclaration`s #[derive(Clone, PartialEq)] pub struct Module<'ast, T: Field> { + /// Structs of the module + pub types: TypeDeclarations<'ast>, /// Functions of the module pub functions: FunctionDeclarations<'ast, T>, pub imports: Vec>, // we still use `imports` as they are not directly converted into `FunctionDeclaration`s after the importer is done, `imports` is empty @@ -80,6 +102,49 @@ pub enum FunctionSymbol<'ast, T: Field> { Flat(FlatEmbed), } +/// A user defined type, a struct defined in this module for now // TODO allow importing types +#[derive(Debug, Clone, PartialEq)] +pub enum TypeSymbol<'ast> { + Here(StructTypeNode<'ast>), +} + +/// A struct type definition +#[derive(Debug, Clone, PartialEq)] +pub struct StructType<'ast> { + fields: Vec>, +} + +impl<'ast> fmt::Display for StructType<'ast> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!( + f, + "{}", + self.fields + .iter() + .map(|fi| fi.to_string()) + .collect::>() + .join("\n") + ) + } +} + +type StructTypeNode<'ast> = Node>; + +/// A struct type definition +#[derive(Debug, Clone, PartialEq)] +pub struct StructField<'ast> { + id: Identifier<'ast>, + ty: Type, +} + +impl<'ast> fmt::Display for StructField<'ast> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}: {},", self.id, self.ty) + } +} + +type StructFieldNode<'ast> = Node>; + /// A function import #[derive(Debug, Clone, PartialEq)] pub struct FunctionImport<'ast> { diff --git a/zokrates_core/src/absy/node.rs b/zokrates_core/src/absy/node.rs index e9cf423c..13d4e44e 100644 --- a/zokrates_core/src/absy/node.rs +++ b/zokrates_core/src/absy/node.rs @@ -75,6 +75,9 @@ impl<'ast, T: Field> NodeValue for ExpressionList<'ast, T> {} impl<'ast, T: Field> NodeValue for Assignee<'ast, T> {} impl<'ast, T: Field> NodeValue for Statement<'ast, T> {} impl<'ast, T: Field> NodeValue for FunctionDeclaration<'ast, T> {} +impl<'ast> NodeValue for TypeDeclaration<'ast> {} +impl<'ast> NodeValue for StructType<'ast> {} +impl<'ast> NodeValue for StructField<'ast> {} impl<'ast, T: Field> NodeValue for Function<'ast, T> {} impl<'ast, T: Field> NodeValue for Module<'ast, T> {} impl<'ast> NodeValue for FunctionImport<'ast> {} diff --git a/zokrates_core/src/imports.rs b/zokrates_core/src/imports.rs index 2b34ba60..4fbd2bb5 100644 --- a/zokrates_core/src/imports.rs +++ b/zokrates_core/src/imports.rs @@ -223,6 +223,7 @@ impl Importer { Ok(Module { imports: vec![], functions: functions, + ..destination }) } } diff --git a/zokrates_pest_ast/src/lib.rs b/zokrates_pest_ast/src/lib.rs index 48b451da..46a90d52 100644 --- a/zokrates_pest_ast/src/lib.rs +++ b/zokrates_pest_ast/src/lib.rs @@ -13,8 +13,8 @@ pub use ast::{ ConstantExpression, DefinitionStatement, Expression, File, FromExpression, Function, IdentifierExpression, ImportDirective, ImportSource, InlineArrayExpression, IterationStatement, MultiAssignmentStatement, Parameter, PostfixExpression, Range, RangeOrExpression, - ReturnStatement, Span, Spread, SpreadOrExpression, Statement, TernaryExpression, ToExpression, - Type, UnaryExpression, UnaryOperator, Visibility, + ReturnStatement, Span, Spread, SpreadOrExpression, Statement, StructDefinition, + TernaryExpression, ToExpression, Type, UnaryExpression, UnaryOperator, Visibility, }; mod ast { From 6eb0efe1ceab242c572b4e57eff55e6f87e156ff Mon Sep 17 00:00:00 2001 From: schaeff Date: Fri, 26 Jul 2019 19:15:44 +0200 Subject: [PATCH 06/35] add member access to untyped absy --- zokrates_core/src/absy/from_ast.rs | 4 ++-- zokrates_core/src/absy/mod.rs | 3 +++ zokrates_core/src/semantics.rs | 1 + 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/zokrates_core/src/absy/from_ast.rs b/zokrates_core/src/absy/from_ast.rs index 8a508f81..00dcc34e 100644 --- a/zokrates_core/src/absy/from_ast.rs +++ b/zokrates_core/src/absy/from_ast.rs @@ -36,7 +36,7 @@ impl<'ast> From> for absy::ImportNode<'ast> { } impl<'ast> From> for absy::TypeDeclarationNode<'ast> { - fn from(definition: pest::StructDefinition<'ast>) -> absy::TypeDeclarationNode { + fn from(_: pest::StructDefinition<'ast>) -> absy::TypeDeclarationNode { unimplemented!() } } @@ -565,7 +565,7 @@ impl<'ast> From> for Type { }) .unwrap() } - pest::Type::Struct(s) => unimplemented!("struct declarations not supported yet"), + pest::Type::Struct(_) => unimplemented!("struct declarations not supported yet"), } } } diff --git a/zokrates_core/src/absy/mod.rs b/zokrates_core/src/absy/mod.rs index ec6e7ae5..42569bb0 100644 --- a/zokrates_core/src/absy/mod.rs +++ b/zokrates_core/src/absy/mod.rs @@ -483,6 +483,7 @@ pub enum Expression<'ast, T: Field> { Box>, Box>, ), + Member(Box>, Box>), Or(Box>, Box>), } @@ -532,6 +533,7 @@ impl<'ast, T: Field> fmt::Display for Expression<'ast, T> { write!(f, "]") } Expression::Select(ref array, ref index) => write!(f, "{}[{}]", array, index), + Expression::Member(ref struc, ref id) => write!(f, "{}.{}", struc, id), Expression::Or(ref lhs, ref rhs) => write!(f, "{} || {}", lhs, rhs), } } @@ -571,6 +573,7 @@ impl<'ast, T: Field> fmt::Debug for Expression<'ast, T> { write!(f, "]") } Expression::Select(ref array, ref index) => write!(f, "{}[{}]", array, index), + Expression::Member(ref struc, ref id) => write!(f, "{}.{}", struc, id), Expression::Or(ref lhs, ref rhs) => write!(f, "{} || {}", lhs, rhs), } } diff --git a/zokrates_core/src/semantics.rs b/zokrates_core/src/semantics.rs index 193e9a78..57b9449d 100644 --- a/zokrates_core/src/semantics.rs +++ b/zokrates_core/src/semantics.rs @@ -1103,6 +1103,7 @@ impl<'ast> Checker<'ast> { }, } } + Expression::Member(..) => unimplemented!(), Expression::InlineArray(expressions) => { // we should have at least one expression let size = expressions.len(); From b475c5baa6cfb5aa864ae44dd47bcfcdb2d323f8 Mon Sep 17 00:00:00 2001 From: schaeff Date: Fri, 26 Jul 2019 19:27:52 +0200 Subject: [PATCH 07/35] add conversion in from_flat. wip --- zokrates_core/src/absy/from_ast.rs | 37 ++++++++++++++++++++++++++++-- zokrates_pest_ast/src/lib.rs | 2 +- 2 files changed, 36 insertions(+), 3 deletions(-) diff --git a/zokrates_core/src/absy/from_ast.rs b/zokrates_core/src/absy/from_ast.rs index 00dcc34e..a3b6efff 100644 --- a/zokrates_core/src/absy/from_ast.rs +++ b/zokrates_core/src/absy/from_ast.rs @@ -36,8 +36,41 @@ impl<'ast> From> for absy::ImportNode<'ast> { } impl<'ast> From> for absy::TypeDeclarationNode<'ast> { - fn from(_: pest::StructDefinition<'ast>) -> absy::TypeDeclarationNode { - unimplemented!() + fn from(definition: pest::StructDefinition<'ast>) -> absy::TypeDeclarationNode { + use absy::NodeValue; + + let span = definition.span; + + let id = definition.id.span.as_str(); + + let ty = absy::StructType { + fields: definition + .fields + .into_iter() + .map(|f| absy::StructFieldNode::from(f)) + .collect(), + } + .span(span.clone()); // TODO check + + absy::TypeDeclaration { + id, + symbol: absy::TypeSymbol::Here(ty), + } + .span(span) + } +} + +impl<'ast> From> for absy::StructFieldNode<'ast> { + fn from(field: pest::StructField<'ast>) -> absy::StructFieldNode { + use absy::NodeValue; + + let span = field.span; + + let id = field.id.span.as_str(); + + let ty = Type::from(field.ty); + + absy::StructField { id, ty }.span(span) } } diff --git a/zokrates_pest_ast/src/lib.rs b/zokrates_pest_ast/src/lib.rs index 46a90d52..c24ce49b 100644 --- a/zokrates_pest_ast/src/lib.rs +++ b/zokrates_pest_ast/src/lib.rs @@ -13,7 +13,7 @@ pub use ast::{ ConstantExpression, DefinitionStatement, Expression, File, FromExpression, Function, IdentifierExpression, ImportDirective, ImportSource, InlineArrayExpression, IterationStatement, MultiAssignmentStatement, Parameter, PostfixExpression, Range, RangeOrExpression, - ReturnStatement, Span, Spread, SpreadOrExpression, Statement, StructDefinition, + ReturnStatement, Span, Spread, SpreadOrExpression, Statement, StructDefinition, StructField, TernaryExpression, ToExpression, Type, UnaryExpression, UnaryOperator, Visibility, }; From e81214edefea3d4c3d3b5e7511b7ae216bdd1c4c Mon Sep 17 00:00:00 2001 From: schaeff Date: Fri, 26 Jul 2019 22:19:45 +0200 Subject: [PATCH 08/35] implement member access in from_ast --- zokrates_core/src/absy/from_ast.rs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/zokrates_core/src/absy/from_ast.rs b/zokrates_core/src/absy/from_ast.rs index a3b6efff..83103033 100644 --- a/zokrates_core/src/absy/from_ast.rs +++ b/zokrates_core/src/absy/from_ast.rs @@ -511,7 +511,9 @@ impl<'ast, T: Field> From> for absy::ExpressionNod absy::Expression::Select(box acc, box absy::RangeOrExpression::from(a.expression)) .span(a.span) } - pest::Access::Member(_) => unimplemented!("member access is not implemented yet"), + pest::Access::Member(m) => { + absy::Expression::Member(box acc, box m.id.span.as_str()).span(m.span) + } }) } } From c84c0c5ca760915f29caad536b4bbc687c90b491 Mon Sep 17 00:00:00 2001 From: schaeff Date: Wed, 31 Jul 2019 16:49:21 +0200 Subject: [PATCH 09/35] implement semantics and flattening for simple cases --- zokrates_core/src/absy/from_ast.rs | 42 +-- zokrates_core/src/absy/mod.rs | 14 +- zokrates_core/src/absy/variable.rs | 22 +- zokrates_core/src/flatten/mod.rs | 208 ++++++++++++- zokrates_core/src/semantics.rs | 310 +++++++++++++++++--- zokrates_core/src/static_analysis/inline.rs | 26 +- zokrates_core/src/static_analysis/unroll.rs | 1 + zokrates_core/src/typed_absy/folder.rs | 70 +++++ zokrates_core/src/typed_absy/mod.rs | 141 ++++++++- zokrates_core/src/typed_absy/parameter.rs | 16 +- zokrates_core/src/typed_absy/variable.rs | 24 +- zokrates_core/src/types/mod.rs | 51 ++++ zokrates_core/src/types/signature.rs | 62 ++++ zokrates_parser/src/zokrates.pest | 2 +- zokrates_pest_ast/src/lib.rs | 2 +- 15 files changed, 889 insertions(+), 102 deletions(-) diff --git a/zokrates_core/src/absy/from_ast.rs b/zokrates_core/src/absy/from_ast.rs index 83103033..d1bbbe82 100644 --- a/zokrates_core/src/absy/from_ast.rs +++ b/zokrates_core/src/absy/from_ast.rs @@ -1,6 +1,6 @@ use absy; use imports; -use types::Type; +use types::UnresolvedType; use zokrates_field::field::Field; use zokrates_pest_ast as pest; @@ -68,7 +68,7 @@ impl<'ast> From> for absy::StructFieldNode<'ast> { let id = field.id.span.as_str(); - let ty = Type::from(field.ty); + let ty = UnresolvedType::from(field.ty); absy::StructField { id, ty }.span(span) } @@ -80,7 +80,7 @@ impl<'ast, T: Field> From> for absy::FunctionDeclarationNod let span = function.span; - let signature = absy::Signature::new() + let signature = absy::UnresolvedSignature::new() .inputs( function .parameters @@ -94,7 +94,7 @@ impl<'ast, T: Field> From> for absy::FunctionDeclarationNod .returns .clone() .into_iter() - .map(|r| Type::from(r)) + .map(|r| UnresolvedType::from(r)) .collect(), ); @@ -135,8 +135,8 @@ impl<'ast> From> for absy::ParameterNode<'ast> { }) .unwrap_or(false); - let variable = - absy::Variable::new(param.id.span.as_str(), Type::from(param.ty)).span(param.id.span); + let variable = absy::Variable::new(param.id.span.as_str(), UnresolvedType::from(param.ty)) + .span(param.id.span); absy::Parameter::new(variable, private).span(param.span) } @@ -167,7 +167,8 @@ fn statements_from_multi_assignment<'ast, T: Field>( .filter(|i| i.ty.is_some()) .map(|i| { absy::Statement::Declaration( - absy::Variable::new(i.id.span.as_str(), Type::from(i.ty.unwrap())).span(i.id.span), + absy::Variable::new(i.id.span.as_str(), UnresolvedType::from(i.ty.unwrap())) + .span(i.id.span), ) .span(i.span) }); @@ -202,8 +203,11 @@ fn statements_from_definition<'ast, T: Field>( vec![ absy::Statement::Declaration( - absy::Variable::new(definition.id.span.as_str(), Type::from(definition.ty)) - .span(definition.id.span.clone()), + absy::Variable::new( + definition.id.span.as_str(), + UnresolvedType::from(definition.ty), + ) + .span(definition.id.span.clone()), ) .span(definition.span.clone()), absy::Statement::Definition( @@ -262,7 +266,7 @@ impl<'ast, T: Field> From> for absy::StatementNod let from = absy::ExpressionNode::from(statement.from); let to = absy::ExpressionNode::from(statement.to); let index = statement.index.span.as_str(); - let ty = Type::from(statement.ty); + let ty = UnresolvedType::from(statement.ty); let statements: Vec> = statement .statements .into_iter() @@ -564,17 +568,17 @@ impl<'ast, T: Field> From> for absy::AssigneeNode<'ast, T> } } -impl<'ast> From> for Type { - fn from(t: pest::Type<'ast>) -> Type { +impl<'ast> From> for UnresolvedType { + fn from(t: pest::Type<'ast>) -> UnresolvedType { match t { pest::Type::Basic(t) => match t { - pest::BasicType::Field(_) => Type::FieldElement, - pest::BasicType::Boolean(_) => Type::Boolean, + pest::BasicType::Field(_) => UnresolvedType::FieldElement, + pest::BasicType::Boolean(_) => UnresolvedType::Boolean, }, pest::Type::Array(t) => { let inner_type = match t.ty { - pest::BasicType::Field(_) => Type::FieldElement, - pest::BasicType::Boolean(_) => Type::Boolean, + pest::BasicType::Field(_) => UnresolvedType::FieldElement, + pest::BasicType::Boolean(_) => UnresolvedType::Boolean, }; t.size @@ -595,12 +599,12 @@ impl<'ast> From> for Type { ), }) .fold(None, |acc, s| match acc { - None => Some(Type::array(inner_type.clone(), s)), - Some(acc) => Some(Type::array(acc, s)), + None => Some(UnresolvedType::array(inner_type.clone(), s)), + Some(acc) => Some(UnresolvedType::array(acc, s)), }) .unwrap() } - pest::Type::Struct(_) => unimplemented!("struct declarations not supported yet"), + pest::Type::Struct(s) => UnresolvedType::User(s.id.span.as_str().to_string()), } } } diff --git a/zokrates_core/src/absy/mod.rs b/zokrates_core/src/absy/mod.rs index 42569bb0..2c46cf15 100644 --- a/zokrates_core/src/absy/mod.rs +++ b/zokrates_core/src/absy/mod.rs @@ -13,7 +13,7 @@ pub mod variable; pub use crate::absy::node::{Node, NodeValue}; pub use crate::absy::parameter::{Parameter, ParameterNode}; pub use crate::absy::variable::{Variable, VariableNode}; -use crate::types::{FunctionIdentifier, Signature, Type}; +use crate::types::{FunctionIdentifier, UnresolvedSignature, UnresolvedType}; use embed::FlatEmbed; use crate::imports::ImportNode; @@ -50,7 +50,7 @@ pub struct FunctionDeclaration<'ast, T: Field> { pub symbol: FunctionSymbol<'ast, T>, } -/// A declaration of a `FunctionSymbol`, be it from an import or a function definition +/// A declaration of a `TypeSymbol`, be it from an import or a function definition #[derive(PartialEq, Debug, Clone)] pub struct TypeDeclaration<'ast> { pub id: Identifier<'ast>, @@ -111,7 +111,7 @@ pub enum TypeSymbol<'ast> { /// A struct type definition #[derive(Debug, Clone, PartialEq)] pub struct StructType<'ast> { - fields: Vec>, + pub fields: Vec>, } impl<'ast> fmt::Display for StructType<'ast> { @@ -128,13 +128,13 @@ impl<'ast> fmt::Display for StructType<'ast> { } } -type StructTypeNode<'ast> = Node>; +pub type StructTypeNode<'ast> = Node>; /// A struct type definition #[derive(Debug, Clone, PartialEq)] pub struct StructField<'ast> { - id: Identifier<'ast>, - ty: Type, + pub id: Identifier<'ast>, + pub ty: UnresolvedType, } impl<'ast> fmt::Display for StructField<'ast> { @@ -220,7 +220,7 @@ pub struct Function<'ast, T: Field> { /// Vector of statements that are executed when running the function pub statements: Vec>, /// function signature - pub signature: Signature, + pub signature: UnresolvedSignature, } pub type FunctionNode<'ast, T> = Node>; diff --git a/zokrates_core/src/absy/variable.rs b/zokrates_core/src/absy/variable.rs index 8eefd9cb..a41acc0b 100644 --- a/zokrates_core/src/absy/variable.rs +++ b/zokrates_core/src/absy/variable.rs @@ -1,19 +1,19 @@ use crate::absy::Node; use std::fmt; -use types::Type; +use types::UnresolvedType; use crate::absy::Identifier; #[derive(Clone, PartialEq, Hash, Eq)] pub struct Variable<'ast> { pub id: Identifier<'ast>, - pub _type: Type, + pub _type: UnresolvedType, } pub type VariableNode<'ast> = Node>; impl<'ast> Variable<'ast> { - pub fn new>(id: S, t: Type) -> Variable<'ast> { + pub fn new>(id: S, t: UnresolvedType) -> Variable<'ast> { Variable { id: id.into(), _type: t, @@ -23,32 +23,36 @@ impl<'ast> Variable<'ast> { pub fn field_element>(id: S) -> Variable<'ast> { Variable { id: id.into(), - _type: Type::FieldElement, + _type: UnresolvedType::FieldElement, } } pub fn boolean>(id: S) -> Variable<'ast> { Variable { id: id.into(), - _type: Type::Boolean, + _type: UnresolvedType::Boolean, } } pub fn field_array>(id: S, size: usize) -> Variable<'ast> { Variable { id: id.into(), - _type: Type::array(Type::FieldElement, size), + _type: UnresolvedType::array(UnresolvedType::FieldElement, size), } } - pub fn array>(id: S, inner_ty: Type, size: usize) -> Variable<'ast> { + pub fn array>( + id: S, + inner_ty: UnresolvedType, + size: usize, + ) -> Variable<'ast> { Variable { id: id.into(), - _type: Type::array(inner_ty, size), + _type: UnresolvedType::array(inner_ty, size), } } - pub fn get_type(&self) -> Type { + pub fn get_type(&self) -> UnresolvedType { self._type.clone() } } diff --git a/zokrates_core/src/flatten/mod.rs b/zokrates_core/src/flatten/mod.rs index 8edeb375..49d9d165 100644 --- a/zokrates_core/src/flatten/mod.rs +++ b/zokrates_core/src/flatten/mod.rs @@ -9,7 +9,7 @@ use crate::flat_absy::*; use crate::helpers::{DirectiveStatement, Helper, RustHelper}; use crate::typed_absy::*; use crate::types::Type; -use crate::types::{FunctionKey, Signature}; +use crate::types::{FunctionKey, MemberId, Signature}; use std::collections::HashMap; use std::convert::TryFrom; use types::FunctionIdentifier; @@ -39,6 +39,7 @@ trait Flatten<'ast, T: Field>: TryFrom, Error: std::fmt -> Self; fn select(array: ArrayExpression<'ast, T>, index: FieldElementExpression<'ast, T>) -> Self; + fn member(s: StructExpression<'ast, T>, id: MemberId) -> Self; } impl<'ast, T: Field> Flatten<'ast, T> for FieldElementExpression<'ast, T> { @@ -62,6 +63,10 @@ impl<'ast, T: Field> Flatten<'ast, T> for FieldElementExpression<'ast, T> { fn select(array: ArrayExpression<'ast, T>, index: FieldElementExpression<'ast, T>) -> Self { FieldElementExpression::Select(box array, box index) } + + fn member(s: StructExpression<'ast, T>, id: MemberId) -> Self { + FieldElementExpression::Member(box s, id) + } } impl<'ast, T: Field> Flatten<'ast, T> for BooleanExpression<'ast, T> { @@ -85,6 +90,37 @@ impl<'ast, T: Field> Flatten<'ast, T> for BooleanExpression<'ast, T> { fn select(array: ArrayExpression<'ast, T>, index: FieldElementExpression<'ast, T>) -> Self { BooleanExpression::Select(box array, box index) } + + fn member(s: StructExpression<'ast, T>, id: MemberId) -> Self { + BooleanExpression::Member(box s, id) + } +} + +impl<'ast, T: Field> Flatten<'ast, T> for StructExpression<'ast, T> { + fn flatten( + self, + flattener: &mut Flattener<'ast, T>, + symbols: &TypedFunctionSymbols<'ast, T>, + statements_flattened: &mut Vec>, + ) -> Vec> { + unimplemented!() + } + + fn if_else( + condition: BooleanExpression<'ast, T>, + consequence: Self, + alternative: Self, + ) -> Self { + unimplemented!() + } + + fn select(array: ArrayExpression<'ast, T>, index: FieldElementExpression<'ast, T>) -> Self { + unimplemented!() + } + + fn member(s: StructExpression<'ast, T>, id: MemberId) -> Self { + unimplemented!() + } } impl<'ast, T: Field> Flatten<'ast, T> for ArrayExpression<'ast, T> { @@ -111,6 +147,11 @@ impl<'ast, T: Field> Flatten<'ast, T> for ArrayExpression<'ast, T> { statements_flattened, self, ), + Type::Struct(..) => flattener.flatten_array_expression::>( + symbols, + statements_flattened, + self, + ), } } @@ -138,6 +179,10 @@ impl<'ast, T: Field> Flatten<'ast, T> for ArrayExpression<'ast, T> { inner: ArrayExpressionInner::Select(box array, box index), } } + + fn member(s: StructExpression<'ast, T>, id: MemberId) -> Self { + unimplemented!() + } } impl<'ast, T: Field> Flattener<'ast, T> { @@ -238,6 +283,45 @@ impl<'ast, T: Field> Flattener<'ast, T> { res.into_iter().map(|r| r.into()).collect() } + fn flatten_member_expression>( + &mut self, + symbols: &TypedFunctionSymbols<'ast, T>, + statements_flattened: &mut Vec>, + s: StructExpression<'ast, T>, + member_id: MemberId, + ) -> Vec> { + let members = s.ty; + + match s.inner { + StructExpressionInner::Identifier(id) => { + // the struct is encoded as a sequence, so we need to identify the offset at which this member starts + let offset = members + .iter() + .map(|(id, ty)| (id, ty.get_primitive_count())) + .fold((false, 0), |acc, (id, count)| { + if acc.0 && *id != member_id { + (false, acc.1 + count) + } else { + (true, acc.1) + } + }) + .1; + // we also need the size of this member + let size = members + .iter() + .find(|(id, _)| *id == member_id) + .unwrap() + .1 + .get_primitive_count(); + self.layout.get(&id).unwrap()[offset..(offset + size)] + .into_iter() + .map(|i| i.clone().into()) + .collect() + } + _ => unimplemented!(), + } + } + fn flatten_select_expression>( &mut self, symbols: &TypedFunctionSymbols<'ast, T>, @@ -280,6 +364,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { ) .flatten(self, symbols, statements_flattened) } + ArrayExpressionInner::Member(box s, id) => unimplemented!(), ArrayExpressionInner::Select(box array, box index) => { assert!(n < T::from(size)); let n = n.to_dec_string().parse::().unwrap(); @@ -355,6 +440,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { ), ) } + ArrayExpressionInner::Member(box s, id) => U::member(s, id), ArrayExpressionInner::Select(box array, box index) => U::select( ArrayExpression { ty: ty.clone(), @@ -671,6 +757,14 @@ impl<'ast, T: Field> Flattener<'ast, T> { alternative, )[0] .clone(), + BooleanExpression::Member(box s, id) => self + .flatten_member_expression::>( + symbols, + statements_flattened, + s, + id, + )[0] + .clone(), BooleanExpression::Select(box array, box index) => self .flatten_select_expression::>( symbols, @@ -817,7 +911,15 @@ impl<'ast, T: Field> Flattener<'ast, T> { statements_flattened, e, ), + Type::Struct(..) => self.flatten_array_expression::>( + symbols, + statements_flattened, + e, + ), }, + TypedExpression::Struct(e) => { + self.flatten_struct_expression(symbols, statements_flattened, e) + } } } @@ -1050,6 +1152,14 @@ impl<'ast, T: Field> Flattener<'ast, T> { assert!(exprs_flattened.expressions.len() == 1); // outside of MultipleDefinition, FunctionCalls must return a single value exprs_flattened.expressions[0].clone() } + FieldElementExpression::Member(box s, id) => self + .flatten_member_expression::>( + symbols, + statements_flattened, + s, + id, + )[0] + .clone(), FieldElementExpression::Select(box array, box index) => self .flatten_select_expression::>( symbols, @@ -1061,6 +1171,78 @@ impl<'ast, T: Field> Flattener<'ast, T> { } } + fn flatten_struct_expression( + &mut self, + symbols: &HashMap, TypedFunctionSymbol<'ast, T>>, + statements_flattened: &mut Vec>, + expr: StructExpression<'ast, T>, + ) -> Vec> { + let ty = expr.get_type(); + //assert_eq!(U::get_type(), inner_type); + + match expr.inner { + StructExpressionInner::Identifier(x) => self + .layout + .get(&x) + .unwrap() + .iter() + .map(|v| FlatExpression::Identifier(v.clone())) + .collect(), + // StructExpressionInner::Value(values) => { + // values + // .into_iter() + // .flat_map(|v| { + // U::try_from(v) + // .unwrap() + // .flatten(self, symbols, statements_flattened) + // }) + // .collect() + // } + StructExpressionInner::FunctionCall(key, param_expressions) => { + let exprs_flattened = self.flatten_function_call( + symbols, + statements_flattened, + key.id, + vec![ty], + param_expressions, + ); + exprs_flattened.expressions + } + // StructExpressionInner::IfElse(ref condition, ref consequence, ref alternative) => (0 + // ..size) + // .flat_map(|i| { + // U::if_else( + // *condition.clone(), + // U::select( + // *consequence.clone(), + // FieldElementExpression::Number(T::from(i)), + // ), + // U::select( + // *alternative.clone(), + // FieldElementExpression::Number(T::from(i)), + // ), + // ) + // .flatten(self, symbols, statements_flattened) + // }) + // .collect(), + StructExpressionInner::Member(box s, id) => self + .flatten_member_expression::>( + symbols, + statements_flattened, + s, + id, + ), + StructExpressionInner::Select(box array, box index) => self + .flatten_select_expression::>( + symbols, + statements_flattened, + array, + index, + ), + _ => unimplemented!("yeah well"), + } + } + /// # Remarks /// * U is the inner type fn flatten_array_expression>( @@ -1120,6 +1302,13 @@ impl<'ast, T: Field> Flattener<'ast, T> { .flatten(self, symbols, statements_flattened) }) .collect(), + ArrayExpressionInner::Member(box s, id) => self + .flatten_member_expression::>( + symbols, + statements_flattened, + s, + id, + ), ArrayExpressionInner::Select(box array, box index) => self .flatten_select_expression::>( symbols, @@ -1282,6 +1471,17 @@ impl<'ast, T: Field> Flattener<'ast, T> { .map(|(v, r)| FlatStatement::Definition(v, r)), ); } + Type::Struct(..) => { + let vars = match assignee { + TypedAssignee::Identifier(v) => self.use_variable(&v), + _ => unimplemented!(), + }; + statements_flattened.extend( + vars.into_iter() + .zip(rhs.into_iter()) + .map(|(v, r)| FlatStatement::Definition(v, r)), + ); + } } } TypedStatement::Condition(expr1, expr2) => { @@ -1483,11 +1683,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { /// /// * `name` - a String that holds the name of the variable fn use_variable(&mut self, variable: &Variable<'ast>) -> Vec { - let vars = match variable.get_type() { - Type::FieldElement => self.issue_new_variables(1), - Type::Boolean => self.issue_new_variables(1), - Type::Array(ty, size) => self.issue_new_variables(ty.get_primitive_count() * size), - }; + let vars = self.issue_new_variables(variable.get_type().get_primitive_count()); self.layout.insert(variable.id.clone(), vars.clone()); vars diff --git a/zokrates_core/src/semantics.rs b/zokrates_core/src/semantics.rs index 57b9449d..03e03633 100644 --- a/zokrates_core/src/semantics.rs +++ b/zokrates_core/src/semantics.rs @@ -4,17 +4,17 @@ //! @author Thibaut Schaeffer //! @date 2017 -use crate::absy::variable::Variable; use crate::absy::Identifier; use crate::absy::*; use crate::typed_absy::*; +use crate::typed_absy::{Parameter, Variable}; use std::collections::{HashMap, HashSet}; use std::fmt; use zokrates_field::field::Field; use crate::parser::Position; -use crate::types::{FunctionKey, Type}; +use crate::types::{FunctionKey, Signature, Type, UnresolvedSignature, UnresolvedType, UserTypeId}; use std::hash::{Hash, Hasher}; @@ -125,6 +125,7 @@ impl<'ast> Eq for ScopedVariable<'ast> {} pub struct Checker<'ast> { scope: HashSet>, functions: HashSet>, + types: HashMap>, level: usize, } @@ -133,6 +134,7 @@ impl<'ast> Checker<'ast> { Checker { scope: HashSet::new(), functions: HashSet::new(), + types: HashMap::new(), level: 0, } } @@ -178,6 +180,44 @@ impl<'ast> Checker<'ast> { }) } + fn check_type_symbol( + &mut self, + s: TypeSymbol<'ast>, + module_id: &ModuleId, + modules: &mut Modules<'ast, T>, + typed_modules: &mut TypedModules<'ast, T>, + ) -> Result> { + match s { + TypeSymbol::Here(t) => { + self.check_struct_type_declaration(t, module_id, modules, typed_modules) + } + } + } + + fn check_struct_type_declaration( + &mut self, + s: StructTypeNode<'ast>, + module_id: &ModuleId, + modules: &mut Modules<'ast, T>, + typed_modules: &mut TypedModules<'ast, T>, + ) -> Result> { + let pos = s.pos(); + let s = s.value; + + let fields = s + .fields + .into_iter() + .map(|n| { + ( + n.value.id.to_string(), + self.check_type(n.value.ty, module_id).unwrap(), + ) + }) + .collect(); + + Ok(Type::Struct(fields)) + } + fn check_module( &mut self, module_id: &ModuleId, @@ -194,13 +234,32 @@ impl<'ast> Checker<'ast> { // if it was not, check it Some(module) => { assert_eq!(module.imports.len(), 0); + + for declaration in module.types { + let pos = declaration.pos(); + let declaration = declaration.value; + + let ty = self + .check_type_symbol(declaration.symbol, module_id, modules, typed_modules) + .unwrap(); + self.types + .entry(module_id.clone()) + .or_default() + .insert(declaration.id.to_string(), ty); + } + for declaration in module.functions { self.enter_scope(); let pos = declaration.pos(); let declaration = declaration.value; - match self.check_function_symbol(declaration.symbol, modules, typed_modules) { + match self.check_function_symbol( + declaration.symbol, + module_id, + modules, + typed_modules, + ) { Ok(checked_function_symbols) => { for funct in checked_function_symbols { let query = FunctionQuery::new( @@ -298,7 +357,7 @@ impl<'ast> Checker<'ast> { fn check_for_var(&self, var: &VariableNode) -> Result<(), Error> { match var.value.get_type() { - Type::FieldElement => Ok(()), + UnresolvedType::FieldElement => Ok(()), t => Err(Error { pos: Some(var.pos()), message: format!("Variable in for loop cannot have type {}", t), @@ -309,6 +368,7 @@ impl<'ast> Checker<'ast> { fn check_function( &mut self, funct_node: FunctionNode<'ast, T>, + module_id: &ModuleId, ) -> Result, Vec> { let mut errors = vec![]; let funct = funct_node.value; @@ -316,13 +376,18 @@ impl<'ast> Checker<'ast> { assert_eq!(funct.arguments.len(), funct.signature.inputs.len()); for arg in &funct.arguments { - self.insert_into_scope(arg.value.id.value.clone()); + let v = self + .check_variable(arg.value.id.clone(), module_id) + .unwrap(); + self.insert_into_scope(v); } let mut statements_checked = vec![]; + let signature = self.check_signature(funct.signature, module_id)?; + for stat in funct.statements.into_iter() { - match self.check_statement(stat, &funct.signature.outputs) { + match self.check_statement(stat, &signature.outputs, module_id) { Ok(statement) => { statements_checked.push(statement); } @@ -340,16 +405,57 @@ impl<'ast> Checker<'ast> { arguments: funct .arguments .into_iter() - .map(|a| a.value.into()) + .map(|a| self.check_parameter(a, module_id)) .collect(), statements: statements_checked, - signature: funct.signature, + signature, }) } + fn check_parameter(&self, p: ParameterNode<'ast>, module_id: &ModuleId) -> Parameter<'ast> { + Parameter { + id: self.check_variable(p.value.id, module_id).unwrap(), + private: p.value.private, + } + } + + fn check_signature( + &self, + signature: UnresolvedSignature, + module_id: &ModuleId, + ) -> Result> { + Ok(Signature { + inputs: signature + .inputs + .into_iter() + .map(|t| self.check_type(t, module_id).unwrap()) + .collect(), + outputs: signature + .outputs + .into_iter() + .map(|t| self.check_type(t, module_id).unwrap()) + .collect(), + }) + } + + fn check_type(&self, ty: UnresolvedType, module_id: &ModuleId) -> Result> { + match ty { + UnresolvedType::FieldElement => Ok(Type::FieldElement), + UnresolvedType::Boolean => Ok(Type::Boolean), + UnresolvedType::Array(t, size) => Ok(Type::Array( + box self.check_type(*t, module_id).unwrap(), + size, + )), + UnresolvedType::User(id) => { + Ok(self.types.get(module_id).unwrap().get(&id).unwrap().clone()) + } + } + } + fn check_function_symbol( &mut self, funct_symbol: FunctionSymbol<'ast, T>, + module_id: &ModuleId, modules: &mut Modules<'ast, T>, typed_modules: &mut TypedModules<'ast, T>, ) -> Result>, Vec> { @@ -358,7 +464,7 @@ impl<'ast> Checker<'ast> { match funct_symbol { FunctionSymbol::Here(funct_node) => self - .check_function(funct_node) + .check_function(funct_node, module_id) .map(|f| vec![TypedFunctionSymbol::Here(f)]), FunctionSymbol::There(import_node) => { let pos = import_node.pos(); @@ -410,10 +516,22 @@ impl<'ast> Checker<'ast> { } } + fn check_variable( + &self, + v: crate::absy::VariableNode<'ast>, + module_id: &ModuleId, + ) -> Result, Vec> { + Ok(Variable::with_id_and_type( + v.value.id.into(), + self.check_type(v.value._type, module_id).unwrap(), + )) + } + fn check_statement( &mut self, stat: StatementNode<'ast, T>, header_return_types: &Vec, + module_id: &ModuleId, ) -> Result, Error> { let pos = stat.pos(); @@ -425,7 +543,7 @@ impl<'ast> Checker<'ast> { expression_list_checked.push(e_checked); } - let return_statement_types: Vec = expression_list_checked + let return_statement_types: Vec<_> = expression_list_checked .iter() .map(|e| e.get_type()) .collect(); @@ -450,13 +568,16 @@ impl<'ast> Checker<'ast> { }), } } - Statement::Declaration(var) => match self.insert_into_scope(var.clone().value) { - true => Ok(TypedStatement::Declaration(var.value.into())), - false => Err(Error { - pos: Some(pos), - message: format!("Duplicate declaration for variable named {}", var.value.id), - }), - }, + Statement::Declaration(var) => { + let var = self.check_variable(var, module_id).unwrap(); + match self.insert_into_scope(var.clone()) { + true => Ok(TypedStatement::Declaration(var)), + false => Err(Error { + pos: Some(pos), + message: format!("Duplicate declaration for variable named {}", var.id), + }), + } + } Statement::Definition(assignee, expr) => { // we create multidef when rhs is a function call to benefit from inference // check rhs is not a function call here @@ -510,22 +631,20 @@ impl<'ast> Checker<'ast> { self.check_for_var(&var)?; - self.insert_into_scope(var.clone().value); + let var = self.check_variable(var, module_id).unwrap(); + + self.insert_into_scope(var.clone()); let mut checked_statements = vec![]; for stat in statements { - let checked_stat = self.check_statement(stat, header_return_types)?; + let checked_stat = + self.check_statement(stat, header_return_types, module_id)?; checked_statements.push(checked_stat); } self.exit_scope(); - Ok(TypedStatement::For( - var.value.into(), - from, - to, - checked_statements, - )) + Ok(TypedStatement::For(var, from, to, checked_statements)) } Statement::MultipleDefinition(assignees, rhs) => { match rhs.value { @@ -568,8 +687,8 @@ impl<'ast> Checker<'ast> { let f = &candidates[0]; // we can infer the left hand side to be typed as the return values - let lhs: Vec<_> = var_names.iter().zip(f.signature.outputs.iter()).map(|(name, ty)| - Variable::new(*name, ty.clone()) + let lhs: Vec = var_names.iter().zip(f.signature.outputs.iter()).map(|(name, ty)| + Variable::with_id_and_type(crate::typed_absy::Identifier::from(*name), ty.clone()) ).collect(); let assignees: Vec<_> = lhs.iter().map(|v| v.clone().into()).collect(); @@ -605,12 +724,10 @@ impl<'ast> Checker<'ast> { // check that the assignee is declared match assignee.value { Assignee::Identifier(variable_name) => match self.get_scope(&variable_name) { - Some(var) => Ok(TypedAssignee::Identifier( - crate::typed_absy::Variable::with_id_and_type( - variable_name.into(), - var.id.get_type(), - ), - )), + Some(var) => Ok(TypedAssignee::Identifier(Variable::with_id_and_type( + variable_name.into(), + var.id._type.clone(), + ))), None => Err(Error { pos: Some(assignee.pos()), message: format!("Undeclared variable: {:?}", variable_name), @@ -705,6 +822,11 @@ impl<'ast> Checker<'ast> { inner: ArrayExpressionInner::Identifier(name.into()), } .into()), + Type::Struct(members) => Ok(StructExpression { + ty: members, + inner: StructExpressionInner::Identifier(name.into()), + } + .into()), }, None => Err(Error { pos: Some(pos), @@ -896,6 +1018,17 @@ impl<'ast> Checker<'ast> { ), } .into()), + Type::Struct(members) => Ok(StructExpression { + ty: members.clone(), + inner: StructExpressionInner::FunctionCall( + FunctionKey { + id: f.id.clone(), + signature: f.signature.clone(), + }, + arguments_checked, + ), + } + .into()), _ => unimplemented!(), }, n => Err(Error { @@ -1090,6 +1223,11 @@ impl<'ast> Checker<'ast> { inner: ArrayExpressionInner::Select(box a, box i), } .into()), + Type::Struct(members) => Ok(StructExpression { + ty: members.clone(), + inner: StructExpressionInner::Select(box a, box i), + } + .into()), } } (a, e) => Err(Error { @@ -1103,7 +1241,61 @@ impl<'ast> Checker<'ast> { }, } } - Expression::Member(..) => unimplemented!(), + Expression::Member(box e, box id) => { + let e = self.check_expression(e)?; + + match e { + TypedExpression::Struct(s) => { + // check that the struct has that field and return the type if it does + let ty = + s.ty.iter() + .find(|(member_id, ty)| member_id == id) + .map(|(member_id, ty)| ty); + + match ty { + Some(ty) => match ty { + Type::FieldElement => { + Ok(FieldElementExpression::Member(box s, id.to_string()).into()) + } + Type::Boolean => { + Ok(BooleanExpression::Member(box s, id.to_string()).into()) + } + Type::Array(box ty, size) => Ok(ArrayExpression { + ty: ty.clone(), + size: *size, + inner: ArrayExpressionInner::Member(box s, id.to_string()), + } + .into()), + Type::Struct(members) => Ok(StructExpression { + ty: members.clone(), + inner: StructExpressionInner::Member(box s, id.to_string()), + } + .into()), + }, + None => Err(Error { + pos: Some(pos), + message: format!( + "{} doesn't have member {}. Members are {}", + TypedExpression::Struct(s.clone()), + id, + s.ty.iter() + .map(|(member_id, _)| member_id.to_string()) + .collect::>() + .join(", ") + ), + }), + } + } + e => Err(Error { + pos: Some(pos), + message: format!( + "Cannot access member {} on expression of type {}", + id, + e.get_type() + ), + }), + } + } Expression::InlineArray(expressions) => { // we should have at least one expression let size = expressions.len(); @@ -1211,6 +1403,49 @@ impl<'ast> Checker<'ast> { unwrapped_expressions.push(unwrapped_e.into()); } + Ok(ArrayExpression { + ty, + size: unwrapped_expressions.len(), + inner: ArrayExpressionInner::Value(unwrapped_expressions), + } + .into()) + } + ty @ Type::Struct(..) => { + // we check all expressions have that same type + let mut unwrapped_expressions = vec![]; + + for e in expressions_checked { + let unwrapped_e = match e { + TypedExpression::Struct(e) => { + if e.get_type() == ty { + Ok(e) + } else { + Err(Error { + pos: Some(pos), + + message: format!( + "Expected {} to have type {}, but type is {}", + e, + ty, + e.get_type() + ), + }) + } + } + e => Err(Error { + pos: Some(pos), + + message: format!( + "Expected {} to have type {}, but type is {}", + e, + ty, + e.get_type() + ), + }), + }?; + unwrapped_expressions.push(unwrapped_e.into()); + } + Ok(ArrayExpression { ty, size: unwrapped_expressions.len(), @@ -1266,9 +1501,12 @@ impl<'ast> Checker<'ast> { } } - fn get_scope(&self, variable_name: &Identifier<'ast>) -> Option<&ScopedVariable> { + fn get_scope(&self, variable_name: &'ast str) -> Option<&'ast ScopedVariable> { self.scope.get(&ScopedVariable { - id: Variable::new(*variable_name, Type::FieldElement), + id: Variable::with_id_and_type( + crate::typed_absy::Identifier::from(variable_name), + Type::FieldElement, + ), level: 0, }) } diff --git a/zokrates_core/src/static_analysis/inline.rs b/zokrates_core/src/static_analysis/inline.rs index 8c824677..0db17e20 100644 --- a/zokrates_core/src/static_analysis/inline.rs +++ b/zokrates_core/src/static_analysis/inline.rs @@ -18,7 +18,7 @@ use std::collections::HashMap; use typed_absy::{folder::*, *}; -use types::{FunctionKey, Type}; +use types::{FunctionKey, MemberId, Type}; use zokrates_field::field::Field; /// An inliner @@ -260,6 +260,30 @@ impl<'ast, T: Field> Folder<'ast, T> for Inliner<'ast, T> { e => fold_array_expression_inner(self, ty, size, e), } } + + fn fold_struct_expression_inner( + &mut self, + ty: &Vec<(MemberId, Type)>, + e: StructExpressionInner<'ast, T>, + ) -> StructExpressionInner<'ast, T> { + match e { + StructExpressionInner::FunctionCall(key, exps) => { + let exps: Vec<_> = exps.into_iter().map(|e| self.fold_expression(e)).collect(); + + match self.try_inline_call(&key, exps) { + Ok(mut ret) => match ret.pop().unwrap() { + TypedExpression::Struct(e) => e.inner, + _ => unreachable!(), + }, + Err((key, expressions)) => { + StructExpressionInner::FunctionCall(key, expressions) + } + } + } + // default + e => fold_struct_expression_inner(self, ty, e), + } + } } #[cfg(test)] diff --git a/zokrates_core/src/static_analysis/unroll.rs b/zokrates_core/src/static_analysis/unroll.rs index cd560ac6..22d69a68 100644 --- a/zokrates_core/src/static_analysis/unroll.rs +++ b/zokrates_core/src/static_analysis/unroll.rs @@ -140,6 +140,7 @@ impl<'ast, T: Field> Folder<'ast, T> for Unroller<'ast> { ), }, TypedExpression::Array(..) => unimplemented!(), + TypedExpression::Struct(..) => unimplemented!(), }; vec![TypedStatement::Definition( diff --git a/zokrates_core/src/typed_absy/folder.rs b/zokrates_core/src/typed_absy/folder.rs index 0f7889d7..32bf506c 100644 --- a/zokrates_core/src/typed_absy/folder.rs +++ b/zokrates_core/src/typed_absy/folder.rs @@ -60,6 +60,7 @@ pub trait Folder<'ast, T: Field>: Sized { TypedExpression::FieldElement(e) => self.fold_field_expression(e).into(), TypedExpression::Boolean(e) => self.fold_boolean_expression(e).into(), TypedExpression::Array(e) => self.fold_array_expression(e).into(), + TypedExpression::Struct(e) => self.fold_struct_expression(e).into(), } } @@ -67,6 +68,13 @@ pub trait Folder<'ast, T: Field>: Sized { fold_array_expression(self, e) } + fn fold_struct_expression( + &mut self, + e: StructExpression<'ast, T>, + ) -> StructExpression<'ast, T> { + fold_struct_expression(self, e) + } + fn fold_expression_list( &mut self, es: TypedExpressionList<'ast, T>, @@ -105,6 +113,13 @@ pub trait Folder<'ast, T: Field>: Sized { ) -> ArrayExpressionInner<'ast, T> { fold_array_expression_inner(self, ty, size, e) } + fn fold_struct_expression_inner( + &mut self, + ty: &Vec<(MemberId, Type)>, + e: StructExpressionInner<'ast, T>, + ) -> StructExpressionInner<'ast, T> { + fold_struct_expression_inner(self, ty, e) + } } pub fn fold_module<'ast, T: Field, F: Folder<'ast, T>>( @@ -178,6 +193,10 @@ pub fn fold_array_expression_inner<'ast, T: Field, F: Folder<'ast, T>>( box f.fold_array_expression(alternative), ) } + ArrayExpressionInner::Member(box s, id) => { + let s = f.fold_struct_expression(s); + ArrayExpressionInner::Member(box s, id) + } ArrayExpressionInner::Select(box array, box index) => { let array = f.fold_array_expression(array); let index = f.fold_field_expression(index); @@ -186,6 +205,39 @@ pub fn fold_array_expression_inner<'ast, T: Field, F: Folder<'ast, T>>( } } +pub fn fold_struct_expression_inner<'ast, T: Field, F: Folder<'ast, T>>( + f: &mut F, + _: &Vec<(MemberId, Type)>, + e: StructExpressionInner<'ast, T>, +) -> StructExpressionInner<'ast, T> { + match e { + StructExpressionInner::Identifier(id) => StructExpressionInner::Identifier(f.fold_name(id)), + StructExpressionInner::Value(exprs) => { + StructExpressionInner::Value(exprs.into_iter().map(|e| f.fold_expression(e)).collect()) + } + StructExpressionInner::FunctionCall(id, exps) => { + let exps = exps.into_iter().map(|e| f.fold_expression(e)).collect(); + StructExpressionInner::FunctionCall(id, exps) + } + StructExpressionInner::IfElse(box condition, box consequence, box alternative) => { + StructExpressionInner::IfElse( + box f.fold_boolean_expression(condition), + box f.fold_struct_expression(consequence), + box f.fold_struct_expression(alternative), + ) + } + StructExpressionInner::Member(box s, id) => { + let s = f.fold_struct_expression(s); + StructExpressionInner::Member(box s, id) + } + StructExpressionInner::Select(box array, box index) => { + let array = f.fold_array_expression(array); + let index = f.fold_field_expression(index); + StructExpressionInner::Select(box array, box index) + } + } +} + pub fn fold_field_expression<'ast, T: Field, F: Folder<'ast, T>>( f: &mut F, e: FieldElementExpression<'ast, T>, @@ -230,6 +282,10 @@ pub fn fold_field_expression<'ast, T: Field, F: Folder<'ast, T>>( let exps = exps.into_iter().map(|e| f.fold_expression(e)).collect(); FieldElementExpression::FunctionCall(key, exps) } + FieldElementExpression::Member(box s, id) => { + let s = f.fold_struct_expression(s); + FieldElementExpression::Member(box s, id) + } FieldElementExpression::Select(box array, box index) => { let array = f.fold_array_expression(array); let index = f.fold_field_expression(index); @@ -290,6 +346,10 @@ pub fn fold_boolean_expression<'ast, T: Field, F: Folder<'ast, T>>( let alt = f.fold_boolean_expression(alt); BooleanExpression::IfElse(box cond, box cons, box alt) } + BooleanExpression::Member(box s, id) => { + let s = f.fold_struct_expression(s); + BooleanExpression::Member(box s, id) + } BooleanExpression::Select(box array, box index) => { let array = f.fold_array_expression(array); let index = f.fold_field_expression(index); @@ -327,6 +387,16 @@ pub fn fold_array_expression<'ast, T: Field, F: Folder<'ast, T>>( } } +pub fn fold_struct_expression<'ast, T: Field, F: Folder<'ast, T>>( + f: &mut F, + e: StructExpression<'ast, T>, +) -> StructExpression<'ast, T> { + StructExpression { + inner: f.fold_struct_expression_inner(&e.ty, e.inner), + ..e + } +} + pub fn fold_function_symbol<'ast, T: Field, F: Folder<'ast, T>>( f: &mut F, s: TypedFunctionSymbol<'ast, T>, diff --git a/zokrates_core/src/typed_absy/mod.rs b/zokrates_core/src/typed_absy/mod.rs index cf3cc036..b3020f06 100644 --- a/zokrates_core/src/typed_absy/mod.rs +++ b/zokrates_core/src/typed_absy/mod.rs @@ -12,7 +12,7 @@ mod variable; pub use crate::typed_absy::parameter::Parameter; pub use crate::typed_absy::variable::Variable; -use crate::types::{FunctionKey, Signature, Type}; +use crate::types::{FunctionKey, MemberId, Signature, Type}; use embed::FlatEmbed; use std::collections::HashMap; use std::convert::TryFrom; @@ -359,6 +359,7 @@ pub enum TypedExpression<'ast, T: Field> { Boolean(BooleanExpression<'ast, T>), FieldElement(FieldElementExpression<'ast, T>), Array(ArrayExpression<'ast, T>), + Struct(StructExpression<'ast, T>), } impl<'ast, T: Field> From> for TypedExpression<'ast, T> { @@ -379,12 +380,19 @@ impl<'ast, T: Field> From> for TypedExpression<'ast, T> } } +impl<'ast, T: Field> From> for TypedExpression<'ast, T> { + fn from(e: StructExpression<'ast, T>) -> TypedExpression { + TypedExpression::Struct(e) + } +} + impl<'ast, T: Field> fmt::Display for TypedExpression<'ast, T> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match *self { TypedExpression::Boolean(ref e) => write!(f, "{}", e), TypedExpression::FieldElement(ref e) => write!(f, "{}", e), - TypedExpression::Array(ref e) => write!(f, "{}", e.inner), + TypedExpression::Array(ref e) => write!(f, "{}", e), + TypedExpression::Struct(ref s) => write!(f, "{}", s), } } } @@ -395,6 +403,7 @@ impl<'ast, T: Field> fmt::Debug for TypedExpression<'ast, T> { TypedExpression::Boolean(ref e) => write!(f, "{:?}", e), TypedExpression::FieldElement(ref e) => write!(f, "{:?}", e), TypedExpression::Array(ref e) => write!(f, "{:?}", e), + TypedExpression::Struct(ref s) => write!(f, "{}", s), } } } @@ -411,12 +420,25 @@ impl<'ast, T: Field> fmt::Debug for ArrayExpression<'ast, T> { } } +impl<'ast, T: Field> fmt::Display for StructExpression<'ast, T> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}", self.inner) + } +} + +impl<'ast, T: Field> fmt::Debug for StructExpression<'ast, T> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{:?}", self.inner) + } +} + impl<'ast, T: Field> Typed for TypedExpression<'ast, T> { fn get_type(&self) -> Type { match *self { TypedExpression::Boolean(_) => Type::Boolean, TypedExpression::FieldElement(_) => Type::FieldElement, TypedExpression::Array(ref e) => e.get_type(), + TypedExpression::Struct(ref s) => s.get_type(), } } } @@ -427,6 +449,12 @@ impl<'ast, T: Field> Typed for ArrayExpression<'ast, T> { } } +impl<'ast, T: Field> Typed for StructExpression<'ast, T> { + fn get_type(&self) -> Type { + Type::Struct(self.ty.clone()) + } +} + pub trait MultiTyped { fn get_types(&self) -> &Vec; } @@ -475,6 +503,7 @@ pub enum FieldElementExpression<'ast, T: Field> { Box>, ), FunctionCall(FunctionKey<'ast>, Vec>), + Member(Box>, MemberId), Select( Box>, Box>, @@ -520,6 +549,7 @@ pub enum BooleanExpression<'ast, T: Field> { Box>, Box>, ), + Member(Box>, MemberId), Select( Box>, Box>, @@ -543,6 +573,7 @@ pub enum ArrayExpressionInner<'ast, T: Field> { Box>, Box>, ), + Member(Box>, MemberId), Select( Box>, Box>, @@ -559,6 +590,29 @@ impl<'ast, T: Field> ArrayExpression<'ast, T> { } } +#[derive(Clone, PartialEq, Hash, Eq)] +pub struct StructExpression<'ast, T: Field> { + pub ty: Vec<(MemberId, Type)>, + pub inner: StructExpressionInner<'ast, T>, +} + +#[derive(Clone, PartialEq, Hash, Eq)] +pub enum StructExpressionInner<'ast, T: Field> { + Identifier(Identifier<'ast>), + Value(Vec>), + FunctionCall(FunctionKey<'ast>, Vec>), + IfElse( + Box>, + Box>, + Box>, + ), + Member(Box>, MemberId), + Select( + Box>, + Box>, + ), +} + // Downcasts impl<'ast, T: Field> TryFrom> for FieldElementExpression<'ast, T> { @@ -596,6 +650,17 @@ impl<'ast, T: Field> TryFrom> for ArrayExpression<'ast, } } +impl<'ast, T: Field> TryFrom> for StructExpression<'ast, T> { + type Error = (); + + fn try_from(te: TypedExpression<'ast, T>) -> Result, Self::Error> { + match te { + TypedExpression::Struct(e) => Ok(e), + _ => Err(()), + } + } +} + impl<'ast, T: Field> fmt::Display for FieldElementExpression<'ast, T> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match *self { @@ -623,6 +688,7 @@ impl<'ast, T: Field> fmt::Display for FieldElementExpression<'ast, T> { } write!(f, ")") } + FieldElementExpression::Member(ref struc, ref id) => write!(f, "{}.{}", struc, id), FieldElementExpression::Select(ref id, ref index) => write!(f, "{}[{}]", id, index), } } @@ -646,6 +712,7 @@ impl<'ast, T: Field> fmt::Display for BooleanExpression<'ast, T> { "if {} then {} else {} fi", condition, consequent, alternative ), + BooleanExpression::Member(ref struc, ref id) => write!(f, "{}.{}", struc, id), BooleanExpression::Select(ref id, ref index) => write!(f, "{}[{}]", id, index), } } @@ -679,11 +746,48 @@ impl<'ast, T: Field> fmt::Display for ArrayExpressionInner<'ast, T> { "if {} then {} else {} fi", condition, consequent, alternative ), + ArrayExpressionInner::Member(ref s, ref id) => write!(f, "{}.{}", s, id), ArrayExpressionInner::Select(ref id, ref index) => write!(f, "{}[{}]", id, index), } } } +impl<'ast, T: Field> fmt::Display for StructExpressionInner<'ast, T> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match *self { + StructExpressionInner::Identifier(ref var) => write!(f, "{}", var), + StructExpressionInner::Value(ref values) => write!( + f, + "[{}]", + values + .iter() + .map(|o| o.to_string()) + .collect::>() + .join(", ") + ), + StructExpressionInner::FunctionCall(ref key, ref p) => { + r#try!(write!(f, "{}(", key.id,)); + for (i, param) in p.iter().enumerate() { + r#try!(write!(f, "{}", param)); + if i < p.len() - 1 { + r#try!(write!(f, ", ")); + } + } + write!(f, ")") + } + StructExpressionInner::IfElse(ref condition, ref consequent, ref alternative) => { + write!( + f, + "if {} then {} else {} fi", + condition, consequent, alternative + ) + } + StructExpressionInner::Member(ref struc, ref id) => write!(f, "{}.{}", struc, id), + StructExpressionInner::Select(ref id, ref index) => write!(f, "{}[{}]", id, index), + } + } +} + impl<'ast, T: Field> fmt::Debug for BooleanExpression<'ast, T> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "{}", self) @@ -714,6 +818,9 @@ impl<'ast, T: Field> fmt::Debug for FieldElementExpression<'ast, T> { r#try!(f.debug_list().entries(p.iter()).finish()); write!(f, ")") } + FieldElementExpression::Member(ref struc, ref id) => { + write!(f, "Member({:?}, {:?})", struc, id) + } FieldElementExpression::Select(ref id, ref index) => { write!(f, "Select({:?}, {:?})", id, index) } @@ -736,6 +843,9 @@ impl<'ast, T: Field> fmt::Debug for ArrayExpressionInner<'ast, T> { "IfElse({:?}, {:?}, {:?})", condition, consequent, alternative ), + ArrayExpressionInner::Member(ref struc, ref id) => { + write!(f, "Member({:?}, {:?})", struc, id) + } ArrayExpressionInner::Select(ref id, ref index) => { write!(f, "Select({:?}, {:?})", id, index) } @@ -743,6 +853,33 @@ impl<'ast, T: Field> fmt::Debug for ArrayExpressionInner<'ast, T> { } } +impl<'ast, T: Field> fmt::Debug for StructExpressionInner<'ast, T> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match *self { + StructExpressionInner::Identifier(ref var) => write!(f, "{:?}", var), + StructExpressionInner::Value(ref values) => write!(f, "{:?}", values), + StructExpressionInner::FunctionCall(ref i, ref p) => { + r#try!(write!(f, "FunctionCall({:?}, (", i)); + r#try!(f.debug_list().entries(p.iter()).finish()); + write!(f, ")") + } + StructExpressionInner::IfElse(ref condition, ref consequent, ref alternative) => { + write!( + f, + "IfElse({:?}, {:?}, {:?})", + condition, consequent, alternative + ) + } + StructExpressionInner::Member(ref struc, ref id) => { + write!(f, "Member({:?}, {:?})", struc, id) + } + StructExpressionInner::Select(ref id, ref index) => { + write!(f, "Select({:?}, {:?})", id, index) + } + } + } +} + impl<'ast, T: Field> fmt::Display for TypedExpressionList<'ast, T> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match *self { diff --git a/zokrates_core/src/typed_absy/parameter.rs b/zokrates_core/src/typed_absy/parameter.rs index 7437a31e..58cf8c2f 100644 --- a/zokrates_core/src/typed_absy/parameter.rs +++ b/zokrates_core/src/typed_absy/parameter.rs @@ -31,11 +31,11 @@ impl<'ast> fmt::Debug for Parameter<'ast> { } } -impl<'ast> From> for Parameter<'ast> { - fn from(p: absy::Parameter<'ast>) -> Parameter { - Parameter { - private: p.private, - id: p.id.value.into(), - } - } -} +// impl<'ast> From> for Parameter<'ast> { +// fn from(p: absy::Parameter<'ast>) -> Parameter { +// Parameter { +// private: p.private, +// id: p.id.value.into(), +// } +// } +// } diff --git a/zokrates_core/src/typed_absy/variable.rs b/zokrates_core/src/typed_absy/variable.rs index 1fbc2154..cf7b7b86 100644 --- a/zokrates_core/src/typed_absy/variable.rs +++ b/zokrates_core/src/typed_absy/variable.rs @@ -48,15 +48,15 @@ impl<'ast> fmt::Debug for Variable<'ast> { } } -impl<'ast> From> for Variable<'ast> { - fn from(v: absy::Variable) -> Variable { - Variable::with_id_and_type( - Identifier { - id: v.id, - version: 0, - stack: vec![], - }, - v._type, - ) - } -} +// impl<'ast> From> for Variable<'ast> { +// fn from(v: absy::Variable) -> Variable { +// Variable::with_id_and_type( +// Identifier { +// id: v.id, +// version: 0, +// stack: vec![], +// }, +// v._type, +// ) +// } +// } diff --git a/zokrates_core/src/types/mod.rs b/zokrates_core/src/types/mod.rs index 390f2243..40cb595a 100644 --- a/zokrates_core/src/types/mod.rs +++ b/zokrates_core/src/types/mod.rs @@ -1,8 +1,13 @@ pub use crate::types::signature::Signature; +pub use crate::types::signature::UnresolvedSignature; use std::fmt; pub type Identifier<'ast> = &'ast str; +pub type MemberId = String; + +pub type UserTypeId = String; + mod signature; #[derive(Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] @@ -10,6 +15,26 @@ pub enum Type { FieldElement, Boolean, Array(Box, usize), + Struct(Vec<(MemberId, Type)>), +} + +#[derive(Clone, PartialEq, Eq, Hash, Serialize, Deserialize, Debug)] +pub enum UnresolvedType { + FieldElement, + Boolean, + Array(Box, usize), + User(UserTypeId), +} + +impl fmt::Display for UnresolvedType { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + UnresolvedType::FieldElement => write!(f, "field"), + UnresolvedType::Boolean => write!(f, "bool"), + UnresolvedType::Array(ref ty, ref size) => write!(f, "{}[{}]", ty, size), + UnresolvedType::User(i) => write!(f, "{}", i), + } + } } impl fmt::Display for Type { @@ -18,6 +43,15 @@ impl fmt::Display for Type { Type::FieldElement => write!(f, "field"), Type::Boolean => write!(f, "bool"), Type::Array(ref ty, ref size) => write!(f, "{}[{}]", ty, size), + Type::Struct(ref members) => write!( + f, + "{{{}}}", + members + .iter() + .map(|(id, t)| format!("{}: {}", id, t)) + .collect::>() + .join(", ") + ), } } } @@ -28,6 +62,15 @@ impl fmt::Debug for Type { Type::FieldElement => write!(f, "field"), Type::Boolean => write!(f, "bool"), Type::Array(ref ty, ref size) => write!(f, "{}[{}]", ty, size), + Type::Struct(ref members) => write!( + f, + "{{{}}}", + members + .iter() + .map(|(id, t)| format!("{}: {}", id, t)) + .collect::>() + .join(", ") + ), } } } @@ -42,6 +85,7 @@ impl Type { Type::FieldElement => String::from("f"), Type::Boolean => String::from("b"), Type::Array(box ty, size) => format!("{}[{}]", ty.to_slug(), size), + Type::Struct(members) => unimplemented!(), } } @@ -51,10 +95,17 @@ impl Type { Type::FieldElement => 1, Type::Boolean => 1, Type::Array(ty, size) => size * ty.get_primitive_count(), + Type::Struct(members) => members.iter().map(|(_, t)| t.get_primitive_count()).sum(), } } } +impl UnresolvedType { + pub fn array(ty: UnresolvedType, size: usize) -> Self { + UnresolvedType::Array(box ty, size) + } +} + #[derive(Clone, PartialEq, Hash, Eq)] pub struct Variable<'ast> { pub id: Identifier<'ast>, diff --git a/zokrates_core/src/types/signature.rs b/zokrates_core/src/types/signature.rs index 66096c2d..bca1e7e7 100644 --- a/zokrates_core/src/types/signature.rs +++ b/zokrates_core/src/types/signature.rs @@ -1,6 +1,68 @@ use crate::types::Type; use std::fmt; +pub use self::unresolved::UnresolvedSignature; + +mod unresolved { + use super::*; + use types::UnresolvedType; + + #[derive(Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] + pub struct UnresolvedSignature { + pub inputs: Vec, + pub outputs: Vec, + } + + impl fmt::Debug for UnresolvedSignature { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!( + f, + "Signature(inputs: {:?}, outputs: {:?})", + self.inputs, self.outputs + ) + } + } + + impl fmt::Display for UnresolvedSignature { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + r#try!(write!(f, "(")); + for (i, t) in self.inputs.iter().enumerate() { + r#try!(write!(f, "{}", t)); + if i < self.inputs.len() - 1 { + r#try!(write!(f, ", ")); + } + } + r#try!(write!(f, ") -> (")); + for (i, t) in self.outputs.iter().enumerate() { + r#try!(write!(f, "{}", t)); + if i < self.outputs.len() - 1 { + r#try!(write!(f, ", ")); + } + } + write!(f, ")") + } + } + + impl UnresolvedSignature { + pub fn new() -> UnresolvedSignature { + UnresolvedSignature { + inputs: vec![], + outputs: vec![], + } + } + + pub fn inputs(mut self, inputs: Vec) -> Self { + self.inputs = inputs; + self + } + + pub fn outputs(mut self, outputs: Vec) -> Self { + self.outputs = outputs; + self + } + } +} + #[derive(Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] pub struct Signature { pub inputs: Vec, diff --git a/zokrates_parser/src/zokrates.pest b/zokrates_parser/src/zokrates.pest index 2abd642c..b33f91b6 100644 --- a/zokrates_parser/src/zokrates.pest +++ b/zokrates_parser/src/zokrates.pest @@ -21,7 +21,7 @@ type_list = _{(ty ~ ("," ~ ty)*)?} // structs ty_struct = { identifier } // type definitions -ty_struct_definition = { "struct" ~ identifier ~ "{" ~ NEWLINE* ~ struct_field_list ~ NEWLINE* ~ "}" } +ty_struct_definition = { "struct" ~ identifier ~ "{" ~ NEWLINE* ~ struct_field_list ~ NEWLINE* ~ "}" ~ NEWLINE* } struct_field_list = _{(struct_field ~ ("," ~ NEWLINE* ~ struct_field)*)? ~ ","? } struct_field = { identifier ~ ":" ~ ty } diff --git a/zokrates_pest_ast/src/lib.rs b/zokrates_pest_ast/src/lib.rs index c24ce49b..ad957d6b 100644 --- a/zokrates_pest_ast/src/lib.rs +++ b/zokrates_pest_ast/src/lib.rs @@ -244,7 +244,7 @@ mod ast { #[derive(Debug, FromPest, PartialEq, Clone)] #[pest_ast(rule(Rule::ty_struct))] pub struct StructType<'ast> { - id: IdentifierExpression<'ast>, + pub id: IdentifierExpression<'ast>, #[pest_ast(outer())] pub span: Span<'ast>, } From 266320f7d9916bae5ebb16ffbbd24ffb55766360 Mon Sep 17 00:00:00 2001 From: schaeff Date: Fri, 2 Aug 2019 12:49:43 +0200 Subject: [PATCH 10/35] wip --- t.code | 16 + zokrates_core/src/absy/from_ast.rs | 40 ++- zokrates_core/src/absy/mod.rs | 100 +++--- zokrates_core/src/absy/node.rs | 5 +- zokrates_core/src/imports.rs | 26 +- zokrates_core/src/semantics.rs | 507 +++++++++++++++++++++-------- zokrates_core/src/types/mod.rs | 5 + 7 files changed, 473 insertions(+), 226 deletions(-) create mode 100644 t.code diff --git a/t.code b/t.code new file mode 100644 index 00000000..d96d3796 --- /dev/null +++ b/t.code @@ -0,0 +1,16 @@ +struct Foo { + a: field, + b: field[2], +} + +struct Bar { + a: Foo, + b: field[2] +} + +def f(Foo a) -> (Foo): + return a + +def main(Bar a) -> (Foo): + return f(a.a) + diff --git a/zokrates_core/src/absy/from_ast.rs b/zokrates_core/src/absy/from_ast.rs index d1bbbe82..d8012a78 100644 --- a/zokrates_core/src/absy/from_ast.rs +++ b/zokrates_core/src/absy/from_ast.rs @@ -7,15 +7,25 @@ use zokrates_pest_ast as pest; impl<'ast, T: Field> From> for absy::Module<'ast, T> { fn from(prog: pest::File<'ast>) -> absy::Module { absy::Module { - types: prog + // types: prog + // .structs + // .into_iter() + // .map(|t| absy::TypeDeclarationNode::from(t)) + // .collect(), + // functions: prog + // .functions + // .into_iter() + // .map(|f| absy::FunctionDeclarationNode::from(f)) + // .collect(), + symbols: prog .structs .into_iter() - .map(|t| absy::TypeDeclarationNode::from(t)) - .collect(), - functions: prog - .functions - .into_iter() - .map(|f| absy::FunctionDeclarationNode::from(f)) + .map(|t| absy::SymbolDeclarationNode::from(t)) + .chain( + prog.functions + .into_iter() + .map(|f| absy::SymbolDeclarationNode::from(f)), + ) .collect(), imports: prog .imports @@ -35,8 +45,8 @@ impl<'ast> From> for absy::ImportNode<'ast> { } } -impl<'ast> From> for absy::TypeDeclarationNode<'ast> { - fn from(definition: pest::StructDefinition<'ast>) -> absy::TypeDeclarationNode { +impl<'ast, T: Field> From> for absy::SymbolDeclarationNode<'ast, T> { + fn from(definition: pest::StructDefinition<'ast>) -> absy::SymbolDeclarationNode<'ast, T> { use absy::NodeValue; let span = definition.span; @@ -52,9 +62,9 @@ impl<'ast> From> for absy::TypeDeclarationNode<'ast } .span(span.clone()); // TODO check - absy::TypeDeclaration { + absy::SymbolDeclaration { id, - symbol: absy::TypeSymbol::Here(ty), + symbol: absy::Symbol::HereType(ty), } .span(span) } @@ -74,8 +84,8 @@ impl<'ast> From> for absy::StructFieldNode<'ast> { } } -impl<'ast, T: Field> From> for absy::FunctionDeclarationNode<'ast, T> { - fn from(function: pest::Function<'ast>) -> absy::FunctionDeclarationNode { +impl<'ast, T: Field> From> for absy::SymbolDeclarationNode<'ast, T> { + fn from(function: pest::Function<'ast>) -> absy::SymbolDeclarationNode { use absy::NodeValue; let span = function.span; @@ -115,9 +125,9 @@ impl<'ast, T: Field> From> for absy::FunctionDeclarationNod } .span(span.clone()); // TODO check - absy::FunctionDeclaration { + absy::SymbolDeclaration { id, - symbol: absy::FunctionSymbol::Here(function), + symbol: absy::Symbol::HereFunction(function), } .span(span) } diff --git a/zokrates_core/src/absy/mod.rs b/zokrates_core/src/absy/mod.rs index 2c46cf15..bbedd529 100644 --- a/zokrates_core/src/absy/mod.rs +++ b/zokrates_core/src/absy/mod.rs @@ -31,11 +31,8 @@ pub type ModuleId = String; /// A collection of `Module`s pub type Modules<'ast, T> = HashMap>; -/// A collection of `FunctionDeclaration`. Duplicates are allowed here as they are fine syntatically. -pub type FunctionDeclarations<'ast, T> = Vec>; - -/// A collection of `StructDeclaration`. Duplicates are allowed here as they are fine syntatically. -pub type TypeDeclarations<'ast> = Vec>; +/// A collection of `SymbolDeclaration`. Duplicates are allowed here as they are fine syntatically. +pub type Declarations<'ast, T> = Vec>; /// A `Program` is a collection of `Module`s and an id of the main `Module` pub struct Program<'ast, T: Field> { @@ -45,34 +42,26 @@ pub struct Program<'ast, T: Field> { /// A declaration of a `FunctionSymbol`, be it from an import or a function definition #[derive(PartialEq, Debug, Clone)] -pub struct FunctionDeclaration<'ast, T: Field> { +pub struct SymbolDeclaration<'ast, T: Field> { pub id: Identifier<'ast>, - pub symbol: FunctionSymbol<'ast, T>, + pub symbol: Symbol<'ast, T>, } -/// A declaration of a `TypeSymbol`, be it from an import or a function definition #[derive(PartialEq, Debug, Clone)] -pub struct TypeDeclaration<'ast> { - pub id: Identifier<'ast>, - pub symbol: TypeSymbol<'ast>, +pub enum Symbol<'ast, T: Field> { + HereType(StructTypeNode<'ast>), + HereFunction(FunctionNode<'ast, T>), + There(SymbolImportNode<'ast>), + Flat(FlatEmbed), } -impl<'ast> fmt::Display for TypeDeclaration<'ast> { +impl<'ast, T: Field> fmt::Display for SymbolDeclaration<'ast, T> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self.symbol { - TypeSymbol::Here(ref s) => write!(f, "struct {} {}", self.id, s), - } - } -} - -type TypeDeclarationNode<'ast> = Node>; - -impl<'ast, T: Field> fmt::Display for FunctionDeclaration<'ast, T> { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self.symbol { - FunctionSymbol::Here(ref fun) => write!(f, "def {}{}", self.id, fun), - FunctionSymbol::There(ref import) => write!(f, "import {} as {}", import, self.id), - FunctionSymbol::Flat(ref flat_fun) => write!( + Symbol::HereType(ref t) => write!(f, "struct {} {}", self.id, t), + Symbol::HereFunction(ref fun) => write!(f, "def {}{}", self.id, fun), + Symbol::There(ref import) => write!(f, "import {} as {}", import, self.id), + Symbol::Flat(ref flat_fun) => write!( f, "def {}{}:\n\t// hidden", self.id, @@ -82,31 +71,30 @@ impl<'ast, T: Field> fmt::Display for FunctionDeclaration<'ast, T> { } } -type FunctionDeclarationNode<'ast, T> = Node>; +type SymbolDeclarationNode<'ast, T> = Node>; /// A module as a collection of `FunctionDeclaration`s #[derive(Clone, PartialEq)] pub struct Module<'ast, T: Field> { - /// Structs of the module - pub types: TypeDeclarations<'ast>, - /// Functions of the module - pub functions: FunctionDeclarations<'ast, T>, + /// Symbols of the module + pub symbols: Declarations<'ast, T>, pub imports: Vec>, // we still use `imports` as they are not directly converted into `FunctionDeclaration`s after the importer is done, `imports` is empty } -/// A function, be it defined in this module, imported from another module or a flat embed -#[derive(Debug, Clone, PartialEq)] -pub enum FunctionSymbol<'ast, T: Field> { - Here(FunctionNode<'ast, T>), - There(FunctionImportNode<'ast>), - Flat(FlatEmbed), -} +// /// A function, be it defined in this module, imported from another module or a flat embed +// #[derive(Debug, Clone, PartialEq)] +// pub enum FunctionSymbol<'ast, T: Field> { +// Here(FunctionNode<'ast, T>), +// There(FunctionImportNode<'ast>), +// Flat(FlatEmbed), +// } -/// A user defined type, a struct defined in this module for now // TODO allow importing types -#[derive(Debug, Clone, PartialEq)] -pub enum TypeSymbol<'ast> { - Here(StructTypeNode<'ast>), -} +// /// A user defined type, a struct defined in this module for now // TODO allow importing types +// #[derive(Debug, Clone, PartialEq)] +// pub enum TypeSymbol<'ast> { +// Here(StructTypeNode<'ast>), +// There(TypeImportNode<'ast>), +// } /// A struct type definition #[derive(Debug, Clone, PartialEq)] @@ -145,32 +133,32 @@ impl<'ast> fmt::Display for StructField<'ast> { type StructFieldNode<'ast> = Node>; -/// A function import +/// An import #[derive(Debug, Clone, PartialEq)] -pub struct FunctionImport<'ast> { - /// the id of the function in the target module. Note: there may be many candidates as imports statements do not specify the signature - pub function_id: Identifier<'ast>, +pub struct SymbolImport<'ast> { + /// the id of the symbol in the target module. Note: there may be many candidates as imports statements do not specify the signature. In that case they must all be functions however. + pub symbol_id: Identifier<'ast>, /// the id of the module to import from pub module_id: ModuleId, } -type FunctionImportNode<'ast> = Node>; +type SymbolImportNode<'ast> = Node>; -impl<'ast> FunctionImport<'ast> { +impl<'ast> SymbolImport<'ast> { pub fn with_id_in_module>, U: Into>( - function_id: S, + symbol_id: S, module_id: U, ) -> Self { - FunctionImport { - function_id: function_id.into(), + SymbolImport { + symbol_id: symbol_id.into(), module_id: module_id.into(), } } } -impl<'ast> fmt::Display for FunctionImport<'ast> { +impl<'ast> fmt::Display for SymbolImport<'ast> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "{} from {}", self.function_id, self.module_id) + write!(f, "{} from {}", self.symbol_id, self.module_id) } } @@ -184,7 +172,7 @@ impl<'ast, T: Field> fmt::Display for Module<'ast, T> { .collect::>(), ); res.extend( - self.functions + self.symbols .iter() .map(|x| format!("{}", x)) .collect::>(), @@ -197,13 +185,13 @@ impl<'ast, T: Field> fmt::Debug for Module<'ast, T> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!( f, - "module(\n\timports:\n\t\t{}\n\tfunctions:\n\t\t{}\n)", + "module(\n\timports:\n\t\t{}\n\tsymbols:\n\t\t{}\n)", self.imports .iter() .map(|x| format!("{:?}", x)) .collect::>() .join("\n\t\t"), - self.functions + self.symbols .iter() .map(|x| format!("{:?}", x)) .collect::>() diff --git a/zokrates_core/src/absy/node.rs b/zokrates_core/src/absy/node.rs index 13d4e44e..4b8a2748 100644 --- a/zokrates_core/src/absy/node.rs +++ b/zokrates_core/src/absy/node.rs @@ -74,13 +74,12 @@ impl<'ast, T: Field> NodeValue for Expression<'ast, T> {} impl<'ast, T: Field> NodeValue for ExpressionList<'ast, T> {} impl<'ast, T: Field> NodeValue for Assignee<'ast, T> {} impl<'ast, T: Field> NodeValue for Statement<'ast, T> {} -impl<'ast, T: Field> NodeValue for FunctionDeclaration<'ast, T> {} -impl<'ast> NodeValue for TypeDeclaration<'ast> {} +impl<'ast, T: Field> NodeValue for SymbolDeclaration<'ast, T> {} impl<'ast> NodeValue for StructType<'ast> {} impl<'ast> NodeValue for StructField<'ast> {} impl<'ast, T: Field> NodeValue for Function<'ast, T> {} impl<'ast, T: Field> NodeValue for Module<'ast, T> {} -impl<'ast> NodeValue for FunctionImport<'ast> {} +impl<'ast> NodeValue for SymbolImport<'ast> {} impl<'ast> NodeValue for Variable<'ast> {} impl<'ast> NodeValue for Parameter<'ast> {} impl<'ast> NodeValue for Import<'ast> {} diff --git a/zokrates_core/src/imports.rs b/zokrates_core/src/imports.rs index 4fbd2bb5..ab7f14b9 100644 --- a/zokrates_core/src/imports.rs +++ b/zokrates_core/src/imports.rs @@ -124,7 +124,7 @@ impl Importer { modules: &mut HashMap>, arena: &'ast Arena, ) -> Result, CompileErrors> { - let mut functions: Vec<_> = vec![]; + let mut symbols: Vec<_> = vec![]; for import in destination.imports { let pos = import.pos(); @@ -136,10 +136,10 @@ impl Importer { "EMBED/sha256round" => { let alias = alias.unwrap_or("sha256round"); - functions.push( - FunctionDeclaration { + symbols.push( + SymbolDeclaration { id: &alias, - symbol: FunctionSymbol::Flat(FlatEmbed::Sha256Round), + symbol: Symbol::Flat(FlatEmbed::Sha256Round), } .start_end(pos.0, pos.1), ); @@ -147,10 +147,10 @@ impl Importer { "EMBED/unpack" => { let alias = alias.unwrap_or("unpack"); - functions.push( - FunctionDeclaration { + symbols.push( + SymbolDeclaration { id: &alias, - symbol: FunctionSymbol::Flat(FlatEmbed::Unpack), + symbol: Symbol::Flat(FlatEmbed::Unpack), } .start_end(pos.0, pos.1), ); @@ -185,11 +185,11 @@ impl Importer { modules.insert(import.source.to_string(), compiled); - functions.push( - FunctionDeclaration { + symbols.push( + SymbolDeclaration { id: &alias, - symbol: FunctionSymbol::There( - FunctionImport::with_id_in_module( + symbol: Symbol::There( + SymbolImport::with_id_in_module( "main", import.source.clone(), ) @@ -218,11 +218,11 @@ impl Importer { } } - functions.extend(destination.functions); + symbols.extend(destination.symbols); Ok(Module { imports: vec![], - functions: functions, + symbols, ..destination }) } diff --git a/zokrates_core/src/semantics.rs b/zokrates_core/src/semantics.rs index 03e03633..58600c2c 100644 --- a/zokrates_core/src/semantics.rs +++ b/zokrates_core/src/semantics.rs @@ -159,17 +159,17 @@ impl<'ast> Checker<'ast> { let mut errors = vec![]; - match Checker::check_single_main(modules.get(&program.main).unwrap()) { - Ok(_) => {} - Err(e) => errors.push(e), - }; - // recursively type-check modules starting with `main` match self.check_module(&program.main, &mut modules, &mut typed_modules) { Ok(()) => {} Err(e) => errors.extend(e), }; + match Checker::check_single_main(typed_modules.get(&program.main).unwrap()) { + Ok(_) => {} + Err(e) => errors.push(e), + }; + if errors.len() > 0 { return Err(errors); } @@ -180,19 +180,61 @@ impl<'ast> Checker<'ast> { }) } - fn check_type_symbol( - &mut self, - s: TypeSymbol<'ast>, - module_id: &ModuleId, - modules: &mut Modules<'ast, T>, - typed_modules: &mut TypedModules<'ast, T>, - ) -> Result> { - match s { - TypeSymbol::Here(t) => { - self.check_struct_type_declaration(t, module_id, modules, typed_modules) - } - } - } + // fn check_type_symbol( + // &mut self, + // s: StructTypeNode<'ast>, + // module_id: &ModuleId, + // modules: &mut Modules<'ast, T>, + // typed_modules: &mut TypedModules<'ast, T>, + // ) -> Result> { + + // let mut errors = vec![]; + + // match s { + // TypeSymbol::Here(t) => { + // self.check_struct_type_declaration(t, module_id, modules, typed_modules) + // } + // TypeSymbol::There(import_node) => { + // let pos = import_node.pos(); + // let import = import_node.value; + + // let res = + // match Checker::new().check_module(&import.module_id, modules, typed_modules) { + // Ok(()) => { + // match self + // .types + // .get(&import.module_id) + // .unwrap() + // .get(import.type_id) + // { + // Some(ty) => Some(ty), + // None => { + // errors.push(Error { + // pos: Some(pos), + // message: format!( + // "Type {} not found in module {}", + // import.type_id, import.module_id + // ), + // }); + // None + // } + // } + // } + // Err(e) => { + // errors.extend(e); + // None + // } + // }; + + // // return if any errors occured + // if errors.len() > 0 { + // return Err(errors); + // } + + // Ok(res.unwrap().clone()) + // } + // } + // } fn check_struct_type_declaration( &mut self, @@ -235,81 +277,268 @@ impl<'ast> Checker<'ast> { Some(module) => { assert_eq!(module.imports.len(), 0); - for declaration in module.types { + let ids = HashSet::new(); + + for declaration in module.symbols { let pos = declaration.pos(); let declaration = declaration.value; - let ty = self - .check_type_symbol(declaration.symbol, module_id, modules, typed_modules) - .unwrap(); - self.types - .entry(module_id.clone()) - .or_default() - .insert(declaration.id.to_string(), ty); - } + match declaration.symbol { + Symbol::HereType(t) => { + match ids.insert(declaration.id) { + true => errors.push(Error { + pos: Some(pos), + message: format!( + "Another symbol with id {} is already defined", + declaration.id, + ), + }), + false => {} + }; - for declaration in module.functions { - self.enter_scope(); + let ty = self + .check_struct_type_declaration(t, module_id, modules, typed_modules) + .unwrap(); - let pos = declaration.pos(); - let declaration = declaration.value; + self.types + .entry(module_id.clone()) + .or_default() + .insert(declaration.id.to_string(), ty); + } + Symbol::HereFunction(f) => { + self.enter_scope(); - match self.check_function_symbol( - declaration.symbol, - module_id, - modules, - typed_modules, - ) { - Ok(checked_function_symbols) => { - for funct in checked_function_symbols { - let query = FunctionQuery::new( - declaration.id.clone(), - &funct.signature(&typed_modules).inputs, - &funct - .signature(&typed_modules) - .outputs - .clone() - .into_iter() - .map(|o| Some(o)) - .collect(), - ); + match self.check_function(f, module_id) { + Ok(funct) => { + let query = FunctionQuery::new( + declaration.id.clone(), + &funct.signature(&typed_modules).inputs, + &funct + .signature(&typed_modules) + .outputs + .clone() + .into_iter() + .map(|o| Some(o)) + .collect(), + ); - let candidates = self.find_candidates(&query); + let candidates = self.find_candidates(&query); - match candidates.len() { - 1 => { - errors.push(Error { + match candidates.len() { + 1 => { + errors.push(Error { + pos: Some(pos), + message: format!( + "Duplicate definition for function {} with signature {}", + declaration.id, + funct.signature(&typed_modules) + ), + }); + } + 0 => {} + _ => panic!( + "duplicate function declaration should have been caught" + ), + } + + ids.insert(declaration.id); + + self.functions.insert( + FunctionKey::with_id(declaration.id.clone()) + .signature(funct.signature(&typed_modules).clone()), + ); + checked_functions.insert( + FunctionKey::with_id(declaration.id.clone()) + .signature(funct.signature(&typed_modules).clone()), + funct, + ); + } + Err(e) => { + errors.extend(e); + } + } + + self.exit_scope(); + } + Symbol::There(import) => { + let pos = import.pos(); + let import = import.value; + + match Checker::new().check_module( + &import.module_id, + modules, + typed_modules, + ) { + Ok(()) => { + // find candidates in the checked module + let function_candidates: Vec<_> = typed_modules + .get(&import.module_id) + .unwrap() + .functions + .iter() + .filter(|(k, _)| k.id == import.symbol_id) + .map(|(_, v)| FunctionKey { + id: import.symbol_id.clone(), + signature: v.signature(&typed_modules).clone(), + }) + .collect(); + + // find candidates in the types + let type_candidate = self + .types + .get(&import.module_id) + .map(|m| m.get(import.symbol_id)); + + match (function_candidates.len(), type_candidate) { + (0, Some(t)) => errors.push(Error { pos: Some(pos), message: format!( - "Duplicate definition for function {} with signature {}", - declaration.id, - funct.signature(&typed_modules) + "Duplicate symbol {} in module {}", + import.symbol_id, import.module_id ), - }); + }), + (0, None) => unreachable!(), + _ => { + ids.insert(declaration.id); + for candidate in function_candidates { + self.functions + .insert(candidate.clone().id(declaration.id)); + checked_functions.insert( + candidate.clone().id(declaration.id), + TypedFunctionSymbol::There( + candidate, + import.module_id.clone(), + ), + ); + } + } } - 0 => {} - _ => panic!( - "duplicate function declaration should have been caught" - ), } - self.functions.insert( - FunctionKey::with_id(declaration.id.clone()) - .signature(funct.signature(&typed_modules).clone()), - ); - checked_functions.insert( - FunctionKey::with_id(declaration.id.clone()) - .signature(funct.signature(&typed_modules).clone()), - funct, - ); - } + Err(e) => { + errors.extend(e); + } + }; } - Err(e) => { - errors.extend(e); + Symbol::Flat(funct) => { + let query = FunctionQuery::new( + declaration.id.clone(), + &funct.signature::().inputs, + &funct + .signature::() + .outputs + .clone() + .into_iter() + .map(|o| Some(o)) + .collect(), + ); + + let candidates = self.find_candidates(&query); + + match candidates.len() { + 1 => { + errors.push(Error { + pos: Some(pos), + message: format!( + "Duplicate definition for function {} with signature {}", + declaration.id, + funct.signature::() + ), + }); + } + 0 => {} + _ => { + panic!("duplicate function declaration should have been caught") + } + } + ids.insert(declaration.id); + self.functions.insert( + FunctionKey::with_id(declaration.id.clone()) + .signature(funct.signature::().clone()), + ); + checked_functions.insert( + FunctionKey::with_id(declaration.id.clone()) + .signature(funct.signature::().clone()), + TypedFunctionSymbol::Flat(funct), + ); } } - - self.exit_scope(); } + + // for declaration in module.types { + // let pos = declaration.pos(); + // let declaration = declaration.value; + + // let ty = self + // .check_type_symbol(declaration.symbol, module_id, modules, typed_modules) + // .unwrap(); + // self.types + // .entry(module_id.clone()) + // .or_default() + // .insert(declaration.id.to_string(), ty); + // } + + // for declaration in module.functions { + // self.enter_scope(); + + // let pos = declaration.pos(); + // let declaration = declaration.value; + + // match self.check_function_symbol( + // declaration.symbol, + // module_id, + // modules, + // typed_modules, + // ) { + // Ok(checked_function_symbols) => { + // for funct in checked_function_symbols { + // let query = FunctionQuery::new( + // declaration.id.clone(), + // &funct.signature(&typed_modules).inputs, + // &funct + // .signature(&typed_modules) + // .outputs + // .clone() + // .into_iter() + // .map(|o| Some(o)) + // .collect(), + // ); + + // let candidates = self.find_candidates(&query); + + // match candidates.len() { + // 1 => { + // errors.push(Error { + // pos: Some(pos), + // message: format!( + // "Duplicate definition for function {} with signature {}", + // declaration.id, + // funct.signature(&typed_modules) + // ), + // }); + // } + // 0 => {} + // _ => panic!( + // "duplicate function declaration should have been caught" + // ), + // } + // self.functions.insert( + // FunctionKey::with_id(declaration.id.clone()) + // .signature(funct.signature(&typed_modules).clone()), + // ); + // checked_functions.insert( + // FunctionKey::with_id(declaration.id.clone()) + // .signature(funct.signature(&typed_modules).clone()), + // funct, + // ); + // } + // } + // Err(e) => { + // errors.extend(e); + // } + // } + + // self.exit_scope(); + // } Some(TypedModule { functions: checked_functions, }) @@ -336,11 +565,11 @@ impl<'ast> Checker<'ast> { Ok(()) } - fn check_single_main(module: &Module) -> Result<(), Error> { + fn check_single_main(module: &TypedModule) -> Result<(), Error> { match module .functions .iter() - .filter(|node| node.value.id == "main") + .filter(|(key, _)| key.id == "main") .count() { 1 => Ok(()), @@ -369,7 +598,7 @@ impl<'ast> Checker<'ast> { &mut self, funct_node: FunctionNode<'ast, T>, module_id: &ModuleId, - ) -> Result, Vec> { + ) -> Result, Vec> { let mut errors = vec![]; let funct = funct_node.value; @@ -401,7 +630,7 @@ impl<'ast> Checker<'ast> { return Err(errors); } - Ok(TypedFunction { + Ok(TypedFunctionSymbol::Here(TypedFunction { arguments: funct .arguments .into_iter() @@ -409,7 +638,7 @@ impl<'ast> Checker<'ast> { .collect(), statements: statements_checked, signature, - }) + })) } fn check_parameter(&self, p: ParameterNode<'ast>, module_id: &ModuleId) -> Parameter<'ast> { @@ -452,69 +681,69 @@ impl<'ast> Checker<'ast> { } } - fn check_function_symbol( - &mut self, - funct_symbol: FunctionSymbol<'ast, T>, - module_id: &ModuleId, - modules: &mut Modules<'ast, T>, - typed_modules: &mut TypedModules<'ast, T>, - ) -> Result>, Vec> { - let mut symbols = vec![]; - let mut errors = vec![]; + // fn check_function_symbol( + // &mut self, + // funct_symbol: FunctionSymbol<'ast, T>, + // module_id: &ModuleId, + // modules: &mut Modules<'ast, T>, + // typed_modules: &mut TypedModules<'ast, T>, + // ) -> Result>, Vec> { + // let mut symbols = vec![]; + // let mut errors = vec![]; - match funct_symbol { - FunctionSymbol::Here(funct_node) => self - .check_function(funct_node, module_id) - .map(|f| vec![TypedFunctionSymbol::Here(f)]), - FunctionSymbol::There(import_node) => { - let pos = import_node.pos(); - let import = import_node.value; + // match funct_symbol { + // FunctionSymbol::Here(funct_node) => self + // .check_function(funct_node, module_id) + // .map(|f| vec![TypedFunctionSymbol::Here(f)]), + // FunctionSymbol::There(import_node) => { + // let pos = import_node.pos(); + // let import = import_node.value; - match Checker::new().check_module(&import.module_id, modules, typed_modules) { - Ok(()) => { - // find candidates in the checked module - let candidates: Vec<_> = typed_modules - .get(&import.module_id) - .unwrap() - .functions - .iter() - .filter(|(k, _)| k.id == import.function_id) - .map(|(_, v)| FunctionKey { - id: import.function_id.clone(), - signature: v.signature(&typed_modules).clone(), - }) - .collect(); + // match Checker::new().check_module(&import.module_id, modules, typed_modules) { + // Ok(()) => { + // // find candidates in the checked module + // let candidates: Vec<_> = typed_modules + // .get(&import.module_id) + // .unwrap() + // .functions + // .iter() + // .filter(|(k, _)| k.id == import.function_id) + // .map(|(_, v)| FunctionKey { + // id: import.function_id.clone(), + // signature: v.signature(&typed_modules).clone(), + // }) + // .collect(); - match candidates.len() { - 0 => errors.push(Error { - pos: Some(pos), - message: format!( - "Function {} not found in module {}", - import.function_id, import.module_id - ), - }), - _ => { - symbols.extend(candidates.into_iter().map(|f| { - TypedFunctionSymbol::There(f, import.module_id.clone()) - })) - } - } - } - Err(e) => { - errors.extend(e); - } - }; + // match candidates.len() { + // 0 => errors.push(Error { + // pos: Some(pos), + // message: format!( + // "Function {} not found in module {}", + // import.function_id, import.module_id + // ), + // }), + // _ => { + // symbols.extend(candidates.into_iter().map(|f| { + // TypedFunctionSymbol::There(f, import.module_id.clone()) + // })) + // } + // } + // } + // Err(e) => { + // errors.extend(e); + // } + // }; - // return if any errors occured - if errors.len() > 0 { - return Err(errors); - } + // // return if any errors occured + // if errors.len() > 0 { + // return Err(errors); + // } - Ok(symbols) - } - FunctionSymbol::Flat(flat_fun) => Ok(vec![TypedFunctionSymbol::Flat(flat_fun)]), - } - } + // Ok(symbols) + // } + // FunctionSymbol::Flat(flat_fun) => Ok(vec![TypedFunctionSymbol::Flat(flat_fun)]), + // } + // } fn check_variable( &self, diff --git a/zokrates_core/src/types/mod.rs b/zokrates_core/src/types/mod.rs index 40cb595a..c162450f 100644 --- a/zokrates_core/src/types/mod.rs +++ b/zokrates_core/src/types/mod.rs @@ -145,6 +145,11 @@ impl<'ast> FunctionKey<'ast> { self } + pub fn id>>(mut self, id: S) -> Self { + self.id = id.into(); + self + } + pub fn to_slug(&self) -> String { format!("{}_{}", self.id, self.signature.to_slug()) } From edab5c20f8cf773296e4755c7a3fe1a3951cde34 Mon Sep 17 00:00:00 2001 From: schaeff Date: Fri, 2 Aug 2019 18:46:44 +0200 Subject: [PATCH 11/35] fix mut --- zokrates_core/src/semantics.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/zokrates_core/src/semantics.rs b/zokrates_core/src/semantics.rs index 58600c2c..c16a7397 100644 --- a/zokrates_core/src/semantics.rs +++ b/zokrates_core/src/semantics.rs @@ -277,7 +277,7 @@ impl<'ast> Checker<'ast> { Some(module) => { assert_eq!(module.imports.len(), 0); - let ids = HashSet::new(); + let mut ids = HashSet::new(); for declaration in module.symbols { let pos = declaration.pos(); From ff21034f073221feb4cbf211e805fc795d3e429a Mon Sep 17 00:00:00 2001 From: schaeff Date: Fri, 2 Aug 2019 19:58:44 +0200 Subject: [PATCH 12/35] introduce targetted imports --- zokrates_cli/examples/imports/foo.code | 2 +- zokrates_core/src/absy/from_ast.rs | 15 +++++++++++++-- zokrates_core/src/imports.rs | 23 ++++++++++++++++------- zokrates_parser/src/zokrates.pest | 5 ++++- zokrates_pest_ast/src/lib.rs | 19 ++++++++++++++++++- 5 files changed, 52 insertions(+), 12 deletions(-) diff --git a/zokrates_cli/examples/imports/foo.code b/zokrates_cli/examples/imports/foo.code index 9b2256b4..b624cbab 100644 --- a/zokrates_cli/examples/imports/foo.code +++ b/zokrates_cli/examples/imports/foo.code @@ -1,4 +1,4 @@ -import "./baz.code" +from "./baz.code" import main as baz def main() -> (field): return baz() \ No newline at end of file diff --git a/zokrates_core/src/absy/from_ast.rs b/zokrates_core/src/absy/from_ast.rs index 09a56420..4858b382 100644 --- a/zokrates_core/src/absy/from_ast.rs +++ b/zokrates_core/src/absy/from_ast.rs @@ -24,9 +24,20 @@ impl<'ast, T: Field> From> for absy::Module<'ast, T> { impl<'ast> From> for absy::ImportNode<'ast> { fn from(import: pest::ImportDirective<'ast>) -> absy::ImportNode { use absy::NodeValue; - imports::Import::new(import.source.span.as_str()) + + match import { + pest::ImportDirective::Main(import) => { + imports::Import::new(None, import.source.span.as_str()) + .alias(import.alias.map(|a| a.span.as_str())) + .span(import.span) + } + pest::ImportDirective::From(import) => imports::Import::new( + Some(import.symbol.span.as_str()), + import.source.span.as_str(), + ) .alias(import.alias.map(|a| a.span.as_str())) - .span(import.span) + .span(import.span), + } } } diff --git a/zokrates_core/src/imports.rs b/zokrates_core/src/imports.rs index 2b34ba60..1dc85bbe 100644 --- a/zokrates_core/src/imports.rs +++ b/zokrates_core/src/imports.rs @@ -57,15 +57,17 @@ impl From for Error { #[derive(PartialEq, Clone)] pub struct Import<'ast> { source: Identifier<'ast>, + symbol: Option>, alias: Option>, } pub type ImportNode<'ast> = Node>; impl<'ast> Import<'ast> { - pub fn new(source: Identifier<'ast>) -> Import<'ast> { + pub fn new(symbol: Option>, source: Identifier<'ast>) -> Import<'ast> { Import { - source: source, + symbol, + source, alias: None, } } @@ -74,9 +76,14 @@ impl<'ast> Import<'ast> { &self.alias } - pub fn new_with_alias(source: Identifier<'ast>, alias: Identifier<'ast>) -> Import<'ast> { + pub fn new_with_alias( + symbol: Option>, + source: Identifier<'ast>, + alias: Identifier<'ast>, + ) -> Import<'ast> { Import { - source: source, + symbol, + source, alias: Some(alias), } } @@ -190,7 +197,7 @@ impl Importer { id: &alias, symbol: FunctionSymbol::There( FunctionImport::with_id_in_module( - "main", + import.symbol.unwrap_or("main"), import.source.clone(), ) .start_end(pos.0, pos.1), @@ -235,8 +242,9 @@ mod tests { #[test] fn create_with_no_alias() { assert_eq!( - Import::new("./foo/bar/baz.code"), + Import::new(None, "./foo/bar/baz.code"), Import { + symbol: None, source: "./foo/bar/baz.code", alias: None, } @@ -246,8 +254,9 @@ mod tests { #[test] fn create_with_alias() { assert_eq!( - Import::new_with_alias("./foo/bar/baz.code", &"myalias"), + Import::new_with_alias(None, "./foo/bar/baz.code", &"myalias"), Import { + symbol: None, source: "./foo/bar/baz.code", alias: Some("myalias"), } diff --git a/zokrates_parser/src/zokrates.pest b/zokrates_parser/src/zokrates.pest index ab5cbd5a..96f481bb 100644 --- a/zokrates_parser/src/zokrates.pest +++ b/zokrates_parser/src/zokrates.pest @@ -4,7 +4,10 @@ */ file = { SOI ~ NEWLINE* ~ import_directive* ~ NEWLINE* ~ function_definition* ~ EOI } -import_directive = {"import" ~ "\"" ~ import_source ~ "\"" ~ ("as" ~ identifier)? ~ NEWLINE+} + +import_directive = { main_import_directive | from_import_directive } +from_import_directive = { "from" ~ "\"" ~ import_source ~ "\"" ~ "import" ~ identifier ~ ("as" ~ identifier)? ~ NEWLINE*} +main_import_directive = {"import" ~ "\"" ~ import_source ~ "\"" ~ ("as" ~ identifier)? ~ NEWLINE+} import_source = @{(!"\"" ~ ANY)*} function_definition = {"def" ~ identifier ~ "(" ~ parameter_list ~ ")" ~ "->" ~ "(" ~ type_list ~ ")" ~ ":" ~ NEWLINE* ~ statement* } diff --git a/zokrates_pest_ast/src/lib.rs b/zokrates_pest_ast/src/lib.rs index 4e61be89..add62e81 100644 --- a/zokrates_pest_ast/src/lib.rs +++ b/zokrates_pest_ast/src/lib.rs @@ -174,13 +174,30 @@ mod ast { #[derive(Debug, FromPest, PartialEq, Clone)] #[pest_ast(rule(Rule::import_directive))] - pub struct ImportDirective<'ast> { + pub enum ImportDirective<'ast> { + Main(MainImportDirective<'ast>), + From(FromImportDirective<'ast>), + } + + #[derive(Debug, FromPest, PartialEq, Clone)] + #[pest_ast(rule(Rule::main_import_directive))] + pub struct MainImportDirective<'ast> { pub source: ImportSource<'ast>, pub alias: Option>, #[pest_ast(outer())] pub span: Span<'ast>, } + #[derive(Debug, FromPest, PartialEq, Clone)] + #[pest_ast(rule(Rule::from_import_directive))] + pub struct FromImportDirective<'ast> { + pub source: ImportSource<'ast>, + pub symbol: IdentifierExpression<'ast>, + pub alias: Option>, + #[pest_ast(outer())] + pub span: Span<'ast>, + } + #[derive(Debug, FromPest, PartialEq, Clone)] #[pest_ast(rule(Rule::import_source))] pub struct ImportSource<'ast> { From 0673381bce4b28262860692ee9ec8b1afb3e595c Mon Sep 17 00:00:00 2001 From: schaeff Date: Wed, 7 Aug 2019 16:25:09 +0200 Subject: [PATCH 13/35] implement flattening --- t.code | 14 +- zokrates_core/src/absy/from_ast.rs | 19 +- zokrates_core/src/absy/types.rs | 1 - zokrates_core/src/absy/variable.rs | 2 +- zokrates_core/src/flatten/mod.rs | 472 ++++++++++++++++++----- zokrates_core/src/semantics.rs | 14 +- zokrates_core/src/typed_absy/types.rs | 3 - zokrates_core/src/typed_absy/variable.rs | 1 - zokrates_parser/src/zokrates.pest | 3 +- zokrates_pest_ast/src/lib.rs | 22 +- 10 files changed, 420 insertions(+), 131 deletions(-) diff --git a/t.code b/t.code index 351d0176..cd71dab9 100644 --- a/t.code +++ b/t.code @@ -1,13 +1,15 @@ from "./u.code" import Foo struct Bar { - a: Foo, - b: field[2] + a: field, + b: field, + c: field, } -def f(Foo a) -> (Foo): - return a +struct Baz { + a: Bar +} -def main(Bar a) -> (Foo): - return f(a.a) +def main(Baz a, Baz b, bool c) -> (Baz): + return if c then a else b fi.a diff --git a/zokrates_core/src/absy/from_ast.rs b/zokrates_core/src/absy/from_ast.rs index 9feec687..39f61c5f 100644 --- a/zokrates_core/src/absy/from_ast.rs +++ b/zokrates_core/src/absy/from_ast.rs @@ -8,16 +8,6 @@ use zokrates_pest_ast as pest; impl<'ast, T: Field> From> for absy::Module<'ast, T> { fn from(prog: pest::File<'ast>) -> absy::Module { absy::Module { - // types: prog - // .structs - // .into_iter() - // .map(|t| absy::TypeDeclarationNode::from(t)) - // .collect(), - // functions: prog - // .functions - // .into_iter() - // .map(|f| absy::FunctionDeclarationNode::from(f)) - // .collect(), symbols: prog .structs .into_iter() @@ -602,8 +592,13 @@ impl<'ast> From> for UnresolvedTypeNode { }, pest::Type::Array(t) => { let inner_type = match t.ty { - pest::BasicType::Field(t) => UnresolvedType::FieldElement.span(t.span), - pest::BasicType::Boolean(t) => UnresolvedType::Boolean.span(t.span), + pest::BasicOrStructType::Basic(t) => match t { + pest::BasicType::Field(t) => UnresolvedType::FieldElement.span(t.span), + pest::BasicType::Boolean(t) => UnresolvedType::Boolean.span(t.span), + }, + pest::BasicOrStructType::Struct(t) => { + UnresolvedType::User(t.span.as_str().to_string()).span(t.span) + } }; let span = t.span; diff --git a/zokrates_core/src/absy/types.rs b/zokrates_core/src/absy/types.rs index bc2f6bac..d5e22303 100644 --- a/zokrates_core/src/absy/types.rs +++ b/zokrates_core/src/absy/types.rs @@ -39,7 +39,6 @@ pub use self::signature::UnresolvedSignature; mod signature { use std::fmt; - use super::*; use absy::UnresolvedTypeNode; #[derive(Clone, PartialEq, Serialize, Deserialize)] diff --git a/zokrates_core/src/absy/variable.rs b/zokrates_core/src/absy/variable.rs index bf96e69f..f03b3f0d 100644 --- a/zokrates_core/src/absy/variable.rs +++ b/zokrates_core/src/absy/variable.rs @@ -1,5 +1,5 @@ use crate::absy::types::UnresolvedType; -use crate::absy::{Node, NodeValue, UnresolvedTypeNode}; +use crate::absy::{Node, UnresolvedTypeNode}; use std::fmt; use crate::absy::Identifier; diff --git a/zokrates_core/src/flatten/mod.rs b/zokrates_core/src/flatten/mod.rs index acb3fef7..7c967cdd 100644 --- a/zokrates_core/src/flatten/mod.rs +++ b/zokrates_core/src/flatten/mod.rs @@ -101,7 +101,7 @@ impl<'ast, T: Field> Flatten<'ast, T> for StructExpression<'ast, T> { symbols: &TypedFunctionSymbols<'ast, T>, statements_flattened: &mut Vec>, ) -> Vec> { - unimplemented!() + flattener.flatten_struct_expression(symbols, statements_flattened, self) } fn if_else( @@ -109,15 +109,42 @@ impl<'ast, T: Field> Flatten<'ast, T> for StructExpression<'ast, T> { consequence: Self, alternative: Self, ) -> Self { - unimplemented!() + StructExpression { + ty: consequence.ty.clone(), + inner: StructExpressionInner::IfElse(box condition, box consequence, box alternative), + } } fn select(array: ArrayExpression<'ast, T>, index: FieldElementExpression<'ast, T>) -> Self { - unimplemented!() + let members = match array.inner_type() { + Type::Struct(members) => members, + _ => unreachable!(), + }; + + StructExpression { + ty: members.clone(), + inner: StructExpressionInner::Select(box array, box index), + } } - fn member(s: StructExpression<'ast, T>, id: MemberId) -> Self { - unimplemented!() + fn member(s: StructExpression<'ast, T>, member_id: MemberId) -> Self { + let members = s.ty.clone(); + + let ty = members + .into_iter() + .find(|(id, _)| *id == member_id) + .unwrap() + .1; + + let members = match ty { + Type::Struct(members) => members, + _ => unreachable!(), + }; + + StructExpression { + ty: members, + inner: StructExpressionInner::Member(box s, member_id), + } } } @@ -178,8 +205,25 @@ impl<'ast, T: Field> Flatten<'ast, T> for ArrayExpression<'ast, T> { } } - fn member(s: StructExpression<'ast, T>, id: MemberId) -> Self { - unimplemented!() + fn member(s: StructExpression<'ast, T>, member_id: MemberId) -> Self { + let members = s.ty.clone(); + + let ty = members + .into_iter() + .find(|(id, _)| *id == member_id) + .unwrap() + .1; + + let (ty, size) = match ty { + Type::Array(box ty, size) => (ty, size), + _ => unreachable!(), + }; + + ArrayExpression { + ty, + size, + inner: ArrayExpressionInner::Member(box s, member_id), + } } } @@ -281,7 +325,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { res.into_iter().map(|r| r.into()).collect() } - fn flatten_member_expression>( + fn flatten_member_expression( &mut self, symbols: &TypedFunctionSymbols<'ast, T>, statements_flattened: &mut Vec>, @@ -289,35 +333,270 @@ impl<'ast, T: Field> Flattener<'ast, T> { member_id: MemberId, ) -> Vec> { let members = s.ty; + let expected_output_size = members + .iter() + .find(|(id, _)| *id == member_id) + .unwrap() + .1 + .get_primitive_count(); - match s.inner { - StructExpressionInner::Identifier(id) => { - // the struct is encoded as a sequence, so we need to identify the offset at which this member starts - let offset = members - .iter() - .map(|(id, ty)| (id, ty.get_primitive_count())) - .fold((false, 0), |acc, (id, count)| { - if acc.0 && *id != member_id { - (false, acc.1 + count) - } else { - (true, acc.1) + let res = + match s.inner { + StructExpressionInner::Value(values) => { + // If the struct has an explicit value, we get the value at the given member + assert_eq!(values.len(), members.len()); + values + .into_iter() + .zip(members.into_iter()) + .filter(|(_, (id, _))| *id == member_id) + .flat_map(|(v, (_, t))| match t { + Type::FieldElement => FieldElementExpression::try_from(v) + .unwrap() + .flatten(self, symbols, statements_flattened), + Type::Boolean => BooleanExpression::try_from(v).unwrap().flatten( + self, + symbols, + statements_flattened, + ), + Type::Array(box ty, size) => ArrayExpression::try_from(v) + .unwrap() + .flatten(self, symbols, statements_flattened), + Type::Struct(members) => StructExpression::try_from(v) + .unwrap() + .flatten(self, symbols, statements_flattened), + }) + .collect() + } + StructExpressionInner::Identifier(id) => { + // If the struct is an identifier, we allocated variables in the layout for that identifier. We need to access a subset of these values. + // the struct is encoded as a sequence, so we need to identify the offset at which this member starts + let offset = members + .iter() + .take_while(|(id, _)| *id != member_id) + .map(|(_, ty)| ty.get_primitive_count()) + .sum(); + + // we also need the size of this member + let size = members + .iter() + .find(|(id, _)| *id == member_id) + .unwrap() + .1 + .get_primitive_count(); + self.layout.get(&id).unwrap()[offset..(offset + size)] + .into_iter() + .map(|i| i.clone().into()) + .collect() + } + StructExpressionInner::Select(box array, box index) => { + // If the struct is an array element `array[index]`, we're accessing `array[index].member` + // We construct `array := array.map(|e| e.member)` and access `array[index]` + let ty = members + .clone() + .into_iter() + .find(|(id, _)| *id == member_id) + .unwrap() + .1; + + match ty { + Type::FieldElement => { + let array = ArrayExpression { + size: array.size, + ty: Type::FieldElement, + inner: ArrayExpressionInner::Value( + (0..array.size) + .map(|i| { + FieldElementExpression::Member( + box StructExpression { + ty: members.clone(), + inner: StructExpressionInner::Select( + box array.clone(), + box FieldElementExpression::Number( + T::from(i), + ), + ), + }, + member_id.clone(), + ) + .into() + }) + .collect(), + ), + }; + self.flatten_select_expression::>( + symbols, + statements_flattened, + array, + index, + ) } - }) - .1; - // we also need the size of this member - let size = members - .iter() - .find(|(id, _)| *id == member_id) - .unwrap() - .1 - .get_primitive_count(); - self.layout.get(&id).unwrap()[offset..(offset + size)] - .into_iter() - .map(|i| i.clone().into()) - .collect() - } - _ => unimplemented!(), - } + Type::Boolean => { + let array = ArrayExpression { + size: array.size, + ty: Type::Boolean, + inner: ArrayExpressionInner::Value( + (0..array.size) + .map(|i| { + BooleanExpression::Member( + box StructExpression { + ty: members.clone(), + inner: StructExpressionInner::Select( + box array.clone(), + box FieldElementExpression::Number( + T::from(i), + ), + ), + }, + member_id.clone(), + ) + .into() + }) + .collect(), + ), + }; + self.flatten_select_expression::>( + symbols, + statements_flattened, + array, + index, + ) + } + Type::Struct(m) => { + let array = ArrayExpression { + size: array.size, + ty: Type::Struct(m.clone()), + inner: ArrayExpressionInner::Value( + (0..array.size) + .map(|i| { + StructExpression { + ty: m.clone(), + inner: StructExpressionInner::Member( + box StructExpression { + ty: members.clone(), + inner: StructExpressionInner::Select( + box array.clone(), + box FieldElementExpression::Number( + T::from(i), + ), + ), + }, + member_id.clone(), + ), + } + .into() + }) + .collect(), + ), + }; + self.flatten_select_expression::>( + symbols, + statements_flattened, + array, + index, + ) + } + Type::Array(box ty, size) => { + let array = ArrayExpression { + size: array.size, + ty: Type::Array(box ty.clone(), size), + inner: ArrayExpressionInner::Value( + (0..array.size) + .map(|i| { + ArrayExpression { + size, + ty: ty.clone(), + inner: ArrayExpressionInner::Member( + box StructExpression { + ty: members.clone(), + inner: StructExpressionInner::Select( + box array.clone(), + box FieldElementExpression::Number( + T::from(i), + ), + ), + }, + member_id.clone(), + ), + } + .into() + }) + .collect(), + ), + }; + self.flatten_select_expression::>( + symbols, + statements_flattened, + array, + index, + ) + } + } + } + StructExpressionInner::FunctionCall(..) => unreachable!(), + StructExpressionInner::IfElse(box condition, box consequence, box alternative) => { + // if the struct is `(if c then a else b)`, we want to access `(if c then a else b).member` + // we reduce to `if c then a.member else b.member` + let ty = members + .clone() + .into_iter() + .find(|(id, _)| *id == member_id) + .unwrap() + .1; + + match ty { + Type::FieldElement => self.flatten_if_else_expression( + symbols, + statements_flattened, + condition.clone(), + FieldElementExpression::member(consequence.clone(), member_id.clone()), + FieldElementExpression::member(alternative.clone(), member_id), + ), + Type::Boolean => self.flatten_if_else_expression( + symbols, + statements_flattened, + condition.clone(), + BooleanExpression::member(consequence.clone(), member_id.clone()), + BooleanExpression::member(alternative.clone(), member_id), + ), + Type::Struct(m) => self.flatten_if_else_expression( + symbols, + statements_flattened, + condition.clone(), + StructExpression::member(consequence.clone(), member_id.clone()), + StructExpression::member(alternative.clone(), member_id), + ), + Type::Array(box ty, size) => self.flatten_if_else_expression( + symbols, + statements_flattened, + condition.clone(), + ArrayExpression::member(consequence.clone(), member_id.clone()), + ArrayExpression::member(alternative.clone(), member_id), + ), + } + } + StructExpressionInner::Member(box s0, m_id) => { + let e = self.flatten_member_expression(symbols, statements_flattened, s0, m_id); + + let offset = members + .iter() + .take_while(|(id, _)| *id != member_id) + .map(|(_, ty)| ty.get_primitive_count()) + .sum(); + + // we also need the size of this member + let size = members + .iter() + .find(|(id, _)| *id == member_id) + .unwrap() + .1 + .get_primitive_count(); + + e[offset..(offset + size)].into() + } + }; + + assert_eq!(res.len(), expected_output_size); + res } fn flatten_select_expression>( @@ -362,7 +641,13 @@ impl<'ast, T: Field> Flattener<'ast, T> { ) .flatten(self, symbols, statements_flattened) } - ArrayExpressionInner::Member(box s, id) => unimplemented!(), + ArrayExpressionInner::Member(box s, id) => { + assert!(n < T::from(size)); + let n = n.to_dec_string().parse::().unwrap(); + self.flatten_member_expression(symbols, statements_flattened, s, id) + [n * ty.get_primitive_count()..(n + 1) * ty.get_primitive_count()] + .to_vec() + } ArrayExpressionInner::Select(box array, box index) => { assert!(n < T::from(size)); let n = n.to_dec_string().parse::().unwrap(); @@ -755,14 +1040,9 @@ impl<'ast, T: Field> Flattener<'ast, T> { alternative, )[0] .clone(), - BooleanExpression::Member(box s, id) => self - .flatten_member_expression::>( - symbols, - statements_flattened, - s, - id, - )[0] - .clone(), + BooleanExpression::Member(box s, id) => { + self.flatten_member_expression(symbols, statements_flattened, s, id)[0].clone() + } BooleanExpression::Select(box array, box index) => self .flatten_select_expression::>( symbols, @@ -1150,14 +1430,9 @@ impl<'ast, T: Field> Flattener<'ast, T> { assert!(exprs_flattened.expressions.len() == 1); // outside of MultipleDefinition, FunctionCalls must return a single value exprs_flattened.expressions[0].clone() } - FieldElementExpression::Member(box s, id) => self - .flatten_member_expression::>( - symbols, - statements_flattened, - s, - id, - )[0] - .clone(), + FieldElementExpression::Member(box s, id) => { + self.flatten_member_expression(symbols, statements_flattened, s, id)[0].clone() + } FieldElementExpression::Select(box array, box index) => self .flatten_select_expression::>( symbols, @@ -1176,9 +1451,10 @@ impl<'ast, T: Field> Flattener<'ast, T> { expr: StructExpression<'ast, T>, ) -> Vec> { let ty = expr.get_type(); - //assert_eq!(U::get_type(), inner_type); + let expected_output_size = expr.get_type().get_primitive_count(); + let members = expr.ty; - match expr.inner { + let res = match expr.inner { StructExpressionInner::Identifier(x) => self .layout .get(&x) @@ -1186,16 +1462,10 @@ impl<'ast, T: Field> Flattener<'ast, T> { .iter() .map(|v| FlatExpression::Identifier(v.clone())) .collect(), - // StructExpressionInner::Value(values) => { - // values - // .into_iter() - // .flat_map(|v| { - // U::try_from(v) - // .unwrap() - // .flatten(self, symbols, statements_flattened) - // }) - // .collect() - // } + StructExpressionInner::Value(values) => values + .into_iter() + .flat_map(|v| self.flatten_expression(symbols, statements_flattened, v)) + .collect(), StructExpressionInner::FunctionCall(key, param_expressions) => { let exprs_flattened = self.flatten_function_call( symbols, @@ -1206,30 +1476,40 @@ impl<'ast, T: Field> Flattener<'ast, T> { ); exprs_flattened.expressions } - // StructExpressionInner::IfElse(ref condition, ref consequence, ref alternative) => (0 - // ..size) - // .flat_map(|i| { - // U::if_else( - // *condition.clone(), - // U::select( - // *consequence.clone(), - // FieldElementExpression::Number(T::from(i)), - // ), - // U::select( - // *alternative.clone(), - // FieldElementExpression::Number(T::from(i)), - // ), - // ) - // .flatten(self, symbols, statements_flattened) - // }) - // .collect(), - StructExpressionInner::Member(box s, id) => self - .flatten_member_expression::>( - symbols, - statements_flattened, - s, - id, - ), + StructExpressionInner::IfElse(box condition, box consequence, box alternative) => { + members + .into_iter() + .flat_map(|(id, ty)| match ty { + Type::FieldElement => FieldElementExpression::if_else( + condition.clone(), + FieldElementExpression::member(consequence.clone(), id.clone()), + FieldElementExpression::member(alternative.clone(), id.clone()), + ) + .flatten(self, symbols, statements_flattened), + Type::Boolean => BooleanExpression::if_else( + condition.clone(), + BooleanExpression::member(consequence.clone(), id.clone()), + BooleanExpression::member(alternative.clone(), id.clone()), + ) + .flatten(self, symbols, statements_flattened), + Type::Struct(..) => StructExpression::if_else( + condition.clone(), + StructExpression::member(consequence.clone(), id.clone()), + StructExpression::member(alternative.clone(), id.clone()), + ) + .flatten(self, symbols, statements_flattened), + Type::Array(..) => ArrayExpression::if_else( + condition.clone(), + ArrayExpression::member(consequence.clone(), id.clone()), + ArrayExpression::member(alternative.clone(), id.clone()), + ) + .flatten(self, symbols, statements_flattened), + }) + .collect() + } + StructExpressionInner::Member(box s, id) => { + self.flatten_member_expression(symbols, statements_flattened, s, id) + } StructExpressionInner::Select(box array, box index) => self .flatten_select_expression::>( symbols, @@ -1237,8 +1517,10 @@ impl<'ast, T: Field> Flattener<'ast, T> { array, index, ), - _ => unimplemented!("yeah well"), - } + }; + + assert_eq!(res.len(), expected_output_size); + res } /// # Remarks @@ -1300,13 +1582,9 @@ impl<'ast, T: Field> Flattener<'ast, T> { .flatten(self, symbols, statements_flattened) }) .collect(), - ArrayExpressionInner::Member(box s, id) => self - .flatten_member_expression::>( - symbols, - statements_flattened, - s, - id, - ), + ArrayExpressionInner::Member(box s, id) => { + self.flatten_member_expression(symbols, statements_flattened, s, id) + } ArrayExpressionInner::Select(box array, box index) => self .flatten_select_expression::>( symbols, diff --git a/zokrates_core/src/semantics.rs b/zokrates_core/src/semantics.rs index bb979727..8d6f69f0 100644 --- a/zokrates_core/src/semantics.rs +++ b/zokrates_core/src/semantics.rs @@ -1072,6 +1072,16 @@ impl<'ast> Checker<'ast> { unimplemented!("handle consequence alternative inner type mismatch") } }, + (TypedExpression::Struct(consequence), TypedExpression::Struct(alternative)) => { + if consequence.get_type() == alternative.get_type() { + Ok(StructExpression { + ty: consequence.ty.clone(), + inner: StructExpressionInner::IfElse(box condition, box consequence, box alternative) + }.into()) + } else { + unimplemented!("handle consequence alternative inner type mismatch") + } + }, _ => unimplemented!() } false => Err(Error { @@ -1367,8 +1377,8 @@ impl<'ast> Checker<'ast> { // check that the struct has that field and return the type if it does let ty = s.ty.iter() - .find(|(member_id, ty)| member_id == id) - .map(|(member_id, ty)| ty); + .find(|(member_id, _)| member_id == id) + .map(|(_, ty)| ty); match ty { Some(ty) => match ty { diff --git a/zokrates_core/src/typed_absy/types.rs b/zokrates_core/src/typed_absy/types.rs index 249b7e1e..e3a49e96 100644 --- a/zokrates_core/src/typed_absy/types.rs +++ b/zokrates_core/src/typed_absy/types.rs @@ -1,12 +1,9 @@ -use absy::UnresolvedTypeNode; use std::fmt; pub type Identifier<'ast> = &'ast str; pub type MemberId = String; -pub type UserTypeId = String; - #[derive(Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] pub enum Type { FieldElement, diff --git a/zokrates_core/src/typed_absy/variable.rs b/zokrates_core/src/typed_absy/variable.rs index d28646b3..f38236c3 100644 --- a/zokrates_core/src/typed_absy/variable.rs +++ b/zokrates_core/src/typed_absy/variable.rs @@ -1,4 +1,3 @@ -use crate::absy; use crate::typed_absy::types::Type; use crate::typed_absy::Identifier; use std::fmt; diff --git a/zokrates_parser/src/zokrates.pest b/zokrates_parser/src/zokrates.pest index a2bcea3d..e1ca18d2 100644 --- a/zokrates_parser/src/zokrates.pest +++ b/zokrates_parser/src/zokrates.pest @@ -18,7 +18,8 @@ parameter = {vis? ~ ty ~ identifier} ty_field = {"field"} ty_bool = {"bool"} ty_basic = { ty_field | ty_bool } -ty_array = { ty_basic ~ ("[" ~ expression ~ "]")+ } +ty_basic_or_struct = { ty_basic | ty_struct } +ty_array = { ty_basic_or_struct ~ ("[" ~ expression ~ "]")+ } ty = { ty_array | ty_basic | ty_struct } type_list = _{(ty ~ ("," ~ ty)*)?} // structs diff --git a/zokrates_pest_ast/src/lib.rs b/zokrates_pest_ast/src/lib.rs index 0d8a4c59..2033b671 100644 --- a/zokrates_pest_ast/src/lib.rs +++ b/zokrates_pest_ast/src/lib.rs @@ -9,12 +9,13 @@ extern crate lazy_static; pub use ast::{ Access, ArrayAccess, ArrayInitializerExpression, ArrayType, AssertionStatement, Assignee, - AssignmentStatement, BasicType, BinaryExpression, BinaryOperator, CallAccess, - ConstantExpression, DefinitionStatement, Expression, File, FromExpression, Function, - IdentifierExpression, ImportDirective, ImportSource, InlineArrayExpression, IterationStatement, - MultiAssignmentStatement, Parameter, PostfixExpression, Range, RangeOrExpression, - ReturnStatement, Span, Spread, SpreadOrExpression, Statement, StructDefinition, StructField, - TernaryExpression, ToExpression, Type, UnaryExpression, UnaryOperator, Visibility, + AssignmentStatement, BasicOrStructType, BasicType, BinaryExpression, BinaryOperator, + CallAccess, ConstantExpression, DefinitionStatement, Expression, File, FromExpression, + Function, IdentifierExpression, ImportDirective, ImportSource, InlineArrayExpression, + IterationStatement, MultiAssignmentStatement, Parameter, PostfixExpression, Range, + RangeOrExpression, ReturnStatement, Span, Spread, SpreadOrExpression, Statement, + StructDefinition, StructField, TernaryExpression, ToExpression, Type, UnaryExpression, + UnaryOperator, Visibility, }; mod ast { @@ -251,12 +252,19 @@ mod ast { #[derive(Debug, FromPest, PartialEq, Clone)] #[pest_ast(rule(Rule::ty_array))] pub struct ArrayType<'ast> { - pub ty: BasicType<'ast>, + pub ty: BasicOrStructType<'ast>, pub size: Vec>, #[pest_ast(outer())] pub span: Span<'ast>, } + #[derive(Debug, FromPest, PartialEq, Clone)] + #[pest_ast(rule(Rule::ty_basic_or_struct))] + pub enum BasicOrStructType<'ast> { + Struct(StructType<'ast>), + Basic(BasicType<'ast>), + } + #[derive(Debug, FromPest, PartialEq, Clone)] #[pest_ast(rule(Rule::ty_bool))] pub struct BooleanType<'ast> { From 9dfcdd3f9ac3f0f6fe1f03bbc0781173e10adc98 Mon Sep 17 00:00:00 2001 From: schaeff Date: Wed, 7 Aug 2019 18:50:43 +0200 Subject: [PATCH 14/35] implement inline struct --- t.code | 5 +- zokrates_core/src/absy/from_ast.rs | 20 +++ zokrates_core/src/absy/mod.rs | 18 ++- zokrates_core/src/semantics.rs | 229 ++++++++++++++++++++--------- zokrates_parser/src/zokrates.pest | 6 +- zokrates_pest_ast/src/lib.rs | 31 +++- 6 files changed, 233 insertions(+), 76 deletions(-) diff --git a/t.code b/t.code index cd71dab9..00aa02d9 100644 --- a/t.code +++ b/t.code @@ -10,6 +10,7 @@ struct Baz { a: Bar } -def main(Baz a, Baz b, bool c) -> (Baz): - return if c then a else b fi.a +def main(Bar a, Bar b, bool c) -> (Bar): + Bar bar = Bar { a: 1, b: 1, c: 1 } + return if false then a else bar fi diff --git a/zokrates_core/src/absy/from_ast.rs b/zokrates_core/src/absy/from_ast.rs index 39f61c5f..6160c9d9 100644 --- a/zokrates_core/src/absy/from_ast.rs +++ b/zokrates_core/src/absy/from_ast.rs @@ -323,6 +323,7 @@ impl<'ast, T: Field> From> for absy::ExpressionNode<'ast, pest::Expression::Identifier(e) => absy::ExpressionNode::from(e), pest::Expression::Postfix(e) => absy::ExpressionNode::from(e), pest::Expression::InlineArray(e) => absy::ExpressionNode::from(e), + pest::Expression::InlineStruct(e) => absy::ExpressionNode::from(e), pest::Expression::ArrayInitializer(e) => absy::ExpressionNode::from(e), pest::Expression::Unary(e) => absy::ExpressionNode::from(e), } @@ -475,6 +476,25 @@ impl<'ast, T: Field> From> for absy::Expressio } } +impl<'ast, T: Field> From> for absy::ExpressionNode<'ast, T> { + fn from(s: pest::InlineStructExpression<'ast>) -> absy::ExpressionNode<'ast, T> { + use absy::NodeValue; + absy::Expression::InlineStruct( + s.ty.span.as_str().to_string(), + s.members + .into_iter() + .map(|member| { + ( + member.id.span.as_str(), + absy::ExpressionNode::from(member.expression), + ) + }) + .collect(), + ) + .span(s.span) + } +} + impl<'ast, T: Field> From> for absy::ExpressionNode<'ast, T> { diff --git a/zokrates_core/src/absy/mod.rs b/zokrates_core/src/absy/mod.rs index 19d8524d..e463a342 100644 --- a/zokrates_core/src/absy/mod.rs +++ b/zokrates_core/src/absy/mod.rs @@ -13,7 +13,7 @@ pub mod variable; pub use crate::absy::node::{Node, NodeValue}; pub use crate::absy::parameter::{Parameter, ParameterNode}; -use crate::absy::types::{FunctionIdentifier, UnresolvedSignature, UnresolvedType}; +use crate::absy::types::{FunctionIdentifier, UnresolvedSignature, UnresolvedType, UserTypeId}; pub use crate::absy::variable::{Variable, VariableNode}; use embed::FlatEmbed; @@ -455,6 +455,7 @@ pub enum Expression<'ast, T: Field> { And(Box>, Box>), Not(Box>), InlineArray(Vec>), + InlineStruct(UserTypeId, Vec<(Identifier<'ast>, ExpressionNode<'ast, T>)>), Select( Box>, Box>, @@ -508,6 +509,16 @@ impl<'ast, T: Field> fmt::Display for Expression<'ast, T> { } write!(f, "]") } + Expression::InlineStruct(ref id, ref members) => { + r#try!(write!(f, "{} {{", id)); + for (i, (member_id, e)) in members.iter().enumerate() { + r#try!(write!(f, "{}: {}", member_id, e)); + if i < members.len() - 1 { + r#try!(write!(f, ", ")); + } + } + write!(f, "}}") + } Expression::Select(ref array, ref index) => write!(f, "{}[{}]", array, index), Expression::Member(ref struc, ref id) => write!(f, "{}.{}", struc, id), Expression::Or(ref lhs, ref rhs) => write!(f, "{} || {}", lhs, rhs), @@ -548,6 +559,11 @@ impl<'ast, T: Field> fmt::Debug for Expression<'ast, T> { r#try!(f.debug_list().entries(exprs.iter()).finish()); write!(f, "]") } + Expression::InlineStruct(ref id, ref members) => { + r#try!(write!(f, "InlineStruct({:?}, [", id)); + r#try!(f.debug_list().entries(members.iter()).finish()); + write!(f, "]") + } Expression::Select(ref array, ref index) => write!(f, "{}[{}]", array, index), Expression::Member(ref struc, ref id) => write!(f, "{}.{}", struc, id), Expression::Or(ref lhs, ref rhs) => write!(f, "{} || {}", lhs, rhs), diff --git a/zokrates_core/src/semantics.rs b/zokrates_core/src/semantics.rs index 8d6f69f0..3e41c598 100644 --- a/zokrates_core/src/semantics.rs +++ b/zokrates_core/src/semantics.rs @@ -657,7 +657,7 @@ impl<'ast> Checker<'ast> { Statement::Return(list) => { let mut expression_list_checked = vec![]; for e in list.value.expressions { - let e_checked = self.check_expression(e)?; + let e_checked = self.check_expression(e, module_id, &types)?; expression_list_checked.push(e_checked); } @@ -705,11 +705,11 @@ impl<'ast> Checker<'ast> { } // check the expression to be assigned - let checked_expr = self.check_expression(expr)?; + let checked_expr = self.check_expression(expr, module_id, &types)?; let expression_type = checked_expr.get_type(); // check that the assignee is declared and is well formed - let var = self.check_assignee(assignee)?; + let var = self.check_assignee(assignee, module_id, &types)?; let var_type = var.get_type(); @@ -726,8 +726,8 @@ impl<'ast> Checker<'ast> { } } Statement::Condition(lhs, rhs) => { - let checked_lhs = self.check_expression(lhs)?; - let checked_rhs = self.check_expression(rhs)?; + let checked_lhs = self.check_expression(lhs, module_id, &types)?; + let checked_rhs = self.check_expression(rhs, module_id, &types)?; if checked_lhs.get_type() == checked_rhs.get_type() { Ok(TypedStatement::Condition(checked_lhs, checked_rhs)) @@ -789,7 +789,7 @@ impl<'ast> Checker<'ast> { // find arguments types let mut arguments_checked = vec![]; for arg in arguments { - let arg_checked = self.check_expression(arg)?; + let arg_checked = self.check_expression(arg, module_id, &types)?; arguments_checked.push(arg_checked); } @@ -837,6 +837,8 @@ impl<'ast> Checker<'ast> { fn check_assignee( &mut self, assignee: AssigneeNode<'ast, T>, + module_id: &ModuleId, + types: &TypeMap, ) -> Result, Error> { let pos = assignee.pos(); // check that the assignee is declared @@ -852,9 +854,11 @@ impl<'ast> Checker<'ast> { }), }, Assignee::ArrayElement(box assignee, box index) => { - let checked_assignee = self.check_assignee(assignee)?; + let checked_assignee = self.check_assignee(assignee, module_id, &types)?; let checked_index = match index { - RangeOrExpression::Expression(e) => self.check_expression(e)?, + RangeOrExpression::Expression(e) => { + self.check_expression(e, module_id, &types)? + } r => unimplemented!( "Using slices in assignments is not supported yet, found {}", r @@ -885,12 +889,15 @@ impl<'ast> Checker<'ast> { fn check_spread_or_expression( &mut self, spread_or_expression: SpreadOrExpression<'ast, T>, + module_id: &ModuleId, + types: &TypeMap, ) -> Result>, Error> { match spread_or_expression { SpreadOrExpression::Spread(s) => { let pos = s.pos(); - let checked_expression = self.check_expression(s.value.expression)?; + let checked_expression = + self.check_expression(s.value.expression, module_id, &types)?; match checked_expression { TypedExpression::Array(e) => { let size = e.size(); @@ -914,13 +921,17 @@ impl<'ast> Checker<'ast> { }), } } - SpreadOrExpression::Expression(e) => self.check_expression(e).map(|r| vec![r]), + SpreadOrExpression::Expression(e) => { + self.check_expression(e, module_id, &types).map(|r| vec![r]) + } } } fn check_expression( &mut self, expr: ExpressionNode<'ast, T>, + module_id: &ModuleId, + types: &TypeMap, ) -> Result, Error> { let pos = expr.pos(); @@ -953,8 +964,8 @@ impl<'ast> Checker<'ast> { } } Expression::Add(box e1, box e2) => { - let e1_checked = self.check_expression(e1)?; - let e2_checked = self.check_expression(e2)?; + let e1_checked = self.check_expression(e1, module_id, &types)?; + let e2_checked = self.check_expression(e2, module_id, &types)?; match (e1_checked, e2_checked) { (TypedExpression::FieldElement(e1), TypedExpression::FieldElement(e2)) => { @@ -972,8 +983,8 @@ impl<'ast> Checker<'ast> { } } Expression::Sub(box e1, box e2) => { - let e1_checked = self.check_expression(e1)?; - let e2_checked = self.check_expression(e2)?; + let e1_checked = self.check_expression(e1, module_id, &types)?; + let e2_checked = self.check_expression(e2, module_id, &types)?; match (e1_checked, e2_checked) { (TypedExpression::FieldElement(e1), TypedExpression::FieldElement(e2)) => { @@ -991,8 +1002,8 @@ impl<'ast> Checker<'ast> { } } Expression::Mult(box e1, box e2) => { - let e1_checked = self.check_expression(e1)?; - let e2_checked = self.check_expression(e2)?; + let e1_checked = self.check_expression(e1, module_id, &types)?; + let e2_checked = self.check_expression(e2, module_id, &types)?; match (e1_checked, e2_checked) { (TypedExpression::FieldElement(e1), TypedExpression::FieldElement(e2)) => { @@ -1010,8 +1021,8 @@ impl<'ast> Checker<'ast> { } } Expression::Div(box e1, box e2) => { - let e1_checked = self.check_expression(e1)?; - let e2_checked = self.check_expression(e2)?; + let e1_checked = self.check_expression(e1, module_id, &types)?; + let e2_checked = self.check_expression(e2, module_id, &types)?; match (e1_checked, e2_checked) { (TypedExpression::FieldElement(e1), TypedExpression::FieldElement(e2)) => { @@ -1029,8 +1040,8 @@ impl<'ast> Checker<'ast> { } } Expression::Pow(box e1, box e2) => { - let e1_checked = self.check_expression(e1)?; - let e2_checked = self.check_expression(e2)?; + let e1_checked = self.check_expression(e1, module_id, &types)?; + let e2_checked = self.check_expression(e2, module_id, &types)?; match (e1_checked, e2_checked) { (TypedExpression::FieldElement(e1), TypedExpression::FieldElement(e2)) => Ok( @@ -1048,9 +1059,9 @@ impl<'ast> Checker<'ast> { } } Expression::IfElse(box condition, box consequence, box alternative) => { - let condition_checked = self.check_expression(condition)?; - let consequence_checked = self.check_expression(consequence)?; - let alternative_checked = self.check_expression(alternative)?; + let condition_checked = self.check_expression(condition, module_id, &types)?; + let consequence_checked = self.check_expression(consequence, module_id, &types)?; + let alternative_checked = self.check_expression(alternative, module_id, &types)?; match condition_checked { TypedExpression::Boolean(condition) => { @@ -1104,7 +1115,7 @@ impl<'ast> Checker<'ast> { // check the arguments let mut arguments_checked = vec![]; for arg in arguments { - let arg_checked = self.check_expression(arg)?; + let arg_checked = self.check_expression(arg, module_id, &types)?; arguments_checked.push(arg_checked); } @@ -1181,8 +1192,8 @@ impl<'ast> Checker<'ast> { } } Expression::Lt(box e1, box e2) => { - let e1_checked = self.check_expression(e1)?; - let e2_checked = self.check_expression(e2)?; + let e1_checked = self.check_expression(e1, module_id, &types)?; + let e2_checked = self.check_expression(e2, module_id, &types)?; match (e1_checked, e2_checked) { (TypedExpression::FieldElement(e1), TypedExpression::FieldElement(e2)) => { Ok(BooleanExpression::Lt(box e1, box e2).into()) @@ -1200,8 +1211,8 @@ impl<'ast> Checker<'ast> { } } Expression::Le(box e1, box e2) => { - let e1_checked = self.check_expression(e1)?; - let e2_checked = self.check_expression(e2)?; + let e1_checked = self.check_expression(e1, module_id, &types)?; + let e2_checked = self.check_expression(e2, module_id, &types)?; match (e1_checked, e2_checked) { (TypedExpression::FieldElement(e1), TypedExpression::FieldElement(e2)) => { Ok(BooleanExpression::Le(box e1, box e2).into()) @@ -1219,8 +1230,8 @@ impl<'ast> Checker<'ast> { } } Expression::Eq(box e1, box e2) => { - let e1_checked = self.check_expression(e1)?; - let e2_checked = self.check_expression(e2)?; + let e1_checked = self.check_expression(e1, module_id, &types)?; + let e2_checked = self.check_expression(e2, module_id, &types)?; match (e1_checked, e2_checked) { (TypedExpression::FieldElement(e1), TypedExpression::FieldElement(e2)) => { Ok(BooleanExpression::Eq(box e1, box e2).into()) @@ -1238,8 +1249,8 @@ impl<'ast> Checker<'ast> { } } Expression::Ge(box e1, box e2) => { - let e1_checked = self.check_expression(e1)?; - let e2_checked = self.check_expression(e2)?; + let e1_checked = self.check_expression(e1, module_id, &types)?; + let e2_checked = self.check_expression(e2, module_id, &types)?; match (e1_checked, e2_checked) { (TypedExpression::FieldElement(e1), TypedExpression::FieldElement(e2)) => { Ok(BooleanExpression::Ge(box e1, box e2).into()) @@ -1257,8 +1268,8 @@ impl<'ast> Checker<'ast> { } } Expression::Gt(box e1, box e2) => { - let e1_checked = self.check_expression(e1)?; - let e2_checked = self.check_expression(e2)?; + let e1_checked = self.check_expression(e1, module_id, &types)?; + let e2_checked = self.check_expression(e2, module_id, &types)?; match (e1_checked, e2_checked) { (TypedExpression::FieldElement(e1), TypedExpression::FieldElement(e2)) => { Ok(BooleanExpression::Gt(box e1, box e2).into()) @@ -1276,7 +1287,7 @@ impl<'ast> Checker<'ast> { } } Expression::Select(box array, box index) => { - let array = self.check_expression(array)?; + let array = self.check_expression(array, module_id, &types)?; match index { RangeOrExpression::Range(r) => match array { @@ -1338,39 +1349,43 @@ impl<'ast> Checker<'ast> { } _ => panic!(""), }, - RangeOrExpression::Expression(e) => match (array, self.check_expression(e)?) { - (TypedExpression::Array(a), TypedExpression::FieldElement(i)) => { - match a.inner_type() { - Type::FieldElement => { - Ok(FieldElementExpression::Select(box a, box i).into()) + RangeOrExpression::Expression(e) => { + match (array, self.check_expression(e, module_id, &types)?) { + (TypedExpression::Array(a), TypedExpression::FieldElement(i)) => { + match a.inner_type() { + Type::FieldElement => { + Ok(FieldElementExpression::Select(box a, box i).into()) + } + Type::Boolean => { + Ok(BooleanExpression::Select(box a, box i).into()) + } + Type::Array(box ty, size) => Ok(ArrayExpression { + size: *size, + ty: ty.clone(), + inner: ArrayExpressionInner::Select(box a, box i), + } + .into()), + Type::Struct(members) => Ok(StructExpression { + ty: members.clone(), + inner: StructExpressionInner::Select(box a, box i), + } + .into()), } - Type::Boolean => Ok(BooleanExpression::Select(box a, box i).into()), - Type::Array(box ty, size) => Ok(ArrayExpression { - size: *size, - ty: ty.clone(), - inner: ArrayExpressionInner::Select(box a, box i), - } - .into()), - Type::Struct(members) => Ok(StructExpression { - ty: members.clone(), - inner: StructExpressionInner::Select(box a, box i), - } - .into()), } + (a, e) => Err(Error { + pos: Some(pos), + message: format!( + "Cannot access element {} on expression of type {}", + e, + a.get_type() + ), + }), } - (a, e) => Err(Error { - pos: Some(pos), - message: format!( - "Cannot access element {} on expression of type {}", - e, - a.get_type() - ), - }), - }, + } } } Expression::Member(box e, box id) => { - let e = self.check_expression(e)?; + let e = self.check_expression(e, module_id, &types)?; match e { TypedExpression::Struct(s) => { @@ -1431,7 +1446,7 @@ impl<'ast> Checker<'ast> { // check each expression, getting its type let mut expressions_checked = vec![]; for e in expressions { - let e_checked = self.check_spread_or_expression(e)?; + let e_checked = self.check_spread_or_expression(e, module_id, &types)?; expressions_checked.extend(e_checked); } @@ -1583,9 +1598,87 @@ impl<'ast> Checker<'ast> { } } } + Expression::InlineStruct(id, inline_members) => { + let ty = self.check_type( + UnresolvedType::User(id.clone()).at(42, 42, 42), + module_id, + &types, + )?; + let members = match ty { + Type::Struct(members) => members, + _ => unreachable!(), + }; + + // check that we provided the required number of values + + if members.len() != inline_members.len() { + return Err(Error { + pos: Some(pos), + message: format!( + "Inline struct {} does not match {} : {}", + Expression::InlineStruct(id.clone(), inline_members), + id, + Type::Struct(members) + ), + }); + } + + // check that the mapping of values matches the expected type + // put the value into a map, pick members from this map following declared members, and try to parse them + + let mut inline_members_map = inline_members + .clone() + .into_iter() + .map(|(id, v)| (id.to_string(), v)) + .collect::>(); + let mut result: Vec> = vec![]; + + for (member_id, ty) in &members { + match inline_members_map.remove(member_id) { + Some(value) => { + let expression_checked = + self.check_expression(value, module_id, &types)?; + let checked_type = expression_checked.get_type(); + if checked_type != *ty { + return Err(Error { + pos: Some(pos), + message: format!( + "Member {} of struct {} has type {}, found {} of type {}", + member_id, + id.clone(), + ty, + expression_checked, + checked_type, + ), + }); + } else { + result.push(expression_checked.into()); + } + } + None => { + return Err(Error { + pos: Some(pos), + message: format!( + "Member {} of struct {} : {} not found in value {}", + member_id, + id.clone(), + Type::Struct(members.clone()), + Expression::InlineStruct(id.clone(), inline_members), + ), + }) + } + } + } + + Ok(StructExpression { + ty: members, + inner: StructExpressionInner::Value(result), + } + .into()) + } Expression::And(box e1, box e2) => { - let e1_checked = self.check_expression(e1)?; - let e2_checked = self.check_expression(e2)?; + let e1_checked = self.check_expression(e1, module_id, &types)?; + let e2_checked = self.check_expression(e2, module_id, &types)?; match (e1_checked, e2_checked) { (TypedExpression::Boolean(e1), TypedExpression::Boolean(e2)) => { Ok(BooleanExpression::And(box e1, box e2).into()) @@ -1602,8 +1695,8 @@ impl<'ast> Checker<'ast> { } } Expression::Or(box e1, box e2) => { - let e1_checked = self.check_expression(e1)?; - let e2_checked = self.check_expression(e2)?; + let e1_checked = self.check_expression(e1, module_id, &types)?; + let e2_checked = self.check_expression(e2, module_id, &types)?; match (e1_checked, e2_checked) { (TypedExpression::Boolean(e1), TypedExpression::Boolean(e2)) => { Ok(BooleanExpression::Or(box e1, box e2).into()) @@ -1616,7 +1709,7 @@ impl<'ast> Checker<'ast> { } } Expression::Not(box e) => { - let e_checked = self.check_expression(e)?; + let e_checked = self.check_expression(e, module_id, &types)?; match e_checked { TypedExpression::Boolean(e) => Ok(BooleanExpression::Not(box e).into()), e => Err(Error { diff --git a/zokrates_parser/src/zokrates.pest b/zokrates_parser/src/zokrates.pest index e1ca18d2..417d201c 100644 --- a/zokrates_parser/src/zokrates.pest +++ b/zokrates_parser/src/zokrates.pest @@ -57,7 +57,7 @@ optionally_typed_identifier = { (identifier) | (ty ~ identifier) } // we don't u expression_list = _{(expression ~ ("," ~ expression)*)?} expression = { term ~ (op_binary ~ term)* } -term = { ("(" ~ expression ~ ")") | conditional_expression | postfix_expression | primary_expression | inline_array_expression | array_initializer_expression | unary_expression } +term = { ("(" ~ expression ~ ")") | inline_struct_expression | conditional_expression | postfix_expression | primary_expression | inline_array_expression | array_initializer_expression | unary_expression } spread = { "..." ~ expression } range = { from_expression? ~ ".." ~ to_expression? } from_expression = { expression } @@ -75,6 +75,10 @@ primary_expression = { identifier | constant } +inline_struct_expression = { identifier ~ "{" ~ NEWLINE* ~ inline_struct_member_list ~ NEWLINE* ~ "}" } +inline_struct_member_list = _{(inline_struct_member ~ ("," ~ NEWLINE* ~ inline_struct_member)*)? ~ ","? } +inline_struct_member = { identifier ~ ":" ~ expression } + inline_array_expression = { "[" ~ inline_array_inner ~ "]" } inline_array_inner = _{(spread_or_expression ~ ("," ~ spread_or_expression)*)?} spread_or_expression = { spread | expression } diff --git a/zokrates_pest_ast/src/lib.rs b/zokrates_pest_ast/src/lib.rs index 2033b671..d5e79556 100644 --- a/zokrates_pest_ast/src/lib.rs +++ b/zokrates_pest_ast/src/lib.rs @@ -12,10 +12,10 @@ pub use ast::{ AssignmentStatement, BasicOrStructType, BasicType, BinaryExpression, BinaryOperator, CallAccess, ConstantExpression, DefinitionStatement, Expression, File, FromExpression, Function, IdentifierExpression, ImportDirective, ImportSource, InlineArrayExpression, - IterationStatement, MultiAssignmentStatement, Parameter, PostfixExpression, Range, - RangeOrExpression, ReturnStatement, Span, Spread, SpreadOrExpression, Statement, - StructDefinition, StructField, TernaryExpression, ToExpression, Type, UnaryExpression, - UnaryOperator, Visibility, + InlineStructExpression, InlineStructMember, IterationStatement, MultiAssignmentStatement, + Parameter, PostfixExpression, Range, RangeOrExpression, ReturnStatement, Span, Spread, + SpreadOrExpression, Statement, StructDefinition, StructField, TernaryExpression, ToExpression, + Type, UnaryExpression, UnaryOperator, Visibility, }; mod ast { @@ -122,6 +122,9 @@ mod ast { Rule::postfix_expression => Expression::Postfix( PostfixExpression::from_pest(&mut pair.into_inner()).unwrap(), ), + Rule::inline_struct_expression => Expression::InlineStruct( + InlineStructExpression::from_pest(&mut pair.into_inner()).unwrap(), + ), Rule::inline_array_expression => Expression::InlineArray( InlineArrayExpression::from_pest(&mut pair.into_inner()).unwrap(), ), @@ -412,6 +415,7 @@ mod ast { Identifier(IdentifierExpression<'ast>), Constant(ConstantExpression<'ast>), InlineArray(InlineArrayExpression<'ast>), + InlineStruct(InlineStructExpression<'ast>), ArrayInitializer(ArrayInitializerExpression<'ast>), Unary(UnaryExpression<'ast>), } @@ -481,6 +485,24 @@ mod ast { pub span: Span<'ast>, } + #[derive(Debug, FromPest, PartialEq, Clone)] + #[pest_ast(rule(Rule::inline_struct_expression))] + pub struct InlineStructExpression<'ast> { + pub ty: IdentifierExpression<'ast>, + pub members: Vec>, + #[pest_ast(outer())] + pub span: Span<'ast>, + } + + #[derive(Debug, FromPest, PartialEq, Clone)] + #[pest_ast(rule(Rule::inline_struct_member))] + pub struct InlineStructMember<'ast> { + pub id: IdentifierExpression<'ast>, + pub expression: Expression<'ast>, + #[pest_ast(outer())] + pub span: Span<'ast>, + } + #[derive(Debug, FromPest, PartialEq, Clone)] #[pest_ast(rule(Rule::array_initializer_expression))] pub struct ArrayInitializerExpression<'ast> { @@ -586,6 +608,7 @@ mod ast { Expression::Ternary(t) => &t.span, Expression::Postfix(p) => &p.span, Expression::InlineArray(a) => &a.span, + Expression::InlineStruct(s) => &s.span, Expression::ArrayInitializer(a) => &a.span, Expression::Unary(u) => &u.span, } From 8601d6599f4f9e15fb1bd918247df4d331cd1409 Mon Sep 17 00:00:00 2001 From: schaeff Date: Thu, 8 Aug 2019 00:37:44 +0200 Subject: [PATCH 15/35] fix all tests, remove all warnings --- t.code | 4 +- .../examples/imports/import_github.code | 6 - zokrates_core/src/absy/from_ast.rs | 91 ++-- zokrates_core/src/flatten/mod.rs | 453 +++++++++--------- zokrates_core/src/semantics.rs | 451 +++++++++++------ zokrates_core/src/typed_absy/mod.rs | 26 +- zokrates_core/src/typed_absy/types.rs | 27 +- zokrates_parser/src/lib.rs | 10 +- zokrates_pest_ast/src/lib.rs | 36 +- 9 files changed, 640 insertions(+), 464 deletions(-) delete mode 100644 zokrates_cli/examples/imports/import_github.code diff --git a/t.code b/t.code index 00aa02d9..a38f522b 100644 --- a/t.code +++ b/t.code @@ -1,8 +1,8 @@ -from "./u.code" import Foo +from "./u.code" import Fooo struct Bar { a: field, - b: field, + a: field, c: field, } diff --git a/zokrates_cli/examples/imports/import_github.code b/zokrates_cli/examples/imports/import_github.code deleted file mode 100644 index ec30f95e..00000000 --- a/zokrates_cli/examples/imports/import_github.code +++ /dev/null @@ -1,6 +0,0 @@ -// See also mock URLs in zokrates_github_resolver -import "github.com/Zokrates/ZoKrates/master/zokrates_cli/examples/imports/foo.code" -import "github.com/Zokrates/ZoKrates/master/zokrates_cli/examples/imports/bar.code" - -def main() -> (field): - return foo() + bar() diff --git a/zokrates_core/src/absy/from_ast.rs b/zokrates_core/src/absy/from_ast.rs index 6160c9d9..bd77b686 100644 --- a/zokrates_core/src/absy/from_ast.rs +++ b/zokrates_core/src/absy/from_ast.rs @@ -1,6 +1,4 @@ use absy; -use absy::types::UnresolvedType; -use absy::UnresolvedTypeNode; use imports; use zokrates_field::field::Field; use zokrates_pest_ast as pest; @@ -80,7 +78,7 @@ impl<'ast> From> for absy::StructFieldNode<'ast> { let id = field.id.span.as_str(); - let ty = UnresolvedTypeNode::from(field.ty); + let ty = absy::UnresolvedTypeNode::from(field.ty); absy::StructField { id, ty }.span(span) } @@ -98,7 +96,7 @@ impl<'ast, T: Field> From> for absy::SymbolDeclarationNode< .parameters .clone() .into_iter() - .map(|p| UnresolvedTypeNode::from(p.ty)) + .map(|p| absy::UnresolvedTypeNode::from(p.ty)) .collect(), ) .outputs( @@ -106,7 +104,7 @@ impl<'ast, T: Field> From> for absy::SymbolDeclarationNode< .returns .clone() .into_iter() - .map(|r| UnresolvedTypeNode::from(r)) + .map(|r| absy::UnresolvedTypeNode::from(r)) .collect(), ); @@ -147,9 +145,11 @@ impl<'ast> From> for absy::ParameterNode<'ast> { }) .unwrap_or(false); - let variable = - absy::Variable::new(param.id.span.as_str(), UnresolvedTypeNode::from(param.ty)) - .span(param.id.span); + let variable = absy::Variable::new( + param.id.span.as_str(), + absy::UnresolvedTypeNode::from(param.ty), + ) + .span(param.id.span); absy::Parameter::new(variable, private).span(param.span) } @@ -180,8 +180,11 @@ fn statements_from_multi_assignment<'ast, T: Field>( .filter(|i| i.ty.is_some()) .map(|i| { absy::Statement::Declaration( - absy::Variable::new(i.id.span.as_str(), UnresolvedTypeNode::from(i.ty.unwrap())) - .span(i.id.span), + absy::Variable::new( + i.id.span.as_str(), + absy::UnresolvedTypeNode::from(i.ty.unwrap()), + ) + .span(i.id.span), ) .span(i.span) }); @@ -218,7 +221,7 @@ fn statements_from_definition<'ast, T: Field>( absy::Statement::Declaration( absy::Variable::new( definition.id.span.as_str(), - UnresolvedTypeNode::from(definition.ty), + absy::UnresolvedTypeNode::from(definition.ty), ) .span(definition.id.span.clone()), ) @@ -279,7 +282,7 @@ impl<'ast, T: Field> From> for absy::StatementNod let from = absy::ExpressionNode::from(statement.from); let to = absy::ExpressionNode::from(statement.to); let index = statement.index.span.as_str(); - let ty = UnresolvedTypeNode::from(statement.ty); + let ty = absy::UnresolvedTypeNode::from(statement.ty); let statements: Vec> = statement .statements .into_iter() @@ -601,23 +604,25 @@ impl<'ast, T: Field> From> for absy::AssigneeNode<'ast, T> } } -impl<'ast> From> for UnresolvedTypeNode { - fn from(t: pest::Type<'ast>) -> UnresolvedTypeNode { +impl<'ast> From> for absy::UnresolvedTypeNode { + fn from(t: pest::Type<'ast>) -> absy::UnresolvedTypeNode { use absy::NodeValue; match t { pest::Type::Basic(t) => match t { - pest::BasicType::Field(t) => UnresolvedType::FieldElement.span(t.span), - pest::BasicType::Boolean(t) => UnresolvedType::Boolean.span(t.span), + pest::BasicType::Field(t) => absy::UnresolvedType::FieldElement.span(t.span), + pest::BasicType::Boolean(t) => absy::UnresolvedType::Boolean.span(t.span), }, pest::Type::Array(t) => { let inner_type = match t.ty { pest::BasicOrStructType::Basic(t) => match t { - pest::BasicType::Field(t) => UnresolvedType::FieldElement.span(t.span), - pest::BasicType::Boolean(t) => UnresolvedType::Boolean.span(t.span), + pest::BasicType::Field(t) => { + absy::UnresolvedType::FieldElement.span(t.span) + } + pest::BasicType::Boolean(t) => absy::UnresolvedType::Boolean.span(t.span), }, pest::BasicOrStructType::Struct(t) => { - UnresolvedType::User(t.span.as_str().to_string()).span(t.span) + absy::UnresolvedType::User(t.span.as_str().to_string()).span(t.span) } }; @@ -641,14 +646,14 @@ impl<'ast> From> for UnresolvedTypeNode { ), }) .fold(None, |acc, s| match acc { - None => Some(UnresolvedType::array(inner_type.clone(), s)), - Some(acc) => Some(UnresolvedType::array(acc.span(span.clone()), s)), + None => Some(absy::UnresolvedType::array(inner_type.clone(), s)), + Some(acc) => Some(absy::UnresolvedType::array(acc.span(span.clone()), s)), }) .unwrap() .span(span.clone()) } pest::Type::Struct(s) => { - UnresolvedType::User(s.id.span.as_str().to_string()).span(s.span) + absy::UnresolvedType::User(s.id.span.as_str().to_string()).span(s.span) } } } @@ -657,6 +662,7 @@ impl<'ast> From> for UnresolvedTypeNode { #[cfg(test)] mod tests { use super::*; + use absy::NodeValue; use zokrates_field::field::FieldPrime; #[test] @@ -665,9 +671,9 @@ mod tests { "; let ast = pest::generate_ast(&source).unwrap(); let expected: absy::Module = absy::Module { - functions: vec![absy::FunctionDeclaration { + symbols: vec![absy::SymbolDeclaration { id: &source[4..8], - symbol: absy::FunctionSymbol::Here( + symbol: absy::Symbol::HereFunction( absy::Function { arguments: vec![], statements: vec![absy::Statement::Return( @@ -680,9 +686,9 @@ mod tests { .into(), ) .into()], - signature: absy::Signature::new() + signature: absy::UnresolvedSignature::new() .inputs(vec![]) - .outputs(vec![Type::FieldElement]), + .outputs(vec![absy::UnresolvedType::FieldElement.mock()]), } .into(), ), @@ -699,9 +705,9 @@ mod tests { "; let ast = pest::generate_ast(&source).unwrap(); let expected: absy::Module = absy::Module { - functions: vec![absy::FunctionDeclaration { + symbols: vec![absy::SymbolDeclaration { id: &source[4..8], - symbol: absy::FunctionSymbol::Here( + symbol: absy::Symbol::HereFunction( absy::Function { arguments: vec![], statements: vec![absy::Statement::Return( @@ -711,9 +717,9 @@ mod tests { .into(), ) .into()], - signature: absy::Signature::new() + signature: absy::UnresolvedSignature::new() .inputs(vec![]) - .outputs(vec![Type::Boolean]), + .outputs(vec![absy::UnresolvedType::Boolean.mock()]), } .into(), ), @@ -731,17 +737,25 @@ mod tests { let ast = pest::generate_ast(&source).unwrap(); let expected: absy::Module = absy::Module { - functions: vec![absy::FunctionDeclaration { + symbols: vec![absy::SymbolDeclaration { id: &source[4..8], - symbol: absy::FunctionSymbol::Here( + symbol: absy::Symbol::HereFunction( absy::Function { arguments: vec![ absy::Parameter::private( - absy::Variable::field_element(&source[23..24]).into(), + absy::Variable::new( + &source[23..24], + UnresolvedType::FieldElement.mock(), + ) + .into(), ) .into(), absy::Parameter::public( - absy::Variable::boolean(&source[31..32]).into(), + absy::Variable::new( + &source[31..32], + UnresolvedType::Boolean.mock(), + ) + .into(), ) .into(), ], @@ -755,9 +769,12 @@ mod tests { .into(), ) .into()], - signature: absy::Signature::new() - .inputs(vec![Type::FieldElement, Type::Boolean]) - .outputs(vec![Type::FieldElement]), + signature: absy::UnresolvedSignature::new() + .inputs(vec![ + absy::UnresolvedType::FieldElement.mock(), + absy::UnresolvedType::Boolean.mock(), + ]) + .outputs(vec![absy::UnresolvedType::FieldElement.mock()]), } .into(), ), diff --git a/zokrates_core/src/flatten/mod.rs b/zokrates_core/src/flatten/mod.rs index 7c967cdd..b9295c98 100644 --- a/zokrates_core/src/flatten/mod.rs +++ b/zokrates_core/src/flatten/mod.rs @@ -340,73 +340,138 @@ impl<'ast, T: Field> Flattener<'ast, T> { .1 .get_primitive_count(); - let res = - match s.inner { - StructExpressionInner::Value(values) => { - // If the struct has an explicit value, we get the value at the given member - assert_eq!(values.len(), members.len()); - values - .into_iter() - .zip(members.into_iter()) - .filter(|(_, (id, _))| *id == member_id) - .flat_map(|(v, (_, t))| match t { - Type::FieldElement => FieldElementExpression::try_from(v) - .unwrap() - .flatten(self, symbols, statements_flattened), - Type::Boolean => BooleanExpression::try_from(v).unwrap().flatten( - self, - symbols, - statements_flattened, + let res = match s.inner { + StructExpressionInner::Value(values) => { + // If the struct has an explicit value, we get the value at the given member + assert_eq!(values.len(), members.len()); + values + .into_iter() + .zip(members.into_iter()) + .filter(|(_, (id, _))| *id == member_id) + .flat_map(|(v, (_, t))| match t { + Type::FieldElement => FieldElementExpression::try_from(v).unwrap().flatten( + self, + symbols, + statements_flattened, + ), + Type::Boolean => BooleanExpression::try_from(v).unwrap().flatten( + self, + symbols, + statements_flattened, + ), + Type::Array(..) => ArrayExpression::try_from(v).unwrap().flatten( + self, + symbols, + statements_flattened, + ), + Type::Struct(..) => StructExpression::try_from(v).unwrap().flatten( + self, + symbols, + statements_flattened, + ), + }) + .collect() + } + StructExpressionInner::Identifier(id) => { + // If the struct is an identifier, we allocated variables in the layout for that identifier. We need to access a subset of these values. + // the struct is encoded as a sequence, so we need to identify the offset at which this member starts + let offset = members + .iter() + .take_while(|(id, _)| *id != member_id) + .map(|(_, ty)| ty.get_primitive_count()) + .sum(); + + // we also need the size of this member + let size = members + .iter() + .find(|(id, _)| *id == member_id) + .unwrap() + .1 + .get_primitive_count(); + self.layout.get(&id).unwrap()[offset..(offset + size)] + .into_iter() + .map(|i| i.clone().into()) + .collect() + } + StructExpressionInner::Select(box array, box index) => { + // If the struct is an array element `array[index]`, we're accessing `array[index].member` + // We construct `array := array.map(|e| e.member)` and access `array[index]` + let ty = members + .clone() + .into_iter() + .find(|(id, _)| *id == member_id) + .unwrap() + .1; + + match ty { + Type::FieldElement => { + let array = ArrayExpression { + size: array.size, + ty: Type::FieldElement, + inner: ArrayExpressionInner::Value( + (0..array.size) + .map(|i| { + FieldElementExpression::Member( + box StructExpression { + ty: members.clone(), + inner: StructExpressionInner::Select( + box array.clone(), + box FieldElementExpression::Number(T::from(i)), + ), + }, + member_id.clone(), + ) + .into() + }) + .collect(), ), - Type::Array(box ty, size) => ArrayExpression::try_from(v) - .unwrap() - .flatten(self, symbols, statements_flattened), - Type::Struct(members) => StructExpression::try_from(v) - .unwrap() - .flatten(self, symbols, statements_flattened), - }) - .collect() - } - StructExpressionInner::Identifier(id) => { - // If the struct is an identifier, we allocated variables in the layout for that identifier. We need to access a subset of these values. - // the struct is encoded as a sequence, so we need to identify the offset at which this member starts - let offset = members - .iter() - .take_while(|(id, _)| *id != member_id) - .map(|(_, ty)| ty.get_primitive_count()) - .sum(); - - // we also need the size of this member - let size = members - .iter() - .find(|(id, _)| *id == member_id) - .unwrap() - .1 - .get_primitive_count(); - self.layout.get(&id).unwrap()[offset..(offset + size)] - .into_iter() - .map(|i| i.clone().into()) - .collect() - } - StructExpressionInner::Select(box array, box index) => { - // If the struct is an array element `array[index]`, we're accessing `array[index].member` - // We construct `array := array.map(|e| e.member)` and access `array[index]` - let ty = members - .clone() - .into_iter() - .find(|(id, _)| *id == member_id) - .unwrap() - .1; - - match ty { - Type::FieldElement => { - let array = ArrayExpression { - size: array.size, - ty: Type::FieldElement, - inner: ArrayExpressionInner::Value( - (0..array.size) - .map(|i| { - FieldElementExpression::Member( + }; + self.flatten_select_expression::>( + symbols, + statements_flattened, + array, + index, + ) + } + Type::Boolean => { + let array = ArrayExpression { + size: array.size, + ty: Type::Boolean, + inner: ArrayExpressionInner::Value( + (0..array.size) + .map(|i| { + BooleanExpression::Member( + box StructExpression { + ty: members.clone(), + inner: StructExpressionInner::Select( + box array.clone(), + box FieldElementExpression::Number(T::from(i)), + ), + }, + member_id.clone(), + ) + .into() + }) + .collect(), + ), + }; + self.flatten_select_expression::>( + symbols, + statements_flattened, + array, + index, + ) + } + Type::Struct(m) => { + let array = ArrayExpression { + size: array.size, + ty: Type::Struct(m.clone()), + inner: ArrayExpressionInner::Value( + (0..array.size) + .map(|i| { + StructExpression { + ty: m.clone(), + inner: StructExpressionInner::Member( box StructExpression { ty: members.clone(), inner: StructExpressionInner::Select( @@ -417,27 +482,31 @@ impl<'ast, T: Field> Flattener<'ast, T> { ), }, member_id.clone(), - ) - .into() - }) - .collect(), - ), - }; - self.flatten_select_expression::>( - symbols, - statements_flattened, - array, - index, - ) - } - Type::Boolean => { - let array = ArrayExpression { - size: array.size, - ty: Type::Boolean, - inner: ArrayExpressionInner::Value( - (0..array.size) - .map(|i| { - BooleanExpression::Member( + ), + } + .into() + }) + .collect(), + ), + }; + self.flatten_select_expression::>( + symbols, + statements_flattened, + array, + index, + ) + } + Type::Array(box ty, size) => { + let array = ArrayExpression { + size: array.size, + ty: Type::Array(box ty.clone(), size), + inner: ArrayExpressionInner::Value( + (0..array.size) + .map(|i| { + ArrayExpression { + size, + ty: ty.clone(), + inner: ArrayExpressionInner::Member( box StructExpression { ty: members.clone(), inner: StructExpressionInner::Select( @@ -448,152 +517,84 @@ impl<'ast, T: Field> Flattener<'ast, T> { ), }, member_id.clone(), - ) - .into() - }) - .collect(), - ), - }; - self.flatten_select_expression::>( - symbols, - statements_flattened, - array, - index, - ) - } - Type::Struct(m) => { - let array = ArrayExpression { - size: array.size, - ty: Type::Struct(m.clone()), - inner: ArrayExpressionInner::Value( - (0..array.size) - .map(|i| { - StructExpression { - ty: m.clone(), - inner: StructExpressionInner::Member( - box StructExpression { - ty: members.clone(), - inner: StructExpressionInner::Select( - box array.clone(), - box FieldElementExpression::Number( - T::from(i), - ), - ), - }, - member_id.clone(), - ), - } - .into() - }) - .collect(), - ), - }; - self.flatten_select_expression::>( - symbols, - statements_flattened, - array, - index, - ) - } - Type::Array(box ty, size) => { - let array = ArrayExpression { - size: array.size, - ty: Type::Array(box ty.clone(), size), - inner: ArrayExpressionInner::Value( - (0..array.size) - .map(|i| { - ArrayExpression { - size, - ty: ty.clone(), - inner: ArrayExpressionInner::Member( - box StructExpression { - ty: members.clone(), - inner: StructExpressionInner::Select( - box array.clone(), - box FieldElementExpression::Number( - T::from(i), - ), - ), - }, - member_id.clone(), - ), - } - .into() - }) - .collect(), - ), - }; - self.flatten_select_expression::>( - symbols, - statements_flattened, - array, - index, - ) - } + ), + } + .into() + }) + .collect(), + ), + }; + self.flatten_select_expression::>( + symbols, + statements_flattened, + array, + index, + ) } } - StructExpressionInner::FunctionCall(..) => unreachable!(), - StructExpressionInner::IfElse(box condition, box consequence, box alternative) => { - // if the struct is `(if c then a else b)`, we want to access `(if c then a else b).member` - // we reduce to `if c then a.member else b.member` - let ty = members - .clone() - .into_iter() - .find(|(id, _)| *id == member_id) - .unwrap() - .1; + } + StructExpressionInner::FunctionCall(..) => unreachable!(), + StructExpressionInner::IfElse(box condition, box consequence, box alternative) => { + // if the struct is `(if c then a else b)`, we want to access `(if c then a else b).member` + // we reduce to `if c then a.member else b.member` + let ty = members + .clone() + .into_iter() + .find(|(id, _)| *id == member_id) + .unwrap() + .1; - match ty { - Type::FieldElement => self.flatten_if_else_expression( - symbols, - statements_flattened, - condition.clone(), - FieldElementExpression::member(consequence.clone(), member_id.clone()), - FieldElementExpression::member(alternative.clone(), member_id), - ), - Type::Boolean => self.flatten_if_else_expression( - symbols, - statements_flattened, - condition.clone(), - BooleanExpression::member(consequence.clone(), member_id.clone()), - BooleanExpression::member(alternative.clone(), member_id), - ), - Type::Struct(m) => self.flatten_if_else_expression( - symbols, - statements_flattened, - condition.clone(), - StructExpression::member(consequence.clone(), member_id.clone()), - StructExpression::member(alternative.clone(), member_id), - ), - Type::Array(box ty, size) => self.flatten_if_else_expression( - symbols, - statements_flattened, - condition.clone(), - ArrayExpression::member(consequence.clone(), member_id.clone()), - ArrayExpression::member(alternative.clone(), member_id), - ), - } + match ty { + Type::FieldElement => self.flatten_if_else_expression( + symbols, + statements_flattened, + condition.clone(), + FieldElementExpression::member(consequence.clone(), member_id.clone()), + FieldElementExpression::member(alternative.clone(), member_id), + ), + Type::Boolean => self.flatten_if_else_expression( + symbols, + statements_flattened, + condition.clone(), + BooleanExpression::member(consequence.clone(), member_id.clone()), + BooleanExpression::member(alternative.clone(), member_id), + ), + Type::Struct(..) => self.flatten_if_else_expression( + symbols, + statements_flattened, + condition.clone(), + StructExpression::member(consequence.clone(), member_id.clone()), + StructExpression::member(alternative.clone(), member_id), + ), + Type::Array(..) => self.flatten_if_else_expression( + symbols, + statements_flattened, + condition.clone(), + ArrayExpression::member(consequence.clone(), member_id.clone()), + ArrayExpression::member(alternative.clone(), member_id), + ), } - StructExpressionInner::Member(box s0, m_id) => { - let e = self.flatten_member_expression(symbols, statements_flattened, s0, m_id); + } + StructExpressionInner::Member(box s0, m_id) => { + let e = self.flatten_member_expression(symbols, statements_flattened, s0, m_id); - let offset = members - .iter() - .take_while(|(id, _)| *id != member_id) - .map(|(_, ty)| ty.get_primitive_count()) - .sum(); + let offset = members + .iter() + .take_while(|(id, _)| *id != member_id) + .map(|(_, ty)| ty.get_primitive_count()) + .sum(); - // we also need the size of this member - let size = members - .iter() - .find(|(id, _)| *id == member_id) - .unwrap() - .1 - .get_primitive_count(); + // we also need the size of this member + let size = members + .iter() + .find(|(id, _)| *id == member_id) + .unwrap() + .1 + .get_primitive_count(); - e[offset..(offset + size)].into() - } - }; + e[offset..(offset + size)].into() + } + }; assert_eq!(res.len(), expected_output_size); res diff --git a/zokrates_core/src/semantics.rs b/zokrates_core/src/semantics.rs index 3e41c598..483d5ebc 100644 --- a/zokrates_core/src/semantics.rs +++ b/zokrates_core/src/semantics.rs @@ -208,6 +208,7 @@ impl<'ast> Checker<'ast> { let mut errors = vec![]; let mut fields: Vec<(_, _)> = vec![]; + let mut fields_set = HashSet::new(); for field in s.fields { let member_id = field.value.id.to_string(); @@ -215,9 +216,13 @@ impl<'ast> Checker<'ast> { .check_type(field.value.ty, module_id, &state.types) .map(|t| (member_id, t)) { - Ok(f) => { - fields.push(f); - } + Ok(f) => match fields_set.insert(f.0.clone()) { + true => fields.push(f), + false => errors.push(Error { + pos: Some(pos), + message: format!("Duplicate key {} in struct definition", f.0,), + }), + }, Err(e) => { errors.push(e); } @@ -282,6 +287,8 @@ impl<'ast> Checker<'ast> { match self.check_function(f, module_id, &state.types) { Ok(funct) => { + let funct = TypedFunctionSymbol::Here(funct); + let query = FunctionQuery::new( declaration.id.clone(), &funct.signature(&state.typed_modules).inputs, @@ -338,6 +345,8 @@ impl<'ast> Checker<'ast> { let pos = import.pos(); let import = import.value; + state.types.insert(module_id.to_string(), HashMap::new()); + match Checker::new().check_module(&import.module_id, state) { Ok(()) => { // find candidates in the checked module @@ -357,8 +366,8 @@ impl<'ast> Checker<'ast> { // find candidates in the types let type_candidate = state .types - .get(&import.module_id) - .unwrap() + .entry(import.module_id.clone()) + .or_insert_with(|| HashMap::new()) .get(import.symbol_id) .cloned(); @@ -371,7 +380,15 @@ impl<'ast> Checker<'ast> { .or_default() .insert(import.symbol_id.to_string(), t.clone()); } - (0, None) => unreachable!(), + (0, None) => { + errors.push(Error { + pos: Some(pos), + message: format!( + "Could not find symbol {} in module {}", + import.symbol_id, import.module_id, + ), + }); + } _ => { ids.insert(declaration.id); for candidate in function_candidates { @@ -498,7 +515,7 @@ impl<'ast> Checker<'ast> { funct_node: FunctionNode<'ast, T>, module_id: &ModuleId, types: &TypeMap, - ) -> Result, Vec> { + ) -> Result, Vec> { let mut errors = vec![]; let funct = funct_node.value; let mut arguments_checked = vec![]; @@ -541,11 +558,11 @@ impl<'ast> Checker<'ast> { return Err(errors); } - Ok(TypedFunctionSymbol::Here(TypedFunction { + Ok(TypedFunction { arguments: arguments_checked, statements: statements_checked, signature: signature.unwrap(), - })) + }) } fn check_parameter( @@ -1758,8 +1775,8 @@ impl<'ast> Checker<'ast> { #[cfg(test)] mod tests { use super::*; + use absy; use typed_absy; - use types::Signature; use zokrates_field::field::FieldPrime; mod array { @@ -1767,13 +1784,17 @@ mod tests { #[test] fn element_type_mismatch() { + let types = HashMap::new(); + let module_id = String::from(""); // [3, true] let a = Expression::InlineArray(vec![ Expression::FieldConstant(FieldPrime::from(3)).mock().into(), Expression::BooleanConstant(true).mock().into(), ]) .mock(); - assert!(Checker::new().check_expression(a).is_err()); + assert!(Checker::new() + .check_expression(a, &module_id, &types) + .is_err()); // [[0], [0, 0]] let a = Expression::InlineArray(vec![ @@ -1790,7 +1811,9 @@ mod tests { .into(), ]) .mock(); - assert!(Checker::new().check_expression(a).is_err()); + assert!(Checker::new() + .check_expression(a, &module_id, &types) + .is_err()); // [[0], true] let a = Expression::InlineArray(vec![ @@ -1804,13 +1827,14 @@ mod tests { .into(), ]) .mock(); - assert!(Checker::new().check_expression(a).is_err()); + assert!(Checker::new() + .check_expression(a, &module_id, &types) + .is_err()); } } mod symbols { use super::*; - use crate::types::Signature; #[test] fn imported_symbol() { @@ -1824,9 +1848,9 @@ mod tests { // after semantic check, `bar` should import a checked function let foo: Module = Module { - functions: vec![FunctionDeclaration { + symbols: vec![SymbolDeclaration { id: "main", - symbol: FunctionSymbol::Here( + symbol: Symbol::HereFunction( Function { statements: vec![Statement::Return( ExpressionList { @@ -1838,7 +1862,8 @@ mod tests { .mock(), ) .mock()], - signature: Signature::new().outputs(vec![Type::FieldElement]), + signature: UnresolvedSignature::new() + .outputs(vec![UnresolvedType::FieldElement.mock()]), arguments: vec![], } .mock(), @@ -1849,28 +1874,27 @@ mod tests { }; let bar: Module = Module { - functions: vec![FunctionDeclaration { + symbols: vec![SymbolDeclaration { id: "main", - symbol: FunctionSymbol::There( - FunctionImport::with_id_in_module("main", "foo").mock(), - ), + symbol: Symbol::There(SymbolImport::with_id_in_module("main", "foo").mock()), } .mock()], imports: vec![], }; - let mut modules = vec![(String::from("foo"), foo), (String::from("bar"), bar)] - .into_iter() - .collect(); - let mut typed_modules = HashMap::new(); + let mut state = State::new( + vec![(String::from("foo"), foo), (String::from("bar"), bar)] + .into_iter() + .collect(), + ); let mut checker = Checker::new(); checker - .check_module(&String::from("bar"), &mut modules, &mut typed_modules) + .check_module(&String::from("bar"), &mut state) .unwrap(); assert_eq!( - typed_modules.get(&String::from("bar")), + state.typed_modules.get(&String::from("bar")), Some(&TypedModule { functions: vec![( FunctionKey::with_id("main") @@ -1910,9 +1934,12 @@ mod tests { ) .mock(); + let types = HashMap::new(); + let module_id = String::from(""); + let mut checker = Checker::new(); assert_eq!( - checker.check_statement(statement, &vec![]), + checker.check_statement(statement, &vec![], &module_id, &types), Err(Error { pos: Some((Position::mock(), Position::mock())), message: "Identifier \"b\" is undefined".to_string() @@ -1930,18 +1957,21 @@ mod tests { ) .mock(); + let types = HashMap::new(); + let module_id = String::from(""); + let mut scope = HashSet::new(); scope.insert(ScopedVariable { - id: Variable::field_element("a"), + id: Variable::field_element("a".into()), level: 0, }); scope.insert(ScopedVariable { - id: Variable::field_element("b"), + id: Variable::field_element("b".into()), level: 0, }); let mut checker = new_with_args(scope, 1, HashSet::new()); assert_eq!( - checker.check_statement(statement, &vec![]), + checker.check_statement(statement, &vec![], &module_id, &types), Ok(TypedStatement::Definition( TypedAssignee::Identifier(typed_absy::Variable::field_element("a".into())), FieldElementExpression::Identifier("b".into()).into() @@ -1958,7 +1988,10 @@ mod tests { // should fail let foo_args = vec![]; let foo_statements = vec![ - Statement::Declaration(Variable::field_element("a").mock()).mock(), + Statement::Declaration( + absy::Variable::new("a", UnresolvedType::FieldElement.mock()).mock(), + ) + .mock(), Statement::Definition( Assignee::Identifier("a").mock(), Expression::FieldConstant(FieldPrime::from(1)).mock(), @@ -1968,9 +2001,9 @@ mod tests { let foo = Function { arguments: foo_args, statements: foo_statements, - signature: Signature { + signature: UnresolvedSignature { inputs: vec![], - outputs: vec![Type::FieldElement], + outputs: vec![UnresolvedType::FieldElement.mock()], }, } .mock(); @@ -1987,35 +2020,35 @@ mod tests { let bar = Function { arguments: bar_args, statements: bar_statements, - signature: Signature { + signature: UnresolvedSignature { inputs: vec![], - outputs: vec![Type::FieldElement], + outputs: vec![UnresolvedType::FieldElement.mock()], }, } .mock(); - let funcs = vec![ - FunctionDeclaration { + let symbols = vec![ + SymbolDeclaration { id: "foo", - symbol: FunctionSymbol::Here(foo), + symbol: Symbol::HereFunction(foo), } .mock(), - FunctionDeclaration { + SymbolDeclaration { id: "bar", - symbol: FunctionSymbol::Here(bar), + symbol: Symbol::HereFunction(bar), } .mock(), ]; let module = Module { - functions: funcs, + symbols, imports: vec![], }; - let mut modules = vec![(String::from("main"), module)].into_iter().collect(); + let mut state = State::new(vec![(String::from("main"), module)].into_iter().collect()); let mut checker = Checker::new(); assert_eq!( - checker.check_module(&String::from("main"), &mut modules, &mut HashMap::new()), + checker.check_module(&String::from("main"), &mut state), Err(vec![Error { pos: Some((Position::mock(), Position::mock())), message: "Identifier \"a\" is undefined".to_string() @@ -2035,7 +2068,10 @@ mod tests { // should pass let foo_args = vec![]; let foo_statements = vec![ - Statement::Declaration(Variable::field_element("a").mock()).mock(), + Statement::Declaration( + absy::Variable::new("a", UnresolvedType::FieldElement.mock()).mock(), + ) + .mock(), Statement::Definition( Assignee::Identifier("a").mock(), Expression::FieldConstant(FieldPrime::from(1)).mock(), @@ -2046,16 +2082,19 @@ mod tests { let foo = Function { arguments: foo_args, statements: foo_statements, - signature: Signature { + signature: UnresolvedSignature { inputs: vec![], - outputs: vec![Type::FieldElement], + outputs: vec![UnresolvedType::FieldElement.mock()], }, } .mock(); let bar_args = vec![]; let bar_statements = vec![ - Statement::Declaration(Variable::field_element("a").mock()).mock(), + Statement::Declaration( + absy::Variable::new("a", UnresolvedType::FieldElement.mock()).mock(), + ) + .mock(), Statement::Definition( Assignee::Identifier("a").mock(), Expression::FieldConstant(FieldPrime::from(2)).mock(), @@ -2072,9 +2111,9 @@ mod tests { let bar = Function { arguments: bar_args, statements: bar_statements, - signature: Signature { + signature: UnresolvedSignature { inputs: vec![], - outputs: vec![Type::FieldElement], + outputs: vec![UnresolvedType::FieldElement.mock()], }, } .mock(); @@ -2091,40 +2130,40 @@ mod tests { let main = Function { arguments: main_args, statements: main_statements, - signature: Signature { + signature: UnresolvedSignature { inputs: vec![], - outputs: vec![Type::FieldElement], + outputs: vec![UnresolvedType::FieldElement.mock()], }, } .mock(); - let funcs = vec![ - FunctionDeclaration { + let symbols = vec![ + SymbolDeclaration { id: "foo", - symbol: FunctionSymbol::Here(foo), + symbol: Symbol::HereFunction(foo), } .mock(), - FunctionDeclaration { + SymbolDeclaration { id: "bar", - symbol: FunctionSymbol::Here(bar), + symbol: Symbol::HereFunction(bar), } .mock(), - FunctionDeclaration { + SymbolDeclaration { id: "main", - symbol: FunctionSymbol::Here(main), + symbol: Symbol::HereFunction(main), } .mock(), ]; let module = Module { - functions: funcs, + symbols, imports: vec![], }; - let mut modules = vec![(String::from("main"), module)].into_iter().collect(); + let mut state = State::new(vec![(String::from("main"), module)].into_iter().collect()); let mut checker = Checker::new(); assert!(checker - .check_module(&String::from("main"), &mut modules, &mut HashMap::new()) + .check_module(&String::from("main"), &mut state) .is_ok()); } @@ -2137,7 +2176,7 @@ mod tests { // should fail let foo_statements = vec![ Statement::For( - Variable::field_element("i").mock(), + absy::Variable::new("i", UnresolvedType::FieldElement.mock()).mock(), FieldPrime::from(0), FieldPrime::from(10), vec![], @@ -2154,16 +2193,19 @@ mod tests { let foo = Function { arguments: vec![], statements: foo_statements, - signature: Signature { + signature: UnresolvedSignature { inputs: vec![], - outputs: vec![Type::FieldElement], + outputs: vec![UnresolvedType::FieldElement.mock()], }, } .mock(); + let types = HashMap::new(); + let module_id = String::from(""); + let mut checker = Checker::new(); assert_eq!( - checker.check_function(foo), + checker.check_function(foo, &module_id, &types), Err(vec![Error { pos: Some((Position::mock(), Position::mock())), message: "Identifier \"i\" is undefined".to_string() @@ -2180,7 +2222,10 @@ mod tests { // should pass let for_statements = vec![ - Statement::Declaration(Variable::field_element("a").mock()).mock(), + Statement::Declaration( + absy::Variable::new("a", UnresolvedType::FieldElement.mock()).mock(), + ) + .mock(), Statement::Definition( Assignee::Identifier("a").mock(), Expression::Identifier("i").mock(), @@ -2189,7 +2234,7 @@ mod tests { ]; let foo_statements = vec![Statement::For( - Variable::field_element("i").mock(), + absy::Variable::new("i", UnresolvedType::FieldElement.mock()).mock(), FieldPrime::from(0), FieldPrime::from(10), for_statements, @@ -2214,9 +2259,9 @@ mod tests { let foo = Function { arguments: vec![], statements: foo_statements, - signature: Signature { + signature: UnresolvedSignature { inputs: vec![], - outputs: vec![Type::FieldElement], + outputs: vec![UnresolvedType::FieldElement.mock()], }, } .mock(); @@ -2230,8 +2275,14 @@ mod tests { }, }; + let types = HashMap::new(); + let module_id = String::from(""); + let mut checker = Checker::new(); - assert_eq!(checker.check_function(foo), Ok(foo_checked)); + assert_eq!( + checker.check_function(foo, &module_id, &types), + Ok(foo_checked) + ); } #[test] @@ -2242,7 +2293,10 @@ mod tests { // field a = foo() // should fail let bar_statements: Vec> = vec![ - Statement::Declaration(Variable::field_element("a").mock()).mock(), + Statement::Declaration( + absy::Variable::new("a", UnresolvedType::FieldElement.mock()).mock(), + ) + .mock(), Statement::MultipleDefinition( vec![Assignee::Identifier("a").mock()], Expression::FunctionCall("foo", vec![]).mock(), @@ -2263,16 +2317,19 @@ mod tests { let bar = Function { arguments: vec![], statements: bar_statements, - signature: Signature { + signature: UnresolvedSignature { inputs: vec![], - outputs: vec![Type::FieldElement], + outputs: vec![UnresolvedType::FieldElement.mock()], }, } .mock(); + let types = HashMap::new(); + let module_id = String::from(""); + let mut checker = new_with_args(HashSet::new(), 0, functions); assert_eq!( - checker.check_function(bar), + checker.check_function(bar, &module_id, &types), Err(vec![Error { pos: Some((Position::mock(), Position::mock())), message: @@ -2308,16 +2365,19 @@ mod tests { let bar = Function { arguments: vec![], statements: bar_statements, - signature: Signature { + signature: UnresolvedSignature { inputs: vec![], - outputs: vec![Type::FieldElement], + outputs: vec![UnresolvedType::FieldElement.mock()], }, } .mock(); + let types = HashMap::new(); + let module_id = String::from(""); + let mut checker = new_with_args(HashSet::new(), 0, functions); assert_eq!( - checker.check_function(bar), + checker.check_function(bar, &module_id, &types), Err(vec![Error { pos: Some((Position::mock(), Position::mock())), message: "Function definition for function foo with signature () -> (_) not found." @@ -2332,7 +2392,10 @@ mod tests { // field a = foo() // should fail let bar_statements: Vec> = vec![ - Statement::Declaration(Variable::field_element("a").mock()).mock(), + Statement::Declaration( + absy::Variable::new("a", UnresolvedType::FieldElement.mock()).mock(), + ) + .mock(), Statement::MultipleDefinition( vec![Assignee::Identifier("a").mock()], Expression::FunctionCall("foo", vec![]).mock(), @@ -2343,16 +2406,19 @@ mod tests { let bar = Function { arguments: vec![], statements: bar_statements, - signature: Signature { + signature: UnresolvedSignature { inputs: vec![], - outputs: vec![Type::FieldElement], + outputs: vec![UnresolvedType::FieldElement.mock()], }, } .mock(); + let types = HashMap::new(); + let module_id = String::from(""); + let mut checker = new_with_args(HashSet::new(), 0, HashSet::new()); assert_eq!( - checker.check_function(bar), + checker.check_function(bar, &module_id, &types), Err(vec![Error { pos: Some((Position::mock(), Position::mock())), @@ -2385,21 +2451,30 @@ mod tests { let foo = Function { arguments: vec![crate::absy::Parameter { - id: Variable::field_element("x").mock(), + id: absy::Variable::new("x", UnresolvedType::FieldElement.mock()).mock(), private: false, } .mock()], statements: foo_statements, - signature: Signature { - inputs: vec![Type::FieldElement], - outputs: vec![Type::FieldElement, Type::FieldElement], + signature: UnresolvedSignature { + inputs: vec![UnresolvedType::FieldElement.mock()], + outputs: vec![ + UnresolvedType::FieldElement.mock(), + UnresolvedType::FieldElement.mock(), + ], }, } .mock(); let main_statements: Vec> = vec![ - Statement::Declaration(Variable::field_element("a").mock()).mock(), - Statement::Declaration(Variable::field_element("b").mock()).mock(), + Statement::Declaration( + absy::Variable::new("a", UnresolvedType::FieldElement.mock()).mock(), + ) + .mock(), + Statement::Declaration( + absy::Variable::new("b", UnresolvedType::FieldElement.mock()).mock(), + ) + .mock(), Statement::MultipleDefinition( vec![ Assignee::Identifier("a").mock(), @@ -2420,34 +2495,34 @@ mod tests { let main = Function { arguments: vec![], statements: main_statements, - signature: Signature { + signature: UnresolvedSignature { inputs: vec![], - outputs: vec![Type::FieldElement], + outputs: vec![UnresolvedType::FieldElement.mock()], }, } .mock(); let module = Module { - functions: vec![ - FunctionDeclaration { + symbols: vec![ + SymbolDeclaration { id: "foo", - symbol: FunctionSymbol::Here(foo), + symbol: Symbol::HereFunction(foo), } .mock(), - FunctionDeclaration { + SymbolDeclaration { id: "main", - symbol: FunctionSymbol::Here(main), + symbol: Symbol::HereFunction(main), } .mock(), ], imports: vec![], }; - let mut modules = vec![(String::from("main"), module)].into_iter().collect(); + let mut state = State::new(vec![(String::from("main"), module)].into_iter().collect()); let mut checker = new_with_args(HashSet::new(), 0, HashSet::new()); assert_eq!( - checker.check_module(&String::from("main"), &mut modules, &mut HashMap::new()), + checker.check_module(&String::from("main"), &mut state), Err(vec![Error { pos: Some((Position::mock(), Position::mock())), message: "Identifier \"x\" is undefined".to_string() @@ -2469,16 +2544,19 @@ mod tests { let bar = Function { arguments: vec![], statements: bar_statements, - signature: Signature { + signature: UnresolvedSignature { inputs: vec![], - outputs: vec![Type::FieldElement], + outputs: vec![UnresolvedType::FieldElement.mock()], }, } .mock(); + let types = HashMap::new(); + let module_id = String::from(""); + let mut checker = new_with_args(HashSet::new(), 0, HashSet::new()); assert_eq!( - checker.check_function(bar), + checker.check_function(bar, &module_id, &types), Err(vec![Error { pos: Some((Position::mock(), Position::mock())), @@ -2507,16 +2585,22 @@ mod tests { let bar = Function { arguments: vec![], statements: bar_statements, - signature: Signature { + signature: UnresolvedSignature { inputs: vec![], - outputs: vec![Type::FieldElement, Type::FieldElement], + outputs: vec![ + UnresolvedType::FieldElement.mock(), + UnresolvedType::FieldElement.mock(), + ], }, } .mock(); + let types = HashMap::new(); + let module_id = String::from(""); + let mut checker = new_with_args(HashSet::new(), 0, HashSet::new()); assert_eq!( - checker.check_function(bar), + checker.check_function(bar, &module_id, &types), Err(vec![Error { pos: Some((Position::mock(), Position::mock())), message: "Identifier \"a\" is undefined".to_string() @@ -2534,8 +2618,14 @@ mod tests { // // should pass let bar_statements: Vec> = vec![ - Statement::Declaration(Variable::field_element("a").mock()).mock(), - Statement::Declaration(Variable::field_element("b").mock()).mock(), + Statement::Declaration( + absy::Variable::new("a", UnresolvedType::FieldElement.mock()).mock(), + ) + .mock(), + Statement::Declaration( + absy::Variable::new("b", UnresolvedType::FieldElement.mock()).mock(), + ) + .mock(), Statement::MultipleDefinition( vec![ Assignee::Identifier("a").mock(), @@ -2594,9 +2684,9 @@ mod tests { let bar = Function { arguments: vec![], statements: bar_statements, - signature: Signature { + signature: UnresolvedSignature { inputs: vec![], - outputs: vec![Type::FieldElement], + outputs: vec![UnresolvedType::FieldElement.mock()], }, } .mock(); @@ -2610,8 +2700,14 @@ mod tests { }, }; + let types = HashMap::new(); + let module_id = String::from(""); + let mut checker = new_with_args(HashSet::new(), 0, functions); - assert_eq!(checker.check_function(bar), Ok(bar_checked)); + assert_eq!( + checker.check_function(bar, &module_id, &types), + Ok(bar_checked) + ); } #[test] @@ -2633,12 +2729,12 @@ mod tests { let foo1_arguments = vec![ crate::absy::Parameter { - id: Variable::field_element("a").mock(), + id: absy::Variable::new("a", UnresolvedType::FieldElement.mock()).mock(), private: true, } .mock(), crate::absy::Parameter { - id: Variable::field_element("b").mock(), + id: absy::Variable::new("b", UnresolvedType::FieldElement.mock()).mock(), private: true, } .mock(), @@ -2654,12 +2750,12 @@ mod tests { let foo2_arguments = vec![ crate::absy::Parameter { - id: Variable::field_element("c").mock(), + id: absy::Variable::new("c", UnresolvedType::FieldElement.mock()).mock(), private: true, } .mock(), crate::absy::Parameter { - id: Variable::field_element("d").mock(), + id: absy::Variable::new("d", UnresolvedType::FieldElement.mock()).mock(), private: true, } .mock(), @@ -2668,9 +2764,12 @@ mod tests { let foo1 = Function { arguments: foo1_arguments, statements: foo1_statements, - signature: Signature { - inputs: vec![Type::FieldElement, Type::FieldElement], - outputs: vec![Type::FieldElement], + signature: UnresolvedSignature { + inputs: vec![ + UnresolvedType::FieldElement.mock(), + UnresolvedType::FieldElement.mock(), + ], + outputs: vec![UnresolvedType::FieldElement.mock()], }, } .mock(); @@ -2678,34 +2777,37 @@ mod tests { let foo2 = Function { arguments: foo2_arguments, statements: foo2_statements, - signature: Signature { - inputs: vec![Type::FieldElement, Type::FieldElement], - outputs: vec![Type::FieldElement], + signature: UnresolvedSignature { + inputs: vec![ + UnresolvedType::FieldElement.mock(), + UnresolvedType::FieldElement.mock(), + ], + outputs: vec![UnresolvedType::FieldElement.mock()], }, } .mock(); let module = Module { - functions: vec![ - FunctionDeclaration { + symbols: vec![ + SymbolDeclaration { id: "foo", - symbol: FunctionSymbol::Here(foo1), + symbol: Symbol::HereFunction(foo1), } .mock(), - FunctionDeclaration { + SymbolDeclaration { id: "foo", - symbol: FunctionSymbol::Here(foo2), + symbol: Symbol::HereFunction(foo2), } .mock(), ], imports: vec![], }; - let mut modules = vec![(String::from("main"), module)].into_iter().collect(); + let mut state = State::new(vec![(String::from("main"), module)].into_iter().collect()); let mut checker = Checker::new(); assert_eq!( - checker.check_module(&String::from("main"), &mut modules, &mut HashMap::new()), + checker.check_module(&String::from("main"), &mut state), Err(vec![Error { pos: Some((Position::mock(), Position::mock())), @@ -2733,7 +2835,7 @@ mod tests { .mock()]; let main1_arguments = vec![crate::absy::Parameter { - id: Variable::field_element("a").mock(), + id: absy::Variable::new("a", UnresolvedType::FieldElement.mock()).mock(), private: false, } .mock()]; @@ -2751,9 +2853,9 @@ mod tests { let main1 = Function { arguments: main1_arguments, statements: main1_statements, - signature: Signature { - inputs: vec![Type::FieldElement], - outputs: vec![Type::FieldElement], + signature: UnresolvedSignature { + inputs: vec![UnresolvedType::FieldElement.mock()], + outputs: vec![UnresolvedType::FieldElement.mock()], }, } .mock(); @@ -2761,28 +2863,28 @@ mod tests { let main2 = Function { arguments: main2_arguments, statements: main2_statements, - signature: Signature { + signature: UnresolvedSignature { inputs: vec![], - outputs: vec![Type::FieldElement], + outputs: vec![UnresolvedType::FieldElement.mock()], }, } .mock(); - let functions = vec![ - FunctionDeclaration { + let symbols = vec![ + SymbolDeclaration { id: "main", - symbol: FunctionSymbol::Here(main1), + symbol: Symbol::HereFunction(main1), } .mock(), - FunctionDeclaration { + SymbolDeclaration { id: "main", - symbol: FunctionSymbol::Here(main2), + symbol: Symbol::HereFunction(main2), } .mock(), ]; let main_module = Module { - functions: functions, + symbols, imports: vec![], }; @@ -2810,14 +2912,26 @@ mod tests { // // should fail + let types = HashMap::new(); + let module_id = String::from(""); let mut checker = Checker::new(); let _: Result, Error> = checker.check_statement( - Statement::Declaration(Variable::field_element("a").mock()).mock(), + Statement::Declaration( + absy::Variable::new("a", UnresolvedType::FieldElement.mock()).mock(), + ) + .mock(), &vec![], + &module_id, + &types, ); let s2_checked: Result, Error> = checker.check_statement( - Statement::Declaration(Variable::field_element("a").mock()).mock(), + Statement::Declaration( + absy::Variable::new("a", UnresolvedType::FieldElement.mock()).mock(), + ) + .mock(), &vec![], + &module_id, + &types, ); assert_eq!( s2_checked, @@ -2835,14 +2949,25 @@ mod tests { // // should fail + let types = HashMap::new(); + let module_id = String::from(""); + let mut checker = Checker::new(); let _: Result, Error> = checker.check_statement( - Statement::Declaration(Variable::field_element("a").mock()).mock(), + Statement::Declaration( + absy::Variable::new("a", UnresolvedType::FieldElement.mock()).mock(), + ) + .mock(), &vec![], + &module_id, + &types, ); let s2_checked: Result, Error> = checker.check_statement( - Statement::Declaration(Variable::boolean("a").mock()).mock(), + Statement::Declaration(absy::Variable::new("a", UnresolvedType::Boolean.mock()).mock()) + .mock(), &vec![], + &module_id, + &types, ); assert_eq!( s2_checked, @@ -2861,16 +2986,23 @@ mod tests { // a = 42 let a = Assignee::Identifier::("a").mock(); + let types = HashMap::new(); + let module_id = String::from(""); let mut checker: Checker = Checker::new(); checker .check_statement::( - Statement::Declaration(Variable::field_element("a").mock()).mock(), + Statement::Declaration( + absy::Variable::new("a", UnresolvedType::FieldElement.mock()).mock(), + ) + .mock(), &vec![], + &module_id, + &types, ) .unwrap(); assert_eq!( - checker.check_assignee(a), + checker.check_assignee(a, &module_id, &types), Ok(TypedAssignee::Identifier( typed_absy::Variable::field_element("a".into()) )) @@ -2889,16 +3021,28 @@ mod tests { ) .mock(); + let types = HashMap::new(); + let module_id = String::from(""); + let mut checker: Checker = Checker::new(); checker .check_statement::( - Statement::Declaration(Variable::field_array("a", 33).mock()).mock(), + Statement::Declaration( + absy::Variable::new( + "a", + UnresolvedType::array(UnresolvedType::FieldElement.mock(), 33).mock(), + ) + .mock(), + ) + .mock(), &vec![], + &module_id, + &types, ) .unwrap(); assert_eq!( - checker.check_assignee(a), + checker.check_assignee(a, &module_id, &types), Ok(TypedAssignee::ArrayElement( box TypedAssignee::Identifier(typed_absy::Variable::field_array( "a".into(), @@ -2927,19 +3071,32 @@ mod tests { ) .mock(); + let types = HashMap::new(); + let module_id = String::from(""); let mut checker: Checker = Checker::new(); checker .check_statement::( Statement::Declaration( - Variable::array("a", Type::array(Type::FieldElement, 33), 42).mock(), + absy::Variable::new( + "a", + UnresolvedType::array( + UnresolvedType::array(UnresolvedType::FieldElement.mock(), 33) + .mock(), + 42, + ) + .mock(), + ) + .mock(), ) .mock(), &vec![], + &module_id, + &types, ) .unwrap(); assert_eq!( - checker.check_assignee(a), + checker.check_assignee(a, &module_id, &types), Ok(TypedAssignee::ArrayElement( box TypedAssignee::ArrayElement( box TypedAssignee::Identifier(typed_absy::Variable::array( diff --git a/zokrates_core/src/typed_absy/mod.rs b/zokrates_core/src/typed_absy/mod.rs index 46c3e825..63731d14 100644 --- a/zokrates_core/src/typed_absy/mod.rs +++ b/zokrates_core/src/typed_absy/mod.rs @@ -83,17 +83,21 @@ pub struct TypedModule<'ast, T: Field> { impl<'ast> fmt::Display for Identifier<'ast> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!( - f, - "{}_{}_{}", - self.stack - .iter() - .map(|(name, sig, count)| format!("{}_{}_{}", name, sig.to_slug(), count)) - .collect::>() - .join("_"), - self.id, - self.version - ) + if self.version == 0 && self.stack.len() == 0 { + write!(f, "{}", self.id) + } else { + write!( + f, + "{}_{}_{}", + self.stack + .iter() + .map(|(name, sig, count)| format!("{}_{}_{}", name, sig.to_slug(), count)) + .collect::>() + .join("_"), + self.id, + self.version + ) + } } } diff --git a/zokrates_core/src/typed_absy/types.rs b/zokrates_core/src/typed_absy/types.rs index e3a49e96..addabdf5 100644 --- a/zokrates_core/src/typed_absy/types.rs +++ b/zokrates_core/src/typed_absy/types.rs @@ -60,7 +60,14 @@ impl Type { Type::FieldElement => String::from("f"), Type::Boolean => String::from("b"), Type::Array(box ty, size) => format!("{}[{}]", ty.to_slug(), size), - Type::Struct(members) => unimplemented!(), + Type::Struct(members) => format!( + "{{{}}}", + members + .iter() + .map(|(id, ty)| format!("{}:{}", id, ty)) + .collect::>() + .join(",") + ), } } @@ -75,24 +82,6 @@ impl Type { } } -#[derive(Clone, PartialEq, Hash, Eq)] -pub struct Variable<'ast> { - pub id: Identifier<'ast>, - pub _type: Type, -} - -impl<'ast> fmt::Display for Variable<'ast> { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "{} {}", self._type, self.id,) - } -} - -impl<'ast> fmt::Debug for Variable<'ast> { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "Variable(type: {:?}, id: {:?})", self._type, self.id,) - } -} - pub type FunctionIdentifier<'ast> = &'ast str; #[derive(PartialEq, Eq, Hash, Debug, Clone)] diff --git a/zokrates_parser/src/lib.rs b/zokrates_parser/src/lib.rs index b104cd9b..e0383a84 100644 --- a/zokrates_parser/src/lib.rs +++ b/zokrates_parser/src/lib.rs @@ -121,7 +121,7 @@ mod tests { "#, rule: Rule::statement, tokens: [ - statement(0, 10, [ + statement(0, 22, [ multi_assignment_statement(0, 9, [ optionally_typed_identifier(0, 1, [ identifier(0, 1) @@ -153,7 +153,7 @@ mod tests { ", rule: Rule::ty_struct_definition, tokens: [ - ty_struct_definition(0, 40, [ + ty_struct_definition(0, 41, [ identifier(7, 10), struct_field(13, 23, [ identifier(13, 16), @@ -167,8 +167,10 @@ mod tests { identifier(25, 28), ty(30, 39, [ ty_array(30, 39, [ - ty_basic(30, 35, [ - ty_field(30, 35) + ty_basic_or_struct(30, 35, [ + ty_basic(30, 35, [ + ty_field(30, 35) + ]) ]), expression(36, 37, [ term(36, 37, [ diff --git a/zokrates_pest_ast/src/lib.rs b/zokrates_pest_ast/src/lib.rs index d5e79556..395e0269 100644 --- a/zokrates_pest_ast/src/lib.rs +++ b/zokrates_pest_ast/src/lib.rs @@ -795,7 +795,9 @@ mod tests { span: Span::new(&source, 33, 37).unwrap() }, parameters: vec![], - returns: vec![Type::Basic(BasicType::Field(FieldType {}))], + returns: vec![Type::Basic(BasicType::Field(FieldType { + span: Span::new(&source, 44, 49).unwrap() + }))], statements: vec![Statement::Return(ReturnStatement { expressions: vec![Expression::add( Expression::Constant(ConstantExpression::DecimalNumber( @@ -816,14 +818,14 @@ mod tests { })], span: Span::new(&source, 29, source.len()).unwrap(), }], - imports: vec![ImportDirective { + imports: vec![ImportDirective::Main(MainImportDirective { source: ImportSource { value: String::from("foo"), span: Span::new(&source, 8, 11).unwrap() }, alias: None, span: Span::new(&source, 0, 29).unwrap() - }], + })], eoi: EOI {}, span: Span::new(&source, 0, 65).unwrap() }) @@ -845,7 +847,9 @@ mod tests { span: Span::new(&source, 33, 37).unwrap() }, parameters: vec![], - returns: vec![Type::Basic(BasicType::Field(FieldType {}))], + returns: vec![Type::Basic(BasicType::Field(FieldType { + span: Span::new(&source, 44, 49).unwrap() + }))], statements: vec![Statement::Return(ReturnStatement { expressions: vec![Expression::add( Expression::Constant(ConstantExpression::DecimalNumber( @@ -884,14 +888,14 @@ mod tests { })], span: Span::new(&source, 29, 74).unwrap(), }], - imports: vec![ImportDirective { + imports: vec![ImportDirective::Main(MainImportDirective { source: ImportSource { value: String::from("foo"), span: Span::new(&source, 8, 11).unwrap() }, alias: None, span: Span::new(&source, 0, 29).unwrap() - }], + })], eoi: EOI {}, span: Span::new(&source, 0, 74).unwrap() }) @@ -913,7 +917,9 @@ mod tests { span: Span::new(&source, 33, 37).unwrap() }, parameters: vec![], - returns: vec![Type::Basic(BasicType::Field(FieldType {}))], + returns: vec![Type::Basic(BasicType::Field(FieldType { + span: Span::new(&source, 44, 49).unwrap() + }))], statements: vec![Statement::Return(ReturnStatement { expressions: vec![Expression::if_else( Expression::Constant(ConstantExpression::DecimalNumber( @@ -940,14 +946,14 @@ mod tests { })], span: Span::new(&source, 29, 81).unwrap(), }], - imports: vec![ImportDirective { + imports: vec![ImportDirective::Main(MainImportDirective { source: ImportSource { value: String::from("foo"), span: Span::new(&source, 8, 11).unwrap() }, alias: None, span: Span::new(&source, 0, 29).unwrap() - }], + })], eoi: EOI {}, span: Span::new(&source, 0, 81).unwrap() }) @@ -968,7 +974,9 @@ mod tests { span: Span::new(&source, 4, 8).unwrap() }, parameters: vec![], - returns: vec![Type::Basic(BasicType::Field(FieldType {}))], + returns: vec![Type::Basic(BasicType::Field(FieldType { + span: Span::new(&source, 15, 20).unwrap() + }))], statements: vec![Statement::Return(ReturnStatement { expressions: vec![Expression::Constant(ConstantExpression::DecimalNumber( DecimalNumberExpression { @@ -1001,7 +1009,9 @@ mod tests { span: Span::new(&source, 4, 8).unwrap() }, parameters: vec![], - returns: vec![Type::Basic(BasicType::Field(FieldType {}))], + returns: vec![Type::Basic(BasicType::Field(FieldType { + span: Span::new(&source, 15, 20).unwrap() + }))], statements: vec![Statement::MultiAssignment(MultiAssignmentStatement { function_id: IdentifierExpression { value: String::from("foo"), @@ -1009,7 +1019,9 @@ mod tests { }, lhs: vec![ OptionallyTypedIdentifier { - ty: Some(Type::Basic(BasicType::Field(FieldType {}))), + ty: Some(Type::Basic(BasicType::Field(FieldType { + span: Span::new(&source, 23, 28).unwrap() + }))), id: IdentifierExpression { value: String::from("a"), span: Span::new(&source, 29, 30).unwrap(), From dcc13af656f8d7a463d19b5317064fa98ecb0f0d Mon Sep 17 00:00:00 2001 From: schaeff Date: Thu, 8 Aug 2019 02:15:20 +0200 Subject: [PATCH 16/35] try non preveiw fmt, fix test --- .circleci/config.yml | 2 +- zokrates_core/src/absy/from_ast.rs | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index 94bbd1a8..54ebe26f 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -44,7 +44,7 @@ jobs: - v4-cargo-cache-{{ arch }}-{{ checksum "Cargo.lock" }} - run: name: Check format - command: rustup component add rustfmt-preview; cargo fmt --all -- --check + command: rustup component add rustfmt; cargo fmt --all -- --check - run: name: Install libsnark prerequisites command: ./scripts/install_libsnark_prerequisites.sh diff --git a/zokrates_core/src/absy/from_ast.rs b/zokrates_core/src/absy/from_ast.rs index bd77b686..8cfa3255 100644 --- a/zokrates_core/src/absy/from_ast.rs +++ b/zokrates_core/src/absy/from_ast.rs @@ -745,7 +745,7 @@ mod tests { absy::Parameter::private( absy::Variable::new( &source[23..24], - UnresolvedType::FieldElement.mock(), + absy::UnresolvedType::FieldElement.mock(), ) .into(), ) @@ -753,7 +753,7 @@ mod tests { absy::Parameter::public( absy::Variable::new( &source[31..32], - UnresolvedType::Boolean.mock(), + absy::UnresolvedType::Boolean.mock(), ) .into(), ) From d45922c8f54d081733bc6bf6cbdb42486e5e9078 Mon Sep 17 00:00:00 2001 From: schaeff Date: Thu, 8 Aug 2019 02:17:51 +0200 Subject: [PATCH 17/35] fix format --- zokrates_core/src/typed_absy/types.rs | 1 - zokrates_field/src/field.rs | 1 - 2 files changed, 2 deletions(-) diff --git a/zokrates_core/src/typed_absy/types.rs b/zokrates_core/src/typed_absy/types.rs index addabdf5..176d6509 100644 --- a/zokrates_core/src/typed_absy/types.rs +++ b/zokrates_core/src/typed_absy/types.rs @@ -276,7 +276,6 @@ pub mod signature { assert_eq!(s.to_slug(), String::from("if[42]f[21]o")); } } - } #[cfg(test)] diff --git a/zokrates_field/src/field.rs b/zokrates_field/src/field.rs index 274515dc..a36e69eb 100644 --- a/zokrates_field/src/field.rs +++ b/zokrates_field/src/field.rs @@ -800,5 +800,4 @@ mod tests { assert_eq!(FieldPrime::from_bellman(a), cc); } } - } From cd79fa42e2cd63abcc8a2eb1dc31a26828ede6c7 Mon Sep 17 00:00:00 2001 From: schaeff Date: Thu, 8 Aug 2019 19:55:18 +0200 Subject: [PATCH 18/35] add tests for declaration --- zokrates_core/src/absy/mod.rs | 2 +- zokrates_core/src/semantics.rs | 744 ++++++++++++++++++++++++--------- 2 files changed, 537 insertions(+), 209 deletions(-) diff --git a/zokrates_core/src/absy/mod.rs b/zokrates_core/src/absy/mod.rs index e463a342..d1cc58b5 100644 --- a/zokrates_core/src/absy/mod.rs +++ b/zokrates_core/src/absy/mod.rs @@ -72,7 +72,7 @@ impl<'ast, T: Field> fmt::Display for SymbolDeclaration<'ast, T> { } } -type SymbolDeclarationNode<'ast, T> = Node>; +pub type SymbolDeclarationNode<'ast, T> = Node>; /// A module as a collection of `FunctionDeclaration`s #[derive(Clone, PartialEq)] diff --git a/zokrates_core/src/semantics.rs b/zokrates_core/src/semantics.rs index 483d5ebc..61e6ca6c 100644 --- a/zokrates_core/src/semantics.rs +++ b/zokrates_core/src/semantics.rs @@ -30,11 +30,11 @@ type TypeMap = HashMap>; /// The global state of the program during semantic checks #[derive(Debug)] struct State<'ast, T: Field> { - /// The modules yet to be checked + /// The modules yet to be checked, which we consume as we explore the dependency tree modules: Modules<'ast, T>, - /// The already checked modules + /// The already checked modules, which we're returning at the end typed_modules: TypedModules<'ast, T>, - /// The user-defined types + /// The user-defined types, which we keep track at this phase only. In later phases, we rely only on basic types and combinations thereof types: TypeMap, } @@ -197,11 +197,11 @@ impl<'ast> Checker<'ast> { }) } - fn check_struct_type_declaration( + fn check_struct_type_declaration( &mut self, s: StructTypeNode<'ast>, module_id: &ModuleId, - state: &mut State<'ast, T>, + types: &TypeMap, ) -> Result> { let pos = s.pos(); let s = s.value; @@ -213,7 +213,7 @@ impl<'ast> Checker<'ast> { for field in s.fields { let member_id = field.value.id.to_string(); match self - .check_type(field.value.ty, module_id, &state.types) + .check_type(field.value.ty, module_id, &types) .map(|t| (member_id, t)) { Ok(f) => match fields_set.insert(f.0.clone()) { @@ -236,6 +236,207 @@ impl<'ast> Checker<'ast> { Ok(Type::Struct(fields)) } + fn check_symbol_declaration( + &mut self, + declaration: SymbolDeclarationNode<'ast, T>, + module_id: &ModuleId, + state: &mut State<'ast, T>, + functions: &mut HashMap, TypedFunctionSymbol<'ast, T>>, + symbol_id_set: &mut HashSet>, + ) -> Result<(), Vec> { + let mut errors = vec![]; + + let pos = declaration.pos(); + let declaration = declaration.value; + + match declaration.symbol { + Symbol::HereType(t) => { + match self.check_struct_type_declaration(t.clone(), module_id, &state.types) { + Ok(ty) => { + match symbol_id_set.insert(declaration.id) { + false => errors.push(Error { + pos: Some(pos), + message: format!( + "Another symbol with id {} is already defined", + declaration.id, + ), + }), + true => {} + }; + state + .types + .entry(module_id.clone()) + .or_default() + .insert(declaration.id.to_string(), ty); + } + Err(e) => errors.extend(e), + } + } + Symbol::HereFunction(f) => match self.check_function(f, module_id, &state.types) { + Ok(funct) => { + let funct = TypedFunctionSymbol::Here(funct); + + let query = FunctionQuery::new( + declaration.id.clone(), + &funct.signature(&state.typed_modules).inputs, + &funct + .signature(&state.typed_modules) + .outputs + .clone() + .into_iter() + .map(|o| Some(o)) + .collect(), + ); + + let candidates = self.find_candidates(&query); + + match candidates.len() { + 1 => { + errors.push(Error { + pos: Some(pos), + message: format!( + "Duplicate definition for function {} with signature {}", + declaration.id, + funct.signature(&state.typed_modules) + ), + }); + } + 0 => {} + _ => panic!("duplicate function declaration should have been caught"), + } + + symbol_id_set.insert(declaration.id); + + self.functions.insert( + FunctionKey::with_id(declaration.id.clone()) + .signature(funct.signature(&state.typed_modules).clone()), + ); + functions.insert( + FunctionKey::with_id(declaration.id.clone()) + .signature(funct.signature(&state.typed_modules).clone()), + funct, + ); + } + Err(e) => { + errors.extend(e); + } + }, + Symbol::There(import) => { + let pos = import.pos(); + let import = import.value; + + match Checker::new().check_module(&import.module_id, state) { + Ok(()) => { + // find candidates in the checked module + let function_candidates: Vec<_> = state + .typed_modules + .get(&import.module_id) + .unwrap() + .functions + .iter() + .filter(|(k, _)| k.id == import.symbol_id) + .map(|(_, v)| FunctionKey { + id: import.symbol_id.clone(), + signature: v.signature(&state.typed_modules).clone(), + }) + .collect(); + + // find candidates in the types + let type_candidate = state + .types + .entry(import.module_id.clone()) + .or_default() + .get(import.symbol_id) + .cloned(); + + match (function_candidates.len(), type_candidate) { + (0, Some(t)) => { + symbol_id_set.insert(declaration.id); + state + .types + .entry(module_id.clone()) + .or_default() + .insert(import.symbol_id.to_string(), t.clone()); + } + (0, None) => { + errors.push(Error { + pos: Some(pos), + message: format!( + "Could not find symbol {} in module {}", + import.symbol_id, import.module_id, + ), + }); + } + _ => { + symbol_id_set.insert(declaration.id); + for candidate in function_candidates { + self.functions.insert(candidate.clone().id(declaration.id)); + functions.insert( + candidate.clone().id(declaration.id), + TypedFunctionSymbol::There( + candidate, + import.module_id.clone(), + ), + ); + } + } + } + } + Err(e) => { + errors.extend(e); + } + }; + } + Symbol::Flat(funct) => { + let query = FunctionQuery::new( + declaration.id.clone(), + &funct.signature::().inputs, + &funct + .signature::() + .outputs + .clone() + .into_iter() + .map(|o| Some(o)) + .collect(), + ); + + let candidates = self.find_candidates(&query); + + match candidates.len() { + 1 => { + errors.push(Error { + pos: Some(pos), + message: format!( + "Duplicate definition for function {} with signature {}", + declaration.id, + funct.signature::() + ), + }); + } + 0 => {} + _ => panic!("duplicate function declaration should have been caught"), + } + symbol_id_set.insert(declaration.id); + self.functions.insert( + FunctionKey::with_id(declaration.id.clone()) + .signature(funct.signature::().clone()), + ); + functions.insert( + FunctionKey::with_id(declaration.id.clone()) + .signature(funct.signature::().clone()), + TypedFunctionSymbol::Flat(funct), + ); + } + }; + + // return if any errors occured + if errors.len() > 0 { + return Err(errors); + } + + Ok(()) + } + fn check_module( &mut self, module_id: &ModuleId, @@ -252,205 +453,24 @@ impl<'ast> Checker<'ast> { Some(module) => { assert_eq!(module.imports.len(), 0); - let mut ids = HashSet::new(); + // we need to create an entry in the types map to store types for this module + state.types.entry(module_id.clone()).or_default(); + // we keep track of the introduced symbols to avoid colisions between types and functions + let mut symbol_id_set = HashSet::new(); + + // we go through symbol declarations and check them for declaration in module.symbols { - let pos = declaration.pos(); - let declaration = declaration.value; - - match declaration.symbol { - Symbol::HereType(t) => { - match ids.insert(declaration.id) { - false => errors.push(Error { - pos: Some(pos), - message: format!( - "Another symbol with id {} is already defined", - declaration.id, - ), - }), - true => {} - }; - - match self.check_struct_type_declaration(t.clone(), module_id, state) { - Ok(ty) => { - state - .types - .entry(module_id.clone()) - .or_default() - .insert(declaration.id.to_string(), ty); - } - Err(e) => errors.extend(e), - } - } - Symbol::HereFunction(f) => { - self.enter_scope(); - - match self.check_function(f, module_id, &state.types) { - Ok(funct) => { - let funct = TypedFunctionSymbol::Here(funct); - - let query = FunctionQuery::new( - declaration.id.clone(), - &funct.signature(&state.typed_modules).inputs, - &funct - .signature(&state.typed_modules) - .outputs - .clone() - .into_iter() - .map(|o| Some(o)) - .collect(), - ); - - let candidates = self.find_candidates(&query); - - match candidates.len() { - 1 => { - errors.push(Error { - pos: Some(pos), - message: format!( - "Duplicate definition for function {} with signature {}", - declaration.id, - funct.signature(&state.typed_modules) - ), - }); - } - 0 => {} - _ => panic!( - "duplicate function declaration should have been caught" - ), - } - - ids.insert(declaration.id); - - self.functions.insert( - FunctionKey::with_id(declaration.id.clone()).signature( - funct.signature(&state.typed_modules).clone(), - ), - ); - checked_functions.insert( - FunctionKey::with_id(declaration.id.clone()).signature( - funct.signature(&state.typed_modules).clone(), - ), - funct, - ); - } - Err(e) => { - errors.extend(e); - } - } - - self.exit_scope(); - } - Symbol::There(import) => { - let pos = import.pos(); - let import = import.value; - - state.types.insert(module_id.to_string(), HashMap::new()); - - match Checker::new().check_module(&import.module_id, state) { - Ok(()) => { - // find candidates in the checked module - let function_candidates: Vec<_> = state - .typed_modules - .get(&import.module_id) - .unwrap() - .functions - .iter() - .filter(|(k, _)| k.id == import.symbol_id) - .map(|(_, v)| FunctionKey { - id: import.symbol_id.clone(), - signature: v.signature(&state.typed_modules).clone(), - }) - .collect(); - - // find candidates in the types - let type_candidate = state - .types - .entry(import.module_id.clone()) - .or_insert_with(|| HashMap::new()) - .get(import.symbol_id) - .cloned(); - - match (function_candidates.len(), type_candidate) { - (0, Some(t)) => { - ids.insert(declaration.id); - state - .types - .entry(module_id.clone()) - .or_default() - .insert(import.symbol_id.to_string(), t.clone()); - } - (0, None) => { - errors.push(Error { - pos: Some(pos), - message: format!( - "Could not find symbol {} in module {}", - import.symbol_id, import.module_id, - ), - }); - } - _ => { - ids.insert(declaration.id); - for candidate in function_candidates { - self.functions - .insert(candidate.clone().id(declaration.id)); - checked_functions.insert( - candidate.clone().id(declaration.id), - TypedFunctionSymbol::There( - candidate, - import.module_id.clone(), - ), - ); - } - } - } - } - Err(e) => { - errors.extend(e); - } - }; - } - Symbol::Flat(funct) => { - let query = FunctionQuery::new( - declaration.id.clone(), - &funct.signature::().inputs, - &funct - .signature::() - .outputs - .clone() - .into_iter() - .map(|o| Some(o)) - .collect(), - ); - - let candidates = self.find_candidates(&query); - - match candidates.len() { - 1 => { - errors.push(Error { - pos: Some(pos), - message: format!( - "Duplicate definition for function {} with signature {}", - declaration.id, - funct.signature::() - ), - }); - } - 0 => {} - _ => { - panic!("duplicate function declaration should have been caught") - } - } - ids.insert(declaration.id); - self.functions.insert( - FunctionKey::with_id(declaration.id.clone()) - .signature(funct.signature::().clone()), - ); - checked_functions.insert( - FunctionKey::with_id(declaration.id.clone()) - .signature(funct.signature::().clone()), - TypedFunctionSymbol::Flat(funct), - ); + match self.check_symbol_declaration( + declaration, + module_id, + state, + &mut checked_functions, + &mut symbol_id_set, + ) { + Ok(()) => {} + Err(e) => { + errors.extend(e); } } } @@ -469,15 +489,15 @@ impl<'ast> Checker<'ast> { // insert into typed_modules if we checked anything match to_insert { Some(typed_module) => { - state.typed_modules.insert(module_id.clone(), typed_module); + // there should be no checked module at that key just yet, if there is we have a colision or we checked something twice + assert!(state + .typed_modules + .insert(module_id.clone(), typed_module) + .is_none()); } None => {} }; - if errors.len() > 0 { - return Err(errors); - } - Ok(()) } @@ -516,6 +536,8 @@ impl<'ast> Checker<'ast> { module_id: &ModuleId, types: &TypeMap, ) -> Result, Vec> { + self.enter_scope(); + let mut errors = vec![]; let funct = funct_node.value; let mut arguments_checked = vec![]; @@ -558,6 +580,8 @@ impl<'ast> Checker<'ast> { return Err(errors); } + self.exit_scope(); + Ok(TypedFunction { arguments: arguments_checked, statements: statements_checked, @@ -2978,6 +3002,310 @@ mod tests { ); } + mod structs { + use super::*; + + mod declaration { + use super::*; + + #[test] + fn empty_def() { + // an empty struct should be allowed to be defined + let module_id = "".to_string(); + let types = HashMap::new(); + let declaration = StructType { fields: vec![] }.mock(); + + let expected_type = Type::Struct(vec![]); + + assert_eq!( + Checker::new().check_struct_type_declaration(declaration, &module_id, &types), + Ok(expected_type) + ); + } + + #[test] + fn valid_def() { + // a valid struct should be allowed to be defined + let module_id = "".to_string(); + let types = HashMap::new(); + let declaration = StructType { + fields: vec![ + StructField { + id: "foo", + ty: UnresolvedType::FieldElement.mock(), + } + .mock(), + StructField { + id: "bar", + ty: UnresolvedType::Boolean.mock(), + } + .mock(), + ], + } + .mock(); + + let expected_type = Type::Struct(vec![ + ("foo".to_string(), Type::FieldElement), + ("bar".to_string(), Type::Boolean), + ]); + + assert_eq!( + Checker::new().check_struct_type_declaration(declaration, &module_id, &types), + Ok(expected_type) + ); + } + + #[test] + fn preserve_order() { + // two structs with inverted members are not equal + let module_id = "".to_string(); + let types = HashMap::new(); + + let declaration0 = StructType { + fields: vec![ + StructField { + id: "foo", + ty: UnresolvedType::FieldElement.mock(), + } + .mock(), + StructField { + id: "bar", + ty: UnresolvedType::Boolean.mock(), + } + .mock(), + ], + } + .mock(); + + let declaration1 = StructType { + fields: vec![ + StructField { + id: "bar", + ty: UnresolvedType::Boolean.mock(), + } + .mock(), + StructField { + id: "foo", + ty: UnresolvedType::FieldElement.mock(), + } + .mock(), + ], + } + .mock(); + + assert!( + Checker::new().check_struct_type_declaration(declaration0, &module_id, &types) + != Checker::new().check_struct_type_declaration( + declaration1, + &module_id, + &types + ) + ); + } + + #[test] + fn duplicate_member_def() { + // definition of a struct with a duplicate member should be rejected + let module_id = "".to_string(); + let types = HashMap::new(); + + let declaration = StructType { + fields: vec![ + StructField { + id: "foo", + ty: UnresolvedType::FieldElement.mock(), + } + .mock(), + StructField { + id: "foo", + ty: UnresolvedType::Boolean.mock(), + } + .mock(), + ], + } + .mock(); + + assert_eq!( + Checker::new() + .check_struct_type_declaration(declaration, &module_id, &types) + .unwrap_err()[0] + .message, + "Duplicate key foo in struct definition" + ); + } + + #[test] + fn recursive() { + // a struct wrapping another struct should be allowed to be defined + + // struct Foo = { foo: field } + // struct Bar = { foo: Foo } + + let module_id = "".to_string(); + + let module: Module = Module { + imports: vec![], + symbols: vec![ + SymbolDeclaration { + id: "Foo", + symbol: Symbol::HereType( + StructType { + fields: vec![StructField { + id: "foo", + ty: UnresolvedType::FieldElement.mock(), + } + .mock()], + } + .mock(), + ), + } + .mock(), + SymbolDeclaration { + id: "Bar", + symbol: Symbol::HereType( + StructType { + fields: vec![StructField { + id: "foo", + ty: UnresolvedType::User("Foo".to_string()).mock(), + } + .mock()], + } + .mock(), + ), + } + .mock(), + ], + }; + + let mut state = State::new(vec![(module_id.clone(), module)].into_iter().collect()); + + assert!(Checker::new().check_module(&module_id, &mut state).is_ok()); + assert_eq!( + state + .types + .get(&"".to_string()) + .unwrap() + .get(&"Bar".to_string()) + .unwrap(), + &Type::Struct(vec![( + "foo".to_string(), + Type::Struct(vec![("foo".to_string(), Type::FieldElement)]) + )]) + ); + } + + #[test] + fn recursive_undefined() { + // a struct wrapping an undefined struct should be rejected + + // struct Bar = { foo: Foo } + + let module_id = "".to_string(); + + let module: Module = Module { + imports: vec![], + symbols: vec![SymbolDeclaration { + id: "Bar", + symbol: Symbol::HereType( + StructType { + fields: vec![StructField { + id: "foo", + ty: UnresolvedType::User("Foo".to_string()).mock(), + } + .mock()], + } + .mock(), + ), + } + .mock()], + }; + + let mut state = State::new(vec![(module_id.clone(), module)].into_iter().collect()); + + assert!(Checker::new().check_module(&module_id, &mut state).is_err()); + } + + #[test] + fn self_referential() { + // a struct wrapping itself should be rejected + + // struct Foo = { foo: Foo } + + let module_id = "".to_string(); + + let module: Module = Module { + imports: vec![], + symbols: vec![SymbolDeclaration { + id: "Foo", + symbol: Symbol::HereType( + StructType { + fields: vec![StructField { + id: "foo", + ty: UnresolvedType::User("Foo".to_string()).mock(), + } + .mock()], + } + .mock(), + ), + } + .mock()], + }; + + let mut state = State::new(vec![(module_id.clone(), module)].into_iter().collect()); + + assert!(Checker::new().check_module(&module_id, &mut state).is_err()); + } + + #[test] + fn cyclic() { + // A wrapping B wrapping A should be rejected + + // struct Foo = { bar: Bar } + // struct Bar = { foo: Foo } + + let module_id = "".to_string(); + + let module: Module = Module { + imports: vec![], + symbols: vec![ + SymbolDeclaration { + id: "Foo", + symbol: Symbol::HereType( + StructType { + fields: vec![StructField { + id: "bar", + ty: UnresolvedType::User("Bar".to_string()).mock(), + } + .mock()], + } + .mock(), + ), + } + .mock(), + SymbolDeclaration { + id: "Bar", + symbol: Symbol::HereType( + StructType { + fields: vec![StructField { + id: "foo", + ty: UnresolvedType::User("Foo".to_string()).mock(), + } + .mock()], + } + .mock(), + ), + } + .mock(), + ], + }; + + let mut state = State::new(vec![(module_id.clone(), module)].into_iter().collect()); + + assert!(Checker::new().check_module(&module_id, &mut state).is_err()); + } + } + } + mod assignee { use super::*; From 70529ef3e6b2517faaa5667b8a510adafc760081 Mon Sep 17 00:00:00 2001 From: schaeff Date: Fri, 9 Aug 2019 14:09:47 +0200 Subject: [PATCH 19/35] add more tests for structs, move return value checking out of statement check to function check --- zokrates_core/src/semantics.rs | 606 ++++++++++++++++++++++++++++++--- 1 file changed, 560 insertions(+), 46 deletions(-) diff --git a/zokrates_core/src/semantics.rs b/zokrates_core/src/semantics.rs index 61e6ca6c..d3c7c59d 100644 --- a/zokrates_core/src/semantics.rs +++ b/zokrates_core/src/semantics.rs @@ -560,8 +560,36 @@ impl<'ast> Checker<'ast> { match self.check_signature(funct.signature, module_id, types) { Ok(s) => { for stat in funct.statements.into_iter() { - match self.check_statement(stat, &s.outputs, module_id, types) { + let pos = stat.pos(); + + match self.check_statement(stat, module_id, types) { Ok(statement) => { + match &statement { + TypedStatement::Return(e) => { + match e.iter().map(|e| e.get_type()).collect::>() + == s.outputs + { + true => {} + false => errors.push(Error { + pos: Some(pos), + message: format!( + "Expected ({}) in return statement, found ({})", + s.outputs + .iter() + .map(|t| t.to_string()) + .collect::>() + .join(", "), + e.iter() + .map(|e| e.get_type()) + .map(|t| t.to_string()) + .collect::>() + .join(", ") + ), + }), + } + } + _ => {} + }; statements_checked.push(statement); } Err(e) => { @@ -688,7 +716,6 @@ impl<'ast> Checker<'ast> { fn check_statement( &mut self, stat: StatementNode<'ast, T>, - header_return_types: &Vec, module_id: &ModuleId, types: &TypeMap, ) -> Result, Error> { @@ -702,30 +729,7 @@ impl<'ast> Checker<'ast> { expression_list_checked.push(e_checked); } - let return_statement_types: Vec<_> = expression_list_checked - .iter() - .map(|e| e.get_type()) - .collect(); - - match return_statement_types == *header_return_types { - true => Ok(TypedStatement::Return(expression_list_checked)), - false => Err(Error { - pos: Some(pos), - message: format!( - "Expected ({}) in return statement, found ({})", - header_return_types - .iter() - .map(|t| t.to_string()) - .collect::>() - .join(", "), - return_statement_types - .iter() - .map(|t| t.to_string()) - .collect::>() - .join(", ") - ), - }), - } + Ok(TypedStatement::Return(expression_list_checked)) } Statement::Declaration(var) => { let var = self.check_variable(var, module_id, types).unwrap(); @@ -797,8 +801,7 @@ impl<'ast> Checker<'ast> { let mut checked_statements = vec![]; for stat in statements { - let checked_stat = - self.check_statement(stat, header_return_types, module_id, types)?; + let checked_stat = self.check_statement(stat, module_id, types)?; checked_statements.push(checked_stat); } @@ -1458,15 +1461,7 @@ impl<'ast> Checker<'ast> { }, None => Err(Error { pos: Some(pos), - message: format!( - "{} doesn't have member {}. Members are {}", - TypedExpression::Struct(s.clone()), - id, - s.ty.iter() - .map(|(member_id, _)| member_id.to_string()) - .collect::>() - .join(", ") - ), + message: format!("{} doesn't have member {}", s.get_type(), id,), }), } } @@ -1963,7 +1958,7 @@ mod tests { let mut checker = Checker::new(); assert_eq!( - checker.check_statement(statement, &vec![], &module_id, &types), + checker.check_statement(statement, &module_id, &types), Err(Error { pos: Some((Position::mock(), Position::mock())), message: "Identifier \"b\" is undefined".to_string() @@ -1995,7 +1990,7 @@ mod tests { }); let mut checker = new_with_args(scope, 1, HashSet::new()); assert_eq!( - checker.check_statement(statement, &vec![], &module_id, &types), + checker.check_statement(statement, &module_id, &types), Ok(TypedStatement::Definition( TypedAssignee::Identifier(typed_absy::Variable::field_element("a".into())), FieldElementExpression::Identifier("b".into()).into() @@ -2944,7 +2939,6 @@ mod tests { absy::Variable::new("a", UnresolvedType::FieldElement.mock()).mock(), ) .mock(), - &vec![], &module_id, &types, ); @@ -2953,7 +2947,6 @@ mod tests { absy::Variable::new("a", UnresolvedType::FieldElement.mock()).mock(), ) .mock(), - &vec![], &module_id, &types, ); @@ -2982,14 +2975,12 @@ mod tests { absy::Variable::new("a", UnresolvedType::FieldElement.mock()).mock(), ) .mock(), - &vec![], &module_id, &types, ); let s2_checked: Result, Error> = checker.check_statement( Statement::Declaration(absy::Variable::new("a", UnresolvedType::Boolean.mock()).mock()) .mock(), - &vec![], &module_id, &types, ); @@ -3005,6 +2996,33 @@ mod tests { mod structs { use super::*; + const MODULE_ID: &str = ""; + + /// helper function to create a module at location "" with a single symbol `Foo { foo: field }` + fn create_module_with_foo( + s: StructType<'static>, + ) -> (Checker<'static>, State<'static, FieldPrime>) { + let module_id = "".to_string(); + + let module: Module = Module { + imports: vec![], + symbols: vec![SymbolDeclaration { + id: "Foo", + symbol: Symbol::HereType(s.mock()), + } + .mock()], + }; + + let mut state = State::new(vec![(module_id.clone(), module)].into_iter().collect()); + + let mut checker = Checker::new(); + + checker.check_module(&module_id, &mut state).unwrap(); + + (checker, state) + } + + /// tests about declaring a type mod declaration { use super::*; @@ -3304,6 +3322,505 @@ mod tests { assert!(Checker::new().check_module(&module_id, &mut state).is_err()); } } + + /// tests about using the defined type identifier + mod usage { + use super::*; + + #[test] + fn ty() { + // a defined type can be checked + // Foo { foo: field } + // Foo + + // an undefined type cannot be checked + // Bar + + let (checker, state) = create_module_with_foo(StructType { + fields: vec![StructField { + id: "foo", + ty: UnresolvedType::FieldElement.mock(), + } + .mock()], + }); + + assert_eq!( + checker.check_type( + UnresolvedType::User("Foo".to_string()).mock(), + &MODULE_ID.to_string(), + &state.types + ), + Ok(Type::Struct(vec![("foo".to_string(), Type::FieldElement)])) + ); + + assert_eq!( + checker + .check_type( + UnresolvedType::User("Bar".to_string()).mock(), + &MODULE_ID.to_string(), + &state.types + ) + .unwrap_err() + .message, + "Undefined type Bar" + ); + } + + #[test] + fn parameter() { + // a defined type can be used as parameter + + // an undefined type cannot be used as parameter + + let (checker, state) = create_module_with_foo(StructType { + fields: vec![StructField { + id: "foo", + ty: UnresolvedType::FieldElement.mock(), + } + .mock()], + }); + + assert_eq!( + checker.check_parameter( + absy::Parameter { + id: absy::Variable::new( + "a", + UnresolvedType::User("Foo".to_string()).mock(), + ) + .mock(), + private: true, + } + .mock(), + &MODULE_ID.to_string(), + &state.types, + ), + Ok(Parameter { + id: Variable::with_id_and_type( + "a".into(), + Type::Struct(vec![("foo".to_string(), Type::FieldElement)]) + ), + private: true + }) + ); + + assert_eq!( + checker + .check_parameter( + absy::Parameter { + id: absy::Variable::new( + "a", + UnresolvedType::User("Bar".to_string()).mock(), + ) + .mock(), + private: true, + } + .mock(), + &MODULE_ID.to_string(), + &state.types, + ) + .unwrap_err()[0] + .message, + "Undefined type Bar" + ); + } + + #[test] + fn variable_declaration() { + // a defined type can be used in a variable declaration + + // an undefined type cannot be used in a variable declaration + + let (mut checker, state) = create_module_with_foo(StructType { + fields: vec![StructField { + id: "foo", + ty: UnresolvedType::FieldElement.mock(), + } + .mock()], + }); + + assert_eq!( + checker.check_statement::( + Statement::Declaration( + absy::Variable::new( + "a", + UnresolvedType::User("Foo".to_string()).mock(), + ) + .mock() + ) + .mock(), + &MODULE_ID.to_string(), + &state.types, + ), + Ok(TypedStatement::Declaration(Variable::with_id_and_type( + "a".into(), + Type::Struct(vec![("foo".to_string(), Type::FieldElement)]) + ))) + ); + + assert_eq!( + checker + .check_parameter( + absy::Parameter { + id: absy::Variable::new( + "a", + UnresolvedType::User("Bar".to_string()).mock(), + ) + .mock(), + private: true, + } + .mock(), + &MODULE_ID.to_string(), + &state.types, + ) + .unwrap_err()[0] + .message, + "Undefined type Bar" + ); + } + } + + /// tests about accessing members + mod member { + use super::*; + + #[test] + fn valid() { + // accessing a member on a struct should succeed and return the right type + + // struct Foo = { foo: field } + // Foo { foo: 42 }.foo + + let (mut checker, state) = create_module_with_foo(StructType { + fields: vec![StructField { + id: "foo", + ty: UnresolvedType::FieldElement.mock(), + } + .mock()], + }); + + assert_eq!( + checker.check_expression( + Expression::Member( + box Expression::InlineStruct( + "Foo".to_string(), + vec![( + "foo", + Expression::FieldConstant(FieldPrime::from(42)).mock() + )] + ) + .mock(), + "foo".into() + ) + .mock(), + &MODULE_ID.to_string(), + &state.types + ), + Ok(FieldElementExpression::Member( + box StructExpression { + ty: vec![("foo".to_string(), Type::FieldElement)], + inner: StructExpressionInner::Value(vec![ + FieldElementExpression::Number(FieldPrime::from(42)).into() + ]) + }, + "foo".to_string() + ) + .into()) + ); + } + + #[test] + fn invalid() { + // accessing an undefined member on a struct should fail + + // struct Foo = { foo: field } + // Foo { foo: 42 }.bar + + let (mut checker, state) = create_module_with_foo(StructType { + fields: vec![StructField { + id: "foo", + ty: UnresolvedType::FieldElement.mock(), + } + .mock()], + }); + + assert_eq!( + checker + .check_expression( + Expression::Member( + box Expression::InlineStruct( + "Foo".to_string(), + vec![( + "foo", + Expression::FieldConstant(FieldPrime::from(42)).mock() + )] + ) + .mock(), + "bar".into() + ) + .mock(), + &MODULE_ID.to_string(), + &state.types + ) + .unwrap_err() + .message, + "{foo: field} doesn\'t have member bar" + ); + } + } + + /// tests about defining struct instance inline + mod value { + use super::*; + + #[test] + fn wrong_name() { + // a A value cannot be defined with B as id, even if A and B have the same members + + let (mut checker, state) = create_module_with_foo(StructType { + fields: vec![StructField { + id: "foo", + ty: UnresolvedType::FieldElement.mock(), + } + .mock()], + }); + + assert_eq!( + checker + .check_expression( + Expression::InlineStruct( + "Bar".to_string(), + vec![( + "foo", + Expression::FieldConstant(FieldPrime::from(42)).mock() + )] + ) + .mock(), + &MODULE_ID.to_string(), + &state.types + ) + .unwrap_err() + .message, + "Undefined type Bar" + ); + } + + #[test] + fn valid() { + // a A value can be defined with members ordered as in the declaration of A + + // struct Foo = { foo: field, bar: bool } + // Foo foo = Foo { foo: 42, bar: true } + + let (mut checker, state) = create_module_with_foo(StructType { + fields: vec![ + StructField { + id: "foo", + ty: UnresolvedType::FieldElement.mock(), + } + .mock(), + StructField { + id: "bar", + ty: UnresolvedType::Boolean.mock(), + } + .mock(), + ], + }); + + assert_eq!( + checker.check_expression( + Expression::InlineStruct( + "Foo".to_string(), + vec![ + ( + "foo", + Expression::FieldConstant(FieldPrime::from(42)).mock() + ), + ("bar", Expression::BooleanConstant(true).mock()) + ] + ) + .mock(), + &MODULE_ID.to_string(), + &state.types + ), + Ok(StructExpression { + ty: vec![ + ("foo".to_string(), Type::FieldElement), + ("bar".to_string(), Type::Boolean) + ], + inner: StructExpressionInner::Value(vec![ + FieldElementExpression::Number(FieldPrime::from(42)).into(), + BooleanExpression::Value(true).into() + ]) + } + .into()) + ); + } + + #[test] + fn shuffled() { + // a A value can be defined with shuffled members compared to the declaration of A + + // struct Foo = { foo: field, bar: bool } + // Foo foo = Foo { bar: true, foo: 42 } + + let (mut checker, state) = create_module_with_foo(StructType { + fields: vec![ + StructField { + id: "foo", + ty: UnresolvedType::FieldElement.mock(), + } + .mock(), + StructField { + id: "bar", + ty: UnresolvedType::Boolean.mock(), + } + .mock(), + ], + }); + + assert_eq!( + checker.check_expression( + Expression::InlineStruct( + "Foo".to_string(), + vec![ + ("bar", Expression::BooleanConstant(true).mock()), + ( + "foo", + Expression::FieldConstant(FieldPrime::from(42)).mock() + ) + ] + ) + .mock(), + &MODULE_ID.to_string(), + &state.types + ), + Ok(StructExpression { + ty: vec![ + ("foo".to_string(), Type::FieldElement), + ("bar".to_string(), Type::Boolean) + ], + inner: StructExpressionInner::Value(vec![ + FieldElementExpression::Number(FieldPrime::from(42)).into(), + BooleanExpression::Value(true).into() + ]) + } + .into()) + ); + } + + #[test] + fn subset() { + // a A value cannot be defined with A as id but members being a subset of the declaration + + // struct Foo = { foo: field, bar: bool } + // Foo foo = Foo { foo: 42 } + + let (mut checker, state) = create_module_with_foo(StructType { + fields: vec![ + StructField { + id: "foo", + ty: UnresolvedType::FieldElement.mock(), + } + .mock(), + StructField { + id: "bar", + ty: UnresolvedType::Boolean.mock(), + } + .mock(), + ], + }); + + assert_eq!( + checker + .check_expression( + Expression::InlineStruct( + "Foo".to_string(), + vec![( + "foo", + Expression::FieldConstant(FieldPrime::from(42)).mock() + )] + ) + .mock(), + &MODULE_ID.to_string(), + &state.types + ) + .unwrap_err() + .message, + "Inline struct Foo {foo: 42} does not match Foo : {foo: field, bar: bool}" + ); + } + + #[test] + fn invalid() { + // a A value cannot be defined with A as id but members being different ids than the declaration + // a A value cannot be defined with A as id but members being different types than the declaration + + // struct Foo = { foo: field, bar: bool } + // Foo { foo: 42, baz: bool } // error + // Foo { foo: 42, baz: 42 } // error + + let (mut checker, state) = create_module_with_foo(StructType { + fields: vec![ + StructField { + id: "foo", + ty: UnresolvedType::FieldElement.mock(), + } + .mock(), + StructField { + id: "bar", + ty: UnresolvedType::Boolean.mock(), + } + .mock(), + ], + }); + + assert_eq!( + checker + .check_expression( + Expression::InlineStruct( + "Foo".to_string(), + vec![( + "baz", + Expression::BooleanConstant(true).mock() + ),( + "foo", + Expression::FieldConstant(FieldPrime::from(42)).mock() + )] + ) + .mock(), + &MODULE_ID.to_string(), + &state.types + ).unwrap_err() + .message, + "Member bar of struct Foo : {foo: field, bar: bool} not found in value Foo {baz: true, foo: 42}" + ); + + assert_eq!( + checker + .check_expression( + Expression::InlineStruct( + "Foo".to_string(), + vec![ + ( + "bar", + Expression::FieldConstant(FieldPrime::from(42)).mock() + ), + ( + "foo", + Expression::FieldConstant(FieldPrime::from(42)).mock() + ) + ] + ) + .mock(), + &MODULE_ID.to_string(), + &state.types + ) + .unwrap_err() + .message, + "Member bar of struct Foo has type bool, found 42 of type field" + ); + } + } } mod assignee { @@ -3323,7 +3840,6 @@ mod tests { absy::Variable::new("a", UnresolvedType::FieldElement.mock()).mock(), ) .mock(), - &vec![], &module_id, &types, ) @@ -3363,7 +3879,6 @@ mod tests { .mock(), ) .mock(), - &vec![], &module_id, &types, ) @@ -3417,7 +3932,6 @@ mod tests { .mock(), ) .mock(), - &vec![], &module_id, &types, ) From da57329b2b1b7fd178a7397ed14f47d0eeb0a5a3 Mon Sep 17 00:00:00 2001 From: schaeff Date: Sat, 10 Aug 2019 23:21:28 +0200 Subject: [PATCH 20/35] implement symbol unifier to handle symbol naming rules for functions and types --- zokrates_core/src/absy/from_ast.rs | 16 +- zokrates_core/src/absy/mod.rs | 14 + zokrates_core/src/semantics.rs | 642 +++++++++++++++++--------- zokrates_core/src/typed_absy/types.rs | 4 +- 4 files changed, 456 insertions(+), 220 deletions(-) diff --git a/zokrates_core/src/absy/from_ast.rs b/zokrates_core/src/absy/from_ast.rs index 8cfa3255..f3e9acfa 100644 --- a/zokrates_core/src/absy/from_ast.rs +++ b/zokrates_core/src/absy/from_ast.rs @@ -5,23 +5,17 @@ use zokrates_pest_ast as pest; impl<'ast, T: Field> From> for absy::Module<'ast, T> { fn from(prog: pest::File<'ast>) -> absy::Module { - absy::Module { - symbols: prog - .structs + absy::Module::with_symbols( + prog.structs .into_iter() .map(|t| absy::SymbolDeclarationNode::from(t)) .chain( prog.functions .into_iter() .map(|f| absy::SymbolDeclarationNode::from(f)), - ) - .collect(), - imports: prog - .imports - .into_iter() - .map(|i| absy::ImportNode::from(i)) - .collect(), - } + ), + ) + .imports(prog.imports.into_iter().map(|i| absy::ImportNode::from(i))) } } diff --git a/zokrates_core/src/absy/mod.rs b/zokrates_core/src/absy/mod.rs index d1cc58b5..3d5080e3 100644 --- a/zokrates_core/src/absy/mod.rs +++ b/zokrates_core/src/absy/mod.rs @@ -82,6 +82,20 @@ pub struct Module<'ast, T: Field> { pub imports: Vec>, // we still use `imports` as they are not directly converted into `FunctionDeclaration`s after the importer is done, `imports` is empty } +impl<'ast, T: Field> Module<'ast, T> { + pub fn with_symbols>>(i: I) -> Self { + Module { + symbols: i.into_iter().collect(), + imports: vec![], + } + } + + pub fn imports>>(mut self, i: I) -> Self { + self.imports = i.into_iter().collect(); + self + } +} + pub type UnresolvedTypeNode = Node; /// A struct type definition diff --git a/zokrates_core/src/semantics.rs b/zokrates_core/src/semantics.rs index d3c7c59d..9e9f30c2 100644 --- a/zokrates_core/src/semantics.rs +++ b/zokrates_core/src/semantics.rs @@ -8,7 +8,7 @@ use crate::absy::Identifier; use crate::absy::*; use crate::typed_absy::*; use crate::typed_absy::{Parameter, Variable}; -use std::collections::{HashMap, HashSet}; +use std::collections::{hash_map::Entry, BTreeSet, HashMap, HashSet}; use std::fmt; use zokrates_field::field::Field; @@ -38,6 +38,54 @@ struct State<'ast, T: Field> { types: TypeMap, } +/// A symbol for a given name: either a type, or a group of functions. Not both! +#[derive(PartialEq, Hash, Eq, Debug)] +enum SymbolType { + Type, + Functions(BTreeSet), +} + +/// A data structure to keep track of all symbols in a module +#[derive(Default)] +struct SymbolUnifier { + symbols: HashMap, +} + +impl SymbolUnifier { + fn insert_type>(&mut self, id: S) -> bool { + let s_type = self.symbols.entry(id.into()); + match s_type { + // if anything is already called `id`, we cannot introduce this type + Entry::Occupied(..) => false, + // otherwise, we can! + Entry::Vacant(v) => { + v.insert(SymbolType::Type); + true + } + } + } + + fn insert_function>(&mut self, id: S, signature: Signature) -> bool { + let s_type = self.symbols.entry(id.into()); + match s_type { + // if anything is already called `id`, it depends what it is + Entry::Occupied(mut o) => { + match o.get_mut() { + // if it's a Type, then we can't introduce a function + SymbolType::Type => false, + // if it's a Function, we can introduce a new function only if it has a different signature + SymbolType::Functions(signatures) => signatures.insert(signature), + } + } + // otherwise, we can! + Entry::Vacant(v) => { + v.insert(SymbolType::Functions(vec![signature].into_iter().collect())); + true + } + } + } +} + impl<'ast, T: Field> State<'ast, T> { fn new(modules: Modules<'ast, T>) -> Self { State { @@ -242,7 +290,7 @@ impl<'ast> Checker<'ast> { module_id: &ModuleId, state: &mut State<'ast, T>, functions: &mut HashMap, TypedFunctionSymbol<'ast, T>>, - symbol_id_set: &mut HashSet>, + symbol_unifier: &mut SymbolUnifier, ) -> Result<(), Vec> { let mut errors = vec![]; @@ -253,11 +301,11 @@ impl<'ast> Checker<'ast> { Symbol::HereType(t) => { match self.check_struct_type_declaration(t.clone(), module_id, &state.types) { Ok(ty) => { - match symbol_id_set.insert(declaration.id) { + match symbol_unifier.insert_type(declaration.id) { false => errors.push(Error { pos: Some(pos), message: format!( - "Another symbol with id {} is already defined", + "{} conflicts with another symbol", declaration.id, ), }), @@ -274,47 +322,22 @@ impl<'ast> Checker<'ast> { } Symbol::HereFunction(f) => match self.check_function(f, module_id, &state.types) { Ok(funct) => { - let funct = TypedFunctionSymbol::Here(funct); - - let query = FunctionQuery::new( - declaration.id.clone(), - &funct.signature(&state.typed_modules).inputs, - &funct - .signature(&state.typed_modules) - .outputs - .clone() - .into_iter() - .map(|o| Some(o)) - .collect(), - ); - - let candidates = self.find_candidates(&query); - - match candidates.len() { - 1 => { - errors.push(Error { - pos: Some(pos), - message: format!( - "Duplicate definition for function {} with signature {}", - declaration.id, - funct.signature(&state.typed_modules) - ), - }); - } - 0 => {} - _ => panic!("duplicate function declaration should have been caught"), - } - - symbol_id_set.insert(declaration.id); + match symbol_unifier.insert_function(declaration.id, funct.signature.clone()) { + false => errors.push(Error { + pos: Some(pos), + message: format!("{} conflicts with another symbol", declaration.id,), + }), + true => {} + }; self.functions.insert( FunctionKey::with_id(declaration.id.clone()) - .signature(funct.signature(&state.typed_modules).clone()), + .signature(funct.signature.clone()), ); functions.insert( FunctionKey::with_id(declaration.id.clone()) - .signature(funct.signature(&state.typed_modules).clone()), - funct, + .signature(funct.signature.clone()), + TypedFunctionSymbol::Here(funct), ); } Err(e) => { @@ -351,7 +374,19 @@ impl<'ast> Checker<'ast> { match (function_candidates.len(), type_candidate) { (0, Some(t)) => { - symbol_id_set.insert(declaration.id); + // we imported a type, so the symbol it gets bound to should not already exist + match symbol_unifier.insert_type(declaration.id) { + false => { + errors.push(Error { + pos: Some(pos), + message: format!( + "{} conflicts with another symbol", + declaration.id, + ), + }); + } + true => {} + }; state .types .entry(module_id.clone()) @@ -367,9 +402,23 @@ impl<'ast> Checker<'ast> { ), }); } + (_, Some(_)) => unreachable!("collision in module we're importing from should have been caught when checking it"), _ => { - symbol_id_set.insert(declaration.id); for candidate in function_candidates { + + match symbol_unifier.insert_function(declaration.id, candidate.signature.clone()) { + false => { + errors.push(Error { + pos: Some(pos), + message: format!( + "{} conflicts with another symbol", + declaration.id, + ), + }); + }, + true => {} + }; + self.functions.insert(candidate.clone().id(declaration.id)); functions.insert( candidate.clone().id(declaration.id), @@ -380,7 +429,7 @@ impl<'ast> Checker<'ast> { ); } } - } + }; } Err(e) => { errors.extend(e); @@ -388,35 +437,16 @@ impl<'ast> Checker<'ast> { }; } Symbol::Flat(funct) => { - let query = FunctionQuery::new( - declaration.id.clone(), - &funct.signature::().inputs, - &funct - .signature::() - .outputs - .clone() - .into_iter() - .map(|o| Some(o)) - .collect(), - ); - - let candidates = self.find_candidates(&query); - - match candidates.len() { - 1 => { + match symbol_unifier.insert_function(declaration.id, funct.signature::()) { + false => { errors.push(Error { pos: Some(pos), - message: format!( - "Duplicate definition for function {} with signature {}", - declaration.id, - funct.signature::() - ), + message: format!("{} conflicts with another symbol", declaration.id,), }); } - 0 => {} - _ => panic!("duplicate function declaration should have been caught"), - } - symbol_id_set.insert(declaration.id); + true => {} + }; + self.functions.insert( FunctionKey::with_id(declaration.id.clone()) .signature(funct.signature::().clone()), @@ -457,7 +487,7 @@ impl<'ast> Checker<'ast> { state.types.entry(module_id.clone()).or_default(); // we keep track of the introduced symbols to avoid colisions between types and functions - let mut symbol_id_set = HashSet::new(); + let mut symbol_unifier = SymbolUnifier::default(); // we go through symbol declarations and check them for declaration in module.symbols { @@ -466,7 +496,7 @@ impl<'ast> Checker<'ast> { module_id, state, &mut checked_functions, - &mut symbol_id_set, + &mut symbol_unifier, ) { Ok(()) => {} Err(e) => { @@ -1798,6 +1828,8 @@ mod tests { use typed_absy; use zokrates_field::field::FieldPrime; + const MODULE_ID: &str = ""; + mod array { use super::*; @@ -1855,11 +1887,75 @@ mod tests { mod symbols { use super::*; + /// Helper function to create (() -> (): return) + fn function0() -> FunctionNode<'static, FieldPrime> { + let statements: Vec> = vec![Statement::Return( + ExpressionList { + expressions: vec![], + } + .mock(), + ) + .mock()]; + + let arguments = vec![]; + + let signature = UnresolvedSignature::new(); + + Function { + arguments, + statements, + signature, + } + .mock() + } + + /// Helper function to create ((private field a) -> (): return) + fn function1() -> FunctionNode<'static, FieldPrime> { + let statements: Vec> = vec![Statement::Return( + ExpressionList { + expressions: vec![], + } + .mock(), + ) + .mock()]; + + let arguments = vec![absy::Parameter { + id: absy::Variable::new("a", UnresolvedType::FieldElement.mock()).mock(), + private: true, + } + .mock()]; + + let signature = + UnresolvedSignature::new().inputs(vec![UnresolvedType::FieldElement.mock()]); + + Function { + arguments, + statements, + signature, + } + .mock() + } + + fn struct0() -> StructTypeNode<'static> { + StructType { fields: vec![] }.mock() + } + + fn struct1() -> StructTypeNode<'static> { + StructType { + fields: vec![StructField { + id: "foo".into(), + ty: UnresolvedType::FieldElement.mock(), + } + .mock()], + } + .mock() + } + #[test] - fn imported_symbol() { + fn imported_function() { // foo.code - // def main() -> (field): - // return 1 + // def main() -> (): + // return // bar.code // from "./foo.code" import main @@ -1869,24 +1965,7 @@ mod tests { let foo: Module = Module { symbols: vec![SymbolDeclaration { id: "main", - symbol: Symbol::HereFunction( - Function { - statements: vec![Statement::Return( - ExpressionList { - expressions: vec![Expression::FieldConstant(FieldPrime::from( - 1, - )) - .mock()], - } - .mock(), - ) - .mock()], - signature: UnresolvedSignature::new() - .outputs(vec![UnresolvedType::FieldElement.mock()]), - arguments: vec![], - } - .mock(), - ), + symbol: Symbol::HereFunction(function0()), } .mock()], imports: vec![], @@ -1909,18 +1988,17 @@ mod tests { let mut checker = Checker::new(); - checker - .check_module(&String::from("bar"), &mut state) - .unwrap(); + assert_eq!( + checker.check_module(&String::from("bar"), &mut state), + Ok(()) + ); assert_eq!( state.typed_modules.get(&String::from("bar")), Some(&TypedModule { functions: vec![( - FunctionKey::with_id("main") - .signature(Signature::new().outputs(vec![Type::FieldElement])), + FunctionKey::with_id("main").signature(Signature::new()), TypedFunctionSymbol::There( - FunctionKey::with_id("main") - .signature(Signature::new().outputs(vec![Type::FieldElement])), + FunctionKey::with_id("main").signature(Signature::new()), "foo".to_string() ) )] @@ -1929,6 +2007,266 @@ mod tests { }) ); } + + #[test] + fn duplicate_function_declaration() { + // def foo(): + // return + // def foo(): + // return + // + // should fail + + let module = Module { + symbols: vec![ + SymbolDeclaration { + id: "foo", + symbol: Symbol::HereFunction(function0()), + } + .mock(), + SymbolDeclaration { + id: "foo", + symbol: Symbol::HereFunction(function0()), + } + .mock(), + ], + imports: vec![], + }; + + let mut state = State::new(vec![(MODULE_ID.to_string(), module)].into_iter().collect()); + + let mut checker = Checker::new(); + assert_eq!( + checker + .check_module(&MODULE_ID.to_string(), &mut state) + .unwrap_err()[0] + .message, + "foo conflicts with another symbol" + ); + } + + #[test] + fn overloaded_function_declaration() { + // def foo(): + // return + // def foo(a): + // return + // + // should succeed as overloading is allowed + + let module = Module { + symbols: vec![ + SymbolDeclaration { + id: "foo", + symbol: Symbol::HereFunction(function0()), + } + .mock(), + SymbolDeclaration { + id: "foo", + symbol: Symbol::HereFunction(function1()), + } + .mock(), + ], + imports: vec![], + }; + + let mut state = State::new(vec![(MODULE_ID.to_string(), module)].into_iter().collect()); + + let mut checker = Checker::new(); + assert_eq!( + checker.check_module(&MODULE_ID.to_string(), &mut state), + Ok(()) + ); + assert!(state + .typed_modules + .get(&MODULE_ID.to_string()) + .unwrap() + .functions + .contains_key(&FunctionKey::with_id("foo").signature(Signature::new()))); + assert!(state + .typed_modules + .get(&MODULE_ID.to_string()) + .unwrap() + .functions + .contains_key( + &FunctionKey::with_id("foo") + .signature(Signature::new().inputs(vec![Type::FieldElement])) + )) + } + + #[test] + fn duplicate_type_declaration() { + // struct Foo {} + // struct Foo { foo: field } + // + // should fail + + let module: Module = Module { + symbols: vec![ + SymbolDeclaration { + id: "foo", + symbol: Symbol::HereType(struct0()), + } + .mock(), + SymbolDeclaration { + id: "foo", + symbol: Symbol::HereType(struct1()), + } + .mock(), + ], + imports: vec![], + }; + + let mut state = State::new(vec![(String::from("main"), module)].into_iter().collect()); + + let mut checker = Checker::new(); + assert_eq!( + checker + .check_module(&String::from("main"), &mut state) + .unwrap_err()[0] + .message, + "foo conflicts with another symbol" + ); + } + + #[test] + fn type_function_conflict() { + // struct foo {} + // def foo(): + // return + // + // should fail + + let module = Module { + symbols: vec![ + SymbolDeclaration { + id: "foo", + symbol: Symbol::HereFunction(function0()), + } + .mock(), + SymbolDeclaration { + id: "foo", + symbol: Symbol::HereType(StructType { fields: vec![] }.mock()), + } + .mock(), + ], + imports: vec![], + }; + + let mut state = State::new(vec![(String::from("main"), module)].into_iter().collect()); + + let mut checker = Checker::new(); + assert_eq!( + checker + .check_module(&String::from("main"), &mut state) + .unwrap_err()[0] + .message, + "foo conflicts with another symbol" + ); + } + + #[test] + fn type_imported_function_conflict() { + // import first + + // // bar.code + // def main() -> (): return + // + // // main.code + // import main from "bar" as foo + // struct foo {} + // + // should fail + + let bar = Module::with_symbols(vec![SymbolDeclaration { + id: "main", + symbol: Symbol::HereFunction(function0()), + } + .mock()]); + + let main = Module { + symbols: vec![ + SymbolDeclaration { + id: "foo", + symbol: Symbol::There( + SymbolImport::with_id_in_module("main", "bar".to_string()).mock(), + ), + } + .mock(), + SymbolDeclaration { + id: "foo", + symbol: Symbol::HereType(struct0()), + } + .mock(), + ], + imports: vec![], + }; + + let mut state = State::new( + vec![(MODULE_ID.to_string(), main), ("bar".to_string(), bar)] + .into_iter() + .collect(), + ); + + let mut checker = Checker::new(); + assert_eq!( + checker + .check_module(&MODULE_ID.to_string(), &mut state) + .unwrap_err()[0] + .message, + "foo conflicts with another symbol" + ); + + // type declaration first + + // // bar.code + // def main() -> (): return + // + // // main.code + // struct foo {} + // import main from "bar" as foo + // + // should fail + + let bar = Module::with_symbols(vec![SymbolDeclaration { + id: "main", + symbol: Symbol::HereFunction(function0()), + } + .mock()]); + + let main = Module { + symbols: vec![ + SymbolDeclaration { + id: "foo", + symbol: Symbol::HereType(struct0()), + } + .mock(), + SymbolDeclaration { + id: "foo", + symbol: Symbol::There( + SymbolImport::with_id_in_module("main", "bar".to_string()).mock(), + ), + } + .mock(), + ], + imports: vec![], + }; + + let mut state = State::new( + vec![(MODULE_ID.to_string(), main), ("bar".to_string(), bar)] + .into_iter() + .collect(), + ); + + let mut checker = Checker::new(); + assert_eq!( + checker + .check_module(&MODULE_ID.to_string(), &mut state) + .unwrap_err()[0] + .message, + "foo conflicts with another symbol" + ); + } } pub fn new_with_args<'ast>( @@ -2729,114 +3067,6 @@ mod tests { ); } - #[test] - fn duplicate_function_declaration() { - // def foo(a, b): - // return 1 - // def foo(c, d): - // return 2 - // - // should fail - - let foo1_statements: Vec> = vec![Statement::Return( - ExpressionList { - expressions: vec![Expression::FieldConstant(FieldPrime::from(1)).mock()], - } - .mock(), - ) - .mock()]; - - let foo1_arguments = vec![ - crate::absy::Parameter { - id: absy::Variable::new("a", UnresolvedType::FieldElement.mock()).mock(), - private: true, - } - .mock(), - crate::absy::Parameter { - id: absy::Variable::new("b", UnresolvedType::FieldElement.mock()).mock(), - private: true, - } - .mock(), - ]; - - let foo2_statements: Vec> = vec![Statement::Return( - ExpressionList { - expressions: vec![Expression::FieldConstant(FieldPrime::from(1)).mock()], - } - .mock(), - ) - .mock()]; - - let foo2_arguments = vec![ - crate::absy::Parameter { - id: absy::Variable::new("c", UnresolvedType::FieldElement.mock()).mock(), - private: true, - } - .mock(), - crate::absy::Parameter { - id: absy::Variable::new("d", UnresolvedType::FieldElement.mock()).mock(), - private: true, - } - .mock(), - ]; - - let foo1 = Function { - arguments: foo1_arguments, - statements: foo1_statements, - signature: UnresolvedSignature { - inputs: vec![ - UnresolvedType::FieldElement.mock(), - UnresolvedType::FieldElement.mock(), - ], - outputs: vec![UnresolvedType::FieldElement.mock()], - }, - } - .mock(); - - let foo2 = Function { - arguments: foo2_arguments, - statements: foo2_statements, - signature: UnresolvedSignature { - inputs: vec![ - UnresolvedType::FieldElement.mock(), - UnresolvedType::FieldElement.mock(), - ], - outputs: vec![UnresolvedType::FieldElement.mock()], - }, - } - .mock(); - - let module = Module { - symbols: vec![ - SymbolDeclaration { - id: "foo", - symbol: Symbol::HereFunction(foo1), - } - .mock(), - SymbolDeclaration { - id: "foo", - symbol: Symbol::HereFunction(foo2), - } - .mock(), - ], - imports: vec![], - }; - - let mut state = State::new(vec![(String::from("main"), module)].into_iter().collect()); - - let mut checker = Checker::new(); - assert_eq!( - checker.check_module(&String::from("main"), &mut state), - Err(vec![Error { - pos: Some((Position::mock(), Position::mock())), - - message: - "Duplicate definition for function foo with signature (field, field) -> (field)" - .to_string() - }]) - ); - } - #[test] fn duplicate_main_function() { // def main(a): @@ -2996,8 +3226,6 @@ mod tests { mod structs { use super::*; - const MODULE_ID: &str = ""; - /// helper function to create a module at location "" with a single symbol `Foo { foo: field }` fn create_module_with_foo( s: StructType<'static>, diff --git a/zokrates_core/src/typed_absy/types.rs b/zokrates_core/src/typed_absy/types.rs index 176d6509..24985526 100644 --- a/zokrates_core/src/typed_absy/types.rs +++ b/zokrates_core/src/typed_absy/types.rs @@ -4,7 +4,7 @@ pub type Identifier<'ast> = &'ast str; pub type MemberId = String; -#[derive(Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] +#[derive(Clone, PartialEq, Eq, Hash, Serialize, Deserialize, PartialOrd, Ord)] pub enum Type { FieldElement, Boolean, @@ -119,7 +119,7 @@ pub mod signature { use super::*; use std::fmt; - #[derive(Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] + #[derive(Clone, PartialEq, Eq, Hash, Serialize, Deserialize, Ord, PartialOrd)] pub struct Signature { pub inputs: Vec, pub outputs: Vec, From 1f21b85389957df4a93805f70b0e0ac69a722beb Mon Sep 17 00:00:00 2001 From: schaeff Date: Sun, 11 Aug 2019 16:15:05 +0200 Subject: [PATCH 21/35] add test for unifier --- zokrates_core/src/semantics.rs | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/zokrates_core/src/semantics.rs b/zokrates_core/src/semantics.rs index 9e9f30c2..2876ca90 100644 --- a/zokrates_core/src/semantics.rs +++ b/zokrates_core/src/semantics.rs @@ -1951,6 +1951,23 @@ mod tests { .mock() } + #[test] + fn unifier() { + // the unifier should only accept either a single type or many functions of different signatures for each symbol + + let mut unifier = SymbolUnifier::default(); + + assert!(unifier.insert_type("foo")); + assert!(!unifier.insert_type("foo")); + assert!(!unifier.insert_function("foo", Signature::new())); + assert!(unifier.insert_function("bar", Signature::new())); + assert!(!unifier.insert_function("bar", Signature::new())); + assert!( + unifier.insert_function("bar", Signature::new().inputs(vec![Type::FieldElement])) + ); + assert!(!unifier.insert_type("bar")); + } + #[test] fn imported_function() { // foo.code From 08bc0dbf7658a2f6d12cb245a807e8fb811ee2a6 Mon Sep 17 00:00:00 2001 From: schaeff Date: Tue, 27 Aug 2019 16:13:47 +0200 Subject: [PATCH 22/35] remove try --- zokrates_core/src/absy/mod.rs | 10 +++++----- zokrates_core/src/absy/types.rs | 12 ++++++------ zokrates_core/src/typed_absy/mod.rs | 10 +++++----- zokrates_core/src/typed_absy/types.rs | 12 ++++++------ 4 files changed, 22 insertions(+), 22 deletions(-) diff --git a/zokrates_core/src/absy/mod.rs b/zokrates_core/src/absy/mod.rs index 026aa52c..8ec70586 100644 --- a/zokrates_core/src/absy/mod.rs +++ b/zokrates_core/src/absy/mod.rs @@ -524,11 +524,11 @@ impl<'ast, T: Field> fmt::Display for Expression<'ast, T> { write!(f, "]") } Expression::InlineStruct(ref id, ref members) => { - r#try!(write!(f, "{} {{", id)); + write!(f, "{} {{", id)?; for (i, (member_id, e)) in members.iter().enumerate() { - r#try!(write!(f, "{}: {}", member_id, e)); + write!(f, "{}: {}", member_id, e)?; if i < members.len() - 1 { - r#try!(write!(f, ", ")); + write!(f, ", ")?; } } write!(f, "}}") @@ -574,8 +574,8 @@ impl<'ast, T: Field> fmt::Debug for Expression<'ast, T> { write!(f, "]") } Expression::InlineStruct(ref id, ref members) => { - r#try!(write!(f, "InlineStruct({:?}, [", id)); - r#try!(f.debug_list().entries(members.iter()).finish()); + write!(f, "InlineStruct({:?}, [", id)?; + f.debug_list().entries(members.iter()).finish()?; write!(f, "]") } Expression::Select(ref array, ref index) => write!(f, "{}[{}]", array, index), diff --git a/zokrates_core/src/absy/types.rs b/zokrates_core/src/absy/types.rs index d5e22303..1dbbc718 100644 --- a/zokrates_core/src/absy/types.rs +++ b/zokrates_core/src/absy/types.rs @@ -59,18 +59,18 @@ mod signature { impl fmt::Display for UnresolvedSignature { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - r#try!(write!(f, "(")); + write!(f, "(")?; for (i, t) in self.inputs.iter().enumerate() { - r#try!(write!(f, "{}", t)); + write!(f, "{}", t)?; if i < self.inputs.len() - 1 { - r#try!(write!(f, ", ")); + write!(f, ", ")?; } } - r#try!(write!(f, ") -> (")); + write!(f, ") -> (")?; for (i, t) in self.outputs.iter().enumerate() { - r#try!(write!(f, "{}", t)); + write!(f, "{}", t)?; if i < self.outputs.len() - 1 { - r#try!(write!(f, ", ")); + write!(f, ", ")?; } } write!(f, ")") diff --git a/zokrates_core/src/typed_absy/mod.rs b/zokrates_core/src/typed_absy/mod.rs index 1e1a9f56..636355f0 100644 --- a/zokrates_core/src/typed_absy/mod.rs +++ b/zokrates_core/src/typed_absy/mod.rs @@ -809,11 +809,11 @@ impl<'ast, T: Field> fmt::Display for StructExpressionInner<'ast, T> { .join(", ") ), StructExpressionInner::FunctionCall(ref key, ref p) => { - r#try!(write!(f, "{}(", key.id,)); + write!(f, "{}(", key.id,)?; for (i, param) in p.iter().enumerate() { - r#try!(write!(f, "{}", param)); + write!(f, "{}", param)?; if i < p.len() - 1 { - r#try!(write!(f, ", ")); + write!(f, ", ")?; } } write!(f, ")") @@ -902,8 +902,8 @@ impl<'ast, T: Field> fmt::Debug for StructExpressionInner<'ast, T> { StructExpressionInner::Identifier(ref var) => write!(f, "{:?}", var), StructExpressionInner::Value(ref values) => write!(f, "{:?}", values), StructExpressionInner::FunctionCall(ref i, ref p) => { - r#try!(write!(f, "FunctionCall({:?}, (", i)); - r#try!(f.debug_list().entries(p.iter()).finish()); + write!(f, "FunctionCall({:?}, (", i)?; + f.debug_list().entries(p.iter()).finish()?; write!(f, ")") } StructExpressionInner::IfElse(ref condition, ref consequent, ref alternative) => { diff --git a/zokrates_core/src/typed_absy/types.rs b/zokrates_core/src/typed_absy/types.rs index 24985526..f4c4a7c0 100644 --- a/zokrates_core/src/typed_absy/types.rs +++ b/zokrates_core/src/typed_absy/types.rs @@ -137,18 +137,18 @@ pub mod signature { impl fmt::Display for Signature { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - r#try!(write!(f, "(")); + write!(f, "(")?; for (i, t) in self.inputs.iter().enumerate() { - r#try!(write!(f, "{}", t)); + write!(f, "{}", t)?; if i < self.inputs.len() - 1 { - r#try!(write!(f, ", ")); + write!(f, ", ")?; } } - r#try!(write!(f, ") -> (")); + write!(f, ") -> (")?; for (i, t) in self.outputs.iter().enumerate() { - r#try!(write!(f, "{}", t)); + write!(f, "{}", t)?; if i < self.outputs.len() - 1 { - r#try!(write!(f, ", ")); + write!(f, ", ")?; } } write!(f, ")") From fd69f75e4a4017d25006b7ed6e25c160f6160450 Mon Sep 17 00:00:00 2001 From: schaeff Date: Tue, 27 Aug 2019 16:28:24 +0200 Subject: [PATCH 23/35] remove commented out impl --- zokrates_core/src/typed_absy/parameter.rs | 9 --------- 1 file changed, 9 deletions(-) diff --git a/zokrates_core/src/typed_absy/parameter.rs b/zokrates_core/src/typed_absy/parameter.rs index a4d42bbc..277ac537 100644 --- a/zokrates_core/src/typed_absy/parameter.rs +++ b/zokrates_core/src/typed_absy/parameter.rs @@ -29,12 +29,3 @@ impl<'ast> fmt::Debug for Parameter<'ast> { write!(f, "Parameter(variable: {:?})", self.id) } } - -// impl<'ast> From> for Parameter<'ast> { -// fn from(p: absy::Parameter<'ast>) -> Parameter { -// Parameter { -// private: p.private, -// id: p.id.value.into(), -// } -// } -// } From 6ab1a96ca646662078d299378af4ace8bfaff2eb Mon Sep 17 00:00:00 2001 From: schaeff Date: Wed, 11 Sep 2019 19:11:39 +0200 Subject: [PATCH 24/35] add assignment to struct members --- zokrates_cli/examples/structs/add.code | 13 + zokrates_cli/examples/structs/set_member.code | 14 + zokrates_core/src/absy/from_ast.rs | 18 +- zokrates_core/src/absy/mod.rs | 2 + zokrates_core/src/compile.rs | 4 - zokrates_core/src/flatten/mod.rs | 127 +----- zokrates_core/src/semantics.rs | 132 ++++-- zokrates_core/src/static_analysis/mod.rs | 1 + .../src/static_analysis/propagation.rs | 79 ++++ zokrates_core/src/static_analysis/unroll.rs | 393 +++++++++++------- zokrates_core/src/typed_absy/folder.rs | 1 + zokrates_core/src/typed_absy/mod.rs | 11 + zokrates_parser/src/zokrates.pest | 3 +- zokrates_pest_ast/src/lib.rs | 25 +- 14 files changed, 509 insertions(+), 314 deletions(-) create mode 100644 zokrates_cli/examples/structs/add.code create mode 100644 zokrates_cli/examples/structs/set_member.code diff --git a/zokrates_cli/examples/structs/add.code b/zokrates_cli/examples/structs/add.code new file mode 100644 index 00000000..3941f08f --- /dev/null +++ b/zokrates_cli/examples/structs/add.code @@ -0,0 +1,13 @@ +struct Point { + x: field, + y: field +} + +def main(Point p, Point q) -> (Point): + + field a = 42 + field d = 21 + + field dpxpyqxqy = d * p.x * p.y * q.x * q.y + + return Point { x: (p.x * q.y + q.x * p.y) / (1 + dpxpyqxqy) , y: (q.x * q.y - a * p.x * p.y) / (1 - dpxpyqxqy) } diff --git a/zokrates_cli/examples/structs/set_member.code b/zokrates_cli/examples/structs/set_member.code new file mode 100644 index 00000000..99872793 --- /dev/null +++ b/zokrates_cli/examples/structs/set_member.code @@ -0,0 +1,14 @@ +struct Bar { + c: field[2], + d: bool +} + +struct Foo { + a: Bar, + b: bool +} + +def main() -> (Foo): + Foo[2] f = [Foo { a: Bar { c: [0, 0], d: false }, b: true}, Foo { a: Bar {c: [0, 0], d: false}, b: true}] + f[0].a.c = [42, 43] + return f[0] diff --git a/zokrates_core/src/absy/from_ast.rs b/zokrates_core/src/absy/from_ast.rs index 50fbc640..050861bc 100644 --- a/zokrates_core/src/absy/from_ast.rs +++ b/zokrates_core/src/absy/from_ast.rs @@ -588,13 +588,17 @@ impl<'ast, T: Field> From> for absy::AssigneeNode<'ast, T> let a = absy::AssigneeNode::from(assignee.id); let span = assignee.span; - assignee - .indices - .into_iter() - .map(|i| absy::RangeOrExpression::from(i)) - .fold(a, |acc, s| { - absy::Assignee::Select(box acc, box s).span(span.clone()) - }) + assignee.accesses.into_iter().fold(a, |acc, s| { + match s { + pest::AssigneeAccess::Select(s) => { + absy::Assignee::Select(box acc, box absy::RangeOrExpression::from(s.expression)) + } + pest::AssigneeAccess::Member(m) => { + absy::Assignee::Member(box acc, box m.id.span.as_str()) + } + } + .span(span.clone()) + }) } } diff --git a/zokrates_core/src/absy/mod.rs b/zokrates_core/src/absy/mod.rs index 4fba23f5..39028d6c 100644 --- a/zokrates_core/src/absy/mod.rs +++ b/zokrates_core/src/absy/mod.rs @@ -254,6 +254,7 @@ impl<'ast, T: Field> fmt::Debug for Function<'ast, T> { pub enum Assignee<'ast, T: Field> { Identifier(Identifier<'ast>), Select(Box>, Box>), + Member(Box>, Box>), } pub type AssigneeNode<'ast, T> = Node>; @@ -263,6 +264,7 @@ impl<'ast, T: Field> fmt::Debug for Assignee<'ast, T> { match *self { Assignee::Identifier(ref s) => write!(f, "Identifier({:?})", s), Assignee::Select(ref a, ref e) => write!(f, "Select({:?}[{:?}])", a, e), + Assignee::Member(ref s, ref m) => write!(f, "Member({:?}.{:?})", s, m), } } } diff --git a/zokrates_core/src/compile.rs b/zokrates_core/src/compile.rs index fb215000..82990bcf 100644 --- a/zokrates_core/src/compile.rs +++ b/zokrates_core/src/compile.rs @@ -152,13 +152,9 @@ pub fn compile>( ) })?; - println!("{}", typed_ast); - // analyse (unroll and constant propagation) let typed_ast = typed_ast.analyse(); - println!("{}", typed_ast); - // flatten input program let program_flattened = Flattener::flatten(typed_ast); diff --git a/zokrates_core/src/flatten/mod.rs b/zokrates_core/src/flatten/mod.rs index 68371f88..bd1e4543 100644 --- a/zokrates_core/src/flatten/mod.rs +++ b/zokrates_core/src/flatten/mod.rs @@ -279,115 +279,27 @@ impl<'ast, T: Field> Flattener<'ast, T> { .collect() } StructExpressionInner::Select(box array, box index) => { - // If the struct is an array element `array[index]`, we're accessing `array[index].member` - // We construct `array := array.map(|e| e.member)` and access `array[index]` - let ty = members - .clone() - .into_iter() + let offset = members + .iter() + .take_while(|(id, _)| *id != member_id) + .map(|(_, ty)| ty.get_primitive_count()) + .sum(); + + // we also need the size of this member + let size = members + .iter() .find(|(id, _)| *id == member_id) .unwrap() - .1; + .1 + .get_primitive_count(); - match ty { - Type::FieldElement => { - let array = ArrayExpressionInner::Value( - (0..array.size()) - .map(|i| { - FieldElementExpression::Member( - box StructExpressionInner::Select( - box array.clone(), - box FieldElementExpression::Number(T::from(i)), - ) - .annotate(members.clone()), - member_id.clone(), - ) - .into() - }) - .collect(), - ) - .annotate(Type::FieldElement, array.size()); - self.flatten_select_expression::>( - symbols, - statements_flattened, - array, - index, - ) - } - Type::Boolean => { - let array = ArrayExpressionInner::Value( - (0..array.size()) - .map(|i| { - BooleanExpression::Member( - box StructExpressionInner::Select( - box array.clone(), - box FieldElementExpression::Number(T::from(i)), - ) - .annotate(members.clone()), - member_id.clone(), - ) - .into() - }) - .collect(), - ) - .annotate(Type::Boolean, array.size()); - self.flatten_select_expression::>( - symbols, - statements_flattened, - array, - index, - ) - } - Type::Struct(m) => { - let array = ArrayExpressionInner::Value( - (0..array.size()) - .map(|i| { - StructExpressionInner::Member( - box StructExpressionInner::Select( - box array.clone(), - box FieldElementExpression::Number(T::from(i)), - ) - .annotate(members.clone()), - member_id.clone(), - ) - .annotate(m.clone()) - .into() - }) - .collect(), - ) - .annotate(Type::Struct(m.clone()), array.size()); - self.flatten_select_expression::>( - symbols, - statements_flattened, - array, - index, - ) - } - Type::Array(box ty, size) => { - let array = ArrayExpressionInner::Value( - (0..array.size()) - .map(|i| { - ArrayExpressionInner::Member( - box StructExpressionInner::Select( - box array.clone(), - box FieldElementExpression::Number(T::from(i)), - ) - .annotate(members.clone()), - member_id.clone(), - ) - .annotate(ty.clone(), size) - .into() - }) - .collect(), - ) - .annotate(Type::Array(box ty.clone(), size), array.size()); - self.flatten_select_expression::>( - symbols, - statements_flattened, - array, - index, - ) - } - } + self.flatten_select_expression::>( + symbols, + statements_flattened, + array, + index, + )[offset..offset + size] + .to_vec() } StructExpressionInner::FunctionCall(..) => unreachable!(), StructExpressionInner::IfElse(box condition, box consequence, box alternative) => { @@ -1550,6 +1462,9 @@ impl<'ast, T: Field> Flattener<'ast, T> { TypedAssignee::Select(..) => unreachable!( "array element redefs should have been replaced by array redefs in unroll" ), + TypedAssignee::Member(..) => unreachable!( + "struct member redefs should have been replaced by struct redef in unroll" + ), } } TypedStatement::Condition(lhs, rhs) => { diff --git a/zokrates_core/src/semantics.rs b/zokrates_core/src/semantics.rs index 230f8586..fe9a3f53 100644 --- a/zokrates_core/src/semantics.rs +++ b/zokrates_core/src/semantics.rs @@ -623,7 +623,7 @@ impl<'ast> Checker<'ast> { statements_checked.push(statement); } Err(e) => { - errors.push(e); + errors.extend(e); } } } @@ -748,21 +748,23 @@ impl<'ast> Checker<'ast> { stat: StatementNode<'ast, T>, module_id: &ModuleId, types: &TypeMap, - ) -> Result, Error> { + ) -> Result, Vec> { let pos = stat.pos(); match stat.value { Statement::Return(list) => { let mut expression_list_checked = vec![]; for e in list.value.expressions { - let e_checked = self.check_expression(e, module_id, &types)?; + let e_checked = self + .check_expression(e, module_id, &types) + .map_err(|e| vec![e])?; expression_list_checked.push(e_checked); } Ok(TypedStatement::Return(expression_list_checked)) } Statement::Declaration(var) => { - let var = self.check_variable(var, module_id, types).unwrap(); + let var = self.check_variable(var, module_id, types)?; match self.insert_into_scope(var.clone()) { true => Ok(TypedStatement::Declaration(var)), false => Err(Error { @@ -770,6 +772,7 @@ impl<'ast> Checker<'ast> { message: format!("Duplicate declaration for variable named {}", var.id), }), } + .map_err(|e| vec![e]) } Statement::Definition(assignee, expr) => { // we create multidef when rhs is a function call to benefit from inference @@ -780,11 +783,15 @@ impl<'ast> Checker<'ast> { } // check the expression to be assigned - let checked_expr = self.check_expression(expr, module_id, &types)?; + let checked_expr = self + .check_expression(expr, module_id, &types) + .map_err(|e| vec![e])?; let expression_type = checked_expr.get_type(); // check that the assignee is declared and is well formed - let var = self.check_assignee(assignee, module_id, &types)?; + let var = self + .check_assignee(assignee, module_id, &types) + .map_err(|e| vec![e])?; let var_type = var.get_type(); @@ -799,10 +806,15 @@ impl<'ast> Checker<'ast> { ), }), } + .map_err(|e| vec![e]) } Statement::Condition(lhs, rhs) => { - let checked_lhs = self.check_expression(lhs, module_id, &types)?; - let checked_rhs = self.check_expression(rhs, module_id, &types)?; + let checked_lhs = self + .check_expression(lhs, module_id, &types) + .map_err(|e| vec![e])?; + let checked_rhs = self + .check_expression(rhs, module_id, &types) + .map_err(|e| vec![e])?; if checked_lhs.get_type() == checked_rhs.get_type() { Ok(TypedStatement::Condition(checked_lhs, checked_rhs)) @@ -818,11 +830,12 @@ impl<'ast> Checker<'ast> { ), }) } + .map_err(|e| vec![e]) } Statement::For(var, from, to, statements) => { self.enter_scope(); - self.check_for_var(&var)?; + self.check_for_var(&var).map_err(|e| vec![e])?; let var = self.check_variable(var, module_id, types).unwrap(); @@ -856,14 +869,14 @@ impl<'ast> Checker<'ast> { ref a => Err(Error { pos: Some(pos), message: format!("Left hand side of function return assignment must be a list of identifiers, found {}", a)}) - }?; + }.map_err(|e| vec![e])?; vars_types.push(t); var_names.push(name); } // find arguments types let mut arguments_checked = vec![]; for arg in arguments { - let arg_checked = self.check_expression(arg, module_id, &types)?; + let arg_checked = self.check_expression(arg, module_id, &types).map_err(|e| vec![e])?; arguments_checked.push(arg_checked); } @@ -903,7 +916,7 @@ impl<'ast> Checker<'ast> { pos: Some(pos), message: format!("{} should be a FunctionCall", rhs), }), - } + }.map_err(|e| vec![e]) } } } @@ -930,33 +943,68 @@ impl<'ast> Checker<'ast> { Assignee::Select(box assignee, box index) => { let checked_assignee = self.check_assignee(assignee, module_id, &types)?; - let checked_index = match index { - RangeOrExpression::Expression(e) => { - self.check_expression(e, module_id, &types)? - } - r => unimplemented!( - "Using slices in assignments is not supported yet, found {}", - r - ), - }; + let ty = checked_assignee.get_type(); + match ty { + Type::Array(..) => { + let checked_index = match index { + RangeOrExpression::Expression(e) => { + self.check_expression(e, module_id, &types)? + } + r => unimplemented!( + "Using slices in assignments is not supported yet, found {}", + r + ), + }; - let checked_typed_index = match checked_index { - TypedExpression::FieldElement(e) => Ok(e), - e => Err(Error { + let checked_typed_index = match checked_index { + TypedExpression::FieldElement(e) => Ok(e), + e => Err(Error { + pos: Some(pos), + + message: format!( + "Expected array {} index to have type field, found {}", + checked_assignee, + e.get_type() + ), + }), + }?; + + Ok(TypedAssignee::Select( + box checked_assignee, + box checked_typed_index, + )) + } + ty => Err(Error { pos: Some(pos), message: format!( - "Expected array {} index to have type field, found {}", - checked_assignee, - e.get_type() + "Cannot access element at index {} on {} of type {}", + index, checked_assignee, ty, ), }), - }?; + } + } + Assignee::Member(box assignee, box member) => { + let checked_assignee = self.check_assignee(assignee, module_id, &types)?; - Ok(TypedAssignee::Select( - box checked_assignee, - box checked_typed_index, - )) + let ty = checked_assignee.get_type(); + match &ty { + Type::Struct(members) => match members.iter().find(|(id, _)| id == member) { + Some(_) => Ok(TypedAssignee::Member(box checked_assignee, member.into())), + None => Err(Error { + pos: Some(pos), + message: format!("{} doesn't have member {}", ty, member), + }), + }, + ty => Err(Error { + pos: Some(pos), + + message: format!( + "Cannot access field {} on {} as of type {}", + member, checked_assignee, ty, + ), + }), + } } } } @@ -2322,10 +2370,10 @@ mod tests { let mut checker = Checker::new(); assert_eq!( checker.check_statement(statement, &module_id, &types), - Err(Error { + Err(vec![Error { pos: Some((Position::mock(), Position::mock())), message: "Identifier \"b\" is undefined".to_string() - }) + }]) ); } @@ -3189,7 +3237,7 @@ mod tests { let types = HashMap::new(); let module_id = String::from(""); let mut checker = Checker::new(); - let _: Result, Error> = checker.check_statement( + let _: Result, Vec> = checker.check_statement( Statement::Declaration( absy::Variable::new("a", UnresolvedType::FieldElement.mock()).mock(), ) @@ -3197,7 +3245,7 @@ mod tests { &module_id, &types, ); - let s2_checked: Result, Error> = checker.check_statement( + let s2_checked: Result, Vec> = checker.check_statement( Statement::Declaration( absy::Variable::new("a", UnresolvedType::FieldElement.mock()).mock(), ) @@ -3207,10 +3255,10 @@ mod tests { ); assert_eq!( s2_checked, - Err(Error { + Err(vec![Error { pos: Some((Position::mock(), Position::mock())), message: "Duplicate declaration for variable named a".to_string() - }) + }]) ); } @@ -3225,7 +3273,7 @@ mod tests { let module_id = String::from(""); let mut checker = Checker::new(); - let _: Result, Error> = checker.check_statement( + let _: Result, Vec> = checker.check_statement( Statement::Declaration( absy::Variable::new("a", UnresolvedType::FieldElement.mock()).mock(), ) @@ -3233,7 +3281,7 @@ mod tests { &module_id, &types, ); - let s2_checked: Result, Error> = checker.check_statement( + let s2_checked: Result, Vec> = checker.check_statement( Statement::Declaration(absy::Variable::new("a", UnresolvedType::Boolean.mock()).mock()) .mock(), &module_id, @@ -3241,10 +3289,10 @@ mod tests { ); assert_eq!( s2_checked, - Err(Error { + Err(vec![Error { pos: Some((Position::mock(), Position::mock())), message: "Duplicate declaration for variable named a".to_string() - }) + }]) ); } diff --git a/zokrates_core/src/static_analysis/mod.rs b/zokrates_core/src/static_analysis/mod.rs index 879dc65c..386d4c49 100644 --- a/zokrates_core/src/static_analysis/mod.rs +++ b/zokrates_core/src/static_analysis/mod.rs @@ -24,6 +24,7 @@ impl<'ast, T: Field> Analyse for TypedProgram<'ast, T> { fn analyse(self) -> Self { // unroll let r = Unroller::unroll(self); + println!("{}", r); // inline let r = Inliner::inline(r); // propagate diff --git a/zokrates_core/src/static_analysis/propagation.rs b/zokrates_core/src/static_analysis/propagation.rs index 084d40d5..a6ef4fb3 100644 --- a/zokrates_core/src/static_analysis/propagation.rs +++ b/zokrates_core/src/static_analysis/propagation.rs @@ -35,6 +35,10 @@ fn is_constant<'ast, T: Field>(e: &TypedExpression<'ast, T>) -> bool { ArrayExpressionInner::Value(v) => v.iter().all(|e| is_constant(e)), _ => false, }, + TypedExpression::Struct(a) => match a.as_inner() { + StructExpressionInner::Value(v) => v.iter().all(|e| is_constant(e)), + _ => false, + }, _ => false, } } @@ -71,6 +75,9 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> { TypedStatement::Definition(TypedAssignee::Select(..), _) => { unreachable!("array updates should have been replaced with full array redef") } + TypedStatement::Definition(TypedAssignee::Member(..), _) => { + unreachable!("struct update should have been replaced with full struct redef") + } // propagate lhs and rhs for conditions TypedStatement::Condition(e1, e2) => { // could stop execution here if condition is known to fail @@ -224,6 +231,24 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> { } } } + FieldElementExpression::Member(box s, m) => { + let s = self.fold_struct_expression(s); + + let members = match s.get_type() { + Type::Struct(members) => members, + _ => unreachable!(), + }; + + match s.into_inner() { + StructExpressionInner::Value(v) => { + match members.iter().zip(v).find(|(id, _)| id.0 == m).unwrap().1 { + TypedExpression::FieldElement(s) => s, + _ => unreachable!(), + } + } + inner => FieldElementExpression::Member(box inner.annotate(members), m), + } + } e => fold_field_expression(self, e), } } @@ -302,6 +327,24 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> { c => ArrayExpressionInner::IfElse(box c, box consequence, box alternative), } } + ArrayExpressionInner::Member(box s, m) => { + let s = self.fold_struct_expression(s); + + let members = match s.get_type() { + Type::Struct(members) => members, + _ => unreachable!(), + }; + + match s.into_inner() { + StructExpressionInner::Value(v) => { + match members.iter().zip(v).find(|(id, _)| id.0 == m).unwrap().1 { + TypedExpression::Array(a) => a.into_inner(), + _ => unreachable!(), + } + } + inner => ArrayExpressionInner::Member(box inner.annotate(members), m), + } + } e => fold_array_expression_inner(self, ty, size, e), } } @@ -380,6 +423,24 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> { c => StructExpressionInner::IfElse(box c, box consequence, box alternative), } } + StructExpressionInner::Member(box s, m) => { + let s = self.fold_struct_expression(s); + + let members = match s.get_type() { + Type::Struct(members) => members, + _ => unreachable!(), + }; + + match s.into_inner() { + StructExpressionInner::Value(v) => { + match members.iter().zip(v).find(|(id, _)| id.0 == m).unwrap().1 { + TypedExpression::Struct(s) => s.into_inner(), + _ => unreachable!(), + } + } + inner => StructExpressionInner::Member(box inner.annotate(members), m), + } + } e => fold_struct_expression_inner(self, ty, e), } } @@ -508,6 +569,24 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> { c => BooleanExpression::IfElse(box c, box consequence, box alternative), } } + BooleanExpression::Member(box s, m) => { + let s = self.fold_struct_expression(s); + + let members = match s.get_type() { + Type::Struct(members) => members, + _ => unreachable!(), + }; + + match s.into_inner() { + StructExpressionInner::Value(v) => { + match members.iter().zip(v).find(|(id, _)| id.0 == m).unwrap().1 { + TypedExpression::Boolean(s) => s, + _ => unreachable!(), + } + } + inner => BooleanExpression::Member(box inner.annotate(members), m), + } + } e => fold_boolean_expression(self, e), } } diff --git a/zokrates_core/src/static_analysis/unroll.rs b/zokrates_core/src/static_analysis/unroll.rs index 2e294041..dcee2bf0 100644 --- a/zokrates_core/src/static_analysis/unroll.rs +++ b/zokrates_core/src/static_analysis/unroll.rs @@ -5,7 +5,7 @@ //! @date 2018 use crate::typed_absy::folder::*; -use crate::typed_absy::types::Type; +use crate::typed_absy::types::{MemberId, Type}; use crate::typed_absy::*; use std::collections::HashMap; use std::collections::HashSet; @@ -47,7 +47,7 @@ impl<'ast> Unroller<'ast> { fn choose_many( base: TypedExpression<'ast, T>, - indices: Vec>, + indices: Vec>, new_expression: TypedExpression<'ast, T>, statements: &mut HashSet>, ) -> TypedExpression<'ast, T> { @@ -55,158 +55,256 @@ impl<'ast> Unroller<'ast> { match indices.len() { 0 => new_expression, - _ => { - let base = match base { - TypedExpression::Array(e) => e, - e => unreachable!("can't take an element on a {}", e.get_type()), - }; + _ => match base { + TypedExpression::Array(base) => { + let inner_ty = base.inner_type(); + let size = base.size(); - let inner_ty = base.inner_type(); - let size = base.size(); + let head = indices.remove(0); + let tail = indices; - let head = indices.pop().unwrap(); - let tail = indices; + match head { + Access::Select(head) => { + statements.insert(TypedStatement::Condition( + BooleanExpression::Lt( + box head.clone(), + box FieldElementExpression::Number(T::from(size)), + ) + .into(), + BooleanExpression::Value(true).into(), + )); - statements.insert(TypedStatement::Condition( - BooleanExpression::Lt( - box head.clone(), - box FieldElementExpression::Number(T::from(size)), - ) - .into(), - BooleanExpression::Value(true).into(), - )); + ArrayExpressionInner::Value( + (0..size) + .map(|i| match inner_ty { + Type::Array(..) => ArrayExpression::if_else( + BooleanExpression::Eq( + box FieldElementExpression::Number(T::from(i)), + box head.clone(), + ), + match Self::choose_many( + ArrayExpression::select( + base.clone(), + FieldElementExpression::Number(T::from(i)), + ) + .into(), + tail.clone(), + new_expression.clone(), + statements, + ) { + TypedExpression::Array(e) => e, + e => unreachable!( + "the interior was expected to be an array, was {}", + e.get_type() + ), + }, + ArrayExpression::select( + base.clone(), + FieldElementExpression::Number(T::from(i)), + ), + ) + .into(), + Type::Struct(..) => StructExpression::if_else( + BooleanExpression::Eq( + box FieldElementExpression::Number(T::from(i)), + box head.clone(), + ), + match Self::choose_many( + StructExpression::select( + base.clone(), + FieldElementExpression::Number(T::from(i)), + ) + .into(), + tail.clone(), + new_expression.clone(), + statements, + ) { + TypedExpression::Struct(e) => e, + e => unreachable!( + "the interior was expected to be a struct, was {}", + e.get_type() + ), + }, + StructExpression::select( + base.clone(), + FieldElementExpression::Number(T::from(i)), + ), + ) + .into(), + Type::FieldElement => FieldElementExpression::if_else( + BooleanExpression::Eq( + box FieldElementExpression::Number(T::from(i)), + box head.clone(), + ), + match Self::choose_many( + FieldElementExpression::select( + base.clone(), + FieldElementExpression::Number(T::from(i)), + ) + .into(), + tail.clone(), + new_expression.clone(), + statements, + ) { + TypedExpression::FieldElement(e) => e, + e => unreachable!( + "the interior was expected to be a field, was {}", + e.get_type() + ), + }, + FieldElementExpression::select( + base.clone(), + FieldElementExpression::Number(T::from(i)), + ), + ) + .into(), + Type::Boolean => BooleanExpression::if_else( + BooleanExpression::Eq( + box FieldElementExpression::Number(T::from(i)), + box head.clone(), + ), + match Self::choose_many( + BooleanExpression::select( + base.clone(), + FieldElementExpression::Number(T::from(i)), + ) + .into(), + tail.clone(), + new_expression.clone(), + statements, + ) { + TypedExpression::Boolean(e) => e, + e => unreachable!( + "the interior was expected to be a boolean, was {}", + e.get_type() + ), + }, + BooleanExpression::select( + base.clone(), + FieldElementExpression::Number(T::from(i)), + ), + ) + .into(), + }) + .collect(), + ) + .annotate(inner_ty.clone(), size) + .into() + } + Access::Member(..) => unreachable!("can't get a member from an array"), + } + } + TypedExpression::Struct(base) => { + let members = match base.get_type() { + Type::Struct(members) => members.clone(), + _ => unreachable!(), + }; - ArrayExpressionInner::Value( - (0..size) - .map(|i| match inner_ty { - Type::Array(..) => ArrayExpression::if_else( - BooleanExpression::Eq( - box FieldElementExpression::Number(T::from(i)), - box head.clone(), - ), - match Self::choose_many( - ArrayExpression::select( - base.clone(), - FieldElementExpression::Number(T::from(i)), - ) - .into(), - tail.clone(), - new_expression.clone(), - statements, - ) { - TypedExpression::Array(e) => e, - e => unreachable!( - "the interior was expected to be an array, was {}", - e.get_type() - ), - }, - ArrayExpression::select( - base.clone(), - FieldElementExpression::Number(T::from(i)), - ), - ) - .into(), - Type::Struct(..) => StructExpression::if_else( - BooleanExpression::Eq( - box FieldElementExpression::Number(T::from(i)), - box head.clone(), - ), - match Self::choose_many( - StructExpression::select( - base.clone(), - FieldElementExpression::Number(T::from(i)), - ) - .into(), - tail.clone(), - new_expression.clone(), - statements, - ) { - TypedExpression::Struct(e) => e, - e => unreachable!( - "the interior was expected to be a struct, was {}", - e.get_type() - ), - }, - StructExpression::select( - base.clone(), - FieldElementExpression::Number(T::from(i)), - ), - ) - .into(), - Type::FieldElement => FieldElementExpression::if_else( - BooleanExpression::Eq( - box FieldElementExpression::Number(T::from(i)), - box head.clone(), - ), - match Self::choose_many( - FieldElementExpression::select( - base.clone(), - FieldElementExpression::Number(T::from(i)), - ) - .into(), - tail.clone(), - new_expression.clone(), - statements, - ) { - TypedExpression::FieldElement(e) => e, - e => unreachable!( - "the interior was expected to be a field, was {}", - e.get_type() - ), - }, - FieldElementExpression::select( - base.clone(), - FieldElementExpression::Number(T::from(i)), - ), - ) - .into(), - Type::Boolean => BooleanExpression::if_else( - BooleanExpression::Eq( - box FieldElementExpression::Number(T::from(i)), - box head.clone(), - ), - match Self::choose_many( - BooleanExpression::select( - base.clone(), - FieldElementExpression::Number(T::from(i)), - ) - .into(), - tail.clone(), - new_expression.clone(), - statements, - ) { - TypedExpression::Boolean(e) => e, - e => unreachable!( - "the interior was expected to be a boolean, was {}", - e.get_type() - ), - }, - BooleanExpression::select( - base.clone(), - FieldElementExpression::Number(T::from(i)), - ), - ) - .into(), - }) - .collect(), - ) - .annotate(inner_ty.clone(), size) - .into() - } + let head = indices.remove(0); + let tail = indices; + + match head { + Access::Member(head) => StructExpressionInner::Value( + members + .clone() + .into_iter() + .map(|(id, t)| match t { + Type::FieldElement => { + if id == head { + Self::choose_many( + FieldElementExpression::member( + base.clone(), + head.clone(), + ) + .into(), + tail.clone(), + new_expression.clone(), + statements, + ) + } else { + FieldElementExpression::member(base.clone(), id.clone()) + .into() + } + } + Type::Boolean => { + if id == head { + Self::choose_many( + BooleanExpression::member( + base.clone(), + head.clone(), + ) + .into(), + tail.clone(), + new_expression.clone(), + statements, + ) + } else { + BooleanExpression::member(base.clone(), id.clone()) + .into() + } + } + Type::Array(..) => { + if id == head { + Self::choose_many( + ArrayExpression::member(base.clone(), head.clone()) + .into(), + tail.clone(), + new_expression.clone(), + statements, + ) + } else { + ArrayExpression::member(base.clone(), id.clone()).into() + } + } + Type::Struct(..) => { + if id == head { + Self::choose_many( + StructExpression::member( + base.clone(), + head.clone(), + ) + .into(), + tail.clone(), + new_expression.clone(), + statements, + ) + } else { + StructExpression::member(base.clone(), id.clone()) + .into() + } + } + }) + .collect(), + ) + .annotate(members) + .into(), + Access::Select(..) => unreachable!("can't get a element from a struct"), + } + } + e => unreachable!("can't make an access on a {}", e.get_type()), + }, } } } -/// Turn an assignee into its representation as a base variable and a list of indices +#[derive(Clone, Debug)] +enum Access<'ast, T: Field> { + Select(FieldElementExpression<'ast, T>), + Member(MemberId), +} +/// Turn an assignee into its representation as a base variable and a list accesses /// a[2][3][4] -> (a, [2, 3, 4]) -fn linear<'ast, T: Field>( - a: TypedAssignee<'ast, T>, -) -> (Variable, Vec>) { +fn linear<'ast, T: Field>(a: TypedAssignee<'ast, T>) -> (Variable, Vec>) { match a { TypedAssignee::Identifier(v) => (v, vec![]), TypedAssignee::Select(box array, box index) => { let (v, mut indices) = linear(array); - indices.push(index); + indices.push(Access::Select(index)); + (v, indices) + } + TypedAssignee::Member(box s, m) => { + let (v, mut indices) = linear(s); + indices.push(Access::Member(m)); (v, indices) } } @@ -324,7 +422,12 @@ mod tests { let index = FieldElementExpression::Number(FieldPrime::from(1)); - let a1 = Unroller::choose_many(a0.clone().into(), vec![index], e, &mut HashSet::new()); + let a1 = Unroller::choose_many( + a0.clone().into(), + vec![Access::Select(index)], + e, + &mut HashSet::new(), + ); // a[1] = 42 // -> a = [0 == 1 ? 42 : a[0], 1 == 1 ? 42 : a[1], 2 == 1 ? 42 : a[2]] @@ -382,7 +485,7 @@ mod tests { let a1 = Unroller::choose_many( a0.clone().into(), - vec![index], + vec![Access::Select(index)], e.clone().into(), &mut HashSet::new(), ); @@ -440,8 +543,8 @@ mod tests { let e = FieldElementExpression::Number(FieldPrime::from(42)); let indices = vec![ - FieldElementExpression::Number(FieldPrime::from(0)), - FieldElementExpression::Number(FieldPrime::from(0)), + Access::Select(FieldElementExpression::Number(FieldPrime::from(0))), + Access::Select(FieldElementExpression::Number(FieldPrime::from(0))), ]; let a1 = Unroller::choose_many( diff --git a/zokrates_core/src/typed_absy/folder.rs b/zokrates_core/src/typed_absy/folder.rs index f9274f08..180bed48 100644 --- a/zokrates_core/src/typed_absy/folder.rs +++ b/zokrates_core/src/typed_absy/folder.rs @@ -48,6 +48,7 @@ pub trait Folder<'ast, T: Field>: Sized { box self.fold_assignee(a), box self.fold_field_expression(index), ), + TypedAssignee::Member(box s, m) => TypedAssignee::Member(box self.fold_assignee(s), m), } } diff --git a/zokrates_core/src/typed_absy/mod.rs b/zokrates_core/src/typed_absy/mod.rs index 43a87506..7b6c3195 100644 --- a/zokrates_core/src/typed_absy/mod.rs +++ b/zokrates_core/src/typed_absy/mod.rs @@ -241,6 +241,7 @@ pub enum TypedAssignee<'ast, T: Field> { Box>, Box>, ), + Member(Box>, MemberId), } impl<'ast, T: Field> Typed for TypedAssignee<'ast, T> { @@ -254,6 +255,15 @@ impl<'ast, T: Field> Typed for TypedAssignee<'ast, T> { _ => unreachable!("an array element should only be defined over arrays"), } } + TypedAssignee::Member(ref s, ref m) => { + let s_type = s.get_type(); + match s_type { + Type::Struct(members) => { + members.iter().find(|(id, _)| id == m).unwrap().1.clone() + } + _ => unreachable!("a struct access should only be defined over structs"), + } + } } } } @@ -263,6 +273,7 @@ impl<'ast, T: Field> fmt::Debug for TypedAssignee<'ast, T> { match *self { TypedAssignee::Identifier(ref s) => write!(f, "{}", s.id), TypedAssignee::Select(ref a, ref e) => write!(f, "{}[{}]", a, e), + TypedAssignee::Member(ref s, ref m) => write!(f, "{}.{}", s, m), } } } diff --git a/zokrates_parser/src/zokrates.pest b/zokrates_parser/src/zokrates.pest index 417d201c..0c054478 100644 --- a/zokrates_parser/src/zokrates.pest +++ b/zokrates_parser/src/zokrates.pest @@ -89,7 +89,8 @@ unary_expression = { op_unary ~ term } // End Expressions -assignee = { identifier ~ ("[" ~ range_or_expression ~ "]")* } +assignee = { identifier ~ assignee_access* } +assignee_access = { array_access | member_access } identifier = @{ ((!keyword ~ ASCII_ALPHA) | (keyword ~ (ASCII_ALPHANUMERIC | "_"))) ~ (ASCII_ALPHANUMERIC | "_")* } constant = { decimal_number | boolean_literal } decimal_number = @{ "0" | ASCII_NONZERO_DIGIT ~ ASCII_DIGIT* } diff --git a/zokrates_pest_ast/src/lib.rs b/zokrates_pest_ast/src/lib.rs index a537f197..164d7962 100644 --- a/zokrates_pest_ast/src/lib.rs +++ b/zokrates_pest_ast/src/lib.rs @@ -9,13 +9,13 @@ extern crate lazy_static; pub use ast::{ Access, ArrayAccess, ArrayInitializerExpression, ArrayType, AssertionStatement, Assignee, - AssignmentStatement, BasicOrStructType, BasicType, BinaryExpression, BinaryOperator, - CallAccess, ConstantExpression, DefinitionStatement, Expression, File, FromExpression, - Function, IdentifierExpression, ImportDirective, ImportSource, InlineArrayExpression, - InlineStructExpression, InlineStructMember, IterationStatement, MultiAssignmentStatement, - Parameter, PostfixExpression, Range, RangeOrExpression, ReturnStatement, Span, Spread, - SpreadOrExpression, Statement, StructDefinition, StructField, TernaryExpression, ToExpression, - Type, UnaryExpression, UnaryOperator, Visibility, + AssigneeAccess, AssignmentStatement, BasicOrStructType, BasicType, BinaryExpression, + BinaryOperator, CallAccess, ConstantExpression, DefinitionStatement, Expression, File, + FromExpression, Function, IdentifierExpression, ImportDirective, ImportSource, + InlineArrayExpression, InlineStructExpression, InlineStructMember, IterationStatement, + MultiAssignmentStatement, Parameter, PostfixExpression, Range, RangeOrExpression, + ReturnStatement, Span, Spread, SpreadOrExpression, Statement, StructDefinition, StructField, + TernaryExpression, ToExpression, Type, UnaryExpression, UnaryOperator, Visibility, }; mod ast { @@ -529,6 +529,13 @@ mod ast { Member(MemberAccess<'ast>), } + #[derive(Debug, FromPest, PartialEq, Clone)] + #[pest_ast(rule(Rule::assignee_access))] + pub enum AssigneeAccess<'ast> { + Select(ArrayAccess<'ast>), + Member(MemberAccess<'ast>), + } + #[derive(Debug, FromPest, PartialEq, Clone)] #[pest_ast(rule(Rule::call_access))] pub struct CallAccess<'ast> { @@ -684,8 +691,8 @@ mod ast { #[derive(Debug, FromPest, PartialEq, Clone)] #[pest_ast(rule(Rule::assignee))] pub struct Assignee<'ast> { - pub id: IdentifierExpression<'ast>, // a - pub indices: Vec>, // [42 + x][31][7] + pub id: IdentifierExpression<'ast>, // a + pub accesses: Vec>, // [42 + x].foo[7] #[pest_ast(outer())] pub span: Span<'ast>, } From 79fea57be8232466ddf72b27def7630bbcfc6806 Mon Sep 17 00:00:00 2001 From: schaeff Date: Mon, 23 Sep 2019 18:33:08 +0200 Subject: [PATCH 25/35] clean, add boolean check to all boolean user input --- t.code | 16 --- u.code | 4 - zokrates_core/src/compile.rs | 2 - zokrates_core/src/flatten/mod.rs | 30 +---- .../src/static_analysis/constrain_inputs.rs | 120 ++++++++++++++++++ zokrates_core/src/static_analysis/mod.rs | 5 +- .../tests/tests/arrays/identity.code | 2 + .../tests/tests/arrays/identity.json | 38 ++++++ .../tests/tests/structs/identity.code | 7 + .../tests/tests/structs/identity.json | 38 ++++++ 10 files changed, 211 insertions(+), 51 deletions(-) delete mode 100644 t.code delete mode 100644 u.code create mode 100644 zokrates_core/src/static_analysis/constrain_inputs.rs create mode 100644 zokrates_core_test/tests/tests/arrays/identity.code create mode 100644 zokrates_core_test/tests/tests/arrays/identity.json create mode 100644 zokrates_core_test/tests/tests/structs/identity.code create mode 100644 zokrates_core_test/tests/tests/structs/identity.json diff --git a/t.code b/t.code deleted file mode 100644 index a38f522b..00000000 --- a/t.code +++ /dev/null @@ -1,16 +0,0 @@ -from "./u.code" import Fooo - -struct Bar { - a: field, - a: field, - c: field, -} - -struct Baz { - a: Bar -} - -def main(Bar a, Bar b, bool c) -> (Bar): - Bar bar = Bar { a: 1, b: 1, c: 1 } - return if false then a else bar fi - diff --git a/u.code b/u.code deleted file mode 100644 index f420f0f2..00000000 --- a/u.code +++ /dev/null @@ -1,4 +0,0 @@ -struct Foo { - a: field, - b: field[2], -} \ No newline at end of file diff --git a/zokrates_core/src/compile.rs b/zokrates_core/src/compile.rs index 82990bcf..861b76ca 100644 --- a/zokrates_core/src/compile.rs +++ b/zokrates_core/src/compile.rs @@ -138,8 +138,6 @@ pub fn compile>( let source = arena.alloc(source); - println!("{:?}", source); - let compiled = compile_program(source, location.clone(), resolve_option, &arena)?; // check semantics diff --git a/zokrates_core/src/flatten/mod.rs b/zokrates_core/src/flatten/mod.rs index f4a7d972..c92eb68b 100644 --- a/zokrates_core/src/flatten/mod.rs +++ b/zokrates_core/src/flatten/mod.rs @@ -1548,7 +1548,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { let arguments_flattened = funct .arguments .into_iter() - .flat_map(|p| self.use_parameter(&p, &mut statements_flattened)) + .flat_map(|p| self.use_parameter(&p)) .collect(); // flatten statements in functions and apply substitution @@ -1602,19 +1602,8 @@ impl<'ast, T: Field> Flattener<'ast, T> { vars } - fn use_parameter( - &mut self, - parameter: &Parameter<'ast>, - statements: &mut Vec>, - ) -> Vec { + fn use_parameter(&mut self, parameter: &Parameter<'ast>) -> Vec { let variables = self.use_variable(¶meter.id); - match parameter.id.get_type() { - Type::Boolean => statements.extend(Self::boolean_constraint(&variables)), - Type::Array(box Type::Boolean, _) => { - statements.extend(Self::boolean_constraint(&variables)) - } - _ => {} - }; variables .into_iter() @@ -1635,21 +1624,6 @@ impl<'ast, T: Field> Flattener<'ast, T> { (0..count).map(|_| self.issue_new_variable()).collect() } - fn boolean_constraint(variables: &Vec) -> Vec> { - variables - .iter() - .map(|v| { - FlatStatement::Condition( - FlatExpression::Identifier(*v), - FlatExpression::Mult( - box FlatExpression::Identifier(*v), - box FlatExpression::Identifier(*v), - ), - ) - }) - .collect() - } - // create an internal variable. We do not register it in the layout fn use_sym(&mut self) -> FlatVariable { self.issue_new_variable() diff --git a/zokrates_core/src/static_analysis/constrain_inputs.rs b/zokrates_core/src/static_analysis/constrain_inputs.rs new file mode 100644 index 00000000..711024f7 --- /dev/null +++ b/zokrates_core/src/static_analysis/constrain_inputs.rs @@ -0,0 +1,120 @@ +//! Add runtime boolean checks on user inputs +//! Example: +//! ``` +//! struct Foo { +//! bar: bool +//! } +//! +//! def main(Foo f) -> (): +//! f.bar == f.bar && f.bar +//! return +//! ``` +//! @file unroll.rs +//! @author Thibaut Schaeffer +//! @date 2018 + +use crate::typed_absy::folder::Folder; +use crate::typed_absy::types::Type; +use crate::typed_absy::*; +use zokrates_field::field::Field; + +pub struct InputConstrainer<'ast, T: Field> { + constraints: Vec>, +} + +impl<'ast, T: Field> InputConstrainer<'ast, T> { + fn new() -> Self { + InputConstrainer { + constraints: vec![], + } + } + + pub fn constrain(p: TypedProgram) -> TypedProgram { + InputConstrainer::new().fold_program(p) + } + + fn constrain_expression(&mut self, e: TypedExpression<'ast, T>) { + match e { + TypedExpression::FieldElement(_) => {} + TypedExpression::Boolean(b) => self.constraints.push(TypedStatement::Condition( + b.clone().into(), + BooleanExpression::And(box b.clone(), box b).into(), + )), + TypedExpression::Array(a) => { + for i in 0..a.size() { + let e = match a.inner_type() { + Type::FieldElement => FieldElementExpression::select( + a.clone(), + FieldElementExpression::Number(T::from(i)), + ) + .into(), + Type::Boolean => BooleanExpression::select( + a.clone(), + FieldElementExpression::Number(T::from(i)), + ) + .into(), + Type::Array(..) => ArrayExpression::select( + a.clone(), + FieldElementExpression::Number(T::from(i)), + ) + .into(), + Type::Struct(..) => StructExpression::select( + a.clone(), + FieldElementExpression::Number(T::from(i)), + ) + .into(), + }; + + self.constrain_expression(e); + } + } + TypedExpression::Struct(s) => { + for (id, ty) in s.ty() { + let e = match ty { + Type::FieldElement => { + FieldElementExpression::member(s.clone(), id.clone()).into() + } + Type::Boolean => BooleanExpression::member(s.clone(), id.clone()).into(), + Type::Array(..) => ArrayExpression::member(s.clone(), id.clone()).into(), + Type::Struct(..) => StructExpression::member(s.clone(), id.clone()).into(), + }; + + self.constrain_expression(e); + } + } + } + } +} + +impl<'ast, T: Field> Folder<'ast, T> for InputConstrainer<'ast, T> { + fn fold_parameter(&mut self, p: Parameter<'ast>) -> Parameter<'ast> { + let v = p.id.clone(); + + let e = match v.get_type() { + Type::FieldElement => FieldElementExpression::Identifier(v.id).into(), + Type::Boolean => BooleanExpression::Identifier(v.id).into(), + Type::Struct(members) => StructExpressionInner::Identifier(v.id) + .annotate(members) + .into(), + Type::Array(box ty, size) => ArrayExpressionInner::Identifier(v.id) + .annotate(ty, size) + .into(), + }; + + self.constrain_expression(e); + + p + } + + fn fold_function(&mut self, f: TypedFunction<'ast, T>) -> TypedFunction<'ast, T> { + TypedFunction { + arguments: f + .arguments + .into_iter() + .map(|a| self.fold_parameter(a)) + .collect(), + statements: self.constraints.drain(..).chain(f.statements).collect(), + ..f + } + } +} diff --git a/zokrates_core/src/static_analysis/mod.rs b/zokrates_core/src/static_analysis/mod.rs index 386d4c49..e2c2e10d 100644 --- a/zokrates_core/src/static_analysis/mod.rs +++ b/zokrates_core/src/static_analysis/mod.rs @@ -4,11 +4,13 @@ //! @author Thibaut Schaeffer //! @date 2018 +mod constrain_inputs; mod flat_propagation; mod inline; mod propagation; mod unroll; +use self::constrain_inputs::InputConstrainer; use self::inline::Inliner; use self::propagation::Propagator; use self::unroll::Unroller; @@ -24,11 +26,12 @@ impl<'ast, T: Field> Analyse for TypedProgram<'ast, T> { fn analyse(self) -> Self { // unroll let r = Unroller::unroll(self); - println!("{}", r); // inline let r = Inliner::inline(r); // propagate let r = Propagator::propagate(r); + // constrain inputs + let r = InputConstrainer::constrain(r); r } } diff --git a/zokrates_core_test/tests/tests/arrays/identity.code b/zokrates_core_test/tests/tests/arrays/identity.code new file mode 100644 index 00000000..8c97930b --- /dev/null +++ b/zokrates_core_test/tests/tests/arrays/identity.code @@ -0,0 +1,2 @@ +def main(bool[3] a) -> (bool[3]): + return a \ No newline at end of file diff --git a/zokrates_core_test/tests/tests/arrays/identity.json b/zokrates_core_test/tests/tests/arrays/identity.json new file mode 100644 index 00000000..e6a36236 --- /dev/null +++ b/zokrates_core_test/tests/tests/arrays/identity.json @@ -0,0 +1,38 @@ +{ + "entry_point": "./tests/tests/arrays/identity.code", + "tests": [ + { + "input": { + "values": ["0", "0", "0"] + }, + "output": { + "Ok": { + "values": ["0", "0", "0"] + } + } + }, + { + "input": { + "values": ["1", "0", "1"] + }, + "output": { + "Ok": { + "values": ["1", "0", "1"] + } + } + }, + { + "input": { + "values": ["2", "1", "1"] + }, + "output": { + "Err": { + "UnsatisfiedConstraint": { + "left": "4", + "right": "2" + } + } + } + } + ] +} \ No newline at end of file diff --git a/zokrates_core_test/tests/tests/structs/identity.code b/zokrates_core_test/tests/tests/structs/identity.code new file mode 100644 index 00000000..041e79ed --- /dev/null +++ b/zokrates_core_test/tests/tests/structs/identity.code @@ -0,0 +1,7 @@ +struct A { + a: field, + b: bool +} + +def main(A a) -> (A): + return a \ No newline at end of file diff --git a/zokrates_core_test/tests/tests/structs/identity.json b/zokrates_core_test/tests/tests/structs/identity.json new file mode 100644 index 00000000..db4b136e --- /dev/null +++ b/zokrates_core_test/tests/tests/structs/identity.json @@ -0,0 +1,38 @@ +{ + "entry_point": "./tests/tests/structs/identity.code", + "tests": [ + { + "input": { + "values": ["42", "0"] + }, + "output": { + "Ok": { + "values": ["42", "0"] + } + } + }, + { + "input": { + "values": ["42", "1"] + }, + "output": { + "Ok": { + "values": ["42", "1"] + } + } + }, + { + "input": { + "values": ["42", "3"] + }, + "output": { + "Err": { + "UnsatisfiedConstraint": { + "left": "9", + "right": "3" + } + } + } + } + ] +} \ No newline at end of file From c815a32574a7cb256a749f8bbf28c32c5991886d Mon Sep 17 00:00:00 2001 From: schaeff Date: Mon, 23 Sep 2019 18:48:59 +0200 Subject: [PATCH 26/35] remove test, fix comment --- zokrates_core/src/flatten/mod.rs | 52 ------------------- .../src/static_analysis/constrain_inputs.rs | 2 +- 2 files changed, 1 insertion(+), 53 deletions(-) diff --git a/zokrates_core/src/flatten/mod.rs b/zokrates_core/src/flatten/mod.rs index c92eb68b..6e55180e 100644 --- a/zokrates_core/src/flatten/mod.rs +++ b/zokrates_core/src/flatten/mod.rs @@ -1652,58 +1652,6 @@ mod tests { use crate::typed_absy::types::Type; use zokrates_field::field::FieldPrime; - mod boolean_checks { - use super::*; - - #[test] - fn boolean_arg() { - // def main(bool a): - // return a - // - // -> should flatten to - // - // def main(_0) -> (1): - // _0 * _0 == _0 - // return _0 - - let function: TypedFunction = TypedFunction { - arguments: vec![Parameter::private(Variable::boolean("a".into()))], - statements: vec![TypedStatement::Return(vec![BooleanExpression::Identifier( - "a".into(), - ) - .into()])], - signature: Signature::new() - .inputs(vec![Type::Boolean]) - .outputs(vec![Type::Boolean]), - }; - - let expected = FlatFunction { - arguments: vec![FlatParameter::private(FlatVariable::new(0))], - statements: vec![ - FlatStatement::Condition( - FlatExpression::Identifier(FlatVariable::new(0)), - FlatExpression::Mult( - box FlatExpression::Identifier(FlatVariable::new(0)), - box FlatExpression::Identifier(FlatVariable::new(0)), - ), - ), - FlatStatement::Return(FlatExpressionList { - expressions: vec![FlatExpression::Identifier(FlatVariable::new(0))], - }), - ], - signature: Signature::new() - .inputs(vec![Type::Boolean]) - .outputs(vec![Type::Boolean]), - }; - - let mut flattener = Flattener::new(); - - let flat_function = flattener.flatten_function(&mut HashMap::new(), function); - - assert_eq!(flat_function, expected); - } - } - #[test] fn powers_zero() { // def main(): diff --git a/zokrates_core/src/static_analysis/constrain_inputs.rs b/zokrates_core/src/static_analysis/constrain_inputs.rs index 711024f7..38995c30 100644 --- a/zokrates_core/src/static_analysis/constrain_inputs.rs +++ b/zokrates_core/src/static_analysis/constrain_inputs.rs @@ -1,6 +1,6 @@ //! Add runtime boolean checks on user inputs //! Example: -//! ``` +//! ```zokrates //! struct Foo { //! bar: bool //! } From 20f3a980fe947c60f54c34d9a1212e863534916c Mon Sep 17 00:00:00 2001 From: JacobEberhardt Date: Tue, 24 Sep 2019 22:14:42 +0200 Subject: [PATCH 27/35] Struct doc with examples. Break current preliminary impl. as they have spec character. --- zokrates_book/src/concepts/types.md | 66 +++++++++++++++---- zokrates_cli/examples/book/struct_assign.code | 10 +++ zokrates_cli/examples/book/struct_init.code | 8 +++ zokrates_cli/examples/book/structs.code | 14 ++++ 4 files changed, 85 insertions(+), 13 deletions(-) create mode 100644 zokrates_cli/examples/book/struct_assign.code create mode 100644 zokrates_cli/examples/book/struct_init.code create mode 100644 zokrates_cli/examples/book/structs.code diff --git a/zokrates_book/src/concepts/types.md b/zokrates_book/src/concepts/types.md index 23264b1a..5e64e1eb 100644 --- a/zokrates_book/src/concepts/types.md +++ b/zokrates_book/src/concepts/types.md @@ -1,10 +1,10 @@ -## Types +# Types -ZoKrates currently exposes two primitive types and a complex array type: +ZoKrates currently exposes two primitive types and two complex types: -### Primitive Types +## Primitive Types -#### `field` +### `field` This is the most basic type in ZoKrates, and it represents a positive integer in `[0, p - 1]` where `p` is a (large) prime number. @@ -16,7 +16,7 @@ While `field` values mostly behave like unsigned integers, one should keep in mi {{#include ../../../zokrates_cli/examples/book/field_overflow.code}} ``` -#### `bool` +### `bool` ZoKrates has limited support for booleans, to the extent that they can only be used as the condition in `if ... else ... endif` expressions. @@ -24,9 +24,11 @@ You can use them for equality checks, inequality checks and inequality checks be Note that while equality checks are cheap, inequality checks should be use wisely as they are orders of magnitude more expensive. -### Complex Types +## Complex Types -#### Arrays +ZoKrates provides two complex types, Arrays and Structs. + +### Arrays ZoKrates supports static arrays, i.e., their length needs to be known at compile time. Arrays can contain elements of any type and have arbitrary dimensions. @@ -37,10 +39,10 @@ The following examples code shows examples of how to use arrays: {{#include ../../../zokrates_cli/examples/book/array.code}} ``` -##### Declaration and Initialization +#### Declaration and Initialization An array is defined by appending `[]` to a type literal representing the type of the array's elements. -Initialization always needs to happen in the same statement than declaration, unless the array is declared within a function's signature. +Initialization always needs to happen in the same statement as declaration, unless the array is declared within a function's signature. For initialization, a list of comma-separated values is provided within brackets `[]`. @@ -54,7 +56,7 @@ The following code provides examples for declaration and initialization: bool[13] b = [false; 13] // initialize a bool array with value false ``` -##### Multidimensional Arrays +#### Multidimensional Arrays As an array can contain any type of elements, it can contain arrays again. There is a special syntax to declare such multi-dimensional arrays, i.e., arrays of arrays. @@ -67,10 +69,10 @@ Consider the following example: {{#include ../../../zokrates_cli/examples/book/multidim_array.code}} ``` -##### Spreads and Slices +#### Spreads and Slices ZoKrates provides some syntactic sugar to retrieve subsets of arrays. -###### Spreads +##### Spreads The spread operator `...` applied to an copies the elements of an existing array. This can be used to conveniently compose new arrays, as shown in the following example: ``` @@ -78,10 +80,48 @@ field[3] = [1, 2, 3] field[4] c = [...a, 4] // initialize an array copying values from `a`, followed by 4 ``` -###### Slices +##### Slices An array can also be assigned to by creating a copy of a subset of an existing array. This operation is called slicing, and the following example shows how to slice in ZoKrates: ``` field[3] a = [1, 2, 3] field[2] b = a[1..3] // initialize an array copying a slice from `a` ``` + +### Structs +A struct is a composite datatype representing a named collection of variables. +The contained variables can be of any type. + +The following code shows an example of how to use structs. + +```zokrates +{{#include ../../../zokrates_cli/examples/book/structs.code}} +``` + +#### Definition +Before a struct data type can be used, it needs to be defined. +A struct definition starts with the `struct` keyword followed by a name. Afterwards, a new-line separated list of variables is declared in curly braces `{}`. For example: + +```zokrates +struct Point { + field x + field y +} +``` + +#### Declaration and Initialization + +Initialization of a variable of a struct type always needs to happen in the same statement as declaration, unless the struct-typed variable is declared within a function's signature. + +The following example shows declaration and initialization of a variable of the `Point` struct type: + +```zokrates +{{#include ../../../zokrates_cli/examples/book/struct_init.code}} +``` + +#### Assignment +The variables within a struct instance, the so called members, can be accessed through the `.` operator as shown in the following extended example: + +```zokrates +{{#include ../../../zokrates_cli/examples/book/struct_assign.code}} +``` \ No newline at end of file diff --git a/zokrates_cli/examples/book/struct_assign.code b/zokrates_cli/examples/book/struct_assign.code new file mode 100644 index 00000000..c52794d1 --- /dev/null +++ b/zokrates_cli/examples/book/struct_assign.code @@ -0,0 +1,10 @@ +struct Point { + field x + field y +} + +def main(field a) -> (Point): + Point p = Point {x: 1, y: 0} + p.x = a + p.y = p.x + return p diff --git a/zokrates_cli/examples/book/struct_init.code b/zokrates_cli/examples/book/struct_init.code new file mode 100644 index 00000000..837afc84 --- /dev/null +++ b/zokrates_cli/examples/book/struct_init.code @@ -0,0 +1,8 @@ +struct Point { + field x + field y +} + +def main() -> (Point): + Point p = Point {x: 1, y: 0} + return p diff --git a/zokrates_cli/examples/book/structs.code b/zokrates_cli/examples/book/structs.code new file mode 100644 index 00000000..6abdccce --- /dev/null +++ b/zokrates_cli/examples/book/structs.code @@ -0,0 +1,14 @@ +struct Bar { + field[2] c + bool b +} + +struct Foo { + Bar a + bool b +} + +def main() -> (Foo): + Foo[2] f = [Foo { a: Bar { c: [0, 0], d: false }, b: true}, Foo { a: Bar {c: [0, 0], d: false}, b: true}] + f[0].a.c = [42, 43] + return f[0] From 47395d0b4b773c6e6b6b20c83e0f43a3535e30db Mon Sep 17 00:00:00 2001 From: schaeff Date: Wed, 25 Sep 2019 12:06:11 +0200 Subject: [PATCH 28/35] change struct declaration syntax --- zokrates_cli/examples/book/structs.code | 2 +- zokrates_cli/examples/structs/add.code | 4 +- zokrates_cli/examples/structs/set_member.code | 8 ++-- .../tests/tests/structs/identity.code | 4 +- zokrates_parser/src/lib.rs | 42 +++++++++---------- zokrates_parser/src/zokrates.pest | 4 +- zokrates_pest_ast/src/lib.rs | 6 +-- 7 files changed, 35 insertions(+), 35 deletions(-) diff --git a/zokrates_cli/examples/book/structs.code b/zokrates_cli/examples/book/structs.code index 6abdccce..a08399ba 100644 --- a/zokrates_cli/examples/book/structs.code +++ b/zokrates_cli/examples/book/structs.code @@ -1,6 +1,6 @@ struct Bar { field[2] c - bool b + bool d } struct Foo { diff --git a/zokrates_cli/examples/structs/add.code b/zokrates_cli/examples/structs/add.code index 3941f08f..5d1fd856 100644 --- a/zokrates_cli/examples/structs/add.code +++ b/zokrates_cli/examples/structs/add.code @@ -1,6 +1,6 @@ struct Point { - x: field, - y: field + field x + field y } def main(Point p, Point q) -> (Point): diff --git a/zokrates_cli/examples/structs/set_member.code b/zokrates_cli/examples/structs/set_member.code index 99872793..7cca5707 100644 --- a/zokrates_cli/examples/structs/set_member.code +++ b/zokrates_cli/examples/structs/set_member.code @@ -1,11 +1,11 @@ struct Bar { - c: field[2], - d: bool + field[2] c + bool d } struct Foo { - a: Bar, - b: bool + Bar a + bool b } def main() -> (Foo): diff --git a/zokrates_core_test/tests/tests/structs/identity.code b/zokrates_core_test/tests/tests/structs/identity.code index 041e79ed..e5441761 100644 --- a/zokrates_core_test/tests/tests/structs/identity.code +++ b/zokrates_core_test/tests/tests/structs/identity.code @@ -1,6 +1,6 @@ struct A { - a: field, - b: bool + field a + bool b } def main(A a) -> (A): diff --git a/zokrates_parser/src/lib.rs b/zokrates_parser/src/lib.rs index e0383a84..be385565 100644 --- a/zokrates_parser/src/lib.rs +++ b/zokrates_parser/src/lib.rs @@ -149,40 +149,40 @@ mod tests { fn parse_struct_def() { parses_to! { parser: ZoKratesParser, - input: "struct Foo { foo: field, bar: field[2] } + input: "struct Foo { field foo\n field[2] bar } ", rule: Rule::ty_struct_definition, tokens: [ - ty_struct_definition(0, 41, [ + ty_struct_definition(0, 39, [ identifier(7, 10), - struct_field(13, 23, [ - identifier(13, 16), - ty(18, 23, [ - ty_basic(18, 23, [ - ty_field(18, 23) + struct_field(13, 22, [ + ty(13, 18, [ + ty_basic(13, 18, [ + ty_field(13, 18) ]) - ]) + ]), + identifier(19, 22) ]), - struct_field(25, 39, [ - identifier(25, 28), - ty(30, 39, [ - ty_array(30, 39, [ - ty_basic_or_struct(30, 35, [ - ty_basic(30, 35, [ - ty_field(30, 35) + struct_field(24, 36, [ + ty(24, 33, [ + ty_array(24, 33, [ + ty_basic_or_struct(24, 29, [ + ty_basic(24, 29, [ + ty_field(24, 29) ]) ]), - expression(36, 37, [ - term(36, 37, [ - primary_expression(36, 37, [ - constant(36, 37, [ - decimal_number(36, 37) + expression(30, 31, [ + term(30, 31, [ + primary_expression(30, 31, [ + constant(30, 31, [ + decimal_number(30, 31) ]) ]) ]) ]) ]) - ]) + ]), + identifier(33, 36) ]) ]) ] diff --git a/zokrates_parser/src/zokrates.pest b/zokrates_parser/src/zokrates.pest index 0c054478..40bc09dc 100644 --- a/zokrates_parser/src/zokrates.pest +++ b/zokrates_parser/src/zokrates.pest @@ -26,8 +26,8 @@ type_list = _{(ty ~ ("," ~ ty)*)?} ty_struct = { identifier } // type definitions ty_struct_definition = { "struct" ~ identifier ~ "{" ~ NEWLINE* ~ struct_field_list ~ NEWLINE* ~ "}" ~ NEWLINE* } -struct_field_list = _{(struct_field ~ ("," ~ NEWLINE* ~ struct_field)*)? ~ ","? } -struct_field = { identifier ~ ":" ~ ty } +struct_field_list = _{(struct_field ~ (NEWLINE+ ~ struct_field)*)? } +struct_field = { ty ~ identifier } vis_private = {"private"} vis_public = {"public"} diff --git a/zokrates_pest_ast/src/lib.rs b/zokrates_pest_ast/src/lib.rs index 164d7962..e86541be 100644 --- a/zokrates_pest_ast/src/lib.rs +++ b/zokrates_pest_ast/src/lib.rs @@ -178,8 +178,8 @@ mod ast { #[derive(Debug, FromPest, PartialEq, Clone)] #[pest_ast(rule(Rule::struct_field))] pub struct StructField<'ast> { - pub id: IdentifierExpression<'ast>, pub ty: Type<'ast>, + pub id: IdentifierExpression<'ast>, #[pest_ast(outer())] pub span: Span<'ast>, } @@ -1083,8 +1083,8 @@ mod tests { let source = r#"import "heyman" as yo struct Foo { - foo: field[2], - bar: Bar + field[2] foo + Bar bar } def main(private field[23] a) -> (bool[234 + 6]): From 3b0b6959e343c6050cfd7470d7e42f473d1b2ef2 Mon Sep 17 00:00:00 2001 From: Thibaut Date: Sat, 28 Sep 2019 18:24:29 +0200 Subject: [PATCH 29/35] change struct declaration syntax --- zokrates_cli/examples/book/structs.code | 2 +- zokrates_cli/examples/structs/add.code | 4 +- zokrates_cli/examples/structs/set_member.code | 8 ++-- .../tests/tests/structs/identity.code | 4 +- zokrates_parser/src/lib.rs | 42 +++++++++---------- zokrates_parser/src/zokrates.pest | 4 +- zokrates_pest_ast/src/lib.rs | 6 +-- 7 files changed, 35 insertions(+), 35 deletions(-) diff --git a/zokrates_cli/examples/book/structs.code b/zokrates_cli/examples/book/structs.code index 6abdccce..a08399ba 100644 --- a/zokrates_cli/examples/book/structs.code +++ b/zokrates_cli/examples/book/structs.code @@ -1,6 +1,6 @@ struct Bar { field[2] c - bool b + bool d } struct Foo { diff --git a/zokrates_cli/examples/structs/add.code b/zokrates_cli/examples/structs/add.code index 3941f08f..5d1fd856 100644 --- a/zokrates_cli/examples/structs/add.code +++ b/zokrates_cli/examples/structs/add.code @@ -1,6 +1,6 @@ struct Point { - x: field, - y: field + field x + field y } def main(Point p, Point q) -> (Point): diff --git a/zokrates_cli/examples/structs/set_member.code b/zokrates_cli/examples/structs/set_member.code index 99872793..7cca5707 100644 --- a/zokrates_cli/examples/structs/set_member.code +++ b/zokrates_cli/examples/structs/set_member.code @@ -1,11 +1,11 @@ struct Bar { - c: field[2], - d: bool + field[2] c + bool d } struct Foo { - a: Bar, - b: bool + Bar a + bool b } def main() -> (Foo): diff --git a/zokrates_core_test/tests/tests/structs/identity.code b/zokrates_core_test/tests/tests/structs/identity.code index 041e79ed..e5441761 100644 --- a/zokrates_core_test/tests/tests/structs/identity.code +++ b/zokrates_core_test/tests/tests/structs/identity.code @@ -1,6 +1,6 @@ struct A { - a: field, - b: bool + field a + bool b } def main(A a) -> (A): diff --git a/zokrates_parser/src/lib.rs b/zokrates_parser/src/lib.rs index e0383a84..be385565 100644 --- a/zokrates_parser/src/lib.rs +++ b/zokrates_parser/src/lib.rs @@ -149,40 +149,40 @@ mod tests { fn parse_struct_def() { parses_to! { parser: ZoKratesParser, - input: "struct Foo { foo: field, bar: field[2] } + input: "struct Foo { field foo\n field[2] bar } ", rule: Rule::ty_struct_definition, tokens: [ - ty_struct_definition(0, 41, [ + ty_struct_definition(0, 39, [ identifier(7, 10), - struct_field(13, 23, [ - identifier(13, 16), - ty(18, 23, [ - ty_basic(18, 23, [ - ty_field(18, 23) + struct_field(13, 22, [ + ty(13, 18, [ + ty_basic(13, 18, [ + ty_field(13, 18) ]) - ]) + ]), + identifier(19, 22) ]), - struct_field(25, 39, [ - identifier(25, 28), - ty(30, 39, [ - ty_array(30, 39, [ - ty_basic_or_struct(30, 35, [ - ty_basic(30, 35, [ - ty_field(30, 35) + struct_field(24, 36, [ + ty(24, 33, [ + ty_array(24, 33, [ + ty_basic_or_struct(24, 29, [ + ty_basic(24, 29, [ + ty_field(24, 29) ]) ]), - expression(36, 37, [ - term(36, 37, [ - primary_expression(36, 37, [ - constant(36, 37, [ - decimal_number(36, 37) + expression(30, 31, [ + term(30, 31, [ + primary_expression(30, 31, [ + constant(30, 31, [ + decimal_number(30, 31) ]) ]) ]) ]) ]) - ]) + ]), + identifier(33, 36) ]) ]) ] diff --git a/zokrates_parser/src/zokrates.pest b/zokrates_parser/src/zokrates.pest index 0c054478..40bc09dc 100644 --- a/zokrates_parser/src/zokrates.pest +++ b/zokrates_parser/src/zokrates.pest @@ -26,8 +26,8 @@ type_list = _{(ty ~ ("," ~ ty)*)?} ty_struct = { identifier } // type definitions ty_struct_definition = { "struct" ~ identifier ~ "{" ~ NEWLINE* ~ struct_field_list ~ NEWLINE* ~ "}" ~ NEWLINE* } -struct_field_list = _{(struct_field ~ ("," ~ NEWLINE* ~ struct_field)*)? ~ ","? } -struct_field = { identifier ~ ":" ~ ty } +struct_field_list = _{(struct_field ~ (NEWLINE+ ~ struct_field)*)? } +struct_field = { ty ~ identifier } vis_private = {"private"} vis_public = {"public"} diff --git a/zokrates_pest_ast/src/lib.rs b/zokrates_pest_ast/src/lib.rs index 164d7962..e86541be 100644 --- a/zokrates_pest_ast/src/lib.rs +++ b/zokrates_pest_ast/src/lib.rs @@ -178,8 +178,8 @@ mod ast { #[derive(Debug, FromPest, PartialEq, Clone)] #[pest_ast(rule(Rule::struct_field))] pub struct StructField<'ast> { - pub id: IdentifierExpression<'ast>, pub ty: Type<'ast>, + pub id: IdentifierExpression<'ast>, #[pest_ast(outer())] pub span: Span<'ast>, } @@ -1083,8 +1083,8 @@ mod tests { let source = r#"import "heyman" as yo struct Foo { - foo: field[2], - bar: Bar + field[2] foo + Bar bar } def main(private field[23] a) -> (bool[234 + 6]): From f305ded6464fd74dc7e77756587041d623e4706f Mon Sep 17 00:00:00 2001 From: Thibaut Date: Mon, 30 Sep 2019 13:49:54 +0200 Subject: [PATCH 30/35] allow newlines in inline defs --- zokrates_cli/examples/structs/add.code | 5 ++++- zokrates_cli/examples/structs/set_member.code | 17 ++++++++++++++++- zokrates_parser/src/zokrates.pest | 4 ++-- 3 files changed, 22 insertions(+), 4 deletions(-) diff --git a/zokrates_cli/examples/structs/add.code b/zokrates_cli/examples/structs/add.code index 5d1fd856..388ce158 100644 --- a/zokrates_cli/examples/structs/add.code +++ b/zokrates_cli/examples/structs/add.code @@ -10,4 +10,7 @@ def main(Point p, Point q) -> (Point): field dpxpyqxqy = d * p.x * p.y * q.x * q.y - return Point { x: (p.x * q.y + q.x * p.y) / (1 + dpxpyqxqy) , y: (q.x * q.y - a * p.x * p.y) / (1 - dpxpyqxqy) } + return Point { + x: (p.x * q.y + q.x * p.y) / (1 + dpxpyqxqy), + y: (q.x * q.y - a * p.x * p.y) / (1 - dpxpyqxqy) + } diff --git a/zokrates_cli/examples/structs/set_member.code b/zokrates_cli/examples/structs/set_member.code index 7cca5707..be326580 100644 --- a/zokrates_cli/examples/structs/set_member.code +++ b/zokrates_cli/examples/structs/set_member.code @@ -9,6 +9,21 @@ struct Foo { } def main() -> (Foo): - Foo[2] f = [Foo { a: Bar { c: [0, 0], d: false }, b: true}, Foo { a: Bar {c: [0, 0], d: false}, b: true}] + Foo[2] f = [ + Foo { + a: Bar { + c: [0, 0], + d: false + }, + b: true + }, + Foo { + a: Bar { + c: [0, 0], + d: false + }, + b: true + } + ] f[0].a.c = [42, 43] return f[0] diff --git a/zokrates_parser/src/zokrates.pest b/zokrates_parser/src/zokrates.pest index 40bc09dc..7b62f2e6 100644 --- a/zokrates_parser/src/zokrates.pest +++ b/zokrates_parser/src/zokrates.pest @@ -79,8 +79,8 @@ inline_struct_expression = { identifier ~ "{" ~ NEWLINE* ~ inline_struct_member_ inline_struct_member_list = _{(inline_struct_member ~ ("," ~ NEWLINE* ~ inline_struct_member)*)? ~ ","? } inline_struct_member = { identifier ~ ":" ~ expression } -inline_array_expression = { "[" ~ inline_array_inner ~ "]" } -inline_array_inner = _{(spread_or_expression ~ ("," ~ spread_or_expression)*)?} +inline_array_expression = { "[" ~ NEWLINE* ~ inline_array_inner ~ NEWLINE* ~ "]" } +inline_array_inner = _{(spread_or_expression ~ ("," ~ NEWLINE* ~ spread_or_expression)*)?} spread_or_expression = { spread | expression } range_or_expression = { range | expression } array_initializer_expression = { "[" ~ expression ~ ";" ~ constant ~ "]" } From 5ea2fe92ea1c8c3d060e041e806b76b4307f00cc Mon Sep 17 00:00:00 2001 From: Thibaut Date: Wed, 2 Oct 2019 19:39:51 +0200 Subject: [PATCH 31/35] add import docs --- zokrates_book/src/concepts/imports.md | 57 ++++++++++++++++++++++----- 1 file changed, 48 insertions(+), 9 deletions(-) diff --git a/zokrates_book/src/concepts/imports.md b/zokrates_book/src/concepts/imports.md index 25d3dcb4..2472030a 100644 --- a/zokrates_book/src/concepts/imports.md +++ b/zokrates_book/src/concepts/imports.md @@ -1,27 +1,66 @@ ## Imports -You can separate your code into multiple ZoKrates files using `import` statements, ignoring the `.zok` extension of the imported file: +You can separate your code into multiple ZoKrates files using `import` statements to import symbols, ignoring the `.zok` extension of the imported file. + +### Import syntax + +#### Symbol selection + +The preferred way to import a symbol is by module and name: +```zokrates +from "./path/to/my/module" import MySymbol + +// `MySymbol` is now in scope. +``` + +#### Aliasing + +The `as` keyword enables renaming symbols: + +```zokrates +from "./path/to/my/module" import MySymbol as MyAlias + +// `MySymbol` is now in scope under the alias MyAlias. +``` +#### Legacy + +The legacy way to import a symbol is by only specifying a module: +``` +import "./path/to/my/module" +``` +In this case, the name of the symbol is assumed to be `main` and the alias is assumed to be the module's filename so that the above is equivalent to +```zokrates +from "./path/to/my/module" import main as module + +// `main` is now in scope under the alias `module`. +``` + +Note that this legacy method is likely to be become deprecated, so it is recommended to use the preferred way instead. +### Symbols + +Two type of symbols can be imported + +#### Functions +Functions are imported by name. If many functions have the same name but different signatures, all of them get imported, and which one to use in a particular call is infered. + +#### User-defined types +User-defined types declared with the `struct` keyword are imported by name. ### Relative Imports You can import a resource in the same folder directly, like this: ```zokrates -import "./mycode" +from "./mycode" import foo ``` There also is a handy syntax to import from the parent directory: ```zokrates -import "../mycode" +from "../mycode" import foo ``` Also imports further up the file-system are supported: ```zokrates -import "../../../mycode" -``` - -You can also choose to rename the imported resource, like so: -```zokrates -import "./mycode" as abc +from "../../../mycode" import foo ``` ### Absolute Imports From bab73384c17422057768dac2b2e1e68007f6139d Mon Sep 17 00:00:00 2001 From: Thibaut Date: Wed, 2 Oct 2019 20:08:06 +0200 Subject: [PATCH 32/35] fix access ssa --- zokrates_core/src/static_analysis/unroll.rs | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/zokrates_core/src/static_analysis/unroll.rs b/zokrates_core/src/static_analysis/unroll.rs index d93179fd..5180045c 100644 --- a/zokrates_core/src/static_analysis/unroll.rs +++ b/zokrates_core/src/static_analysis/unroll.rs @@ -341,7 +341,10 @@ impl<'ast, T: Field> Folder<'ast, T> for Unroller<'ast> { let base = self.fold_expression(base); let indices = indices .into_iter() - .map(|i| self.fold_field_expression(i)) + .map(|a| match a { + Access::Select(i) => Access::Select(self.fold_field_expression(i)), + a => a, + }) .collect(); let mut range_checks = HashSet::new(); From 5a5dba7e72dab142c9fb55d90a6d6d9480bbbdbc Mon Sep 17 00:00:00 2001 From: Thibaut Schaeffer Date: Mon, 7 Oct 2019 11:49:03 +0900 Subject: [PATCH 33/35] Apply suggestions from code review Co-Authored-By: Stefan --- zokrates_book/src/concepts/types.md | 4 ++-- zokrates_core/src/absy/mod.rs | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/zokrates_book/src/concepts/types.md b/zokrates_book/src/concepts/types.md index 27f2259d..67ad89f9 100644 --- a/zokrates_book/src/concepts/types.md +++ b/zokrates_book/src/concepts/types.md @@ -73,7 +73,7 @@ Consider the following example: ZoKrates provides some syntactic sugar to retrieve subsets of arrays. ##### Spreads -The spread operator `...` applied to an copies the elements of an existing array. +The spread operator `...` applied to an array copies the elements of the existing array. This can be used to conveniently compose new arrays, as shown in the following example: ``` field[3] = [1, 2, 3] @@ -124,4 +124,4 @@ The variables within a struct instance, the so called members, can be accessed t ```zokrates {{#include ../../../zokrates_cli/examples/book/struct_assign.code}} -``` \ No newline at end of file +``` diff --git a/zokrates_core/src/absy/mod.rs b/zokrates_core/src/absy/mod.rs index 39028d6c..c2e1e283 100644 --- a/zokrates_core/src/absy/mod.rs +++ b/zokrates_core/src/absy/mod.rs @@ -32,7 +32,7 @@ pub type ModuleId = String; /// A collection of `Module`s pub type Modules<'ast, T> = HashMap>; -/// A collection of `SymbolDeclaration`. Duplicates are allowed here as they are fine syntatically. +/// A collection of `SymbolDeclaration`. Duplicates are allowed here as they are fine syntactically. pub type Declarations<'ast, T> = Vec>; /// A `Program` is a collection of `Module`s and an id of the main `Module` From 5040e5b7bf4303c8b5294756c073e8ee11b343a1 Mon Sep 17 00:00:00 2001 From: Thibaut Date: Mon, 7 Oct 2019 11:49:41 +0900 Subject: [PATCH 34/35] tweaks following review --- out.code | 6 ------ zokrates_core/src/absy/from_ast.rs | 2 +- .../src/static_analysis/constrain_inputs.rs | 18 ++++++++++++++++-- zokrates_core/src/typed_absy/mod.rs | 2 +- 4 files changed, 18 insertions(+), 10 deletions(-) delete mode 100644 out.code diff --git a/out.code b/out.code deleted file mode 100644 index 9680b0df..00000000 --- a/out.code +++ /dev/null @@ -1,6 +0,0 @@ -def main() -> (4): - (1 * ~one) * (42 * ~one) == 1 * ~out_0 - (1 * ~one) * (43 * ~one) == 1 * ~out_1 - (1 * ~one) * (0) == 1 * ~out_2 - (1 * ~one) * (1 * ~one) == 1 * ~out_3 - return ~out_0, ~out_1, ~out_2, ~out_3 diff --git a/zokrates_core/src/absy/from_ast.rs b/zokrates_core/src/absy/from_ast.rs index c1852208..44532609 100644 --- a/zokrates_core/src/absy/from_ast.rs +++ b/zokrates_core/src/absy/from_ast.rs @@ -54,7 +54,7 @@ impl<'ast, T: Field> From> for absy::SymbolDeclarat .map(|f| absy::StructFieldNode::from(f)) .collect(), } - .span(span.clone()); // TODO check + .span(span.clone()); absy::SymbolDeclaration { id, diff --git a/zokrates_core/src/static_analysis/constrain_inputs.rs b/zokrates_core/src/static_analysis/constrain_inputs.rs index 38995c30..0fa6dfbd 100644 --- a/zokrates_core/src/static_analysis/constrain_inputs.rs +++ b/zokrates_core/src/static_analysis/constrain_inputs.rs @@ -1,4 +1,5 @@ //! Add runtime boolean checks on user inputs +//! //! Example: //! ```zokrates //! struct Foo { @@ -9,9 +10,22 @@ //! f.bar == f.bar && f.bar //! return //! ``` -//! @file unroll.rs +//! +//! Becomes +//! +//! ```zokrates +//! struct Foo { +//! bar: bool +//! } +//! +//! def main(Foo f) -> (): +//! f.bar == f.bar && f.bar +//! return +//! ``` +//! +//! @file constrain_inputs.rs //! @author Thibaut Schaeffer -//! @date 2018 +//! @date 2019 use crate::typed_absy::folder::Folder; use crate::typed_absy::types::Type; diff --git a/zokrates_core/src/typed_absy/mod.rs b/zokrates_core/src/typed_absy/mod.rs index 820743d1..0602649f 100644 --- a/zokrates_core/src/typed_absy/mod.rs +++ b/zokrates_core/src/typed_absy/mod.rs @@ -74,7 +74,7 @@ impl<'ast, T: Field> fmt::Display for TypedProgram<'ast, T> { } } -/// A +/// A typed program as a collection of functions. Types have been resolved during semantic checking. #[derive(PartialEq, Clone)] pub struct TypedModule<'ast, T: Field> { /// Functions of the program From 65177026069a8f10a7d6c6db9d1658e4fd0a33a4 Mon Sep 17 00:00:00 2001 From: Thibaut Date: Mon, 7 Oct 2019 11:56:16 +0900 Subject: [PATCH 35/35] remove rustfmt temporarilly as it didnt ship --- .circleci/config.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index c2235d45..1c8419a5 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -42,9 +42,9 @@ jobs: - restore_cache: keys: - v4-cargo-cache-{{ arch }}-{{ checksum "Cargo.lock" }} - - run: - name: Check format - command: rustup component add rustfmt; cargo fmt --all -- --check + # - run: + # name: Check format + # command: rustup component add rustfmt; cargo fmt --all -- --check - run: name: Install libsnark prerequisites command: ./scripts/install_libsnark_prerequisites.sh