1
0
Fork 0
mirror of synced 2025-09-24 04:40:05 +00:00

fully recursive arrays

This commit is contained in:
schaeff 2019-07-16 15:30:49 +02:00
parent 34f41dc53f
commit b2fcb4c271
8 changed files with 471 additions and 148 deletions

View file

@ -0,0 +1,12 @@
def main(field[2][2][2] cube) -> (field):
field res = 0
for field i in 0..2 do
for field j in 0..2 do
for field k in 0..2 do
res = res + cube[i][j][k]
endfor
endfor
endfor
return res

View file

@ -448,22 +448,26 @@ impl<'ast, T: Field> From<pest::PostfixExpression<'ast>> for absy::ExpressionNod
fn from(expression: pest::PostfixExpression<'ast>) -> absy::ExpressionNode<'ast, T> {
use absy::NodeValue;
assert!(expression.access.len() == 1); // we only allow a single access: function call or array access
let id_str = expression.id.span.as_str();
let id = absy::ExpressionNode::from(expression.id);
match expression.access[0].clone() {
pest::Access::Call(a) => absy::Expression::FunctionCall(
&expression.id.span.as_str(),
a.expressions
.into_iter()
.map(|e| absy::ExpressionNode::from(e))
.collect(),
),
pest::Access::Select(a) => absy::Expression::Select(
box absy::ExpressionNode::from(expression.id),
box absy::RangeOrExpression::from(a.expression),
),
}
.span(expression.span)
expression.accesses.into_iter().fold(id, |acc, a| match a {
pest::Access::Call(a) => match acc.value {
absy::Expression::Identifier(_) => absy::Expression::FunctionCall(
&id_str,
a.expressions
.into_iter()
.map(|e| absy::ExpressionNode::from(e))
.collect(),
),
e => unimplemented!("only identifiers are callable, found \"{}\"", e),
}
.span(a.span),
pest::Access::Select(a) => {
absy::Expression::Select(box acc, box absy::RangeOrExpression::from(a.expression))
.span(a.span)
}
})
}
}
@ -521,24 +525,33 @@ impl<'ast> From<pest::Type<'ast>> for Type {
pest::BasicType::Boolean(_) => Type::Boolean,
},
pest::Type::Array(t) => {
let size = match t.size {
pest::Expression::Constant(c) => match c {
pest::ConstantExpression::DecimalNumber(n) => {
str::parse::<usize>(&n.value).unwrap()
}
_ => unimplemented!(
"Array size should be a decimal number, found {}",
c.span().as_str()
),
},
e => {
unimplemented!("Array size should be constant, found {}", e.span().as_str())
}
let inner_type = match t.ty {
pest::BasicType::Field(_) => Type::FieldElement,
pest::BasicType::Boolean(_) => Type::Boolean,
};
match t.ty {
pest::BasicType::Field(_) => Type::array(Type::FieldElement, size),
pest::BasicType::Boolean(_) => Type::array(Type::Boolean, size),
}
t.size
.into_iter()
.map(|s| match s {
pest::Expression::Constant(c) => match c {
pest::ConstantExpression::DecimalNumber(n) => {
str::parse::<usize>(&n.value).unwrap()
}
_ => unimplemented!(
"Array size should be a decimal number, found {}",
c.span().as_str()
),
},
e => unimplemented!(
"Array size should be constant, found {}",
e.span().as_str()
),
})
.fold(None, |acc, s| match acc {
None => Some(Type::array(inner_type.clone(), s)),
Some(acc) => Some(Type::array(acc, s)),
})
.unwrap()
}
}
}

View file

