1
0
Fork 0
mirror of synced 2025-09-23 04:08:33 +00:00

Merge pull request #448 from Zokrates/structs

Implement structures
This commit is contained in:
Thibaut Schaeffer 2019-10-07 12:16:44 +09:00 committed by GitHub
commit 8abf996252
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
38 changed files with 4740 additions and 1278 deletions

View file

@ -42,9 +42,9 @@ jobs:
- restore_cache:
keys:
- v4-cargo-cache-{{ arch }}-{{ checksum "Cargo.lock" }}
- run:
name: Check format
command: rustup component add rustfmt-preview; cargo fmt --all -- --check
# - run:
# name: Check format
# command: rustup component add rustfmt; cargo fmt --all -- --check
- run:
name: Install libsnark prerequisites
command: ./scripts/install_libsnark_prerequisites.sh

View file

@ -1,27 +1,66 @@
## Imports
You can separate your code into multiple ZoKrates files using `import` statements, ignoring the `.zok` extension of the imported file:
You can separate your code into multiple ZoKrates files using `import` statements to import symbols, ignoring the `.zok` extension of the imported file.
### Import syntax
#### Symbol selection
The preferred way to import a symbol is by module and name:
```zokrates
from "./path/to/my/module" import MySymbol
// `MySymbol` is now in scope.
```
#### Aliasing
The `as` keyword enables renaming symbols:
```zokrates
from "./path/to/my/module" import MySymbol as MyAlias
// `MySymbol` is now in scope under the alias MyAlias.
```
#### Legacy
The legacy way to import a symbol is by only specifying a module:
```
import "./path/to/my/module"
```
In this case, the name of the symbol is assumed to be `main` and the alias is assumed to be the module's filename so that the above is equivalent to
```zokrates
from "./path/to/my/module" import main as module
// `main` is now in scope under the alias `module`.
```
Note that this legacy method is likely to be become deprecated, so it is recommended to use the preferred way instead.
### Symbols
Two type of symbols can be imported
#### Functions
Functions are imported by name. If many functions have the same name but different signatures, all of them get imported, and which one to use in a particular call is infered.
#### User-defined types
User-defined types declared with the `struct` keyword are imported by name.
### Relative Imports
You can import a resource in the same folder directly, like this:
```zokrates
import "./mycode"
from "./mycode" import foo
```
There also is a handy syntax to import from the parent directory:
```zokrates
import "../mycode"
from "../mycode" import foo
```
Also imports further up the file-system are supported:
```zokrates
import "../../../mycode"
```
You can also choose to rename the imported resource, like so:
```zokrates
import "./mycode" as abc
from "../../../mycode" import foo
```
### Absolute Imports

View file

