diff --git a/changelogs/unreleased/1010-dark64 b/changelogs/unreleased/1010-dark64 new file mode 100644 index 00000000..6aa232ca --- /dev/null +++ b/changelogs/unreleased/1010-dark64 @@ -0,0 +1 @@ +Introduce ternary operator \ No newline at end of file diff --git a/zokrates_book/src/language/control_flow.md b/zokrates_book/src/language/control_flow.md index 9e47d289..38a57465 100644 --- a/zokrates_book/src/language/control_flow.md +++ b/zokrates_book/src/language/control_flow.md @@ -26,6 +26,12 @@ An if-expression allows you to branch your code depending on a boolean condition {{#include ../../../zokrates_cli/examples/book/if_else.zok}} ``` +The conditional expression can also be written using a ternary operator: + +```zokrates +{{#include ../../../zokrates_cli/examples/book/ternary.zok}} +``` + There are two important caveats when it comes to conditional expressions. Before we go into them, let's define two concepts: - for an execution of the program, *an executed branch* is a branch which has to be paid for when executing the program, generating proofs, etc. - for an execution of the program, *a logically executed branch* is a branch which is "chosen" by the condition of an if-expression. This is the more intuitive notion of execution, and there is only one for each if-expression. diff --git a/zokrates_book/src/language/operators.md b/zokrates_book/src/language/operators.md index b82d59b7..7f013b22 100644 --- a/zokrates_book/src/language/operators.md +++ b/zokrates_book/src/language/operators.md @@ -16,10 +16,12 @@ The following table lists the precedence and associativity of all operators. Ope | `!=`
`==`
| Not Equal
Equal
| ✓ | ✓ | ✓ | Left | | | `&&` | Boolean AND |   |   | ✓ | Left | | | || | Boolean OR |   |   | ✓ | Left | | -| `if c then x else y fi` | Conditional expression | ✓ | ✓ | ✓ | Right | | +| `if c then x else y fi` | Conditional expression | ✓ | ✓ | ✓ | Right | [^4] | [^1]: The exponent must be a compile-time constant of type `u32` [^2]: The right operand must be a compile time constant of type `u32` -[^3]: Both operands are asserted to be strictly lower than the biggest power of 2 lower than `p/2`, unless one of them can be determined to be a compile-time constant \ No newline at end of file +[^3]: Both operands are asserted to be strictly lower than the biggest power of 2 lower than `p/2`, unless one of them can be determined to be a compile-time constant + +[^4]: Conditional expression can also be written using a ternary operator: `c ? x : y` \ No newline at end of file diff --git a/zokrates_cli/examples/book/ternary.zok b/zokrates_cli/examples/book/ternary.zok new file mode 100644 index 00000000..c90ecc5c --- /dev/null +++ b/zokrates_cli/examples/book/ternary.zok @@ -0,0 +1,3 @@ +def main(field x) -> field: + field y = x + 2 == 3 ? 1 : 5 + return y \ No newline at end of file diff --git a/zokrates_cli/examples/compile_errors/ternary_precedence.zok b/zokrates_cli/examples/compile_errors/ternary_precedence.zok new file mode 100644 index 00000000..180b081e --- /dev/null +++ b/zokrates_cli/examples/compile_errors/ternary_precedence.zok @@ -0,0 +1,5 @@ +def main(bool a) -> field: + // ternary expression should be wrapped in parentheses 1 + (a ? 2 : 3) + // otherwise the whole expression is parsed as (1 + a) ? 2 : 3 + field x = 1 + a ? 2 : 3 + return x \ No newline at end of file diff --git a/zokrates_core/src/absy/from_ast.rs b/zokrates_core/src/absy/from_ast.rs index 64d723f0..81f0b903 100644 --- a/zokrates_core/src/absy/from_ast.rs +++ b/zokrates_core/src/absy/from_ast.rs @@ -378,6 +378,7 @@ impl<'ast> From> for absy::ExpressionNode<'ast> { match expression { pest::Expression::Binary(e) => absy::ExpressionNode::from(e), pest::Expression::Ternary(e) => absy::ExpressionNode::from(e), + pest::Expression::IfElse(e) => absy::ExpressionNode::from(e), pest::Expression::Literal(e) => absy::ExpressionNode::from(e), pest::Expression::Identifier(e) => absy::ExpressionNode::from(e), pest::Expression::Postfix(e) => absy::ExpressionNode::from(e), @@ -478,13 +479,27 @@ impl<'ast> From> for absy::ExpressionNode<'ast> { } } +impl<'ast> From> for absy::ExpressionNode<'ast> { + fn from(expression: pest::IfElseExpression<'ast>) -> absy::ExpressionNode<'ast> { + use crate::absy::NodeValue; + absy::Expression::Conditional( + box absy::ExpressionNode::from(*expression.condition), + box absy::ExpressionNode::from(*expression.consequence), + box absy::ExpressionNode::from(*expression.alternative), + absy::ConditionalKind::IfElse, + ) + .span(expression.span) + } +} + impl<'ast> From> for absy::ExpressionNode<'ast> { fn from(expression: pest::TernaryExpression<'ast>) -> absy::ExpressionNode<'ast> { use crate::absy::NodeValue; - absy::Expression::IfElse( - box absy::ExpressionNode::from(*expression.first), - box absy::ExpressionNode::from(*expression.second), - box absy::ExpressionNode::from(*expression.third), + absy::Expression::Conditional( + box absy::ExpressionNode::from(*expression.condition), + box absy::ExpressionNode::from(*expression.consequence), + box absy::ExpressionNode::from(*expression.alternative), + absy::ConditionalKind::Ternary, ) .span(expression.span) } diff --git a/zokrates_core/src/absy/mod.rs b/zokrates_core/src/absy/mod.rs index c709db28..e6c8ec69 100644 --- a/zokrates_core/src/absy/mod.rs +++ b/zokrates_core/src/absy/mod.rs @@ -460,6 +460,12 @@ impl<'ast> fmt::Display for Range<'ast> { } } +#[derive(Debug, Clone, PartialEq)] +pub enum ConditionalKind { + IfElse, + Ternary, +} + /// An expression #[derive(Debug, Clone, PartialEq)] pub enum Expression<'ast> { @@ -479,10 +485,11 @@ pub enum Expression<'ast> { Pow(Box>, Box>), Neg(Box>), Pos(Box>), - IfElse( + Conditional( Box>, Box>, Box>, + ConditionalKind, ), FunctionCall( Box>, @@ -530,11 +537,18 @@ impl<'ast> fmt::Display for Expression<'ast> { Expression::Neg(ref e) => write!(f, "(-{})", e), Expression::Pos(ref e) => write!(f, "(+{})", e), Expression::BooleanConstant(b) => write!(f, "{}", b), - Expression::IfElse(ref condition, ref consequent, ref alternative) => write!( - f, - "if {} then {} else {} fi", - condition, consequent, alternative - ), + Expression::Conditional(ref condition, ref consequent, ref alternative, ref kind) => { + match kind { + ConditionalKind::IfElse => write!( + f, + "if {} then {} else {} fi", + condition, consequent, alternative + ), + ConditionalKind::Ternary => { + write!(f, "{} ? {} : {}", condition, consequent, alternative) + } + } + } Expression::FunctionCall(ref i, ref g, ref p) => { if let Some(g) = g { write!( diff --git a/zokrates_core/src/semantics.rs b/zokrates_core/src/semantics.rs index 0cbd7eaf..b0041ef0 100644 --- a/zokrates_core/src/semantics.rs +++ b/zokrates_core/src/semantics.rs @@ -2270,7 +2270,7 @@ impl<'ast, T: Field> Checker<'ast, T> { }), } } - Expression::IfElse(box condition, box consequence, box alternative) => { + Expression::Conditional(box condition, box consequence, box alternative, kind) => { let condition_checked = self.check_expression(condition, module_id, types)?; let consequence_checked = self.check_expression(consequence, module_id, types)?; let alternative_checked = self.check_expression(alternative, module_id, types)?; @@ -2282,40 +2282,49 @@ impl<'ast, T: Field> Checker<'ast, T> { ) .map_err(|(e1, e2)| ErrorInner { pos: Some(pos), - message: format!("{{consequence}} and {{alternative}} in `if/else` expression should have the same type, found {}, {}", e1.get_type(), e2.get_type()), + message: format!("{{consequence}} and {{alternative}} in conditional expression should have the same type, found {}, {}", e1.get_type(), e2.get_type()), })?; + let kind = match kind { + crate::absy::ConditionalKind::IfElse => { + crate::typed_absy::ConditionalKind::IfElse + } + crate::absy::ConditionalKind::Ternary => { + crate::typed_absy::ConditionalKind::Ternary + } + }; + match condition_checked { TypedExpression::Boolean(condition) => { match (consequence_checked, alternative_checked) { (TypedExpression::FieldElement(consequence), TypedExpression::FieldElement(alternative)) => { - Ok(FieldElementExpression::if_else(condition, consequence, alternative).into()) + Ok(FieldElementExpression::conditional(condition, consequence, alternative, kind).into()) }, (TypedExpression::Boolean(consequence), TypedExpression::Boolean(alternative)) => { - Ok(BooleanExpression::if_else(condition, consequence, alternative).into()) + Ok(BooleanExpression::conditional(condition, consequence, alternative, kind).into()) }, (TypedExpression::Array(consequence), TypedExpression::Array(alternative)) => { - Ok(ArrayExpression::if_else(condition, consequence, alternative).into()) + Ok(ArrayExpression::conditional(condition, consequence, alternative, kind).into()) }, (TypedExpression::Struct(consequence), TypedExpression::Struct(alternative)) => { - Ok(StructExpression::if_else(condition, consequence, alternative).into()) + Ok(StructExpression::conditional(condition, consequence, alternative, kind).into()) }, (TypedExpression::Uint(consequence), TypedExpression::Uint(alternative)) => { - Ok(UExpression::if_else(condition, consequence, alternative).into()) + Ok(UExpression::conditional(condition, consequence, alternative, kind).into()) }, (TypedExpression::Int(consequence), TypedExpression::Int(alternative)) => { - Ok(IntExpression::if_else(condition, consequence, alternative).into()) + Ok(IntExpression::conditional(condition, consequence, alternative, kind).into()) }, (c, a) => Err(ErrorInner { pos: Some(pos), - message: format!("{{consequence}} and {{alternative}} in `if/else` expression should have the same type, found {}, {}", c.get_type(), a.get_type()) + message: format!("{{consequence}} and {{alternative}} in conditional expression should have the same type, found {}, {}", c.get_type(), a.get_type()) }) } } c => Err(ErrorInner { pos: Some(pos), message: format!( - "{{condition}} after `if` should be a boolean, found {}", + "{{condition}} should be a boolean, found {}", c.get_type() ), }), diff --git a/zokrates_core/src/static_analysis/branch_isolator.rs b/zokrates_core/src/static_analysis/branch_isolator.rs index c9d848d8..77580386 100644 --- a/zokrates_core/src/static_analysis/branch_isolator.rs +++ b/zokrates_core/src/static_analysis/branch_isolator.rs @@ -17,17 +17,18 @@ impl Isolator { } impl<'ast, T: Field> Folder<'ast, T> for Isolator { - fn fold_if_else_expression< - E: Expr<'ast, T> + Block<'ast, T> + Fold<'ast, T> + IfElse<'ast, T>, + fn fold_conditional_expression< + E: Expr<'ast, T> + Block<'ast, T> + Fold<'ast, T> + Conditional<'ast, T>, >( &mut self, _: &E::Ty, - e: IfElseExpression<'ast, T, E>, - ) -> IfElseOrExpression<'ast, T, E> { - IfElseOrExpression::IfElse(IfElseExpression::new( + e: ConditionalExpression<'ast, T, E>, + ) -> ConditionalOrExpression<'ast, T, E> { + ConditionalOrExpression::Conditional(ConditionalExpression::new( self.fold_boolean_expression(*e.condition), E::block(vec![], e.consequence.fold(self)), E::block(vec![], e.alternative.fold(self)), + e.kind, )) } } diff --git a/zokrates_core/src/static_analysis/flatten_complex_types.rs b/zokrates_core/src/static_analysis/flatten_complex_types.rs index 1986efde..7709fbb6 100644 --- a/zokrates_core/src/static_analysis/flatten_complex_types.rs +++ b/zokrates_core/src/static_analysis/flatten_complex_types.rs @@ -281,12 +281,12 @@ impl<'ast, T: Field> Flattener { } } - fn fold_if_else_expression>( + fn fold_conditional_expression>( &mut self, statements_buffer: &mut Vec>, - c: typed_absy::IfElseExpression<'ast, T, E>, + c: typed_absy::ConditionalExpression<'ast, T, E>, ) -> Vec> { - fold_if_else_expression(self, statements_buffer, c) + fold_conditional_expression(self, statements_buffer, c) } fn fold_member_expression( @@ -448,8 +448,8 @@ fn fold_array_expression_inner<'ast, T: Field>( exprs } typed_absy::ArrayExpressionInner::FunctionCall(..) => unreachable!(), - typed_absy::ArrayExpressionInner::IfElse(c) => { - f.fold_if_else_expression(statements_buffer, c) + typed_absy::ArrayExpressionInner::Conditional(c) => { + f.fold_conditional_expression(statements_buffer, c) } typed_absy::ArrayExpressionInner::Member(m) => { f.fold_member_expression(statements_buffer, m) @@ -523,8 +523,8 @@ fn fold_struct_expression_inner<'ast, T: Field>( .flat_map(|e| f.fold_expression(statements_buffer, e)) .collect(), typed_absy::StructExpressionInner::FunctionCall(..) => unreachable!(), - typed_absy::StructExpressionInner::IfElse(c) => { - f.fold_if_else_expression(statements_buffer, c) + typed_absy::StructExpressionInner::Conditional(c) => { + f.fold_conditional_expression(statements_buffer, c) } typed_absy::StructExpressionInner::Member(m) => { f.fold_member_expression(statements_buffer, m) @@ -646,10 +646,10 @@ fn fold_select_expression<'ast, T: Field, E>( } } -fn fold_if_else_expression<'ast, T: Field, E: Flatten<'ast, T>>( +fn fold_conditional_expression<'ast, T: Field, E: Flatten<'ast, T>>( f: &mut Flattener, statements_buffer: &mut Vec>, - c: typed_absy::IfElseExpression<'ast, T, E>, + c: typed_absy::ConditionalExpression<'ast, T, E>, ) -> Vec> { let mut consequence_statements = vec![]; let mut alternative_statements = vec![]; @@ -742,8 +742,8 @@ fn fold_field_expression<'ast, T: Field>( typed_absy::FieldElementExpression::Pos(box e) => { f.fold_field_expression(statements_buffer, e) } - typed_absy::FieldElementExpression::IfElse(c) => f - .fold_if_else_expression(statements_buffer, c) + typed_absy::FieldElementExpression::Conditional(c) => f + .fold_conditional_expression(statements_buffer, c) .pop() .unwrap() .try_into() @@ -908,8 +908,8 @@ fn fold_boolean_expression<'ast, T: Field>( let e = f.fold_boolean_expression(statements_buffer, e); zir::BooleanExpression::Not(box e) } - typed_absy::BooleanExpression::IfElse(c) => f - .fold_if_else_expression(statements_buffer, c) + typed_absy::BooleanExpression::Conditional(c) => f + .fold_conditional_expression(statements_buffer, c) .pop() .unwrap() .try_into() @@ -1070,8 +1070,8 @@ fn fold_uint_expression_inner<'ast, T: Field>( ) .unwrap() .into_inner(), - typed_absy::UExpressionInner::IfElse(c) => zir::UExpression::try_from( - f.fold_if_else_expression(statements_buffer, c) + typed_absy::UExpressionInner::Conditional(c) => zir::UExpression::try_from( + f.fold_conditional_expression(statements_buffer, c) .pop() .unwrap(), ) diff --git a/zokrates_core/src/static_analysis/propagation.rs b/zokrates_core/src/static_analysis/propagation.rs index 8c9c6e8a..14e4520a 100644 --- a/zokrates_core/src/static_analysis/propagation.rs +++ b/zokrates_core/src/static_analysis/propagation.rs @@ -172,13 +172,13 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Propagator<'ast, 'a, T> { fold_function(self, f) } - fn fold_if_else_expression< - E: Expr<'ast, T> + IfElse<'ast, T> + PartialEq + ResultFold<'ast, T>, + fn fold_conditional_expression< + E: Expr<'ast, T> + Conditional<'ast, T> + PartialEq + ResultFold<'ast, T>, >( &mut self, _: &E::Ty, - e: IfElseExpression<'ast, T, E>, - ) -> Result, Self::Error> { + e: ConditionalExpression<'ast, T, E>, + ) -> Result, Self::Error> { Ok( match ( self.fold_boolean_expression(*e.condition)?, @@ -186,16 +186,16 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Propagator<'ast, 'a, T> { e.alternative.fold(self)?, ) { (BooleanExpression::Value(true), consequence, _) => { - IfElseOrExpression::Expression(consequence.into_inner()) + ConditionalOrExpression::Expression(consequence.into_inner()) } (BooleanExpression::Value(false), _, alternative) => { - IfElseOrExpression::Expression(alternative.into_inner()) + ConditionalOrExpression::Expression(alternative.into_inner()) } (_, consequence, alternative) if consequence == alternative => { - IfElseOrExpression::Expression(consequence.into_inner()) + ConditionalOrExpression::Expression(consequence.into_inner()) } - (condition, consequence, alternative) => IfElseOrExpression::IfElse( - IfElseExpression::new(condition, consequence, alternative), + (condition, consequence, alternative) => ConditionalOrExpression::Conditional( + ConditionalExpression::new(condition, consequence, alternative, e.kind), ), }, ) @@ -1431,10 +1431,11 @@ mod tests { #[test] fn if_else_true() { - let e = FieldElementExpression::if_else( + let e = FieldElementExpression::conditional( BooleanExpression::Value(true), FieldElementExpression::Number(Bn128Field::from(2)), FieldElementExpression::Number(Bn128Field::from(3)), + ConditionalKind::IfElse, ); assert_eq!( @@ -1445,10 +1446,11 @@ mod tests { #[test] fn if_else_false() { - let e = FieldElementExpression::if_else( + let e = FieldElementExpression::conditional( BooleanExpression::Value(false), FieldElementExpression::Number(Bn128Field::from(2)), FieldElementExpression::Number(Bn128Field::from(3)), + ConditionalKind::IfElse, ); assert_eq!( diff --git a/zokrates_core/src/static_analysis/variable_write_remover.rs b/zokrates_core/src/static_analysis/variable_write_remover.rs index bce07349..d5522641 100644 --- a/zokrates_core/src/static_analysis/variable_write_remover.rs +++ b/zokrates_core/src/static_analysis/variable_write_remover.rs @@ -56,7 +56,7 @@ impl<'ast> VariableWriteRemover { (0..size) .map(|i| match inner_ty { Type::Int => unreachable!(), - Type::Array(..) => ArrayExpression::if_else( + Type::Array(..) => ArrayExpression::conditional( BooleanExpression::UintEq( box i.into(), box head.clone(), @@ -74,9 +74,10 @@ impl<'ast> VariableWriteRemover { ), }, ArrayExpression::select(base.clone(), i), + ConditionalKind::IfElse, ) .into(), - Type::Struct(..) => StructExpression::if_else( + Type::Struct(..) => StructExpression::conditional( BooleanExpression::UintEq( box i.into(), box head.clone(), @@ -94,9 +95,10 @@ impl<'ast> VariableWriteRemover { ), }, StructExpression::select(base.clone(), i), + ConditionalKind::IfElse, ) .into(), - Type::FieldElement => FieldElementExpression::if_else( + Type::FieldElement => FieldElementExpression::conditional( BooleanExpression::UintEq( box i.into(), box head.clone(), @@ -115,9 +117,10 @@ impl<'ast> VariableWriteRemover { ), }, FieldElementExpression::select(base.clone(), i), + ConditionalKind::IfElse, ) .into(), - Type::Boolean => BooleanExpression::if_else( + Type::Boolean => BooleanExpression::conditional( BooleanExpression::UintEq( box i.into(), box head.clone(), @@ -135,9 +138,10 @@ impl<'ast> VariableWriteRemover { ), }, BooleanExpression::select(base.clone(), i), + ConditionalKind::IfElse, ) .into(), - Type::Uint(..) => UExpression::if_else( + Type::Uint(..) => UExpression::conditional( BooleanExpression::UintEq( box i.into(), box head.clone(), @@ -155,6 +159,7 @@ impl<'ast> VariableWriteRemover { ), }, UExpression::select(base.clone(), i), + ConditionalKind::IfElse, ) .into(), }) diff --git a/zokrates_core/src/typed_absy/folder.rs b/zokrates_core/src/typed_absy/folder.rs index 6a60e65a..3447f664 100644 --- a/zokrates_core/src/typed_absy/folder.rs +++ b/zokrates_core/src/typed_absy/folder.rs @@ -284,18 +284,18 @@ pub trait Folder<'ast, T: Field>: Sized { fold_block_expression(self, block) } - fn fold_if_else_expression< + fn fold_conditional_expression< E: Expr<'ast, T> + Fold<'ast, T> + Block<'ast, T> - + IfElse<'ast, T> + + Conditional<'ast, T> + From>, >( &mut self, ty: &E::Ty, - e: IfElseExpression<'ast, T, E>, - ) -> IfElseOrExpression<'ast, T, E> { - fold_if_else_expression(self, ty, e) + e: ConditionalExpression<'ast, T, E>, + ) -> ConditionalOrExpression<'ast, T, E> { + fold_conditional_expression(self, ty, e) } fn fold_member_expression< @@ -319,7 +319,7 @@ pub trait Folder<'ast, T: Field>: Sized { } fn fold_select_expression< - E: Expr<'ast, T> + Select<'ast, T> + IfElse<'ast, T> + From>, + E: Expr<'ast, T> + Select<'ast, T> + Conditional<'ast, T> + From>, >( &mut self, ty: &E::Ty, @@ -506,9 +506,9 @@ pub fn fold_array_expression_inner<'ast, T: Field, F: Folder<'ast, T>>( FunctionCallOrExpression::Expression(u) => u, } } - ArrayExpressionInner::IfElse(c) => match f.fold_if_else_expression(ty, c) { - IfElseOrExpression::IfElse(s) => ArrayExpressionInner::IfElse(s), - IfElseOrExpression::Expression(u) => u, + ArrayExpressionInner::Conditional(c) => match f.fold_conditional_expression(ty, c) { + ConditionalOrExpression::Conditional(s) => ArrayExpressionInner::Conditional(s), + ConditionalOrExpression::Expression(u) => u, }, ArrayExpressionInner::Select(select) => match f.fold_select_expression(ty, select) { SelectOrExpression::Select(s) => ArrayExpressionInner::Select(s), @@ -553,9 +553,9 @@ pub fn fold_struct_expression_inner<'ast, T: Field, F: Folder<'ast, T>>( FunctionCallOrExpression::Expression(u) => u, } } - StructExpressionInner::IfElse(c) => match f.fold_if_else_expression(ty, c) { - IfElseOrExpression::IfElse(s) => StructExpressionInner::IfElse(s), - IfElseOrExpression::Expression(u) => u, + StructExpressionInner::Conditional(c) => match f.fold_conditional_expression(ty, c) { + ConditionalOrExpression::Conditional(s) => StructExpressionInner::Conditional(s), + ConditionalOrExpression::Expression(u) => u, }, StructExpressionInner::Select(select) => match f.fold_select_expression(ty, select) { SelectOrExpression::Select(s) => StructExpressionInner::Select(s), @@ -615,10 +615,10 @@ pub fn fold_field_expression<'ast, T: Field, F: Folder<'ast, T>>( FieldElementExpression::Pos(box e) } - FieldElementExpression::IfElse(c) => { - match f.fold_if_else_expression(&Type::FieldElement, c) { - IfElseOrExpression::IfElse(s) => FieldElementExpression::IfElse(s), - IfElseOrExpression::Expression(u) => u, + FieldElementExpression::Conditional(c) => { + match f.fold_conditional_expression(&Type::FieldElement, c) { + ConditionalOrExpression::Conditional(s) => FieldElementExpression::Conditional(s), + ConditionalOrExpression::Expression(u) => u, } } FieldElementExpression::FunctionCall(function_call) => { @@ -643,20 +643,21 @@ pub fn fold_field_expression<'ast, T: Field, F: Folder<'ast, T>>( } } -pub fn fold_if_else_expression< +pub fn fold_conditional_expression< 'ast, T: Field, - E: Expr<'ast, T> + Fold<'ast, T> + IfElse<'ast, T> + From>, + E: Expr<'ast, T> + Fold<'ast, T> + Conditional<'ast, T> + From>, F: Folder<'ast, T>, >( f: &mut F, _: &E::Ty, - e: IfElseExpression<'ast, T, E>, -) -> IfElseOrExpression<'ast, T, E> { - IfElseOrExpression::IfElse(IfElseExpression::new( + e: ConditionalExpression<'ast, T, E>, +) -> ConditionalOrExpression<'ast, T, E> { + ConditionalOrExpression::Conditional(ConditionalExpression::new( f.fold_boolean_expression(*e.condition), e.consequence.fold(f), e.alternative.fold(f), + e.kind, )) } @@ -679,7 +680,7 @@ pub fn fold_member_expression< pub fn fold_select_expression< 'ast, T: Field, - E: Expr<'ast, T> + Select<'ast, T> + IfElse<'ast, T> + From>, + E: Expr<'ast, T> + Select<'ast, T> + Conditional<'ast, T> + From>, F: Folder<'ast, T>, >( f: &mut F, @@ -794,9 +795,10 @@ pub fn fold_boolean_expression<'ast, T: Field, F: Folder<'ast, T>>( FunctionCallOrExpression::Expression(u) => u, } } - BooleanExpression::IfElse(c) => match f.fold_if_else_expression(&Type::Boolean, c) { - IfElseOrExpression::IfElse(s) => BooleanExpression::IfElse(s), - IfElseOrExpression::Expression(u) => u, + BooleanExpression::Conditional(c) => match f.fold_conditional_expression(&Type::Boolean, c) + { + ConditionalOrExpression::Conditional(s) => BooleanExpression::Conditional(s), + ConditionalOrExpression::Expression(u) => u, }, BooleanExpression::Select(select) => match f.fold_select_expression(&Type::Boolean, select) { @@ -922,9 +924,9 @@ pub fn fold_uint_expression_inner<'ast, T: Field, F: Folder<'ast, T>>( SelectOrExpression::Select(s) => UExpressionInner::Select(s), SelectOrExpression::Expression(u) => u, }, - UExpressionInner::IfElse(c) => match f.fold_if_else_expression(&ty, c) { - IfElseOrExpression::IfElse(s) => UExpressionInner::IfElse(s), - IfElseOrExpression::Expression(u) => u, + UExpressionInner::Conditional(c) => match f.fold_conditional_expression(&ty, c) { + ConditionalOrExpression::Conditional(s) => UExpressionInner::Conditional(s), + ConditionalOrExpression::Expression(u) => u, }, UExpressionInner::Member(m) => match f.fold_member_expression(&ty, m) { MemberOrExpression::Member(m) => UExpressionInner::Member(m), diff --git a/zokrates_core/src/typed_absy/integer.rs b/zokrates_core/src/typed_absy/integer.rs index bd1b2444..63b3fb13 100644 --- a/zokrates_core/src/typed_absy/integer.rs +++ b/zokrates_core/src/typed_absy/integer.rs @@ -5,9 +5,10 @@ use crate::typed_absy::types::{ }; use crate::typed_absy::UBitwidth; use crate::typed_absy::{ - ArrayExpression, ArrayExpressionInner, BooleanExpression, Expr, FieldElementExpression, IfElse, - IfElseExpression, Select, SelectExpression, StructExpression, StructExpressionInner, Typed, - TypedExpression, TypedExpressionOrSpread, TypedSpread, UExpression, UExpressionInner, + ArrayExpression, ArrayExpressionInner, BooleanExpression, Conditional, ConditionalExpression, + Expr, FieldElementExpression, Select, SelectExpression, StructExpression, + StructExpressionInner, Typed, TypedExpression, TypedExpressionOrSpread, TypedSpread, + UExpression, UExpressionInner, }; use num_bigint::BigUint; use std::convert::TryFrom; @@ -239,7 +240,7 @@ pub enum IntExpression<'ast, T> { Div(Box>, Box>), Rem(Box>, Box>), Pow(Box>, Box>), - IfElse(IfElseExpression<'ast, T, IntExpression<'ast, T>>), + Conditional(ConditionalExpression<'ast, T, IntExpression<'ast, T>>), Select(SelectExpression<'ast, T, IntExpression<'ast, T>>), Xor(Box>, Box>), And(Box>, Box>), @@ -354,7 +355,7 @@ impl<'ast, T: fmt::Display> fmt::Display for IntExpression<'ast, T> { IntExpression::RightShift(ref e, ref by) => write!(f, "({} >> {})", e, by), IntExpression::LeftShift(ref e, ref by) => write!(f, "({} << {})", e, by), IntExpression::Not(ref e) => write!(f, "!{}", e), - IntExpression::IfElse(ref c) => write!(f, "{}", c), + IntExpression::Conditional(ref c) => write!(f, "{}", c), } } } @@ -404,10 +405,11 @@ impl<'ast, T: Field> FieldElementExpression<'ast, T> { )), IntExpression::Pos(box e) => Ok(Self::Pos(box Self::try_from_int(e)?)), IntExpression::Neg(box e) => Ok(Self::Neg(box Self::try_from_int(e)?)), - IntExpression::IfElse(c) => Ok(Self::IfElse(IfElseExpression::new( + IntExpression::Conditional(c) => Ok(Self::Conditional(ConditionalExpression::new( *c.condition, Self::try_from_int(*c.consequence)?, Self::try_from_int(*c.alternative)?, + c.kind, ))), IntExpression::Select(select) => { let array = *select.array; @@ -523,10 +525,11 @@ impl<'ast, T: Field> UExpression<'ast, T> { Self::try_from_int(e1, bitwidth)?, e2, )), - IfElse(c) => Ok(UExpression::if_else( + Conditional(c) => Ok(UExpression::conditional( *c.condition, Self::try_from_int(*c.consequence, bitwidth)?, Self::try_from_int(*c.alternative, bitwidth)?, + c.kind, )), Select(select) => { let array = *select.array; @@ -693,6 +696,7 @@ impl<'ast, T> From for IntExpression<'ast, T> { #[cfg(test)] mod tests { use super::*; + use crate::typed_absy::ConditionalKind; use zokrates_field::Bn128Field; #[test] @@ -714,7 +718,7 @@ mod tests { n.clone() * n.clone(), IntExpression::pow(n.clone(), n.clone()), n.clone() / n.clone(), - IntExpression::if_else(c.clone(), n.clone(), n.clone()), + IntExpression::conditional(c.clone(), n.clone(), n.clone(), ConditionalKind::IfElse), IntExpression::select(n_a.clone(), i.clone()), ]; @@ -725,7 +729,12 @@ mod tests { t.clone() * t.clone(), FieldElementExpression::pow(t.clone(), i.clone()), t.clone() / t.clone(), - FieldElementExpression::if_else(c.clone(), t.clone(), t.clone()), + FieldElementExpression::conditional( + c.clone(), + t.clone(), + t.clone(), + ConditionalKind::IfElse, + ), FieldElementExpression::select(t_a.clone(), i.clone()), ]; @@ -780,7 +789,7 @@ mod tests { IntExpression::left_shift(n.clone(), i.clone()), IntExpression::right_shift(n.clone(), i.clone()), !n.clone(), - IntExpression::if_else(c.clone(), n.clone(), n.clone()), + IntExpression::conditional(c.clone(), n.clone(), n.clone(), ConditionalKind::IfElse), IntExpression::select(n_a.clone(), i.clone()), ]; @@ -797,7 +806,7 @@ mod tests { UExpression::left_shift(t.clone(), i.clone()), UExpression::right_shift(t.clone(), i.clone()), !t.clone(), - UExpression::if_else(c.clone(), t.clone(), t.clone()), + UExpression::conditional(c.clone(), t.clone(), t.clone(), ConditionalKind::IfElse), UExpression::select(t_a.clone(), i.clone()), ]; diff --git a/zokrates_core/src/typed_absy/mod.rs b/zokrates_core/src/typed_absy/mod.rs index 6c970ff3..a37777af 100644 --- a/zokrates_core/src/typed_absy/mod.rs +++ b/zokrates_core/src/typed_absy/mod.rs @@ -792,7 +792,7 @@ impl<'ast, T: fmt::Display> fmt::Display for StructExpression<'ast, T> { StructExpressionInner::FunctionCall(ref function_call) => { write!(f, "{}", function_call) } - StructExpressionInner::IfElse(ref c) => write!(f, "{}", c), + StructExpressionInner::Conditional(ref c) => write!(f, "{}", c), StructExpressionInner::Member(ref m) => write!(f, "{}", m), StructExpressionInner::Select(ref select) => write!(f, "{}", select), } @@ -937,30 +937,50 @@ impl<'ast, T: fmt::Display, E> fmt::Display for SelectExpression<'ast, T, E> { } } -#[derive(Clone, PartialEq, Debug, Hash, Eq, PartialOrd, Ord)] -pub struct IfElseExpression<'ast, T, E> { +#[derive(Debug, Clone, PartialEq, Hash, Eq, PartialOrd, Ord)] +pub enum ConditionalKind { + IfElse, + Ternary, +} + +#[derive(Debug, Clone, PartialEq, Hash, Eq, PartialOrd, Ord)] +pub struct ConditionalExpression<'ast, T, E> { pub condition: Box>, pub consequence: Box, pub alternative: Box, + pub kind: ConditionalKind, } -impl<'ast, T, E> IfElseExpression<'ast, T, E> { - pub fn new(condition: BooleanExpression<'ast, T>, consequence: E, alternative: E) -> Self { - IfElseExpression { +impl<'ast, T, E> ConditionalExpression<'ast, T, E> { + pub fn new( + condition: BooleanExpression<'ast, T>, + consequence: E, + alternative: E, + kind: ConditionalKind, + ) -> Self { + ConditionalExpression { condition: box condition, consequence: box consequence, alternative: box alternative, + kind, } } } -impl<'ast, T: fmt::Display, E: fmt::Display> fmt::Display for IfElseExpression<'ast, T, E> { +impl<'ast, T: fmt::Display, E: fmt::Display> fmt::Display for ConditionalExpression<'ast, T, E> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!( - f, - "if {} then {} else {} fi", - self.condition, self.consequence, self.alternative - ) + match self.kind { + ConditionalKind::IfElse => write!( + f, + "if {} then {} else {} fi", + self.condition, self.consequence, self.alternative + ), + ConditionalKind::Ternary => write!( + f, + "{} ? {} : {}", + self.condition, self.consequence, self.alternative + ), + } } } @@ -1042,7 +1062,7 @@ pub enum FieldElementExpression<'ast, T> { Box>, Box>, ), - IfElse(IfElseExpression<'ast, T, Self>), + Conditional(ConditionalExpression<'ast, T, Self>), Neg(Box>), Pos(Box>), FunctionCall(FunctionCallExpression<'ast, T, Self>), @@ -1142,7 +1162,7 @@ pub enum BooleanExpression<'ast, T> { Box>, ), Not(Box>), - IfElse(IfElseExpression<'ast, T, Self>), + Conditional(ConditionalExpression<'ast, T, Self>), Member(MemberExpression<'ast, T, Self>), FunctionCall(FunctionCallExpression<'ast, T, Self>), Select(SelectExpression<'ast, T, Self>), @@ -1249,7 +1269,7 @@ pub enum ArrayExpressionInner<'ast, T> { Identifier(Identifier<'ast>), Value(ArrayValue<'ast, T>), FunctionCall(FunctionCallExpression<'ast, T, ArrayExpression<'ast, T>>), - IfElse(IfElseExpression<'ast, T, ArrayExpression<'ast, T>>), + Conditional(ConditionalExpression<'ast, T, ArrayExpression<'ast, T>>), Member(MemberExpression<'ast, T, ArrayExpression<'ast, T>>), Select(SelectExpression<'ast, T, ArrayExpression<'ast, T>>), Slice( @@ -1313,7 +1333,7 @@ pub enum StructExpressionInner<'ast, T> { Identifier(Identifier<'ast>), Value(Vec>), FunctionCall(FunctionCallExpression<'ast, T, StructExpression<'ast, T>>), - IfElse(IfElseExpression<'ast, T, StructExpression<'ast, T>>), + Conditional(ConditionalExpression<'ast, T, StructExpression<'ast, T>>), Member(MemberExpression<'ast, T, StructExpression<'ast, T>>), Select(SelectExpression<'ast, T, StructExpression<'ast, T>>), } @@ -1455,7 +1475,7 @@ impl<'ast, T: fmt::Display> fmt::Display for FieldElementExpression<'ast, T> { FieldElementExpression::Pow(ref lhs, ref rhs) => write!(f, "{}**{}", lhs, rhs), FieldElementExpression::Neg(ref e) => write!(f, "(-{})", e), FieldElementExpression::Pos(ref e) => write!(f, "(+{})", e), - FieldElementExpression::IfElse(ref c) => write!(f, "{}", c), + FieldElementExpression::Conditional(ref c) => write!(f, "{}", c), FieldElementExpression::FunctionCall(ref function_call) => { write!(f, "{}", function_call) } @@ -1489,7 +1509,7 @@ impl<'ast, T: fmt::Display> fmt::Display for UExpression<'ast, T> { UExpressionInner::Pos(ref e) => write!(f, "(+{})", e), UExpressionInner::Select(ref select) => write!(f, "{}", select), UExpressionInner::FunctionCall(ref function_call) => write!(f, "{}", function_call), - UExpressionInner::IfElse(ref c) => write!(f, "{}", c), + UExpressionInner::Conditional(ref c) => write!(f, "{}", c), UExpressionInner::Member(ref m) => write!(f, "{}", m), } } @@ -1518,7 +1538,7 @@ impl<'ast, T: fmt::Display> fmt::Display for BooleanExpression<'ast, T> { BooleanExpression::Not(ref exp) => write!(f, "!{}", exp), BooleanExpression::Value(b) => write!(f, "{}", b), BooleanExpression::FunctionCall(ref function_call) => write!(f, "{}", function_call), - BooleanExpression::IfElse(ref c) => write!(f, "{}", c), + BooleanExpression::Conditional(ref c) => write!(f, "{}", c), BooleanExpression::Member(ref m) => write!(f, "{}", m), BooleanExpression::Select(ref select) => write!(f, "{}", select), } @@ -1540,7 +1560,7 @@ impl<'ast, T: fmt::Display> fmt::Display for ArrayExpressionInner<'ast, T> { .join(", ") ), ArrayExpressionInner::FunctionCall(ref function_call) => write!(f, "{}", function_call), - ArrayExpressionInner::IfElse(ref c) => write!(f, "{}", c), + ArrayExpressionInner::Conditional(ref c) => write!(f, "{}", c), ArrayExpressionInner::Member(ref m) => write!(f, "{}", m), ArrayExpressionInner::Select(ref select) => write!(f, "{}", select), ArrayExpressionInner::Slice(ref a, ref from, ref to) => { @@ -1780,81 +1800,121 @@ pub enum MemberOrExpression<'ast, T, E: Expr<'ast, T>> { Expression(E::Inner), } -pub enum IfElseOrExpression<'ast, T, E: Expr<'ast, T>> { - IfElse(IfElseExpression<'ast, T, E>), +pub enum ConditionalOrExpression<'ast, T, E: Expr<'ast, T>> { + Conditional(ConditionalExpression<'ast, T, E>), Expression(E::Inner), } -pub trait IfElse<'ast, T> { - fn if_else(condition: BooleanExpression<'ast, T>, consequence: Self, alternative: Self) - -> Self; -} - -impl<'ast, T> IfElse<'ast, T> for FieldElementExpression<'ast, T> { - fn if_else( +pub trait Conditional<'ast, T> { + fn conditional( condition: BooleanExpression<'ast, T>, consequence: Self, alternative: Self, + kind: ConditionalKind, + ) -> Self; +} + +impl<'ast, T> Conditional<'ast, T> for FieldElementExpression<'ast, T> { + fn conditional( + condition: BooleanExpression<'ast, T>, + consequence: Self, + alternative: Self, + kind: ConditionalKind, ) -> Self { - FieldElementExpression::IfElse(IfElseExpression::new(condition, consequence, alternative)) + FieldElementExpression::Conditional(ConditionalExpression::new( + condition, + consequence, + alternative, + kind, + )) } } -impl<'ast, T> IfElse<'ast, T> for IntExpression<'ast, T> { - fn if_else( +impl<'ast, T> Conditional<'ast, T> for IntExpression<'ast, T> { + fn conditional( condition: BooleanExpression<'ast, T>, consequence: Self, alternative: Self, + kind: ConditionalKind, ) -> Self { - IntExpression::IfElse(IfElseExpression::new(condition, consequence, alternative)) + IntExpression::Conditional(ConditionalExpression::new( + condition, + consequence, + alternative, + kind, + )) } } -impl<'ast, T> IfElse<'ast, T> for BooleanExpression<'ast, T> { - fn if_else( +impl<'ast, T> Conditional<'ast, T> for BooleanExpression<'ast, T> { + fn conditional( condition: BooleanExpression<'ast, T>, consequence: Self, alternative: Self, + kind: ConditionalKind, ) -> Self { - BooleanExpression::IfElse(IfElseExpression::new(condition, consequence, alternative)) + BooleanExpression::Conditional(ConditionalExpression::new( + condition, + consequence, + alternative, + kind, + )) } } -impl<'ast, T> IfElse<'ast, T> for UExpression<'ast, T> { - fn if_else( +impl<'ast, T> Conditional<'ast, T> for UExpression<'ast, T> { + fn conditional( condition: BooleanExpression<'ast, T>, consequence: Self, alternative: Self, + kind: ConditionalKind, ) -> Self { let bitwidth = consequence.bitwidth; - UExpressionInner::IfElse(IfElseExpression::new(condition, consequence, alternative)) - .annotate(bitwidth) + UExpressionInner::Conditional(ConditionalExpression::new( + condition, + consequence, + alternative, + kind, + )) + .annotate(bitwidth) } } -impl<'ast, T: Clone> IfElse<'ast, T> for ArrayExpression<'ast, T> { - fn if_else( +impl<'ast, T: Clone> Conditional<'ast, T> for ArrayExpression<'ast, T> { + fn conditional( condition: BooleanExpression<'ast, T>, consequence: Self, alternative: Self, + kind: ConditionalKind, ) -> Self { let ty = consequence.inner_type().clone(); let size = consequence.size(); - ArrayExpressionInner::IfElse(IfElseExpression::new(condition, consequence, alternative)) - .annotate(ty, size) + ArrayExpressionInner::Conditional(ConditionalExpression::new( + condition, + consequence, + alternative, + kind, + )) + .annotate(ty, size) } } -impl<'ast, T: Clone> IfElse<'ast, T> for StructExpression<'ast, T> { - fn if_else( +impl<'ast, T: Clone> Conditional<'ast, T> for StructExpression<'ast, T> { + fn conditional( condition: BooleanExpression<'ast, T>, consequence: Self, alternative: Self, + kind: ConditionalKind, ) -> Self { let ty = consequence.ty().clone(); - StructExpressionInner::IfElse(IfElseExpression::new(condition, consequence, alternative)) - .annotate(ty) + StructExpressionInner::Conditional(ConditionalExpression::new( + condition, + consequence, + alternative, + kind, + )) + .annotate(ty) } } diff --git a/zokrates_core/src/typed_absy/result_folder.rs b/zokrates_core/src/typed_absy/result_folder.rs index 62a745af..5986ca22 100644 --- a/zokrates_core/src/typed_absy/result_folder.rs +++ b/zokrates_core/src/typed_absy/result_folder.rs @@ -184,14 +184,14 @@ pub trait ResultFolder<'ast, T: Field>: Sized { fold_types(self, tys) } - fn fold_if_else_expression< - E: Expr<'ast, T> + PartialEq + IfElse<'ast, T> + ResultFold<'ast, T>, + fn fold_conditional_expression< + E: Expr<'ast, T> + PartialEq + Conditional<'ast, T> + ResultFold<'ast, T>, >( &mut self, ty: &E::Ty, - e: IfElseExpression<'ast, T, E>, - ) -> Result, Self::Error> { - fold_if_else_expression(self, ty, e) + e: ConditionalExpression<'ast, T, E>, + ) -> Result, Self::Error> { + fold_conditional_expression(self, ty, e) } fn fold_block_expression>( @@ -507,9 +507,9 @@ pub fn fold_array_expression_inner<'ast, T: Field, F: ResultFolder<'ast, T>>( FunctionCallOrExpression::Expression(u) => u, } } - ArrayExpressionInner::IfElse(c) => match f.fold_if_else_expression(ty, c)? { - IfElseOrExpression::IfElse(c) => ArrayExpressionInner::IfElse(c), - IfElseOrExpression::Expression(u) => u, + ArrayExpressionInner::Conditional(c) => match f.fold_conditional_expression(ty, c)? { + ConditionalOrExpression::Conditional(c) => ArrayExpressionInner::Conditional(c), + ConditionalOrExpression::Expression(u) => u, }, ArrayExpressionInner::Member(m) => match f.fold_member_expression(ty, m)? { MemberOrExpression::Member(m) => ArrayExpressionInner::Member(m), @@ -572,9 +572,9 @@ pub fn fold_struct_expression_inner<'ast, T: Field, F: ResultFolder<'ast, T>>( FunctionCallOrExpression::Expression(u) => u, } } - StructExpressionInner::IfElse(c) => match f.fold_if_else_expression(ty, c)? { - IfElseOrExpression::IfElse(c) => StructExpressionInner::IfElse(c), - IfElseOrExpression::Expression(u) => u, + StructExpressionInner::Conditional(c) => match f.fold_conditional_expression(ty, c)? { + ConditionalOrExpression::Conditional(c) => StructExpressionInner::Conditional(c), + ConditionalOrExpression::Expression(u) => u, }, StructExpressionInner::Member(m) => match f.fold_member_expression(ty, m)? { MemberOrExpression::Member(m) => StructExpressionInner::Member(m), @@ -635,10 +635,10 @@ pub fn fold_field_expression<'ast, T: Field, F: ResultFolder<'ast, T>>( FieldElementExpression::Pos(box e) } - FieldElementExpression::IfElse(c) => { - match f.fold_if_else_expression(&Type::FieldElement, c)? { - IfElseOrExpression::IfElse(c) => FieldElementExpression::IfElse(c), - IfElseOrExpression::Expression(u) => u, + FieldElementExpression::Conditional(c) => { + match f.fold_conditional_expression(&Type::FieldElement, c)? { + ConditionalOrExpression::Conditional(c) => FieldElementExpression::Conditional(c), + ConditionalOrExpression::Expression(u) => u, } } FieldElementExpression::FunctionCall(function_call) => { @@ -689,11 +689,11 @@ pub fn fold_block_expression<'ast, T: Field, E: ResultFold<'ast, T>, F: ResultFo }) } -pub fn fold_if_else_expression< +pub fn fold_conditional_expression< 'ast, T: Field, E: Expr<'ast, T> - + IfElse<'ast, T> + + Conditional<'ast, T> + PartialEq + ResultFold<'ast, T> + From>, @@ -701,13 +701,16 @@ pub fn fold_if_else_expression< >( f: &mut F, _: &E::Ty, - e: IfElseExpression<'ast, T, E>, -) -> Result, F::Error> { - Ok(IfElseOrExpression::IfElse(IfElseExpression::new( - f.fold_boolean_expression(*e.condition)?, - e.consequence.fold(f)?, - e.alternative.fold(f)?, - ))) + e: ConditionalExpression<'ast, T, E>, +) -> Result, F::Error> { + Ok(ConditionalOrExpression::Conditional( + ConditionalExpression::new( + f.fold_boolean_expression(*e.condition)?, + e.consequence.fold(f)?, + e.alternative.fold(f)?, + e.kind, + ), + )) } pub fn fold_member_expression< @@ -863,10 +866,12 @@ pub fn fold_boolean_expression<'ast, T: Field, F: ResultFolder<'ast, T>>( FunctionCallOrExpression::Expression(u) => u, } } - BooleanExpression::IfElse(c) => match f.fold_if_else_expression(&Type::Boolean, c)? { - IfElseOrExpression::IfElse(c) => BooleanExpression::IfElse(c), - IfElseOrExpression::Expression(u) => u, - }, + BooleanExpression::Conditional(c) => { + match f.fold_conditional_expression(&Type::Boolean, c)? { + ConditionalOrExpression::Conditional(c) => BooleanExpression::Conditional(c), + ConditionalOrExpression::Expression(u) => u, + } + } BooleanExpression::Select(select) => { match f.fold_select_expression(&Type::Boolean, select)? { SelectOrExpression::Select(s) => BooleanExpression::Select(s), @@ -991,9 +996,9 @@ pub fn fold_uint_expression_inner<'ast, T: Field, F: ResultFolder<'ast, T>>( SelectOrExpression::Select(s) => UExpressionInner::Select(s), SelectOrExpression::Expression(u) => u, }, - UExpressionInner::IfElse(c) => match f.fold_if_else_expression(&ty, c)? { - IfElseOrExpression::IfElse(c) => UExpressionInner::IfElse(c), - IfElseOrExpression::Expression(u) => u, + UExpressionInner::Conditional(c) => match f.fold_conditional_expression(&ty, c)? { + ConditionalOrExpression::Conditional(c) => UExpressionInner::Conditional(c), + ConditionalOrExpression::Expression(u) => u, }, UExpressionInner::Member(m) => match f.fold_member_expression(&ty, m)? { MemberOrExpression::Member(m) => UExpressionInner::Member(m), diff --git a/zokrates_core/src/typed_absy/uint.rs b/zokrates_core/src/typed_absy/uint.rs index df4e476e..dc5aa09b 100644 --- a/zokrates_core/src/typed_absy/uint.rs +++ b/zokrates_core/src/typed_absy/uint.rs @@ -193,7 +193,7 @@ pub enum UExpressionInner<'ast, T> { FunctionCall(FunctionCallExpression<'ast, T, UExpression<'ast, T>>), LeftShift(Box>, Box>), RightShift(Box>, Box>), - IfElse(IfElseExpression<'ast, T, UExpression<'ast, T>>), + Conditional(ConditionalExpression<'ast, T, UExpression<'ast, T>>), Member(MemberExpression<'ast, T, UExpression<'ast, T>>), Select(SelectExpression<'ast, T, UExpression<'ast, T>>), } diff --git a/zokrates_core_test/tests/tests/ternary.json b/zokrates_core_test/tests/tests/ternary.json new file mode 100644 index 00000000..fb751f4d --- /dev/null +++ b/zokrates_core_test/tests/tests/ternary.json @@ -0,0 +1,36 @@ +{ + "entry_point": "./tests/tests/ternary.zok", + "max_constraint_count": 6, + "tests": [ + { + "input": { + "values": ["1", "0"] + }, + "output": { + "Ok": { + "values": ["1"] + } + } + }, + { + "input": { + "values": ["0", "1"] + }, + "output": { + "Ok": { + "values": ["2"] + } + } + }, + { + "input": { + "values": ["0", "0"] + }, + "output": { + "Ok": { + "values": ["3"] + } + } + } + ] +} diff --git a/zokrates_core_test/tests/tests/ternary.zok b/zokrates_core_test/tests/tests/ternary.zok new file mode 100644 index 00000000..7bd9486e --- /dev/null +++ b/zokrates_core_test/tests/tests/ternary.zok @@ -0,0 +1,5 @@ +def main(bool a, bool b) -> field: + field x = a ? 1 : b ? 2 : 3 // (a ? 1 : (b ? 2 : 3)) + field y = if a then 1 else if b then 2 else 3 fi fi + assert(x == y) + return x \ No newline at end of file diff --git a/zokrates_parser/src/zokrates.pest b/zokrates_parser/src/zokrates.pest index e766d084..06970e18 100644 --- a/zokrates_parser/src/zokrates.pest +++ b/zokrates_parser/src/zokrates.pest @@ -68,13 +68,13 @@ expression = { unaried_term ~ (op_binary ~ unaried_term)* } unaried_term = { op_unary? ~ powered_term } powered_term = { postfixed_term ~ (op_pow ~ exponent_expression)? } postfixed_term = { term ~ access* } -term = { ("(" ~ expression ~ ")") | inline_struct_expression | conditional_expression | primary_expression | inline_array_expression | array_initializer_expression } +term = { ("(" ~ expression ~ ")") | inline_struct_expression | if_else_expression | primary_expression | inline_array_expression | array_initializer_expression } spread = { "..." ~ expression } range = { from_expression? ~ ".." ~ to_expression? } from_expression = { expression } to_expression = { expression } -conditional_expression = { "if" ~ expression ~ "then" ~ expression ~ "else" ~ expression ~ "fi"} +if_else_expression = { "if" ~ expression ~ "then" ~ expression ~ "else" ~ expression ~ "fi"} access = { array_access | call_access | member_access } array_access = { "[" ~ range_or_expression ~ "]" } @@ -155,8 +155,10 @@ op_neg = {"-"} op_pos = {"+"} op_left_shift = @{"<<"} op_right_shift = @{">>"} +op_ternary = {"?" ~ expression ~ ":"} + // `op_pow` is *not* in `op_binary` because its precedence is handled in this parser rather than down the line in precedence climbing -op_binary = _ { op_or | op_and | op_bit_xor | op_bit_and | op_bit_or | op_left_shift | op_right_shift | op_equal | op_not_equal | op_lte | op_lt | op_gte | op_gt | op_add | op_sub | op_mul | op_div | op_rem } +op_binary = _ { op_or | op_and | op_bit_xor | op_bit_and | op_bit_or | op_left_shift | op_right_shift | op_equal | op_not_equal | op_lte | op_lt | op_gte | op_gt | op_add | op_sub | op_mul | op_div | op_rem | op_ternary } op_unary = { op_pos | op_neg | op_not } WHITESPACE = _{ " " | "\t" | "\\" ~ COMMENT? ~ NEWLINE} diff --git a/zokrates_pest_ast/src/lib.rs b/zokrates_pest_ast/src/lib.rs index 3bc24404..e35b3a42 100644 --- a/zokrates_pest_ast/src/lib.rs +++ b/zokrates_pest_ast/src/lib.rs @@ -13,7 +13,7 @@ pub use ast::{ CallAccess, ConstantDefinition, ConstantGenericValue, DecimalLiteralExpression, DecimalNumber, DecimalSuffix, DefinitionStatement, ExplicitGenerics, Expression, FieldType, File, FromExpression, FunctionDefinition, HexLiteralExpression, HexNumberExpression, - IdentifierExpression, ImportDirective, ImportSymbol, InlineArrayExpression, + IdentifierExpression, IfElseExpression, ImportDirective, ImportSymbol, InlineArrayExpression, InlineStructExpression, InlineStructMember, IterationStatement, LiteralExpression, Parameter, PostfixExpression, Range, RangeOrExpression, ReturnStatement, Span, Spread, SpreadOrExpression, Statement, StructDefinition, StructField, SymbolDeclaration, TernaryExpression, ToExpression, @@ -38,6 +38,7 @@ mod ast { // based on https://docs.python.org/3/reference/expressions.html#operator-precedence fn build_precedence_climber() -> PrecClimber { PrecClimber::new(vec![ + Operator::new(Rule::op_ternary, Assoc::Right), Operator::new(Rule::op_or, Assoc::Left), Operator::new(Rule::op_and, Assoc::Left), Operator::new(Rule::op_lt, Assoc::Left) @@ -89,6 +90,12 @@ mod ast { Rule::op_bit_or => Expression::binary(BinaryOperator::BitOr, lhs, rhs, span), Rule::op_right_shift => Expression::binary(BinaryOperator::RightShift, lhs, rhs, span), Rule::op_left_shift => Expression::binary(BinaryOperator::LeftShift, lhs, rhs, span), + Rule::op_ternary => Expression::ternary( + lhs, + Box::new(Expression::from_pest(&mut pair.into_inner()).unwrap()), + rhs, + span, + ), _ => unreachable!(), }) } @@ -412,6 +419,7 @@ mod ast { #[derive(Debug, PartialEq, Clone)] pub enum Expression<'ast> { Ternary(TernaryExpression<'ast>), + IfElse(IfElseExpression<'ast>), Binary(BinaryExpression<'ast>), Unary(UnaryExpression<'ast>), Postfix(PostfixExpression<'ast>), @@ -427,7 +435,7 @@ mod ast { pub enum Term<'ast> { Expression(Expression<'ast>), InlineStruct(InlineStructExpression<'ast>), - Ternary(TernaryExpression<'ast>), + IfElse(IfElseExpression<'ast>), Primary(PrimaryExpression<'ast>), InlineArray(InlineArrayExpression<'ast>), ArrayInitializer(ArrayInitializerExpression<'ast>), @@ -543,7 +551,7 @@ mod ast { fn from(t: Term<'ast>) -> Self { match t { Term::Expression(e) => e, - Term::Ternary(e) => Expression::Ternary(e), + Term::IfElse(e) => Expression::IfElse(e), Term::Primary(e) => e.into(), Term::InlineArray(e) => Expression::InlineArray(e), Term::InlineStruct(e) => Expression::InlineStruct(e), @@ -761,27 +769,49 @@ mod ast { pub span: Span<'ast>, } - #[derive(Debug, FromPest, PartialEq, Clone)] - #[pest_ast(rule(Rule::conditional_expression))] + #[derive(Debug, PartialEq, Clone)] pub struct TernaryExpression<'ast> { - pub first: Box>, - pub second: Box>, - pub third: Box>, + pub condition: Box>, + pub consequence: Box>, + pub alternative: Box>, + pub span: Span<'ast>, + } + + #[derive(Debug, FromPest, PartialEq, Clone)] + #[pest_ast(rule(Rule::if_else_expression))] + pub struct IfElseExpression<'ast> { + pub condition: Box>, + pub consequence: Box>, + pub alternative: Box>, #[pest_ast(outer())] pub span: Span<'ast>, } impl<'ast> Expression<'ast> { + pub fn if_else( + condition: Box>, + consequence: Box>, + alternative: Box>, + span: Span<'ast>, + ) -> Self { + Expression::IfElse(IfElseExpression { + condition, + consequence, + alternative, + span, + }) + } + pub fn ternary( - first: Box>, - second: Box>, - third: Box>, + condition: Box>, + consequence: Box>, + alternative: Box>, span: Span<'ast>, ) -> Self { Expression::Ternary(TernaryExpression { - first, - second, - third, + condition, + consequence, + alternative, span, }) } @@ -806,6 +836,7 @@ mod ast { Expression::Identifier(i) => &i.span, Expression::Literal(c) => c.span(), Expression::Ternary(t) => &t.span, + Expression::IfElse(ie) => &ie.span, Expression::Postfix(p) => &p.span, Expression::InlineArray(a) => &a.span, Expression::InlineStruct(s) => &s.span, @@ -1071,20 +1102,6 @@ mod tests { pub fn pow(left: Expression<'ast>, right: Expression<'ast>, span: Span<'ast>) -> Self { Self::binary(BinaryOperator::Pow, Box::new(left), Box::new(right), span) } - - pub fn if_else( - condition: Expression<'ast>, - consequence: Expression<'ast>, - alternative: Expression<'ast>, - span: Span<'ast>, - ) -> Self { - Self::ternary( - Box::new(condition), - Box::new(consequence), - Box::new(alternative), - span, - ) - } } #[test] @@ -1263,7 +1280,7 @@ mod tests { }))], statements: vec![Statement::Return(ReturnStatement { expressions: vec![Expression::if_else( - Expression::Literal(LiteralExpression::DecimalLiteral( + Box::new(Expression::Literal(LiteralExpression::DecimalLiteral( DecimalLiteralExpression { suffix: None, value: DecimalNumber { @@ -1271,8 +1288,8 @@ mod tests { }, span: Span::new(source, 62, 63).unwrap() } - )), - Expression::Literal(LiteralExpression::DecimalLiteral( + ))), + Box::new(Expression::Literal(LiteralExpression::DecimalLiteral( DecimalLiteralExpression { suffix: None, value: DecimalNumber { @@ -1280,8 +1297,8 @@ mod tests { }, span: Span::new(source, 69, 70).unwrap() } - )), - Expression::Literal(LiteralExpression::DecimalLiteral( + ))), + Box::new(Expression::Literal(LiteralExpression::DecimalLiteral( DecimalLiteralExpression { suffix: None, value: DecimalNumber { @@ -1289,8 +1306,8 @@ mod tests { }, span: Span::new(source, 76, 77).unwrap() } - )), - Span::new(source, 59, 80).unwrap() + ))), + Span::new(&source, 59, 80).unwrap() )], span: Span::new(source, 52, 80).unwrap(), })],