@ -33,16 +33,12 @@ trait Flatten<'ast, T: Field>: TryFrom<TypedExpression<'ast, T>, Error: std::fmt
flattener: &mut Flattener<'ast, T>,
symbols: &TypedFunctionSymbols<'ast, T>,
statements_flattened: &mut Vec<FlatStatement<T>>,
) -> FlatExpression<T>;
fn get_type() -> Type;
) -> Vec<FlatExpression<T>>;
fn if_else(condition: BooleanExpression<'ast, T>, consequence: Self, alternative: Self)
-> Self;
fn select(array: ArrayExpression<'ast, T>, index: FieldElementExpression<'ast, T>) -> Self;
fn zero() -> Self;
}
impl<'ast, T: Field> Flatten<'ast, T> for FieldElementExpression<'ast, T> {
@ -51,12 +47,8 @@ impl<'ast, T: Field> Flatten<'ast, T> for FieldElementExpression<'ast, T> {
flattener: &mut Flattener<'ast, T>,
symbols: &TypedFunctionSymbols<'ast, T>,
statements_flattened: &mut Vec<FlatStatement<T>>,
) -> FlatExpression<T> {
flattener.flatten_field_expression(symbols, statements_flattened, self)
}
fn get_type() -> Type {
Type::FieldElement
) -> Vec<FlatExpression<T>> {
vec![flattener.flatten_field_expression(symbols, statements_flattened, self)]
}
fn if_else(
@ -70,10 +62,6 @@ impl<'ast, T: Field> Flatten<'ast, T> for FieldElementExpression<'ast, T> {
fn select(array: ArrayExpression<'ast, T>, index: FieldElementExpression<'ast, T>) -> Self {
FieldElementExpression::Select(box array, box index)
}
fn zero() -> Self {
FieldElementExpression::Number(T::from(0))
}
}
impl<'ast, T: Field> Flatten<'ast, T> for BooleanExpression<'ast, T> {
@ -82,12 +70,8 @@ impl<'ast, T: Field> Flatten<'ast, T> for BooleanExpression<'ast, T> {
flattener: &mut Flattener<'ast, T>,
symbols: &TypedFunctionSymbols<'ast, T>,
statements_flattened: &mut Vec<FlatStatement<T>>,
) -> FlatExpression<T> {
flattener.flatten_boolean_expression(symbols, statements_flattened, self)
}
fn get_type() -> Type {
Type::Boolean
) -> Vec<FlatExpression<T>> {
vec![flattener.flatten_boolean_expression(symbols, statements_flattened, self)]
}
fn if_else(
@ -101,9 +85,58 @@ impl<'ast, T: Field> Flatten<'ast, T> for BooleanExpression<'ast, T> {
fn select(array: ArrayExpression<'ast, T>, index: FieldElementExpression<'ast, T>) -> Self {
BooleanExpression::Select(box array, box index)
}
}
fn zero() -> Self {
BooleanExpression::Value(false)
impl<'ast, T: Field> Flatten<'ast, T> for ArrayExpression<'ast, T> {
fn flatten(
self,
flattener: &mut Flattener<'ast, T>,
symbols: &TypedFunctionSymbols<'ast, T>,
statements_flattened: &mut Vec<FlatStatement<T>>,
) -> Vec<FlatExpression<T>> {
match self.inner_type() {
Type::FieldElement => flattener
.flatten_array_expression::<FieldElementExpression<'ast, T>>(
symbols,
statements_flattened,
self,
),
Type::Boolean => flattener.flatten_array_expression::<BooleanExpression<'ast, T>>(
symbols,
statements_flattened,
self,
),
Type::Array(..) => flattener.flatten_array_expression::<ArrayExpression<'ast, T>>(
symbols,
statements_flattened,
self,
),
}
}
fn if_else(
condition: BooleanExpression<'ast, T>,
consequence: Self,
alternative: Self,
) -> Self {
ArrayExpression {
ty: consequence.inner_type().clone(),
size: consequence.size(),
inner: ArrayExpressionInner::IfElse(box condition, box consequence, box alternative),
}
}
fn select(array: ArrayExpression<'ast, T>, index: FieldElementExpression<'ast, T>) -> Self {
let (ty, size) = match array.inner_type() {
Type::Array(inner, size) => (inner, size),
_ => unreachable!(),
};
ArrayExpression {
ty: *ty.clone(),
size: *size,
inner: ArrayExpressionInner::Select(box array, box index),
}
}
}
@ -133,44 +166,76 @@ impl<'ast, T: Field> Flattener<'ast, T> {
condition: BooleanExpression<'ast, T>,
consequence: U,
alternative: U,
) -> FlatExpression<T> {
) -> Vec<FlatExpression<T>> {
let condition = self.flatten_boolean_expression(symbols, statements_flattened, condition);
let consequence = consequence.flatten(self, symbols, statements_flattened);
let alternative = alternative.flatten(self, symbols, statements_flattened);
let size = consequence.len();
let condition_id = self.use_sym();
statements_flattened.push(FlatStatement::Definition(condition_id, condition));
let consequence_id = self.use_sym();
statements_flattened.push(FlatStatement::Definition(consequence_id, consequence));
let consequence_ids: Vec<_> = (0..size).map(|_| self.use_sym()).collect();
statements_flattened.extend(
consequence
.into_iter()
.zip(consequence_ids.iter())
.map(|(c, c_id)| FlatStatement::Definition(*c_id, c)),
);
let alternative_id = self.use_sym();
statements_flattened.push(FlatStatement::Definition(alternative_id, alternative));
let alternative_ids: Vec<_> = (0..size).map(|_| self.use_sym()).collect();
statements_flattened.extend(
alternative
.into_iter()
.zip(alternative_ids.iter())
.map(|(a, a_id)| FlatStatement::Definition(*a_id, a)),
);
let term0 = self.use_sym();
statements_flattened.push(FlatStatement::Definition(
term0,
FlatExpression::Mult(box condition_id.clone().into(), box consequence_id.into()),
let term0_ids: Vec<_> = (0..size).map(|_| self.use_sym()).collect();
statements_flattened.extend(consequence_ids.iter().zip(term0_ids.iter()).map(
|(c_id, t0_id)| {
FlatStatement::Definition(
*t0_id,
FlatExpression::Mult(
box condition_id.clone().into(),
box FlatExpression::from(*c_id),
),
)
},
));
let term1 = self.use_sym();
statements_flattened.push(FlatStatement::Definition(
term1,
FlatExpression::Mult(
box FlatExpression::Sub(
box FlatExpression::Number(T::one()),
box condition_id.into(),
),
box alternative_id.into(),
),
let term1_ids: Vec<_> = (0..size).map(|_| self.use_sym()).collect();
statements_flattened.extend(alternative_ids.iter().zip(term1_ids.iter()).map(
|(a_id, t1_id)| {
FlatStatement::Definition(
*t1_id,
FlatExpression::Mult(
box FlatExpression::Sub(
box FlatExpression::Number(T::one()),
box condition_id.into(),
),
box FlatExpression::from(*a_id),
),
)
},
));
let res = self.use_sym();
statements_flattened.push(FlatStatement::Definition(
res,
FlatExpression::Add(box term0.into(), box term1.into()),
let res: Vec<_> = (0..size).map(|_| self.use_sym()).collect();
statements_flattened.extend(term0_ids.iter().zip(term1_ids).zip(res.iter()).map(
|((t0_id, t1_id), r_id)| {
FlatStatement::Definition(
*r_id,
FlatExpression::Add(
box FlatExpression::from(*t0_id),
box FlatExpression::from(t1_id),
),
)
},
));
res.into()
res.into_iter().map(|r| r.into()).collect()
}
fn flatten_select_expression<U: Flatten<'ast, T>>(
@ -179,29 +244,33 @@ impl<'ast, T: Field> Flattener<'ast, T> {
statements_flattened: &mut Vec<FlatStatement<T>>,
array: ArrayExpression<'ast, T>,
index: FieldElementExpression<'ast, T>,
) -> FlatExpression<T> {
) -> Vec<FlatExpression<T>> {
let size = array.size();
let ty = array.inner_type();
assert_eq!(ty, &U::get_type());
let ty = array.inner_type().clone();
//assert_eq!(ty, &U::get_type());
let element_size = ty.get_primitive_count();
match index {
FieldElementExpression::Number(n) => match array.inner {
ArrayExpressionInner::Identifier(id) => {
assert!(n < T::from(size));
FlatExpression::Identifier(
self.layout.get(&id).unwrap().clone()
[n.to_dec_string().parse::<usize>().unwrap()],
)
let n = n.to_dec_string().parse::<usize>().unwrap();
self.layout.get(&id).unwrap()[n * element_size..(n + 1) * element_size]
.into_iter()
.map(|i| i.clone().into())
.collect()
}
ArrayExpressionInner::Value(expressions) => {
assert!(n < T::from(size));
U::try_from(expressions[n.to_dec_string().parse::<usize>().unwrap()].clone())
.unwrap()
.flatten(self, symbols, statements_flattened)
}
ArrayExpressionInner::FunctionCall(..) => {
unimplemented!("please use intermediate variables for now")
let n = n.to_dec_string().parse::<usize>().unwrap();
U::try_from(expressions[n].clone()).unwrap().flatten(
self,
symbols,
statements_flattened,
)
}
ArrayExpressionInner::FunctionCall(..) => unreachable!(),
ArrayExpressionInner::IfElse(condition, consequence, alternative) => {
// [if cond then [a, b] else [c, d]][1] == if cond then [a, b][1] else [c, d][1]
U::if_else(
@ -211,9 +280,23 @@ impl<'ast, T: Field> Flattener<'ast, T> {
)
.flatten(self, symbols, statements_flattened)
}
ArrayExpressionInner::Select(box array, box index) => {
assert!(n < T::from(size));
let n = n.to_dec_string().parse::<usize>().unwrap();
let e = self.flatten_select_expression::<U>(
symbols,
statements_flattened,
array,
index,
);
e[n * element_size..(n + 1) * element_size]
.into_iter()
.map(|i| i.clone().into())
.collect()
}
},
e => {
let size = array.size();
// we have array[e] with e an arbitrary expression
// first we check that e is in 0..array.len(), so we check that sum(if e == i then 1 else 0) == 1
// here depending on the size, we could use a proper range check based on bits
@ -242,32 +325,25 @@ impl<'ast, T: Field> Flattener<'ast, T> {
// now we flatten to sum(if e == i then array[i] else false)
(0..size)
.map(|i| {
U::if_else(
BooleanExpression::Eq(
box e.clone(),
box FieldElementExpression::Number(T::from(i)),
let term = match array.clone().inner {
// a[i] = a[i]
ArrayExpressionInner::Identifier(id) => U::select(
ArrayExpression {
ty: ty.clone(),
size,
inner: ArrayExpressionInner::Identifier(id),
},
FieldElementExpression::Number(T::from(i)),
),
match array.clone().inner {
ArrayExpressionInner::Identifier(id) => U::select(
ArrayExpression {
ty: ty.clone(),
size,
inner: ArrayExpressionInner::Identifier(id),
},
FieldElementExpression::Number(T::from(i)),
),
ArrayExpressionInner::Value(expressions) => {
assert_eq!(size, expressions.len());
U::try_from(expressions[i].clone()).unwrap()
}
ArrayExpressionInner::FunctionCall(..) => {
unimplemented!("please use intermediate variables for now")
}
ArrayExpressionInner::IfElse(
condition,
consequence,
alternative,
) => U::if_else(
// [a_0, ..., a_n][i] = a_i
ArrayExpressionInner::Value(expressions) => {
assert_eq!(size, expressions.len());
U::try_from(expressions[i].clone()).unwrap()
}
ArrayExpressionInner::FunctionCall(..) => unreachable!(),
// (if c then a else b fi)[i] = if c then a[i] else b[i]
ArrayExpressionInner::IfElse(condition, consequence, alternative) => {
U::if_else(
*condition,
U::select(
*consequence,
@ -277,15 +353,30 @@ impl<'ast, T: Field> Flattener<'ast, T> {
*alternative,
FieldElementExpression::Number(T::from(i)),
),
),
},
U::zero(),
)
.flatten(self, symbols, statements_flattened)
)
}
ArrayExpressionInner::Select(box array, box index) => U::select(
ArrayExpression {
ty: ty.clone(),
size,
inner: ArrayExpressionInner::Select(box array, box index),
},
FieldElementExpression::Number(T::from(i)),
),
};
(term, FieldElementExpression::Number(T::from(i)))
})
.fold(FlatExpression::Number(T::from(0)), |acc, e| {
FlatExpression::Add(box acc, box e)
.fold(None, |acc, (term, index)| match acc {
None => Some(term),
Some(acc) => Some(U::if_else(
BooleanExpression::Eq(box e.clone(), box index),
term,
acc,
)),
})
.unwrap()
.flatten(self, symbols, statements_flattened)
}
}
}
@ -578,14 +669,16 @@ impl<'ast, T: Field> Flattener<'ast, T> {
condition,
consequence,
alternative,
),
)[0]
.clone(),
BooleanExpression::Select(box array, box index) => self
.flatten_select_expression::<BooleanExpression<'ast, T>>(
symbols,
statements_flattened,
array,
index,
),
)[0]
.clone(),
}
}
@ -719,7 +812,11 @@ impl<'ast, T: Field> Flattener<'ast, T> {
statements_flattened,
e,
),
_ => unreachable!(),
Type::Array(..) => self.flatten_array_expression::<ArrayExpression<'ast, T>>(
symbols,
statements_flattened,
e,
),
},
}
}
@ -940,7 +1037,8 @@ impl<'ast, T: Field> Flattener<'ast, T> {
condition,
consequence,
alternative,
),
)[0]
.clone(),
FieldElementExpression::FunctionCall(key, param_expressions) => {
let exprs_flattened = self.flatten_function_call(
symbols,
@ -958,7 +1056,8 @@ impl<'ast, T: Field> Flattener<'ast, T> {
statements_flattened,
array,
index,
),
)[0]
.clone(),
}
}
@ -971,8 +1070,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
expr: ArrayExpression<'ast, T>,
) -> Vec<FlatExpression<T>> {
let ty = expr.get_type();
let inner_type = expr.inner_type().clone();
assert_eq!(U::get_type(), inner_type);
//assert_eq!(U::get_type(), inner_type);
let size = expr.size();
match expr.inner {
@ -987,7 +1085,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
assert_eq!(size, values.len());
values
.into_iter()
.map(|v| {
.flat_map(|v| {
U::try_from(v)
.unwrap()
.flatten(self, symbols, statements_flattened)
@ -1007,7 +1105,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
}
ArrayExpressionInner::IfElse(ref condition, ref consequence, ref alternative) => (0
..size)
.map(|i| {
.flat_map(|i| {
U::if_else(
*condition.clone(),
U::select(
@ -1022,6 +1120,13 @@ impl<'ast, T: Field> Flattener<'ast, T> {
.flatten(self, symbols, statements_flattened)
})
.collect(),
ArrayExpressionInner::Select(box array, box index) => self
.flatten_select_expression::<ArrayExpression<'ast, T>>(
symbols,
statements_flattened,
array,
index,
),
}
}
@ -2123,6 +2228,137 @@ mod tests {
);
}
#[test]
fn array_of_array_sum() {
// field[2][2] foo = [[1, 2], [3, 4]]
// bar = foo[0][0] + foo[0][1] + foo[1][0] + foo[1][1]
// we don't optimise detecting constants, this will be done in an optimiser pass
let mut flattener = Flattener::new();
let mut statements_flattened = vec![];
let def = TypedStatement::Definition(
TypedAssignee::Identifier(Variable::field_array("foo".into(), 4)),
ArrayExpression {
ty: Type::array(Type::FieldElement, 2),
size: 2,
inner: ArrayExpressionInner::Value(vec![
ArrayExpression {
size: 2,
ty: Type::FieldElement,
inner: ArrayExpressionInner::Value(vec![
FieldElementExpression::Number(FieldPrime::from(1)).into(),
FieldElementExpression::Number(FieldPrime::from(2)).into(),
]),
}
.into(),
ArrayExpression {
size: 2,
ty: Type::FieldElement,
inner: ArrayExpressionInner::Value(vec![
FieldElementExpression::Number(FieldPrime::from(3)).into(),
FieldElementExpression::Number(FieldPrime::from(4)).into(),
]),
}
.into(),
]),
}
.into(),
);
let sum = TypedStatement::Definition(
TypedAssignee::Identifier(Variable::field_element("bar".into())),
FieldElementExpression::Add(
box FieldElementExpression::Add(
box FieldElementExpression::Add(
box FieldElementExpression::Select(
box ArrayExpression {
ty: Type::FieldElement,
size: 2,
inner: ArrayExpressionInner::Select(
box ArrayExpression {
ty: Type::array(Type::FieldElement, 2),
size: 2,
inner: ArrayExpressionInner::Identifier("foo".into()),
},
box FieldElementExpression::Number(FieldPrime::from(0)),
),
},
box FieldElementExpression::Number(FieldPrime::from(0)),
),
box FieldElementExpression::Select(
box ArrayExpression {
ty: Type::FieldElement,
size: 2,
inner: ArrayExpressionInner::Select(
box ArrayExpression {
ty: Type::array(Type::FieldElement, 2),
size: 2,
inner: ArrayExpressionInner::Identifier("foo".into()),
},
box FieldElementExpression::Number(FieldPrime::from(0)),
),
},
box FieldElementExpression::Number(FieldPrime::from(1)),
),
),
box FieldElementExpression::Select(
box ArrayExpression {
ty: Type::FieldElement,
size: 2,
inner: ArrayExpressionInner::Select(
box ArrayExpression {
ty: Type::array(Type::FieldElement, 2),
size: 2,
inner: ArrayExpressionInner::Identifier("foo".into()),
},
box FieldElementExpression::Number(FieldPrime::from(1)),
),
},
box FieldElementExpression::Number(FieldPrime::from(0)),
),
),
box FieldElementExpression::Select(
box ArrayExpression {
ty: Type::FieldElement,
size: 2,
inner: ArrayExpressionInner::Select(
box ArrayExpression {
ty: Type::array(Type::FieldElement, 2),
size: 2,
inner: ArrayExpressionInner::Identifier("foo".into()),
},
box FieldElementExpression::Number(FieldPrime::from(1)),
),
},
box FieldElementExpression::Number(FieldPrime::from(1)),
),
)
.into(),
);
flattener.flatten_statement(&HashMap::new(), &mut statements_flattened, def);
flattener.flatten_statement(&HashMap::new(), &mut statements_flattened, sum);
assert_eq!(
statements_flattened[4],
FlatStatement::Definition(
FlatVariable::new(4),
FlatExpression::Add(
box FlatExpression::Add(
box FlatExpression::Add(
box FlatExpression::Identifier(FlatVariable::new(0)),
box FlatExpression::Identifier(FlatVariable::new(1)),
),
box FlatExpression::Identifier(FlatVariable::new(2))
),
box FlatExpression::Identifier(FlatVariable::new(3)),
)
)
);
}
#[test]
fn array_if() {
// if 1 == 1 then [1] else [3] fi

View file

@ -1084,7 +1084,12 @@ impl<'ast> Checker<'ast> {
Ok(FieldElementExpression::Select(box a, box i).into())
}
Type::Boolean => Ok(BooleanExpression::Select(box a, box i).into()),
Type::Array(_, _) => unimplemented!("multi dim wow"),
Type::Array(box ty, size) => Ok(ArrayExpression {
size: *size,
ty: ty.clone(),
inner: ArrayExpressionInner::Select(box a, box i),
}
.into()),
}
}
(a, e) => Err(Error {
@ -1110,7 +1115,7 @@ impl<'ast> Checker<'ast> {
}
// we infer the type to be the type of the first element
let inferred_type = expressions_checked.get(0).unwrap().get_type();
let inferred_type = expressions_checked.get(0).unwrap().get_type().clone();
match inferred_type {
Type::FieldElement => {
@ -1169,16 +1174,49 @@ impl<'ast> Checker<'ast> {
}
.into())
}
_ => Err(Error {
pos: Some(pos),
ty @ Type::Array(..) => {
// we check all expressions have that same type
let mut unwrapped_expressions = vec![];
message: format!(
"Only arrays of {} or {} are supported, found {}",
Type::FieldElement,
Type::Boolean,
inferred_type
),
}),
for e in expressions_checked {
let unwrapped_e = match e {
TypedExpression::Array(e) => {
if e.get_type() == ty {
Ok(e)
} else {
Err(Error {
pos: Some(pos),
message: format!(
"Expected {} to have type {}, but type is {}",
e,
ty,
e.get_type()
),
})
}
}
e => Err(Error {
pos: Some(pos),
message: format!(
"Expected {} to have type {}, but type is {}",
e,
ty,
e.get_type()
),
}),
}?;
unwrapped_expressions.push(unwrapped_e.into());
}
Ok(ArrayExpression {
ty,
size: unwrapped_expressions.len(),
inner: ArrayExpressionInner::Value(unwrapped_expressions),
}
.into())
}
}
}
Expression::And(box e1, box e2) => {

View file

@ -178,6 +178,11 @@ pub fn fold_array_expression_inner<'ast, T: Field, F: Folder<'ast, T>>(
box f.fold_array_expression(alternative),
)
}
ArrayExpressionInner::Select(box array, box index) => {
let array = f.fold_array_expression(array);
let index = f.fold_field_expression(index);
ArrayExpressionInner::Select(box array, box index)
}
}
}

View file

@ -543,6 +543,10 @@ pub enum ArrayExpressionInner<'ast, T: Field> {
Box<ArrayExpression<'ast, T>>,
Box<ArrayExpression<'ast, T>>,
),
Select(
Box<ArrayExpression<'ast, T>>,
Box<FieldElementExpression<'ast, T>>,
),
}
impl<'ast, T: Field> ArrayExpression<'ast, T> {
@ -581,6 +585,17 @@ impl<'ast, T: Field> TryFrom<TypedExpression<'ast, T>> for BooleanExpression<'as
}
}
impl<'ast, T: Field> TryFrom<TypedExpression<'ast, T>> for ArrayExpression<'ast, T> {
type Error = ();
fn try_from(te: TypedExpression<'ast, T>) -> Result<ArrayExpression<'ast, T>, Self::Error> {
match te {
TypedExpression::Array(e) => Ok(e),
_ => Err(()),
}
}
}
impl<'ast, T: Field> fmt::Display for FieldElementExpression<'ast, T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
@ -664,6 +679,7 @@ impl<'ast, T: Field> fmt::Display for ArrayExpressionInner<'ast, T> {
"if {} then {} else {} fi",
condition, consequent, alternative
),
ArrayExpressionInner::Select(ref id, ref index) => write!(f, "{}[{}]", id, index),
}
}
}
@ -720,6 +736,9 @@ impl<'ast, T: Field> fmt::Debug for ArrayExpressionInner<'ast, T> {
"IfElse({:?}, {:?}, {:?})",
condition, consequent, alternative
),
ArrayExpressionInner::Select(ref id, ref index) => {
write!(f, "Select({:?}, {:?})", id, index)
}
}
}
}

View file

@ -16,7 +16,7 @@ ty_field = {"field"}
ty_bool = {"bool"}
ty_basic = { ty_field | ty_bool }
// (unidimensional for now) arrays of (basic for now) types
ty_array = { ty_basic ~ ("[" ~ expression ~ "]") }
ty_array = { ty_basic ~ ("[" ~ expression ~ "]")+ }
ty = { ty_array | ty_basic }
type_list = _{(ty ~ ("," ~ ty)*)?}

View file

@ -212,7 +212,7 @@ mod ast {
#[pest_ast(rule(Rule::ty_array))]
pub struct ArrayType<'ast> {
pub ty: BasicType<'ast>,
pub size: Expression<'ast>,
pub size: Vec<Expression<'ast>>,
#[pest_ast(outer())]
pub span: Span<'ast>,
}
@ -403,7 +403,7 @@ mod ast {
#[pest_ast(rule(Rule::postfix_expression))]
pub struct PostfixExpression<'ast> {
pub id: IdentifierExpression<'ast>,
pub access: Vec<Access<'ast>>,
pub accesses: Vec<Access<'ast>>,
#[pest_ast(outer())]
pub span: Span<'ast>,
}