@ -1,10 +1,10 @@
## Types
# Types
ZoKrates currently exposes two primitive types and a complex array type:
ZoKrates currently exposes two primitive types and two complex types:
### Primitive Types
## Primitive Types
#### `field`
### `field`
This is the most basic type in ZoKrates, and it represents a positive integer in `[0, p - 1]` where `p` is a (large) prime number.
@ -16,7 +16,7 @@ While `field` values mostly behave like unsigned integers, one should keep in mi
{{#include ../../../zokrates_cli/examples/book/field_overflow.zok}}
```
#### `bool`
### `bool`
ZoKrates has limited support for booleans, to the extent that they can only be used as the condition in `if ... else ... endif` expressions.
@ -24,9 +24,11 @@ You can use them for equality checks, inequality checks and inequality checks be
Note that while equality checks are cheap, inequality checks should be use wisely as they are orders of magnitude more expensive.
### Complex Types
## Complex Types
#### Arrays
ZoKrates provides two complex types, Arrays and Structs.
### Arrays
ZoKrates supports static arrays, i.e., their length needs to be known at compile time.
Arrays can contain elements of any type and have arbitrary dimensions.
@ -37,10 +39,10 @@ The following examples code shows examples of how to use arrays:
{{#include ../../../zokrates_cli/examples/book/array.zok}}
```
##### Declaration and Initialization
#### Declaration and Initialization
An array is defined by appending `[]` to a type literal representing the type of the array's elements.
Initialization always needs to happen in the same statement than declaration, unless the array is declared within a function's signature.
Initialization always needs to happen in the same statement as declaration, unless the array is declared within a function's signature.
For initialization, a list of comma-separated values is provided within brackets `[]`.
@ -54,7 +56,7 @@ The following code provides examples for declaration and initialization:
bool[13] b = [false; 13] // initialize a bool array with value false
```
##### Multidimensional Arrays
#### Multidimensional Arrays
As an array can contain any type of elements, it can contain arrays again.
There is a special syntax to declare such multi-dimensional arrays, i.e., arrays of arrays.
@ -67,21 +69,59 @@ Consider the following example:
{{#include ../../../zokrates_cli/examples/book/multidim_array.zok}}
```
##### Spreads and Slices
#### Spreads and Slices
ZoKrates provides some syntactic sugar to retrieve subsets of arrays.
###### Spreads
The spread operator `...` applied to an copies the elements of an existing array.
##### Spreads
The spread operator `...` applied to an array copies the elements of the existing array.
This can be used to conveniently compose new arrays, as shown in the following example:
```
field[3] = [1, 2, 3]
field[4] c = [...a, 4] // initialize an array copying values from `a`, followed by 4
```
###### Slices
##### Slices
An array can also be assigned to by creating a copy of a subset of an existing array.
This operation is called slicing, and the following example shows how to slice in ZoKrates:
```
field[3] a = [1, 2, 3]
field[2] b = a[1..3] // initialize an array copying a slice from `a`
```
### Structs
A struct is a composite datatype representing a named collection of variables.
The contained variables can be of any type.
The following code shows an example of how to use structs.
```zokrates
{{#include ../../../zokrates_cli/examples/book/structs.code}}
```
#### Definition
Before a struct data type can be used, it needs to be defined.
A struct definition starts with the `struct` keyword followed by a name. Afterwards, a new-line separated list of variables is declared in curly braces `{}`. For example:
```zokrates
struct Point {
field x
field y
}
```
#### Declaration and Initialization
Initialization of a variable of a struct type always needs to happen in the same statement as declaration, unless the struct-typed variable is declared within a function's signature.
The following example shows declaration and initialization of a variable of the `Point` struct type:
```zokrates
{{#include ../../../zokrates_cli/examples/book/struct_init.code}}
```
#### Assignment
The variables within a struct instance, the so called members, can be accessed through the `.` operator as shown in the following extended example:
```zokrates
{{#include ../../../zokrates_cli/examples/book/struct_assign.code}}
```

View file

@ -0,0 +1,10 @@
struct Point {
field x
field y
}
def main(field a) -> (Point):
Point p = Point {x: 1, y: 0}
p.x = a
p.y = p.x
return p

View file

@ -0,0 +1,8 @@
struct Point {
field x
field y
}
def main() -> (Point):
Point p = Point {x: 1, y: 0}
return p

View file

@ -0,0 +1,14 @@
struct Bar {
field[2] c
bool d
}
struct Foo {
Bar a
bool b
}
def main() -> (Foo):
Foo[2] f = [Foo { a: Bar { c: [0, 0], d: false }, b: true}, Foo { a: Bar {c: [0, 0], d: false}, b: true}]
f[0].a.c = [42, 43]
return f[0]

View file

@ -0,0 +1,16 @@
struct Point {
field x
field y
}
def main(Point p, Point q) -> (Point):
field a = 42
field d = 21
field dpxpyqxqy = d * p.x * p.y * q.x * q.y
return Point {
x: (p.x * q.y + q.x * p.y) / (1 + dpxpyqxqy),
y: (q.x * q.y - a * p.x * p.y) / (1 - dpxpyqxqy)
}

View file

@ -0,0 +1,29 @@
struct Bar {
field[2] c
bool d
}
struct Foo {
Bar a
bool b
}
def main() -> (Foo):
Foo[2] f = [
Foo {
a: Bar {
c: [0, 0],
d: false
},
b: true
},
Foo {
a: Bar {
c: [0, 0],
d: false
},
b: true
}
]
f[0].a.c = [42, 43]
return f[0]

View file

@ -1,48 +1,96 @@
use absy;
use imports;
use types::Type;
use zokrates_field::field::Field;
use zokrates_pest_ast as pest;
impl<'ast, T: Field> From<pest::File<'ast>> for absy::Module<'ast, T> {
fn from(prog: pest::File<'ast>) -> absy::Module<T> {
absy::Module {
functions: prog
.functions
absy::Module::with_symbols(
prog.structs
.into_iter()
.map(|f| absy::FunctionDeclarationNode::from(f))
.collect(),
imports: prog
.imports
.into_iter()
.map(|i| absy::ImportNode::from(i))
.collect(),
}
.map(|t| absy::SymbolDeclarationNode::from(t))
.chain(
prog.functions
.into_iter()
.map(|f| absy::SymbolDeclarationNode::from(f)),
),
)
.imports(prog.imports.into_iter().map(|i| absy::ImportNode::from(i)))
}
}
impl<'ast> From<pest::ImportDirective<'ast>> for absy::ImportNode<'ast> {
fn from(import: pest::ImportDirective<'ast>) -> absy::ImportNode {
use absy::NodeValue;
imports::Import::new(import.source.span.as_str())
match import {
pest::ImportDirective::Main(import) => {
imports::Import::new(None, import.source.span.as_str())
.alias(import.alias.map(|a| a.span.as_str()))
.span(import.span)
}
pest::ImportDirective::From(import) => imports::Import::new(
Some(import.symbol.span.as_str()),
import.source.span.as_str(),
)
.alias(import.alias.map(|a| a.span.as_str()))
.span(import.span)
.span(import.span),
}
}
}
impl<'ast, T: Field> From<pest::Function<'ast>> for absy::FunctionDeclarationNode<'ast, T> {
fn from(function: pest::Function<'ast>) -> absy::FunctionDeclarationNode<T> {
impl<'ast, T: Field> From<pest::StructDefinition<'ast>> for absy::SymbolDeclarationNode<'ast, T> {
fn from(definition: pest::StructDefinition<'ast>) -> absy::SymbolDeclarationNode<'ast, T> {
use absy::NodeValue;
let span = definition.span;
let id = definition.id.span.as_str();
let ty = absy::StructType {
fields: definition
.fields
.into_iter()
.map(|f| absy::StructFieldNode::from(f))
.collect(),
}
.span(span.clone());
absy::SymbolDeclaration {
id,
symbol: absy::Symbol::HereType(ty),
}
.span(span)
}
}
impl<'ast> From<pest::StructField<'ast>> for absy::StructFieldNode<'ast> {
fn from(field: pest::StructField<'ast>) -> absy::StructFieldNode {
use absy::NodeValue;
let span = field.span;
let id = field.id.span.as_str();
let ty = absy::UnresolvedTypeNode::from(field.ty);
absy::StructField { id, ty }.span(span)
}
}
impl<'ast, T: Field> From<pest::Function<'ast>> for absy::SymbolDeclarationNode<'ast, T> {
fn from(function: pest::Function<'ast>) -> absy::SymbolDeclarationNode<T> {
use absy::NodeValue;
let span = function.span;
let signature = absy::Signature::new()
let signature = absy::UnresolvedSignature::new()
.inputs(
function
.parameters
.clone()
.into_iter()
.map(|p| absy::ParameterNode::from(p).value.id.value.get_type())
.map(|p| absy::UnresolvedTypeNode::from(p.ty))
.collect(),
)
.outputs(
@ -50,7 +98,7 @@ impl<'ast, T: Field> From<pest::Function<'ast>> for absy::FunctionDeclarationNod
.returns
.clone()
.into_iter()
.map(|r| Type::from(r))
.map(|r| absy::UnresolvedTypeNode::from(r))
.collect(),
);
@ -71,9 +119,9 @@ impl<'ast, T: Field> From<pest::Function<'ast>> for absy::FunctionDeclarationNod
}
.span(span.clone());
absy::FunctionDeclaration {
absy::SymbolDeclaration {
id,
symbol: absy::FunctionSymbol::Here(function),
symbol: absy::Symbol::HereFunction(function),
}
.span(span)
}
@ -91,8 +139,11 @@ impl<'ast> From<pest::Parameter<'ast>> for absy::ParameterNode<'ast> {
})
.unwrap_or(false);
let variable =
absy::Variable::new(param.id.span.as_str(), Type::from(param.ty)).span(param.id.span);
let variable = absy::Variable::new(
param.id.span.as_str(),
absy::UnresolvedTypeNode::from(param.ty),
)
.span(param.id.span);
absy::Parameter::new(variable, private).span(param.span)
}
@ -123,7 +174,11 @@ fn statements_from_multi_assignment<'ast, T: Field>(
.filter(|i| i.ty.is_some())
.map(|i| {
absy::Statement::Declaration(
absy::Variable::new(i.id.span.as_str(), Type::from(i.ty.unwrap())).span(i.id.span),
absy::Variable::new(
i.id.span.as_str(),
absy::UnresolvedTypeNode::from(i.ty.unwrap()),
)
.span(i.id.span),
)
.span(i.span)
});
@ -158,8 +213,11 @@ fn statements_from_definition<'ast, T: Field>(
vec![
absy::Statement::Declaration(
absy::Variable::new(definition.id.span.as_str(), Type::from(definition.ty))
.span(definition.id.span.clone()),
absy::Variable::new(
definition.id.span.as_str(),
absy::UnresolvedTypeNode::from(definition.ty),
)
.span(definition.id.span.clone()),
)
.span(definition.span.clone()),
absy::Statement::Definition(
@ -218,7 +276,7 @@ impl<'ast, T: Field> From<pest::IterationStatement<'ast>> for absy::StatementNod
let from = absy::ExpressionNode::from(statement.from);
let to = absy::ExpressionNode::from(statement.to);
let index = statement.index.span.as_str();
let ty = Type::from(statement.ty);
let ty = absy::UnresolvedTypeNode::from(statement.ty);
let statements: Vec<absy::StatementNode<T>> = statement
.statements
.into_iter()
@ -262,6 +320,7 @@ impl<'ast, T: Field> From<pest::Expression<'ast>> for absy::ExpressionNode<'ast,
pest::Expression::Identifier(e) => absy::ExpressionNode::from(e),
pest::Expression::Postfix(e) => absy::ExpressionNode::from(e),
pest::Expression::InlineArray(e) => absy::ExpressionNode::from(e),
pest::Expression::InlineStruct(e) => absy::ExpressionNode::from(e),
pest::Expression::ArrayInitializer(e) => absy::ExpressionNode::from(e),
pest::Expression::Unary(e) => absy::ExpressionNode::from(e),
}
@ -414,6 +473,25 @@ impl<'ast, T: Field> From<pest::InlineArrayExpression<'ast>> for absy::Expressio
}
}
impl<'ast, T: Field> From<pest::InlineStructExpression<'ast>> for absy::ExpressionNode<'ast, T> {
fn from(s: pest::InlineStructExpression<'ast>) -> absy::ExpressionNode<'ast, T> {
use absy::NodeValue;
absy::Expression::InlineStruct(
s.ty.span.as_str().to_string(),
s.members
.into_iter()
.map(|member| {
(
member.id.span.as_str(),
absy::ExpressionNode::from(member.expression),
)
})
.collect(),
)
.span(s.span)
}
}
impl<'ast, T: Field> From<pest::ArrayInitializerExpression<'ast>>
for absy::ExpressionNode<'ast, T>
{
@ -471,6 +549,9 @@ impl<'ast, T: Field> From<pest::PostfixExpression<'ast>> for absy::ExpressionNod
absy::Expression::Select(box acc, box absy::RangeOrExpression::from(a.expression))
.span(a.span)
}
pest::Access::Member(m) => {
absy::Expression::Member(box acc, box m.id.span.as_str()).span(m.span)
}
})
}
}
@ -511,29 +592,44 @@ impl<'ast, T: Field> From<pest::Assignee<'ast>> for absy::AssigneeNode<'ast, T>
let a = absy::AssigneeNode::from(assignee.id);
let span = assignee.span;
assignee
.indices
.into_iter()
.map(|i| absy::RangeOrExpression::from(i))
.fold(a, |acc, s| {
absy::Assignee::Select(box acc, box s).span(span.clone())
})
assignee.accesses.into_iter().fold(a, |acc, s| {
match s {
pest::AssigneeAccess::Select(s) => {
absy::Assignee::Select(box acc, box absy::RangeOrExpression::from(s.expression))
}
pest::AssigneeAccess::Member(m) => {
absy::Assignee::Member(box acc, box m.id.span.as_str())
}
}
.span(span.clone())
})
}
}
impl<'ast> From<pest::Type<'ast>> for Type {
fn from(t: pest::Type<'ast>) -> Type {
impl<'ast> From<pest::Type<'ast>> for absy::UnresolvedTypeNode {
fn from(t: pest::Type<'ast>) -> absy::UnresolvedTypeNode {
use absy::NodeValue;
match t {
pest::Type::Basic(t) => match t {
pest::BasicType::Field(_) => Type::FieldElement,
pest::BasicType::Boolean(_) => Type::Boolean,
pest::BasicType::Field(t) => absy::UnresolvedType::FieldElement.span(t.span),
pest::BasicType::Boolean(t) => absy::UnresolvedType::Boolean.span(t.span),
},
pest::Type::Array(t) => {
let inner_type = match t.ty {
pest::BasicType::Field(_) => Type::FieldElement,
pest::BasicType::Boolean(_) => Type::Boolean,
pest::BasicOrStructType::Basic(t) => match t {
pest::BasicType::Field(t) => {
absy::UnresolvedType::FieldElement.span(t.span)
}
pest::BasicType::Boolean(t) => absy::UnresolvedType::Boolean.span(t.span),
},
pest::BasicOrStructType::Struct(t) => {
absy::UnresolvedType::User(t.span.as_str().to_string()).span(t.span)
}
};
let span = t.span;
t.dimensions
.into_iter()
.map(|s| match s {
@ -553,10 +649,14 @@ impl<'ast> From<pest::Type<'ast>> for Type {
})
.rev()
.fold(None, |acc, s| match acc {
None => Some(Type::array(inner_type.clone(), s)),
Some(acc) => Some(Type::array(acc, s)),
None => Some(absy::UnresolvedType::array(inner_type.clone(), s)),
Some(acc) => Some(absy::UnresolvedType::array(acc.span(span.clone()), s)),
})
.unwrap()
.span(span.clone())
}
pest::Type::Struct(s) => {
absy::UnresolvedType::User(s.id.span.as_str().to_string()).span(s.span)
}
}
}
@ -565,6 +665,7 @@ impl<'ast> From<pest::Type<'ast>> for Type {
#[cfg(test)]
mod tests {
use super::*;
use absy::NodeValue;
use zokrates_field::field::FieldPrime;
#[test]
@ -572,9 +673,9 @@ mod tests {
let source = "def main() -> (field): return 42";
let ast = pest::generate_ast(&source).unwrap();
let expected: absy::Module<FieldPrime> = absy::Module {
functions: vec![absy::FunctionDeclaration {
symbols: vec![absy::SymbolDeclaration {
id: &source[4..8],
symbol: absy::FunctionSymbol::Here(
symbol: absy::Symbol::HereFunction(
absy::Function {
arguments: vec![],
statements: vec![absy::Statement::Return(
@ -587,9 +688,9 @@ mod tests {
.into(),
)
.into()],
signature: absy::Signature::new()
signature: absy::UnresolvedSignature::new()
.inputs(vec![])
.outputs(vec![Type::FieldElement]),
.outputs(vec![absy::UnresolvedType::FieldElement.mock()]),
}
.into(),
),
@ -605,9 +706,9 @@ mod tests {
let source = "def main() -> (bool): return true";
let ast = pest::generate_ast(&source).unwrap();
let expected: absy::Module<FieldPrime> = absy::Module {
functions: vec![absy::FunctionDeclaration {
symbols: vec![absy::SymbolDeclaration {
id: &source[4..8],
symbol: absy::FunctionSymbol::Here(
symbol: absy::Symbol::HereFunction(
absy::Function {
arguments: vec![],
statements: vec![absy::Statement::Return(
@ -617,9 +718,9 @@ mod tests {
.into(),
)
.into()],
signature: absy::Signature::new()
signature: absy::UnresolvedSignature::new()
.inputs(vec![])
.outputs(vec![Type::Boolean]),
.outputs(vec![absy::UnresolvedType::Boolean.mock()]),
}
.into(),
),
@ -636,17 +737,25 @@ mod tests {
let ast = pest::generate_ast(&source).unwrap();
let expected: absy::Module<FieldPrime> = absy::Module {
functions: vec![absy::FunctionDeclaration {
symbols: vec![absy::SymbolDeclaration {
id: &source[4..8],
symbol: absy::FunctionSymbol::Here(
symbol: absy::Symbol::HereFunction(
absy::Function {
arguments: vec![
absy::Parameter::private(
absy::Variable::field_element(&source[23..24]).into(),
absy::Variable::new(
&source[23..24],
absy::UnresolvedType::FieldElement.mock(),
)
.into(),
)
.into(),
absy::Parameter::public(
absy::Variable::boolean(&source[31..32]).into(),
absy::Variable::new(
&source[31..32],
absy::UnresolvedType::Boolean.mock(),
)
.into(),
)
.into(),
],
@ -660,9 +769,12 @@ mod tests {
.into(),
)
.into()],
signature: absy::Signature::new()
.inputs(vec![Type::FieldElement, Type::Boolean])
.outputs(vec![Type::FieldElement]),
signature: absy::UnresolvedSignature::new()
.inputs(vec![
absy::UnresolvedType::FieldElement.mock(),
absy::UnresolvedType::Boolean.mock(),
])
.outputs(vec![absy::UnresolvedType::FieldElement.mock()]),
}
.into(),
),
@ -678,14 +790,14 @@ mod tests {
use super::*;
/// Helper method to generate the ast for `def main(private {ty} a) -> (): return` which we use to check ty
fn wrap(ty: types::Type) -> absy::Module<'static, FieldPrime> {
fn wrap(ty: absy::UnresolvedType) -> absy::Module<'static, FieldPrime> {
absy::Module {
functions: vec![absy::FunctionDeclaration {
symbols: vec![absy::SymbolDeclaration {
id: "main",
symbol: absy::FunctionSymbol::Here(
symbol: absy::Symbol::HereFunction(
absy::Function {
arguments: vec![absy::Parameter::private(
absy::Variable::new("a", ty.clone()).into(),
absy::Variable::new("a", ty.clone().mock()).into(),
)
.into()],
statements: vec![absy::Statement::Return(
@ -695,7 +807,7 @@ mod tests {
.into(),
)
.into()],
signature: absy::Signature::new().inputs(vec![ty]),
signature: absy::UnresolvedSignature::new().inputs(vec![ty.mock()]),
}
.into(),
),
@ -708,19 +820,33 @@ mod tests {
#[test]
fn array() {
let vectors = vec![
("field", types::Type::FieldElement),
("bool", types::Type::Boolean),
("field", absy::UnresolvedType::FieldElement),
("bool", absy::UnresolvedType::Boolean),
(
"field[2]",
types::Type::Array(box types::Type::FieldElement, 2),
absy::UnresolvedType::Array(box absy::UnresolvedType::FieldElement.mock(), 2),
),
(
"field[2][3]",
types::Type::Array(box Type::Array(box types::Type::FieldElement, 3), 2),
absy::UnresolvedType::Array(
box absy::UnresolvedType::Array(
box absy::UnresolvedType::FieldElement.mock(),
3,
)
.mock(),
2,
),
),
(
"bool[2][3]",
types::Type::Array(box Type::Array(box types::Type::Boolean, 3), 2),
absy::UnresolvedType::Array(
box absy::UnresolvedType::Array(
box absy::UnresolvedType::Boolean.mock(),
3,
)
.mock(),
2,
),
),
];
@ -737,9 +863,9 @@ mod tests {
use super::*;
fn wrap(expression: absy::Expression<'static, FieldPrime>) -> absy::Module<FieldPrime> {
absy::Module {
functions: vec![absy::FunctionDeclaration {
symbols: vec![absy::SymbolDeclaration {
id: "main",
symbol: absy::FunctionSymbol::Here(
symbol: absy::Symbol::HereFunction(
absy::Function {
arguments: vec![],
statements: vec![absy::Statement::Return(
@ -749,7 +875,7 @@ mod tests {
.into(),
)
.into()],
signature: absy::Signature::new(),
signature: absy::UnresolvedSignature::new(),
}
.into(),
),

View file

@ -8,12 +8,13 @@
mod from_ast;
mod node;
pub mod parameter;
pub mod types;
pub mod variable;
pub use crate::absy::node::{Node, NodeValue};
pub use crate::absy::parameter::{Parameter, ParameterNode};
use crate::absy::types::{FunctionIdentifier, UnresolvedSignature, UnresolvedType, UserTypeId};
pub use crate::absy::variable::{Variable, VariableNode};
use crate::types::{FunctionIdentifier, Signature};
use embed::FlatEmbed;
use crate::imports::ImportNode;
@ -31,8 +32,8 @@ pub type ModuleId = String;
/// A collection of `Module`s
pub type Modules<'ast, T> = HashMap<ModuleId, Module<'ast, T>>;
/// A collection of `FunctionDeclaration`. Duplicates are allowed here as they are fine syntatically.
pub type FunctionDeclarations<'ast, T> = Vec<FunctionDeclarationNode<'ast, T>>;
/// A collection of `SymbolDeclaration`. Duplicates are allowed here as they are fine syntactically.
pub type Declarations<'ast, T> = Vec<SymbolDeclarationNode<'ast, T>>;
/// A `Program` is a collection of `Module`s and an id of the main `Module`
pub struct Program<'ast, T: Field> {
@ -42,17 +43,26 @@ pub struct Program<'ast, T: Field> {
/// A declaration of a `FunctionSymbol`, be it from an import or a function definition
#[derive(PartialEq, Debug, Clone)]
pub struct FunctionDeclaration<'ast, T: Field> {
pub struct SymbolDeclaration<'ast, T: Field> {
pub id: Identifier<'ast>,
pub symbol: FunctionSymbol<'ast, T>,
pub symbol: Symbol<'ast, T>,
}
impl<'ast, T: Field> fmt::Display for FunctionDeclaration<'ast, T> {
#[derive(PartialEq, Debug, Clone)]
pub enum Symbol<'ast, T: Field> {
HereType(StructTypeNode<'ast>),
HereFunction(FunctionNode<'ast, T>),
There(SymbolImportNode<'ast>),
Flat(FlatEmbed),
}
impl<'ast, T: Field> fmt::Display for SymbolDeclaration<'ast, T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self.symbol {
FunctionSymbol::Here(ref fun) => write!(f, "def {}{}", self.id, fun),
FunctionSymbol::There(ref import) => write!(f, "import {} as {}", import, self.id),
FunctionSymbol::Flat(ref flat_fun) => write!(
Symbol::HereType(ref t) => write!(f, "struct {} {}", self.id, t),
Symbol::HereFunction(ref fun) => write!(f, "def {}{}", self.id, fun),
Symbol::There(ref import) => write!(f, "import {} as {}", import, self.id),
Symbol::Flat(ref flat_fun) => write!(
f,
"def {}{}:\n\t// hidden",
self.id,
@ -62,50 +72,95 @@ impl<'ast, T: Field> fmt::Display for FunctionDeclaration<'ast, T> {
}
}
type FunctionDeclarationNode<'ast, T> = Node<FunctionDeclaration<'ast, T>>;
pub type SymbolDeclarationNode<'ast, T> = Node<SymbolDeclaration<'ast, T>>;
/// A module as a collection of `FunctionDeclaration`s
#[derive(Clone, PartialEq)]
pub struct Module<'ast, T: Field> {
/// Functions of the module
pub functions: FunctionDeclarations<'ast, T>,
/// Symbols of the module
pub symbols: Declarations<'ast, T>,
pub imports: Vec<ImportNode<'ast>>, // we still use `imports` as they are not directly converted into `FunctionDeclaration`s after the importer is done, `imports` is empty
}
/// A function, be it defined in this module, imported from another module or a flat embed
#[derive(Debug, Clone, PartialEq)]
pub enum FunctionSymbol<'ast, T: Field> {
Here(FunctionNode<'ast, T>),
There(FunctionImportNode<'ast>),
Flat(FlatEmbed),
impl<'ast, T: Field> Module<'ast, T> {
pub fn with_symbols<I: IntoIterator<Item = SymbolDeclarationNode<'ast, T>>>(i: I) -> Self {
Module {
symbols: i.into_iter().collect(),
imports: vec![],
}
}
pub fn imports<I: IntoIterator<Item = ImportNode<'ast>>>(mut self, i: I) -> Self {
self.imports = i.into_iter().collect();
self
}
}
/// A function import
pub type UnresolvedTypeNode = Node<UnresolvedType>;
/// A struct type definition
#[derive(Debug, Clone, PartialEq)]
pub struct FunctionImport<'ast> {
/// the id of the function in the target module. Note: there may be many candidates as imports statements do not specify the signature
pub function_id: Identifier<'ast>,
pub struct StructType<'ast> {
pub fields: Vec<StructFieldNode<'ast>>,
}
impl<'ast> fmt::Display for StructType<'ast> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(
f,
"{}",
self.fields
.iter()
.map(|fi| fi.to_string())
.collect::<Vec<_>>()
.join("\n")
)
}
}
pub type StructTypeNode<'ast> = Node<StructType<'ast>>;
/// A struct type definition
#[derive(Debug, Clone, PartialEq)]
pub struct StructField<'ast> {
pub id: Identifier<'ast>,
pub ty: UnresolvedTypeNode,
}
impl<'ast> fmt::Display for StructField<'ast> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}: {},", self.id, self.ty)
}
}
type StructFieldNode<'ast> = Node<StructField<'ast>>;
/// An import
#[derive(Debug, Clone, PartialEq)]
pub struct SymbolImport<'ast> {
/// the id of the symbol in the target module. Note: there may be many candidates as imports statements do not specify the signature. In that case they must all be functions however.
pub symbol_id: Identifier<'ast>,
/// the id of the module to import from
pub module_id: ModuleId,
}
type FunctionImportNode<'ast> = Node<FunctionImport<'ast>>;
type SymbolImportNode<'ast> = Node<SymbolImport<'ast>>;
impl<'ast> FunctionImport<'ast> {
impl<'ast> SymbolImport<'ast> {
pub fn with_id_in_module<S: Into<Identifier<'ast>>, U: Into<ModuleId>>(
function_id: S,
symbol_id: S,
module_id: U,
) -> Self {
FunctionImport {
function_id: function_id.into(),
SymbolImport {
symbol_id: symbol_id.into(),
module_id: module_id.into(),
}
}
}
impl<'ast> fmt::Display for FunctionImport<'ast> {
impl<'ast> fmt::Display for SymbolImport<'ast> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{} from {}", self.function_id, self.module_id)
write!(f, "{} from {}", self.symbol_id, self.module_id)
}
}
@ -119,7 +174,7 @@ impl<'ast, T: Field> fmt::Display for Module<'ast, T> {
.collect::<Vec<_>>(),
);
res.extend(
self.functions
self.symbols
.iter()
.map(|x| format!("{}", x))
.collect::<Vec<_>>(),
@ -132,13 +187,13 @@ impl<'ast, T: Field> fmt::Debug for Module<'ast, T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(
f,
"module(\n\timports:\n\t\t{}\n\tfunctions:\n\t\t{}\n)",
"module(\n\timports:\n\t\t{}\n\tsymbols:\n\t\t{}\n)",
self.imports
.iter()
.map(|x| format!("{:?}", x))
.collect::<Vec<_>>()
.join("\n\t\t"),
self.functions
self.symbols
.iter()
.map(|x| format!("{:?}", x))
.collect::<Vec<_>>()
@ -155,7 +210,7 @@ pub struct Function<'ast, T: Field> {
/// Vector of statements that are executed when running the function
pub statements: Vec<StatementNode<'ast, T>>,
/// function signature
pub signature: Signature,
pub signature: UnresolvedSignature,
}
pub type FunctionNode<'ast, T> = Node<Function<'ast, T>>;
@ -199,6 +254,7 @@ impl<'ast, T: Field> fmt::Debug for Function<'ast, T> {
pub enum Assignee<'ast, T: Field> {
Identifier(Identifier<'ast>),
Select(Box<AssigneeNode<'ast, T>>, Box<RangeOrExpression<'ast, T>>),
Member(Box<AssigneeNode<'ast, T>>, Box<Identifier<'ast>>),
}
pub type AssigneeNode<'ast, T> = Node<Assignee<'ast, T>>;
@ -208,6 +264,7 @@ impl<'ast, T: Field> fmt::Debug for Assignee<'ast, T> {
match *self {
Assignee::Identifier(ref s) => write!(f, "Identifier({:?})", s),
Assignee::Select(ref a, ref e) => write!(f, "Select({:?}[{:?}])", a, e),
Assignee::Member(ref s, ref m) => write!(f, "Member({:?}.{:?})", s, m),
}
}
}
@ -414,10 +471,12 @@ pub enum Expression<'ast, T: Field> {
And(Box<ExpressionNode<'ast, T>>, Box<ExpressionNode<'ast, T>>),
Not(Box<ExpressionNode<'ast, T>>),
InlineArray(Vec<SpreadOrExpression<'ast, T>>),
InlineStruct(UserTypeId, Vec<(Identifier<'ast>, ExpressionNode<'ast, T>)>),
Select(
Box<ExpressionNode<'ast, T>>,
Box<RangeOrExpression<'ast, T>>,
),
Member(Box<ExpressionNode<'ast, T>>, Box<Identifier<'ast>>),
Or(Box<ExpressionNode<'ast, T>>, Box<ExpressionNode<'ast, T>>),
}
@ -466,7 +525,18 @@ impl<'ast, T: Field> fmt::Display for Expression<'ast, T> {
}
write!(f, "]")
}
Expression::InlineStruct(ref id, ref members) => {
write!(f, "{} {{", id)?;
for (i, (member_id, e)) in members.iter().enumerate() {
write!(f, "{}: {}", member_id, e)?;
if i < members.len() - 1 {
write!(f, ", ")?;
}
}
write!(f, "}}")
}
Expression::Select(ref array, ref index) => write!(f, "{}[{}]", array, index),
Expression::Member(ref struc, ref id) => write!(f, "{}.{}", struc, id),
Expression::Or(ref lhs, ref rhs) => write!(f, "{} || {}", lhs, rhs),
}
}
@ -505,9 +575,15 @@ impl<'ast, T: Field> fmt::Debug for Expression<'ast, T> {
f.debug_list().entries(exprs.iter()).finish()?;
write!(f, "]")
}
Expression::InlineStruct(ref id, ref members) => {
write!(f, "InlineStruct({:?}, [", id)?;
f.debug_list().entries(members.iter()).finish()?;
write!(f, "]")
}
Expression::Select(ref array, ref index) => {
write!(f, "Select({:?}, {:?})", array, index)
}
Expression::Member(ref struc, ref id) => write!(f, "{}.{}", struc, id),
Expression::Or(ref lhs, ref rhs) => write!(f, "{} || {}", lhs, rhs),
}
}

View file

@ -74,10 +74,13 @@ impl<'ast, T: Field> NodeValue for Expression<'ast, T> {}
impl<'ast, T: Field> NodeValue for ExpressionList<'ast, T> {}
impl<'ast, T: Field> NodeValue for Assignee<'ast, T> {}
impl<'ast, T: Field> NodeValue for Statement<'ast, T> {}
impl<'ast, T: Field> NodeValue for FunctionDeclaration<'ast, T> {}
impl<'ast, T: Field> NodeValue for SymbolDeclaration<'ast, T> {}
impl NodeValue for UnresolvedType {}
impl<'ast> NodeValue for StructType<'ast> {}
impl<'ast> NodeValue for StructField<'ast> {}
impl<'ast, T: Field> NodeValue for Function<'ast, T> {}
impl<'ast, T: Field> NodeValue for Module<'ast, T> {}
impl<'ast> NodeValue for FunctionImport<'ast> {}
impl<'ast> NodeValue for SymbolImport<'ast> {}
impl<'ast> NodeValue for Variable<'ast> {}
impl<'ast> NodeValue for Parameter<'ast> {}
impl<'ast> NodeValue for Import<'ast> {}

View file

@ -0,0 +1,98 @@
use absy::UnresolvedTypeNode;
use std::fmt;
pub type Identifier<'ast> = &'ast str;
pub type MemberId = String;
pub type UserTypeId = String;
#[derive(Clone, PartialEq, Serialize, Deserialize, Debug)]
pub enum UnresolvedType {
FieldElement,
Boolean,
Array(Box<UnresolvedTypeNode>, usize),
User(UserTypeId),
}
impl fmt::Display for UnresolvedType {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
UnresolvedType::FieldElement => write!(f, "field"),
UnresolvedType::Boolean => write!(f, "bool"),
UnresolvedType::Array(ref ty, ref size) => write!(f, "{}[{}]", ty, size),
UnresolvedType::User(i) => write!(f, "{}", i),
}
}
}
impl UnresolvedType {
pub fn array(ty: UnresolvedTypeNode, size: usize) -> Self {
UnresolvedType::Array(box ty, size)
}
}
pub type FunctionIdentifier<'ast> = &'ast str;
pub use self::signature::UnresolvedSignature;
mod signature {
use std::fmt;
use absy::UnresolvedTypeNode;
#[derive(Clone, PartialEq, Serialize, Deserialize)]
pub struct UnresolvedSignature {
pub inputs: Vec<UnresolvedTypeNode>,
pub outputs: Vec<UnresolvedTypeNode>,
}
impl fmt::Debug for UnresolvedSignature {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(
f,
"Signature(inputs: {:?}, outputs: {:?})",
self.inputs, self.outputs
)
}
}
impl fmt::Display for UnresolvedSignature {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "(")?;
for (i, t) in self.inputs.iter().enumerate() {
write!(f, "{}", t)?;
if i < self.inputs.len() - 1 {
write!(f, ", ")?;
}
}
write!(f, ") -> (")?;
for (i, t) in self.outputs.iter().enumerate() {
write!(f, "{}", t)?;
if i < self.outputs.len() - 1 {
write!(f, ", ")?;
}
}
write!(f, ")")
}
}
impl UnresolvedSignature {
pub fn new() -> UnresolvedSignature {
UnresolvedSignature {
inputs: vec![],
outputs: vec![],
}
}
pub fn inputs(mut self, inputs: Vec<UnresolvedTypeNode>) -> Self {
self.inputs = inputs;
self
}
pub fn outputs(mut self, outputs: Vec<UnresolvedTypeNode>) -> Self {
self.outputs = outputs;
self
}
}
}

View file

@ -1,55 +1,27 @@
use crate::absy::Node;
use crate::absy::types::UnresolvedType;
use crate::absy::{Node, UnresolvedTypeNode};
use std::fmt;
use types::Type;
use crate::absy::Identifier;
#[derive(Clone, PartialEq, Hash, Eq)]
#[derive(Clone, PartialEq)]
pub struct Variable<'ast> {
pub id: Identifier<'ast>,
pub _type: Type,
pub _type: UnresolvedTypeNode,
}
pub type VariableNode<'ast> = Node<Variable<'ast>>;
impl<'ast> Variable<'ast> {
pub fn new<S: Into<&'ast str>>(id: S, t: Type) -> Variable<'ast> {
pub fn new<S: Into<&'ast str>>(id: S, t: UnresolvedTypeNode) -> Variable<'ast> {
Variable {
id: id.into(),
_type: t,
}
}
pub fn field_element<S: Into<&'ast str>>(id: S) -> Variable<'ast> {
Variable {
id: id.into(),
_type: Type::FieldElement,
}
}
pub fn boolean<S: Into<&'ast str>>(id: S) -> Variable<'ast> {
Variable {
id: id.into(),
_type: Type::Boolean,
}
}
pub fn field_array<S: Into<&'ast str>>(id: S, size: usize) -> Variable<'ast> {
Variable {
id: id.into(),
_type: Type::array(Type::FieldElement, size),
}
}
pub fn array<S: Into<&'ast str>>(id: S, inner_ty: Type, size: usize) -> Variable<'ast> {
Variable {
id: id.into(),
_type: Type::array(inner_ty, size),
}
}
pub fn get_type(&self) -> Type {
self._type.clone()
pub fn get_type(&self) -> UnresolvedType {
self._type.value.clone()
}
}

View file

@ -5,7 +5,7 @@ use flat_absy::{
};
use reduce::Reduce;
use std::collections::HashMap;
use types::{FunctionKey, Signature, Type};
use typed_absy::types::{FunctionKey, Signature, Type};
use zokrates_embed::{generate_sha256_round_constraints, BellmanConstraint};
use zokrates_field::field::Field;

View file

@ -12,7 +12,7 @@ pub use self::flat_parameter::FlatParameter;
pub use self::flat_variable::FlatVariable;
use crate::helpers::DirectiveStatement;
use crate::types::Signature;
use crate::typed_absy::types::Signature;
use std::collections::HashMap;
use std::fmt;
use zokrates_field::field::Field;

View file

@ -7,12 +7,10 @@
use crate::flat_absy::*;
use crate::helpers::{DirectiveStatement, Helper, RustHelper};
use crate::typed_absy::types::{FunctionIdentifier, FunctionKey, MemberId, Signature, Type};
use crate::typed_absy::*;
use crate::types::Type;
use crate::types::{FunctionKey, Signature};
use std::collections::HashMap;
use std::convert::TryFrom;
use types::FunctionIdentifier;
use zokrates_field::field::Field;
/// Flattener, computes flattened program.
@ -29,7 +27,7 @@ pub struct Flattener<'ast, T: Field> {
// We introduce a trait in order to make it possible to make flattening `e` generic over the type of `e`
trait Flatten<'ast, T: Field>:
TryFrom<TypedExpression<'ast, T>, Error = ()> + IfElse<'ast, T> + Select<'ast, T>
TryFrom<TypedExpression<'ast, T>, Error = ()> + IfElse<'ast, T> + Select<'ast, T> + Member<'ast, T>
{
fn flatten(
self,
@ -61,6 +59,17 @@ impl<'ast, T: Field> Flatten<'ast, T> for BooleanExpression<'ast, T> {
}
}
impl<'ast, T: Field> Flatten<'ast, T> for StructExpression<'ast, T> {
fn flatten(
self,
flattener: &mut Flattener<'ast, T>,
symbols: &TypedFunctionSymbols<'ast, T>,
statements_flattened: &mut Vec<FlatStatement<T>>,
) -> Vec<FlatExpression<T>> {
flattener.flatten_struct_expression(symbols, statements_flattened, self)
}
}
impl<'ast, T: Field> Flatten<'ast, T> for ArrayExpression<'ast, T> {
fn flatten(
self,
@ -85,6 +94,11 @@ impl<'ast, T: Field> Flatten<'ast, T> for ArrayExpression<'ast, T> {
statements_flattened,
self,
),
Type::Struct(..) => flattener.flatten_array_expression::<StructExpression<'ast, T>>(
symbols,
statements_flattened,
self,
),
}
}
}
@ -194,6 +208,163 @@ impl<'ast, T: Field> Flattener<'ast, T> {
res.into_iter().map(|r| r.into()).collect()
}
fn flatten_member_expression(
&mut self,
symbols: &TypedFunctionSymbols<'ast, T>,
statements_flattened: &mut Vec<FlatStatement<T>>,
s: StructExpression<'ast, T>,
member_id: MemberId,
) -> Vec<FlatExpression<T>> {
let members = s.ty().clone();
let expected_output_size = members
.iter()
.find(|(id, _)| *id == member_id)
.unwrap()
.1
.get_primitive_count();
let res =
match s.into_inner() {
StructExpressionInner::Value(values) => {
// If the struct has an explicit value, we get the value at the given member
assert_eq!(values.len(), members.len());
values
.into_iter()
.zip(members.into_iter())
.filter(|(_, (id, _))| *id == member_id)
.flat_map(|(v, (_, t))| match t {
Type::FieldElement => FieldElementExpression::try_from(v)
.unwrap()
.flatten(self, symbols, statements_flattened),
Type::Boolean => BooleanExpression::try_from(v).unwrap().flatten(
self,
symbols,
statements_flattened,
),
Type::Array(..) => ArrayExpression::try_from(v).unwrap().flatten(
self,
symbols,
statements_flattened,
),
Type::Struct(..) => StructExpression::try_from(v).unwrap().flatten(
self,
symbols,
statements_flattened,
),
})
.collect()
}
StructExpressionInner::Identifier(id) => {
// If the struct is an identifier, we allocated variables in the layout for that identifier. We need to access a subset of these values.
// the struct is encoded as a sequence, so we need to identify the offset at which this member starts
let offset = members
.iter()
.take_while(|(id, _)| *id != member_id)
.map(|(_, ty)| ty.get_primitive_count())
.sum();
// we also need the size of this member
let size = members
.iter()
.find(|(id, _)| *id == member_id)
.unwrap()
.1
.get_primitive_count();
self.layout.get(&id).unwrap()[offset..(offset + size)]
.into_iter()
.map(|i| i.clone().into())
.collect()
}
StructExpressionInner::Select(box array, box index) => {
let offset = members
.iter()
.take_while(|(id, _)| *id != member_id)
.map(|(_, ty)| ty.get_primitive_count())
.sum();
// we also need the size of this member
let size = members
.iter()
.find(|(id, _)| *id == member_id)
.unwrap()
.1
.get_primitive_count();
self.flatten_select_expression::<StructExpression<'ast, T>>(
symbols,
statements_flattened,
array,
index,
)[offset..offset + size]
.to_vec()
}
StructExpressionInner::FunctionCall(..) => unreachable!(),
StructExpressionInner::IfElse(box condition, box consequence, box alternative) => {
// if the struct is `(if c then a else b)`, we want to access `(if c then a else b).member`
// we reduce to `if c then a.member else b.member`
let ty = members
.clone()
.into_iter()
.find(|(id, _)| *id == member_id)
.unwrap()
.1;
match ty {
Type::FieldElement => self.flatten_if_else_expression(
symbols,
statements_flattened,
condition.clone(),
FieldElementExpression::member(consequence.clone(), member_id.clone()),
FieldElementExpression::member(alternative.clone(), member_id),
),
Type::Boolean => self.flatten_if_else_expression(
symbols,
statements_flattened,
condition.clone(),
BooleanExpression::member(consequence.clone(), member_id.clone()),
BooleanExpression::member(alternative.clone(), member_id),
),
Type::Struct(..) => self.flatten_if_else_expression(
symbols,
statements_flattened,
condition.clone(),
StructExpression::member(consequence.clone(), member_id.clone()),
StructExpression::member(alternative.clone(), member_id),
),
Type::Array(..) => self.flatten_if_else_expression(
symbols,
statements_flattened,
condition.clone(),
ArrayExpression::member(consequence.clone(), member_id.clone()),
ArrayExpression::member(alternative.clone(), member_id),
),
}
}
StructExpressionInner::Member(box s0, m_id) => {
let e = self.flatten_member_expression(symbols, statements_flattened, s0, m_id);
let offset = members
.iter()
.take_while(|(id, _)| *id != member_id)
.map(|(_, ty)| ty.get_primitive_count())
.sum();
// we also need the size of this member
let size = members
.iter()
.find(|(id, _)| *id == member_id)
.unwrap()
.1
.get_primitive_count();
e[offset..(offset + size)].into()
}
};
assert_eq!(res.len(), expected_output_size);
res
}
/// Flatten an array selection expression
///
/// # Arguments
@ -244,6 +415,13 @@ impl<'ast, T: Field> Flattener<'ast, T> {
)
.flatten(self, symbols, statements_flattened)
}
ArrayExpressionInner::Member(box s, id) => {
assert!(n < T::from(size));
let n = n.to_dec_string().parse::<usize>().unwrap();
self.flatten_member_expression(symbols, statements_flattened, s, id)
[n * ty.get_primitive_count()..(n + 1) * ty.get_primitive_count()]
.to_vec()
}
ArrayExpressionInner::Select(box array, box index) => {
assert!(n < T::from(size));
let n = n.to_dec_string().parse::<usize>().unwrap();
@ -270,6 +448,13 @@ impl<'ast, T: Field> Flattener<'ast, T> {
array,
index,
),
Type::Struct(..) => self
.flatten_select_expression::<StructExpression<'ast, T>>(
symbols,
statements_flattened,
array,
index,
),
};
e[n * element_size..(n + 1) * element_size]
@ -333,6 +518,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
),
)
}
ArrayExpressionInner::Member(box s, id) => U::member(s, id),
ArrayExpressionInner::Select(box array, box index) => U::select(
ArrayExpressionInner::Select(box array, box index)
.annotate(ty.clone(), size),
@ -647,6 +833,9 @@ impl<'ast, T: Field> Flattener<'ast, T> {
alternative,
)[0]
.clone(),
BooleanExpression::Member(box s, id) => {
self.flatten_member_expression(symbols, statements_flattened, s, id)[0].clone()
}
BooleanExpression::Select(box array, box index) => self
.flatten_select_expression::<BooleanExpression<'ast, T>>(
symbols,
@ -802,7 +991,15 @@ impl<'ast, T: Field> Flattener<'ast, T> {
statements_flattened,
e,
),
Type::Struct(..) => self.flatten_array_expression::<StructExpression<'ast, T>>(
symbols,
statements_flattened,
e,
),
},
TypedExpression::Struct(e) => {
self.flatten_struct_expression(symbols, statements_flattened, e)
}
}
}
@ -1035,6 +1232,9 @@ impl<'ast, T: Field> Flattener<'ast, T> {
assert!(exprs_flattened.expressions.len() == 1); // outside of MultipleDefinition, FunctionCalls must return a single value
exprs_flattened.expressions[0].clone()
}
FieldElementExpression::Member(box s, id) => {
self.flatten_member_expression(symbols, statements_flattened, s, id)[0].clone()
}
FieldElementExpression::Select(box array, box index) => self
.flatten_select_expression::<FieldElementExpression<'ast, T>>(
symbols,
@ -1046,6 +1246,92 @@ impl<'ast, T: Field> Flattener<'ast, T> {
}
}
/// Flattens an array expression
///
/// # Arguments
///
/// * `symbols` - Available functions in in this context
/// * `statements_flattened` - Vector where new flattened statements can be added.
/// * `expr` - `StructExpression` that will be flattened.
fn flatten_struct_expression(
&mut self,
symbols: &TypedFunctionSymbols<'ast, T>,
statements_flattened: &mut Vec<FlatStatement<T>>,
expr: StructExpression<'ast, T>,
) -> Vec<FlatExpression<T>> {
let ty = expr.get_type();
let expected_output_size = expr.get_type().get_primitive_count();
let members = expr.ty().clone();
let res = match expr.into_inner() {
StructExpressionInner::Identifier(x) => self
.layout
.get(&x)
.unwrap()
.iter()
.map(|v| FlatExpression::Identifier(v.clone()))
.collect(),
StructExpressionInner::Value(values) => values
.into_iter()
.flat_map(|v| self.flatten_expression(symbols, statements_flattened, v))
.collect(),
StructExpressionInner::FunctionCall(key, param_expressions) => {
let exprs_flattened = self.flatten_function_call(
symbols,
statements_flattened,
key.id,
vec![ty],
param_expressions,
);
exprs_flattened.expressions
}
StructExpressionInner::IfElse(box condition, box consequence, box alternative) => {
members
.into_iter()
.flat_map(|(id, ty)| match ty {
Type::FieldElement => FieldElementExpression::if_else(
condition.clone(),
FieldElementExpression::member(consequence.clone(), id.clone()),
FieldElementExpression::member(alternative.clone(), id.clone()),
)
.flatten(self, symbols, statements_flattened),
Type::Boolean => BooleanExpression::if_else(
condition.clone(),
BooleanExpression::member(consequence.clone(), id.clone()),
BooleanExpression::member(alternative.clone(), id.clone()),
)
.flatten(self, symbols, statements_flattened),
Type::Struct(..) => StructExpression::if_else(
condition.clone(),
StructExpression::member(consequence.clone(), id.clone()),
StructExpression::member(alternative.clone(), id.clone()),
)
.flatten(self, symbols, statements_flattened),
Type::Array(..) => ArrayExpression::if_else(
condition.clone(),
ArrayExpression::member(consequence.clone(), id.clone()),
ArrayExpression::member(alternative.clone(), id.clone()),
)
.flatten(self, symbols, statements_flattened),
})
.collect()
}
StructExpressionInner::Member(box s, id) => {
self.flatten_member_expression(symbols, statements_flattened, s, id)
}
StructExpressionInner::Select(box array, box index) => self
.flatten_select_expression::<StructExpression<'ast, T>>(
symbols,
statements_flattened,
array,
index,
),
};
assert_eq!(res.len(), expected_output_size);
res
}
/// Flattens an array expression
///
/// # Arguments
@ -1112,6 +1398,9 @@ impl<'ast, T: Field> Flattener<'ast, T> {
.flatten(self, symbols, statements_flattened)
})
.collect(),
ArrayExpressionInner::Member(box s, id) => {
self.flatten_member_expression(symbols, statements_flattened, s, id)
}
ArrayExpressionInner::Select(box array, box index) => self
.flatten_select_expression::<ArrayExpression<'ast, T>>(
symbols,
@ -1170,6 +1459,9 @@ impl<'ast, T: Field> Flattener<'ast, T> {
TypedAssignee::Select(..) => unreachable!(
"array element redefs should have been replaced by array redefs in unroll"
),
TypedAssignee::Member(..) => unreachable!(
"struct member redefs should have been replaced by struct redef in unroll"
),
}
}
TypedStatement::Condition(lhs, rhs) => {
@ -1256,7 +1548,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
let arguments_flattened = funct
.arguments
.into_iter()
.flat_map(|p| self.use_parameter(&p, &mut statements_flattened))
.flat_map(|p| self.use_parameter(&p))
.collect();
// flatten statements in functions and apply substitution
@ -1304,29 +1596,14 @@ impl<'ast, T: Field> Flattener<'ast, T> {
///
/// * `name` - a String that holds the name of the variable
fn use_variable(&mut self, variable: &Variable<'ast>) -> Vec<FlatVariable> {
let vars = match variable.get_type() {
Type::FieldElement => self.issue_new_variables(1),
Type::Boolean => self.issue_new_variables(1),
Type::Array(ty, size) => self.issue_new_variables(ty.get_primitive_count() * size),
};
let vars = self.issue_new_variables(variable.get_type().get_primitive_count());
self.layout.insert(variable.id.clone(), vars.clone());
vars
}
fn use_parameter(
&mut self,
parameter: &Parameter<'ast>,
statements: &mut Vec<FlatStatement<T>>,
) -> Vec<FlatParameter> {
fn use_parameter(&mut self, parameter: &Parameter<'ast>) -> Vec<FlatParameter> {
let variables = self.use_variable(&parameter.id);
match parameter.id.get_type() {
Type::Boolean => statements.extend(Self::boolean_constraint(&variables)),
Type::Array(box Type::Boolean, _) => {
statements.extend(Self::boolean_constraint(&variables))
}
_ => {}
};
variables
.into_iter()
@ -1347,21 +1624,6 @@ impl<'ast, T: Field> Flattener<'ast, T> {
(0..count).map(|_| self.issue_new_variable()).collect()
}
fn boolean_constraint(variables: &Vec<FlatVariable>) -> Vec<FlatStatement<T>> {
variables
.iter()
.map(|v| {
FlatStatement::Condition(
FlatExpression::Identifier(*v),
FlatExpression::Mult(
box FlatExpression::Identifier(*v),
box FlatExpression::Identifier(*v),
),
)
})
.collect()
}
// create an internal variable. We do not register it in the layout
fn use_sym(&mut self) -> FlatVariable {
self.issue_new_variable()
@ -1386,62 +1648,10 @@ impl<'ast, T: Field> Flattener<'ast, T> {
#[cfg(test)]
mod tests {
use super::*;
use crate::types::Signature;
use crate::types::Type;
use crate::typed_absy::types::Signature;
use crate::typed_absy::types::Type;
use zokrates_field::field::FieldPrime;
mod boolean_checks {
use super::*;
#[test]
fn boolean_arg() {
// def main(bool a):
// return a
//
// -> should flatten to
//
// def main(_0) -> (1):
// _0 * _0 == _0
// return _0
let function: TypedFunction<FieldPrime> = TypedFunction {
arguments: vec![Parameter::private(Variable::boolean("a".into()))],
statements: vec![TypedStatement::Return(vec![BooleanExpression::Identifier(
"a".into(),
)
.into()])],
signature: Signature::new()
.inputs(vec![Type::Boolean])
.outputs(vec![Type::Boolean]),
};
let expected = FlatFunction {
arguments: vec![FlatParameter::private(FlatVariable::new(0))],
statements: vec![
FlatStatement::Condition(
FlatExpression::Identifier(FlatVariable::new(0)),
FlatExpression::Mult(
box FlatExpression::Identifier(FlatVariable::new(0)),
box FlatExpression::Identifier(FlatVariable::new(0)),
),
),
FlatStatement::Return(FlatExpressionList {
expressions: vec![FlatExpression::Identifier(FlatVariable::new(0))],
}),
],
signature: Signature::new()
.inputs(vec![Type::Boolean])
.outputs(vec![Type::Boolean]),
};
let mut flattener = Flattener::new();
let flat_function = flattener.flatten_function(&mut HashMap::new(), function);
assert_eq!(flat_function, expected);
}
}
#[test]
fn powers_zero() {
// def main():

View file

@ -57,15 +57,17 @@ impl From<io::Error> for Error {
#[derive(PartialEq, Clone)]
pub struct Import<'ast> {
source: Identifier<'ast>,
symbol: Option<Identifier<'ast>>,
alias: Option<Identifier<'ast>>,
}
pub type ImportNode<'ast> = Node<Import<'ast>>;
impl<'ast> Import<'ast> {
pub fn new(source: Identifier<'ast>) -> Import<'ast> {
pub fn new(symbol: Option<Identifier<'ast>>, source: Identifier<'ast>) -> Import<'ast> {
Import {
source: source,
symbol,
source,
alias: None,
}
}
@ -74,9 +76,14 @@ impl<'ast> Import<'ast> {
&self.alias
}
pub fn new_with_alias(source: Identifier<'ast>, alias: Identifier<'ast>) -> Import<'ast> {
pub fn new_with_alias(
symbol: Option<Identifier<'ast>>,
source: Identifier<'ast>,
alias: Identifier<'ast>,
) -> Import<'ast> {
Import {
source: source,
symbol,
source,
alias: Some(alias),
}
}
@ -124,7 +131,7 @@ impl Importer {
modules: &mut HashMap<ModuleId, Module<'ast, T>>,
arena: &'ast Arena<String>,
) -> Result<Module<'ast, T>, CompileErrors> {
let mut functions: Vec<_> = vec![];
let mut symbols: Vec<_> = vec![];
for import in destination.imports {
let pos = import.pos();
@ -136,10 +143,10 @@ impl Importer {
"EMBED/sha256round" => {
let alias = alias.unwrap_or("sha256round");
functions.push(
FunctionDeclaration {
symbols.push(
SymbolDeclaration {
id: &alias,
symbol: FunctionSymbol::Flat(FlatEmbed::Sha256Round),
symbol: Symbol::Flat(FlatEmbed::Sha256Round),
}
.start_end(pos.0, pos.1),
);
@ -147,10 +154,10 @@ impl Importer {
"EMBED/unpack" => {
let alias = alias.unwrap_or("unpack");
functions.push(
FunctionDeclaration {
symbols.push(
SymbolDeclaration {
id: &alias,
symbol: FunctionSymbol::Flat(FlatEmbed::Unpack),
symbol: Symbol::Flat(FlatEmbed::Unpack),
}
.start_end(pos.0, pos.1),
);
@ -185,12 +192,12 @@ impl Importer {
modules.insert(import.source.to_string(), compiled);
functions.push(
FunctionDeclaration {
symbols.push(
SymbolDeclaration {
id: &alias,
symbol: FunctionSymbol::There(
FunctionImport::with_id_in_module(
"main",
symbol: Symbol::There(
SymbolImport::with_id_in_module(
import.symbol.unwrap_or("main"),
import.source.clone(),
)
.start_end(pos.0, pos.1),
@ -218,11 +225,12 @@ impl Importer {
}
}
functions.extend(destination.functions);
symbols.extend(destination.symbols);
Ok(Module {
imports: vec![],
functions: functions,
symbols,
..destination
})
}
}
@ -235,8 +243,9 @@ mod tests {
#[test]
fn create_with_no_alias() {
assert_eq!(
Import::new("./foo/bar/baz.zok"),
Import::new(None, "./foo/bar/baz.zok"),
Import {
symbol: None,
source: "./foo/bar/baz.zok",
alias: None,
}
@ -246,8 +255,9 @@ mod tests {
#[test]
fn create_with_alias() {
assert_eq!(
Import::new_with_alias("./foo/bar/baz.zok", &"myalias"),
Import::new_with_alias(None, "./foo/bar/baz.zok", &"myalias"),
Import {
symbol: None,
source: "./foo/bar/baz.zok",
alias: Some("myalias"),
}

View file

@ -35,7 +35,6 @@ mod parser;
mod semantics;
mod static_analysis;
mod typed_absy;
mod types;
pub mod absy;
pub mod compile;

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,134 @@
//! Add runtime boolean checks on user inputs
//!
//! Example:
//! ```zokrates
//! struct Foo {
//! bar: bool
//! }
//!
//! def main(Foo f) -> ():
//! f.bar == f.bar && f.bar
//! return
//! ```
//!
//! Becomes
//!
//! ```zokrates
//! struct Foo {
//! bar: bool
//! }
//!
//! def main(Foo f) -> ():
//! f.bar == f.bar && f.bar
//! return
//! ```
//!
//! @file constrain_inputs.rs
//! @author Thibaut Schaeffer <thibaut@schaeff.fr>
//! @date 2019
use crate::typed_absy::folder::Folder;
use crate::typed_absy::types::Type;
use crate::typed_absy::*;
use zokrates_field::field::Field;
pub struct InputConstrainer<'ast, T: Field> {
constraints: Vec<TypedStatement<'ast, T>>,
}
impl<'ast, T: Field> InputConstrainer<'ast, T> {
fn new() -> Self {
InputConstrainer {
constraints: vec![],
}
}
pub fn constrain(p: TypedProgram<T>) -> TypedProgram<T> {
InputConstrainer::new().fold_program(p)
}
fn constrain_expression(&mut self, e: TypedExpression<'ast, T>) {
match e {
TypedExpression::FieldElement(_) => {}
TypedExpression::Boolean(b) => self.constraints.push(TypedStatement::Condition(
b.clone().into(),
BooleanExpression::And(box b.clone(), box b).into(),
)),
TypedExpression::Array(a) => {
for i in 0..a.size() {
let e = match a.inner_type() {
Type::FieldElement => FieldElementExpression::select(
a.clone(),
FieldElementExpression::Number(T::from(i)),
)
.into(),
Type::Boolean => BooleanExpression::select(
a.clone(),
FieldElementExpression::Number(T::from(i)),
)
.into(),
Type::Array(..) => ArrayExpression::select(
a.clone(),
FieldElementExpression::Number(T::from(i)),
)
.into(),
Type::Struct(..) => StructExpression::select(
a.clone(),
FieldElementExpression::Number(T::from(i)),
)
.into(),
};
self.constrain_expression(e);
}
}
TypedExpression::Struct(s) => {
for (id, ty) in s.ty() {
let e = match ty {
Type::FieldElement => {
FieldElementExpression::member(s.clone(), id.clone()).into()
}
Type::Boolean => BooleanExpression::member(s.clone(), id.clone()).into(),
Type::Array(..) => ArrayExpression::member(s.clone(), id.clone()).into(),
Type::Struct(..) => StructExpression::member(s.clone(), id.clone()).into(),
};
self.constrain_expression(e);
}
}
}
}
}
impl<'ast, T: Field> Folder<'ast, T> for InputConstrainer<'ast, T> {
fn fold_parameter(&mut self, p: Parameter<'ast>) -> Parameter<'ast> {
let v = p.id.clone();
let e = match v.get_type() {
Type::FieldElement => FieldElementExpression::Identifier(v.id).into(),
Type::Boolean => BooleanExpression::Identifier(v.id).into(),
Type::Struct(members) => StructExpressionInner::Identifier(v.id)
.annotate(members)
.into(),
Type::Array(box ty, size) => ArrayExpressionInner::Identifier(v.id)
.annotate(ty, size)
.into(),
};
self.constrain_expression(e);
p
}
fn fold_function(&mut self, f: TypedFunction<'ast, T>) -> TypedFunction<'ast, T> {
TypedFunction {
arguments: f
.arguments
.into_iter()
.map(|a| self.fold_parameter(a))
.collect(),
statements: self.constraints.drain(..).chain(f.statements).collect(),
..f
}
}
}

View file

@ -17,8 +17,8 @@
//! where any call in `main` must be to `_SHA_256_ROUND` or `_UNPACK`
use std::collections::HashMap;
use typed_absy::types::{FunctionKey, MemberId, Type};
use typed_absy::{folder::*, *};
use types::{FunctionKey, Type};
use zokrates_field::field::Field;
/// An inliner
@ -260,12 +260,36 @@ impl<'ast, T: Field> Folder<'ast, T> for Inliner<'ast, T> {
e => fold_array_expression_inner(self, ty, size, e),
}
}
fn fold_struct_expression_inner(
&mut self,
ty: &Vec<(MemberId, Type)>,
e: StructExpressionInner<'ast, T>,
) -> StructExpressionInner<'ast, T> {
match e {
StructExpressionInner::FunctionCall(key, exps) => {
let exps: Vec<_> = exps.into_iter().map(|e| self.fold_expression(e)).collect();
match self.try_inline_call(&key, exps) {
Ok(mut ret) => match ret.pop().unwrap() {
TypedExpression::Struct(e) => e.into_inner(),
_ => unreachable!(),
},
Err((key, expressions)) => {
StructExpressionInner::FunctionCall(key, expressions)
}
}
}
// default
e => fold_struct_expression_inner(self, ty, e),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use types::{FunctionKey, Signature, Type};
use typed_absy::types::{FunctionKey, Signature, Type};
use zokrates_field::field::FieldPrime;
#[test]

View file

@ -4,11 +4,13 @@
//! @author Thibaut Schaeffer <thibaut@schaeff.fr>
//! @date 2018
mod constrain_inputs;
mod flat_propagation;
mod inline;
mod propagation;
mod unroll;
use self::constrain_inputs::InputConstrainer;
use self::inline::Inliner;
use self::propagation::Propagator;
use self::unroll::Unroller;
@ -28,6 +30,8 @@ impl<'ast, T: Field> Analyse for TypedProgram<'ast, T> {
let r = Inliner::inline(r);
// propagate
let r = Propagator::propagate(r);
// constrain inputs
let r = InputConstrainer::constrain(r);
r
}
}

View file

@ -8,7 +8,7 @@ use crate::typed_absy::folder::*;
use crate::typed_absy::*;
use std::collections::HashMap;
use std::convert::TryFrom;
use types::Type;
use typed_absy::types::{MemberId, Type};
use zokrates_field::field::Field;
pub struct Propagator<'ast, T: Field> {
@ -35,6 +35,10 @@ fn is_constant<'ast, T: Field>(e: &TypedExpression<'ast, T>) -> bool {
ArrayExpressionInner::Value(v) => v.iter().all(|e| is_constant(e)),
_ => false,
},
TypedExpression::Struct(a) => match a.as_inner() {
StructExpressionInner::Value(v) => v.iter().all(|e| is_constant(e)),
_ => false,
},
_ => false,
}
}
@ -71,6 +75,9 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> {
TypedStatement::Definition(TypedAssignee::Select(..), _) => {
unreachable!("array updates should have been replaced with full array redef")
}
TypedStatement::Definition(TypedAssignee::Member(..), _) => {
unreachable!("struct update should have been replaced with full struct redef")
}
// propagate lhs and rhs for conditions
TypedStatement::Condition(e1, e2) => {
// could stop execution here if condition is known to fail
@ -224,6 +231,24 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> {
}
}
}
FieldElementExpression::Member(box s, m) => {
let s = self.fold_struct_expression(s);
let members = match s.get_type() {
Type::Struct(members) => members,
_ => unreachable!(),
};
match s.into_inner() {
StructExpressionInner::Value(v) => {
match members.iter().zip(v).find(|(id, _)| id.0 == m).unwrap().1 {
TypedExpression::FieldElement(s) => s,
_ => unreachable!(),
}
}
inner => FieldElementExpression::Member(box inner.annotate(members), m),
}
}
e => fold_field_expression(self, e),
}
}
@ -302,10 +327,124 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> {
c => ArrayExpressionInner::IfElse(box c, box consequence, box alternative),
}
}
ArrayExpressionInner::Member(box s, m) => {
let s = self.fold_struct_expression(s);
let members = match s.get_type() {
Type::Struct(members) => members,
_ => unreachable!(),
};
match s.into_inner() {
StructExpressionInner::Value(v) => {
match members.iter().zip(v).find(|(id, _)| id.0 == m).unwrap().1 {
TypedExpression::Array(a) => a.into_inner(),
_ => unreachable!(),
}
}
inner => ArrayExpressionInner::Member(box inner.annotate(members), m),
}
}
e => fold_array_expression_inner(self, ty, size, e),
}
}
fn fold_struct_expression_inner(
&mut self,
ty: &Vec<(MemberId, Type)>,
e: StructExpressionInner<'ast, T>,
) -> StructExpressionInner<'ast, T> {
match e {
StructExpressionInner::Identifier(id) => {
match self
.constants
.get(&TypedAssignee::Identifier(Variable::struc(
id.clone(),
ty.clone(),
))) {
Some(e) => match e {
TypedExpression::Struct(e) => e.as_inner().clone(),
_ => panic!("constant stored for an array should be an array"),
},
None => StructExpressionInner::Identifier(id),
}
}
StructExpressionInner::Select(box array, box index) => {
let array = self.fold_array_expression(array);
let index = self.fold_field_expression(index);
let inner_type = array.inner_type().clone();
let size = array.size();
match (array.into_inner(), index) {
(ArrayExpressionInner::Value(v), FieldElementExpression::Number(n)) => {
let n_as_usize = n.to_dec_string().parse::<usize>().unwrap();
if n_as_usize < size {
StructExpression::try_from(v[n_as_usize].clone())
.unwrap()
.into_inner()
} else {
unreachable!(
"out of bounds index ({} >= {}) found during static analysis",
n_as_usize, size
);
}
}
(ArrayExpressionInner::Identifier(id), FieldElementExpression::Number(n)) => {
match self.constants.get(&TypedAssignee::Select(
box TypedAssignee::Identifier(Variable::array(
id.clone(),
inner_type.clone(),
size,
)),
box FieldElementExpression::Number(n.clone()).into(),
)) {
Some(e) => match e {
TypedExpression::Struct(e) => e.clone().into_inner(),
_ => unreachable!(""),
},
None => StructExpressionInner::Select(
box ArrayExpressionInner::Identifier(id).annotate(inner_type, size),
box FieldElementExpression::Number(n),
),
}
}
(a, i) => {
StructExpressionInner::Select(box a.annotate(inner_type, size), box i)
}
}
}
StructExpressionInner::IfElse(box condition, box consequence, box alternative) => {
let consequence = self.fold_struct_expression(consequence);
let alternative = self.fold_struct_expression(alternative);
match self.fold_boolean_expression(condition) {
BooleanExpression::Value(true) => consequence.into_inner(),
BooleanExpression::Value(false) => alternative.into_inner(),
c => StructExpressionInner::IfElse(box c, box consequence, box alternative),
}
}
StructExpressionInner::Member(box s, m) => {
let s = self.fold_struct_expression(s);
let members = match s.get_type() {
Type::Struct(members) => members,
_ => unreachable!(),
};
match s.into_inner() {
StructExpressionInner::Value(v) => {
match members.iter().zip(v).find(|(id, _)| id.0 == m).unwrap().1 {
TypedExpression::Struct(s) => s.into_inner(),
_ => unreachable!(),
}
}
inner => StructExpressionInner::Member(box inner.annotate(members), m),
}
}
e => fold_struct_expression_inner(self, ty, e),
}
}
fn fold_boolean_expression(
&mut self,
e: BooleanExpression<'ast, T>,
@ -430,6 +569,24 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> {
c => BooleanExpression::IfElse(box c, box consequence, box alternative),
}
}
BooleanExpression::Member(box s, m) => {
let s = self.fold_struct_expression(s);
let members = match s.get_type() {
Type::Struct(members) => members,
_ => unreachable!(),
};
match s.into_inner() {
StructExpressionInner::Value(v) => {
match members.iter().zip(v).find(|(id, _)| id.0 == m).unwrap().1 {
TypedExpression::Boolean(s) => s,
_ => unreachable!(),
}
}
inner => BooleanExpression::Member(box inner.annotate(members), m),
}
}
e => fold_boolean_expression(self, e),
}
}

View file

@ -5,8 +5,8 @@
//! @date 2018
use crate::typed_absy::folder::*;
use crate::typed_absy::types::{MemberId, Type};
use crate::typed_absy::*;
use crate::types::Type;
use std::collections::HashMap;
use std::collections::HashSet;
use zokrates_field::field::Field;
@ -47,7 +47,7 @@ impl<'ast> Unroller<'ast> {
fn choose_many<T: Field>(
base: TypedExpression<'ast, T>,
indices: Vec<FieldElementExpression<'ast, T>>,
indices: Vec<Access<'ast, T>>,
new_expression: TypedExpression<'ast, T>,
statements: &mut HashSet<TypedStatement<'ast, T>>,
) -> TypedExpression<'ast, T> {
@ -55,131 +55,256 @@ impl<'ast> Unroller<'ast> {
match indices.len() {
0 => new_expression,
_ => {
let base = match base {
TypedExpression::Array(e) => e,
e => unreachable!("can't take an element on a {}", e.get_type()),
};
_ => match base {
TypedExpression::Array(base) => {
let inner_ty = base.inner_type();
let size = base.size();
let inner_ty = base.inner_type();
let size = base.size();
let head = indices.remove(0);
let tail = indices;
let head = indices.pop().unwrap();
let tail = indices;
statements.insert(TypedStatement::Condition(
BooleanExpression::Lt(
box head.clone(),
box FieldElementExpression::Number(T::from(size)),
)
.into(),
BooleanExpression::Value(true).into(),
));
ArrayExpressionInner::Value(
(0..size)
.map(|i| match inner_ty {
Type::Array(..) => ArrayExpression::if_else(
BooleanExpression::Eq(
box FieldElementExpression::Number(T::from(i)),
match head {
Access::Select(head) => {
statements.insert(TypedStatement::Condition(
BooleanExpression::Lt(
box head.clone(),
),
match Self::choose_many(
ArrayExpression::select(
base.clone(),
FieldElementExpression::Number(T::from(i)),
)
.into(),
tail.clone(),
new_expression.clone(),
statements,
) {
TypedExpression::Array(e) => e,
e => unreachable!(
"the interior was expected to be an array, was {}",
e.get_type()
),
},
ArrayExpression::select(
base.clone(),
FieldElementExpression::Number(T::from(i)),
),
box FieldElementExpression::Number(T::from(size)),
)
.into(),
BooleanExpression::Value(true).into(),
));
ArrayExpressionInner::Value(
(0..size)
.map(|i| match inner_ty {
Type::Array(..) => ArrayExpression::if_else(
BooleanExpression::Eq(
box FieldElementExpression::Number(T::from(i)),
box head.clone(),
),
match Self::choose_many(
ArrayExpression::select(
base.clone(),
FieldElementExpression::Number(T::from(i)),
)
.into(),
tail.clone(),
new_expression.clone(),
statements,
) {
TypedExpression::Array(e) => e,
e => unreachable!(
"the interior was expected to be an array, was {}",
e.get_type()
),
},
ArrayExpression::select(
base.clone(),
FieldElementExpression::Number(T::from(i)),
),
)
.into(),
Type::Struct(..) => StructExpression::if_else(
BooleanExpression::Eq(
box FieldElementExpression::Number(T::from(i)),
box head.clone(),
),
match Self::choose_many(
StructExpression::select(
base.clone(),
FieldElementExpression::Number(T::from(i)),
)
.into(),
tail.clone(),
new_expression.clone(),
statements,
) {
TypedExpression::Struct(e) => e,
e => unreachable!(
"the interior was expected to be a struct, was {}",
e.get_type()
),
},
StructExpression::select(
base.clone(),
FieldElementExpression::Number(T::from(i)),
),
)
.into(),
Type::FieldElement => FieldElementExpression::if_else(
BooleanExpression::Eq(
box FieldElementExpression::Number(T::from(i)),
box head.clone(),
),
match Self::choose_many(
FieldElementExpression::select(
base.clone(),
FieldElementExpression::Number(T::from(i)),
)
.into(),
tail.clone(),
new_expression.clone(),
statements,
) {
TypedExpression::FieldElement(e) => e,
e => unreachable!(
"the interior was expected to be a field, was {}",
e.get_type()
),
},
FieldElementExpression::select(
base.clone(),
FieldElementExpression::Number(T::from(i)),
),
)
.into(),
Type::Boolean => BooleanExpression::if_else(
BooleanExpression::Eq(
box FieldElementExpression::Number(T::from(i)),
box head.clone(),
),
match Self::choose_many(
BooleanExpression::select(
base.clone(),
FieldElementExpression::Number(T::from(i)),
)
.into(),
tail.clone(),
new_expression.clone(),
statements,
) {
TypedExpression::Boolean(e) => e,
e => unreachable!(
"the interior was expected to be a boolean, was {}",
e.get_type()
),
},
BooleanExpression::select(
base.clone(),
FieldElementExpression::Number(T::from(i)),
),
)
.into(),
})
.collect(),
)
.into(),
Type::FieldElement => FieldElementExpression::if_else(
BooleanExpression::Eq(
box FieldElementExpression::Number(T::from(i)),
box head.clone(),
),
match Self::choose_many(
FieldElementExpression::select(
base.clone(),
FieldElementExpression::Number(T::from(i)),
)
.into(),
tail.clone(),
new_expression.clone(),
statements,
) {
TypedExpression::FieldElement(e) => e,
e => unreachable!(
"the interior was expected to be a field, was {}",
e.get_type()
),
},
FieldElementExpression::select(
base.clone(),
FieldElementExpression::Number(T::from(i)),
),
)
.into(),
Type::Boolean => BooleanExpression::if_else(
BooleanExpression::Eq(
box FieldElementExpression::Number(T::from(i)),
box head.clone(),
),
match Self::choose_many(
BooleanExpression::select(
base.clone(),
FieldElementExpression::Number(T::from(i)),
)
.into(),
tail.clone(),
new_expression.clone(),
statements,
) {
TypedExpression::Boolean(e) => e,
e => unreachable!(
"the interior was expected to be a boolean, was {}",
e.get_type()
),
},
BooleanExpression::select(
base.clone(),
FieldElementExpression::Number(T::from(i)),
),
)
.into(),
})
.collect(),
)
.annotate(inner_ty.clone(), size)
.into()
}
.annotate(inner_ty.clone(), size)
.into()
}
Access::Member(..) => unreachable!("can't get a member from an array"),
}
}
TypedExpression::Struct(base) => {
let members = match base.get_type() {
Type::Struct(members) => members.clone(),
_ => unreachable!(),
};
let head = indices.remove(0);
let tail = indices;
match head {
Access::Member(head) => StructExpressionInner::Value(
members
.clone()
.into_iter()
.map(|(id, t)| match t {
Type::FieldElement => {
if id == head {
Self::choose_many(
FieldElementExpression::member(
base.clone(),
head.clone(),
)
.into(),
tail.clone(),
new_expression.clone(),
statements,
)
} else {
FieldElementExpression::member(base.clone(), id.clone())
.into()
}
}
Type::Boolean => {
if id == head {
Self::choose_many(
BooleanExpression::member(
base.clone(),
head.clone(),
)
.into(),
tail.clone(),
new_expression.clone(),
statements,
)
} else {
BooleanExpression::member(base.clone(), id.clone())
.into()
}
}
Type::Array(..) => {
if id == head {
Self::choose_many(
ArrayExpression::member(base.clone(), head.clone())
.into(),
tail.clone(),
new_expression.clone(),
statements,
)
} else {
ArrayExpression::member(base.clone(), id.clone()).into()
}
}
Type::Struct(..) => {
if id == head {
Self::choose_many(
StructExpression::member(
base.clone(),
head.clone(),
)
.into(),
tail.clone(),
new_expression.clone(),
statements,
)
} else {
StructExpression::member(base.clone(), id.clone())
.into()
}
}
})
.collect(),
)
.annotate(members)
.into(),
Access::Select(..) => unreachable!("can't get a element from a struct"),
}
}
e => unreachable!("can't make an access on a {}", e.get_type()),
},
}
}
}
/// Turn an assignee into its representation as a base variable and a list of indices
#[derive(Clone, Debug)]
enum Access<'ast, T: Field> {
Select(FieldElementExpression<'ast, T>),
Member(MemberId),
}
/// Turn an assignee into its representation as a base variable and a list accesses
/// a[2][3][4] -> (a, [2, 3, 4])
fn linear<'ast, T: Field>(
a: TypedAssignee<'ast, T>,
) -> (Variable, Vec<FieldElementExpression<'ast, T>>) {
fn linear<'ast, T: Field>(a: TypedAssignee<'ast, T>) -> (Variable, Vec<Access<'ast, T>>) {
match a {
TypedAssignee::Identifier(v) => (v, vec![]),
TypedAssignee::Select(box array, box index) => {
let (v, mut indices) = linear(array);
indices.push(index);
indices.push(Access::Select(index));
(v, indices)
}
TypedAssignee::Member(box s, m) => {
let (v, mut indices) = linear(s);
indices.push(Access::Member(m));
(v, indices)
}
}
@ -206,12 +331,20 @@ impl<'ast, T: Field> Folder<'ast, T> for Unroller<'ast> {
.annotate(ty, size)
.into()
}
Type::Struct(members) => {
StructExpressionInner::Identifier(variable.id.clone().into())
.annotate(members)
.into()
}
};
let base = self.fold_expression(base);
let indices = indices
.into_iter()
.map(|i| self.fold_field_expression(i))
.map(|a| match a {
Access::Select(i) => Access::Select(self.fold_field_expression(i)),
a => a,
})
.collect();
let mut range_checks = HashSet::new();
@ -298,7 +431,12 @@ mod tests {
let index = FieldElementExpression::Number(FieldPrime::from(1));
let a1 = Unroller::choose_many(a0.clone().into(), vec![index], e, &mut HashSet::new());
let a1 = Unroller::choose_many(
a0.clone().into(),
vec![Access::Select(index)],
e,
&mut HashSet::new(),
);
// a[1] = 42
// -> a = [0 == 1 ? 42 : a[0], 1 == 1 ? 42 : a[1], 2 == 1 ? 42 : a[2]]
@ -356,7 +494,7 @@ mod tests {
let a1 = Unroller::choose_many(
a0.clone().into(),
vec![index],
vec![Access::Select(index)],
e.clone().into(),
&mut HashSet::new(),
);
@ -414,8 +552,8 @@ mod tests {
let e = FieldElementExpression::Number(FieldPrime::from(42));
let indices = vec![
FieldElementExpression::Number(FieldPrime::from(0)),
FieldElementExpression::Number(FieldPrime::from(0)),
Access::Select(FieldElementExpression::Number(FieldPrime::from(0))),
Access::Select(FieldElementExpression::Number(FieldPrime::from(0))),
];
let a1 = Unroller::choose_many(
@ -528,7 +666,7 @@ mod tests {
#[cfg(test)]
mod statement {
use super::*;
use crate::types::{FunctionKey, Signature};
use crate::typed_absy::types::{FunctionKey, Signature};
#[test]
fn for_loop() {
@ -709,7 +847,7 @@ mod tests {
#[test]
fn incremental_multiple_definition() {
use crate::types::Type;
use crate::typed_absy::types::Type;
// field a
// a = 2

View file

@ -48,6 +48,7 @@ pub trait Folder<'ast, T: Field>: Sized {
box self.fold_assignee(a),
box self.fold_field_expression(index),
),
TypedAssignee::Member(box s, m) => TypedAssignee::Member(box self.fold_assignee(s), m),
}
}
@ -60,6 +61,7 @@ pub trait Folder<'ast, T: Field>: Sized {
TypedExpression::FieldElement(e) => self.fold_field_expression(e).into(),
TypedExpression::Boolean(e) => self.fold_boolean_expression(e).into(),
TypedExpression::Array(e) => self.fold_array_expression(e).into(),
TypedExpression::Struct(e) => self.fold_struct_expression(e).into(),
}
}
@ -67,6 +69,13 @@ pub trait Folder<'ast, T: Field>: Sized {
fold_array_expression(self, e)
}
fn fold_struct_expression(
&mut self,
e: StructExpression<'ast, T>,
) -> StructExpression<'ast, T> {
fold_struct_expression(self, e)
}
fn fold_expression_list(
&mut self,
es: TypedExpressionList<'ast, T>,
@ -105,6 +114,13 @@ pub trait Folder<'ast, T: Field>: Sized {
) -> ArrayExpressionInner<'ast, T> {
fold_array_expression_inner(self, ty, size, e)
}
fn fold_struct_expression_inner(
&mut self,
ty: &Vec<(MemberId, Type)>,
e: StructExpressionInner<'ast, T>,
) -> StructExpressionInner<'ast, T> {
fold_struct_expression_inner(self, ty, e)
}
}
pub fn fold_module<'ast, T: Field, F: Folder<'ast, T>>(
@ -178,6 +194,10 @@ pub fn fold_array_expression_inner<'ast, T: Field, F: Folder<'ast, T>>(
box f.fold_array_expression(alternative),
)
}
ArrayExpressionInner::Member(box s, id) => {
let s = f.fold_struct_expression(s);
ArrayExpressionInner::Member(box s, id)
}
ArrayExpressionInner::Select(box array, box index) => {
let array = f.fold_array_expression(array);
let index = f.fold_field_expression(index);
@ -186,6 +206,39 @@ pub fn fold_array_expression_inner<'ast, T: Field, F: Folder<'ast, T>>(
}
}
pub fn fold_struct_expression_inner<'ast, T: Field, F: Folder<'ast, T>>(
f: &mut F,
_: &Vec<(MemberId, Type)>,
e: StructExpressionInner<'ast, T>,
) -> StructExpressionInner<'ast, T> {
match e {
StructExpressionInner::Identifier(id) => StructExpressionInner::Identifier(f.fold_name(id)),
StructExpressionInner::Value(exprs) => {
StructExpressionInner::Value(exprs.into_iter().map(|e| f.fold_expression(e)).collect())
}
StructExpressionInner::FunctionCall(id, exps) => {
let exps = exps.into_iter().map(|e| f.fold_expression(e)).collect();
StructExpressionInner::FunctionCall(id, exps)
}
StructExpressionInner::IfElse(box condition, box consequence, box alternative) => {
StructExpressionInner::IfElse(
box f.fold_boolean_expression(condition),
box f.fold_struct_expression(consequence),
box f.fold_struct_expression(alternative),
)
}
StructExpressionInner::Member(box s, id) => {
let s = f.fold_struct_expression(s);
StructExpressionInner::Member(box s, id)
}
StructExpressionInner::Select(box array, box index) => {
let array = f.fold_array_expression(array);
let index = f.fold_field_expression(index);
StructExpressionInner::Select(box array, box index)
}
}
}
pub fn fold_field_expression<'ast, T: Field, F: Folder<'ast, T>>(
f: &mut F,
e: FieldElementExpression<'ast, T>,
@ -230,6 +283,10 @@ pub fn fold_field_expression<'ast, T: Field, F: Folder<'ast, T>>(
let exps = exps.into_iter().map(|e| f.fold_expression(e)).collect();
FieldElementExpression::FunctionCall(key, exps)
}
FieldElementExpression::Member(box s, id) => {
let s = f.fold_struct_expression(s);
FieldElementExpression::Member(box s, id)
}
FieldElementExpression::Select(box array, box index) => {
let array = f.fold_array_expression(array);
let index = f.fold_field_expression(index);
@ -290,6 +347,10 @@ pub fn fold_boolean_expression<'ast, T: Field, F: Folder<'ast, T>>(
let alt = f.fold_boolean_expression(alt);
BooleanExpression::IfElse(box cond, box cons, box alt)
}
BooleanExpression::Member(box s, id) => {
let s = f.fold_struct_expression(s);
BooleanExpression::Member(box s, id)
}
BooleanExpression::Select(box array, box index) => {
let array = f.fold_array_expression(array);
let index = f.fold_field_expression(index);
@ -327,6 +388,16 @@ pub fn fold_array_expression<'ast, T: Field, F: Folder<'ast, T>>(
}
}
pub fn fold_struct_expression<'ast, T: Field, F: Folder<'ast, T>>(
f: &mut F,
e: StructExpression<'ast, T>,
) -> StructExpression<'ast, T> {
StructExpression {
inner: f.fold_struct_expression_inner(&e.ty, e.inner),
..e
}
}
pub fn fold_function_symbol<'ast, T: Field, F: Folder<'ast, T>>(
f: &mut F,
s: TypedFunctionSymbol<'ast, T>,

View file

@ -7,11 +7,13 @@
pub mod folder;
mod parameter;
pub mod types;
mod variable;
pub use crate::typed_absy::parameter::Parameter;
pub use crate::typed_absy::variable::Variable;
use crate::types::{FunctionKey, Signature, Type};
use crate::typed_absy::types::{FunctionKey, MemberId, Signature, Type};
use embed::FlatEmbed;
use std::collections::HashMap;
use std::convert::TryFrom;
@ -72,7 +74,7 @@ impl<'ast, T: Field> fmt::Display for TypedProgram<'ast, T> {
}
}
/// A
/// A typed program as a collection of functions. Types have been resolved during semantic checking.
#[derive(PartialEq, Clone)]
pub struct TypedModule<'ast, T: Field> {
/// Functions of the program
@ -239,6 +241,7 @@ pub enum TypedAssignee<'ast, T: Field> {
Box<TypedAssignee<'ast, T>>,
Box<FieldElementExpression<'ast, T>>,
),
Member(Box<TypedAssignee<'ast, T>>, MemberId),
}
impl<'ast, T: Field> Typed for TypedAssignee<'ast, T> {
@ -252,6 +255,15 @@ impl<'ast, T: Field> Typed for TypedAssignee<'ast, T> {
_ => unreachable!("an array element should only be defined over arrays"),
}
}
TypedAssignee::Member(ref s, ref m) => {
let s_type = s.get_type();
match s_type {
Type::Struct(members) => {
members.iter().find(|(id, _)| id == m).unwrap().1.clone()
}
_ => unreachable!("a struct access should only be defined over structs"),
}
}
}
}
}
@ -261,6 +273,7 @@ impl<'ast, T: Field> fmt::Debug for TypedAssignee<'ast, T> {
match *self {
TypedAssignee::Identifier(ref s) => write!(f, "{}", s.id),
TypedAssignee::Select(ref a, ref e) => write!(f, "{}[{}]", a, e),
TypedAssignee::Member(ref s, ref m) => write!(f, "{}.{}", s, m),
}
}
}
@ -362,6 +375,7 @@ pub enum TypedExpression<'ast, T: Field> {
Boolean(BooleanExpression<'ast, T>),
FieldElement(FieldElementExpression<'ast, T>),
Array(ArrayExpression<'ast, T>),
Struct(StructExpression<'ast, T>),
}
impl<'ast, T: Field> From<BooleanExpression<'ast, T>> for TypedExpression<'ast, T> {
@ -382,12 +396,19 @@ impl<'ast, T: Field> From<ArrayExpression<'ast, T>> for TypedExpression<'ast, T>
}
}
impl<'ast, T: Field> From<StructExpression<'ast, T>> for TypedExpression<'ast, T> {
fn from(e: StructExpression<'ast, T>) -> TypedExpression<T> {
TypedExpression::Struct(e)
}
}
impl<'ast, T: Field> fmt::Display for TypedExpression<'ast, T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
TypedExpression::Boolean(ref e) => write!(f, "{}", e),
TypedExpression::FieldElement(ref e) => write!(f, "{}", e),
TypedExpression::Array(ref e) => write!(f, "{}", e.inner),
TypedExpression::Array(ref e) => write!(f, "{}", e),
TypedExpression::Struct(ref s) => write!(f, "{}", s),
}
}
}
@ -398,6 +419,7 @@ impl<'ast, T: Field> fmt::Debug for TypedExpression<'ast, T> {
TypedExpression::Boolean(ref e) => write!(f, "{:?}", e),
TypedExpression::FieldElement(ref e) => write!(f, "{:?}", e),
TypedExpression::Array(ref e) => write!(f, "{:?}", e),
TypedExpression::Struct(ref s) => write!(f, "{}", s),
}
}
}
@ -414,12 +436,57 @@ impl<'ast, T: Field> fmt::Debug for ArrayExpression<'ast, T> {
}
}
impl<'ast, T: Field> fmt::Display for StructExpression<'ast, T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self.inner {
StructExpressionInner::Identifier(ref var) => write!(f, "{}", var),
StructExpressionInner::Value(ref values) => write!(
f,
"{{{}}}",
self.ty
.iter()
.map(|(id, _)| id)
.zip(values.iter())
.map(|(id, o)| format!("{}: {}", id, o.to_string()))
.collect::<Vec<String>>()
.join(", ")
),
StructExpressionInner::FunctionCall(ref key, ref p) => {
write!(f, "{}(", key.id,)?;
for (i, param) in p.iter().enumerate() {
write!(f, "{}", param)?;
if i < p.len() - 1 {
write!(f, ", ")?;
}
}
write!(f, ")")
}
StructExpressionInner::IfElse(ref condition, ref consequent, ref alternative) => {
write!(
f,
"if {} then {} else {} fi",
condition, consequent, alternative
)
}
StructExpressionInner::Member(ref struc, ref id) => write!(f, "{}.{}", struc, id),
StructExpressionInner::Select(ref id, ref index) => write!(f, "{}[{}]", id, index),
}
}
}
impl<'ast, T: Field> fmt::Debug for StructExpression<'ast, T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{:?}", self.inner)
}
}
impl<'ast, T: Field> Typed for TypedExpression<'ast, T> {
fn get_type(&self) -> Type {
match *self {
TypedExpression::Boolean(ref e) => e.get_type(),
TypedExpression::FieldElement(ref e) => e.get_type(),
TypedExpression::Array(ref e) => e.get_type(),
TypedExpression::Struct(ref s) => s.get_type(),
}
}
}
@ -430,6 +497,12 @@ impl<'ast, T: Field> Typed for ArrayExpression<'ast, T> {
}
}
impl<'ast, T: Field> Typed for StructExpression<'ast, T> {
fn get_type(&self) -> Type {
Type::Struct(self.ty.clone())
}
}
impl<'ast, T: Field> Typed for FieldElementExpression<'ast, T> {
fn get_type(&self) -> Type {
Type::FieldElement
@ -490,6 +563,7 @@ pub enum FieldElementExpression<'ast, T: Field> {
Box<FieldElementExpression<'ast, T>>,
),
FunctionCall(FunctionKey<'ast>, Vec<TypedExpression<'ast, T>>),
Member(Box<StructExpression<'ast, T>>, MemberId),
Select(
Box<ArrayExpression<'ast, T>>,
Box<FieldElementExpression<'ast, T>>,
@ -535,6 +609,7 @@ pub enum BooleanExpression<'ast, T: Field> {
Box<BooleanExpression<'ast, T>>,
Box<BooleanExpression<'ast, T>>,
),
Member(Box<StructExpression<'ast, T>>, MemberId),
Select(
Box<ArrayExpression<'ast, T>>,
Box<FieldElementExpression<'ast, T>>,
@ -563,6 +638,7 @@ pub enum ArrayExpressionInner<'ast, T: Field> {
Box<ArrayExpression<'ast, T>>,
Box<ArrayExpression<'ast, T>>,
),
Member(Box<StructExpression<'ast, T>>, MemberId),
Select(
Box<ArrayExpression<'ast, T>>,
Box<FieldElementExpression<'ast, T>>,
@ -597,6 +673,49 @@ impl<'ast, T: Field> ArrayExpression<'ast, T> {
}
}
#[derive(Clone, PartialEq, Hash, Eq)]
pub struct StructExpression<'ast, T: Field> {
ty: Vec<(MemberId, Type)>,
inner: StructExpressionInner<'ast, T>,
}
impl<'ast, T: Field> StructExpression<'ast, T> {
pub fn ty(&self) -> &Vec<(MemberId, Type)> {
&self.ty
}
pub fn as_inner(&self) -> &StructExpressionInner<'ast, T> {
&self.inner
}
pub fn into_inner(self) -> StructExpressionInner<'ast, T> {
self.inner
}
}
#[derive(Clone, PartialEq, Hash, Eq)]
pub enum StructExpressionInner<'ast, T: Field> {
Identifier(Identifier<'ast>),
Value(Vec<TypedExpression<'ast, T>>),
FunctionCall(FunctionKey<'ast>, Vec<TypedExpression<'ast, T>>),
IfElse(
Box<BooleanExpression<'ast, T>>,
Box<StructExpression<'ast, T>>,
Box<StructExpression<'ast, T>>,
),
Member(Box<StructExpression<'ast, T>>, MemberId),
Select(
Box<ArrayExpression<'ast, T>>,
Box<FieldElementExpression<'ast, T>>,
),
}
impl<'ast, T: Field> StructExpressionInner<'ast, T> {
pub fn annotate(self, ty: Vec<(MemberId, Type)>) -> StructExpression<'ast, T> {
StructExpression { ty, inner: self }
}
}
// Downcasts
// Due to the fact that we keep TypedExpression simple, we end up with ArrayExpressionInner::Value whose elements are any TypedExpression, but we enforce by
// construction that these elements are of the type declared in the corresponding ArrayExpression. As we know this by construction, we can downcast the TypedExpression to the correct type
@ -636,6 +755,17 @@ impl<'ast, T: Field> TryFrom<TypedExpression<'ast, T>> for ArrayExpression<'ast,
}
}
impl<'ast, T: Field> TryFrom<TypedExpression<'ast, T>> for StructExpression<'ast, T> {
type Error = ();
fn try_from(te: TypedExpression<'ast, T>) -> Result<StructExpression<'ast, T>, Self::Error> {
match te {
TypedExpression::Struct(e) => Ok(e),
_ => Err(()),
}
}
}
impl<'ast, T: Field> fmt::Display for FieldElementExpression<'ast, T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
@ -663,6 +793,7 @@ impl<'ast, T: Field> fmt::Display for FieldElementExpression<'ast, T> {
}
write!(f, ")")
}
FieldElementExpression::Member(ref struc, ref id) => write!(f, "{}.{}", struc, id),
FieldElementExpression::Select(ref id, ref index) => write!(f, "{}[{}]", id, index),
}
}
@ -686,6 +817,7 @@ impl<'ast, T: Field> fmt::Display for BooleanExpression<'ast, T> {
"if {} then {} else {} fi",
condition, consequent, alternative
),
BooleanExpression::Member(ref struc, ref id) => write!(f, "{}.{}", struc, id),
BooleanExpression::Select(ref id, ref index) => write!(f, "{}[{}]", id, index),
}
}
@ -719,6 +851,7 @@ impl<'ast, T: Field> fmt::Display for ArrayExpressionInner<'ast, T> {
"if {} then {} else {} fi",
condition, consequent, alternative
),
ArrayExpressionInner::Member(ref s, ref id) => write!(f, "{}.{}", s, id),
ArrayExpressionInner::Select(ref id, ref index) => write!(f, "{}[{}]", id, index),
}
}
@ -754,6 +887,9 @@ impl<'ast, T: Field> fmt::Debug for FieldElementExpression<'ast, T> {
f.debug_list().entries(p.iter()).finish()?;
write!(f, ")")
}
FieldElementExpression::Member(ref struc, ref id) => {
write!(f, "Member({:?}, {:?})", struc, id)
}
FieldElementExpression::Select(ref id, ref index) => {
write!(f, "Select({:?}, {:?})", id, index)
}
@ -776,6 +912,9 @@ impl<'ast, T: Field> fmt::Debug for ArrayExpressionInner<'ast, T> {
"IfElse({:?}, {:?}, {:?})",
condition, consequent, alternative
),
ArrayExpressionInner::Member(ref struc, ref id) => {
write!(f, "Member({:?}, {:?})", struc, id)
}
ArrayExpressionInner::Select(ref id, ref index) => {
write!(f, "Select({:?}, {:?})", id, index)
}
@ -783,6 +922,33 @@ impl<'ast, T: Field> fmt::Debug for ArrayExpressionInner<'ast, T> {
}
}
impl<'ast, T: Field> fmt::Debug for StructExpressionInner<'ast, T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
StructExpressionInner::Identifier(ref var) => write!(f, "{:?}", var),
StructExpressionInner::Value(ref values) => write!(f, "{:?}", values),
StructExpressionInner::FunctionCall(ref i, ref p) => {
write!(f, "FunctionCall({:?}, (", i)?;
f.debug_list().entries(p.iter()).finish()?;
write!(f, ")")
}
StructExpressionInner::IfElse(ref condition, ref consequent, ref alternative) => {
write!(
f,
"IfElse({:?}, {:?}, {:?})",
condition, consequent, alternative
)
}
StructExpressionInner::Member(ref struc, ref id) => {
write!(f, "Member({:?}, {:?})", struc, id)
}
StructExpressionInner::Select(ref id, ref index) => {
write!(f, "Select({:?}, {:?})", id, index)
}
}
}
}
impl<'ast, T: Field> fmt::Display for TypedExpressionList<'ast, T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
@ -852,6 +1018,17 @@ impl<'ast, T: Field> IfElse<'ast, T> for ArrayExpression<'ast, T> {
}
}
impl<'ast, T: Field> IfElse<'ast, T> for StructExpression<'ast, T> {
fn if_else(
condition: BooleanExpression<'ast, T>,
consequence: Self,
alternative: Self,
) -> Self {
let ty = consequence.ty().clone();
StructExpressionInner::IfElse(box condition, box consequence, box alternative).annotate(ty)
}
}
pub trait Select<'ast, T: Field> {
fn select(array: ArrayExpression<'ast, T>, index: FieldElementExpression<'ast, T>) -> Self;
}
@ -878,3 +1055,68 @@ impl<'ast, T: Field> Select<'ast, T> for ArrayExpression<'ast, T> {
ArrayExpressionInner::Select(box array, box index).annotate(*ty, size)
}
}
impl<'ast, T: Field> Select<'ast, T> for StructExpression<'ast, T> {
fn select(array: ArrayExpression<'ast, T>, index: FieldElementExpression<'ast, T>) -> Self {
let members = match array.inner_type().clone() {
Type::Struct(members) => members,
_ => unreachable!(),
};
StructExpressionInner::Select(box array, box index).annotate(members)
}
}
pub trait Member<'ast, T: Field> {
fn member(s: StructExpression<'ast, T>, member_id: MemberId) -> Self;
}
impl<'ast, T: Field> Member<'ast, T> for FieldElementExpression<'ast, T> {
fn member(s: StructExpression<'ast, T>, member_id: MemberId) -> Self {
FieldElementExpression::Member(box s, member_id)
}
}
impl<'ast, T: Field> Member<'ast, T> for BooleanExpression<'ast, T> {
fn member(s: StructExpression<'ast, T>, member_id: MemberId) -> Self {
BooleanExpression::Member(box s, member_id)
}
}
impl<'ast, T: Field> Member<'ast, T> for ArrayExpression<'ast, T> {
fn member(s: StructExpression<'ast, T>, member_id: MemberId) -> Self {
let members = s.ty().clone();
let ty = members
.into_iter()
.find(|(id, _)| *id == member_id)
.unwrap()
.1;
let (ty, size) = match ty {
Type::Array(box ty, size) => (ty, size),
_ => unreachable!(),
};
ArrayExpressionInner::Member(box s, member_id).annotate(ty, size)
}
}
impl<'ast, T: Field> Member<'ast, T> for StructExpression<'ast, T> {
fn member(s: StructExpression<'ast, T>, member_id: MemberId) -> Self {
let members = s.ty().clone();
let ty = members
.into_iter()
.find(|(id, _)| *id == member_id)
.unwrap()
.1;
let members = match ty {
Type::Struct(members) => members,
_ => unreachable!(),
};
StructExpressionInner::Member(box s, member_id).annotate(members)
}
}

