1
0
Fork 0
mirror of synced 2025-09-23 12:18:44 +00:00

Merge pull request #863 from Zokrates/remove-strict-ordering

Relax ordering of symbol declarations
This commit is contained in:
Thibaut Schaeffer 2021-05-17 22:52:11 +02:00 committed by GitHub
commit 62dc3b072e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
64 changed files with 1249 additions and 1432 deletions

View file

@ -0,0 +1 @@
Relax ordering of symbol declarations

View file

@ -0,0 +1 @@
Support the use of constants in struct and function declarations

View file

@ -0,0 +1,4 @@
const field SIZE = 2
def main(field[SIZE] n):
return

View file

@ -0,0 +1,4 @@
const u8 SIZE = 0x02
def main(field[SIZE] n):
return

View file

@ -0,0 +1,7 @@
const u32 N = 42
def foo<N>(field[N] a) -> bool:
return true
def main():
return

View file

@ -1,66 +1,71 @@
use crate::absy;
use crate::imports;
use crate::absy::SymbolDefinition;
use num_bigint::BigUint;
use std::path::Path;
use zokrates_pest_ast as pest;
impl<'ast> From<pest::File<'ast>> for absy::Module<'ast> {
fn from(prog: pest::File<'ast>) -> absy::Module<'ast> {
absy::Module::with_symbols(
prog.structs
.into_iter()
.map(absy::SymbolDeclarationNode::from)
.chain(
prog.constants
.into_iter()
.map(absy::SymbolDeclarationNode::from),
)
.chain(
prog.functions
.into_iter()
.map(absy::SymbolDeclarationNode::from),
),
)
.imports(
prog.imports
.into_iter()
.map(absy::ImportDirective::from)
.flatten(),
)
fn from(file: pest::File<'ast>) -> absy::Module<'ast> {
absy::Module::with_symbols(file.declarations.into_iter().flat_map(|d| match d {
pest::SymbolDeclaration::Import(i) => import_directive_to_symbol_vec(i),
pest::SymbolDeclaration::Constant(c) => vec![c.into()],
pest::SymbolDeclaration::Struct(s) => vec![s.into()],
pest::SymbolDeclaration::Function(f) => vec![f.into()],
}))
}
}
impl<'ast> From<pest::ImportDirective<'ast>> for absy::ImportDirective<'ast> {
fn from(import: pest::ImportDirective<'ast>) -> absy::ImportDirective<'ast> {
use crate::absy::NodeValue;
fn import_directive_to_symbol_vec(
import: pest::ImportDirective,
) -> Vec<absy::SymbolDeclarationNode> {
use crate::absy::NodeValue;
match import {
pest::ImportDirective::Main(import) => absy::ImportDirective::Main(
imports::Import::new(None, std::path::Path::new(import.source.span.as_str()))
.alias(import.alias.map(|a| a.span.as_str()))
.span(import.span),
),
pest::ImportDirective::From(import) => absy::ImportDirective::From(
import
.symbols
.iter()
.map(|symbol| {
imports::Import::new(
Some(symbol.symbol.span.as_str()),
std::path::Path::new(import.source.span.as_str()),
)
.alias(
symbol
.alias
.as_ref()
.map(|a| a.span.as_str())
.or_else(|| Some(symbol.symbol.span.as_str())),
)
.span(symbol.span.clone())
})
.collect(),
),
match import {
pest::ImportDirective::Main(import) => {
let span = import.span;
let source = Path::new(import.source.span.as_str());
let id = "main";
let alias = import.alias.map(|a| a.span.as_str());
let import = absy::CanonicalImport {
source,
id: absy::SymbolIdentifier::from(id).alias(alias),
}
.span(span.clone());
vec![absy::SymbolDeclaration {
id: alias.unwrap_or(id),
symbol: absy::Symbol::Here(absy::SymbolDefinition::Import(import)),
}
.span(span.clone())]
}
pest::ImportDirective::From(import) => {
let span = import.span;
let source = Path::new(import.source.span.as_str());
import
.symbols
.into_iter()
.map(|symbol| {
let alias = symbol
.alias
.as_ref()
.map(|a| a.span.as_str())
.unwrap_or_else(|| symbol.id.span.as_str());
let import = absy::CanonicalImport {
source,
id: absy::SymbolIdentifier::from(symbol.id.span.as_str())
.alias(Some(alias)),
}
.span(span.clone());
absy::SymbolDeclaration {
id: alias,
symbol: absy::Symbol::Here(absy::SymbolDefinition::Import(import)),
}
.span(span.clone())
})
.collect()
}
}
}
@ -84,7 +89,7 @@ impl<'ast> From<pest::StructDefinition<'ast>> for absy::SymbolDeclarationNode<'a
absy::SymbolDeclaration {
id,
symbol: absy::Symbol::Here(SymbolDefinition::Struct(ty)),
symbol: absy::Symbol::Here(absy::SymbolDefinition::Struct(ty)),
}
.span(span)
}
@ -119,14 +124,14 @@ impl<'ast> From<pest::ConstantDefinition<'ast>> for absy::SymbolDeclarationNode<
absy::SymbolDeclaration {
id,
symbol: absy::Symbol::Here(SymbolDefinition::Constant(ty)),
symbol: absy::Symbol::Here(absy::SymbolDefinition::Constant(ty)),
}
.span(span)
}
}
impl<'ast> From<pest::Function<'ast>> for absy::SymbolDeclarationNode<'ast> {
fn from(function: pest::Function<'ast>) -> absy::SymbolDeclarationNode<'ast> {
impl<'ast> From<pest::FunctionDefinition<'ast>> for absy::SymbolDeclarationNode<'ast> {
fn from(function: pest::FunctionDefinition<'ast>) -> absy::SymbolDeclarationNode<'ast> {
use crate::absy::NodeValue;
let span = function.span;
@ -175,7 +180,7 @@ impl<'ast> From<pest::Function<'ast>> for absy::SymbolDeclarationNode<'ast> {
absy::SymbolDeclaration {
id,
symbol: absy::Symbol::Here(SymbolDefinition::Function(function)),
symbol: absy::Symbol::Here(absy::SymbolDefinition::Function(function)),
}
.span(span)
}
@ -781,7 +786,7 @@ mod tests {
let expected: absy::Module = absy::Module {
symbols: vec![absy::SymbolDeclaration {
id: &source[4..8],
symbol: absy::Symbol::Here(SymbolDefinition::Function(
symbol: absy::Symbol::Here(absy::SymbolDefinition::Function(
absy::Function {
arguments: vec![],
statements: vec![absy::Statement::Return(
@ -801,7 +806,6 @@ mod tests {
)),
}
.into()],
imports: vec![],
};
assert_eq!(absy::Module::from(ast), expected);
}
@ -813,7 +817,7 @@ mod tests {
let expected: absy::Module = absy::Module {
symbols: vec![absy::SymbolDeclaration {
id: &source[4..8],
symbol: absy::Symbol::Here(SymbolDefinition::Function(
symbol: absy::Symbol::Here(absy::SymbolDefinition::Function(
absy::Function {
arguments: vec![],
statements: vec![absy::Statement::Return(
@ -831,7 +835,6 @@ mod tests {
)),
}
.into()],
imports: vec![],
};
assert_eq!(absy::Module::from(ast), expected);
}
@ -844,7 +847,7 @@ mod tests {
let expected: absy::Module = absy::Module {
symbols: vec![absy::SymbolDeclaration {
id: &source[4..8],
symbol: absy::Symbol::Here(SymbolDefinition::Function(
symbol: absy::Symbol::Here(absy::SymbolDefinition::Function(
absy::Function {
arguments: vec![
absy::Parameter::private(
@ -884,7 +887,6 @@ mod tests {
)),
}
.into()],
imports: vec![],
};
assert_eq!(absy::Module::from(ast), expected);
@ -898,7 +900,7 @@ mod tests {
absy::Module {
symbols: vec![absy::SymbolDeclaration {
id: "main",
symbol: absy::Symbol::Here(SymbolDefinition::Function(
symbol: absy::Symbol::Here(absy::SymbolDefinition::Function(
absy::Function {
arguments: vec![absy::Parameter::private(
absy::Variable::new("a", ty.clone().mock()).into(),
@ -917,7 +919,6 @@ mod tests {
)),
}
.into()],
imports: vec![],
}
}
@ -972,7 +973,7 @@ mod tests {
absy::Module {
symbols: vec![absy::SymbolDeclaration {
id: "main",
symbol: absy::Symbol::Here(SymbolDefinition::Function(
symbol: absy::Symbol::Here(absy::SymbolDefinition::Function(
absy::Function {
arguments: vec![],
statements: vec![absy::Statement::Return(
@ -988,7 +989,6 @@ mod tests {
)),
}
.into()],
imports: vec![],
}
}

View file

@ -18,8 +18,6 @@ pub use crate::absy::variable::{Variable, VariableNode};
use crate::embed::FlatEmbed;
use std::path::{Path, PathBuf};
use crate::imports::ImportDirective;
use crate::imports::ImportNode;
use std::fmt;
use num_bigint::BigUint;
@ -44,61 +42,134 @@ pub struct Program<'ast> {
pub main: OwnedModuleId,
}
/// A declaration of a `FunctionSymbol`, be it from an import or a function definition
#[derive(PartialEq, Clone, Debug)]
#[derive(Debug, PartialEq, Clone)]
pub struct SymbolIdentifier<'ast> {
pub id: Identifier<'ast>,
pub alias: Option<Identifier<'ast>>,
}
impl<'ast> From<Identifier<'ast>> for SymbolIdentifier<'ast> {
fn from(id: &'ast str) -> Self {
SymbolIdentifier { id, alias: None }
}
}
impl<'ast> SymbolIdentifier<'ast> {
pub fn alias(mut self, alias: Option<Identifier<'ast>>) -> Self {
self.alias = alias;
self
}
pub fn get_alias(&self) -> Identifier<'ast> {
self.alias.unwrap_or(self.id)
}
}
impl<'ast> fmt::Display for SymbolIdentifier<'ast> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(
f,
"{}{}",
self.id,
self.alias.map(|a| format!(" as {}", a)).unwrap_or_default()
)
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct CanonicalImport<'ast> {
pub source: &'ast Path,
pub id: SymbolIdentifier<'ast>,
}
pub type CanonicalImportNode<'ast> = Node<CanonicalImport<'ast>>;
impl<'ast> fmt::Display for CanonicalImport<'ast> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "from \"{}\" import {}", self.source.display(), self.id)
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct SymbolImport<'ast> {
pub module_id: OwnedModuleId,
pub symbol_id: Identifier<'ast>,
}
pub type SymbolImportNode<'ast> = Node<SymbolImport<'ast>>;
impl<'ast> SymbolImport<'ast> {
pub fn with_id_in_module<S: Into<Identifier<'ast>>, U: Into<OwnedModuleId>>(
symbol_id: S,
module_id: U,
) -> Self {
SymbolImport {
symbol_id: symbol_id.into(),
module_id: module_id.into(),
}
}
}
impl<'ast> fmt::Display for SymbolImport<'ast> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(
f,
"from \"{}\" import {}",
self.module_id.display(),
self.symbol_id
)
}
}
/// A declaration of a symbol
#[derive(Debug, PartialEq, Clone)]
pub struct SymbolDeclaration<'ast> {
pub id: Identifier<'ast>,
pub symbol: Symbol<'ast>,
}
#[allow(clippy::large_enum_variant)]
#[derive(PartialEq, Clone)]
#[derive(Debug, PartialEq, Clone)]
pub enum SymbolDefinition<'ast> {
Import(CanonicalImportNode<'ast>),
Struct(StructDefinitionNode<'ast>),
Constant(ConstantDefinitionNode<'ast>),
Function(FunctionNode<'ast>),
}
impl<'ast> fmt::Debug for SymbolDefinition<'ast> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
SymbolDefinition::Struct(s) => write!(f, "Struct({:?})", s),
SymbolDefinition::Constant(c) => write!(f, "Constant({:?})", c),
SymbolDefinition::Function(func) => write!(f, "Function({:?})", func),
}
}
}
#[derive(PartialEq, Clone)]
#[derive(Debug, PartialEq, Clone)]
pub enum Symbol<'ast> {
Here(SymbolDefinition<'ast>),
There(SymbolImportNode<'ast>),
Flat(FlatEmbed),
}
impl<'ast> fmt::Debug for Symbol<'ast> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
Symbol::Here(k) => write!(f, "Here({:?})", k),
Symbol::There(i) => write!(f, "There({:?})", i),
Symbol::Flat(flat) => write!(f, "Flat({:?})", flat),
}
}
}
impl<'ast> fmt::Display for SymbolDeclaration<'ast> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self.symbol {
Symbol::Here(ref kind) => match kind {
SymbolDefinition::Struct(t) => write!(f, "struct {} {}", self.id, t),
SymbolDefinition::Constant(c) => write!(
match &self.symbol {
Symbol::Here(ref symbol) => match symbol {
SymbolDefinition::Import(ref i) => write!(
f,
"from \"{}\" import {}",
i.value.source.display(),
i.value.id
),
SymbolDefinition::Struct(ref t) => write!(f, "struct {} {}", self.id, t),
SymbolDefinition::Constant(ref c) => write!(
f,
"const {} {} = {}",
c.value.ty, self.id, c.value.expression
),
SymbolDefinition::Function(func) => write!(f, "def {}{}", self.id, func),
SymbolDefinition::Function(ref func) => {
write!(f, "def {}{}", self.id, func)
}
},
Symbol::There(ref import) => write!(f, "import {} as {}", import, self.id),
Symbol::There(ref i) => write!(
f,
"from \"{}\" import {} as {}",
i.value.module_id.display(),
i.value.symbol_id,
self.id
),
Symbol::Flat(ref flat_fun) => {
write!(f, "def {}{}:\n\t// hidden", self.id, flat_fun.signature())
}
@ -109,25 +180,18 @@ impl<'ast> fmt::Display for SymbolDeclaration<'ast> {
pub type SymbolDeclarationNode<'ast> = Node<SymbolDeclaration<'ast>>;
/// A module as a collection of `FunctionDeclaration`s
#[derive(Clone, PartialEq)]
#[derive(Debug, Clone, PartialEq)]
pub struct Module<'ast> {
/// Symbols of the module
pub symbols: Declarations<'ast>,
pub imports: Vec<ImportNode<'ast>>, // we still use `imports` as they are not directly converted into `FunctionDeclaration`s after the importer is done, `imports` is empty
}
impl<'ast> Module<'ast> {
pub fn with_symbols<I: IntoIterator<Item = SymbolDeclarationNode<'ast>>>(i: I) -> Self {
Module {
symbols: i.into_iter().collect(),
imports: vec![],
}
}
pub fn imports<I: IntoIterator<Item = ImportNode<'ast>>>(mut self, i: I) -> Self {
self.imports = i.into_iter().collect();
self
}
}
pub type UnresolvedTypeNode<'ast> = Node<UnresolvedType<'ast>>;
@ -169,7 +233,7 @@ impl<'ast> fmt::Display for StructDefinitionField<'ast> {
type StructDefinitionFieldNode<'ast> = Node<StructDefinitionField<'ast>>;
#[derive(Clone, PartialEq)]
#[derive(Debug, Clone, PartialEq)]
pub struct ConstantDefinition<'ast> {
pub ty: UnresolvedTypeNode<'ast>,
pub expression: ExpressionNode<'ast>,
@ -183,92 +247,21 @@ impl<'ast> fmt::Display for ConstantDefinition<'ast> {
}
}
impl<'ast> fmt::Debug for ConstantDefinition<'ast> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(
f,
"ConstantDefinition({:?}, {:?})",
self.ty, self.expression
)
}
}
/// An import
#[derive(Debug, Clone, PartialEq)]
pub struct SymbolImport<'ast> {
/// the id of the symbol in the target module. Note: there may be many candidates as imports statements do not specify the signature. In that case they must all be functions however.
pub symbol_id: Identifier<'ast>,
/// the id of the module to import from
pub module_id: OwnedModuleId,
}
type SymbolImportNode<'ast> = Node<SymbolImport<'ast>>;
impl<'ast> SymbolImport<'ast> {
pub fn with_id_in_module<S: Into<Identifier<'ast>>, U: Into<OwnedModuleId>>(
symbol_id: S,
module_id: U,
) -> Self {
SymbolImport {
symbol_id: symbol_id.into(),
module_id: module_id.into(),
}
}
}
impl<'ast> fmt::Display for SymbolImport<'ast> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(
f,
"{} from {}",
self.symbol_id,
self.module_id.display().to_string()
)
}
}
impl<'ast> fmt::Display for Module<'ast> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let mut res = vec![];
res.extend(
self.imports
.iter()
.map(|x| format!("{}", x))
.collect::<Vec<_>>(),
);
res.extend(
self.symbols
.iter()
.map(|x| format!("{}", x))
.collect::<Vec<_>>(),
);
let res = self
.symbols
.iter()
.map(|x| format!("{}", x))
.collect::<Vec<_>>();
write!(f, "{}", res.join("\n"))
}
}
impl<'ast> fmt::Debug for Module<'ast> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(
f,
"module(\n\timports:\n\t\t{}\n\tsymbols:\n\t\t{}\n)",
self.imports
.iter()
.map(|x| format!("{:?}", x))
.collect::<Vec<_>>()
.join("\n\t\t"),
self.symbols
.iter()
.map(|x| format!("{:?}", x))
.collect::<Vec<_>>()
.join("\n\t\t")
)
}
}
pub type ConstantGenericNode<'ast> = Node<Identifier<'ast>>;
/// A function defined locally
#[derive(Clone, PartialEq)]
#[derive(Debug, Clone, PartialEq)]
pub struct Function<'ast> {
/// Arguments of the function
pub arguments: Vec<ParameterNode<'ast>>,
@ -312,23 +305,8 @@ impl<'ast> fmt::Display for Function<'ast> {
}
}
impl<'ast> fmt::Debug for Function<'ast> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(
f,
"Function(arguments: {:?}, ...):\n{}",
self.arguments,
self.statements
.iter()
.map(|x| format!("\t{:?}", x))
.collect::<Vec<_>>()
.join("\n")
)
}
}
/// Something that we can assign to
#[derive(Clone, PartialEq)]
#[derive(Debug, Clone, PartialEq)]
pub enum Assignee<'ast> {
Identifier(Identifier<'ast>),
Select(Box<AssigneeNode<'ast>>, Box<RangeOrExpression<'ast>>),
@ -337,16 +315,6 @@ pub enum Assignee<'ast> {
pub type AssigneeNode<'ast> = Node<Assignee<'ast>>;
impl<'ast> fmt::Debug for Assignee<'ast> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
Assignee::Identifier(ref s) => write!(f, "Identifier({:?})", s),
Assignee::Select(ref a, ref e) => write!(f, "Select({:?}[{:?}])", a, e),
Assignee::Member(ref s, ref m) => write!(f, "Member({:?}.{:?})", s, m),
}
}
}
impl<'ast> fmt::Display for Assignee<'ast> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
@ -359,7 +327,7 @@ impl<'ast> fmt::Display for Assignee<'ast> {
/// A statement in a `Function`
#[allow(clippy::large_enum_variant)]
#[derive(Clone, PartialEq)]
#[derive(Debug, Clone, PartialEq)]
pub enum Statement<'ast> {
Return(ExpressionListNode<'ast>),
Declaration(VariableNode<'ast>),
@ -403,31 +371,8 @@ impl<'ast> fmt::Display for Statement<'ast> {
}
}
impl<'ast> fmt::Debug for Statement<'ast> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
Statement::Return(ref expr) => write!(f, "Return({:?})", expr),
Statement::Declaration(ref var) => write!(f, "Declaration({:?})", var),
Statement::Definition(ref lhs, ref rhs) => {
write!(f, "Definition({:?}, {:?})", lhs, rhs)
}
Statement::Assertion(ref e) => write!(f, "Assertion({:?})", e),
Statement::For(ref var, ref start, ref stop, ref list) => {
writeln!(f, "for {:?} in {:?}..{:?} do", var, start, stop)?;
for l in list {
writeln!(f, "\t\t{:?}", l)?;
}
write!(f, "\tendfor")
}
Statement::MultipleDefinition(ref lhs, ref rhs) => {
write!(f, "MultipleDefinition({:?}, {:?})", lhs, rhs)
}
}
}
}
/// An element of an inline array, can be a spread `...a` or an expression `a`
#[derive(Clone, PartialEq)]
#[derive(Debug, Clone, PartialEq)]
pub enum SpreadOrExpression<'ast> {
Spread(SpreadNode<'ast>),
Expression(ExpressionNode<'ast>),
@ -448,17 +393,8 @@ impl<'ast> fmt::Display for SpreadOrExpression<'ast> {
}
}
impl<'ast> fmt::Debug for SpreadOrExpression<'ast> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
SpreadOrExpression::Spread(ref s) => write!(f, "{:?}", s),
SpreadOrExpression::Expression(ref e) => write!(f, "{:?}", e),
}
}
}
/// The index in an array selector. Can be a range or an expression.
#[derive(Clone, PartialEq)]
#[derive(Debug, Clone, PartialEq)]
pub enum RangeOrExpression<'ast> {
Range(RangeNode<'ast>),
Expression(ExpressionNode<'ast>),
@ -473,13 +409,10 @@ impl<'ast> fmt::Display for RangeOrExpression<'ast> {
}
}
impl<'ast> fmt::Debug for RangeOrExpression<'ast> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
RangeOrExpression::Range(ref s) => write!(f, "{:?}", s),
RangeOrExpression::Expression(ref e) => write!(f, "{:?}", e),
}
}
/// A spread
#[derive(Debug, Clone, PartialEq)]
pub struct Spread<'ast> {
pub expression: ExpressionNode<'ast>,
}
pub type SpreadNode<'ast> = Node<Spread<'ast>>;
@ -490,20 +423,8 @@ impl<'ast> fmt::Display for Spread<'ast> {
}
}
impl<'ast> fmt::Debug for Spread<'ast> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "Spread({:?})", self.expression)
}
}
/// A spread
#[derive(Clone, PartialEq)]
pub struct Spread<'ast> {
pub expression: ExpressionNode<'ast>,
}
/// A range
#[derive(Clone, PartialEq)]
#[derive(Debug, Clone, PartialEq)]
pub struct Range<'ast> {
pub from: Option<ExpressionNode<'ast>>,
pub to: Option<ExpressionNode<'ast>>,
@ -528,14 +449,8 @@ impl<'ast> fmt::Display for Range<'ast> {
}
}
impl<'ast> fmt::Debug for Range<'ast> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "Range({:?}, {:?})", self.from, self.to)
}
}
/// An expression
#[derive(Clone, PartialEq)]
#[derive(Debug, Clone, PartialEq)]
pub enum Expression<'ast> {
IntConstant(BigUint),
FieldConstant(BigUint),
@ -672,73 +587,8 @@ impl<'ast> fmt::Display for Expression<'ast> {
}
}
impl<'ast> fmt::Debug for Expression<'ast> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
Expression::U8Constant(ref i) => write!(f, "U8({:x})", i),
Expression::U16Constant(ref i) => write!(f, "U16({:x})", i),
Expression::U32Constant(ref i) => write!(f, "U32({:x})", i),
Expression::U64Constant(ref i) => write!(f, "U64({:x})", i),
Expression::FieldConstant(ref i) => write!(f, "Field({:?})", i),
Expression::IntConstant(ref i) => write!(f, "Int({:?})", i),
Expression::Identifier(ref var) => write!(f, "Ide({})", var),
Expression::Add(ref lhs, ref rhs) => write!(f, "Add({:?}, {:?})", lhs, rhs),
Expression::Sub(ref lhs, ref rhs) => write!(f, "Sub({:?}, {:?})", lhs, rhs),
Expression::Mult(ref lhs, ref rhs) => write!(f, "Mult({:?}, {:?})", lhs, rhs),
Expression::Div(ref lhs, ref rhs) => write!(f, "Div({:?}, {:?})", lhs, rhs),
Expression::Rem(ref lhs, ref rhs) => write!(f, "Rem({:?}, {:?})", lhs, rhs),
Expression::Pow(ref lhs, ref rhs) => write!(f, "Pow({:?}, {:?})", lhs, rhs),
Expression::Neg(ref e) => write!(f, "Neg({:?})", e),
Expression::Pos(ref e) => write!(f, "Pos({:?})", e),
Expression::BooleanConstant(b) => write!(f, "{}", b),
Expression::IfElse(ref condition, ref consequent, ref alternative) => write!(
f,
"IfElse({:?}, {:?}, {:?})",
condition, consequent, alternative
),
Expression::FunctionCall(ref g, ref i, ref p) => {
write!(f, "FunctionCall({:?}, {:?}, (", g, i)?;
f.debug_list().entries(p.iter()).finish()?;
write!(f, ")")
}
Expression::Lt(ref lhs, ref rhs) => write!(f, "Lt({:?}, {:?})", lhs, rhs),
Expression::Le(ref lhs, ref rhs) => write!(f, "Le({:?}, {:?})", lhs, rhs),
Expression::Eq(ref lhs, ref rhs) => write!(f, "Eq({:?}, {:?})", lhs, rhs),
Expression::Ge(ref lhs, ref rhs) => write!(f, "Ge({:?}, {:?})", lhs, rhs),
Expression::Gt(ref lhs, ref rhs) => write!(f, "Gt({:?}, {:?})", lhs, rhs),
Expression::And(ref lhs, ref rhs) => write!(f, "And({:?}, {:?})", lhs, rhs),
Expression::Not(ref exp) => write!(f, "Not({:?})", exp),
Expression::InlineArray(ref exprs) => {
write!(f, "InlineArray([")?;
f.debug_list().entries(exprs.iter()).finish()?;
write!(f, "]")
}
Expression::ArrayInitializer(ref e, ref count) => {
write!(f, "ArrayInitializer({:?}, {:?})", e, count)
}
Expression::InlineStruct(ref id, ref members) => {
write!(f, "InlineStruct({:?}, [", id)?;
f.debug_list().entries(members.iter()).finish()?;
write!(f, "]")
}
Expression::Select(ref array, ref index) => {
write!(f, "Select({:?}, {:?})", array, index)
}
Expression::Member(ref struc, ref id) => write!(f, "Member({:?}, {:?})", struc, id),
Expression::Or(ref lhs, ref rhs) => write!(f, "Or({:?}, {:?})", lhs, rhs),
Expression::BitXor(ref lhs, ref rhs) => write!(f, "BitXor({:?}, {:?})", lhs, rhs),
Expression::BitAnd(ref lhs, ref rhs) => write!(f, "BitAnd({:?}, {:?})", lhs, rhs),
Expression::BitOr(ref lhs, ref rhs) => write!(f, "BitOr({:?}, {:?})", lhs, rhs),
Expression::LeftShift(ref lhs, ref rhs) => write!(f, "LeftShift({:?}, {:?})", lhs, rhs),
Expression::RightShift(ref lhs, ref rhs) => {
write!(f, "RightShift({:?}, {:?})", lhs, rhs)
}
}
}
}
/// A list of expressions, used in return statements
#[derive(Clone, PartialEq, Default)]
#[derive(Debug, Clone, PartialEq, Default)]
pub struct ExpressionList<'ast> {
pub expressions: Vec<ExpressionNode<'ast>>,
}
@ -756,9 +606,3 @@ impl<'ast> fmt::Display for ExpressionList<'ast> {
write!(f, "")
}
}
impl<'ast> fmt::Debug for ExpressionList<'ast> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "ExpressionList({:?})", self.expressions)
}
}

