implement flattening
This commit is contained in:
parent
bb0973de1e
commit
0673381bce
10 changed files with 420 additions and 131 deletions
14
t.code
14
t.code
|
@ -1,13 +1,15 @@
|
|||
from "./u.code" import Foo
|
||||
|
||||
struct Bar {
|
||||
a: Foo,
|
||||
b: field[2]
|
||||
a: field,
|
||||
b: field,
|
||||
c: field,
|
||||
}
|
||||
|
||||
def f(Foo a) -> (Foo):
|
||||
return a
|
||||
struct Baz {
|
||||
a: Bar
|
||||
}
|
||||
|
||||
def main(Bar a) -> (Foo):
|
||||
return f(a.a)
|
||||
def main(Baz a, Baz b, bool c) -> (Baz):
|
||||
return if c then a else b fi.a
|
||||
|
||||
|
|
|
@ -8,16 +8,6 @@ use zokrates_pest_ast as pest;
|
|||
impl<'ast, T: Field> From<pest::File<'ast>> for absy::Module<'ast, T> {
|
||||
fn from(prog: pest::File<'ast>) -> absy::Module<T> {
|
||||
absy::Module {
|
||||
// types: prog
|
||||
// .structs
|
||||
// .into_iter()
|
||||
// .map(|t| absy::TypeDeclarationNode::from(t))
|
||||
// .collect(),
|
||||
// functions: prog
|
||||
// .functions
|
||||
// .into_iter()
|
||||
// .map(|f| absy::FunctionDeclarationNode::from(f))
|
||||
// .collect(),
|
||||
symbols: prog
|
||||
.structs
|
||||
.into_iter()
|
||||
|
@ -602,8 +592,13 @@ impl<'ast> From<pest::Type<'ast>> for UnresolvedTypeNode {
|
|||
},
|
||||
pest::Type::Array(t) => {
|
||||
let inner_type = match t.ty {
|
||||
pest::BasicType::Field(t) => UnresolvedType::FieldElement.span(t.span),
|
||||
pest::BasicType::Boolean(t) => UnresolvedType::Boolean.span(t.span),
|
||||
pest::BasicOrStructType::Basic(t) => match t {
|
||||
pest::BasicType::Field(t) => UnresolvedType::FieldElement.span(t.span),
|
||||
pest::BasicType::Boolean(t) => UnresolvedType::Boolean.span(t.span),
|
||||
},
|
||||
pest::BasicOrStructType::Struct(t) => {
|
||||
UnresolvedType::User(t.span.as_str().to_string()).span(t.span)
|
||||
}
|
||||
};
|
||||
|
||||
let span = t.span;
|
||||
|
|
|
@ -39,7 +39,6 @@ pub use self::signature::UnresolvedSignature;
|
|||
mod signature {
|
||||
use std::fmt;
|
||||
|
||||
use super::*;
|
||||
use absy::UnresolvedTypeNode;
|
||||
|
||||
#[derive(Clone, PartialEq, Serialize, Deserialize)]
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
use crate::absy::types::UnresolvedType;
|
||||
use crate::absy::{Node, NodeValue, UnresolvedTypeNode};
|
||||
use crate::absy::{Node, UnresolvedTypeNode};
|
||||
use std::fmt;
|
||||
|
||||
use crate::absy::Identifier;
|
||||
|
|
|
@ -101,7 +101,7 @@ impl<'ast, T: Field> Flatten<'ast, T> for StructExpression<'ast, T> {
|
|||
symbols: &TypedFunctionSymbols<'ast, T>,
|
||||
statements_flattened: &mut Vec<FlatStatement<T>>,
|
||||
) -> Vec<FlatExpression<T>> {
|
||||
unimplemented!()
|
||||
flattener.flatten_struct_expression(symbols, statements_flattened, self)
|
||||
}
|
||||
|
||||
fn if_else(
|
||||
|
@ -109,15 +109,42 @@ impl<'ast, T: Field> Flatten<'ast, T> for StructExpression<'ast, T> {
|
|||
consequence: Self,
|
||||
alternative: Self,
|
||||
) -> Self {
|
||||
unimplemented!()
|
||||
StructExpression {
|
||||
ty: consequence.ty.clone(),
|
||||
inner: StructExpressionInner::IfElse(box condition, box consequence, box alternative),
|
||||
}
|
||||
}
|
||||
|
||||
fn select(array: ArrayExpression<'ast, T>, index: FieldElementExpression<'ast, T>) -> Self {
|
||||
unimplemented!()
|
||||
let members = match array.inner_type() {
|
||||
Type::Struct(members) => members,
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
StructExpression {
|
||||
ty: members.clone(),
|
||||
inner: StructExpressionInner::Select(box array, box index),
|
||||
}
|
||||
}
|
||||
|
||||
fn member(s: StructExpression<'ast, T>, id: MemberId) -> Self {
|
||||
unimplemented!()
|
||||
fn member(s: StructExpression<'ast, T>, member_id: MemberId) -> Self {
|
||||
let members = s.ty.clone();
|
||||
|
||||
let ty = members
|
||||
.into_iter()
|
||||
.find(|(id, _)| *id == member_id)
|
||||
.unwrap()
|
||||
.1;
|
||||
|
||||
let members = match ty {
|
||||
Type::Struct(members) => members,
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
StructExpression {
|
||||
ty: members,
|
||||
inner: StructExpressionInner::Member(box s, member_id),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -178,8 +205,25 @@ impl<'ast, T: Field> Flatten<'ast, T> for ArrayExpression<'ast, T> {
|
|||
}
|
||||
}
|
||||
|
||||
fn member(s: StructExpression<'ast, T>, id: MemberId) -> Self {
|
||||
unimplemented!()
|
||||
fn member(s: StructExpression<'ast, T>, member_id: MemberId) -> Self {
|
||||
let members = s.ty.clone();
|
||||
|
||||
let ty = members
|
||||
.into_iter()
|
||||
.find(|(id, _)| *id == member_id)
|
||||
.unwrap()
|
||||
.1;
|
||||
|
||||
let (ty, size) = match ty {
|
||||
Type::Array(box ty, size) => (ty, size),
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
ArrayExpression {
|
||||
ty,
|
||||
size,
|
||||
inner: ArrayExpressionInner::Member(box s, member_id),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -281,7 +325,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
res.into_iter().map(|r| r.into()).collect()
|
||||
}
|
||||
|
||||
fn flatten_member_expression<U: Flatten<'ast, T>>(
|
||||
fn flatten_member_expression(
|
||||
&mut self,
|
||||
symbols: &TypedFunctionSymbols<'ast, T>,
|
||||
statements_flattened: &mut Vec<FlatStatement<T>>,
|
||||
|
@ -289,35 +333,270 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
member_id: MemberId,
|
||||
) -> Vec<FlatExpression<T>> {
|
||||
let members = s.ty;
|
||||
let expected_output_size = members
|
||||
.iter()
|
||||
.find(|(id, _)| *id == member_id)
|
||||
.unwrap()
|
||||
.1
|
||||
.get_primitive_count();
|
||||
|
||||
match s.inner {
|
||||
StructExpressionInner::Identifier(id) => {
|
||||
// the struct is encoded as a sequence, so we need to identify the offset at which this member starts
|
||||
let offset = members
|
||||
.iter()
|
||||
.map(|(id, ty)| (id, ty.get_primitive_count()))
|
||||
.fold((false, 0), |acc, (id, count)| {
|
||||
if acc.0 && *id != member_id {
|
||||
(false, acc.1 + count)
|
||||
} else {
|
||||
(true, acc.1)
|
||||
let res =
|
||||
match s.inner {
|
||||
StructExpressionInner::Value(values) => {
|
||||
// If the struct has an explicit value, we get the value at the given member
|
||||
assert_eq!(values.len(), members.len());
|
||||
values
|
||||
.into_iter()
|
||||
.zip(members.into_iter())
|
||||
.filter(|(_, (id, _))| *id == member_id)
|
||||
.flat_map(|(v, (_, t))| match t {
|
||||
Type::FieldElement => FieldElementExpression::try_from(v)
|
||||
.unwrap()
|
||||
.flatten(self, symbols, statements_flattened),
|
||||
Type::Boolean => BooleanExpression::try_from(v).unwrap().flatten(
|
||||
self,
|
||||
symbols,
|
||||
statements_flattened,
|
||||
),
|
||||
Type::Array(box ty, size) => ArrayExpression::try_from(v)
|
||||
.unwrap()
|
||||
.flatten(self, symbols, statements_flattened),
|
||||
Type::Struct(members) => StructExpression::try_from(v)
|
||||
.unwrap()
|
||||
.flatten(self, symbols, statements_flattened),
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
StructExpressionInner::Identifier(id) => {
|
||||
// If the struct is an identifier, we allocated variables in the layout for that identifier. We need to access a subset of these values.
|
||||
// the struct is encoded as a sequence, so we need to identify the offset at which this member starts
|
||||
let offset = members
|
||||
.iter()
|
||||
.take_while(|(id, _)| *id != member_id)
|
||||
.map(|(_, ty)| ty.get_primitive_count())
|
||||
.sum();
|
||||
|
||||
// we also need the size of this member
|
||||
let size = members
|
||||
.iter()
|
||||
.find(|(id, _)| *id == member_id)
|
||||
.unwrap()
|
||||
.1
|
||||
.get_primitive_count();
|
||||
self.layout.get(&id).unwrap()[offset..(offset + size)]
|
||||
.into_iter()
|
||||
.map(|i| i.clone().into())
|
||||
.collect()
|
||||
}
|
||||
StructExpressionInner::Select(box array, box index) => {
|
||||
// If the struct is an array element `array[index]`, we're accessing `array[index].member`
|
||||
// We construct `array := array.map(|e| e.member)` and access `array[index]`
|
||||
let ty = members
|
||||
.clone()
|
||||
.into_iter()
|
||||
.find(|(id, _)| *id == member_id)
|
||||
.unwrap()
|
||||
.1;
|
||||
|
||||
match ty {
|
||||
Type::FieldElement => {
|
||||
let array = ArrayExpression {
|
||||
size: array.size,
|
||||
ty: Type::FieldElement,
|
||||
inner: ArrayExpressionInner::Value(
|
||||
(0..array.size)
|
||||
.map(|i| {
|
||||
FieldElementExpression::Member(
|
||||
box StructExpression {
|
||||
ty: members.clone(),
|
||||
inner: StructExpressionInner::Select(
|
||||
box array.clone(),
|
||||
box FieldElementExpression::Number(
|
||||
T::from(i),
|
||||
),
|
||||
),
|
||||
},
|
||||
member_id.clone(),
|
||||
)
|
||||
.into()
|
||||
})
|
||||
.collect(),
|
||||
),
|
||||
};
|
||||
self.flatten_select_expression::<FieldElementExpression<'ast, T>>(
|
||||
symbols,
|
||||
statements_flattened,
|
||||
array,
|
||||
index,
|
||||
)
|
||||
}
|
||||
})
|
||||
.1;
|
||||
// we also need the size of this member
|
||||
let size = members
|
||||
.iter()
|
||||
.find(|(id, _)| *id == member_id)
|
||||
.unwrap()
|
||||
.1
|
||||
.get_primitive_count();
|
||||
self.layout.get(&id).unwrap()[offset..(offset + size)]
|
||||
.into_iter()
|
||||
.map(|i| i.clone().into())
|
||||
.collect()
|
||||
}
|
||||
_ => unimplemented!(),
|
||||
}
|
||||
Type::Boolean => {
|
||||
let array = ArrayExpression {
|
||||
size: array.size,
|
||||
ty: Type::Boolean,
|
||||
inner: ArrayExpressionInner::Value(
|
||||
(0..array.size)
|
||||
.map(|i| {
|
||||
BooleanExpression::Member(
|
||||
box StructExpression {
|
||||
ty: members.clone(),
|
||||
inner: StructExpressionInner::Select(
|
||||
box array.clone(),
|
||||
box FieldElementExpression::Number(
|
||||
T::from(i),
|
||||
),
|
||||
),
|
||||
},
|
||||
member_id.clone(),
|
||||
)
|
||||
.into()
|
||||
})
|
||||
.collect(),
|
||||
),
|
||||
};
|
||||
self.flatten_select_expression::<BooleanExpression<'ast, T>>(
|
||||
symbols,
|
||||
statements_flattened,
|
||||
array,
|
||||
index,
|
||||
)
|
||||
}
|
||||
Type::Struct(m) => {
|
||||
let array = ArrayExpression {
|
||||
size: array.size,
|
||||
ty: Type::Struct(m.clone()),
|
||||
inner: ArrayExpressionInner::Value(
|
||||
(0..array.size)
|
||||
.map(|i| {
|
||||
StructExpression {
|
||||
ty: m.clone(),
|
||||
inner: StructExpressionInner::Member(
|
||||
box StructExpression {
|
||||
ty: members.clone(),
|
||||
inner: StructExpressionInner::Select(
|
||||
box array.clone(),
|
||||
box FieldElementExpression::Number(
|
||||
T::from(i),
|
||||
),
|
||||
),
|
||||
},
|
||||
member_id.clone(),
|
||||
),
|
||||
}
|
||||
.into()
|
||||
})
|
||||
.collect(),
|
||||
),
|
||||
};
|
||||
self.flatten_select_expression::<StructExpression<'ast, T>>(
|
||||
symbols,
|
||||
statements_flattened,
|
||||
array,
|
||||
index,
|
||||
)
|
||||
}
|
||||
Type::Array(box ty, size) => {
|
||||
let array = ArrayExpression {
|
||||
size: array.size,
|
||||
ty: Type::Array(box ty.clone(), size),
|
||||
inner: ArrayExpressionInner::Value(
|
||||
(0..array.size)
|
||||
.map(|i| {
|
||||
ArrayExpression {
|
||||
size,
|
||||
ty: ty.clone(),
|
||||
inner: ArrayExpressionInner::Member(
|
||||
box StructExpression {
|
||||
ty: members.clone(),
|
||||
inner: StructExpressionInner::Select(
|
||||
box array.clone(),
|
||||
box FieldElementExpression::Number(
|
||||
T::from(i),
|
||||
),
|
||||
),
|
||||
},
|
||||
member_id.clone(),
|
||||
),
|
||||
}
|
||||
.into()
|
||||
})
|
||||
.collect(),
|
||||
),
|
||||
};
|
||||
self.flatten_select_expression::<ArrayExpression<'ast, T>>(
|
||||
symbols,
|
||||
statements_flattened,
|
||||
array,
|
||||
index,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
StructExpressionInner::FunctionCall(..) => unreachable!(),
|
||||
StructExpressionInner::IfElse(box condition, box consequence, box alternative) => {
|
||||
// if the struct is `(if c then a else b)`, we want to access `(if c then a else b).member`
|
||||
// we reduce to `if c then a.member else b.member`
|
||||
let ty = members
|
||||
.clone()
|
||||
.into_iter()
|
||||
.find(|(id, _)| *id == member_id)
|
||||
.unwrap()
|
||||
.1;
|
||||
|
||||
match ty {
|
||||
Type::FieldElement => self.flatten_if_else_expression(
|
||||
symbols,
|
||||
statements_flattened,
|
||||
condition.clone(),
|
||||
FieldElementExpression::member(consequence.clone(), member_id.clone()),
|
||||
FieldElementExpression::member(alternative.clone(), member_id),
|
||||
),
|
||||
Type::Boolean => self.flatten_if_else_expression(
|
||||
symbols,
|
||||
statements_flattened,
|
||||
condition.clone(),
|
||||
BooleanExpression::member(consequence.clone(), member_id.clone()),
|
||||
BooleanExpression::member(alternative.clone(), member_id),
|
||||
),
|
||||
Type::Struct(m) => self.flatten_if_else_expression(
|
||||
symbols,
|
||||
statements_flattened,
|
||||
condition.clone(),
|
||||
StructExpression::member(consequence.clone(), member_id.clone()),
|
||||
StructExpression::member(alternative.clone(), member_id),
|
||||
),
|
||||
Type::Array(box ty, size) => self.flatten_if_else_expression(
|
||||
symbols,
|
||||
statements_flattened,
|
||||
condition.clone(),
|
||||
ArrayExpression::member(consequence.clone(), member_id.clone()),
|
||||
ArrayExpression::member(alternative.clone(), member_id),
|
||||
),
|
||||
}
|
||||
}
|
||||
StructExpressionInner::Member(box s0, m_id) => {
|
||||
let e = self.flatten_member_expression(symbols, statements_flattened, s0, m_id);
|
||||
|
||||
let offset = members
|
||||
.iter()
|
||||
.take_while(|(id, _)| *id != member_id)
|
||||
.map(|(_, ty)| ty.get_primitive_count())
|
||||
.sum();
|
||||
|
||||
// we also need the size of this member
|
||||
let size = members
|
||||
.iter()
|
||||
.find(|(id, _)| *id == member_id)
|
||||
.unwrap()
|
||||
.1
|
||||
.get_primitive_count();
|
||||
|
||||
e[offset..(offset + size)].into()
|
||||
}
|
||||
};
|
||||
|
||||
assert_eq!(res.len(), expected_output_size);
|
||||
res
|
||||
}
|
||||
|
||||
fn flatten_select_expression<U: Flatten<'ast, T>>(
|
||||
|
@ -362,7 +641,13 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
)
|
||||
.flatten(self, symbols, statements_flattened)
|
||||
}
|
||||
ArrayExpressionInner::Member(box s, id) => unimplemented!(),
|
||||
ArrayExpressionInner::Member(box s, id) => {
|
||||
assert!(n < T::from(size));
|
||||
let n = n.to_dec_string().parse::<usize>().unwrap();
|
||||
self.flatten_member_expression(symbols, statements_flattened, s, id)
|
||||
[n * ty.get_primitive_count()..(n + 1) * ty.get_primitive_count()]
|
||||
.to_vec()
|
||||
}
|
||||
ArrayExpressionInner::Select(box array, box index) => {
|
||||
assert!(n < T::from(size));
|
||||
let n = n.to_dec_string().parse::<usize>().unwrap();
|
||||
|
@ -755,14 +1040,9 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
alternative,
|
||||
)[0]
|
||||
.clone(),
|
||||
BooleanExpression::Member(box s, id) => self
|
||||
.flatten_member_expression::<BooleanExpression<'ast, T>>(
|
||||
symbols,
|
||||
statements_flattened,
|
||||
s,
|
||||
id,
|
||||
)[0]
|
||||
.clone(),
|
||||
BooleanExpression::Member(box s, id) => {
|
||||
self.flatten_member_expression(symbols, statements_flattened, s, id)[0].clone()
|
||||
}
|
||||
BooleanExpression::Select(box array, box index) => self
|
||||
.flatten_select_expression::<BooleanExpression<'ast, T>>(
|
||||
symbols,
|
||||
|
@ -1150,14 +1430,9 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
assert!(exprs_flattened.expressions.len() == 1); // outside of MultipleDefinition, FunctionCalls must return a single value
|
||||
exprs_flattened.expressions[0].clone()
|
||||
}
|
||||
FieldElementExpression::Member(box s, id) => self
|
||||
.flatten_member_expression::<FieldElementExpression<'ast, T>>(
|
||||
symbols,
|
||||
statements_flattened,
|
||||
s,
|
||||
id,
|
||||
)[0]
|
||||
.clone(),
|
||||
FieldElementExpression::Member(box s, id) => {
|
||||
self.flatten_member_expression(symbols, statements_flattened, s, id)[0].clone()
|
||||
}
|
||||
FieldElementExpression::Select(box array, box index) => self
|
||||
.flatten_select_expression::<FieldElementExpression<'ast, T>>(
|
||||
symbols,
|
||||
|
@ -1176,9 +1451,10 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
expr: StructExpression<'ast, T>,
|
||||
) -> Vec<FlatExpression<T>> {
|
||||
let ty = expr.get_type();
|
||||
//assert_eq!(U::get_type(), inner_type);
|
||||
let expected_output_size = expr.get_type().get_primitive_count();
|
||||
let members = expr.ty;
|
||||
|
||||
match expr.inner {
|
||||
let res = match expr.inner {
|
||||
StructExpressionInner::Identifier(x) => self
|
||||
.layout
|
||||
.get(&x)
|
||||
|
@ -1186,16 +1462,10 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
.iter()
|
||||
.map(|v| FlatExpression::Identifier(v.clone()))
|
||||
.collect(),
|
||||
// StructExpressionInner::Value(values) => {
|
||||
// values
|
||||
// .into_iter()
|
||||
// .flat_map(|v| {
|
||||
// U::try_from(v)
|
||||
// .unwrap()
|
||||
// .flatten(self, symbols, statements_flattened)
|
||||
// })
|
||||
// .collect()
|
||||
// }
|
||||
StructExpressionInner::Value(values) => values
|
||||
.into_iter()
|
||||
.flat_map(|v| self.flatten_expression(symbols, statements_flattened, v))
|
||||
.collect(),
|
||||
StructExpressionInner::FunctionCall(key, param_expressions) => {
|
||||
let exprs_flattened = self.flatten_function_call(
|
||||
symbols,
|
||||
|
@ -1206,30 +1476,40 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
);
|
||||
exprs_flattened.expressions
|
||||
}
|
||||
// StructExpressionInner::IfElse(ref condition, ref consequence, ref alternative) => (0
|
||||
// ..size)
|
||||
// .flat_map(|i| {
|
||||
// U::if_else(
|
||||
// *condition.clone(),
|
||||
// U::select(
|
||||
// *consequence.clone(),
|
||||
// FieldElementExpression::Number(T::from(i)),
|
||||
// ),
|
||||
// U::select(
|
||||
// *alternative.clone(),
|
||||
// FieldElementExpression::Number(T::from(i)),
|
||||
// ),
|
||||
// )
|
||||
// .flatten(self, symbols, statements_flattened)
|
||||
// })
|
||||
// .collect(),
|
||||
StructExpressionInner::Member(box s, id) => self
|
||||
.flatten_member_expression::<StructExpression<'ast, T>>(
|
||||
symbols,
|
||||
statements_flattened,
|
||||
s,
|
||||
id,
|
||||
),
|
||||
StructExpressionInner::IfElse(box condition, box consequence, box alternative) => {
|
||||
members
|
||||
.into_iter()
|
||||
.flat_map(|(id, ty)| match ty {
|
||||
Type::FieldElement => FieldElementExpression::if_else(
|
||||
condition.clone(),
|
||||
FieldElementExpression::member(consequence.clone(), id.clone()),
|
||||
FieldElementExpression::member(alternative.clone(), id.clone()),
|
||||
)
|
||||
.flatten(self, symbols, statements_flattened),
|
||||
Type::Boolean => BooleanExpression::if_else(
|
||||
condition.clone(),
|
||||
BooleanExpression::member(consequence.clone(), id.clone()),
|
||||
BooleanExpression::member(alternative.clone(), id.clone()),
|
||||
)
|
||||
.flatten(self, symbols, statements_flattened),
|
||||
Type::Struct(..) => StructExpression::if_else(
|
||||
condition.clone(),
|
||||
StructExpression::member(consequence.clone(), id.clone()),
|
||||
StructExpression::member(alternative.clone(), id.clone()),
|
||||
)
|
||||
.flatten(self, symbols, statements_flattened),
|
||||
Type::Array(..) => ArrayExpression::if_else(
|
||||
condition.clone(),
|
||||
ArrayExpression::member(consequence.clone(), id.clone()),
|
||||
ArrayExpression::member(alternative.clone(), id.clone()),
|
||||
)
|
||||
.flatten(self, symbols, statements_flattened),
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
StructExpressionInner::Member(box s, id) => {
|
||||
self.flatten_member_expression(symbols, statements_flattened, s, id)
|
||||
}
|
||||
StructExpressionInner::Select(box array, box index) => self
|
||||
.flatten_select_expression::<StructExpression<'ast, T>>(
|
||||
symbols,
|
||||
|
@ -1237,8 +1517,10 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
array,
|
||||
index,
|
||||
),
|
||||
_ => unimplemented!("yeah well"),
|
||||
}
|
||||
};
|
||||
|
||||
assert_eq!(res.len(), expected_output_size);
|
||||
res
|
||||
}
|
||||
|
||||
/// # Remarks
|
||||
|
@ -1300,13 +1582,9 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
.flatten(self, symbols, statements_flattened)
|
||||
})
|
||||
.collect(),
|
||||
ArrayExpressionInner::Member(box s, id) => self
|
||||
.flatten_member_expression::<ArrayExpression<'ast, T>>(
|
||||
symbols,
|
||||
statements_flattened,
|
||||
s,
|
||||
id,
|
||||
),
|
||||
ArrayExpressionInner::Member(box s, id) => {
|
||||
self.flatten_member_expression(symbols, statements_flattened, s, id)
|
||||
}
|
||||
ArrayExpressionInner::Select(box array, box index) => self
|
||||
.flatten_select_expression::<ArrayExpression<'ast, T>>(
|
||||
symbols,
|
||||
|
|
|
@ -1072,6 +1072,16 @@ impl<'ast> Checker<'ast> {
|
|||
unimplemented!("handle consequence alternative inner type mismatch")
|
||||
}
|
||||
},
|
||||
(TypedExpression::Struct(consequence), TypedExpression::Struct(alternative)) => {
|
||||
if consequence.get_type() == alternative.get_type() {
|
||||
Ok(StructExpression {
|
||||
ty: consequence.ty.clone(),
|
||||
inner: StructExpressionInner::IfElse(box condition, box consequence, box alternative)
|
||||
}.into())
|
||||
} else {
|
||||
unimplemented!("handle consequence alternative inner type mismatch")
|
||||
}
|
||||
},
|
||||
_ => unimplemented!()
|
||||
}
|
||||
false => Err(Error {
|
||||
|
@ -1367,8 +1377,8 @@ impl<'ast> Checker<'ast> {
|
|||
// check that the struct has that field and return the type if it does
|
||||
let ty =
|
||||
s.ty.iter()
|
||||
.find(|(member_id, ty)| member_id == id)
|
||||
.map(|(member_id, ty)| ty);
|
||||
.find(|(member_id, _)| member_id == id)
|
||||
.map(|(_, ty)| ty);
|
||||
|
||||
match ty {
|
||||
Some(ty) => match ty {
|
||||
|
|
|
@ -1,12 +1,9 @@
|
|||
use absy::UnresolvedTypeNode;
|
||||
use std::fmt;
|
||||
|
||||
pub type Identifier<'ast> = &'ast str;
|
||||
|
||||
pub type MemberId = String;
|
||||
|
||||
pub type UserTypeId = String;
|
||||
|
||||
#[derive(Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
|
||||
pub enum Type {
|
||||
FieldElement,
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
use crate::absy;
|
||||
use crate::typed_absy::types::Type;
|
||||
use crate::typed_absy::Identifier;
|
||||
use std::fmt;
|
||||
|
|
|
@ -18,7 +18,8 @@ parameter = {vis? ~ ty ~ identifier}
|
|||
ty_field = {"field"}
|
||||
ty_bool = {"bool"}
|
||||
ty_basic = { ty_field | ty_bool }
|
||||
ty_array = { ty_basic ~ ("[" ~ expression ~ "]")+ }
|
||||
ty_basic_or_struct = { ty_basic | ty_struct }
|
||||
ty_array = { ty_basic_or_struct ~ ("[" ~ expression ~ "]")+ }
|
||||
ty = { ty_array | ty_basic | ty_struct }
|
||||
type_list = _{(ty ~ ("," ~ ty)*)?}
|
||||
// structs
|
||||
|
|
|
@ -9,12 +9,13 @@ extern crate lazy_static;
|
|||
|
||||
pub use ast::{
|
||||
Access, ArrayAccess, ArrayInitializerExpression, ArrayType, AssertionStatement, Assignee,
|
||||
AssignmentStatement, BasicType, BinaryExpression, BinaryOperator, CallAccess,
|
||||
ConstantExpression, DefinitionStatement, Expression, File, FromExpression, Function,
|
||||
IdentifierExpression, ImportDirective, ImportSource, InlineArrayExpression, IterationStatement,
|
||||
MultiAssignmentStatement, Parameter, PostfixExpression, Range, RangeOrExpression,
|
||||
ReturnStatement, Span, Spread, SpreadOrExpression, Statement, StructDefinition, StructField,
|
||||
TernaryExpression, ToExpression, Type, UnaryExpression, UnaryOperator, Visibility,
|
||||
AssignmentStatement, BasicOrStructType, BasicType, BinaryExpression, BinaryOperator,
|
||||
CallAccess, ConstantExpression, DefinitionStatement, Expression, File, FromExpression,
|
||||
Function, IdentifierExpression, ImportDirective, ImportSource, InlineArrayExpression,
|
||||
IterationStatement, MultiAssignmentStatement, Parameter, PostfixExpression, Range,
|
||||
RangeOrExpression, ReturnStatement, Span, Spread, SpreadOrExpression, Statement,
|
||||
StructDefinition, StructField, TernaryExpression, ToExpression, Type, UnaryExpression,
|
||||
UnaryOperator, Visibility,
|
||||
};
|
||||
|
||||
mod ast {
|
||||
|
@ -251,12 +252,19 @@ mod ast {
|
|||
#[derive(Debug, FromPest, PartialEq, Clone)]
|
||||
#[pest_ast(rule(Rule::ty_array))]
|
||||
pub struct ArrayType<'ast> {
|
||||
pub ty: BasicType<'ast>,
|
||||
pub ty: BasicOrStructType<'ast>,
|
||||
pub size: Vec<Expression<'ast>>,
|
||||
#[pest_ast(outer())]
|
||||
pub span: Span<'ast>,
|
||||
}
|
||||
|
||||
#[derive(Debug, FromPest, PartialEq, Clone)]
|
||||
#[pest_ast(rule(Rule::ty_basic_or_struct))]
|
||||
pub enum BasicOrStructType<'ast> {
|
||||
Struct(StructType<'ast>),
|
||||
Basic(BasicType<'ast>),
|
||||
}
|
||||
|
||||
#[derive(Debug, FromPest, PartialEq, Clone)]
|
||||
#[pest_ast(rule(Rule::ty_bool))]
|
||||
pub struct BooleanType<'ast> {
|
||||
|
|
Loading…
Reference in a new issue