View file

@ -1,4 +1,3 @@
use crate::absy;
use crate::typed_absy::Variable;
use std::fmt;
@ -30,12 +29,3 @@ impl<'ast> fmt::Debug for Parameter<'ast> {
write!(f, "Parameter(variable: {:?})", self.id)
}
}
impl<'ast> From<absy::Parameter<'ast>> for Parameter<'ast> {
fn from(p: absy::Parameter<'ast>) -> Parameter {
Parameter {
private: p.private,
id: p.id.value.into(),
}
}
}

View file

@ -0,0 +1,290 @@
use std::fmt;
pub type Identifier<'ast> = &'ast str;
pub type MemberId = String;
#[derive(Clone, PartialEq, Eq, Hash, Serialize, Deserialize, PartialOrd, Ord)]
pub enum Type {
FieldElement,
Boolean,
Array(Box<Type>, usize),
Struct(Vec<(MemberId, Type)>),
}
impl fmt::Display for Type {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
Type::FieldElement => write!(f, "field"),
Type::Boolean => write!(f, "bool"),
Type::Array(ref ty, ref size) => write!(f, "{}[{}]", ty, size),
Type::Struct(ref members) => write!(
f,
"{{{}}}",
members
.iter()
.map(|(id, t)| format!("{}: {}", id, t))
.collect::<Vec<_>>()
.join(", ")
),
}
}
}
impl fmt::Debug for Type {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
Type::FieldElement => write!(f, "field"),
Type::Boolean => write!(f, "bool"),
Type::Array(ref ty, ref size) => write!(f, "{}[{}]", ty, size),
Type::Struct(ref members) => write!(
f,
"{{{}}}",
members
.iter()
.map(|(id, t)| format!("{}: {}", id, t))
.collect::<Vec<_>>()
.join(", ")
),
}
}
}
impl Type {
pub fn array(ty: Type, size: usize) -> Self {
Type::Array(box ty, size)
}
fn to_slug(&self) -> String {
match self {
Type::FieldElement => String::from("f"),
Type::Boolean => String::from("b"),
Type::Array(box ty, size) => format!("{}[{}]", ty.to_slug(), size),
Type::Struct(members) => format!(
"{{{}}}",
members
.iter()
.map(|(id, ty)| format!("{}:{}", id, ty))
.collect::<Vec<_>>()
.join(",")
),
}
}
// the number of field elements the type maps to
pub fn get_primitive_count(&self) -> usize {
match self {
Type::FieldElement => 1,
Type::Boolean => 1,
Type::Array(ty, size) => size * ty.get_primitive_count(),
Type::Struct(members) => members.iter().map(|(_, t)| t.get_primitive_count()).sum(),
}
}
}
pub type FunctionIdentifier<'ast> = &'ast str;
#[derive(PartialEq, Eq, Hash, Debug, Clone)]
pub struct FunctionKey<'ast> {
pub id: FunctionIdentifier<'ast>,
pub signature: Signature,
}
impl<'ast> FunctionKey<'ast> {
pub fn with_id<S: Into<Identifier<'ast>>>(id: S) -> Self {
FunctionKey {
id: id.into(),
signature: Signature::new(),
}
}
pub fn signature(mut self, signature: Signature) -> Self {
self.signature = signature;
self
}
pub fn id<S: Into<Identifier<'ast>>>(mut self, id: S) -> Self {
self.id = id.into();
self
}
pub fn to_slug(&self) -> String {
format!("{}_{}", self.id, self.signature.to_slug())
}
}
pub use self::signature::Signature;
pub mod signature {
use super::*;
use std::fmt;
#[derive(Clone, PartialEq, Eq, Hash, Serialize, Deserialize, Ord, PartialOrd)]
pub struct Signature {
pub inputs: Vec<Type>,
pub outputs: Vec<Type>,
}
impl fmt::Debug for Signature {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(
f,
"Signature(inputs: {:?}, outputs: {:?})",
self.inputs, self.outputs
)
}
}
impl fmt::Display for Signature {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "(")?;
for (i, t) in self.inputs.iter().enumerate() {
write!(f, "{}", t)?;
if i < self.inputs.len() - 1 {
write!(f, ", ")?;
}
}
write!(f, ") -> (")?;
for (i, t) in self.outputs.iter().enumerate() {
write!(f, "{}", t)?;
if i < self.outputs.len() - 1 {
write!(f, ", ")?;
}
}
write!(f, ")")
}
}
impl Signature {
/// Returns a slug for a signature, with the following encoding:
/// i{inputs}o{outputs} where {inputs} and {outputs} each encode a list of types.
/// A list of types is encoded by compressing sequences of the same type like so:
///
/// [field, field, field] -> 3f
/// [field] -> f
/// [field, bool, field] -> fbf
/// [field, field, bool, field] -> 2fbf
///
pub fn to_slug(&self) -> String {
let to_slug = |types| {
let mut res = vec![];
for t in types {
let len = res.len();
if len == 0 {
res.push((1, t))
} else {
if res[len - 1].1 == t {
res[len - 1].0 += 1;
} else {
res.push((1, t))
}
}
}
res.into_iter()
.map(|(n, t): (usize, &Type)| {
let mut r = String::new();
if n > 1 {
r.push_str(&format!("{}", n));
}
r.push_str(&t.to_slug());
r
})
.fold(String::new(), |mut acc, e| {
acc.push_str(&e);
acc
})
};
format!("i{}o{}", to_slug(&self.inputs), to_slug(&self.outputs))
}
pub fn new() -> Signature {
Signature {
inputs: vec![],
outputs: vec![],
}
}
pub fn inputs(mut self, inputs: Vec<Type>) -> Self {
self.inputs = inputs;
self
}
pub fn outputs(mut self, outputs: Vec<Type>) -> Self {
self.outputs = outputs;
self
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn signature() {
let s = Signature::new()
.inputs(vec![Type::FieldElement, Type::Boolean])
.outputs(vec![Type::Boolean]);
assert_eq!(s.to_string(), String::from("(field, bool) -> (bool)"));
}
#[test]
fn slug_0() {
let s = Signature::new().inputs(vec![]).outputs(vec![]);
assert_eq!(s.to_slug(), String::from("io"));
}
#[test]
fn slug_1() {
let s = Signature::new()
.inputs(vec![Type::FieldElement, Type::Boolean])
.outputs(vec![
Type::FieldElement,
Type::FieldElement,
Type::Boolean,
Type::FieldElement,
]);
assert_eq!(s.to_slug(), String::from("ifbo2fbf"));
}
#[test]
fn slug_2() {
let s = Signature::new()
.inputs(vec![
Type::FieldElement,
Type::FieldElement,
Type::FieldElement,
])
.outputs(vec![Type::FieldElement, Type::Boolean, Type::FieldElement]);
assert_eq!(s.to_slug(), String::from("i3fofbf"));
}
#[test]
fn array_slug() {
let s = Signature::new()
.inputs(vec![
Type::array(Type::FieldElement, 42),
Type::array(Type::FieldElement, 21),
])
.outputs(vec![]);
assert_eq!(s.to_slug(), String::from("if[42]f[21]o"));
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn array() {
let t = Type::Array(box Type::FieldElement, 42);
assert_eq!(t.get_primitive_count(), 42);
}
}

View file

@ -1,6 +1,5 @@
use crate::absy;
use crate::typed_absy::types::{MemberId, Type};
use crate::typed_absy::Identifier;
use crate::types::Type;
use std::fmt;
#[derive(Clone, PartialEq, Hash, Eq)]
@ -27,6 +26,10 @@ impl<'ast> Variable<'ast> {
Self::with_id_and_type(id, Type::array(ty, size))
}
pub fn struc(id: Identifier<'ast>, ty: Vec<(MemberId, Type)>) -> Variable<'ast> {
Self::with_id_and_type(id, Type::Struct(ty))
}
pub fn with_id_and_type(id: Identifier<'ast>, _type: Type) -> Variable<'ast> {
Variable { id, _type }
}
@ -48,15 +51,15 @@ impl<'ast> fmt::Debug for Variable<'ast> {
}
}
impl<'ast> From<absy::Variable<'ast>> for Variable<'ast> {
fn from(v: absy::Variable) -> Variable {
Variable::with_id_and_type(
Identifier {
id: v.id,
version: 0,
stack: vec![],
},
v._type,
)
}
}
// impl<'ast> From<absy::Variable<'ast>> for Variable<'ast> {
// fn from(v: absy::Variable) -> Variable {
// Variable::with_id_and_type(
// Identifier {
// id: v.id,
// version: 0,
// stack: vec![],
// },
// v._type,
// )
// }
// }

View file

@ -1,111 +0,0 @@
pub use crate::types::signature::Signature;
use std::fmt;
pub type Identifier<'ast> = &'ast str;
mod signature;
#[derive(Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum Type {
FieldElement,
Boolean,
Array(Box<Type>, usize),
}
impl fmt::Display for Type {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
Type::FieldElement => write!(f, "field"),
Type::Boolean => write!(f, "bool"),
Type::Array(ref ty, ref size) => write!(f, "{}[{}]", ty, size),
}
}
}
impl fmt::Debug for Type {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
Type::FieldElement => write!(f, "field"),
Type::Boolean => write!(f, "bool"),
Type::Array(ref ty, ref size) => write!(f, "{}[{}]", ty, size),
}
}
}
impl Type {
pub fn array(ty: Type, size: usize) -> Self {
Type::Array(box ty, size)
}
fn to_slug(&self) -> String {
match self {
Type::FieldElement => String::from("f"),
Type::Boolean => String::from("b"),
Type::Array(box ty, size) => format!("{}[{}]", ty.to_slug(), size),
}
}
// the number of field elements the type maps to
pub fn get_primitive_count(&self) -> usize {
match self {
Type::FieldElement => 1,
Type::Boolean => 1,
Type::Array(ty, size) => size * ty.get_primitive_count(),
}
}
}
#[derive(Clone, PartialEq, Hash, Eq)]
pub struct Variable<'ast> {
pub id: Identifier<'ast>,
pub _type: Type,
}
impl<'ast> fmt::Display for Variable<'ast> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{} {}", self._type, self.id,)
}
}
impl<'ast> fmt::Debug for Variable<'ast> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "Variable(type: {:?}, id: {:?})", self._type, self.id,)
}
}
pub type FunctionIdentifier<'ast> = &'ast str;
#[derive(PartialEq, Eq, Hash, Debug, Clone)]
pub struct FunctionKey<'ast> {
pub id: FunctionIdentifier<'ast>,
pub signature: Signature,
}
impl<'ast> FunctionKey<'ast> {
pub fn with_id<S: Into<Identifier<'ast>>>(id: S) -> Self {
FunctionKey {
id: id.into(),
signature: Signature::new(),
}
}
pub fn signature(mut self, signature: Signature) -> Self {
self.signature = signature;
self
}
pub fn to_slug(&self) -> String {
format!("{}_{}", self.id, self.signature.to_slug())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn array() {
let t = Type::Array(box Type::FieldElement, 42);
assert_eq!(t.get_primitive_count(), 42);
}
}

