diff --git a/zokrates_abi/src/lib.rs b/zokrates_abi/src/lib.rs index 1a158de9..20a6e9da 100644 --- a/zokrates_abi/src/lib.rs +++ b/zokrates_abi/src/lib.rs @@ -404,7 +404,7 @@ mod tests { mod strict { use super::*; - use zokrates_core::typed_absy::types::StructMember; + use zokrates_core::typed_absy::types::{StructMember, StructType}; #[test] fn fields() { @@ -449,10 +449,11 @@ mod tests { assert_eq!( parse_strict::( s, - vec![Type::Struct(vec![StructMember::new( - "a".into(), - Type::FieldElement - )])] + vec![Type::Struct(StructType::new( + "".into(), + "".into(), + vec![StructMember::new("a".into(), Type::FieldElement)] + ))] ) .unwrap(), CheckedValues(vec![CheckedValue::Struct( @@ -466,10 +467,11 @@ mod tests { assert_eq!( parse_strict::( s, - vec![Type::Struct(vec![StructMember::new( - "a".into(), - Type::FieldElement - )])] + vec![Type::Struct(StructType::new( + "".into(), + "".into(), + vec![StructMember::new("a".into(), Type::FieldElement)] + ))] ) .unwrap_err(), Error::Type("Member with id `a` not found".into()) @@ -479,10 +481,11 @@ mod tests { assert_eq!( parse_strict::( s, - vec![Type::Struct(vec![StructMember::new( - "a".into(), - Type::FieldElement - )])] + vec![Type::Struct(StructType::new( + "".into(), + "".into(), + vec![StructMember::new("a".into(), Type::FieldElement)] + ))] ) .unwrap_err(), Error::Type("Expected 1 member(s), found 0".into()) @@ -492,10 +495,11 @@ mod tests { assert_eq!( parse_strict::( s, - vec![Type::Struct(vec![StructMember::new( - "a".into(), - Type::FieldElement - )])] + vec![Type::Struct(StructType::new( + "".into(), + "".into(), + vec![StructMember::new("a".into(), Type::FieldElement)] + ))] ) .unwrap_err(), Error::Type("Value `false` doesn't match expected type `field`".into()) diff --git a/zokrates_book/src/reference/abi.md b/zokrates_book/src/reference/abi.md index bfd03c05..0e0c9b07 100644 --- a/zokrates_book/src/reference/abi.md +++ b/zokrates_book/src/reference/abi.md @@ -5,17 +5,7 @@ In order to interact programatically with compiled ZoKrates programs, ZoKrates s To illustrate this, we'll use the following example program: ``` -struct Bar { - field a -} - -struct Foo { - field a - Bar b -} - -def main(private Foo foo, bool[2] bar, field num) -> (field): - return 42 +{{#include ../../../zokrates_cli/examples/book/abi.zok}} ``` ## ABI specification @@ -26,48 +16,54 @@ In this example, the ABI specification is: ```json { - "inputs": [ - { - "name": "foo", - "public": false, - "type": "struct", - "components": [ - { - "name": "a", - "type": "field" - }, - { - "name": "b", - "type": "struct", - "components": [ + "inputs":[ + { + "name":"foo", + "public":true, + "type":"struct", + "components":{ + "name":"Foo", + "members":[ + { + "name":"a", + "type":"field" + }, + { + "name":"b", + "type":"struct", + "components":{ + "name":"Bar", + "members":[ { - "name": "a", - "type": "field" + "name":"a", + "type":"field" } - ] - } + ] + } + } ] - }, - { - "name": "bar", - "public": "true", - "type": "array", - "components": { - "size": 2, - "type": "bool" - } - }, - { - "name": "num", - "public": "true", - "type": "field" - } - ], - "outputs": [ - { - "type": "field" - } - ] + } + }, + { + "name":"bar", + "public":true, + "type":"array", + "components":{ + "size":2, + "type":"bool" + } + }, + { + "name":"num", + "public":true, + "type":"field" + } + ], + "outputs":[ + { + "type":"field" + } + ] } ``` @@ -78,19 +74,20 @@ When executing a program, arguments can be passed as a JSON object of the follow ```json [ - { - "a": "42", - "b": - { - "a": "42" - } - }, - [ - true, - false - ], - "42" + { + "a":"42", + "b":{ + "a":"42" + } + }, + [ + true, + false + ], + "42" ] ``` -Note that field elements are passed as JSON strings in order to support arbitrary large numbers. \ No newline at end of file +Note the following: +- Field elements are passed as JSON strings in order to support arbitrary large numbers. +- Structs are passed as JSON objects, ignoring the struct name \ No newline at end of file diff --git a/zokrates_cli/examples/book/abi.zok b/zokrates_cli/examples/book/abi.zok new file mode 100644 index 00000000..84969792 --- /dev/null +++ b/zokrates_cli/examples/book/abi.zok @@ -0,0 +1,11 @@ +struct Bar { + field a +} + +struct Foo { + field a + Bar b +} + +def main(Foo foo, bool[2] bar, field num) -> (field): + return 42 \ No newline at end of file diff --git a/zokrates_cli/examples/error/struct_if_else.zok b/zokrates_cli/examples/error/struct_if_else.zok new file mode 100644 index 00000000..73f81980 --- /dev/null +++ b/zokrates_cli/examples/error/struct_if_else.zok @@ -0,0 +1,5 @@ +struct Foo {} +struct Bar {} + +def main() -> (Foo): + return Bar {} \ No newline at end of file diff --git a/zokrates_core/src/absy/from_ast.rs b/zokrates_core/src/absy/from_ast.rs index 0e6eb860..ecfae3b4 100644 --- a/zokrates_core/src/absy/from_ast.rs +++ b/zokrates_core/src/absy/from_ast.rs @@ -33,7 +33,12 @@ impl<'ast> From> for absy::ImportNode<'ast> { Some(import.symbol.span.as_str()), std::path::Path::new(import.source.span.as_str()), ) - .alias(import.alias.map(|a| a.span.as_str())) + .alias( + import + .alias + .map(|a| a.span.as_str()) + .or(Some(import.symbol.span.as_str())), + ) .span(import.span), } } @@ -47,11 +52,11 @@ impl<'ast, T: Field> From> for absy::SymbolDeclarat let id = definition.id.span.as_str(); - let ty = absy::StructType { + let ty = absy::StructDefinition { fields: definition .fields .into_iter() - .map(|f| absy::StructFieldNode::from(f)) + .map(|f| absy::StructDefinitionFieldNode::from(f)) .collect(), } .span(span.clone()); @@ -64,8 +69,8 @@ impl<'ast, T: Field> From> for absy::SymbolDeclarat } } -impl<'ast> From> for absy::StructFieldNode<'ast> { - fn from(field: pest::StructField<'ast>) -> absy::StructFieldNode { +impl<'ast> From> for absy::StructDefinitionFieldNode<'ast> { + fn from(field: pest::StructField<'ast>) -> absy::StructDefinitionFieldNode { use absy::NodeValue; let span = field.span; @@ -74,7 +79,7 @@ impl<'ast> From> for absy::StructFieldNode<'ast> { let ty = absy::UnresolvedTypeNode::from(field.ty); - absy::StructField { id, ty }.span(span) + absy::StructDefinitionField { id, ty }.span(span) } } diff --git a/zokrates_core/src/absy/mod.rs b/zokrates_core/src/absy/mod.rs index 16ccc9c2..0101e387 100644 --- a/zokrates_core/src/absy/mod.rs +++ b/zokrates_core/src/absy/mod.rs @@ -51,7 +51,7 @@ pub struct SymbolDeclaration<'ast, T> { #[derive(PartialEq, Clone)] pub enum Symbol<'ast, T> { - HereType(StructTypeNode<'ast>), + HereType(StructDefinitionNode<'ast>), HereFunction(FunctionNode<'ast, T>), There(SymbolImportNode<'ast>), Flat(FlatEmbed), @@ -109,11 +109,11 @@ pub type UnresolvedTypeNode = Node; /// A struct type definition #[derive(Debug, Clone, PartialEq)] -pub struct StructType<'ast> { - pub fields: Vec>, +pub struct StructDefinition<'ast> { + pub fields: Vec>, } -impl<'ast> fmt::Display for StructType<'ast> { +impl<'ast> fmt::Display for StructDefinition<'ast> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!( f, @@ -127,22 +127,22 @@ impl<'ast> fmt::Display for StructType<'ast> { } } -pub type StructTypeNode<'ast> = Node>; +pub type StructDefinitionNode<'ast> = Node>; /// A struct type definition #[derive(Debug, Clone, PartialEq)] -pub struct StructField<'ast> { +pub struct StructDefinitionField<'ast> { pub id: Identifier<'ast>, pub ty: UnresolvedTypeNode, } -impl<'ast> fmt::Display for StructField<'ast> { +impl<'ast> fmt::Display for StructDefinitionField<'ast> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "{}: {},", self.id, self.ty) } } -type StructFieldNode<'ast> = Node>; +type StructDefinitionFieldNode<'ast> = Node>; /// An import #[derive(Debug, Clone, PartialEq)] diff --git a/zokrates_core/src/absy/node.rs b/zokrates_core/src/absy/node.rs index 21837f39..1cbc059d 100644 --- a/zokrates_core/src/absy/node.rs +++ b/zokrates_core/src/absy/node.rs @@ -82,8 +82,8 @@ impl<'ast, T: fmt::Display + fmt::Debug + PartialEq> NodeValue for Assignee<'ast impl<'ast, T: fmt::Display + fmt::Debug + PartialEq> NodeValue for Statement<'ast, T> {} impl<'ast, T: Field> NodeValue for SymbolDeclaration<'ast, T> {} impl NodeValue for UnresolvedType {} -impl<'ast> NodeValue for StructType<'ast> {} -impl<'ast> NodeValue for StructField<'ast> {} +impl<'ast> NodeValue for StructDefinition<'ast> {} +impl<'ast> NodeValue for StructDefinitionField<'ast> {} impl<'ast, T: fmt::Display + fmt::Debug + PartialEq> NodeValue for Function<'ast, T> {} impl<'ast, T: Field> NodeValue for Module<'ast, T> {} impl<'ast> NodeValue for SymbolImport<'ast> {} diff --git a/zokrates_core/src/compile.rs b/zokrates_core/src/compile.rs index cc082470..977f9a56 100644 --- a/zokrates_core/src/compile.rs +++ b/zokrates_core/src/compile.rs @@ -285,4 +285,111 @@ mod test { ); assert!(res.is_ok()); } + + mod abi { + use super::*; + use typed_absy::abi::*; + use typed_absy::types::*; + + #[test] + fn use_struct_declaration_types() { + // when importing types and renaming them, we use the top-most renaming in the ABI + + // // main.zok + // from foo import Foo as FooMain + // + // // foo.zok + // from bar import Bar as BarFoo + // struct Foo { BarFoo b } + // + // // bar.zok + // struct Bar { field a } + + // Expected resolved type for FooMain: + // FooMain { BarFoo b } + + let main = r#" +from "foo" import Foo as FooMain +def main(FooMain f) -> (): + return +"#; + + struct CustomResolver; + + impl Resolver for CustomResolver { + fn resolve( + &self, + _: PathBuf, + import_location: PathBuf, + ) -> Result<(String, PathBuf), E> { + let loc = import_location.display().to_string(); + if loc == "main" { + Ok(( + r#" +from "foo" import Foo as FooMain +def main(FooMain f) -> (): + return +"# + .into(), + "main".into(), + )) + } else if loc == "foo" { + Ok(( + r#" +from "bar" import Bar as BarFoo +struct Foo { + BarFoo b +} +"# + .into(), + "foo".into(), + )) + } else if loc == "bar" { + Ok(( + r#" +struct Bar { field a } +"# + .into(), + "bar".into(), + )) + } else { + unreachable!() + } + } + } + + let artifacts = compile::( + main.to_string(), + "main".into(), + Some(&CustomResolver), + ) + .unwrap(); + + assert_eq!( + artifacts.abi, + Abi { + inputs: vec![AbiInput { + name: "f".into(), + public: true, + ty: Type::Struct(StructType { + module: "main".into(), + name: "FooMain".into(), + members: vec![StructMember { + id: "b".into(), + ty: box Type::Struct(StructType { + module: "foo".into(), + name: "BarFoo".into(), + members: vec![StructMember { + id: "a".into(), + ty: box Type::FieldElement + }] + }) + }] + }) + }], + outputs: vec![] + } + ); + } + } } diff --git a/zokrates_core/src/semantics.rs b/zokrates_core/src/semantics.rs index 8e481828..e6ebe060 100644 --- a/zokrates_core/src/semantics.rs +++ b/zokrates_core/src/semantics.rs @@ -268,7 +268,8 @@ impl<'ast> Checker<'ast> { fn check_struct_type_declaration( &mut self, - s: StructTypeNode<'ast>, + id: String, + s: StructDefinitionNode<'ast>, module_id: &ModuleId, types: &TypeMap, ) -> Result> { @@ -302,12 +303,14 @@ impl<'ast> Checker<'ast> { return Err(errors); } - Ok(Type::Struct( + Ok(Type::Struct(StructType::new( + module_id.into(), + id, fields .iter() .map(|f| StructMember::new(f.0.clone(), f.1.clone())) .collect(), - )) + ))) } fn check_symbol_declaration( @@ -323,9 +326,14 @@ impl<'ast> Checker<'ast> { let pos = declaration.pos(); let declaration = declaration.value; - match declaration.symbol { + match declaration.symbol.clone() { Symbol::HereType(t) => { - match self.check_struct_type_declaration(t.clone(), module_id, &state.types) { + match self.check_struct_type_declaration( + declaration.id.to_string(), + t.clone(), + module_id, + &state.types, + ) { Ok(ty) => { match symbol_unifier.insert_type(declaration.id) { false => errors.push( @@ -412,6 +420,17 @@ impl<'ast> Checker<'ast> { match (function_candidates.len(), type_candidate) { (0, Some(t)) => { + + // rename the type to the declared symbol + let t = match t { + Type::Struct(t) => Type::Struct(StructType { + module: module_id.clone(), + name: declaration.id.into(), + ..t + }), + _ => unreachable!() + }; + // we imported a type, so the symbol it gets bound to should not already exist match symbol_unifier.insert_type(declaration.id) { false => { @@ -431,7 +450,7 @@ impl<'ast> Checker<'ast> { .types .entry(module_id.clone()) .or_default() - .insert(import.symbol_id.to_string(), t.clone()); + .insert(declaration.id.to_string(), t.clone()); } (0, None) => { errors.push(ErrorInner { @@ -1359,12 +1378,8 @@ impl<'ast> Checker<'ast> { Ok(ArrayExpressionInner::IfElse(box condition, box consequence, box alternative).annotate(inner_type, size).into()) }, (TypedExpression::Struct(consequence), TypedExpression::Struct(alternative)) => { - if consequence.get_type() == alternative.get_type() { - let ty = consequence.ty().clone(); - Ok(StructExpressionInner::IfElse(box condition, box consequence, box alternative).annotate(ty).into()) - } else { - unimplemented!("handle consequence alternative inner type mismatch") - } + let ty = consequence.ty().clone(); + Ok(StructExpressionInner::IfElse(box condition, box consequence, box alternative).annotate(ty).into()) }, (TypedExpression::Uint(consequence), TypedExpression::Uint(alternative)) => { let bitwidth = consequence.bitwidth(); @@ -1937,21 +1952,20 @@ impl<'ast> Checker<'ast> { module_id, &types, )?; - let members = match ty { - Type::Struct(members) => members, + let struct_type = match ty { + Type::Struct(struct_type) => struct_type, _ => unreachable!(), }; // check that we provided the required number of values - if members.len() != inline_members.len() { + if struct_type.len() != inline_members.len() { return Err(ErrorInner { pos: Some(pos), message: format!( - "Inline struct {} does not match {} : {}", + "Inline struct {} does not match {}", Expression::InlineStruct(id.clone(), inline_members), - id, - Type::Struct(members) + Type::Struct(struct_type) ), }); } @@ -1966,7 +1980,7 @@ impl<'ast> Checker<'ast> { .collect::>(); let mut result: Vec> = vec![]; - for member in &members { + for member in struct_type.iter() { match inline_members_map.remove(member.id.as_str()) { Some(value) => { let expression_checked = @@ -1992,10 +2006,9 @@ impl<'ast> Checker<'ast> { return Err(ErrorInner { pos: Some(pos), message: format!( - "Member {} of struct {} : {} not found in value {}", + "Member {} of struct {} not found in value {}", member.id, - id.clone(), - Type::Struct(members.clone()), + Type::Struct(struct_type.clone()), Expression::InlineStruct(id.clone(), inline_members), ), }) @@ -2004,7 +2017,7 @@ impl<'ast> Checker<'ast> { } Ok(StructExpressionInner::Value(result) - .annotate(members) + .annotate(struct_type) .into()) } Expression::And(box e1, box e2) => { @@ -2320,13 +2333,13 @@ mod tests { .mock() } - fn struct0() -> StructTypeNode<'static> { - StructType { fields: vec![] }.mock() + fn struct0() -> StructDefinitionNode<'static> { + StructDefinition { fields: vec![] }.mock() } - fn struct1() -> StructTypeNode<'static> { - StructType { - fields: vec![StructField { + fn struct1() -> StructDefinitionNode<'static> { + StructDefinition { + fields: vec![StructDefinitionField { id: "foo".into(), ty: UnresolvedType::FieldElement.mock(), } @@ -2554,7 +2567,7 @@ mod tests { .mock(), SymbolDeclaration { id: "foo", - symbol: Symbol::HereType(StructType { fields: vec![] }.mock()), + symbol: Symbol::HereType(StructDefinition { fields: vec![] }.mock()), } .mock(), ], @@ -3855,7 +3868,7 @@ mod tests { /// solver function to create a module at location "" with a single symbol `Foo { foo: field }` fn create_module_with_foo( - s: StructType<'static>, + s: StructDefinition<'static>, ) -> (Checker<'static>, State<'static, Bn128Field>) { let module_id: PathBuf = "".into(); @@ -3886,12 +3899,17 @@ mod tests { // an empty struct should be allowed to be defined let module_id = "".into(); let types = HashMap::new(); - let declaration = StructType { fields: vec![] }.mock(); + let declaration = StructDefinition { fields: vec![] }.mock(); - let expected_type = Type::Struct(vec![]); + let expected_type = Type::Struct(StructType::new("".into(), "Foo".into(), vec![])); assert_eq!( - Checker::new().check_struct_type_declaration(declaration, &module_id, &types), + Checker::new().check_struct_type_declaration( + "Foo".into(), + declaration, + &module_id, + &types + ), Ok(expected_type) ); } @@ -3901,14 +3919,14 @@ mod tests { // a valid struct should be allowed to be defined let module_id = "".into(); let types = HashMap::new(); - let declaration = StructType { + let declaration = StructDefinition { fields: vec![ - StructField { + StructDefinitionField { id: "foo", ty: UnresolvedType::FieldElement.mock(), } .mock(), - StructField { + StructDefinitionField { id: "bar", ty: UnresolvedType::Boolean.mock(), } @@ -3917,13 +3935,22 @@ mod tests { } .mock(); - let expected_type = Type::Struct(vec![ - StructMember::new("foo".into(), Type::FieldElement), - StructMember::new("bar".into(), Type::Boolean), - ]); + let expected_type = Type::Struct(StructType::new( + "".into(), + "Foo".into(), + vec![ + StructMember::new("foo".into(), Type::FieldElement), + StructMember::new("bar".into(), Type::Boolean), + ], + )); assert_eq!( - Checker::new().check_struct_type_declaration(declaration, &module_id, &types), + Checker::new().check_struct_type_declaration( + "Foo".into(), + declaration, + &module_id, + &types + ), Ok(expected_type) ); } @@ -3934,14 +3961,14 @@ mod tests { let module_id = "".into(); let types = HashMap::new(); - let declaration0 = StructType { + let declaration0 = StructDefinition { fields: vec![ - StructField { + StructDefinitionField { id: "foo", ty: UnresolvedType::FieldElement.mock(), } .mock(), - StructField { + StructDefinitionField { id: "bar", ty: UnresolvedType::Boolean.mock(), } @@ -3950,14 +3977,14 @@ mod tests { } .mock(); - let declaration1 = StructType { + let declaration1 = StructDefinition { fields: vec![ - StructField { + StructDefinitionField { id: "bar", ty: UnresolvedType::Boolean.mock(), } .mock(), - StructField { + StructDefinitionField { id: "foo", ty: UnresolvedType::FieldElement.mock(), } @@ -3967,8 +3994,18 @@ mod tests { .mock(); assert_ne!( - Checker::new().check_struct_type_declaration(declaration0, &module_id, &types), - Checker::new().check_struct_type_declaration(declaration1, &module_id, &types) + Checker::new().check_struct_type_declaration( + "Foo".into(), + declaration0, + &module_id, + &types + ), + Checker::new().check_struct_type_declaration( + "Foo".into(), + declaration1, + &module_id, + &types + ) ); } @@ -3978,14 +4015,14 @@ mod tests { let module_id = "".into(); let types = HashMap::new(); - let declaration = StructType { + let declaration = StructDefinition { fields: vec![ - StructField { + StructDefinitionField { id: "foo", ty: UnresolvedType::FieldElement.mock(), } .mock(), - StructField { + StructDefinitionField { id: "foo", ty: UnresolvedType::Boolean.mock(), } @@ -3996,7 +4033,12 @@ mod tests { assert_eq!( Checker::new() - .check_struct_type_declaration(declaration, &module_id, &types) + .check_struct_type_declaration( + "Foo".into(), + declaration, + &module_id, + &types + ) .unwrap_err()[0] .message, "Duplicate key foo in struct definition" @@ -4018,8 +4060,8 @@ mod tests { SymbolDeclaration { id: "Foo", symbol: Symbol::HereType( - StructType { - fields: vec![StructField { + StructDefinition { + fields: vec![StructDefinitionField { id: "foo", ty: UnresolvedType::FieldElement.mock(), } @@ -4032,8 +4074,8 @@ mod tests { SymbolDeclaration { id: "Bar", symbol: Symbol::HereType( - StructType { - fields: vec![StructField { + StructDefinition { + fields: vec![StructDefinitionField { id: "foo", ty: UnresolvedType::User("Foo".into()).mock(), } @@ -4056,10 +4098,18 @@ mod tests { .unwrap() .get(&"Bar".to_string()) .unwrap(), - &Type::Struct(vec![StructMember::new( - "foo".into(), - Type::Struct(vec![StructMember::new("foo".into(), Type::FieldElement)]) - )]) + &Type::Struct(StructType::new( + module_id.clone(), + "Bar".into(), + vec![StructMember::new( + "foo".into(), + Type::Struct(StructType::new( + module_id, + "Foo".into(), + vec![StructMember::new("foo".into(), Type::FieldElement)] + )) + )] + )) ); } @@ -4076,8 +4126,8 @@ mod tests { symbols: vec![SymbolDeclaration { id: "Bar", symbol: Symbol::HereType( - StructType { - fields: vec![StructField { + StructDefinition { + fields: vec![StructDefinitionField { id: "foo", ty: UnresolvedType::User("Foo".into()).mock(), } @@ -4107,8 +4157,8 @@ mod tests { symbols: vec![SymbolDeclaration { id: "Foo", symbol: Symbol::HereType( - StructType { - fields: vec![StructField { + StructDefinition { + fields: vec![StructDefinitionField { id: "foo", ty: UnresolvedType::User("Foo".into()).mock(), } @@ -4140,8 +4190,8 @@ mod tests { SymbolDeclaration { id: "Foo", symbol: Symbol::HereType( - StructType { - fields: vec![StructField { + StructDefinition { + fields: vec![StructDefinitionField { id: "bar", ty: UnresolvedType::User("Bar".into()).mock(), } @@ -4154,8 +4204,8 @@ mod tests { SymbolDeclaration { id: "Bar", symbol: Symbol::HereType( - StructType { - fields: vec![StructField { + StructDefinition { + fields: vec![StructDefinitionField { id: "foo", ty: UnresolvedType::User("Foo".into()).mock(), } @@ -4187,8 +4237,8 @@ mod tests { // an undefined type cannot be checked // Bar - let (checker, state) = create_module_with_foo(StructType { - fields: vec![StructField { + let (checker, state) = create_module_with_foo(StructDefinition { + fields: vec![StructDefinitionField { id: "foo", ty: UnresolvedType::FieldElement.mock(), } @@ -4201,10 +4251,11 @@ mod tests { &PathBuf::from(MODULE_ID).into(), &state.types ), - Ok(Type::Struct(vec![StructMember::new( - "foo".into(), - Type::FieldElement - )])) + Ok(Type::Struct(StructType::new( + "".into(), + "Foo".into(), + vec![StructMember::new("foo".into(), Type::FieldElement)] + ))) ); assert_eq!( @@ -4226,8 +4277,8 @@ mod tests { // an undefined type cannot be used as parameter - let (checker, state) = create_module_with_foo(StructType { - fields: vec![StructField { + let (checker, state) = create_module_with_foo(StructDefinition { + fields: vec![StructDefinitionField { id: "foo", ty: UnresolvedType::FieldElement.mock(), } @@ -4249,10 +4300,11 @@ mod tests { Ok(Parameter { id: Variable::with_id_and_type( "a", - Type::Struct(vec![StructMember::new( - "foo".to_string(), - Type::FieldElement - )]) + Type::Struct(StructType::new( + "".into(), + "Foo".into(), + vec![StructMember::new("foo".into(), Type::FieldElement)] + )) ), private: true }) @@ -4285,8 +4337,8 @@ mod tests { // an undefined type cannot be used in a variable declaration - let (mut checker, state) = create_module_with_foo(StructType { - fields: vec![StructField { + let (mut checker, state) = create_module_with_foo(StructDefinition { + fields: vec![StructDefinitionField { id: "foo", ty: UnresolvedType::FieldElement.mock(), } @@ -4305,10 +4357,11 @@ mod tests { ), Ok(TypedStatement::Declaration(Variable::with_id_and_type( "a", - Type::Struct(vec![StructMember::new( - "foo".to_string(), - Type::FieldElement - )]) + Type::Struct(StructType::new( + "".into(), + "Foo".into(), + vec![StructMember::new("foo".into(), Type::FieldElement)] + )) ))) ); @@ -4345,8 +4398,8 @@ mod tests { // struct Foo = { foo: field } // Foo { foo: 42 }.foo - let (mut checker, state) = create_module_with_foo(StructType { - fields: vec![StructField { + let (mut checker, state) = create_module_with_foo(StructDefinition { + fields: vec![StructDefinitionField { id: "foo", ty: UnresolvedType::FieldElement.mock(), } @@ -4375,7 +4428,11 @@ mod tests { Bn128Field::from(42) ) .into()]) - .annotate(vec![StructMember::new("foo".into(), Type::FieldElement)]), + .annotate(StructType::new( + "".into(), + "Foo".into(), + vec![StructMember::new("foo".into(), Type::FieldElement)] + )), "foo".into() ) .into()) @@ -4389,8 +4446,8 @@ mod tests { // struct Foo = { foo: field } // Foo { foo: 42 }.bar - let (mut checker, state) = create_module_with_foo(StructType { - fields: vec![StructField { + let (mut checker, state) = create_module_with_foo(StructDefinition { + fields: vec![StructDefinitionField { id: "foo", ty: UnresolvedType::FieldElement.mock(), } @@ -4417,7 +4474,7 @@ mod tests { ) .unwrap_err() .message, - "{foo: field} doesn\'t have member bar" + "Foo {foo: field} doesn\'t have member bar" ); } } @@ -4430,8 +4487,8 @@ mod tests { fn wrong_name() { // a A value cannot be defined with B as id, even if A and B have the same members - let (mut checker, state) = create_module_with_foo(StructType { - fields: vec![StructField { + let (mut checker, state) = create_module_with_foo(StructDefinition { + fields: vec![StructDefinitionField { id: "foo", ty: UnresolvedType::FieldElement.mock(), } @@ -4465,14 +4522,14 @@ mod tests { // struct Foo = { foo: field, bar: bool } // Foo foo = Foo { foo: 42, bar: true } - let (mut checker, state) = create_module_with_foo(StructType { + let (mut checker, state) = create_module_with_foo(StructDefinition { fields: vec![ - StructField { + StructDefinitionField { id: "foo", ty: UnresolvedType::FieldElement.mock(), } .mock(), - StructField { + StructDefinitionField { id: "bar", ty: UnresolvedType::Boolean.mock(), } @@ -4500,10 +4557,14 @@ mod tests { FieldElementExpression::Number(Bn128Field::from(42)).into(), BooleanExpression::Value(true).into() ]) - .annotate(vec![ - StructMember::new("foo".into(), Type::FieldElement), - StructMember::new("bar".into(), Type::Boolean) - ]) + .annotate(StructType::new( + "".into(), + "Foo".into(), + vec![ + StructMember::new("foo".into(), Type::FieldElement), + StructMember::new("bar".into(), Type::Boolean) + ] + )) .into()) ); } @@ -4515,14 +4576,14 @@ mod tests { // struct Foo = { foo: field, bar: bool } // Foo foo = Foo { bar: true, foo: 42 } - let (mut checker, state) = create_module_with_foo(StructType { + let (mut checker, state) = create_module_with_foo(StructDefinition { fields: vec![ - StructField { + StructDefinitionField { id: "foo", ty: UnresolvedType::FieldElement.mock(), } .mock(), - StructField { + StructDefinitionField { id: "bar", ty: UnresolvedType::Boolean.mock(), } @@ -4550,10 +4611,14 @@ mod tests { FieldElementExpression::Number(Bn128Field::from(42)).into(), BooleanExpression::Value(true).into() ]) - .annotate(vec![ - StructMember::new("foo".into(), Type::FieldElement), - StructMember::new("bar".into(), Type::Boolean) - ]) + .annotate(StructType::new( + "".into(), + "Foo".into(), + vec![ + StructMember::new("foo".into(), Type::FieldElement), + StructMember::new("bar".into(), Type::Boolean) + ] + )) .into()) ); } @@ -4565,14 +4630,14 @@ mod tests { // struct Foo = { foo: field, bar: bool } // Foo foo = Foo { foo: 42 } - let (mut checker, state) = create_module_with_foo(StructType { + let (mut checker, state) = create_module_with_foo(StructDefinition { fields: vec![ - StructField { + StructDefinitionField { id: "foo", ty: UnresolvedType::FieldElement.mock(), } .mock(), - StructField { + StructDefinitionField { id: "bar", ty: UnresolvedType::Boolean.mock(), } @@ -4596,7 +4661,7 @@ mod tests { ) .unwrap_err() .message, - "Inline struct Foo {foo: 42} does not match Foo : {foo: field, bar: bool}" + "Inline struct Foo {foo: 42} does not match Foo {foo: field, bar: bool}" ); } @@ -4609,14 +4674,14 @@ mod tests { // Foo { foo: 42, baz: bool } // error // Foo { foo: 42, baz: 42 } // error - let (mut checker, state) = create_module_with_foo(StructType { + let (mut checker, state) = create_module_with_foo(StructDefinition { fields: vec![ - StructField { + StructDefinitionField { id: "foo", ty: UnresolvedType::FieldElement.mock(), } .mock(), - StructField { + StructDefinitionField { id: "bar", ty: UnresolvedType::Boolean.mock(), } @@ -4642,7 +4707,7 @@ mod tests { &state.types ).unwrap_err() .message, - "Member bar of struct Foo : {foo: field, bar: bool} not found in value Foo {baz: true, foo: 42}" + "Member bar of struct Foo {foo: field, bar: bool} not found in value Foo {baz: true, foo: 42}" ); assert_eq!( diff --git a/zokrates_core/src/static_analysis/flatten_complex_types.rs b/zokrates_core/src/static_analysis/flatten_complex_types.rs index 05e05035..65dac21d 100644 --- a/zokrates_core/src/static_analysis/flatten_complex_types.rs +++ b/zokrates_core/src/static_analysis/flatten_complex_types.rs @@ -1,5 +1,6 @@ use std::marker::PhantomData; use typed_absy; +use typed_absy::types::StructType; use zir; use zokrates_field::Field; @@ -201,7 +202,7 @@ impl<'ast, T: Field> Flattener { } fn fold_struct_expression_inner( &mut self, - ty: &Vec<(typed_absy::types::MemberId, typed_absy::Type)>, + ty: &StructType, e: typed_absy::StructExpressionInner<'ast, T>, ) -> Vec> { fold_struct_expression_inner(self, ty, e) @@ -360,7 +361,7 @@ pub fn fold_array_expression_inner<'ast, T: Field>( pub fn fold_struct_expression_inner<'ast, T: Field>( f: &mut Flattener, - t: &Vec<(typed_absy::types::MemberId, typed_absy::Type)>, + t: &StructType, e: typed_absy::StructExpressionInner<'ast, T>, ) -> Vec> { match e { @@ -427,9 +428,9 @@ pub fn fold_struct_expression_inner<'ast, T: Field>( // we also need the size of this member let size = t .iter() - .find(|(id, _)| id == id) + .find(|member| member.id == id) .unwrap() - .1 + .ty .get_primitive_count(); s[offset..offset + size].to_vec() @@ -442,7 +443,7 @@ pub fn fold_struct_expression_inner<'ast, T: Field>( zir::FieldElementExpression::Number(i) => { let size = t .iter() - .map(|(_, t)| t.get_primitive_count()) + .map(|m| m.ty.get_primitive_count()) .fold(0, |acc, current| acc + current); let start = i.to_dec_string().parse::().unwrap() * size; let end = start + size; @@ -777,10 +778,7 @@ pub fn fold_struct_expression<'ast, T: Field>( f: &mut Flattener, e: typed_absy::StructExpression<'ast, T>, ) -> Vec> { - f.fold_struct_expression_inner( - &e.ty().clone().into_iter().map(|m| (m.id, *m.ty)).collect(), - e.into_inner(), - ) + f.fold_struct_expression_inner(&e.ty().clone(), e.into_inner()) } pub fn fold_function_symbol<'ast, T: Field>( diff --git a/zokrates_core/src/static_analysis/inline.rs b/zokrates_core/src/static_analysis/inline.rs index 7573417f..8e2ed8f6 100644 --- a/zokrates_core/src/static_analysis/inline.rs +++ b/zokrates_core/src/static_analysis/inline.rs @@ -17,7 +17,7 @@ //! where any call in `main` must be to `_SHA_256_ROUND` or `_UNPACK` use std::collections::HashMap; -use typed_absy::types::{FunctionKey, StructMember, Type}; +use typed_absy::types::{FunctionKey, Type}; use typed_absy::{folder::*, *}; use zokrates_field::Field; @@ -483,7 +483,7 @@ impl<'ast, T: Field> Folder<'ast, T> for Inliner<'ast, T> { fn fold_struct_expression_inner( &mut self, - ty: &Vec, + ty: &StructType, e: StructExpressionInner<'ast, T>, ) -> StructExpressionInner<'ast, T> { match e { diff --git a/zokrates_core/src/static_analysis/propagation.rs b/zokrates_core/src/static_analysis/propagation.rs index f2cfd5d1..5f11d48a 100644 --- a/zokrates_core/src/static_analysis/propagation.rs +++ b/zokrates_core/src/static_analysis/propagation.rs @@ -14,7 +14,7 @@ use crate::typed_absy::folder::*; use crate::typed_absy::*; use std::collections::HashMap; use std::convert::TryFrom; -use typed_absy::types::{StructMember, Type}; +use typed_absy::types::Type; use zokrates_field::Field; pub struct Propagator<'ast, T: Field> { @@ -816,7 +816,7 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> { fn fold_struct_expression_inner( &mut self, - ty: &Vec, + ty: &StructType, e: StructExpressionInner<'ast, T>, ) -> StructExpressionInner<'ast, T> { match e { diff --git a/zokrates_core/src/static_analysis/variable_access_remover.rs b/zokrates_core/src/static_analysis/variable_access_remover.rs index 169a6dc2..1dad2a6c 100644 --- a/zokrates_core/src/static_analysis/variable_access_remover.rs +++ b/zokrates_core/src/static_analysis/variable_access_remover.rs @@ -1,4 +1,3 @@ -use typed_absy::types::StructMember; use typed_absy::{folder::*, *}; use zokrates_field::Field; @@ -103,14 +102,14 @@ impl<'ast, T: Field> Folder<'ast, T> for VariableAccessRemover<'ast, T> { fn fold_struct_expression_inner( &mut self, - members: &Vec, + ty: &StructType, e: StructExpressionInner<'ast, T>, ) -> StructExpressionInner<'ast, T> { match e { StructExpressionInner::Select(box a, box i) => { self.select::>(a, i).into_inner() } - e => fold_struct_expression_inner(self, members, e), + e => fold_struct_expression_inner(self, ty, e), } } diff --git a/zokrates_core/src/typed_absy/abi.rs b/zokrates_core/src/typed_absy/abi.rs index 9dc9ed75..206e802a 100644 --- a/zokrates_core/src/typed_absy/abi.rs +++ b/zokrates_core/src/typed_absy/abi.rs @@ -30,7 +30,7 @@ impl Abi { mod tests { use super::*; use std::collections::HashMap; - use typed_absy::types::{ArrayType, FunctionKey, StructMember}; + use typed_absy::types::{ArrayType, FunctionKey, StructMember, StructType}; use typed_absy::{ Parameter, Type, TypedFunction, TypedFunctionSymbol, TypedModule, TypedProgram, Variable, }; @@ -147,15 +147,23 @@ mod tests { inputs: vec![AbiInput { name: String::from("foo"), public: true, - ty: Type::Struct(vec![ + ty: Type::Struct(StructType::new( + "".into(), + "Foo".into(), + vec![ + StructMember::new(String::from("a"), Type::FieldElement), + StructMember::new(String::from("b"), Type::Boolean), + ], + )), + }], + outputs: vec![Type::Struct(StructType::new( + "".into(), + "Foo".into(), + vec![ StructMember::new(String::from("a"), Type::FieldElement), StructMember::new(String::from("b"), Type::Boolean), - ]), - }], - outputs: vec![Type::Struct(vec![ - StructMember::new(String::from("a"), Type::FieldElement), - StructMember::new(String::from("b"), Type::Boolean), - ])], + ], + ))], }; let json = serde_json::to_string_pretty(&abi).unwrap(); @@ -167,31 +175,37 @@ mod tests { "name": "foo", "public": true, "type": "struct", - "components": [ - { - "name": "a", - "type": "field" - }, - { - "name": "b", - "type": "bool" - } - ] + "components": { + "name": "Foo", + "members": [ + { + "name": "a", + "type": "field" + }, + { + "name": "b", + "type": "bool" + } + ] + } } ], "outputs": [ { "type": "struct", - "components": [ - { - "name": "a", - "type": "field" - }, - { - "name": "b", - "type": "bool" - } - ] + "components": { + "name": "Foo", + "members": [ + { + "name": "a", + "type": "field" + }, + { + "name": "b", + "type": "bool" + } + ] + } } ] }"# @@ -204,13 +218,21 @@ mod tests { inputs: vec![AbiInput { name: String::from("foo"), public: true, - ty: Type::Struct(vec![StructMember::new( - String::from("bar"), - Type::Struct(vec![ - StructMember::new(String::from("a"), Type::FieldElement), - StructMember::new(String::from("b"), Type::FieldElement), - ]), - )]), + ty: Type::Struct(StructType::new( + "".into(), + "Foo".into(), + vec![StructMember::new( + String::from("bar"), + Type::Struct(StructType::new( + "".into(), + "Bar".into(), + vec![ + StructMember::new(String::from("a"), Type::FieldElement), + StructMember::new(String::from("b"), Type::FieldElement), + ], + )), + )], + )), }], outputs: vec![], }; @@ -224,22 +246,28 @@ mod tests { "name": "foo", "public": true, "type": "struct", - "components": [ - { - "name": "bar", - "type": "struct", - "components": [ - { - "name": "a", - "type": "field" - }, - { - "name": "b", - "type": "field" + "components": { + "name": "Foo", + "members": [ + { + "name": "bar", + "type": "struct", + "components": { + "name": "Bar", + "members": [ + { + "name": "a", + "type": "field" + }, + { + "name": "b", + "type": "field" + } + ] } - ] - } - ] + } + ] + } } ], "outputs": [] @@ -254,10 +282,14 @@ mod tests { name: String::from("a"), public: false, ty: Type::Array(ArrayType::new( - Type::Struct(vec![ - StructMember::new(String::from("b"), Type::FieldElement), - StructMember::new(String::from("c"), Type::Boolean), - ]), + Type::Struct(StructType::new( + "".into(), + "Foo".into(), + vec![ + StructMember::new(String::from("b"), Type::FieldElement), + StructMember::new(String::from("c"), Type::Boolean), + ], + )), 2, )), }], @@ -276,16 +308,19 @@ mod tests { "components": { "size": 2, "type": "struct", - "components": [ - { - "name": "b", - "type": "field" - }, - { - "name": "c", - "type": "bool" - } - ] + "components": { + "name": "Foo", + "members": [ + { + "name": "b", + "type": "field" + }, + { + "name": "c", + "type": "bool" + } + ] + } } } ], diff --git a/zokrates_core/src/typed_absy/folder.rs b/zokrates_core/src/typed_absy/folder.rs index 4c09542f..69b04487 100644 --- a/zokrates_core/src/typed_absy/folder.rs +++ b/zokrates_core/src/typed_absy/folder.rs @@ -1,7 +1,6 @@ // Generic walk through a typed AST. Not mutating in place use crate::typed_absy::*; -use typed_absy::types::StructMember; use zokrates_field::Field; pub trait Folder<'ast, T: Field>: Sized { @@ -130,7 +129,7 @@ pub trait Folder<'ast, T: Field>: Sized { } fn fold_struct_expression_inner( &mut self, - ty: &Vec, + ty: &StructType, e: StructExpressionInner<'ast, T>, ) -> StructExpressionInner<'ast, T> { fold_struct_expression_inner(self, ty, e) @@ -222,7 +221,7 @@ pub fn fold_array_expression_inner<'ast, T: Field, F: Folder<'ast, T>>( pub fn fold_struct_expression_inner<'ast, T: Field, F: Folder<'ast, T>>( f: &mut F, - _: &Vec, + _: &StructType, e: StructExpressionInner<'ast, T>, ) -> StructExpressionInner<'ast, T> { match e { diff --git a/zokrates_core/src/typed_absy/mod.rs b/zokrates_core/src/typed_absy/mod.rs index 4f82c38c..70c1bc96 100644 --- a/zokrates_core/src/typed_absy/mod.rs +++ b/zokrates_core/src/typed_absy/mod.rs @@ -16,7 +16,7 @@ mod variable; pub use self::identifier::CoreIdentifier; pub use self::parameter::Parameter; -pub use self::types::{Signature, Type}; +pub use self::types::{Signature, StructType, Type}; pub use self::variable::Variable; use std::path::PathBuf; pub use typed_absy::uint::{bitwidth, UExpression, UExpressionInner, UMetadata}; @@ -30,7 +30,6 @@ use zokrates_field::Field; pub use self::folder::Folder; use typed_absy::abi::{Abi, AbiInput}; -use typed_absy::types::StructMember; pub use self::identifier::Identifier; @@ -729,12 +728,12 @@ impl<'ast, T> ArrayExpression<'ast, T> { #[derive(Clone, PartialEq, Hash, Eq)] pub struct StructExpression<'ast, T> { - ty: Vec, + ty: StructType, inner: StructExpressionInner<'ast, T>, } impl<'ast, T> StructExpression<'ast, T> { - pub fn ty(&self) -> &Vec { + pub fn ty(&self) -> &StructType { &self.ty } @@ -765,7 +764,7 @@ pub enum StructExpressionInner<'ast, T> { } impl<'ast, T> StructExpressionInner<'ast, T> { - pub fn annotate(self, ty: Vec) -> StructExpression<'ast, T> { + pub fn annotate(self, ty: StructType) -> StructExpression<'ast, T> { StructExpression { ty, inner: self } } } diff --git a/zokrates_core/src/typed_absy/types.rs b/zokrates_core/src/typed_absy/types.rs index 134cca48..cee6b9dd 100644 --- a/zokrates_core/src/typed_absy/types.rs +++ b/zokrates_core/src/typed_absy/types.rs @@ -1,4 +1,5 @@ use std::fmt; +use std::path::PathBuf; pub type Identifier<'ast> = &'ast str; @@ -19,6 +20,49 @@ pub struct ArrayType { pub ty: Box, } +#[derive(Clone, Hash, Serialize, Deserialize, PartialOrd, Ord)] +pub struct StructType { + #[serde(skip)] + pub module: PathBuf, + pub name: String, + pub members: Vec, +} + +impl PartialEq for StructType { + fn eq(&self, other: &Self) -> bool { + self.members.eq(&other.members) + } +} + +impl Eq for StructType {} + +impl StructType { + pub fn new(module: PathBuf, name: String, members: Vec) -> Self { + StructType { + module, + name, + members, + } + } + + pub fn len(&self) -> usize { + self.members.len() + } + + pub fn iter(&self) -> std::slice::Iter { + self.members.iter() + } +} + +impl IntoIterator for StructType { + type Item = StructMember; + type IntoIter = std::vec::IntoIter; + + fn into_iter(self) -> Self::IntoIter { + self.members.into_iter() + } +} + #[derive(Clone, PartialEq, Eq, Hash, Serialize, Deserialize, PartialOrd, Ord)] #[serde(tag = "type", content = "components")] pub enum Type { @@ -29,7 +73,7 @@ pub enum Type { #[serde(rename = "array")] Array(ArrayType), #[serde(rename = "struct")] - Struct(Vec), + Struct(StructType), #[serde(rename = "u")] Uint(usize), } @@ -59,10 +103,12 @@ impl fmt::Display for Type { Type::Boolean => write!(f, "bool"), Type::Uint(ref bitwidth) => write!(f, "u{}", bitwidth), Type::Array(ref array_type) => write!(f, "{}[{}]", array_type.ty, array_type.size), - Type::Struct(ref members) => write!( + Type::Struct(ref struct_type) => write!( f, - "{{{}}}", - members + "{} {{{}}}", + struct_type.name, + struct_type + .members .iter() .map(|member| format!("{}: {}", member.id, member.ty)) .collect::>() @@ -79,10 +125,12 @@ impl fmt::Debug for Type { Type::Boolean => write!(f, "bool"), Type::Uint(ref bitwidth) => write!(f, "u{}", bitwidth), Type::Array(ref array_type) => write!(f, "{}[{}]", array_type.ty, array_type.size), - Type::Struct(ref members) => write!( + Type::Struct(ref struct_type) => write!( f, - "{{{}}}", - members + "{} {{{}}}", + struct_type.name, + struct_type + .members .iter() .map(|member| format!("{}: {}", member.id, member.ty)) .collect::>() @@ -97,12 +145,8 @@ impl Type { Type::Array(ArrayType::new(ty, size)) } - pub fn struc(ty: Vec<(MemberId, Type)>) -> Self { - Type::Struct( - ty.into_iter() - .map(|(id, ty)| StructMember { id, ty: box ty }) - .collect(), - ) + pub fn struc(struct_ty: StructType) -> Self { + Type::Struct(struct_ty) } fn to_slug(&self) -> String { @@ -111,9 +155,9 @@ impl Type { Type::Boolean => String::from("b"), Type::Uint(bitwidth) => format!("u{}", bitwidth), Type::Array(array_type) => format!("{}[{}]", array_type.ty.to_slug(), array_type.size), - Type::Struct(members) => format!( + Type::Struct(struct_type) => format!( "{{{}}}", - members + struct_type .iter() .map(|member| format!("{}:{}", member.id, member.ty)) .collect::>() @@ -128,8 +172,11 @@ impl Type { Type::FieldElement => 1, Type::Boolean => 1, Type::Uint(_) => 1, - Type::Struct(members) => members.iter().map(|m| m.ty.get_primitive_count()).sum(), Type::Array(array_type) => array_type.size * array_type.ty.get_primitive_count(), + Type::Struct(struct_type) => struct_type + .iter() + .map(|member| member.ty.get_primitive_count()) + .sum(), } } } diff --git a/zokrates_core/src/typed_absy/variable.rs b/zokrates_core/src/typed_absy/variable.rs index 8c3c8532..573a108a 100644 --- a/zokrates_core/src/typed_absy/variable.rs +++ b/zokrates_core/src/typed_absy/variable.rs @@ -1,7 +1,7 @@ use crate::typed_absy::types::Type; use crate::typed_absy::Identifier; use std::fmt; -use typed_absy::types::StructMember; +use typed_absy::types::StructType; #[derive(Clone, PartialEq, Hash, Eq)] pub struct Variable<'ast> { @@ -31,7 +31,7 @@ impl<'ast> Variable<'ast> { Self::with_id_and_type(id, Type::array(ty, size)) } - pub fn struc>>(id: I, ty: Vec) -> Variable<'ast> { + pub fn struc>>(id: I, ty: StructType) -> Variable<'ast> { Self::with_id_and_type(id, Type::Struct(ty)) } diff --git a/zokrates_core_test/tests/tests/structs/if_else.zok b/zokrates_core_test/tests/tests/structs/if_else.zok new file mode 100644 index 00000000..e69de29b