1
0
Fork 0
mirror of synced 2025-09-24 04:40:05 +00:00

implement multidim updates

This commit is contained in:
schaeff 2019-09-09 09:59:24 +02:00
parent 05b59393e5
commit ca6d92b349
10 changed files with 661 additions and 444 deletions

View file

@ -0,0 +1,4 @@
def main(field[10][10][10] a, field i, field j, field k) -> (field[3]):
a[i][j][k] = 42
field[3][3] b = [[1, 2, 3], [1, 2, 3], [1, 2, 3]]
return b[0]

View file

@ -505,19 +505,15 @@ impl<'ast, T: Field> From<pest::Assignee<'ast>> for absy::AssigneeNode<'ast, T>
use absy::NodeValue;
let a = absy::AssigneeNode::from(assignee.id);
match assignee.indices.len() {
0 => a,
1 => absy::Assignee::ArrayElement(
box a,
box absy::RangeOrExpression::from(assignee.indices[0].clone()),
)
.span(assignee.span),
n => unimplemented!(
"Assignment to array of {} dimensions not supported yet (found in {})",
n,
a
),
}
let span = assignee.span;
assignee
.indices
.into_iter()
.map(|i| absy::RangeOrExpression::from(i))
.fold(a, |acc, s| {
absy::Assignee::Select(box acc, box s).span(span.clone())
})
}
}

View file

@ -198,7 +198,7 @@ impl<'ast, T: Field> fmt::Debug for Function<'ast, T> {
#[derive(Clone, PartialEq)]
pub enum Assignee<'ast, T: Field> {
Identifier(Identifier<'ast>),
ArrayElement(Box<AssigneeNode<'ast, T>>, Box<RangeOrExpression<'ast, T>>),
Select(Box<AssigneeNode<'ast, T>>, Box<RangeOrExpression<'ast, T>>),
}
pub type AssigneeNode<'ast, T> = Node<Assignee<'ast, T>>;
@ -206,8 +206,8 @@ pub type AssigneeNode<'ast, T> = Node<Assignee<'ast, T>>;
impl<'ast, T: Field> fmt::Debug for Assignee<'ast, T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
Assignee::Identifier(ref s) => write!(f, "{}", s),
Assignee::ArrayElement(ref a, ref e) => write!(f, "{}[{}]", a, e),
Assignee::Identifier(ref s) => write!(f, "Identifier({:?})", s),
Assignee::Select(ref a, ref e) => write!(f, "Select({:?}[{:?}])", a, e),
}
}
}

View file

@ -153,6 +153,8 @@ pub fn compile<T: Field, R: BufRead, S: BufRead, E: Into<imports::Error>>(
// analyse (unroll and constant propagation)
let typed_ast = typed_ast.analyse();
println!("{}", typed_ast);
// flatten input program
let program_flattened = Flattener::flatten(typed_ast);

View file

@ -250,12 +250,30 @@ impl<'ast, T: Field> Flattener<'ast, T> {
assert!(n < T::from(size));
let n = n.to_dec_string().parse::<usize>().unwrap();
let e = self.flatten_select_expression::<U>(
symbols,
statements_flattened,
array,
index,
);
let e = match array.inner_type() {
Type::FieldElement => self
.flatten_select_expression::<FieldElementExpression<'ast, T>>(
symbols,
statements_flattened,
array,
index,
),
Type::Boolean => self
.flatten_select_expression::<BooleanExpression<'ast, T>>(
symbols,
statements_flattened,
array,
index,
),
Type::Array(..) => self
.flatten_select_expression::<ArrayExpression<'ast, T>>(
symbols,
statements_flattened,
array,
index,
),
};
e[n * element_size..(n + 1) * element_size]
.into_iter()
.map(|i| i.clone().into())
@ -1151,7 +1169,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
.map(|(v, e)| FlatStatement::Definition(v, e)),
);
}
TypedAssignee::ArrayElement(..) => unreachable!(
TypedAssignee::Select(..) => unreachable!(
"array element redefs should have been replaced by array redefs in unroll"
),
}

View file

