From 2a94af6ff07252d9cf3bf821848cdedca7085c48 Mon Sep 17 00:00:00 2001 From: dark64 Date: Tue, 24 Aug 2021 14:49:08 +0200 Subject: [PATCH] implement type aliasing --- .../examples/alias/basic_aliasing.zok | 13 + zokrates_cli/examples/alias/import_alias.zok | 14 + .../examples/alias/struct_aliasing.zok | 15 ++ zokrates_core/src/absy/from_ast.rs | 27 ++ zokrates_core/src/absy/mod.rs | 69 ++++- zokrates_core/src/absy/node.rs | 1 + zokrates_core/src/semantics.rs | 241 ++++++++++++++++-- zokrates_parser/src/zokrates.pest | 5 +- zokrates_pest_ast/src/lib.rs | 15 +- 9 files changed, 362 insertions(+), 38 deletions(-) create mode 100644 zokrates_cli/examples/alias/basic_aliasing.zok create mode 100644 zokrates_cli/examples/alias/import_alias.zok create mode 100644 zokrates_cli/examples/alias/struct_aliasing.zok diff --git a/zokrates_cli/examples/alias/basic_aliasing.zok b/zokrates_cli/examples/alias/basic_aliasing.zok new file mode 100644 index 00000000..c12362e5 --- /dev/null +++ b/zokrates_cli/examples/alias/basic_aliasing.zok @@ -0,0 +1,13 @@ +type byte = u8 +type uint32 = u32 +type UInt32Array = uint32[N] + +type matrix = field[R][C] + +def fill(field v) -> matrix: + return [[v; C]; R] + +def main(uint32 a, uint32 b) -> (UInt32Array<2>, matrix<2, 4>): + UInt32Array<2> res = [a, b] + matrix<2, 4> m = fill(1) + return res, m \ No newline at end of file diff --git a/zokrates_cli/examples/alias/import_alias.zok b/zokrates_cli/examples/alias/import_alias.zok new file mode 100644 index 00000000..6e7e54e4 --- /dev/null +++ b/zokrates_cli/examples/alias/import_alias.zok @@ -0,0 +1,14 @@ +from "./basic_aliasing.zok" import matrix +from "./struct_aliasing.zok" import Buzz + +const u32 R = 2 +const u32 C = 4 + +type matrix_2x4 = matrix + +def buzz() -> Buzz: + return Buzz { a: [0; N], b: [0; N] } + +def main(matrix_2x4 m) -> (Buzz<2>, matrix_2x4): + Buzz<2> b = buzz::<2>() + return b, m \ No newline at end of file diff --git a/zokrates_cli/examples/alias/struct_aliasing.zok b/zokrates_cli/examples/alias/struct_aliasing.zok new file mode 100644 index 00000000..932587b6 --- /dev/null +++ b/zokrates_cli/examples/alias/struct_aliasing.zok @@ -0,0 +1,15 @@ +type FieldArray = field[N] + +struct Foo { + FieldArray a + FieldArray b +} + +type Bar = Foo<2, 2> +type Buzz = Foo + +def main(Bar a) -> Buzz<2>: + Bar bar = Bar { a: [1, 2], b: [1, 2] } + Buzz<2> buzz = Buzz { a: [1, 2], b: [1, 2] } + assert(bar == buzz) + return buzz \ No newline at end of file diff --git a/zokrates_core/src/absy/from_ast.rs b/zokrates_core/src/absy/from_ast.rs index a968b9c2..48bdd35e 100644 --- a/zokrates_core/src/absy/from_ast.rs +++ b/zokrates_core/src/absy/from_ast.rs @@ -1,5 +1,6 @@ use crate::absy; +use crate::absy::SymbolDefinition; use num_bigint::BigUint; use std::path::Path; use zokrates_pest_ast as pest; @@ -10,6 +11,7 @@ impl<'ast> From> for absy::Module<'ast> { pest::SymbolDeclaration::Import(i) => import_directive_to_symbol_vec(i), pest::SymbolDeclaration::Constant(c) => vec![c.into()], pest::SymbolDeclaration::Struct(s) => vec![s.into()], + pest::SymbolDeclaration::Type(t) => vec![t.into()], pest::SymbolDeclaration::Function(f) => vec![f.into()], })) } @@ -135,6 +137,31 @@ impl<'ast> From> for absy::SymbolDeclarationNode< } } +impl<'ast> From> for absy::SymbolDeclarationNode<'ast> { + fn from(definition: pest::TypeDefinition<'ast>) -> absy::SymbolDeclarationNode<'ast> { + use crate::absy::NodeValue; + + let span = definition.span; + let id = definition.id.span.as_str(); + + let ty = absy::TypeDefinition { + generics: definition + .generics + .into_iter() + .map(absy::ConstantGenericNode::from) + .collect(), + ty: definition.ty.into(), + } + .span(span.clone()); + + absy::SymbolDeclaration { + id, + symbol: absy::Symbol::Here(SymbolDefinition::Type(ty)), + } + .span(span) + } +} + impl<'ast> From> for absy::SymbolDeclarationNode<'ast> { fn from(function: pest::FunctionDefinition<'ast>) -> absy::SymbolDeclarationNode<'ast> { use crate::absy::NodeValue; diff --git a/zokrates_core/src/absy/mod.rs b/zokrates_core/src/absy/mod.rs index 70fc76d0..26fa114a 100644 --- a/zokrates_core/src/absy/mod.rs +++ b/zokrates_core/src/absy/mod.rs @@ -133,6 +133,7 @@ pub enum SymbolDefinition<'ast> { Import(CanonicalImportNode<'ast>), Struct(StructDefinitionNode<'ast>), Constant(ConstantDefinitionNode<'ast>), + Type(TypeDefinitionNode<'ast>), Function(FunctionNode<'ast>), } @@ -153,12 +154,28 @@ impl<'ast> fmt::Display for SymbolDeclaration<'ast> { i.value.source.display(), i.value.id ), - SymbolDefinition::Struct(ref t) => write!(f, "struct {}{}", self.id, t), + SymbolDefinition::Struct(ref s) => write!(f, "struct {}{}", self.id, s), SymbolDefinition::Constant(ref c) => write!( f, "const {} {} = {}", c.value.ty, self.id, c.value.expression ), + SymbolDefinition::Type(ref t) => { + write!(f, "type {}", self.id)?; + if !t.value.generics.is_empty() { + write!( + f, + "<{}>", + t.value + .generics + .iter() + .map(|g| g.to_string()) + .collect::>() + .join(", ") + )?; + } + write!(f, " = {}", t.value.ty) + } SymbolDefinition::Function(ref func) => { write!(f, "def {}{}", self.id, func) } @@ -205,15 +222,18 @@ pub struct StructDefinition<'ast> { impl<'ast> fmt::Display for StructDefinition<'ast> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - writeln!( - f, - "<{}> {{", - self.generics - .iter() - .map(|g| g.to_string()) - .collect::>() - .join(", "), - )?; + if !self.generics.is_empty() { + write!( + f, + "<{}> ", + self.generics + .iter() + .map(|g| g.to_string()) + .collect::>() + .join(", ") + )?; + } + writeln!(f, "{{")?; for field in &self.fields { writeln!(f, " {}", field)?; } @@ -248,7 +268,34 @@ pub type ConstantDefinitionNode<'ast> = Node>; impl<'ast> fmt::Display for ConstantDefinition<'ast> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "const {}({})", self.ty, self.expression) + write!(f, "const {} _ = {}", self.ty, self.expression) + } +} + +/// A type definition +#[derive(Debug, Clone, PartialEq)] +pub struct TypeDefinition<'ast> { + pub generics: Vec>, + pub ty: UnresolvedTypeNode<'ast>, +} + +pub type TypeDefinitionNode<'ast> = Node>; + +impl<'ast> fmt::Display for TypeDefinition<'ast> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "type _")?; + if !self.generics.is_empty() { + write!( + f, + "<{}>", + self.generics + .iter() + .map(|g| g.to_string()) + .collect::>() + .join(", ") + )?; + } + write!(f, " = {}", self.ty) } } diff --git a/zokrates_core/src/absy/node.rs b/zokrates_core/src/absy/node.rs index 8a5745c0..222966f1 100644 --- a/zokrates_core/src/absy/node.rs +++ b/zokrates_core/src/absy/node.rs @@ -84,6 +84,7 @@ impl<'ast> NodeValue for UnresolvedType<'ast> {} impl<'ast> NodeValue for StructDefinition<'ast> {} impl<'ast> NodeValue for StructDefinitionField<'ast> {} impl<'ast> NodeValue for ConstantDefinition<'ast> {} +impl<'ast> NodeValue for TypeDefinition<'ast> {} impl<'ast> NodeValue for Function<'ast> {} impl<'ast> NodeValue for Module<'ast> {} impl<'ast> NodeValue for CanonicalImport<'ast> {} diff --git a/zokrates_core/src/semantics.rs b/zokrates_core/src/semantics.rs index 937fbf17..0aeb6f02 100644 --- a/zokrates_core/src/semantics.rs +++ b/zokrates_core/src/semantics.rs @@ -55,7 +55,9 @@ impl ErrorInner { } } -type TypeMap<'ast> = HashMap>>; +type GenericDeclarations<'ast> = Option>>>; +type TypeMap<'ast> = + HashMap, GenericDeclarations<'ast>)>>; type ConstantMap<'ast> = HashMap, DeclarationType<'ast>>>; @@ -349,6 +351,85 @@ impl<'ast, T: Field> Checker<'ast, T> { }) } + fn check_type_definition( + &mut self, + ty: TypeDefinitionNode<'ast>, + module_id: &ModuleId, + state: &State<'ast, T>, + ) -> Result<(DeclarationType<'ast>, GenericDeclarations<'ast>), Vec> { + let pos = ty.pos(); + let ty = ty.value; + + let mut errors = vec![]; + + let mut generics = vec![]; + let mut generics_map = HashMap::new(); + + for (index, g) in ty.generics.iter().enumerate() { + if state + .constants + .get(module_id) + .and_then(|m| m.get(g.value)) + .is_some() + { + errors.push(ErrorInner { + pos: Some(g.pos()), + message: format!( + "Generic parameter {p} conflicts with constant symbol {p}", + p = g.value + ), + }); + } else { + match generics_map.insert(g.value, index).is_none() { + true => { + generics.push(Some(DeclarationConstant::Generic(GenericIdentifier { + name: g.value, + index, + }))); + } + false => { + errors.push(ErrorInner { + pos: Some(g.pos()), + message: format!("Generic parameter {} is already declared", g.value), + }); + } + } + } + } + + let mut used_generics = HashSet::new(); + + match self.check_declaration_type( + ty.ty, + module_id, + state, + &generics_map, + &mut used_generics, + ) { + Ok(ty) => { + // check that all declared generics were used + for declared_generic in generics_map.keys() { + if !used_generics.contains(declared_generic) { + errors.push(ErrorInner { + pos: Some(pos), + message: format!("Generic parameter {} must be used", declared_generic), + }); + } + } + + if !errors.is_empty() { + return Err(errors); + } + + Ok((ty, Some(generics))) + } + Err(e) => { + errors.push(e); + Err(errors) + } + } + } + fn check_constant_definition( &mut self, id: ConstantIdentifier<'ast>, @@ -541,7 +622,7 @@ impl<'ast, T: Field> Checker<'ast, T> { .types .entry(module_id.to_path_buf()) .or_default() - .insert(declaration.id.to_string(), ty) + .insert(declaration.id.to_string(), (ty, None)) .is_none()); } }; @@ -593,6 +674,35 @@ impl<'ast, T: Field> Checker<'ast, T> { } } } + Symbol::Here(SymbolDefinition::Type(t)) => { + match self.check_type_definition(t, module_id, state) { + Ok(ty) => { + match symbol_unifier.insert_type(declaration.id) { + false => errors.push( + ErrorInner { + pos: Some(pos), + message: format!( + "{} conflicts with another symbol", + declaration.id, + ), + } + .in_file(module_id), + ), + true => { + assert!(state + .types + .entry(module_id.to_path_buf()) + .or_default() + .insert(declaration.id.to_string(), ty) + .is_none()); + } + }; + } + Err(e) => { + errors.extend(e.into_iter().map(|inner| inner.in_file(module_id))); + } + } + } Symbol::Here(SymbolDefinition::Function(f)) => { match self.check_function(f, module_id, state) { Ok(funct) => { @@ -673,7 +783,7 @@ impl<'ast, T: Field> Checker<'ast, T> { .cloned(); match (function_candidates.len(), type_candidate, const_candidate) { - (0, Some(t), None) => { + (0, Some((t, alias_generics)), None) => { // rename the type to the declared symbol let t = match t { @@ -684,7 +794,7 @@ impl<'ast, T: Field> Checker<'ast, T> { }), ..t }), - _ => unreachable!() + _ => t // type alias }; // we imported a type, so the symbol it gets bound to should not already exist @@ -706,7 +816,7 @@ impl<'ast, T: Field> Checker<'ast, T> { .types .entry(module_id.to_path_buf()) .or_default() - .insert(declaration.id.to_string(), t); + .insert(declaration.id.to_string(), (t, alias_generics)); } (0, None, Some(ty)) => { match symbol_unifier.insert_constant(declaration.id) { @@ -1187,23 +1297,22 @@ impl<'ast, T: Field> Checker<'ast, T> { ))) } UnresolvedType::User(id, generics) => { - let declaration_type = - types - .get(module_id) - .unwrap() - .get(&id) - .cloned() - .ok_or_else(|| ErrorInner { - pos: Some(pos), - message: format!("Undefined type {}", id), - })?; + let (declaration_type, alias_generics) = types + .get(module_id) + .unwrap() + .get(&id) + .cloned() + .ok_or_else(|| ErrorInner { + pos: Some(pos), + message: format!("Undefined type {}", id), + })?; // absence of generics is treated as 0 generics, as we do not provide inference for now let generics = generics.unwrap_or_default(); // check generics - match declaration_type { - DeclarationType::Struct(struct_type) => { + match (declaration_type, alias_generics) { + (DeclarationType::Struct(struct_type), None) => { match struct_type.generics.len() == generics.len() { true => { // downcast the generics to identifiers, as this is the only possibility here @@ -1263,7 +1372,58 @@ impl<'ast, T: Field> Checker<'ast, T> { }), } } - _ => unreachable!("user defined types should always be structs"), + (declaration_type, Some(alias_generics)) => { + match alias_generics.len() == generics.len() { + true => { + let generic_identifiers = + alias_generics.iter().map(|c| match c.as_ref().unwrap() { + DeclarationConstant::Generic(g) => g.clone(), + _ => unreachable!(), + }); + + // build the generic assignment for this type + let assignment = GGenericsAssignment(generics + .into_iter() + .zip(generic_identifiers) + .map(|(e, g)| match e { + Some(e) => { + self + .check_expression(e, module_id, types) + .and_then(|e| { + UExpression::try_from_typed(e, &UBitwidth::B32) + .map(|e| (g, e)) + .map_err(|e| ErrorInner { + pos: Some(pos), + message: format!("Expected u32 expression, but got expression of type {}", e.get_type()), + }) + }) + }, + None => Err(ErrorInner { + pos: Some(pos), + message: + "Expected u32 constant or identifier, but found `_`. Generic inference is not supported yet." + .into(), + }) + }) + .collect::>()?); + + // specialize the declared type using the generic assignment + Ok(specialize_declaration_type(declaration_type, &assignment) + .unwrap()) + } + false => Err(ErrorInner { + pos: Some(pos), + message: format!( + "Expected {} generic argument{} on type {}, but got {}", + alias_generics.len(), + if alias_generics.len() == 1 { "" } else { "s" }, + id, + generics.len() + ), + }), + } + } + _ => unreachable!(), } } } @@ -1358,7 +1518,7 @@ impl<'ast, T: Field> Checker<'ast, T> { ))) } UnresolvedType::User(id, generics) => { - let declared_ty = state + let (declared_ty, alias_generics) = state .types .get(module_id) .unwrap() @@ -1369,8 +1529,43 @@ impl<'ast, T: Field> Checker<'ast, T> { message: format!("Undefined type {}", id), })?; - match declared_ty { - DeclarationType::Struct(declared_struct_ty) => { + match (declared_ty, alias_generics) { + (ty, Some(alias_generics)) => { + let generics = generics.unwrap_or_default(); + let checked_generics: Vec<_> = generics + .into_iter() + .map(|e| match e { + Some(e) => self + .check_generic_expression( + e, + module_id, + state.constants.get(module_id).unwrap_or(&HashMap::new()), + generics_map, + used_generics, + ) + .map(Some), + None => Err(ErrorInner { + pos: Some(pos), + message: "Expected u32 constant or identifier, but found `_`" + .into(), + }), + }) + .collect::>()?; + + let mut assignment = GGenericsAssignment::default(); + + assignment.0.extend( + alias_generics.iter().zip(checked_generics.iter()).map( + |(decl_g, g_val)| match decl_g.clone().unwrap() { + DeclarationConstant::Generic(g) => (g, g_val.clone().unwrap()), + _ => unreachable!(), + }, + ), + ); + + Ok(specialize_declaration_type(ty, &assignment).unwrap()) + } + (DeclarationType::Struct(declared_struct_ty), None) => { let generics = generics.unwrap_or_default(); match declared_struct_ty.generics.len() == generics.len() { true => { @@ -1441,7 +1636,7 @@ impl<'ast, T: Field> Checker<'ast, T> { }), } } - _ => Ok(declared_ty), + (declared_ty, _) => Ok(declared_ty), } } } @@ -2910,7 +3105,7 @@ impl<'ast, T: Field> Checker<'ast, T> { .into()) } Expression::InlineStruct(id, inline_members) => { - let ty = match types.get(module_id).unwrap().get(&id).cloned() { + let (ty, _) = match types.get(module_id).unwrap().get(&id).cloned() { None => Err(ErrorInner { pos: Some(pos), message: format!("Undefined type `{}`", id), diff --git a/zokrates_parser/src/zokrates.pest b/zokrates_parser/src/zokrates.pest index ce7499e7..c833a844 100644 --- a/zokrates_parser/src/zokrates.pest +++ b/zokrates_parser/src/zokrates.pest @@ -4,7 +4,7 @@ file = { SOI ~ NEWLINE* ~ pragma? ~ NEWLINE* ~ symbol_declaration* ~ EOI } pragma = { "#pragma" ~ "curve" ~ curve } curve = @{ (ASCII_ALPHANUMERIC | "_") * } -symbol_declaration = { (import_directive | ty_struct_definition | const_definition | function_definition) ~ NEWLINE* } +symbol_declaration = { (import_directive | ty_struct_definition | const_definition | type_definition | function_definition) ~ NEWLINE* } import_directive = { main_import_directive | from_import_directive } from_import_directive = { "from" ~ "\"" ~ import_source ~ "\"" ~ "import" ~ import_symbol_list ~ NEWLINE* } @@ -14,6 +14,7 @@ import_symbol = { identifier ~ ("as" ~ identifier)? } import_symbol_list = _{ import_symbol ~ ("," ~ import_symbol)* } function_definition = {"def" ~ identifier ~ constant_generics_declaration? ~ "(" ~ parameter_list ~ ")" ~ return_types ~ ":" ~ NEWLINE* ~ statement* } const_definition = {"const" ~ ty ~ identifier ~ "=" ~ expression ~ NEWLINE*} +type_definition = {"type" ~ identifier ~ constant_generics_declaration? ~ "=" ~ ty ~ NEWLINE*} return_types = _{ ( "->" ~ ( "(" ~ type_list ~ ")" | ty ))? } constant_generics_declaration = _{ "<" ~ constant_generics_list ~ ">" } constant_generics_list = _{ identifier ~ ("," ~ identifier)* } @@ -163,6 +164,6 @@ COMMENT = _{ ("/*" ~ (!"*/" ~ ANY)* ~ "*/") | ("//" ~ (!NEWLINE ~ ANY)*) } // the ordering of reserved keywords matters: if "as" is before "assert", then "assert" gets parsed as (as)(sert) and incorrectly // accepted -keyword = @{"assert"|"as"|"bool"|"byte"|"const"|"def"|"do"|"else"|"endfor"|"export"|"false"|"field"|"for"|"if"|"then"|"fi"|"import"|"from"| +keyword = @{"assert"|"as"|"bool"|"const"|"def"|"do"|"else"|"endfor"|"export"|"false"|"field"|"for"|"if"|"then"|"fi"|"import"|"from"| "in"|"private"|"public"|"return"|"struct"|"true"|"u8"|"u16"|"u32"|"u64" } diff --git a/zokrates_pest_ast/src/lib.rs b/zokrates_pest_ast/src/lib.rs index 22b4d0d7..7f4d5490 100644 --- a/zokrates_pest_ast/src/lib.rs +++ b/zokrates_pest_ast/src/lib.rs @@ -17,8 +17,8 @@ pub use ast::{ InlineStructExpression, InlineStructMember, IterationStatement, LiteralExpression, Parameter, PostfixExpression, Range, RangeOrExpression, ReturnStatement, Span, Spread, SpreadOrExpression, Statement, StructDefinition, StructField, SymbolDeclaration, TernaryExpression, ToExpression, - Type, TypedIdentifier, TypedIdentifierOrAssignee, UnaryExpression, UnaryOperator, Underscore, - Visibility, + Type, TypeDefinition, TypedIdentifier, TypedIdentifierOrAssignee, UnaryExpression, + UnaryOperator, Underscore, Visibility, }; mod ast { @@ -140,6 +140,7 @@ mod ast { Import(ImportDirective<'ast>), Constant(ConstantDefinition<'ast>), Struct(StructDefinition<'ast>), + Type(TypeDefinition<'ast>), Function(FunctionDefinition<'ast>), } @@ -184,6 +185,16 @@ mod ast { pub span: Span<'ast>, } + #[derive(Debug, FromPest, PartialEq, Clone)] + #[pest_ast(rule(Rule::type_definition))] + pub struct TypeDefinition<'ast> { + pub id: IdentifierExpression<'ast>, + pub generics: Vec>, + pub ty: Type<'ast>, + #[pest_ast(outer())] + pub span: Span<'ast>, + } + #[derive(Debug, FromPest, PartialEq, Clone)] #[pest_ast(rule(Rule::import_directive))] pub enum ImportDirective<'ast> {