add negative and positive operators
This commit is contained in:
parent
676f6b844f
commit
5897a6f172
12 changed files with 188 additions and 1 deletions
|
@ -507,6 +507,12 @@ impl<'ast> From<pest::UnaryExpression<'ast>> for absy::ExpressionNode<'ast> {
|
|||
pest::UnaryOperator::Not(_) => {
|
||||
absy::Expression::Not(Box::new(absy::ExpressionNode::from(*unary.expression)))
|
||||
}
|
||||
pest::UnaryOperator::Neg(_) => {
|
||||
absy::Expression::Neg(Box::new(absy::ExpressionNode::from(*unary.expression)))
|
||||
}
|
||||
pest::UnaryOperator::Pos(_) => {
|
||||
absy::Expression::Pos(Box::new(absy::ExpressionNode::from(*unary.expression)))
|
||||
}
|
||||
}
|
||||
.span(unary.span)
|
||||
}
|
||||
|
|
|
@ -484,6 +484,8 @@ pub enum Expression<'ast> {
|
|||
Div(Box<ExpressionNode<'ast>>, Box<ExpressionNode<'ast>>),
|
||||
Rem(Box<ExpressionNode<'ast>>, Box<ExpressionNode<'ast>>),
|
||||
Pow(Box<ExpressionNode<'ast>>, Box<ExpressionNode<'ast>>),
|
||||
Neg(Box<ExpressionNode<'ast>>),
|
||||
Pos(Box<ExpressionNode<'ast>>),
|
||||
IfElse(
|
||||
Box<ExpressionNode<'ast>>,
|
||||
Box<ExpressionNode<'ast>>,
|
||||
|
@ -525,6 +527,8 @@ impl<'ast> fmt::Display for Expression<'ast> {
|
|||
Expression::Div(ref lhs, ref rhs) => write!(f, "({} / {})", lhs, rhs),
|
||||
Expression::Rem(ref lhs, ref rhs) => write!(f, "({} % {})", lhs, rhs),
|
||||
Expression::Pow(ref lhs, ref rhs) => write!(f, "({}**{})", lhs, rhs),
|
||||
Expression::Neg(ref e) => write!(f, "(-{})", e),
|
||||
Expression::Pos(ref e) => write!(f, "(+{})", e),
|
||||
Expression::BooleanConstant(b) => write!(f, "{}", b),
|
||||
Expression::IfElse(ref condition, ref consequent, ref alternative) => write!(
|
||||
f,
|
||||
|
@ -594,6 +598,8 @@ impl<'ast> fmt::Debug for Expression<'ast> {
|
|||
Expression::Div(ref lhs, ref rhs) => write!(f, "Div({:?}, {:?})", lhs, rhs),
|
||||
Expression::Rem(ref lhs, ref rhs) => write!(f, "Rem({:?}, {:?})", lhs, rhs),
|
||||
Expression::Pow(ref lhs, ref rhs) => write!(f, "Pow({:?}, {:?})", lhs, rhs),
|
||||
Expression::Neg(ref e) => write!(f, "Neg({:?})", e),
|
||||
Expression::Pos(ref e) => write!(f, "Pos({:?})", e),
|
||||
Expression::BooleanConstant(b) => write!(f, "{}", b),
|
||||
Expression::IfElse(ref condition, ref consequent, ref alternative) => write!(
|
||||
f,
|
||||
|
|
|
@ -1423,6 +1423,42 @@ impl<'ast> Checker<'ast> {
|
|||
}),
|
||||
}
|
||||
}
|
||||
Expression::Neg(box e) => {
|
||||
let e = self.check_expression(e, module_id, &types)?;
|
||||
|
||||
match e {
|
||||
TypedExpression::FieldElement(e) => {
|
||||
Ok(FieldElementExpression::Neg(box e).into())
|
||||
}
|
||||
TypedExpression::Uint(e) => Ok(UExpression::neg(e).into()),
|
||||
e => Err(ErrorInner {
|
||||
pos: Some(pos),
|
||||
message: format!(
|
||||
"Unary operator `-` cannot be applied to {} of type {}",
|
||||
e,
|
||||
e.get_type()
|
||||
),
|
||||
}),
|
||||
}
|
||||
}
|
||||
Expression::Pos(box e) => {
|
||||
let e = self.check_expression(e, module_id, &types)?;
|
||||
|
||||
match e {
|
||||
TypedExpression::FieldElement(e) => {
|
||||
Ok(FieldElementExpression::Pos(box e).into())
|
||||
}
|
||||
TypedExpression::Uint(e) => Ok(UExpression::pos(e).into()),
|
||||
e => Err(ErrorInner {
|
||||
pos: Some(pos),
|
||||
message: format!(
|
||||
"Unary operator `+` cannot be applied to {} of type {}",
|
||||
e,
|
||||
e.get_type()
|
||||
),
|
||||
}),
|
||||
}
|
||||
}
|
||||
Expression::IfElse(box condition, box consequence, box alternative) => {
|
||||
let condition_checked = self.check_expression(condition, module_id, &types)?;
|
||||
let consequence_checked = self.check_expression(consequence, module_id, &types)?;
|
||||
|
|
|
@ -528,6 +528,19 @@ pub fn fold_field_expression<'ast, T: Field>(
|
|||
let e2 = f.fold_field_expression(e2);
|
||||
zir::FieldElementExpression::Pow(box e1, box e2)
|
||||
}
|
||||
typed_absy::FieldElementExpression::Neg(box e) => {
|
||||
let e = f.fold_field_expression(e);
|
||||
|
||||
zir::FieldElementExpression::Sub(
|
||||
box zir::FieldElementExpression::Number(T::zero()),
|
||||
box e,
|
||||
)
|
||||
}
|
||||
typed_absy::FieldElementExpression::Pos(box e) => {
|
||||
let e = f.fold_field_expression(e);
|
||||
|
||||
e
|
||||
}
|
||||
typed_absy::FieldElementExpression::IfElse(box cond, box cons, box alt) => {
|
||||
let cond = f.fold_boolean_expression(cond);
|
||||
let cons = f.fold_field_expression(cons);
|
||||
|
@ -811,6 +824,21 @@ pub fn fold_uint_expression_inner<'ast, T: Field>(
|
|||
|
||||
zir::UExpressionInner::Not(box e)
|
||||
}
|
||||
typed_absy::UExpressionInner::Neg(box e) => {
|
||||
let bitwidth = e.bitwidth();
|
||||
|
||||
f.fold_uint_expression(typed_absy::UExpression::sub(
|
||||
typed_absy::UExpressionInner::Value(0).annotate(bitwidth),
|
||||
e,
|
||||
))
|
||||
.into_inner()
|
||||
}
|
||||
|
||||
typed_absy::UExpressionInner::Pos(box e) => {
|
||||
let e = f.fold_uint_expression(e);
|
||||
|
||||
e.into_inner()
|
||||
}
|
||||
typed_absy::UExpressionInner::FunctionCall(..) => {
|
||||
unreachable!("function calls should have been removed")
|
||||
}
|
||||
|
|
|
@ -722,6 +722,25 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> {
|
|||
e => UExpressionInner::Not(box e.annotate(bitwidth)),
|
||||
}
|
||||
}
|
||||
UExpressionInner::Neg(box e) => {
|
||||
let e = self.fold_uint_expression(e).into_inner();
|
||||
match e {
|
||||
UExpressionInner::Value(v) => {
|
||||
use std::convert::TryInto;
|
||||
UExpressionInner::Value(
|
||||
2_u128.pow(bitwidth.to_usize().try_into().unwrap()) - v,
|
||||
)
|
||||
}
|
||||
e => UExpressionInner::Neg(box e.annotate(bitwidth)),
|
||||
}
|
||||
}
|
||||
UExpressionInner::Pos(box e) => {
|
||||
let e = self.fold_uint_expression(e).into_inner();
|
||||
match e {
|
||||
UExpressionInner::Value(v) => UExpressionInner::Value(v),
|
||||
e => UExpressionInner::Pos(box e.annotate(bitwidth)),
|
||||
}
|
||||
}
|
||||
UExpressionInner::Select(box array, box index) => {
|
||||
let array = self.fold_array_expression(array);
|
||||
let index = self.fold_field_expression(index);
|
||||
|
@ -830,6 +849,14 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> {
|
|||
}
|
||||
(e1, e2) => FieldElementExpression::Div(box e1, box e2),
|
||||
},
|
||||
FieldElementExpression::Neg(box e) => match self.fold_field_expression(e) {
|
||||
FieldElementExpression::Number(n) => FieldElementExpression::Number(T::zero() - n),
|
||||
e => FieldElementExpression::Neg(box e),
|
||||
},
|
||||
FieldElementExpression::Pos(box e) => match self.fold_field_expression(e) {
|
||||
FieldElementExpression::Number(n) => FieldElementExpression::Number(n),
|
||||
e => FieldElementExpression::Pos(box e),
|
||||
},
|
||||
FieldElementExpression::Pow(box e1, box e2) => {
|
||||
let e1 = self.fold_field_expression(e1);
|
||||
let e2 = self.fold_field_expression(e2);
|
||||
|
|
|
@ -277,6 +277,16 @@ pub fn fold_field_expression<'ast, T: Field, F: Folder<'ast, T>>(
|
|||
let e2 = f.fold_field_expression(e2);
|
||||
FieldElementExpression::Pow(box e1, box e2)
|
||||
}
|
||||
FieldElementExpression::Neg(box e) => {
|
||||
let e = f.fold_field_expression(e);
|
||||
|
||||
FieldElementExpression::Neg(box e)
|
||||
}
|
||||
FieldElementExpression::Pos(box e) => {
|
||||
let e = f.fold_field_expression(e);
|
||||
|
||||
FieldElementExpression::Pos(box e)
|
||||
}
|
||||
FieldElementExpression::IfElse(box cond, box cons, box alt) => {
|
||||
let cond = f.fold_boolean_expression(cond);
|
||||
let cons = f.fold_field_expression(cons);
|
||||
|
@ -470,6 +480,16 @@ pub fn fold_uint_expression_inner<'ast, T: Field, F: Folder<'ast, T>>(
|
|||
|
||||
UExpressionInner::Not(box e)
|
||||
}
|
||||
UExpressionInner::Neg(box e) => {
|
||||
let e = f.fold_uint_expression(e);
|
||||
|
||||
UExpressionInner::Neg(box e)
|
||||
}
|
||||
UExpressionInner::Pos(box e) => {
|
||||
let e = f.fold_uint_expression(e);
|
||||
|
||||
UExpressionInner::Pos(box e)
|
||||
}
|
||||
UExpressionInner::FunctionCall(key, exps) => {
|
||||
let exps = exps.into_iter().map(|e| f.fold_expression(e)).collect();
|
||||
UExpressionInner::FunctionCall(key, exps)
|
||||
|
|
|
@ -626,6 +626,8 @@ pub enum FieldElementExpression<'ast, T> {
|
|||
Box<FieldElementExpression<'ast, T>>,
|
||||
Box<FieldElementExpression<'ast, T>>,
|
||||
),
|
||||
Neg(Box<FieldElementExpression<'ast, T>>),
|
||||
Pos(Box<FieldElementExpression<'ast, T>>),
|
||||
FunctionCall(FunctionKey<'ast>, Vec<TypedExpression<'ast, T>>),
|
||||
Member(Box<StructExpression<'ast, T>>, MemberId),
|
||||
Select(
|
||||
|
@ -876,6 +878,8 @@ impl<'ast, T: fmt::Display> fmt::Display for FieldElementExpression<'ast, T> {
|
|||
FieldElementExpression::Mult(ref lhs, ref rhs) => write!(f, "({} * {})", lhs, rhs),
|
||||
FieldElementExpression::Div(ref lhs, ref rhs) => write!(f, "({} / {})", lhs, rhs),
|
||||
FieldElementExpression::Pow(ref lhs, ref rhs) => write!(f, "{}**{}", lhs, rhs),
|
||||
FieldElementExpression::Neg(ref e) => write!(f, "(-{})", e),
|
||||
FieldElementExpression::Pos(ref e) => write!(f, "(+{})", e),
|
||||
FieldElementExpression::IfElse(ref condition, ref consequent, ref alternative) => {
|
||||
write!(
|
||||
f,
|
||||
|
@ -915,6 +919,8 @@ impl<'ast, T: fmt::Display> fmt::Display for UExpression<'ast, T> {
|
|||
UExpressionInner::RightShift(ref e, ref by) => write!(f, "({} >> {})", e, by),
|
||||
UExpressionInner::LeftShift(ref e, ref by) => write!(f, "({} << {})", e, by),
|
||||
UExpressionInner::Not(ref e) => write!(f, "!{}", e),
|
||||
UExpressionInner::Neg(ref e) => write!(f, "(-{})", e),
|
||||
UExpressionInner::Pos(ref e) => write!(f, "(+{})", e),
|
||||
UExpressionInner::Select(ref id, ref index) => write!(f, "{}[{}]", id, index),
|
||||
UExpressionInner::FunctionCall(ref k, ref p) => {
|
||||
write!(f, "{}(", k.id,)?;
|
||||
|
@ -1067,6 +1073,8 @@ impl<'ast, T: fmt::Debug> fmt::Debug for FieldElementExpression<'ast, T> {
|
|||
}
|
||||
FieldElementExpression::Div(ref lhs, ref rhs) => write!(f, "Div({:?}, {:?})", lhs, rhs),
|
||||
FieldElementExpression::Pow(ref lhs, ref rhs) => write!(f, "Pow({:?}, {:?})", lhs, rhs),
|
||||
FieldElementExpression::Neg(ref e) => write!(f, "Neg({:?})", e),
|
||||
FieldElementExpression::Pos(ref e) => write!(f, "Pos({:?})", e),
|
||||
FieldElementExpression::IfElse(ref condition, ref consequent, ref alternative) => {
|
||||
write!(
|
||||
f,
|
||||
|
|
|
@ -58,6 +58,16 @@ impl<'ast, T: Field> UExpression<'ast, T> {
|
|||
UExpressionInner::Not(box self).annotate(bitwidth)
|
||||
}
|
||||
|
||||
pub fn neg(self) -> UExpression<'ast, T> {
|
||||
let bitwidth = self.bitwidth;
|
||||
UExpressionInner::Neg(box self).annotate(bitwidth)
|
||||
}
|
||||
|
||||
pub fn pos(self) -> UExpression<'ast, T> {
|
||||
let bitwidth = self.bitwidth;
|
||||
UExpressionInner::Pos(box self).annotate(bitwidth)
|
||||
}
|
||||
|
||||
pub fn left_shift(self, by: FieldElementExpression<'ast, T>) -> UExpression<'ast, T> {
|
||||
let bitwidth = self.bitwidth;
|
||||
UExpressionInner::LeftShift(box self, box by).annotate(bitwidth)
|
||||
|
@ -107,6 +117,8 @@ pub enum UExpressionInner<'ast, T> {
|
|||
And(Box<UExpression<'ast, T>>, Box<UExpression<'ast, T>>),
|
||||
Or(Box<UExpression<'ast, T>>, Box<UExpression<'ast, T>>),
|
||||
Not(Box<UExpression<'ast, T>>),
|
||||
Neg(Box<UExpression<'ast, T>>),
|
||||
Pos(Box<UExpression<'ast, T>>),
|
||||
LeftShift(
|
||||
Box<UExpression<'ast, T>>,
|
||||
Box<FieldElementExpression<'ast, T>>,
|
||||
|
|
15
zokrates_core_test/tests/tests/neg_pos.json
Normal file
15
zokrates_core_test/tests/tests/neg_pos.json
Normal file
|
@ -0,0 +1,15 @@
|
|||
{
|
||||
"entry_point": "./tests/tests/neg_pos.zok",
|
||||
"tests": [
|
||||
{
|
||||
"input": {
|
||||
"values": ["1", "2", "1", "2"]
|
||||
},
|
||||
"output": {
|
||||
"Ok": {
|
||||
"values": ["21888242871839275222246405745257275088548364400416034343698204186575808495615", "21888242871839275222246405745257275088548364400416034343698204186575808495616", "21888242871839275222246405745257275088548364400416034343698204186575808495616", "21888242871839275222246405745257275088548364400416034343698204186575808495616", "254", "255", "255", "255"]
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
11
zokrates_core_test/tests/tests/neg_pos.zok
Normal file
11
zokrates_core_test/tests/tests/neg_pos.zok
Normal file
|
@ -0,0 +1,11 @@
|
|||
def main(field x, field y, u8 z, u8 t) -> (field[4], u8[4]):
|
||||
field a = -y // should parse to neg
|
||||
field b = x - y // should parse to sub
|
||||
field c = x + - y // should parse to add(neg)
|
||||
field d = x - + y // should parse to sub(pos)
|
||||
|
||||
u8 e = -t // should parse to neg
|
||||
u8 f = z - t // should parse to sub
|
||||
u8 g = z + - t // should parse to add(neg)
|
||||
u8 h = z - + t // should parse to sub(pos)
|
||||
return [a, b, c, d], [e, f, g, h]
|
|
@ -117,10 +117,12 @@ op_div = {"/"}
|
|||
op_rem = {"%"}
|
||||
op_pow = @{"**"}
|
||||
op_not = {"!"}
|
||||
op_neg = {"-"}
|
||||
op_pos = {"+"}
|
||||
op_left_shift = @{"<<"}
|
||||
op_right_shift = @{">>"}
|
||||
op_binary = _ { op_pow | op_or | op_and | op_bit_xor | op_bit_and | op_bit_or | op_left_shift | op_right_shift | op_equal | op_not_equal | op_lte | op_lt | op_gte | op_gt | op_add | op_sub | op_mul | op_div | op_rem }
|
||||
op_unary = { op_not }
|
||||
op_unary = { op_not | op_neg | op_pos }
|
||||
|
||||
|
||||
WHITESPACE = _{ " " | "\t" | "\\" ~ NEWLINE}
|
||||
|
|
|
@ -435,6 +435,8 @@ mod ast {
|
|||
#[pest_ast(rule(Rule::op_unary))]
|
||||
pub enum UnaryOperator<'ast> {
|
||||
Not(Not<'ast>),
|
||||
Neg(Neg<'ast>),
|
||||
Pos(Pos<'ast>),
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, FromPest, Clone)]
|
||||
|
@ -444,6 +446,20 @@ mod ast {
|
|||
pub span: Span<'ast>,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, FromPest, Clone)]
|
||||
#[pest_ast(rule(Rule::op_neg))]
|
||||
pub struct Neg<'ast> {
|
||||
#[pest_ast(outer())]
|
||||
pub span: Span<'ast>,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, FromPest, Clone)]
|
||||
#[pest_ast(rule(Rule::op_pos))]
|
||||
pub struct Pos<'ast> {
|
||||
#[pest_ast(outer())]
|
||||
pub span: Span<'ast>,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Clone)]
|
||||
pub enum Expression<'ast> {
|
||||
Ternary(TernaryExpression<'ast>),
|
||||
|
|
Loading…
Reference in a new issue