View file

@ -74,7 +74,6 @@ impl<V: NodeValue> From<V> for Node<V> {
use crate::absy::types::UnresolvedType;
use crate::absy::*;
use crate::imports::*;
impl<'ast> NodeValue for Expression<'ast> {}
impl<'ast> NodeValue for ExpressionList<'ast> {}
@ -87,10 +86,10 @@ impl<'ast> NodeValue for StructDefinitionField<'ast> {}
impl<'ast> NodeValue for ConstantDefinition<'ast> {}
impl<'ast> NodeValue for Function<'ast> {}
impl<'ast> NodeValue for Module<'ast> {}
impl<'ast> NodeValue for CanonicalImport<'ast> {}
impl<'ast> NodeValue for SymbolImport<'ast> {}
impl<'ast> NodeValue for Variable<'ast> {}
impl<'ast> NodeValue for Parameter<'ast> {}
impl<'ast> NodeValue for Import<'ast> {}
impl<'ast> NodeValue for Spread<'ast> {}
impl<'ast> NodeValue for Range<'ast> {}
impl<'ast> NodeValue for Identifier<'ast> {}

View file

@ -289,7 +289,7 @@ mod test {
assert!(res.unwrap_err().0[0]
.value()
.to_string()
.contains(&"Can't resolve import without a resolver"));
.contains(&"Cannot resolve import without a resolver"));
}
#[test]

