fully recursive arrays
This commit is contained in:
parent
34f41dc53f
commit
b2fcb4c271
8 changed files with 471 additions and 148 deletions
12
zokrates_cli/examples/arrays/cube.code
Normal file
12
zokrates_cli/examples/arrays/cube.code
Normal file
|
@ -0,0 +1,12 @@
|
|||
def main(field[2][2][2] cube) -> (field):
|
||||
field res = 0
|
||||
|
||||
for field i in 0..2 do
|
||||
for field j in 0..2 do
|
||||
for field k in 0..2 do
|
||||
res = res + cube[i][j][k]
|
||||
endfor
|
||||
endfor
|
||||
endfor
|
||||
|
||||
return res
|
|
@ -448,22 +448,26 @@ impl<'ast, T: Field> From<pest::PostfixExpression<'ast>> for absy::ExpressionNod
|
|||
fn from(expression: pest::PostfixExpression<'ast>) -> absy::ExpressionNode<'ast, T> {
|
||||
use absy::NodeValue;
|
||||
|
||||
assert!(expression.access.len() == 1); // we only allow a single access: function call or array access
|
||||
let id_str = expression.id.span.as_str();
|
||||
let id = absy::ExpressionNode::from(expression.id);
|
||||
|
||||
match expression.access[0].clone() {
|
||||
pest::Access::Call(a) => absy::Expression::FunctionCall(
|
||||
&expression.id.span.as_str(),
|
||||
a.expressions
|
||||
.into_iter()
|
||||
.map(|e| absy::ExpressionNode::from(e))
|
||||
.collect(),
|
||||
),
|
||||
pest::Access::Select(a) => absy::Expression::Select(
|
||||
box absy::ExpressionNode::from(expression.id),
|
||||
box absy::RangeOrExpression::from(a.expression),
|
||||
),
|
||||
}
|
||||
.span(expression.span)
|
||||
expression.accesses.into_iter().fold(id, |acc, a| match a {
|
||||
pest::Access::Call(a) => match acc.value {
|
||||
absy::Expression::Identifier(_) => absy::Expression::FunctionCall(
|
||||
&id_str,
|
||||
a.expressions
|
||||
.into_iter()
|
||||
.map(|e| absy::ExpressionNode::from(e))
|
||||
.collect(),
|
||||
),
|
||||
e => unimplemented!("only identifiers are callable, found \"{}\"", e),
|
||||
}
|
||||
.span(a.span),
|
||||
pest::Access::Select(a) => {
|
||||
absy::Expression::Select(box acc, box absy::RangeOrExpression::from(a.expression))
|
||||
.span(a.span)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -521,24 +525,33 @@ impl<'ast> From<pest::Type<'ast>> for Type {
|
|||
pest::BasicType::Boolean(_) => Type::Boolean,
|
||||
},
|
||||
pest::Type::Array(t) => {
|
||||
let size = match t.size {
|
||||
pest::Expression::Constant(c) => match c {
|
||||
pest::ConstantExpression::DecimalNumber(n) => {
|
||||
str::parse::<usize>(&n.value).unwrap()
|
||||
}
|
||||
_ => unimplemented!(
|
||||
"Array size should be a decimal number, found {}",
|
||||
c.span().as_str()
|
||||
),
|
||||
},
|
||||
e => {
|
||||
unimplemented!("Array size should be constant, found {}", e.span().as_str())
|
||||
}
|
||||
let inner_type = match t.ty {
|
||||
pest::BasicType::Field(_) => Type::FieldElement,
|
||||
pest::BasicType::Boolean(_) => Type::Boolean,
|
||||
};
|
||||
match t.ty {
|
||||
pest::BasicType::Field(_) => Type::array(Type::FieldElement, size),
|
||||
pest::BasicType::Boolean(_) => Type::array(Type::Boolean, size),
|
||||
}
|
||||
|
||||
t.size
|
||||
.into_iter()
|
||||
.map(|s| match s {
|
||||
pest::Expression::Constant(c) => match c {
|
||||
pest::ConstantExpression::DecimalNumber(n) => {
|
||||
str::parse::<usize>(&n.value).unwrap()
|
||||
}
|
||||
_ => unimplemented!(
|
||||
"Array size should be a decimal number, found {}",
|
||||
c.span().as_str()
|
||||
),
|
||||
},
|
||||
e => unimplemented!(
|
||||
"Array size should be constant, found {}",
|
||||
e.span().as_str()
|
||||
),
|
||||
})
|
||||
.fold(None, |acc, s| match acc {
|
||||
None => Some(Type::array(inner_type.clone(), s)),
|
||||
Some(acc) => Some(Type::array(acc, s)),
|
||||
})
|
||||
.unwrap()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -33,16 +33,12 @@ trait Flatten<'ast, T: Field>: TryFrom<TypedExpression<'ast, T>, Error: std::fmt
|
|||
flattener: &mut Flattener<'ast, T>,
|
||||
symbols: &TypedFunctionSymbols<'ast, T>,
|
||||
statements_flattened: &mut Vec<FlatStatement<T>>,
|
||||
) -> FlatExpression<T>;
|
||||
|
||||
fn get_type() -> Type;
|
||||
) -> Vec<FlatExpression<T>>;
|
||||
|
||||
fn if_else(condition: BooleanExpression<'ast, T>, consequence: Self, alternative: Self)
|
||||
-> Self;
|
||||
|
||||
fn select(array: ArrayExpression<'ast, T>, index: FieldElementExpression<'ast, T>) -> Self;
|
||||
|
||||
fn zero() -> Self;
|
||||
}
|
||||
|
||||
impl<'ast, T: Field> Flatten<'ast, T> for FieldElementExpression<'ast, T> {
|
||||
|
@ -51,12 +47,8 @@ impl<'ast, T: Field> Flatten<'ast, T> for FieldElementExpression<'ast, T> {
|
|||
flattener: &mut Flattener<'ast, T>,
|
||||
symbols: &TypedFunctionSymbols<'ast, T>,
|
||||
statements_flattened: &mut Vec<FlatStatement<T>>,
|
||||
) -> FlatExpression<T> {
|
||||
flattener.flatten_field_expression(symbols, statements_flattened, self)
|
||||
}
|
||||
|
||||
fn get_type() -> Type {
|
||||
Type::FieldElement
|
||||
) -> Vec<FlatExpression<T>> {
|
||||
vec![flattener.flatten_field_expression(symbols, statements_flattened, self)]
|
||||
}
|
||||
|
||||
fn if_else(
|
||||
|
@ -70,10 +62,6 @@ impl<'ast, T: Field> Flatten<'ast, T> for FieldElementExpression<'ast, T> {
|
|||
fn select(array: ArrayExpression<'ast, T>, index: FieldElementExpression<'ast, T>) -> Self {
|
||||
FieldElementExpression::Select(box array, box index)
|
||||
}
|
||||
|
||||
fn zero() -> Self {
|
||||
FieldElementExpression::Number(T::from(0))
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast, T: Field> Flatten<'ast, T> for BooleanExpression<'ast, T> {
|
||||
|
@ -82,12 +70,8 @@ impl<'ast, T: Field> Flatten<'ast, T> for BooleanExpression<'ast, T> {
|
|||
flattener: &mut Flattener<'ast, T>,
|
||||
symbols: &TypedFunctionSymbols<'ast, T>,
|
||||
statements_flattened: &mut Vec<FlatStatement<T>>,
|
||||
) -> FlatExpression<T> {
|
||||
flattener.flatten_boolean_expression(symbols, statements_flattened, self)
|
||||
}
|
||||
|
||||
fn get_type() -> Type {
|
||||
Type::Boolean
|
||||
) -> Vec<FlatExpression<T>> {
|
||||
vec![flattener.flatten_boolean_expression(symbols, statements_flattened, self)]
|
||||
}
|
||||
|
||||
fn if_else(
|
||||
|
@ -101,9 +85,58 @@ impl<'ast, T: Field> Flatten<'ast, T> for BooleanExpression<'ast, T> {
|
|||
fn select(array: ArrayExpression<'ast, T>, index: FieldElementExpression<'ast, T>) -> Self {
|
||||
BooleanExpression::Select(box array, box index)
|
||||
}
|
||||
}
|
||||
|
||||
fn zero() -> Self {
|
||||
BooleanExpression::Value(false)
|
||||
impl<'ast, T: Field> Flatten<'ast, T> for ArrayExpression<'ast, T> {
|
||||
fn flatten(
|
||||
self,
|
||||
flattener: &mut Flattener<'ast, T>,
|
||||
symbols: &TypedFunctionSymbols<'ast, T>,
|
||||
statements_flattened: &mut Vec<FlatStatement<T>>,
|
||||
) -> Vec<FlatExpression<T>> {
|
||||
match self.inner_type() {
|
||||
Type::FieldElement => flattener
|
||||
.flatten_array_expression::<FieldElementExpression<'ast, T>>(
|
||||
symbols,
|
||||
statements_flattened,
|
||||
self,
|
||||
),
|
||||
Type::Boolean => flattener.flatten_array_expression::<BooleanExpression<'ast, T>>(
|
||||
symbols,
|
||||
statements_flattened,
|
||||
self,
|
||||
),
|
||||
Type::Array(..) => flattener.flatten_array_expression::<ArrayExpression<'ast, T>>(
|
||||
symbols,
|
||||
statements_flattened,
|
||||
self,
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
fn if_else(
|
||||
condition: BooleanExpression<'ast, T>,
|
||||
consequence: Self,
|
||||
alternative: Self,
|
||||
) -> Self {
|
||||
ArrayExpression {
|
||||
ty: consequence.inner_type().clone(),
|
||||
size: consequence.size(),
|
||||
inner: ArrayExpressionInner::IfElse(box condition, box consequence, box alternative),
|
||||
}
|
||||
}
|
||||
|
||||
fn select(array: ArrayExpression<'ast, T>, index: FieldElementExpression<'ast, T>) -> Self {
|
||||
let (ty, size) = match array.inner_type() {
|
||||
Type::Array(inner, size) => (inner, size),
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
ArrayExpression {
|
||||
ty: *ty.clone(),
|
||||
size: *size,
|
||||
inner: ArrayExpressionInner::Select(box array, box index),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -133,44 +166,76 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
condition: BooleanExpression<'ast, T>,
|
||||
consequence: U,
|
||||
alternative: U,
|
||||
) -> FlatExpression<T> {
|
||||
) -> Vec<FlatExpression<T>> {
|
||||
let condition = self.flatten_boolean_expression(symbols, statements_flattened, condition);
|
||||
|
||||
let consequence = consequence.flatten(self, symbols, statements_flattened);
|
||||
|
||||
let alternative = alternative.flatten(self, symbols, statements_flattened);
|
||||
|
||||
let size = consequence.len();
|
||||
|
||||
let condition_id = self.use_sym();
|
||||
statements_flattened.push(FlatStatement::Definition(condition_id, condition));
|
||||
|
||||
let consequence_id = self.use_sym();
|
||||
statements_flattened.push(FlatStatement::Definition(consequence_id, consequence));
|
||||
let consequence_ids: Vec<_> = (0..size).map(|_| self.use_sym()).collect();
|
||||
statements_flattened.extend(
|
||||
consequence
|
||||
.into_iter()
|
||||
.zip(consequence_ids.iter())
|
||||
.map(|(c, c_id)| FlatStatement::Definition(*c_id, c)),
|
||||
);
|
||||
|
||||
let alternative_id = self.use_sym();
|
||||
statements_flattened.push(FlatStatement::Definition(alternative_id, alternative));
|
||||
let alternative_ids: Vec<_> = (0..size).map(|_| self.use_sym()).collect();
|
||||
statements_flattened.extend(
|
||||
alternative
|
||||
.into_iter()
|
||||
.zip(alternative_ids.iter())
|
||||
.map(|(a, a_id)| FlatStatement::Definition(*a_id, a)),
|
||||
);
|
||||
|
||||
let term0 = self.use_sym();
|
||||
statements_flattened.push(FlatStatement::Definition(
|
||||
term0,
|
||||
FlatExpression::Mult(box condition_id.clone().into(), box consequence_id.into()),
|
||||
let term0_ids: Vec<_> = (0..size).map(|_| self.use_sym()).collect();
|
||||
statements_flattened.extend(consequence_ids.iter().zip(term0_ids.iter()).map(
|
||||
|(c_id, t0_id)| {
|
||||
FlatStatement::Definition(
|
||||
*t0_id,
|
||||
FlatExpression::Mult(
|
||||
box condition_id.clone().into(),
|
||||
box FlatExpression::from(*c_id),
|
||||
),
|
||||
)
|
||||
},
|
||||
));
|
||||
let term1 = self.use_sym();
|
||||
statements_flattened.push(FlatStatement::Definition(
|
||||
term1,
|
||||
FlatExpression::Mult(
|
||||
box FlatExpression::Sub(
|
||||
box FlatExpression::Number(T::one()),
|
||||
box condition_id.into(),
|
||||
),
|
||||
box alternative_id.into(),
|
||||
),
|
||||
|
||||
let term1_ids: Vec<_> = (0..size).map(|_| self.use_sym()).collect();
|
||||
statements_flattened.extend(alternative_ids.iter().zip(term1_ids.iter()).map(
|
||||
|(a_id, t1_id)| {
|
||||
FlatStatement::Definition(
|
||||
*t1_id,
|
||||
FlatExpression::Mult(
|
||||
box FlatExpression::Sub(
|
||||
box FlatExpression::Number(T::one()),
|
||||
box condition_id.into(),
|
||||
),
|
||||
box FlatExpression::from(*a_id),
|
||||
),
|
||||
)
|
||||
},
|
||||
));
|
||||
let res = self.use_sym();
|
||||
statements_flattened.push(FlatStatement::Definition(
|
||||
res,
|
||||
FlatExpression::Add(box term0.into(), box term1.into()),
|
||||
|
||||
let res: Vec<_> = (0..size).map(|_| self.use_sym()).collect();
|
||||
statements_flattened.extend(term0_ids.iter().zip(term1_ids).zip(res.iter()).map(
|
||||
|((t0_id, t1_id), r_id)| {
|
||||
FlatStatement::Definition(
|
||||
*r_id,
|
||||
FlatExpression::Add(
|
||||
box FlatExpression::from(*t0_id),
|
||||
box FlatExpression::from(t1_id),
|
||||
),
|
||||
)
|
||||
},
|
||||
));
|
||||
res.into()
|
||||
res.into_iter().map(|r| r.into()).collect()
|
||||
}
|
||||
|
||||
fn flatten_select_expression<U: Flatten<'ast, T>>(
|
||||
|
@ -179,29 +244,33 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
statements_flattened: &mut Vec<FlatStatement<T>>,
|
||||
array: ArrayExpression<'ast, T>,
|
||||
index: FieldElementExpression<'ast, T>,
|
||||
) -> FlatExpression<T> {
|
||||
) -> Vec<FlatExpression<T>> {
|
||||
let size = array.size();
|
||||
let ty = array.inner_type();
|
||||
assert_eq!(ty, &U::get_type());
|
||||
let ty = array.inner_type().clone();
|
||||
//assert_eq!(ty, &U::get_type());
|
||||
|
||||
let element_size = ty.get_primitive_count();
|
||||
|
||||
match index {
|
||||
FieldElementExpression::Number(n) => match array.inner {
|
||||
ArrayExpressionInner::Identifier(id) => {
|
||||
assert!(n < T::from(size));
|
||||
FlatExpression::Identifier(
|
||||
self.layout.get(&id).unwrap().clone()
|
||||
[n.to_dec_string().parse::<usize>().unwrap()],
|
||||
)
|
||||
let n = n.to_dec_string().parse::<usize>().unwrap();
|
||||
self.layout.get(&id).unwrap()[n * element_size..(n + 1) * element_size]
|
||||
.into_iter()
|
||||
.map(|i| i.clone().into())
|
||||
.collect()
|
||||
}
|
||||
ArrayExpressionInner::Value(expressions) => {
|
||||
assert!(n < T::from(size));
|
||||
U::try_from(expressions[n.to_dec_string().parse::<usize>().unwrap()].clone())
|
||||
.unwrap()
|
||||
.flatten(self, symbols, statements_flattened)
|
||||
}
|
||||
ArrayExpressionInner::FunctionCall(..) => {
|
||||
unimplemented!("please use intermediate variables for now")
|
||||
let n = n.to_dec_string().parse::<usize>().unwrap();
|
||||
U::try_from(expressions[n].clone()).unwrap().flatten(
|
||||
self,
|
||||
symbols,
|
||||
statements_flattened,
|
||||
)
|
||||
}
|
||||
ArrayExpressionInner::FunctionCall(..) => unreachable!(),
|
||||
ArrayExpressionInner::IfElse(condition, consequence, alternative) => {
|
||||
// [if cond then [a, b] else [c, d]][1] == if cond then [a, b][1] else [c, d][1]
|
||||
U::if_else(
|
||||
|
@ -211,9 +280,23 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
)
|
||||
.flatten(self, symbols, statements_flattened)
|
||||
}
|
||||
ArrayExpressionInner::Select(box array, box index) => {
|
||||
assert!(n < T::from(size));
|
||||
let n = n.to_dec_string().parse::<usize>().unwrap();
|
||||
|
||||
let e = self.flatten_select_expression::<U>(
|
||||
symbols,
|
||||
statements_flattened,
|
||||
array,
|
||||
index,
|
||||
);
|
||||
e[n * element_size..(n + 1) * element_size]
|
||||
.into_iter()
|
||||
.map(|i| i.clone().into())
|
||||
.collect()
|
||||
}
|
||||
},
|
||||
e => {
|
||||
let size = array.size();
|
||||
// we have array[e] with e an arbitrary expression
|
||||
// first we check that e is in 0..array.len(), so we check that sum(if e == i then 1 else 0) == 1
|
||||
// here depending on the size, we could use a proper range check based on bits
|
||||
|
@ -242,32 +325,25 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
// now we flatten to sum(if e == i then array[i] else false)
|
||||
(0..size)
|
||||
.map(|i| {
|
||||
U::if_else(
|
||||
BooleanExpression::Eq(
|
||||
box e.clone(),
|
||||
box FieldElementExpression::Number(T::from(i)),
|
||||
let term = match array.clone().inner {
|
||||
// a[i] = a[i]
|
||||
ArrayExpressionInner::Identifier(id) => U::select(
|
||||
ArrayExpression {
|
||||
ty: ty.clone(),
|
||||
size,
|
||||
inner: ArrayExpressionInner::Identifier(id),
|
||||
},
|
||||
FieldElementExpression::Number(T::from(i)),
|
||||
),
|
||||
match array.clone().inner {
|
||||
ArrayExpressionInner::Identifier(id) => U::select(
|
||||
ArrayExpression {
|
||||
ty: ty.clone(),
|
||||
size,
|
||||
inner: ArrayExpressionInner::Identifier(id),
|
||||
},
|
||||
FieldElementExpression::Number(T::from(i)),
|
||||
),
|
||||
ArrayExpressionInner::Value(expressions) => {
|
||||
assert_eq!(size, expressions.len());
|
||||
U::try_from(expressions[i].clone()).unwrap()
|
||||
}
|
||||
ArrayExpressionInner::FunctionCall(..) => {
|
||||
unimplemented!("please use intermediate variables for now")
|
||||
}
|
||||
ArrayExpressionInner::IfElse(
|
||||
condition,
|
||||
consequence,
|
||||
alternative,
|
||||
) => U::if_else(
|
||||
// [a_0, ..., a_n][i] = a_i
|
||||
ArrayExpressionInner::Value(expressions) => {
|
||||
assert_eq!(size, expressions.len());
|
||||
U::try_from(expressions[i].clone()).unwrap()
|
||||
}
|
||||
ArrayExpressionInner::FunctionCall(..) => unreachable!(),
|
||||
// (if c then a else b fi)[i] = if c then a[i] else b[i]
|
||||
ArrayExpressionInner::IfElse(condition, consequence, alternative) => {
|
||||
U::if_else(
|
||||
*condition,
|
||||
U::select(
|
||||
*consequence,
|
||||
|
@ -277,15 +353,30 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
*alternative,
|
||||
FieldElementExpression::Number(T::from(i)),
|
||||
),
|
||||
),
|
||||
},
|
||||
U::zero(),
|
||||
)
|
||||
.flatten(self, symbols, statements_flattened)
|
||||
)
|
||||
}
|
||||
ArrayExpressionInner::Select(box array, box index) => U::select(
|
||||
ArrayExpression {
|
||||
ty: ty.clone(),
|
||||
size,
|
||||
inner: ArrayExpressionInner::Select(box array, box index),
|
||||
},
|
||||
FieldElementExpression::Number(T::from(i)),
|
||||
),
|
||||
};
|
||||
|
||||
(term, FieldElementExpression::Number(T::from(i)))
|
||||
})
|
||||
.fold(FlatExpression::Number(T::from(0)), |acc, e| {
|
||||
FlatExpression::Add(box acc, box e)
|
||||
.fold(None, |acc, (term, index)| match acc {
|
||||
None => Some(term),
|
||||
Some(acc) => Some(U::if_else(
|
||||
BooleanExpression::Eq(box e.clone(), box index),
|
||||
term,
|
||||
acc,
|
||||
)),
|
||||
})
|
||||
.unwrap()
|
||||
.flatten(self, symbols, statements_flattened)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -578,14 +669,16 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
condition,
|
||||
consequence,
|
||||
alternative,
|
||||
),
|
||||
)[0]
|
||||
.clone(),
|
||||
BooleanExpression::Select(box array, box index) => self
|
||||
.flatten_select_expression::<BooleanExpression<'ast, T>>(
|
||||
symbols,
|
||||
statements_flattened,
|
||||
array,
|
||||
index,
|
||||
),
|
||||
)[0]
|
||||
.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -719,7 +812,11 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
statements_flattened,
|
||||
e,
|
||||
),
|
||||
_ => unreachable!(),
|
||||
Type::Array(..) => self.flatten_array_expression::<ArrayExpression<'ast, T>>(
|
||||
symbols,
|
||||
statements_flattened,
|
||||
e,
|
||||
),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
@ -940,7 +1037,8 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
condition,
|
||||
consequence,
|
||||
alternative,
|
||||
),
|
||||
)[0]
|
||||
.clone(),
|
||||
FieldElementExpression::FunctionCall(key, param_expressions) => {
|
||||
let exprs_flattened = self.flatten_function_call(
|
||||
symbols,
|
||||
|
@ -958,7 +1056,8 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
statements_flattened,
|
||||
array,
|
||||
index,
|
||||
),
|
||||
)[0]
|
||||
.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -971,8 +1070,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
expr: ArrayExpression<'ast, T>,
|
||||
) -> Vec<FlatExpression<T>> {
|
||||
let ty = expr.get_type();
|
||||
let inner_type = expr.inner_type().clone();
|
||||
assert_eq!(U::get_type(), inner_type);
|
||||
//assert_eq!(U::get_type(), inner_type);
|
||||
let size = expr.size();
|
||||
|
||||
match expr.inner {
|
||||
|
@ -987,7 +1085,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
assert_eq!(size, values.len());
|
||||
values
|
||||
.into_iter()
|
||||
.map(|v| {
|
||||
.flat_map(|v| {
|
||||
U::try_from(v)
|
||||
.unwrap()
|
||||
.flatten(self, symbols, statements_flattened)
|
||||
|
@ -1007,7 +1105,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
}
|
||||
ArrayExpressionInner::IfElse(ref condition, ref consequence, ref alternative) => (0
|
||||
..size)
|
||||
.map(|i| {
|
||||
.flat_map(|i| {
|
||||
U::if_else(
|
||||
*condition.clone(),
|
||||
U::select(
|
||||
|
@ -1022,6 +1120,13 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
.flatten(self, symbols, statements_flattened)
|
||||
})
|
||||
.collect(),
|
||||
ArrayExpressionInner::Select(box array, box index) => self
|
||||
.flatten_select_expression::<ArrayExpression<'ast, T>>(
|
||||
symbols,
|
||||
statements_flattened,
|
||||
array,
|
||||
index,
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -2123,6 +2228,137 @@ mod tests {
|
|||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn array_of_array_sum() {
|
||||
// field[2][2] foo = [[1, 2], [3, 4]]
|
||||
// bar = foo[0][0] + foo[0][1] + foo[1][0] + foo[1][1]
|
||||
// we don't optimise detecting constants, this will be done in an optimiser pass
|
||||
|
||||
let mut flattener = Flattener::new();
|
||||
|
||||
let mut statements_flattened = vec![];
|
||||
let def = TypedStatement::Definition(
|
||||
TypedAssignee::Identifier(Variable::field_array("foo".into(), 4)),
|
||||
ArrayExpression {
|
||||
ty: Type::array(Type::FieldElement, 2),
|
||||
size: 2,
|
||||
inner: ArrayExpressionInner::Value(vec![
|
||||
ArrayExpression {
|
||||
size: 2,
|
||||
ty: Type::FieldElement,
|
||||
inner: ArrayExpressionInner::Value(vec![
|
||||
FieldElementExpression::Number(FieldPrime::from(1)).into(),
|
||||
FieldElementExpression::Number(FieldPrime::from(2)).into(),
|
||||
]),
|
||||
}
|
||||
.into(),
|
||||
ArrayExpression {
|
||||
size: 2,
|
||||
ty: Type::FieldElement,
|
||||
inner: ArrayExpressionInner::Value(vec![
|
||||
FieldElementExpression::Number(FieldPrime::from(3)).into(),
|
||||
FieldElementExpression::Number(FieldPrime::from(4)).into(),
|
||||
]),
|
||||
}
|
||||
.into(),
|
||||
]),
|
||||
}
|
||||
.into(),
|
||||
);
|
||||
|
||||
let sum = TypedStatement::Definition(
|
||||
TypedAssignee::Identifier(Variable::field_element("bar".into())),
|
||||
FieldElementExpression::Add(
|
||||
box FieldElementExpression::Add(
|
||||
box FieldElementExpression::Add(
|
||||
box FieldElementExpression::Select(
|
||||
box ArrayExpression {
|
||||
ty: Type::FieldElement,
|
||||
size: 2,
|
||||
inner: ArrayExpressionInner::Select(
|
||||
box ArrayExpression {
|
||||
ty: Type::array(Type::FieldElement, 2),
|
||||
size: 2,
|
||||
inner: ArrayExpressionInner::Identifier("foo".into()),
|
||||
},
|
||||
box FieldElementExpression::Number(FieldPrime::from(0)),
|
||||
),
|
||||
},
|
||||
box FieldElementExpression::Number(FieldPrime::from(0)),
|
||||
),
|
||||
box FieldElementExpression::Select(
|
||||
box ArrayExpression {
|
||||
ty: Type::FieldElement,
|
||||
size: 2,
|
||||
inner: ArrayExpressionInner::Select(
|
||||
box ArrayExpression {
|
||||
ty: Type::array(Type::FieldElement, 2),
|
||||
size: 2,
|
||||
inner: ArrayExpressionInner::Identifier("foo".into()),
|
||||
},
|
||||
box FieldElementExpression::Number(FieldPrime::from(0)),
|
||||
),
|
||||
},
|
||||
box FieldElementExpression::Number(FieldPrime::from(1)),
|
||||
),
|
||||
),
|
||||
box FieldElementExpression::Select(
|
||||
box ArrayExpression {
|
||||
ty: Type::FieldElement,
|
||||
size: 2,
|
||||
inner: ArrayExpressionInner::Select(
|
||||
box ArrayExpression {
|
||||
ty: Type::array(Type::FieldElement, 2),
|
||||
size: 2,
|
||||
inner: ArrayExpressionInner::Identifier("foo".into()),
|
||||
},
|
||||
box FieldElementExpression::Number(FieldPrime::from(1)),
|
||||
),
|
||||
},
|
||||
box FieldElementExpression::Number(FieldPrime::from(0)),
|
||||
),
|
||||
),
|
||||
box FieldElementExpression::Select(
|
||||
box ArrayExpression {
|
||||
ty: Type::FieldElement,
|
||||
size: 2,
|
||||
inner: ArrayExpressionInner::Select(
|
||||
box ArrayExpression {
|
||||
ty: Type::array(Type::FieldElement, 2),
|
||||
size: 2,
|
||||
inner: ArrayExpressionInner::Identifier("foo".into()),
|
||||
},
|
||||
box FieldElementExpression::Number(FieldPrime::from(1)),
|
||||
),
|
||||
},
|
||||
box FieldElementExpression::Number(FieldPrime::from(1)),
|
||||
),
|
||||
)
|
||||
.into(),
|
||||
);
|
||||
|
||||
flattener.flatten_statement(&HashMap::new(), &mut statements_flattened, def);
|
||||
|
||||
flattener.flatten_statement(&HashMap::new(), &mut statements_flattened, sum);
|
||||
|
||||
assert_eq!(
|
||||
statements_flattened[4],
|
||||
FlatStatement::Definition(
|
||||
FlatVariable::new(4),
|
||||
FlatExpression::Add(
|
||||
box FlatExpression::Add(
|
||||
box FlatExpression::Add(
|
||||
box FlatExpression::Identifier(FlatVariable::new(0)),
|
||||
box FlatExpression::Identifier(FlatVariable::new(1)),
|
||||
),
|
||||
box FlatExpression::Identifier(FlatVariable::new(2))
|
||||
),
|
||||
box FlatExpression::Identifier(FlatVariable::new(3)),
|
||||
)
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn array_if() {
|
||||
// if 1 == 1 then [1] else [3] fi
|
||||
|
|
|
@ -1084,7 +1084,12 @@ impl<'ast> Checker<'ast> {
|
|||
Ok(FieldElementExpression::Select(box a, box i).into())
|
||||
}
|
||||
Type::Boolean => Ok(BooleanExpression::Select(box a, box i).into()),
|
||||
Type::Array(_, _) => unimplemented!("multi dim wow"),
|
||||
Type::Array(box ty, size) => Ok(ArrayExpression {
|
||||
size: *size,
|
||||
ty: ty.clone(),
|
||||
inner: ArrayExpressionInner::Select(box a, box i),
|
||||
}
|
||||
.into()),
|
||||
}
|
||||
}
|
||||
(a, e) => Err(Error {
|
||||
|
@ -1110,7 +1115,7 @@ impl<'ast> Checker<'ast> {
|
|||
}
|
||||
|
||||
// we infer the type to be the type of the first element
|
||||
let inferred_type = expressions_checked.get(0).unwrap().get_type();
|
||||
let inferred_type = expressions_checked.get(0).unwrap().get_type().clone();
|
||||
|
||||
match inferred_type {
|
||||
Type::FieldElement => {
|
||||
|
@ -1169,16 +1174,49 @@ impl<'ast> Checker<'ast> {
|
|||
}
|
||||
.into())
|
||||
}
|
||||
_ => Err(Error {
|
||||
pos: Some(pos),
|
||||
ty @ Type::Array(..) => {
|
||||
// we check all expressions have that same type
|
||||
let mut unwrapped_expressions = vec![];
|
||||
|
||||
message: format!(
|
||||
"Only arrays of {} or {} are supported, found {}",
|
||||
Type::FieldElement,
|
||||
Type::Boolean,
|
||||
inferred_type
|
||||
),
|
||||
}),
|
||||
for e in expressions_checked {
|
||||
let unwrapped_e = match e {
|
||||
TypedExpression::Array(e) => {
|
||||
if e.get_type() == ty {
|
||||
Ok(e)
|
||||
} else {
|
||||
Err(Error {
|
||||
pos: Some(pos),
|
||||
|
||||
message: format!(
|
||||
"Expected {} to have type {}, but type is {}",
|
||||
e,
|
||||
ty,
|
||||
e.get_type()
|
||||
),
|
||||
})
|
||||
}
|
||||
}
|
||||
e => Err(Error {
|
||||
pos: Some(pos),
|
||||
|
||||
message: format!(
|
||||
"Expected {} to have type {}, but type is {}",
|
||||
e,
|
||||
ty,
|
||||
e.get_type()
|
||||
),
|
||||
}),
|
||||
}?;
|
||||
unwrapped_expressions.push(unwrapped_e.into());
|
||||
}
|
||||
|
||||
Ok(ArrayExpression {
|
||||
ty,
|
||||
size: unwrapped_expressions.len(),
|
||||
inner: ArrayExpressionInner::Value(unwrapped_expressions),
|
||||
}
|
||||
.into())
|
||||
}
|
||||
}
|
||||
}
|
||||
Expression::And(box e1, box e2) => {
|
||||
|
|
|
@ -178,6 +178,11 @@ pub fn fold_array_expression_inner<'ast, T: Field, F: Folder<'ast, T>>(
|
|||
box f.fold_array_expression(alternative),
|
||||
)
|
||||
}
|
||||
ArrayExpressionInner::Select(box array, box index) => {
|
||||
let array = f.fold_array_expression(array);
|
||||
let index = f.fold_field_expression(index);
|
||||
ArrayExpressionInner::Select(box array, box index)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -543,6 +543,10 @@ pub enum ArrayExpressionInner<'ast, T: Field> {
|
|||
Box<ArrayExpression<'ast, T>>,
|
||||
Box<ArrayExpression<'ast, T>>,
|
||||
),
|
||||
Select(
|
||||
Box<ArrayExpression<'ast, T>>,
|
||||
Box<FieldElementExpression<'ast, T>>,
|
||||
),
|
||||
}
|
||||
|
||||
impl<'ast, T: Field> ArrayExpression<'ast, T> {
|
||||
|
@ -581,6 +585,17 @@ impl<'ast, T: Field> TryFrom<TypedExpression<'ast, T>> for BooleanExpression<'as
|
|||
}
|
||||
}
|
||||
|
||||
impl<'ast, T: Field> TryFrom<TypedExpression<'ast, T>> for ArrayExpression<'ast, T> {
|
||||
type Error = ();
|
||||
|
||||
fn try_from(te: TypedExpression<'ast, T>) -> Result<ArrayExpression<'ast, T>, Self::Error> {
|
||||
match te {
|
||||
TypedExpression::Array(e) => Ok(e),
|
||||
_ => Err(()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast, T: Field> fmt::Display for FieldElementExpression<'ast, T> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
match *self {
|
||||
|
@ -664,6 +679,7 @@ impl<'ast, T: Field> fmt::Display for ArrayExpressionInner<'ast, T> {
|
|||
"if {} then {} else {} fi",
|
||||
condition, consequent, alternative
|
||||
),
|
||||
ArrayExpressionInner::Select(ref id, ref index) => write!(f, "{}[{}]", id, index),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -720,6 +736,9 @@ impl<'ast, T: Field> fmt::Debug for ArrayExpressionInner<'ast, T> {
|
|||
"IfElse({:?}, {:?}, {:?})",
|
||||
condition, consequent, alternative
|
||||
),
|
||||
ArrayExpressionInner::Select(ref id, ref index) => {
|
||||
write!(f, "Select({:?}, {:?})", id, index)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -16,7 +16,7 @@ ty_field = {"field"}
|
|||
ty_bool = {"bool"}
|
||||
ty_basic = { ty_field | ty_bool }
|
||||
// (unidimensional for now) arrays of (basic for now) types
|
||||
ty_array = { ty_basic ~ ("[" ~ expression ~ "]") }
|
||||
ty_array = { ty_basic ~ ("[" ~ expression ~ "]")+ }
|
||||
ty = { ty_array | ty_basic }
|
||||
type_list = _{(ty ~ ("," ~ ty)*)?}
|
||||
|
||||
|
|
|
@ -212,7 +212,7 @@ mod ast {
|
|||
#[pest_ast(rule(Rule::ty_array))]
|
||||
pub struct ArrayType<'ast> {
|
||||
pub ty: BasicType<'ast>,
|
||||
pub size: Expression<'ast>,
|
||||
pub size: Vec<Expression<'ast>>,
|
||||
#[pest_ast(outer())]
|
||||
pub span: Span<'ast>,
|
||||
}
|
||||
|
@ -403,7 +403,7 @@ mod ast {
|
|||
#[pest_ast(rule(Rule::postfix_expression))]
|
||||
pub struct PostfixExpression<'ast> {
|
||||
pub id: IdentifierExpression<'ast>,
|
||||
pub access: Vec<Access<'ast>>,
|
||||
pub accesses: Vec<Access<'ast>>,
|
||||
#[pest_ast(outer())]
|
||||
pub span: Span<'ast>,
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue