1
0
Fork 0
mirror of synced 2025-09-24 04:40:05 +00:00

Merge pull request #434 from Zokrates/rec-arrays

Arrays of anything
This commit is contained in:
Thibaut Schaeffer 2019-09-23 11:25:52 +02:00 committed by GitHub
commit 04fe7ef1ea
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
34 changed files with 2796 additions and 1053 deletions

View file

@ -1,8 +1,10 @@
## Types
ZoKrates currently exposes three types:
ZoKrates currently exposes two primitive types and a complex array type:
### `field`
### Primitive Types
#### `field`
This is the most basic type in ZoKrates, and it represents a positive integer in `[0, p - 1]` where `p` is a (large) prime number.
@ -14,7 +16,7 @@ While `field` values mostly behave like unsigned integers, one should keep in mi
{{#include ../../../zokrates_cli/examples/book/field_overflow.code}}
```
### `bool`
#### `bool`
ZoKrates has limited support for booleans, to the extent that they can only be used as the condition in `if ... else ... endif` expressions.
@ -22,10 +24,64 @@ You can use them for equality checks, inequality checks and inequality checks be
Note that while equality checks are cheap, inequality checks should be use wisely as they are orders of magnitude more expensive.
### `field[n]`
### Complex Types
Static arrays of `field` can be instantiated with a constant size, and their elements can be accessed and updated:
#### Arrays
ZoKrates supports static arrays, i.e., their length needs to be known at compile time.
Arrays can contain elements of any type and have arbitrary dimensions.
The following examples code shows examples of how to use arrays:
```zokrates
{{#include ../../../zokrates_cli/examples/book/array.code}}
```
##### Declaration and Initialization
An array is defined by appending `[]` to a type literal representing the type of the array's elements.
Initialization always needs to happen in the same statement than declaration, unless the array is declared within a function's signature.
For initialization, a list of comma-separated values is provided within brackets `[]`.
ZoKrates offers a special shorthand syntax to initialize an array with a constant value:
`[value;repetitions]`
The following code provides examples for declaration and initialization:
```zokrates
field[3] a = [1, 2, 3] // initialize a field array with field values
bool[13] b = [false; 13] // initialize a bool array with value false
```
##### Multidimensional Arrays
As an array can contain any type of elements, it can contain arrays again.
There is a special syntax to declare such multi-dimensional arrays, i.e., arrays of arrays.
To declare an array of an inner array, i.e., and array of elements of a type, prepend brackets `[size]` to the declaration of the inner array.
In summary, this leads to the following scheme for array declarations:
`data_type[size of 1st dimension][size of 2nd dimension]`.
Consider the following example:
```zokrates
{{#include ../../../zokrates_cli/examples/book/multidim_array.code}}
```
##### Spreads and Slices
ZoKrates provides some syntactic sugar to retrieve subsets of arrays.
###### Spreads
The spread operator `...` applied to an copies the elements of an existing array.
This can be used to conveniently compose new arrays, as shown in the following example:
```
field[3] = [1, 2, 3]
field[4] c = [...a, 4] // initialize an array copying values from `a`, followed by 4
```
###### Slices
An array can also be assigned to by creating a copy of a subset of an existing array.
This operation is called slicing, and the following example shows how to slice in ZoKrates:
```
field[3] a = [1, 2, 3]
field[2] b = a[1..3] // initialize an array copying a slice from `a`
```

View file

@ -0,0 +1,9 @@
def main(bool[3] a) -> (field[3]):
bool[3] c = [true, true || false, true]
a[1] = true || a[2]
a[2] = a[0]
field[3] result = [0; 3]
for field i in 0..3 do
result[i] = if a[i] then 33 else 0 fi
endfor
return result

View file

@ -0,0 +1,12 @@
def main(field[2][2][2] cube) -> (field):
field res = 0
for field i in 0..2 do
for field j in 0..2 do
for field k in 0..2 do
res = res + cube[i][j][k]
endfor
endfor
endfor
return res

View file

@ -0,0 +1,4 @@
def main(field[10][10][10] a, field i, field j, field k) -> (field[3]):
a[i][j][k] = 42
field[3][3] b = [[1, 2, 3], [1, 2, 3], [1, 2, 3]]
return b[0]

View file

@ -0,0 +1,10 @@
def main() -> ():
field[3] a = [1, 2, 3]
bool[3] b = [true, true, false]
field[3][2] c = [[1, 2], [3, 4], [5, 6]]
field[3] aa = [...a]
bool[3] bb = [...b]
field[3][2] cc = [...c]
return

View file

@ -1,7 +1,8 @@
def main() -> (field):
field[3] a = [1, 2, 3] // initialize an array with values
field[3] a = [1, 2, 3] // initialize a field array with field values
a[2] = 4 // set a member to a value
field[4] b = [42; 4] // initialize an array of 4 values all equal to 42
field[4] c = [...a, 4] // initialize an array copying values from `a`, followed by 4
field[2] d = a[1..3] // initialize an array copying a slice from `a`
bool[3] e = [true, true || false, true] // initialize a boolean array
return a[0] + b[1] + c[2]

View file

@ -0,0 +1,9 @@
def main() -> (field):
// Array of two elements of array of 3 elements
field[2][3] a = [[1, 2, 3],[4, 5, 6]]
field[3] b = a[0] // should be [1, 2, 3]
// allowed access [0..2][0..3]
return a[1][2]

View file

@ -0,0 +1,6 @@
[
0,
0,
0,
0
]

View file

@ -0,0 +1,3 @@
def main(field[2][2] a) -> (field[2][2]):
a[1][1] = 42
return a

View file

@ -0,0 +1,4 @@
~out_0 0
~out_1 0
~out_2 0
~out_3 42

View file

@ -448,22 +448,30 @@ impl<'ast, T: Field> From<pest::PostfixExpression<'ast>> for absy::ExpressionNod
fn from(expression: pest::PostfixExpression<'ast>) -> absy::ExpressionNode<'ast, T> {
use absy::NodeValue;
assert!(expression.access.len() == 1); // we only allow a single access: function call or array access
let id_str = expression.id.span.as_str();
let id = absy::ExpressionNode::from(expression.id);
match expression.access[0].clone() {
pest::Access::Call(a) => absy::Expression::FunctionCall(
&expression.id.span.as_str(),
a.expressions
.into_iter()
.map(|e| absy::ExpressionNode::from(e))
.collect(),
),
pest::Access::Select(a) => absy::Expression::Select(
box absy::ExpressionNode::from(expression.id),
box absy::RangeOrExpression::from(a.expression),
),
}
.span(expression.span)
// pest::PostFixExpression contains an array of "accesses": `a(34)[42]` is represented as `[a, [Call(34), Select(42)]]`, but absy::ExpressionNode
// is recursive, so it is `Select(Call(a, 34), 42)`. We apply this transformation here
// we start with the id, and we fold the array of accesses by wrapping the current value
expression.accesses.into_iter().fold(id, |acc, a| match a {
pest::Access::Call(a) => match acc.value {
absy::Expression::Identifier(_) => absy::Expression::FunctionCall(
&id_str,
a.expressions
.into_iter()
.map(|e| absy::ExpressionNode::from(e))
.collect(),
),
e => unimplemented!("only identifiers are callable, found \"{}\"", e),
}
.span(a.span),
pest::Access::Select(a) => {
absy::Expression::Select(box acc, box absy::RangeOrExpression::from(a.expression))
.span(a.span)
}
})
}
}
@ -501,15 +509,15 @@ impl<'ast, T: Field> From<pest::Assignee<'ast>> for absy::AssigneeNode<'ast, T>
use absy::NodeValue;
let a = absy::AssigneeNode::from(assignee.id);
match assignee.indices.len() {
0 => a,
1 => absy::Assignee::ArrayElement(
box a,
box absy::RangeOrExpression::from(assignee.indices[0].clone()),
)
.span(assignee.span),
n => unimplemented!("Array should have one dimension, found {} in {}", n, a),
}
let span = assignee.span;
assignee
.indices
.into_iter()
.map(|i| absy::RangeOrExpression::from(i))
.fold(a, |acc, s| {
absy::Assignee::Select(box acc, box s).span(span.clone())
})
}
}
@ -521,27 +529,34 @@ impl<'ast> From<pest::Type<'ast>> for Type {
pest::BasicType::Boolean(_) => Type::Boolean,
},
pest::Type::Array(t) => {
let size = match t.size {
pest::Expression::Constant(c) => match c {
pest::ConstantExpression::DecimalNumber(n) => {
str::parse::<usize>(&n.value).unwrap()
}
_ => unimplemented!(
"Array size should be a decimal number, found {}",
c.span().as_str()
),
},
e => {
unimplemented!("Array size should be constant, found {}", e.span().as_str())
}
let inner_type = match t.ty {
pest::BasicType::Field(_) => Type::FieldElement,
pest::BasicType::Boolean(_) => Type::Boolean,
};
match t.ty {
pest::BasicType::Field(_) => Type::FieldElementArray(size),
_ => unimplemented!(
"Array elements should be field elements, found {}",
t.span.as_str()
),
}
t.dimensions
.into_iter()
.map(|s| match s {
pest::Expression::Constant(c) => match c {
pest::ConstantExpression::DecimalNumber(n) => {
str::parse::<usize>(&n.value).unwrap()
}
_ => unimplemented!(
"Array size should be a decimal number, found {}",
c.span().as_str()
),
},
e => unimplemented!(
"Array size should be constant, found {}",
e.span().as_str()
),
})
.rev()
.fold(None, |acc, s| match acc {
None => Some(Type::array(inner_type.clone(), s)),
Some(acc) => Some(Type::array(acc, s)),
})
.unwrap()
}
}
}
@ -658,4 +673,186 @@ mod tests {
assert_eq!(absy::Module::<FieldPrime>::from(ast), expected);
}
mod types {
use super::*;
/// Helper method to generate the ast for `def main(private {ty} a) -> (): return` which we use to check ty
fn wrap(ty: types::Type) -> absy::Module<'static, FieldPrime> {
absy::Module {
functions: vec![absy::FunctionDeclaration {
id: "main",
symbol: absy::FunctionSymbol::Here(
absy::Function {
arguments: vec![absy::Parameter::private(
absy::Variable::new("a", ty.clone()).into(),
)
.into()],
statements: vec![absy::Statement::Return(
absy::ExpressionList {
expressions: vec![],
}
.into(),
)
.into()],
signature: absy::Signature::new().inputs(vec![ty]),
}
.into(),
),
}
.into()],
imports: vec![],
}
}
#[test]
fn array() {
let vectors = vec![
("field", types::Type::FieldElement),
("bool", types::Type::Boolean),
(
"field[2]",
types::Type::Array(box types::Type::FieldElement, 2),
),
(
"field[2][3]",
types::Type::Array(box Type::Array(box types::Type::FieldElement, 3), 2),
),
(
"bool[2][3]",
types::Type::Array(box Type::Array(box types::Type::Boolean, 3), 2),
),
];
for (ty, expected) in vectors {
let source = format!("def main(private {} a) -> (): return", ty);
let expected = wrap(expected);
let ast = pest::generate_ast(&source).unwrap();
assert_eq!(absy::Module::<FieldPrime>::from(ast), expected);
}
}
}
mod postfix {
use super::*;
fn wrap(expression: absy::Expression<'static, FieldPrime>) -> absy::Module<FieldPrime> {
absy::Module {
functions: vec![absy::FunctionDeclaration {
id: "main",
symbol: absy::FunctionSymbol::Here(
absy::Function {
arguments: vec![],
statements: vec![absy::Statement::Return(
absy::ExpressionList {
expressions: vec![expression.into()],
}
.into(),
)
.into()],
signature: absy::Signature::new(),
}
.into(),
),
}
.into()],
imports: vec![],
}
}
#[test]
fn success() {
// we basically accept `()?[]*` : an optional call at first, then only array accesses
let vectors = vec![
("a", absy::Expression::Identifier("a").into()),
(
"a[3]",
absy::Expression::Select(
box absy::Expression::Identifier("a").into(),
box absy::RangeOrExpression::Expression(
absy::Expression::FieldConstant(FieldPrime::from(3)).into(),
)
.into(),
),
),
(
"a[3][4]",
absy::Expression::Select(
box absy::Expression::Select(
box absy::Expression::Identifier("a").into(),
box absy::RangeOrExpression::Expression(
absy::Expression::FieldConstant(FieldPrime::from(3)).into(),
)
.into(),
)
.into(),
box absy::RangeOrExpression::Expression(
absy::Expression::FieldConstant(FieldPrime::from(4)).into(),
)
.into(),
),
),
(
"a(3)[4]",
absy::Expression::Select(
box absy::Expression::FunctionCall(
"a",
vec![absy::Expression::FieldConstant(FieldPrime::from(3)).into()],
)
.into(),
box absy::RangeOrExpression::Expression(
absy::Expression::FieldConstant(FieldPrime::from(4)).into(),
)
.into(),
),
),
(
"a(3)[4][5]",
absy::Expression::Select(
box absy::Expression::Select(
box absy::Expression::FunctionCall(
"a",
vec![absy::Expression::FieldConstant(FieldPrime::from(3)).into()],
)
.into(),
box absy::RangeOrExpression::Expression(
absy::Expression::FieldConstant(FieldPrime::from(4)).into(),
)
.into(),
)
.into(),
box absy::RangeOrExpression::Expression(
absy::Expression::FieldConstant(FieldPrime::from(5)).into(),
)
.into(),
),
),
];
for (source, expected) in vectors {
let source = format!("def main() -> (): return {}", source);
let expected = wrap(expected);
let ast = pest::generate_ast(&source).unwrap();
assert_eq!(absy::Module::<FieldPrime>::from(ast), expected);
}
}
#[test]
#[should_panic]
fn call_array_element() {
// a call after an array access should be rejected
let source = "def main() -> (): return a[2](3)";
let ast = pest::generate_ast(&source).unwrap();
absy::Module::<FieldPrime>::from(ast);
}
#[test]
#[should_panic]
fn call_call_result() {
// a call after a call should be rejected
let source = "def main() -> (): return a(2)(3)";
let ast = pest::generate_ast(&source).unwrap();
absy::Module::<FieldPrime>::from(ast);
}
}
}

View file

@ -198,7 +198,7 @@ impl<'ast, T: Field> fmt::Debug for Function<'ast, T> {
#[derive(Clone, PartialEq)]
pub enum Assignee<'ast, T: Field> {
Identifier(Identifier<'ast>),
ArrayElement(Box<AssigneeNode<'ast, T>>, Box<RangeOrExpression<'ast, T>>),
Select(Box<AssigneeNode<'ast, T>>, Box<RangeOrExpression<'ast, T>>),
}
pub type AssigneeNode<'ast, T> = Node<Assignee<'ast, T>>;
@ -206,8 +206,8 @@ pub type AssigneeNode<'ast, T> = Node<Assignee<'ast, T>>;
impl<'ast, T: Field> fmt::Debug for Assignee<'ast, T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
Assignee::Identifier(ref s) => write!(f, "{}", s),
Assignee::ArrayElement(ref a, ref e) => write!(f, "{}[{}]", a, e),
Assignee::Identifier(ref s) => write!(f, "Identifier({:?})", s),
Assignee::Select(ref a, ref e) => write!(f, "Select({:?}[{:?}])", a, e),
}
}
}
@ -288,6 +288,12 @@ pub enum SpreadOrExpression<'ast, T: Field> {
Expression(ExpressionNode<'ast, T>),
}
impl<'ast, T: Field> From<ExpressionNode<'ast, T>> for SpreadOrExpression<'ast, T> {
fn from(e: ExpressionNode<'ast, T>) -> SpreadOrExpression<'ast, T> {
SpreadOrExpression::Expression(e)
}
}
impl<'ast, T: Field> fmt::Display for SpreadOrExpression<'ast, T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
@ -499,7 +505,9 @@ impl<'ast, T: Field> fmt::Debug for Expression<'ast, T> {
f.debug_list().entries(exprs.iter()).finish()?;
write!(f, "]")
}
Expression::Select(ref array, ref index) => write!(f, "{}[{}]", array, index),
Expression::Select(ref array, ref index) => {
write!(f, "Select({:?}, {:?})", array, index)
}
Expression::Or(ref lhs, ref rhs) => write!(f, "{} || {}", lhs, rhs),
}
}

View file

@ -37,7 +37,14 @@ impl<'ast> Variable<'ast> {
pub fn field_array<S: Into<&'ast str>>(id: S, size: usize) -> Variable<'ast> {
Variable {
id: id.into(),
_type: Type::FieldElementArray(size),
_type: Type::array(Type::FieldElement, size),
}
}
pub fn array<S: Into<&'ast str>>(id: S, inner_ty: Type, size: usize) -> Variable<'ast> {
Variable {
id: id.into(),
_type: Type::array(inner_ty, size),
}
}

View file

@ -22,13 +22,16 @@ impl FlatEmbed {
match self {
FlatEmbed::Sha256Round => Signature::new()
.inputs(vec![
Type::FieldElementArray(512),
Type::FieldElementArray(256),
Type::array(Type::FieldElement, 512),
Type::array(Type::FieldElement, 256),
])
.outputs(vec![Type::FieldElementArray(256)]),
.outputs(vec![Type::array(Type::FieldElement, 256)]),
FlatEmbed::Unpack => Signature::new()
.inputs(vec![Type::FieldElement])
.outputs(vec![Type::FieldElementArray(T::get_required_bits())]),
.outputs(vec![Type::array(
Type::FieldElement,
T::get_required_bits(),
)]),
}
}
@ -123,10 +126,10 @@ pub fn sha256_round<T: Field>() -> FlatFunction<T> {
// define the signature of the resulting function
let signature = Signature {
inputs: vec![
Type::FieldElementArray(input_indices.len()),
Type::FieldElementArray(current_hash_indices.len()),
Type::array(Type::FieldElement, input_indices.len()),
Type::array(Type::FieldElement, current_hash_indices.len()),
],
outputs: vec![Type::FieldElementArray(output_indices.len())],
outputs: vec![Type::array(Type::FieldElement, output_indices.len())],
};
// define parameters to the function based on the variables
@ -234,7 +237,7 @@ pub fn unpack<T: Field>() -> FlatFunction<T> {
let signature = Signature {
inputs: vec![Type::FieldElement],
outputs: vec![Type::FieldElementArray(nbits)],
outputs: vec![Type::array(Type::FieldElement, nbits)],
};
let outputs = directive_outputs
@ -351,10 +354,10 @@ mod tests {
compiled.signature,
Signature::new()
.inputs(vec![
Type::FieldElementArray(512),
Type::FieldElementArray(256)
Type::array(Type::FieldElement, 512),
Type::array(Type::FieldElement, 256)
])
.outputs(vec![Type::FieldElementArray(256)])
.outputs(vec![Type::array(Type::FieldElement, 256)])
);
// function should have 768 inputs

File diff suppressed because it is too large Load diff

View file

@ -53,7 +53,7 @@ impl<T: Field> fmt::Display for DirectiveStatement<T> {
}
}
#[derive(Clone, PartialEq, Debug, Serialize, Deserialize)]
#[derive(Clone, PartialEq, Debug, Serialize, Deserialize, Hash, Eq)]
pub enum Helper {
Rust(RustHelper),
#[cfg(feature = "wasm")]

View file

@ -3,7 +3,7 @@ use std::fmt;
use zokrates_embed::generate_sha256_round_witness;
use zokrates_field::field::Field;
#[derive(Clone, PartialEq, Debug, Serialize, Deserialize)]
#[derive(Clone, PartialEq, Debug, Serialize, Deserialize, Hash, Eq)]
pub enum RustHelper {
Identity,
ConditionEq,

View file

@ -3,6 +3,7 @@ use std::fmt;
use rustc_hex::FromHex;
use serde::{Deserialize, Deserializer};
use std::hash::{Hash, Hasher};
use std::rc::Rc;
use wasmi::{ImportsBuilder, ModuleInstance, ModuleRef, NopExternals};
use zokrates_field::field::Field;
@ -71,6 +72,14 @@ impl PartialEq for WasmHelper {
}
}
impl Eq for WasmHelper {}
impl Hash for WasmHelper {
fn hash<H: Hasher>(&self, state: &mut H) {
self.1.hash(state);
}
}
impl fmt::Display for WasmHelper {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "Hex(\"{:?}\")", &self.1[..])

View file

@ -5,7 +5,7 @@ use std::fmt;
use std::ops::{Add, Div, Mul, Sub};
use zokrates_field::field::Field;
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Hash, Eq)]
pub struct QuadComb<T: Field> {
pub left: LinComb<T>,
pub right: LinComb<T>,

View file

@ -10,13 +10,13 @@ mod from_flat;
mod interpreter;
mod witness;
use self::expression::QuadComb;
pub use self::expression::QuadComb;
pub use self::expression::{CanonicalLinComb, LinComb};
pub use self::interpreter::{Error, ExecutionResult};
pub use self::witness::Witness;
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Hash, Eq)]
pub enum Statement<T: Field> {
Constraint(QuadComb<T>, LinComb<T>),
Directive(Directive<T>),
@ -32,7 +32,7 @@ impl<T: Field> Statement<T> {
}
}
#[derive(Clone, PartialEq, Debug, Serialize, Deserialize)]
#[derive(Clone, PartialEq, Debug, Serialize, Deserialize, Hash, Eq)]
pub struct Directive<T: Field> {
pub inputs: Vec<LinComb<T>>,
pub outputs: Vec<FlatVariable>,
@ -102,7 +102,7 @@ impl<T: Field> fmt::Display for Function<T> {
}
}
#[derive(Serialize, Deserialize, Debug, Clone)]
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct Prog<T: Field> {
pub main: Function<T>,
pub private: Vec<bool>,

View file

@ -0,0 +1,143 @@
//! Module containing the `DuplicateOptimizer` to remove duplicate constraints
use crate::ir::folder::Folder;
use crate::ir::*;
use std::collections::{hash_map::DefaultHasher, HashSet};
use zokrates_field::field::Field;
type Hash = u64;
fn hash<T: Field>(s: &Statement<T>) -> Hash {
use std::hash::Hash;
use std::hash::Hasher;
let mut hasher = DefaultHasher::new();
s.hash(&mut hasher);
hasher.finish()
}
#[derive(Debug)]
pub struct DuplicateOptimizer {
seen: HashSet<Hash>,
}
impl DuplicateOptimizer {
fn new() -> Self {
DuplicateOptimizer {
seen: HashSet::new(),
}
}
pub fn optimize<T: Field>(p: Prog<T>) -> Prog<T> {
Self::new().fold_module(p)
}
}
impl<T: Field> Folder<T> for DuplicateOptimizer {
fn fold_statement(&mut self, s: Statement<T>) -> Vec<Statement<T>> {
let hashed = hash(&s);
let result = match self.seen.get(&hashed) {
Some(_) => vec![],
None => vec![s],
};
self.seen.insert(hashed);
result
}
}
#[cfg(test)]
mod tests {
use super::*;
use flat_absy::FlatVariable;
use zokrates_field::field::FieldPrime;
#[test]
fn identity() {
use num::Zero;
let p: Prog<FieldPrime> = Prog {
private: vec![],
main: Function {
id: "main".to_string(),
statements: vec![
Statement::Constraint(
QuadComb::from_linear_combinations(
LinComb::summand(3, FlatVariable::new(3)),
LinComb::summand(3, FlatVariable::new(3)),
),
LinComb::one(),
),
Statement::Constraint(
QuadComb::from_linear_combinations(
LinComb::summand(3, FlatVariable::new(42)),
LinComb::summand(3, FlatVariable::new(3)),
),
LinComb::zero(),
),
],
returns: vec![],
arguments: vec![],
},
};
let expected = p.clone();
assert_eq!(DuplicateOptimizer::optimize(p), expected);
}
#[test]
fn remove_duplicates() {
use num::Zero;
let constraint = Statement::Constraint(
QuadComb::from_linear_combinations(
LinComb::summand(3, FlatVariable::new(3)),
LinComb::summand(3, FlatVariable::new(3)),
),
LinComb::one(),
);
let p: Prog<FieldPrime> = Prog {
private: vec![],
main: Function {
id: "main".to_string(),
statements: vec![
constraint.clone(),
constraint.clone(),
Statement::Constraint(
QuadComb::from_linear_combinations(
LinComb::summand(3, FlatVariable::new(42)),
LinComb::summand(3, FlatVariable::new(3)),
),
LinComb::zero(),
),
constraint.clone(),
constraint.clone(),
],
returns: vec![],
arguments: vec![],
},
};
let expected = Prog {
private: vec![],
main: Function {
id: "main".to_string(),
statements: vec![
constraint.clone(),
Statement::Constraint(
QuadComb::from_linear_combinations(
LinComb::summand(3, FlatVariable::new(42)),
LinComb::summand(3, FlatVariable::new(3)),
),
LinComb::zero(),
),
],
returns: vec![],
arguments: vec![],
},
};
assert_eq!(DuplicateOptimizer::optimize(p), expected);
}
}

View file

@ -4,9 +4,11 @@
//! @author Thibaut Schaeffer <thibaut@schaeff.fr>
//! @date 2018
mod duplicate;
mod redefinition;
mod tautology;
use self::duplicate::DuplicateOptimizer;
use self::redefinition::RedefinitionOptimizer;
use self::tautology::TautologyOptimizer;
@ -23,6 +25,8 @@ impl<T: Field> Optimize for Prog<T> {
let r = RedefinitionOptimizer::optimize(self);
// remove constraints that are always satisfied
let r = TautologyOptimizer::optimize(r);
// remove duplicate constraints
let r = DuplicateOptimizer::optimize(r);
r
}
}

View file

@ -616,7 +616,7 @@ impl<'ast> Checker<'ast> {
message: format!("Undeclared variable: {:?}", variable_name),
}),
},
Assignee::ArrayElement(box assignee, box index) => {
Assignee::Select(box assignee, box index) => {
let checked_assignee = self.check_assignee(assignee)?;
let checked_index = match index {
RangeOrExpression::Expression(e) => self.check_expression(e)?,
@ -639,7 +639,7 @@ impl<'ast> Checker<'ast> {
}),
}?;
Ok(TypedAssignee::ArrayElement(
Ok(TypedAssignee::Select(
box checked_assignee,
box checked_typed_index,
))
@ -657,29 +657,43 @@ impl<'ast> Checker<'ast> {
let checked_expression = self.check_expression(s.value.expression)?;
match checked_expression {
TypedExpression::FieldElementArray(e) => match e {
// if we're doing a spread over an inline array, we return the inside of the array: ...[x, y, z] == x, y, z
FieldElementArrayExpression::Value(_, v) => {
Ok(v.into_iter().map(|e| e.into()).collect())
}
e => {
let size = e.size();
Ok((0..size)
.map(|i| {
FieldElementExpression::Select(
box e.clone(),
TypedExpression::Array(e) => {
let ty = e.inner_type().clone();
let size = e.size();
match e.into_inner() {
// if we're doing a spread over an inline array, we return the inside of the array: ...[x, y, z] == x, y, z
// this is not strictly needed, but it makes spreads memory linear rather than quadratic
ArrayExpressionInner::Value(v) => Ok(v),
// otherwise we return a[0], ..., a[a.size() -1 ]
e => Ok((0..size)
.map(|i| match &ty {
Type::FieldElement => FieldElementExpression::Select(
box e.clone().annotate(Type::FieldElement, size),
box FieldElementExpression::Number(T::from(i)),
)
.into()
.into(),
Type::Boolean => BooleanExpression::Select(
box e.clone().annotate(Type::Boolean, size),
box FieldElementExpression::Number(T::from(i)),
)
.into(),
Type::Array(box ty, s) => ArrayExpressionInner::Select(
box e
.clone()
.annotate(Type::Array(box ty.clone(), *s), size),
box FieldElementExpression::Number(T::from(i)),
)
.annotate(ty.clone(), *s)
.into(),
})
.collect())
.collect()),
}
},
}
e => Err(Error {
pos: Some(pos),
message: format!(
"Expected spread operator to apply on field element array, found {}",
"Expected spread operator to apply on array, found {}",
e.get_type()
),
}),
@ -705,9 +719,9 @@ impl<'ast> Checker<'ast> {
Type::FieldElement => {
Ok(FieldElementExpression::Identifier(name.into()).into())
}
Type::FieldElementArray(n) => {
Ok(FieldElementArrayExpression::Identifier(n, name.into()).into())
}
Type::Array(ty, size) => Ok(ArrayExpressionInner::Identifier(name.into())
.annotate(*ty, size)
.into()),
},
None => Err(Error {
pos: Some(pos),
@ -824,10 +838,15 @@ impl<'ast> Checker<'ast> {
(TypedExpression::FieldElement(consequence), TypedExpression::FieldElement(alternative)) => {
Ok(FieldElementExpression::IfElse(box condition, box consequence, box alternative).into())
},
(TypedExpression::FieldElementArray(consequence), TypedExpression::FieldElementArray(alternative)) => {
Ok(FieldElementArrayExpression::IfElse(box condition, box consequence, box alternative).into())
(TypedExpression::Boolean(consequence), TypedExpression::Boolean(alternative)) => {
Ok(BooleanExpression::IfElse(box condition, box consequence, box alternative).into())
},
_ => unimplemented!()
(TypedExpression::Array(consequence), TypedExpression::Array(alternative)) => {
let inner_type = consequence.inner_type().clone();
let size = consequence.size();
Ok(ArrayExpressionInner::IfElse(box condition, box consequence, box alternative).annotate(inner_type, size).into())
},
_ => unreachable!("types should match here as we checked them explicitly")
}
false => Err(Error {
pos: Some(pos),
@ -870,7 +889,7 @@ impl<'ast> Checker<'ast> {
let f = &candidates[0];
// the return count has to be 1
match f.signature.outputs.len() {
1 => match f.signature.outputs[0] {
1 => match &f.signature.outputs[0] {
Type::FieldElement => Ok(FieldElementExpression::FunctionCall(
FunctionKey {
id: f.id.clone(),
@ -879,15 +898,15 @@ impl<'ast> Checker<'ast> {
arguments_checked,
)
.into()),
Type::FieldElementArray(size) => {
Ok(FieldElementArrayExpression::FunctionCall(
size,
Type::Array(box ty, size) => {
Ok(ArrayExpressionInner::FunctionCall(
FunctionKey {
id: f.id.clone(),
signature: f.signature.clone(),
},
arguments_checked,
)
.annotate(ty.clone(), size.clone())
.into())
}
_ => unimplemented!(),
@ -910,7 +929,9 @@ impl<'ast> Checker<'ast> {
fun_id, query
),
}),
_ => panic!("duplicate definition should have been caught before the call"),
_ => {
unreachable!("duplicate definition should have been caught before the call")
}
}
}
Expression::Lt(box e1, box e2) => {
@ -1013,8 +1034,9 @@ impl<'ast> Checker<'ast> {
match index {
RangeOrExpression::Range(r) => match array {
TypedExpression::FieldElementArray(array) => {
TypedExpression::Array(array) => {
let array_size = array.size();
let inner_type = array.inner_type().clone();
let from = r
.value
@ -1050,27 +1072,44 @@ impl<'ast> Checker<'ast> {
f, t,
),
}),
(f, t, _) => Ok(FieldElementArrayExpression::Value(
t - f,
(f, t, _) => Ok(ArrayExpressionInner::Value(
(f..t)
.map(|i| {
FieldElementExpression::Select(
box array.clone(),
box FieldElementExpression::Number(T::from(i)),
)
.into()
})
.collect(),
)
.annotate(inner_type, t - f)
.into()),
}
}
_ => panic!(""),
e => Err(Error {
pos: Some(pos),
message: format!(
"Cannot access slice of expression {} of type {}",
e,
e.get_type(),
),
}),
},
RangeOrExpression::Expression(e) => match (array, self.check_expression(e)?) {
(
TypedExpression::FieldElementArray(a),
TypedExpression::FieldElement(i),
) => Ok(FieldElementExpression::Select(box a, box i).into()),
(TypedExpression::Array(a), TypedExpression::FieldElement(i)) => {
match a.inner_type().clone() {
Type::FieldElement => {
Ok(FieldElementExpression::Select(box a, box i).into())
}
Type::Boolean => Ok(BooleanExpression::Select(box a, box i).into()),
Type::Array(box ty, size) => {
Ok(ArrayExpressionInner::Select(box a, box i)
.annotate(ty.clone(), size.clone())
.into())
}
}
}
(a, e) => Err(Error {
pos: Some(pos),
message: format!(
@ -1083,9 +1122,6 @@ impl<'ast> Checker<'ast> {
}
}
Expression::InlineArray(expressions) => {
// we should have at least one expression
let size = expressions.len();
assert!(size > 0);
// check each expression, getting its type
let mut expressions_checked = vec![];
for e in expressions {
@ -1094,7 +1130,7 @@ impl<'ast> Checker<'ast> {
}
// we infer the type to be the type of the first element
let inferred_type = expressions_checked.get(0).unwrap().get_type();
let inferred_type = expressions_checked.get(0).unwrap().get_type().clone();
match inferred_type {
Type::FieldElement => {
@ -1115,24 +1151,84 @@ impl<'ast> Checker<'ast> {
),
}),
}?;
unwrapped_expressions.push(unwrapped_e);
unwrapped_expressions.push(unwrapped_e.into());
}
Ok(FieldElementArrayExpression::Value(
unwrapped_expressions.len(),
unwrapped_expressions,
)
.into())
}
_ => Err(Error {
pos: Some(pos),
let size = unwrapped_expressions.len();
message: format!(
"Only arrays of {} are supported, found {}",
Type::FieldElement,
inferred_type
),
}),
Ok(ArrayExpressionInner::Value(unwrapped_expressions)
.annotate(Type::FieldElement, size)
.into())
}
Type::Boolean => {
// we check all expressions have that same type
let mut unwrapped_expressions = vec![];
for e in expressions_checked {
let unwrapped_e = match e {
TypedExpression::Boolean(e) => Ok(e),
e => Err(Error {
pos: Some(pos),
message: format!(
"Expected {} to have type {}, but type is {}",
e,
inferred_type,
e.get_type()
),
}),
}?;
unwrapped_expressions.push(unwrapped_e.into());
}
let size = unwrapped_expressions.len();
Ok(ArrayExpressionInner::Value(unwrapped_expressions)
.annotate(Type::Boolean, size)
.into())
}
ty @ Type::Array(..) => {
// we check all expressions have that same type
let mut unwrapped_expressions = vec![];
for e in expressions_checked {
let unwrapped_e = match e {
TypedExpression::Array(e) => {
if e.get_type() == ty {
Ok(e)
} else {
Err(Error {
pos: Some(pos),
message: format!(
"Expected {} to have type {}, but type is {}",
e,
ty,
e.get_type()
),
})
}
}
e => Err(Error {
pos: Some(pos),
message: format!(
"Expected {} to have type {}, but type is {}",
e,
ty,
e.get_type()
),
}),
}?;
unwrapped_expressions.push(unwrapped_e.into());
}
let size = unwrapped_expressions.len();
Ok(ArrayExpressionInner::Value(unwrapped_expressions)
.annotate(ty, size)
.into())
}
}
}
Expression::And(box e1, box e2) => {
@ -1218,6 +1314,52 @@ mod tests {
use types::Signature;
use zokrates_field::field::FieldPrime;
mod array {
use super::*;
#[test]
fn element_type_mismatch() {
// [3, true]
let a = Expression::InlineArray(vec![
Expression::FieldConstant(FieldPrime::from(3)).mock().into(),
Expression::BooleanConstant(true).mock().into(),
])
.mock();
assert!(Checker::new().check_expression(a).is_err());
// [[0], [0, 0]]
let a = Expression::InlineArray(vec![
Expression::InlineArray(vec![Expression::FieldConstant(FieldPrime::from(0))
.mock()
.into()])
.mock()
.into(),
Expression::InlineArray(vec![
Expression::FieldConstant(FieldPrime::from(0)).mock().into(),
Expression::FieldConstant(FieldPrime::from(0)).mock().into(),
])
.mock()
.into(),
])
.mock();
assert!(Checker::new().check_expression(a).is_err());
// [[0], true]
let a = Expression::InlineArray(vec![
Expression::InlineArray(vec![Expression::FieldConstant(FieldPrime::from(0))
.mock()
.into()])
.mock()
.into(),
Expression::InlineArray(vec![Expression::BooleanConstant(true).mock().into()])
.mock()
.into(),
])
.mock();
assert!(Checker::new().check_expression(a).is_err());
}
}
mod symbols {
use super::*;
use crate::types::Signature;
@ -2291,7 +2433,7 @@ mod tests {
fn array_element() {
// field[33] a
// a[2] = 42
let a = Assignee::ArrayElement(
let a = Assignee::Select(
box Assignee::Identifier("a").mock(),
box RangeOrExpression::Expression(
Expression::FieldConstant(FieldPrime::from(2)).mock(),
@ -2309,7 +2451,7 @@ mod tests {
assert_eq!(
checker.check_assignee(a),
Ok(TypedAssignee::ArrayElement(
Ok(TypedAssignee::Select(
box TypedAssignee::Identifier(typed_absy::Variable::field_array(
"a".into(),
33
@ -2318,5 +2460,50 @@ mod tests {
))
);
}
#[test]
fn array_of_array_element() {
// field[33][42] a
// a[1][2]
let a = Assignee::Select(
box Assignee::Select(
box Assignee::Identifier("a").mock(),
box RangeOrExpression::Expression(
Expression::FieldConstant(FieldPrime::from(1)).mock(),
),
)
.mock(),
box RangeOrExpression::Expression(
Expression::FieldConstant(FieldPrime::from(2)).mock(),
),
)
.mock();
let mut checker: Checker = Checker::new();
checker
.check_statement::<FieldPrime>(
Statement::Declaration(
Variable::array("a", Type::array(Type::FieldElement, 33), 42).mock(),
)
.mock(),
&vec![],
)
.unwrap();
assert_eq!(
checker.check_assignee(a),
Ok(TypedAssignee::Select(
box TypedAssignee::Select(
box TypedAssignee::Identifier(typed_absy::Variable::array(
"a".into(),
Type::array(Type::FieldElement, 33),
42
)),
box FieldElementExpression::Number(FieldPrime::from(1)).into()
),
box FieldElementExpression::Number(FieldPrime::from(2)).into()
))
);
}
}
}

View file

@ -0,0 +1,232 @@
use crate::flat_absy::{FlatExpression, FlatExpressionList, FlatFunction, FlatStatement};
use crate::flat_absy::{FlatParameter, FlatVariable};
use crate::helpers::{DirectiveStatement, Helper, RustHelper};
use crate::types::{Signature, Type};
use bellman::pairing::ff::ScalarEngine;
use reduce::Reduce;
use zokrates_embed::{generate_sha256_round_constraints, BellmanConstraint};
use zokrates_field::field::Field;
// util to convert a vector of `(variable_id, coefficient)` to a flat_expression
fn flat_expression_from_vec<T: Field>(
v: Vec<(usize, <<T as Field>::BellmanEngine as ScalarEngine>::Fr)>,
) -> FlatExpression<T> {
match v
.into_iter()
.map(|(key, val)| {
FlatExpression::Mult(
box FlatExpression::Number(T::from_bellman(val)),
box FlatExpression::Identifier(FlatVariable::new(key)),
)
})
.reduce(|acc, e| FlatExpression::Add(box acc, box e))
{
Some(e @ FlatExpression::Mult(..)) => {
FlatExpression::Add(box FlatExpression::Number(T::zero()), box e)
} // the R1CS serializer only recognizes Add
Some(e) => e,
None => FlatExpression::Number(T::zero()),
}
}
impl<T: Field> From<BellmanConstraint<T::BellmanEngine>> for FlatStatement<T> {
fn from(c: zokrates_embed::BellmanConstraint<T::BellmanEngine>) -> FlatStatement<T> {
let rhs_a = flat_expression_from_vec(c.a);
let rhs_b = flat_expression_from_vec(c.b);
let lhs = flat_expression_from_vec(c.c);
FlatStatement::Condition(lhs, FlatExpression::Mult(box rhs_a, box rhs_b))
}
}
/// Returns a flat function which computes a sha256 round
///
/// # Remarks
///
/// The variables inside the function are set in this order:
/// - constraint system variables
/// - arguments
pub fn sha_round<T: Field>() -> FlatFunction<T> {
// Define iterators for all indices at hand
let (r1cs, input_indices, current_hash_indices, output_indices) =
generate_sha256_round_constraints::<T::BellmanEngine>();
// indices of the input
let input_indices = input_indices.into_iter();
// indices of the current hash
let current_hash_indices = current_hash_indices.into_iter();
// indices of the output
let output_indices = output_indices.into_iter();
let variable_count = r1cs.aux_count + 1; // auxiliary and ONE
// indices of the sha256round constraint system variables
let cs_indices = (0..variable_count).into_iter();
// indices of the arguments to the function
// apply an offset of `variable_count` to get the indice of our dummy `input` argument
let input_argument_indices = input_indices
.clone()
.into_iter()
.map(|i| i + variable_count);
// apply an offset of `variable_count` to get the indice of our dummy `current_hash` argument
let current_hash_argument_indices = current_hash_indices
.clone()
.into_iter()
.map(|i| i + variable_count);
// define the signature of the resulting function
let signature = Signature {
inputs: vec![
Type::array(Type::FieldElement, input_indices.len()),
Type::array(Type::FieldElement, current_hash_indices.len()),
],
outputs: vec![Type::array(Type::FieldElement, output_indices.len())],
};
// define parameters to the function based on the variables
let arguments = input_argument_indices
.clone()
.chain(current_hash_argument_indices.clone())
.map(|i| FlatParameter {
id: FlatVariable::new(i),
private: true,
})
.collect();
// define a binding of the first variable in the constraint system to one
let one_binding_statement = FlatStatement::Condition(
FlatVariable::new(0).into(),
FlatExpression::Number(T::from(1)),
);
let input_binding_statements =
// bind input and current_hash to inputs
input_indices.clone().chain(current_hash_indices).zip(input_argument_indices.clone().chain(current_hash_argument_indices.clone())).map(|(cs_index, argument_index)| {
FlatStatement::Condition(
FlatVariable::new(cs_index).into(),
FlatVariable::new(argument_index).into(),
)
});
// insert flattened statements to represent constraints
let constraint_statements = r1cs.constraints.into_iter().map(|c| c.into());
// define which subset of the witness is returned
let outputs: Vec<FlatExpression<T>> = output_indices
.map(|o| FlatExpression::Identifier(FlatVariable::new(o)))
.collect();
// insert a directive to set the witness based on the bellman gadget and inputs
let directive_statement = FlatStatement::Directive(DirectiveStatement {
outputs: cs_indices.map(|i| FlatVariable::new(i)).collect(),
inputs: input_argument_indices
.chain(current_hash_argument_indices)
.map(|i| FlatVariable::new(i).into())
.collect(),
helper: Helper::Rust(RustHelper::Sha256Round),
});
// insert a statement to return the subset of the witness
let return_statement = FlatStatement::Return(FlatExpressionList {
expressions: outputs,
});
let statements = std::iter::once(directive_statement)
.chain(std::iter::once(one_binding_statement))
.chain(input_binding_statements)
.chain(constraint_statements)
.chain(std::iter::once(return_statement))
.collect();
FlatFunction {
id: "main".to_owned(),
arguments,
statements,
signature,
}
}
#[cfg(test)]
mod tests {
use super::*;
use zokrates_field::field::FieldPrime;
#[test]
fn generate_sha256_constraints() {
let compiled = sha_round();
// function should have a signature of 768 inputs and 256 outputs
assert_eq!(
compiled.signature,
Signature::new()
.inputs(vec![
Type::array(Type::FieldElement, 512),
Type::array(Type::FieldElement, 256)
])
.outputs(vec![Type::array(Type::FieldElement, 256)])
);
// function should have 768 inputs
assert_eq!(compiled.arguments.len(), 768,);
// function should return 256 values
assert_eq!(
compiled
.statements
.iter()
.filter_map(|s| match s {
FlatStatement::Return(v) => Some(v),
_ => None,
})
.next()
.unwrap()
.expressions
.len(),
256,
);
// directive should take 768 inputs and return n_var outputs
let directive = compiled
.statements
.iter()
.filter_map(|s| match s {
FlatStatement::Directive(d) => Some(d.clone()),
_ => None,
})
.next()
.unwrap();
assert_eq!(directive.inputs.len(), 768);
assert_eq!(directive.outputs.len(), 26935);
// function input should be offset by variable_count
assert_eq!(
compiled.arguments[0].id,
FlatVariable::new(directive.outputs.len() + 1)
);
// bellman variable #0: index 0 should equal 1
assert_eq!(
compiled.statements[1],
FlatStatement::Condition(
FlatVariable::new(0).into(),
FlatExpression::Number(FieldPrime::from(1))
)
);
// bellman input #0: index 1 should equal zokrates input #0: index v_count
assert_eq!(
compiled.statements[2],
FlatStatement::Condition(FlatVariable::new(1).into(), FlatVariable::new(26936).into())
);
let f = crate::ir::Function::from(compiled);
let prog = crate::ir::Prog {
main: f,
private: vec![true; 768],
};
let input = (0..512).map(|_| 0).chain((0..256).map(|_| 1)).collect();
prog.execute(&input).unwrap();
}
}

View file

@ -18,7 +18,7 @@
use std::collections::HashMap;
use typed_absy::{folder::*, *};
use types::FunctionKey;
use types::{FunctionKey, Type};
use zokrates_field::field::Field;
/// An inliner
@ -238,35 +238,26 @@ impl<'ast, T: Field> Folder<'ast, T> for Inliner<'ast, T> {
}
}
fn fold_field_array_expression(
fn fold_array_expression_inner(
&mut self,
e: FieldElementArrayExpression<'ast, T>,
) -> FieldElementArrayExpression<'ast, T> {
ty: &Type,
size: usize,
e: ArrayExpressionInner<'ast, T>,
) -> ArrayExpressionInner<'ast, T> {
match e {
FieldElementArrayExpression::Identifier(size, id) => {
FieldElementArrayExpression::Identifier(
size,
self.fold_variable(Variable::field_array(id, size)).id,
)
}
FieldElementArrayExpression::FunctionCall(size, key, expressions) => {
//inline the arguments
let expressions: Vec<_> = expressions
.into_iter()
.map(|e| self.fold_expression(e))
.collect();
ArrayExpressionInner::FunctionCall(key, exps) => {
let exps: Vec<_> = exps.into_iter().map(|e| self.fold_expression(e)).collect();
match self.try_inline_call(&key, expressions) {
match self.try_inline_call(&key, exps) {
Ok(mut ret) => match ret.pop().unwrap() {
TypedExpression::FieldElementArray(e) => e,
TypedExpression::Array(e) => e.into_inner(),
_ => unreachable!(),
},
Err((key, expressions)) => {
FieldElementArrayExpression::FunctionCall(size, key, expressions)
}
Err((key, expressions)) => ArrayExpressionInner::FunctionCall(key, expressions),
}
}
e => fold_field_array_expression(self, e),
// default
e => fold_array_expression_inner(self, ty, size, e),
}
}
}

View file

@ -7,6 +7,8 @@
use crate::typed_absy::folder::*;
use crate::typed_absy::*;
use std::collections::HashMap;
use std::convert::TryFrom;
use types::Type;
use zokrates_field::field::Field;
pub struct Propagator<'ast, T: Field> {
@ -25,6 +27,18 @@ impl<'ast, T: Field> Propagator<'ast, T> {
}
}
fn is_constant<'ast, T: Field>(e: &TypedExpression<'ast, T>) -> bool {
match e {
TypedExpression::FieldElement(FieldElementExpression::Number(..)) => true,
TypedExpression::Boolean(BooleanExpression::Value(..)) => true,
TypedExpression::Array(a) => match a.as_inner() {
ArrayExpressionInner::Value(v) => v.iter().all(|e| is_constant(e)),
_ => false,
},
_ => false,
}
}
impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> {
fn fold_function(&mut self, f: TypedFunction<'ast, T>) -> TypedFunction<'ast, T> {
self.constants = HashMap::new();
@ -33,83 +47,50 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> {
fn fold_statement(&mut self, s: TypedStatement<'ast, T>) -> Vec<TypedStatement<'ast, T>> {
let res = match s {
TypedStatement::Declaration(v) => Some(TypedStatement::Declaration(v)),
TypedStatement::Return(expressions) => Some(TypedStatement::Return(expressions.into_iter().map(|e| self.fold_expression(e)).collect())),
// propagation to the defined variable if rhs is a constant
TypedStatement::Definition(TypedAssignee::Identifier(var), expr) => {
match self.fold_expression(expr) {
e @ TypedExpression::Boolean(BooleanExpression::Value(..)) | e @ TypedExpression::FieldElement(FieldElementExpression::Number(..)) => {
self.constants.insert(TypedAssignee::Identifier(var), e);
None
},
TypedExpression::FieldElementArray(FieldElementArrayExpression::Value(size, array)) => {
match array.iter().all(|e| match e {
FieldElementExpression::Number(..) => true,
_ => false
}) {
true => {
// all elements of the array are constants
self.constants.insert(TypedAssignee::Identifier(var), FieldElementArrayExpression::Value(size, array).into());
None
},
false => {
Some(TypedStatement::Definition(TypedAssignee::Identifier(var), FieldElementArrayExpression::Value(size, array).into()))
}
}
},
e => {
Some(TypedStatement::Definition(TypedAssignee::Identifier(var), e))
}
}
},
// a[b] = c
TypedStatement::Definition(TypedAssignee::ArrayElement(box TypedAssignee::Identifier(var), box index), expr) => {
let index = self.fold_field_expression(index);
let expr = self.fold_expression(expr);
TypedStatement::Declaration(v) => Some(TypedStatement::Declaration(v)),
TypedStatement::Return(expressions) => Some(TypedStatement::Return(
expressions
.into_iter()
.map(|e| self.fold_expression(e))
.collect(),
)),
// propagation to the defined variable if rhs is a constant
TypedStatement::Definition(TypedAssignee::Identifier(var), expr) => {
let expr = self.fold_expression(expr);
match (index, expr) {
(
FieldElementExpression::Number(n),
TypedExpression::FieldElement(expr @ FieldElementExpression::Number(..))
) => {
// a[42] = 33
// -> store (a[42] -> 33) in the constants, possibly overwriting the previous entry
self.constants.entry(TypedAssignee::Identifier(var)).and_modify(|e| {
match *e {
TypedExpression::FieldElementArray(FieldElementArrayExpression::Value(size, ref mut v)) => {
let n_as_usize = n.to_dec_string().parse::<usize>().unwrap();
if n_as_usize < size {
v[n_as_usize] = expr;
} else {
panic!(format!("out of bounds index ({} >= {}) found during static analysis", n_as_usize, size));
}
},
_ => panic!("constants should only store constants")
}
});
None
},
(index, expr) => {
// a[42] = e
// -> remove a from the constants as one of its elements is not constant
self.constants.remove(&TypedAssignee::Identifier(var.clone()));
Some(TypedStatement::Definition(TypedAssignee::ArrayElement(box TypedAssignee::Identifier(var), box index), expr))
}
}
},
TypedStatement::Definition(..) => panic!("multi dimensinal arrays are not supported, this should have been caught during semantic checking"),
// propagate lhs and rhs for conditions
TypedStatement::Condition(e1, e2) => {
// could stop execution here if condition is known to fail
Some(TypedStatement::Condition(self.fold_expression(e1), self.fold_expression(e2)))
},
// we unrolled for loops in the previous step
TypedStatement::For(..) => panic!("for loop is unexpected, it should have been unrolled"),
TypedStatement::MultipleDefinition(variables, expression_list) => {
let expression_list = self.fold_expression_list(expression_list);
Some(TypedStatement::MultipleDefinition(variables, expression_list))
}
};
if is_constant(&expr) {
self.constants.insert(TypedAssignee::Identifier(var), expr);
None
} else {
Some(TypedStatement::Definition(
TypedAssignee::Identifier(var),
expr,
))
}
}
TypedStatement::Definition(TypedAssignee::Select(..), _) => {
unreachable!("array updates should have been replaced with full array redef")
}
// propagate lhs and rhs for conditions
TypedStatement::Condition(e1, e2) => {
// could stop execution here if condition is known to fail
Some(TypedStatement::Condition(
self.fold_expression(e1),
self.fold_expression(e2),
))
}
// we unrolled for loops in the previous step
TypedStatement::For(..) => {
unreachable!("for loop is unexpected, it should have been unrolled")
}
TypedStatement::MultipleDefinition(variables, expression_list) => {
let expression_list = self.fold_expression_list(expression_list);
Some(TypedStatement::MultipleDefinition(
variables,
expression_list,
))
}
};
match res {
Some(v) => vec![v],
None => vec![],
@ -129,9 +110,9 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> {
))) {
Some(e) => match e {
TypedExpression::FieldElement(e) => e.clone(),
_ => {
panic!("constant stored for a field element should be a field element")
}
_ => unreachable!(
"constant stored for a field element should be a field element"
),
},
None => FieldElementExpression::Identifier(id),
}
@ -185,7 +166,7 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> {
(e1, FieldElementExpression::Number(n2)) => {
FieldElementExpression::Pow(box e1, box FieldElementExpression::Number(n2))
}
(_, e2) => panic!(format!(
(_, e2) => unreachable!(format!(
"non-constant exponent {} detected during static analysis",
e2
)),
@ -201,67 +182,127 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> {
}
}
FieldElementExpression::Select(box array, box index) => {
let array = self.fold_field_array_expression(array);
let array = self.fold_array_expression(array);
let index = self.fold_field_expression(index);
match (array, index) {
(
FieldElementArrayExpression::Value(size, v),
FieldElementExpression::Number(n),
) => {
let inner_type = array.inner_type().clone();
let size = array.size();
match (array.into_inner(), index) {
(ArrayExpressionInner::Value(v), FieldElementExpression::Number(n)) => {
let n_as_usize = n.to_dec_string().parse::<usize>().unwrap();
if n_as_usize < size {
v[n_as_usize].clone()
FieldElementExpression::try_from(v[n_as_usize].clone()).unwrap()
} else {
panic!(format!(
unreachable!(
"out of bounds index ({} >= {}) found during static analysis",
n_as_usize, size
));
);
}
}
(
FieldElementArrayExpression::Identifier(size, id),
FieldElementExpression::Number(n),
) => match self.constants.get(&TypedAssignee::ArrayElement(
box TypedAssignee::Identifier(Variable::field_array(id.clone(), size)),
box FieldElementExpression::Number(n.clone()).into(),
)) {
Some(e) => match e {
TypedExpression::FieldElement(e) => e.clone(),
_ => panic!(""),
},
None => FieldElementExpression::Select(
box FieldElementArrayExpression::Identifier(size, id),
box FieldElementExpression::Number(n),
),
},
(a, i) => FieldElementExpression::Select(box a, box i),
(ArrayExpressionInner::Identifier(id), FieldElementExpression::Number(n)) => {
match self.constants.get(&TypedAssignee::Select(
box TypedAssignee::Identifier(Variable::array(
id.clone(),
inner_type.clone(),
size,
)),
box FieldElementExpression::Number(n.clone()).into(),
)) {
Some(e) => match e {
TypedExpression::FieldElement(e) => e.clone(),
_ => unreachable!(""),
},
None => FieldElementExpression::Select(
box ArrayExpressionInner::Identifier(id).annotate(inner_type, size),
box FieldElementExpression::Number(n),
),
}
}
(a, i) => {
FieldElementExpression::Select(box a.annotate(inner_type, size), box i)
}
}
}
e => fold_field_expression(self, e),
}
}
fn fold_field_array_expression(
fn fold_array_expression_inner(
&mut self,
e: FieldElementArrayExpression<'ast, T>,
) -> FieldElementArrayExpression<'ast, T> {
ty: &Type,
size: usize,
e: ArrayExpressionInner<'ast, T>,
) -> ArrayExpressionInner<'ast, T> {
match e {
FieldElementArrayExpression::Identifier(size, id) => {
ArrayExpressionInner::Identifier(id) => {
match self
.constants
.get(&TypedAssignee::Identifier(Variable::field_array(
.get(&TypedAssignee::Identifier(Variable::array(
id.clone(),
ty.clone(),
size,
))) {
Some(e) => match e {
TypedExpression::FieldElementArray(e) => e.clone(),
TypedExpression::Array(e) => e.as_inner().clone(),
_ => panic!("constant stored for an array should be an array"),
},
None => FieldElementArrayExpression::Identifier(size, id),
None => ArrayExpressionInner::Identifier(id),
}
}
e => fold_field_array_expression(self, e),
ArrayExpressionInner::Select(box array, box index) => {
let array = self.fold_array_expression(array);
let index = self.fold_field_expression(index);
let inner_type = array.inner_type().clone();
let size = array.size();
match (array.into_inner(), index) {
(ArrayExpressionInner::Value(v), FieldElementExpression::Number(n)) => {
let n_as_usize = n.to_dec_string().parse::<usize>().unwrap();
if n_as_usize < size {
ArrayExpression::try_from(v[n_as_usize].clone())
.unwrap()
.into_inner()
} else {
unreachable!(
"out of bounds index ({} >= {}) found during static analysis",
n_as_usize, size
);
}
}
(ArrayExpressionInner::Identifier(id), FieldElementExpression::Number(n)) => {
match self.constants.get(&TypedAssignee::Select(
box TypedAssignee::Identifier(Variable::array(
id.clone(),
inner_type.clone(),
size,
)),
box FieldElementExpression::Number(n.clone()).into(),
)) {
Some(e) => match e {
TypedExpression::Array(e) => e.clone().into_inner(),
_ => unreachable!(""),
},
None => ArrayExpressionInner::Select(
box ArrayExpressionInner::Identifier(id).annotate(inner_type, size),
box FieldElementExpression::Number(n),
),
}
}
(a, i) => ArrayExpressionInner::Select(box a.annotate(inner_type, size), box i),
}
}
ArrayExpressionInner::IfElse(box condition, box consequence, box alternative) => {
let consequence = self.fold_array_expression(consequence);
let alternative = self.fold_array_expression(alternative);
match self.fold_boolean_expression(condition) {
BooleanExpression::Value(true) => consequence.into_inner(),
BooleanExpression::Value(false) => alternative.into_inner(),
c => ArrayExpressionInner::IfElse(box c, box consequence, box alternative),
}
}
e => fold_array_expression_inner(self, ty, size, e),
}
}
@ -380,6 +421,15 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> {
e => e,
}
}
BooleanExpression::IfElse(box condition, box consequence, box alternative) => {
let consequence = self.fold_boolean_expression(consequence);
let alternative = self.fold_boolean_expression(alternative);
match self.fold_boolean_expression(condition) {
BooleanExpression::Value(true) => consequence,
BooleanExpression::Value(false) => alternative,
c => BooleanExpression::IfElse(box c, box consequence, box alternative),
}
}
e => fold_boolean_expression(self, e),
}
}
@ -494,14 +544,12 @@ mod tests {
#[test]
fn select() {
let e = FieldElementExpression::Select(
box FieldElementArrayExpression::Value(
3,
vec![
FieldElementExpression::Number(FieldPrime::from(1)),
FieldElementExpression::Number(FieldPrime::from(2)),
FieldElementExpression::Number(FieldPrime::from(3)),
],
),
box ArrayExpressionInner::Value(vec![
FieldElementExpression::Number(FieldPrime::from(1)).into(),
FieldElementExpression::Number(FieldPrime::from(2)).into(),
FieldElementExpression::Number(FieldPrime::from(3)).into(),
])
.annotate(Type::FieldElement, 3),
box FieldElementExpression::Add(
box FieldElementExpression::Number(FieldPrime::from(1)),
box FieldElementExpression::Number(FieldPrime::from(1)),
@ -770,122 +818,4 @@ mod tests {
}
}
}
#[cfg(test)]
mod statement {
use super::*;
#[cfg(test)]
mod definition {
use super::*;
#[test]
fn update_constant_array() {
// field[2] a = [21, 22]
// // constants should store [21, 22]
// a[1] = 42
// // constants should store [21, 42]
let declaration = TypedStatement::Declaration(Variable::field_array("a".into(), 2));
let definition = TypedStatement::Definition(
TypedAssignee::Identifier(Variable::field_array("a".into(), 2)),
FieldElementArrayExpression::Value(
2,
vec![
FieldElementExpression::Number(FieldPrime::from(21)),
FieldElementExpression::Number(FieldPrime::from(22)),
],
)
.into(),
);
let overwrite = TypedStatement::Definition(
TypedAssignee::ArrayElement(
box TypedAssignee::Identifier(Variable::field_array("a".into(), 2)),
box FieldElementExpression::Number(FieldPrime::from(1)),
),
FieldElementExpression::Number(FieldPrime::from(42)).into(),
);
let mut p = Propagator::new();
p.fold_statement(declaration);
p.fold_statement(definition);
let expected_value: TypedExpression<FieldPrime> =
FieldElementArrayExpression::Value(
2,
vec![
FieldElementExpression::Number(FieldPrime::from(21)),
FieldElementExpression::Number(FieldPrime::from(22)),
],
)
.into();
assert_eq!(
p.constants
.get(&TypedAssignee::Identifier(Variable::field_array(
"a".into(),
2
)))
.unwrap(),
&expected_value
);
p.fold_statement(overwrite);
let expected_value: TypedExpression<FieldPrime> =
FieldElementArrayExpression::Value(
2,
vec![
FieldElementExpression::Number(FieldPrime::from(21)),
FieldElementExpression::Number(FieldPrime::from(42)),
],
)
.into();
assert_eq!(
p.constants
.get(&TypedAssignee::Identifier(Variable::field_array(
"a".into(),
2
)))
.unwrap(),
&expected_value
);
}
#[test]
fn update_variable_array() {
// propagation does NOT support "partially constant" arrays. That means that in order for updates to use propagation,
// the array needs to have been defined as `field[3] = [value1, value2, value3]` with all values propagateable to constants
// a passed as input
// // constants should store nothing
// a[1] = 42
// // constants should store nothing
let declaration = TypedStatement::Declaration(Variable::field_array("a".into(), 2));
let overwrite = TypedStatement::Definition(
TypedAssignee::ArrayElement(
box TypedAssignee::Identifier(Variable::field_array("a".into(), 2)),
box FieldElementExpression::Number(FieldPrime::from(1)),
),
FieldElementExpression::Number(FieldPrime::from(42)).into(),
);
let mut p = Propagator::new();
p.fold_statement(declaration);
p.fold_statement(overwrite);
assert_eq!(
p.constants
.get(&TypedAssignee::Identifier(Variable::field_array(
"a".into(),
2
))),
None
);
}
}
}
}

View file

@ -8,6 +8,7 @@ use crate::typed_absy::folder::*;
use crate::typed_absy::*;
use crate::types::Type;
use std::collections::HashMap;
use std::collections::HashSet;
use zokrates_field::field::Field;
pub struct Unroller<'ast> {
@ -43,76 +44,180 @@ impl<'ast> Unroller<'ast> {
pub fn unroll<T: Field>(p: TypedProgram<T>) -> TypedProgram<T> {
Unroller::new().fold_program(p)
}
fn choose_many<T: Field>(
base: TypedExpression<'ast, T>,
indices: Vec<FieldElementExpression<'ast, T>>,
new_expression: TypedExpression<'ast, T>,
statements: &mut HashSet<TypedStatement<'ast, T>>,
) -> TypedExpression<'ast, T> {
let mut indices = indices;
match indices.len() {
0 => new_expression,
_ => {
let base = match base {
TypedExpression::Array(e) => e,
e => unreachable!("can't take an element on a {}", e.get_type()),
};
let inner_ty = base.inner_type();
let size = base.size();
let head = indices.pop().unwrap();
let tail = indices;
statements.insert(TypedStatement::Condition(
BooleanExpression::Lt(
box head.clone(),
box FieldElementExpression::Number(T::from(size)),
)
.into(),
BooleanExpression::Value(true).into(),
));
ArrayExpressionInner::Value(
(0..size)
.map(|i| match inner_ty {
Type::Array(..) => ArrayExpression::if_else(
BooleanExpression::Eq(
box FieldElementExpression::Number(T::from(i)),
box head.clone(),
),
match Self::choose_many(
ArrayExpression::select(
base.clone(),
FieldElementExpression::Number(T::from(i)),
)
.into(),
tail.clone(),
new_expression.clone(),
statements,
) {
TypedExpression::Array(e) => e,
e => unreachable!(
"the interior was expected to be an array, was {}",
e.get_type()
),
},
ArrayExpression::select(
base.clone(),
FieldElementExpression::Number(T::from(i)),
),
)
.into(),
Type::FieldElement => FieldElementExpression::if_else(
BooleanExpression::Eq(
box FieldElementExpression::Number(T::from(i)),
box head.clone(),
),
match Self::choose_many(
FieldElementExpression::select(
base.clone(),
FieldElementExpression::Number(T::from(i)),
)
.into(),
tail.clone(),
new_expression.clone(),
statements,
) {
TypedExpression::FieldElement(e) => e,
e => unreachable!(
"the interior was expected to be a field, was {}",
e.get_type()
),
},
FieldElementExpression::select(
base.clone(),
FieldElementExpression::Number(T::from(i)),
),
)
.into(),
Type::Boolean => BooleanExpression::if_else(
BooleanExpression::Eq(
box FieldElementExpression::Number(T::from(i)),
box head.clone(),
),
match Self::choose_many(
BooleanExpression::select(
base.clone(),
FieldElementExpression::Number(T::from(i)),
)
.into(),
tail.clone(),
new_expression.clone(),
statements,
) {
TypedExpression::Boolean(e) => e,
e => unreachable!(
"the interior was expected to be a boolean, was {}",
e.get_type()
),
},
BooleanExpression::select(
base.clone(),
FieldElementExpression::Number(T::from(i)),
),
)
.into(),
})
.collect(),
)
.annotate(inner_ty.clone(), size)
.into()
}
}
}
}
/// Turn an assignee into its representation as a base variable and a list of indices
/// a[2][3][4] -> (a, [2, 3, 4])
fn linear<'ast, T: Field>(
a: TypedAssignee<'ast, T>,
) -> (Variable, Vec<FieldElementExpression<'ast, T>>) {
match a {
TypedAssignee::Identifier(v) => (v, vec![]),
TypedAssignee::Select(box array, box index) => {
let (v, mut indices) = linear(array);
indices.push(index);
(v, indices)
}
}
}
impl<'ast, T: Field> Folder<'ast, T> for Unroller<'ast> {
fn fold_statement(&mut self, s: TypedStatement<'ast, T>) -> Vec<TypedStatement<'ast, T>> {
match s {
TypedStatement::Declaration(_) => vec![],
TypedStatement::Definition(TypedAssignee::Identifier(variable), expr) => {
TypedStatement::Definition(assignee, expr) => {
let expr = self.fold_expression(expr);
vec![TypedStatement::Definition(
TypedAssignee::Identifier(self.issue_next_ssa_variable(variable)),
expr,
)]
}
TypedStatement::Definition(
TypedAssignee::ArrayElement(array @ box TypedAssignee::Identifier(..), box index),
expr,
) => {
let expr = self.fold_expression(expr);
let index = self.fold_field_expression(index);
let current_array = self.fold_assignee(*array.clone());
let (variable, indices) = linear(assignee);
let current_ssa_variable = match current_array {
TypedAssignee::Identifier(v) => v,
_ => panic!("assignee should be an identifier"),
let base = match variable.get_type() {
Type::FieldElement => {
FieldElementExpression::Identifier(variable.id.clone().into()).into()
}
Type::Boolean => {
BooleanExpression::Identifier(variable.id.clone().into()).into()
}
Type::Array(box ty, size) => {
ArrayExpressionInner::Identifier(variable.id.clone().into())
.annotate(ty, size)
.into()
}
};
let original_variable = match *array {
TypedAssignee::Identifier(v) => v,
_ => panic!("assignee should be an identifier"),
};
let mut range_checks = HashSet::new();
let e = Self::choose_many(base, indices, expr, &mut range_checks);
let array_size = match original_variable.get_type() {
Type::FieldElementArray(size) => size,
_ => panic!("array identifier should be a field element array"),
};
let expr = match expr {
TypedExpression::FieldElement(e) => e,
_ => panic!("right side of array element definition must be a field element"),
};
let new_variable = self.issue_next_ssa_variable(original_variable);
let new_array = FieldElementArrayExpression::Value(
array_size,
(0..array_size)
.map(|i| {
FieldElementExpression::IfElse(
box BooleanExpression::Eq(
box index.clone(),
box FieldElementExpression::Number(T::from(i)),
),
box expr.clone(),
box FieldElementExpression::Select(
box FieldElementArrayExpression::Identifier(
array_size,
current_ssa_variable.id.clone(),
),
box FieldElementExpression::Number(T::from(i)),
),
)
})
.collect(),
);
vec![TypedStatement::Definition(
TypedAssignee::Identifier(new_variable),
new_array.into(),
)]
range_checks
.into_iter()
.chain(std::iter::once(TypedStatement::Definition(
TypedAssignee::Identifier(self.issue_next_ssa_variable(variable)),
e,
)))
.collect()
}
TypedStatement::MultipleDefinition(variables, exprs) => {
let exprs = self.fold_expression_list(exprs);
@ -179,6 +284,241 @@ mod tests {
use super::*;
use zokrates_field::field::FieldPrime;
#[test]
fn ssa_array() {
let a0 = ArrayExpressionInner::Identifier("a".into()).annotate(Type::FieldElement, 3);
let e = FieldElementExpression::Number(FieldPrime::from(42)).into();
let index = FieldElementExpression::Number(FieldPrime::from(1));
let a1 = Unroller::choose_many(a0.clone().into(), vec![index], e, &mut HashSet::new());
// a[1] = 42
// -> a = [0 == 1 ? 42 : a[0], 1 == 1 ? 42 : a[1], 2 == 1 ? 42 : a[2]]
assert_eq!(
a1,
ArrayExpressionInner::Value(vec![
FieldElementExpression::if_else(
BooleanExpression::Eq(
box FieldElementExpression::Number(FieldPrime::from(0)),
box FieldElementExpression::Number(FieldPrime::from(1))
),
FieldElementExpression::Number(FieldPrime::from(42)),
FieldElementExpression::select(
a0.clone(),
FieldElementExpression::Number(FieldPrime::from(0))
)
)
.into(),
FieldElementExpression::if_else(
BooleanExpression::Eq(
box FieldElementExpression::Number(FieldPrime::from(1)),
box FieldElementExpression::Number(FieldPrime::from(1))
),
FieldElementExpression::Number(FieldPrime::from(42)),
FieldElementExpression::select(
a0.clone(),
FieldElementExpression::Number(FieldPrime::from(1))
)
)
.into(),
FieldElementExpression::if_else(
BooleanExpression::Eq(
box FieldElementExpression::Number(FieldPrime::from(2)),
box FieldElementExpression::Number(FieldPrime::from(1))
),
FieldElementExpression::Number(FieldPrime::from(42)),
FieldElementExpression::select(
a0.clone(),
FieldElementExpression::Number(FieldPrime::from(2))
)
)
.into()
])
.annotate(Type::FieldElement, 3)
.into()
);
let a0 = ArrayExpressionInner::Identifier("a".into())
.annotate(Type::array(Type::FieldElement, 3), 3);
let e = ArrayExpressionInner::Identifier("b".into()).annotate(Type::FieldElement, 3);
let index = FieldElementExpression::Number(FieldPrime::from(1));
let a1 = Unroller::choose_many(
a0.clone().into(),
vec![index],
e.clone().into(),
&mut HashSet::new(),
);
// a[0] = b
// -> a = [0 == 1 ? b : a[0], 1 == 1 ? b : a[1], 2 == 1 ? b : a[2]]
assert_eq!(
a1,
ArrayExpressionInner::Value(vec![
ArrayExpression::if_else(
BooleanExpression::Eq(
box FieldElementExpression::Number(FieldPrime::from(0)),
box FieldElementExpression::Number(FieldPrime::from(1))
),
e.clone(),
ArrayExpression::select(
a0.clone(),
FieldElementExpression::Number(FieldPrime::from(0))
)
)
.into(),
ArrayExpression::if_else(
BooleanExpression::Eq(
box FieldElementExpression::Number(FieldPrime::from(1)),
box FieldElementExpression::Number(FieldPrime::from(1))
),
e.clone(),
ArrayExpression::select(
a0.clone(),
FieldElementExpression::Number(FieldPrime::from(1))
)
)
.into(),
ArrayExpression::if_else(
BooleanExpression::Eq(
box FieldElementExpression::Number(FieldPrime::from(2)),
box FieldElementExpression::Number(FieldPrime::from(1))
),
e.clone(),
ArrayExpression::select(
a0.clone(),
FieldElementExpression::Number(FieldPrime::from(2))
)
)
.into()
])
.annotate(Type::array(Type::FieldElement, 3), 3)
.into()
);
let a0 = ArrayExpressionInner::Identifier("a".into())
.annotate(Type::array(Type::FieldElement, 2), 2);
let e = FieldElementExpression::Number(FieldPrime::from(42));
let indices = vec![
FieldElementExpression::Number(FieldPrime::from(0)),
FieldElementExpression::Number(FieldPrime::from(0)),
];
let a1 = Unroller::choose_many(
a0.clone().into(),
indices,
e.clone().into(),
&mut HashSet::new(),
);
// a[0][0] = 42
// -> a = [0 == 0 ? [0 == 0 ? 42 : a[0][0], 1 == 0 ? 42 : a[0][1]] : a[0], 1 == 0 ? [0 == 0 ? 42 : a[1][0], 1 == 0 ? 42 : a[1][1]] : a[1]]
assert_eq!(
a1,
ArrayExpressionInner::Value(vec![
ArrayExpression::if_else(
BooleanExpression::Eq(
box FieldElementExpression::Number(FieldPrime::from(0)),
box FieldElementExpression::Number(FieldPrime::from(0))
),
ArrayExpressionInner::Value(vec![
FieldElementExpression::if_else(
BooleanExpression::Eq(
box FieldElementExpression::Number(FieldPrime::from(0)),
box FieldElementExpression::Number(FieldPrime::from(0))
),
e.clone(),
FieldElementExpression::select(
ArrayExpression::select(
a0.clone(),
FieldElementExpression::Number(FieldPrime::from(0))
),
FieldElementExpression::Number(FieldPrime::from(0))
)
)
.into(),
FieldElementExpression::if_else(
BooleanExpression::Eq(
box FieldElementExpression::Number(FieldPrime::from(1)),
box FieldElementExpression::Number(FieldPrime::from(0))
),
e.clone(),
FieldElementExpression::select(
ArrayExpression::select(
a0.clone(),
FieldElementExpression::Number(FieldPrime::from(0))
),
FieldElementExpression::Number(FieldPrime::from(1))
)
)
.into()
])
.annotate(Type::FieldElement, 2),
ArrayExpression::select(
a0.clone(),
FieldElementExpression::Number(FieldPrime::from(0))
)
)
.into(),
ArrayExpression::if_else(
BooleanExpression::Eq(
box FieldElementExpression::Number(FieldPrime::from(1)),
box FieldElementExpression::Number(FieldPrime::from(0))
),
ArrayExpressionInner::Value(vec![
FieldElementExpression::if_else(
BooleanExpression::Eq(
box FieldElementExpression::Number(FieldPrime::from(0)),
box FieldElementExpression::Number(FieldPrime::from(0))
),
e.clone(),
FieldElementExpression::select(
ArrayExpression::select(
a0.clone(),
FieldElementExpression::Number(FieldPrime::from(1))
),
FieldElementExpression::Number(FieldPrime::from(0))
)
)
.into(),
FieldElementExpression::if_else(
BooleanExpression::Eq(
box FieldElementExpression::Number(FieldPrime::from(1)),
box FieldElementExpression::Number(FieldPrime::from(0))
),
e.clone(),
FieldElementExpression::select(
ArrayExpression::select(
a0.clone(),
FieldElementExpression::Number(FieldPrime::from(1))
),
FieldElementExpression::Number(FieldPrime::from(1))
)
)
.into()
])
.annotate(Type::FieldElement, 2),
ArrayExpression::select(
a0.clone(),
FieldElementExpression::Number(FieldPrime::from(1))
)
)
.into(),
])
.annotate(Type::array(Type::FieldElement, 2), 2)
.into()
);
}
#[cfg(test)]
mod statement {
use super::*;
@ -442,13 +782,11 @@ mod tests {
let s = TypedStatement::Definition(
TypedAssignee::Identifier(Variable::field_array("a".into(), 2)),
FieldElementArrayExpression::Value(
2,
vec![
FieldElementExpression::Number(FieldPrime::from(1)),
FieldElementExpression::Number(FieldPrime::from(1)),
],
)
ArrayExpressionInner::Value(vec![
FieldElementExpression::Number(FieldPrime::from(1)).into(),
FieldElementExpression::Number(FieldPrime::from(1)).into(),
])
.annotate(Type::FieldElement, 2)
.into(),
);
@ -459,19 +797,17 @@ mod tests {
Identifier::from("a").version(0),
2
)),
FieldElementArrayExpression::Value(
2,
vec![
FieldElementExpression::Number(FieldPrime::from(1)),
FieldElementExpression::Number(FieldPrime::from(1))
]
)
ArrayExpressionInner::Value(vec![
FieldElementExpression::Number(FieldPrime::from(1)).into(),
FieldElementExpression::Number(FieldPrime::from(1)).into()
])
.annotate(Type::FieldElement, 2)
.into()
)]
);
let s: TypedStatement<FieldPrime> = TypedStatement::Definition(
TypedAssignee::ArrayElement(
TypedAssignee::Select(
box TypedAssignee::Identifier(Variable::field_array("a".into(), 2)),
box FieldElementExpression::Number(FieldPrime::from(1)),
),
@ -480,28 +816,36 @@ mod tests {
assert_eq!(
u.fold_statement(s),
vec![TypedStatement::Definition(
TypedAssignee::Identifier(Variable::field_array(
Identifier::from("a").version(1),
2
)),
FieldElementArrayExpression::Value(
2,
vec![
vec![
TypedStatement::Condition(
BooleanExpression::Lt(
box FieldElementExpression::Number(FieldPrime::from(1)),
box FieldElementExpression::Number(FieldPrime::from(2))
)
.into(),
BooleanExpression::Value(true).into()
),
TypedStatement::Definition(
TypedAssignee::Identifier(Variable::field_array(
Identifier::from("a").version(1),
2
)),
ArrayExpressionInner::Value(vec![
FieldElementExpression::IfElse(
box BooleanExpression::Eq(
box FieldElementExpression::Number(FieldPrime::from(1)),
box FieldElementExpression::Number(FieldPrime::from(0))
box FieldElementExpression::Number(FieldPrime::from(0)),
box FieldElementExpression::Number(FieldPrime::from(1))
),
box FieldElementExpression::Number(FieldPrime::from(2)),
box FieldElementExpression::Select(
box FieldElementArrayExpression::Identifier(
2,
box ArrayExpressionInner::Identifier(
Identifier::from("a").version(0)
),
)
.annotate(Type::FieldElement, 2),
box FieldElementExpression::Number(FieldPrime::from(0))
),
),
)
.into(),
FieldElementExpression::IfElse(
box BooleanExpression::Eq(
box FieldElementExpression::Number(FieldPrime::from(1)),
@ -509,18 +853,172 @@ mod tests {
),
box FieldElementExpression::Number(FieldPrime::from(2)),
box FieldElementExpression::Select(
box FieldElementArrayExpression::Identifier(
2,
box ArrayExpressionInner::Identifier(
Identifier::from("a").version(0)
),
)
.annotate(Type::FieldElement, 2),
box FieldElementExpression::Number(FieldPrime::from(1))
),
),
]
)
.into(),
])
.annotate(Type::FieldElement, 2)
.into()
)
.into()
]
);
}
#[test]
fn incremental_array_of_arrays_definition() {
// field[2][2] a = [[0, 1], [2, 3]]
// a[1] = [4, 5]
// should be turned into
// a_0 = [[0, 1], [2, 3]]
// a_1 = [if 0 == 1 then [4, 5] else a_0[0], if 1 == 1 then [4, 5] else a_0[1]]
let mut u = Unroller::new();
let array_of_array_ty = Type::array(Type::array(Type::FieldElement, 2), 2);
let s: TypedStatement<FieldPrime> = TypedStatement::Declaration(
Variable::with_id_and_type("a".into(), array_of_array_ty.clone()),
);
assert_eq!(u.fold_statement(s), vec![]);
let s = TypedStatement::Definition(
TypedAssignee::Identifier(Variable::with_id_and_type(
"a".into(),
array_of_array_ty.clone(),
)),
ArrayExpressionInner::Value(vec![
ArrayExpressionInner::Value(vec![
FieldElementExpression::Number(FieldPrime::from(0)).into(),
FieldElementExpression::Number(FieldPrime::from(1)).into(),
])
.annotate(Type::FieldElement, 2)
.into(),
ArrayExpressionInner::Value(vec![
FieldElementExpression::Number(FieldPrime::from(2)).into(),
FieldElementExpression::Number(FieldPrime::from(3)).into(),
])
.annotate(Type::FieldElement, 2)
.into(),
])
.annotate(Type::array(Type::FieldElement, 2), 2)
.into(),
);
assert_eq!(
u.fold_statement(s),
vec![TypedStatement::Definition(
TypedAssignee::Identifier(Variable::with_id_and_type(
Identifier::from("a").version(0),
array_of_array_ty.clone(),
)),
ArrayExpressionInner::Value(vec![
ArrayExpressionInner::Value(vec![
FieldElementExpression::Number(FieldPrime::from(0)).into(),
FieldElementExpression::Number(FieldPrime::from(1)).into(),
])
.annotate(Type::FieldElement, 2)
.into(),
ArrayExpressionInner::Value(vec![
FieldElementExpression::Number(FieldPrime::from(2)).into(),
FieldElementExpression::Number(FieldPrime::from(3)).into(),
])
.annotate(Type::FieldElement, 2)
.into(),
])
.annotate(Type::array(Type::FieldElement, 2), 2)
.into(),
)]
);
let s: TypedStatement<FieldPrime> = TypedStatement::Definition(
TypedAssignee::Select(
box TypedAssignee::Identifier(Variable::with_id_and_type(
"a".into(),
array_of_array_ty.clone(),
)),
box FieldElementExpression::Number(FieldPrime::from(1)),
),
ArrayExpressionInner::Value(vec![
FieldElementExpression::Number(FieldPrime::from(4)).into(),
FieldElementExpression::Number(FieldPrime::from(5)).into(),
])
.annotate(Type::FieldElement, 2)
.into(),
);
assert_eq!(
u.fold_statement(s),
vec![
TypedStatement::Condition(
BooleanExpression::Lt(
box FieldElementExpression::Number(FieldPrime::from(1)),
box FieldElementExpression::Number(FieldPrime::from(2))
)
.into(),
BooleanExpression::Value(true).into()
),
TypedStatement::Definition(
TypedAssignee::Identifier(Variable::with_id_and_type(
Identifier::from("a").version(1),
array_of_array_ty.clone()
)),
ArrayExpressionInner::Value(vec![
ArrayExpressionInner::IfElse(
box BooleanExpression::Eq(
box FieldElementExpression::Number(FieldPrime::from(0)),
box FieldElementExpression::Number(FieldPrime::from(1))
),
box ArrayExpressionInner::Value(vec![
FieldElementExpression::Number(FieldPrime::from(4)).into(),
FieldElementExpression::Number(FieldPrime::from(5)).into(),
])
.annotate(Type::FieldElement, 2)
.into(),
box ArrayExpressionInner::Select(
box ArrayExpressionInner::Identifier(
Identifier::from("a").version(0)
)
.annotate(Type::array(Type::FieldElement, 2), 2),
box FieldElementExpression::Number(FieldPrime::from(0))
)
.annotate(Type::FieldElement, 2),
)
.annotate(Type::FieldElement, 2)
.into(),
ArrayExpressionInner::IfElse(
box BooleanExpression::Eq(
box FieldElementExpression::Number(FieldPrime::from(1)),
box FieldElementExpression::Number(FieldPrime::from(1))
),
box ArrayExpressionInner::Value(vec![
FieldElementExpression::Number(FieldPrime::from(4)).into(),
FieldElementExpression::Number(FieldPrime::from(5)).into(),
])
.annotate(Type::FieldElement, 2)
.into(),
box ArrayExpressionInner::Select(
box ArrayExpressionInner::Identifier(
Identifier::from("a").version(0)
)
.annotate(Type::array(Type::FieldElement, 2), 2),
box FieldElementExpression::Number(FieldPrime::from(1))
)
.annotate(Type::FieldElement, 2),
)
.annotate(Type::FieldElement, 2)
.into(),
])
.annotate(Type::array(Type::FieldElement, 2), 2)
.into()
)
]
);
}
}
}

View file

@ -44,7 +44,7 @@ pub trait Folder<'ast, T: Field>: Sized {
fn fold_assignee(&mut self, a: TypedAssignee<'ast, T>) -> TypedAssignee<'ast, T> {
match a {
TypedAssignee::Identifier(v) => TypedAssignee::Identifier(self.fold_variable(v)),
TypedAssignee::ArrayElement(box a, box index) => TypedAssignee::ArrayElement(
TypedAssignee::Select(box a, box index) => TypedAssignee::Select(
box self.fold_assignee(a),
box self.fold_field_expression(index),
),
@ -59,10 +59,14 @@ pub trait Folder<'ast, T: Field>: Sized {
match e {
TypedExpression::FieldElement(e) => self.fold_field_expression(e).into(),
TypedExpression::Boolean(e) => self.fold_boolean_expression(e).into(),
TypedExpression::FieldElementArray(e) => self.fold_field_array_expression(e).into(),
TypedExpression::Array(e) => self.fold_array_expression(e).into(),
}
}
fn fold_array_expression(&mut self, e: ArrayExpression<'ast, T>) -> ArrayExpression<'ast, T> {
fold_array_expression(self, e)
}
fn fold_expression_list(
&mut self,
es: TypedExpressionList<'ast, T>,
@ -93,11 +97,13 @@ pub trait Folder<'ast, T: Field>: Sized {
) -> BooleanExpression<'ast, T> {
fold_boolean_expression(self, e)
}
fn fold_field_array_expression(
fn fold_array_expression_inner(
&mut self,
e: FieldElementArrayExpression<'ast, T>,
) -> FieldElementArrayExpression<'ast, T> {
fold_field_array_expression(self, e)
ty: &Type,
size: usize,
e: ArrayExpressionInner<'ast, T>,
) -> ArrayExpressionInner<'ast, T> {
fold_array_expression_inner(self, ty, size, e)
}
}
@ -150,32 +156,33 @@ pub fn fold_statement<'ast, T: Field, F: Folder<'ast, T>>(
vec![res]
}
pub fn fold_field_array_expression<'ast, T: Field, F: Folder<'ast, T>>(
pub fn fold_array_expression_inner<'ast, T: Field, F: Folder<'ast, T>>(
f: &mut F,
e: FieldElementArrayExpression<'ast, T>,
) -> FieldElementArrayExpression<'ast, T> {
_: &Type,
_: usize,
e: ArrayExpressionInner<'ast, T>,
) -> ArrayExpressionInner<'ast, T> {
match e {
FieldElementArrayExpression::Identifier(size, id) => {
FieldElementArrayExpression::Identifier(size, f.fold_name(id))
ArrayExpressionInner::Identifier(id) => ArrayExpressionInner::Identifier(f.fold_name(id)),
ArrayExpressionInner::Value(exprs) => {
ArrayExpressionInner::Value(exprs.into_iter().map(|e| f.fold_expression(e)).collect())
}
FieldElementArrayExpression::Value(size, exprs) => FieldElementArrayExpression::Value(
size,
exprs
.into_iter()
.map(|e| f.fold_field_expression(e))
.collect(),
),
FieldElementArrayExpression::FunctionCall(size, id, exps) => {
ArrayExpressionInner::FunctionCall(id, exps) => {
let exps = exps.into_iter().map(|e| f.fold_expression(e)).collect();
FieldElementArrayExpression::FunctionCall(size, id, exps)
ArrayExpressionInner::FunctionCall(id, exps)
}
FieldElementArrayExpression::IfElse(box condition, box consequence, box alternative) => {
FieldElementArrayExpression::IfElse(
ArrayExpressionInner::IfElse(box condition, box consequence, box alternative) => {
ArrayExpressionInner::IfElse(
box f.fold_boolean_expression(condition),
box f.fold_field_array_expression(consequence),
box f.fold_field_array_expression(alternative),
box f.fold_array_expression(consequence),
box f.fold_array_expression(alternative),
)
}
ArrayExpressionInner::Select(box array, box index) => {
let array = f.fold_array_expression(array);
let index = f.fold_field_expression(index);
ArrayExpressionInner::Select(box array, box index)
}
}
}
@ -224,7 +231,7 @@ pub fn fold_field_expression<'ast, T: Field, F: Folder<'ast, T>>(
FieldElementExpression::FunctionCall(key, exps)
}
FieldElementExpression::Select(box array, box index) => {
let array = f.fold_field_array_expression(array);
let array = f.fold_array_expression(array);
let index = f.fold_field_expression(index);
FieldElementExpression::Select(box array, box index)
}
@ -277,6 +284,17 @@ pub fn fold_boolean_expression<'ast, T: Field, F: Folder<'ast, T>>(
let e = f.fold_boolean_expression(e);
BooleanExpression::Not(box e)
}
BooleanExpression::IfElse(box cond, box cons, box alt) => {
let cond = f.fold_boolean_expression(cond);
let cons = f.fold_boolean_expression(cons);
let alt = f.fold_boolean_expression(alt);
BooleanExpression::IfElse(box cond, box cons, box alt)
}
BooleanExpression::Select(box array, box index) => {
let array = f.fold_array_expression(array);
let index = f.fold_field_expression(index);
BooleanExpression::Select(box array, box index)
}
}
}
@ -299,6 +317,16 @@ pub fn fold_function<'ast, T: Field, F: Folder<'ast, T>>(
}
}
pub fn fold_array_expression<'ast, T: Field, F: Folder<'ast, T>>(
f: &mut F,
e: ArrayExpression<'ast, T>,
) -> ArrayExpression<'ast, T> {
ArrayExpression {
inner: f.fold_array_expression_inner(&e.ty, e.size, e.inner),
..e
}
}
pub fn fold_function_symbol<'ast, T: Field, F: Folder<'ast, T>>(
f: &mut F,
s: TypedFunctionSymbol<'ast, T>,

View file

@ -14,6 +14,7 @@ pub use crate::typed_absy::variable::Variable;
use crate::types::{FunctionKey, Signature, Type};
use embed::FlatEmbed;
use std::collections::HashMap;
use std::convert::TryFrom;
use std::fmt;
use zokrates_field::field::Field;
@ -80,17 +81,21 @@ pub struct TypedModule<'ast, T: Field> {
impl<'ast> fmt::Display for Identifier<'ast> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(
f,
"{}_{}_{}",
self.stack
.iter()
.map(|(name, sig, count)| format!("{}_{}_{}", name, sig.to_slug(), count))
.collect::<Vec<_>>()
.join("_"),
self.id,
self.version
)
if self.stack.len() == 0 && self.version == 0 {
write!(f, "{}", self.id)
} else {
write!(
f,
"{}_{}_{}",
self.stack
.iter()
.map(|(name, sig, count)| format!("{}_{}_{}", name, sig.to_slug(), count))
.collect::<Vec<_>>()
.join("_"),
self.id,
self.version
)
}
}
}
@ -230,7 +235,7 @@ impl<'ast, T: Field> fmt::Debug for TypedFunction<'ast, T> {
#[derive(Clone, PartialEq, Hash, Eq)]
pub enum TypedAssignee<'ast, T: Field> {
Identifier(Variable<'ast>),
ArrayElement(
Select(
Box<TypedAssignee<'ast, T>>,
Box<FieldElementExpression<'ast, T>>,
),
@ -240,11 +245,11 @@ impl<'ast, T: Field> Typed for TypedAssignee<'ast, T> {
fn get_type(&self) -> Type {
match *self {
TypedAssignee::Identifier(ref v) => v.get_type(),
TypedAssignee::ArrayElement(ref a, _) => {
TypedAssignee::Select(ref a, _) => {
let a_type = a.get_type();
match a_type {
Type::FieldElementArray(_) => Type::FieldElement,
_ => panic!("array element has to take array"),
Type::Array(box t, _) => t,
_ => unreachable!("an array element should only be defined over arrays"),
}
}
}
@ -255,7 +260,7 @@ impl<'ast, T: Field> fmt::Debug for TypedAssignee<'ast, T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
TypedAssignee::Identifier(ref s) => write!(f, "{}", s.id),
TypedAssignee::ArrayElement(ref a, ref e) => write!(f, "{}[{}]", a, e),
TypedAssignee::Select(ref a, ref e) => write!(f, "{}[{}]", a, e),
}
}
}
@ -267,7 +272,7 @@ impl<'ast, T: Field> fmt::Display for TypedAssignee<'ast, T> {
}
/// A statement in a `TypedFunction`
#[derive(Clone, PartialEq)]
#[derive(Clone, PartialEq, Hash, Eq)]
pub enum TypedStatement<'ast, T: Field> {
Return(Vec<TypedExpression<'ast, T>>),
Definition(TypedAssignee<'ast, T>, TypedExpression<'ast, T>),
@ -356,7 +361,7 @@ pub trait Typed {
pub enum TypedExpression<'ast, T: Field> {
Boolean(BooleanExpression<'ast, T>),
FieldElement(FieldElementExpression<'ast, T>),
FieldElementArray(FieldElementArrayExpression<'ast, T>),
Array(ArrayExpression<'ast, T>),
}
impl<'ast, T: Field> From<BooleanExpression<'ast, T>> for TypedExpression<'ast, T> {
@ -371,9 +376,9 @@ impl<'ast, T: Field> From<FieldElementExpression<'ast, T>> for TypedExpression<'
}
}
impl<'ast, T: Field> From<FieldElementArrayExpression<'ast, T>> for TypedExpression<'ast, T> {
fn from(e: FieldElementArrayExpression<'ast, T>) -> TypedExpression<T> {
TypedExpression::FieldElementArray(e)
impl<'ast, T: Field> From<ArrayExpression<'ast, T>> for TypedExpression<'ast, T> {
fn from(e: ArrayExpression<'ast, T>) -> TypedExpression<T> {
TypedExpression::Array(e)
}
}
@ -382,7 +387,7 @@ impl<'ast, T: Field> fmt::Display for TypedExpression<'ast, T> {
match *self {
TypedExpression::Boolean(ref e) => write!(f, "{}", e),
TypedExpression::FieldElement(ref e) => write!(f, "{}", e),
TypedExpression::FieldElementArray(ref e) => write!(f, "{}", e),
TypedExpression::Array(ref e) => write!(f, "{}", e.inner),
}
}
}
@ -392,29 +397,48 @@ impl<'ast, T: Field> fmt::Debug for TypedExpression<'ast, T> {
match *self {
TypedExpression::Boolean(ref e) => write!(f, "{:?}", e),
TypedExpression::FieldElement(ref e) => write!(f, "{:?}", e),
TypedExpression::FieldElementArray(ref e) => write!(f, "{:?}", e),
TypedExpression::Array(ref e) => write!(f, "{:?}", e),
}
}
}
impl<'ast, T: Field> fmt::Display for ArrayExpression<'ast, T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}", self.inner)
}
}
impl<'ast, T: Field> fmt::Debug for ArrayExpression<'ast, T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{:?}", self.inner)
}
}
impl<'ast, T: Field> Typed for TypedExpression<'ast, T> {
fn get_type(&self) -> Type {
match *self {
TypedExpression::Boolean(_) => Type::Boolean,
TypedExpression::FieldElement(_) => Type::FieldElement,
TypedExpression::FieldElementArray(ref e) => e.get_type(),
TypedExpression::Boolean(ref e) => e.get_type(),
TypedExpression::FieldElement(ref e) => e.get_type(),
TypedExpression::Array(ref e) => e.get_type(),
}
}
}
impl<'ast, T: Field> Typed for FieldElementArrayExpression<'ast, T> {
impl<'ast, T: Field> Typed for ArrayExpression<'ast, T> {
fn get_type(&self) -> Type {
match *self {
FieldElementArrayExpression::Identifier(n, _) => Type::FieldElementArray(n),
FieldElementArrayExpression::Value(n, _) => Type::FieldElementArray(n),
FieldElementArrayExpression::FunctionCall(n, _, _) => Type::FieldElementArray(n),
FieldElementArrayExpression::IfElse(_, ref consequence, _) => consequence.get_type(),
}
Type::array(self.ty.clone(), self.size)
}
}
impl<'ast, T: Field> Typed for FieldElementExpression<'ast, T> {
fn get_type(&self) -> Type {
Type::FieldElement
}
}
impl<'ast, T: Field> Typed for BooleanExpression<'ast, T> {
fn get_type(&self) -> Type {
Type::Boolean
}
}
@ -422,7 +446,7 @@ pub trait MultiTyped {
fn get_types(&self) -> &Vec<Type>;
}
#[derive(Clone, PartialEq)]
#[derive(Clone, PartialEq, Hash, Eq)]
pub enum TypedExpressionList<'ast, T: Field> {
FunctionCall(FunctionKey<'ast>, Vec<TypedExpression<'ast, T>>, Vec<Type>),
}
@ -467,7 +491,7 @@ pub enum FieldElementExpression<'ast, T: Field> {
),
FunctionCall(FunctionKey<'ast>, Vec<TypedExpression<'ast, T>>),
Select(
Box<FieldElementArrayExpression<'ast, T>>,
Box<ArrayExpression<'ast, T>>,
Box<FieldElementExpression<'ast, T>>,
),
}
@ -506,30 +530,108 @@ pub enum BooleanExpression<'ast, T: Field> {
Box<BooleanExpression<'ast, T>>,
),
Not(Box<BooleanExpression<'ast, T>>),
}
/// An expression of type `field[n]
/// # Remarks
/// * for now we store the array size in the variants
#[derive(Clone, PartialEq, Hash, Eq)]
pub enum FieldElementArrayExpression<'ast, T: Field> {
Identifier(usize, Identifier<'ast>),
Value(usize, Vec<FieldElementExpression<'ast, T>>),
FunctionCall(usize, FunctionKey<'ast>, Vec<TypedExpression<'ast, T>>),
IfElse(
Box<BooleanExpression<'ast, T>>,
Box<FieldElementArrayExpression<'ast, T>>,
Box<FieldElementArrayExpression<'ast, T>>,
Box<BooleanExpression<'ast, T>>,
Box<BooleanExpression<'ast, T>>,
),
Select(
Box<ArrayExpression<'ast, T>>,
Box<FieldElementExpression<'ast, T>>,
),
}
impl<'ast, T: Field> FieldElementArrayExpression<'ast, T> {
/// An expression of type `array`
/// # Remarks
/// * Contrary to basic types which represented as enums, we wrap an enum `ArrayExpressionInner` in a struct in order to keep track of the type (content and size)
/// of the array. Only using an enum would require generics, which would propagate up to TypedExpression which we want to keep simple, hence this "runtime"
/// type checking
#[derive(Clone, PartialEq, Hash, Eq)]
pub struct ArrayExpression<'ast, T: Field> {
size: usize,
ty: Type,
inner: ArrayExpressionInner<'ast, T>,
}
#[derive(Clone, PartialEq, Hash, Eq)]
pub enum ArrayExpressionInner<'ast, T: Field> {
Identifier(Identifier<'ast>),
Value(Vec<TypedExpression<'ast, T>>),
FunctionCall(FunctionKey<'ast>, Vec<TypedExpression<'ast, T>>),
IfElse(
Box<BooleanExpression<'ast, T>>,
Box<ArrayExpression<'ast, T>>,
Box<ArrayExpression<'ast, T>>,
),
Select(
Box<ArrayExpression<'ast, T>>,
Box<FieldElementExpression<'ast, T>>,
),
}
impl<'ast, T: Field> ArrayExpressionInner<'ast, T> {
pub fn annotate(self, ty: Type, size: usize) -> ArrayExpression<'ast, T> {
ArrayExpression {
size,
ty,
inner: self,
}
}
}
impl<'ast, T: Field> ArrayExpression<'ast, T> {
pub fn inner_type(&self) -> &Type {
&self.ty
}
pub fn size(&self) -> usize {
match *self {
FieldElementArrayExpression::Identifier(s, _)
| FieldElementArrayExpression::Value(s, _)
| FieldElementArrayExpression::FunctionCall(s, ..) => s,
FieldElementArrayExpression::IfElse(_, ref consequence, _) => consequence.size(),
self.size
}
pub fn as_inner(&self) -> &ArrayExpressionInner<'ast, T> {
&self.inner
}
pub fn into_inner(self) -> ArrayExpressionInner<'ast, T> {
self.inner
}
}
// Downcasts
// Due to the fact that we keep TypedExpression simple, we end up with ArrayExpressionInner::Value whose elements are any TypedExpression, but we enforce by
// construction that these elements are of the type declared in the corresponding ArrayExpression. As we know this by construction, we can downcast the TypedExpression to the correct type
// ArrayExpression { type: Type::FieldElement, size: 42, inner: [TypedExpression::FieldElement(FieldElementExpression), ...]} <- the fact that inner only contains field elements is not enforced by the rust type system
impl<'ast, T: Field> TryFrom<TypedExpression<'ast, T>> for FieldElementExpression<'ast, T> {
type Error = ();
fn try_from(
te: TypedExpression<'ast, T>,
) -> Result<FieldElementExpression<'ast, T>, Self::Error> {
match te {
TypedExpression::FieldElement(e) => Ok(e),
_ => Err(()),
}
}
}
impl<'ast, T: Field> TryFrom<TypedExpression<'ast, T>> for BooleanExpression<'ast, T> {
type Error = ();
fn try_from(te: TypedExpression<'ast, T>) -> Result<BooleanExpression<'ast, T>, Self::Error> {
match te {
TypedExpression::Boolean(e) => Ok(e),
_ => Err(()),
}
}
}
impl<'ast, T: Field> TryFrom<TypedExpression<'ast, T>> for ArrayExpression<'ast, T> {
type Error = ();
fn try_from(te: TypedExpression<'ast, T>) -> Result<ArrayExpression<'ast, T>, Self::Error> {
match te {
TypedExpression::Array(e) => Ok(e),
_ => Err(()),
}
}
}
@ -579,15 +681,21 @@ impl<'ast, T: Field> fmt::Display for BooleanExpression<'ast, T> {
BooleanExpression::And(ref lhs, ref rhs) => write!(f, "{} && {}", lhs, rhs),
BooleanExpression::Not(ref exp) => write!(f, "!{}", exp),
BooleanExpression::Value(b) => write!(f, "{}", b),
BooleanExpression::IfElse(ref condition, ref consequent, ref alternative) => write!(
f,
"if {} then {} else {} fi",
condition, consequent, alternative
),
BooleanExpression::Select(ref id, ref index) => write!(f, "{}[{}]", id, index),
}
}
}
impl<'ast, T: Field> fmt::Display for FieldElementArrayExpression<'ast, T> {
impl<'ast, T: Field> fmt::Display for ArrayExpressionInner<'ast, T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
FieldElementArrayExpression::Identifier(_, ref var) => write!(f, "{}", var),
FieldElementArrayExpression::Value(_, ref values) => write!(
ArrayExpressionInner::Identifier(ref var) => write!(f, "{}", var),
ArrayExpressionInner::Value(ref values) => write!(
f,
"[{}]",
values
@ -596,7 +704,7 @@ impl<'ast, T: Field> fmt::Display for FieldElementArrayExpression<'ast, T> {
.collect::<Vec<String>>()
.join(", ")
),
FieldElementArrayExpression::FunctionCall(_, ref key, ref p) => {
ArrayExpressionInner::FunctionCall(ref key, ref p) => {
write!(f, "{}(", key.id,)?;
for (i, param) in p.iter().enumerate() {
write!(f, "{}", param)?;
@ -606,13 +714,12 @@ impl<'ast, T: Field> fmt::Display for FieldElementArrayExpression<'ast, T> {
}
write!(f, ")")
}
FieldElementArrayExpression::IfElse(ref condition, ref consequent, ref alternative) => {
write!(
f,
"if {} then {} else {} fi",
condition, consequent, alternative
)
}
ArrayExpressionInner::IfElse(ref condition, ref consequent, ref alternative) => write!(
f,
"if {} then {} else {} fi",
condition, consequent, alternative
),
ArrayExpressionInner::Select(ref id, ref index) => write!(f, "{}[{}]", id, index),
}
}
}
@ -654,22 +761,23 @@ impl<'ast, T: Field> fmt::Debug for FieldElementExpression<'ast, T> {
}
}
impl<'ast, T: Field> fmt::Debug for FieldElementArrayExpression<'ast, T> {
impl<'ast, T: Field> fmt::Debug for ArrayExpressionInner<'ast, T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
FieldElementArrayExpression::Identifier(_, ref var) => write!(f, "{:?}", var),
FieldElementArrayExpression::Value(_, ref values) => write!(f, "{:?}", values),
FieldElementArrayExpression::FunctionCall(_, ref i, ref p) => {
ArrayExpressionInner::Identifier(ref var) => write!(f, "Identifier({:?})", var),
ArrayExpressionInner::Value(ref values) => write!(f, "Value({:?})", values),
ArrayExpressionInner::FunctionCall(ref i, ref p) => {
write!(f, "FunctionCall({:?}, (", i)?;
f.debug_list().entries(p.iter()).finish()?;
write!(f, ")")
}
FieldElementArrayExpression::IfElse(ref condition, ref consequent, ref alternative) => {
write!(
f,
"IfElse({:?}, {:?}, {:?})",
condition, consequent, alternative
)
ArrayExpressionInner::IfElse(ref condition, ref consequent, ref alternative) => write!(
f,
"IfElse({:?}, {:?}, {:?})",
condition, consequent, alternative
),
ArrayExpressionInner::Select(ref id, ref index) => {
write!(f, "Select({:?}, {:?})", id, index)
}
}
}
@ -703,3 +811,70 @@ impl<'ast, T: Field> fmt::Debug for TypedExpressionList<'ast, T> {
}
}
}
// Common behaviour accross expressions
pub trait IfElse<'ast, T: Field> {
fn if_else(condition: BooleanExpression<'ast, T>, consequence: Self, alternative: Self)
-> Self;
}
impl<'ast, T: Field> IfElse<'ast, T> for FieldElementExpression<'ast, T> {
fn if_else(
condition: BooleanExpression<'ast, T>,
consequence: Self,
alternative: Self,
) -> Self {
FieldElementExpression::IfElse(box condition, box consequence, box alternative)
}
}
impl<'ast, T: Field> IfElse<'ast, T> for BooleanExpression<'ast, T> {
fn if_else(
condition: BooleanExpression<'ast, T>,
consequence: Self,
alternative: Self,
) -> Self {
BooleanExpression::IfElse(box condition, box consequence, box alternative)
}
}
impl<'ast, T: Field> IfElse<'ast, T> for ArrayExpression<'ast, T> {
fn if_else(
condition: BooleanExpression<'ast, T>,
consequence: Self,
alternative: Self,
) -> Self {
let ty = consequence.inner_type().clone();
let size = consequence.size();
ArrayExpressionInner::IfElse(box condition, box consequence, box alternative)
.annotate(ty, size)
}
}
pub trait Select<'ast, T: Field> {
fn select(array: ArrayExpression<'ast, T>, index: FieldElementExpression<'ast, T>) -> Self;
}
impl<'ast, T: Field> Select<'ast, T> for FieldElementExpression<'ast, T> {
fn select(array: ArrayExpression<'ast, T>, index: FieldElementExpression<'ast, T>) -> Self {
FieldElementExpression::Select(box array, box index)
}
}
impl<'ast, T: Field> Select<'ast, T> for BooleanExpression<'ast, T> {
fn select(array: ArrayExpression<'ast, T>, index: FieldElementExpression<'ast, T>) -> Self {
BooleanExpression::Select(box array, box index)
}
}
impl<'ast, T: Field> Select<'ast, T> for ArrayExpression<'ast, T> {
fn select(array: ArrayExpression<'ast, T>, index: FieldElementExpression<'ast, T>) -> Self {
let (ty, size) = match array.inner_type() {
Type::Array(inner, size) => (inner.clone(), size.clone()),
_ => unreachable!(),
};
ArrayExpressionInner::Select(box array, box index).annotate(*ty, size)
}
}