View file

@ -56,94 +56,6 @@ impl From<io::Error> for Error {
}
}
#[derive(PartialEq, Clone)]
pub enum ImportDirective<'ast> {
Main(ImportNode<'ast>),
From(Vec<ImportNode<'ast>>),
}
impl<'ast> IntoIterator for ImportDirective<'ast> {
type Item = ImportNode<'ast>;
type IntoIter = std::vec::IntoIter<Self::Item>;
fn into_iter(self) -> Self::IntoIter {
let vec = match self {
ImportDirective::Main(v) => vec![v],
ImportDirective::From(v) => v,
};
vec.into_iter()
}
}
type ImportPath<'ast> = &'ast Path;
#[derive(PartialEq, Clone)]
pub struct Import<'ast> {
source: ImportPath<'ast>,
symbol: Option<Identifier<'ast>>,
alias: Option<Identifier<'ast>>,
}
pub type ImportNode<'ast> = Node<Import<'ast>>;
impl<'ast> Import<'ast> {
pub fn new(symbol: Option<Identifier<'ast>>, source: ImportPath<'ast>) -> Import<'ast> {
Import {
symbol,
source,
alias: None,
}
}
pub fn get_alias(&self) -> &Option<Identifier<'ast>> {
&self.alias
}
pub fn new_with_alias(
symbol: Option<Identifier<'ast>>,
source: ImportPath<'ast>,
alias: Identifier<'ast>,
) -> Import<'ast> {
Import {
symbol,
source,
alias: Some(alias),
}
}
pub fn alias(mut self, alias: Option<Identifier<'ast>>) -> Self {
self.alias = alias;
self
}
pub fn get_source(&self) -> &ImportPath<'ast> {
&self.source
}
}
impl<'ast> fmt::Display for Import<'ast> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self.alias {
Some(ref alias) => write!(f, "import {} as {}", self.source.display(), alias),
None => write!(f, "import {}", self.source.display()),
}
}
}
impl<'ast> fmt::Debug for Import<'ast> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self.alias {
Some(ref alias) => write!(
f,
"import(source: {}, alias: {})",
self.source.display(),
alias
),
None => write!(f, "import(source: {})", self.source.display()),
}
}
}
pub struct Importer;
impl Importer {
@ -154,255 +66,157 @@ impl Importer {
modules: &mut HashMap<OwnedModuleId, Module<'ast>>,
arena: &'ast Arena<String>,
) -> Result<Module<'ast>, CompileErrors> {
let mut symbols: Vec<_> = vec![];
let symbols: Vec<_> = destination
.symbols
.into_iter()
.map(|s| match s.value.symbol {
Symbol::Here(SymbolDefinition::Import(import)) => {
Importer::resolve::<T, E>(import, &location, resolver, modules, arena)
}
_ => Ok(s),
})
.collect::<Result<_, _>>()?;
for import in destination.imports {
let pos = import.pos();
let import = import.value;
let alias = import.alias;
// handle the case of special bellman and packing imports
if import.source.starts_with("EMBED") {
match import.source.to_str().unwrap() {
#[cfg(feature = "bellman")]
"EMBED/sha256round" => {
if T::id() != Bn128Field::id() {
return Err(CompileErrorInner::ImportError(
Error::new(format!(
"Embed sha256round cannot be used with curve {}",
T::name()
))
.with_pos(Some(pos)),
)
.in_file(&location)
.into());
} else {
let alias = alias.unwrap_or("sha256round");
Ok(Module::with_symbols(symbols))
}
symbols.push(
SymbolDeclaration {
id: &alias,
symbol: Symbol::Flat(FlatEmbed::Sha256Round),
}
.start_end(pos.0, pos.1),
);
}
}
"EMBED/unpack" => {
let alias = alias.unwrap_or("unpack");
fn resolve<'ast, T: Field, E: Into<Error>>(
import: CanonicalImportNode<'ast>,
location: &Path,
resolver: Option<&dyn Resolver<E>>,
modules: &mut HashMap<OwnedModuleId, Module<'ast>>,
arena: &'ast Arena<String>,
) -> Result<SymbolDeclarationNode<'ast>, CompileErrors> {
let pos = import.pos();
let module_id = import.value.source;
let symbol = import.value.id;
symbols.push(
SymbolDeclaration {
id: &alias,
symbol: Symbol::Flat(FlatEmbed::Unpack),
}
.start_end(pos.0, pos.1),
);
}
"EMBED/u64_to_bits" => {
let alias = alias.unwrap_or("u64_to_bits");
symbols.push(
SymbolDeclaration {
id: &alias,
symbol: Symbol::Flat(FlatEmbed::U64ToBits),
}
.start_end(pos.0, pos.1),
);
}
"EMBED/u32_to_bits" => {
let alias = alias.unwrap_or("u32_to_bits");
symbols.push(
SymbolDeclaration {
id: &alias,
symbol: Symbol::Flat(FlatEmbed::U32ToBits),
}
.start_end(pos.0, pos.1),
);
}
"EMBED/u16_to_bits" => {
let alias = alias.unwrap_or("u16_to_bits");
symbols.push(
SymbolDeclaration {
id: &alias,
symbol: Symbol::Flat(FlatEmbed::U16ToBits),
}
.start_end(pos.0, pos.1),
);
}
"EMBED/u8_to_bits" => {
let alias = alias.unwrap_or("u8_to_bits");
symbols.push(
SymbolDeclaration {
id: &alias,
symbol: Symbol::Flat(FlatEmbed::U8ToBits),
}
.start_end(pos.0, pos.1),
);
}
"EMBED/u64_from_bits" => {
let alias = alias.unwrap_or("u64_from_bits");
symbols.push(
SymbolDeclaration {
id: &alias,
symbol: Symbol::Flat(FlatEmbed::U64FromBits),
}
.start_end(pos.0, pos.1),
);
}
"EMBED/u32_from_bits" => {
let alias = alias.unwrap_or("u32_from_bits");
symbols.push(
SymbolDeclaration {
id: &alias,
symbol: Symbol::Flat(FlatEmbed::U32FromBits),
}
.start_end(pos.0, pos.1),
);
}
"EMBED/u16_from_bits" => {
let alias = alias.unwrap_or("u16_from_bits");
symbols.push(
SymbolDeclaration {
id: &alias,
symbol: Symbol::Flat(FlatEmbed::U16FromBits),
}
.start_end(pos.0, pos.1),
);
}
"EMBED/u8_from_bits" => {
let alias = alias.unwrap_or("u8_from_bits");
symbols.push(
SymbolDeclaration {
id: &alias,
symbol: Symbol::Flat(FlatEmbed::U8FromBits),
}
.start_end(pos.0, pos.1),
);
}
s => {
let symbol_declaration = match module_id.to_str().unwrap() {
"EMBED" => match symbol.id {
#[cfg(feature = "bellman")]
"sha256round" => {
if T::id() != Bn128Field::id() {
return Err(CompileErrorInner::ImportError(
Error::new(format!("Embed {} not found", s)).with_pos(Some(pos)),
Error::new(format!(
"Embed sha256round cannot be used with curve {}",
T::name()
))
.with_pos(Some(pos)),
)
.in_file(&location)
.in_file(location)
.into());
} else {
SymbolDeclaration {
id: symbol.get_alias(),
symbol: Symbol::Flat(FlatEmbed::Sha256Round),
}
}
}
} else {
// to resolve imports, we need a resolver
match resolver {
Some(res) => match res.resolve(location.clone(), import.source.to_path_buf()) {
Ok((source, new_location)) => {
// generate an alias from the imported path if none was given explicitely
let alias = import.alias.unwrap_or(
std::path::Path::new(import.source)
.file_stem()
.ok_or_else(|| {
CompileErrors::from(
CompileErrorInner::ImportError(Error::new(format!(
"Could not determine alias for import {}",
import.source.display()
)))
.in_file(&location),
)
})?
.to_str()
.unwrap(),
);
match modules.get(&new_location) {
Some(_) => {}
None => {
let source = arena.alloc(source);
let compiled = compile_module::<T, E>(
source,
new_location.clone(),
resolver,
modules,
&arena,
)?;
assert!(modules
.insert(new_location.clone(), compiled)
.is_none());
}
};
symbols.push(
SymbolDeclaration {
id: &alias,
symbol: Symbol::There(
SymbolImport::with_id_in_module(
import.symbol.unwrap_or("main"),
new_location.display().to_string(),
)
.start_end(pos.0, pos.1),
),
}
.start_end(pos.0, pos.1),
);
}
Err(err) => {
return Err(CompileErrorInner::ImportError(
err.into().with_pos(Some(pos)),
)
.in_file(&location)
.into());
}
},
None => {
return Err(CompileErrorInner::from(Error::new(
"Can't resolve import without a resolver",
))
.in_file(&location)
.into());
}
"unpack" => SymbolDeclaration {
id: symbol.get_alias(),
symbol: Symbol::Flat(FlatEmbed::Unpack),
},
"u64_to_bits" => SymbolDeclaration {
id: symbol.get_alias(),
symbol: Symbol::Flat(FlatEmbed::U64ToBits),
},
"u32_to_bits" => SymbolDeclaration {
id: symbol.get_alias(),
symbol: Symbol::Flat(FlatEmbed::U32ToBits),
},
"u16_to_bits" => SymbolDeclaration {
id: symbol.get_alias(),
symbol: Symbol::Flat(FlatEmbed::U16ToBits),
},
"u8_to_bits" => SymbolDeclaration {
id: symbol.get_alias(),
symbol: Symbol::Flat(FlatEmbed::U8ToBits),
},
"u64_from_bits" => SymbolDeclaration {
id: symbol.get_alias(),
symbol: Symbol::Flat(FlatEmbed::U64FromBits),
},
"u32_from_bits" => SymbolDeclaration {
id: symbol.get_alias(),
symbol: Symbol::Flat(FlatEmbed::U32FromBits),
},
"u16_from_bits" => SymbolDeclaration {
id: symbol.get_alias(),
symbol: Symbol::Flat(FlatEmbed::U16FromBits),
},
"u8_from_bits" => SymbolDeclaration {
id: symbol.get_alias(),
symbol: Symbol::Flat(FlatEmbed::U8FromBits),
},
s => {
return Err(CompileErrorInner::ImportError(
Error::new(format!("Embed {} not found", s)).with_pos(Some(pos)),
)
.in_file(location)
.into());
}
}
}
},
_ => match resolver {
Some(res) => match res.resolve(location.to_path_buf(), module_id.to_path_buf()) {
Ok((source, new_location)) => {
let alias = symbol.alias.unwrap_or(
module_id
.file_stem()
.ok_or_else(|| {
CompileErrors::from(
CompileErrorInner::ImportError(Error::new(format!(
"Could not determine alias for import {}",
module_id.display()
)))
.in_file(location),
)
})?
.to_str()
.unwrap(),
);
symbols.extend(destination.symbols);
match modules.get(&new_location) {
Some(_) => {}
None => {
let source = arena.alloc(source);
let compiled = compile_module::<T, E>(
source,
new_location.clone(),
resolver,
modules,
&arena,
)?;
Ok(Module {
imports: vec![],
symbols,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn create_with_no_alias() {
assert_eq!(
Import::new(None, Path::new("./foo/bar/baz.zok")),
Import {
symbol: None,
source: Path::new("./foo/bar/baz.zok"),
alias: None,
}
);
}
#[test]
fn create_with_alias() {
assert_eq!(
Import::new_with_alias(None, Path::new("./foo/bar/baz.zok"), &"myalias"),
Import {
symbol: None,
source: Path::new("./foo/bar/baz.zok"),
alias: Some("myalias"),
}
);
assert!(modules.insert(new_location.clone(), compiled).is_none());
}
};
SymbolDeclaration {
id: &alias,
symbol: Symbol::There(
SymbolImport::with_id_in_module(symbol.id, new_location)
.start_end(pos.0, pos.1),
),
}
}
Err(err) => {
return Err(
CompileErrorInner::ImportError(err.into().with_pos(Some(pos)))
.in_file(location)
.into(),
);
}
},
None => {
return Err(CompileErrorInner::from(Error::new(
"Cannot resolve import without a resolver",
))
.in_file(location)
.into());
}
},
};
Ok(symbol_declaration.start_end(pos.0, pos.1))
}
}

File diff suppressed because it is too large Load diff

View file

@ -1,20 +1,37 @@
use crate::static_analysis::propagation::Propagator;
use crate::typed_absy::folder::*;
use crate::typed_absy::result_folder::ResultFolder;
use crate::typed_absy::types::{Constant, DeclarationStructType, GStructMember};
use crate::typed_absy::*;
use std::collections::HashMap;
use std::convert::TryInto;
use zokrates_field::Field;
pub struct ConstantInliner<'ast, T: Field> {
pub struct ConstantInliner<'ast, 'a, T: Field> {
modules: TypedModules<'ast, T>,
location: OwnedTypedModuleId,
propagator: Propagator<'ast, 'a, T>,
}
impl<'ast, T: Field> ConstantInliner<'ast, T> {
pub fn new(modules: TypedModules<'ast, T>, location: OwnedTypedModuleId) -> Self {
ConstantInliner { modules, location }
impl<'ast, 'a, T: Field> ConstantInliner<'ast, 'a, T> {
pub fn new(
modules: TypedModules<'ast, T>,
location: OwnedTypedModuleId,
propagator: Propagator<'ast, 'a, T>,
) -> Self {
ConstantInliner {
modules,
location,
propagator,
}
}
pub fn inline(p: TypedProgram<'ast, T>) -> TypedProgram<'ast, T> {
let mut inliner = ConstantInliner::new(p.modules.clone(), p.main.clone());
let mut constants = HashMap::new();
let mut inliner = ConstantInliner::new(
p.modules.clone(),
p.main.clone(),
Propagator::with_constants(&mut constants),
);
inliner.fold_program(p)
}
@ -51,12 +68,18 @@ impl<'ast, T: Field> ConstantInliner<'ast, T> {
let _ = self.change_location(location);
symbol
}
TypedConstantSymbol::Here(tc) => self.fold_constant(tc),
TypedConstantSymbol::Here(tc) => {
let tc: TypedConstant<T> = self.fold_constant(tc);
TypedConstant {
expression: self.propagator.fold_expression(tc.expression).unwrap(),
..tc
}
}
}
}
}
impl<'ast, T: Field> Folder<'ast, T> for ConstantInliner<'ast, T> {
impl<'ast, 'a, T: Field> Folder<'ast, T> for ConstantInliner<'ast, 'a, T> {
fn fold_program(&mut self, p: TypedProgram<'ast, T>) -> TypedProgram<'ast, T> {
TypedProgram {
modules: p
@ -71,6 +94,62 @@ impl<'ast, T: Field> Folder<'ast, T> for ConstantInliner<'ast, T> {
}
}
fn fold_declaration_type(&mut self, t: DeclarationType<'ast>) -> DeclarationType<'ast> {
match t {
DeclarationType::Array(ref array_ty) => match array_ty.size {
Constant::Identifier(name, _) => {
let tc = self.get_constant(&name.into()).unwrap();
let expression: UExpression<'ast, T> = tc.expression.try_into().unwrap();
match expression.inner {
UExpressionInner::Value(v) => DeclarationType::array((
self.fold_declaration_type(*array_ty.ty.clone()),
Constant::Concrete(v as u32),
)),
_ => unreachable!("expected u32 value"),
}
}
_ => t,
},
DeclarationType::Struct(struct_ty) => DeclarationType::struc(DeclarationStructType {
members: struct_ty
.members
.into_iter()
.map(|m| GStructMember::new(m.id, self.fold_declaration_type(*m.ty)))
.collect(),
..struct_ty
}),
_ => t,
}
}
fn fold_type(&mut self, t: Type<'ast, T>) -> Type<'ast, T> {
use self::GType::*;
match t {
Array(ref array_type) => match &array_type.size.inner {
UExpressionInner::Identifier(v) => match self.get_constant(v) {
Some(tc) => {
let expression: UExpression<'ast, T> = tc.expression.try_into().unwrap();
Type::array(GArrayType::new(
self.fold_type(*array_type.ty.clone()),
expression,
))
}
None => t,
},
_ => t,
},
Struct(struct_type) => Type::struc(GStructType {
members: struct_type
.members
.into_iter()
.map(|m| GStructMember::new(m.id, self.fold_type(*m.ty)))
.collect(),
..struct_type
}),
_ => t,
}
}
fn fold_constant_symbol(
&mut self,
s: TypedConstantSymbol<'ast, T>,
@ -636,11 +715,9 @@ mod tests {
let expected_main = TypedFunction {
arguments: vec![],
statements: vec![TypedStatement::Return(vec![FieldElementExpression::Add(
box FieldElementExpression::Number(Bn128Field::from(1)),
box FieldElementExpression::Number(Bn128Field::from(1)),
)
.into()])],
statements: vec![TypedStatement::Return(vec![
FieldElementExpression::Number(Bn128Field::from(2)).into(),
])],
signature: DeclarationSignature::new()
.inputs(vec![])
.outputs(vec![DeclarationType::FieldElement]),
@ -675,9 +752,8 @@ mod tests {
const_b_id,
TypedConstantSymbol::Here(TypedConstant::new(
GType::FieldElement,
TypedExpression::FieldElement(FieldElementExpression::Add(
box FieldElementExpression::Number(Bn128Field::from(1)),
box FieldElementExpression::Number(Bn128Field::from(1)),
TypedExpression::FieldElement(FieldElementExpression::Number(
Bn128Field::from(2),
)),
)),
),

View file

@ -31,10 +31,21 @@ pub trait Folder<'ast, T: Field>: Sized {
fold_function_symbol(self, s)
}
fn fold_declaration_function_key(
&mut self,
key: DeclarationFunctionKey<'ast>,
) -> DeclarationFunctionKey<'ast> {
fold_declaration_function_key(self, key)
}
fn fold_function(&mut self, f: TypedFunction<'ast, T>) -> TypedFunction<'ast, T> {
fold_function(self, f)
}
fn fold_signature(&mut self, s: DeclarationSignature<'ast>) -> DeclarationSignature<'ast> {
fold_signature(self, s)
}
fn fold_parameter(&mut self, p: DeclarationParameter<'ast>) -> DeclarationParameter<'ast> {
DeclarationParameter {
id: self.fold_declaration_variable(p.id),
@ -190,6 +201,7 @@ pub trait Folder<'ast, T: Field>: Sized {
) -> ArrayExpressionInner<'ast, T> {
fold_array_expression_inner(self, ty, e)
}
fn fold_struct_expression_inner(
&mut self,
ty: &StructType<'ast, T>,
@ -212,7 +224,12 @@ pub fn fold_module<'ast, T: Field, F: Folder<'ast, T>>(
functions: m
.functions
.into_iter()
.map(|(key, fun)| (key, f.fold_function_symbol(fun)))
.map(|(key, fun)| {
(
f.fold_declaration_function_key(key),
f.fold_function_symbol(fun),
)
})
.collect(),
}
}
@ -653,6 +670,16 @@ pub fn fold_uint_expression_inner<'ast, T: Field, F: Folder<'ast, T>>(
}
}
pub fn fold_declaration_function_key<'ast, T: Field, F: Folder<'ast, T>>(
f: &mut F,
key: DeclarationFunctionKey<'ast>,
) -> DeclarationFunctionKey<'ast> {
DeclarationFunctionKey {
signature: f.fold_signature(key.signature),
..key
}
}
pub fn fold_function<'ast, T: Field, F: Folder<'ast, T>>(
f: &mut F,
fun: TypedFunction<'ast, T>,
@ -668,7 +695,26 @@ pub fn fold_function<'ast, T: Field, F: Folder<'ast, T>>(
.into_iter()
.flat_map(|s| f.fold_statement(s))
.collect(),
..fun
signature: f.fold_signature(fun.signature),
}
}
fn fold_signature<'ast, T: Field, F: Folder<'ast, T>>(
f: &mut F,
s: DeclarationSignature<'ast>,
) -> DeclarationSignature<'ast> {
DeclarationSignature {
generics: s.generics,
inputs: s
.inputs
.into_iter()
.map(|o| f.fold_declaration_type(o))
.collect(),
outputs: s
.outputs
.into_iter()
.map(|o| f.fold_declaration_type(o))
.collect(),
}
}
@ -721,9 +767,10 @@ pub fn fold_struct_expression<'ast, T: Field, F: Folder<'ast, T>>(
f: &mut F,
e: StructExpression<'ast, T>,
) -> StructExpression<'ast, T> {
let ty = f.fold_struct_type(e.ty);
StructExpression {
inner: f.fold_struct_expression_inner(&e.ty, e.inner),
..e
inner: f.fold_struct_expression_inner(&ty, e.inner),
ty,
}
}

View file

@ -290,8 +290,8 @@ impl<'ast, T: fmt::Display> fmt::Display for TypedFunction<'ast, T> {
#[derive(Clone, PartialEq, Debug)]
pub struct TypedConstant<'ast, T> {
ty: Type<'ast, T>,
expression: TypedExpression<'ast, T>,
pub ty: Type<'ast, T>,
pub expression: TypedExpression<'ast, T>,
}
impl<'ast, T> TypedConstant<'ast, T> {

View file

@ -42,6 +42,13 @@ pub trait ResultFolder<'ast, T: Field>: Sized {
fold_function_symbol(self, s)
}
fn fold_declaration_function_key(
&mut self,
key: DeclarationFunctionKey<'ast>,
) -> Result<DeclarationFunctionKey<'ast>, Self::Error> {
fold_declaration_function_key(self, key)
}
fn fold_function(
&mut self,
f: TypedFunction<'ast, T>,
@ -49,6 +56,13 @@ pub trait ResultFolder<'ast, T: Field>: Sized {
fold_function(self, f)
}
fn fold_signature(
&mut self,
s: DeclarationSignature<'ast>,
) -> Result<DeclarationSignature<'ast>, Self::Error> {
fold_signature(self, s)
}
fn fold_parameter(
&mut self,
p: DeclarationParameter<'ast>,
@ -723,6 +737,16 @@ pub fn fold_uint_expression_inner<'ast, T: Field, F: ResultFolder<'ast, T>>(
Ok(e)
}
pub fn fold_declaration_function_key<'ast, T: Field, F: ResultFolder<'ast, T>>(
f: &mut F,
key: DeclarationFunctionKey<'ast>,
) -> Result<DeclarationFunctionKey<'ast>, F::Error> {
Ok(DeclarationFunctionKey {
signature: f.fold_signature(key.signature)?,
..key
})
}
pub fn fold_function<'ast, T: Field, F: ResultFolder<'ast, T>>(
f: &mut F,
fun: TypedFunction<'ast, T>,
@ -741,7 +765,26 @@ pub fn fold_function<'ast, T: Field, F: ResultFolder<'ast, T>>(
.into_iter()
.flatten()
.collect(),
..fun
signature: f.fold_signature(fun.signature)?,
})
}
fn fold_signature<'ast, T: Field, F: ResultFolder<'ast, T>>(
f: &mut F,
s: DeclarationSignature<'ast>,
) -> Result<DeclarationSignature<'ast>, F::Error> {
Ok(DeclarationSignature {
generics: s.generics,
inputs: s
.inputs
.into_iter()
.map(|o| f.fold_declaration_type(o))
.collect::<Result<_, _>>()?,
outputs: s
.outputs
.into_iter()
.map(|o| f.fold_declaration_type(o))
.collect::<Result<_, _>>()?,
})
}
@ -801,9 +844,10 @@ pub fn fold_struct_expression<'ast, T: Field, F: ResultFolder<'ast, T>>(
f: &mut F,
e: StructExpression<'ast, T>,
) -> Result<StructExpression<'ast, T>, F::Error> {
let ty = f.fold_struct_type(e.ty)?;
Ok(StructExpression {
inner: f.fold_struct_expression_inner(&e.ty, e.inner)?,
..e
inner: f.fold_struct_expression_inner(&ty, e.inner)?,
ty,
})
}

View file

@ -1,4 +1,4 @@
use crate::typed_absy::{OwnedTypedModuleId, UExpression, UExpressionInner};
use crate::typed_absy::{Identifier, OwnedTypedModuleId, UExpression, UExpressionInner};
use crate::typed_absy::{TryFrom, TryInto};
use serde::{de::Error, ser::SerializeMap, Deserialize, Deserializer, Serialize, Serializer};
use std::collections::BTreeMap;
@ -54,6 +54,7 @@ pub struct SpecializationError;
pub enum Constant<'ast> {
Generic(GenericIdentifier<'ast>),
Concrete(u32),
Identifier(&'ast str, usize),
}
impl<'ast> From<u32> for Constant<'ast> {
@ -79,6 +80,7 @@ impl<'ast> fmt::Display for Constant<'ast> {
match self {
Constant::Generic(i) => write!(f, "{}", i),
Constant::Concrete(v) => write!(f, "{}", v),
Constant::Identifier(v, _) => write!(f, "{}", v),
}
}
}
@ -96,6 +98,9 @@ impl<'ast, T> From<Constant<'ast>> for UExpression<'ast, T> {
UExpressionInner::Identifier(i.name.into()).annotate(UBitwidth::B32)
}
Constant::Concrete(v) => UExpressionInner::Value(v as u128).annotate(UBitwidth::B32),
Constant::Identifier(v, size) => {
UExpressionInner::Identifier(Identifier::from(v)).annotate(UBitwidth::from(size))
}
}
}
}
@ -920,6 +925,7 @@ pub mod signature {
}
},
Constant::Concrete(s0) => s1 == *s0 as usize,
Constant::Identifier(_, s0) => s1 == *s0,
}
}
(DeclarationType::FieldElement, GType::FieldElement)
@ -945,6 +951,7 @@ pub mod signature {
let size = match t0.size {
Constant::Generic(s) => constants.0.get(&s).cloned().ok_or(s),
Constant::Concrete(s) => Ok(s.into()),
Constant::Identifier(_, s) => Ok((s as u32).into()),
}?;
GType::Array(GArrayType { size, ty })

View file

@ -1,4 +1,5 @@
const field[2] ARRAY = [1, 2]
const u32 N = 2
const field[N] ARRAY = [1, 2]
def main() -> field[2]:
def main() -> field[N]:
return ARRAY

View file

@ -0,0 +1,16 @@
{
"entry_point": "./tests/tests/constants/array_size.zok",
"max_constraint_count": 2,
"tests": [
{
"input": {
"values": ["42", "42"]
},
"output": {
"Ok": {
"values": ["42", "42"]
}
}
}
]
}

View file

@ -0,0 +1,5 @@
const u32 SIZE = 2
def main(field[SIZE] a) -> field[SIZE]:
field[SIZE] b = a
return b

View file

@ -0,0 +1,16 @@
{
"entry_point": "./tests/tests/constants/mixed.zok",
"max_constraint_count": 6,
"tests": [
{
"input": {
"values": []
},
"output": {
"Ok": {
"values": ["1", "2", "1", "3", "4", "0"]
}
}
}
]
}

View file

@ -0,0 +1,15 @@
const u32 N = 2
const bool B = true
struct Foo {
field[N] a
bool b
}
const Foo[N] F = [
Foo { a: [1, 2], b: B },
Foo { a: [3, 4], b: !B }
]
def main() -> Foo[N]:
return F

View file

@ -0,0 +1,16 @@
{
"entry_point": "./tests/tests/constants/propagate.zok",
"max_constraint_count": 4,
"tests": [
{
"input": {
"values": []
},
"output": {
"Ok": {
"values": ["42", "42", "42", "42"]
}
}
}
]
}

View file

@ -0,0 +1,5 @@
const u32 TWO = 2
const u32 FOUR = TWO * TWO
def main() -> field[FOUR]:
return [42; FOUR]

View file

@ -1,6 +1,6 @@
{
"entry_point": "./tests/tests/constants/struct.zok",
"max_constraint_count": 1,
"max_constraint_count": 6,
"tests": [
{
"input": {
@ -8,7 +8,7 @@
},
"output": {
"Ok": {
"values": ["4"]
"values": ["1", "2", "3", "4", "5", "6"]
}
}
}

View file

@ -1,9 +1,14 @@
struct Foo {
field a
field b
const u32 N = 2
struct State {
field[N] a
field[N][N] b
}
const Foo FOO = Foo { a: 2, b: 2 }
const State STATE = State {
a: [1, 2],
b: [[3, 4], [5, 6]]
}
def main() -> field:
return FOO.a + FOO.b
def main() -> State:
return STATE

View file

@ -1,4 +1,4 @@
import "EMBED/unpack" as unpack
from "EMBED" import unpack
def main(field x):
bool[1] bits = unpack(x)

View file

@ -1,5 +1,5 @@
import "EMBED/u32_to_bits" as to_bits
import "EMBED/u32_from_bits" as from_bits
import "utils/casts/u32_to_bits" as to_bits
import "utils/casts/u32_from_bits" as from_bits
def rotl32<N>(u32 e) -> u32:
bool[32] b = to_bits(e)

View file

@ -1,5 +1,5 @@
import "EMBED/u32_to_bits" as to_bits
import "EMBED/u32_from_bits" as from_bits
import "utils/casts/u32_to_bits" as to_bits
import "utils/casts/u32_from_bits" as from_bits
def rotr32<N>(u32 e) -> u32:
bool[32] b = to_bits(e)

View file

@ -1,4 +1,4 @@
import "EMBED/unpack"
from "EMBED" import unpack
def main(field a) -> (bool[255]):

View file

@ -1,4 +1,4 @@
import "EMBED/unpack"
from "EMBED" import unpack
def main(field a) -> (bool[254]):

View file

@ -1,5 +1,5 @@
import "EMBED/u32_to_bits" as to_bits
import "EMBED/u32_from_bits" as from_bits
import "utils/casts/u32_to_bits" as to_bits
import "utils/casts/u32_from_bits" as from_bits
def right_rotate_2(u32 e) -> u32:
bool[32] b = to_bits(e)

View file

@ -1,11 +1,11 @@
import "EMBED/u64_to_bits" as to_bits_64
import "EMBED/u64_from_bits" as from_bits_64
import "EMBED/u32_to_bits" as to_bits_32
import "EMBED/u32_from_bits" as from_bits_32
import "EMBED/u16_to_bits" as to_bits_16
import "EMBED/u16_from_bits" as from_bits_16
import "EMBED/u8_to_bits" as to_bits_8
import "EMBED/u8_from_bits" as from_bits_8
import "utils/casts/u64_to_bits" as to_bits_64
import "utils/casts/u64_from_bits" as from_bits_64
import "utils/casts/u32_to_bits" as to_bits_32
import "utils/casts/u32_from_bits" as from_bits_32
import "utils/casts/u16_to_bits" as to_bits_16
import "utils/casts/u16_from_bits" as from_bits_16
import "utils/casts/u8_to_bits" as to_bits_8
import "utils/casts/u8_from_bits" as from_bits_8
def main(u64 d, u32 e, u16 f, u8 g) -> (u64, u32, u16, u8):
bool[64] d_bits = to_bits_64(d)

View file

@ -1,5 +1,5 @@
import "EMBED/u32_to_bits" as to_bits
import "EMBED/u32_from_bits" as from_bits
import "utils/casts/u32_to_bits" as to_bits
import "utils/casts/u32_from_bits" as from_bits
def right_rotate_2(u32 e) -> u32:
bool[32] b = to_bits(e)

View file

@ -1,60 +1,9 @@
import "EMBED/u32_to_bits" as to_bits
import "EMBED/u32_from_bits" as from_bits
from "EMBED" import u32_to_bits as to_bits
from "EMBED" import u32_from_bits as from_bits
def right_rotate_2(u32 e) -> u32:
def right_rotate<N>(u32 e) -> u32:
bool[32] b = to_bits(e)
u32 res = from_bits([...b[30..], ...b[..30]])
return res
def right_rotate_4(u32 e) -> u32:
bool[32] b = to_bits(e)
u32 res = from_bits([...b[28..], ...b[..28]])
return res
def right_rotate_6(u32 e) -> u32:
bool[32] b = to_bits(e)
u32 res = from_bits([...b[26..], ...b[..26]])
return res
def right_rotate_7(u32 e) -> u32:
bool[32] b = to_bits(e)
u32 res = from_bits([...b[25..], ...b[..25]])
return res
def right_rotate_11(u32 e) -> u32:
bool[32] b = to_bits(e)
u32 res = from_bits([...b[21..], ...b[..21]])
return res
def right_rotate_13(u32 e) -> u32:
bool[32] b = to_bits(e)
u32 res = from_bits([...b[19..], ...b[..19]])
return res
def right_rotate_17(u32 e) -> u32:
bool[32] b = to_bits(e)
u32 res = from_bits([...b[15..], ...b[..15]])
return res
def right_rotate_18(u32 e) -> u32:
bool[32] b = to_bits(e)
u32 res = from_bits([...b[14..], ...b[..14]])
return res
def right_rotate_19(u32 e) -> u32:
bool[32] b = to_bits(e)
u32 res = from_bits([...b[13..], ...b[..13]])
return res
def right_rotate_22(u32 e) -> u32:
bool[32] b = to_bits(e)
u32 res = from_bits([...b[10..], ...b[..10]])
return res
def right_rotate_25(u32 e) -> u32:
bool[32] b = to_bits(e)
u32 res = from_bits([...b[7..], ...b[..7]])
u32 res = from_bits([...b[32-N..], ...b[..32-N]])
return res
def main():
@ -62,7 +11,7 @@ def main():
u32 f = 0x01234567
// rotate
u32 rotated = right_rotate_4(e)
u32 rotated = right_rotate::<4>(e)
assert(rotated == 0x81234567)
// and
@ -93,16 +42,16 @@ def main():
assert(f == from_bits(expected2))
// S0
u32 e2 = right_rotate_2(e)
u32 e13 = right_rotate_13(e)
u32 e22 = right_rotate_22(e)
u32 e2 = right_rotate::<2>(e)
u32 e13 = right_rotate::<13>(e)
u32 e22 = right_rotate::<22>(e)
u32 S0 = e2 ^ e13 ^ e22
assert(S0 == 0x66146474)
// S1
u32 e6 = right_rotate_6(e)
u32 e11 = right_rotate_11(e)
u32 e25 = right_rotate_25(e)
u32 e6 = right_rotate::<6>(e)
u32 e11 = right_rotate::<11>(e)
u32 e25 = right_rotate::<25>(e)
u32 S1 = e6 ^ e11 ^ e25
assert(S1 == 0x3561abda)

View file

@ -1,5 +1,5 @@
import "EMBED/u32_to_bits" as to_bits
import "EMBED/u32_from_bits" as from_bits
import "utils/casts/u32_to_bits" as to_bits
import "utils/casts/u32_from_bits" as from_bits
def right_rotate_4(u32 e) -> u32:
bool[32] b = to_bits(e)

View file

@ -1,49 +1,9 @@
import "EMBED/u32_to_bits" as to_bits
import "EMBED/u32_from_bits" as from_bits
def right_rotate_2(u32 e) -> u32:
bool[32] b = to_bits(e)
return from_bits([...b[30..], ...b[..30]])
def right_rotate_6(u32 e) -> u32:
bool[32] b = to_bits(e)
return from_bits([...b[26..], ...b[..26]])
def right_rotate_7(u32 e) -> u32:
bool[32] b = to_bits(e)
return from_bits([...b[25..], ...b[..25]])
def right_rotate_11(u32 e) -> u32:
bool[32] b = to_bits(e)
return from_bits([...b[21..], ...b[..21]])
def right_rotate_13(u32 e) -> u32:
bool[32] b = to_bits(e)
return from_bits([...b[19..], ...b[..19]])
def right_rotate_17(u32 e) -> u32:
bool[32] b = to_bits(e)
return from_bits([...b[15..], ...b[..15]])
def right_rotate_18(u32 e) -> u32:
bool[32] b = to_bits(e)
return from_bits([...b[14..], ...b[..14]])
def right_rotate_19(u32 e) -> u32:
bool[32] b = to_bits(e)
return from_bits([...b[13..], ...b[..13]])
def right_rotate_22(u32 e) -> u32:
bool[32] b = to_bits(e)
return from_bits([...b[10..], ...b[..10]])
def right_rotate_25(u32 e) -> u32:
bool[32] b = to_bits(e)
return from_bits([...b[7..], ...b[..7]])
def right_rotate<N>(u32 x) -> u32:
return (x >> N) | (x << (32 - N))
def extend(u32[64] w, u32 i) -> u32:
u32 s0 = right_rotate_7(w[i-15]) ^ right_rotate_18(w[i-15]) ^ (w[i-15] >> 3)
u32 s1 = right_rotate_17(w[i-2]) ^ right_rotate_19(w[i-2]) ^ (w[i-2] >> 10)
u32 s0 = right_rotate::<7>(w[i-15]) ^ right_rotate::<18>(w[i-15]) ^ (w[i-15] >> 3)
u32 s1 = right_rotate::<17>(w[i-2]) ^ right_rotate::<19>(w[i-2]) ^ (w[i-2] >> 10)
return w[i-16] + s0 + w[i-7] + s1
def temp1(u32 e, u32 f, u32 g, u32 h, u32 k, u32 w) -> u32:
@ -51,7 +11,7 @@ def temp1(u32 e, u32 f, u32 g, u32 h, u32 k, u32 w) -> u32:
u32 ch = (e & f) ^ ((!e) & g)
// S1 := (e rightrotate 6) xor (e rightrotate 11) xor (e rightrotate 25)
u32 S1 = right_rotate_6(e) ^ right_rotate_11(e) ^ right_rotate_25(e)
u32 S1 = right_rotate::<6>(e) ^ right_rotate::<11>(e) ^ right_rotate::<25>(e)
// temp1 := h + S1 + ch + k + w
return h + S1 + ch + k + w
@ -61,7 +21,7 @@ def temp2(u32 a, u32 b, u32 c) -> u32:
u32 maj = (a & b) ^ (a & c) ^ (b & c)
// S0 := (a rightrotate 2) xor (a rightrotate 13) xor (a rightrotate 22)
u32 S0 = right_rotate_2(a) ^ right_rotate_13(a) ^ right_rotate_22(a)
u32 S0 = right_rotate::<2>(a) ^ right_rotate::<13>(a) ^ right_rotate::<22>(a)
// temp2 := S0 + maj
return S0 + maj

View file

@ -1,17 +1,5 @@
import "EMBED/u32_to_bits" as to_bits
import "EMBED/u32_from_bits" as from_bits
def right_rotate_6(u32 e) -> u32:
bool[32] b = to_bits(e)
return from_bits([...b[26..], ...b[..26]])
def right_rotate_11(u32 e) -> u32:
bool[32] b = to_bits(e)
return from_bits([...b[21..], ...b[..21]])
def right_rotate_25(u32 e) -> u32:
bool[32] b = to_bits(e)
return from_bits([...b[7..], ...b[..7]])
def right_rotate<N>(u32 x) -> u32:
return (x >> N) | (x << (32 - N))
// input constraining costs 6 * 33 = 198 constraints, the rest 200
def main(u32 e, u32 f, u32 g, u32 h, u32 k, u32 w) -> u32:
@ -19,7 +7,7 @@ def main(u32 e, u32 f, u32 g, u32 h, u32 k, u32 w) -> u32:
u32 ch = (e & f) ^ ((!e) & g) // should be 100 constraints
// S1 := (e rightrotate 6) xor (e rightrotate 11) xor (e rightrotate 25)
u32 S1 = right_rotate_6(e) ^ right_rotate_11(e) ^ right_rotate_25(e) // should be 66 constraints
u32 S1 = right_rotate::<6>(e) ^ right_rotate::<11>(e) ^ right_rotate::<25>(e) // should be 66 constraints
// temp1 := h + S1 + ch + k + w
return h + S1 + ch + k + w // should be 35 constraints

View file

@ -1,17 +1,5 @@
import "EMBED/u32_to_bits" as to_bits
import "EMBED/u32_from_bits" as from_bits
def right_rotate_2(u32 e) -> u32:
bool[32] b = to_bits(e)
return from_bits([...b[30..], ...b[..30]])
def right_rotate_13(u32 e) -> u32:
bool[32] b = to_bits(e)
return from_bits([...b[19..], ...b[..19]])
def right_rotate_22(u32 e) -> u32:
bool[32] b = to_bits(e)
return from_bits([...b[10..], ...b[..10]])
def right_rotate<N>(u32 x) -> u32:
return (x >> N) | (x << (32 - N))
// input constraining is 99 constraints, the rest is 265 -> total 364
def main(u32 a, u32 b, u32 c) -> u32:
@ -19,7 +7,7 @@ def main(u32 a, u32 b, u32 c) -> u32:
u32 maj = (a & b) ^ (a & c) ^ (b & c) // 165 constraints
// S0 := (a rightrotate 2) xor (a rightrotate 13) xor (a rightrotate 22)
u32 S0 = right_rotate_2(a) ^ right_rotate_13(a) ^ right_rotate_22(a) // 66 constraints
u32 S0 = right_rotate::<2>(a) ^ right_rotate::<13>(a) ^ right_rotate::<22>(a) // 66 constraints
// temp2 := S0 + maj
return S0 + maj // 34 constraints

View file

@ -1,9 +1,11 @@
file = { SOI ~ NEWLINE* ~ pragma? ~ NEWLINE* ~ import_directive* ~ NEWLINE* ~ ty_struct_definition* ~ NEWLINE* ~ const_definition* ~ NEWLINE* ~ function_definition* ~ EOI }
file = { SOI ~ NEWLINE* ~ pragma? ~ NEWLINE* ~ symbol_declaration* ~ EOI }
pragma = { "#pragma" ~ "curve" ~ curve }
curve = @{ (ASCII_ALPHANUMERIC | "_") * }
symbol_declaration = { (import_directive | ty_struct_definition | const_definition | function_definition) ~ NEWLINE* }
import_directive = { main_import_directive | from_import_directive }
from_import_directive = { "from" ~ "\"" ~ import_source ~ "\"" ~ "import" ~ import_symbol_list ~ NEWLINE* }
main_import_directive = { "import" ~ "\"" ~ import_source ~ "\"" ~ ("as" ~ identifier)? ~ NEWLINE+ }

View file

@ -12,12 +12,13 @@ pub use ast::{
Assignee, AssigneeAccess, BasicOrStructType, BasicType, BinaryExpression, BinaryOperator,
CallAccess, ConstantDefinition, ConstantGenericValue, DecimalLiteralExpression, DecimalNumber,
DecimalSuffix, DefinitionStatement, ExplicitGenerics, Expression, FieldType, File,
FromExpression, Function, HexLiteralExpression, HexNumberExpression, IdentifierExpression,
ImportDirective, ImportSource, ImportSymbol, InlineArrayExpression, InlineStructExpression,
InlineStructMember, IterationStatement, LiteralExpression, OptionallyTypedAssignee, Parameter,
PostfixExpression, Range, RangeOrExpression, ReturnStatement, Span, Spread, SpreadOrExpression,
Statement, StructDefinition, StructField, TernaryExpression, ToExpression, Type,
UnaryExpression, UnaryOperator, Underscore, Visibility,
FromExpression, FunctionDefinition, HexLiteralExpression, HexNumberExpression,
IdentifierExpression, ImportDirective, ImportSource, ImportSymbol, InlineArrayExpression,
InlineStructExpression, InlineStructMember, IterationStatement, LiteralExpression,
OptionallyTypedAssignee, Parameter, PostfixExpression, Range, RangeOrExpression,
ReturnStatement, Span, Spread, SpreadOrExpression, Statement, StructDefinition, StructField,
SymbolDeclaration, TernaryExpression, ToExpression, Type, UnaryExpression, UnaryOperator,
Underscore, Visibility,
};
mod ast {
@ -109,10 +110,7 @@ mod ast {
#[pest_ast(rule(Rule::file))]
pub struct File<'ast> {
pub pragma: Option<Pragma<'ast>>,
pub imports: Vec<ImportDirective<'ast>>,
pub structs: Vec<StructDefinition<'ast>>,
pub constants: Vec<ConstantDefinition<'ast>>,
pub functions: Vec<Function<'ast>>,
pub declarations: Vec<SymbolDeclaration<'ast>>,
pub eoi: EOI,
#[pest_ast(outer())]
pub span: Span<'ast>,
@ -135,6 +133,16 @@ mod ast {
pub span: Span<'ast>,
}
#[allow(clippy::large_enum_variant)]
#[derive(Debug, FromPest, PartialEq, Clone)]
#[pest_ast(rule(Rule::symbol_declaration))]
pub enum SymbolDeclaration<'ast> {
Import(ImportDirective<'ast>),
Constant(ConstantDefinition<'ast>),
Struct(StructDefinition<'ast>),
Function(FunctionDefinition<'ast>),
}
#[derive(Debug, FromPest, PartialEq, Clone)]
#[pest_ast(rule(Rule::ty_struct_definition))]
pub struct StructDefinition<'ast> {
@ -155,7 +163,7 @@ mod ast {
#[derive(Debug, FromPest, PartialEq, Clone)]
#[pest_ast(rule(Rule::function_definition))]
pub struct Function<'ast> {
pub struct FunctionDefinition<'ast> {
pub id: IdentifierExpression<'ast>,
pub generics: Vec<IdentifierExpression<'ast>>,
pub parameters: Vec<Parameter<'ast>>,
@ -194,7 +202,7 @@ mod ast {
#[derive(Debug, FromPest, PartialEq, Clone)]
#[pest_ast(rule(Rule::import_symbol))]
pub struct ImportSymbol<'ast> {
pub symbol: IdentifierExpression<'ast>,
pub id: IdentifierExpression<'ast>,
pub alias: Option<IdentifierExpression<'ast>>,
#[pest_ast(outer())]
pub span: Span<'ast>,
@ -1057,52 +1065,52 @@ mod tests {
generate_ast(&source),
Ok(File {
pragma: None,
structs: vec![],
constants: vec![],
functions: vec![Function {
generics: vec![],
id: IdentifierExpression {
value: String::from("main"),
span: Span::new(&source, 33, 37).unwrap()
},
parameters: vec![],
returns: vec![Type::Basic(BasicType::Field(FieldType {
span: Span::new(&source, 44, 49).unwrap()
}))],
statements: vec![Statement::Return(ReturnStatement {
expressions: vec![Expression::add(
Expression::Literal(LiteralExpression::DecimalLiteral(
DecimalLiteralExpression {
value: DecimalNumber {
declarations: vec![
SymbolDeclaration::Import(ImportDirective::Main(MainImportDirective {
source: ImportSource {
value: String::from("foo"),
span: Span::new(&source, 8, 11).unwrap()
},
alias: None,
span: Span::new(&source, 0, 29).unwrap()
})),
SymbolDeclaration::Function(FunctionDefinition {
generics: vec![],
id: IdentifierExpression {
value: String::from("main"),
span: Span::new(&source, 33, 37).unwrap()
},
parameters: vec![],
returns: vec![Type::Basic(BasicType::Field(FieldType {
span: Span::new(&source, 44, 49).unwrap()
}))],
statements: vec![Statement::Return(ReturnStatement {
expressions: vec![Expression::add(
Expression::Literal(LiteralExpression::DecimalLiteral(
DecimalLiteralExpression {
value: DecimalNumber {
span: Span::new(&source, 59, 60).unwrap()
},
suffix: None,
span: Span::new(&source, 59, 60).unwrap()
},
suffix: None,
span: Span::new(&source, 59, 60).unwrap()
}
)),
Expression::Literal(LiteralExpression::DecimalLiteral(
DecimalLiteralExpression {
value: DecimalNumber {
}
)),
Expression::Literal(LiteralExpression::DecimalLiteral(
DecimalLiteralExpression {
value: DecimalNumber {
span: Span::new(&source, 63, 64).unwrap()
},
suffix: None,
span: Span::new(&source, 63, 64).unwrap()
},
suffix: None,
span: Span::new(&source, 63, 64).unwrap()
}
)),
Span::new(&source, 59, 64).unwrap()
)],
span: Span::new(&source, 52, 64).unwrap(),
})],
span: Span::new(&source, 29, source.len()).unwrap(),
}],
imports: vec![ImportDirective::Main(MainImportDirective {
source: ImportSource {
value: String::from("foo"),
span: Span::new(&source, 8, 11).unwrap()
},
alias: None,
span: Span::new(&source, 0, 29).unwrap()
})],
}
)),
Span::new(&source, 59, 64).unwrap()
)],
span: Span::new(&source, 52, 64).unwrap(),
})],
span: Span::new(&source, 29, source.len()).unwrap(),
})
],
eoi: EOI {},
span: Span::new(&source, 0, 65).unwrap()
})
@ -1118,76 +1126,76 @@ mod tests {
generate_ast(&source),
Ok(File {
pragma: None,
structs: vec![],
constants: vec![],
functions: vec![Function {
generics: vec![],
id: IdentifierExpression {
value: String::from("main"),
span: Span::new(&source, 33, 37).unwrap()
},
parameters: vec![],
returns: vec![Type::Basic(BasicType::Field(FieldType {
span: Span::new(&source, 44, 49).unwrap()
}))],
statements: vec![Statement::Return(ReturnStatement {
expressions: vec![Expression::add(
Expression::Literal(LiteralExpression::DecimalLiteral(
DecimalLiteralExpression {
suffix: None,
value: DecimalNumber {
span: Span::new(&source, 59, 60).unwrap()
},
span: Span::new(&source, 59, 60).unwrap()
}
)),
Expression::mul(
declarations: vec![
SymbolDeclaration::Import(ImportDirective::Main(MainImportDirective {
source: ImportSource {
value: String::from("foo"),
span: Span::new(&source, 8, 11).unwrap()
},
alias: None,
span: Span::new(&source, 0, 29).unwrap()
})),
SymbolDeclaration::Function(FunctionDefinition {
generics: vec![],
id: IdentifierExpression {
value: String::from("main"),
span: Span::new(&source, 33, 37).unwrap()
},
parameters: vec![],
returns: vec![Type::Basic(BasicType::Field(FieldType {
span: Span::new(&source, 44, 49).unwrap()
}))],
statements: vec![Statement::Return(ReturnStatement {
expressions: vec![Expression::add(
Expression::Literal(LiteralExpression::DecimalLiteral(
DecimalLiteralExpression {
suffix: None,
value: DecimalNumber {
span: Span::new(&source, 63, 64).unwrap()
span: Span::new(&source, 59, 60).unwrap()
},
span: Span::new(&source, 63, 64).unwrap()
span: Span::new(&source, 59, 60).unwrap()
}
)),
Expression::pow(
Expression::mul(
Expression::Literal(LiteralExpression::DecimalLiteral(
DecimalLiteralExpression {
suffix: None,
value: DecimalNumber {
span: Span::new(&source, 63, 64).unwrap()
},
span: Span::new(&source, 63, 64).unwrap()
}
)),
Expression::pow(
Expression::Literal(LiteralExpression::DecimalLiteral(
DecimalLiteralExpression {
suffix: None,
value: DecimalNumber {
span: Span::new(&source, 67, 68).unwrap()
},
span: Span::new(&source, 67, 68).unwrap()
},
span: Span::new(&source, 67, 68).unwrap()
}
)),
Expression::Literal(LiteralExpression::DecimalLiteral(
DecimalLiteralExpression {
suffix: None,
value: DecimalNumber {
}
)),
Expression::Literal(LiteralExpression::DecimalLiteral(
DecimalLiteralExpression {
suffix: None,
value: DecimalNumber {
span: Span::new(&source, 72, 73).unwrap()
},
span: Span::new(&source, 72, 73).unwrap()
},
span: Span::new(&source, 72, 73).unwrap()
}
)),
Span::new(&source, 67, 73).unwrap()
}
)),
Span::new(&source, 67, 73).unwrap()
),
Span::new(&source, 63, 73).unwrap()
),
Span::new(&source, 63, 73).unwrap()
),
Span::new(&source, 59, 73).unwrap()
)],
span: Span::new(&source, 52, 73).unwrap(),
})],
span: Span::new(&source, 29, 74).unwrap(),
}],
imports: vec![ImportDirective::Main(MainImportDirective {
source: ImportSource {
value: String::from("foo"),
span: Span::new(&source, 8, 11).unwrap()
},
alias: None,
span: Span::new(&source, 0, 29).unwrap()
})],
Span::new(&source, 59, 73).unwrap()
)],
span: Span::new(&source, 52, 73).unwrap(),
})],
span: Span::new(&source, 29, 74).unwrap(),
})
],
eoi: EOI {},
span: Span::new(&source, 0, 74).unwrap()
})
@ -1203,61 +1211,61 @@ mod tests {
generate_ast(&source),
Ok(File {
pragma: None,
structs: vec![],
constants: vec![],
functions: vec![Function {
generics: vec![],
id: IdentifierExpression {
value: String::from("main"),
span: Span::new(&source, 33, 37).unwrap()
},
parameters: vec![],
returns: vec![Type::Basic(BasicType::Field(FieldType {
span: Span::new(&source, 44, 49).unwrap()
}))],
statements: vec![Statement::Return(ReturnStatement {
expressions: vec![Expression::if_else(
Expression::Literal(LiteralExpression::DecimalLiteral(
DecimalLiteralExpression {
suffix: None,
value: DecimalNumber {
declarations: vec![
SymbolDeclaration::Import(ImportDirective::Main(MainImportDirective {
source: ImportSource {
value: String::from("foo"),
span: Span::new(&source, 8, 11).unwrap()
},
alias: None,
span: Span::new(&source, 0, 29).unwrap()
})),
SymbolDeclaration::Function(FunctionDefinition {
generics: vec![],
id: IdentifierExpression {
value: String::from("main"),
span: Span::new(&source, 33, 37).unwrap()
},
parameters: vec![],
returns: vec![Type::Basic(BasicType::Field(FieldType {
span: Span::new(&source, 44, 49).unwrap()
}))],
statements: vec![Statement::Return(ReturnStatement {
expressions: vec![Expression::if_else(
Expression::Literal(LiteralExpression::DecimalLiteral(
DecimalLiteralExpression {
suffix: None,
value: DecimalNumber {
span: Span::new(&source, 62, 63).unwrap()
},
span: Span::new(&source, 62, 63).unwrap()
},
span: Span::new(&source, 62, 63).unwrap()
}
)),
Expression::Literal(LiteralExpression::DecimalLiteral(
DecimalLiteralExpression {
suffix: None,
value: DecimalNumber {
}
)),
Expression::Literal(LiteralExpression::DecimalLiteral(
DecimalLiteralExpression {
suffix: None,
value: DecimalNumber {
span: Span::new(&source, 69, 70).unwrap()
},
span: Span::new(&source, 69, 70).unwrap()
},
span: Span::new(&source, 69, 70).unwrap()
}
)),
Expression::Literal(LiteralExpression::DecimalLiteral(
DecimalLiteralExpression {
suffix: None,
value: DecimalNumber {
}
)),
Expression::Literal(LiteralExpression::DecimalLiteral(
DecimalLiteralExpression {
suffix: None,
value: DecimalNumber {
span: Span::new(&source, 76, 77).unwrap()
},
span: Span::new(&source, 76, 77).unwrap()
},
span: Span::new(&source, 76, 77).unwrap()
}
)),
Span::new(&source, 59, 80).unwrap()
)],
span: Span::new(&source, 52, 80).unwrap(),
})],
span: Span::new(&source, 29, 81).unwrap(),
}],
imports: vec![ImportDirective::Main(MainImportDirective {
source: ImportSource {
value: String::from("foo"),
span: Span::new(&source, 8, 11).unwrap()
},
alias: None,
span: Span::new(&source, 0, 29).unwrap()
})],
}
)),
Span::new(&source, 59, 80).unwrap()
)],
span: Span::new(&source, 52, 80).unwrap(),
})],
span: Span::new(&source, 29, 81).unwrap(),
})
],
eoi: EOI {},
span: Span::new(&source, 0, 81).unwrap()
})
@ -1272,9 +1280,7 @@ mod tests {
generate_ast(&source),
Ok(File {
pragma: None,
structs: vec![],
constants: vec![],
functions: vec![Function {
declarations: vec![SymbolDeclaration::Function(FunctionDefinition {
generics: vec![],
id: IdentifierExpression {
value: String::from("main"),
@ -1297,8 +1303,7 @@ mod tests {
span: Span::new(&source, 23, 33).unwrap(),
})],
span: Span::new(&source, 0, 34).unwrap(),
}],
imports: vec![],
})],
eoi: EOI {},
span: Span::new(&source, 0, 34).unwrap()
})
@ -1313,9 +1318,7 @@ mod tests {
generate_ast(&source),
Ok(File {
pragma: None,
structs: vec![],
constants: vec![],
functions: vec![Function {
declarations: vec![SymbolDeclaration::Function(FunctionDefinition {
generics: vec![],
id: IdentifierExpression {
value: String::from("main"),
@ -1403,8 +1406,7 @@ mod tests {
span: Span::new(&source, 23, 49).unwrap()
})],
span: Span::new(&source, 0, 50).unwrap(),
}],
imports: vec![],
})],
eoi: EOI {},
span: Span::new(&source, 0, 50).unwrap()
})

View file

@ -1,7 +1,7 @@
// https://tools.ietf.org/html/rfc7693
import "EMBED/u32_to_bits" as to_bits
import "EMBED/u32_from_bits" as from_bits
import "utils/casts/u32_to_bits"
import "utils/casts/u32_from_bits"
def rotr32<N>(u32 x) -> u32:
return (x >> N) | (x << (32 - N))

View file

@ -1,6 +1,6 @@
import "./512bitBool.zok" as pedersen
import "EMBED/u32_to_bits" as to_bits
import "EMBED/u32_from_bits" as from_bits
import "utils/casts/u32_to_bits" as to_bits
import "utils/casts/u32_from_bits" as from_bits
def main(u32[16] inputs) -> u32[8]:
bool[512] e = [\

View file

@ -1,5 +1,5 @@
#pragma curve bn128
import "EMBED/sha256round" as sha256round
from "EMBED" import sha256round
// a and b is NOT checked to be 0 or 1
// the return value is checked to be 0 or 1

View file

@ -1,4 +1,4 @@
import "EMBED/u32_from_bits" as from_bits
from "EMBED" import u32_from_bits
// convert an array of bool to an array of u32
// the sizes must match (one u32 for 32 bool) otherwise an error will happen
@ -9,7 +9,7 @@ def main<N, P>(bool[N] bits) -> u32[P]:
u32[P] res = [0; P]
for u32 i in 0..P do
res[i] = from_bits(bits[32 * i..32 * (i + 1)])
res[i] = u32_from_bits(bits[32 * i..32 * (i + 1)])
endfor
return res

View file

@ -1,6 +1,5 @@
import "EMBED/unpack" as unpack
import "EMBED/u16_from_bits" as from_bits
from "EMBED" import unpack, u16_from_bits
def main(field i) -> u16:
bool[16] bits = unpack(i)
return from_bits(bits)
return u16_from_bits(bits)

View file

@ -1,6 +1,5 @@
import "EMBED/unpack" as unpack
import "EMBED/u32_from_bits" as from_bits
from "EMBED" import unpack, u32_from_bits
def main(field i) -> u32:
bool[32] bits = unpack(i)
return from_bits(bits)
return u32_from_bits(bits)

View file

@ -1,6 +1,5 @@
import "EMBED/unpack" as unpack
import "EMBED/u64_from_bits" as from_bits
from "EMBED" import unpack, u64_from_bits
def main(field i) -> u64:
bool[64] bits = unpack(i)
return from_bits(bits)
return u64_from_bits(bits)

View file

@ -1,6 +1,5 @@
import "EMBED/unpack" as unpack
import "EMBED/u8_from_bits" as from_bits
from "EMBED" import unpack, u8_from_bits
def main(field i) -> u8:
bool[8] bits = unpack(i)
return from_bits(bits)
return u8_from_bits(bits)

View file

@ -1,4 +1,4 @@
import "EMBED/u16_from_bits" as from_bits
from "EMBED" import u16_from_bits
def main(bool[16] a) -> u16:
return from_bits(a)
return u16_from_bits(a)

View file

@ -1,4 +1,4 @@
import "EMBED/u16_to_bits" as to_bits
from "EMBED" import u16_to_bits
def main(u16 a) -> bool[16]:
return to_bits(a)
return u16_to_bits(a)

View file

@ -1,7 +1,7 @@
import "EMBED/u16_to_bits" as to_bits
from "EMBED" import u16_to_bits
def main(u16 i) -> field:
bool[16] bits = to_bits(i)
bool[16] bits = u16_to_bits(i)
field res = 0
for u32 j in 0..16 do
u32 exponent = 16 - j - 1

View file

@ -1,4 +1,4 @@
import "EMBED/u32_to_bits" as to_bits
from "EMBED" import u32_to_bits
def main<N, P>(u32[N] input) -> bool[P]:
assert(P == 32 * N)
@ -6,7 +6,7 @@ def main<N, P>(u32[N] input) -> bool[P]:
bool[P] res = [false; P]
for u32 i in 0..N do
bool[32] bits = to_bits(input[i])
bool[32] bits = u32_to_bits(input[i])
for u32 j in 0..32 do
res[i * 32 + j] = bits[j]
endfor

View file

@ -1,4 +1,4 @@
import "EMBED/u32_from_bits" as from_bits
from "EMBED" import u32_from_bits
def main(bool[32] a) -> u32:
return from_bits(a)
return u32_from_bits(a)

View file

@ -1,4 +1,4 @@
import "EMBED/u32_to_bits" as to_bits
from "EMBED" import u32_to_bits
def main(u32 a) -> bool[32]:
return to_bits(a)
return u32_to_bits(a)

View file

@ -1,7 +1,7 @@
import "EMBED/u32_to_bits" as to_bits
from "EMBED" import u32_to_bits
def main(u32 i) -> field:
bool[32] bits = to_bits(i)
bool[32] bits = u32_to_bits(i)
field res = 0
for u32 j in 0..32 do
u32 exponent = 32 - j - 1

View file

@ -1,4 +1,4 @@
import "EMBED/u64_from_bits" as from_bits
from "EMBED" import u64_from_bits
def main(bool[64] a) -> u64:
return from_bits(a)
return u64_from_bits(a)

View file

@ -1,4 +1,4 @@
import "EMBED/u64_to_bits" as to_bits
from "EMBED" import u64_to_bits
def main(u64 a) -> bool[64]:
return to_bits(a)
return u64_to_bits(a)

View file

@ -1,7 +1,7 @@
import "EMBED/u64_to_bits" as to_bits
from "EMBED" import u64_to_bits
def main(u64 i) -> field:
bool[64] bits = to_bits(i)
bool[64] bits = u64_to_bits(i)
field res = 0
for u32 j in 0..64 do
u32 exponent = 64 - j - 1

View file

@ -1,4 +1,4 @@
import "EMBED/u8_from_bits" as from_bits
from "EMBED" import u8_from_bits
def main(bool[8] a) -> u8:
return from_bits(a)
return u8_from_bits(a)

View file

@ -1,4 +1,4 @@
import "EMBED/u8_to_bits" as to_bits
from "EMBED" import u8_to_bits
def main(u8 a) -> bool[8]:
return to_bits(a)
return u8_to_bits(a)

View file

@ -1,7 +1,7 @@
import "EMBED/u8_to_bits" as to_bits
from "EMBED" import u8_to_bits
def main(u8 i) -> field:
bool[8] bits = to_bits(i)
bool[8] bits = u8_to_bits(i)
field res = 0
for u32 j in 0..8 do
u32 exponent = 8 - j - 1

View file

@ -1,12 +1,12 @@
#pragma curve bn128
import "EMBED/unpack" as unpack
import "./unpack" as unpack
// Unpack a field element as 256 big-endian bits
// Note: uniqueness of the output is not guaranteed
// For example, `0` can map to `[0, 0, ..., 0]` or to `bits(p)`
def main(field i) -> bool[256]:
bool[254] b = unpack(i)
bool[254] b = unpack::<254>(i)
return [false, false, ...b]

View file

@ -1,6 +1,6 @@
#pragma curve bn128
import "EMBED/unpack" as unpack
from "EMBED" import unpack
// Unpack a field element as N big endian bits
def main<N>(field i) -> bool[N]:

View file

@ -1,11 +1,20 @@
import "EMBED/u32_to_bits" as to_bits
from "../bool/pack256.zok" import main as pack256
import "../../casts/u32_to_bits"
import "../bool/pack256"
// pack 256 big-endian bits into one field element
// Note: This is not a injective operation as `p` is smaller than `2**256 - 1 for bn128
// For example, `[0, 0,..., 0]` and `bits(p)` both point to `0`
def main(u32[8] input) -> field:
bool[256] bits = [...to_bits(input[0]), ...to_bits(input[1]), ...to_bits(input[2]), ...to_bits(input[3]), ...to_bits(input[4]), ...to_bits(input[5]), ...to_bits(input[6]), ...to_bits(input[7])]
bool[256] bits = [
...u32_to_bits(input[0]),
...u32_to_bits(input[1]),
...u32_to_bits(input[2]),
...u32_to_bits(input[3]),
...u32_to_bits(input[4]),
...u32_to_bits(input[5]),
...u32_to_bits(input[6]),
...u32_to_bits(input[7])
]
return pack256(bits)