diff --git a/zokrates_core/src/absy/from_ast.rs b/zokrates_core/src/absy/from_ast.rs index 5fbed664..be2cec1b 100644 --- a/zokrates_core/src/absy/from_ast.rs +++ b/zokrates_core/src/absy/from_ast.rs @@ -507,6 +507,12 @@ impl<'ast> From> for absy::ExpressionNode<'ast> { pest::UnaryOperator::Not(_) => { absy::Expression::Not(Box::new(absy::ExpressionNode::from(*unary.expression))) } + pest::UnaryOperator::Neg(_) => { + absy::Expression::Neg(Box::new(absy::ExpressionNode::from(*unary.expression))) + } + pest::UnaryOperator::Pos(_) => { + absy::Expression::Pos(Box::new(absy::ExpressionNode::from(*unary.expression))) + } } .span(unary.span) } diff --git a/zokrates_core/src/absy/mod.rs b/zokrates_core/src/absy/mod.rs index a497c374..71d8703f 100644 --- a/zokrates_core/src/absy/mod.rs +++ b/zokrates_core/src/absy/mod.rs @@ -484,6 +484,8 @@ pub enum Expression<'ast> { Div(Box>, Box>), Rem(Box>, Box>), Pow(Box>, Box>), + Neg(Box>), + Pos(Box>), IfElse( Box>, Box>, @@ -525,6 +527,8 @@ impl<'ast> fmt::Display for Expression<'ast> { Expression::Div(ref lhs, ref rhs) => write!(f, "({} / {})", lhs, rhs), Expression::Rem(ref lhs, ref rhs) => write!(f, "({} % {})", lhs, rhs), Expression::Pow(ref lhs, ref rhs) => write!(f, "({}**{})", lhs, rhs), + Expression::Neg(ref e) => write!(f, "(-{})", e), + Expression::Pos(ref e) => write!(f, "(+{})", e), Expression::BooleanConstant(b) => write!(f, "{}", b), Expression::IfElse(ref condition, ref consequent, ref alternative) => write!( f, @@ -594,6 +598,8 @@ impl<'ast> fmt::Debug for Expression<'ast> { Expression::Div(ref lhs, ref rhs) => write!(f, "Div({:?}, {:?})", lhs, rhs), Expression::Rem(ref lhs, ref rhs) => write!(f, "Rem({:?}, {:?})", lhs, rhs), Expression::Pow(ref lhs, ref rhs) => write!(f, "Pow({:?}, {:?})", lhs, rhs), + Expression::Neg(ref e) => write!(f, "Neg({:?})", e), + Expression::Pos(ref e) => write!(f, "Pos({:?})", e), Expression::BooleanConstant(b) => write!(f, "{}", b), Expression::IfElse(ref condition, ref consequent, ref alternative) => write!( f, diff --git a/zokrates_core/src/semantics.rs b/zokrates_core/src/semantics.rs index 4e074645..614bdd27 100644 --- a/zokrates_core/src/semantics.rs +++ b/zokrates_core/src/semantics.rs @@ -1423,6 +1423,42 @@ impl<'ast> Checker<'ast> { }), } } + Expression::Neg(box e) => { + let e = self.check_expression(e, module_id, &types)?; + + match e { + TypedExpression::FieldElement(e) => { + Ok(FieldElementExpression::Neg(box e).into()) + } + TypedExpression::Uint(e) => Ok(UExpression::neg(e).into()), + e => Err(ErrorInner { + pos: Some(pos), + message: format!( + "Unary operator `-` cannot be applied to {} of type {}", + e, + e.get_type() + ), + }), + } + } + Expression::Pos(box e) => { + let e = self.check_expression(e, module_id, &types)?; + + match e { + TypedExpression::FieldElement(e) => { + Ok(FieldElementExpression::Pos(box e).into()) + } + TypedExpression::Uint(e) => Ok(UExpression::pos(e).into()), + e => Err(ErrorInner { + pos: Some(pos), + message: format!( + "Unary operator `+` cannot be applied to {} of type {}", + e, + e.get_type() + ), + }), + } + } Expression::IfElse(box condition, box consequence, box alternative) => { let condition_checked = self.check_expression(condition, module_id, &types)?; let consequence_checked = self.check_expression(consequence, module_id, &types)?; diff --git a/zokrates_core/src/static_analysis/flatten_complex_types.rs b/zokrates_core/src/static_analysis/flatten_complex_types.rs index 708bd68d..2f516a81 100644 --- a/zokrates_core/src/static_analysis/flatten_complex_types.rs +++ b/zokrates_core/src/static_analysis/flatten_complex_types.rs @@ -528,6 +528,19 @@ pub fn fold_field_expression<'ast, T: Field>( let e2 = f.fold_field_expression(e2); zir::FieldElementExpression::Pow(box e1, box e2) } + typed_absy::FieldElementExpression::Neg(box e) => { + let e = f.fold_field_expression(e); + + zir::FieldElementExpression::Sub( + box zir::FieldElementExpression::Number(T::zero()), + box e, + ) + } + typed_absy::FieldElementExpression::Pos(box e) => { + let e = f.fold_field_expression(e); + + e + } typed_absy::FieldElementExpression::IfElse(box cond, box cons, box alt) => { let cond = f.fold_boolean_expression(cond); let cons = f.fold_field_expression(cons); @@ -811,6 +824,21 @@ pub fn fold_uint_expression_inner<'ast, T: Field>( zir::UExpressionInner::Not(box e) } + typed_absy::UExpressionInner::Neg(box e) => { + let bitwidth = e.bitwidth(); + + f.fold_uint_expression(typed_absy::UExpression::sub( + typed_absy::UExpressionInner::Value(0).annotate(bitwidth), + e, + )) + .into_inner() + } + + typed_absy::UExpressionInner::Pos(box e) => { + let e = f.fold_uint_expression(e); + + e.into_inner() + } typed_absy::UExpressionInner::FunctionCall(..) => { unreachable!("function calls should have been removed") } diff --git a/zokrates_core/src/static_analysis/propagation.rs b/zokrates_core/src/static_analysis/propagation.rs index 2f2993ce..3388ba32 100644 --- a/zokrates_core/src/static_analysis/propagation.rs +++ b/zokrates_core/src/static_analysis/propagation.rs @@ -722,6 +722,25 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> { e => UExpressionInner::Not(box e.annotate(bitwidth)), } } + UExpressionInner::Neg(box e) => { + let e = self.fold_uint_expression(e).into_inner(); + match e { + UExpressionInner::Value(v) => { + use std::convert::TryInto; + UExpressionInner::Value( + 2_u128.pow(bitwidth.to_usize().try_into().unwrap()) - v, + ) + } + e => UExpressionInner::Neg(box e.annotate(bitwidth)), + } + } + UExpressionInner::Pos(box e) => { + let e = self.fold_uint_expression(e).into_inner(); + match e { + UExpressionInner::Value(v) => UExpressionInner::Value(v), + e => UExpressionInner::Pos(box e.annotate(bitwidth)), + } + } UExpressionInner::Select(box array, box index) => { let array = self.fold_array_expression(array); let index = self.fold_field_expression(index); @@ -830,6 +849,14 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> { } (e1, e2) => FieldElementExpression::Div(box e1, box e2), }, + FieldElementExpression::Neg(box e) => match self.fold_field_expression(e) { + FieldElementExpression::Number(n) => FieldElementExpression::Number(T::zero() - n), + e => FieldElementExpression::Neg(box e), + }, + FieldElementExpression::Pos(box e) => match self.fold_field_expression(e) { + FieldElementExpression::Number(n) => FieldElementExpression::Number(n), + e => FieldElementExpression::Pos(box e), + }, FieldElementExpression::Pow(box e1, box e2) => { let e1 = self.fold_field_expression(e1); let e2 = self.fold_field_expression(e2); diff --git a/zokrates_core/src/typed_absy/folder.rs b/zokrates_core/src/typed_absy/folder.rs index 76a81ff2..a6a8f8b4 100644 --- a/zokrates_core/src/typed_absy/folder.rs +++ b/zokrates_core/src/typed_absy/folder.rs @@ -277,6 +277,16 @@ pub fn fold_field_expression<'ast, T: Field, F: Folder<'ast, T>>( let e2 = f.fold_field_expression(e2); FieldElementExpression::Pow(box e1, box e2) } + FieldElementExpression::Neg(box e) => { + let e = f.fold_field_expression(e); + + FieldElementExpression::Neg(box e) + } + FieldElementExpression::Pos(box e) => { + let e = f.fold_field_expression(e); + + FieldElementExpression::Pos(box e) + } FieldElementExpression::IfElse(box cond, box cons, box alt) => { let cond = f.fold_boolean_expression(cond); let cons = f.fold_field_expression(cons); @@ -470,6 +480,16 @@ pub fn fold_uint_expression_inner<'ast, T: Field, F: Folder<'ast, T>>( UExpressionInner::Not(box e) } + UExpressionInner::Neg(box e) => { + let e = f.fold_uint_expression(e); + + UExpressionInner::Neg(box e) + } + UExpressionInner::Pos(box e) => { + let e = f.fold_uint_expression(e); + + UExpressionInner::Pos(box e) + } UExpressionInner::FunctionCall(key, exps) => { let exps = exps.into_iter().map(|e| f.fold_expression(e)).collect(); UExpressionInner::FunctionCall(key, exps) diff --git a/zokrates_core/src/typed_absy/mod.rs b/zokrates_core/src/typed_absy/mod.rs index 0bb7f7c9..e158b7ae 100644 --- a/zokrates_core/src/typed_absy/mod.rs +++ b/zokrates_core/src/typed_absy/mod.rs @@ -626,6 +626,8 @@ pub enum FieldElementExpression<'ast, T> { Box>, Box>, ), + Neg(Box>), + Pos(Box>), FunctionCall(FunctionKey<'ast>, Vec>), Member(Box>, MemberId), Select( @@ -876,6 +878,8 @@ impl<'ast, T: fmt::Display> fmt::Display for FieldElementExpression<'ast, T> { FieldElementExpression::Mult(ref lhs, ref rhs) => write!(f, "({} * {})", lhs, rhs), FieldElementExpression::Div(ref lhs, ref rhs) => write!(f, "({} / {})", lhs, rhs), FieldElementExpression::Pow(ref lhs, ref rhs) => write!(f, "{}**{}", lhs, rhs), + FieldElementExpression::Neg(ref e) => write!(f, "(-{})", e), + FieldElementExpression::Pos(ref e) => write!(f, "(+{})", e), FieldElementExpression::IfElse(ref condition, ref consequent, ref alternative) => { write!( f, @@ -915,6 +919,8 @@ impl<'ast, T: fmt::Display> fmt::Display for UExpression<'ast, T> { UExpressionInner::RightShift(ref e, ref by) => write!(f, "({} >> {})", e, by), UExpressionInner::LeftShift(ref e, ref by) => write!(f, "({} << {})", e, by), UExpressionInner::Not(ref e) => write!(f, "!{}", e), + UExpressionInner::Neg(ref e) => write!(f, "(-{})", e), + UExpressionInner::Pos(ref e) => write!(f, "(+{})", e), UExpressionInner::Select(ref id, ref index) => write!(f, "{}[{}]", id, index), UExpressionInner::FunctionCall(ref k, ref p) => { write!(f, "{}(", k.id,)?; @@ -1067,6 +1073,8 @@ impl<'ast, T: fmt::Debug> fmt::Debug for FieldElementExpression<'ast, T> { } FieldElementExpression::Div(ref lhs, ref rhs) => write!(f, "Div({:?}, {:?})", lhs, rhs), FieldElementExpression::Pow(ref lhs, ref rhs) => write!(f, "Pow({:?}, {:?})", lhs, rhs), + FieldElementExpression::Neg(ref e) => write!(f, "Neg({:?})", e), + FieldElementExpression::Pos(ref e) => write!(f, "Pos({:?})", e), FieldElementExpression::IfElse(ref condition, ref consequent, ref alternative) => { write!( f, diff --git a/zokrates_core/src/typed_absy/uint.rs b/zokrates_core/src/typed_absy/uint.rs index 0403d15e..de0c7055 100644 --- a/zokrates_core/src/typed_absy/uint.rs +++ b/zokrates_core/src/typed_absy/uint.rs @@ -58,6 +58,16 @@ impl<'ast, T: Field> UExpression<'ast, T> { UExpressionInner::Not(box self).annotate(bitwidth) } + pub fn neg(self) -> UExpression<'ast, T> { + let bitwidth = self.bitwidth; + UExpressionInner::Neg(box self).annotate(bitwidth) + } + + pub fn pos(self) -> UExpression<'ast, T> { + let bitwidth = self.bitwidth; + UExpressionInner::Pos(box self).annotate(bitwidth) + } + pub fn left_shift(self, by: FieldElementExpression<'ast, T>) -> UExpression<'ast, T> { let bitwidth = self.bitwidth; UExpressionInner::LeftShift(box self, box by).annotate(bitwidth) @@ -107,6 +117,8 @@ pub enum UExpressionInner<'ast, T> { And(Box>, Box>), Or(Box>, Box>), Not(Box>), + Neg(Box>), + Pos(Box>), LeftShift( Box>, Box>, diff --git a/zokrates_core_test/tests/tests/neg_pos.json b/zokrates_core_test/tests/tests/neg_pos.json new file mode 100644 index 00000000..a5e6d782 --- /dev/null +++ b/zokrates_core_test/tests/tests/neg_pos.json @@ -0,0 +1,15 @@ +{ + "entry_point": "./tests/tests/neg_pos.zok", + "tests": [ + { + "input": { + "values": ["1", "2", "1", "2"] + }, + "output": { + "Ok": { + "values": ["21888242871839275222246405745257275088548364400416034343698204186575808495615", "21888242871839275222246405745257275088548364400416034343698204186575808495616", "21888242871839275222246405745257275088548364400416034343698204186575808495616", "21888242871839275222246405745257275088548364400416034343698204186575808495616", "254", "255", "255", "255"] + } + } + } + ] +} \ No newline at end of file diff --git a/zokrates_core_test/tests/tests/neg_pos.zok b/zokrates_core_test/tests/tests/neg_pos.zok new file mode 100644 index 00000000..560db199 --- /dev/null +++ b/zokrates_core_test/tests/tests/neg_pos.zok @@ -0,0 +1,11 @@ +def main(field x, field y, u8 z, u8 t) -> (field[4], u8[4]): + field a = -y // should parse to neg + field b = x - y // should parse to sub + field c = x + - y // should parse to add(neg) + field d = x - + y // should parse to sub(pos) + + u8 e = -t // should parse to neg + u8 f = z - t // should parse to sub + u8 g = z + - t // should parse to add(neg) + u8 h = z - + t // should parse to sub(pos) + return [a, b, c, d], [e, f, g, h] \ No newline at end of file diff --git a/zokrates_parser/src/zokrates.pest b/zokrates_parser/src/zokrates.pest index b6afdd04..612ce20d 100644 --- a/zokrates_parser/src/zokrates.pest +++ b/zokrates_parser/src/zokrates.pest @@ -117,10 +117,12 @@ op_div = {"/"} op_rem = {"%"} op_pow = @{"**"} op_not = {"!"} +op_neg = {"-"} +op_pos = {"+"} op_left_shift = @{"<<"} op_right_shift = @{">>"} op_binary = _ { op_pow | op_or | op_and | op_bit_xor | op_bit_and | op_bit_or | op_left_shift | op_right_shift | op_equal | op_not_equal | op_lte | op_lt | op_gte | op_gt | op_add | op_sub | op_mul | op_div | op_rem } -op_unary = { op_not } +op_unary = { op_not | op_neg | op_pos } WHITESPACE = _{ " " | "\t" | "\\" ~ NEWLINE} diff --git a/zokrates_pest_ast/src/lib.rs b/zokrates_pest_ast/src/lib.rs index 4ba8e299..bbff5489 100644 --- a/zokrates_pest_ast/src/lib.rs +++ b/zokrates_pest_ast/src/lib.rs @@ -435,6 +435,8 @@ mod ast { #[pest_ast(rule(Rule::op_unary))] pub enum UnaryOperator<'ast> { Not(Not<'ast>), + Neg(Neg<'ast>), + Pos(Pos<'ast>), } #[derive(Debug, PartialEq, FromPest, Clone)] @@ -444,6 +446,20 @@ mod ast { pub span: Span<'ast>, } + #[derive(Debug, PartialEq, FromPest, Clone)] + #[pest_ast(rule(Rule::op_neg))] + pub struct Neg<'ast> { + #[pest_ast(outer())] + pub span: Span<'ast>, + } + + #[derive(Debug, PartialEq, FromPest, Clone)] + #[pest_ast(rule(Rule::op_pos))] + pub struct Pos<'ast> { + #[pest_ast(outer())] + pub span: Span<'ast>, + } + #[derive(Debug, PartialEq, Clone)] pub enum Expression<'ast> { Ternary(TernaryExpression<'ast>),