@ -601,6 +601,8 @@ impl<'ast> Checker<'ast> {
&mut self,
assignee: AssigneeNode<'ast, T>,
) -> Result<TypedAssignee<'ast, T>, Error> {
println!("{:?}", assignee.value);
let pos = assignee.pos();
// check that the assignee is declared
match assignee.value {
@ -616,7 +618,7 @@ impl<'ast> Checker<'ast> {
message: format!("Undeclared variable: {:?}", variable_name),
}),
},
Assignee::ArrayElement(box assignee, box index) => {
Assignee::Select(box assignee, box index) => {
let checked_assignee = self.check_assignee(assignee)?;
let checked_index = match index {
RangeOrExpression::Expression(e) => self.check_expression(e)?,
@ -639,7 +641,7 @@ impl<'ast> Checker<'ast> {
}),
}?;
Ok(TypedAssignee::ArrayElement(
Ok(TypedAssignee::Select(
box checked_assignee,
box checked_typed_index,
))
@ -2422,7 +2424,7 @@ mod tests {
fn array_element() {
// field[33] a
// a[2] = 42
let a = Assignee::ArrayElement(
let a = Assignee::Select(
box Assignee::Identifier("a").mock(),
box RangeOrExpression::Expression(
Expression::FieldConstant(FieldPrime::from(2)).mock(),
@ -2440,7 +2442,7 @@ mod tests {
assert_eq!(
checker.check_assignee(a),
Ok(TypedAssignee::ArrayElement(
Ok(TypedAssignee::Select(
box TypedAssignee::Identifier(typed_absy::Variable::field_array(
"a".into(),
33
@ -2454,8 +2456,8 @@ mod tests {
fn array_of_array_element() {
// field[33][42] a
// a[1][2]
let a = Assignee::ArrayElement(
box Assignee::ArrayElement(
let a = Assignee::Select(
box Assignee::Select(
box Assignee::Identifier("a").mock(),
box RangeOrExpression::Expression(
Expression::FieldConstant(FieldPrime::from(1)).mock(),
@ -2481,8 +2483,8 @@ mod tests {
assert_eq!(
checker.check_assignee(a),
Ok(TypedAssignee::ArrayElement(
box TypedAssignee::ArrayElement(
Ok(TypedAssignee::Select(
box TypedAssignee::Select(
box TypedAssignee::Identifier(typed_absy::Variable::array(
"a".into(),
Type::array(Type::FieldElement, 33),

View file

@ -35,105 +35,90 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> {
fn fold_statement(&mut self, s: TypedStatement<'ast, T>) -> Vec<TypedStatement<'ast, T>> {
let res = match s {
TypedStatement::Declaration(v) => Some(TypedStatement::Declaration(v)),
TypedStatement::Return(expressions) => Some(TypedStatement::Return(expressions.into_iter().map(|e| self.fold_expression(e)).collect())),
// propagation to the defined variable if rhs is a constant
TypedStatement::Definition(TypedAssignee::Identifier(var), expr) => {
match self.fold_expression(expr) {
e @ TypedExpression::Boolean(BooleanExpression::Value(..)) | e @ TypedExpression::FieldElement(FieldElementExpression::Number(..)) => {
self.constants.insert(TypedAssignee::Identifier(var), e);
None
},
TypedExpression::Array(e) => {
TypedStatement::Declaration(v) => Some(TypedStatement::Declaration(v)),
TypedStatement::Return(expressions) => Some(TypedStatement::Return(
expressions
.into_iter()
.map(|e| self.fold_expression(e))
.collect(),
)),
// propagation to the defined variable if rhs is a constant
TypedStatement::Definition(TypedAssignee::Identifier(var), expr) => {
match self.fold_expression(expr) {
e @ TypedExpression::Boolean(BooleanExpression::Value(..))
| e @ TypedExpression::FieldElement(FieldElementExpression::Number(..)) => {
self.constants.insert(TypedAssignee::Identifier(var), e);
None
}
TypedExpression::Array(e) => {
let ty = e.inner_type().clone();
let size = e.size();
match e.into_inner() {
ArrayExpressionInner::Value(array) => {
let array: Vec<_> = array.into_iter().map(|e| self.fold_expression(e)).collect();
let array: Vec<_> =
array.into_iter().map(|e| self.fold_expression(e)).collect();
match array.iter().all(|e| match e {
TypedExpression::FieldElement(FieldElementExpression::Number(..)) => true,
TypedExpression::FieldElement(
FieldElementExpression::Number(..),
) => true,
TypedExpression::Boolean(BooleanExpression::Value(..)) => true,
_ => false
_ => false,
}) {
true => {
// all elements of the array are constants
self.constants.insert(TypedAssignee::Identifier(var), ArrayExpressionInner::Value(array).annotate(
ty,
size).into());
self.constants.insert(
TypedAssignee::Identifier(var),
ArrayExpressionInner::Value(array)
.annotate(ty, size)
.into(),
);
None
},
false => {
Some(TypedStatement::Definition(TypedAssignee::Identifier(var),
ArrayExpressionInner::Value(array).annotate(
ty,
size).into()))
}
false => Some(TypedStatement::Definition(
TypedAssignee::Identifier(var),
ArrayExpressionInner::Value(array)
.annotate(ty, size)
.into(),
)),
}
},
e => Some(TypedStatement::Definition(TypedAssignee::Identifier(var), TypedExpression::Array(e.annotate(ty, size))))
}
e => Some(TypedStatement::Definition(
TypedAssignee::Identifier(var),
TypedExpression::Array(e.annotate(ty, size)),
)),
}
},
e => {
Some(TypedStatement::Definition(TypedAssignee::Identifier(var), e))
}
}
},
// a[b] = c
TypedStatement::Definition(TypedAssignee::ArrayElement(box TypedAssignee::Identifier(var), box index), expr) => {
let index = self.fold_field_expression(index);
let expr = self.fold_expression(expr);
match (index, expr) {
(
FieldElementExpression::Number(n),
TypedExpression::FieldElement(expr @ FieldElementExpression::Number(..))
) => {
// a[42] = 33
// -> store (a[42] -> 33) in the constants, possibly overwriting the previous entry
self.constants.entry(TypedAssignee::Identifier(var)).and_modify(|e| {
match e {
TypedExpression::Array(e) => {
let size = e.size();
match e.as_inner_mut() {
ArrayExpressionInner::Value(ref mut v) => {
let n_as_usize = n.to_dec_string().parse::<usize>().unwrap();
if n_as_usize < size {
v[n_as_usize] = expr.into();
} else {
panic!(format!("out of bounds index ({} >= {}) found during static analysis", n_as_usize, size));
}
},
_ => panic!("constants should only store constants")
}
},
_ => unreachable!()
}
});
None
},
(index, expr) => {
// a[42] = e
// -> remove a from the constants as one of its elements is not constant
self.constants.remove(&TypedAssignee::Identifier(var.clone()));
Some(TypedStatement::Definition(TypedAssignee::ArrayElement(box TypedAssignee::Identifier(var), box index), expr))
}
}
},
TypedStatement::Definition(..) => panic!("multi dimensinal arrays are not supported, this should have been caught during semantic checking"),
// propagate lhs and rhs for conditions
TypedStatement::Condition(e1, e2) => {
// could stop execution here if condition is known to fail
Some(TypedStatement::Condition(self.fold_expression(e1), self.fold_expression(e2)))
},
// we unrolled for loops in the previous step
TypedStatement::For(..) => panic!("for loop is unexpected, it should have been unrolled"),
TypedStatement::MultipleDefinition(variables, expression_list) => {
let expression_list = self.fold_expression_list(expression_list);
Some(TypedStatement::MultipleDefinition(variables, expression_list))
}
};
}
e => Some(TypedStatement::Definition(
TypedAssignee::Identifier(var),
e,
)),
}
}
TypedStatement::Definition(TypedAssignee::Select(..), _) => {
unreachable!("array updates should have been replaced with full array redef")
}
// propagate lhs and rhs for conditions
TypedStatement::Condition(e1, e2) => {
// could stop execution here if condition is known to fail
Some(TypedStatement::Condition(
self.fold_expression(e1),
self.fold_expression(e2),
))
}
// we unrolled for loops in the previous step
TypedStatement::For(..) => {
panic!("for loop is unexpected, it should have been unrolled")
}
TypedStatement::MultipleDefinition(variables, expression_list) => {
let expression_list = self.fold_expression_list(expression_list);
Some(TypedStatement::MultipleDefinition(
variables,
expression_list,
))
}
};
match res {
Some(v) => vec![v],
None => vec![],
@ -244,7 +229,7 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> {
}
}
(ArrayExpressionInner::Identifier(id), FieldElementExpression::Number(n)) => {
match self.constants.get(&TypedAssignee::ArrayElement(
match self.constants.get(&TypedAssignee::Select(
box TypedAssignee::Identifier(Variable::array(
id.clone(),
inner_type.clone(),
@ -293,6 +278,15 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> {
None => ArrayExpressionInner::Identifier(id),
}
}
ArrayExpressionInner::IfElse(box condition, box consequence, box alternative) => {
let consequence = self.fold_array_expression(consequence);
let alternative = self.fold_array_expression(alternative);
match self.fold_boolean_expression(condition) {
BooleanExpression::Value(true) => consequence.into_inner(),
BooleanExpression::Value(false) => alternative.into_inner(),
c => ArrayExpressionInner::IfElse(box c, box consequence, box alternative),
}
}
e => fold_array_expression_inner(self, ty, size, e),
}
}
@ -809,116 +803,4 @@ mod tests {
}
}
}
#[cfg(test)]
mod statement {
use super::*;
#[cfg(test)]
mod definition {
use super::*;
#[test]
fn update_constant_array() {
// field[2] a = [21, 22]
// // constants should store [21, 22]
// a[1] = 42
// // constants should store [21, 42]
let declaration = TypedStatement::Declaration(Variable::field_array("a".into(), 2));
let definition = TypedStatement::Definition(
TypedAssignee::Identifier(Variable::field_array("a".into(), 2)),
ArrayExpressionInner::Value(vec![
FieldElementExpression::Number(FieldPrime::from(21)).into(),
FieldElementExpression::Number(FieldPrime::from(22)).into(),
])
.annotate(Type::FieldElement, 2)
.into(),
);
let overwrite = TypedStatement::Definition(
TypedAssignee::ArrayElement(
box TypedAssignee::Identifier(Variable::field_array("a".into(), 2)),
box FieldElementExpression::Number(FieldPrime::from(1)),
),
FieldElementExpression::Number(FieldPrime::from(42)).into(),
);
let mut p = Propagator::new();
p.fold_statement(declaration);
p.fold_statement(definition);
let expected_value: TypedExpression<FieldPrime> =
ArrayExpressionInner::Value(vec![
FieldElementExpression::Number(FieldPrime::from(21)).into(),
FieldElementExpression::Number(FieldPrime::from(22)).into(),
])
.annotate(Type::FieldElement, 2)
.into();
assert_eq!(
p.constants
.get(&TypedAssignee::Identifier(Variable::field_array(
"a".into(),
2
)))
.unwrap(),
&expected_value
);
p.fold_statement(overwrite);
let expected_value: TypedExpression<FieldPrime> =
ArrayExpressionInner::Value(vec![
FieldElementExpression::Number(FieldPrime::from(21)).into(),
FieldElementExpression::Number(FieldPrime::from(42)).into(),
])
.annotate(Type::FieldElement, 2)
.into();
assert_eq!(
p.constants
.get(&TypedAssignee::Identifier(Variable::field_array(
"a".into(),
2
)))
.unwrap(),
&expected_value
);
}
#[test]
fn update_variable_array() {
// propagation does NOT support "partially constant" arrays. That means that in order for updates to use propagation,
// the array needs to have been defined as `field[3] = [value1, value2, value3]` with all values propagateable to constants
// a passed as input
// // constants should store nothing
// a[1] = 42
// // constants should store nothing
let declaration = TypedStatement::Declaration(Variable::field_array("a".into(), 2));
let overwrite = TypedStatement::Definition(
TypedAssignee::ArrayElement(
box TypedAssignee::Identifier(Variable::field_array("a".into(), 2)),
box FieldElementExpression::Number(FieldPrime::from(1)),
),
FieldElementExpression::Number(FieldPrime::from(42)).into(),
);
let mut p = Propagator::new();
p.fold_statement(declaration);
p.fold_statement(overwrite);
assert_eq!(
p.constants
.get(&TypedAssignee::Identifier(Variable::field_array(
"a".into(),
2
))),
None
);
}
}
}
}

View file

@ -8,6 +8,7 @@ use crate::typed_absy::folder::*;
use crate::typed_absy::*;
use crate::types::Type;
use std::collections::HashMap;
use std::collections::HashSet;
use zokrates_field::field::Field;
pub struct Unroller<'ast> {
@ -43,132 +44,181 @@ impl<'ast> Unroller<'ast> {
pub fn unroll<T: Field>(p: TypedProgram<T>) -> TypedProgram<T> {
Unroller::new().fold_program(p)
}
fn choose_many<T: Field>(
base: TypedExpression<'ast, T>,
indices: Vec<FieldElementExpression<'ast, T>>,
new_expression: TypedExpression<'ast, T>,
statements: &mut HashSet<TypedStatement<'ast, T>>,
) -> TypedExpression<'ast, T> {
let mut indices = indices;
match indices.len() {
0 => new_expression,
_ => {
let base = match base {
TypedExpression::Array(e) => e,
e => unreachable!("can't take an element on a {}", e.get_type()),
};
let inner_ty = base.inner_type();
let size = base.size();
let head = indices.pop().unwrap();
let tail = indices;
statements.insert(TypedStatement::Condition(
BooleanExpression::Lt(
box head.clone(),
box FieldElementExpression::Number(T::from(size)),
)
.into(),
BooleanExpression::Value(true).into(),
));
ArrayExpressionInner::Value(
(0..size)
.map(|i| match inner_ty {
Type::Array(..) => ArrayExpression::if_else(
BooleanExpression::Eq(
box FieldElementExpression::Number(T::from(i)),
box head.clone(),
),
match Self::choose_many(
ArrayExpression::select(
base.clone(),
FieldElementExpression::Number(T::from(i)),
)
.into(),
tail.clone(),
new_expression.clone(),
statements,
) {
TypedExpression::Array(e) => e,
e => unreachable!(
"the interior was expected to be an array, was {}",
e.get_type()
),
},
ArrayExpression::select(
base.clone(),
FieldElementExpression::Number(T::from(i)),
),
)
.into(),
Type::FieldElement => FieldElementExpression::if_else(
BooleanExpression::Eq(
box FieldElementExpression::Number(T::from(i)),
box head.clone(),
),
match Self::choose_many(
FieldElementExpression::select(
base.clone(),
FieldElementExpression::Number(T::from(i)),
)
.into(),
tail.clone(),
new_expression.clone(),
statements,
) {
TypedExpression::FieldElement(e) => e,
e => unreachable!(
"the interior was expected to be a field, was {}",
e.get_type()
),
},
FieldElementExpression::select(
base.clone(),
FieldElementExpression::Number(T::from(i)),
),
)
.into(),
Type::Boolean => BooleanExpression::if_else(
BooleanExpression::Eq(
box FieldElementExpression::Number(T::from(i)),
box head.clone(),
),
match Self::choose_many(
BooleanExpression::select(
base.clone(),
FieldElementExpression::Number(T::from(i)),
)
.into(),
tail.clone(),
new_expression.clone(),
statements,
) {
TypedExpression::Boolean(e) => e,
e => unreachable!(
"the interior was expected to be a boolean, was {}",
e.get_type()
),
},
BooleanExpression::select(
base.clone(),
FieldElementExpression::Number(T::from(i)),
),
)
.into(),
})
.collect(),
)
.annotate(inner_ty.clone(), size)
.into()
}
}
}
}
/// Turn an assignee into its representation as a base variable and a list of indices
/// a[2][3][4] -> (a, [2, 3, 4])
fn linear<'ast, T: Field>(
a: TypedAssignee<'ast, T>,
) -> (Variable, Vec<FieldElementExpression<'ast, T>>) {
match a {
TypedAssignee::Identifier(v) => (v, vec![]),
TypedAssignee::Select(box array, box index) => {
let (v, mut indices) = linear(array);
indices.push(index);
(v, indices)
}
}
}
impl<'ast, T: Field> Folder<'ast, T> for Unroller<'ast> {
fn fold_statement(&mut self, s: TypedStatement<'ast, T>) -> Vec<TypedStatement<'ast, T>> {
match s {
TypedStatement::Declaration(_) => vec![],
TypedStatement::Definition(TypedAssignee::Identifier(variable), expr) => {
TypedStatement::Definition(assignee, expr) => {
let expr = self.fold_expression(expr);
vec![TypedStatement::Definition(
TypedAssignee::Identifier(self.issue_next_ssa_variable(variable)),
expr,
)]
}
TypedStatement::Definition(
TypedAssignee::ArrayElement(array @ box TypedAssignee::Identifier(..), box index),
expr,
) => {
let expr = self.fold_expression(expr);
let index = self.fold_field_expression(index);
let current_array = self.fold_assignee(*array.clone());
let (variable, indices) = linear(assignee);
let current_ssa_variable = match current_array {
TypedAssignee::Identifier(v) => v,
_ => panic!("assignee should be an identifier"),
};
let original_variable = match *array {
TypedAssignee::Identifier(v) => v,
_ => panic!("assignee should be an identifier"),
};
let array_size = match original_variable.get_type() {
Type::Array(_, size) => size,
_ => panic!("array identifier should be a field element array"),
};
let new_variable = self.issue_next_ssa_variable(original_variable);
let new_array = match expr {
TypedExpression::FieldElement(e) => ArrayExpressionInner::Value(
(0..array_size)
.map(|i| {
FieldElementExpression::IfElse(
box BooleanExpression::Eq(
box index.clone(),
box FieldElementExpression::Number(T::from(i)),
),
box e.clone(),
box FieldElementExpression::Select(
box ArrayExpressionInner::Identifier(
current_ssa_variable.id.clone(),
)
.annotate(Type::FieldElement, array_size),
box FieldElementExpression::Number(T::from(i)),
),
)
.into()
})
.collect(),
)
.annotate(Type::FieldElement, array_size),
TypedExpression::Boolean(e) => ArrayExpressionInner::Value(
(0..array_size)
.map(|i| {
BooleanExpression::IfElse(
box BooleanExpression::Eq(
box index.clone(),
box FieldElementExpression::Number(T::from(i)),
),
box e.clone(),
box BooleanExpression::Select(
box ArrayExpressionInner::Identifier(
current_ssa_variable.id.clone(),
)
.annotate(Type::Boolean, array_size),
box FieldElementExpression::Number(T::from(i)),
),
)
.into()
})
.collect(),
)
.annotate(Type::Boolean, array_size),
TypedExpression::Array(e) => {
let array_inner_type = Type::array(e.inner_type().clone(), e.size());
let array_size = array_size;
let element_inner_type = e.inner_type().clone();
let element_size = e.size();
ArrayExpressionInner::Value(
(0..array_size)
.map(|i| {
ArrayExpressionInner::IfElse(
box BooleanExpression::Eq(
box index.clone(),
box FieldElementExpression::Number(T::from(i)),
),
box e.clone(),
box ArrayExpressionInner::Select(
box ArrayExpressionInner::Identifier(
current_ssa_variable.id.clone(),
)
.annotate(array_inner_type.clone(), array_size),
box FieldElementExpression::Number(T::from(i)),
)
.annotate(e.inner_type().clone(), e.size()),
)
.annotate(element_inner_type.clone(), element_size)
.into()
})
.collect(),
)
.annotate(array_inner_type.clone(), array_size)
let base = match variable.get_type() {
Type::FieldElement => {
FieldElementExpression::Identifier(variable.id.clone().into()).into()
}
Type::Boolean => {
BooleanExpression::Identifier(variable.id.clone().into()).into()
}
Type::Array(box ty, size) => {
ArrayExpressionInner::Identifier(variable.id.clone().into())
.annotate(ty, size)
.into()
}
};
vec![TypedStatement::Definition(
TypedAssignee::Identifier(new_variable),
new_array.into(),
)]
let mut range_checks = HashSet::new();
let e = Self::choose_many(base, indices, expr, &mut range_checks);
range_checks
.into_iter()
.chain(std::iter::once(TypedStatement::Definition(
TypedAssignee::Identifier(self.issue_next_ssa_variable(variable)),
e,
)))
.collect()
}
TypedStatement::Definition(
TypedAssignee::ArrayElement(box TypedAssignee::ArrayElement(..), _),
_,
) => unreachable!("multi array with redefs breaks ssa now"),
TypedStatement::MultipleDefinition(variables, exprs) => {
let exprs = self.fold_expression_list(exprs);
let variables = variables
@ -234,6 +284,241 @@ mod tests {
use super::*;
use zokrates_field::field::FieldPrime;
#[test]
fn ssa_array() {
let a0 = ArrayExpressionInner::Identifier("a".into()).annotate(Type::FieldElement, 3);
let e = FieldElementExpression::Number(FieldPrime::from(42)).into();
let index = FieldElementExpression::Number(FieldPrime::from(1));
let a1 = Unroller::choose_many(a0.clone().into(), vec![index], e, &mut HashSet::new());
// a[1] = 42
// -> a = [0 == 1 ? 42 : a[0], 1 == 1 ? 42 : a[1], 2 == 1 ? 42 : a[2]]
assert_eq!(
a1,
ArrayExpressionInner::Value(vec![
FieldElementExpression::if_else(
BooleanExpression::Eq(
box FieldElementExpression::Number(FieldPrime::from(0)),
box FieldElementExpression::Number(FieldPrime::from(1))
),
FieldElementExpression::Number(FieldPrime::from(42)),
FieldElementExpression::select(
a0.clone(),
FieldElementExpression::Number(FieldPrime::from(0))
)
)
.into(),
FieldElementExpression::if_else(
BooleanExpression::Eq(
box FieldElementExpression::Number(FieldPrime::from(1)),
box FieldElementExpression::Number(FieldPrime::from(1))
),
FieldElementExpression::Number(FieldPrime::from(42)),
FieldElementExpression::select(
a0.clone(),
FieldElementExpression::Number(FieldPrime::from(1))
)
)
.into(),
FieldElementExpression::if_else(
BooleanExpression::Eq(
box FieldElementExpression::Number(FieldPrime::from(2)),
box FieldElementExpression::Number(FieldPrime::from(1))
),
FieldElementExpression::Number(FieldPrime::from(42)),
FieldElementExpression::select(
a0.clone(),
FieldElementExpression::Number(FieldPrime::from(2))
)
)
.into()
])
.annotate(Type::FieldElement, 3)
.into()
);
let a0 = ArrayExpressionInner::Identifier("a".into())
.annotate(Type::array(Type::FieldElement, 3), 3);
let e = ArrayExpressionInner::Identifier("b".into()).annotate(Type::FieldElement, 3);
let index = FieldElementExpression::Number(FieldPrime::from(1));
let a1 = Unroller::choose_many(
a0.clone().into(),
vec![index],
e.clone().into(),
&mut HashSet::new(),
);
// a[0] = b
// -> a = [0 == 1 ? b : a[0], 1 == 1 ? b : a[1], 2 == 1 ? b : a[2]]
assert_eq!(
a1,
ArrayExpressionInner::Value(vec![
ArrayExpression::if_else(
BooleanExpression::Eq(
box FieldElementExpression::Number(FieldPrime::from(0)),
box FieldElementExpression::Number(FieldPrime::from(1))
),
e.clone(),
ArrayExpression::select(
a0.clone(),
FieldElementExpression::Number(FieldPrime::from(0))
)
)
.into(),
ArrayExpression::if_else(
BooleanExpression::Eq(
box FieldElementExpression::Number(FieldPrime::from(1)),
box FieldElementExpression::Number(FieldPrime::from(1))
),
e.clone(),
ArrayExpression::select(
a0.clone(),
FieldElementExpression::Number(FieldPrime::from(1))
)
)
.into(),
ArrayExpression::if_else(
BooleanExpression::Eq(
box FieldElementExpression::Number(FieldPrime::from(2)),
box FieldElementExpression::Number(FieldPrime::from(1))
),
e.clone(),
ArrayExpression::select(
a0.clone(),
FieldElementExpression::Number(FieldPrime::from(2))
)
)
.into()
])
.annotate(Type::array(Type::FieldElement, 3), 3)
.into()
);
let a0 = ArrayExpressionInner::Identifier("a".into())
.annotate(Type::array(Type::FieldElement, 2), 2);
let e = FieldElementExpression::Number(FieldPrime::from(42));
let indices = vec![
FieldElementExpression::Number(FieldPrime::from(0)),
FieldElementExpression::Number(FieldPrime::from(0)),
];
let a1 = Unroller::choose_many(
a0.clone().into(),
indices,
e.clone().into(),
&mut HashSet::new(),
);
// a[0][0] = 42
// -> a = [0 == 0 ? [0 == 0 ? 42 : a[0][0], 1 == 0 ? 42 : a[0][1]] : a[0], 1 == 0 ? [0 == 0 ? 42 : a[1][0], 1 == 0 ? 42 : a[1][1]] : a[1]]
assert_eq!(
a1,
ArrayExpressionInner::Value(vec![
ArrayExpression::if_else(
BooleanExpression::Eq(
box FieldElementExpression::Number(FieldPrime::from(0)),
box FieldElementExpression::Number(FieldPrime::from(0))
),
ArrayExpressionInner::Value(vec![
FieldElementExpression::if_else(
BooleanExpression::Eq(
box FieldElementExpression::Number(FieldPrime::from(0)),
box FieldElementExpression::Number(FieldPrime::from(0))
),
e.clone(),
FieldElementExpression::select(
ArrayExpression::select(
a0.clone(),
FieldElementExpression::Number(FieldPrime::from(0))
),
FieldElementExpression::Number(FieldPrime::from(0))
)
)
.into(),
FieldElementExpression::if_else(
BooleanExpression::Eq(
box FieldElementExpression::Number(FieldPrime::from(1)),
box FieldElementExpression::Number(FieldPrime::from(0))
),
e.clone(),
FieldElementExpression::select(
ArrayExpression::select(
a0.clone(),
FieldElementExpression::Number(FieldPrime::from(0))
),
FieldElementExpression::Number(FieldPrime::from(1))
)
)
.into()
])
.annotate(Type::FieldElement, 2),
ArrayExpression::select(
a0.clone(),
FieldElementExpression::Number(FieldPrime::from(0))
)
)
.into(),
ArrayExpression::if_else(
BooleanExpression::Eq(
box FieldElementExpression::Number(FieldPrime::from(1)),
box FieldElementExpression::Number(FieldPrime::from(0))
),
ArrayExpressionInner::Value(vec![
FieldElementExpression::if_else(
BooleanExpression::Eq(
box FieldElementExpression::Number(FieldPrime::from(0)),
box FieldElementExpression::Number(FieldPrime::from(0))
),
e.clone(),
FieldElementExpression::select(
ArrayExpression::select(
a0.clone(),
FieldElementExpression::Number(FieldPrime::from(1))
),
FieldElementExpression::Number(FieldPrime::from(0))
)
)
.into(),
FieldElementExpression::if_else(
BooleanExpression::Eq(
box FieldElementExpression::Number(FieldPrime::from(1)),
box FieldElementExpression::Number(FieldPrime::from(0))
),
e.clone(),
FieldElementExpression::select(
ArrayExpression::select(
a0.clone(),
FieldElementExpression::Number(FieldPrime::from(1))
),
FieldElementExpression::Number(FieldPrime::from(1))
)
)
.into()
])
.annotate(Type::FieldElement, 2),
ArrayExpression::select(
a0.clone(),
FieldElementExpression::Number(FieldPrime::from(1))
)
)
.into(),
])
.annotate(Type::array(Type::FieldElement, 2), 2)
.into()
);
}
#[cfg(test)]
mod statement {
use super::*;
@ -522,7 +807,7 @@ mod tests {
);
let s: TypedStatement<FieldPrime> = TypedStatement::Definition(
TypedAssignee::ArrayElement(
TypedAssignee::Select(
box TypedAssignee::Identifier(Variable::field_array("a".into(), 2)),
box FieldElementExpression::Number(FieldPrime::from(1)),
),
@ -531,46 +816,56 @@ mod tests {
assert_eq!(
u.fold_statement(s),
vec![TypedStatement::Definition(
TypedAssignee::Identifier(Variable::field_array(
Identifier::from("a").version(1),
2
)),
ArrayExpressionInner::Value(vec![
FieldElementExpression::IfElse(
box BooleanExpression::Eq(
box FieldElementExpression::Number(FieldPrime::from(1)),
box FieldElementExpression::Number(FieldPrime::from(0))
),
box FieldElementExpression::Number(FieldPrime::from(2)),
box FieldElementExpression::Select(
box ArrayExpressionInner::Identifier(
Identifier::from("a").version(0)
)
.annotate(Type::FieldElement, 2),
box FieldElementExpression::Number(FieldPrime::from(0))
),
vec![
TypedStatement::Condition(
BooleanExpression::Lt(
box FieldElementExpression::Number(FieldPrime::from(1)),
box FieldElementExpression::Number(FieldPrime::from(2))
)
.into(),
FieldElementExpression::IfElse(
box BooleanExpression::Eq(
box FieldElementExpression::Number(FieldPrime::from(1)),
box FieldElementExpression::Number(FieldPrime::from(1))
),
box FieldElementExpression::Number(FieldPrime::from(2)),
box FieldElementExpression::Select(
box ArrayExpressionInner::Identifier(
Identifier::from("a").version(0)
)
.annotate(Type::FieldElement, 2),
box FieldElementExpression::Number(FieldPrime::from(1))
),
)
.into(),
])
.annotate(Type::FieldElement, 2)
.into()
)]
BooleanExpression::Value(true).into()
),
TypedStatement::Definition(
TypedAssignee::Identifier(Variable::field_array(
Identifier::from("a").version(1),
2
)),
ArrayExpressionInner::Value(vec![
FieldElementExpression::IfElse(
box BooleanExpression::Eq(
box FieldElementExpression::Number(FieldPrime::from(0)),
box FieldElementExpression::Number(FieldPrime::from(1))
),
box FieldElementExpression::Number(FieldPrime::from(2)),
box FieldElementExpression::Select(
box ArrayExpressionInner::Identifier(
Identifier::from("a").version(0)
)
.annotate(Type::FieldElement, 2),
box FieldElementExpression::Number(FieldPrime::from(0))
),
)
.into(),
FieldElementExpression::IfElse(
box BooleanExpression::Eq(
box FieldElementExpression::Number(FieldPrime::from(1)),
box FieldElementExpression::Number(FieldPrime::from(1))
),
box FieldElementExpression::Number(FieldPrime::from(2)),
box FieldElementExpression::Select(
box ArrayExpressionInner::Identifier(
Identifier::from("a").version(0)
)
.annotate(Type::FieldElement, 2),
box FieldElementExpression::Number(FieldPrime::from(1))
),
)
.into(),
])
.annotate(Type::FieldElement, 2)
.into()
)
]
);
}
@ -642,7 +937,7 @@ mod tests {
);
let s: TypedStatement<FieldPrime> = TypedStatement::Definition(
TypedAssignee::ArrayElement(
TypedAssignee::Select(
box TypedAssignee::Identifier(Variable::with_id_and_type(
"a".into(),
array_of_array_ty.clone(),
@ -659,60 +954,70 @@ mod tests {
assert_eq!(
u.fold_statement(s),
vec![TypedStatement::Definition(
TypedAssignee::Identifier(Variable::with_id_and_type(
Identifier::from("a").version(1),
array_of_array_ty.clone()
)),
ArrayExpressionInner::Value(vec![
ArrayExpressionInner::IfElse(
box BooleanExpression::Eq(
box FieldElementExpression::Number(FieldPrime::from(1)),
box FieldElementExpression::Number(FieldPrime::from(0))
),
box ArrayExpressionInner::Value(vec![
FieldElementExpression::Number(FieldPrime::from(4)).into(),
FieldElementExpression::Number(FieldPrime::from(5)).into(),
])
vec![
TypedStatement::Condition(
BooleanExpression::Lt(
box FieldElementExpression::Number(FieldPrime::from(1)),
box FieldElementExpression::Number(FieldPrime::from(2))
)
.into(),
BooleanExpression::Value(true).into()
),
TypedStatement::Definition(
TypedAssignee::Identifier(Variable::with_id_and_type(
Identifier::from("a").version(1),
array_of_array_ty.clone()
)),
ArrayExpressionInner::Value(vec![
ArrayExpressionInner::IfElse(
box BooleanExpression::Eq(
box FieldElementExpression::Number(FieldPrime::from(0)),
box FieldElementExpression::Number(FieldPrime::from(1))
),
box ArrayExpressionInner::Value(vec![
FieldElementExpression::Number(FieldPrime::from(4)).into(),
FieldElementExpression::Number(FieldPrime::from(5)).into(),
])
.annotate(Type::FieldElement, 2)
.into(),
box ArrayExpressionInner::Select(
box ArrayExpressionInner::Identifier(
Identifier::from("a").version(0)
)
.annotate(Type::array(Type::FieldElement, 2), 2),
box FieldElementExpression::Number(FieldPrime::from(0))
)
.annotate(Type::FieldElement, 2),
)
.annotate(Type::FieldElement, 2)
.into(),
box ArrayExpressionInner::Select(
box ArrayExpressionInner::Identifier(
Identifier::from("a").version(0)
ArrayExpressionInner::IfElse(
box BooleanExpression::Eq(
box FieldElementExpression::Number(FieldPrime::from(1)),
box FieldElementExpression::Number(FieldPrime::from(1))
),
box ArrayExpressionInner::Value(vec![
FieldElementExpression::Number(FieldPrime::from(4)).into(),
FieldElementExpression::Number(FieldPrime::from(5)).into(),
])
.annotate(Type::FieldElement, 2)
.into(),
box ArrayExpressionInner::Select(
box ArrayExpressionInner::Identifier(
Identifier::from("a").version(0)
)
.annotate(Type::array(Type::FieldElement, 2), 2),
box FieldElementExpression::Number(FieldPrime::from(1))
)
.annotate(Type::array(Type::FieldElement, 2), 2),
box FieldElementExpression::Number(FieldPrime::from(0))
.annotate(Type::FieldElement, 2),
)
.annotate(Type::FieldElement, 2),
)
.annotate(Type::FieldElement, 2)
.into(),
ArrayExpressionInner::IfElse(
box BooleanExpression::Eq(
box FieldElementExpression::Number(FieldPrime::from(1)),
box FieldElementExpression::Number(FieldPrime::from(1))
),
box ArrayExpressionInner::Value(vec![
FieldElementExpression::Number(FieldPrime::from(4)).into(),
FieldElementExpression::Number(FieldPrime::from(5)).into(),
])
.annotate(Type::FieldElement, 2)
.into(),
box ArrayExpressionInner::Select(
box ArrayExpressionInner::Identifier(
Identifier::from("a").version(0)
)
.annotate(Type::array(Type::FieldElement, 2), 2),
box FieldElementExpression::Number(FieldPrime::from(1))
)
.annotate(Type::FieldElement, 2),
)
.annotate(Type::FieldElement, 2)
.into(),
])
.annotate(Type::array(Type::FieldElement, 2), 2)
.into()
)]
])
.annotate(Type::array(Type::FieldElement, 2), 2)
.into()
)
]
);
}
}

View file

@ -44,7 +44,7 @@ pub trait Folder<'ast, T: Field>: Sized {
fn fold_assignee(&mut self, a: TypedAssignee<'ast, T>) -> TypedAssignee<'ast, T> {
match a {
TypedAssignee::Identifier(v) => TypedAssignee::Identifier(self.fold_variable(v)),
TypedAssignee::ArrayElement(box a, box index) => TypedAssignee::ArrayElement(
TypedAssignee::Select(box a, box index) => TypedAssignee::Select(
box self.fold_assignee(a),
box self.fold_field_expression(index),
),

View file

@ -231,7 +231,7 @@ impl<'ast, T: Field> fmt::Debug for TypedFunction<'ast, T> {
#[derive(Clone, PartialEq, Hash, Eq)]
pub enum TypedAssignee<'ast, T: Field> {
Identifier(Variable<'ast>),
ArrayElement(
Select(
Box<TypedAssignee<'ast, T>>,
Box<FieldElementExpression<'ast, T>>,
),
@ -241,11 +241,11 @@ impl<'ast, T: Field> Typed for TypedAssignee<'ast, T> {
fn get_type(&self) -> Type {
match *self {
TypedAssignee::Identifier(ref v) => v.get_type(),
TypedAssignee::ArrayElement(ref a, _) => {
TypedAssignee::Select(ref a, _) => {
let a_type = a.get_type();
match a_type {
Type::Array(box t, _) => t,
_ => unreachable!("an array element should only be defined over arrays, this will be handled gracefully in the future"),
_ => unreachable!("an array element should only be defined over arrays"),
}
}
}
@ -256,7 +256,7 @@ impl<'ast, T: Field> fmt::Debug for TypedAssignee<'ast, T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
TypedAssignee::Identifier(ref s) => write!(f, "{}", s.id),
TypedAssignee::ArrayElement(ref a, ref e) => write!(f, "{}[{}]", a, e),
TypedAssignee::Select(ref a, ref e) => write!(f, "{}[{}]", a, e),
}
}
}
@ -268,7 +268,7 @@ impl<'ast, T: Field> fmt::Display for TypedAssignee<'ast, T> {
}
/// A statement in a `TypedFunction`
#[derive(Clone, PartialEq)]
#[derive(Clone, PartialEq, Hash, Eq)]
pub enum TypedStatement<'ast, T: Field> {
Return(Vec<TypedExpression<'ast, T>>),
Definition(TypedAssignee<'ast, T>, TypedExpression<'ast, T>),
@ -413,8 +413,8 @@ impl<'ast, T: Field> fmt::Debug for ArrayExpression<'ast, T> {
impl<'ast, T: Field> Typed for TypedExpression<'ast, T> {
fn get_type(&self) -> Type {
match *self {
TypedExpression::Boolean(_) => Type::Boolean,
TypedExpression::FieldElement(_) => Type::FieldElement,
TypedExpression::Boolean(ref e) => e.get_type(),
TypedExpression::FieldElement(ref e) => e.get_type(),
TypedExpression::Array(ref e) => e.get_type(),
}
}
@ -426,11 +426,23 @@ impl<'ast, T: Field> Typed for ArrayExpression<'ast, T> {
}
}
impl<'ast, T: Field> Typed for FieldElementExpression<'ast, T> {
fn get_type(&self) -> Type {
Type::FieldElement
}
}
impl<'ast, T: Field> Typed for BooleanExpression<'ast, T> {
fn get_type(&self) -> Type {
Type::Boolean
}
}
pub trait MultiTyped {
fn get_types(&self) -> &Vec<Type>;
}
#[derive(Clone, PartialEq)]
#[derive(Clone, PartialEq, Hash, Eq)]
pub enum TypedExpressionList<'ast, T: Field> {
FunctionCall(FunctionKey<'ast>, Vec<TypedExpression<'ast, T>>, Vec<Type>),
}
@ -574,10 +586,6 @@ impl<'ast, T: Field> ArrayExpression<'ast, T> {
pub fn into_inner(self) -> ArrayExpressionInner<'ast, T> {
self.inner
}
pub fn as_inner_mut(&mut self) -> &mut ArrayExpressionInner<'ast, T> {
&mut self.inner
}
}
// Downcasts
@ -745,8 +753,8 @@ impl<'ast, T: Field> fmt::Debug for FieldElementExpression<'ast, T> {
impl<'ast, T: Field> fmt::Debug for ArrayExpressionInner<'ast, T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
ArrayExpressionInner::Identifier(ref var) => write!(f, "{:?}", var),
ArrayExpressionInner::Value(ref values) => write!(f, "{:?}", values),
ArrayExpressionInner::Identifier(ref var) => write!(f, "Identifier({:?})", var),
ArrayExpressionInner::Value(ref values) => write!(f, "Value({:?})", values),
ArrayExpressionInner::FunctionCall(ref i, ref p) => {
write!(f, "FunctionCall({:?}, (", i)?;
f.debug_list().entries(p.iter()).finish()?;