From 0673381bce4b28262860692ee9ec8b1afb3e595c Mon Sep 17 00:00:00 2001 From: schaeff Date: Wed, 7 Aug 2019 16:25:09 +0200 Subject: [PATCH] implement flattening --- t.code | 14 +- zokrates_core/src/absy/from_ast.rs | 19 +- zokrates_core/src/absy/types.rs | 1 - zokrates_core/src/absy/variable.rs | 2 +- zokrates_core/src/flatten/mod.rs | 472 ++++++++++++++++++----- zokrates_core/src/semantics.rs | 14 +- zokrates_core/src/typed_absy/types.rs | 3 - zokrates_core/src/typed_absy/variable.rs | 1 - zokrates_parser/src/zokrates.pest | 3 +- zokrates_pest_ast/src/lib.rs | 22 +- 10 files changed, 420 insertions(+), 131 deletions(-) diff --git a/t.code b/t.code index 351d0176..cd71dab9 100644 --- a/t.code +++ b/t.code @@ -1,13 +1,15 @@ from "./u.code" import Foo struct Bar { - a: Foo, - b: field[2] + a: field, + b: field, + c: field, } -def f(Foo a) -> (Foo): - return a +struct Baz { + a: Bar +} -def main(Bar a) -> (Foo): - return f(a.a) +def main(Baz a, Baz b, bool c) -> (Baz): + return if c then a else b fi.a diff --git a/zokrates_core/src/absy/from_ast.rs b/zokrates_core/src/absy/from_ast.rs index 9feec687..39f61c5f 100644 --- a/zokrates_core/src/absy/from_ast.rs +++ b/zokrates_core/src/absy/from_ast.rs @@ -8,16 +8,6 @@ use zokrates_pest_ast as pest; impl<'ast, T: Field> From> for absy::Module<'ast, T> { fn from(prog: pest::File<'ast>) -> absy::Module { absy::Module { - // types: prog - // .structs - // .into_iter() - // .map(|t| absy::TypeDeclarationNode::from(t)) - // .collect(), - // functions: prog - // .functions - // .into_iter() - // .map(|f| absy::FunctionDeclarationNode::from(f)) - // .collect(), symbols: prog .structs .into_iter() @@ -602,8 +592,13 @@ impl<'ast> From> for UnresolvedTypeNode { }, pest::Type::Array(t) => { let inner_type = match t.ty { - pest::BasicType::Field(t) => UnresolvedType::FieldElement.span(t.span), - pest::BasicType::Boolean(t) => UnresolvedType::Boolean.span(t.span), + pest::BasicOrStructType::Basic(t) => match t { + pest::BasicType::Field(t) => UnresolvedType::FieldElement.span(t.span), + pest::BasicType::Boolean(t) => UnresolvedType::Boolean.span(t.span), + }, + pest::BasicOrStructType::Struct(t) => { + UnresolvedType::User(t.span.as_str().to_string()).span(t.span) + } }; let span = t.span; diff --git a/zokrates_core/src/absy/types.rs b/zokrates_core/src/absy/types.rs index bc2f6bac..d5e22303 100644 --- a/zokrates_core/src/absy/types.rs +++ b/zokrates_core/src/absy/types.rs @@ -39,7 +39,6 @@ pub use self::signature::UnresolvedSignature; mod signature { use std::fmt; - use super::*; use absy::UnresolvedTypeNode; #[derive(Clone, PartialEq, Serialize, Deserialize)] diff --git a/zokrates_core/src/absy/variable.rs b/zokrates_core/src/absy/variable.rs index bf96e69f..f03b3f0d 100644 --- a/zokrates_core/src/absy/variable.rs +++ b/zokrates_core/src/absy/variable.rs @@ -1,5 +1,5 @@ use crate::absy::types::UnresolvedType; -use crate::absy::{Node, NodeValue, UnresolvedTypeNode}; +use crate::absy::{Node, UnresolvedTypeNode}; use std::fmt; use crate::absy::Identifier; diff --git a/zokrates_core/src/flatten/mod.rs b/zokrates_core/src/flatten/mod.rs index acb3fef7..7c967cdd 100644 --- a/zokrates_core/src/flatten/mod.rs +++ b/zokrates_core/src/flatten/mod.rs @@ -101,7 +101,7 @@ impl<'ast, T: Field> Flatten<'ast, T> for StructExpression<'ast, T> { symbols: &TypedFunctionSymbols<'ast, T>, statements_flattened: &mut Vec>, ) -> Vec> { - unimplemented!() + flattener.flatten_struct_expression(symbols, statements_flattened, self) } fn if_else( @@ -109,15 +109,42 @@ impl<'ast, T: Field> Flatten<'ast, T> for StructExpression<'ast, T> { consequence: Self, alternative: Self, ) -> Self { - unimplemented!() + StructExpression { + ty: consequence.ty.clone(), + inner: StructExpressionInner::IfElse(box condition, box consequence, box alternative), + } } fn select(array: ArrayExpression<'ast, T>, index: FieldElementExpression<'ast, T>) -> Self { - unimplemented!() + let members = match array.inner_type() { + Type::Struct(members) => members, + _ => unreachable!(), + }; + + StructExpression { + ty: members.clone(), + inner: StructExpressionInner::Select(box array, box index), + } } - fn member(s: StructExpression<'ast, T>, id: MemberId) -> Self { - unimplemented!() + fn member(s: StructExpression<'ast, T>, member_id: MemberId) -> Self { + let members = s.ty.clone(); + + let ty = members + .into_iter() + .find(|(id, _)| *id == member_id) + .unwrap() + .1; + + let members = match ty { + Type::Struct(members) => members, + _ => unreachable!(), + }; + + StructExpression { + ty: members, + inner: StructExpressionInner::Member(box s, member_id), + } } } @@ -178,8 +205,25 @@ impl<'ast, T: Field> Flatten<'ast, T> for ArrayExpression<'ast, T> { } } - fn member(s: StructExpression<'ast, T>, id: MemberId) -> Self { - unimplemented!() + fn member(s: StructExpression<'ast, T>, member_id: MemberId) -> Self { + let members = s.ty.clone(); + + let ty = members + .into_iter() + .find(|(id, _)| *id == member_id) + .unwrap() + .1; + + let (ty, size) = match ty { + Type::Array(box ty, size) => (ty, size), + _ => unreachable!(), + }; + + ArrayExpression { + ty, + size, + inner: ArrayExpressionInner::Member(box s, member_id), + } } } @@ -281,7 +325,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { res.into_iter().map(|r| r.into()).collect() } - fn flatten_member_expression>( + fn flatten_member_expression( &mut self, symbols: &TypedFunctionSymbols<'ast, T>, statements_flattened: &mut Vec>, @@ -289,35 +333,270 @@ impl<'ast, T: Field> Flattener<'ast, T> { member_id: MemberId, ) -> Vec> { let members = s.ty; + let expected_output_size = members + .iter() + .find(|(id, _)| *id == member_id) + .unwrap() + .1 + .get_primitive_count(); - match s.inner { - StructExpressionInner::Identifier(id) => { - // the struct is encoded as a sequence, so we need to identify the offset at which this member starts - let offset = members - .iter() - .map(|(id, ty)| (id, ty.get_primitive_count())) - .fold((false, 0), |acc, (id, count)| { - if acc.0 && *id != member_id { - (false, acc.1 + count) - } else { - (true, acc.1) + let res = + match s.inner { + StructExpressionInner::Value(values) => { + // If the struct has an explicit value, we get the value at the given member + assert_eq!(values.len(), members.len()); + values + .into_iter() + .zip(members.into_iter()) + .filter(|(_, (id, _))| *id == member_id) + .flat_map(|(v, (_, t))| match t { + Type::FieldElement => FieldElementExpression::try_from(v) + .unwrap() + .flatten(self, symbols, statements_flattened), + Type::Boolean => BooleanExpression::try_from(v).unwrap().flatten( + self, + symbols, + statements_flattened, + ), + Type::Array(box ty, size) => ArrayExpression::try_from(v) + .unwrap() + .flatten(self, symbols, statements_flattened), + Type::Struct(members) => StructExpression::try_from(v) + .unwrap() + .flatten(self, symbols, statements_flattened), + }) + .collect() + } + StructExpressionInner::Identifier(id) => { + // If the struct is an identifier, we allocated variables in the layout for that identifier. We need to access a subset of these values. + // the struct is encoded as a sequence, so we need to identify the offset at which this member starts + let offset = members + .iter() + .take_while(|(id, _)| *id != member_id) + .map(|(_, ty)| ty.get_primitive_count()) + .sum(); + + // we also need the size of this member + let size = members + .iter() + .find(|(id, _)| *id == member_id) + .unwrap() + .1 + .get_primitive_count(); + self.layout.get(&id).unwrap()[offset..(offset + size)] + .into_iter() + .map(|i| i.clone().into()) + .collect() + } + StructExpressionInner::Select(box array, box index) => { + // If the struct is an array element `array[index]`, we're accessing `array[index].member` + // We construct `array := array.map(|e| e.member)` and access `array[index]` + let ty = members + .clone() + .into_iter() + .find(|(id, _)| *id == member_id) + .unwrap() + .1; + + match ty { + Type::FieldElement => { + let array = ArrayExpression { + size: array.size, + ty: Type::FieldElement, + inner: ArrayExpressionInner::Value( + (0..array.size) + .map(|i| { + FieldElementExpression::Member( + box StructExpression { + ty: members.clone(), + inner: StructExpressionInner::Select( + box array.clone(), + box FieldElementExpression::Number( + T::from(i), + ), + ), + }, + member_id.clone(), + ) + .into() + }) + .collect(), + ), + }; + self.flatten_select_expression::>( + symbols, + statements_flattened, + array, + index, + ) } - }) - .1; - // we also need the size of this member - let size = members - .iter() - .find(|(id, _)| *id == member_id) - .unwrap() - .1 - .get_primitive_count(); - self.layout.get(&id).unwrap()[offset..(offset + size)] - .into_iter() - .map(|i| i.clone().into()) - .collect() - } - _ => unimplemented!(), - } + Type::Boolean => { + let array = ArrayExpression { + size: array.size, + ty: Type::Boolean, + inner: ArrayExpressionInner::Value( + (0..array.size) + .map(|i| { + BooleanExpression::Member( + box StructExpression { + ty: members.clone(), + inner: StructExpressionInner::Select( + box array.clone(), + box FieldElementExpression::Number( + T::from(i), + ), + ), + }, + member_id.clone(), + ) + .into() + }) + .collect(), + ), + }; + self.flatten_select_expression::>( + symbols, + statements_flattened, + array, + index, + ) + } + Type::Struct(m) => { + let array = ArrayExpression { + size: array.size, + ty: Type::Struct(m.clone()), + inner: ArrayExpressionInner::Value( + (0..array.size) + .map(|i| { + StructExpression { + ty: m.clone(), + inner: StructExpressionInner::Member( + box StructExpression { + ty: members.clone(), + inner: StructExpressionInner::Select( + box array.clone(), + box FieldElementExpression::Number( + T::from(i), + ), + ), + }, + member_id.clone(), + ), + } + .into() + }) + .collect(), + ), + }; + self.flatten_select_expression::>( + symbols, + statements_flattened, + array, + index, + ) + } + Type::Array(box ty, size) => { + let array = ArrayExpression { + size: array.size, + ty: Type::Array(box ty.clone(), size), + inner: ArrayExpressionInner::Value( + (0..array.size) + .map(|i| { + ArrayExpression { + size, + ty: ty.clone(), + inner: ArrayExpressionInner::Member( + box StructExpression { + ty: members.clone(), + inner: StructExpressionInner::Select( + box array.clone(), + box FieldElementExpression::Number( + T::from(i), + ), + ), + }, + member_id.clone(), + ), + } + .into() + }) + .collect(), + ), + }; + self.flatten_select_expression::>( + symbols, + statements_flattened, + array, + index, + ) + } + } + } + StructExpressionInner::FunctionCall(..) => unreachable!(), + StructExpressionInner::IfElse(box condition, box consequence, box alternative) => { + // if the struct is `(if c then a else b)`, we want to access `(if c then a else b).member` + // we reduce to `if c then a.member else b.member` + let ty = members + .clone() + .into_iter() + .find(|(id, _)| *id == member_id) + .unwrap() + .1; + + match ty { + Type::FieldElement => self.flatten_if_else_expression( + symbols, + statements_flattened, + condition.clone(), + FieldElementExpression::member(consequence.clone(), member_id.clone()), + FieldElementExpression::member(alternative.clone(), member_id), + ), + Type::Boolean => self.flatten_if_else_expression( + symbols, + statements_flattened, + condition.clone(), + BooleanExpression::member(consequence.clone(), member_id.clone()), + BooleanExpression::member(alternative.clone(), member_id), + ), + Type::Struct(m) => self.flatten_if_else_expression( + symbols, + statements_flattened, + condition.clone(), + StructExpression::member(consequence.clone(), member_id.clone()), + StructExpression::member(alternative.clone(), member_id), + ), + Type::Array(box ty, size) => self.flatten_if_else_expression( + symbols, + statements_flattened, + condition.clone(), + ArrayExpression::member(consequence.clone(), member_id.clone()), + ArrayExpression::member(alternative.clone(), member_id), + ), + } + } + StructExpressionInner::Member(box s0, m_id) => { + let e = self.flatten_member_expression(symbols, statements_flattened, s0, m_id); + + let offset = members + .iter() + .take_while(|(id, _)| *id != member_id) + .map(|(_, ty)| ty.get_primitive_count()) + .sum(); + + // we also need the size of this member + let size = members + .iter() + .find(|(id, _)| *id == member_id) + .unwrap() + .1 + .get_primitive_count(); + + e[offset..(offset + size)].into() + } + }; + + assert_eq!(res.len(), expected_output_size); + res } fn flatten_select_expression>( @@ -362,7 +641,13 @@ impl<'ast, T: Field> Flattener<'ast, T> { ) .flatten(self, symbols, statements_flattened) } - ArrayExpressionInner::Member(box s, id) => unimplemented!(), + ArrayExpressionInner::Member(box s, id) => { + assert!(n < T::from(size)); + let n = n.to_dec_string().parse::().unwrap(); + self.flatten_member_expression(symbols, statements_flattened, s, id) + [n * ty.get_primitive_count()..(n + 1) * ty.get_primitive_count()] + .to_vec() + } ArrayExpressionInner::Select(box array, box index) => { assert!(n < T::from(size)); let n = n.to_dec_string().parse::().unwrap(); @@ -755,14 +1040,9 @@ impl<'ast, T: Field> Flattener<'ast, T> { alternative, )[0] .clone(), - BooleanExpression::Member(box s, id) => self - .flatten_member_expression::>( - symbols, - statements_flattened, - s, - id, - )[0] - .clone(), + BooleanExpression::Member(box s, id) => { + self.flatten_member_expression(symbols, statements_flattened, s, id)[0].clone() + } BooleanExpression::Select(box array, box index) => self .flatten_select_expression::>( symbols, @@ -1150,14 +1430,9 @@ impl<'ast, T: Field> Flattener<'ast, T> { assert!(exprs_flattened.expressions.len() == 1); // outside of MultipleDefinition, FunctionCalls must return a single value exprs_flattened.expressions[0].clone() } - FieldElementExpression::Member(box s, id) => self - .flatten_member_expression::>( - symbols, - statements_flattened, - s, - id, - )[0] - .clone(), + FieldElementExpression::Member(box s, id) => { + self.flatten_member_expression(symbols, statements_flattened, s, id)[0].clone() + } FieldElementExpression::Select(box array, box index) => self .flatten_select_expression::>( symbols, @@ -1176,9 +1451,10 @@ impl<'ast, T: Field> Flattener<'ast, T> { expr: StructExpression<'ast, T>, ) -> Vec> { let ty = expr.get_type(); - //assert_eq!(U::get_type(), inner_type); + let expected_output_size = expr.get_type().get_primitive_count(); + let members = expr.ty; - match expr.inner { + let res = match expr.inner { StructExpressionInner::Identifier(x) => self .layout .get(&x) @@ -1186,16 +1462,10 @@ impl<'ast, T: Field> Flattener<'ast, T> { .iter() .map(|v| FlatExpression::Identifier(v.clone())) .collect(), - // StructExpressionInner::Value(values) => { - // values - // .into_iter() - // .flat_map(|v| { - // U::try_from(v) - // .unwrap() - // .flatten(self, symbols, statements_flattened) - // }) - // .collect() - // } + StructExpressionInner::Value(values) => values + .into_iter() + .flat_map(|v| self.flatten_expression(symbols, statements_flattened, v)) + .collect(), StructExpressionInner::FunctionCall(key, param_expressions) => { let exprs_flattened = self.flatten_function_call( symbols, @@ -1206,30 +1476,40 @@ impl<'ast, T: Field> Flattener<'ast, T> { ); exprs_flattened.expressions } - // StructExpressionInner::IfElse(ref condition, ref consequence, ref alternative) => (0 - // ..size) - // .flat_map(|i| { - // U::if_else( - // *condition.clone(), - // U::select( - // *consequence.clone(), - // FieldElementExpression::Number(T::from(i)), - // ), - // U::select( - // *alternative.clone(), - // FieldElementExpression::Number(T::from(i)), - // ), - // ) - // .flatten(self, symbols, statements_flattened) - // }) - // .collect(), - StructExpressionInner::Member(box s, id) => self - .flatten_member_expression::>( - symbols, - statements_flattened, - s, - id, - ), + StructExpressionInner::IfElse(box condition, box consequence, box alternative) => { + members + .into_iter() + .flat_map(|(id, ty)| match ty { + Type::FieldElement => FieldElementExpression::if_else( + condition.clone(), + FieldElementExpression::member(consequence.clone(), id.clone()), + FieldElementExpression::member(alternative.clone(), id.clone()), + ) + .flatten(self, symbols, statements_flattened), + Type::Boolean => BooleanExpression::if_else( + condition.clone(), + BooleanExpression::member(consequence.clone(), id.clone()), + BooleanExpression::member(alternative.clone(), id.clone()), + ) + .flatten(self, symbols, statements_flattened), + Type::Struct(..) => StructExpression::if_else( + condition.clone(), + StructExpression::member(consequence.clone(), id.clone()), + StructExpression::member(alternative.clone(), id.clone()), + ) + .flatten(self, symbols, statements_flattened), + Type::Array(..) => ArrayExpression::if_else( + condition.clone(), + ArrayExpression::member(consequence.clone(), id.clone()), + ArrayExpression::member(alternative.clone(), id.clone()), + ) + .flatten(self, symbols, statements_flattened), + }) + .collect() + } + StructExpressionInner::Member(box s, id) => { + self.flatten_member_expression(symbols, statements_flattened, s, id) + } StructExpressionInner::Select(box array, box index) => self .flatten_select_expression::>( symbols, @@ -1237,8 +1517,10 @@ impl<'ast, T: Field> Flattener<'ast, T> { array, index, ), - _ => unimplemented!("yeah well"), - } + }; + + assert_eq!(res.len(), expected_output_size); + res } /// # Remarks @@ -1300,13 +1582,9 @@ impl<'ast, T: Field> Flattener<'ast, T> { .flatten(self, symbols, statements_flattened) }) .collect(), - ArrayExpressionInner::Member(box s, id) => self - .flatten_member_expression::>( - symbols, - statements_flattened, - s, - id, - ), + ArrayExpressionInner::Member(box s, id) => { + self.flatten_member_expression(symbols, statements_flattened, s, id) + } ArrayExpressionInner::Select(box array, box index) => self .flatten_select_expression::>( symbols, diff --git a/zokrates_core/src/semantics.rs b/zokrates_core/src/semantics.rs index bb979727..8d6f69f0 100644 --- a/zokrates_core/src/semantics.rs +++ b/zokrates_core/src/semantics.rs @@ -1072,6 +1072,16 @@ impl<'ast> Checker<'ast> { unimplemented!("handle consequence alternative inner type mismatch") } }, + (TypedExpression::Struct(consequence), TypedExpression::Struct(alternative)) => { + if consequence.get_type() == alternative.get_type() { + Ok(StructExpression { + ty: consequence.ty.clone(), + inner: StructExpressionInner::IfElse(box condition, box consequence, box alternative) + }.into()) + } else { + unimplemented!("handle consequence alternative inner type mismatch") + } + }, _ => unimplemented!() } false => Err(Error { @@ -1367,8 +1377,8 @@ impl<'ast> Checker<'ast> { // check that the struct has that field and return the type if it does let ty = s.ty.iter() - .find(|(member_id, ty)| member_id == id) - .map(|(member_id, ty)| ty); + .find(|(member_id, _)| member_id == id) + .map(|(_, ty)| ty); match ty { Some(ty) => match ty { diff --git a/zokrates_core/src/typed_absy/types.rs b/zokrates_core/src/typed_absy/types.rs index 249b7e1e..e3a49e96 100644 --- a/zokrates_core/src/typed_absy/types.rs +++ b/zokrates_core/src/typed_absy/types.rs @@ -1,12 +1,9 @@ -use absy::UnresolvedTypeNode; use std::fmt; pub type Identifier<'ast> = &'ast str; pub type MemberId = String; -pub type UserTypeId = String; - #[derive(Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] pub enum Type { FieldElement, diff --git a/zokrates_core/src/typed_absy/variable.rs b/zokrates_core/src/typed_absy/variable.rs index d28646b3..f38236c3 100644 --- a/zokrates_core/src/typed_absy/variable.rs +++ b/zokrates_core/src/typed_absy/variable.rs @@ -1,4 +1,3 @@ -use crate::absy; use crate::typed_absy::types::Type; use crate::typed_absy::Identifier; use std::fmt; diff --git a/zokrates_parser/src/zokrates.pest b/zokrates_parser/src/zokrates.pest index a2bcea3d..e1ca18d2 100644 --- a/zokrates_parser/src/zokrates.pest +++ b/zokrates_parser/src/zokrates.pest @@ -18,7 +18,8 @@ parameter = {vis? ~ ty ~ identifier} ty_field = {"field"} ty_bool = {"bool"} ty_basic = { ty_field | ty_bool } -ty_array = { ty_basic ~ ("[" ~ expression ~ "]")+ } +ty_basic_or_struct = { ty_basic | ty_struct } +ty_array = { ty_basic_or_struct ~ ("[" ~ expression ~ "]")+ } ty = { ty_array | ty_basic | ty_struct } type_list = _{(ty ~ ("," ~ ty)*)?} // structs diff --git a/zokrates_pest_ast/src/lib.rs b/zokrates_pest_ast/src/lib.rs index 0d8a4c59..2033b671 100644 --- a/zokrates_pest_ast/src/lib.rs +++ b/zokrates_pest_ast/src/lib.rs @@ -9,12 +9,13 @@ extern crate lazy_static; pub use ast::{ Access, ArrayAccess, ArrayInitializerExpression, ArrayType, AssertionStatement, Assignee, - AssignmentStatement, BasicType, BinaryExpression, BinaryOperator, CallAccess, - ConstantExpression, DefinitionStatement, Expression, File, FromExpression, Function, - IdentifierExpression, ImportDirective, ImportSource, InlineArrayExpression, IterationStatement, - MultiAssignmentStatement, Parameter, PostfixExpression, Range, RangeOrExpression, - ReturnStatement, Span, Spread, SpreadOrExpression, Statement, StructDefinition, StructField, - TernaryExpression, ToExpression, Type, UnaryExpression, UnaryOperator, Visibility, + AssignmentStatement, BasicOrStructType, BasicType, BinaryExpression, BinaryOperator, + CallAccess, ConstantExpression, DefinitionStatement, Expression, File, FromExpression, + Function, IdentifierExpression, ImportDirective, ImportSource, InlineArrayExpression, + IterationStatement, MultiAssignmentStatement, Parameter, PostfixExpression, Range, + RangeOrExpression, ReturnStatement, Span, Spread, SpreadOrExpression, Statement, + StructDefinition, StructField, TernaryExpression, ToExpression, Type, UnaryExpression, + UnaryOperator, Visibility, }; mod ast { @@ -251,12 +252,19 @@ mod ast { #[derive(Debug, FromPest, PartialEq, Clone)] #[pest_ast(rule(Rule::ty_array))] pub struct ArrayType<'ast> { - pub ty: BasicType<'ast>, + pub ty: BasicOrStructType<'ast>, pub size: Vec>, #[pest_ast(outer())] pub span: Span<'ast>, } + #[derive(Debug, FromPest, PartialEq, Clone)] + #[pest_ast(rule(Rule::ty_basic_or_struct))] + pub enum BasicOrStructType<'ast> { + Struct(StructType<'ast>), + Basic(BasicType<'ast>), + } + #[derive(Debug, FromPest, PartialEq, Clone)] #[pest_ast(rule(Rule::ty_bool))] pub struct BooleanType<'ast> {