1
0
Fork 0
mirror of synced 2025-09-24 04:40:05 +00:00

merge rec-arrays

This commit is contained in:
schaeff 2019-09-10 12:52:03 +02:00
commit 494eca0599
15 changed files with 1367 additions and 756 deletions

View file

@ -0,0 +1,4 @@
def main(field[10][10][10] a, field i, field j, field k) -> (field[3]):
a[i][j][k] = 42
field[3][3] b = [[1, 2, 3], [1, 2, 3], [1, 2, 3]]
return b[0]

View file

@ -0,0 +1,6 @@
[
0,
0,
0,
0
]

View file

@ -0,0 +1,3 @@
def main(field[2][2] a) -> (field[2][2]):
a[1][1] = 42
return a

View file

@ -0,0 +1,4 @@
~out_0 0
~out_1 0
~out_2 0
~out_3 42

View file

@ -586,15 +586,15 @@ impl<'ast, T: Field> From<pest::Assignee<'ast>> for absy::AssigneeNode<'ast, T>
use absy::NodeValue;
let a = absy::AssigneeNode::from(assignee.id);
match assignee.indices.len() {
0 => a,
1 => absy::Assignee::ArrayElement(
box a,
box absy::RangeOrExpression::from(assignee.indices[0].clone()),
)
.span(assignee.span),
n => unimplemented!("Array should have one dimension, found {} in {}", n, a),
}
let span = assignee.span;
assignee
.indices
.into_iter()
.map(|i| absy::RangeOrExpression::from(i))
.fold(a, |acc, s| {
absy::Assignee::Select(box acc, box s).span(span.clone())
})
}
}
@ -622,7 +622,7 @@ impl<'ast> From<pest::Type<'ast>> for absy::UnresolvedTypeNode {
let span = t.span;
t.size
t.dimensions
.into_iter()
.map(|s| match s {
pest::Expression::Constant(c) => match c {
@ -776,4 +776,200 @@ mod tests {
assert_eq!(absy::Module::<FieldPrime>::from(ast), expected);
}
mod types {
use super::*;
/// Helper method to generate the ast for `def main(private {ty} a) -> (): return` which we use to check ty
fn wrap(ty: absy::UnresolvedType) -> absy::Module<'static, FieldPrime> {
absy::Module {
symbols: vec![absy::SymbolDeclaration {
id: "main",
symbol: absy::Symbol::HereFunction(
absy::Function {
arguments: vec![absy::Parameter::private(
absy::Variable::new("a", ty.clone().mock()).into(),
)
.into()],
statements: vec![absy::Statement::Return(
absy::ExpressionList {
expressions: vec![],
}
.into(),
)
.into()],
signature: absy::UnresolvedSignature::new().inputs(vec![ty.mock()]),
}
.into(),
),
}
.into()],
imports: vec![],
}
}
#[test]
fn array() {
let vectors = vec![
("field", absy::UnresolvedType::FieldElement),
("bool", absy::UnresolvedType::Boolean),
(
"field[2]",
absy::UnresolvedType::Array(box absy::UnresolvedType::FieldElement.mock(), 2),
),
(
"field[2][3]",
absy::UnresolvedType::Array(
box absy::UnresolvedType::Array(
box absy::UnresolvedType::FieldElement.mock(),
2,
)
.mock(),
3,
),
),
(
"bool[2][3]",
absy::UnresolvedType::Array(
box absy::UnresolvedType::Array(
box absy::UnresolvedType::Boolean.mock(),
2,
)
.mock(),
3,
),
),
];
for (ty, expected) in vectors {
let source = format!("def main(private {} a) -> (): return", ty);
let expected = wrap(expected);
let ast = pest::generate_ast(&source).unwrap();
assert_eq!(absy::Module::<FieldPrime>::from(ast), expected);
}
}
}
mod postfix {
use super::*;
fn wrap(expression: absy::Expression<'static, FieldPrime>) -> absy::Module<FieldPrime> {
absy::Module {
symbols: vec![absy::SymbolDeclaration {
id: "main",
symbol: absy::Symbol::HereFunction(
absy::Function {
arguments: vec![],
statements: vec![absy::Statement::Return(
absy::ExpressionList {
expressions: vec![expression.into()],
}
.into(),
)
.into()],
signature: absy::UnresolvedSignature::new(),
}
.into(),
),
}
.into()],
imports: vec![],
}
}
#[test]
fn success() {
// we basically accept `()?[]*` : an optional call at first, then only array accesses
let vectors = vec![
("a", absy::Expression::Identifier("a").into()),
(
"a[3]",
absy::Expression::Select(
box absy::Expression::Identifier("a").into(),
box absy::RangeOrExpression::Expression(
absy::Expression::FieldConstant(FieldPrime::from(3)).into(),
)
.into(),
),
),
(
"a[3][4]",
absy::Expression::Select(
box absy::Expression::Select(
box absy::Expression::Identifier("a").into(),
box absy::RangeOrExpression::Expression(
absy::Expression::FieldConstant(FieldPrime::from(3)).into(),
)
.into(),
)
.into(),
box absy::RangeOrExpression::Expression(
absy::Expression::FieldConstant(FieldPrime::from(4)).into(),
)
.into(),
),
),
(
"a(3)[4]",
absy::Expression::Select(
box absy::Expression::FunctionCall(
"a",
vec![absy::Expression::FieldConstant(FieldPrime::from(3)).into()],
)
.into(),
box absy::RangeOrExpression::Expression(
absy::Expression::FieldConstant(FieldPrime::from(4)).into(),
)
.into(),
),
),
(
"a(3)[4][5]",
absy::Expression::Select(
box absy::Expression::Select(
box absy::Expression::FunctionCall(
"a",
vec![absy::Expression::FieldConstant(FieldPrime::from(3)).into()],
)
.into(),
box absy::RangeOrExpression::Expression(
absy::Expression::FieldConstant(FieldPrime::from(4)).into(),
)
.into(),
)
.into(),
box absy::RangeOrExpression::Expression(
absy::Expression::FieldConstant(FieldPrime::from(5)).into(),
)
.into(),
),
),
];
for (source, expected) in vectors {
let source = format!("def main() -> (): return {}", source);
let expected = wrap(expected);
let ast = pest::generate_ast(&source).unwrap();
assert_eq!(absy::Module::<FieldPrime>::from(ast), expected);
}
}
#[test]
#[should_panic]
fn call_array_element() {
// a call after an array access should be rejected
let source = "def main() -> (): return a[2](3)";
let ast = pest::generate_ast(&source).unwrap();
absy::Module::<FieldPrime>::from(ast);
}
#[test]
#[should_panic]
fn call_call_result() {
// a call after a call should be rejected
let source = "def main() -> (): return a(2)(3)";
let ast = pest::generate_ast(&source).unwrap();
absy::Module::<FieldPrime>::from(ast);
}
}
}

View file

@ -253,7 +253,7 @@ impl<'ast, T: Field> fmt::Debug for Function<'ast, T> {
#[derive(Clone, PartialEq)]
pub enum Assignee<'ast, T: Field> {
Identifier(Identifier<'ast>),
ArrayElement(Box<AssigneeNode<'ast, T>>, Box<RangeOrExpression<'ast, T>>),
Select(Box<AssigneeNode<'ast, T>>, Box<RangeOrExpression<'ast, T>>),
}
pub type AssigneeNode<'ast, T> = Node<Assignee<'ast, T>>;
@ -261,8 +261,8 @@ pub type AssigneeNode<'ast, T> = Node<Assignee<'ast, T>>;
impl<'ast, T: Field> fmt::Debug for Assignee<'ast, T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
Assignee::Identifier(ref s) => write!(f, "{}", s),
Assignee::ArrayElement(ref a, ref e) => write!(f, "{}[{}]", a, e),
Assignee::Identifier(ref s) => write!(f, "Identifier({:?})", s),
Assignee::Select(ref a, ref e) => write!(f, "Select({:?}[{:?}])", a, e),
}
}
}
@ -578,7 +578,9 @@ impl<'ast, T: Field> fmt::Debug for Expression<'ast, T> {
f.debug_list().entries(members.iter()).finish()?;
write!(f, "]")
}
Expression::Select(ref array, ref index) => write!(f, "{}[{}]", array, index),
Expression::Select(ref array, ref index) => {
write!(f, "Select({:?}, {:?})", array, index)
}
Expression::Member(ref struc, ref id) => write!(f, "{}.{}", struc, id),
Expression::Or(ref lhs, ref rhs) => write!(f, "{} || {}", lhs, rhs),
}

View file

@ -138,6 +138,8 @@ pub fn compile<T: Field, R: BufRead, S: BufRead, E: Into<imports::Error>>(
let source = arena.alloc(source);
println!("{:?}", source);
let compiled = compile_program(source, location.clone(), resolve_option, &arena)?;
// check semantics
@ -150,9 +152,13 @@ pub fn compile<T: Field, R: BufRead, S: BufRead, E: Into<imports::Error>>(
)
})?;
println!("{}", typed_ast);
// analyse (unroll and constant propagation)
let typed_ast = typed_ast.analyse();
println!("{}", typed_ast);
// flatten input program
let program_flattened = Flattener::flatten(typed_ast);

View file

