From 5539edc16b9579be8346ace0675c2b8596c22631 Mon Sep 17 00:00:00 2001 From: dark64 Date: Wed, 12 May 2021 20:03:08 +0200 Subject: [PATCH] remove strict ordering of declarations, refactor import logic --- zokrates_core/src/absy/from_ast.rs | 145 +++-- zokrates_core/src/absy/mod.rs | 475 ++++++--------- zokrates_core/src/absy/node.rs | 4 +- zokrates_core/src/compile.rs | 2 +- zokrates_core/src/imports.rs | 570 ++++++++---------- zokrates_core/src/semantics.rs | 421 ++++++------- .../tests/tests/generics/embed.zok | 2 +- .../tests/tests/left_rotation_bits.zok | 4 +- .../tests/tests/right_rotation_bits.zok | 4 +- zokrates_core_test/tests/tests/split_bls.zok | 2 +- zokrates_core_test/tests/tests/split_bn.zok | 2 +- .../tests/tests/uint/extend.zok | 4 +- .../tests/tests/uint/from_to_bits.zok | 16 +- .../tests/tests/uint/operations.zok | 4 +- .../tests/tests/uint/propagation/rotate.zok | 73 +-- .../tests/tests/uint/rotate.zok | 4 +- .../tests/tests/uint/sha256.zok | 52 +- zokrates_core_test/tests/tests/uint/temp1.zok | 18 +- zokrates_core_test/tests/tests/uint/temp2.zok | 18 +- zokrates_parser/src/zokrates.pest | 4 +- zokrates_pest_ast/src/lib.rs | 350 +++++------ .../stdlib/hashes/blake2/blake2s_p.zok | 4 +- .../stdlib/hashes/pedersen/512bit.zok | 4 +- .../sha256/embed/shaRoundNoBoolCheck.zok | 2 +- .../utils/casts/bool_array_to_u32_array.zok | 4 +- .../stdlib/utils/casts/field_to_u16.zok | 5 +- .../stdlib/utils/casts/field_to_u32.zok | 5 +- .../stdlib/utils/casts/field_to_u64.zok | 5 +- .../stdlib/utils/casts/field_to_u8.zok | 5 +- .../stdlib/utils/casts/u16_from_bits.zok | 4 +- .../stdlib/utils/casts/u16_to_bits.zok | 4 +- .../stdlib/utils/casts/u16_to_field.zok | 4 +- .../utils/casts/u32_array_to_bool_array.zok | 4 +- .../stdlib/utils/casts/u32_from_bits.zok | 4 +- .../stdlib/utils/casts/u32_to_bits.zok | 4 +- .../stdlib/utils/casts/u32_to_field.zok | 4 +- .../stdlib/utils/casts/u64_from_bits.zok | 4 +- .../stdlib/utils/casts/u64_to_bits.zok | 4 +- .../stdlib/utils/casts/u64_to_field.zok | 4 +- .../stdlib/utils/casts/u8_from_bits.zok | 4 +- .../stdlib/utils/casts/u8_to_bits.zok | 4 +- .../stdlib/utils/casts/u8_to_field.zok | 4 +- .../utils/pack/bool/nonStrictUnpack256.zok | 4 +- .../stdlib/utils/pack/bool/unpack.zok | 2 +- .../stdlib/utils/pack/u32/pack256.zok | 15 +- 45 files changed, 988 insertions(+), 1293 deletions(-) diff --git a/zokrates_core/src/absy/from_ast.rs b/zokrates_core/src/absy/from_ast.rs index 652932c6..4b7a7852 100644 --- a/zokrates_core/src/absy/from_ast.rs +++ b/zokrates_core/src/absy/from_ast.rs @@ -1,66 +1,68 @@ use crate::absy; -use crate::imports; -use crate::absy::SymbolDefinition; use num_bigint::BigUint; +use std::path::Path; use zokrates_pest_ast as pest; impl<'ast> From> for absy::Module<'ast> { - fn from(prog: pest::File<'ast>) -> absy::Module<'ast> { - absy::Module::with_symbols( - prog.structs - .into_iter() - .map(absy::SymbolDeclarationNode::from) - .chain( - prog.constants - .into_iter() - .map(absy::SymbolDeclarationNode::from), - ) - .chain( - prog.functions - .into_iter() - .map(absy::SymbolDeclarationNode::from), - ), - ) - .imports( - prog.imports - .into_iter() - .map(absy::ImportDirective::from) - .flatten(), - ) + fn from(file: pest::File<'ast>) -> absy::Module<'ast> { + absy::Module::with_symbols(file.declarations.into_iter().map(|d| match d { + pest::SymbolDeclaration::Import(i) => i.into(), + pest::SymbolDeclaration::Constant(c) => c.into(), + pest::SymbolDeclaration::Struct(s) => s.into(), + pest::SymbolDeclaration::Function(f) => f.into(), + })) } } -impl<'ast> From> for absy::ImportDirective<'ast> { - fn from(import: pest::ImportDirective<'ast>) -> absy::ImportDirective<'ast> { +impl<'ast> From> for absy::SymbolDeclarationNode<'ast> { + fn from(import: pest::ImportDirective<'ast>) -> absy::SymbolDeclarationNode<'ast> { use crate::absy::NodeValue; match import { - pest::ImportDirective::Main(import) => absy::ImportDirective::Main( - imports::Import::new(None, std::path::Path::new(import.source.span.as_str())) - .alias(import.alias.map(|a| a.span.as_str())) - .span(import.span), - ), - pest::ImportDirective::From(import) => absy::ImportDirective::From( - import - .symbols - .iter() - .map(|symbol| { - imports::Import::new( - Some(symbol.symbol.span.as_str()), - std::path::Path::new(import.source.span.as_str()), - ) - .alias( - symbol - .alias - .as_ref() - .map(|a| a.span.as_str()) - .or_else(|| Some(symbol.symbol.span.as_str())), - ) - .span(symbol.span.clone()) - }) - .collect(), - ), + pest::ImportDirective::Main(import) => { + let span = import.span; + let source = Path::new(import.source.span.as_str()); + + let import = absy::MainImport { + source, + alias: import.alias.map(|a| a.span.as_str()), + } + .span(span.clone()); + + absy::SymbolDeclaration { + id: None, + symbol: absy::Symbol::Here(absy::SymbolDefinition::Import( + absy::ImportDirective::Main(import), + )), + } + .span(span) + } + pest::ImportDirective::From(import) => { + let span = import.span; + let source = Path::new(import.source.span.as_str()); + + let import = absy::FromImport { + source, + symbols: import + .symbols + .into_iter() + .map(|symbol| absy::SymbolIdentifier { + id: symbol.id.span.as_str(), + alias: symbol.alias.map(|a| a.span.as_str()), + }) + .collect(), + } + .span(span.clone()); + + absy::SymbolDeclaration { + id: None, + symbol: absy::Symbol::Here(absy::SymbolDefinition::Import( + absy::ImportDirective::From(import), + )), + } + .span(span) + } } } } @@ -83,8 +85,8 @@ impl<'ast> From> for absy::SymbolDeclarationNode<'a .span(span.clone()); absy::SymbolDeclaration { - id, - symbol: absy::Symbol::Here(SymbolDefinition::Struct(ty)), + id: Some(id), + symbol: absy::Symbol::Here(absy::SymbolDefinition::Struct(ty)), } .span(span) } @@ -118,15 +120,15 @@ impl<'ast> From> for absy::SymbolDeclarationNode< .span(span.clone()); absy::SymbolDeclaration { - id, - symbol: absy::Symbol::Here(SymbolDefinition::Constant(ty)), + id: Some(id), + symbol: absy::Symbol::Here(absy::SymbolDefinition::Constant(ty)), } .span(span) } } -impl<'ast> From> for absy::SymbolDeclarationNode<'ast> { - fn from(function: pest::Function<'ast>) -> absy::SymbolDeclarationNode<'ast> { +impl<'ast> From> for absy::SymbolDeclarationNode<'ast> { + fn from(function: pest::FunctionDefinition<'ast>) -> absy::SymbolDeclarationNode<'ast> { use crate::absy::NodeValue; let span = function.span; @@ -174,8 +176,8 @@ impl<'ast> From> for absy::SymbolDeclarationNode<'ast> { .span(span.clone()); absy::SymbolDeclaration { - id, - symbol: absy::Symbol::Here(SymbolDefinition::Function(function)), + id: Some(id), + symbol: absy::Symbol::Here(absy::SymbolDefinition::Function(function)), } .span(span) } @@ -780,8 +782,8 @@ mod tests { let ast = pest::generate_ast(&source).unwrap(); let expected: absy::Module = absy::Module { symbols: vec![absy::SymbolDeclaration { - id: &source[4..8], - symbol: absy::Symbol::Here(SymbolDefinition::Function( + id: Some(&source[4..8]), + symbol: absy::Symbol::Here(absy::SymbolDefinition::Function( absy::Function { arguments: vec![], statements: vec![absy::Statement::Return( @@ -801,7 +803,6 @@ mod tests { )), } .into()], - imports: vec![], }; assert_eq!(absy::Module::from(ast), expected); } @@ -812,8 +813,8 @@ mod tests { let ast = pest::generate_ast(&source).unwrap(); let expected: absy::Module = absy::Module { symbols: vec![absy::SymbolDeclaration { - id: &source[4..8], - symbol: absy::Symbol::Here(SymbolDefinition::Function( + id: Some(&source[4..8]), + symbol: absy::Symbol::Here(absy::SymbolDefinition::Function( absy::Function { arguments: vec![], statements: vec![absy::Statement::Return( @@ -831,7 +832,6 @@ mod tests { )), } .into()], - imports: vec![], }; assert_eq!(absy::Module::from(ast), expected); } @@ -843,8 +843,8 @@ mod tests { let expected: absy::Module = absy::Module { symbols: vec![absy::SymbolDeclaration { - id: &source[4..8], - symbol: absy::Symbol::Here(SymbolDefinition::Function( + id: Some(&source[4..8]), + symbol: absy::Symbol::Here(absy::SymbolDefinition::Function( absy::Function { arguments: vec![ absy::Parameter::private( @@ -884,7 +884,6 @@ mod tests { )), } .into()], - imports: vec![], }; assert_eq!(absy::Module::from(ast), expected); @@ -897,8 +896,8 @@ mod tests { fn wrap(ty: UnresolvedType<'static>) -> absy::Module<'static> { absy::Module { symbols: vec![absy::SymbolDeclaration { - id: "main", - symbol: absy::Symbol::Here(SymbolDefinition::Function( + id: Some("main"), + symbol: absy::Symbol::Here(absy::SymbolDefinition::Function( absy::Function { arguments: vec![absy::Parameter::private( absy::Variable::new("a", ty.clone().mock()).into(), @@ -917,7 +916,6 @@ mod tests { )), } .into()], - imports: vec![], } } @@ -971,8 +969,8 @@ mod tests { fn wrap(expression: absy::Expression<'static>) -> absy::Module { absy::Module { symbols: vec![absy::SymbolDeclaration { - id: "main", - symbol: absy::Symbol::Here(SymbolDefinition::Function( + id: Some("main"), + symbol: absy::Symbol::Here(absy::SymbolDefinition::Function( absy::Function { arguments: vec![], statements: vec![absy::Statement::Return( @@ -988,7 +986,6 @@ mod tests { )), } .into()], - imports: vec![], } } diff --git a/zokrates_core/src/absy/mod.rs b/zokrates_core/src/absy/mod.rs index d070860d..f3f986ef 100644 --- a/zokrates_core/src/absy/mod.rs +++ b/zokrates_core/src/absy/mod.rs @@ -18,8 +18,6 @@ pub use crate::absy::variable::{Variable, VariableNode}; use crate::embed::FlatEmbed; use std::path::{Path, PathBuf}; -use crate::imports::ImportDirective; -use crate::imports::ImportNode; use std::fmt; use num_bigint::BigUint; @@ -44,63 +42,173 @@ pub struct Program<'ast> { pub main: OwnedModuleId, } -/// A declaration of a `FunctionSymbol`, be it from an import or a function definition -#[derive(PartialEq, Clone, Debug)] -pub struct SymbolDeclaration<'ast> { +#[derive(Debug, PartialEq, Clone)] +pub struct SymbolIdentifier<'ast> { pub id: Identifier<'ast>, + pub alias: Option>, +} + +impl<'ast> From> for SymbolIdentifier<'ast> { + fn from(id: &'ast str) -> Self { + SymbolIdentifier { id, alias: None } + } +} + +impl<'ast> SymbolIdentifier<'ast> { + pub fn alias(mut self, alias: Option>) -> Self { + self.alias = alias; + self + } + pub fn get_alias(&self) -> Identifier<'ast> { + self.alias.unwrap_or(self.id) + } +} + +impl<'ast> fmt::Display for SymbolIdentifier<'ast> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!( + f, + "{}{}", + self.id, + self.alias.map(|a| format!(" as {}", a)).unwrap_or_default() + ) + } +} + +#[derive(Debug, Clone, PartialEq)] +pub struct MainImport<'ast> { + pub source: &'ast Path, + pub alias: Option>, +} + +pub type MainImportNode<'ast> = Node>; + +impl<'ast> fmt::Display for MainImport<'ast> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self.alias { + Some(ref alias) => write!(f, "import \"{}\" as {}", self.source.display(), alias), + None => write!(f, "import \"{}\"", self.source.display()), + } + } +} + +#[derive(Debug, Clone, PartialEq)] +pub struct FromImport<'ast> { + pub source: &'ast Path, + pub symbols: Vec>, +} + +pub type FromImportNode<'ast> = Node>; + +impl<'ast> fmt::Display for FromImport<'ast> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!( + f, + "from \"{}\" import {}", + self.source.display(), + self.symbols + .iter() + .map(|s| s.to_string()) + .collect::>() + .join(", ") + ) + } +} + +#[derive(Debug, Clone, PartialEq)] +pub enum ImportDirective<'ast> { + Main(MainImportNode<'ast>), + From(FromImportNode<'ast>), +} + +impl<'ast> fmt::Display for ImportDirective<'ast> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + ImportDirective::Main(main) => write!(f, "{}", main), + ImportDirective::From(from) => write!(f, "{}", from), + } + } +} + +#[derive(Debug, Clone, PartialEq)] +pub struct SymbolImport<'ast> { + pub module_id: OwnedModuleId, + pub symbol_id: SymbolIdentifier<'ast>, +} + +pub type SymbolImportNode<'ast> = Node>; + +impl<'ast> SymbolImport<'ast> { + pub fn with_id_in_module>, U: Into>( + symbol_id: S, + module_id: U, + ) -> Self { + SymbolImport { + symbol_id: symbol_id.into(), + module_id: module_id.into(), + } + } +} + +impl<'ast> fmt::Display for SymbolImport<'ast> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!( + f, + "from \"{}\" import {}", + self.module_id.display(), + self.symbol_id + ) + } +} + +/// A declaration of a symbol +#[derive(Debug, PartialEq, Clone)] +pub struct SymbolDeclaration<'ast> { + pub id: Option>, pub symbol: Symbol<'ast>, } #[allow(clippy::large_enum_variant)] -#[derive(PartialEq, Clone)] +#[derive(Debug, PartialEq, Clone)] pub enum SymbolDefinition<'ast> { + Import(ImportDirective<'ast>), Struct(StructDefinitionNode<'ast>), Constant(ConstantDefinitionNode<'ast>), Function(FunctionNode<'ast>), } -impl<'ast> fmt::Debug for SymbolDefinition<'ast> { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self { - SymbolDefinition::Struct(s) => write!(f, "Struct({:?})", s), - SymbolDefinition::Constant(c) => write!(f, "Constant({:?})", c), - SymbolDefinition::Function(func) => write!(f, "Function({:?})", func), - } - } -} - -#[derive(PartialEq, Clone)] +#[derive(Debug, PartialEq, Clone)] pub enum Symbol<'ast> { Here(SymbolDefinition<'ast>), There(SymbolImportNode<'ast>), Flat(FlatEmbed), } -impl<'ast> fmt::Debug for Symbol<'ast> { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self { - Symbol::Here(k) => write!(f, "Here({:?})", k), - Symbol::There(i) => write!(f, "There({:?})", i), - Symbol::Flat(flat) => write!(f, "Flat({:?})", flat), - } - } -} - impl<'ast> fmt::Display for SymbolDeclaration<'ast> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self.symbol { - Symbol::Here(ref kind) => match kind { - SymbolDefinition::Struct(t) => write!(f, "struct {} {}", self.id, t), - SymbolDefinition::Constant(c) => write!( + match &self.symbol { + Symbol::Here(ref symbol) => match symbol { + SymbolDefinition::Import(ref i) => write!(f, "{}", i), + SymbolDefinition::Struct(ref t) => write!(f, "struct {} {}", self.id.unwrap(), t), + SymbolDefinition::Constant(ref c) => write!( f, "const {} {} = {}", - c.value.ty, self.id, c.value.expression + c.value.ty, + self.id.unwrap(), + c.value.expression ), - SymbolDefinition::Function(func) => write!(f, "def {}{}", self.id, func), + SymbolDefinition::Function(ref func) => { + write!(f, "def {}{}", self.id.unwrap(), func) + } }, - Symbol::There(ref import) => write!(f, "import {} as {}", import, self.id), + Symbol::There(ref i) => write!(f, "{}", i), Symbol::Flat(ref flat_fun) => { - write!(f, "def {}{}:\n\t// hidden", self.id, flat_fun.signature()) + write!( + f, + "def {}{}:\n\t// hidden", + self.id.unwrap(), + flat_fun.signature() + ) } } } @@ -109,25 +217,18 @@ impl<'ast> fmt::Display for SymbolDeclaration<'ast> { pub type SymbolDeclarationNode<'ast> = Node>; /// A module as a collection of `FunctionDeclaration`s -#[derive(Clone, PartialEq)] +#[derive(Debug, Clone, PartialEq)] pub struct Module<'ast> { /// Symbols of the module pub symbols: Declarations<'ast>, - pub imports: Vec>, // we still use `imports` as they are not directly converted into `FunctionDeclaration`s after the importer is done, `imports` is empty } impl<'ast> Module<'ast> { pub fn with_symbols>>(i: I) -> Self { Module { symbols: i.into_iter().collect(), - imports: vec![], } } - - pub fn imports>>(mut self, i: I) -> Self { - self.imports = i.into_iter().collect(); - self - } } pub type UnresolvedTypeNode<'ast> = Node>; @@ -169,7 +270,7 @@ impl<'ast> fmt::Display for StructDefinitionField<'ast> { type StructDefinitionFieldNode<'ast> = Node>; -#[derive(Clone, PartialEq)] +#[derive(Debug, Clone, PartialEq)] pub struct ConstantDefinition<'ast> { pub ty: UnresolvedTypeNode<'ast>, pub expression: ExpressionNode<'ast>, @@ -183,92 +284,55 @@ impl<'ast> fmt::Display for ConstantDefinition<'ast> { } } -impl<'ast> fmt::Debug for ConstantDefinition<'ast> { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!( - f, - "ConstantDefinition({:?}, {:?})", - self.ty, self.expression - ) - } -} - -/// An import -#[derive(Debug, Clone, PartialEq)] -pub struct SymbolImport<'ast> { - /// the id of the symbol in the target module. Note: there may be many candidates as imports statements do not specify the signature. In that case they must all be functions however. - pub symbol_id: Identifier<'ast>, - /// the id of the module to import from - pub module_id: OwnedModuleId, -} - -type SymbolImportNode<'ast> = Node>; - -impl<'ast> SymbolImport<'ast> { - pub fn with_id_in_module>, U: Into>( - symbol_id: S, - module_id: U, - ) -> Self { - SymbolImport { - symbol_id: symbol_id.into(), - module_id: module_id.into(), - } - } -} - -impl<'ast> fmt::Display for SymbolImport<'ast> { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!( - f, - "{} from {}", - self.symbol_id, - self.module_id.display().to_string() - ) - } -} +// /// An import +// #[derive(Debug, Clone, PartialEq)] +// pub struct SymbolImport<'ast> { +// /// the id of the symbol in the target module. Note: there may be many candidates as imports statements do not specify the signature. In that case they must all be functions however. +// pub symbol_id: Identifier<'ast>, +// /// the id of the module to import from +// pub module_id: OwnedModuleId, +// } +// +// type SymbolImportNode<'ast> = Node>; +// +// impl<'ast> SymbolImport<'ast> { +// pub fn with_id_in_module>, U: Into>( +// symbol_id: S, +// module_id: U, +// ) -> Self { +// SymbolImport { +// symbol_id: symbol_id.into(), +// module_id: module_id.into(), +// } +// } +// } +// +// impl<'ast> fmt::Display for SymbolImport<'ast> { +// fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { +// write!( +// f, +// "from {} import {}", +// self.module_id.display().to_string(), +// self.symbol_id, +// ) +// } +// } impl<'ast> fmt::Display for Module<'ast> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - let mut res = vec![]; - res.extend( - self.imports - .iter() - .map(|x| format!("{}", x)) - .collect::>(), - ); - res.extend( - self.symbols - .iter() - .map(|x| format!("{}", x)) - .collect::>(), - ); + let res = self + .symbols + .iter() + .map(|x| format!("{}", x)) + .collect::>(); write!(f, "{}", res.join("\n")) } } -impl<'ast> fmt::Debug for Module<'ast> { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!( - f, - "module(\n\timports:\n\t\t{}\n\tsymbols:\n\t\t{}\n)", - self.imports - .iter() - .map(|x| format!("{:?}", x)) - .collect::>() - .join("\n\t\t"), - self.symbols - .iter() - .map(|x| format!("{:?}", x)) - .collect::>() - .join("\n\t\t") - ) - } -} - pub type ConstantGenericNode<'ast> = Node>; /// A function defined locally -#[derive(Clone, PartialEq)] +#[derive(Debug, Clone, PartialEq)] pub struct Function<'ast> { /// Arguments of the function pub arguments: Vec>, @@ -312,23 +376,8 @@ impl<'ast> fmt::Display for Function<'ast> { } } -impl<'ast> fmt::Debug for Function<'ast> { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!( - f, - "Function(arguments: {:?}, ...):\n{}", - self.arguments, - self.statements - .iter() - .map(|x| format!("\t{:?}", x)) - .collect::>() - .join("\n") - ) - } -} - /// Something that we can assign to -#[derive(Clone, PartialEq)] +#[derive(Debug, Clone, PartialEq)] pub enum Assignee<'ast> { Identifier(Identifier<'ast>), Select(Box>, Box>), @@ -337,16 +386,6 @@ pub enum Assignee<'ast> { pub type AssigneeNode<'ast> = Node>; -impl<'ast> fmt::Debug for Assignee<'ast> { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match *self { - Assignee::Identifier(ref s) => write!(f, "Identifier({:?})", s), - Assignee::Select(ref a, ref e) => write!(f, "Select({:?}[{:?}])", a, e), - Assignee::Member(ref s, ref m) => write!(f, "Member({:?}.{:?})", s, m), - } - } -} - impl<'ast> fmt::Display for Assignee<'ast> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match *self { @@ -359,7 +398,7 @@ impl<'ast> fmt::Display for Assignee<'ast> { /// A statement in a `Function` #[allow(clippy::large_enum_variant)] -#[derive(Clone, PartialEq)] +#[derive(Debug, Clone, PartialEq)] pub enum Statement<'ast> { Return(ExpressionListNode<'ast>), Declaration(VariableNode<'ast>), @@ -403,31 +442,8 @@ impl<'ast> fmt::Display for Statement<'ast> { } } -impl<'ast> fmt::Debug for Statement<'ast> { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match *self { - Statement::Return(ref expr) => write!(f, "Return({:?})", expr), - Statement::Declaration(ref var) => write!(f, "Declaration({:?})", var), - Statement::Definition(ref lhs, ref rhs) => { - write!(f, "Definition({:?}, {:?})", lhs, rhs) - } - Statement::Assertion(ref e) => write!(f, "Assertion({:?})", e), - Statement::For(ref var, ref start, ref stop, ref list) => { - writeln!(f, "for {:?} in {:?}..{:?} do", var, start, stop)?; - for l in list { - writeln!(f, "\t\t{:?}", l)?; - } - write!(f, "\tendfor") - } - Statement::MultipleDefinition(ref lhs, ref rhs) => { - write!(f, "MultipleDefinition({:?}, {:?})", lhs, rhs) - } - } - } -} - /// An element of an inline array, can be a spread `...a` or an expression `a` -#[derive(Clone, PartialEq)] +#[derive(Debug, Clone, PartialEq)] pub enum SpreadOrExpression<'ast> { Spread(SpreadNode<'ast>), Expression(ExpressionNode<'ast>), @@ -448,17 +464,8 @@ impl<'ast> fmt::Display for SpreadOrExpression<'ast> { } } -impl<'ast> fmt::Debug for SpreadOrExpression<'ast> { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match *self { - SpreadOrExpression::Spread(ref s) => write!(f, "{:?}", s), - SpreadOrExpression::Expression(ref e) => write!(f, "{:?}", e), - } - } -} - /// The index in an array selector. Can be a range or an expression. -#[derive(Clone, PartialEq)] +#[derive(Debug, Clone, PartialEq)] pub enum RangeOrExpression<'ast> { Range(RangeNode<'ast>), Expression(ExpressionNode<'ast>), @@ -473,13 +480,10 @@ impl<'ast> fmt::Display for RangeOrExpression<'ast> { } } -impl<'ast> fmt::Debug for RangeOrExpression<'ast> { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match *self { - RangeOrExpression::Range(ref s) => write!(f, "{:?}", s), - RangeOrExpression::Expression(ref e) => write!(f, "{:?}", e), - } - } +/// A spread +#[derive(Debug, Clone, PartialEq)] +pub struct Spread<'ast> { + pub expression: ExpressionNode<'ast>, } pub type SpreadNode<'ast> = Node>; @@ -490,20 +494,8 @@ impl<'ast> fmt::Display for Spread<'ast> { } } -impl<'ast> fmt::Debug for Spread<'ast> { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "Spread({:?})", self.expression) - } -} - -/// A spread -#[derive(Clone, PartialEq)] -pub struct Spread<'ast> { - pub expression: ExpressionNode<'ast>, -} - /// A range -#[derive(Clone, PartialEq)] +#[derive(Debug, Clone, PartialEq)] pub struct Range<'ast> { pub from: Option>, pub to: Option>, @@ -528,14 +520,8 @@ impl<'ast> fmt::Display for Range<'ast> { } } -impl<'ast> fmt::Debug for Range<'ast> { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "Range({:?}, {:?})", self.from, self.to) - } -} - /// An expression -#[derive(Clone, PartialEq)] +#[derive(Debug, Clone, PartialEq)] pub enum Expression<'ast> { IntConstant(BigUint), FieldConstant(BigUint), @@ -672,73 +658,8 @@ impl<'ast> fmt::Display for Expression<'ast> { } } -impl<'ast> fmt::Debug for Expression<'ast> { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match *self { - Expression::U8Constant(ref i) => write!(f, "U8({:x})", i), - Expression::U16Constant(ref i) => write!(f, "U16({:x})", i), - Expression::U32Constant(ref i) => write!(f, "U32({:x})", i), - Expression::U64Constant(ref i) => write!(f, "U64({:x})", i), - Expression::FieldConstant(ref i) => write!(f, "Field({:?})", i), - Expression::IntConstant(ref i) => write!(f, "Int({:?})", i), - Expression::Identifier(ref var) => write!(f, "Ide({})", var), - Expression::Add(ref lhs, ref rhs) => write!(f, "Add({:?}, {:?})", lhs, rhs), - Expression::Sub(ref lhs, ref rhs) => write!(f, "Sub({:?}, {:?})", lhs, rhs), - Expression::Mult(ref lhs, ref rhs) => write!(f, "Mult({:?}, {:?})", lhs, rhs), - Expression::Div(ref lhs, ref rhs) => write!(f, "Div({:?}, {:?})", lhs, rhs), - Expression::Rem(ref lhs, ref rhs) => write!(f, "Rem({:?}, {:?})", lhs, rhs), - Expression::Pow(ref lhs, ref rhs) => write!(f, "Pow({:?}, {:?})", lhs, rhs), - Expression::Neg(ref e) => write!(f, "Neg({:?})", e), - Expression::Pos(ref e) => write!(f, "Pos({:?})", e), - Expression::BooleanConstant(b) => write!(f, "{}", b), - Expression::IfElse(ref condition, ref consequent, ref alternative) => write!( - f, - "IfElse({:?}, {:?}, {:?})", - condition, consequent, alternative - ), - Expression::FunctionCall(ref g, ref i, ref p) => { - write!(f, "FunctionCall({:?}, {:?}, (", g, i)?; - f.debug_list().entries(p.iter()).finish()?; - write!(f, ")") - } - Expression::Lt(ref lhs, ref rhs) => write!(f, "Lt({:?}, {:?})", lhs, rhs), - Expression::Le(ref lhs, ref rhs) => write!(f, "Le({:?}, {:?})", lhs, rhs), - Expression::Eq(ref lhs, ref rhs) => write!(f, "Eq({:?}, {:?})", lhs, rhs), - Expression::Ge(ref lhs, ref rhs) => write!(f, "Ge({:?}, {:?})", lhs, rhs), - Expression::Gt(ref lhs, ref rhs) => write!(f, "Gt({:?}, {:?})", lhs, rhs), - Expression::And(ref lhs, ref rhs) => write!(f, "And({:?}, {:?})", lhs, rhs), - Expression::Not(ref exp) => write!(f, "Not({:?})", exp), - Expression::InlineArray(ref exprs) => { - write!(f, "InlineArray([")?; - f.debug_list().entries(exprs.iter()).finish()?; - write!(f, "]") - } - Expression::ArrayInitializer(ref e, ref count) => { - write!(f, "ArrayInitializer({:?}, {:?})", e, count) - } - Expression::InlineStruct(ref id, ref members) => { - write!(f, "InlineStruct({:?}, [", id)?; - f.debug_list().entries(members.iter()).finish()?; - write!(f, "]") - } - Expression::Select(ref array, ref index) => { - write!(f, "Select({:?}, {:?})", array, index) - } - Expression::Member(ref struc, ref id) => write!(f, "Member({:?}, {:?})", struc, id), - Expression::Or(ref lhs, ref rhs) => write!(f, "Or({:?}, {:?})", lhs, rhs), - Expression::BitXor(ref lhs, ref rhs) => write!(f, "BitXor({:?}, {:?})", lhs, rhs), - Expression::BitAnd(ref lhs, ref rhs) => write!(f, "BitAnd({:?}, {:?})", lhs, rhs), - Expression::BitOr(ref lhs, ref rhs) => write!(f, "BitOr({:?}, {:?})", lhs, rhs), - Expression::LeftShift(ref lhs, ref rhs) => write!(f, "LeftShift({:?}, {:?})", lhs, rhs), - Expression::RightShift(ref lhs, ref rhs) => { - write!(f, "RightShift({:?}, {:?})", lhs, rhs) - } - } - } -} - /// A list of expressions, used in return statements -#[derive(Clone, PartialEq, Default)] +#[derive(Debug, Clone, PartialEq, Default)] pub struct ExpressionList<'ast> { pub expressions: Vec>, } @@ -756,9 +677,3 @@ impl<'ast> fmt::Display for ExpressionList<'ast> { write!(f, "") } } - -impl<'ast> fmt::Debug for ExpressionList<'ast> { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "ExpressionList({:?})", self.expressions) - } -} diff --git a/zokrates_core/src/absy/node.rs b/zokrates_core/src/absy/node.rs index 5b949d49..cd63fb1f 100644 --- a/zokrates_core/src/absy/node.rs +++ b/zokrates_core/src/absy/node.rs @@ -74,7 +74,6 @@ impl From for Node { use crate::absy::types::UnresolvedType; use crate::absy::*; -use crate::imports::*; impl<'ast> NodeValue for Expression<'ast> {} impl<'ast> NodeValue for ExpressionList<'ast> {} @@ -87,10 +86,11 @@ impl<'ast> NodeValue for StructDefinitionField<'ast> {} impl<'ast> NodeValue for ConstantDefinition<'ast> {} impl<'ast> NodeValue for Function<'ast> {} impl<'ast> NodeValue for Module<'ast> {} +impl<'ast> NodeValue for MainImport<'ast> {} +impl<'ast> NodeValue for FromImport<'ast> {} impl<'ast> NodeValue for SymbolImport<'ast> {} impl<'ast> NodeValue for Variable<'ast> {} impl<'ast> NodeValue for Parameter<'ast> {} -impl<'ast> NodeValue for Import<'ast> {} impl<'ast> NodeValue for Spread<'ast> {} impl<'ast> NodeValue for Range<'ast> {} impl<'ast> NodeValue for Identifier<'ast> {} diff --git a/zokrates_core/src/compile.rs b/zokrates_core/src/compile.rs index 87d27429..c9812eea 100644 --- a/zokrates_core/src/compile.rs +++ b/zokrates_core/src/compile.rs @@ -289,7 +289,7 @@ mod test { assert!(res.unwrap_err().0[0] .value() .to_string() - .contains(&"Can't resolve import without a resolver")); + .contains(&"Cannot resolve import without a resolver")); } #[test] diff --git a/zokrates_core/src/imports.rs b/zokrates_core/src/imports.rs index 44ed8c79..996dacd0 100644 --- a/zokrates_core/src/imports.rs +++ b/zokrates_core/src/imports.rs @@ -12,7 +12,7 @@ use crate::parser::Position; use std::collections::HashMap; use std::fmt; use std::io; -use std::path::{Path, PathBuf}; +use std::path::PathBuf; use typed_arena::Arena; use zokrates_common::Resolver; @@ -56,94 +56,6 @@ impl From for Error { } } -#[derive(PartialEq, Clone)] -pub enum ImportDirective<'ast> { - Main(ImportNode<'ast>), - From(Vec>), -} - -impl<'ast> IntoIterator for ImportDirective<'ast> { - type Item = ImportNode<'ast>; - type IntoIter = std::vec::IntoIter; - - fn into_iter(self) -> Self::IntoIter { - let vec = match self { - ImportDirective::Main(v) => vec![v], - ImportDirective::From(v) => v, - }; - vec.into_iter() - } -} - -type ImportPath<'ast> = &'ast Path; - -#[derive(PartialEq, Clone)] -pub struct Import<'ast> { - source: ImportPath<'ast>, - symbol: Option>, - alias: Option>, -} - -pub type ImportNode<'ast> = Node>; - -impl<'ast> Import<'ast> { - pub fn new(symbol: Option>, source: ImportPath<'ast>) -> Import<'ast> { - Import { - symbol, - source, - alias: None, - } - } - - pub fn get_alias(&self) -> &Option> { - &self.alias - } - - pub fn new_with_alias( - symbol: Option>, - source: ImportPath<'ast>, - alias: Identifier<'ast>, - ) -> Import<'ast> { - Import { - symbol, - source, - alias: Some(alias), - } - } - - pub fn alias(mut self, alias: Option>) -> Self { - self.alias = alias; - self - } - - pub fn get_source(&self) -> &ImportPath<'ast> { - &self.source - } -} - -impl<'ast> fmt::Display for Import<'ast> { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self.alias { - Some(ref alias) => write!(f, "import {} as {}", self.source.display(), alias), - None => write!(f, "import {}", self.source.display()), - } - } -} - -impl<'ast> fmt::Debug for Import<'ast> { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self.alias { - Some(ref alias) => write!( - f, - "import(source: {}, alias: {})", - self.source.display(), - alias - ), - None => write!(f, "import(source: {})", self.source.display()), - } - } -} - pub struct Importer; impl Importer { @@ -156,253 +68,261 @@ impl Importer { ) -> Result, CompileErrors> { let mut symbols: Vec<_> = vec![]; - for import in destination.imports { - let pos = import.pos(); - let import = import.value; - let alias = import.alias; - // handle the case of special bellman and packing imports - if import.source.starts_with("EMBED") { - match import.source.to_str().unwrap() { - #[cfg(feature = "bellman")] - "EMBED/sha256round" => { - if T::id() != Bn128Field::id() { - return Err(CompileErrorInner::ImportError( - Error::new(format!( - "Embed sha256round cannot be used with curve {}", - T::name() + for symbol in destination.symbols { + match symbol.value.symbol { + Symbol::Here(ref s) => match s { + SymbolDefinition::Import(ImportDirective::Main(main)) => { + let pos = main.pos(); + let module_id = &main.value.source; + + match resolver { + Some(res) => { + match res.resolve(location.clone(), module_id.to_path_buf()) { + Ok((source, new_location)) => { + // generate an alias from the imported path if none was given explicitly + let alias = main.value.alias.or( + module_id + .file_stem() + .ok_or_else(|| { + CompileErrors::from( + CompileErrorInner::ImportError(Error::new( + format!( + "Could not determine alias for import {}", + module_id.display() + ), + )) + .in_file(&location), + ) + })? + .to_str() + ); + + match modules.get(&new_location) { + Some(_) => {} + None => { + let source = arena.alloc(source); + + let compiled = compile_module::( + source, + new_location.clone(), + resolver, + modules, + &arena, + )?; + + assert!(modules + .insert(new_location.clone(), compiled) + .is_none()); + } + }; + + symbols.push( + SymbolDeclaration { + id: alias, + symbol: Symbol::There( + SymbolImport::with_id_in_module( + "main", + new_location, + ) + .start_end(pos.0, pos.1), + ), + } + .start_end(pos.0, pos.1), + ); + } + Err(err) => { + return Err(CompileErrorInner::ImportError( + err.into().with_pos(Some(pos)), + ) + .in_file(&location) + .into()); + } + } + } + None => { + return Err(CompileErrorInner::from(Error::new( + "Cannot resolve import without a resolver", )) - .with_pos(Some(pos)), - ) - .in_file(&location) - .into()); - } else { - let alias = alias.unwrap_or("sha256round"); - - symbols.push( - SymbolDeclaration { - id: &alias, - symbol: Symbol::Flat(FlatEmbed::Sha256Round), - } - .start_end(pos.0, pos.1), - ); + .in_file(&location) + .into()); + } } } - "EMBED/unpack" => { - let alias = alias.unwrap_or("unpack"); + SymbolDefinition::Import(ImportDirective::From(from)) => { + let pos = from.pos(); + let module_id = &from.value.source; - symbols.push( - SymbolDeclaration { - id: &alias, - symbol: Symbol::Flat(FlatEmbed::Unpack), - } - .start_end(pos.0, pos.1), - ); - } - "EMBED/u64_to_bits" => { - let alias = alias.unwrap_or("u64_to_bits"); - - symbols.push( - SymbolDeclaration { - id: &alias, - symbol: Symbol::Flat(FlatEmbed::U64ToBits), - } - .start_end(pos.0, pos.1), - ); - } - "EMBED/u32_to_bits" => { - let alias = alias.unwrap_or("u32_to_bits"); - - symbols.push( - SymbolDeclaration { - id: &alias, - symbol: Symbol::Flat(FlatEmbed::U32ToBits), - } - .start_end(pos.0, pos.1), - ); - } - "EMBED/u16_to_bits" => { - let alias = alias.unwrap_or("u16_to_bits"); - - symbols.push( - SymbolDeclaration { - id: &alias, - symbol: Symbol::Flat(FlatEmbed::U16ToBits), - } - .start_end(pos.0, pos.1), - ); - } - "EMBED/u8_to_bits" => { - let alias = alias.unwrap_or("u8_to_bits"); - - symbols.push( - SymbolDeclaration { - id: &alias, - symbol: Symbol::Flat(FlatEmbed::U8ToBits), - } - .start_end(pos.0, pos.1), - ); - } - "EMBED/u64_from_bits" => { - let alias = alias.unwrap_or("u64_from_bits"); - - symbols.push( - SymbolDeclaration { - id: &alias, - symbol: Symbol::Flat(FlatEmbed::U64FromBits), - } - .start_end(pos.0, pos.1), - ); - } - "EMBED/u32_from_bits" => { - let alias = alias.unwrap_or("u32_from_bits"); - - symbols.push( - SymbolDeclaration { - id: &alias, - symbol: Symbol::Flat(FlatEmbed::U32FromBits), - } - .start_end(pos.0, pos.1), - ); - } - "EMBED/u16_from_bits" => { - let alias = alias.unwrap_or("u16_from_bits"); - - symbols.push( - SymbolDeclaration { - id: &alias, - symbol: Symbol::Flat(FlatEmbed::U16FromBits), - } - .start_end(pos.0, pos.1), - ); - } - "EMBED/u8_from_bits" => { - let alias = alias.unwrap_or("u8_from_bits"); - - symbols.push( - SymbolDeclaration { - id: &alias, - symbol: Symbol::Flat(FlatEmbed::U8FromBits), - } - .start_end(pos.0, pos.1), - ); - } - s => { - return Err(CompileErrorInner::ImportError( - Error::new(format!("Embed {} not found", s)).with_pos(Some(pos)), - ) - .in_file(&location) - .into()); - } - } - } else { - // to resolve imports, we need a resolver - match resolver { - Some(res) => match res.resolve(location.clone(), import.source.to_path_buf()) { - Ok((source, new_location)) => { - // generate an alias from the imported path if none was given explicitely - let alias = import.alias.unwrap_or( - std::path::Path::new(import.source) - .file_stem() - .ok_or_else(|| { - CompileErrors::from( - CompileErrorInner::ImportError(Error::new(format!( - "Could not determine alias for import {}", - import.source.display() - ))) - .in_file(&location), - ) - })? - .to_str() - .unwrap(), - ); - - match modules.get(&new_location) { - Some(_) => {} - None => { - let source = arena.alloc(source); - - let compiled = compile_module::( - source, - new_location.clone(), - resolver, - modules, - &arena, - )?; - - assert!(modules - .insert(new_location.clone(), compiled) - .is_none()); + match module_id.to_str().unwrap() { + "EMBED" => { + for symbol in &from.value.symbols { + match symbol.id { + #[cfg(feature = "bellman")] + "sha256round" => { + if T::id() != Bn128Field::id() { + return Err(CompileErrorInner::ImportError( + Error::new(format!( + "Embed sha256round cannot be used with curve {}", + T::name() + )) + .with_pos(Some(pos)), + ).in_file(&location).into()); + } else { + symbols.push( + SymbolDeclaration { + id: Some(symbol.get_alias()), + symbol: Symbol::Flat( + FlatEmbed::Sha256Round, + ), + } + .start_end(pos.0, pos.1), + ) + } + } + "unpack" => symbols.push( + SymbolDeclaration { + id: Some(symbol.get_alias()), + symbol: Symbol::Flat(FlatEmbed::Unpack), + } + .start_end(pos.0, pos.1), + ), + "u64_to_bits" => symbols.push( + SymbolDeclaration { + id: Some(symbol.get_alias()), + symbol: Symbol::Flat(FlatEmbed::U64ToBits), + } + .start_end(pos.0, pos.1), + ), + "u32_to_bits" => symbols.push( + SymbolDeclaration { + id: Some(symbol.get_alias()), + symbol: Symbol::Flat(FlatEmbed::U32ToBits), + } + .start_end(pos.0, pos.1), + ), + "u16_to_bits" => symbols.push( + SymbolDeclaration { + id: Some(symbol.get_alias()), + symbol: Symbol::Flat(FlatEmbed::U16ToBits), + } + .start_end(pos.0, pos.1), + ), + "u8_to_bits" => symbols.push( + SymbolDeclaration { + id: Some(symbol.get_alias()), + symbol: Symbol::Flat(FlatEmbed::U8ToBits), + } + .start_end(pos.0, pos.1), + ), + "u64_from_bits" => symbols.push( + SymbolDeclaration { + id: Some(symbol.get_alias()), + symbol: Symbol::Flat(FlatEmbed::U64FromBits), + } + .start_end(pos.0, pos.1), + ), + "u32_from_bits" => symbols.push( + SymbolDeclaration { + id: Some(symbol.get_alias()), + symbol: Symbol::Flat(FlatEmbed::U32FromBits), + } + .start_end(pos.0, pos.1), + ), + "u16_from_bits" => symbols.push( + SymbolDeclaration { + id: Some(symbol.get_alias()), + symbol: Symbol::Flat(FlatEmbed::U16FromBits), + } + .start_end(pos.0, pos.1), + ), + "u8_from_bits" => symbols.push( + SymbolDeclaration { + id: Some(symbol.get_alias()), + symbol: Symbol::Flat(FlatEmbed::U8FromBits), + } + .start_end(pos.0, pos.1), + ), + s => { + return Err(CompileErrorInner::ImportError( + Error::new(format!("Embed {} not found", s)) + .with_pos(Some(pos)), + ) + .in_file(&location) + .into()) + } + } } - }; + } + _ => { + for symbol in &from.value.symbols { + match resolver { + Some(res) => { + match res + .resolve(location.clone(), module_id.to_path_buf()) + { + Ok((source, new_location)) => { + match modules.get(&new_location) { + Some(_) => {} + None => { + let source = arena.alloc(source); - symbols.push( - SymbolDeclaration { - id: &alias, - symbol: Symbol::There( - SymbolImport::with_id_in_module( - import.symbol.unwrap_or("main"), - new_location.display().to_string(), - ) - .start_end(pos.0, pos.1), - ), + let compiled = compile_module::( + source, + new_location.clone(), + resolver, + modules, + &arena, + )?; + + assert!(modules + .insert( + new_location.clone(), + compiled + ) + .is_none()); + } + }; + + symbols.push( + SymbolDeclaration { + id: Some(symbol.get_alias()), + symbol: Symbol::There( + SymbolImport::with_id_in_module( + symbol.id, + new_location, + ) + .start_end(pos.0, pos.1), + ), + } + .start_end(pos.0, pos.1), + ); + } + Err(err) => { + return Err(CompileErrorInner::ImportError( + err.into().with_pos(Some(pos)), + ) + .in_file(&location) + .into()); + } + } + } + None => { + return Err(CompileErrorInner::from(Error::new( + "Cannot resolve import without a resolver", + )) + .in_file(&location) + .into()); + } + } } - .start_end(pos.0, pos.1), - ); + } } - Err(err) => { - return Err(CompileErrorInner::ImportError( - err.into().with_pos(Some(pos)), - ) - .in_file(&location) - .into()); - } - }, - None => { - return Err(CompileErrorInner::from(Error::new( - "Can't resolve import without a resolver", - )) - .in_file(&location) - .into()); } - } + _ => symbols.push(symbol), + }, + _ => unreachable!(), } } - symbols.extend(destination.symbols); - - Ok(Module { - imports: vec![], - symbols, - }) - } -} - -#[cfg(test)] -mod tests { - - use super::*; - - #[test] - fn create_with_no_alias() { - assert_eq!( - Import::new(None, Path::new("./foo/bar/baz.zok")), - Import { - symbol: None, - source: Path::new("./foo/bar/baz.zok"), - alias: None, - } - ); - } - - #[test] - fn create_with_alias() { - assert_eq!( - Import::new_with_alias(None, Path::new("./foo/bar/baz.zok"), &"myalias"), - Import { - symbol: None, - source: Path::new("./foo/bar/baz.zok"), - alias: Some("myalias"), - } - ); + Ok(Module::with_symbols(symbols)) } } diff --git a/zokrates_core/src/semantics.rs b/zokrates_core/src/semantics.rs index 8047c497..0796a6c3 100644 --- a/zokrates_core/src/semantics.rs +++ b/zokrates_core/src/semantics.rs @@ -452,130 +452,128 @@ impl<'ast, T: Field> Checker<'ast, T> { let pos = declaration.pos(); let declaration = declaration.value; + let declaration_id = declaration.id.unwrap(); match declaration.symbol.clone() { - Symbol::Here(kind) => match kind { - SymbolDefinition::Struct(t) => { - match self.check_struct_type_declaration( - declaration.id.to_string(), - t.clone(), - module_id, - &state.types, - ) { - Ok(ty) => { - match symbol_unifier.insert_type(declaration.id) { - false => errors.push( - ErrorInner { - pos: Some(pos), - message: format!( - "{} conflicts with another symbol", - declaration.id, - ), - } - .in_file(module_id), - ), - true => { - // there should be no entry in the map for this type yet - assert!(state - .types - .entry(module_id.to_path_buf()) - .or_default() - .insert(declaration.id.to_string(), ty) - .is_none()); + Symbol::Here(SymbolDefinition::Struct(t)) => { + match self.check_struct_type_declaration( + declaration_id.to_string(), + t.clone(), + module_id, + &state.types, + ) { + Ok(ty) => { + match symbol_unifier.insert_type(declaration_id) { + false => errors.push( + ErrorInner { + pos: Some(pos), + message: format!( + "{} conflicts with another symbol", + declaration_id + ), } - }; - } - Err(e) => errors.extend(e.into_iter().map(|inner| Error { - inner, - module_id: module_id.to_path_buf(), - })), + .in_file(module_id), + ), + true => { + // there should be no entry in the map for this type yet + assert!(state + .types + .entry(module_id.to_path_buf()) + .or_default() + .insert(declaration_id.to_string(), ty) + .is_none()); + } + }; + } + Err(e) => errors.extend(e.into_iter().map(|inner| Error { + inner, + module_id: module_id.to_path_buf(), + })), + } + } + Symbol::Here(SymbolDefinition::Constant(c)) => { + match self.check_constant_definition(declaration_id, c, module_id, &state.types) { + Ok(c) => { + match symbol_unifier.insert_constant(declaration_id) { + false => errors.push( + ErrorInner { + pos: Some(pos), + message: format!( + "{} conflicts with another symbol", + declaration_id + ), + } + .in_file(module_id), + ), + true => { + constants + .insert(declaration_id, TypedConstantSymbol::Here(c.clone())); + self.insert_into_scope(Variable::with_id_and_type( + declaration_id, + c.get_type(), + )); + assert!(state + .constants + .entry(module_id.to_path_buf()) + .or_default() + .insert(declaration_id, c.get_type()) + .is_none()); + } + }; + } + Err(e) => { + errors.push(e.in_file(module_id)); } } - SymbolDefinition::Constant(c) => { - match self.check_constant_definition(declaration.id, c, module_id, &state.types) - { - Ok(c) => { - match symbol_unifier.insert_constant(declaration.id) { - false => errors.push( - ErrorInner { - pos: Some(pos), - message: format!( - "{} conflicts with another symbol", - declaration.id, - ), - } - .in_file(module_id), - ), - true => { - constants.insert( - declaration.id, - TypedConstantSymbol::Here(c.clone()), - ); - self.insert_into_scope(Variable::with_id_and_type( - declaration.id, - c.get_type(), - )); - assert!(state - .constants - .entry(module_id.to_path_buf()) - .or_default() - .insert(declaration.id, c.get_type()) - .is_none()); + } + Symbol::Here(SymbolDefinition::Function(f)) => { + match self.check_function(f, module_id, &state.types) { + Ok(funct) => { + match symbol_unifier + .insert_function(declaration_id, funct.signature.clone()) + { + false => errors.push( + ErrorInner { + pos: Some(pos), + message: format!( + "{} conflicts with another symbol", + declaration_id + ), } - }; - } - Err(e) => { - errors.push(e.in_file(module_id)); - } - } - } - SymbolDefinition::Function(f) => { - match self.check_function(f, module_id, &state.types) { - Ok(funct) => { - match symbol_unifier - .insert_function(declaration.id, funct.signature.clone()) - { - false => errors.push( - ErrorInner { - pos: Some(pos), - message: format!( - "{} conflicts with another symbol", - declaration.id, - ), - } - .in_file(module_id), - ), - true => {} - }; + .in_file(module_id), + ), + true => {} + }; - self.functions.insert( - DeclarationFunctionKey::with_location( - module_id.to_path_buf(), - declaration.id, - ) - .signature(funct.signature.clone()), - ); - functions.insert( - DeclarationFunctionKey::with_location( - module_id.to_path_buf(), - declaration.id, - ) - .signature(funct.signature.clone()), - TypedFunctionSymbol::Here(funct), - ); - } - Err(e) => { - errors.extend(e.into_iter().map(|inner| inner.in_file(module_id))); - } + self.functions.insert( + DeclarationFunctionKey::with_location( + module_id.to_path_buf(), + declaration_id, + ) + .signature(funct.signature.clone()), + ); + functions.insert( + DeclarationFunctionKey::with_location( + module_id.to_path_buf(), + declaration_id, + ) + .signature(funct.signature.clone()), + TypedFunctionSymbol::Here(funct), + ); + } + Err(e) => { + errors.extend(e.into_iter().map(|inner| inner.in_file(module_id))); } } - }, + } Symbol::There(import) => { let pos = import.pos(); let import = import.value; match Checker::new().check_module(&import.module_id, state) { Ok(()) => { + let symbol_id = import.symbol_id.get_alias(); + // find candidates in the checked module let function_candidates: Vec<_> = state .typed_modules @@ -583,10 +581,10 @@ impl<'ast, T: Field> Checker<'ast, T> { .unwrap() .functions .iter() - .filter(|(k, _)| k.id == import.symbol_id) + .filter(|(k, _)| k.id == symbol_id) .map(|(_, v)| DeclarationFunctionKey { module: import.module_id.to_path_buf(), - id: import.symbol_id, + id: symbol_id, signature: v.signature(&state.typed_modules).clone(), }) .collect(); @@ -596,7 +594,7 @@ impl<'ast, T: Field> Checker<'ast, T> { .types .entry(import.module_id.to_path_buf()) .or_default() - .get(import.symbol_id) + .get(symbol_id) .cloned(); // find constant definition candidate @@ -604,7 +602,7 @@ impl<'ast, T: Field> Checker<'ast, T> { .constants .entry(import.module_id.to_path_buf()) .or_default() - .get(import.symbol_id) + .get(symbol_id) .cloned(); match (function_candidates.len(), type_candidate, const_candidate) { @@ -614,7 +612,7 @@ impl<'ast, T: Field> Checker<'ast, T> { let t = match t { DeclarationType::Struct(t) => DeclarationType::Struct(DeclarationStructType { location: Some(StructLocation { - name: declaration.id.into(), + name: declaration_id.into(), module: module_id.to_path_buf() }), ..t @@ -623,28 +621,7 @@ impl<'ast, T: Field> Checker<'ast, T> { }; // we imported a type, so the symbol it gets bound to should not already exist - match symbol_unifier.insert_type(declaration.id) { - false => { - errors.push(Error { - module_id: module_id.to_path_buf(), - inner: ErrorInner { - pos: Some(pos), - message: format!( - "{} conflicts with another symbol", - declaration.id, - ), - }}); - } - true => {} - }; - state - .types - .entry(module_id.to_path_buf()) - .or_default() - .insert(declaration.id.to_string(), t); - } - (0, None, Some(ty)) => { - match symbol_unifier.insert_constant(declaration.id) { + match symbol_unifier.insert_type(declaration_id) { false => { errors.push(Error { module_id: module_id.to_path_buf(), @@ -652,19 +629,40 @@ impl<'ast, T: Field> Checker<'ast, T> { pos: Some(pos), message: format!( "{} conflicts with another symbol", - declaration.id, + declaration_id, + ), + }}); + } + true => {} + }; + state + .types + .entry(module_id.to_path_buf()) + .or_default() + .insert(declaration_id.to_string(), t); + } + (0, None, Some(ty)) => { + match symbol_unifier.insert_constant(declaration_id) { + false => { + errors.push(Error { + module_id: module_id.to_path_buf(), + inner: ErrorInner { + pos: Some(pos), + message: format!( + "{} conflicts with another symbol", + declaration_id, ), }}); } true => { - constants.insert(declaration.id, TypedConstantSymbol::There(import.module_id, import.symbol_id)); - self.insert_into_scope(Variable::with_id_and_type(declaration.id, ty.clone())); + constants.insert(declaration_id, TypedConstantSymbol::There(import.module_id.to_path_buf(), symbol_id)); + self.insert_into_scope(Variable::with_id_and_type(declaration_id, ty.clone())); state .constants .entry(module_id.to_path_buf()) .or_default() - .insert(declaration.id, ty); + .insert(declaration_id, ty); } }; } @@ -673,7 +671,7 @@ impl<'ast, T: Field> Checker<'ast, T> { pos: Some(pos), message: format!( "Could not find symbol {} in module {}", - import.symbol_id, import.module_id.display(), + symbol_id, import.module_id.display(), ), }.in_file(module_id)); } @@ -681,20 +679,20 @@ impl<'ast, T: Field> Checker<'ast, T> { _ => { for candidate in function_candidates { - match symbol_unifier.insert_function(declaration.id, candidate.signature.clone()) { + match symbol_unifier.insert_function(declaration_id, candidate.signature.clone()) { false => { errors.push(ErrorInner { pos: Some(pos), message: format!( "{} conflicts with another symbol", - declaration.id, + declaration_id, ), }.in_file(module_id)); }, true => {} }; - let local_key = candidate.clone().id(declaration.id).module(module_id.to_path_buf()); + let local_key = candidate.clone().id(declaration_id).module(module_id.to_path_buf()); self.functions.insert(local_key.clone()); functions.insert( @@ -712,14 +710,14 @@ impl<'ast, T: Field> Checker<'ast, T> { }; } Symbol::Flat(funct) => { - match symbol_unifier.insert_function(declaration.id, funct.signature()) { + match symbol_unifier.insert_function(declaration_id, funct.signature()) { false => { errors.push( ErrorInner { pos: Some(pos), message: format!( "{} conflicts with another symbol", - declaration.id, + declaration_id ), } .in_file(module_id), @@ -729,15 +727,16 @@ impl<'ast, T: Field> Checker<'ast, T> { }; self.functions.insert( - DeclarationFunctionKey::with_location(module_id.to_path_buf(), declaration.id) + DeclarationFunctionKey::with_location(module_id.to_path_buf(), declaration_id) .signature(funct.signature()), ); functions.insert( - DeclarationFunctionKey::with_location(module_id.to_path_buf(), declaration.id) + DeclarationFunctionKey::with_location(module_id.to_path_buf(), declaration_id) .signature(funct.signature()), TypedFunctionSymbol::Flat(funct), ); } + _ => unreachable!(), }; // return if any errors occured @@ -762,8 +761,6 @@ impl<'ast, T: Field> Checker<'ast, T> { None => None, // if it was not, check it Some(module) => { - assert_eq!(module.imports.len(), 0); - // we need to create an entry in the types map to store types for this module state.types.entry(module_id.to_path_buf()).or_default(); @@ -3245,20 +3242,18 @@ mod tests { let foo: Module = Module { symbols: vec![SymbolDeclaration { - id: "main", + id: Some("main"), symbol: Symbol::Here(SymbolDefinition::Function(function0())), } .mock()], - imports: vec![], }; let bar: Module = Module { symbols: vec![SymbolDeclaration { - id: "main", + id: Some("main"), symbol: Symbol::There(SymbolImport::with_id_in_module("main", "foo").mock()), } .mock()], - imports: vec![], }; let mut state = State::::new( @@ -3303,17 +3298,16 @@ mod tests { let module = Module { symbols: vec![ SymbolDeclaration { - id: "foo", + id: Some("foo"), symbol: Symbol::Here(SymbolDefinition::Function(function0())), } .mock(), SymbolDeclaration { - id: "foo", + id: Some("foo"), symbol: Symbol::Here(SymbolDefinition::Function(function0())), } .mock(), ], - imports: vec![], }; let mut state = State::::new( @@ -3382,17 +3376,16 @@ mod tests { let module = Module { symbols: vec![ SymbolDeclaration { - id: "foo", + id: Some("foo"), symbol: Symbol::Here(SymbolDefinition::Function(f0)), } .mock(), SymbolDeclaration { - id: "foo", + id: Some("foo"), symbol: Symbol::Here(SymbolDefinition::Function(f1)), } .mock(), ], - imports: vec![], }; let mut state = State::new(vec![((*MODULE_ID).clone(), module)].into_iter().collect()); @@ -3420,17 +3413,16 @@ mod tests { let module = Module { symbols: vec![ SymbolDeclaration { - id: "foo", + id: Some("foo"), symbol: Symbol::Here(SymbolDefinition::Function(foo)), } .mock(), SymbolDeclaration { - id: "main", + id: Some("main"), symbol: Symbol::Here(SymbolDefinition::Function(function0())), } .mock(), ], - imports: vec![], }; let mut state = @@ -3473,17 +3465,16 @@ mod tests { let module = Module { symbols: vec![ SymbolDeclaration { - id: "foo", + id: Some("foo"), symbol: Symbol::Here(SymbolDefinition::Function(foo)), } .mock(), SymbolDeclaration { - id: "main", + id: Some("main"), symbol: Symbol::Here(SymbolDefinition::Function(function0())), } .mock(), ], - imports: vec![], }; let mut state = @@ -3513,17 +3504,16 @@ mod tests { let module = Module { symbols: vec![ SymbolDeclaration { - id: "foo", + id: Some("foo"), symbol: Symbol::Here(SymbolDefinition::Function(function0())), } .mock(), SymbolDeclaration { - id: "foo", + id: Some("foo"), symbol: Symbol::Here(SymbolDefinition::Function(function1())), } .mock(), ], - imports: vec![], }; let mut state = State::::new( @@ -3563,17 +3553,16 @@ mod tests { let module: Module = Module { symbols: vec![ SymbolDeclaration { - id: "foo", + id: Some("foo"), symbol: Symbol::Here(SymbolDefinition::Struct(struct0())), } .mock(), SymbolDeclaration { - id: "foo", + id: Some("foo"), symbol: Symbol::Here(SymbolDefinition::Struct(struct1())), } .mock(), ], - imports: vec![], }; let mut state = State::::new( @@ -3600,19 +3589,18 @@ mod tests { let module = Module { symbols: vec![ SymbolDeclaration { - id: "foo", + id: Some("foo"), symbol: Symbol::Here(SymbolDefinition::Function(function0())), } .mock(), SymbolDeclaration { - id: "foo", + id: Some("foo"), symbol: Symbol::Here(SymbolDefinition::Struct( StructDefinition { fields: vec![] }.mock(), )), } .mock(), ], - imports: vec![], }; let mut state = State::::new( @@ -3642,7 +3630,7 @@ mod tests { // should fail let bar = Module::with_symbols(vec![SymbolDeclaration { - id: "main", + id: Some("main"), symbol: Symbol::Here(SymbolDefinition::Function(function0())), } .mock()]); @@ -3650,19 +3638,18 @@ mod tests { let main = Module { symbols: vec![ SymbolDeclaration { - id: "foo", + id: Some("foo"), symbol: Symbol::There( SymbolImport::with_id_in_module("main", "bar").mock(), ), } .mock(), SymbolDeclaration { - id: "foo", + id: Some("foo"), symbol: Symbol::Here(SymbolDefinition::Struct(struct0())), } .mock(), ], - imports: vec![], }; let mut state = State::::new( @@ -3691,7 +3678,7 @@ mod tests { // should fail let bar = Module::with_symbols(vec![SymbolDeclaration { - id: "main", + id: Some("main"), symbol: Symbol::Here(SymbolDefinition::Function(function0())), } .mock()]); @@ -3699,19 +3686,18 @@ mod tests { let main = Module { symbols: vec![ SymbolDeclaration { - id: "foo", + id: Some("foo"), symbol: Symbol::Here(SymbolDefinition::Struct(struct0())), } .mock(), SymbolDeclaration { - id: "foo", + id: Some("foo"), symbol: Symbol::There( SymbolImport::with_id_in_module("main", "bar").mock(), ), } .mock(), ], - imports: vec![], }; let mut state = State::::new( @@ -3919,20 +3905,17 @@ mod tests { let symbols = vec![ SymbolDeclaration { - id: "foo", + id: Some("foo"), symbol: Symbol::Here(SymbolDefinition::Function(foo)), } .mock(), SymbolDeclaration { - id: "bar", + id: Some("bar"), symbol: Symbol::Here(SymbolDefinition::Function(bar)), } .mock(), ]; - let module = Module { - symbols, - imports: vec![], - }; + let module = Module { symbols }; let mut state = State::::new(vec![((*MODULE_ID).clone(), module)].into_iter().collect()); @@ -4034,25 +4017,22 @@ mod tests { let symbols = vec![ SymbolDeclaration { - id: "foo", + id: Some("foo"), symbol: Symbol::Here(SymbolDefinition::Function(foo)), } .mock(), SymbolDeclaration { - id: "bar", + id: Some("bar"), symbol: Symbol::Here(SymbolDefinition::Function(bar)), } .mock(), SymbolDeclaration { - id: "main", + id: Some("main"), symbol: Symbol::Here(SymbolDefinition::Function(main)), } .mock(), ]; - let module = Module { - symbols, - imports: vec![], - }; + let module = Module { symbols }; let mut state = State::::new(vec![((*MODULE_ID).clone(), module)].into_iter().collect()); @@ -4422,17 +4402,16 @@ mod tests { let module = Module { symbols: vec![ SymbolDeclaration { - id: "foo", + id: Some("foo"), symbol: Symbol::Here(SymbolDefinition::Function(foo)), } .mock(), SymbolDeclaration { - id: "main", + id: Some("main"), symbol: Symbol::Here(SymbolDefinition::Function(main)), } .mock(), ], - imports: vec![], }; let mut state = @@ -4509,17 +4488,16 @@ mod tests { let module = Module { symbols: vec![ SymbolDeclaration { - id: "foo", + id: Some("foo"), symbol: Symbol::Here(SymbolDefinition::Function(foo)), } .mock(), SymbolDeclaration { - id: "main", + id: Some("main"), symbol: Symbol::Here(SymbolDefinition::Function(main)), } .mock(), ], - imports: vec![], }; let mut state = @@ -4625,17 +4603,16 @@ mod tests { let module = Module { symbols: vec![ SymbolDeclaration { - id: "foo", + id: Some("foo"), symbol: Symbol::Here(SymbolDefinition::Function(foo)), } .mock(), SymbolDeclaration { - id: "main", + id: Some("main"), symbol: Symbol::Here(SymbolDefinition::Function(main)), } .mock(), ], - imports: vec![], }; let mut state = @@ -4918,21 +4895,18 @@ mod tests { let symbols = vec![ SymbolDeclaration { - id: "main", + id: Some("main"), symbol: Symbol::Here(SymbolDefinition::Function(main1)), } .mock(), SymbolDeclaration { - id: "main", + id: Some("main"), symbol: Symbol::Here(SymbolDefinition::Function(main2)), } .mock(), ]; - let main_module = Module { - symbols, - imports: vec![], - }; + let main_module = Module { symbols }; let program = Program { modules: vec![((*MODULE_ID).clone(), main_module)] @@ -5034,9 +5008,8 @@ mod tests { s: StructDefinition<'static>, ) -> (Checker, State) { let module: Module = Module { - imports: vec![], symbols: vec![SymbolDeclaration { - id: "Foo", + id: Some("Foo"), symbol: Symbol::Here(SymbolDefinition::Struct(s.mock())), } .mock()], @@ -5163,10 +5136,9 @@ mod tests { // struct Bar = { foo: Foo } let module: Module = Module { - imports: vec![], symbols: vec![ SymbolDeclaration { - id: "Foo", + id: Some("Foo"), symbol: Symbol::Here(SymbolDefinition::Struct( StructDefinition { fields: vec![StructDefinitionField { @@ -5180,7 +5152,7 @@ mod tests { } .mock(), SymbolDeclaration { - id: "Bar", + id: Some("Bar"), symbol: Symbol::Here(SymbolDefinition::Struct( StructDefinition { fields: vec![StructDefinitionField { @@ -5233,9 +5205,8 @@ mod tests { // struct Bar = { foo: Foo } let module: Module = Module { - imports: vec![], symbols: vec![SymbolDeclaration { - id: "Bar", + id: Some("Bar"), symbol: Symbol::Here(SymbolDefinition::Struct( StructDefinition { fields: vec![StructDefinitionField { @@ -5266,9 +5237,8 @@ mod tests { // struct Foo = { foo: Foo } let module: Module = Module { - imports: vec![], symbols: vec![SymbolDeclaration { - id: "Foo", + id: Some("Foo"), symbol: Symbol::Here(SymbolDefinition::Struct( StructDefinition { fields: vec![StructDefinitionField { @@ -5300,10 +5270,9 @@ mod tests { // struct Bar = { foo: Foo } let module: Module = Module { - imports: vec![], symbols: vec![ SymbolDeclaration { - id: "Foo", + id: Some("Foo"), symbol: Symbol::Here(SymbolDefinition::Struct( StructDefinition { fields: vec![StructDefinitionField { @@ -5317,7 +5286,7 @@ mod tests { } .mock(), SymbolDeclaration { - id: "Bar", + id: Some("Bar"), symbol: Symbol::Here(SymbolDefinition::Struct( StructDefinition { fields: vec![StructDefinitionField { @@ -5795,17 +5764,17 @@ mod tests { let m = Module::with_symbols(vec![ absy::SymbolDeclaration { - id: "foo", + id: Some("foo"), symbol: Symbol::Here(SymbolDefinition::Function(foo_field)), } .mock(), absy::SymbolDeclaration { - id: "foo", + id: Some("foo"), symbol: Symbol::Here(SymbolDefinition::Function(foo_u32)), } .mock(), absy::SymbolDeclaration { - id: "main", + id: Some("main"), symbol: Symbol::Here(SymbolDefinition::Function(main)), } .mock(), diff --git a/zokrates_core_test/tests/tests/generics/embed.zok b/zokrates_core_test/tests/tests/generics/embed.zok index 88160bba..08abf49d 100644 --- a/zokrates_core_test/tests/tests/generics/embed.zok +++ b/zokrates_core_test/tests/tests/generics/embed.zok @@ -1,4 +1,4 @@ -import "EMBED/unpack" as unpack +from "EMBED" import unpack def main(field x): bool[1] bits = unpack(x) diff --git a/zokrates_core_test/tests/tests/left_rotation_bits.zok b/zokrates_core_test/tests/tests/left_rotation_bits.zok index 49de4de6..66e60ea6 100644 --- a/zokrates_core_test/tests/tests/left_rotation_bits.zok +++ b/zokrates_core_test/tests/tests/left_rotation_bits.zok @@ -1,5 +1,5 @@ -import "EMBED/u32_to_bits" as to_bits -import "EMBED/u32_from_bits" as from_bits +import "utils/casts/u32_to_bits" as to_bits +import "utils/casts/u32_from_bits" as from_bits def rotl32(u32 e) -> u32: bool[32] b = to_bits(e) diff --git a/zokrates_core_test/tests/tests/right_rotation_bits.zok b/zokrates_core_test/tests/tests/right_rotation_bits.zok index d18bc080..58e2f8b5 100644 --- a/zokrates_core_test/tests/tests/right_rotation_bits.zok +++ b/zokrates_core_test/tests/tests/right_rotation_bits.zok @@ -1,5 +1,5 @@ -import "EMBED/u32_to_bits" as to_bits -import "EMBED/u32_from_bits" as from_bits +import "utils/casts/u32_to_bits" as to_bits +import "utils/casts/u32_from_bits" as from_bits def rotr32(u32 e) -> u32: bool[32] b = to_bits(e) diff --git a/zokrates_core_test/tests/tests/split_bls.zok b/zokrates_core_test/tests/tests/split_bls.zok index 924c5c6c..fbf59197 100644 --- a/zokrates_core_test/tests/tests/split_bls.zok +++ b/zokrates_core_test/tests/tests/split_bls.zok @@ -1,4 +1,4 @@ -import "EMBED/unpack" +from "EMBED" import unpack def main(field a) -> (bool[255]): diff --git a/zokrates_core_test/tests/tests/split_bn.zok b/zokrates_core_test/tests/tests/split_bn.zok index 253bdf09..a33762fa 100644 --- a/zokrates_core_test/tests/tests/split_bn.zok +++ b/zokrates_core_test/tests/tests/split_bn.zok @@ -1,4 +1,4 @@ -import "EMBED/unpack" +from "EMBED" import unpack def main(field a) -> (bool[254]): diff --git a/zokrates_core_test/tests/tests/uint/extend.zok b/zokrates_core_test/tests/tests/uint/extend.zok index f50f328a..90ed9d12 100644 --- a/zokrates_core_test/tests/tests/uint/extend.zok +++ b/zokrates_core_test/tests/tests/uint/extend.zok @@ -1,5 +1,5 @@ -import "EMBED/u32_to_bits" as to_bits -import "EMBED/u32_from_bits" as from_bits +import "utils/casts/u32_to_bits" as to_bits +import "utils/casts/u32_from_bits" as from_bits def right_rotate_2(u32 e) -> u32: bool[32] b = to_bits(e) diff --git a/zokrates_core_test/tests/tests/uint/from_to_bits.zok b/zokrates_core_test/tests/tests/uint/from_to_bits.zok index b3b52ccc..1df3ae40 100644 --- a/zokrates_core_test/tests/tests/uint/from_to_bits.zok +++ b/zokrates_core_test/tests/tests/uint/from_to_bits.zok @@ -1,11 +1,11 @@ -import "EMBED/u64_to_bits" as to_bits_64 -import "EMBED/u64_from_bits" as from_bits_64 -import "EMBED/u32_to_bits" as to_bits_32 -import "EMBED/u32_from_bits" as from_bits_32 -import "EMBED/u16_to_bits" as to_bits_16 -import "EMBED/u16_from_bits" as from_bits_16 -import "EMBED/u8_to_bits" as to_bits_8 -import "EMBED/u8_from_bits" as from_bits_8 +import "utils/casts/u64_to_bits" as to_bits_64 +import "utils/casts/u64_from_bits" as from_bits_64 +import "utils/casts/u32_to_bits" as to_bits_32 +import "utils/casts/u32_from_bits" as from_bits_32 +import "utils/casts/u16_to_bits" as to_bits_16 +import "utils/casts/u16_from_bits" as from_bits_16 +import "utils/casts/u8_to_bits" as to_bits_8 +import "utils/casts/u8_from_bits" as from_bits_8 def main(u64 d, u32 e, u16 f, u8 g) -> (u64, u32, u16, u8): bool[64] d_bits = to_bits_64(d) diff --git a/zokrates_core_test/tests/tests/uint/operations.zok b/zokrates_core_test/tests/tests/uint/operations.zok index 032ed3a1..073250b4 100644 --- a/zokrates_core_test/tests/tests/uint/operations.zok +++ b/zokrates_core_test/tests/tests/uint/operations.zok @@ -1,5 +1,5 @@ -import "EMBED/u32_to_bits" as to_bits -import "EMBED/u32_from_bits" as from_bits +import "utils/casts/u32_to_bits" as to_bits +import "utils/casts/u32_from_bits" as from_bits def right_rotate_2(u32 e) -> u32: bool[32] b = to_bits(e) diff --git a/zokrates_core_test/tests/tests/uint/propagation/rotate.zok b/zokrates_core_test/tests/tests/uint/propagation/rotate.zok index 05fdf877..92b41c7d 100644 --- a/zokrates_core_test/tests/tests/uint/propagation/rotate.zok +++ b/zokrates_core_test/tests/tests/uint/propagation/rotate.zok @@ -1,60 +1,9 @@ -import "EMBED/u32_to_bits" as to_bits -import "EMBED/u32_from_bits" as from_bits +from "EMBED" import u32_to_bits as to_bits +from "EMBED" import u32_from_bits as from_bits -def right_rotate_2(u32 e) -> u32: +def right_rotate(u32 e) -> u32: bool[32] b = to_bits(e) - u32 res = from_bits([...b[30..], ...b[..30]]) - return res - -def right_rotate_4(u32 e) -> u32: - bool[32] b = to_bits(e) - u32 res = from_bits([...b[28..], ...b[..28]]) - return res - -def right_rotate_6(u32 e) -> u32: - bool[32] b = to_bits(e) - u32 res = from_bits([...b[26..], ...b[..26]]) - return res - -def right_rotate_7(u32 e) -> u32: - bool[32] b = to_bits(e) - u32 res = from_bits([...b[25..], ...b[..25]]) - return res - -def right_rotate_11(u32 e) -> u32: - bool[32] b = to_bits(e) - u32 res = from_bits([...b[21..], ...b[..21]]) - return res - -def right_rotate_13(u32 e) -> u32: - bool[32] b = to_bits(e) - u32 res = from_bits([...b[19..], ...b[..19]]) - return res - -def right_rotate_17(u32 e) -> u32: - bool[32] b = to_bits(e) - u32 res = from_bits([...b[15..], ...b[..15]]) - return res - -def right_rotate_18(u32 e) -> u32: - bool[32] b = to_bits(e) - u32 res = from_bits([...b[14..], ...b[..14]]) - return res - -def right_rotate_19(u32 e) -> u32: - bool[32] b = to_bits(e) - u32 res = from_bits([...b[13..], ...b[..13]]) - return res - -def right_rotate_22(u32 e) -> u32: - bool[32] b = to_bits(e) - u32 res = from_bits([...b[10..], ...b[..10]]) - return res - -def right_rotate_25(u32 e) -> u32: - bool[32] b = to_bits(e) - u32 res = from_bits([...b[7..], ...b[..7]]) - + u32 res = from_bits([...b[32-N..], ...b[..32-N]]) return res def main(): @@ -62,7 +11,7 @@ def main(): u32 f = 0x01234567 // rotate - u32 rotated = right_rotate_4(e) + u32 rotated = right_rotate::<4>(e) assert(rotated == 0x81234567) // and @@ -93,16 +42,16 @@ def main(): assert(f == from_bits(expected2)) // S0 - u32 e2 = right_rotate_2(e) - u32 e13 = right_rotate_13(e) - u32 e22 = right_rotate_22(e) + u32 e2 = right_rotate::<2>(e) + u32 e13 = right_rotate::<13>(e) + u32 e22 = right_rotate::<22>(e) u32 S0 = e2 ^ e13 ^ e22 assert(S0 == 0x66146474) // S1 - u32 e6 = right_rotate_6(e) - u32 e11 = right_rotate_11(e) - u32 e25 = right_rotate_25(e) + u32 e6 = right_rotate::<6>(e) + u32 e11 = right_rotate::<11>(e) + u32 e25 = right_rotate::<25>(e) u32 S1 = e6 ^ e11 ^ e25 assert(S1 == 0x3561abda) diff --git a/zokrates_core_test/tests/tests/uint/rotate.zok b/zokrates_core_test/tests/tests/uint/rotate.zok index dd0a6ab4..9e3223ab 100644 --- a/zokrates_core_test/tests/tests/uint/rotate.zok +++ b/zokrates_core_test/tests/tests/uint/rotate.zok @@ -1,5 +1,5 @@ -import "EMBED/u32_to_bits" as to_bits -import "EMBED/u32_from_bits" as from_bits +import "utils/casts/u32_to_bits" as to_bits +import "utils/casts/u32_from_bits" as from_bits def right_rotate_4(u32 e) -> u32: bool[32] b = to_bits(e) diff --git a/zokrates_core_test/tests/tests/uint/sha256.zok b/zokrates_core_test/tests/tests/uint/sha256.zok index 7aabe22c..1c8022c4 100644 --- a/zokrates_core_test/tests/tests/uint/sha256.zok +++ b/zokrates_core_test/tests/tests/uint/sha256.zok @@ -1,49 +1,9 @@ -import "EMBED/u32_to_bits" as to_bits -import "EMBED/u32_from_bits" as from_bits - -def right_rotate_2(u32 e) -> u32: - bool[32] b = to_bits(e) - return from_bits([...b[30..], ...b[..30]]) - -def right_rotate_6(u32 e) -> u32: - bool[32] b = to_bits(e) - return from_bits([...b[26..], ...b[..26]]) - -def right_rotate_7(u32 e) -> u32: - bool[32] b = to_bits(e) - return from_bits([...b[25..], ...b[..25]]) - -def right_rotate_11(u32 e) -> u32: - bool[32] b = to_bits(e) - return from_bits([...b[21..], ...b[..21]]) - -def right_rotate_13(u32 e) -> u32: - bool[32] b = to_bits(e) - return from_bits([...b[19..], ...b[..19]]) - -def right_rotate_17(u32 e) -> u32: - bool[32] b = to_bits(e) - return from_bits([...b[15..], ...b[..15]]) - -def right_rotate_18(u32 e) -> u32: - bool[32] b = to_bits(e) - return from_bits([...b[14..], ...b[..14]]) - -def right_rotate_19(u32 e) -> u32: - bool[32] b = to_bits(e) - return from_bits([...b[13..], ...b[..13]]) - -def right_rotate_22(u32 e) -> u32: - bool[32] b = to_bits(e) - return from_bits([...b[10..], ...b[..10]]) - -def right_rotate_25(u32 e) -> u32: - bool[32] b = to_bits(e) - return from_bits([...b[7..], ...b[..7]]) +def right_rotate(u32 x) -> u32: + return (x >> N) | (x << (32 - N)) def extend(u32[64] w, u32 i) -> u32: - u32 s0 = right_rotate_7(w[i-15]) ^ right_rotate_18(w[i-15]) ^ (w[i-15] >> 3) - u32 s1 = right_rotate_17(w[i-2]) ^ right_rotate_19(w[i-2]) ^ (w[i-2] >> 10) + u32 s0 = right_rotate::<7>(w[i-15]) ^ right_rotate::<18>(w[i-15]) ^ (w[i-15] >> 3) + u32 s1 = right_rotate::<17>(w[i-2]) ^ right_rotate::<19>(w[i-2]) ^ (w[i-2] >> 10) return w[i-16] + s0 + w[i-7] + s1 def temp1(u32 e, u32 f, u32 g, u32 h, u32 k, u32 w) -> u32: @@ -51,7 +11,7 @@ def temp1(u32 e, u32 f, u32 g, u32 h, u32 k, u32 w) -> u32: u32 ch = (e & f) ^ ((!e) & g) // S1 := (e rightrotate 6) xor (e rightrotate 11) xor (e rightrotate 25) - u32 S1 = right_rotate_6(e) ^ right_rotate_11(e) ^ right_rotate_25(e) + u32 S1 = right_rotate::<6>(e) ^ right_rotate::<11>(e) ^ right_rotate::<25>(e) // temp1 := h + S1 + ch + k + w return h + S1 + ch + k + w @@ -61,7 +21,7 @@ def temp2(u32 a, u32 b, u32 c) -> u32: u32 maj = (a & b) ^ (a & c) ^ (b & c) // S0 := (a rightrotate 2) xor (a rightrotate 13) xor (a rightrotate 22) - u32 S0 = right_rotate_2(a) ^ right_rotate_13(a) ^ right_rotate_22(a) + u32 S0 = right_rotate::<2>(a) ^ right_rotate::<13>(a) ^ right_rotate::<22>(a) // temp2 := S0 + maj return S0 + maj diff --git a/zokrates_core_test/tests/tests/uint/temp1.zok b/zokrates_core_test/tests/tests/uint/temp1.zok index 4b3d1e54..b7d799c7 100644 --- a/zokrates_core_test/tests/tests/uint/temp1.zok +++ b/zokrates_core_test/tests/tests/uint/temp1.zok @@ -1,17 +1,5 @@ -import "EMBED/u32_to_bits" as to_bits -import "EMBED/u32_from_bits" as from_bits - -def right_rotate_6(u32 e) -> u32: - bool[32] b = to_bits(e) - return from_bits([...b[26..], ...b[..26]]) - -def right_rotate_11(u32 e) -> u32: - bool[32] b = to_bits(e) - return from_bits([...b[21..], ...b[..21]]) - -def right_rotate_25(u32 e) -> u32: - bool[32] b = to_bits(e) - return from_bits([...b[7..], ...b[..7]]) +def right_rotate(u32 x) -> u32: + return (x >> N) | (x << (32 - N)) // input constraining costs 6 * 33 = 198 constraints, the rest 200 def main(u32 e, u32 f, u32 g, u32 h, u32 k, u32 w) -> u32: @@ -19,7 +7,7 @@ def main(u32 e, u32 f, u32 g, u32 h, u32 k, u32 w) -> u32: u32 ch = (e & f) ^ ((!e) & g) // should be 100 constraints // S1 := (e rightrotate 6) xor (e rightrotate 11) xor (e rightrotate 25) - u32 S1 = right_rotate_6(e) ^ right_rotate_11(e) ^ right_rotate_25(e) // should be 66 constraints + u32 S1 = right_rotate::<6>(e) ^ right_rotate::<11>(e) ^ right_rotate::<25>(e) // should be 66 constraints // temp1 := h + S1 + ch + k + w return h + S1 + ch + k + w // should be 35 constraints \ No newline at end of file diff --git a/zokrates_core_test/tests/tests/uint/temp2.zok b/zokrates_core_test/tests/tests/uint/temp2.zok index ba3b5f10..4ebd25d6 100644 --- a/zokrates_core_test/tests/tests/uint/temp2.zok +++ b/zokrates_core_test/tests/tests/uint/temp2.zok @@ -1,17 +1,5 @@ -import "EMBED/u32_to_bits" as to_bits -import "EMBED/u32_from_bits" as from_bits - -def right_rotate_2(u32 e) -> u32: - bool[32] b = to_bits(e) - return from_bits([...b[30..], ...b[..30]]) - -def right_rotate_13(u32 e) -> u32: - bool[32] b = to_bits(e) - return from_bits([...b[19..], ...b[..19]]) - -def right_rotate_22(u32 e) -> u32: - bool[32] b = to_bits(e) - return from_bits([...b[10..], ...b[..10]]) +def right_rotate(u32 x) -> u32: + return (x >> N) | (x << (32 - N)) // input constraining is 99 constraints, the rest is 265 -> total 364 def main(u32 a, u32 b, u32 c) -> u32: @@ -19,7 +7,7 @@ def main(u32 a, u32 b, u32 c) -> u32: u32 maj = (a & b) ^ (a & c) ^ (b & c) // 165 constraints // S0 := (a rightrotate 2) xor (a rightrotate 13) xor (a rightrotate 22) - u32 S0 = right_rotate_2(a) ^ right_rotate_13(a) ^ right_rotate_22(a) // 66 constraints + u32 S0 = right_rotate::<2>(a) ^ right_rotate::<13>(a) ^ right_rotate::<22>(a) // 66 constraints // temp2 := S0 + maj return S0 + maj // 34 constraints \ No newline at end of file diff --git a/zokrates_parser/src/zokrates.pest b/zokrates_parser/src/zokrates.pest index 8280154e..5802414f 100644 --- a/zokrates_parser/src/zokrates.pest +++ b/zokrates_parser/src/zokrates.pest @@ -1,9 +1,11 @@ -file = { SOI ~ NEWLINE* ~ pragma? ~ NEWLINE* ~ import_directive* ~ NEWLINE* ~ ty_struct_definition* ~ NEWLINE* ~ const_definition* ~ NEWLINE* ~ function_definition* ~ EOI } +file = { SOI ~ NEWLINE* ~ pragma? ~ NEWLINE* ~ symbol_declaration* ~ EOI } pragma = { "#pragma" ~ "curve" ~ curve } curve = @{ (ASCII_ALPHANUMERIC | "_") * } +symbol_declaration = { (import_directive | ty_struct_definition | const_definition | function_definition) ~ NEWLINE* } + import_directive = { main_import_directive | from_import_directive } from_import_directive = { "from" ~ "\"" ~ import_source ~ "\"" ~ "import" ~ import_symbol_list ~ NEWLINE* } main_import_directive = { "import" ~ "\"" ~ import_source ~ "\"" ~ ("as" ~ identifier)? ~ NEWLINE+ } diff --git a/zokrates_pest_ast/src/lib.rs b/zokrates_pest_ast/src/lib.rs index be50f71d..4810357d 100644 --- a/zokrates_pest_ast/src/lib.rs +++ b/zokrates_pest_ast/src/lib.rs @@ -12,12 +12,13 @@ pub use ast::{ Assignee, AssigneeAccess, BasicOrStructType, BasicType, BinaryExpression, BinaryOperator, CallAccess, ConstantDefinition, ConstantGenericValue, DecimalLiteralExpression, DecimalNumber, DecimalSuffix, DefinitionStatement, ExplicitGenerics, Expression, FieldType, File, - FromExpression, Function, HexLiteralExpression, HexNumberExpression, IdentifierExpression, - ImportDirective, ImportSource, ImportSymbol, InlineArrayExpression, InlineStructExpression, - InlineStructMember, IterationStatement, LiteralExpression, OptionallyTypedAssignee, Parameter, - PostfixExpression, Range, RangeOrExpression, ReturnStatement, Span, Spread, SpreadOrExpression, - Statement, StructDefinition, StructField, TernaryExpression, ToExpression, Type, - UnaryExpression, UnaryOperator, Underscore, Visibility, + FromExpression, FunctionDefinition, HexLiteralExpression, HexNumberExpression, + IdentifierExpression, ImportDirective, ImportSource, ImportSymbol, InlineArrayExpression, + InlineStructExpression, InlineStructMember, IterationStatement, LiteralExpression, + OptionallyTypedAssignee, Parameter, PostfixExpression, Range, RangeOrExpression, + ReturnStatement, Span, Spread, SpreadOrExpression, Statement, StructDefinition, StructField, + SymbolDeclaration, TernaryExpression, ToExpression, Type, UnaryExpression, UnaryOperator, + Underscore, Visibility, }; mod ast { @@ -109,10 +110,7 @@ mod ast { #[pest_ast(rule(Rule::file))] pub struct File<'ast> { pub pragma: Option>, - pub imports: Vec>, - pub structs: Vec>, - pub constants: Vec>, - pub functions: Vec>, + pub declarations: Vec>, pub eoi: EOI, #[pest_ast(outer())] pub span: Span<'ast>, @@ -135,6 +133,16 @@ mod ast { pub span: Span<'ast>, } + #[allow(clippy::large_enum_variant)] + #[derive(Debug, FromPest, PartialEq, Clone)] + #[pest_ast(rule(Rule::symbol_declaration))] + pub enum SymbolDeclaration<'ast> { + Import(ImportDirective<'ast>), + Constant(ConstantDefinition<'ast>), + Struct(StructDefinition<'ast>), + Function(FunctionDefinition<'ast>), + } + #[derive(Debug, FromPest, PartialEq, Clone)] #[pest_ast(rule(Rule::ty_struct_definition))] pub struct StructDefinition<'ast> { @@ -155,7 +163,7 @@ mod ast { #[derive(Debug, FromPest, PartialEq, Clone)] #[pest_ast(rule(Rule::function_definition))] - pub struct Function<'ast> { + pub struct FunctionDefinition<'ast> { pub id: IdentifierExpression<'ast>, pub generics: Vec>, pub parameters: Vec>, @@ -194,7 +202,7 @@ mod ast { #[derive(Debug, FromPest, PartialEq, Clone)] #[pest_ast(rule(Rule::import_symbol))] pub struct ImportSymbol<'ast> { - pub symbol: IdentifierExpression<'ast>, + pub id: IdentifierExpression<'ast>, pub alias: Option>, #[pest_ast(outer())] pub span: Span<'ast>, @@ -1057,52 +1065,52 @@ mod tests { generate_ast(&source), Ok(File { pragma: None, - structs: vec![], - constants: vec![], - functions: vec![Function { - generics: vec![], - id: IdentifierExpression { - value: String::from("main"), - span: Span::new(&source, 33, 37).unwrap() - }, - parameters: vec![], - returns: vec![Type::Basic(BasicType::Field(FieldType { - span: Span::new(&source, 44, 49).unwrap() - }))], - statements: vec![Statement::Return(ReturnStatement { - expressions: vec![Expression::add( - Expression::Literal(LiteralExpression::DecimalLiteral( - DecimalLiteralExpression { - value: DecimalNumber { + declarations: vec![ + SymbolDeclaration::Import(ImportDirective::Main(MainImportDirective { + source: ImportSource { + value: String::from("foo"), + span: Span::new(&source, 8, 11).unwrap() + }, + alias: None, + span: Span::new(&source, 0, 29).unwrap() + })), + SymbolDeclaration::Function(FunctionDefinition { + generics: vec![], + id: IdentifierExpression { + value: String::from("main"), + span: Span::new(&source, 33, 37).unwrap() + }, + parameters: vec![], + returns: vec![Type::Basic(BasicType::Field(FieldType { + span: Span::new(&source, 44, 49).unwrap() + }))], + statements: vec![Statement::Return(ReturnStatement { + expressions: vec![Expression::add( + Expression::Literal(LiteralExpression::DecimalLiteral( + DecimalLiteralExpression { + value: DecimalNumber { + span: Span::new(&source, 59, 60).unwrap() + }, + suffix: None, span: Span::new(&source, 59, 60).unwrap() - }, - suffix: None, - span: Span::new(&source, 59, 60).unwrap() - } - )), - Expression::Literal(LiteralExpression::DecimalLiteral( - DecimalLiteralExpression { - value: DecimalNumber { + } + )), + Expression::Literal(LiteralExpression::DecimalLiteral( + DecimalLiteralExpression { + value: DecimalNumber { + span: Span::new(&source, 63, 64).unwrap() + }, + suffix: None, span: Span::new(&source, 63, 64).unwrap() - }, - suffix: None, - span: Span::new(&source, 63, 64).unwrap() - } - )), - Span::new(&source, 59, 64).unwrap() - )], - span: Span::new(&source, 52, 64).unwrap(), - })], - span: Span::new(&source, 29, source.len()).unwrap(), - }], - imports: vec![ImportDirective::Main(MainImportDirective { - source: ImportSource { - value: String::from("foo"), - span: Span::new(&source, 8, 11).unwrap() - }, - alias: None, - span: Span::new(&source, 0, 29).unwrap() - })], + } + )), + Span::new(&source, 59, 64).unwrap() + )], + span: Span::new(&source, 52, 64).unwrap(), + })], + span: Span::new(&source, 29, source.len()).unwrap(), + }) + ], eoi: EOI {}, span: Span::new(&source, 0, 65).unwrap() }) @@ -1118,76 +1126,76 @@ mod tests { generate_ast(&source), Ok(File { pragma: None, - structs: vec![], - constants: vec![], - functions: vec![Function { - generics: vec![], - id: IdentifierExpression { - value: String::from("main"), - span: Span::new(&source, 33, 37).unwrap() - }, - parameters: vec![], - returns: vec![Type::Basic(BasicType::Field(FieldType { - span: Span::new(&source, 44, 49).unwrap() - }))], - statements: vec![Statement::Return(ReturnStatement { - expressions: vec![Expression::add( - Expression::Literal(LiteralExpression::DecimalLiteral( - DecimalLiteralExpression { - suffix: None, - value: DecimalNumber { - span: Span::new(&source, 59, 60).unwrap() - }, - span: Span::new(&source, 59, 60).unwrap() - } - )), - Expression::mul( + declarations: vec![ + SymbolDeclaration::Import(ImportDirective::Main(MainImportDirective { + source: ImportSource { + value: String::from("foo"), + span: Span::new(&source, 8, 11).unwrap() + }, + alias: None, + span: Span::new(&source, 0, 29).unwrap() + })), + SymbolDeclaration::Function(FunctionDefinition { + generics: vec![], + id: IdentifierExpression { + value: String::from("main"), + span: Span::new(&source, 33, 37).unwrap() + }, + parameters: vec![], + returns: vec![Type::Basic(BasicType::Field(FieldType { + span: Span::new(&source, 44, 49).unwrap() + }))], + statements: vec![Statement::Return(ReturnStatement { + expressions: vec![Expression::add( Expression::Literal(LiteralExpression::DecimalLiteral( DecimalLiteralExpression { suffix: None, value: DecimalNumber { - span: Span::new(&source, 63, 64).unwrap() + span: Span::new(&source, 59, 60).unwrap() }, - span: Span::new(&source, 63, 64).unwrap() + span: Span::new(&source, 59, 60).unwrap() } )), - Expression::pow( + Expression::mul( Expression::Literal(LiteralExpression::DecimalLiteral( DecimalLiteralExpression { suffix: None, value: DecimalNumber { + span: Span::new(&source, 63, 64).unwrap() + }, + span: Span::new(&source, 63, 64).unwrap() + } + )), + Expression::pow( + Expression::Literal(LiteralExpression::DecimalLiteral( + DecimalLiteralExpression { + suffix: None, + value: DecimalNumber { + span: Span::new(&source, 67, 68).unwrap() + }, span: Span::new(&source, 67, 68).unwrap() - }, - span: Span::new(&source, 67, 68).unwrap() - } - )), - Expression::Literal(LiteralExpression::DecimalLiteral( - DecimalLiteralExpression { - suffix: None, - value: DecimalNumber { + } + )), + Expression::Literal(LiteralExpression::DecimalLiteral( + DecimalLiteralExpression { + suffix: None, + value: DecimalNumber { + span: Span::new(&source, 72, 73).unwrap() + }, span: Span::new(&source, 72, 73).unwrap() - }, - span: Span::new(&source, 72, 73).unwrap() - } - )), - Span::new(&source, 67, 73).unwrap() + } + )), + Span::new(&source, 67, 73).unwrap() + ), + Span::new(&source, 63, 73).unwrap() ), - Span::new(&source, 63, 73).unwrap() - ), - Span::new(&source, 59, 73).unwrap() - )], - span: Span::new(&source, 52, 73).unwrap(), - })], - span: Span::new(&source, 29, 74).unwrap(), - }], - imports: vec![ImportDirective::Main(MainImportDirective { - source: ImportSource { - value: String::from("foo"), - span: Span::new(&source, 8, 11).unwrap() - }, - alias: None, - span: Span::new(&source, 0, 29).unwrap() - })], + Span::new(&source, 59, 73).unwrap() + )], + span: Span::new(&source, 52, 73).unwrap(), + })], + span: Span::new(&source, 29, 74).unwrap(), + }) + ], eoi: EOI {}, span: Span::new(&source, 0, 74).unwrap() }) @@ -1203,61 +1211,61 @@ mod tests { generate_ast(&source), Ok(File { pragma: None, - structs: vec![], - constants: vec![], - functions: vec![Function { - generics: vec![], - id: IdentifierExpression { - value: String::from("main"), - span: Span::new(&source, 33, 37).unwrap() - }, - parameters: vec![], - returns: vec![Type::Basic(BasicType::Field(FieldType { - span: Span::new(&source, 44, 49).unwrap() - }))], - statements: vec![Statement::Return(ReturnStatement { - expressions: vec![Expression::if_else( - Expression::Literal(LiteralExpression::DecimalLiteral( - DecimalLiteralExpression { - suffix: None, - value: DecimalNumber { + declarations: vec![ + SymbolDeclaration::Import(ImportDirective::Main(MainImportDirective { + source: ImportSource { + value: String::from("foo"), + span: Span::new(&source, 8, 11).unwrap() + }, + alias: None, + span: Span::new(&source, 0, 29).unwrap() + })), + SymbolDeclaration::Function(FunctionDefinition { + generics: vec![], + id: IdentifierExpression { + value: String::from("main"), + span: Span::new(&source, 33, 37).unwrap() + }, + parameters: vec![], + returns: vec![Type::Basic(BasicType::Field(FieldType { + span: Span::new(&source, 44, 49).unwrap() + }))], + statements: vec![Statement::Return(ReturnStatement { + expressions: vec![Expression::if_else( + Expression::Literal(LiteralExpression::DecimalLiteral( + DecimalLiteralExpression { + suffix: None, + value: DecimalNumber { + span: Span::new(&source, 62, 63).unwrap() + }, span: Span::new(&source, 62, 63).unwrap() - }, - span: Span::new(&source, 62, 63).unwrap() - } - )), - Expression::Literal(LiteralExpression::DecimalLiteral( - DecimalLiteralExpression { - suffix: None, - value: DecimalNumber { + } + )), + Expression::Literal(LiteralExpression::DecimalLiteral( + DecimalLiteralExpression { + suffix: None, + value: DecimalNumber { + span: Span::new(&source, 69, 70).unwrap() + }, span: Span::new(&source, 69, 70).unwrap() - }, - span: Span::new(&source, 69, 70).unwrap() - } - )), - Expression::Literal(LiteralExpression::DecimalLiteral( - DecimalLiteralExpression { - suffix: None, - value: DecimalNumber { + } + )), + Expression::Literal(LiteralExpression::DecimalLiteral( + DecimalLiteralExpression { + suffix: None, + value: DecimalNumber { + span: Span::new(&source, 76, 77).unwrap() + }, span: Span::new(&source, 76, 77).unwrap() - }, - span: Span::new(&source, 76, 77).unwrap() - } - )), - Span::new(&source, 59, 80).unwrap() - )], - span: Span::new(&source, 52, 80).unwrap(), - })], - span: Span::new(&source, 29, 81).unwrap(), - }], - imports: vec![ImportDirective::Main(MainImportDirective { - source: ImportSource { - value: String::from("foo"), - span: Span::new(&source, 8, 11).unwrap() - }, - alias: None, - span: Span::new(&source, 0, 29).unwrap() - })], + } + )), + Span::new(&source, 59, 80).unwrap() + )], + span: Span::new(&source, 52, 80).unwrap(), + })], + span: Span::new(&source, 29, 81).unwrap(), + }) + ], eoi: EOI {}, span: Span::new(&source, 0, 81).unwrap() }) @@ -1272,9 +1280,7 @@ mod tests { generate_ast(&source), Ok(File { pragma: None, - structs: vec![], - constants: vec![], - functions: vec![Function { + declarations: vec![SymbolDeclaration::Function(FunctionDefinition { generics: vec![], id: IdentifierExpression { value: String::from("main"), @@ -1297,8 +1303,7 @@ mod tests { span: Span::new(&source, 23, 33).unwrap(), })], span: Span::new(&source, 0, 34).unwrap(), - }], - imports: vec![], + })], eoi: EOI {}, span: Span::new(&source, 0, 34).unwrap() }) @@ -1313,9 +1318,7 @@ mod tests { generate_ast(&source), Ok(File { pragma: None, - structs: vec![], - constants: vec![], - functions: vec![Function { + declarations: vec![SymbolDeclaration::Function(FunctionDefinition { generics: vec![], id: IdentifierExpression { value: String::from("main"), @@ -1403,8 +1406,7 @@ mod tests { span: Span::new(&source, 23, 49).unwrap() })], span: Span::new(&source, 0, 50).unwrap(), - }], - imports: vec![], + })], eoi: EOI {}, span: Span::new(&source, 0, 50).unwrap() }) diff --git a/zokrates_stdlib/stdlib/hashes/blake2/blake2s_p.zok b/zokrates_stdlib/stdlib/hashes/blake2/blake2s_p.zok index acb5dc81..d4754a16 100644 --- a/zokrates_stdlib/stdlib/hashes/blake2/blake2s_p.zok +++ b/zokrates_stdlib/stdlib/hashes/blake2/blake2s_p.zok @@ -1,7 +1,7 @@ // https://tools.ietf.org/html/rfc7693 -import "EMBED/u32_to_bits" as to_bits -import "EMBED/u32_from_bits" as from_bits +import "utils/casts/u32_to_bits" +import "utils/casts/u32_from_bits" def rotr32(u32 x) -> u32: return (x >> N) | (x << (32 - N)) diff --git a/zokrates_stdlib/stdlib/hashes/pedersen/512bit.zok b/zokrates_stdlib/stdlib/hashes/pedersen/512bit.zok index bbc66f9b..6f6eb3af 100644 --- a/zokrates_stdlib/stdlib/hashes/pedersen/512bit.zok +++ b/zokrates_stdlib/stdlib/hashes/pedersen/512bit.zok @@ -1,6 +1,6 @@ import "./512bitBool.zok" as pedersen -import "EMBED/u32_to_bits" as to_bits -import "EMBED/u32_from_bits" as from_bits +import "utils/casts/u32_to_bits" as to_bits +import "utils/casts/u32_from_bits" as from_bits def main(u32[16] inputs) -> u32[8]: bool[512] e = [\ diff --git a/zokrates_stdlib/stdlib/hashes/sha256/embed/shaRoundNoBoolCheck.zok b/zokrates_stdlib/stdlib/hashes/sha256/embed/shaRoundNoBoolCheck.zok index 1f971ae0..7e16650b 100644 --- a/zokrates_stdlib/stdlib/hashes/sha256/embed/shaRoundNoBoolCheck.zok +++ b/zokrates_stdlib/stdlib/hashes/sha256/embed/shaRoundNoBoolCheck.zok @@ -1,5 +1,5 @@ #pragma curve bn128 -import "EMBED/sha256round" as sha256round +from "EMBED" import sha256round // a and b is NOT checked to be 0 or 1 // the return value is checked to be 0 or 1 diff --git a/zokrates_stdlib/stdlib/utils/casts/bool_array_to_u32_array.zok b/zokrates_stdlib/stdlib/utils/casts/bool_array_to_u32_array.zok index ba1f2693..50983a90 100644 --- a/zokrates_stdlib/stdlib/utils/casts/bool_array_to_u32_array.zok +++ b/zokrates_stdlib/stdlib/utils/casts/bool_array_to_u32_array.zok @@ -1,4 +1,4 @@ -import "EMBED/u32_from_bits" as from_bits +from "EMBED" import u32_from_bits // convert an array of bool to an array of u32 // the sizes must match (one u32 for 32 bool) otherwise an error will happen @@ -9,7 +9,7 @@ def main(bool[N] bits) -> u32[P]: u32[P] res = [0; P] for u32 i in 0..P do - res[i] = from_bits(bits[32 * i..32 * (i + 1)]) + res[i] = u32_from_bits(bits[32 * i..32 * (i + 1)]) endfor return res \ No newline at end of file diff --git a/zokrates_stdlib/stdlib/utils/casts/field_to_u16.zok b/zokrates_stdlib/stdlib/utils/casts/field_to_u16.zok index 20b5efa5..9f39cbc1 100644 --- a/zokrates_stdlib/stdlib/utils/casts/field_to_u16.zok +++ b/zokrates_stdlib/stdlib/utils/casts/field_to_u16.zok @@ -1,6 +1,5 @@ -import "EMBED/unpack" as unpack -import "EMBED/u16_from_bits" as from_bits +from "EMBED" import unpack, u16_from_bits def main(field i) -> u16: bool[16] bits = unpack(i) - return from_bits(bits) \ No newline at end of file + return u16_from_bits(bits) \ No newline at end of file diff --git a/zokrates_stdlib/stdlib/utils/casts/field_to_u32.zok b/zokrates_stdlib/stdlib/utils/casts/field_to_u32.zok index 5e9827e8..cf14aa90 100644 --- a/zokrates_stdlib/stdlib/utils/casts/field_to_u32.zok +++ b/zokrates_stdlib/stdlib/utils/casts/field_to_u32.zok @@ -1,6 +1,5 @@ -import "EMBED/unpack" as unpack -import "EMBED/u32_from_bits" as from_bits +from "EMBED" import unpack, u32_from_bits def main(field i) -> u32: bool[32] bits = unpack(i) - return from_bits(bits) \ No newline at end of file + return u32_from_bits(bits) \ No newline at end of file diff --git a/zokrates_stdlib/stdlib/utils/casts/field_to_u64.zok b/zokrates_stdlib/stdlib/utils/casts/field_to_u64.zok index cf72c5fe..8433dd63 100644 --- a/zokrates_stdlib/stdlib/utils/casts/field_to_u64.zok +++ b/zokrates_stdlib/stdlib/utils/casts/field_to_u64.zok @@ -1,6 +1,5 @@ -import "EMBED/unpack" as unpack -import "EMBED/u64_from_bits" as from_bits +from "EMBED" import unpack, u64_from_bits def main(field i) -> u64: bool[64] bits = unpack(i) - return from_bits(bits) \ No newline at end of file + return u64_from_bits(bits) \ No newline at end of file diff --git a/zokrates_stdlib/stdlib/utils/casts/field_to_u8.zok b/zokrates_stdlib/stdlib/utils/casts/field_to_u8.zok index 7efb927c..3045e302 100644 --- a/zokrates_stdlib/stdlib/utils/casts/field_to_u8.zok +++ b/zokrates_stdlib/stdlib/utils/casts/field_to_u8.zok @@ -1,6 +1,5 @@ -import "EMBED/unpack" as unpack -import "EMBED/u8_from_bits" as from_bits +from "EMBED" import unpack, u8_from_bits def main(field i) -> u8: bool[8] bits = unpack(i) - return from_bits(bits) \ No newline at end of file + return u8_from_bits(bits) \ No newline at end of file diff --git a/zokrates_stdlib/stdlib/utils/casts/u16_from_bits.zok b/zokrates_stdlib/stdlib/utils/casts/u16_from_bits.zok index 0bf8cf4f..35209f6e 100644 --- a/zokrates_stdlib/stdlib/utils/casts/u16_from_bits.zok +++ b/zokrates_stdlib/stdlib/utils/casts/u16_from_bits.zok @@ -1,4 +1,4 @@ -import "EMBED/u16_from_bits" as from_bits +from "EMBED" import u16_from_bits def main(bool[16] a) -> u16: - return from_bits(a) \ No newline at end of file + return u16_from_bits(a) \ No newline at end of file diff --git a/zokrates_stdlib/stdlib/utils/casts/u16_to_bits.zok b/zokrates_stdlib/stdlib/utils/casts/u16_to_bits.zok index f1dd9fee..33a86e63 100644 --- a/zokrates_stdlib/stdlib/utils/casts/u16_to_bits.zok +++ b/zokrates_stdlib/stdlib/utils/casts/u16_to_bits.zok @@ -1,4 +1,4 @@ -import "EMBED/u16_to_bits" as to_bits +from "EMBED" import u16_to_bits def main(u16 a) -> bool[16]: - return to_bits(a) \ No newline at end of file + return u16_to_bits(a) \ No newline at end of file diff --git a/zokrates_stdlib/stdlib/utils/casts/u16_to_field.zok b/zokrates_stdlib/stdlib/utils/casts/u16_to_field.zok index 690419fb..ca5c8f27 100644 --- a/zokrates_stdlib/stdlib/utils/casts/u16_to_field.zok +++ b/zokrates_stdlib/stdlib/utils/casts/u16_to_field.zok @@ -1,7 +1,7 @@ -import "EMBED/u16_to_bits" as to_bits +from "EMBED" import u16_to_bits def main(u16 i) -> field: - bool[16] bits = to_bits(i) + bool[16] bits = u16_to_bits(i) field res = 0 for u32 j in 0..16 do u32 exponent = 16 - j - 1 diff --git a/zokrates_stdlib/stdlib/utils/casts/u32_array_to_bool_array.zok b/zokrates_stdlib/stdlib/utils/casts/u32_array_to_bool_array.zok index 28c3d65d..71ec03fe 100644 --- a/zokrates_stdlib/stdlib/utils/casts/u32_array_to_bool_array.zok +++ b/zokrates_stdlib/stdlib/utils/casts/u32_array_to_bool_array.zok @@ -1,4 +1,4 @@ -import "EMBED/u32_to_bits" as to_bits +from "EMBED" import u32_to_bits def main(u32[N] input) -> bool[P]: assert(P == 32 * N) @@ -6,7 +6,7 @@ def main(u32[N] input) -> bool[P]: bool[P] res = [false; P] for u32 i in 0..N do - bool[32] bits = to_bits(input[i]) + bool[32] bits = u32_to_bits(input[i]) for u32 j in 0..32 do res[i * 32 + j] = bits[j] endfor diff --git a/zokrates_stdlib/stdlib/utils/casts/u32_from_bits.zok b/zokrates_stdlib/stdlib/utils/casts/u32_from_bits.zok index f4620c44..1bca1aae 100644 --- a/zokrates_stdlib/stdlib/utils/casts/u32_from_bits.zok +++ b/zokrates_stdlib/stdlib/utils/casts/u32_from_bits.zok @@ -1,4 +1,4 @@ -import "EMBED/u32_from_bits" as from_bits +from "EMBED" import u32_from_bits def main(bool[32] a) -> u32: - return from_bits(a) \ No newline at end of file + return u32_from_bits(a) \ No newline at end of file diff --git a/zokrates_stdlib/stdlib/utils/casts/u32_to_bits.zok b/zokrates_stdlib/stdlib/utils/casts/u32_to_bits.zok index 3b68cdbd..6087717f 100644 --- a/zokrates_stdlib/stdlib/utils/casts/u32_to_bits.zok +++ b/zokrates_stdlib/stdlib/utils/casts/u32_to_bits.zok @@ -1,4 +1,4 @@ -import "EMBED/u32_to_bits" as to_bits +from "EMBED" import u32_to_bits def main(u32 a) -> bool[32]: - return to_bits(a) \ No newline at end of file + return u32_to_bits(a) \ No newline at end of file diff --git a/zokrates_stdlib/stdlib/utils/casts/u32_to_field.zok b/zokrates_stdlib/stdlib/utils/casts/u32_to_field.zok index 7c181d83..a63f9941 100644 --- a/zokrates_stdlib/stdlib/utils/casts/u32_to_field.zok +++ b/zokrates_stdlib/stdlib/utils/casts/u32_to_field.zok @@ -1,7 +1,7 @@ -import "EMBED/u32_to_bits" as to_bits +from "EMBED" import u32_to_bits def main(u32 i) -> field: - bool[32] bits = to_bits(i) + bool[32] bits = u32_to_bits(i) field res = 0 for u32 j in 0..32 do u32 exponent = 32 - j - 1 diff --git a/zokrates_stdlib/stdlib/utils/casts/u64_from_bits.zok b/zokrates_stdlib/stdlib/utils/casts/u64_from_bits.zok index be30561f..c7d6fc82 100644 --- a/zokrates_stdlib/stdlib/utils/casts/u64_from_bits.zok +++ b/zokrates_stdlib/stdlib/utils/casts/u64_from_bits.zok @@ -1,4 +1,4 @@ -import "EMBED/u64_from_bits" as from_bits +from "EMBED" import u64_from_bits def main(bool[64] a) -> u64: - return from_bits(a) \ No newline at end of file + return u64_from_bits(a) \ No newline at end of file diff --git a/zokrates_stdlib/stdlib/utils/casts/u64_to_bits.zok b/zokrates_stdlib/stdlib/utils/casts/u64_to_bits.zok index a7e71ed4..95cca409 100644 --- a/zokrates_stdlib/stdlib/utils/casts/u64_to_bits.zok +++ b/zokrates_stdlib/stdlib/utils/casts/u64_to_bits.zok @@ -1,4 +1,4 @@ -import "EMBED/u64_to_bits" as to_bits +from "EMBED" import u64_to_bits def main(u64 a) -> bool[64]: - return to_bits(a) \ No newline at end of file + return u64_to_bits(a) \ No newline at end of file diff --git a/zokrates_stdlib/stdlib/utils/casts/u64_to_field.zok b/zokrates_stdlib/stdlib/utils/casts/u64_to_field.zok index 20895867..865ca3d8 100644 --- a/zokrates_stdlib/stdlib/utils/casts/u64_to_field.zok +++ b/zokrates_stdlib/stdlib/utils/casts/u64_to_field.zok @@ -1,7 +1,7 @@ -import "EMBED/u64_to_bits" as to_bits +from "EMBED" import u64_to_bits def main(u64 i) -> field: - bool[64] bits = to_bits(i) + bool[64] bits = u64_to_bits(i) field res = 0 for u32 j in 0..64 do u32 exponent = 64 - j - 1 diff --git a/zokrates_stdlib/stdlib/utils/casts/u8_from_bits.zok b/zokrates_stdlib/stdlib/utils/casts/u8_from_bits.zok index e1a0ade4..cf251542 100644 --- a/zokrates_stdlib/stdlib/utils/casts/u8_from_bits.zok +++ b/zokrates_stdlib/stdlib/utils/casts/u8_from_bits.zok @@ -1,4 +1,4 @@ -import "EMBED/u8_from_bits" as from_bits +from "EMBED" import u8_from_bits def main(bool[8] a) -> u8: - return from_bits(a) \ No newline at end of file + return u8_from_bits(a) \ No newline at end of file diff --git a/zokrates_stdlib/stdlib/utils/casts/u8_to_bits.zok b/zokrates_stdlib/stdlib/utils/casts/u8_to_bits.zok index d2ffc8ec..c3a60efe 100644 --- a/zokrates_stdlib/stdlib/utils/casts/u8_to_bits.zok +++ b/zokrates_stdlib/stdlib/utils/casts/u8_to_bits.zok @@ -1,4 +1,4 @@ -import "EMBED/u8_to_bits" as to_bits +from "EMBED" import u8_to_bits def main(u8 a) -> bool[8]: - return to_bits(a) \ No newline at end of file + return u8_to_bits(a) \ No newline at end of file diff --git a/zokrates_stdlib/stdlib/utils/casts/u8_to_field.zok b/zokrates_stdlib/stdlib/utils/casts/u8_to_field.zok index b1d14ba7..82d6c173 100644 --- a/zokrates_stdlib/stdlib/utils/casts/u8_to_field.zok +++ b/zokrates_stdlib/stdlib/utils/casts/u8_to_field.zok @@ -1,7 +1,7 @@ -import "EMBED/u8_to_bits" as to_bits +from "EMBED" import u8_to_bits def main(u8 i) -> field: - bool[8] bits = to_bits(i) + bool[8] bits = u8_to_bits(i) field res = 0 for u32 j in 0..8 do u32 exponent = 8 - j - 1 diff --git a/zokrates_stdlib/stdlib/utils/pack/bool/nonStrictUnpack256.zok b/zokrates_stdlib/stdlib/utils/pack/bool/nonStrictUnpack256.zok index ca356e8c..4e48909f 100644 --- a/zokrates_stdlib/stdlib/utils/pack/bool/nonStrictUnpack256.zok +++ b/zokrates_stdlib/stdlib/utils/pack/bool/nonStrictUnpack256.zok @@ -1,12 +1,12 @@ #pragma curve bn128 -import "EMBED/unpack" as unpack +import "./unpack" as unpack // Unpack a field element as 256 big-endian bits // Note: uniqueness of the output is not guaranteed // For example, `0` can map to `[0, 0, ..., 0]` or to `bits(p)` def main(field i) -> bool[256]: - bool[254] b = unpack(i) + bool[254] b = unpack::<254>(i) return [false, false, ...b] \ No newline at end of file diff --git a/zokrates_stdlib/stdlib/utils/pack/bool/unpack.zok b/zokrates_stdlib/stdlib/utils/pack/bool/unpack.zok index 38bd04a7..d5b7a5cd 100644 --- a/zokrates_stdlib/stdlib/utils/pack/bool/unpack.zok +++ b/zokrates_stdlib/stdlib/utils/pack/bool/unpack.zok @@ -1,6 +1,6 @@ #pragma curve bn128 -import "EMBED/unpack" as unpack +from "EMBED" import unpack // Unpack a field element as N big endian bits def main(field i) -> bool[N]: diff --git a/zokrates_stdlib/stdlib/utils/pack/u32/pack256.zok b/zokrates_stdlib/stdlib/utils/pack/u32/pack256.zok index f8c0ae68..ab549658 100644 --- a/zokrates_stdlib/stdlib/utils/pack/u32/pack256.zok +++ b/zokrates_stdlib/stdlib/utils/pack/u32/pack256.zok @@ -1,11 +1,20 @@ -import "EMBED/u32_to_bits" as to_bits -from "../bool/pack256.zok" import main as pack256 +import "../../casts/u32_to_bits" +import "../bool/pack256" // pack 256 big-endian bits into one field element // Note: This is not a injective operation as `p` is smaller than `2**256 - 1 for bn128 // For example, `[0, 0,..., 0]` and `bits(p)` both point to `0` def main(u32[8] input) -> field: - bool[256] bits = [...to_bits(input[0]), ...to_bits(input[1]), ...to_bits(input[2]), ...to_bits(input[3]), ...to_bits(input[4]), ...to_bits(input[5]), ...to_bits(input[6]), ...to_bits(input[7])] + bool[256] bits = [ + ...u32_to_bits(input[0]), + ...u32_to_bits(input[1]), + ...u32_to_bits(input[2]), + ...u32_to_bits(input[3]), + ...u32_to_bits(input[4]), + ...u32_to_bits(input[5]), + ...u32_to_bits(input[6]), + ...u32_to_bits(input[7]) + ] return pack256(bits)