diff --git a/changelogs/unreleased/1010-dark64 b/changelogs/unreleased/1010-dark64
new file mode 100644
index 00000000..6aa232ca
--- /dev/null
+++ b/changelogs/unreleased/1010-dark64
@@ -0,0 +1 @@
+Introduce ternary operator
\ No newline at end of file
diff --git a/zokrates_book/src/language/control_flow.md b/zokrates_book/src/language/control_flow.md
index 9e47d289..38a57465 100644
--- a/zokrates_book/src/language/control_flow.md
+++ b/zokrates_book/src/language/control_flow.md
@@ -26,6 +26,12 @@ An if-expression allows you to branch your code depending on a boolean condition
{{#include ../../../zokrates_cli/examples/book/if_else.zok}}
```
+The conditional expression can also be written using a ternary operator:
+
+```zokrates
+{{#include ../../../zokrates_cli/examples/book/ternary.zok}}
+```
+
There are two important caveats when it comes to conditional expressions. Before we go into them, let's define two concepts:
- for an execution of the program, *an executed branch* is a branch which has to be paid for when executing the program, generating proofs, etc.
- for an execution of the program, *a logically executed branch* is a branch which is "chosen" by the condition of an if-expression. This is the more intuitive notion of execution, and there is only one for each if-expression.
diff --git a/zokrates_book/src/language/operators.md b/zokrates_book/src/language/operators.md
index b82d59b7..7f013b22 100644
--- a/zokrates_book/src/language/operators.md
+++ b/zokrates_book/src/language/operators.md
@@ -16,10 +16,12 @@ The following table lists the precedence and associativity of all operators. Ope
| `!=`
`==`
| Not Equal
Equal
| ✓ | ✓ | ✓ | Left | |
| `&&` | Boolean AND | | | ✓ | Left | |
| ||
| Boolean OR | | | ✓ | Left | |
-| `if c then x else y fi` | Conditional expression | ✓ | ✓ | ✓ | Right | |
+| `if c then x else y fi` | Conditional expression | ✓ | ✓ | ✓ | Right | [^4] |
[^1]: The exponent must be a compile-time constant of type `u32`
[^2]: The right operand must be a compile time constant of type `u32`
-[^3]: Both operands are asserted to be strictly lower than the biggest power of 2 lower than `p/2`, unless one of them can be determined to be a compile-time constant
\ No newline at end of file
+[^3]: Both operands are asserted to be strictly lower than the biggest power of 2 lower than `p/2`, unless one of them can be determined to be a compile-time constant
+
+[^4]: Conditional expression can also be written using a ternary operator: `c ? x : y`
\ No newline at end of file
diff --git a/zokrates_cli/examples/book/ternary.zok b/zokrates_cli/examples/book/ternary.zok
new file mode 100644
index 00000000..c90ecc5c
--- /dev/null
+++ b/zokrates_cli/examples/book/ternary.zok
@@ -0,0 +1,3 @@
+def main(field x) -> field:
+ field y = x + 2 == 3 ? 1 : 5
+ return y
\ No newline at end of file
diff --git a/zokrates_cli/examples/compile_errors/ternary_precedence.zok b/zokrates_cli/examples/compile_errors/ternary_precedence.zok
new file mode 100644
index 00000000..180b081e
--- /dev/null
+++ b/zokrates_cli/examples/compile_errors/ternary_precedence.zok
@@ -0,0 +1,5 @@
+def main(bool a) -> field:
+ // ternary expression should be wrapped in parentheses 1 + (a ? 2 : 3)
+ // otherwise the whole expression is parsed as (1 + a) ? 2 : 3
+ field x = 1 + a ? 2 : 3
+ return x
\ No newline at end of file
diff --git a/zokrates_core/src/absy/from_ast.rs b/zokrates_core/src/absy/from_ast.rs
index 64d723f0..81f0b903 100644
--- a/zokrates_core/src/absy/from_ast.rs
+++ b/zokrates_core/src/absy/from_ast.rs
@@ -378,6 +378,7 @@ impl<'ast> From> for absy::ExpressionNode<'ast> {
match expression {
pest::Expression::Binary(e) => absy::ExpressionNode::from(e),
pest::Expression::Ternary(e) => absy::ExpressionNode::from(e),
+ pest::Expression::IfElse(e) => absy::ExpressionNode::from(e),
pest::Expression::Literal(e) => absy::ExpressionNode::from(e),
pest::Expression::Identifier(e) => absy::ExpressionNode::from(e),
pest::Expression::Postfix(e) => absy::ExpressionNode::from(e),
@@ -478,13 +479,27 @@ impl<'ast> From> for absy::ExpressionNode<'ast> {
}
}
+impl<'ast> From> for absy::ExpressionNode<'ast> {
+ fn from(expression: pest::IfElseExpression<'ast>) -> absy::ExpressionNode<'ast> {
+ use crate::absy::NodeValue;
+ absy::Expression::Conditional(
+ box absy::ExpressionNode::from(*expression.condition),
+ box absy::ExpressionNode::from(*expression.consequence),
+ box absy::ExpressionNode::from(*expression.alternative),
+ absy::ConditionalKind::IfElse,
+ )
+ .span(expression.span)
+ }
+}
+
impl<'ast> From> for absy::ExpressionNode<'ast> {
fn from(expression: pest::TernaryExpression<'ast>) -> absy::ExpressionNode<'ast> {
use crate::absy::NodeValue;
- absy::Expression::IfElse(
- box absy::ExpressionNode::from(*expression.first),
- box absy::ExpressionNode::from(*expression.second),
- box absy::ExpressionNode::from(*expression.third),
+ absy::Expression::Conditional(
+ box absy::ExpressionNode::from(*expression.condition),
+ box absy::ExpressionNode::from(*expression.consequence),
+ box absy::ExpressionNode::from(*expression.alternative),
+ absy::ConditionalKind::Ternary,
)
.span(expression.span)
}
diff --git a/zokrates_core/src/absy/mod.rs b/zokrates_core/src/absy/mod.rs
index c709db28..e6c8ec69 100644
--- a/zokrates_core/src/absy/mod.rs
+++ b/zokrates_core/src/absy/mod.rs
@@ -460,6 +460,12 @@ impl<'ast> fmt::Display for Range<'ast> {
}
}
+#[derive(Debug, Clone, PartialEq)]
+pub enum ConditionalKind {
+ IfElse,
+ Ternary,
+}
+
/// An expression
#[derive(Debug, Clone, PartialEq)]
pub enum Expression<'ast> {
@@ -479,10 +485,11 @@ pub enum Expression<'ast> {
Pow(Box>, Box>),
Neg(Box>),
Pos(Box>),
- IfElse(
+ Conditional(
Box>,
Box>,
Box>,
+ ConditionalKind,
),
FunctionCall(
Box>,
@@ -530,11 +537,18 @@ impl<'ast> fmt::Display for Expression<'ast> {
Expression::Neg(ref e) => write!(f, "(-{})", e),
Expression::Pos(ref e) => write!(f, "(+{})", e),
Expression::BooleanConstant(b) => write!(f, "{}", b),
- Expression::IfElse(ref condition, ref consequent, ref alternative) => write!(
- f,
- "if {} then {} else {} fi",
- condition, consequent, alternative
- ),
+ Expression::Conditional(ref condition, ref consequent, ref alternative, ref kind) => {
+ match kind {
+ ConditionalKind::IfElse => write!(
+ f,
+ "if {} then {} else {} fi",
+ condition, consequent, alternative
+ ),
+ ConditionalKind::Ternary => {
+ write!(f, "{} ? {} : {}", condition, consequent, alternative)
+ }
+ }
+ }
Expression::FunctionCall(ref i, ref g, ref p) => {
if let Some(g) = g {
write!(
diff --git a/zokrates_core/src/semantics.rs b/zokrates_core/src/semantics.rs
index 0cbd7eaf..b0041ef0 100644
--- a/zokrates_core/src/semantics.rs
+++ b/zokrates_core/src/semantics.rs
@@ -2270,7 +2270,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
}),
}
}
- Expression::IfElse(box condition, box consequence, box alternative) => {
+ Expression::Conditional(box condition, box consequence, box alternative, kind) => {
let condition_checked = self.check_expression(condition, module_id, types)?;
let consequence_checked = self.check_expression(consequence, module_id, types)?;
let alternative_checked = self.check_expression(alternative, module_id, types)?;
@@ -2282,40 +2282,49 @@ impl<'ast, T: Field> Checker<'ast, T> {
)
.map_err(|(e1, e2)| ErrorInner {
pos: Some(pos),
- message: format!("{{consequence}} and {{alternative}} in `if/else` expression should have the same type, found {}, {}", e1.get_type(), e2.get_type()),
+ message: format!("{{consequence}} and {{alternative}} in conditional expression should have the same type, found {}, {}", e1.get_type(), e2.get_type()),
})?;
+ let kind = match kind {
+ crate::absy::ConditionalKind::IfElse => {
+ crate::typed_absy::ConditionalKind::IfElse
+ }
+ crate::absy::ConditionalKind::Ternary => {
+ crate::typed_absy::ConditionalKind::Ternary
+ }
+ };
+
match condition_checked {
TypedExpression::Boolean(condition) => {
match (consequence_checked, alternative_checked) {
(TypedExpression::FieldElement(consequence), TypedExpression::FieldElement(alternative)) => {
- Ok(FieldElementExpression::if_else(condition, consequence, alternative).into())
+ Ok(FieldElementExpression::conditional(condition, consequence, alternative, kind).into())
},
(TypedExpression::Boolean(consequence), TypedExpression::Boolean(alternative)) => {
- Ok(BooleanExpression::if_else(condition, consequence, alternative).into())
+ Ok(BooleanExpression::conditional(condition, consequence, alternative, kind).into())
},
(TypedExpression::Array(consequence), TypedExpression::Array(alternative)) => {
- Ok(ArrayExpression::if_else(condition, consequence, alternative).into())
+ Ok(ArrayExpression::conditional(condition, consequence, alternative, kind).into())
},
(TypedExpression::Struct(consequence), TypedExpression::Struct(alternative)) => {
- Ok(StructExpression::if_else(condition, consequence, alternative).into())
+ Ok(StructExpression::conditional(condition, consequence, alternative, kind).into())
},
(TypedExpression::Uint(consequence), TypedExpression::Uint(alternative)) => {
- Ok(UExpression::if_else(condition, consequence, alternative).into())
+ Ok(UExpression::conditional(condition, consequence, alternative, kind).into())
},
(TypedExpression::Int(consequence), TypedExpression::Int(alternative)) => {
- Ok(IntExpression::if_else(condition, consequence, alternative).into())
+ Ok(IntExpression::conditional(condition, consequence, alternative, kind).into())
},
(c, a) => Err(ErrorInner {
pos: Some(pos),
- message: format!("{{consequence}} and {{alternative}} in `if/else` expression should have the same type, found {}, {}", c.get_type(), a.get_type())
+ message: format!("{{consequence}} and {{alternative}} in conditional expression should have the same type, found {}, {}", c.get_type(), a.get_type())
})
}
}
c => Err(ErrorInner {
pos: Some(pos),
message: format!(
- "{{condition}} after `if` should be a boolean, found {}",
+ "{{condition}} should be a boolean, found {}",
c.get_type()
),
}),
diff --git a/zokrates_core/src/static_analysis/branch_isolator.rs b/zokrates_core/src/static_analysis/branch_isolator.rs
index c9d848d8..77580386 100644
--- a/zokrates_core/src/static_analysis/branch_isolator.rs
+++ b/zokrates_core/src/static_analysis/branch_isolator.rs
@@ -17,17 +17,18 @@ impl Isolator {
}
impl<'ast, T: Field> Folder<'ast, T> for Isolator {
- fn fold_if_else_expression<
- E: Expr<'ast, T> + Block<'ast, T> + Fold<'ast, T> + IfElse<'ast, T>,
+ fn fold_conditional_expression<
+ E: Expr<'ast, T> + Block<'ast, T> + Fold<'ast, T> + Conditional<'ast, T>,
>(
&mut self,
_: &E::Ty,
- e: IfElseExpression<'ast, T, E>,
- ) -> IfElseOrExpression<'ast, T, E> {
- IfElseOrExpression::IfElse(IfElseExpression::new(
+ e: ConditionalExpression<'ast, T, E>,
+ ) -> ConditionalOrExpression<'ast, T, E> {
+ ConditionalOrExpression::Conditional(ConditionalExpression::new(
self.fold_boolean_expression(*e.condition),
E::block(vec![], e.consequence.fold(self)),
E::block(vec![], e.alternative.fold(self)),
+ e.kind,
))
}
}
diff --git a/zokrates_core/src/static_analysis/flatten_complex_types.rs b/zokrates_core/src/static_analysis/flatten_complex_types.rs
index 1986efde..7709fbb6 100644
--- a/zokrates_core/src/static_analysis/flatten_complex_types.rs
+++ b/zokrates_core/src/static_analysis/flatten_complex_types.rs
@@ -281,12 +281,12 @@ impl<'ast, T: Field> Flattener {
}
}
- fn fold_if_else_expression>(
+ fn fold_conditional_expression>(
&mut self,
statements_buffer: &mut Vec>,
- c: typed_absy::IfElseExpression<'ast, T, E>,
+ c: typed_absy::ConditionalExpression<'ast, T, E>,
) -> Vec> {
- fold_if_else_expression(self, statements_buffer, c)
+ fold_conditional_expression(self, statements_buffer, c)
}
fn fold_member_expression(
@@ -448,8 +448,8 @@ fn fold_array_expression_inner<'ast, T: Field>(
exprs
}
typed_absy::ArrayExpressionInner::FunctionCall(..) => unreachable!(),
- typed_absy::ArrayExpressionInner::IfElse(c) => {
- f.fold_if_else_expression(statements_buffer, c)
+ typed_absy::ArrayExpressionInner::Conditional(c) => {
+ f.fold_conditional_expression(statements_buffer, c)
}
typed_absy::ArrayExpressionInner::Member(m) => {
f.fold_member_expression(statements_buffer, m)
@@ -523,8 +523,8 @@ fn fold_struct_expression_inner<'ast, T: Field>(
.flat_map(|e| f.fold_expression(statements_buffer, e))
.collect(),
typed_absy::StructExpressionInner::FunctionCall(..) => unreachable!(),
- typed_absy::StructExpressionInner::IfElse(c) => {
- f.fold_if_else_expression(statements_buffer, c)
+ typed_absy::StructExpressionInner::Conditional(c) => {
+ f.fold_conditional_expression(statements_buffer, c)
}
typed_absy::StructExpressionInner::Member(m) => {
f.fold_member_expression(statements_buffer, m)
@@ -646,10 +646,10 @@ fn fold_select_expression<'ast, T: Field, E>(
}
}
-fn fold_if_else_expression<'ast, T: Field, E: Flatten<'ast, T>>(
+fn fold_conditional_expression<'ast, T: Field, E: Flatten<'ast, T>>(
f: &mut Flattener,
statements_buffer: &mut Vec>,
- c: typed_absy::IfElseExpression<'ast, T, E>,
+ c: typed_absy::ConditionalExpression<'ast, T, E>,
) -> Vec> {
let mut consequence_statements = vec![];
let mut alternative_statements = vec![];
@@ -742,8 +742,8 @@ fn fold_field_expression<'ast, T: Field>(
typed_absy::FieldElementExpression::Pos(box e) => {
f.fold_field_expression(statements_buffer, e)
}
- typed_absy::FieldElementExpression::IfElse(c) => f
- .fold_if_else_expression(statements_buffer, c)
+ typed_absy::FieldElementExpression::Conditional(c) => f
+ .fold_conditional_expression(statements_buffer, c)
.pop()
.unwrap()
.try_into()
@@ -908,8 +908,8 @@ fn fold_boolean_expression<'ast, T: Field>(
let e = f.fold_boolean_expression(statements_buffer, e);
zir::BooleanExpression::Not(box e)
}
- typed_absy::BooleanExpression::IfElse(c) => f
- .fold_if_else_expression(statements_buffer, c)
+ typed_absy::BooleanExpression::Conditional(c) => f
+ .fold_conditional_expression(statements_buffer, c)
.pop()
.unwrap()
.try_into()
@@ -1070,8 +1070,8 @@ fn fold_uint_expression_inner<'ast, T: Field>(
)
.unwrap()
.into_inner(),
- typed_absy::UExpressionInner::IfElse(c) => zir::UExpression::try_from(
- f.fold_if_else_expression(statements_buffer, c)
+ typed_absy::UExpressionInner::Conditional(c) => zir::UExpression::try_from(
+ f.fold_conditional_expression(statements_buffer, c)
.pop()
.unwrap(),
)
diff --git a/zokrates_core/src/static_analysis/propagation.rs b/zokrates_core/src/static_analysis/propagation.rs
index 8c9c6e8a..14e4520a 100644
--- a/zokrates_core/src/static_analysis/propagation.rs
+++ b/zokrates_core/src/static_analysis/propagation.rs
@@ -172,13 +172,13 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Propagator<'ast, 'a, T> {
fold_function(self, f)
}
- fn fold_if_else_expression<
- E: Expr<'ast, T> + IfElse<'ast, T> + PartialEq + ResultFold<'ast, T>,
+ fn fold_conditional_expression<
+ E: Expr<'ast, T> + Conditional<'ast, T> + PartialEq + ResultFold<'ast, T>,
>(
&mut self,
_: &E::Ty,
- e: IfElseExpression<'ast, T, E>,
- ) -> Result, Self::Error> {
+ e: ConditionalExpression<'ast, T, E>,
+ ) -> Result, Self::Error> {
Ok(
match (
self.fold_boolean_expression(*e.condition)?,
@@ -186,16 +186,16 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Propagator<'ast, 'a, T> {
e.alternative.fold(self)?,
) {
(BooleanExpression::Value(true), consequence, _) => {
- IfElseOrExpression::Expression(consequence.into_inner())
+ ConditionalOrExpression::Expression(consequence.into_inner())
}
(BooleanExpression::Value(false), _, alternative) => {
- IfElseOrExpression::Expression(alternative.into_inner())
+ ConditionalOrExpression::Expression(alternative.into_inner())
}
(_, consequence, alternative) if consequence == alternative => {
- IfElseOrExpression::Expression(consequence.into_inner())
+ ConditionalOrExpression::Expression(consequence.into_inner())
}
- (condition, consequence, alternative) => IfElseOrExpression::IfElse(
- IfElseExpression::new(condition, consequence, alternative),
+ (condition, consequence, alternative) => ConditionalOrExpression::Conditional(
+ ConditionalExpression::new(condition, consequence, alternative, e.kind),
),
},
)
@@ -1431,10 +1431,11 @@ mod tests {
#[test]
fn if_else_true() {
- let e = FieldElementExpression::if_else(
+ let e = FieldElementExpression::conditional(
BooleanExpression::Value(true),
FieldElementExpression::Number(Bn128Field::from(2)),
FieldElementExpression::Number(Bn128Field::from(3)),
+ ConditionalKind::IfElse,
);
assert_eq!(
@@ -1445,10 +1446,11 @@ mod tests {
#[test]
fn if_else_false() {
- let e = FieldElementExpression::if_else(
+ let e = FieldElementExpression::conditional(
BooleanExpression::Value(false),
FieldElementExpression::Number(Bn128Field::from(2)),
FieldElementExpression::Number(Bn128Field::from(3)),
+ ConditionalKind::IfElse,
);
assert_eq!(
diff --git a/zokrates_core/src/static_analysis/variable_write_remover.rs b/zokrates_core/src/static_analysis/variable_write_remover.rs
index bce07349..d5522641 100644
--- a/zokrates_core/src/static_analysis/variable_write_remover.rs
+++ b/zokrates_core/src/static_analysis/variable_write_remover.rs
@@ -56,7 +56,7 @@ impl<'ast> VariableWriteRemover {
(0..size)
.map(|i| match inner_ty {
Type::Int => unreachable!(),
- Type::Array(..) => ArrayExpression::if_else(
+ Type::Array(..) => ArrayExpression::conditional(
BooleanExpression::UintEq(
box i.into(),
box head.clone(),
@@ -74,9 +74,10 @@ impl<'ast> VariableWriteRemover {
),
},
ArrayExpression::select(base.clone(), i),
+ ConditionalKind::IfElse,
)
.into(),
- Type::Struct(..) => StructExpression::if_else(
+ Type::Struct(..) => StructExpression::conditional(
BooleanExpression::UintEq(
box i.into(),
box head.clone(),
@@ -94,9 +95,10 @@ impl<'ast> VariableWriteRemover {
),
},
StructExpression::select(base.clone(), i),
+ ConditionalKind::IfElse,
)
.into(),
- Type::FieldElement => FieldElementExpression::if_else(
+ Type::FieldElement => FieldElementExpression::conditional(
BooleanExpression::UintEq(
box i.into(),
box head.clone(),
@@ -115,9 +117,10 @@ impl<'ast> VariableWriteRemover {
),
},
FieldElementExpression::select(base.clone(), i),
+ ConditionalKind::IfElse,
)
.into(),
- Type::Boolean => BooleanExpression::if_else(
+ Type::Boolean => BooleanExpression::conditional(
BooleanExpression::UintEq(
box i.into(),
box head.clone(),
@@ -135,9 +138,10 @@ impl<'ast> VariableWriteRemover {
),
},
BooleanExpression::select(base.clone(), i),
+ ConditionalKind::IfElse,
)
.into(),
- Type::Uint(..) => UExpression::if_else(
+ Type::Uint(..) => UExpression::conditional(
BooleanExpression::UintEq(
box i.into(),
box head.clone(),
@@ -155,6 +159,7 @@ impl<'ast> VariableWriteRemover {
),
},
UExpression::select(base.clone(), i),
+ ConditionalKind::IfElse,
)
.into(),
})
diff --git a/zokrates_core/src/typed_absy/folder.rs b/zokrates_core/src/typed_absy/folder.rs
index 6a60e65a..3447f664 100644
--- a/zokrates_core/src/typed_absy/folder.rs
+++ b/zokrates_core/src/typed_absy/folder.rs
@@ -284,18 +284,18 @@ pub trait Folder<'ast, T: Field>: Sized {
fold_block_expression(self, block)
}
- fn fold_if_else_expression<
+ fn fold_conditional_expression<
E: Expr<'ast, T>
+ Fold<'ast, T>
+ Block<'ast, T>
- + IfElse<'ast, T>
+ + Conditional<'ast, T>
+ From>,
>(
&mut self,
ty: &E::Ty,
- e: IfElseExpression<'ast, T, E>,
- ) -> IfElseOrExpression<'ast, T, E> {
- fold_if_else_expression(self, ty, e)
+ e: ConditionalExpression<'ast, T, E>,
+ ) -> ConditionalOrExpression<'ast, T, E> {
+ fold_conditional_expression(self, ty, e)
}
fn fold_member_expression<
@@ -319,7 +319,7 @@ pub trait Folder<'ast, T: Field>: Sized {
}
fn fold_select_expression<
- E: Expr<'ast, T> + Select<'ast, T> + IfElse<'ast, T> + From>,
+ E: Expr<'ast, T> + Select<'ast, T> + Conditional<'ast, T> + From>,
>(
&mut self,
ty: &E::Ty,
@@ -506,9 +506,9 @@ pub fn fold_array_expression_inner<'ast, T: Field, F: Folder<'ast, T>>(
FunctionCallOrExpression::Expression(u) => u,
}
}
- ArrayExpressionInner::IfElse(c) => match f.fold_if_else_expression(ty, c) {
- IfElseOrExpression::IfElse(s) => ArrayExpressionInner::IfElse(s),
- IfElseOrExpression::Expression(u) => u,
+ ArrayExpressionInner::Conditional(c) => match f.fold_conditional_expression(ty, c) {
+ ConditionalOrExpression::Conditional(s) => ArrayExpressionInner::Conditional(s),
+ ConditionalOrExpression::Expression(u) => u,
},
ArrayExpressionInner::Select(select) => match f.fold_select_expression(ty, select) {
SelectOrExpression::Select(s) => ArrayExpressionInner::Select(s),
@@ -553,9 +553,9 @@ pub fn fold_struct_expression_inner<'ast, T: Field, F: Folder<'ast, T>>(
FunctionCallOrExpression::Expression(u) => u,
}
}
- StructExpressionInner::IfElse(c) => match f.fold_if_else_expression(ty, c) {
- IfElseOrExpression::IfElse(s) => StructExpressionInner::IfElse(s),
- IfElseOrExpression::Expression(u) => u,
+ StructExpressionInner::Conditional(c) => match f.fold_conditional_expression(ty, c) {
+ ConditionalOrExpression::Conditional(s) => StructExpressionInner::Conditional(s),
+ ConditionalOrExpression::Expression(u) => u,
},
StructExpressionInner::Select(select) => match f.fold_select_expression(ty, select) {
SelectOrExpression::Select(s) => StructExpressionInner::Select(s),
@@ -615,10 +615,10 @@ pub fn fold_field_expression<'ast, T: Field, F: Folder<'ast, T>>(
FieldElementExpression::Pos(box e)
}
- FieldElementExpression::IfElse(c) => {
- match f.fold_if_else_expression(&Type::FieldElement, c) {
- IfElseOrExpression::IfElse(s) => FieldElementExpression::IfElse(s),
- IfElseOrExpression::Expression(u) => u,
+ FieldElementExpression::Conditional(c) => {
+ match f.fold_conditional_expression(&Type::FieldElement, c) {
+ ConditionalOrExpression::Conditional(s) => FieldElementExpression::Conditional(s),
+ ConditionalOrExpression::Expression(u) => u,
}
}
FieldElementExpression::FunctionCall(function_call) => {
@@ -643,20 +643,21 @@ pub fn fold_field_expression<'ast, T: Field, F: Folder<'ast, T>>(
}
}
-pub fn fold_if_else_expression<
+pub fn fold_conditional_expression<
'ast,
T: Field,
- E: Expr<'ast, T> + Fold<'ast, T> + IfElse<'ast, T> + From>,
+ E: Expr<'ast, T> + Fold<'ast, T> + Conditional<'ast, T> + From>,
F: Folder<'ast, T>,
>(
f: &mut F,
_: &E::Ty,
- e: IfElseExpression<'ast, T, E>,
-) -> IfElseOrExpression<'ast, T, E> {
- IfElseOrExpression::IfElse(IfElseExpression::new(
+ e: ConditionalExpression<'ast, T, E>,
+) -> ConditionalOrExpression<'ast, T, E> {
+ ConditionalOrExpression::Conditional(ConditionalExpression::new(
f.fold_boolean_expression(*e.condition),
e.consequence.fold(f),
e.alternative.fold(f),
+ e.kind,
))
}
@@ -679,7 +680,7 @@ pub fn fold_member_expression<
pub fn fold_select_expression<
'ast,
T: Field,
- E: Expr<'ast, T> + Select<'ast, T> + IfElse<'ast, T> + From>,
+ E: Expr<'ast, T> + Select<'ast, T> + Conditional<'ast, T> + From>,
F: Folder<'ast, T>,
>(
f: &mut F,
@@ -794,9 +795,10 @@ pub fn fold_boolean_expression<'ast, T: Field, F: Folder<'ast, T>>(
FunctionCallOrExpression::Expression(u) => u,
}
}
- BooleanExpression::IfElse(c) => match f.fold_if_else_expression(&Type::Boolean, c) {
- IfElseOrExpression::IfElse(s) => BooleanExpression::IfElse(s),
- IfElseOrExpression::Expression(u) => u,
+ BooleanExpression::Conditional(c) => match f.fold_conditional_expression(&Type::Boolean, c)
+ {
+ ConditionalOrExpression::Conditional(s) => BooleanExpression::Conditional(s),
+ ConditionalOrExpression::Expression(u) => u,
},
BooleanExpression::Select(select) => match f.fold_select_expression(&Type::Boolean, select)
{
@@ -922,9 +924,9 @@ pub fn fold_uint_expression_inner<'ast, T: Field, F: Folder<'ast, T>>(
SelectOrExpression::Select(s) => UExpressionInner::Select(s),
SelectOrExpression::Expression(u) => u,
},
- UExpressionInner::IfElse(c) => match f.fold_if_else_expression(&ty, c) {
- IfElseOrExpression::IfElse(s) => UExpressionInner::IfElse(s),
- IfElseOrExpression::Expression(u) => u,
+ UExpressionInner::Conditional(c) => match f.fold_conditional_expression(&ty, c) {
+ ConditionalOrExpression::Conditional(s) => UExpressionInner::Conditional(s),
+ ConditionalOrExpression::Expression(u) => u,
},
UExpressionInner::Member(m) => match f.fold_member_expression(&ty, m) {
MemberOrExpression::Member(m) => UExpressionInner::Member(m),
diff --git a/zokrates_core/src/typed_absy/integer.rs b/zokrates_core/src/typed_absy/integer.rs
index bd1b2444..63b3fb13 100644
--- a/zokrates_core/src/typed_absy/integer.rs
+++ b/zokrates_core/src/typed_absy/integer.rs
@@ -5,9 +5,10 @@ use crate::typed_absy::types::{
};
use crate::typed_absy::UBitwidth;
use crate::typed_absy::{
- ArrayExpression, ArrayExpressionInner, BooleanExpression, Expr, FieldElementExpression, IfElse,
- IfElseExpression, Select, SelectExpression, StructExpression, StructExpressionInner, Typed,
- TypedExpression, TypedExpressionOrSpread, TypedSpread, UExpression, UExpressionInner,
+ ArrayExpression, ArrayExpressionInner, BooleanExpression, Conditional, ConditionalExpression,
+ Expr, FieldElementExpression, Select, SelectExpression, StructExpression,
+ StructExpressionInner, Typed, TypedExpression, TypedExpressionOrSpread, TypedSpread,
+ UExpression, UExpressionInner,
};
use num_bigint::BigUint;
use std::convert::TryFrom;
@@ -239,7 +240,7 @@ pub enum IntExpression<'ast, T> {
Div(Box>, Box>),
Rem(Box>, Box>),
Pow(Box>, Box>),
- IfElse(IfElseExpression<'ast, T, IntExpression<'ast, T>>),
+ Conditional(ConditionalExpression<'ast, T, IntExpression<'ast, T>>),
Select(SelectExpression<'ast, T, IntExpression<'ast, T>>),
Xor(Box>, Box>),
And(Box>, Box>),
@@ -354,7 +355,7 @@ impl<'ast, T: fmt::Display> fmt::Display for IntExpression<'ast, T> {
IntExpression::RightShift(ref e, ref by) => write!(f, "({} >> {})", e, by),
IntExpression::LeftShift(ref e, ref by) => write!(f, "({} << {})", e, by),
IntExpression::Not(ref e) => write!(f, "!{}", e),
- IntExpression::IfElse(ref c) => write!(f, "{}", c),
+ IntExpression::Conditional(ref c) => write!(f, "{}", c),
}
}
}
@@ -404,10 +405,11 @@ impl<'ast, T: Field> FieldElementExpression<'ast, T> {
)),
IntExpression::Pos(box e) => Ok(Self::Pos(box Self::try_from_int(e)?)),
IntExpression::Neg(box e) => Ok(Self::Neg(box Self::try_from_int(e)?)),
- IntExpression::IfElse(c) => Ok(Self::IfElse(IfElseExpression::new(
+ IntExpression::Conditional(c) => Ok(Self::Conditional(ConditionalExpression::new(
*c.condition,
Self::try_from_int(*c.consequence)?,
Self::try_from_int(*c.alternative)?,
+ c.kind,
))),
IntExpression::Select(select) => {
let array = *select.array;
@@ -523,10 +525,11 @@ impl<'ast, T: Field> UExpression<'ast, T> {
Self::try_from_int(e1, bitwidth)?,
e2,
)),
- IfElse(c) => Ok(UExpression::if_else(
+ Conditional(c) => Ok(UExpression::conditional(
*c.condition,
Self::try_from_int(*c.consequence, bitwidth)?,
Self::try_from_int(*c.alternative, bitwidth)?,
+ c.kind,
)),
Select(select) => {
let array = *select.array;
@@ -693,6 +696,7 @@ impl<'ast, T> From for IntExpression<'ast, T> {
#[cfg(test)]
mod tests {
use super::*;
+ use crate::typed_absy::ConditionalKind;
use zokrates_field::Bn128Field;
#[test]
@@ -714,7 +718,7 @@ mod tests {
n.clone() * n.clone(),
IntExpression::pow(n.clone(), n.clone()),
n.clone() / n.clone(),
- IntExpression::if_else(c.clone(), n.clone(), n.clone()),
+ IntExpression::conditional(c.clone(), n.clone(), n.clone(), ConditionalKind::IfElse),
IntExpression::select(n_a.clone(), i.clone()),
];
@@ -725,7 +729,12 @@ mod tests {
t.clone() * t.clone(),
FieldElementExpression::pow(t.clone(), i.clone()),
t.clone() / t.clone(),
- FieldElementExpression::if_else(c.clone(), t.clone(), t.clone()),
+ FieldElementExpression::conditional(
+ c.clone(),
+ t.clone(),
+ t.clone(),
+ ConditionalKind::IfElse,
+ ),
FieldElementExpression::select(t_a.clone(), i.clone()),
];
@@ -780,7 +789,7 @@ mod tests {
IntExpression::left_shift(n.clone(), i.clone()),
IntExpression::right_shift(n.clone(), i.clone()),
!n.clone(),
- IntExpression::if_else(c.clone(), n.clone(), n.clone()),
+ IntExpression::conditional(c.clone(), n.clone(), n.clone(), ConditionalKind::IfElse),
IntExpression::select(n_a.clone(), i.clone()),
];
@@ -797,7 +806,7 @@ mod tests {
UExpression::left_shift(t.clone(), i.clone()),
UExpression::right_shift(t.clone(), i.clone()),
!t.clone(),
- UExpression::if_else(c.clone(), t.clone(), t.clone()),
+ UExpression::conditional(c.clone(), t.clone(), t.clone(), ConditionalKind::IfElse),
UExpression::select(t_a.clone(), i.clone()),
];
diff --git a/zokrates_core/src/typed_absy/mod.rs b/zokrates_core/src/typed_absy/mod.rs
index 6c970ff3..a37777af 100644
--- a/zokrates_core/src/typed_absy/mod.rs
+++ b/zokrates_core/src/typed_absy/mod.rs
@@ -792,7 +792,7 @@ impl<'ast, T: fmt::Display> fmt::Display for StructExpression<'ast, T> {
StructExpressionInner::FunctionCall(ref function_call) => {
write!(f, "{}", function_call)
}
- StructExpressionInner::IfElse(ref c) => write!(f, "{}", c),
+ StructExpressionInner::Conditional(ref c) => write!(f, "{}", c),
StructExpressionInner::Member(ref m) => write!(f, "{}", m),
StructExpressionInner::Select(ref select) => write!(f, "{}", select),
}
@@ -937,30 +937,50 @@ impl<'ast, T: fmt::Display, E> fmt::Display for SelectExpression<'ast, T, E> {
}
}
-#[derive(Clone, PartialEq, Debug, Hash, Eq, PartialOrd, Ord)]
-pub struct IfElseExpression<'ast, T, E> {
+#[derive(Debug, Clone, PartialEq, Hash, Eq, PartialOrd, Ord)]
+pub enum ConditionalKind {
+ IfElse,
+ Ternary,
+}
+
+#[derive(Debug, Clone, PartialEq, Hash, Eq, PartialOrd, Ord)]
+pub struct ConditionalExpression<'ast, T, E> {
pub condition: Box>,
pub consequence: Box,
pub alternative: Box,
+ pub kind: ConditionalKind,
}
-impl<'ast, T, E> IfElseExpression<'ast, T, E> {
- pub fn new(condition: BooleanExpression<'ast, T>, consequence: E, alternative: E) -> Self {
- IfElseExpression {
+impl<'ast, T, E> ConditionalExpression<'ast, T, E> {
+ pub fn new(
+ condition: BooleanExpression<'ast, T>,
+ consequence: E,
+ alternative: E,
+ kind: ConditionalKind,
+ ) -> Self {
+ ConditionalExpression {
condition: box condition,
consequence: box consequence,
alternative: box alternative,
+ kind,
}
}
}
-impl<'ast, T: fmt::Display, E: fmt::Display> fmt::Display for IfElseExpression<'ast, T, E> {
+impl<'ast, T: fmt::Display, E: fmt::Display> fmt::Display for ConditionalExpression<'ast, T, E> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
- write!(
- f,
- "if {} then {} else {} fi",
- self.condition, self.consequence, self.alternative
- )
+ match self.kind {
+ ConditionalKind::IfElse => write!(
+ f,
+ "if {} then {} else {} fi",
+ self.condition, self.consequence, self.alternative
+ ),
+ ConditionalKind::Ternary => write!(
+ f,
+ "{} ? {} : {}",
+ self.condition, self.consequence, self.alternative
+ ),
+ }
}
}
@@ -1042,7 +1062,7 @@ pub enum FieldElementExpression<'ast, T> {
Box>,
Box>,
),
- IfElse(IfElseExpression<'ast, T, Self>),
+ Conditional(ConditionalExpression<'ast, T, Self>),
Neg(Box>),
Pos(Box>),
FunctionCall(FunctionCallExpression<'ast, T, Self>),
@@ -1142,7 +1162,7 @@ pub enum BooleanExpression<'ast, T> {
Box>,
),
Not(Box>),
- IfElse(IfElseExpression<'ast, T, Self>),
+ Conditional(ConditionalExpression<'ast, T, Self>),
Member(MemberExpression<'ast, T, Self>),
FunctionCall(FunctionCallExpression<'ast, T, Self>),
Select(SelectExpression<'ast, T, Self>),
@@ -1249,7 +1269,7 @@ pub enum ArrayExpressionInner<'ast, T> {
Identifier(Identifier<'ast>),
Value(ArrayValue<'ast, T>),
FunctionCall(FunctionCallExpression<'ast, T, ArrayExpression<'ast, T>>),
- IfElse(IfElseExpression<'ast, T, ArrayExpression<'ast, T>>),
+ Conditional(ConditionalExpression<'ast, T, ArrayExpression<'ast, T>>),
Member(MemberExpression<'ast, T, ArrayExpression<'ast, T>>),
Select(SelectExpression<'ast, T, ArrayExpression<'ast, T>>),
Slice(
@@ -1313,7 +1333,7 @@ pub enum StructExpressionInner<'ast, T> {
Identifier(Identifier<'ast>),
Value(Vec>),
FunctionCall(FunctionCallExpression<'ast, T, StructExpression<'ast, T>>),
- IfElse(IfElseExpression<'ast, T, StructExpression<'ast, T>>),
+ Conditional(ConditionalExpression<'ast, T, StructExpression<'ast, T>>),
Member(MemberExpression<'ast, T, StructExpression<'ast, T>>),
Select(SelectExpression<'ast, T, StructExpression<'ast, T>>),
}
@@ -1455,7 +1475,7 @@ impl<'ast, T: fmt::Display> fmt::Display for FieldElementExpression<'ast, T> {
FieldElementExpression::Pow(ref lhs, ref rhs) => write!(f, "{}**{}", lhs, rhs),
FieldElementExpression::Neg(ref e) => write!(f, "(-{})", e),
FieldElementExpression::Pos(ref e) => write!(f, "(+{})", e),
- FieldElementExpression::IfElse(ref c) => write!(f, "{}", c),
+ FieldElementExpression::Conditional(ref c) => write!(f, "{}", c),
FieldElementExpression::FunctionCall(ref function_call) => {
write!(f, "{}", function_call)
}
@@ -1489,7 +1509,7 @@ impl<'ast, T: fmt::Display> fmt::Display for UExpression<'ast, T> {
UExpressionInner::Pos(ref e) => write!(f, "(+{})", e),
UExpressionInner::Select(ref select) => write!(f, "{}", select),
UExpressionInner::FunctionCall(ref function_call) => write!(f, "{}", function_call),
- UExpressionInner::IfElse(ref c) => write!(f, "{}", c),
+ UExpressionInner::Conditional(ref c) => write!(f, "{}", c),
UExpressionInner::Member(ref m) => write!(f, "{}", m),
}
}
@@ -1518,7 +1538,7 @@ impl<'ast, T: fmt::Display> fmt::Display for BooleanExpression<'ast, T> {
BooleanExpression::Not(ref exp) => write!(f, "!{}", exp),
BooleanExpression::Value(b) => write!(f, "{}", b),
BooleanExpression::FunctionCall(ref function_call) => write!(f, "{}", function_call),
- BooleanExpression::IfElse(ref c) => write!(f, "{}", c),
+ BooleanExpression::Conditional(ref c) => write!(f, "{}", c),
BooleanExpression::Member(ref m) => write!(f, "{}", m),
BooleanExpression::Select(ref select) => write!(f, "{}", select),
}
@@ -1540,7 +1560,7 @@ impl<'ast, T: fmt::Display> fmt::Display for ArrayExpressionInner<'ast, T> {
.join(", ")
),
ArrayExpressionInner::FunctionCall(ref function_call) => write!(f, "{}", function_call),
- ArrayExpressionInner::IfElse(ref c) => write!(f, "{}", c),
+ ArrayExpressionInner::Conditional(ref c) => write!(f, "{}", c),
ArrayExpressionInner::Member(ref m) => write!(f, "{}", m),
ArrayExpressionInner::Select(ref select) => write!(f, "{}", select),
ArrayExpressionInner::Slice(ref a, ref from, ref to) => {
@@ -1780,81 +1800,121 @@ pub enum MemberOrExpression<'ast, T, E: Expr<'ast, T>> {
Expression(E::Inner),
}
-pub enum IfElseOrExpression<'ast, T, E: Expr<'ast, T>> {
- IfElse(IfElseExpression<'ast, T, E>),
+pub enum ConditionalOrExpression<'ast, T, E: Expr<'ast, T>> {
+ Conditional(ConditionalExpression<'ast, T, E>),
Expression(E::Inner),
}
-pub trait IfElse<'ast, T> {
- fn if_else(condition: BooleanExpression<'ast, T>, consequence: Self, alternative: Self)
- -> Self;
-}
-
-impl<'ast, T> IfElse<'ast, T> for FieldElementExpression<'ast, T> {
- fn if_else(
+pub trait Conditional<'ast, T> {
+ fn conditional(
condition: BooleanExpression<'ast, T>,
consequence: Self,
alternative: Self,
+ kind: ConditionalKind,
+ ) -> Self;
+}
+
+impl<'ast, T> Conditional<'ast, T> for FieldElementExpression<'ast, T> {
+ fn conditional(
+ condition: BooleanExpression<'ast, T>,
+ consequence: Self,
+ alternative: Self,
+ kind: ConditionalKind,
) -> Self {
- FieldElementExpression::IfElse(IfElseExpression::new(condition, consequence, alternative))
+ FieldElementExpression::Conditional(ConditionalExpression::new(
+ condition,
+ consequence,
+ alternative,
+ kind,
+ ))
}
}
-impl<'ast, T> IfElse<'ast, T> for IntExpression<'ast, T> {
- fn if_else(
+impl<'ast, T> Conditional<'ast, T> for IntExpression<'ast, T> {
+ fn conditional(
condition: BooleanExpression<'ast, T>,
consequence: Self,
alternative: Self,
+ kind: ConditionalKind,
) -> Self {
- IntExpression::IfElse(IfElseExpression::new(condition, consequence, alternative))
+ IntExpression::Conditional(ConditionalExpression::new(
+ condition,
+ consequence,
+ alternative,
+ kind,
+ ))
}
}
-impl<'ast, T> IfElse<'ast, T> for BooleanExpression<'ast, T> {
- fn if_else(
+impl<'ast, T> Conditional<'ast, T> for BooleanExpression<'ast, T> {
+ fn conditional(
condition: BooleanExpression<'ast, T>,
consequence: Self,
alternative: Self,
+ kind: ConditionalKind,
) -> Self {
- BooleanExpression::IfElse(IfElseExpression::new(condition, consequence, alternative))
+ BooleanExpression::Conditional(ConditionalExpression::new(
+ condition,
+ consequence,
+ alternative,
+ kind,
+ ))
}
}
-impl<'ast, T> IfElse<'ast, T> for UExpression<'ast, T> {
- fn if_else(
+impl<'ast, T> Conditional<'ast, T> for UExpression<'ast, T> {
+ fn conditional(
condition: BooleanExpression<'ast, T>,
consequence: Self,
alternative: Self,
+ kind: ConditionalKind,
) -> Self {
let bitwidth = consequence.bitwidth;
- UExpressionInner::IfElse(IfElseExpression::new(condition, consequence, alternative))
- .annotate(bitwidth)
+ UExpressionInner::Conditional(ConditionalExpression::new(
+ condition,
+ consequence,
+ alternative,
+ kind,
+ ))
+ .annotate(bitwidth)
}
}
-impl<'ast, T: Clone> IfElse<'ast, T> for ArrayExpression<'ast, T> {
- fn if_else(
+impl<'ast, T: Clone> Conditional<'ast, T> for ArrayExpression<'ast, T> {
+ fn conditional(
condition: BooleanExpression<'ast, T>,
consequence: Self,
alternative: Self,
+ kind: ConditionalKind,
) -> Self {
let ty = consequence.inner_type().clone();
let size = consequence.size();
- ArrayExpressionInner::IfElse(IfElseExpression::new(condition, consequence, alternative))
- .annotate(ty, size)
+ ArrayExpressionInner::Conditional(ConditionalExpression::new(
+ condition,
+ consequence,
+ alternative,
+ kind,
+ ))
+ .annotate(ty, size)
}
}
-impl<'ast, T: Clone> IfElse<'ast, T> for StructExpression<'ast, T> {
- fn if_else(
+impl<'ast, T: Clone> Conditional<'ast, T> for StructExpression<'ast, T> {
+ fn conditional(
condition: BooleanExpression<'ast, T>,
consequence: Self,
alternative: Self,
+ kind: ConditionalKind,
) -> Self {
let ty = consequence.ty().clone();
- StructExpressionInner::IfElse(IfElseExpression::new(condition, consequence, alternative))
- .annotate(ty)
+ StructExpressionInner::Conditional(ConditionalExpression::new(
+ condition,
+ consequence,
+ alternative,
+ kind,
+ ))
+ .annotate(ty)
}
}
diff --git a/zokrates_core/src/typed_absy/result_folder.rs b/zokrates_core/src/typed_absy/result_folder.rs
index 62a745af..5986ca22 100644
--- a/zokrates_core/src/typed_absy/result_folder.rs
+++ b/zokrates_core/src/typed_absy/result_folder.rs
@@ -184,14 +184,14 @@ pub trait ResultFolder<'ast, T: Field>: Sized {
fold_types(self, tys)
}
- fn fold_if_else_expression<
- E: Expr<'ast, T> + PartialEq + IfElse<'ast, T> + ResultFold<'ast, T>,
+ fn fold_conditional_expression<
+ E: Expr<'ast, T> + PartialEq + Conditional<'ast, T> + ResultFold<'ast, T>,
>(
&mut self,
ty: &E::Ty,
- e: IfElseExpression<'ast, T, E>,
- ) -> Result, Self::Error> {
- fold_if_else_expression(self, ty, e)
+ e: ConditionalExpression<'ast, T, E>,
+ ) -> Result, Self::Error> {
+ fold_conditional_expression(self, ty, e)
}
fn fold_block_expression>(
@@ -507,9 +507,9 @@ pub fn fold_array_expression_inner<'ast, T: Field, F: ResultFolder<'ast, T>>(
FunctionCallOrExpression::Expression(u) => u,
}
}
- ArrayExpressionInner::IfElse(c) => match f.fold_if_else_expression(ty, c)? {
- IfElseOrExpression::IfElse(c) => ArrayExpressionInner::IfElse(c),
- IfElseOrExpression::Expression(u) => u,
+ ArrayExpressionInner::Conditional(c) => match f.fold_conditional_expression(ty, c)? {
+ ConditionalOrExpression::Conditional(c) => ArrayExpressionInner::Conditional(c),
+ ConditionalOrExpression::Expression(u) => u,
},
ArrayExpressionInner::Member(m) => match f.fold_member_expression(ty, m)? {
MemberOrExpression::Member(m) => ArrayExpressionInner::Member(m),
@@ -572,9 +572,9 @@ pub fn fold_struct_expression_inner<'ast, T: Field, F: ResultFolder<'ast, T>>(
FunctionCallOrExpression::Expression(u) => u,
}
}
- StructExpressionInner::IfElse(c) => match f.fold_if_else_expression(ty, c)? {
- IfElseOrExpression::IfElse(c) => StructExpressionInner::IfElse(c),
- IfElseOrExpression::Expression(u) => u,
+ StructExpressionInner::Conditional(c) => match f.fold_conditional_expression(ty, c)? {
+ ConditionalOrExpression::Conditional(c) => StructExpressionInner::Conditional(c),
+ ConditionalOrExpression::Expression(u) => u,
},
StructExpressionInner::Member(m) => match f.fold_member_expression(ty, m)? {
MemberOrExpression::Member(m) => StructExpressionInner::Member(m),
@@ -635,10 +635,10 @@ pub fn fold_field_expression<'ast, T: Field, F: ResultFolder<'ast, T>>(
FieldElementExpression::Pos(box e)
}
- FieldElementExpression::IfElse(c) => {
- match f.fold_if_else_expression(&Type::FieldElement, c)? {
- IfElseOrExpression::IfElse(c) => FieldElementExpression::IfElse(c),
- IfElseOrExpression::Expression(u) => u,
+ FieldElementExpression::Conditional(c) => {
+ match f.fold_conditional_expression(&Type::FieldElement, c)? {
+ ConditionalOrExpression::Conditional(c) => FieldElementExpression::Conditional(c),
+ ConditionalOrExpression::Expression(u) => u,
}
}
FieldElementExpression::FunctionCall(function_call) => {
@@ -689,11 +689,11 @@ pub fn fold_block_expression<'ast, T: Field, E: ResultFold<'ast, T>, F: ResultFo
})
}
-pub fn fold_if_else_expression<
+pub fn fold_conditional_expression<
'ast,
T: Field,
E: Expr<'ast, T>
- + IfElse<'ast, T>
+ + Conditional<'ast, T>
+ PartialEq
+ ResultFold<'ast, T>
+ From>,
@@ -701,13 +701,16 @@ pub fn fold_if_else_expression<
>(
f: &mut F,
_: &E::Ty,
- e: IfElseExpression<'ast, T, E>,
-) -> Result, F::Error> {
- Ok(IfElseOrExpression::IfElse(IfElseExpression::new(
- f.fold_boolean_expression(*e.condition)?,
- e.consequence.fold(f)?,
- e.alternative.fold(f)?,
- )))
+ e: ConditionalExpression<'ast, T, E>,
+) -> Result, F::Error> {
+ Ok(ConditionalOrExpression::Conditional(
+ ConditionalExpression::new(
+ f.fold_boolean_expression(*e.condition)?,
+ e.consequence.fold(f)?,
+ e.alternative.fold(f)?,
+ e.kind,
+ ),
+ ))
}
pub fn fold_member_expression<
@@ -863,10 +866,12 @@ pub fn fold_boolean_expression<'ast, T: Field, F: ResultFolder<'ast, T>>(
FunctionCallOrExpression::Expression(u) => u,
}
}
- BooleanExpression::IfElse(c) => match f.fold_if_else_expression(&Type::Boolean, c)? {
- IfElseOrExpression::IfElse(c) => BooleanExpression::IfElse(c),
- IfElseOrExpression::Expression(u) => u,
- },
+ BooleanExpression::Conditional(c) => {
+ match f.fold_conditional_expression(&Type::Boolean, c)? {
+ ConditionalOrExpression::Conditional(c) => BooleanExpression::Conditional(c),
+ ConditionalOrExpression::Expression(u) => u,
+ }
+ }
BooleanExpression::Select(select) => {
match f.fold_select_expression(&Type::Boolean, select)? {
SelectOrExpression::Select(s) => BooleanExpression::Select(s),
@@ -991,9 +996,9 @@ pub fn fold_uint_expression_inner<'ast, T: Field, F: ResultFolder<'ast, T>>(
SelectOrExpression::Select(s) => UExpressionInner::Select(s),
SelectOrExpression::Expression(u) => u,
},
- UExpressionInner::IfElse(c) => match f.fold_if_else_expression(&ty, c)? {
- IfElseOrExpression::IfElse(c) => UExpressionInner::IfElse(c),
- IfElseOrExpression::Expression(u) => u,
+ UExpressionInner::Conditional(c) => match f.fold_conditional_expression(&ty, c)? {
+ ConditionalOrExpression::Conditional(c) => UExpressionInner::Conditional(c),
+ ConditionalOrExpression::Expression(u) => u,
},
UExpressionInner::Member(m) => match f.fold_member_expression(&ty, m)? {
MemberOrExpression::Member(m) => UExpressionInner::Member(m),
diff --git a/zokrates_core/src/typed_absy/uint.rs b/zokrates_core/src/typed_absy/uint.rs
index df4e476e..dc5aa09b 100644
--- a/zokrates_core/src/typed_absy/uint.rs
+++ b/zokrates_core/src/typed_absy/uint.rs
@@ -193,7 +193,7 @@ pub enum UExpressionInner<'ast, T> {
FunctionCall(FunctionCallExpression<'ast, T, UExpression<'ast, T>>),
LeftShift(Box>, Box>),
RightShift(Box>, Box>),
- IfElse(IfElseExpression<'ast, T, UExpression<'ast, T>>),
+ Conditional(ConditionalExpression<'ast, T, UExpression<'ast, T>>),
Member(MemberExpression<'ast, T, UExpression<'ast, T>>),
Select(SelectExpression<'ast, T, UExpression<'ast, T>>),
}
diff --git a/zokrates_core_test/tests/tests/ternary.json b/zokrates_core_test/tests/tests/ternary.json
new file mode 100644
index 00000000..fb751f4d
--- /dev/null
+++ b/zokrates_core_test/tests/tests/ternary.json
@@ -0,0 +1,36 @@
+{
+ "entry_point": "./tests/tests/ternary.zok",
+ "max_constraint_count": 6,
+ "tests": [
+ {
+ "input": {
+ "values": ["1", "0"]
+ },
+ "output": {
+ "Ok": {
+ "values": ["1"]
+ }
+ }
+ },
+ {
+ "input": {
+ "values": ["0", "1"]
+ },
+ "output": {
+ "Ok": {
+ "values": ["2"]
+ }
+ }
+ },
+ {
+ "input": {
+ "values": ["0", "0"]
+ },
+ "output": {
+ "Ok": {
+ "values": ["3"]
+ }
+ }
+ }
+ ]
+}
diff --git a/zokrates_core_test/tests/tests/ternary.zok b/zokrates_core_test/tests/tests/ternary.zok
new file mode 100644
index 00000000..7bd9486e
--- /dev/null
+++ b/zokrates_core_test/tests/tests/ternary.zok
@@ -0,0 +1,5 @@
+def main(bool a, bool b) -> field:
+ field x = a ? 1 : b ? 2 : 3 // (a ? 1 : (b ? 2 : 3))
+ field y = if a then 1 else if b then 2 else 3 fi fi
+ assert(x == y)
+ return x
\ No newline at end of file
diff --git a/zokrates_parser/src/zokrates.pest b/zokrates_parser/src/zokrates.pest
index e766d084..06970e18 100644
--- a/zokrates_parser/src/zokrates.pest
+++ b/zokrates_parser/src/zokrates.pest
@@ -68,13 +68,13 @@ expression = { unaried_term ~ (op_binary ~ unaried_term)* }
unaried_term = { op_unary? ~ powered_term }
powered_term = { postfixed_term ~ (op_pow ~ exponent_expression)? }
postfixed_term = { term ~ access* }
-term = { ("(" ~ expression ~ ")") | inline_struct_expression | conditional_expression | primary_expression | inline_array_expression | array_initializer_expression }
+term = { ("(" ~ expression ~ ")") | inline_struct_expression | if_else_expression | primary_expression | inline_array_expression | array_initializer_expression }
spread = { "..." ~ expression }
range = { from_expression? ~ ".." ~ to_expression? }
from_expression = { expression }
to_expression = { expression }
-conditional_expression = { "if" ~ expression ~ "then" ~ expression ~ "else" ~ expression ~ "fi"}
+if_else_expression = { "if" ~ expression ~ "then" ~ expression ~ "else" ~ expression ~ "fi"}
access = { array_access | call_access | member_access }
array_access = { "[" ~ range_or_expression ~ "]" }
@@ -155,8 +155,10 @@ op_neg = {"-"}
op_pos = {"+"}
op_left_shift = @{"<<"}
op_right_shift = @{">>"}
+op_ternary = {"?" ~ expression ~ ":"}
+
// `op_pow` is *not* in `op_binary` because its precedence is handled in this parser rather than down the line in precedence climbing
-op_binary = _ { op_or | op_and | op_bit_xor | op_bit_and | op_bit_or | op_left_shift | op_right_shift | op_equal | op_not_equal | op_lte | op_lt | op_gte | op_gt | op_add | op_sub | op_mul | op_div | op_rem }
+op_binary = _ { op_or | op_and | op_bit_xor | op_bit_and | op_bit_or | op_left_shift | op_right_shift | op_equal | op_not_equal | op_lte | op_lt | op_gte | op_gt | op_add | op_sub | op_mul | op_div | op_rem | op_ternary }
op_unary = { op_pos | op_neg | op_not }
WHITESPACE = _{ " " | "\t" | "\\" ~ COMMENT? ~ NEWLINE}
diff --git a/zokrates_pest_ast/src/lib.rs b/zokrates_pest_ast/src/lib.rs
index 3bc24404..e35b3a42 100644
--- a/zokrates_pest_ast/src/lib.rs
+++ b/zokrates_pest_ast/src/lib.rs
@@ -13,7 +13,7 @@ pub use ast::{
CallAccess, ConstantDefinition, ConstantGenericValue, DecimalLiteralExpression, DecimalNumber,
DecimalSuffix, DefinitionStatement, ExplicitGenerics, Expression, FieldType, File,
FromExpression, FunctionDefinition, HexLiteralExpression, HexNumberExpression,
- IdentifierExpression, ImportDirective, ImportSymbol, InlineArrayExpression,
+ IdentifierExpression, IfElseExpression, ImportDirective, ImportSymbol, InlineArrayExpression,
InlineStructExpression, InlineStructMember, IterationStatement, LiteralExpression, Parameter,
PostfixExpression, Range, RangeOrExpression, ReturnStatement, Span, Spread, SpreadOrExpression,
Statement, StructDefinition, StructField, SymbolDeclaration, TernaryExpression, ToExpression,
@@ -38,6 +38,7 @@ mod ast {
// based on https://docs.python.org/3/reference/expressions.html#operator-precedence
fn build_precedence_climber() -> PrecClimber {
PrecClimber::new(vec![
+ Operator::new(Rule::op_ternary, Assoc::Right),
Operator::new(Rule::op_or, Assoc::Left),
Operator::new(Rule::op_and, Assoc::Left),
Operator::new(Rule::op_lt, Assoc::Left)
@@ -89,6 +90,12 @@ mod ast {
Rule::op_bit_or => Expression::binary(BinaryOperator::BitOr, lhs, rhs, span),
Rule::op_right_shift => Expression::binary(BinaryOperator::RightShift, lhs, rhs, span),
Rule::op_left_shift => Expression::binary(BinaryOperator::LeftShift, lhs, rhs, span),
+ Rule::op_ternary => Expression::ternary(
+ lhs,
+ Box::new(Expression::from_pest(&mut pair.into_inner()).unwrap()),
+ rhs,
+ span,
+ ),
_ => unreachable!(),
})
}
@@ -412,6 +419,7 @@ mod ast {
#[derive(Debug, PartialEq, Clone)]
pub enum Expression<'ast> {
Ternary(TernaryExpression<'ast>),
+ IfElse(IfElseExpression<'ast>),
Binary(BinaryExpression<'ast>),
Unary(UnaryExpression<'ast>),
Postfix(PostfixExpression<'ast>),
@@ -427,7 +435,7 @@ mod ast {
pub enum Term<'ast> {
Expression(Expression<'ast>),
InlineStruct(InlineStructExpression<'ast>),
- Ternary(TernaryExpression<'ast>),
+ IfElse(IfElseExpression<'ast>),
Primary(PrimaryExpression<'ast>),
InlineArray(InlineArrayExpression<'ast>),
ArrayInitializer(ArrayInitializerExpression<'ast>),
@@ -543,7 +551,7 @@ mod ast {
fn from(t: Term<'ast>) -> Self {
match t {
Term::Expression(e) => e,
- Term::Ternary(e) => Expression::Ternary(e),
+ Term::IfElse(e) => Expression::IfElse(e),
Term::Primary(e) => e.into(),
Term::InlineArray(e) => Expression::InlineArray(e),
Term::InlineStruct(e) => Expression::InlineStruct(e),
@@ -761,27 +769,49 @@ mod ast {
pub span: Span<'ast>,
}
- #[derive(Debug, FromPest, PartialEq, Clone)]
- #[pest_ast(rule(Rule::conditional_expression))]
+ #[derive(Debug, PartialEq, Clone)]
pub struct TernaryExpression<'ast> {
- pub first: Box>,
- pub second: Box>,
- pub third: Box>,
+ pub condition: Box>,
+ pub consequence: Box>,
+ pub alternative: Box>,
+ pub span: Span<'ast>,
+ }
+
+ #[derive(Debug, FromPest, PartialEq, Clone)]
+ #[pest_ast(rule(Rule::if_else_expression))]
+ pub struct IfElseExpression<'ast> {
+ pub condition: Box>,
+ pub consequence: Box>,
+ pub alternative: Box>,
#[pest_ast(outer())]
pub span: Span<'ast>,
}
impl<'ast> Expression<'ast> {
+ pub fn if_else(
+ condition: Box>,
+ consequence: Box>,
+ alternative: Box>,
+ span: Span<'ast>,
+ ) -> Self {
+ Expression::IfElse(IfElseExpression {
+ condition,
+ consequence,
+ alternative,
+ span,
+ })
+ }
+
pub fn ternary(
- first: Box>,
- second: Box>,
- third: Box>,
+ condition: Box>,
+ consequence: Box>,
+ alternative: Box>,
span: Span<'ast>,
) -> Self {
Expression::Ternary(TernaryExpression {
- first,
- second,
- third,
+ condition,
+ consequence,
+ alternative,
span,
})
}
@@ -806,6 +836,7 @@ mod ast {
Expression::Identifier(i) => &i.span,
Expression::Literal(c) => c.span(),
Expression::Ternary(t) => &t.span,
+ Expression::IfElse(ie) => &ie.span,
Expression::Postfix(p) => &p.span,
Expression::InlineArray(a) => &a.span,
Expression::InlineStruct(s) => &s.span,
@@ -1071,20 +1102,6 @@ mod tests {
pub fn pow(left: Expression<'ast>, right: Expression<'ast>, span: Span<'ast>) -> Self {
Self::binary(BinaryOperator::Pow, Box::new(left), Box::new(right), span)
}
-
- pub fn if_else(
- condition: Expression<'ast>,
- consequence: Expression<'ast>,
- alternative: Expression<'ast>,
- span: Span<'ast>,
- ) -> Self {
- Self::ternary(
- Box::new(condition),
- Box::new(consequence),
- Box::new(alternative),
- span,
- )
- }
}
#[test]
@@ -1263,7 +1280,7 @@ mod tests {
}))],
statements: vec![Statement::Return(ReturnStatement {
expressions: vec![Expression::if_else(
- Expression::Literal(LiteralExpression::DecimalLiteral(
+ Box::new(Expression::Literal(LiteralExpression::DecimalLiteral(
DecimalLiteralExpression {
suffix: None,
value: DecimalNumber {
@@ -1271,8 +1288,8 @@ mod tests {
},
span: Span::new(source, 62, 63).unwrap()
}
- )),
- Expression::Literal(LiteralExpression::DecimalLiteral(
+ ))),
+ Box::new(Expression::Literal(LiteralExpression::DecimalLiteral(
DecimalLiteralExpression {
suffix: None,
value: DecimalNumber {
@@ -1280,8 +1297,8 @@ mod tests {
},
span: Span::new(source, 69, 70).unwrap()
}
- )),
- Expression::Literal(LiteralExpression::DecimalLiteral(
+ ))),
+ Box::new(Expression::Literal(LiteralExpression::DecimalLiteral(
DecimalLiteralExpression {
suffix: None,
value: DecimalNumber {
@@ -1289,8 +1306,8 @@ mod tests {
},
span: Span::new(source, 76, 77).unwrap()
}
- )),
- Span::new(source, 59, 80).unwrap()
+ ))),
+ Span::new(&source, 59, 80).unwrap()
)],
span: Span::new(source, 52, 80).unwrap(),
})],