From 2a94af6ff07252d9cf3bf821848cdedca7085c48 Mon Sep 17 00:00:00 2001 From: dark64 Date: Tue, 24 Aug 2021 14:49:08 +0200 Subject: [PATCH 01/19] implement type aliasing --- .../examples/alias/basic_aliasing.zok | 13 + zokrates_cli/examples/alias/import_alias.zok | 14 + .../examples/alias/struct_aliasing.zok | 15 ++ zokrates_core/src/absy/from_ast.rs | 27 ++ zokrates_core/src/absy/mod.rs | 69 ++++- zokrates_core/src/absy/node.rs | 1 + zokrates_core/src/semantics.rs | 241 ++++++++++++++++-- zokrates_parser/src/zokrates.pest | 5 +- zokrates_pest_ast/src/lib.rs | 15 +- 9 files changed, 362 insertions(+), 38 deletions(-) create mode 100644 zokrates_cli/examples/alias/basic_aliasing.zok create mode 100644 zokrates_cli/examples/alias/import_alias.zok create mode 100644 zokrates_cli/examples/alias/struct_aliasing.zok diff --git a/zokrates_cli/examples/alias/basic_aliasing.zok b/zokrates_cli/examples/alias/basic_aliasing.zok new file mode 100644 index 00000000..c12362e5 --- /dev/null +++ b/zokrates_cli/examples/alias/basic_aliasing.zok @@ -0,0 +1,13 @@ +type byte = u8 +type uint32 = u32 +type UInt32Array = uint32[N] + +type matrix = field[R][C] + +def fill(field v) -> matrix: + return [[v; C]; R] + +def main(uint32 a, uint32 b) -> (UInt32Array<2>, matrix<2, 4>): + UInt32Array<2> res = [a, b] + matrix<2, 4> m = fill(1) + return res, m \ No newline at end of file diff --git a/zokrates_cli/examples/alias/import_alias.zok b/zokrates_cli/examples/alias/import_alias.zok new file mode 100644 index 00000000..6e7e54e4 --- /dev/null +++ b/zokrates_cli/examples/alias/import_alias.zok @@ -0,0 +1,14 @@ +from "./basic_aliasing.zok" import matrix +from "./struct_aliasing.zok" import Buzz + +const u32 R = 2 +const u32 C = 4 + +type matrix_2x4 = matrix + +def buzz() -> Buzz: + return Buzz { a: [0; N], b: [0; N] } + +def main(matrix_2x4 m) -> (Buzz<2>, matrix_2x4): + Buzz<2> b = buzz::<2>() + return b, m \ No newline at end of file diff --git a/zokrates_cli/examples/alias/struct_aliasing.zok b/zokrates_cli/examples/alias/struct_aliasing.zok new file mode 100644 index 00000000..932587b6 --- /dev/null +++ b/zokrates_cli/examples/alias/struct_aliasing.zok @@ -0,0 +1,15 @@ +type FieldArray = field[N] + +struct Foo { + FieldArray a + FieldArray b +} + +type Bar = Foo<2, 2> +type Buzz = Foo + +def main(Bar a) -> Buzz<2>: + Bar bar = Bar { a: [1, 2], b: [1, 2] } + Buzz<2> buzz = Buzz { a: [1, 2], b: [1, 2] } + assert(bar == buzz) + return buzz \ No newline at end of file diff --git a/zokrates_core/src/absy/from_ast.rs b/zokrates_core/src/absy/from_ast.rs index a968b9c2..48bdd35e 100644 --- a/zokrates_core/src/absy/from_ast.rs +++ b/zokrates_core/src/absy/from_ast.rs @@ -1,5 +1,6 @@ use crate::absy; +use crate::absy::SymbolDefinition; use num_bigint::BigUint; use std::path::Path; use zokrates_pest_ast as pest; @@ -10,6 +11,7 @@ impl<'ast> From> for absy::Module<'ast> { pest::SymbolDeclaration::Import(i) => import_directive_to_symbol_vec(i), pest::SymbolDeclaration::Constant(c) => vec![c.into()], pest::SymbolDeclaration::Struct(s) => vec![s.into()], + pest::SymbolDeclaration::Type(t) => vec![t.into()], pest::SymbolDeclaration::Function(f) => vec![f.into()], })) } @@ -135,6 +137,31 @@ impl<'ast> From> for absy::SymbolDeclarationNode< } } +impl<'ast> From> for absy::SymbolDeclarationNode<'ast> { + fn from(definition: pest::TypeDefinition<'ast>) -> absy::SymbolDeclarationNode<'ast> { + use crate::absy::NodeValue; + + let span = definition.span; + let id = definition.id.span.as_str(); + + let ty = absy::TypeDefinition { + generics: definition + .generics + .into_iter() + .map(absy::ConstantGenericNode::from) + .collect(), + ty: definition.ty.into(), + } + .span(span.clone()); + + absy::SymbolDeclaration { + id, + symbol: absy::Symbol::Here(SymbolDefinition::Type(ty)), + } + .span(span) + } +} + impl<'ast> From> for absy::SymbolDeclarationNode<'ast> { fn from(function: pest::FunctionDefinition<'ast>) -> absy::SymbolDeclarationNode<'ast> { use crate::absy::NodeValue; diff --git a/zokrates_core/src/absy/mod.rs b/zokrates_core/src/absy/mod.rs index 70fc76d0..26fa114a 100644 --- a/zokrates_core/src/absy/mod.rs +++ b/zokrates_core/src/absy/mod.rs @@ -133,6 +133,7 @@ pub enum SymbolDefinition<'ast> { Import(CanonicalImportNode<'ast>), Struct(StructDefinitionNode<'ast>), Constant(ConstantDefinitionNode<'ast>), + Type(TypeDefinitionNode<'ast>), Function(FunctionNode<'ast>), } @@ -153,12 +154,28 @@ impl<'ast> fmt::Display for SymbolDeclaration<'ast> { i.value.source.display(), i.value.id ), - SymbolDefinition::Struct(ref t) => write!(f, "struct {}{}", self.id, t), + SymbolDefinition::Struct(ref s) => write!(f, "struct {}{}", self.id, s), SymbolDefinition::Constant(ref c) => write!( f, "const {} {} = {}", c.value.ty, self.id, c.value.expression ), + SymbolDefinition::Type(ref t) => { + write!(f, "type {}", self.id)?; + if !t.value.generics.is_empty() { + write!( + f, + "<{}>", + t.value + .generics + .iter() + .map(|g| g.to_string()) + .collect::>() + .join(", ") + )?; + } + write!(f, " = {}", t.value.ty) + } SymbolDefinition::Function(ref func) => { write!(f, "def {}{}", self.id, func) } @@ -205,15 +222,18 @@ pub struct StructDefinition<'ast> { impl<'ast> fmt::Display for StructDefinition<'ast> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - writeln!( - f, - "<{}> {{", - self.generics - .iter() - .map(|g| g.to_string()) - .collect::>() - .join(", "), - )?; + if !self.generics.is_empty() { + write!( + f, + "<{}> ", + self.generics + .iter() + .map(|g| g.to_string()) + .collect::>() + .join(", ") + )?; + } + writeln!(f, "{{")?; for field in &self.fields { writeln!(f, " {}", field)?; } @@ -248,7 +268,34 @@ pub type ConstantDefinitionNode<'ast> = Node>; impl<'ast> fmt::Display for ConstantDefinition<'ast> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "const {}({})", self.ty, self.expression) + write!(f, "const {} _ = {}", self.ty, self.expression) + } +} + +/// A type definition +#[derive(Debug, Clone, PartialEq)] +pub struct TypeDefinition<'ast> { + pub generics: Vec>, + pub ty: UnresolvedTypeNode<'ast>, +} + +pub type TypeDefinitionNode<'ast> = Node>; + +impl<'ast> fmt::Display for TypeDefinition<'ast> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "type _")?; + if !self.generics.is_empty() { + write!( + f, + "<{}>", + self.generics + .iter() + .map(|g| g.to_string()) + .collect::>() + .join(", ") + )?; + } + write!(f, " = {}", self.ty) } } diff --git a/zokrates_core/src/absy/node.rs b/zokrates_core/src/absy/node.rs index 8a5745c0..222966f1 100644 --- a/zokrates_core/src/absy/node.rs +++ b/zokrates_core/src/absy/node.rs @@ -84,6 +84,7 @@ impl<'ast> NodeValue for UnresolvedType<'ast> {} impl<'ast> NodeValue for StructDefinition<'ast> {} impl<'ast> NodeValue for StructDefinitionField<'ast> {} impl<'ast> NodeValue for ConstantDefinition<'ast> {} +impl<'ast> NodeValue for TypeDefinition<'ast> {} impl<'ast> NodeValue for Function<'ast> {} impl<'ast> NodeValue for Module<'ast> {} impl<'ast> NodeValue for CanonicalImport<'ast> {} diff --git a/zokrates_core/src/semantics.rs b/zokrates_core/src/semantics.rs index 937fbf17..0aeb6f02 100644 --- a/zokrates_core/src/semantics.rs +++ b/zokrates_core/src/semantics.rs @@ -55,7 +55,9 @@ impl ErrorInner { } } -type TypeMap<'ast> = HashMap>>; +type GenericDeclarations<'ast> = Option>>>; +type TypeMap<'ast> = + HashMap, GenericDeclarations<'ast>)>>; type ConstantMap<'ast> = HashMap, DeclarationType<'ast>>>; @@ -349,6 +351,85 @@ impl<'ast, T: Field> Checker<'ast, T> { }) } + fn check_type_definition( + &mut self, + ty: TypeDefinitionNode<'ast>, + module_id: &ModuleId, + state: &State<'ast, T>, + ) -> Result<(DeclarationType<'ast>, GenericDeclarations<'ast>), Vec> { + let pos = ty.pos(); + let ty = ty.value; + + let mut errors = vec![]; + + let mut generics = vec![]; + let mut generics_map = HashMap::new(); + + for (index, g) in ty.generics.iter().enumerate() { + if state + .constants + .get(module_id) + .and_then(|m| m.get(g.value)) + .is_some() + { + errors.push(ErrorInner { + pos: Some(g.pos()), + message: format!( + "Generic parameter {p} conflicts with constant symbol {p}", + p = g.value + ), + }); + } else { + match generics_map.insert(g.value, index).is_none() { + true => { + generics.push(Some(DeclarationConstant::Generic(GenericIdentifier { + name: g.value, + index, + }))); + } + false => { + errors.push(ErrorInner { + pos: Some(g.pos()), + message: format!("Generic parameter {} is already declared", g.value), + }); + } + } + } + } + + let mut used_generics = HashSet::new(); + + match self.check_declaration_type( + ty.ty, + module_id, + state, + &generics_map, + &mut used_generics, + ) { + Ok(ty) => { + // check that all declared generics were used + for declared_generic in generics_map.keys() { + if !used_generics.contains(declared_generic) { + errors.push(ErrorInner { + pos: Some(pos), + message: format!("Generic parameter {} must be used", declared_generic), + }); + } + } + + if !errors.is_empty() { + return Err(errors); + } + + Ok((ty, Some(generics))) + } + Err(e) => { + errors.push(e); + Err(errors) + } + } + } + fn check_constant_definition( &mut self, id: ConstantIdentifier<'ast>, @@ -541,7 +622,7 @@ impl<'ast, T: Field> Checker<'ast, T> { .types .entry(module_id.to_path_buf()) .or_default() - .insert(declaration.id.to_string(), ty) + .insert(declaration.id.to_string(), (ty, None)) .is_none()); } }; @@ -593,6 +674,35 @@ impl<'ast, T: Field> Checker<'ast, T> { } } } + Symbol::Here(SymbolDefinition::Type(t)) => { + match self.check_type_definition(t, module_id, state) { + Ok(ty) => { + match symbol_unifier.insert_type(declaration.id) { + false => errors.push( + ErrorInner { + pos: Some(pos), + message: format!( + "{} conflicts with another symbol", + declaration.id, + ), + } + .in_file(module_id), + ), + true => { + assert!(state + .types + .entry(module_id.to_path_buf()) + .or_default() + .insert(declaration.id.to_string(), ty) + .is_none()); + } + }; + } + Err(e) => { + errors.extend(e.into_iter().map(|inner| inner.in_file(module_id))); + } + } + } Symbol::Here(SymbolDefinition::Function(f)) => { match self.check_function(f, module_id, state) { Ok(funct) => { @@ -673,7 +783,7 @@ impl<'ast, T: Field> Checker<'ast, T> { .cloned(); match (function_candidates.len(), type_candidate, const_candidate) { - (0, Some(t), None) => { + (0, Some((t, alias_generics)), None) => { // rename the type to the declared symbol let t = match t { @@ -684,7 +794,7 @@ impl<'ast, T: Field> Checker<'ast, T> { }), ..t }), - _ => unreachable!() + _ => t // type alias }; // we imported a type, so the symbol it gets bound to should not already exist @@ -706,7 +816,7 @@ impl<'ast, T: Field> Checker<'ast, T> { .types .entry(module_id.to_path_buf()) .or_default() - .insert(declaration.id.to_string(), t); + .insert(declaration.id.to_string(), (t, alias_generics)); } (0, None, Some(ty)) => { match symbol_unifier.insert_constant(declaration.id) { @@ -1187,23 +1297,22 @@ impl<'ast, T: Field> Checker<'ast, T> { ))) } UnresolvedType::User(id, generics) => { - let declaration_type = - types - .get(module_id) - .unwrap() - .get(&id) - .cloned() - .ok_or_else(|| ErrorInner { - pos: Some(pos), - message: format!("Undefined type {}", id), - })?; + let (declaration_type, alias_generics) = types + .get(module_id) + .unwrap() + .get(&id) + .cloned() + .ok_or_else(|| ErrorInner { + pos: Some(pos), + message: format!("Undefined type {}", id), + })?; // absence of generics is treated as 0 generics, as we do not provide inference for now let generics = generics.unwrap_or_default(); // check generics - match declaration_type { - DeclarationType::Struct(struct_type) => { + match (declaration_type, alias_generics) { + (DeclarationType::Struct(struct_type), None) => { match struct_type.generics.len() == generics.len() { true => { // downcast the generics to identifiers, as this is the only possibility here @@ -1263,7 +1372,58 @@ impl<'ast, T: Field> Checker<'ast, T> { }), } } - _ => unreachable!("user defined types should always be structs"), + (declaration_type, Some(alias_generics)) => { + match alias_generics.len() == generics.len() { + true => { + let generic_identifiers = + alias_generics.iter().map(|c| match c.as_ref().unwrap() { + DeclarationConstant::Generic(g) => g.clone(), + _ => unreachable!(), + }); + + // build the generic assignment for this type + let assignment = GGenericsAssignment(generics + .into_iter() + .zip(generic_identifiers) + .map(|(e, g)| match e { + Some(e) => { + self + .check_expression(e, module_id, types) + .and_then(|e| { + UExpression::try_from_typed(e, &UBitwidth::B32) + .map(|e| (g, e)) + .map_err(|e| ErrorInner { + pos: Some(pos), + message: format!("Expected u32 expression, but got expression of type {}", e.get_type()), + }) + }) + }, + None => Err(ErrorInner { + pos: Some(pos), + message: + "Expected u32 constant or identifier, but found `_`. Generic inference is not supported yet." + .into(), + }) + }) + .collect::>()?); + + // specialize the declared type using the generic assignment + Ok(specialize_declaration_type(declaration_type, &assignment) + .unwrap()) + } + false => Err(ErrorInner { + pos: Some(pos), + message: format!( + "Expected {} generic argument{} on type {}, but got {}", + alias_generics.len(), + if alias_generics.len() == 1 { "" } else { "s" }, + id, + generics.len() + ), + }), + } + } + _ => unreachable!(), } } } @@ -1358,7 +1518,7 @@ impl<'ast, T: Field> Checker<'ast, T> { ))) } UnresolvedType::User(id, generics) => { - let declared_ty = state + let (declared_ty, alias_generics) = state .types .get(module_id) .unwrap() @@ -1369,8 +1529,43 @@ impl<'ast, T: Field> Checker<'ast, T> { message: format!("Undefined type {}", id), })?; - match declared_ty { - DeclarationType::Struct(declared_struct_ty) => { + match (declared_ty, alias_generics) { + (ty, Some(alias_generics)) => { + let generics = generics.unwrap_or_default(); + let checked_generics: Vec<_> = generics + .into_iter() + .map(|e| match e { + Some(e) => self + .check_generic_expression( + e, + module_id, + state.constants.get(module_id).unwrap_or(&HashMap::new()), + generics_map, + used_generics, + ) + .map(Some), + None => Err(ErrorInner { + pos: Some(pos), + message: "Expected u32 constant or identifier, but found `_`" + .into(), + }), + }) + .collect::>()?; + + let mut assignment = GGenericsAssignment::default(); + + assignment.0.extend( + alias_generics.iter().zip(checked_generics.iter()).map( + |(decl_g, g_val)| match decl_g.clone().unwrap() { + DeclarationConstant::Generic(g) => (g, g_val.clone().unwrap()), + _ => unreachable!(), + }, + ), + ); + + Ok(specialize_declaration_type(ty, &assignment).unwrap()) + } + (DeclarationType::Struct(declared_struct_ty), None) => { let generics = generics.unwrap_or_default(); match declared_struct_ty.generics.len() == generics.len() { true => { @@ -1441,7 +1636,7 @@ impl<'ast, T: Field> Checker<'ast, T> { }), } } - _ => Ok(declared_ty), + (declared_ty, _) => Ok(declared_ty), } } } @@ -2910,7 +3105,7 @@ impl<'ast, T: Field> Checker<'ast, T> { .into()) } Expression::InlineStruct(id, inline_members) => { - let ty = match types.get(module_id).unwrap().get(&id).cloned() { + let (ty, _) = match types.get(module_id).unwrap().get(&id).cloned() { None => Err(ErrorInner { pos: Some(pos), message: format!("Undefined type `{}`", id), diff --git a/zokrates_parser/src/zokrates.pest b/zokrates_parser/src/zokrates.pest index ce7499e7..c833a844 100644 --- a/zokrates_parser/src/zokrates.pest +++ b/zokrates_parser/src/zokrates.pest @@ -4,7 +4,7 @@ file = { SOI ~ NEWLINE* ~ pragma? ~ NEWLINE* ~ symbol_declaration* ~ EOI } pragma = { "#pragma" ~ "curve" ~ curve } curve = @{ (ASCII_ALPHANUMERIC | "_") * } -symbol_declaration = { (import_directive | ty_struct_definition | const_definition | function_definition) ~ NEWLINE* } +symbol_declaration = { (import_directive | ty_struct_definition | const_definition | type_definition | function_definition) ~ NEWLINE* } import_directive = { main_import_directive | from_import_directive } from_import_directive = { "from" ~ "\"" ~ import_source ~ "\"" ~ "import" ~ import_symbol_list ~ NEWLINE* } @@ -14,6 +14,7 @@ import_symbol = { identifier ~ ("as" ~ identifier)? } import_symbol_list = _{ import_symbol ~ ("," ~ import_symbol)* } function_definition = {"def" ~ identifier ~ constant_generics_declaration? ~ "(" ~ parameter_list ~ ")" ~ return_types ~ ":" ~ NEWLINE* ~ statement* } const_definition = {"const" ~ ty ~ identifier ~ "=" ~ expression ~ NEWLINE*} +type_definition = {"type" ~ identifier ~ constant_generics_declaration? ~ "=" ~ ty ~ NEWLINE*} return_types = _{ ( "->" ~ ( "(" ~ type_list ~ ")" | ty ))? } constant_generics_declaration = _{ "<" ~ constant_generics_list ~ ">" } constant_generics_list = _{ identifier ~ ("," ~ identifier)* } @@ -163,6 +164,6 @@ COMMENT = _{ ("/*" ~ (!"*/" ~ ANY)* ~ "*/") | ("//" ~ (!NEWLINE ~ ANY)*) } // the ordering of reserved keywords matters: if "as" is before "assert", then "assert" gets parsed as (as)(sert) and incorrectly // accepted -keyword = @{"assert"|"as"|"bool"|"byte"|"const"|"def"|"do"|"else"|"endfor"|"export"|"false"|"field"|"for"|"if"|"then"|"fi"|"import"|"from"| +keyword = @{"assert"|"as"|"bool"|"const"|"def"|"do"|"else"|"endfor"|"export"|"false"|"field"|"for"|"if"|"then"|"fi"|"import"|"from"| "in"|"private"|"public"|"return"|"struct"|"true"|"u8"|"u16"|"u32"|"u64" } diff --git a/zokrates_pest_ast/src/lib.rs b/zokrates_pest_ast/src/lib.rs index 22b4d0d7..7f4d5490 100644 --- a/zokrates_pest_ast/src/lib.rs +++ b/zokrates_pest_ast/src/lib.rs @@ -17,8 +17,8 @@ pub use ast::{ InlineStructExpression, InlineStructMember, IterationStatement, LiteralExpression, Parameter, PostfixExpression, Range, RangeOrExpression, ReturnStatement, Span, Spread, SpreadOrExpression, Statement, StructDefinition, StructField, SymbolDeclaration, TernaryExpression, ToExpression, - Type, TypedIdentifier, TypedIdentifierOrAssignee, UnaryExpression, UnaryOperator, Underscore, - Visibility, + Type, TypeDefinition, TypedIdentifier, TypedIdentifierOrAssignee, UnaryExpression, + UnaryOperator, Underscore, Visibility, }; mod ast { @@ -140,6 +140,7 @@ mod ast { Import(ImportDirective<'ast>), Constant(ConstantDefinition<'ast>), Struct(StructDefinition<'ast>), + Type(TypeDefinition<'ast>), Function(FunctionDefinition<'ast>), } @@ -184,6 +185,16 @@ mod ast { pub span: Span<'ast>, } + #[derive(Debug, FromPest, PartialEq, Clone)] + #[pest_ast(rule(Rule::type_definition))] + pub struct TypeDefinition<'ast> { + pub id: IdentifierExpression<'ast>, + pub generics: Vec>, + pub ty: Type<'ast>, + #[pest_ast(outer())] + pub span: Span<'ast>, + } + #[derive(Debug, FromPest, PartialEq, Clone)] #[pest_ast(rule(Rule::import_directive))] pub enum ImportDirective<'ast> { From 5c5410915c018b88b7fd971c06591983c04e3fe7 Mon Sep 17 00:00:00 2001 From: dark64 Date: Wed, 25 Aug 2021 18:25:55 +0200 Subject: [PATCH 02/19] add changelog --- changelogs/unreleased/982-dark64 | 1 + 1 file changed, 1 insertion(+) create mode 100644 changelogs/unreleased/982-dark64 diff --git a/changelogs/unreleased/982-dark64 b/changelogs/unreleased/982-dark64 new file mode 100644 index 00000000..68919800 --- /dev/null +++ b/changelogs/unreleased/982-dark64 @@ -0,0 +1 @@ +Implement type aliasing \ No newline at end of file From 18b599696bf0e33082d1dbc8015edffa54837842 Mon Sep 17 00:00:00 2001 From: dark64 Date: Wed, 25 Aug 2021 18:29:14 +0200 Subject: [PATCH 03/19] fix test in semantic checker --- zokrates_core/src/semantics.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/zokrates_core/src/semantics.rs b/zokrates_core/src/semantics.rs index 0aeb6f02..2572dc65 100644 --- a/zokrates_core/src/semantics.rs +++ b/zokrates_core/src/semantics.rs @@ -5679,6 +5679,7 @@ mod tests { .get(&*MODULE_ID) .unwrap() .get(&"Bar".to_string()) + .map(|(ty, _)| ty) .unwrap(), &DeclarationType::Struct(DeclarationStructType::new( (*MODULE_ID).clone(), From 6e066f869cfaab199c8997f210b15e764e589289 Mon Sep 17 00:00:00 2001 From: schaeff Date: Wed, 15 Sep 2021 12:38:26 +0200 Subject: [PATCH 04/19] keep original declaration type intact and specialize on the fly --- zokrates_core/src/compile.rs | 2 + zokrates_core/src/semantics.rs | 27 +++----- zokrates_core/src/typed_absy/types.rs | 99 ++++++++++++++++++--------- 3 files changed, 77 insertions(+), 51 deletions(-) diff --git a/zokrates_core/src/compile.rs b/zokrates_core/src/compile.rs index 172ad747..de33ea26 100644 --- a/zokrates_core/src/compile.rs +++ b/zokrates_core/src/compile.rs @@ -247,6 +247,8 @@ fn check_with_arena<'ast, T: Field, E: Into>( let typed_ast = Checker::check(compiled) .map_err(|errors| CompileErrors(errors.into_iter().map(CompileError::from).collect()))?; + log::trace!("\n{}", typed_ast); + let main_module = typed_ast.main.clone(); log::debug!("Run static analysis"); diff --git a/zokrates_core/src/semantics.rs b/zokrates_core/src/semantics.rs index 2572dc65..ce4a97a8 100644 --- a/zokrates_core/src/semantics.rs +++ b/zokrates_core/src/semantics.rs @@ -1600,24 +1600,9 @@ impl<'ast, T: Field> Checker<'ast, T> { _ => unreachable!("generic on declaration struct types must be generic identifiers") })); - // generate actual type based on generic type and concrete generics - let members = declared_struct_ty - .members - .into_iter() - .map(|m| { - Ok(DeclarationStructMember { - ty: box specialize_declaration_type(*m.ty, &assignment) - .unwrap(), - ..m - }) - }) - .collect::, _>>()?; - Ok(DeclarationType::Struct(DeclarationStructType { - canonical_location: declared_struct_ty.canonical_location, - location: declared_struct_ty.location, generics: checked_generics, - members, + ..declared_struct_ty })) } false => Err(ErrorInner { @@ -3113,11 +3098,19 @@ impl<'ast, T: Field> Checker<'ast, T> { Some(ty) => Ok(ty), }?; - let declared_struct_type = match ty { + let mut declared_struct_type = match ty { DeclarationType::Struct(struct_type) => struct_type, _ => unreachable!(), }; + declared_struct_type.generics = (0..declared_struct_type.generics.len()) + .map(|index| { + Some(DeclarationConstant::Generic( + GenericIdentifier::with_name("DUMMY").index(index), + )) + }) + .collect(); + // check that we provided the required number of values if declared_struct_type.members_count() != inline_members.len() { return Err(ErrorInner { diff --git a/zokrates_core/src/typed_absy/types.rs b/zokrates_core/src/typed_absy/types.rs index 24d125c4..4cdf41f2 100644 --- a/zokrates_core/src/typed_absy/types.rs +++ b/zokrates_core/src/typed_absy/types.rs @@ -57,7 +57,7 @@ impl<'ast, T> Types<'ast, T> { } } -#[derive(Debug, Clone, Eq, Ord)] +#[derive(Debug, Clone, Eq)] pub struct GenericIdentifier<'ast> { pub name: &'ast str, pub index: usize, @@ -86,6 +86,12 @@ impl<'ast> PartialOrd for GenericIdentifier<'ast> { } } +impl<'ast> Ord for GenericIdentifier<'ast> { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + self.partial_cmp(other).unwrap() + } +} + impl<'ast> Hash for GenericIdentifier<'ast> { fn hash(&self, state: &mut H) { self.index.hash(state); @@ -995,9 +1001,8 @@ pub fn specialize_declaration_type< Ok(match decl_ty { DeclarationType::Int => unreachable!(), DeclarationType::Array(t0) => { - // let s1 = t1.size.clone(); - let ty = box specialize_declaration_type(*t0.ty, &generics)?; + let size = match t0.size { DeclarationConstant::Generic(s) => generics.0.get(&s).cloned().ok_or(s), DeclarationConstant::Concrete(s) => Ok(s.into()), @@ -1009,37 +1014,63 @@ pub fn specialize_declaration_type< DeclarationType::FieldElement => GType::FieldElement, DeclarationType::Boolean => GType::Boolean, DeclarationType::Uint(b0) => GType::Uint(b0), - DeclarationType::Struct(s0) => GType::Struct(GStructType { - members: s0 - .members - .into_iter() - .map(|m| { - let id = m.id; - specialize_declaration_type(*m.ty, generics) - .map(|ty| GStructMember { ty: box ty, id }) - }) - .collect::>()?, - generics: s0 - .generics - .into_iter() - .map(|g| match g { - Some(constant) => match constant { - DeclarationConstant::Generic(s) => { - generics.0.get(&s).cloned().ok_or(s).map(Some) - } - DeclarationConstant::Concrete(s) => Ok(Some(s.into())), - DeclarationConstant::Constant(..) => { - unreachable!( - "identifiers should have been removed in constant inlining" - ) - } - }, - _ => Ok(None), - }) - .collect::>()?, - canonical_location: s0.canonical_location, - location: s0.location, - }), + DeclarationType::Struct(s0) => { + // here we specialize Foo {FooDef} with some values for Generics + // we need to remap these values for InsideGenerics to then visit the members + + let inside_generics = GGenericsAssignment( + s0.generics + .clone() + .into_iter() + .enumerate() + .map(|(index, g)| { + ( + GenericIdentifier::with_name("dummy").index(index), + g.map(|g| match g { + DeclarationConstant::Generic(s) => { + generics.0.get(&s).cloned().unwrap() + } + DeclarationConstant::Concrete(s) => s.into(), + DeclarationConstant::Constant(c) => c.into(), + }) + .unwrap(), + ) + }) + .collect(), + ); + + GType::Struct(GStructType { + members: s0 + .members + .into_iter() + .map(|m| { + let id = m.id; + specialize_declaration_type(*m.ty, &inside_generics) + .map(|ty| GStructMember { ty: box ty, id }) + }) + .collect::>()?, + generics: s0 + .generics + .into_iter() + .map(|g| match g { + Some(constant) => match constant { + DeclarationConstant::Generic(s) => { + generics.0.get(&s).cloned().ok_or(s).map(Some) + } + DeclarationConstant::Concrete(s) => Ok(Some(s.into())), + DeclarationConstant::Constant(..) => { + unreachable!( + "identifiers should have been removed in constant inlining" + ) + } + }, + _ => Ok(None), + }) + .collect::>()?, + canonical_location: s0.canonical_location, + location: s0.location, + }) + } }) } From 6f8b73ab0f4e00e3617880421bea117699891c25 Mon Sep 17 00:00:00 2001 From: schaeff Date: Thu, 16 Sep 2021 18:36:17 +0300 Subject: [PATCH 05/19] fix abi with concretization, fix StructType --- zokrates_core/src/semantics.rs | 32 +++++-- .../static_analysis/flatten_complex_types.rs | 2 + zokrates_core/src/static_analysis/mod.rs | 10 ++ .../src/static_analysis/propagation.rs | 18 ++-- .../src/static_analysis/struct_concretizer.rs | 96 +++++++++++++++++++ zokrates_core/src/typed_absy/mod.rs | 14 ++- zokrates_core/src/typed_absy/types.rs | 3 +- 7 files changed, 153 insertions(+), 22 deletions(-) create mode 100644 zokrates_core/src/static_analysis/struct_concretizer.rs diff --git a/zokrates_core/src/semantics.rs b/zokrates_core/src/semantics.rs index ce4a97a8..85421404 100644 --- a/zokrates_core/src/semantics.rs +++ b/zokrates_core/src/semantics.rs @@ -6,7 +6,7 @@ use crate::absy::Identifier; use crate::absy::*; -use crate::typed_absy::types::GGenericsAssignment; +use crate::typed_absy::types::{GGenericsAssignment, GenericsAssignment}; use crate::typed_absy::*; use crate::typed_absy::{DeclarationParameter, DeclarationVariable, Variable}; use num_bigint::BigUint; @@ -1032,17 +1032,28 @@ impl<'ast, T: Field> Checker<'ast, T> { match self.check_signature(funct.signature, module_id, state) { Ok(s) => { + // initialise generics map + let mut generics: GenericsAssignment<'ast, T> = GGenericsAssignment::default(); + // define variables for the constants for generic in &s.generics { - let generic = generic.clone().unwrap(); // for declaration signatures, generics cannot be ignored + let generic = match generic.clone().unwrap() { + DeclarationConstant::Generic(g) => g, + _ => unreachable!(), + }; + + // for declaration signatures, generics cannot be ignored let v = Variable::with_id_and_type( - match generic { - DeclarationConstant::Generic(g) => g.name, - _ => unreachable!(), - }, + generic.name.clone(), Type::Uint(UBitwidth::B32), ); + + generics.0.insert( + generic.clone(), + UExpressionInner::Identifier(generic.name.into()).annotate(UBitwidth::B32), + ); + // we don't have to check for conflicts here, because this was done when checking the signature self.insert_into_scope(v.clone()); } @@ -1055,9 +1066,12 @@ impl<'ast, T: Field> Checker<'ast, T> { let decl_v = DeclarationVariable::with_id_and_type(arg.id.value.id, decl_ty.clone()); - match self.insert_into_scope( - crate::typed_absy::variable::try_from_g_variable(decl_v.clone()).unwrap(), - ) { + let ty = specialize_declaration_type(decl_v.clone()._type, &generics).unwrap(); + + match self.insert_into_scope(crate::typed_absy::variable::Variable { + id: decl_v.clone().id, + _type: ty, + }) { true => {} false => { errors.push(ErrorInner { diff --git a/zokrates_core/src/static_analysis/flatten_complex_types.rs b/zokrates_core/src/static_analysis/flatten_complex_types.rs index f3973297..db213ca6 100644 --- a/zokrates_core/src/static_analysis/flatten_complex_types.rs +++ b/zokrates_core/src/static_analysis/flatten_complex_types.rs @@ -1125,6 +1125,8 @@ fn fold_struct_expression<'ast, T: Field>( statements_buffer: &mut Vec>, e: typed_absy::StructExpression<'ast, T>, ) -> Vec> { + println!("{:#?}", e.ty()); + f.fold_struct_expression_inner( statements_buffer, &typed_absy::types::ConcreteStructType::try_from(e.ty().clone()).unwrap(), diff --git a/zokrates_core/src/static_analysis/mod.rs b/zokrates_core/src/static_analysis/mod.rs index 2285115b..ae46821f 100644 --- a/zokrates_core/src/static_analysis/mod.rs +++ b/zokrates_core/src/static_analysis/mod.rs @@ -11,6 +11,7 @@ mod flat_propagation; mod flatten_complex_types; mod propagation; mod reducer; +mod struct_concretizer; mod uint_optimizer; mod unconstrained_vars; mod variable_write_remover; @@ -20,6 +21,7 @@ use self::constant_argument_checker::ConstantArgumentChecker; use self::flatten_complex_types::Flattener; use self::propagation::Propagator; use self::reducer::reduce_program; +use self::struct_concretizer::StructConcretizer; use self::uint_optimizer::UintOptimizer; use self::unconstrained_vars::UnconstrainedVariableDetector; use self::variable_write_remover::VariableWriteRemover; @@ -101,6 +103,14 @@ impl<'ast, T: Field> TypedProgram<'ast, T> { let r = reduce_program(r).map_err(Error::from)?; log::trace!("\n{}", r); + log::debug!("Static analyser: Propagate"); + let r = Propagator::propagate(r)?; + log::trace!("\n{}", r); + + log::debug!("Static analyser: Concretize structs"); + let r = StructConcretizer::concretize(r); + log::trace!("\n{}", r); + // generate abi log::debug!("Static analyser: Generate abi"); let abi = r.abi(); diff --git a/zokrates_core/src/static_analysis/propagation.rs b/zokrates_core/src/static_analysis/propagation.rs index b90fc99d..1450e077 100644 --- a/zokrates_core/src/static_analysis/propagation.rs +++ b/zokrates_core/src/static_analysis/propagation.rs @@ -1194,16 +1194,14 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Propagator<'ast, 'a, T> { let e1 = self.fold_struct_expression(e1)?; let e2 = self.fold_struct_expression(e2)?; - if let (Ok(t1), Ok(t2)) = ( - ConcreteType::try_from(e1.get_type()), - ConcreteType::try_from(e2.get_type()), - ) { - if t1 != t2 { - return Err(Error::Type(format!( - "Cannot compare {} of type {} to {} of type {}", - e1, t1, e2, t2 - ))); - } + let t1 = e1.get_type(); + let t2 = e2.get_type(); + + if t1 != t2 { + return Err(Error::Type(format!( + "Cannot compare {} of type {} to {} of type {}", + e1, t1, e2, t2 + ))); }; Ok(BooleanExpression::StructEq(box e1, box e2)) diff --git a/zokrates_core/src/static_analysis/struct_concretizer.rs b/zokrates_core/src/static_analysis/struct_concretizer.rs new file mode 100644 index 00000000..874a4b36 --- /dev/null +++ b/zokrates_core/src/static_analysis/struct_concretizer.rs @@ -0,0 +1,96 @@ +// After all generics are inlined, a program should be completely "concrete", which means that all types must only contain +// litterals for array sizes. This is especially important to generate the ABI of the program. +// It is direct to ensure that with most types, however the way structs are implemented requires a slightly different process: +// Where for an array, `field[N]` ends up being propagated to `field[42]` which is direct to turn into a concrete type, +// for structs, `Foo { field[N] a }` is propagated to `Foo<42> { field[N] a }`. The missing step is replacing `N` by `42` +// *inside* the canonical type, so that it can be concretized in the same way arrays are. +// We apply this transformation only to the main function. + +use crate::typed_absy::folder::*; +use crate::typed_absy::{ + types::{ + ConcreteGenericsAssignment, DeclarationArrayType, DeclarationConstant, + DeclarationStructMember, GGenericsAssignment, + }, + DeclarationStructType, GenericIdentifier, TypedProgram, +}; +use zokrates_field::Field; + +#[derive(Default)] +pub struct StructConcretizer<'ast> { + generics: ConcreteGenericsAssignment<'ast>, +} + +impl<'ast> StructConcretizer<'ast> { + pub fn concretize(p: TypedProgram<'ast, T>) -> TypedProgram<'ast, T> { + StructConcretizer::default().fold_program(p) + } + + pub fn with_generics(generics: ConcreteGenericsAssignment<'ast>) -> Self { + Self { generics } + } +} + +impl<'ast, T: Field> Folder<'ast, T> for StructConcretizer<'ast> { + fn fold_declaration_struct_type( + &mut self, + ty: DeclarationStructType<'ast>, + ) -> DeclarationStructType<'ast> { + let concrete_generics: Vec<_> = ty + .generics + .iter() + .map(|g| match g.as_ref().unwrap() { + DeclarationConstant::Generic(s) => self.generics.0.get(&s).cloned().unwrap(), + DeclarationConstant::Concrete(s) => *s as usize, + DeclarationConstant::Constant(..) => unreachable!(), + }) + .collect(); + + let concrete_generics_map: ConcreteGenericsAssignment = GGenericsAssignment( + concrete_generics + .iter() + .enumerate() + .map(|(index, g)| (GenericIdentifier::with_name("DUMMY").index(index), *g)) + .collect(), + ); + + let mut internal_concretizer = StructConcretizer::with_generics(concrete_generics_map); + + DeclarationStructType { + members: ty + .members + .into_iter() + .map(|member| { + DeclarationStructMember::new( + member.id, + >::fold_declaration_type( + &mut internal_concretizer, + *member.ty, + ), + ) + }) + .collect(), + generics: concrete_generics + .into_iter() + .map(|g| Some(DeclarationConstant::Concrete(g as u32))) + .collect(), + ..ty + } + } + + fn fold_declaration_array_type( + &mut self, + ty: DeclarationArrayType<'ast>, + ) -> DeclarationArrayType<'ast> { + let size = match ty.size { + DeclarationConstant::Generic(s) => self.generics.0.get(&s).cloned().unwrap() as u32, + DeclarationConstant::Concrete(s) => s, + DeclarationConstant::Constant(..) => unreachable!(), + }; + + DeclarationArrayType { + size: DeclarationConstant::Concrete(size), + ty: box >::fold_declaration_type(self, *ty.ty), + } + } +} diff --git a/zokrates_core/src/typed_absy/mod.rs b/zokrates_core/src/typed_absy/mod.rs index b31dfeee..29a8dfb6 100644 --- a/zokrates_core/src/typed_absy/mod.rs +++ b/zokrates_core/src/typed_absy/mod.rs @@ -663,7 +663,19 @@ impl<'ast, T: fmt::Display> fmt::Display for StructExpression<'ast, T> { StructExpressionInner::IfElse(ref c) => write!(f, "{}", c), StructExpressionInner::Member(ref m) => write!(f, "{}", m), StructExpressionInner::Select(ref select) => write!(f, "{}", select), - } + }?; + + write!( + f, + "/* {} {{{}}} */", + self.ty, + self.ty + .members + .iter() + .map(|m| format!("{}: {}", m.id, m.ty)) + .collect::>() + .join(", ") + ) } } diff --git a/zokrates_core/src/typed_absy/types.rs b/zokrates_core/src/typed_absy/types.rs index 4cdf41f2..94a6acc1 100644 --- a/zokrates_core/src/typed_absy/types.rs +++ b/zokrates_core/src/typed_absy/types.rs @@ -371,8 +371,7 @@ impl<'ast, S, R: PartialEq> PartialEq> for GStructType { .zip(other.generics.iter()) .all(|(a, b)| match (a, b) { (Some(a), Some(b)) => a == b, - (None, None) => true, - _ => false, + _ => true, }) } } From a3bfa9f6d6573fc55ab70a779de5b3895e03d728 Mon Sep 17 00:00:00 2001 From: schaeff Date: Fri, 17 Sep 2021 16:21:52 +0300 Subject: [PATCH 06/19] clippy --- zokrates_core/src/semantics.rs | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/zokrates_core/src/semantics.rs b/zokrates_core/src/semantics.rs index 85421404..213a69a7 100644 --- a/zokrates_core/src/semantics.rs +++ b/zokrates_core/src/semantics.rs @@ -1044,10 +1044,7 @@ impl<'ast, T: Field> Checker<'ast, T> { // for declaration signatures, generics cannot be ignored - let v = Variable::with_id_and_type( - generic.name.clone(), - Type::Uint(UBitwidth::B32), - ); + let v = Variable::with_id_and_type(generic.name, Type::Uint(UBitwidth::B32)); generics.0.insert( generic.clone(), From ae825b7f36334974d57b298c182c1aed326ac86d Mon Sep 17 00:00:00 2001 From: schaeff Date: Tue, 21 Sep 2021 14:01:49 +0300 Subject: [PATCH 07/19] update highlighters --- zokrates_parser/src/ace_mode/index.js | 2 +- .../src/textmate/zokrates.tmLanguage.yaml | 14 +++++++++----- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/zokrates_parser/src/ace_mode/index.js b/zokrates_parser/src/ace_mode/index.js index 9778609b..b576695f 100644 --- a/zokrates_parser/src/ace_mode/index.js +++ b/zokrates_parser/src/ace_mode/index.js @@ -37,7 +37,7 @@ ace.define("ace/mode/zokrates_highlight_rules",["require","exports","module","ac var ZoKratesHighlightRules = function () { var keywords = ( - "assert|as|bool|byte|const|def|do|else|endfor|export|false|field|for|if|then|fi|import|from|in|private|public|return|struct|true|u8|u16|u32|u64" + "assert|as|bool|byte|const|def|do|else|endfor|export|false|field|for|if|then|fi|import|from|in|private|public|return|struct|true|type|u8|u16|u32|u64" ); var keywordMapper = this.createKeywordMapper({ diff --git a/zokrates_parser/src/textmate/zokrates.tmLanguage.yaml b/zokrates_parser/src/textmate/zokrates.tmLanguage.yaml index 2a2e9fb6..bb200057 100644 --- a/zokrates_parser/src/textmate/zokrates.tmLanguage.yaml +++ b/zokrates_parser/src/textmate/zokrates.tmLanguage.yaml @@ -202,19 +202,23 @@ repository: - comment: 'control flow keywords' name: keyword.control.zokrates - match: \b(do|else|for|do|endfor|if|then|fi|return|assert)\b + match: \b(for|in|do|endfor|if|then|else|fi|return|assert)\b - comment: 'storage keywords' name: storage.type.zokrates match: \b(struct)\b - - comment: const + comment: 'const keyword' name: keyword.other.const.zokrates - match: \bconst\b + match: \b(const)\b - - comment: def + comment: 'type keyword' + name: keyword.other.type.zokrates + match: \b(type)\b + - + comment: 'def keyword' name: keyword.other.def.zokrates - match: \bdef\b + match: \b(def)\b - comment: 'import keywords' name: keyword.other.import.zokrates From fad1b2fa3773e8b3952d683904f2f34e77b5021d Mon Sep 17 00:00:00 2001 From: schaeff Date: Tue, 21 Sep 2021 15:14:26 +0300 Subject: [PATCH 08/19] clean, use u32 for concrete array sizes, introduce functions to map using generic assignments --- zokrates_core/src/embed.rs | 24 ++-- .../static_analysis/flatten_complex_types.rs | 15 +-- .../src/static_analysis/struct_concretizer.rs | 11 +- .../static_analysis/variable_write_remover.rs | 7 +- zokrates_core/src/typed_absy/mod.rs | 14 +-- zokrates_core/src/typed_absy/types.rs | 112 ++++++++---------- zokrates_core/src/typed_absy/uint.rs | 10 +- zokrates_core/src/typed_absy/variable.rs | 2 +- zokrates_core/src/zir/identifier.rs | 2 +- 9 files changed, 80 insertions(+), 117 deletions(-) diff --git a/zokrates_core/src/embed.rs b/zokrates_core/src/embed.rs index 3ee07660..88f0604c 100644 --- a/zokrates_core/src/embed.rs +++ b/zokrates_core/src/embed.rs @@ -94,59 +94,59 @@ impl FlatEmbed { .inputs(vec![DeclarationType::uint(8)]) .outputs(vec![DeclarationType::array(( DeclarationType::Boolean, - 8usize, + 8u32, ))]), FlatEmbed::U16ToBits => DeclarationSignature::new() .inputs(vec![DeclarationType::uint(16)]) .outputs(vec![DeclarationType::array(( DeclarationType::Boolean, - 16usize, + 16u32, ))]), FlatEmbed::U32ToBits => DeclarationSignature::new() .inputs(vec![DeclarationType::uint(32)]) .outputs(vec![DeclarationType::array(( DeclarationType::Boolean, - 32usize, + 32u32, ))]), FlatEmbed::U64ToBits => DeclarationSignature::new() .inputs(vec![DeclarationType::uint(64)]) .outputs(vec![DeclarationType::array(( DeclarationType::Boolean, - 64usize, + 64u32, ))]), FlatEmbed::U8FromBits => DeclarationSignature::new() .outputs(vec![DeclarationType::uint(8)]) .inputs(vec![DeclarationType::array(( DeclarationType::Boolean, - 8usize, + 8u32, ))]), FlatEmbed::U16FromBits => DeclarationSignature::new() .outputs(vec![DeclarationType::uint(16)]) .inputs(vec![DeclarationType::array(( DeclarationType::Boolean, - 16usize, + 16u32, ))]), FlatEmbed::U32FromBits => DeclarationSignature::new() .outputs(vec![DeclarationType::uint(32)]) .inputs(vec![DeclarationType::array(( DeclarationType::Boolean, - 32usize, + 32u32, ))]), FlatEmbed::U64FromBits => DeclarationSignature::new() .outputs(vec![DeclarationType::uint(64)]) .inputs(vec![DeclarationType::array(( DeclarationType::Boolean, - 64usize, + 64u32, ))]), #[cfg(feature = "bellman")] FlatEmbed::Sha256Round => DeclarationSignature::new() .inputs(vec![ - DeclarationType::array((DeclarationType::Boolean, 512usize)), - DeclarationType::array((DeclarationType::Boolean, 256usize)), + DeclarationType::array((DeclarationType::Boolean, 512u32)), + DeclarationType::array((DeclarationType::Boolean, 256u32)), ]) .outputs(vec![DeclarationType::array(( DeclarationType::Boolean, - 256usize, + 256u32, ))]), #[cfg(feature = "ark")] FlatEmbed::SnarkVerifyBls12377 => DeclarationSignature::new() @@ -168,7 +168,7 @@ impl FlatEmbed { index: 0, }, )), // inputs - DeclarationType::array((DeclarationType::FieldElement, 8usize)), // proof + DeclarationType::array((DeclarationType::FieldElement, 8u32)), // proof DeclarationType::array(( DeclarationType::FieldElement, GenericIdentifier { diff --git a/zokrates_core/src/static_analysis/flatten_complex_types.rs b/zokrates_core/src/static_analysis/flatten_complex_types.rs index db213ca6..c9fcfea3 100644 --- a/zokrates_core/src/static_analysis/flatten_complex_types.rs +++ b/zokrates_core/src/static_analysis/flatten_complex_types.rs @@ -340,7 +340,7 @@ impl<'ast, T: Field> Flattener { &mut self, statements_buffer: &mut Vec>, ty: &typed_absy::types::ConcreteType, - size: usize, + size: u32, e: typed_absy::ArrayExpressionInner<'ast, T>, ) -> Vec> { fold_array_expression_inner(self, statements_buffer, ty, size, e) @@ -404,7 +404,7 @@ fn fold_array_expression_inner<'ast, T: Field>( f: &mut Flattener, statements_buffer: &mut Vec>, ty: &typed_absy::types::ConcreteType, - size: usize, + size: u32, array: typed_absy::ArrayExpressionInner<'ast, T>, ) -> Vec> { match array { @@ -437,7 +437,7 @@ fn fold_array_expression_inner<'ast, T: Field>( .flat_map(|e| f.fold_expression_or_spread(statements_buffer, e)) .collect(); - assert_eq!(exprs.len(), size * ty.get_primitive_count()); + assert_eq!(exprs.len(), size as usize * ty.get_primitive_count()); exprs } @@ -458,7 +458,7 @@ fn fold_array_expression_inner<'ast, T: Field>( match (from.into_inner(), to.into_inner()) { (zir::UExpressionInner::Value(from), zir::UExpressionInner::Value(to)) => { - assert_eq!(size, to.saturating_sub(from) as usize); + assert_eq!(size, to.saturating_sub(from) as u32); let element_size = ty.get_primitive_count(); let start = from as usize * element_size; @@ -1108,10 +1108,7 @@ fn fold_array_expression<'ast, T: Field>( statements_buffer: &mut Vec>, e: typed_absy::ArrayExpression<'ast, T>, ) -> Vec> { - let size = match e.size().into_inner() { - typed_absy::UExpressionInner::Value(v) => v, - _ => unreachable!(), - } as usize; + let size: u32 = e.size().try_into().unwrap(); f.fold_array_expression_inner( statements_buffer, &typed_absy::types::ConcreteType::try_from(e.inner_type().clone()).unwrap(), @@ -1125,8 +1122,6 @@ fn fold_struct_expression<'ast, T: Field>( statements_buffer: &mut Vec>, e: typed_absy::StructExpression<'ast, T>, ) -> Vec> { - println!("{:#?}", e.ty()); - f.fold_struct_expression_inner( statements_buffer, &typed_absy::types::ConcreteStructType::try_from(e.ty().clone()).unwrap(), diff --git a/zokrates_core/src/static_analysis/struct_concretizer.rs b/zokrates_core/src/static_analysis/struct_concretizer.rs index 874a4b36..2b137357 100644 --- a/zokrates_core/src/static_analysis/struct_concretizer.rs +++ b/zokrates_core/src/static_analysis/struct_concretizer.rs @@ -36,14 +36,11 @@ impl<'ast, T: Field> Folder<'ast, T> for StructConcretizer<'ast> { &mut self, ty: DeclarationStructType<'ast>, ) -> DeclarationStructType<'ast> { - let concrete_generics: Vec<_> = ty + let concrete_generics: Vec = ty .generics - .iter() - .map(|g| match g.as_ref().unwrap() { - DeclarationConstant::Generic(s) => self.generics.0.get(&s).cloned().unwrap(), - DeclarationConstant::Concrete(s) => *s as usize, - DeclarationConstant::Constant(..) => unreachable!(), - }) + .clone() + .into_iter() + .map(|g| g.unwrap().map_concrete(&self.generics).unwrap()) .collect(); let concrete_generics_map: ConcreteGenericsAssignment = GGenericsAssignment( diff --git a/zokrates_core/src/static_analysis/variable_write_remover.rs b/zokrates_core/src/static_analysis/variable_write_remover.rs index 82cc206e..e66a4461 100644 --- a/zokrates_core/src/static_analysis/variable_write_remover.rs +++ b/zokrates_core/src/static_analysis/variable_write_remover.rs @@ -37,10 +37,9 @@ impl<'ast> VariableWriteRemover { let inner_ty = base.inner_type(); let size = base.size(); - let size = match size.as_inner() { - UExpressionInner::Value(v) => *v as u32, - _ => unreachable!(), - }; + use std::convert::TryInto; + + let size: u32 = size.try_into().unwrap(); let head = indices.remove(0); let tail = indices; diff --git a/zokrates_core/src/typed_absy/mod.rs b/zokrates_core/src/typed_absy/mod.rs index 29a8dfb6..b31dfeee 100644 --- a/zokrates_core/src/typed_absy/mod.rs +++ b/zokrates_core/src/typed_absy/mod.rs @@ -663,19 +663,7 @@ impl<'ast, T: fmt::Display> fmt::Display for StructExpression<'ast, T> { StructExpressionInner::IfElse(ref c) => write!(f, "{}", c), StructExpressionInner::Member(ref m) => write!(f, "{}", m), StructExpressionInner::Select(ref select) => write!(f, "{}", select), - }?; - - write!( - f, - "/* {} {{{}}} */", - self.ty, - self.ty - .members - .iter() - .map(|m| format!("{}: {}", m.id, m.ty)) - .collect::>() - .join(", ") - ) + } } } diff --git a/zokrates_core/src/typed_absy/types.rs b/zokrates_core/src/typed_absy/types.rs index 94a6acc1..608cdb88 100644 --- a/zokrates_core/src/typed_absy/types.rs +++ b/zokrates_core/src/typed_absy/types.rs @@ -137,6 +137,32 @@ pub enum DeclarationConstant<'ast> { Constant(CanonicalConstantIdentifier<'ast>), } +impl<'ast> DeclarationConstant<'ast> { + pub fn map> + From + Clone>( + self, + generics: &GGenericsAssignment<'ast, S>, + ) -> Result> { + match self { + DeclarationConstant::Generic(g) => generics.0.get(&g).cloned().ok_or(g), + DeclarationConstant::Concrete(v) => Ok(v.into()), + DeclarationConstant::Constant(c) => Ok(c.into()), + } + } + + pub fn map_concrete + Clone>( + self, + generics: &GGenericsAssignment<'ast, S>, + ) -> Result> { + match self { + DeclarationConstant::Constant(_) => unreachable!( + "called map_concrete on a constant, it should have been resolved before" + ), + DeclarationConstant::Generic(g) => generics.0.get(&g).cloned().ok_or(g), + DeclarationConstant::Concrete(v) => Ok(v.into()), + } + } +} + impl<'ast, T> PartialEq> for DeclarationConstant<'ast> { fn eq(&self, other: &UExpression<'ast, T>) -> bool { match (self, other.as_inner()) { @@ -158,12 +184,6 @@ impl<'ast> From for DeclarationConstant<'ast> { } } -impl<'ast> From for DeclarationConstant<'ast> { - fn from(e: usize) -> Self { - DeclarationConstant::Concrete(e as u32) - } -} - impl<'ast> From> for DeclarationConstant<'ast> { fn from(e: GenericIdentifier<'ast>) -> Self { DeclarationConstant::Generic(e) @@ -180,8 +200,8 @@ impl<'ast> fmt::Display for DeclarationConstant<'ast> { } } -impl<'ast, T> From for UExpression<'ast, T> { - fn from(i: usize) -> Self { +impl<'ast, T> From for UExpression<'ast, T> { + fn from(i: u32) -> Self { UExpressionInner::Value(i as u128).annotate(UBitwidth::B32) } } @@ -202,25 +222,14 @@ impl<'ast, T> From> for UExpression<'ast, T> { } } -impl<'ast, T> TryInto for UExpression<'ast, T> { +impl<'ast, T> TryInto for UExpression<'ast, T> { type Error = SpecializationError; - fn try_into(self) -> Result { + fn try_into(self) -> Result { assert_eq!(self.bitwidth, UBitwidth::B32); match self.into_inner() { - UExpressionInner::Value(v) => Ok(v as usize), - _ => Err(SpecializationError), - } - } -} - -impl<'ast> TryInto for DeclarationConstant<'ast> { - type Error = SpecializationError; - - fn try_into(self) -> Result { - match self { - DeclarationConstant::Concrete(v) => Ok(v as usize), + UExpressionInner::Value(v) => Ok(v as u32), _ => Err(SpecializationError), } } @@ -237,7 +246,7 @@ pub struct GStructMember { } pub type DeclarationStructMember<'ast> = GStructMember>; -pub type ConcreteStructMember = GStructMember; +pub type ConcreteStructMember = GStructMember; pub type StructMember<'ast, T> = GStructMember>; impl<'ast, S, R: PartialEq> PartialEq> for GStructMember { @@ -277,7 +286,7 @@ pub struct GArrayType { } pub type DeclarationArrayType<'ast> = GArrayType>; -pub type ConcreteArrayType = GArrayType; +pub type ConcreteArrayType = GArrayType; pub type ArrayType<'ast, T> = GArrayType>; impl<'ast, S, R: PartialEq> PartialEq> for GArrayType { @@ -359,7 +368,7 @@ pub struct GStructType { } pub type DeclarationStructType<'ast> = GStructType>; -pub type ConcreteStructType = GStructType; +pub type ConcreteStructType = GStructType; pub type StructType<'ast, T> = GStructType>; impl<'ast, S, R: PartialEq> PartialEq> for GStructType { @@ -615,7 +624,7 @@ impl<'de, S: Deserialize<'de>> Deserialize<'de> for GType { } pub type DeclarationType<'ast> = GType>; -pub type ConcreteType = GType; +pub type ConcreteType = GType; pub type Type<'ast, T> = GType>; impl<'ast, S, R: PartialEq> PartialEq> for GType { @@ -804,7 +813,9 @@ impl ConcreteType { GType::FieldElement => 1, GType::Boolean => 1, GType::Uint(_) => 1, - GType::Array(array_type) => array_type.size * array_type.ty.get_primitive_count(), + GType::Array(array_type) => { + array_type.size as usize * array_type.ty.get_primitive_count() + } GType::Int => unreachable!(), GType::Struct(struct_type) => struct_type .iter() @@ -824,7 +835,7 @@ pub struct GFunctionKey<'ast, S> { } pub type DeclarationFunctionKey<'ast> = GFunctionKey<'ast, DeclarationConstant<'ast>>; -pub type ConcreteFunctionKey<'ast> = GFunctionKey<'ast, usize>; +pub type ConcreteFunctionKey<'ast> = GFunctionKey<'ast, u32>; pub type FunctionKey<'ast, T> = GFunctionKey<'ast, UExpression<'ast, T>>; impl<'ast, S: fmt::Display> fmt::Display for GFunctionKey<'ast, S> { @@ -836,7 +847,7 @@ impl<'ast, S: fmt::Display> fmt::Display for GFunctionKey<'ast, S> { #[derive(Debug, PartialEq, Eq, Hash, Clone)] pub struct GGenericsAssignment<'ast, S>(pub BTreeMap, S>); -pub type ConcreteGenericsAssignment<'ast> = GGenericsAssignment<'ast, usize>; +pub type ConcreteGenericsAssignment<'ast> = GGenericsAssignment<'ast, u32>; pub type GenericsAssignment<'ast, T> = GGenericsAssignment<'ast, UExpression<'ast, T>>; impl<'ast, S> Default for GGenericsAssignment<'ast, S> { @@ -936,7 +947,7 @@ impl<'ast> ConcreteFunctionKey<'ast> { use std::collections::btree_map::Entry; -pub fn check_type<'ast, S: Clone + PartialEq + PartialEq>( +pub fn check_type<'ast, S: Clone + PartialEq + PartialEq>( decl_ty: &DeclarationType<'ast>, ty: >ype, constants: &mut GGenericsAssignment<'ast, S>, @@ -957,7 +968,7 @@ pub fn check_type<'ast, S: Clone + PartialEq + PartialEq>( true } }, - DeclarationConstant::Concrete(s0) => s1 == *s0 as usize, + DeclarationConstant::Concrete(s0) => s1 == *s0 as u32, // in the case of a constant, we do not know the value yet, so we optimistically assume it's correct // if it does not match, it will be caught during inlining DeclarationConstant::Constant(..) => true, @@ -1002,11 +1013,7 @@ pub fn specialize_declaration_type< DeclarationType::Array(t0) => { let ty = box specialize_declaration_type(*t0.ty, &generics)?; - let size = match t0.size { - DeclarationConstant::Generic(s) => generics.0.get(&s).cloned().ok_or(s), - DeclarationConstant::Concrete(s) => Ok(s.into()), - DeclarationConstant::Constant(c) => Ok(c.into()), - }?; + let size = t0.size.map(generics)?; GType::Array(GArrayType { size, ty }) } @@ -1025,14 +1032,7 @@ pub fn specialize_declaration_type< .map(|(index, g)| { ( GenericIdentifier::with_name("dummy").index(index), - g.map(|g| match g { - DeclarationConstant::Generic(s) => { - generics.0.get(&s).cloned().unwrap() - } - DeclarationConstant::Concrete(s) => s.into(), - DeclarationConstant::Constant(c) => c.into(), - }) - .unwrap(), + g.map(|g| g.map(generics).unwrap()).unwrap(), ) }) .collect(), @@ -1052,17 +1052,7 @@ pub fn specialize_declaration_type< .generics .into_iter() .map(|g| match g { - Some(constant) => match constant { - DeclarationConstant::Generic(s) => { - generics.0.get(&s).cloned().ok_or(s).map(Some) - } - DeclarationConstant::Concrete(s) => Ok(Some(s.into())), - DeclarationConstant::Constant(..) => { - unreachable!( - "identifiers should have been removed in constant inlining" - ) - } - }, + Some(constant) => constant.map(generics).map(Some), _ => Ok(None), }) .collect::>()?, @@ -1127,7 +1117,7 @@ pub mod signature { } pub type DeclarationSignature<'ast> = GSignature>; - pub type ConcreteSignature = GSignature; + pub type ConcreteSignature = GSignature; pub type Signature<'ast, T> = GSignature>; impl<'ast> PartialEq> for ConcreteSignature { @@ -1140,7 +1130,7 @@ pub mod signature { .iter() .chain(other.outputs.iter()) .zip(self.inputs.iter().chain(self.outputs.iter())) - .all(|(decl_ty, ty)| check_type::(decl_ty, ty, &mut constants)) + .all(|(decl_ty, ty)| check_type::(decl_ty, ty, &mut constants)) } } @@ -1165,7 +1155,7 @@ pub mod signature { constants.0.extend( decl_generics .zip(values.into_iter()) - .filter_map(|(g, v)| v.map(|v| (g, v as usize))), + .filter_map(|(g, v)| v.map(|v| (g, v))), ); let condition = self @@ -1488,8 +1478,8 @@ pub mod signature { fn array_slug() { let s = ConcreteSignature::new() .inputs(vec![ - ConcreteType::array((ConcreteType::FieldElement, 42usize)), - ConcreteType::array((ConcreteType::FieldElement, 21usize)), + ConcreteType::array((ConcreteType::FieldElement, 42u32)), + ConcreteType::array((ConcreteType::FieldElement, 21u32)), ]) .outputs(vec![]); @@ -1512,7 +1502,7 @@ mod tests { fn array_display() { // field[1][2] let t = ConcreteType::Array(ConcreteArrayType::new( - ConcreteType::Array(ConcreteArrayType::new(ConcreteType::FieldElement, 2usize)), + ConcreteType::Array(ConcreteArrayType::new(ConcreteType::FieldElement, 2u32)), 1usize, )); assert_eq!(format!("{}", t), "field[1][2]"); diff --git a/zokrates_core/src/typed_absy/uint.rs b/zokrates_core/src/typed_absy/uint.rs index 4620dd7c..3fab93a8 100644 --- a/zokrates_core/src/typed_absy/uint.rs +++ b/zokrates_core/src/typed_absy/uint.rs @@ -146,12 +146,6 @@ pub struct UExpression<'ast, T> { pub inner: UExpressionInner<'ast, T>, } -impl<'ast, T> From for UExpression<'ast, T> { - fn from(u: u32) -> Self { - UExpressionInner::Value(u as u128).annotate(UBitwidth::B32) - } -} - impl<'ast, T> From for UExpression<'ast, T> { fn from(u: u16) -> Self { UExpressionInner::Value(u as u128).annotate(UBitwidth::B16) @@ -164,8 +158,8 @@ impl<'ast, T> From for UExpression<'ast, T> { } } -impl<'ast, T> PartialEq for UExpression<'ast, T> { - fn eq(&self, other: &usize) -> bool { +impl<'ast, T> PartialEq for UExpression<'ast, T> { + fn eq(&self, other: &u32) -> bool { match self.as_inner() { UExpressionInner::Value(v) => *v == *other as u128, _ => true, diff --git a/zokrates_core/src/typed_absy/variable.rs b/zokrates_core/src/typed_absy/variable.rs index 2d19a95e..281b9a49 100644 --- a/zokrates_core/src/typed_absy/variable.rs +++ b/zokrates_core/src/typed_absy/variable.rs @@ -12,7 +12,7 @@ pub struct GVariable<'ast, S> { } pub type DeclarationVariable<'ast> = GVariable<'ast, DeclarationConstant<'ast>>; -pub type ConcreteVariable<'ast> = GVariable<'ast, usize>; +pub type ConcreteVariable<'ast> = GVariable<'ast, u32>; pub type Variable<'ast, T> = GVariable<'ast, UExpression<'ast, T>>; impl<'ast, T> TryFrom> for ConcreteVariable<'ast> { diff --git a/zokrates_core/src/zir/identifier.rs b/zokrates_core/src/zir/identifier.rs index 87eea34d..9f7b6a9e 100644 --- a/zokrates_core/src/zir/identifier.rs +++ b/zokrates_core/src/zir/identifier.rs @@ -11,7 +11,7 @@ pub enum Identifier<'ast> { #[derive(Debug, PartialEq, Clone, Hash, Eq)] pub enum SourceIdentifier<'ast> { Basic(CoreIdentifier<'ast>), - Select(Box>, usize), + Select(Box>, u32), Member(Box>, MemberId), } From 2e589796d7404232a9267b413eb70147d447bd3b Mon Sep 17 00:00:00 2001 From: schaeff Date: Tue, 21 Sep 2021 15:34:51 +0300 Subject: [PATCH 09/19] refactor GenericIdentifier to avoid dummy names --- zokrates_core/src/embed.rs | 49 +++++-------------- zokrates_core/src/semantics.rs | 44 ++++++++--------- .../static_analysis/reducer/shallow_ssa.rs | 2 +- .../src/static_analysis/struct_concretizer.rs | 36 +++++++------- zokrates_core/src/typed_absy/types.rs | 34 ++++++++++--- 5 files changed, 78 insertions(+), 87 deletions(-) diff --git a/zokrates_core/src/embed.rs b/zokrates_core/src/embed.rs index 88f0604c..496ceaa8 100644 --- a/zokrates_core/src/embed.rs +++ b/zokrates_core/src/embed.rs @@ -50,25 +50,16 @@ impl FlatEmbed { match self { FlatEmbed::BitArrayLe => DeclarationSignature::new() .generics(vec![Some(DeclarationConstant::Generic( - GenericIdentifier { - name: "N", - index: 0, - }, + GenericIdentifier::with_name("N").with_index(0), ))]) .inputs(vec![ DeclarationType::array(( DeclarationType::Boolean, - GenericIdentifier { - name: "N", - index: 0, - }, + GenericIdentifier::with_name("N").with_index(0), )), DeclarationType::array(( DeclarationType::Boolean, - GenericIdentifier { - name: "N", - index: 0, - }, + GenericIdentifier::with_name("N").with_index(0), )), ]) .outputs(vec![DeclarationType::Boolean]), @@ -77,18 +68,12 @@ impl FlatEmbed { .outputs(vec![DeclarationType::FieldElement]), FlatEmbed::Unpack => DeclarationSignature::new() .generics(vec![Some(DeclarationConstant::Generic( - GenericIdentifier { - name: "N", - index: 0, - }, + GenericIdentifier::with_name("N").with_index(0), ))]) .inputs(vec![DeclarationType::FieldElement]) .outputs(vec![DeclarationType::array(( DeclarationType::Boolean, - GenericIdentifier { - name: "N", - index: 0, - }, + GenericIdentifier::with_name("N").with_index(0), ))]), FlatEmbed::U8ToBits => DeclarationSignature::new() .inputs(vec![DeclarationType::uint(8)]) @@ -151,30 +136,22 @@ impl FlatEmbed { #[cfg(feature = "ark")] FlatEmbed::SnarkVerifyBls12377 => DeclarationSignature::new() .generics(vec![ - Some(DeclarationConstant::Generic(GenericIdentifier { - name: "N", - index: 0, - })), - Some(DeclarationConstant::Generic(GenericIdentifier { - name: "V", - index: 1, - })), + Some(DeclarationConstant::Generic( + GenericIdentifier::with_name("N").with_index(0), + )), + Some(DeclarationConstant::Generic( + GenericIdentifier::with_name("V").with_index(1), + )), ]) .inputs(vec![ DeclarationType::array(( DeclarationType::FieldElement, - GenericIdentifier { - name: "N", - index: 0, - }, + GenericIdentifier::with_name("N").with_index(0), )), // inputs DeclarationType::array((DeclarationType::FieldElement, 8u32)), // proof DeclarationType::array(( DeclarationType::FieldElement, - GenericIdentifier { - name: "V", - index: 1, - }, + GenericIdentifier::with_name("V").with_index(1), )), // 18 + (2 * n) // vk ]) .outputs(vec![DeclarationType::Boolean]), diff --git a/zokrates_core/src/semantics.rs b/zokrates_core/src/semantics.rs index 213a69a7..ff019481 100644 --- a/zokrates_core/src/semantics.rs +++ b/zokrates_core/src/semantics.rs @@ -382,10 +382,9 @@ impl<'ast, T: Field> Checker<'ast, T> { } else { match generics_map.insert(g.value, index).is_none() { true => { - generics.push(Some(DeclarationConstant::Generic(GenericIdentifier { - name: g.value, - index, - }))); + generics.push(Some(DeclarationConstant::Generic( + GenericIdentifier::with_name(g.value).with_index(index), + ))); } false => { errors.push(ErrorInner { @@ -515,10 +514,9 @@ impl<'ast, T: Field> Checker<'ast, T> { } else { match generics_map.insert(g.value, index).is_none() { true => { - generics.push(Some(DeclarationConstant::Generic(GenericIdentifier { - name: g.value, - index, - }))); + generics.push(Some(DeclarationConstant::Generic( + GenericIdentifier::with_name(g.value).with_index(index), + ))); } false => { errors.push(ErrorInner { @@ -1044,11 +1042,12 @@ impl<'ast, T: Field> Checker<'ast, T> { // for declaration signatures, generics cannot be ignored - let v = Variable::with_id_and_type(generic.name, Type::Uint(UBitwidth::B32)); + let v = Variable::with_id_and_type(generic.name(), Type::Uint(UBitwidth::B32)); generics.0.insert( generic.clone(), - UExpressionInner::Identifier(generic.name.into()).annotate(UBitwidth::B32), + UExpressionInner::Identifier(generic.name().into()) + .annotate(UBitwidth::B32), ); // we don't have to check for conflicts here, because this was done when checking the signature @@ -1191,10 +1190,9 @@ impl<'ast, T: Field> Checker<'ast, T> { } else { match generics_map.insert(g.value, index).is_none() { true => { - generics.push(Some(DeclarationConstant::Generic(GenericIdentifier { - name: g.value, - index, - }))); + generics.push(Some(DeclarationConstant::Generic( + GenericIdentifier::with_name(g.value).with_index(index), + ))); } false => { errors.push(ErrorInner { @@ -1482,7 +1480,7 @@ impl<'ast, T: Field> Checker<'ast, T> { }) } } - (None, Some(index)) => Ok(DeclarationConstant::Generic(GenericIdentifier { name, index: *index })), + (None, Some(index)) => Ok(DeclarationConstant::Generic(GenericIdentifier::with_name(name).with_index(*index))), _ => Err(ErrorInner { pos: Some(pos), message: format!("Undeclared symbol `{}`", name) @@ -3117,7 +3115,7 @@ impl<'ast, T: Field> Checker<'ast, T> { declared_struct_type.generics = (0..declared_struct_type.generics.len()) .map(|index| { Some(DeclarationConstant::Generic( - GenericIdentifier::with_name("DUMMY").index(index), + GenericIdentifier::without_name().with_index(index), )) }) .collect(); @@ -3676,7 +3674,7 @@ mod tests { "bar", DeclarationSignature::new() .generics(vec![Some( - GenericIdentifier::with_name("K").index(0).into() + GenericIdentifier::with_name("K").with_index(0).into() )]) .inputs(vec![DeclarationType::FieldElement]) )); @@ -3685,11 +3683,11 @@ mod tests { "bar", DeclarationSignature::new() .generics(vec![Some( - GenericIdentifier::with_name("K").index(0).into() + GenericIdentifier::with_name("K").with_index(0).into() )]) .inputs(vec![DeclarationType::array(( DeclarationType::FieldElement, - GenericIdentifier::with_name("K").index(0) + GenericIdentifier::with_name("K").with_index(0) ))]) )); // a `bar` function with an equivalent signature, just renaming generic parameters @@ -3697,11 +3695,11 @@ mod tests { "bar", DeclarationSignature::new() .generics(vec![Some( - GenericIdentifier::with_name("L").index(0).into() + GenericIdentifier::with_name("L").with_index(0).into() )]) .inputs(vec![DeclarationType::array(( DeclarationType::FieldElement, - GenericIdentifier::with_name("L").index(0) + GenericIdentifier::with_name("L").with_index(0) ))]) )); // a `bar` type isn't allowed as the name is already taken by at least one function @@ -4265,7 +4263,7 @@ mod tests { .inputs(vec![DeclarationType::array(( DeclarationType::array(( DeclarationType::FieldElement, - GenericIdentifier::with_name("K").index(0) + GenericIdentifier::with_name("K").with_index(0) )), GenericIdentifier::with_name("L").index(1) ))]) @@ -4274,7 +4272,7 @@ mod tests { DeclarationType::FieldElement, GenericIdentifier::with_name("L").index(1) )), - GenericIdentifier::with_name("K").index(0) + GenericIdentifier::with_name("K").with_index(0) ))])) ); } diff --git a/zokrates_core/src/static_analysis/reducer/shallow_ssa.rs b/zokrates_core/src/static_analysis/reducer/shallow_ssa.rs index 5ae603bc..a06453b7 100644 --- a/zokrates_core/src/static_analysis/reducer/shallow_ssa.rs +++ b/zokrates_core/src/static_analysis/reducer/shallow_ssa.rs @@ -105,7 +105,7 @@ impl<'ast, 'a> ShallowTransformer<'ast, 'a> { .map(|(g, v)| { TypedStatement::Definition( TypedAssignee::Identifier(Variable::with_id_and_type( - g.name, + g.name(), Type::Uint(UBitwidth::B32), )), UExpression::from(*v as u32).into(), diff --git a/zokrates_core/src/static_analysis/struct_concretizer.rs b/zokrates_core/src/static_analysis/struct_concretizer.rs index 2b137357..b5fb5407 100644 --- a/zokrates_core/src/static_analysis/struct_concretizer.rs +++ b/zokrates_core/src/static_analysis/struct_concretizer.rs @@ -14,24 +14,28 @@ use crate::typed_absy::{ }, DeclarationStructType, GenericIdentifier, TypedProgram, }; +use std::marker::PhantomData; use zokrates_field::Field; -#[derive(Default)] -pub struct StructConcretizer<'ast> { +pub struct StructConcretizer<'ast, T> { generics: ConcreteGenericsAssignment<'ast>, + marker: PhantomData, } -impl<'ast> StructConcretizer<'ast> { - pub fn concretize(p: TypedProgram<'ast, T>) -> TypedProgram<'ast, T> { - StructConcretizer::default().fold_program(p) +impl<'ast, T: Field> StructConcretizer<'ast, T> { + pub fn concretize(p: TypedProgram<'ast, T>) -> TypedProgram<'ast, T> { + StructConcretizer::with_generics(ConcreteGenericsAssignment::default()).fold_program(p) } pub fn with_generics(generics: ConcreteGenericsAssignment<'ast>) -> Self { - Self { generics } + Self { + generics, + marker: PhantomData, + } } } -impl<'ast, T: Field> Folder<'ast, T> for StructConcretizer<'ast> { +impl<'ast, T: Field> Folder<'ast, T> for StructConcretizer<'ast, T> { fn fold_declaration_struct_type( &mut self, ty: DeclarationStructType<'ast>, @@ -47,11 +51,12 @@ impl<'ast, T: Field> Folder<'ast, T> for StructConcretizer<'ast> { concrete_generics .iter() .enumerate() - .map(|(index, g)| (GenericIdentifier::with_name("DUMMY").index(index), *g)) + .map(|(index, g)| (GenericIdentifier::without_name().with_index(index), *g)) .collect(), ); - let mut internal_concretizer = StructConcretizer::with_generics(concrete_generics_map); + let mut internal_concretizer: StructConcretizer<'ast, T> = + StructConcretizer::with_generics(concrete_generics_map); DeclarationStructType { members: ty @@ -60,10 +65,7 @@ impl<'ast, T: Field> Folder<'ast, T> for StructConcretizer<'ast> { .map(|member| { DeclarationStructMember::new( member.id, - >::fold_declaration_type( - &mut internal_concretizer, - *member.ty, - ), + internal_concretizer.fold_declaration_type(*member.ty), ) }) .collect(), @@ -79,15 +81,11 @@ impl<'ast, T: Field> Folder<'ast, T> for StructConcretizer<'ast> { &mut self, ty: DeclarationArrayType<'ast>, ) -> DeclarationArrayType<'ast> { - let size = match ty.size { - DeclarationConstant::Generic(s) => self.generics.0.get(&s).cloned().unwrap() as u32, - DeclarationConstant::Concrete(s) => s, - DeclarationConstant::Constant(..) => unreachable!(), - }; + let size = ty.size.map_concrete(&self.generics).unwrap(); DeclarationArrayType { size: DeclarationConstant::Concrete(size), - ty: box >::fold_declaration_type(self, *ty.ty), + ty: box self.fold_declaration_type(*ty.ty), } } } diff --git a/zokrates_core/src/typed_absy/types.rs b/zokrates_core/src/typed_absy/types.rs index 608cdb88..44c6b2e5 100644 --- a/zokrates_core/src/typed_absy/types.rs +++ b/zokrates_core/src/typed_absy/types.rs @@ -59,19 +59,37 @@ impl<'ast, T> Types<'ast, T> { #[derive(Debug, Clone, Eq)] pub struct GenericIdentifier<'ast> { - pub name: &'ast str, - pub index: usize, + name: Option<&'ast str>, + index: usize, } impl<'ast> GenericIdentifier<'ast> { - pub fn with_name(name: &'ast str) -> Self { - Self { name, index: 0 } + pub fn without_name() -> Self { + Self { + name: None, + index: 0, + } } - pub fn index(mut self, index: usize) -> Self { + pub fn with_name(name: &'ast str) -> Self { + Self { + name: Some(name), + index: 0, + } + } + + pub fn with_index(mut self, index: usize) -> Self { self.index = index; self } + + pub fn name(&self) -> &'ast str { + self.name.unwrap() + } + + pub fn index(&self) -> usize { + self.index + } } impl<'ast> PartialEq for GenericIdentifier<'ast> { @@ -100,7 +118,7 @@ impl<'ast> Hash for GenericIdentifier<'ast> { impl<'ast> fmt::Display for GenericIdentifier<'ast> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "{}", self.name) + write!(f, "{}", self.name()) } } @@ -210,7 +228,7 @@ impl<'ast, T> From> for UExpression<'ast, T> { fn from(c: DeclarationConstant<'ast>) -> Self { match c { DeclarationConstant::Generic(i) => { - UExpressionInner::Identifier(i.name.into()).annotate(UBitwidth::B32) + UExpressionInner::Identifier(i.name().into()).annotate(UBitwidth::B32) } DeclarationConstant::Concrete(v) => { UExpressionInner::Value(v as u128).annotate(UBitwidth::B32) @@ -1031,7 +1049,7 @@ pub fn specialize_declaration_type< .enumerate() .map(|(index, g)| { ( - GenericIdentifier::with_name("dummy").index(index), + GenericIdentifier::without_name().with_index(index), g.map(|g| g.map(generics).unwrap()).unwrap(), ) }) From 1ffbc024b2a751dd010ed88ea1fbd78547682ea8 Mon Sep 17 00:00:00 2001 From: schaeff Date: Tue, 21 Sep 2021 17:39:13 +0300 Subject: [PATCH 10/19] fix tests --- zokrates_abi/src/lib.rs | 7 +--- zokrates_core/src/semantics.rs | 4 +- .../src/static_analysis/constant_inliner.rs | 10 ++--- .../src/static_analysis/propagation.rs | 2 +- .../src/static_analysis/reducer/mod.rs | 42 +++++++++---------- .../static_analysis/reducer/shallow_ssa.rs | 12 +++--- zokrates_core/src/typed_absy/abi.rs | 10 ++--- zokrates_core/src/typed_absy/types.rs | 26 +++--------- 8 files changed, 48 insertions(+), 65 deletions(-) diff --git a/zokrates_abi/src/lib.rs b/zokrates_abi/src/lib.rs index 6c07891e..16cdca31 100644 --- a/zokrates_abi/src/lib.rs +++ b/zokrates_abi/src/lib.rs @@ -413,11 +413,8 @@ mod tests { fn array() { let s = "[[true, false]]"; assert_eq!( - parse_strict::( - s, - vec![ConcreteType::array((ConcreteType::Boolean, 2usize))] - ) - .unwrap(), + parse_strict::(s, vec![ConcreteType::array((ConcreteType::Boolean, 2u32))]) + .unwrap(), Values(vec![Value::Array(vec![ Value::Boolean(true), Value::Boolean(false) diff --git a/zokrates_core/src/semantics.rs b/zokrates_core/src/semantics.rs index ff019481..36f7af1d 100644 --- a/zokrates_core/src/semantics.rs +++ b/zokrates_core/src/semantics.rs @@ -4265,12 +4265,12 @@ mod tests { DeclarationType::FieldElement, GenericIdentifier::with_name("K").with_index(0) )), - GenericIdentifier::with_name("L").index(1) + GenericIdentifier::with_name("L").with_index(1) ))]) .outputs(vec![DeclarationType::array(( DeclarationType::array(( DeclarationType::FieldElement, - GenericIdentifier::with_name("L").index(1) + GenericIdentifier::with_name("L").with_index(1) )), GenericIdentifier::with_name("K").with_index(0) ))])) diff --git a/zokrates_core/src/static_analysis/constant_inliner.rs b/zokrates_core/src/static_analysis/constant_inliner.rs index 589ace6f..d2b803e8 100644 --- a/zokrates_core/src/static_analysis/constant_inliner.rs +++ b/zokrates_core/src/static_analysis/constant_inliner.rs @@ -574,13 +574,13 @@ mod tests { statements: vec![TypedStatement::Return(vec![FieldElementExpression::Add( FieldElementExpression::select( ArrayExpressionInner::Identifier(Identifier::from(const_id)) - .annotate(GType::FieldElement, 2usize), + .annotate(GType::FieldElement, 2u32), UExpressionInner::Value(0u128).annotate(UBitwidth::B32), ) .into(), FieldElementExpression::select( ArrayExpressionInner::Identifier(Identifier::from(const_id)) - .annotate(GType::FieldElement, 2usize), + .annotate(GType::FieldElement, 2u32), UExpressionInner::Value(1u128).annotate(UBitwidth::B32), ) .into(), @@ -608,7 +608,7 @@ mod tests { ] .into(), ) - .annotate(GType::FieldElement, 2usize), + .annotate(GType::FieldElement, 2u32), ))), )] .into_iter() @@ -649,7 +649,7 @@ mod tests { ] .into(), ) - .annotate(GType::FieldElement, 2usize), + .annotate(GType::FieldElement, 2u32), UExpressionInner::Value(0u128).annotate(UBitwidth::B32), ) .into(), @@ -661,7 +661,7 @@ mod tests { ] .into(), ) - .annotate(GType::FieldElement, 2usize), + .annotate(GType::FieldElement, 2u32), UExpressionInner::Value(1u128).annotate(UBitwidth::B32), ) .into(), diff --git a/zokrates_core/src/static_analysis/propagation.rs b/zokrates_core/src/static_analysis/propagation.rs index 1450e077..9424ee35 100644 --- a/zokrates_core/src/static_analysis/propagation.rs +++ b/zokrates_core/src/static_analysis/propagation.rs @@ -1463,7 +1463,7 @@ mod tests { ] .into(), ) - .annotate(Type::FieldElement, 3usize), + .annotate(Type::FieldElement, 3u32), UExpressionInner::Add(box 1u32.into(), box 1u32.into()) .annotate(UBitwidth::B32), ); diff --git a/zokrates_core/src/static_analysis/reducer/mod.rs b/zokrates_core/src/static_analysis/reducer/mod.rs index 84790f29..b5b8fd24 100644 --- a/zokrates_core/src/static_analysis/reducer/mod.rs +++ b/zokrates_core/src/static_analysis/reducer/mod.rs @@ -830,22 +830,22 @@ mod tests { let foo_signature = DeclarationSignature::new() .generics(vec![Some( - GenericIdentifier::with_name("K").index(0).into(), + GenericIdentifier::with_name("K").with_index(0).into(), )]) .inputs(vec![DeclarationType::array(( DeclarationType::FieldElement, - DeclarationConstant::Generic(GenericIdentifier::with_name("K").index(0)), + DeclarationConstant::Generic(GenericIdentifier::with_name("K").with_index(0)), ))]) .outputs(vec![DeclarationType::array(( DeclarationType::FieldElement, - DeclarationConstant::Generic(GenericIdentifier::with_name("K").index(0)), + DeclarationConstant::Generic(GenericIdentifier::with_name("K").with_index(0)), ))]); let foo: TypedFunction = TypedFunction { arguments: vec![DeclarationVariable::array( "a", DeclarationType::FieldElement, - GenericIdentifier::with_name("K").index(0), + GenericIdentifier::with_name("K").with_index(0), ) .into()], statements: vec![TypedStatement::Return(vec![ @@ -954,7 +954,7 @@ mod tests { DeclarationFunctionKey::with_location("main", "foo") .signature(foo_signature.clone()), GGenericsAssignment( - vec![(GenericIdentifier::with_name("K").index(0), 1)] + vec![(GenericIdentifier::with_name("K").with_index(0), 1)] .into_iter() .collect(), ), @@ -1049,22 +1049,22 @@ mod tests { let foo_signature = DeclarationSignature::new() .generics(vec![Some( - GenericIdentifier::with_name("K").index(0).into(), + GenericIdentifier::with_name("K").with_index(0).into(), )]) .inputs(vec![DeclarationType::array(( DeclarationType::FieldElement, - DeclarationConstant::Generic(GenericIdentifier::with_name("K").index(0)), + DeclarationConstant::Generic(GenericIdentifier::with_name("K").with_index(0)), ))]) .outputs(vec![DeclarationType::array(( DeclarationType::FieldElement, - DeclarationConstant::Generic(GenericIdentifier::with_name("K").index(0)), + DeclarationConstant::Generic(GenericIdentifier::with_name("K").with_index(0)), ))]); let foo: TypedFunction = TypedFunction { arguments: vec![DeclarationVariable::array( "a", DeclarationType::FieldElement, - GenericIdentifier::with_name("K").index(0), + GenericIdentifier::with_name("K").with_index(0), ) .into()], statements: vec![TypedStatement::Return(vec![ @@ -1182,7 +1182,7 @@ mod tests { DeclarationFunctionKey::with_location("main", "foo") .signature(foo_signature.clone()), GGenericsAssignment( - vec![(GenericIdentifier::with_name("K").index(0), 1)] + vec![(GenericIdentifier::with_name("K").with_index(0), 1)] .into_iter() .collect(), ), @@ -1280,21 +1280,21 @@ mod tests { let foo_signature = DeclarationSignature::new() .inputs(vec![DeclarationType::array(( DeclarationType::FieldElement, - DeclarationConstant::Generic(GenericIdentifier::with_name("K").index(0)), + DeclarationConstant::Generic(GenericIdentifier::with_name("K").with_index(0)), ))]) .outputs(vec![DeclarationType::array(( DeclarationType::FieldElement, - DeclarationConstant::Generic(GenericIdentifier::with_name("K").index(0)), + DeclarationConstant::Generic(GenericIdentifier::with_name("K").with_index(0)), ))]) .generics(vec![Some( - GenericIdentifier::with_name("K").index(0).into(), + GenericIdentifier::with_name("K").with_index(0).into(), )]); let foo: TypedFunction = TypedFunction { arguments: vec![DeclarationVariable::array( "a", DeclarationType::FieldElement, - DeclarationConstant::Generic(GenericIdentifier::with_name("K").index(0)), + DeclarationConstant::Generic(GenericIdentifier::with_name("K").with_index(0)), ) .into()], statements: vec![ @@ -1358,7 +1358,7 @@ mod tests { arguments: vec![DeclarationVariable::array( "a", DeclarationType::FieldElement, - DeclarationConstant::Generic(GenericIdentifier::with_name("K").index(0)), + DeclarationConstant::Generic(GenericIdentifier::with_name("K").with_index(0)), ) .into()], statements: vec![TypedStatement::Return(vec![ @@ -1433,7 +1433,7 @@ mod tests { DeclarationFunctionKey::with_location("main", "foo") .signature(foo_signature.clone()), GGenericsAssignment( - vec![(GenericIdentifier::with_name("K").index(0), 1)] + vec![(GenericIdentifier::with_name("K").with_index(0), 1)] .into_iter() .collect(), ), @@ -1442,7 +1442,7 @@ mod tests { DeclarationFunctionKey::with_location("main", "bar") .signature(foo_signature.clone()), GGenericsAssignment( - vec![(GenericIdentifier::with_name("K").index(0), 2)] + vec![(GenericIdentifier::with_name("K").with_index(0), 2)] .into_iter() .collect(), ), @@ -1489,22 +1489,22 @@ mod tests { let foo_signature = DeclarationSignature::new() .generics(vec![Some( - GenericIdentifier::with_name("K").index(0).into(), + GenericIdentifier::with_name("K").with_index(0).into(), )]) .inputs(vec![DeclarationType::array(( DeclarationType::FieldElement, - GenericIdentifier::with_name("K").index(0), + GenericIdentifier::with_name("K").with_index(0), ))]) .outputs(vec![DeclarationType::array(( DeclarationType::FieldElement, - GenericIdentifier::with_name("K").index(0), + GenericIdentifier::with_name("K").with_index(0), ))]); let foo: TypedFunction = TypedFunction { arguments: vec![DeclarationVariable::array( "a", DeclarationType::FieldElement, - GenericIdentifier::with_name("K").index(0), + GenericIdentifier::with_name("K").with_index(0), ) .into()], statements: vec![TypedStatement::Return(vec![ diff --git a/zokrates_core/src/static_analysis/reducer/shallow_ssa.rs b/zokrates_core/src/static_analysis/reducer/shallow_ssa.rs index a06453b7..aa744e8d 100644 --- a/zokrates_core/src/static_analysis/reducer/shallow_ssa.rs +++ b/zokrates_core/src/static_analysis/reducer/shallow_ssa.rs @@ -662,7 +662,7 @@ mod tests { ], signature: DeclarationSignature::new() .generics(vec![Some( - GenericIdentifier::with_name("K").index(0).into(), + GenericIdentifier::with_name("K").with_index(0).into(), )]) .inputs(vec![DeclarationType::FieldElement]) .outputs(vec![DeclarationType::FieldElement]), @@ -673,7 +673,7 @@ mod tests { let ssa = ShallowTransformer::transform( f, &GGenericsAssignment( - vec![(GenericIdentifier::with_name("K").index(0), 1)] + vec![(GenericIdentifier::with_name("K").with_index(0), 1)] .into_iter() .collect(), ), @@ -742,7 +742,7 @@ mod tests { ], signature: DeclarationSignature::new() .generics(vec![Some( - GenericIdentifier::with_name("K").index(0).into(), + GenericIdentifier::with_name("K").with_index(0).into(), )]) .inputs(vec![DeclarationType::FieldElement]) .outputs(vec![DeclarationType::FieldElement]), @@ -851,7 +851,7 @@ mod tests { ], signature: DeclarationSignature::new() .generics(vec![Some( - GenericIdentifier::with_name("K").index(0).into(), + GenericIdentifier::with_name("K").with_index(0).into(), )]) .inputs(vec![DeclarationType::FieldElement]) .outputs(vec![DeclarationType::FieldElement]), @@ -862,7 +862,7 @@ mod tests { let ssa = ShallowTransformer::transform( f, &GGenericsAssignment( - vec![(GenericIdentifier::with_name("K").index(0), 1)] + vec![(GenericIdentifier::with_name("K").with_index(0), 1)] .into_iter() .collect(), ), @@ -934,7 +934,7 @@ mod tests { ], signature: DeclarationSignature::new() .generics(vec![Some( - GenericIdentifier::with_name("K").index(0).into(), + GenericIdentifier::with_name("K").with_index(0).into(), )]) .inputs(vec![DeclarationType::FieldElement]) .outputs(vec![DeclarationType::FieldElement]), diff --git a/zokrates_core/src/typed_absy/abi.rs b/zokrates_core/src/typed_absy/abi.rs index 15554a73..0dfbfade 100644 --- a/zokrates_core/src/typed_absy/abi.rs +++ b/zokrates_core/src/typed_absy/abi.rs @@ -231,12 +231,12 @@ mod tests { ty: ConcreteType::Struct(ConcreteStructType::new( "".into(), "Bar".into(), - vec![Some(1usize)], + vec![Some(1u32)], vec![ConcreteStructMember::new( String::from("a"), ConcreteType::Array(ConcreteArrayType::new( ConcreteType::FieldElement, - 1usize, + 1u32, )), )], )), @@ -400,7 +400,7 @@ mod tests { ConcreteStructMember::new(String::from("c"), ConcreteType::Boolean), ], )), - 2usize, + 2u32, )), }], outputs: vec![ConcreteType::Boolean], @@ -454,8 +454,8 @@ mod tests { name: String::from("a"), public: false, ty: ConcreteType::Array(ConcreteArrayType::new( - ConcreteType::Array(ConcreteArrayType::new(ConcreteType::FieldElement, 2usize)), - 2usize, + ConcreteType::Array(ConcreteArrayType::new(ConcreteType::FieldElement, 2u32)), + 2u32, )), }], outputs: vec![ConcreteType::FieldElement], diff --git a/zokrates_core/src/typed_absy/types.rs b/zokrates_core/src/typed_absy/types.rs index 44c6b2e5..750de024 100644 --- a/zokrates_core/src/typed_absy/types.rs +++ b/zokrates_core/src/typed_absy/types.rs @@ -1405,33 +1405,19 @@ pub mod signature { let generic1 = DeclarationSignature::new() .generics(vec![Some( - GenericIdentifier { - name: "P", - index: 0, - } - .into(), + GenericIdentifier::with_name("P").with_index(0).into(), )]) .inputs(vec![DeclarationType::array(DeclarationArrayType::new( DeclarationType::FieldElement, - GenericIdentifier { - name: "P", - index: 0, - }, + GenericIdentifier::with_name("P").with_index(0), ))]); let generic2 = DeclarationSignature::new() .generics(vec![Some( - GenericIdentifier { - name: "Q", - index: 0, - } - .into(), + GenericIdentifier::with_name("Q").with_index(0).into(), )]) .inputs(vec![DeclarationType::array(DeclarationArrayType::new( DeclarationType::FieldElement, - GenericIdentifier { - name: "Q", - index: 0, - }, + GenericIdentifier::with_name("Q").with_index(0), ))]); assert_eq!(generic1, generic2); @@ -1512,7 +1498,7 @@ mod tests { #[test] fn array() { - let t = ConcreteType::Array(ConcreteArrayType::new(ConcreteType::FieldElement, 42usize)); + let t = ConcreteType::Array(ConcreteArrayType::new(ConcreteType::FieldElement, 42u32)); assert_eq!(t.get_primitive_count(), 42); } @@ -1521,7 +1507,7 @@ mod tests { // field[1][2] let t = ConcreteType::Array(ConcreteArrayType::new( ConcreteType::Array(ConcreteArrayType::new(ConcreteType::FieldElement, 2u32)), - 1usize, + 1u32, )); assert_eq!(format!("{}", t), "field[1][2]"); } From 30ebec6f375af5a243679a912b45d5894157f8a2 Mon Sep 17 00:00:00 2001 From: schaeff Date: Thu, 23 Sep 2021 11:14:18 +0300 Subject: [PATCH 11/19] simplify treatment of user defined types --- zokrates_core/src/semantics.rs | 399 +++++++++++++-------------------- 1 file changed, 159 insertions(+), 240 deletions(-) diff --git a/zokrates_core/src/semantics.rs b/zokrates_core/src/semantics.rs index 36f7af1d..2c28e0b3 100644 --- a/zokrates_core/src/semantics.rs +++ b/zokrates_core/src/semantics.rs @@ -55,9 +55,33 @@ impl ErrorInner { } } -type GenericDeclarations<'ast> = Option>>>; -type TypeMap<'ast> = - HashMap, GenericDeclarations<'ast>)>>; +// a single struct to cover all cases of user-defined types +#[derive(Debug, Clone)] +struct UserDeclarationType<'ast> { + generics: Vec>, + ty: DeclarationType<'ast>, +} + +impl<'ast> UserDeclarationType<'ast> { + // returns the declared generics for this user type + // for alias of basic types this is empty + // for structs this is the same as the used generics + // for aliases of structs this is the names of the generics declared on the left side of the type declaration + fn declaration_generics(&self) -> Vec<&'ast str> { + self.generics + .iter() + .filter_map(|g| match g { + DeclarationConstant::Generic(g) => Some(g), + _ => None, + }) + .collect::>() // we collect into a BTreeSet because draining it after yields the element in the right order defined by Ord + .into_iter() + .map(|g| g.name()) + .collect() + } +} + +type TypeMap<'ast> = HashMap>>; type ConstantMap<'ast> = HashMap, DeclarationType<'ast>>>; @@ -356,7 +380,7 @@ impl<'ast, T: Field> Checker<'ast, T> { ty: TypeDefinitionNode<'ast>, module_id: &ModuleId, state: &State<'ast, T>, - ) -> Result<(DeclarationType<'ast>, GenericDeclarations<'ast>), Vec> { + ) -> Result, Vec> { let pos = ty.pos(); let ty = ty.value; @@ -382,9 +406,9 @@ impl<'ast, T: Field> Checker<'ast, T> { } else { match generics_map.insert(g.value, index).is_none() { true => { - generics.push(Some(DeclarationConstant::Generic( + generics.push(DeclarationConstant::Generic( GenericIdentifier::with_name(g.value).with_index(index), - ))); + )); } false => { errors.push(ErrorInner { @@ -420,7 +444,7 @@ impl<'ast, T: Field> Checker<'ast, T> { return Err(errors); } - Ok((ty, Some(generics))) + Ok(UserDeclarationType { generics, ty }) } Err(e) => { errors.push(e); @@ -486,7 +510,7 @@ impl<'ast, T: Field> Checker<'ast, T> { s: StructDefinitionNode<'ast>, module_id: &ModuleId, state: &State<'ast, T>, - ) -> Result, Vec> { + ) -> Result, Vec> { let pos = s.pos(); let s = s.value; @@ -569,7 +593,7 @@ impl<'ast, T: Field> Checker<'ast, T> { return Err(errors); } - Ok(DeclarationType::Struct(DeclarationStructType::new( + Ok(DeclarationStructType::new( module_id.to_path_buf(), id, generics, @@ -577,7 +601,7 @@ impl<'ast, T: Field> Checker<'ast, T> { .iter() .map(|f| DeclarationStructMember::new(f.0.clone(), f.1.clone())) .collect(), - ))) + )) } fn check_symbol_declaration( @@ -620,7 +644,18 @@ impl<'ast, T: Field> Checker<'ast, T> { .types .entry(module_id.to_path_buf()) .or_default() - .insert(declaration.id.to_string(), (ty, None)) + .insert( + declaration.id.to_string(), + UserDeclarationType { + generics: ty + .generics + .clone() + .into_iter() + .map(|g| g.unwrap()) + .collect(), + ty: DeclarationType::Struct(ty) + } + ) .is_none()); } }; @@ -781,18 +816,20 @@ impl<'ast, T: Field> Checker<'ast, T> { .cloned(); match (function_candidates.len(), type_candidate, const_candidate) { - (0, Some((t, alias_generics)), None) => { - + (0, Some(t), None) => { // rename the type to the declared symbol - let t = match t { - DeclarationType::Struct(t) => DeclarationType::Struct(DeclarationStructType { - location: Some(StructLocation { - name: declaration.id.into(), - module: module_id.to_path_buf() + let t = UserDeclarationType { + ty: match t.ty { + DeclarationType::Struct(t) => DeclarationType::Struct(DeclarationStructType { + location: Some(StructLocation { + name: declaration.id.into(), + module: module_id.to_path_buf() + }), + ..t }), - ..t - }), - _ => t // type alias + _ => t.ty // all other cases + }, + ..t }; // we imported a type, so the symbol it gets bound to should not already exist @@ -814,7 +851,7 @@ impl<'ast, T: Field> Checker<'ast, T> { .types .entry(module_id.to_path_buf()) .or_default() - .insert(declaration.id.to_string(), (t, alias_generics)); + .insert(declaration.id.to_string(), t); } (0, None, Some(ty)) => { match symbol_unifier.insert_constant(declaration.id) { @@ -1306,133 +1343,71 @@ impl<'ast, T: Field> Checker<'ast, T> { ))) } UnresolvedType::User(id, generics) => { - let (declaration_type, alias_generics) = types - .get(module_id) - .unwrap() - .get(&id) - .cloned() - .ok_or_else(|| ErrorInner { - pos: Some(pos), - message: format!("Undefined type {}", id), - })?; + let declared_ty = + types + .get(module_id) + .unwrap() + .get(&id) + .cloned() + .ok_or_else(|| ErrorInner { + pos: Some(pos), + message: format!("Undefined type {}", id), + })?; + + let generic_identifiers = declared_ty.declaration_generics(); + + let declaration_type = declared_ty.ty; // absence of generics is treated as 0 generics, as we do not provide inference for now let generics = generics.unwrap_or_default(); // check generics - match (declaration_type, alias_generics) { - (DeclarationType::Struct(struct_type), None) => { - match struct_type.generics.len() == generics.len() { - true => { - // downcast the generics to identifiers, as this is the only possibility here - let generic_identifiers = struct_type.generics.iter().map(|c| { - match c.as_ref().unwrap() { - DeclarationConstant::Generic(g) => g.clone(), - _ => unreachable!(), - } - }); - - // build the generic assignment for this type - let assignment = GGenericsAssignment(generics - .into_iter() - .zip(generic_identifiers) - .map(|(e, g)| match e { - Some(e) => { - self - .check_expression(e, module_id, types) - .and_then(|e| { - UExpression::try_from_typed(e, &UBitwidth::B32) - .map(|e| (g, e)) - .map_err(|e| ErrorInner { - pos: Some(pos), - message: format!("Expected u32 expression, but got expression of type {}", e.get_type()), - }) + match generic_identifiers.len() == generics.len() { + true => { + // build the generic assignment for this type + let assignment = GGenericsAssignment(generics + .into_iter() + .zip(generic_identifiers) + .enumerate() + .map(|(i, (e, g))| match e { + Some(e) => { + self + .check_expression(e, module_id, types) + .and_then(|e| { + UExpression::try_from_typed(e, &UBitwidth::B32) + .map(|e| (GenericIdentifier::with_name(g).with_index(i), e)) + .map_err(|e| ErrorInner { + pos: Some(pos), + message: format!("Expected u32 expression, but got expression of type {}", e.get_type()), }) - }, - None => Err(ErrorInner { - pos: Some(pos), - message: - "Expected u32 constant or identifier, but found `_`. Generic inference is not supported yet." - .into(), }) - }) - .collect::>()?); + }, + None => Err(ErrorInner { + pos: Some(pos), + message: + "Expected u32 constant or identifier, but found `_`. Generic inference is not supported yet." + .into(), + }) + }) + .collect::>()?); - // specialize the declared type using the generic assignment - Ok(specialize_declaration_type( - DeclarationType::Struct(struct_type), - &assignment, - ) - .unwrap()) - } - false => Err(ErrorInner { - pos: Some(pos), - message: format!( - "Expected {} generic argument{} on type {}, but got {}", - struct_type.generics.len(), - if struct_type.generics.len() == 1 { - "" - } else { - "s" - }, - id, - generics.len() - ), - }), - } + // specialize the declared type using the generic assignment + Ok(specialize_declaration_type(declaration_type, &assignment).unwrap()) } - (declaration_type, Some(alias_generics)) => { - match alias_generics.len() == generics.len() { - true => { - let generic_identifiers = - alias_generics.iter().map(|c| match c.as_ref().unwrap() { - DeclarationConstant::Generic(g) => g.clone(), - _ => unreachable!(), - }); - - // build the generic assignment for this type - let assignment = GGenericsAssignment(generics - .into_iter() - .zip(generic_identifiers) - .map(|(e, g)| match e { - Some(e) => { - self - .check_expression(e, module_id, types) - .and_then(|e| { - UExpression::try_from_typed(e, &UBitwidth::B32) - .map(|e| (g, e)) - .map_err(|e| ErrorInner { - pos: Some(pos), - message: format!("Expected u32 expression, but got expression of type {}", e.get_type()), - }) - }) - }, - None => Err(ErrorInner { - pos: Some(pos), - message: - "Expected u32 constant or identifier, but found `_`. Generic inference is not supported yet." - .into(), - }) - }) - .collect::>()?); - - // specialize the declared type using the generic assignment - Ok(specialize_declaration_type(declaration_type, &assignment) - .unwrap()) - } - false => Err(ErrorInner { - pos: Some(pos), - message: format!( - "Expected {} generic argument{} on type {}, but got {}", - alias_generics.len(), - if alias_generics.len() == 1 { "" } else { "s" }, - id, - generics.len() - ), - }), - } - } - _ => unreachable!(), + false => Err(ErrorInner { + pos: Some(pos), + message: format!( + "Expected {} generic argument{} on type {}, but got {}", + generic_identifiers.len(), + if generic_identifiers.len() == 1 { + "" + } else { + "s" + }, + id, + generics.len() + ), + }), } } } @@ -1527,7 +1502,7 @@ impl<'ast, T: Field> Checker<'ast, T> { ))) } UnresolvedType::User(id, generics) => { - let (declared_ty, alias_generics) = state + let ty = state .types .get(module_id) .unwrap() @@ -1538,99 +1513,47 @@ impl<'ast, T: Field> Checker<'ast, T> { message: format!("Undefined type {}", id), })?; - match (declared_ty, alias_generics) { - (ty, Some(alias_generics)) => { - let generics = generics.unwrap_or_default(); - let checked_generics: Vec<_> = generics - .into_iter() - .map(|e| match e { - Some(e) => self - .check_generic_expression( - e, - module_id, - state.constants.get(module_id).unwrap_or(&HashMap::new()), - generics_map, - used_generics, - ) - .map(Some), - None => Err(ErrorInner { - pos: Some(pos), - message: "Expected u32 constant or identifier, but found `_`" - .into(), - }), - }) - .collect::>()?; + let generics = generics.unwrap_or_default(); + let checked_generics: Vec<_> = generics + .into_iter() + .map(|e| match e { + Some(e) => self + .check_generic_expression( + e, + module_id, + state.constants.get(module_id).unwrap_or(&HashMap::new()), + generics_map, + used_generics, + ) + .map(Some), + None => Err(ErrorInner { + pos: Some(pos), + message: "Expected u32 constant or identifier, but found `_`".into(), + }), + }) + .collect::>()?; + match ty.generics.len() == checked_generics.len() { + true => { let mut assignment = GGenericsAssignment::default(); - assignment.0.extend( - alias_generics.iter().zip(checked_generics.iter()).map( - |(decl_g, g_val)| match decl_g.clone().unwrap() { - DeclarationConstant::Generic(g) => (g, g_val.clone().unwrap()), - _ => unreachable!(), - }, - ), - ); + assignment.0.extend(ty.generics.iter().zip(checked_generics.iter()).map(|(decl_g, g_val)| match decl_g.clone() { + DeclarationConstant::Generic(g) => (g, g_val.clone().unwrap()), + _ => unreachable!("generic on declaration struct types must be generic identifiers") + })); - Ok(specialize_declaration_type(ty, &assignment).unwrap()) + Ok(specialize_declaration_type(ty.ty, &assignment).unwrap()) } - (DeclarationType::Struct(declared_struct_ty), None) => { - let generics = generics.unwrap_or_default(); - match declared_struct_ty.generics.len() == generics.len() { - true => { - let checked_generics: Vec<_> = generics - .into_iter() - .map(|e| match e { - Some(e) => self - .check_generic_expression( - e, - module_id, - state - .constants - .get(module_id) - .unwrap_or(&HashMap::new()), - generics_map, - used_generics, - ) - .map(Some), - None => Err(ErrorInner { - pos: Some(pos), - message: - "Expected u32 constant or identifier, but found `_`" - .into(), - }), - }) - .collect::>()?; - - let mut assignment = GGenericsAssignment::default(); - - assignment.0.extend(declared_struct_ty.generics.iter().zip(checked_generics.iter()).map(|(decl_g, g_val)| match decl_g.clone().unwrap() { - DeclarationConstant::Generic(g) => (g, g_val.clone().unwrap()), - _ => unreachable!("generic on declaration struct types must be generic identifiers") - })); - - Ok(DeclarationType::Struct(DeclarationStructType { - generics: checked_generics, - ..declared_struct_ty - })) - } - false => Err(ErrorInner { - pos: Some(pos), - message: format!( - "Expected {} generic argument{} on type {}, but got {}", - declared_struct_ty.generics.len(), - if declared_struct_ty.generics.len() == 1 { - "" - } else { - "s" - }, - id, - generics.len() - ), - }), - } - } - (declared_ty, _) => Ok(declared_ty), + false => Err(ErrorInner { + pos: Some(pos), + message: format!( + "Expected {} generic argument{} on type {}, but got {}", + ty.generics.len(), + if ty.generics.len() == 1 { "" } else { "s" }, + id, + checked_generics.len() + ), + }), } } } @@ -3099,7 +3022,7 @@ impl<'ast, T: Field> Checker<'ast, T> { .into()) } Expression::InlineStruct(id, inline_members) => { - let (ty, _) = match types.get(module_id).unwrap().get(&id).cloned() { + let ty = match types.get(module_id).unwrap().get(&id).cloned() { None => Err(ErrorInner { pos: Some(pos), message: format!("Undefined type `{}`", id), @@ -3107,7 +3030,7 @@ impl<'ast, T: Field> Checker<'ast, T> { Some(ty) => Ok(ty), }?; - let mut declared_struct_type = match ty { + let mut declared_struct_type = match ty.ty { DeclarationType::Struct(struct_type) => struct_type, _ => unreachable!(), }; @@ -5529,12 +5452,8 @@ mod tests { } .mock(); - let expected_type = DeclarationType::Struct(DeclarationStructType::new( - "".into(), - "Foo".into(), - vec![], - vec![], - )); + let expected_type = + DeclarationStructType::new("".into(), "Foo".into(), vec![], vec![]); assert_eq!( Checker::::new().check_struct_type_declaration( @@ -5570,7 +5489,7 @@ mod tests { } .mock(); - let expected_type = DeclarationType::Struct(DeclarationStructType::new( + let expected_type = DeclarationStructType::new( "".into(), "Foo".into(), vec![], @@ -5578,7 +5497,7 @@ mod tests { DeclarationStructMember::new("foo".into(), DeclarationType::FieldElement), DeclarationStructMember::new("bar".into(), DeclarationType::Boolean), ], - )); + ); assert_eq!( Checker::::new().check_struct_type_declaration( @@ -5681,7 +5600,7 @@ mod tests { .get(&*MODULE_ID) .unwrap() .get(&"Bar".to_string()) - .map(|(ty, _)| ty) + .map(|ty| &ty.ty) .unwrap(), &DeclarationType::Struct(DeclarationStructType::new( (*MODULE_ID).clone(), From b8e831ed5b8efc00881b2f89fcd86ac3f643061f Mon Sep 17 00:00:00 2001 From: schaeff Date: Thu, 23 Sep 2021 14:13:21 +0300 Subject: [PATCH 12/19] avoid visiting members --- zokrates_core/src/semantics.rs | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/zokrates_core/src/semantics.rs b/zokrates_core/src/semantics.rs index 2c28e0b3..fb4687e6 100644 --- a/zokrates_core/src/semantics.rs +++ b/zokrates_core/src/semantics.rs @@ -1542,7 +1542,19 @@ impl<'ast, T: Field> Checker<'ast, T> { _ => unreachable!("generic on declaration struct types must be generic identifiers") })); - Ok(specialize_declaration_type(ty.ty, &assignment).unwrap()) + let res = match ty.ty { + // if the type is a struct, we do not specialize in the members. + // we only remap the generics + DeclarationType::Struct(declared_struct_ty) => { + DeclarationType::Struct(DeclarationStructType { + generics: checked_generics, + ..declared_struct_ty + }) + } + ty => specialize_declaration_type(ty, &assignment).unwrap(), + }; + + Ok(res) } false => Err(ErrorInner { pos: Some(pos), From 4b99a99815c29d9db827171a6465843ff43203d9 Mon Sep 17 00:00:00 2001 From: schaeff Date: Thu, 23 Sep 2021 14:45:45 +0300 Subject: [PATCH 13/19] remove explicit generics where applicable --- zokrates_cli/examples/compile_errors/variable_constant_lt.zok | 2 +- zokrates_stdlib/stdlib/hashes/keccak/256bit.zok | 2 +- zokrates_stdlib/stdlib/hashes/keccak/384bit.zok | 2 +- zokrates_stdlib/stdlib/hashes/keccak/512bit.zok | 2 +- zokrates_stdlib/stdlib/hashes/sha3/256bit.zok | 2 +- zokrates_stdlib/stdlib/hashes/sha3/384bit.zok | 2 +- zokrates_stdlib/stdlib/hashes/sha3/512bit.zok | 2 +- zokrates_stdlib/stdlib/utils/pack/bool/nonStrictUnpack256.zok | 2 +- zokrates_stdlib/stdlib/utils/pack/bool/unpack128.zok | 2 +- zokrates_stdlib/tests/tests/hashes/blake2/blake2s_1536bit.zok | 2 +- zokrates_stdlib/tests/tests/hashes/blake2/blake2s_512bit.zok | 2 +- zokrates_stdlib/tests/tests/hashes/blake2/blake2s_p.zok | 2 +- zokrates_stdlib/tests/tests/hashes/keccak/384bit.zok | 2 +- zokrates_stdlib/tests/tests/hashes/mimcSponge/mimcSponge.zok | 4 ++-- zokrates_stdlib/tests/tests/hashes/sha3/256bit.zok | 2 +- zokrates_stdlib/tests/tests/hashes/sha3/512bit.zok | 2 +- 16 files changed, 17 insertions(+), 17 deletions(-) diff --git a/zokrates_cli/examples/compile_errors/variable_constant_lt.zok b/zokrates_cli/examples/compile_errors/variable_constant_lt.zok index 062d8295..86c3979c 100644 --- a/zokrates_cli/examples/compile_errors/variable_constant_lt.zok +++ b/zokrates_cli/examples/compile_errors/variable_constant_lt.zok @@ -2,4 +2,4 @@ from "EMBED" import bit_array_le // Calling the `bit_array_le` embed on a non-constant second argument should fail at compile-time def main(bool[1] a, bool[1] b) -> bool: - return bit_array_le::<1>(a, b) \ No newline at end of file + return bit_array_le(a, b) \ No newline at end of file diff --git a/zokrates_stdlib/stdlib/hashes/keccak/256bit.zok b/zokrates_stdlib/stdlib/hashes/keccak/256bit.zok index 59d800fe..a066170c 100644 --- a/zokrates_stdlib/stdlib/hashes/keccak/256bit.zok +++ b/zokrates_stdlib/stdlib/hashes/keccak/256bit.zok @@ -1,4 +1,4 @@ import "hashes/keccak/keccak" as keccak def main(u64[N] input) -> u64[4]: - return keccak::(input, 0x0000000000000001)[..4] \ No newline at end of file + return keccak::<_, 256>(input, 0x0000000000000001)[..4] \ No newline at end of file diff --git a/zokrates_stdlib/stdlib/hashes/keccak/384bit.zok b/zokrates_stdlib/stdlib/hashes/keccak/384bit.zok index f261ebcc..0378c6dd 100644 --- a/zokrates_stdlib/stdlib/hashes/keccak/384bit.zok +++ b/zokrates_stdlib/stdlib/hashes/keccak/384bit.zok @@ -1,4 +1,4 @@ import "hashes/keccak/keccak" as keccak def main(u64[N] input) -> u64[6]: - return keccak::(input, 0x0000000000000001)[..6] \ No newline at end of file + return keccak::<_, 384>(input, 0x0000000000000001)[..6] \ No newline at end of file diff --git a/zokrates_stdlib/stdlib/hashes/keccak/512bit.zok b/zokrates_stdlib/stdlib/hashes/keccak/512bit.zok index 8345df52..00678fcb 100644 --- a/zokrates_stdlib/stdlib/hashes/keccak/512bit.zok +++ b/zokrates_stdlib/stdlib/hashes/keccak/512bit.zok @@ -1,4 +1,4 @@ import "hashes/keccak/keccak" as keccak def main(u64[N] input) -> u64[8]: - return keccak::(input, 0x0000000000000001)[..8] \ No newline at end of file + return keccak::<_, 512>(input, 0x0000000000000001)[..8] \ No newline at end of file diff --git a/zokrates_stdlib/stdlib/hashes/sha3/256bit.zok b/zokrates_stdlib/stdlib/hashes/sha3/256bit.zok index 99d213fa..2dfed2cc 100644 --- a/zokrates_stdlib/stdlib/hashes/sha3/256bit.zok +++ b/zokrates_stdlib/stdlib/hashes/sha3/256bit.zok @@ -1,4 +1,4 @@ import "hashes/keccak/keccak" as keccak def main(u64[N] input) -> (u64[4]): - return keccak::(input, 0x0000000000000006)[..4] \ No newline at end of file + return keccak::<_, 256>(input, 0x0000000000000006)[..4] \ No newline at end of file diff --git a/zokrates_stdlib/stdlib/hashes/sha3/384bit.zok b/zokrates_stdlib/stdlib/hashes/sha3/384bit.zok index 1b6dfeff..95c8dafe 100644 --- a/zokrates_stdlib/stdlib/hashes/sha3/384bit.zok +++ b/zokrates_stdlib/stdlib/hashes/sha3/384bit.zok @@ -1,4 +1,4 @@ import "hashes/keccak/keccak" as keccak def main(u64[N] input) -> (u64[6]): - return keccak::(input, 0x0000000000000006)[..6] \ No newline at end of file + return keccak::<_, 384>(input, 0x0000000000000006)[..6] \ No newline at end of file diff --git a/zokrates_stdlib/stdlib/hashes/sha3/512bit.zok b/zokrates_stdlib/stdlib/hashes/sha3/512bit.zok index 6c37836e..5dbd8d5e 100644 --- a/zokrates_stdlib/stdlib/hashes/sha3/512bit.zok +++ b/zokrates_stdlib/stdlib/hashes/sha3/512bit.zok @@ -1,4 +1,4 @@ import "hashes/keccak/keccak" as keccak def main(u64[N] input) -> (u64[8]): - return keccak::(input, 0x0000000000000006)[..8] \ No newline at end of file + return keccak::<_, 512>(input, 0x0000000000000006)[..8] \ No newline at end of file diff --git a/zokrates_stdlib/stdlib/utils/pack/bool/nonStrictUnpack256.zok b/zokrates_stdlib/stdlib/utils/pack/bool/nonStrictUnpack256.zok index e31dece4..754fc659 100644 --- a/zokrates_stdlib/stdlib/utils/pack/bool/nonStrictUnpack256.zok +++ b/zokrates_stdlib/stdlib/utils/pack/bool/nonStrictUnpack256.zok @@ -7,6 +7,6 @@ import "./unpack_unchecked" // For example, `0` can map to `[0, 0, ..., 0]` or to `bits(p)` def main(field i) -> bool[256]: - bool[254] b = unpack_unchecked::<254>(i) + bool[254] b = unpack_unchecked(i) return [false, false, ...b] \ No newline at end of file diff --git a/zokrates_stdlib/stdlib/utils/pack/bool/unpack128.zok b/zokrates_stdlib/stdlib/utils/pack/bool/unpack128.zok index 8f0b1203..1998bede 100644 --- a/zokrates_stdlib/stdlib/utils/pack/bool/unpack128.zok +++ b/zokrates_stdlib/stdlib/utils/pack/bool/unpack128.zok @@ -3,5 +3,5 @@ import "./unpack" as unpack // Unpack a field element as 128 big-endian bits // If the input is larger than `2**128 - 1`, the output is truncated. def main(field i) -> bool[128]: - bool[128] res = unpack::<128>(i) + bool[128] res = unpack(i) return res \ No newline at end of file diff --git a/zokrates_stdlib/tests/tests/hashes/blake2/blake2s_1536bit.zok b/zokrates_stdlib/tests/tests/hashes/blake2/blake2s_1536bit.zok index 05340e3c..2115a182 100644 --- a/zokrates_stdlib/tests/tests/hashes/blake2/blake2s_1536bit.zok +++ b/zokrates_stdlib/tests/tests/hashes/blake2/blake2s_1536bit.zok @@ -9,7 +9,7 @@ import "hashes/blake2/blake2s" // '879043503b04cab2f3c0d7a4bb01c1db74c238c49887da84e8a619893092b6e2' def main(): - u32[8] h = blake2s::<3>([[0x12345678; 16]; 3]) // 3 * 16 * 32 = 1536 bit input + u32[8] h = blake2s([[0x12345678; 16]; 3]) // 3 * 16 * 32 = 1536 bit input assert(h == [ 0x87904350, 0x3B04CAB2, 0xF3C0D7A4, 0xBB01C1DB, 0x74C238C4, 0x9887DA84, 0xE8A61989, 0x3092B6E2 diff --git a/zokrates_stdlib/tests/tests/hashes/blake2/blake2s_512bit.zok b/zokrates_stdlib/tests/tests/hashes/blake2/blake2s_512bit.zok index 2398c608..62c0c425 100644 --- a/zokrates_stdlib/tests/tests/hashes/blake2/blake2s_512bit.zok +++ b/zokrates_stdlib/tests/tests/hashes/blake2/blake2s_512bit.zok @@ -9,7 +9,7 @@ import "hashes/blake2/blake2s" // '52af1aec3e6663bcc759d55fc7557fbb2f710219f0de138b1b52c919f5c94415' def main(): - u32[8] h = blake2s::<1>([[0x12345678; 16]; 1]) // 16 * 32 = 512 bit input + u32[8] h = blake2s([[0x12345678; 16]; 1]) // 16 * 32 = 512 bit input assert(h == [ 0x52AF1AEC, 0x3E6663BC, 0xC759D55F, 0xC7557FBB, 0x2F710219, 0xF0DE138B, 0x1B52C919, 0xF5C94415 diff --git a/zokrates_stdlib/tests/tests/hashes/blake2/blake2s_p.zok b/zokrates_stdlib/tests/tests/hashes/blake2/blake2s_p.zok index ecea1f26..91d3da97 100644 --- a/zokrates_stdlib/tests/tests/hashes/blake2/blake2s_p.zok +++ b/zokrates_stdlib/tests/tests/hashes/blake2/blake2s_p.zok @@ -9,7 +9,7 @@ import "hashes/blake2/blake2s_p" as blake2s // '780105bc9ca7633b1f289b3d1558dece65e04ac23f88e711dc29600fa3e0258a' def main(): - u32[8] h = blake2s::<1>([[0x12345678; 16]; 1], [0x12345678, 0]) + u32[8] h = blake2s([[0x12345678; 16]; 1], [0x12345678, 0]) assert(h == [ 0x780105BC, 0x9CA7633B, 0x1F289B3D, 0x1558DECE, 0x65E04AC2, 0x3F88E711, 0xDC29600F, 0xA3E0258A diff --git a/zokrates_stdlib/tests/tests/hashes/keccak/384bit.zok b/zokrates_stdlib/tests/tests/hashes/keccak/384bit.zok index a12a3f96..3220ca0b 100644 --- a/zokrates_stdlib/tests/tests/hashes/keccak/384bit.zok +++ b/zokrates_stdlib/tests/tests/hashes/keccak/384bit.zok @@ -9,7 +9,7 @@ import "hashes/keccak/384bit" as keccak384 // 'a944b9b859c1e69d66b52d4cf1f678b24ed8a9ccb0a32bbe882af8a3a1acbd3b68eed9c628307e5d3789f1a64a50e8e7' def main(): - u64[6] h = keccak384::<20>([42; 20]) + u64[6] h = keccak384([42; 20]) assert(h == [ 0xA944B9B859C1E69D, 0x66B52D4CF1F678B2, 0x4ED8A9CCB0A32BBE, 0x882AF8A3A1ACBD3B, 0x68EED9C628307E5D, 0x3789F1A64A50E8E7 diff --git a/zokrates_stdlib/tests/tests/hashes/mimcSponge/mimcSponge.zok b/zokrates_stdlib/tests/tests/hashes/mimcSponge/mimcSponge.zok index 4924b8df..7e61a472 100644 --- a/zokrates_stdlib/tests/tests/hashes/mimcSponge/mimcSponge.zok +++ b/zokrates_stdlib/tests/tests/hashes/mimcSponge/mimcSponge.zok @@ -1,12 +1,12 @@ import "hashes/mimcSponge/mimcSponge" as mimcSponge def main(): - assert(mimcSponge::<2, 3>([1, 2], 3) == [ + assert(mimcSponge::<_, 3>([1, 2], 3) == [ 20225509322021146255705869525264566735642015554514977326536820959638320229084, 13871743498877225461925335509899475799121918157213219438898506786048812913771, 21633608428713573518356618235457250173701815120501233429160399974209848779097 ]) - assert(mimcSponge::<2, 3>([0, 0], 0) == [ + assert(mimcSponge::<_, 3>([0, 0], 0) == [ 20636625426020718969131298365984859231982649550971729229988535915544421356929, 6046202021237334713296073963481784771443313518730771623154467767602059802325, 16227963524034219233279650312501310147918176407385833422019760797222680144279 diff --git a/zokrates_stdlib/tests/tests/hashes/sha3/256bit.zok b/zokrates_stdlib/tests/tests/hashes/sha3/256bit.zok index 203bb970..243dea41 100644 --- a/zokrates_stdlib/tests/tests/hashes/sha3/256bit.zok +++ b/zokrates_stdlib/tests/tests/hashes/sha3/256bit.zok @@ -9,6 +9,6 @@ import "hashes/sha3/256bit" as sha3_256 // '18d00c9e97cd5516243b67b243ede9e2cf0d45d3a844d33340bfc4efc9165100' def main(): - u64[4] h = sha3_256::<20>([42; 20]) + u64[4] h = sha3_256([42; 20]) assert(h == [0x18D00C9E97CD5516, 0x243B67B243EDE9E2, 0xCF0D45D3A844D333, 0x40BFC4EFC9165100]) return \ No newline at end of file diff --git a/zokrates_stdlib/tests/tests/hashes/sha3/512bit.zok b/zokrates_stdlib/tests/tests/hashes/sha3/512bit.zok index 9e65810a..69badc5f 100644 --- a/zokrates_stdlib/tests/tests/hashes/sha3/512bit.zok +++ b/zokrates_stdlib/tests/tests/hashes/sha3/512bit.zok @@ -9,7 +9,7 @@ import "hashes/sha3/512bit" as sha3_512 // '73a0967b68de5ce1093cbd7482fd4de9ccc9c782e2edc71b583d26fe16fb19e3322a2a024b7f6e163fbb1a15161686dd3a39233f9cf8616e7c74e91fa1aa3b2b' def main(): - u64[8] h = sha3_512::<20>([42; 20]) + u64[8] h = sha3_512([42; 20]) assert(h == [ 0x73A0967B68DE5CE1, 0x093CBD7482FD4DE9, 0xCCC9C782E2EDC71B, 0x583D26FE16FB19E3, 0x322A2A024B7F6E16, 0x3FBB1A15161686DD, 0x3A39233F9CF8616E, 0x7C74E91FA1AA3B2B From 9895e85ea6ff168f215510e5a8d72f78b9ededf4 Mon Sep 17 00:00:00 2001 From: schaeff Date: Thu, 23 Sep 2021 16:49:20 +0300 Subject: [PATCH 14/19] remap generics when checking struct alias --- zokrates_core/src/semantics.rs | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/zokrates_core/src/semantics.rs b/zokrates_core/src/semantics.rs index fb4687e6..cabf4155 100644 --- a/zokrates_core/src/semantics.rs +++ b/zokrates_core/src/semantics.rs @@ -1547,7 +1547,11 @@ impl<'ast, T: Field> Checker<'ast, T> { // we only remap the generics DeclarationType::Struct(declared_struct_ty) => { DeclarationType::Struct(DeclarationStructType { - generics: checked_generics, + generics: declared_struct_ty + .generics + .into_iter() + .map(|g| g.map(|g| g.map(&assignment).unwrap())) + .collect(), ..declared_struct_ty }) } From 87115c8cd1dc69b68e57107335dc41ab57bd9ac2 Mon Sep 17 00:00:00 2001 From: schaeff Date: Fri, 24 Sep 2021 15:15:58 +0300 Subject: [PATCH 15/19] check generics on structs and infer them --- zokrates_core/src/typed_absy/types.rs | 50 ++++++++++++++-------- zokrates_stdlib/tests/tests/snark/gm17.zok | 2 +- 2 files changed, 34 insertions(+), 18 deletions(-) diff --git a/zokrates_core/src/typed_absy/types.rs b/zokrates_core/src/typed_absy/types.rs index 750de024..c21031b2 100644 --- a/zokrates_core/src/typed_absy/types.rs +++ b/zokrates_core/src/typed_absy/types.rs @@ -965,6 +965,33 @@ impl<'ast> ConcreteFunctionKey<'ast> { use std::collections::btree_map::Entry; +// check an optional generic value against the corresponding declaration constant +// if None is provided, return true +// if some value is provided, insert it into the map or check that it doesn't conflict if a value is already thereq +pub fn check_generic<'ast, S: Clone + PartialEq + PartialEq>( + generic: &DeclarationConstant<'ast>, + value: Option<&S>, + constants: &mut GGenericsAssignment<'ast, S>, +) -> bool { + value + .map(|value| match generic { + // if the generic is an identifier, we insert into the map, or check if the concrete size + // matches if this identifier is already in the map + DeclarationConstant::Generic(id) => match constants.0.entry(id.clone()) { + Entry::Occupied(e) => *e.get() == *value, + Entry::Vacant(e) => { + e.insert(value.clone()); + true + } + }, + DeclarationConstant::Concrete(generic) => *value == *generic, + // in the case of a constant, we do not know the value yet, so we optimistically assume it's correct + // if it does not match, it will be caught during inlining + DeclarationConstant::Constant(..) => true, + }) + .unwrap_or(true) +} + pub fn check_type<'ast, S: Clone + PartialEq + PartialEq>( decl_ty: &DeclarationType<'ast>, ty: >ype, @@ -972,31 +999,20 @@ pub fn check_type<'ast, S: Clone + PartialEq + PartialEq>( ) -> bool { match (decl_ty, ty) { (DeclarationType::Array(t0), GType::Array(t1)) => { - let s1 = t1.size.clone(); - // both the inner type and the size must match check_type(&t0.ty, &t1.ty, constants) - && match &t0.size { - // if the declared size is an identifier, we insert into the map, or check if the concrete size - // matches if this identifier is already in the map - DeclarationConstant::Generic(id) => match constants.0.entry(id.clone()) { - Entry::Occupied(e) => *e.get() == s1, - Entry::Vacant(e) => { - e.insert(s1); - true - } - }, - DeclarationConstant::Concrete(s0) => s1 == *s0 as u32, - // in the case of a constant, we do not know the value yet, so we optimistically assume it's correct - // if it does not match, it will be caught during inlining - DeclarationConstant::Constant(..) => true, - } + && check_generic(&t0.size, Some(&t1.size), constants) } (DeclarationType::FieldElement, GType::FieldElement) | (DeclarationType::Boolean, GType::Boolean) => true, (DeclarationType::Uint(b0), GType::Uint(b1)) => b0 == b1, (DeclarationType::Struct(s0), GType::Struct(s1)) => { s0.canonical_location == s1.canonical_location + && s0 + .generics + .iter() + .zip(s1.generics.iter()) + .all(|(g0, g1)| check_generic(g0.as_ref().unwrap(), g1.as_ref(), constants)) } _ => false, } diff --git a/zokrates_stdlib/tests/tests/snark/gm17.zok b/zokrates_stdlib/tests/tests/snark/gm17.zok index d09a2473..6efc36b1 100644 --- a/zokrates_stdlib/tests/tests/snark/gm17.zok +++ b/zokrates_stdlib/tests/tests/snark/gm17.zok @@ -54,4 +54,4 @@ from "snark/gm17" import main as verify, Proof, VerificationKey def main(Proof<3> proof, VerificationKey<4> vk) -> bool: - return verify::<3, 4>(proof, vk) \ No newline at end of file + return verify(proof, vk) \ No newline at end of file From 995851a03c8ebdd139c73604de526e532b3c618e Mon Sep 17 00:00:00 2001 From: Thibaut Schaeffer Date: Mon, 27 Sep 2021 20:58:04 +0300 Subject: [PATCH 16/19] Create 1016-schaeff --- changelogs/unreleased/1016-schaeff | 1 + 1 file changed, 1 insertion(+) create mode 100644 changelogs/unreleased/1016-schaeff diff --git a/changelogs/unreleased/1016-schaeff b/changelogs/unreleased/1016-schaeff new file mode 100644 index 00000000..d944de96 --- /dev/null +++ b/changelogs/unreleased/1016-schaeff @@ -0,0 +1 @@ +Fix false positives and false negatives in struct generic inference From 3371df28cc31de39b9ef5e888b73a67c46cc95d9 Mon Sep 17 00:00:00 2001 From: schaeff Date: Mon, 4 Oct 2021 21:37:23 +0300 Subject: [PATCH 17/19] fmt clippy --- zokrates_core/src/static_analysis/propagation.rs | 6 +----- zokrates_core/src/typed_absy/types.rs | 7 +++---- 2 files changed, 4 insertions(+), 9 deletions(-) diff --git a/zokrates_core/src/static_analysis/propagation.rs b/zokrates_core/src/static_analysis/propagation.rs index a41d16c0..8d2d08ed 100644 --- a/zokrates_core/src/static_analysis/propagation.rs +++ b/zokrates_core/src/static_analysis/propagation.rs @@ -1075,11 +1075,7 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Propagator<'ast, 'a, T> { }) // ignore spreads over empty arrays .filter_map(|e| match e { - TypedExpressionOrSpread::Spread(s) - if s.array.size() == UExpression::from(0u32) => - { - None - } + TypedExpressionOrSpread::Spread(s) if s.array.size() == 0u32 => None, e => Some(e), }) .collect(), diff --git a/zokrates_core/src/typed_absy/types.rs b/zokrates_core/src/typed_absy/types.rs index 7213227e..a741b98d 100644 --- a/zokrates_core/src/typed_absy/types.rs +++ b/zokrates_core/src/typed_absy/types.rs @@ -165,7 +165,6 @@ impl<'ast, T> DeclarationConstant<'ast, T> { DeclarationConstant::Concrete(v) => Ok(v.into()), DeclarationConstant::Constant(c) => Ok(c.into()), DeclarationConstant::Expression(_) => unreachable!(), - } } @@ -1023,10 +1022,10 @@ pub fn check_generic<'ast, T, S: Clone + PartialEq + PartialEq>( DeclarationConstant::Expression(e) => match e { TypedExpression::Uint(e) => match e.as_inner() { UExpressionInner::Value(v) => *value == *v as u32, - _ => true + _ => true, }, - _ => unreachable!() - } + _ => unreachable!(), + }, }) .unwrap_or(true) } From 6ca342720fbf6a339390d5f21366535af7ea12c2 Mon Sep 17 00:00:00 2001 From: schaeff Date: Tue, 5 Oct 2021 12:30:53 +0300 Subject: [PATCH 18/19] simplify expression_at making the return value generic, reverting wrong clippy suggestion --- .../src/static_analysis/propagation.rs | 35 ++++++++++--------- zokrates_core/src/typed_absy/mod.rs | 27 +++++++------- zokrates_core/src/typed_absy/result_folder.rs | 5 ++- 3 files changed, 37 insertions(+), 30 deletions(-) diff --git a/zokrates_core/src/static_analysis/propagation.rs b/zokrates_core/src/static_analysis/propagation.rs index 8d2d08ed..3a93991a 100644 --- a/zokrates_core/src/static_analysis/propagation.rs +++ b/zokrates_core/src/static_analysis/propagation.rs @@ -971,7 +971,10 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Propagator<'ast, 'a, T> { } fn fold_select_expression< - E: Expr<'ast, T> + Select<'ast, T> + From>, + E: Expr<'ast, T> + + Select<'ast, T> + + TryFrom> + + Into>, >( &mut self, _: &E::Ty, @@ -988,12 +991,7 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Propagator<'ast, 'a, T> { (ArrayExpressionInner::Value(v), UExpressionInner::Value(n)) => { if n < size { Ok(SelectOrExpression::Expression( - E::from( - v.expression_at::>(n as usize) - .unwrap() - .clone(), - ) - .into_inner(), + v.expression_at::(n as usize).unwrap().into_inner(), )) } else { Err(Error::OutOfBounds(n, size)) @@ -1005,14 +1003,7 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Propagator<'ast, 'a, T> { TypedExpression::Array(a) => match a.as_inner() { ArrayExpressionInner::Value(v) => { Ok(SelectOrExpression::Expression( - E::from( - v.expression_at::>( - n as usize, - ) - .unwrap() - .clone(), - ) - .into_inner(), + v.expression_at::(n as usize).unwrap().into_inner(), )) } _ => unreachable!("should be an array value"), @@ -1075,7 +1066,19 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Propagator<'ast, 'a, T> { }) // ignore spreads over empty arrays .filter_map(|e| match e { - TypedExpressionOrSpread::Spread(s) if s.array.size() == 0u32 => None, + // clippy makes a wrong suggestion here: + // ``` + // this creates an owned instance just for comparison + // UExpression::from(0u32) + // help: try: `0u32` + // ``` + // But for `UExpression`, `PartialEq` is different from `PartialEq` (the latter is too optimistic in this case) + #[allow(clippy::cmp_owned)] + TypedExpressionOrSpread::Spread(s) + if s.array.size() == UExpression::from(0u32) => + { + None + } e => Some(e), }) .collect(), diff --git a/zokrates_core/src/typed_absy/mod.rs b/zokrates_core/src/typed_absy/mod.rs index f9d40185..cf5ab54a 100644 --- a/zokrates_core/src/typed_absy/mod.rs +++ b/zokrates_core/src/typed_absy/mod.rs @@ -1142,11 +1142,13 @@ impl<'ast, T> IntoIterator for ArrayValue<'ast, T> { } impl<'ast, T: Clone> ArrayValue<'ast, T> { - fn expression_at_aux + Into>>( + fn expression_at_aux< + U: Select<'ast, T> + From> + Into>, + >( v: TypedExpressionOrSpread<'ast, T>, - ) -> Vec>> { + ) -> Vec> { match v { - TypedExpressionOrSpread::Expression(e) => vec![Some(e.clone())], + TypedExpressionOrSpread::Expression(e) => vec![Some(U::from(e))], TypedExpressionOrSpread::Spread(s) => match s.array.size().into_inner() { UExpressionInner::Value(size) => { let array_ty = s.array.ty().clone(); @@ -1158,14 +1160,11 @@ impl<'ast, T: Clone> ArrayValue<'ast, T> { .collect(), a => (0..size) .map(|i| { - Some( - U::select( - a.clone() - .annotate(*array_ty.ty.clone(), array_ty.size.clone()), - i as u32, - ) - .into(), - ) + Some(U::select( + a.clone() + .annotate(*array_ty.ty.clone(), array_ty.size.clone()), + i as u32, + )) }) .collect(), } @@ -1175,10 +1174,12 @@ impl<'ast, T: Clone> ArrayValue<'ast, T> { } } - pub fn expression_at + Into>>( + pub fn expression_at< + U: Select<'ast, T> + From> + Into>, + >( &self, index: usize, - ) -> Option> { + ) -> Option { self.0 .iter() .map(|v| Self::expression_at_aux::(v.clone())) diff --git a/zokrates_core/src/typed_absy/result_folder.rs b/zokrates_core/src/typed_absy/result_folder.rs index 738f026c..e6290342 100644 --- a/zokrates_core/src/typed_absy/result_folder.rs +++ b/zokrates_core/src/typed_absy/result_folder.rs @@ -212,7 +212,10 @@ pub trait ResultFolder<'ast, T: Field>: Sized { } fn fold_select_expression< - E: Expr<'ast, T> + Select<'ast, T> + From>, + E: Expr<'ast, T> + + Select<'ast, T> + + Into> + + From>, >( &mut self, ty: &E::Ty, From eb492153f4de240e6d60a5b67885f323fc6e49da Mon Sep 17 00:00:00 2001 From: schaeff Date: Tue, 12 Oct 2021 12:10:02 +0300 Subject: [PATCH 19/19] add docs, use alias in example --- zokrates_book/src/language/types.md | 12 ++++++++++++ zokrates_cli/examples/book/type_aliases.zok | 11 +++++++++++ .../examples/sudoku/prime_sudoku_checker.zok | 17 +++++++++++++---- .../src/static_analysis/struct_concretizer.rs | 1 - 4 files changed, 36 insertions(+), 5 deletions(-) create mode 100644 zokrates_cli/examples/book/type_aliases.zok diff --git a/zokrates_book/src/language/types.md b/zokrates_book/src/language/types.md index b97c1913..8d4bbee2 100644 --- a/zokrates_book/src/language/types.md +++ b/zokrates_book/src/language/types.md @@ -128,6 +128,8 @@ struct Point { } ``` +Note that two struct definitions with the same members still introduce two entirely different types. For example, they cannot be compared with each other. + #### Declaration and Initialization Initialization of a variable of a struct type always needs to happen in the same statement as a declaration, unless the struct-typed variable is declared within a function's signature. @@ -144,3 +146,13 @@ The variables within a struct instance, the so called members, can be accessed t ```zokrates {{#include ../../../zokrates_cli/examples/book/struct_assign.zok}} ``` + +### Type aliases + +Type aliases can be defined for any existing type. This can be useful for readability, or to specialize generic types. + +Note that type aliases are just syntactic sugar: in the type system, a type and its alias are exactly equivalent. For example, they can be compared. + +```zokrates +{{#include ../../../zokrates_cli/examples/book/type_aliases.zok}} +``` diff --git a/zokrates_cli/examples/book/type_aliases.zok b/zokrates_cli/examples/book/type_aliases.zok new file mode 100644 index 00000000..fd500a4e --- /dev/null +++ b/zokrates_cli/examples/book/type_aliases.zok @@ -0,0 +1,11 @@ +type MyField = field + +type Rectangle = bool[L][W] + +type Square = Rectangle + +def main(): + MyField f = 42 + Rectangle<2, 2> r = [[true; 2]; 2] + Square<2> s = r + return \ No newline at end of file diff --git a/zokrates_cli/examples/sudoku/prime_sudoku_checker.zok b/zokrates_cli/examples/sudoku/prime_sudoku_checker.zok index a5cd37d6..aec77917 100644 --- a/zokrates_cli/examples/sudoku/prime_sudoku_checker.zok +++ b/zokrates_cli/examples/sudoku/prime_sudoku_checker.zok @@ -8,6 +8,9 @@ // -------------------------- // | c21 | c22 || d21 | d22 | +type Grid = field[N][N] +const field[4] PRIMES = [2, 3, 5, 7] + // We encode values in the following way: // 1 -> 2 // 2 -> 3 @@ -18,16 +21,22 @@ // assumption: `a, b, c, d` are all in `{ 2, 3, 5, 7 }` def checkNoDuplicates(field a, field b, field c, field d) -> bool: // as `{ 2, 3, 5, 7 }` are primes, the set `{ a, b, c, d }` is equal to the set `{ 2, 3, 5, 7}` if and only if the products match - return a * b * c * d == 2 * 3 * 5 * 7 + return a * b * c * d == PRIMES[0] * PRIMES[1] * PRIMES[2] * PRIMES[3] -// returns `0` if and only if `x` in `{ 2, 3, 5, 7 }` +// returns true if and only if `x` is one of the `4` primes def validateInput(field x) -> bool: - return (x-2) * (x-3) * (x-5) * (x-7) == 0 + field res = 1 + + for u32 i in 0..4 do + res = res * (x - PRIMES[i]) + endfor + + return res == 0 // variables naming: box'row''column' def main(field a21, field b11, field b22, field c11, field c22, field d21, private field a11, private field a12, private field a22, private field b12, private field b21, private field c12, private field c21, private field d11, private field d12, private field d22) -> bool: - field[4][4] a = [[a11, a12, b11, b12], [a21, a22, b21, b22], [c11, c12, d11, d12], [c21, c22, d21, d22]] + Grid<4> a = [[a11, a12, b11, b12], [a21, a22, b21, b22], [c11, c12, d11, d12], [c21, c22, d21, d22]] bool res = true diff --git a/zokrates_core/src/static_analysis/struct_concretizer.rs b/zokrates_core/src/static_analysis/struct_concretizer.rs index c62a31ae..71502cc1 100644 --- a/zokrates_core/src/static_analysis/struct_concretizer.rs +++ b/zokrates_core/src/static_analysis/struct_concretizer.rs @@ -4,7 +4,6 @@ // Where for an array, `field[N]` ends up being propagated to `field[42]` which is direct to turn into a concrete type, // for structs, `Foo { field[N] a }` is propagated to `Foo<42> { field[N] a }`. The missing step is replacing `N` by `42` // *inside* the canonical type, so that it can be concretized in the same way arrays are. -// We apply this transformation only to the main function. use crate::typed_absy::folder::*; use crate::typed_absy::{