View file

@ -18,8 +18,13 @@ impl<'ast> Variable<'ast> {
Self::with_id_and_type(id, Type::Boolean)
}
#[cfg(test)]
pub fn field_array(id: Identifier<'ast>, size: usize) -> Variable<'ast> {
Self::with_id_and_type(id, Type::FieldElementArray(size))
Self::array(id, Type::FieldElement, size)
}
pub fn array(id: Identifier<'ast>, ty: Type, size: usize) -> Variable<'ast> {
Self::with_id_and_type(id, Type::array(ty, size))
}
pub fn with_id_and_type(id: Identifier<'ast>, _type: Type) -> Variable<'ast> {

View file

@ -9,35 +9,39 @@ mod signature;
pub enum Type {
FieldElement,
Boolean,
FieldElementArray(usize),
Array(Box<Type>, usize),
}
impl fmt::Display for Type {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
match self {
Type::FieldElement => write!(f, "field"),
Type::Boolean => write!(f, "bool"),
Type::FieldElementArray(size) => write!(f, "{}[{}]", Type::FieldElement, size),
Type::Array(ref ty, ref size) => write!(f, "{}[{}]", ty, size),
}
}
}
impl fmt::Debug for Type {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
match self {
Type::FieldElement => write!(f, "field"),
Type::Boolean => write!(f, "bool"),
Type::FieldElementArray(size) => write!(f, "{}[{}]", Type::FieldElement, size),
Type::Array(ref ty, ref size) => write!(f, "{}[{}]", ty, size),
}
}
}
impl Type {
pub fn array(ty: Type, size: usize) -> Self {
Type::Array(box ty, size)
}
fn to_slug(&self) -> String {
match *self {
match self {
Type::FieldElement => String::from("f"),
Type::Boolean => String::from("b"),
Type::FieldElementArray(size) => format!("{}[{}]", Type::FieldElement.to_slug(), size),
Type::Array(box ty, size) => format!("{}[{}]", ty.to_slug(), size),
}
}
@ -46,7 +50,7 @@ impl Type {
match self {
Type::FieldElement => 1,
Type::Boolean => 1,
Type::FieldElementArray(size) => size * Type::FieldElement.get_primitive_count(),
Type::Array(ty, size) => size * ty.get_primitive_count(),
}
}
}
@ -101,7 +105,7 @@ mod tests {
#[test]
fn array() {
let t = Type::FieldElementArray(42);
let t = Type::Array(box Type::FieldElement, 42);
assert_eq!(t.get_primitive_count(), 42);
}
}

View file

@ -111,4 +111,50 @@ mod tests {
assert_eq!(s.to_string(), String::from("(field, bool) -> (bool)"));
}
#[test]
fn slug_0() {
let s = Signature::new().inputs(vec![]).outputs(vec![]);
assert_eq!(s.to_slug(), String::from("io"));
}
#[test]
fn slug_1() {
let s = Signature::new()
.inputs(vec![Type::FieldElement, Type::Boolean])
.outputs(vec![
Type::FieldElement,
Type::FieldElement,
Type::Boolean,
Type::FieldElement,
]);
assert_eq!(s.to_slug(), String::from("ifbo2fbf"));
}
#[test]
fn slug_2() {
let s = Signature::new()
.inputs(vec![
Type::FieldElement,
Type::FieldElement,
Type::FieldElement,
])
.outputs(vec![Type::FieldElement, Type::Boolean, Type::FieldElement]);
assert_eq!(s.to_slug(), String::from("i3fofbf"));
}
#[test]
fn array_slug() {
let s = Signature::new()
.inputs(vec![
Type::array(Type::FieldElement, 42),
Type::array(Type::FieldElement, 21),
])
.outputs(vec![]);
assert_eq!(s.to_slug(), String::from("if[42]f[21]o"));
}
}

View file

@ -16,7 +16,7 @@ ty_field = {"field"}
ty_bool = {"bool"}
ty_basic = { ty_field | ty_bool }
// (unidimensional for now) arrays of (basic for now) types
ty_array = { ty_basic ~ ("[" ~ expression ~ "]") }
ty_array = { ty_basic ~ ("[" ~ expression ~ "]")+ }
ty = { ty_array | ty_basic }
type_list = _{(ty ~ ("," ~ ty)*)?}

View file

@ -9,8 +9,8 @@ extern crate lazy_static;
pub use ast::{
Access, ArrayAccess, ArrayInitializerExpression, ArrayType, AssertionStatement, Assignee,
AssignmentStatement, BasicType, BinaryExpression, BinaryOperator, CallAccess,
ConstantExpression, DefinitionStatement, Expression, File, FromExpression, Function,
AssignmentStatement, BasicType, BinaryExpression, BinaryOperator, BooleanType, CallAccess,
ConstantExpression, DefinitionStatement, Expression, FieldType, File, FromExpression, Function,
IdentifierExpression, ImportDirective, ImportSource, InlineArrayExpression, IterationStatement,
MultiAssignmentStatement, Parameter, PostfixExpression, Range, RangeOrExpression,
ReturnStatement, Span, Spread, SpreadOrExpression, Statement, TernaryExpression, ToExpression,
@ -193,36 +193,33 @@ mod ast {
#[derive(Debug, FromPest, PartialEq, Clone)]
#[pest_ast(rule(Rule::ty))]
pub enum Type<'ast> {
Basic(BasicType<'ast>),
Basic(BasicType),
Array(ArrayType<'ast>),
}
#[derive(Debug, FromPest, PartialEq, Clone)]
#[pest_ast(rule(Rule::ty_basic))]
pub enum BasicType<'ast> {
pub enum BasicType {
Field(FieldType),
Boolean(BooleanType<'ast>),
Boolean(BooleanType),
}
#[derive(Debug, FromPest, PartialEq, Clone)]
#[pest_ast(rule(Rule::ty_field))]
pub struct FieldType {}
pub struct FieldType;
#[derive(Debug, FromPest, PartialEq, Clone)]
#[pest_ast(rule(Rule::ty_array))]
pub struct ArrayType<'ast> {
pub ty: BasicType<'ast>,
pub size: Expression<'ast>,
pub ty: BasicType,
pub dimensions: Vec<Expression<'ast>>,
#[pest_ast(outer())]
pub span: Span<'ast>,
}
#[derive(Debug, FromPest, PartialEq, Clone)]
#[pest_ast(rule(Rule::ty_bool))]
pub struct BooleanType<'ast> {
#[pest_ast(outer())]
pub span: Span<'ast>,
}
pub struct BooleanType;
#[derive(Debug, FromPest, PartialEq, Clone)]
#[pest_ast(rule(Rule::parameter))]
@ -403,7 +400,7 @@ mod ast {
#[pest_ast(rule(Rule::postfix_expression))]
pub struct PostfixExpression<'ast> {
pub id: IdentifierExpression<'ast>,
pub access: Vec<Access<'ast>>,
pub accesses: Vec<Access<'ast>>,
#[pest_ast(outer())]
pub span: Span<'ast>,
}