1
0
Fork 0
mirror of synced 2025-09-24 04:40:05 +00:00

implement flattening

This commit is contained in:
schaeff 2019-08-07 16:25:09 +02:00
parent bb0973de1e
commit 0673381bce
10 changed files with 420 additions and 131 deletions

14
t.code
View file

@ -1,13 +1,15 @@
from "./u.code" import Foo
struct Bar {
a: Foo,
b: field[2]
a: field,
b: field,
c: field,
}
def f(Foo a) -> (Foo):
return a
struct Baz {
a: Bar
}
def main(Bar a) -> (Foo):
return f(a.a)
def main(Baz a, Baz b, bool c) -> (Baz):
return if c then a else b fi.a

View file

@ -8,16 +8,6 @@ use zokrates_pest_ast as pest;
impl<'ast, T: Field> From<pest::File<'ast>> for absy::Module<'ast, T> {
fn from(prog: pest::File<'ast>) -> absy::Module<T> {
absy::Module {
// types: prog
// .structs
// .into_iter()
// .map(|t| absy::TypeDeclarationNode::from(t))
// .collect(),
// functions: prog
// .functions
// .into_iter()
// .map(|f| absy::FunctionDeclarationNode::from(f))
// .collect(),
symbols: prog
.structs
.into_iter()
@ -602,8 +592,13 @@ impl<'ast> From<pest::Type<'ast>> for UnresolvedTypeNode {
},
pest::Type::Array(t) => {
let inner_type = match t.ty {
pest::BasicType::Field(t) => UnresolvedType::FieldElement.span(t.span),
pest::BasicType::Boolean(t) => UnresolvedType::Boolean.span(t.span),
pest::BasicOrStructType::Basic(t) => match t {
pest::BasicType::Field(t) => UnresolvedType::FieldElement.span(t.span),
pest::BasicType::Boolean(t) => UnresolvedType::Boolean.span(t.span),
},
pest::BasicOrStructType::Struct(t) => {
UnresolvedType::User(t.span.as_str().to_string()).span(t.span)
}
};
let span = t.span;

View file

@ -39,7 +39,6 @@ pub use self::signature::UnresolvedSignature;
mod signature {
use std::fmt;
use super::*;
use absy::UnresolvedTypeNode;
#[derive(Clone, PartialEq, Serialize, Deserialize)]

View file

@ -1,5 +1,5 @@
use crate::absy::types::UnresolvedType;
use crate::absy::{Node, NodeValue, UnresolvedTypeNode};
use crate::absy::{Node, UnresolvedTypeNode};
use std::fmt;
use crate::absy::Identifier;

View file

@ -101,7 +101,7 @@ impl<'ast, T: Field> Flatten<'ast, T> for StructExpression<'ast, T> {
symbols: &TypedFunctionSymbols<'ast, T>,
statements_flattened: &mut Vec<FlatStatement<T>>,
) -> Vec<FlatExpression<T>> {
unimplemented!()
flattener.flatten_struct_expression(symbols, statements_flattened, self)
}
fn if_else(
@ -109,15 +109,42 @@ impl<'ast, T: Field> Flatten<'ast, T> for StructExpression<'ast, T> {
consequence: Self,
alternative: Self,
) -> Self {
unimplemented!()
StructExpression {
ty: consequence.ty.clone(),
inner: StructExpressionInner::IfElse(box condition, box consequence, box alternative),
}
}
fn select(array: ArrayExpression<'ast, T>, index: FieldElementExpression<'ast, T>) -> Self {
unimplemented!()
let members = match array.inner_type() {
Type::Struct(members) => members,
_ => unreachable!(),
};
StructExpression {
ty: members.clone(),
inner: StructExpressionInner::Select(box array, box index),
}
}
fn member(s: StructExpression<'ast, T>, id: MemberId) -> Self {
unimplemented!()
fn member(s: StructExpression<'ast, T>, member_id: MemberId) -> Self {
let members = s.ty.clone();
let ty = members
.into_iter()
.find(|(id, _)| *id == member_id)
.unwrap()
.1;
let members = match ty {
Type::Struct(members) => members,
_ => unreachable!(),
};
StructExpression {
ty: members,
inner: StructExpressionInner::Member(box s, member_id),
}
}
}
@ -178,8 +205,25 @@ impl<'ast, T: Field> Flatten<'ast, T> for ArrayExpression<'ast, T> {
}
}
fn member(s: StructExpression<'ast, T>, id: MemberId) -> Self {
unimplemented!()
fn member(s: StructExpression<'ast, T>, member_id: MemberId) -> Self {
let members = s.ty.clone();
let ty = members
.into_iter()
.find(|(id, _)| *id == member_id)
.unwrap()
.1;
let (ty, size) = match ty {
Type::Array(box ty, size) => (ty, size),
_ => unreachable!(),
};
ArrayExpression {
ty,
size,
inner: ArrayExpressionInner::Member(box s, member_id),
}
}
}
@ -281,7 +325,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
res.into_iter().map(|r| r.into()).collect()
}
fn flatten_member_expression<U: Flatten<'ast, T>>(
fn flatten_member_expression(
&mut self,
symbols: &TypedFunctionSymbols<'ast, T>,
statements_flattened: &mut Vec<FlatStatement<T>>,
@ -289,35 +333,270 @@ impl<'ast, T: Field> Flattener<'ast, T> {
member_id: MemberId,
) -> Vec<FlatExpression<T>> {
let members = s.ty;
let expected_output_size = members
.iter()
.find(|(id, _)| *id == member_id)
.unwrap()
.1
.get_primitive_count();
match s.inner {
StructExpressionInner::Identifier(id) => {
// the struct is encoded as a sequence, so we need to identify the offset at which this member starts
let offset = members
.iter()
.map(|(id, ty)| (id, ty.get_primitive_count()))
.fold((false, 0), |acc, (id, count)| {
if acc.0 && *id != member_id {
(false, acc.1 + count)
} else {
(true, acc.1)
let res =
match s.inner {
StructExpressionInner::Value(values) => {
// If the struct has an explicit value, we get the value at the given member
assert_eq!(values.len(), members.len());
values
.into_iter()
.zip(members.into_iter())
.filter(|(_, (id, _))| *id == member_id)
.flat_map(|(v, (_, t))| match t {
Type::FieldElement => FieldElementExpression::try_from(v)
.unwrap()
.flatten(self, symbols, statements_flattened),
Type::Boolean => BooleanExpression::try_from(v).unwrap().flatten(
self,
symbols,
statements_flattened,
),
Type::Array(box ty, size) => ArrayExpression::try_from(v)
.unwrap()
.flatten(self, symbols, statements_flattened),
Type::Struct(members) => StructExpression::try_from(v)
.unwrap()
.flatten(self, symbols, statements_flattened),
})
.collect()
}
StructExpressionInner::Identifier(id) => {
// If the struct is an identifier, we allocated variables in the layout for that identifier. We need to access a subset of these values.
// the struct is encoded as a sequence, so we need to identify the offset at which this member starts
let offset = members
.iter()
.take_while(|(id, _)| *id != member_id)
.map(|(_, ty)| ty.get_primitive_count())
.sum();
// we also need the size of this member
let size = members
.iter()
.find(|(id, _)| *id == member_id)
.unwrap()
.1
.get_primitive_count();
self.layout.get(&id).unwrap()[offset..(offset + size)]
.into_iter()
.map(|i| i.clone().into())
.collect()
}
StructExpressionInner::Select(box array, box index) => {
// If the struct is an array element `array[index]`, we're accessing `array[index].member`
// We construct `array := array.map(|e| e.member)` and access `array[index]`
let ty = members
.clone()
.into_iter()
.find(|(id, _)| *id == member_id)
.unwrap()
.1;
match ty {
Type::FieldElement => {
let array = ArrayExpression {
size: array.size,
ty: Type::FieldElement,
inner: ArrayExpressionInner::Value(
(0..array.size)
.map(|i| {
FieldElementExpression::Member(
box StructExpression {
ty: members.clone(),
inner: StructExpressionInner::Select(
box array.clone(),
box FieldElementExpression::Number(
T::from(i),
),
),
},
member_id.clone(),
)
.into()
})
.collect(),
),
};
self.flatten_select_expression::<FieldElementExpression<'ast, T>>(
symbols,
statements_flattened,
array,
index,
)
}
})
.1;
// we also need the size of this member
let size = members
.iter()
.find(|(id, _)| *id == member_id)
.unwrap()
.1
.get_primitive_count();
self.layout.get(&id).unwrap()[offset..(offset + size)]
.into_iter()
.map(|i| i.clone().into())
.collect()
}
_ => unimplemented!(),
}
Type::Boolean => {
let array = ArrayExpression {
size: array.size,
ty: Type::Boolean,
inner: ArrayExpressionInner::Value(
(0..array.size)
.map(|i| {
BooleanExpression::Member(
box StructExpression {
ty: members.clone(),
inner: StructExpressionInner::Select(
box array.clone(),
box FieldElementExpression::Number(
T::from(i),
),
),
},
member_id.clone(),
)
.into()
})
.collect(),
),
};
self.flatten_select_expression::<BooleanExpression<'ast, T>>(
symbols,
statements_flattened,
array,
index,
)
}
Type::Struct(m) => {
let array = ArrayExpression {
size: array.size,
ty: Type::Struct(m.clone()),
inner: ArrayExpressionInner::Value(
(0..array.size)
.map(|i| {
StructExpression {
ty: m.clone(),
inner: StructExpressionInner::Member(
box StructExpression {
ty: members.clone(),
inner: StructExpressionInner::Select(
box array.clone(),
box FieldElementExpression::Number(
T::from(i),
),
),
},
member_id.clone(),
),
}
.into()
})
.collect(),
),
};
self.flatten_select_expression::<StructExpression<'ast, T>>(
symbols,
statements_flattened,
array,
index,
)
}
Type::Array(box ty, size) => {
let array = ArrayExpression {
size: array.size,
ty: Type::Array(box ty.clone(), size),
inner: ArrayExpressionInner::Value(
(0..array.size)
.map(|i| {
ArrayExpression {
size,
ty: ty.clone(),
inner: ArrayExpressionInner::Member(
box StructExpression {
ty: members.clone(),
inner: StructExpressionInner::Select(
box array.clone(),
box FieldElementExpression::Number(
T::from(i),
),
),
},
member_id.clone(),
),
}
.into()
})
.collect(),
),
};
self.flatten_select_expression::<ArrayExpression<'ast, T>>(
symbols,
statements_flattened,
array,
index,
)
}
}
}
StructExpressionInner::FunctionCall(..) => unreachable!(),
StructExpressionInner::IfElse(box condition, box consequence, box alternative) => {
// if the struct is `(if c then a else b)`, we want to access `(if c then a else b).member`
// we reduce to `if c then a.member else b.member`
let ty = members
.clone()
.into_iter()
.find(|(id, _)| *id == member_id)
.unwrap()
.1;
match ty {
Type::FieldElement => self.flatten_if_else_expression(
symbols,
statements_flattened,
condition.clone(),
FieldElementExpression::member(consequence.clone(), member_id.clone()),
FieldElementExpression::member(alternative.clone(), member_id),
),
Type::Boolean => self.flatten_if_else_expression(
symbols,
statements_flattened,
condition.clone(),
BooleanExpression::member(consequence.clone(), member_id.clone()),
BooleanExpression::member(alternative.clone(), member_id),
),
Type::Struct(m) => self.flatten_if_else_expression(
symbols,
statements_flattened,
condition.clone(),
StructExpression::member(consequence.clone(), member_id.clone()),
StructExpression::member(alternative.clone(), member_id),
),
Type::Array(box ty, size) => self.flatten_if_else_expression(
symbols,
statements_flattened,
condition.clone(),
ArrayExpression::member(consequence.clone(), member_id.clone()),
ArrayExpression::member(alternative.clone(), member_id),
),
}
}
StructExpressionInner::Member(box s0, m_id) => {
let e = self.flatten_member_expression(symbols, statements_flattened, s0, m_id);
let offset = members
.iter()
.take_while(|(id, _)| *id != member_id)
.map(|(_, ty)| ty.get_primitive_count())
.sum();
// we also need the size of this member
let size = members
.iter()
.find(|(id, _)| *id == member_id)
.unwrap()
.1
.get_primitive_count();
e[offset..(offset + size)].into()
}
};
assert_eq!(res.len(), expected_output_size);
res
}
fn flatten_select_expression<U: Flatten<'ast, T>>(
@ -362,7 +641,13 @@ impl<'ast, T: Field> Flattener<'ast, T> {
)
.flatten(self, symbols, statements_flattened)
}
ArrayExpressionInner::Member(box s, id) => unimplemented!(),
ArrayExpressionInner::Member(box s, id) => {
assert!(n < T::from(size));
let n = n.to_dec_string().parse::<usize>().unwrap();
self.flatten_member_expression(symbols, statements_flattened, s, id)
[n * ty.get_primitive_count()..(n + 1) * ty.get_primitive_count()]
.to_vec()
}
ArrayExpressionInner::Select(box array, box index) => {
assert!(n < T::from(size));
let n = n.to_dec_string().parse::<usize>().unwrap();
@ -755,14 +1040,9 @@ impl<'ast, T: Field> Flattener<'ast, T> {
alternative,
)[0]
.clone(),
BooleanExpression::Member(box s, id) => self
.flatten_member_expression::<BooleanExpression<'ast, T>>(
symbols,
statements_flattened,
s,
id,
)[0]
.clone(),
BooleanExpression::Member(box s, id) => {
self.flatten_member_expression(symbols, statements_flattened, s, id)[0].clone()
}
BooleanExpression::Select(box array, box index) => self
.flatten_select_expression::<BooleanExpression<'ast, T>>(
symbols,
@ -1150,14 +1430,9 @@ impl<'ast, T: Field> Flattener<'ast, T> {
assert!(exprs_flattened.expressions.len() == 1); // outside of MultipleDefinition, FunctionCalls must return a single value
exprs_flattened.expressions[0].clone()
}
FieldElementExpression::Member(box s, id) => self
.flatten_member_expression::<FieldElementExpression<'ast, T>>(
symbols,
statements_flattened,
s,
id,
)[0]
.clone(),
FieldElementExpression::Member(box s, id) => {
self.flatten_member_expression(symbols, statements_flattened, s, id)[0].clone()
}
FieldElementExpression::Select(box array, box index) => self
.flatten_select_expression::<FieldElementExpression<'ast, T>>(
symbols,
@ -1176,9 +1451,10 @@ impl<'ast, T: Field> Flattener<'ast, T> {
expr: StructExpression<'ast, T>,
) -> Vec<FlatExpression<T>> {
let ty = expr.get_type();
//assert_eq!(U::get_type(), inner_type);
let expected_output_size = expr.get_type().get_primitive_count();
let members = expr.ty;
match expr.inner {
let res = match expr.inner {
StructExpressionInner::Identifier(x) => self
.layout
.get(&x)
@ -1186,16 +1462,10 @@ impl<'ast, T: Field> Flattener<'ast, T> {
.iter()
.map(|v| FlatExpression::Identifier(v.clone()))
.collect(),
// StructExpressionInner::Value(values) => {
// values
// .into_iter()
// .flat_map(|v| {
// U::try_from(v)
// .unwrap()
// .flatten(self, symbols, statements_flattened)
// })
// .collect()
// }
StructExpressionInner::Value(values) => values
.into_iter()
.flat_map(|v| self.flatten_expression(symbols, statements_flattened, v))
.collect(),
StructExpressionInner::FunctionCall(key, param_expressions) => {
let exprs_flattened = self.flatten_function_call(
symbols,
@ -1206,30 +1476,40 @@ impl<'ast, T: Field> Flattener<'ast, T> {
);
exprs_flattened.expressions
}
// StructExpressionInner::IfElse(ref condition, ref consequence, ref alternative) => (0
// ..size)
// .flat_map(|i| {
// U::if_else(
// *condition.clone(),
// U::select(
// *consequence.clone(),
// FieldElementExpression::Number(T::from(i)),
// ),
// U::select(
// *alternative.clone(),
// FieldElementExpression::Number(T::from(i)),
// ),
// )
// .flatten(self, symbols, statements_flattened)
// })
// .collect(),
StructExpressionInner::Member(box s, id) => self
.flatten_member_expression::<StructExpression<'ast, T>>(
symbols,
statements_flattened,
s,
id,
),
StructExpressionInner::IfElse(box condition, box consequence, box alternative) => {
members
.into_iter()
.flat_map(|(id, ty)| match ty {
Type::FieldElement => FieldElementExpression::if_else(
condition.clone(),
FieldElementExpression::member(consequence.clone(), id.clone()),
FieldElementExpression::member(alternative.clone(), id.clone()),
)
.flatten(self, symbols, statements_flattened),
Type::Boolean => BooleanExpression::if_else(
condition.clone(),
BooleanExpression::member(consequence.clone(), id.clone()),
BooleanExpression::member(alternative.clone(), id.clone()),
)
.flatten(self, symbols, statements_flattened),
Type::Struct(..) => StructExpression::if_else(
condition.clone(),
StructExpression::member(consequence.clone(), id.clone()),
StructExpression::member(alternative.clone(), id.clone()),
)
.flatten(self, symbols, statements_flattened),
Type::Array(..) => ArrayExpression::if_else(
condition.clone(),
ArrayExpression::member(consequence.clone(), id.clone()),
ArrayExpression::member(alternative.clone(), id.clone()),
)
.flatten(self, symbols, statements_flattened),
})
.collect()
}
StructExpressionInner::Member(box s, id) => {
self.flatten_member_expression(symbols, statements_flattened, s, id)
}
StructExpressionInner::Select(box array, box index) => self
.flatten_select_expression::<StructExpression<'ast, T>>(
symbols,
@ -1237,8 +1517,10 @@ impl<'ast, T: Field> Flattener<'ast, T> {
array,
index,
),
_ => unimplemented!("yeah well"),
}
};
assert_eq!(res.len(), expected_output_size);
res
}
/// # Remarks
@ -1300,13 +1582,9 @@ impl<'ast, T: Field> Flattener<'ast, T> {
.flatten(self, symbols, statements_flattened)
})
.collect(),
ArrayExpressionInner::Member(box s, id) => self
.flatten_member_expression::<ArrayExpression<'ast, T>>(
symbols,
statements_flattened,
s,
id,
),
ArrayExpressionInner::Member(box s, id) => {
self.flatten_member_expression(symbols, statements_flattened, s, id)
}
ArrayExpressionInner::Select(box array, box index) => self
.flatten_select_expression::<ArrayExpression<'ast, T>>(
symbols,