View file

@ -1,160 +0,0 @@
use crate::types::Type;
use std::fmt;
#[derive(Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct Signature {
pub inputs: Vec<Type>,
pub outputs: Vec<Type>,
}
impl fmt::Debug for Signature {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(
f,
"Signature(inputs: {:?}, outputs: {:?})",
self.inputs, self.outputs
)
}
}
impl fmt::Display for Signature {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "(")?;
for (i, t) in self.inputs.iter().enumerate() {
write!(f, "{}", t)?;
if i < self.inputs.len() - 1 {
write!(f, ", ")?;
}
}
write!(f, ") -> (")?;
for (i, t) in self.outputs.iter().enumerate() {
write!(f, "{}", t)?;
if i < self.outputs.len() - 1 {
write!(f, ", ")?;
}
}
write!(f, ")")
}
}
impl Signature {
/// Returns a slug for a signature, with the following encoding:
/// i{inputs}o{outputs} where {inputs} and {outputs} each encode a list of types.
/// A list of types is encoded by compressing sequences of the same type like so:
///
/// [field, field, field] -> 3f
/// [field] -> f
/// [field, bool, field] -> fbf
/// [field, field, bool, field] -> 2fbf
///
pub fn to_slug(&self) -> String {
let to_slug = |types| {
let mut res = vec![];
for t in types {
let len = res.len();
if len == 0 {
res.push((1, t))
} else {
if res[len - 1].1 == t {
res[len - 1].0 += 1;
} else {
res.push((1, t))
}
}
}
res.into_iter()
.map(|(n, t): (usize, &Type)| {
let mut r = String::new();
if n > 1 {
r.push_str(&format!("{}", n));
}
r.push_str(&t.to_slug());
r
})
.fold(String::new(), |mut acc, e| {
acc.push_str(&e);
acc
})
};
format!("i{}o{}", to_slug(&self.inputs), to_slug(&self.outputs))
}
pub fn new() -> Signature {
Signature {
inputs: vec![],
outputs: vec![],
}
}
pub fn inputs(mut self, inputs: Vec<Type>) -> Self {
self.inputs = inputs;
self
}
pub fn outputs(mut self, outputs: Vec<Type>) -> Self {
self.outputs = outputs;
self
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn signature() {
let s = Signature::new()
.inputs(vec![Type::FieldElement, Type::Boolean])
.outputs(vec![Type::Boolean]);
assert_eq!(s.to_string(), String::from("(field, bool) -> (bool)"));
}
#[test]
fn slug_0() {
let s = Signature::new().inputs(vec![]).outputs(vec![]);
assert_eq!(s.to_slug(), String::from("io"));
}
#[test]
fn slug_1() {
let s = Signature::new()
.inputs(vec![Type::FieldElement, Type::Boolean])
.outputs(vec![
Type::FieldElement,
Type::FieldElement,
Type::Boolean,
Type::FieldElement,
]);
assert_eq!(s.to_slug(), String::from("ifbo2fbf"));
}
#[test]
fn slug_2() {
let s = Signature::new()
.inputs(vec![
Type::FieldElement,
Type::FieldElement,
Type::FieldElement,
])
.outputs(vec![Type::FieldElement, Type::Boolean, Type::FieldElement]);
assert_eq!(s.to_slug(), String::from("i3fofbf"));
}
#[test]
fn array_slug() {
let s = Signature::new()
.inputs(vec![
Type::array(Type::FieldElement, 42),
Type::array(Type::FieldElement, 21),
])
.outputs(vec![]);
assert_eq!(s.to_slug(), String::from("if[42]f[21]o"));
}
}

View file

@ -0,0 +1,2 @@
def main(bool[3] a) -> (bool[3]):
return a

View file

@ -0,0 +1,38 @@
{
"entry_point": "./tests/tests/arrays/identity.code",
"tests": [
{
"input": {
"values": ["0", "0", "0"]
},
"output": {
"Ok": {
"values": ["0", "0", "0"]
}
}
},
{
"input": {
"values": ["1", "0", "1"]
},
"output": {
"Ok": {
"values": ["1", "0", "1"]
}
}
},
{
"input": {
"values": ["2", "1", "1"]
},
"output": {
"Err": {
"UnsatisfiedConstraint": {
"left": "4",
"right": "2"
}
}
}
}
]
}

View file

@ -0,0 +1,7 @@
struct A {
field a
bool b
}
def main(A a) -> (A):
return a

View file

@ -0,0 +1,38 @@
{
"entry_point": "./tests/tests/structs/identity.code",
"tests": [
{
"input": {
"values": ["42", "0"]
},
"output": {
"Ok": {
"values": ["42", "0"]
}
}
},
{
"input": {
"values": ["42", "1"]
},
"output": {
"Ok": {
"values": ["42", "1"]
}
}
},
{
"input": {
"values": ["42", "3"]
},
"output": {
"Err": {
"UnsatisfiedConstraint": {
"left": "9",
"right": "3"
}
}
}
}
]
}

View file

@ -113,6 +113,26 @@ mod tests {
};
}
#[test]
fn parse_single_def_to_multi() {
parses_to! {
parser: ZoKratesParser,
input: r#"a = foo()
"#,
rule: Rule::statement,
tokens: [
statement(0, 22, [
multi_assignment_statement(0, 9, [
optionally_typed_identifier(0, 1, [
identifier(0, 1)
]),
identifier(4, 7),
])
])
]
};
}
#[test]
fn parse_invalid_identifier() {
fails_with! {
@ -125,6 +145,50 @@ mod tests {
};
}
#[test]
fn parse_struct_def() {
parses_to! {
parser: ZoKratesParser,
input: "struct Foo { field foo\n field[2] bar }
",
rule: Rule::ty_struct_definition,
tokens: [
ty_struct_definition(0, 39, [
identifier(7, 10),
struct_field(13, 22, [
ty(13, 18, [
ty_basic(13, 18, [
ty_field(13, 18)
])
]),
identifier(19, 22)
]),
struct_field(24, 36, [
ty(24, 33, [
ty_array(24, 33, [
ty_basic_or_struct(24, 29, [
ty_basic(24, 29, [
ty_field(24, 29)
])
]),
expression(30, 31, [
term(30, 31, [
primary_expression(30, 31, [
constant(30, 31, [
decimal_number(30, 31)
])
])
])
])
])
]),
identifier(33, 36)
])
])
]
};
}
#[test]
fn parse_invalid_identifier_because_keyword() {
fails_with! {

View file

@ -3,8 +3,11 @@
* Author: Jacob Eberhardt, Thibaut Schaeffer
*/
file = { SOI ~ NEWLINE* ~ import_directive* ~ NEWLINE* ~ function_definition* ~ EOI }
import_directive = {"import" ~ "\"" ~ import_source ~ "\"" ~ ("as" ~ identifier)? ~ NEWLINE+}
file = { SOI ~ NEWLINE* ~ import_directive* ~ NEWLINE* ~ ty_struct_definition* ~ NEWLINE* ~ function_definition* ~ EOI }
import_directive = { main_import_directive | from_import_directive }
from_import_directive = { "from" ~ "\"" ~ import_source ~ "\"" ~ "import" ~ identifier ~ ("as" ~ identifier)? ~ NEWLINE*}
main_import_directive = {"import" ~ "\"" ~ import_source ~ "\"" ~ ("as" ~ identifier)? ~ NEWLINE+}
import_source = @{(!"\"" ~ ANY)*}
function_definition = {"def" ~ identifier ~ "(" ~ parameter_list ~ ")" ~ "->" ~ "(" ~ type_list ~ ")" ~ ":" ~ NEWLINE* ~ statement* }
@ -15,10 +18,16 @@ parameter = {vis? ~ ty ~ identifier}
ty_field = {"field"}
ty_bool = {"bool"}
ty_basic = { ty_field | ty_bool }
// (unidimensional for now) arrays of (basic for now) types
ty_array = { ty_basic ~ ("[" ~ expression ~ "]")+ }
ty = { ty_array | ty_basic }
ty_basic_or_struct = { ty_basic | ty_struct }
ty_array = { ty_basic_or_struct ~ ("[" ~ expression ~ "]")+ }
ty = { ty_array | ty_basic | ty_struct }
type_list = _{(ty ~ ("," ~ ty)*)?}
// structs
ty_struct = { identifier }
// type definitions
ty_struct_definition = { "struct" ~ identifier ~ "{" ~ NEWLINE* ~ struct_field_list ~ NEWLINE* ~ "}" ~ NEWLINE* }
struct_field_list = _{(struct_field ~ (NEWLINE+ ~ struct_field)*)? }
struct_field = { ty ~ identifier }
vis_private = {"private"}
vis_public = {"public"}
@ -42,13 +51,13 @@ assignment_statement = {assignee ~ "=" ~ expression } // TODO: Is this optimal?
expression_statement = {expression}
optionally_typed_identifier_list = _{ optionally_typed_identifier ~ ("," ~ optionally_typed_identifier)* }
optionally_typed_identifier = { ty? ~ identifier }
optionally_typed_identifier = { (identifier) | (ty ~ identifier) } // we don't use { ty? ~ identifier } as with a single token, it gets parsed as `ty` but we want `identifier`
// Expressions
expression_list = _{(expression ~ ("," ~ expression)*)?}
expression = { term ~ (op_binary ~ term)* }
term = { ("(" ~ expression ~ ")") | conditional_expression | postfix_expression | primary_expression | inline_array_expression | array_initializer_expression | unary_expression }
term = { ("(" ~ expression ~ ")") | inline_struct_expression | conditional_expression | postfix_expression | primary_expression | inline_array_expression | array_initializer_expression | unary_expression }
spread = { "..." ~ expression }
range = { from_expression? ~ ".." ~ to_expression? }
from_expression = { expression }
@ -57,16 +66,21 @@ to_expression = { expression }
conditional_expression = { "if" ~ expression ~ "then" ~ expression ~ "else" ~ expression ~ "fi"}
postfix_expression = { identifier ~ access+ } // we force there to be at least one access, otherwise this matches single identifiers. Not sure that's what we want.
access = { array_access | call_access }
access = { array_access | call_access | member_access }
array_access = { "[" ~ range_or_expression ~ "]" }
call_access = { "(" ~ expression_list ~ ")" }
member_access = { "." ~ identifier }
primary_expression = { identifier
| constant
}
inline_array_expression = { "[" ~ inline_array_inner ~ "]" }
inline_array_inner = _{(spread_or_expression ~ ("," ~ spread_or_expression)*)?}
inline_struct_expression = { identifier ~ "{" ~ NEWLINE* ~ inline_struct_member_list ~ NEWLINE* ~ "}" }
inline_struct_member_list = _{(inline_struct_member ~ ("," ~ NEWLINE* ~ inline_struct_member)*)? ~ ","? }
inline_struct_member = { identifier ~ ":" ~ expression }
inline_array_expression = { "[" ~ NEWLINE* ~ inline_array_inner ~ NEWLINE* ~ "]" }
inline_array_inner = _{(spread_or_expression ~ ("," ~ NEWLINE* ~ spread_or_expression)*)?}
spread_or_expression = { spread | expression }
range_or_expression = { range | expression }
array_initializer_expression = { "[" ~ expression ~ ";" ~ constant ~ "]" }
@ -75,7 +89,8 @@ unary_expression = { op_unary ~ term }
// End Expressions
assignee = { identifier ~ ("[" ~ range_or_expression ~ "]")* }
assignee = { identifier ~ assignee_access* }
assignee_access = { array_access | member_access }
identifier = @{ ((!keyword ~ ASCII_ALPHA) | (keyword ~ (ASCII_ALPHANUMERIC | "_"))) ~ (ASCII_ALPHANUMERIC | "_")* }
constant = { decimal_number | boolean_literal }
decimal_number = @{ "0" | ASCII_NONZERO_DIGIT ~ ASCII_DIGIT* }

View file

@ -9,12 +9,13 @@ extern crate lazy_static;
pub use ast::{
Access, ArrayAccess, ArrayInitializerExpression, ArrayType, AssertionStatement, Assignee,
AssignmentStatement, BasicType, BinaryExpression, BinaryOperator, BooleanType, CallAccess,
ConstantExpression, DefinitionStatement, Expression, FieldType, File, FromExpression, Function,
IdentifierExpression, ImportDirective, ImportSource, InlineArrayExpression, IterationStatement,
AssigneeAccess, AssignmentStatement, BasicOrStructType, BasicType, BinaryExpression,
BinaryOperator, CallAccess, ConstantExpression, DefinitionStatement, Expression, File,
FromExpression, Function, IdentifierExpression, ImportDirective, ImportSource,
InlineArrayExpression, InlineStructExpression, InlineStructMember, IterationStatement,
MultiAssignmentStatement, Parameter, PostfixExpression, Range, RangeOrExpression,
ReturnStatement, Span, Spread, SpreadOrExpression, Statement, TernaryExpression, ToExpression,
Type, UnaryExpression, UnaryOperator, Visibility,
ReturnStatement, Span, Spread, SpreadOrExpression, Statement, StructDefinition, StructField,
TernaryExpression, ToExpression, Type, UnaryExpression, UnaryOperator, Visibility,
};
mod ast {
@ -121,6 +122,9 @@ mod ast {
Rule::postfix_expression => Expression::Postfix(
PostfixExpression::from_pest(&mut pair.into_inner()).unwrap(),
),
Rule::inline_struct_expression => Expression::InlineStruct(
InlineStructExpression::from_pest(&mut pair.into_inner()).unwrap(),
),
Rule::inline_array_expression => Expression::InlineArray(
InlineArrayExpression::from_pest(&mut pair.into_inner()).unwrap(),
),
@ -155,12 +159,31 @@ mod ast {
#[pest_ast(rule(Rule::file))]
pub struct File<'ast> {
pub imports: Vec<ImportDirective<'ast>>,
pub structs: Vec<StructDefinition<'ast>>,
pub functions: Vec<Function<'ast>>,
pub eoi: EOI,
#[pest_ast(outer())]
pub span: Span<'ast>,
}
#[derive(Debug, FromPest, PartialEq, Clone)]
#[pest_ast(rule(Rule::ty_struct_definition))]
pub struct StructDefinition<'ast> {
pub id: IdentifierExpression<'ast>,
pub fields: Vec<StructField<'ast>>,
#[pest_ast(outer())]
pub span: Span<'ast>,
}
#[derive(Debug, FromPest, PartialEq, Clone)]
#[pest_ast(rule(Rule::struct_field))]
pub struct StructField<'ast> {
pub ty: Type<'ast>,
pub id: IdentifierExpression<'ast>,
#[pest_ast(outer())]
pub span: Span<'ast>,
}
#[derive(Debug, FromPest, PartialEq, Clone)]
#[pest_ast(rule(Rule::function_definition))]
pub struct Function<'ast> {
@ -174,13 +197,30 @@ mod ast {
#[derive(Debug, FromPest, PartialEq, Clone)]
#[pest_ast(rule(Rule::import_directive))]
pub struct ImportDirective<'ast> {
pub enum ImportDirective<'ast> {
Main(MainImportDirective<'ast>),
From(FromImportDirective<'ast>),
}
#[derive(Debug, FromPest, PartialEq, Clone)]
#[pest_ast(rule(Rule::main_import_directive))]
pub struct MainImportDirective<'ast> {
pub source: ImportSource<'ast>,
pub alias: Option<IdentifierExpression<'ast>>,
#[pest_ast(outer())]
pub span: Span<'ast>,
}
#[derive(Debug, FromPest, PartialEq, Clone)]
#[pest_ast(rule(Rule::from_import_directive))]
pub struct FromImportDirective<'ast> {
pub source: ImportSource<'ast>,
pub symbol: IdentifierExpression<'ast>,
pub alias: Option<IdentifierExpression<'ast>>,
#[pest_ast(outer())]
pub span: Span<'ast>,
}
#[derive(Debug, FromPest, PartialEq, Clone)]
#[pest_ast(rule(Rule::import_source))]
pub struct ImportSource<'ast> {
@ -193,33 +233,55 @@ mod ast {
#[derive(Debug, FromPest, PartialEq, Clone)]
#[pest_ast(rule(Rule::ty))]
pub enum Type<'ast> {
Basic(BasicType),
Basic(BasicType<'ast>),
Array(ArrayType<'ast>),
Struct(StructType<'ast>),
}
#[derive(Debug, FromPest, PartialEq, Clone)]
#[pest_ast(rule(Rule::ty_basic))]
pub enum BasicType {
Field(FieldType),
Boolean(BooleanType),
pub enum BasicType<'ast> {
Field(FieldType<'ast>),
Boolean(BooleanType<'ast>),
}
#[derive(Debug, FromPest, PartialEq, Clone)]
#[pest_ast(rule(Rule::ty_field))]
pub struct FieldType;
pub struct FieldType<'ast> {
#[pest_ast(outer())]
pub span: Span<'ast>,
}
#[derive(Debug, FromPest, PartialEq, Clone)]
#[pest_ast(rule(Rule::ty_array))]
pub struct ArrayType<'ast> {
pub ty: BasicType,
pub ty: BasicOrStructType<'ast>,
pub dimensions: Vec<Expression<'ast>>,
#[pest_ast(outer())]
pub span: Span<'ast>,
}
#[derive(Debug, FromPest, PartialEq, Clone)]
#[pest_ast(rule(Rule::ty_basic_or_struct))]
pub enum BasicOrStructType<'ast> {
Struct(StructType<'ast>),
Basic(BasicType<'ast>),
}
#[derive(Debug, FromPest, PartialEq, Clone)]
#[pest_ast(rule(Rule::ty_bool))]
pub struct BooleanType;
pub struct BooleanType<'ast> {
#[pest_ast(outer())]
pub span: Span<'ast>,
}
#[derive(Debug, FromPest, PartialEq, Clone)]
#[pest_ast(rule(Rule::ty_struct))]
pub struct StructType<'ast> {
pub id: IdentifierExpression<'ast>,
#[pest_ast(outer())]
pub span: Span<'ast>,
}
#[derive(Debug, FromPest, PartialEq, Clone)]
#[pest_ast(rule(Rule::parameter))]
@ -353,6 +415,7 @@ mod ast {
Identifier(IdentifierExpression<'ast>),
Constant(ConstantExpression<'ast>),
InlineArray(InlineArrayExpression<'ast>),
InlineStruct(InlineStructExpression<'ast>),
ArrayInitializer(ArrayInitializerExpression<'ast>),
Unary(UnaryExpression<'ast>),
}
@ -422,6 +485,24 @@ mod ast {
pub span: Span<'ast>,
}
#[derive(Debug, FromPest, PartialEq, Clone)]
#[pest_ast(rule(Rule::inline_struct_expression))]
pub struct InlineStructExpression<'ast> {
pub ty: IdentifierExpression<'ast>,
pub members: Vec<InlineStructMember<'ast>>,
#[pest_ast(outer())]
pub span: Span<'ast>,
}
#[derive(Debug, FromPest, PartialEq, Clone)]
#[pest_ast(rule(Rule::inline_struct_member))]
pub struct InlineStructMember<'ast> {
pub id: IdentifierExpression<'ast>,
pub expression: Expression<'ast>,
#[pest_ast(outer())]
pub span: Span<'ast>,
}
#[derive(Debug, FromPest, PartialEq, Clone)]
#[pest_ast(rule(Rule::array_initializer_expression))]
pub struct ArrayInitializerExpression<'ast> {
@ -445,6 +526,14 @@ mod ast {
pub enum Access<'ast> {
Call(CallAccess<'ast>),
Select(ArrayAccess<'ast>),
Member(MemberAccess<'ast>),
}
#[derive(Debug, FromPest, PartialEq, Clone)]
#[pest_ast(rule(Rule::assignee_access))]
pub enum AssigneeAccess<'ast> {
Select(ArrayAccess<'ast>),
Member(MemberAccess<'ast>),
}
#[derive(Debug, FromPest, PartialEq, Clone)]
@ -463,6 +552,14 @@ mod ast {
pub span: Span<'ast>,
}
#[derive(Debug, FromPest, PartialEq, Clone)]
#[pest_ast(rule(Rule::member_access))]
pub struct MemberAccess<'ast> {
pub id: IdentifierExpression<'ast>,
#[pest_ast(outer())]
pub span: Span<'ast>,
}
#[derive(Debug, PartialEq, Clone)]
pub struct BinaryExpression<'ast> {
pub op: BinaryOperator,
@ -518,6 +615,7 @@ mod ast {
Expression::Ternary(t) => &t.span,
Expression::Postfix(p) => &p.span,
Expression::InlineArray(a) => &a.span,
Expression::InlineStruct(s) => &s.span,
Expression::ArrayInitializer(a) => &a.span,
Expression::Unary(u) => &u.span,
}
@ -593,8 +691,8 @@ mod ast {
#[derive(Debug, FromPest, PartialEq, Clone)]
#[pest_ast(rule(Rule::assignee))]
pub struct Assignee<'ast> {
pub id: IdentifierExpression<'ast>, // a
pub indices: Vec<RangeOrExpression<'ast>>, // [42 + x][31][7]
pub id: IdentifierExpression<'ast>, // a
pub accesses: Vec<AssigneeAccess<'ast>>, // [42 + x].foo[7]
#[pest_ast(outer())]
pub span: Span<'ast>,
}
@ -696,13 +794,16 @@ mod tests {
assert_eq!(
generate_ast(&source),
Ok(File {
structs: vec![],
functions: vec![Function {
id: IdentifierExpression {
value: String::from("main"),
span: Span::new(&source, 33, 37).unwrap()
},
parameters: vec![],
returns: vec![Type::Basic(BasicType::Field(FieldType {}))],
returns: vec![Type::Basic(BasicType::Field(FieldType {
span: Span::new(&source, 44, 49).unwrap()
}))],
statements: vec![Statement::Return(ReturnStatement {
expressions: vec![Expression::add(
Expression::Constant(ConstantExpression::DecimalNumber(
@ -723,14 +824,14 @@ mod tests {
})],
span: Span::new(&source, 29, source.len()).unwrap(),
}],
imports: vec![ImportDirective {
imports: vec![ImportDirective::Main(MainImportDirective {
source: ImportSource {
value: String::from("foo"),
span: Span::new(&source, 8, 11).unwrap()
},
alias: None,
span: Span::new(&source, 0, 29).unwrap()
}],
})],
eoi: EOI {},
span: Span::new(&source, 0, 65).unwrap()
})
@ -745,13 +846,16 @@ mod tests {
assert_eq!(
generate_ast(&source),
Ok(File {
structs: vec![],
functions: vec![Function {
id: IdentifierExpression {
value: String::from("main"),
span: Span::new(&source, 33, 37).unwrap()
},
parameters: vec![],
returns: vec![Type::Basic(BasicType::Field(FieldType {}))],
returns: vec![Type::Basic(BasicType::Field(FieldType {
span: Span::new(&source, 44, 49).unwrap()
}))],
statements: vec![Statement::Return(ReturnStatement {
expressions: vec![Expression::add(
Expression::Constant(ConstantExpression::DecimalNumber(
@ -790,14 +894,14 @@ mod tests {
})],
span: Span::new(&source, 29, 74).unwrap(),
}],
imports: vec![ImportDirective {
imports: vec![ImportDirective::Main(MainImportDirective {
source: ImportSource {
value: String::from("foo"),
span: Span::new(&source, 8, 11).unwrap()
},
alias: None,
span: Span::new(&source, 0, 29).unwrap()
}],
})],
eoi: EOI {},
span: Span::new(&source, 0, 74).unwrap()
})
@ -812,13 +916,16 @@ mod tests {
assert_eq!(
generate_ast(&source),
Ok(File {
structs: vec![],
functions: vec![Function {
id: IdentifierExpression {
value: String::from("main"),
span: Span::new(&source, 33, 37).unwrap()
},
parameters: vec![],
returns: vec![Type::Basic(BasicType::Field(FieldType {}))],
returns: vec![Type::Basic(BasicType::Field(FieldType {
span: Span::new(&source, 44, 49).unwrap()
}))],
statements: vec![Statement::Return(ReturnStatement {
expressions: vec![Expression::if_else(
Expression::Constant(ConstantExpression::DecimalNumber(
@ -845,14 +952,14 @@ mod tests {
})],
span: Span::new(&source, 29, 81).unwrap(),
}],
imports: vec![ImportDirective {
imports: vec![ImportDirective::Main(MainImportDirective {
source: ImportSource {
value: String::from("foo"),
span: Span::new(&source, 8, 11).unwrap()
},
alias: None,
span: Span::new(&source, 0, 29).unwrap()
}],
})],
eoi: EOI {},
span: Span::new(&source, 0, 81).unwrap()
})
@ -866,13 +973,16 @@ mod tests {
assert_eq!(
generate_ast(&source),
Ok(File {
structs: vec![],
functions: vec![Function {
id: IdentifierExpression {
value: String::from("main"),
span: Span::new(&source, 4, 8).unwrap()
},
parameters: vec![],
returns: vec![Type::Basic(BasicType::Field(FieldType {}))],
returns: vec![Type::Basic(BasicType::Field(FieldType {
span: Span::new(&source, 15, 20).unwrap()
}))],
statements: vec![Statement::Return(ReturnStatement {
expressions: vec![Expression::Constant(ConstantExpression::DecimalNumber(
DecimalNumberExpression {
@ -898,13 +1008,16 @@ mod tests {
assert_eq!(
generate_ast(&source),
Ok(File {
structs: vec![],
functions: vec![Function {
id: IdentifierExpression {
value: String::from("main"),
span: Span::new(&source, 4, 8).unwrap()
},
parameters: vec![],
returns: vec![Type::Basic(BasicType::Field(FieldType {}))],
returns: vec![Type::Basic(BasicType::Field(FieldType {
span: Span::new(&source, 15, 20).unwrap()
}))],
statements: vec![Statement::MultiAssignment(MultiAssignmentStatement {
function_id: IdentifierExpression {
value: String::from("foo"),
@ -912,7 +1025,9 @@ mod tests {
},
lhs: vec![
OptionallyTypedIdentifier {
ty: Some(Type::Basic(BasicType::Field(FieldType {}))),
ty: Some(Type::Basic(BasicType::Field(FieldType {
span: Span::new(&source, 23, 28).unwrap()
}))),
id: IdentifierExpression {
value: String::from("a"),
span: Span::new(&source, 29, 30).unwrap(),
@ -965,13 +1080,19 @@ mod tests {
#[test]
fn playground() {
let source = r#"import "heyman" as yo
struct Foo {
field[2] foo
Bar bar
}
def main(private field[23] a) -> (bool[234 + 6]):
field a = 1
a[32 + x][55] = y
for field i in 0..3 do
a == 1 + 2 + 3+ 4+ 5+ 6+ 6+ 7+ 8 + 4+ 5+ 3+ 4+ 2+ 3
endfor
a == 1
a.member == 1
return a
"#;
let res = generate_ast(&source);