@ -27,19 +27,17 @@ pub struct Flattener<'ast, T: Field> {
// We introduce a trait in order to make it possible to make flattening `e` generic over the type of `e`
#[rustfmt::skip]
trait Flatten<'ast, T: Field>: TryFrom<TypedExpression<'ast, T>, Error: std::fmt::Debug> {
trait Flatten<'ast, T: Field>
: TryFrom<TypedExpression<'ast, T>, Error: std::fmt::Debug>
+ IfElse<'ast, T>
+ Select<'ast, T>
+ Member<'ast, T> {
fn flatten(
self,
flattener: &mut Flattener<'ast, T>,
symbols: &TypedFunctionSymbols<'ast, T>,
statements_flattened: &mut Vec<FlatStatement<T>>,
) -> Vec<FlatExpression<T>>;
fn if_else(condition: BooleanExpression<'ast, T>, consequence: Self, alternative: Self)
-> Self;
fn select(array: ArrayExpression<'ast, T>, index: FieldElementExpression<'ast, T>) -> Self;
fn member(s: StructExpression<'ast, T>, id: MemberId) -> Self;
}
impl<'ast, T: Field> Flatten<'ast, T> for FieldElementExpression<'ast, T> {
@ -51,22 +49,6 @@ impl<'ast, T: Field> Flatten<'ast, T> for FieldElementExpression<'ast, T> {
) -> Vec<FlatExpression<T>> {
vec![flattener.flatten_field_expression(symbols, statements_flattened, self)]
}
fn if_else(
condition: BooleanExpression<'ast, T>,
consequence: Self,
alternative: Self,
) -> Self {
FieldElementExpression::IfElse(box condition, box consequence, box alternative)
}
fn select(array: ArrayExpression<'ast, T>, index: FieldElementExpression<'ast, T>) -> Self {
FieldElementExpression::Select(box array, box index)
}
fn member(s: StructExpression<'ast, T>, id: MemberId) -> Self {
FieldElementExpression::Member(box s, id)
}
}
impl<'ast, T: Field> Flatten<'ast, T> for BooleanExpression<'ast, T> {
@ -78,22 +60,6 @@ impl<'ast, T: Field> Flatten<'ast, T> for BooleanExpression<'ast, T> {
) -> Vec<FlatExpression<T>> {
vec![flattener.flatten_boolean_expression(symbols, statements_flattened, self)]
}
fn if_else(
condition: BooleanExpression<'ast, T>,
consequence: Self,
alternative: Self,
) -> Self {
BooleanExpression::IfElse(box condition, box consequence, box alternative)
}
fn select(array: ArrayExpression<'ast, T>, index: FieldElementExpression<'ast, T>) -> Self {
BooleanExpression::Select(box array, box index)
}
fn member(s: StructExpression<'ast, T>, id: MemberId) -> Self {
BooleanExpression::Member(box s, id)
}
}
impl<'ast, T: Field> Flatten<'ast, T> for StructExpression<'ast, T> {
@ -105,41 +71,6 @@ impl<'ast, T: Field> Flatten<'ast, T> for StructExpression<'ast, T> {
) -> Vec<FlatExpression<T>> {
flattener.flatten_struct_expression(symbols, statements_flattened, self)
}
fn if_else(
condition: BooleanExpression<'ast, T>,
consequence: Self,
alternative: Self,
) -> Self {
let ty = consequence.ty().clone();
StructExpressionInner::IfElse(box condition, box consequence, box alternative).annotate(ty)
}
fn select(array: ArrayExpression<'ast, T>, index: FieldElementExpression<'ast, T>) -> Self {
let members = match array.inner_type().clone() {
Type::Struct(members) => members,
_ => unreachable!(),
};
StructExpressionInner::Select(box array, box index).annotate(members)
}
fn member(s: StructExpression<'ast, T>, member_id: MemberId) -> Self {
let members = s.ty().clone();
let ty = members
.into_iter()
.find(|(id, _)| *id == member_id)
.unwrap()
.1;
let members = match ty {
Type::Struct(members) => members,
_ => unreachable!(),
};
StructExpressionInner::Member(box s, member_id).annotate(members)
}
}
impl<'ast, T: Field> Flatten<'ast, T> for ArrayExpression<'ast, T> {
@ -173,43 +104,6 @@ impl<'ast, T: Field> Flatten<'ast, T> for ArrayExpression<'ast, T> {
),
}
}
fn if_else(
condition: BooleanExpression<'ast, T>,
consequence: Self,
alternative: Self,
) -> Self {
let ty = consequence.inner_type().clone();
let size = consequence.size();
ArrayExpressionInner::IfElse(box condition, box consequence, box alternative)
.annotate(ty, size)
}
fn select(array: ArrayExpression<'ast, T>, index: FieldElementExpression<'ast, T>) -> Self {
let (ty, size) = match array.inner_type() {
Type::Array(inner, size) => (inner.clone(), size.clone()),
_ => unreachable!(),
};
ArrayExpressionInner::Select(box array, box index).annotate(*ty, size)
}
fn member(s: StructExpression<'ast, T>, member_id: MemberId) -> Self {
let members = s.ty().clone();
let ty = members
.into_iter()
.find(|(id, _)| *id == member_id)
.unwrap()
.1;
let (ty, size) = match ty {
Type::Array(box ty, size) => (ty, size),
_ => unreachable!(),
};
ArrayExpressionInner::Member(box s, member_id).annotate(ty, size)
}
}
impl<'ast, T: Field> Flattener<'ast, T> {
@ -623,12 +517,37 @@ impl<'ast, T: Field> Flattener<'ast, T> {
assert!(n < T::from(size));
let n = n.to_dec_string().parse::<usize>().unwrap();
let e = self.flatten_select_expression::<U>(
symbols,
statements_flattened,
array,
index,
);
let e = match array.inner_type() {
Type::FieldElement => self
.flatten_select_expression::<FieldElementExpression<'ast, T>>(
symbols,
statements_flattened,
array,
index,
),
Type::Boolean => self
.flatten_select_expression::<BooleanExpression<'ast, T>>(
symbols,
statements_flattened,
array,
index,
),
Type::Array(..) => self
.flatten_select_expression::<ArrayExpression<'ast, T>>(
symbols,
statements_flattened,
array,
index,
),
Type::Struct(..) => self
.flatten_select_expression::<StructExpression<'ast, T>>(
symbols,
statements_flattened,
array,
index,
),
};
e[n * element_size..(n + 1) * element_size]
.into_iter()
.map(|i| i.clone().into())
@ -1616,223 +1535,41 @@ impl<'ast, T: Field> Flattener<'ast, T> {
// define n variables with n the number of primitive types for v_type
// assign them to the n primitive types for expr
let rhs = self.flatten_expression(symbols, statements_flattened, expr.clone());
let rhs = self.flatten_expression(symbols, statements_flattened, expr);
match expr.get_type() {
Type::FieldElement | Type::Boolean => {
match assignee {
TypedAssignee::Identifier(ref v) => {
let var = self.use_variable(&v)[0];
// handle return of function call
statements_flattened
.push(FlatStatement::Definition(var, rhs[0].clone()));
}
TypedAssignee::ArrayElement(box array, box index) => {
let expr = match expr {
TypedExpression::FieldElement(e) => e,
_ => panic!("not a field element as rhs of array element update, should have been caught at semantic")
};
match index {
FieldElementExpression::Number(n) => match array {
TypedAssignee::Identifier(id) => {
let var = self.issue_new_variables(1);
let variables = self.layout.get_mut(&id.id).unwrap();
variables
[n.to_dec_string().parse::<usize>().unwrap()] =
var[0];
statements_flattened.push(FlatStatement::Definition(
var[0],
rhs[0].clone(),
));
}
_ => panic!("no multidimension array for now"),
},
e => {
// we have array[e] with e an arbitrary expression
// first we check that e is in 0..array.len(), so we check that sum(if e == i then 1 else 0) == 1
// here depending on the size, we could use a proper range check based on bits
let size = match array.get_type() {
Type::Array(_, n) => n,
_ => panic!("checker should generate array element based on non array")
};
let range_check = (0..size)
.map(|i| {
FieldElementExpression::IfElse(
box BooleanExpression::Eq(
box e.clone(),
box FieldElementExpression::Number(
T::from(i),
),
),
box FieldElementExpression::Number(T::from(1)),
box FieldElementExpression::Number(T::from(0)),
)
})
.fold(
FieldElementExpression::Number(T::from(0)),
|acc, e| {
FieldElementExpression::Add(box acc, box e)
},
);
let range_check_statement = TypedStatement::Condition(
FieldElementExpression::Number(T::from(1)).into(),
range_check.into(),
);
self.flatten_statement(
symbols,
statements_flattened,
range_check_statement,
);
// now we redefine the whole array, updating only the piece that changed
// stat(array[i] = if e == i then `expr` else `array[i]`)
let vars = match array {
TypedAssignee::Identifier(v) => self.use_variable(&v),
_ => unimplemented!(),
};
let statements = vars
.into_iter()
.enumerate()
.map(|(i, v)| {
let rhs = FieldElementExpression::IfElse(
box BooleanExpression::Eq(
box e.clone(),
box FieldElementExpression::Number(
T::from(i),
),
),
box expr.clone(),
box e.clone(),
);
let rhs_flattened = self.flatten_field_expression(
symbols,
statements_flattened,
rhs,
);
FlatStatement::Definition(v, rhs_flattened)
})
.collect::<Vec<_>>();
statements_flattened.extend(statements);
}
}
}
}
}
Type::Array(..) => {
let vars = match assignee {
TypedAssignee::Identifier(v) => self.use_variable(&v),
_ => unimplemented!(),
};
match assignee {
TypedAssignee::Identifier(ref v) => {
let vars = self.use_variable(&v);
// handle return of function call
statements_flattened.extend(
vars.into_iter()
.zip(rhs.into_iter())
.map(|(v, r)| FlatStatement::Definition(v, r)),
);
}
Type::Struct(..) => {
let vars = match assignee {
TypedAssignee::Identifier(v) => self.use_variable(&v),
_ => unimplemented!(),
};
statements_flattened.extend(
vars.into_iter()
.zip(rhs.into_iter())
.map(|(v, r)| FlatStatement::Definition(v, r)),
.zip(rhs)
.map(|(v, e)| FlatStatement::Definition(v, e)),
);
}
TypedAssignee::Select(..) => unreachable!(
"array element redefs should have been replaced by array redefs in unroll"
),
}
}
TypedStatement::Condition(expr1, expr2) => {
TypedStatement::Condition(lhs, rhs) => {
// flatten expr1 and expr2 to n flattened expressions with n the number of primitive types for expr1
// add n conditions to check equality of the n expressions
match (expr1, expr2) {
(TypedExpression::FieldElement(e1), TypedExpression::FieldElement(e2)) => {
let (lhs, rhs) = (
self.flatten_field_expression(symbols, statements_flattened, e1),
self.flatten_field_expression(symbols, statements_flattened, e2),
);
let lhs = self.flatten_expression(symbols, statements_flattened, lhs);
let rhs = self.flatten_expression(symbols, statements_flattened, rhs);
if lhs.is_linear() {
statements_flattened.push(FlatStatement::Condition(lhs, rhs));
} else if rhs.is_linear() {
// swap so that left side is linear
statements_flattened.push(FlatStatement::Condition(rhs, lhs));
} else {
unimplemented!()
}
assert_eq!(lhs.len(), rhs.len());
for (l, r) in lhs.into_iter().zip(rhs.into_iter()) {
if l.is_linear() {
statements_flattened.push(FlatStatement::Condition(l, r));
} else if r.is_linear() {
// swap so that left side is linear
statements_flattened.push(FlatStatement::Condition(r, l));
} else {
unimplemented!()
}
(TypedExpression::Boolean(e1), TypedExpression::Boolean(e2)) => {
let (lhs, rhs) = (
self.flatten_boolean_expression(symbols, statements_flattened, e1),
self.flatten_boolean_expression(symbols, statements_flattened, e2),
);
if lhs.is_linear() {
statements_flattened.push(FlatStatement::Condition(lhs, rhs));
} else if rhs.is_linear() {
// swap so that left side is linear
statements_flattened.push(FlatStatement::Condition(rhs, lhs));
} else {
unimplemented!()
}
}
(TypedExpression::Array(e1), TypedExpression::Array(e2)) => {
let (lhs, rhs) = (
match e1.inner_type() {
Type::FieldElement => self
.flatten_array_expression::<FieldElementExpression<'ast, T>>(
symbols,
statements_flattened,
e1,
),
Type::Boolean => self
.flatten_array_expression::<BooleanExpression<'ast, T>>(
symbols,
statements_flattened,
e1,
),
_ => unreachable!(),
},
match e2.inner_type() {
Type::FieldElement => self
.flatten_array_expression::<FieldElementExpression<'ast, T>>(
symbols,
statements_flattened,
e2,
),
Type::Boolean => self
.flatten_array_expression::<BooleanExpression<'ast, T>>(
symbols,
statements_flattened,
e2,
),
_ => unreachable!(),
},
);
assert_eq!(lhs.len(), rhs.len());
for (l, r) in lhs.into_iter().zip(rhs.into_iter()) {
if l.is_linear() {
statements_flattened.push(FlatStatement::Condition(l, r));
} else if r.is_linear() {
// swap so that left side is linear
statements_flattened.push(FlatStatement::Condition(r, l));
} else {
unimplemented!()
}
}
}
_ => panic!(
"non matching types in condition should have been caught at semantic stage"
),
}
}
TypedStatement::For(..) => unreachable!("static analyser should have unrolled"),

View file

@ -927,8 +927,9 @@ impl<'ast> Checker<'ast> {
message: format!("Undeclared variable: {:?}", variable_name),
}),
},
Assignee::ArrayElement(box assignee, box index) => {
Assignee::Select(box assignee, box index) => {
let checked_assignee = self.check_assignee(assignee, module_id, &types)?;
let checked_index = match index {
RangeOrExpression::Expression(e) => {
self.check_expression(e, module_id, &types)?
@ -952,7 +953,7 @@ impl<'ast> Checker<'ast> {
}),
}?;
Ok(TypedAssignee::ArrayElement(
Ok(TypedAssignee::Select(
box checked_assignee,
box checked_typed_index,
))
@ -1013,7 +1014,7 @@ impl<'ast> Checker<'ast> {
pos: Some(pos),
message: format!(
"Expected spread operator to apply on field element array, found {}",
"Expected spread operator to apply on array, found {}",
e.get_type()
),
}),
@ -1165,14 +1166,13 @@ impl<'ast> Checker<'ast> {
(TypedExpression::FieldElement(consequence), TypedExpression::FieldElement(alternative)) => {
Ok(FieldElementExpression::IfElse(box condition, box consequence, box alternative).into())
},
(TypedExpression::Boolean(consequence), TypedExpression::Boolean(alternative)) => {
Ok(BooleanExpression::IfElse(box condition, box consequence, box alternative).into())
},
(TypedExpression::Array(consequence), TypedExpression::Array(alternative)) => {
if consequence.get_type() == alternative.get_type() && consequence.size() == alternative.size() {
let inner_type = consequence.inner_type().clone();
let size = consequence.size();
Ok(ArrayExpressionInner::IfElse(box condition, box consequence, box alternative).annotate(inner_type ,size).into())
} else {
unimplemented!("handle consequence alternative inner type mismatch")
}
let inner_type = consequence.inner_type().clone();
let size = consequence.size();
Ok(ArrayExpressionInner::IfElse(box condition, box consequence, box alternative).annotate(inner_type, size).into())
},
(TypedExpression::Struct(consequence), TypedExpression::Struct(alternative)) => {
if consequence.get_type() == alternative.get_type() {
@ -1182,7 +1182,7 @@ impl<'ast> Checker<'ast> {
unimplemented!("handle consequence alternative inner type mismatch")
}
},
_ => unimplemented!()
_ => unreachable!("types should match here as we checked them explicitly")
}
false => Err(Error {
pos: Some(pos),
@ -1225,7 +1225,7 @@ impl<'ast> Checker<'ast> {
let f = &candidates[0];
// the return count has to be 1
match f.signature.outputs.len() {
1 => match f.signature.outputs[0].clone() {
1 => match &f.signature.outputs[0] {
Type::FieldElement => Ok(FieldElementExpression::FunctionCall(
FunctionKey {
id: f.id.clone(),
@ -1234,15 +1234,6 @@ impl<'ast> Checker<'ast> {
arguments_checked,
)
.into()),
Type::Array(ty, size) => Ok(ArrayExpressionInner::FunctionCall(
FunctionKey {
id: f.id.clone(),
signature: f.signature.clone(),
},
arguments_checked,
)
.annotate(*ty, size)
.into()),
Type::Struct(members) => Ok(StructExpressionInner::FunctionCall(
FunctionKey {
id: f.id.clone(),
@ -1252,6 +1243,17 @@ impl<'ast> Checker<'ast> {
)
.annotate(members.clone())
.into()),
Type::Array(box ty, size) => {
Ok(ArrayExpressionInner::FunctionCall(
FunctionKey {
id: f.id.clone(),
signature: f.signature.clone(),
},
arguments_checked,
)
.annotate(ty.clone(), size.clone())
.into())
}
_ => unimplemented!(),
},
n => Err(Error {
@ -1272,7 +1274,9 @@ impl<'ast> Checker<'ast> {
fun_id, query
),
}),
_ => panic!("duplicate definition should have been caught before the call"),
_ => {
unreachable!("duplicate definition should have been caught before the call")
}
}
}
Expression::Lt(box e1, box e2) => {
@ -1428,7 +1432,14 @@ impl<'ast> Checker<'ast> {
.into()),
}
}
_ => panic!(""),
e => Err(Error {
pos: Some(pos),
message: format!(
"Cannot access slice of expression {} of type {}",
e,
e.get_type(),
),
}),
},
RangeOrExpression::Expression(e) => {
match (array, self.check_expression(e, module_id, &types)?) {
@ -4094,7 +4105,7 @@ mod tests {
fn array_element() {
// field[33] a
// a[2] = 42
let a = Assignee::ArrayElement(
let a = Assignee::Select(
box Assignee::Identifier("a").mock(),
box RangeOrExpression::Expression(
Expression::FieldConstant(FieldPrime::from(2)).mock(),
@ -4123,7 +4134,7 @@ mod tests {
assert_eq!(
checker.check_assignee(a, &module_id, &types),
Ok(TypedAssignee::ArrayElement(
Ok(TypedAssignee::Select(
box TypedAssignee::Identifier(typed_absy::Variable::field_array(
"a".into(),
33
@ -4137,8 +4148,8 @@ mod tests {
fn array_of_array_element() {
// field[33][42] a
// a[1][2]
let a = Assignee::ArrayElement(
box Assignee::ArrayElement(
let a = Assignee::Select(
box Assignee::Select(
box Assignee::Identifier("a").mock(),
box RangeOrExpression::Expression(
Expression::FieldConstant(FieldPrime::from(1)).mock(),
@ -4176,8 +4187,8 @@ mod tests {
assert_eq!(
checker.check_assignee(a, &module_id, &types),
Ok(TypedAssignee::ArrayElement(
box TypedAssignee::ArrayElement(
Ok(TypedAssignee::Select(
box TypedAssignee::Select(
box TypedAssignee::Identifier(typed_absy::Variable::array(
"a".into(),
Type::array(Type::FieldElement, 33),

View file

@ -8,7 +8,7 @@ use crate::typed_absy::folder::*;
use crate::typed_absy::*;
use std::collections::HashMap;
use std::convert::TryFrom;
use typed_absy::types::Type;
use typed_absy::types::{MemberId, Type};
use zokrates_field::field::Field;
pub struct Propagator<'ast, T: Field> {
@ -27,6 +27,18 @@ impl<'ast, T: Field> Propagator<'ast, T> {
}
}
fn is_constant<'ast, T: Field>(e: &TypedExpression<'ast, T>) -> bool {
match e {
TypedExpression::FieldElement(FieldElementExpression::Number(..)) => true,
TypedExpression::Boolean(BooleanExpression::Value(..)) => true,
TypedExpression::Array(a) => match a.as_inner() {
ArrayExpressionInner::Value(v) => v.iter().all(|e| is_constant(e)),
_ => false,
},
_ => false,
}
}
impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> {
fn fold_function(&mut self, f: TypedFunction<'ast, T>) -> TypedFunction<'ast, T> {
self.constants = HashMap::new();
@ -35,105 +47,50 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> {
fn fold_statement(&mut self, s: TypedStatement<'ast, T>) -> Vec<TypedStatement<'ast, T>> {
let res = match s {
TypedStatement::Declaration(v) => Some(TypedStatement::Declaration(v)),
TypedStatement::Return(expressions) => Some(TypedStatement::Return(expressions.into_iter().map(|e| self.fold_expression(e)).collect())),
// propagation to the defined variable if rhs is a constant
TypedStatement::Definition(TypedAssignee::Identifier(var), expr) => {
match self.fold_expression(expr) {
e @ TypedExpression::Boolean(BooleanExpression::Value(..)) | e @ TypedExpression::FieldElement(FieldElementExpression::Number(..)) => {
self.constants.insert(TypedAssignee::Identifier(var), e);
None
},
TypedExpression::Array(e) => {
let ty = e.inner_type().clone();
let size = e.size();
TypedStatement::Declaration(v) => Some(TypedStatement::Declaration(v)),
TypedStatement::Return(expressions) => Some(TypedStatement::Return(
expressions
.into_iter()
.map(|e| self.fold_expression(e))
.collect(),
)),
// propagation to the defined variable if rhs is a constant
TypedStatement::Definition(TypedAssignee::Identifier(var), expr) => {
let expr = self.fold_expression(expr);
match e.into_inner() {
ArrayExpressionInner::Value(array) => {
let array: Vec<_> = array.into_iter().map(|e| self.fold_expression(e)).collect();
match array.iter().all(|e| match e {
TypedExpression::FieldElement(FieldElementExpression::Number(..)) => true,
TypedExpression::Boolean(BooleanExpression::Value(..)) => true,
_ => false
}) {
true => {
// all elements of the array are constants
self.constants.insert(TypedAssignee::Identifier(var), ArrayExpressionInner::Value(array).annotate(
ty,
size).into());
None
},
false => {
Some(TypedStatement::Definition(TypedAssignee::Identifier(var),
ArrayExpressionInner::Value(array).annotate(
ty,
size).into()))
}
}
},
e => Some(TypedStatement::Definition(TypedAssignee::Identifier(var), TypedExpression::Array(e.annotate(ty, size))))
}
},
e => {
Some(TypedStatement::Definition(TypedAssignee::Identifier(var), e))
}
}
},
// a[b] = c
TypedStatement::Definition(TypedAssignee::ArrayElement(box TypedAssignee::Identifier(var), box index), expr) => {
let index = self.fold_field_expression(index);
let expr = self.fold_expression(expr);
match (index, expr) {
(
FieldElementExpression::Number(n),
TypedExpression::FieldElement(expr @ FieldElementExpression::Number(..))
) => {
// a[42] = 33
// -> store (a[42] -> 33) in the constants, possibly overwriting the previous entry
self.constants.entry(TypedAssignee::Identifier(var)).and_modify(|e| {
match e {
TypedExpression::Array(e) => {
let size = e.size();
match e.as_inner_mut() {
ArrayExpressionInner::Value(ref mut v) => {
let n_as_usize = n.to_dec_string().parse::<usize>().unwrap();
if n_as_usize < size {
v[n_as_usize] = expr.into();
} else {
panic!(format!("out of bounds index ({} >= {}) found during static analysis", n_as_usize, size));
}
},
_ => panic!("constants should only store constants")
}
},
_ => unreachable!()
}
});
None
},
(index, expr) => {
// a[42] = e
// -> remove a from the constants as one of its elements is not constant
self.constants.remove(&TypedAssignee::Identifier(var.clone()));
Some(TypedStatement::Definition(TypedAssignee::ArrayElement(box TypedAssignee::Identifier(var), box index), expr))
}
}
},
TypedStatement::Definition(..) => panic!("multi dimensinal arrays are not supported, this should have been caught during semantic checking"),
// propagate lhs and rhs for conditions
TypedStatement::Condition(e1, e2) => {
// could stop execution here if condition is known to fail
Some(TypedStatement::Condition(self.fold_expression(e1), self.fold_expression(e2)))
},
// we unrolled for loops in the previous step
TypedStatement::For(..) => panic!("for loop is unexpected, it should have been unrolled"),
TypedStatement::MultipleDefinition(variables, expression_list) => {
let expression_list = self.fold_expression_list(expression_list);
Some(TypedStatement::MultipleDefinition(variables, expression_list))
}
};
if is_constant(&expr) {
self.constants.insert(TypedAssignee::Identifier(var), expr);
None
} else {
Some(TypedStatement::Definition(
TypedAssignee::Identifier(var),
expr,
))
}
}
TypedStatement::Definition(TypedAssignee::Select(..), _) => {
unreachable!("array updates should have been replaced with full array redef")
}
// propagate lhs and rhs for conditions
TypedStatement::Condition(e1, e2) => {
// could stop execution here if condition is known to fail
Some(TypedStatement::Condition(
self.fold_expression(e1),
self.fold_expression(e2),
))
}
// we unrolled for loops in the previous step
TypedStatement::For(..) => {
unreachable!("for loop is unexpected, it should have been unrolled")
}
TypedStatement::MultipleDefinition(variables, expression_list) => {
let expression_list = self.fold_expression_list(expression_list);
Some(TypedStatement::MultipleDefinition(
variables,
expression_list,
))
}
};
match res {
Some(v) => vec![v],
None => vec![],
@ -153,9 +110,9 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> {
))) {
Some(e) => match e {
TypedExpression::FieldElement(e) => e.clone(),
_ => {
panic!("constant stored for a field element should be a field element")
}
_ => unreachable!(
"constant stored for a field element should be a field element"
),
},
None => FieldElementExpression::Identifier(id),
}
@ -209,7 +166,7 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> {
(e1, FieldElementExpression::Number(n2)) => {
FieldElementExpression::Pow(box e1, box FieldElementExpression::Number(n2))
}
(_, e2) => panic!(format!(
(_, e2) => unreachable!(format!(
"non-constant exponent {} detected during static analysis",
e2
)),
@ -237,14 +194,14 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> {
if n_as_usize < size {
FieldElementExpression::try_from(v[n_as_usize].clone()).unwrap()
} else {
panic!(format!(
unreachable!(
"out of bounds index ({} >= {}) found during static analysis",
n_as_usize, size
));
);
}
}
(ArrayExpressionInner::Identifier(id), FieldElementExpression::Number(n)) => {
match self.constants.get(&TypedAssignee::ArrayElement(
match self.constants.get(&TypedAssignee::Select(
box TypedAssignee::Identifier(Variable::array(
id.clone(),
inner_type.clone(),
@ -254,7 +211,7 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> {
)) {
Some(e) => match e {
TypedExpression::FieldElement(e) => e.clone(),
_ => panic!(""),
_ => unreachable!(""),
},
None => FieldElementExpression::Select(
box ArrayExpressionInner::Identifier(id).annotate(inner_type, size),
@ -293,10 +250,140 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> {
None => ArrayExpressionInner::Identifier(id),
}
}
ArrayExpressionInner::Select(box array, box index) => {
let array = self.fold_array_expression(array);
let index = self.fold_field_expression(index);
let inner_type = array.inner_type().clone();
let size = array.size();
match (array.into_inner(), index) {
(ArrayExpressionInner::Value(v), FieldElementExpression::Number(n)) => {
let n_as_usize = n.to_dec_string().parse::<usize>().unwrap();
if n_as_usize < size {
ArrayExpression::try_from(v[n_as_usize].clone())
.unwrap()
.into_inner()
} else {
unreachable!(
"out of bounds index ({} >= {}) found during static analysis",
n_as_usize, size
);
}
}
(ArrayExpressionInner::Identifier(id), FieldElementExpression::Number(n)) => {
match self.constants.get(&TypedAssignee::Select(
box TypedAssignee::Identifier(Variable::array(
id.clone(),
inner_type.clone(),
size,
)),
box FieldElementExpression::Number(n.clone()).into(),
)) {
Some(e) => match e {
TypedExpression::Array(e) => e.clone().into_inner(),
_ => unreachable!(""),
},
None => ArrayExpressionInner::Select(
box ArrayExpressionInner::Identifier(id).annotate(inner_type, size),
box FieldElementExpression::Number(n),
),
}
}
(a, i) => ArrayExpressionInner::Select(box a.annotate(inner_type, size), box i),
}
}
ArrayExpressionInner::IfElse(box condition, box consequence, box alternative) => {
let consequence = self.fold_array_expression(consequence);
let alternative = self.fold_array_expression(alternative);
match self.fold_boolean_expression(condition) {
BooleanExpression::Value(true) => consequence.into_inner(),
BooleanExpression::Value(false) => alternative.into_inner(),
c => ArrayExpressionInner::IfElse(box c, box consequence, box alternative),
}
}
e => fold_array_expression_inner(self, ty, size, e),
}
}
fn fold_struct_expression_inner(
&mut self,
ty: &Vec<(MemberId, Type)>,
e: StructExpressionInner<'ast, T>,
) -> StructExpressionInner<'ast, T> {
match e {
StructExpressionInner::Identifier(id) => {
match self
.constants
.get(&TypedAssignee::Identifier(Variable::struc(
id.clone(),
ty.clone(),
))) {
Some(e) => match e {
TypedExpression::Struct(e) => e.as_inner().clone(),
_ => panic!("constant stored for an array should be an array"),
},
None => StructExpressionInner::Identifier(id),
}
}
StructExpressionInner::Select(box array, box index) => {
let array = self.fold_array_expression(array);
let index = self.fold_field_expression(index);
let inner_type = array.inner_type().clone();
let size = array.size();
match (array.into_inner(), index) {
(ArrayExpressionInner::Value(v), FieldElementExpression::Number(n)) => {
let n_as_usize = n.to_dec_string().parse::<usize>().unwrap();
if n_as_usize < size {
StructExpression::try_from(v[n_as_usize].clone())
.unwrap()
.into_inner()
} else {
unreachable!(
"out of bounds index ({} >= {}) found during static analysis",
n_as_usize, size
);
}
}
(ArrayExpressionInner::Identifier(id), FieldElementExpression::Number(n)) => {
match self.constants.get(&TypedAssignee::Select(
box TypedAssignee::Identifier(Variable::array(
id.clone(),
inner_type.clone(),
size,
)),
box FieldElementExpression::Number(n.clone()).into(),
)) {
Some(e) => match e {
TypedExpression::Struct(e) => e.clone().into_inner(),
_ => unreachable!(""),
},
None => StructExpressionInner::Select(
box ArrayExpressionInner::Identifier(id).annotate(inner_type, size),
box FieldElementExpression::Number(n),
),
}
}
(a, i) => {
StructExpressionInner::Select(box a.annotate(inner_type, size), box i)
}
}
}
StructExpressionInner::IfElse(box condition, box consequence, box alternative) => {
let consequence = self.fold_struct_expression(consequence);
let alternative = self.fold_struct_expression(alternative);
match self.fold_boolean_expression(condition) {
BooleanExpression::Value(true) => consequence.into_inner(),
BooleanExpression::Value(false) => alternative.into_inner(),
c => StructExpressionInner::IfElse(box c, box consequence, box alternative),
}
}
e => fold_struct_expression_inner(self, ty, e),
}
}
fn fold_boolean_expression(
&mut self,
e: BooleanExpression<'ast, T>,
@ -809,116 +896,4 @@ mod tests {
}
}
}
#[cfg(test)]
mod statement {
use super::*;
#[cfg(test)]
mod definition {
use super::*;
#[test]
fn update_constant_array() {
// field[2] a = [21, 22]
// // constants should store [21, 22]
// a[1] = 42
// // constants should store [21, 42]
let declaration = TypedStatement::Declaration(Variable::field_array("a".into(), 2));
let definition = TypedStatement::Definition(
TypedAssignee::Identifier(Variable::field_array("a".into(), 2)),
ArrayExpressionInner::Value(vec![
FieldElementExpression::Number(FieldPrime::from(21)).into(),
FieldElementExpression::Number(FieldPrime::from(22)).into(),
])
.annotate(Type::FieldElement, 2)
.into(),
);
let overwrite = TypedStatement::Definition(
TypedAssignee::ArrayElement(
box TypedAssignee::Identifier(Variable::field_array("a".into(), 2)),
box FieldElementExpression::Number(FieldPrime::from(1)),
),
FieldElementExpression::Number(FieldPrime::from(42)).into(),
);
let mut p = Propagator::new();
p.fold_statement(declaration);
p.fold_statement(definition);
let expected_value: TypedExpression<FieldPrime> =
ArrayExpressionInner::Value(vec![
FieldElementExpression::Number(FieldPrime::from(21)).into(),
FieldElementExpression::Number(FieldPrime::from(22)).into(),
])
.annotate(Type::FieldElement, 2)
.into();
assert_eq!(
p.constants
.get(&TypedAssignee::Identifier(Variable::field_array(
"a".into(),
2
)))
.unwrap(),
&expected_value
);
p.fold_statement(overwrite);
let expected_value: TypedExpression<FieldPrime> =
ArrayExpressionInner::Value(vec![
FieldElementExpression::Number(FieldPrime::from(21)).into(),
FieldElementExpression::Number(FieldPrime::from(42)).into(),
])
.annotate(Type::FieldElement, 2)
.into();
assert_eq!(
p.constants
.get(&TypedAssignee::Identifier(Variable::field_array(
"a".into(),
2
)))
.unwrap(),
&expected_value
);
}
#[test]
fn update_variable_array() {
// propagation does NOT support "partially constant" arrays. That means that in order for updates to use propagation,
// the array needs to have been defined as `field[3] = [value1, value2, value3]` with all values propagateable to constants
// a passed as input
// // constants should store nothing
// a[1] = 42
// // constants should store nothing
let declaration = TypedStatement::Declaration(Variable::field_array("a".into(), 2));
let overwrite = TypedStatement::Definition(
TypedAssignee::ArrayElement(
box TypedAssignee::Identifier(Variable::field_array("a".into(), 2)),
box FieldElementExpression::Number(FieldPrime::from(1)),
),
FieldElementExpression::Number(FieldPrime::from(42)).into(),
);
let mut p = Propagator::new();
p.fold_statement(declaration);
p.fold_statement(overwrite);
assert_eq!(
p.constants
.get(&TypedAssignee::Identifier(Variable::field_array(
"a".into(),
2
))),
None
);
}
}
}
}

View file

@ -8,6 +8,7 @@ use crate::typed_absy::folder::*;
use crate::typed_absy::types::Type;
use crate::typed_absy::*;
use std::collections::HashMap;
use std::collections::HashSet;
use zokrates_field::field::Field;
pub struct Unroller<'ast> {
@ -43,98 +44,212 @@ impl<'ast> Unroller<'ast> {
pub fn unroll<T: Field>(p: TypedProgram<T>) -> TypedProgram<T> {
Unroller::new().fold_program(p)
}
fn choose_many<T: Field>(
base: TypedExpression<'ast, T>,
indices: Vec<FieldElementExpression<'ast, T>>,
new_expression: TypedExpression<'ast, T>,
statements: &mut HashSet<TypedStatement<'ast, T>>,
) -> TypedExpression<'ast, T> {
let mut indices = indices;
match indices.len() {
0 => new_expression,
_ => {
let base = match base {
TypedExpression::Array(e) => e,
e => unreachable!("can't take an element on a {}", e.get_type()),
};
let inner_ty = base.inner_type();
let size = base.size();
let head = indices.pop().unwrap();
let tail = indices;
statements.insert(TypedStatement::Condition(
BooleanExpression::Lt(
box head.clone(),
box FieldElementExpression::Number(T::from(size)),
)
.into(),
BooleanExpression::Value(true).into(),
));
ArrayExpressionInner::Value(
(0..size)
.map(|i| match inner_ty {
Type::Array(..) => ArrayExpression::if_else(
BooleanExpression::Eq(
box FieldElementExpression::Number(T::from(i)),
box head.clone(),
),
match Self::choose_many(
ArrayExpression::select(
base.clone(),
FieldElementExpression::Number(T::from(i)),
)
.into(),
tail.clone(),
new_expression.clone(),
statements,
) {
TypedExpression::Array(e) => e,
e => unreachable!(
"the interior was expected to be an array, was {}",
e.get_type()
),
},
ArrayExpression::select(
base.clone(),
FieldElementExpression::Number(T::from(i)),
),
)
.into(),
Type::Struct(..) => StructExpression::if_else(
BooleanExpression::Eq(
box FieldElementExpression::Number(T::from(i)),
box head.clone(),
),
match Self::choose_many(
StructExpression::select(
base.clone(),
FieldElementExpression::Number(T::from(i)),
)
.into(),
tail.clone(),
new_expression.clone(),
statements,
) {
TypedExpression::Struct(e) => e,
e => unreachable!(
"the interior was expected to be a struct, was {}",
e.get_type()
),
},
StructExpression::select(
base.clone(),
FieldElementExpression::Number(T::from(i)),
),
)
.into(),
Type::FieldElement => FieldElementExpression::if_else(
BooleanExpression::Eq(
box FieldElementExpression::Number(T::from(i)),
box head.clone(),
),
match Self::choose_many(
FieldElementExpression::select(
base.clone(),
FieldElementExpression::Number(T::from(i)),
)
.into(),
tail.clone(),
new_expression.clone(),
statements,
) {
TypedExpression::FieldElement(e) => e,
e => unreachable!(
"the interior was expected to be a field, was {}",
e.get_type()
),
},
FieldElementExpression::select(
base.clone(),
FieldElementExpression::Number(T::from(i)),
),
)
.into(),
Type::Boolean => BooleanExpression::if_else(
BooleanExpression::Eq(
box FieldElementExpression::Number(T::from(i)),
box head.clone(),
),
match Self::choose_many(
BooleanExpression::select(
base.clone(),
FieldElementExpression::Number(T::from(i)),
)
.into(),
tail.clone(),
new_expression.clone(),
statements,
) {
TypedExpression::Boolean(e) => e,
e => unreachable!(
"the interior was expected to be a boolean, was {}",
e.get_type()
),
},
BooleanExpression::select(
base.clone(),
FieldElementExpression::Number(T::from(i)),
),
)
.into(),
})
.collect(),
)
.annotate(inner_ty.clone(), size)
.into()
}
}
}
}
/// Turn an assignee into its representation as a base variable and a list of indices
/// a[2][3][4] -> (a, [2, 3, 4])
fn linear<'ast, T: Field>(
a: TypedAssignee<'ast, T>,
) -> (Variable, Vec<FieldElementExpression<'ast, T>>) {
match a {
TypedAssignee::Identifier(v) => (v, vec![]),
TypedAssignee::Select(box array, box index) => {
let (v, mut indices) = linear(array);
indices.push(index);
(v, indices)
}
}
}
impl<'ast, T: Field> Folder<'ast, T> for Unroller<'ast> {
fn fold_statement(&mut self, s: TypedStatement<'ast, T>) -> Vec<TypedStatement<'ast, T>> {
match s {
TypedStatement::Declaration(_) => vec![],
TypedStatement::Definition(TypedAssignee::Identifier(variable), expr) => {
TypedStatement::Definition(assignee, expr) => {
let expr = self.fold_expression(expr);
vec![TypedStatement::Definition(
TypedAssignee::Identifier(self.issue_next_ssa_variable(variable)),
expr,
)]
}
TypedStatement::Definition(
TypedAssignee::ArrayElement(array @ box TypedAssignee::Identifier(..), box index),
expr,
) => {
let expr = self.fold_expression(expr);
let index = self.fold_field_expression(index);
let current_array = self.fold_assignee(*array.clone());
let (variable, indices) = linear(assignee);
let current_ssa_variable = match current_array {
TypedAssignee::Identifier(v) => v,
_ => panic!("assignee should be an identifier"),
let base = match variable.get_type() {
Type::FieldElement => {
FieldElementExpression::Identifier(variable.id.clone().into()).into()
}
Type::Boolean => {
BooleanExpression::Identifier(variable.id.clone().into()).into()
}
Type::Array(box ty, size) => {
ArrayExpressionInner::Identifier(variable.id.clone().into())
.annotate(ty, size)
.into()
}
Type::Struct(members) => {
StructExpressionInner::Identifier(variable.id.clone().into())
.annotate(members)
.into()
}
};
let original_variable = match *array {
TypedAssignee::Identifier(v) => v,
_ => panic!("assignee should be an identifier"),
};
let mut range_checks = HashSet::new();
let e = Self::choose_many(base, indices, expr, &mut range_checks);
let array_size = match original_variable.get_type() {
Type::Array(_, size) => size,
_ => panic!("array identifier should be a field element array"),
};
let new_variable = self.issue_next_ssa_variable(original_variable);
let new_array = match expr {
TypedExpression::FieldElement(e) => ArrayExpressionInner::Value(
(0..array_size)
.map(|i| {
FieldElementExpression::IfElse(
box BooleanExpression::Eq(
box index.clone(),
box FieldElementExpression::Number(T::from(i)),
),
box e.clone(),
box FieldElementExpression::Select(
box ArrayExpressionInner::Identifier(
current_ssa_variable.id.clone(),
)
.annotate(Type::FieldElement, array_size),
box FieldElementExpression::Number(T::from(i)),
),
)
.into()
})
.collect(),
)
.annotate(Type::FieldElement, array_size),
TypedExpression::Boolean(e) => ArrayExpressionInner::Value(
(0..array_size)
.map(|i| {
BooleanExpression::IfElse(
box BooleanExpression::Eq(
box index.clone(),
box FieldElementExpression::Number(T::from(i)),
),
box e.clone(),
box BooleanExpression::Select(
box ArrayExpressionInner::Identifier(
current_ssa_variable.id.clone(),
)
.annotate(Type::Boolean, array_size),
box FieldElementExpression::Number(T::from(i)),
),
)
.into()
})
.collect(),
)
.annotate(Type::Boolean, array_size),
TypedExpression::Array(..) => unimplemented!(),
TypedExpression::Struct(..) => unimplemented!(),
};
vec![TypedStatement::Definition(
TypedAssignee::Identifier(new_variable),
new_array.into(),
)]
range_checks
.into_iter()
.chain(std::iter::once(TypedStatement::Definition(
TypedAssignee::Identifier(self.issue_next_ssa_variable(variable)),
e,
)))
.collect()
}
TypedStatement::MultipleDefinition(variables, exprs) => {
let exprs = self.fold_expression_list(exprs);
@ -201,6 +316,241 @@ mod tests {
use super::*;
use zokrates_field::field::FieldPrime;
#[test]
fn ssa_array() {
let a0 = ArrayExpressionInner::Identifier("a".into()).annotate(Type::FieldElement, 3);
let e = FieldElementExpression::Number(FieldPrime::from(42)).into();
let index = FieldElementExpression::Number(FieldPrime::from(1));
let a1 = Unroller::choose_many(a0.clone().into(), vec![index], e, &mut HashSet::new());
// a[1] = 42
// -> a = [0 == 1 ? 42 : a[0], 1 == 1 ? 42 : a[1], 2 == 1 ? 42 : a[2]]
assert_eq!(
a1,
ArrayExpressionInner::Value(vec![
FieldElementExpression::if_else(
BooleanExpression::Eq(
box FieldElementExpression::Number(FieldPrime::from(0)),
box FieldElementExpression::Number(FieldPrime::from(1))
),
FieldElementExpression::Number(FieldPrime::from(42)),
FieldElementExpression::select(
a0.clone(),
FieldElementExpression::Number(FieldPrime::from(0))
)
)
.into(),
FieldElementExpression::if_else(
BooleanExpression::Eq(
box FieldElementExpression::Number(FieldPrime::from(1)),
box FieldElementExpression::Number(FieldPrime::from(1))
),
FieldElementExpression::Number(FieldPrime::from(42)),
FieldElementExpression::select(
a0.clone(),
FieldElementExpression::Number(FieldPrime::from(1))
)
)
.into(),
FieldElementExpression::if_else(
BooleanExpression::Eq(
box FieldElementExpression::Number(FieldPrime::from(2)),
box FieldElementExpression::Number(FieldPrime::from(1))
),
FieldElementExpression::Number(FieldPrime::from(42)),
FieldElementExpression::select(
a0.clone(),
FieldElementExpression::Number(FieldPrime::from(2))
)
)
.into()
])
.annotate(Type::FieldElement, 3)
.into()
);
let a0 = ArrayExpressionInner::Identifier("a".into())
.annotate(Type::array(Type::FieldElement, 3), 3);
let e = ArrayExpressionInner::Identifier("b".into()).annotate(Type::FieldElement, 3);
let index = FieldElementExpression::Number(FieldPrime::from(1));
let a1 = Unroller::choose_many(
a0.clone().into(),
vec![index],
e.clone().into(),
&mut HashSet::new(),
);
// a[0] = b
// -> a = [0 == 1 ? b : a[0], 1 == 1 ? b : a[1], 2 == 1 ? b : a[2]]
assert_eq!(
a1,
ArrayExpressionInner::Value(vec![
ArrayExpression::if_else(
BooleanExpression::Eq(
box FieldElementExpression::Number(FieldPrime::from(0)),
box FieldElementExpression::Number(FieldPrime::from(1))
),
e.clone(),
ArrayExpression::select(
a0.clone(),
FieldElementExpression::Number(FieldPrime::from(0))
)
)
.into(),
ArrayExpression::if_else(
BooleanExpression::Eq(
box FieldElementExpression::Number(FieldPrime::from(1)),
box FieldElementExpression::Number(FieldPrime::from(1))
),
e.clone(),
ArrayExpression::select(
a0.clone(),
FieldElementExpression::Number(FieldPrime::from(1))
)
)
.into(),
ArrayExpression::if_else(
BooleanExpression::Eq(
box FieldElementExpression::Number(FieldPrime::from(2)),
box FieldElementExpression::Number(FieldPrime::from(1))
),
e.clone(),
ArrayExpression::select(
a0.clone(),
FieldElementExpression::Number(FieldPrime::from(2))
)
)
.into()
])
.annotate(Type::array(Type::FieldElement, 3), 3)
.into()
);
let a0 = ArrayExpressionInner::Identifier("a".into())
.annotate(Type::array(Type::FieldElement, 2), 2);
let e = FieldElementExpression::Number(FieldPrime::from(42));
let indices = vec![
FieldElementExpression::Number(FieldPrime::from(0)),
FieldElementExpression::Number(FieldPrime::from(0)),
];
let a1 = Unroller::choose_many(
a0.clone().into(),
indices,
e.clone().into(),
&mut HashSet::new(),
);
// a[0][0] = 42
// -> a = [0 == 0 ? [0 == 0 ? 42 : a[0][0], 1 == 0 ? 42 : a[0][1]] : a[0], 1 == 0 ? [0 == 0 ? 42 : a[1][0], 1 == 0 ? 42 : a[1][1]] : a[1]]
assert_eq!(
a1,
ArrayExpressionInner::Value(vec![
ArrayExpression::if_else(
BooleanExpression::Eq(
box FieldElementExpression::Number(FieldPrime::from(0)),
box FieldElementExpression::Number(FieldPrime::from(0))
),
ArrayExpressionInner::Value(vec![
FieldElementExpression::if_else(
BooleanExpression::Eq(
box FieldElementExpression::Number(FieldPrime::from(0)),
box FieldElementExpression::Number(FieldPrime::from(0))
),
e.clone(),
FieldElementExpression::select(
ArrayExpression::select(
a0.clone(),
FieldElementExpression::Number(FieldPrime::from(0))
),
FieldElementExpression::Number(FieldPrime::from(0))
)
)
.into(),
FieldElementExpression::if_else(
BooleanExpression::Eq(
box FieldElementExpression::Number(FieldPrime::from(1)),
box FieldElementExpression::Number(FieldPrime::from(0))
),
e.clone(),
FieldElementExpression::select(
ArrayExpression::select(
a0.clone(),
FieldElementExpression::Number(FieldPrime::from(0))
),
FieldElementExpression::Number(FieldPrime::from(1))
)
)
.into()
])
.annotate(Type::FieldElement, 2),
ArrayExpression::select(
a0.clone(),
FieldElementExpression::Number(FieldPrime::from(0))
)
)
.into(),
ArrayExpression::if_else(
BooleanExpression::Eq(
box FieldElementExpression::Number(FieldPrime::from(1)),
box FieldElementExpression::Number(FieldPrime::from(0))
),
ArrayExpressionInner::Value(vec![
FieldElementExpression::if_else(
BooleanExpression::Eq(
box FieldElementExpression::Number(FieldPrime::from(0)),
box FieldElementExpression::Number(FieldPrime::from(0))
),
e.clone(),
FieldElementExpression::select(
ArrayExpression::select(
a0.clone(),
FieldElementExpression::Number(FieldPrime::from(1))
),
FieldElementExpression::Number(FieldPrime::from(0))
)
)
.into(),
FieldElementExpression::if_else(
BooleanExpression::Eq(
box FieldElementExpression::Number(FieldPrime::from(1)),
box FieldElementExpression::Number(FieldPrime::from(0))
),
e.clone(),
FieldElementExpression::select(
ArrayExpression::select(
a0.clone(),
FieldElementExpression::Number(FieldPrime::from(1))
),
FieldElementExpression::Number(FieldPrime::from(1))
)
)
.into()
])
.annotate(Type::FieldElement, 2),
ArrayExpression::select(
a0.clone(),
FieldElementExpression::Number(FieldPrime::from(1))
)
)
.into(),
])
.annotate(Type::array(Type::FieldElement, 2), 2)
.into()
);
}
#[cfg(test)]
mod statement {
use super::*;
@ -489,7 +839,7 @@ mod tests {
);
let s: TypedStatement<FieldPrime> = TypedStatement::Definition(
TypedAssignee::ArrayElement(
TypedAssignee::Select(
box TypedAssignee::Identifier(Variable::field_array("a".into(), 2)),
box FieldElementExpression::Number(FieldPrime::from(1)),
),
@ -498,47 +848,209 @@ mod tests {
assert_eq!(
u.fold_statement(s),
vec![TypedStatement::Definition(
TypedAssignee::Identifier(Variable::field_array(
Identifier::from("a").version(1),
2
)),
vec![
TypedStatement::Condition(
BooleanExpression::Lt(
box FieldElementExpression::Number(FieldPrime::from(1)),
box FieldElementExpression::Number(FieldPrime::from(2))
)
.into(),
BooleanExpression::Value(true).into()
),
TypedStatement::Definition(
TypedAssignee::Identifier(Variable::field_array(
Identifier::from("a").version(1),
2
)),
ArrayExpressionInner::Value(vec![
FieldElementExpression::IfElse(
box BooleanExpression::Eq(
box FieldElementExpression::Number(FieldPrime::from(0)),
box FieldElementExpression::Number(FieldPrime::from(1))
),
box FieldElementExpression::Number(FieldPrime::from(2)),
box FieldElementExpression::Select(
box ArrayExpressionInner::Identifier(
Identifier::from("a").version(0)
)
.annotate(Type::FieldElement, 2),
box FieldElementExpression::Number(FieldPrime::from(0))
),
)
.into(),
FieldElementExpression::IfElse(
box BooleanExpression::Eq(
box FieldElementExpression::Number(FieldPrime::from(1)),
box FieldElementExpression::Number(FieldPrime::from(1))
),
box FieldElementExpression::Number(FieldPrime::from(2)),
box FieldElementExpression::Select(
box ArrayExpressionInner::Identifier(
Identifier::from("a").version(0)
)
.annotate(Type::FieldElement, 2),
box FieldElementExpression::Number(FieldPrime::from(1))
),
)
.into(),
])
.annotate(Type::FieldElement, 2)
.into()
)
]
);
}
#[test]
fn incremental_array_of_arrays_definition() {
// field[2][2] a = [[0, 1], [2, 3]]
// a[1] = [4, 5]
// should be turned into
// a_0 = [[0, 1], [2, 3]]
// a_1 = [if 0 == 1 then [4, 5] else a_0[0], if 1 == 1 then [4, 5] else a_0[1]]
let mut u = Unroller::new();
let array_of_array_ty = Type::array(Type::array(Type::FieldElement, 2), 2);
let s: TypedStatement<FieldPrime> = TypedStatement::Declaration(
Variable::with_id_and_type("a".into(), array_of_array_ty.clone()),
);
assert_eq!(u.fold_statement(s), vec![]);
let s = TypedStatement::Definition(
TypedAssignee::Identifier(Variable::with_id_and_type(
"a".into(),
array_of_array_ty.clone(),
)),
ArrayExpressionInner::Value(vec![
ArrayExpressionInner::Value(vec![
FieldElementExpression::IfElse(
box BooleanExpression::Eq(
box FieldElementExpression::Number(FieldPrime::from(1)),
box FieldElementExpression::Number(FieldPrime::from(0))
),
box FieldElementExpression::Number(FieldPrime::from(2)),
box FieldElementExpression::Select(
box ArrayExpressionInner::Identifier(
Identifier::from("a").version(0)
)
.annotate(Type::FieldElement, 2),
box FieldElementExpression::Number(FieldPrime::from(0))
),
)
.into(),
FieldElementExpression::IfElse(
box BooleanExpression::Eq(
box FieldElementExpression::Number(FieldPrime::from(1)),
box FieldElementExpression::Number(FieldPrime::from(1))
),
box FieldElementExpression::Number(FieldPrime::from(2)),
box FieldElementExpression::Select(
box ArrayExpressionInner::Identifier(
Identifier::from("a").version(0)
)
.annotate(Type::FieldElement, 2),
box FieldElementExpression::Number(FieldPrime::from(1))
),
)
.into(),
FieldElementExpression::Number(FieldPrime::from(0)).into(),
FieldElementExpression::Number(FieldPrime::from(1)).into(),
])
.annotate(Type::FieldElement, 2)
.into()
.into(),
ArrayExpressionInner::Value(vec![
FieldElementExpression::Number(FieldPrime::from(2)).into(),
FieldElementExpression::Number(FieldPrime::from(3)).into(),
])
.annotate(Type::FieldElement, 2)
.into(),
])
.annotate(Type::array(Type::FieldElement, 2), 2)
.into(),
);
assert_eq!(
u.fold_statement(s),
vec![TypedStatement::Definition(
TypedAssignee::Identifier(Variable::with_id_and_type(
Identifier::from("a").version(0),
array_of_array_ty.clone(),
)),
ArrayExpressionInner::Value(vec![
ArrayExpressionInner::Value(vec![
FieldElementExpression::Number(FieldPrime::from(0)).into(),
FieldElementExpression::Number(FieldPrime::from(1)).into(),
])
.annotate(Type::FieldElement, 2)
.into(),
ArrayExpressionInner::Value(vec![
FieldElementExpression::Number(FieldPrime::from(2)).into(),
FieldElementExpression::Number(FieldPrime::from(3)).into(),
])
.annotate(Type::FieldElement, 2)
.into(),
])
.annotate(Type::array(Type::FieldElement, 2), 2)
.into(),
)]
);
let s: TypedStatement<FieldPrime> = TypedStatement::Definition(
TypedAssignee::Select(
box TypedAssignee::Identifier(Variable::with_id_and_type(
"a".into(),
array_of_array_ty.clone(),
)),
box FieldElementExpression::Number(FieldPrime::from(1)),
),
ArrayExpressionInner::Value(vec![
FieldElementExpression::Number(FieldPrime::from(4)).into(),
FieldElementExpression::Number(FieldPrime::from(5)).into(),
])
.annotate(Type::FieldElement, 2)
.into(),
);
assert_eq!(
u.fold_statement(s),
vec![
TypedStatement::Condition(
BooleanExpression::Lt(
box FieldElementExpression::Number(FieldPrime::from(1)),
box FieldElementExpression::Number(FieldPrime::from(2))
)
.into(),
BooleanExpression::Value(true).into()
),
TypedStatement::Definition(
TypedAssignee::Identifier(Variable::with_id_and_type(
Identifier::from("a").version(1),
array_of_array_ty.clone()
)),
ArrayExpressionInner::Value(vec![
ArrayExpressionInner::IfElse(
box BooleanExpression::Eq(
box FieldElementExpression::Number(FieldPrime::from(0)),
box FieldElementExpression::Number(FieldPrime::from(1))
),
box ArrayExpressionInner::Value(vec![
FieldElementExpression::Number(FieldPrime::from(4)).into(),
FieldElementExpression::Number(FieldPrime::from(5)).into(),
])
.annotate(Type::FieldElement, 2)
.into(),
box ArrayExpressionInner::Select(
box ArrayExpressionInner::Identifier(
Identifier::from("a").version(0)
)
.annotate(Type::array(Type::FieldElement, 2), 2),
box FieldElementExpression::Number(FieldPrime::from(0))
)
.annotate(Type::FieldElement, 2),
)
.annotate(Type::FieldElement, 2)
.into(),
ArrayExpressionInner::IfElse(
box BooleanExpression::Eq(
box FieldElementExpression::Number(FieldPrime::from(1)),
box FieldElementExpression::Number(FieldPrime::from(1))
),
box ArrayExpressionInner::Value(vec![
FieldElementExpression::Number(FieldPrime::from(4)).into(),
FieldElementExpression::Number(FieldPrime::from(5)).into(),
])
.annotate(Type::FieldElement, 2)
.into(),
box ArrayExpressionInner::Select(
box ArrayExpressionInner::Identifier(
Identifier::from("a").version(0)
)
.annotate(Type::array(Type::FieldElement, 2), 2),
box FieldElementExpression::Number(FieldPrime::from(1))
)
.annotate(Type::FieldElement, 2),
)
.annotate(Type::FieldElement, 2)
.into(),
])
.annotate(Type::array(Type::FieldElement, 2), 2)
.into()
)
]
);
}
}
}

View file

@ -44,7 +44,7 @@ pub trait Folder<'ast, T: Field>: Sized {
fn fold_assignee(&mut self, a: TypedAssignee<'ast, T>) -> TypedAssignee<'ast, T> {
match a {
TypedAssignee::Identifier(v) => TypedAssignee::Identifier(self.fold_variable(v)),
TypedAssignee::ArrayElement(box a, box index) => TypedAssignee::ArrayElement(
TypedAssignee::Select(box a, box index) => TypedAssignee::Select(
box self.fold_assignee(a),
box self.fold_field_expression(index),
),

View file

@ -83,7 +83,7 @@ pub struct TypedModule<'ast, T: Field> {
impl<'ast> fmt::Display for Identifier<'ast> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
if self.version == 0 && self.stack.len() == 0 {
if self.stack.len() == 0 && self.version == 0 {
write!(f, "{}", self.id)
} else {
write!(
@ -237,7 +237,7 @@ impl<'ast, T: Field> fmt::Debug for TypedFunction<'ast, T> {
#[derive(Clone, PartialEq, Hash, Eq)]
pub enum TypedAssignee<'ast, T: Field> {
Identifier(Variable<'ast>),
ArrayElement(
Select(
Box<TypedAssignee<'ast, T>>,
Box<FieldElementExpression<'ast, T>>,
),
@ -247,11 +247,11 @@ impl<'ast, T: Field> Typed for TypedAssignee<'ast, T> {
fn get_type(&self) -> Type {
match *self {
TypedAssignee::Identifier(ref v) => v.get_type(),
TypedAssignee::ArrayElement(ref a, _) => {
TypedAssignee::Select(ref a, _) => {
let a_type = a.get_type();
match a_type {
Type::Array(box t, _) => t,
_ => panic!("array element has to take array"),
_ => unreachable!("an array element should only be defined over arrays"),
}
}
}
@ -262,7 +262,7 @@ impl<'ast, T: Field> fmt::Debug for TypedAssignee<'ast, T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
TypedAssignee::Identifier(ref s) => write!(f, "{}", s.id),
TypedAssignee::ArrayElement(ref a, ref e) => write!(f, "{}[{}]", a, e),
TypedAssignee::Select(ref a, ref e) => write!(f, "{}[{}]", a, e),
}
}
}
@ -274,7 +274,7 @@ impl<'ast, T: Field> fmt::Display for TypedAssignee<'ast, T> {
}
/// A statement in a `TypedFunction`
#[derive(Clone, PartialEq)]
#[derive(Clone, PartialEq, Hash, Eq)]
pub enum TypedStatement<'ast, T: Field> {
Return(Vec<TypedExpression<'ast, T>>),
Definition(TypedAssignee<'ast, T>, TypedExpression<'ast, T>),
@ -427,7 +427,39 @@ impl<'ast, T: Field> fmt::Debug for ArrayExpression<'ast, T> {
impl<'ast, T: Field> fmt::Display for StructExpression<'ast, T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}", self.inner)
match self.inner {
StructExpressionInner::Identifier(ref var) => write!(f, "{}", var),
StructExpressionInner::Value(ref values) => write!(
f,
"{{{}}}",
self.ty
.iter()
.map(|(id, _)| id)
.zip(values.iter())
.map(|(id, o)| format!("{}: {}", id, o.to_string()))
.collect::<Vec<String>>()
.join(", ")
),
StructExpressionInner::FunctionCall(ref key, ref p) => {
write!(f, "{}(", key.id,)?;
for (i, param) in p.iter().enumerate() {
write!(f, "{}", param)?;
if i < p.len() - 1 {
write!(f, ", ")?;
}
}
write!(f, ")")
}
StructExpressionInner::IfElse(ref condition, ref consequent, ref alternative) => {
write!(
f,
"if {} then {} else {} fi",
condition, consequent, alternative
)
}
StructExpressionInner::Member(ref struc, ref id) => write!(f, "{}.{}", struc, id),
StructExpressionInner::Select(ref id, ref index) => write!(f, "{}[{}]", id, index),
}
}
}
@ -440,8 +472,8 @@ impl<'ast, T: Field> fmt::Debug for StructExpression<'ast, T> {
impl<'ast, T: Field> Typed for TypedExpression<'ast, T> {
fn get_type(&self) -> Type {
match *self {
TypedExpression::Boolean(_) => Type::Boolean,
TypedExpression::FieldElement(_) => Type::FieldElement,
TypedExpression::Boolean(ref e) => e.get_type(),
TypedExpression::FieldElement(ref e) => e.get_type(),
TypedExpression::Array(ref e) => e.get_type(),
TypedExpression::Struct(ref s) => s.get_type(),
}
@ -460,11 +492,23 @@ impl<'ast, T: Field> Typed for StructExpression<'ast, T> {
}
}
impl<'ast, T: Field> Typed for FieldElementExpression<'ast, T> {
fn get_type(&self) -> Type {
Type::FieldElement
}
}
impl<'ast, T: Field> Typed for BooleanExpression<'ast, T> {
fn get_type(&self) -> Type {
Type::Boolean
}
}
pub trait MultiTyped {
fn get_types(&self) -> &Vec<Type>;
}
#[derive(Clone, PartialEq)]
#[derive(Clone, PartialEq, Hash, Eq)]
pub enum TypedExpressionList<'ast, T: Field> {
FunctionCall(FunctionKey<'ast>, Vec<TypedExpression<'ast, T>>, Vec<Type>),
}
@ -611,10 +655,6 @@ impl<'ast, T: Field> ArrayExpression<'ast, T> {
pub fn into_inner(self) -> ArrayExpressionInner<'ast, T> {
self.inner
}
pub fn as_inner_mut(&mut self) -> &mut ArrayExpressionInner<'ast, T> {
&mut self.inner
}
}
#[derive(Clone, PartialEq, Hash, Eq)]
@ -628,6 +668,10 @@ impl<'ast, T: Field> StructExpression<'ast, T> {
&self.ty
}
pub fn as_inner(&self) -> &StructExpressionInner<'ast, T> {
&self.inner
}
pub fn into_inner(self) -> StructExpressionInner<'ast, T> {
self.inner
}
@ -795,42 +839,6 @@ impl<'ast, T: Field> fmt::Display for ArrayExpressionInner<'ast, T> {
}
}
impl<'ast, T: Field> fmt::Display for StructExpressionInner<'ast, T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
StructExpressionInner::Identifier(ref var) => write!(f, "{}", var),
StructExpressionInner::Value(ref values) => write!(
f,
"[{}]",
values
.iter()
.map(|o| o.to_string())
.collect::<Vec<String>>()
.join(", ")
),
StructExpressionInner::FunctionCall(ref key, ref p) => {
write!(f, "{}(", key.id,)?;
for (i, param) in p.iter().enumerate() {
write!(f, "{}", param)?;
if i < p.len() - 1 {
write!(f, ", ")?;
}
}
write!(f, ")")
}
StructExpressionInner::IfElse(ref condition, ref consequent, ref alternative) => {
write!(
f,
"if {} then {} else {} fi",
condition, consequent, alternative
)
}
StructExpressionInner::Member(ref struc, ref id) => write!(f, "{}.{}", struc, id),
StructExpressionInner::Select(ref id, ref index) => write!(f, "{}[{}]", id, index),
}
}
}
impl<'ast, T: Field> fmt::Debug for BooleanExpression<'ast, T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}", self)
@ -874,8 +882,8 @@ impl<'ast, T: Field> fmt::Debug for FieldElementExpression<'ast, T> {
impl<'ast, T: Field> fmt::Debug for ArrayExpressionInner<'ast, T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
ArrayExpressionInner::Identifier(ref var) => write!(f, "{:?}", var),
ArrayExpressionInner::Value(ref values) => write!(f, "{:?}", values),
ArrayExpressionInner::Identifier(ref var) => write!(f, "Identifier({:?})", var),
ArrayExpressionInner::Value(ref values) => write!(f, "Value({:?})", values),
ArrayExpressionInner::FunctionCall(ref i, ref p) => {
write!(f, "FunctionCall({:?}, (", i)?;
f.debug_list().entries(p.iter()).finish()?;
@ -951,3 +959,146 @@ impl<'ast, T: Field> fmt::Debug for TypedExpressionList<'ast, T> {
}
}
}
// Common behaviour accross expressions
pub trait IfElse<'ast, T: Field> {
fn if_else(condition: BooleanExpression<'ast, T>, consequence: Self, alternative: Self)
-> Self;
}
impl<'ast, T: Field> IfElse<'ast, T> for FieldElementExpression<'ast, T> {
fn if_else(
condition: BooleanExpression<'ast, T>,
consequence: Self,
alternative: Self,
) -> Self {
FieldElementExpression::IfElse(box condition, box consequence, box alternative)
}
}
impl<'ast, T: Field> IfElse<'ast, T> for BooleanExpression<'ast, T> {
fn if_else(
condition: BooleanExpression<'ast, T>,
consequence: Self,
alternative: Self,
) -> Self {
BooleanExpression::IfElse(box condition, box consequence, box alternative)
}
}
impl<'ast, T: Field> IfElse<'ast, T> for ArrayExpression<'ast, T> {
fn if_else(
condition: BooleanExpression<'ast, T>,
consequence: Self,
alternative: Self,
) -> Self {
let ty = consequence.inner_type().clone();
let size = consequence.size();
ArrayExpressionInner::IfElse(box condition, box consequence, box alternative)
.annotate(ty, size)
}
}
impl<'ast, T: Field> IfElse<'ast, T> for StructExpression<'ast, T> {
fn if_else(
condition: BooleanExpression<'ast, T>,
consequence: Self,
alternative: Self,
) -> Self {
let ty = consequence.ty().clone();
StructExpressionInner::IfElse(box condition, box consequence, box alternative).annotate(ty)
}
}
pub trait Select<'ast, T: Field> {
fn select(array: ArrayExpression<'ast, T>, index: FieldElementExpression<'ast, T>) -> Self;
}
impl<'ast, T: Field> Select<'ast, T> for FieldElementExpression<'ast, T> {
fn select(array: ArrayExpression<'ast, T>, index: FieldElementExpression<'ast, T>) -> Self {
FieldElementExpression::Select(box array, box index)
}
}
impl<'ast, T: Field> Select<'ast, T> for BooleanExpression<'ast, T> {
fn select(array: ArrayExpression<'ast, T>, index: FieldElementExpression<'ast, T>) -> Self {
BooleanExpression::Select(box array, box index)
}
}
impl<'ast, T: Field> Select<'ast, T> for ArrayExpression<'ast, T> {
fn select(array: ArrayExpression<'ast, T>, index: FieldElementExpression<'ast, T>) -> Self {
let (ty, size) = match array.inner_type() {
Type::Array(inner, size) => (inner.clone(), size.clone()),
_ => unreachable!(),
};
ArrayExpressionInner::Select(box array, box index).annotate(*ty, size)
}
}
impl<'ast, T: Field> Select<'ast, T> for StructExpression<'ast, T> {
fn select(array: ArrayExpression<'ast, T>, index: FieldElementExpression<'ast, T>) -> Self {
let members = match array.inner_type().clone() {
Type::Struct(members) => members,
_ => unreachable!(),
};
StructExpressionInner::Select(box array, box index).annotate(members)
}
}
pub trait Member<'ast, T: Field> {
fn member(s: StructExpression<'ast, T>, member_id: MemberId) -> Self;
}
impl<'ast, T: Field> Member<'ast, T> for FieldElementExpression<'ast, T> {
fn member(s: StructExpression<'ast, T>, member_id: MemberId) -> Self {
FieldElementExpression::Member(box s, member_id)
}
}
impl<'ast, T: Field> Member<'ast, T> for BooleanExpression<'ast, T> {
fn member(s: StructExpression<'ast, T>, member_id: MemberId) -> Self {
BooleanExpression::Member(box s, member_id)
}
}
impl<'ast, T: Field> Member<'ast, T> for ArrayExpression<'ast, T> {
fn member(s: StructExpression<'ast, T>, member_id: MemberId) -> Self {
let members = s.ty().clone();
let ty = members
.into_iter()
.find(|(id, _)| *id == member_id)
.unwrap()
.1;
let (ty, size) = match ty {
Type::Array(box ty, size) => (ty, size),
_ => unreachable!(),
};
ArrayExpressionInner::Member(box s, member_id).annotate(ty, size)
}
}
impl<'ast, T: Field> Member<'ast, T> for StructExpression<'ast, T> {
fn member(s: StructExpression<'ast, T>, member_id: MemberId) -> Self {
let members = s.ty().clone();
let ty = members
.into_iter()
.find(|(id, _)| *id == member_id)
.unwrap()
.1;
let members = match ty {
Type::Struct(members) => members,
_ => unreachable!(),
};
StructExpressionInner::Member(box s, member_id).annotate(members)
}
}

View file

@ -1,4 +1,4 @@
use crate::typed_absy::types::Type;
use crate::typed_absy::types::{MemberId, Type};
use crate::typed_absy::Identifier;
use std::fmt;
@ -26,6 +26,10 @@ impl<'ast> Variable<'ast> {
Self::with_id_and_type(id, Type::array(ty, size))
}
pub fn struc(id: Identifier<'ast>, ty: Vec<(MemberId, Type)>) -> Variable<'ast> {
Self::with_id_and_type(id, Type::Struct(ty))
}
pub fn with_id_and_type(id: Identifier<'ast>, _type: Type) -> Variable<'ast> {
Variable { id, _type }
}

View file

@ -256,7 +256,7 @@ mod ast {
#[pest_ast(rule(Rule::ty_array))]
pub struct ArrayType<'ast> {
pub ty: BasicOrStructType<'ast>,
pub size: Vec<Expression<'ast>>,
pub dimensions: Vec<Expression<'ast>>,
#[pest_ast(outer())]
pub span: Span<'ast>,
}