1
0
Fork 0
mirror of synced 2025-09-23 12:18:44 +00:00

add negative and positive operators

This commit is contained in:
schaeff 2021-03-16 19:22:32 +01:00
parent 676f6b844f
commit 5897a6f172
12 changed files with 188 additions and 1 deletions

View file

@ -507,6 +507,12 @@ impl<'ast> From<pest::UnaryExpression<'ast>> for absy::ExpressionNode<'ast> {
pest::UnaryOperator::Not(_) => {
absy::Expression::Not(Box::new(absy::ExpressionNode::from(*unary.expression)))
}
pest::UnaryOperator::Neg(_) => {
absy::Expression::Neg(Box::new(absy::ExpressionNode::from(*unary.expression)))
}
pest::UnaryOperator::Pos(_) => {
absy::Expression::Pos(Box::new(absy::ExpressionNode::from(*unary.expression)))
}
}
.span(unary.span)
}

View file

@ -484,6 +484,8 @@ pub enum Expression<'ast> {
Div(Box<ExpressionNode<'ast>>, Box<ExpressionNode<'ast>>),
Rem(Box<ExpressionNode<'ast>>, Box<ExpressionNode<'ast>>),
Pow(Box<ExpressionNode<'ast>>, Box<ExpressionNode<'ast>>),
Neg(Box<ExpressionNode<'ast>>),
Pos(Box<ExpressionNode<'ast>>),
IfElse(
Box<ExpressionNode<'ast>>,
Box<ExpressionNode<'ast>>,
@ -525,6 +527,8 @@ impl<'ast> fmt::Display for Expression<'ast> {
Expression::Div(ref lhs, ref rhs) => write!(f, "({} / {})", lhs, rhs),
Expression::Rem(ref lhs, ref rhs) => write!(f, "({} % {})", lhs, rhs),
Expression::Pow(ref lhs, ref rhs) => write!(f, "({}**{})", lhs, rhs),
Expression::Neg(ref e) => write!(f, "(-{})", e),
Expression::Pos(ref e) => write!(f, "(+{})", e),
Expression::BooleanConstant(b) => write!(f, "{}", b),
Expression::IfElse(ref condition, ref consequent, ref alternative) => write!(
f,
@ -594,6 +598,8 @@ impl<'ast> fmt::Debug for Expression<'ast> {
Expression::Div(ref lhs, ref rhs) => write!(f, "Div({:?}, {:?})", lhs, rhs),
Expression::Rem(ref lhs, ref rhs) => write!(f, "Rem({:?}, {:?})", lhs, rhs),
Expression::Pow(ref lhs, ref rhs) => write!(f, "Pow({:?}, {:?})", lhs, rhs),
Expression::Neg(ref e) => write!(f, "Neg({:?})", e),
Expression::Pos(ref e) => write!(f, "Pos({:?})", e),
Expression::BooleanConstant(b) => write!(f, "{}", b),
Expression::IfElse(ref condition, ref consequent, ref alternative) => write!(
f,

View file

@ -1423,6 +1423,42 @@ impl<'ast> Checker<'ast> {
}),
}
}
Expression::Neg(box e) => {
let e = self.check_expression(e, module_id, &types)?;
match e {
TypedExpression::FieldElement(e) => {
Ok(FieldElementExpression::Neg(box e).into())
}
TypedExpression::Uint(e) => Ok(UExpression::neg(e).into()),
e => Err(ErrorInner {
pos: Some(pos),
message: format!(
"Unary operator `-` cannot be applied to {} of type {}",
e,
e.get_type()
),
}),
}
}
Expression::Pos(box e) => {
let e = self.check_expression(e, module_id, &types)?;
match e {
TypedExpression::FieldElement(e) => {
Ok(FieldElementExpression::Pos(box e).into())
}
TypedExpression::Uint(e) => Ok(UExpression::pos(e).into()),
e => Err(ErrorInner {
pos: Some(pos),
message: format!(
"Unary operator `+` cannot be applied to {} of type {}",
e,
e.get_type()
),
}),
}
}
Expression::IfElse(box condition, box consequence, box alternative) => {
let condition_checked = self.check_expression(condition, module_id, &types)?;
let consequence_checked = self.check_expression(consequence, module_id, &types)?;

View file

@ -528,6 +528,19 @@ pub fn fold_field_expression<'ast, T: Field>(
let e2 = f.fold_field_expression(e2);
zir::FieldElementExpression::Pow(box e1, box e2)
}
typed_absy::FieldElementExpression::Neg(box e) => {
let e = f.fold_field_expression(e);
zir::FieldElementExpression::Sub(
box zir::FieldElementExpression::Number(T::zero()),
box e,
)
}
typed_absy::FieldElementExpression::Pos(box e) => {
let e = f.fold_field_expression(e);
e
}
typed_absy::FieldElementExpression::IfElse(box cond, box cons, box alt) => {
let cond = f.fold_boolean_expression(cond);
let cons = f.fold_field_expression(cons);
@ -811,6 +824,21 @@ pub fn fold_uint_expression_inner<'ast, T: Field>(
zir::UExpressionInner::Not(box e)
}
typed_absy::UExpressionInner::Neg(box e) => {
let bitwidth = e.bitwidth();
f.fold_uint_expression(typed_absy::UExpression::sub(
typed_absy::UExpressionInner::Value(0).annotate(bitwidth),
e,
))
.into_inner()
}
typed_absy::UExpressionInner::Pos(box e) => {
let e = f.fold_uint_expression(e);
e.into_inner()
}
typed_absy::UExpressionInner::FunctionCall(..) => {
unreachable!("function calls should have been removed")
}

View file

@ -722,6 +722,25 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> {
e => UExpressionInner::Not(box e.annotate(bitwidth)),
}
}
UExpressionInner::Neg(box e) => {
let e = self.fold_uint_expression(e).into_inner();
match e {
UExpressionInner::Value(v) => {
use std::convert::TryInto;
UExpressionInner::Value(
2_u128.pow(bitwidth.to_usize().try_into().unwrap()) - v,
)
}
e => UExpressionInner::Neg(box e.annotate(bitwidth)),
}
}
UExpressionInner::Pos(box e) => {
let e = self.fold_uint_expression(e).into_inner();
match e {
UExpressionInner::Value(v) => UExpressionInner::Value(v),
e => UExpressionInner::Pos(box e.annotate(bitwidth)),
}
}
UExpressionInner::Select(box array, box index) => {
let array = self.fold_array_expression(array);
let index = self.fold_field_expression(index);
@ -830,6 +849,14 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> {
}
(e1, e2) => FieldElementExpression::Div(box e1, box e2),
},
FieldElementExpression::Neg(box e) => match self.fold_field_expression(e) {
FieldElementExpression::Number(n) => FieldElementExpression::Number(T::zero() - n),
e => FieldElementExpression::Neg(box e),
},
FieldElementExpression::Pos(box e) => match self.fold_field_expression(e) {
FieldElementExpression::Number(n) => FieldElementExpression::Number(n),
e => FieldElementExpression::Pos(box e),
},
FieldElementExpression::Pow(box e1, box e2) => {
let e1 = self.fold_field_expression(e1);
let e2 = self.fold_field_expression(e2);

View file

@ -277,6 +277,16 @@ pub fn fold_field_expression<'ast, T: Field, F: Folder<'ast, T>>(
let e2 = f.fold_field_expression(e2);
FieldElementExpression::Pow(box e1, box e2)
}
FieldElementExpression::Neg(box e) => {
let e = f.fold_field_expression(e);
FieldElementExpression::Neg(box e)
}
FieldElementExpression::Pos(box e) => {
let e = f.fold_field_expression(e);
FieldElementExpression::Pos(box e)
}
FieldElementExpression::IfElse(box cond, box cons, box alt) => {
let cond = f.fold_boolean_expression(cond);
let cons = f.fold_field_expression(cons);
@ -470,6 +480,16 @@ pub fn fold_uint_expression_inner<'ast, T: Field, F: Folder<'ast, T>>(
UExpressionInner::Not(box e)
}
UExpressionInner::Neg(box e) => {
let e = f.fold_uint_expression(e);
UExpressionInner::Neg(box e)
}
UExpressionInner::Pos(box e) => {
let e = f.fold_uint_expression(e);
UExpressionInner::Pos(box e)
}
UExpressionInner::FunctionCall(key, exps) => {
let exps = exps.into_iter().map(|e| f.fold_expression(e)).collect();
UExpressionInner::FunctionCall(key, exps)

View file

@ -626,6 +626,8 @@ pub enum FieldElementExpression<'ast, T> {
Box<FieldElementExpression<'ast, T>>,
Box<FieldElementExpression<'ast, T>>,
),
Neg(Box<FieldElementExpression<'ast, T>>),
Pos(Box<FieldElementExpression<'ast, T>>),
FunctionCall(FunctionKey<'ast>, Vec<TypedExpression<'ast, T>>),
Member(Box<StructExpression<'ast, T>>, MemberId),
Select(
@ -876,6 +878,8 @@ impl<'ast, T: fmt::Display> fmt::Display for FieldElementExpression<'ast, T> {
FieldElementExpression::Mult(ref lhs, ref rhs) => write!(f, "({} * {})", lhs, rhs),
FieldElementExpression::Div(ref lhs, ref rhs) => write!(f, "({} / {})", lhs, rhs),
FieldElementExpression::Pow(ref lhs, ref rhs) => write!(f, "{}**{}", lhs, rhs),
FieldElementExpression::Neg(ref e) => write!(f, "(-{})", e),
FieldElementExpression::Pos(ref e) => write!(f, "(+{})", e),
FieldElementExpression::IfElse(ref condition, ref consequent, ref alternative) => {
write!(
f,
@ -915,6 +919,8 @@ impl<'ast, T: fmt::Display> fmt::Display for UExpression<'ast, T> {
UExpressionInner::RightShift(ref e, ref by) => write!(f, "({} >> {})", e, by),
UExpressionInner::LeftShift(ref e, ref by) => write!(f, "({} << {})", e, by),
UExpressionInner::Not(ref e) => write!(f, "!{}", e),
UExpressionInner::Neg(ref e) => write!(f, "(-{})", e),
UExpressionInner::Pos(ref e) => write!(f, "(+{})", e),
UExpressionInner::Select(ref id, ref index) => write!(f, "{}[{}]", id, index),
UExpressionInner::FunctionCall(ref k, ref p) => {
write!(f, "{}(", k.id,)?;
@ -1067,6 +1073,8 @@ impl<'ast, T: fmt::Debug> fmt::Debug for FieldElementExpression<'ast, T> {
}
FieldElementExpression::Div(ref lhs, ref rhs) => write!(f, "Div({:?}, {:?})", lhs, rhs),
FieldElementExpression::Pow(ref lhs, ref rhs) => write!(f, "Pow({:?}, {:?})", lhs, rhs),
FieldElementExpression::Neg(ref e) => write!(f, "Neg({:?})", e),
FieldElementExpression::Pos(ref e) => write!(f, "Pos({:?})", e),
FieldElementExpression::IfElse(ref condition, ref consequent, ref alternative) => {
write!(
f,

View file

@ -58,6 +58,16 @@ impl<'ast, T: Field> UExpression<'ast, T> {
UExpressionInner::Not(box self).annotate(bitwidth)
}
pub fn neg(self) -> UExpression<'ast, T> {
let bitwidth = self.bitwidth;
UExpressionInner::Neg(box self).annotate(bitwidth)
}
pub fn pos(self) -> UExpression<'ast, T> {
let bitwidth = self.bitwidth;
UExpressionInner::Pos(box self).annotate(bitwidth)
}
pub fn left_shift(self, by: FieldElementExpression<'ast, T>) -> UExpression<'ast, T> {
let bitwidth = self.bitwidth;
UExpressionInner::LeftShift(box self, box by).annotate(bitwidth)
@ -107,6 +117,8 @@ pub enum UExpressionInner<'ast, T> {
And(Box<UExpression<'ast, T>>, Box<UExpression<'ast, T>>),
Or(Box<UExpression<'ast, T>>, Box<UExpression<'ast, T>>),
Not(Box<UExpression<'ast, T>>),
Neg(Box<UExpression<'ast, T>>),
Pos(Box<UExpression<'ast, T>>),
LeftShift(
Box<UExpression<'ast, T>>,
Box<FieldElementExpression<'ast, T>>,

View file

@ -0,0 +1,15 @@
{
"entry_point": "./tests/tests/neg_pos.zok",
"tests": [
{
"input": {
"values": ["1", "2", "1", "2"]
},
"output": {
"Ok": {
"values": ["21888242871839275222246405745257275088548364400416034343698204186575808495615", "21888242871839275222246405745257275088548364400416034343698204186575808495616", "21888242871839275222246405745257275088548364400416034343698204186575808495616", "21888242871839275222246405745257275088548364400416034343698204186575808495616", "254", "255", "255", "255"]
}
}
}
]
}

View file

@ -0,0 +1,11 @@
def main(field x, field y, u8 z, u8 t) -> (field[4], u8[4]):
field a = -y // should parse to neg
field b = x - y // should parse to sub
field c = x + - y // should parse to add(neg)
field d = x - + y // should parse to sub(pos)
u8 e = -t // should parse to neg
u8 f = z - t // should parse to sub
u8 g = z + - t // should parse to add(neg)
u8 h = z - + t // should parse to sub(pos)
return [a, b, c, d], [e, f, g, h]

View file

@ -117,10 +117,12 @@ op_div = {"/"}
op_rem = {"%"}
op_pow = @{"**"}
op_not = {"!"}
op_neg = {"-"}
op_pos = {"+"}
op_left_shift = @{"<<"}
op_right_shift = @{">>"}
op_binary = _ { op_pow | op_or | op_and | op_bit_xor | op_bit_and | op_bit_or | op_left_shift | op_right_shift | op_equal | op_not_equal | op_lte | op_lt | op_gte | op_gt | op_add | op_sub | op_mul | op_div | op_rem }
op_unary = { op_not }
op_unary = { op_not | op_neg | op_pos }
WHITESPACE = _{ " " | "\t" | "\\" ~ NEWLINE}

View file

@ -435,6 +435,8 @@ mod ast {
#[pest_ast(rule(Rule::op_unary))]
pub enum UnaryOperator<'ast> {
Not(Not<'ast>),
Neg(Neg<'ast>),
Pos(Pos<'ast>),
}
#[derive(Debug, PartialEq, FromPest, Clone)]
@ -444,6 +446,20 @@ mod ast {
pub span: Span<'ast>,
}
#[derive(Debug, PartialEq, FromPest, Clone)]
#[pest_ast(rule(Rule::op_neg))]
pub struct Neg<'ast> {
#[pest_ast(outer())]
pub span: Span<'ast>,
}
#[derive(Debug, PartialEq, FromPest, Clone)]
#[pest_ast(rule(Rule::op_pos))]
pub struct Pos<'ast> {
#[pest_ast(outer())]
pub span: Span<'ast>,
}
#[derive(Debug, PartialEq, Clone)]
pub enum Expression<'ast> {
Ternary(TernaryExpression<'ast>),