View file

@ -1072,6 +1072,16 @@ impl<'ast> Checker<'ast> {
unimplemented!("handle consequence alternative inner type mismatch")
}
},
(TypedExpression::Struct(consequence), TypedExpression::Struct(alternative)) => {
if consequence.get_type() == alternative.get_type() {
Ok(StructExpression {
ty: consequence.ty.clone(),
inner: StructExpressionInner::IfElse(box condition, box consequence, box alternative)
}.into())
} else {
unimplemented!("handle consequence alternative inner type mismatch")
}
},
_ => unimplemented!()
}
false => Err(Error {
@ -1367,8 +1377,8 @@ impl<'ast> Checker<'ast> {
// check that the struct has that field and return the type if it does
let ty =
s.ty.iter()
.find(|(member_id, ty)| member_id == id)
.map(|(member_id, ty)| ty);
.find(|(member_id, _)| member_id == id)
.map(|(_, ty)| ty);
match ty {
Some(ty) => match ty {

View file

@ -1,12 +1,9 @@
use absy::UnresolvedTypeNode;
use std::fmt;
pub type Identifier<'ast> = &'ast str;
pub type MemberId = String;
pub type UserTypeId = String;
#[derive(Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum Type {
FieldElement,

View file

@ -1,4 +1,3 @@
use crate::absy;
use crate::typed_absy::types::Type;
use crate::typed_absy::Identifier;
use std::fmt;

View file

@ -18,7 +18,8 @@ parameter = {vis? ~ ty ~ identifier}
ty_field = {"field"}
ty_bool = {"bool"}
ty_basic = { ty_field | ty_bool }
ty_array = { ty_basic ~ ("[" ~ expression ~ "]")+ }
ty_basic_or_struct = { ty_basic | ty_struct }
ty_array = { ty_basic_or_struct ~ ("[" ~ expression ~ "]")+ }
ty = { ty_array | ty_basic | ty_struct }
type_list = _{(ty ~ ("," ~ ty)*)?}
// structs

View file

@ -9,12 +9,13 @@ extern crate lazy_static;
pub use ast::{
Access, ArrayAccess, ArrayInitializerExpression, ArrayType, AssertionStatement, Assignee,
AssignmentStatement, BasicType, BinaryExpression, BinaryOperator, CallAccess,
ConstantExpression, DefinitionStatement, Expression, File, FromExpression, Function,
IdentifierExpression, ImportDirective, ImportSource, InlineArrayExpression, IterationStatement,
MultiAssignmentStatement, Parameter, PostfixExpression, Range, RangeOrExpression,
ReturnStatement, Span, Spread, SpreadOrExpression, Statement, StructDefinition, StructField,
TernaryExpression, ToExpression, Type, UnaryExpression, UnaryOperator, Visibility,
AssignmentStatement, BasicOrStructType, BasicType, BinaryExpression, BinaryOperator,
CallAccess, ConstantExpression, DefinitionStatement, Expression, File, FromExpression,
Function, IdentifierExpression, ImportDirective, ImportSource, InlineArrayExpression,
IterationStatement, MultiAssignmentStatement, Parameter, PostfixExpression, Range,
RangeOrExpression, ReturnStatement, Span, Spread, SpreadOrExpression, Statement,
StructDefinition, StructField, TernaryExpression, ToExpression, Type, UnaryExpression,
UnaryOperator, Visibility,
};
mod ast {
@ -251,12 +252,19 @@ mod ast {
#[derive(Debug, FromPest, PartialEq, Clone)]
#[pest_ast(rule(Rule::ty_array))]
pub struct ArrayType<'ast> {
pub ty: BasicType<'ast>,
pub ty: BasicOrStructType<'ast>,
pub size: Vec<Expression<'ast>>,
#[pest_ast(outer())]
pub span: Span<'ast>,
}
#[derive(Debug, FromPest, PartialEq, Clone)]
#[pest_ast(rule(Rule::ty_basic_or_struct))]
pub enum BasicOrStructType<'ast> {
Struct(StructType<'ast>),
Basic(BasicType<'ast>),
}
#[derive(Debug, FromPest, PartialEq, Clone)]
#[pest_ast(rule(Rule::ty_bool))]
pub struct BooleanType<'ast> {