1
0
Fork 0
mirror of synced 2025-09-24 04:40:05 +00:00

add remainder operation

This commit is contained in:
schaeff 2020-11-19 12:35:42 +00:00
parent 9481ead354
commit 9f3ccdb7f2
20 changed files with 251 additions and 37 deletions

View file

@ -5,7 +5,7 @@ The following table lists the precedence and associativity of all available oper
| Operator | Description | Associativity | Remarks |
|------------------------------|--------------------------------------------------------------|------------------------------------|---------|
| ** <br> | Power | Left | [^1] |
| * <br> /<br> | Multiplication <br> Division <br> | Left <br> Left | |
| *<br> /<br> %<br> | Multiplication <br> Division <br> Remainder | Left <br> Left <br>Left | |
| + <br> - <br> | Addition <br> Subtraction <br> | Left <br> Left | |
| << <br> >> <br> | Left shift <br> Right shift <br> | Left <br> Left | [^2] |
| & | Bitwise AND | Left <br> Left | |
@ -22,4 +22,4 @@ The following table lists the precedence and associativity of all available oper
[^2]: The right operand must be a compile time constant
[^3]: Both operands are be asserted to be strictly lower than the biggest power of 2 lower than `p/2`
[^3]: Both operands are asserted to be strictly lower than the biggest power of 2 lower than `p/2`

View file

@ -320,6 +320,10 @@ impl<'ast> From<pest::BinaryExpression<'ast>> for absy::ExpressionNode<'ast> {
box absy::ExpressionNode::from(*expression.left),
box absy::ExpressionNode::from(*expression.right),
),
pest::BinaryOperator::Rem => absy::Expression::Rem(
box absy::ExpressionNode::from(*expression.left),
box absy::ExpressionNode::from(*expression.right),
),
pest::BinaryOperator::Eq => absy::Expression::Eq(
box absy::ExpressionNode::from(*expression.left),
box absy::ExpressionNode::from(*expression.right),

View file

@ -482,6 +482,7 @@ pub enum Expression<'ast> {
Sub(Box<ExpressionNode<'ast>>, Box<ExpressionNode<'ast>>),
Mult(Box<ExpressionNode<'ast>>, Box<ExpressionNode<'ast>>),
Div(Box<ExpressionNode<'ast>>, Box<ExpressionNode<'ast>>),
Rem(Box<ExpressionNode<'ast>>, Box<ExpressionNode<'ast>>),
Pow(Box<ExpressionNode<'ast>>, Box<ExpressionNode<'ast>>),
IfElse(
Box<ExpressionNode<'ast>>,
@ -522,6 +523,7 @@ impl<'ast> fmt::Display for Expression<'ast> {
Expression::Sub(ref lhs, ref rhs) => write!(f, "({} - {})", lhs, rhs),
Expression::Mult(ref lhs, ref rhs) => write!(f, "({} * {})", lhs, rhs),
Expression::Div(ref lhs, ref rhs) => write!(f, "({} / {})", lhs, rhs),
Expression::Rem(ref lhs, ref rhs) => write!(f, "({} % {})", lhs, rhs),
Expression::Pow(ref lhs, ref rhs) => write!(f, "({}**{})", lhs, rhs),
Expression::BooleanConstant(b) => write!(f, "{}", b),
Expression::IfElse(ref condition, ref consequent, ref alternative) => write!(
@ -590,6 +592,7 @@ impl<'ast> fmt::Debug for Expression<'ast> {
Expression::Sub(ref lhs, ref rhs) => write!(f, "Sub({:?}, {:?})", lhs, rhs),
Expression::Mult(ref lhs, ref rhs) => write!(f, "Mult({:?}, {:?})", lhs, rhs),
Expression::Div(ref lhs, ref rhs) => write!(f, "Div({:?}, {:?})", lhs, rhs),
Expression::Rem(ref lhs, ref rhs) => write!(f, "Rem({:?}, {:?})", lhs, rhs),
Expression::Pow(ref lhs, ref rhs) => write!(f, "Pow({:?}, {:?})", lhs, rhs),
Expression::BooleanConstant(b) => write!(f, "{}", b),
Expression::IfElse(ref condition, ref consequent, ref alternative) => write!(

View file

@ -1307,6 +1307,91 @@ impl<'ast, T: Field> Flattener<'ast, T> {
FlatUExpression::with_field(FlatExpression::Identifier(q))
}
UExpressionInner::Rem(box left, box right) => {
let left_flattened = self
.flatten_uint_expression(symbols, statements_flattened, left)
.get_field_unchecked();
let right_flattened = self
.flatten_uint_expression(symbols, statements_flattened, right)
.get_field_unchecked();
let n = if left_flattened.is_linear() {
left_flattened
} else {
let id = self.use_sym();
statements_flattened.push(FlatStatement::Definition(id, left_flattened));
FlatExpression::Identifier(id)
};
let d = if right_flattened.is_linear() {
right_flattened
} else {
let id = self.use_sym();
statements_flattened.push(FlatStatement::Definition(id, right_flattened));
FlatExpression::Identifier(id)
};
// first check that the d is not 0 by giving its inverse
let invd = self.use_sym();
// # invd = 1/d
statements_flattened.push(FlatStatement::Directive(FlatDirective::new(
vec![invd],
Solver::Div,
vec![FlatExpression::Number(T::one()), d.clone()],
)));
// assert(invd * d == 1)
statements_flattened.push(FlatStatement::Condition(
FlatExpression::Number(T::one()),
FlatExpression::Mult(box invd.into(), box d.clone().into()),
));
// now introduce the quotient and remainder
let q = self.use_sym();
let r = self.use_sym();
statements_flattened.push(FlatStatement::Directive(FlatDirective {
inputs: vec![n.clone(), d.clone()],
outputs: vec![q.clone(), r.clone()],
solver: Solver::EuclideanDiv,
}));
// q in range
let _ = self.get_bits(
FlatUExpression::with_field(FlatExpression::from(q)),
target_bitwidth.to_usize(),
target_bitwidth,
statements_flattened,
);
// r in range
let _ = self.get_bits(
FlatUExpression::with_field(FlatExpression::from(r)),
target_bitwidth.to_usize(),
target_bitwidth,
statements_flattened,
);
// r < d <=> r - d + 2**w < 2**w
let _ = self.get_bits(
FlatUExpression::with_field(FlatExpression::Add(
box FlatExpression::Sub(box r.into(), box d.clone().into()),
box FlatExpression::Number(T::from(
2usize.pow(target_bitwidth.to_usize() as u32),
)),
)),
target_bitwidth.to_usize(),
target_bitwidth,
statements_flattened,
);
// q*d == n - r
statements_flattened.push(FlatStatement::Condition(
FlatExpression::Sub(box n, box r.into()),
FlatExpression::Mult(box q.into(), box d),
));
FlatUExpression::with_field(FlatExpression::Identifier(r))
}
UExpressionInner::IfElse(box condition, box consequence, box alternative) => self
.flatten_if_else_expression(
symbols,

View file

@ -1363,6 +1363,37 @@ impl<'ast> Checker<'ast> {
}),
}
}
Expression::Rem(box e1, box e2) => {
let e1_checked = self.check_expression(e1, module_id, &types)?;
let e2_checked = self.check_expression(e2, module_id, &types)?;
match (e1_checked, e2_checked) {
(TypedExpression::Uint(e1), TypedExpression::Uint(e2)) => {
if e1.get_type() == e2.get_type() {
Ok(UExpression::rem(e1, e2).into())
} else {
Err(ErrorInner {
pos: Some(pos),
message: format!(
"Cannot apply `%` to {}, {}",
e1.get_type(),
e2.get_type()
),
})
}
}
(t1, t2) => Err(ErrorInner {
pos: Some(pos),
message: format!(
"Cannot apply `%` to {}, {}",
t1.get_type(),
t2.get_type()
),
}),
}
}
Expression::Pow(box e1, box e2) => {
let e1_checked = self.check_expression(e1, module_id, &types)?;
let e2_checked = self.check_expression(e2, module_id, &types)?;

View file

@ -730,6 +730,12 @@ pub fn fold_uint_expression_inner<'ast, T: Field>(
zir::UExpressionInner::Div(box left, box right)
}
typed_absy::UExpressionInner::Rem(box left, box right) => {
let left = f.fold_uint_expression(left);
let right = f.fold_uint_expression(right);
zir::UExpressionInner::Rem(box left, box right)
}
typed_absy::UExpressionInner::Xor(box left, box right) => {
let left = f.fold_uint_expression(left);
let right = f.fold_uint_expression(right);

View file

@ -432,7 +432,7 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> {
}
(e, UExpressionInner::Value(v)) => match v {
1 => e,
_ => UExpressionInner::Mult(
_ => UExpressionInner::Div(
box e.annotate(bitwidth),
box UExpressionInner::Value(v).annotate(bitwidth),
),
@ -441,6 +441,27 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> {
UExpressionInner::Div(box e1.annotate(bitwidth), box e2.annotate(bitwidth))
}
},
UExpressionInner::Rem(box e1, box e2) => match (
self.fold_uint_expression(e1).into_inner(),
self.fold_uint_expression(e2).into_inner(),
) {
(UExpressionInner::Value(v1), UExpressionInner::Value(v2)) => {
use std::convert::TryInto;
UExpressionInner::Value(
(v1 % v2) % 2_u128.pow(bitwidth.to_usize().try_into().unwrap()),
)
}
(e, UExpressionInner::Value(v)) => match v {
1 => UExpressionInner::Value(0),
_ => UExpressionInner::Rem(
box e.annotate(bitwidth),
box UExpressionInner::Value(v).annotate(bitwidth),
),
},
(e1, e2) => {
UExpressionInner::Rem(box e1.annotate(bitwidth), box e2.annotate(bitwidth))
}
},
UExpressionInner::RightShift(box e, box by) => {
let e = self.fold_uint_expression(e);
let by = self.fold_field_expression(by);

View file

@ -263,39 +263,14 @@ impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> {
let left = self.fold_uint_expression(left);
let right = self.fold_uint_expression(right);
// let left_max = left.metadata.clone().unwrap().max;
// let right_max = right.metadata.clone().unwrap().max;
UExpression::div(force_reduce(left), force_reduce(right)).with_max(range_max)
}
Rem(box left, box right) => {
// reduce the two terms
let left = self.fold_uint_expression(left);
let right = self.fold_uint_expression(right);
// let (should_reduce_left, should_reduce_right, max) = left_max
// .checked_mul(&right_max)
// .map(|max| (false, false, max))
// .unwrap_or_else(|| {
// range_max
// .clone()
// .checked_mul(&right_max)
// .map(|max| (true, false, max))
// .unwrap_or_else(|| {
// left_max
// .checked_mul(&range_max.clone())
// .map(|max| (false, true, max))
// .unwrap_or_else(|| (true, true, range_max.clone() * range_max))
// })
// });
// let left = if should_reduce_left {
// force_reduce(left)
// } else {
// force_no_reduce(left)
// };
// let right = if should_reduce_right {
// force_reduce(right)
// } else {
// force_no_reduce(right)
// };
let max = range_max;
UExpression::div(force_reduce(left), force_reduce(right)).with_max(max)
UExpression::rem(force_reduce(left), force_reduce(right)).with_max(range_max)
}
Not(box e) => {
let e = self.fold_uint_expression(e);

View file

@ -436,6 +436,12 @@ pub fn fold_uint_expression_inner<'ast, T: Field, F: Folder<'ast, T>>(
UExpressionInner::Div(box left, box right)
}
UExpressionInner::Rem(box left, box right) => {
let left = f.fold_uint_expression(left);
let right = f.fold_uint_expression(right);
UExpressionInner::Rem(box left, box right)
}
UExpressionInner::Xor(box left, box right) => {
let left = f.fold_uint_expression(left);
let right = f.fold_uint_expression(right);

View file

@ -891,6 +891,7 @@ impl<'ast, T: fmt::Display> fmt::Display for UExpression<'ast, T> {
UExpressionInner::Sub(ref lhs, ref rhs) => write!(f, "({} - {})", lhs, rhs),
UExpressionInner::Mult(ref lhs, ref rhs) => write!(f, "({} * {})", lhs, rhs),
UExpressionInner::Div(ref lhs, ref rhs) => write!(f, "({} / {})", lhs, rhs),
UExpressionInner::Rem(ref lhs, ref rhs) => write!(f, "({} % {})", lhs, rhs),
UExpressionInner::RightShift(ref e, ref by) => write!(f, "({} >> {})", e, by),
UExpressionInner::LeftShift(ref e, ref by) => write!(f, "({} << {})", e, by),
UExpressionInner::Not(ref e) => write!(f, "!{}", e),

View file

@ -29,6 +29,12 @@ impl<'ast, T: Field> UExpression<'ast, T> {
UExpressionInner::Div(box self, box other).annotate(bitwidth)
}
pub fn rem(self, other: Self) -> UExpression<'ast, T> {
let bitwidth = self.bitwidth;
assert_eq!(bitwidth, other.bitwidth);
UExpressionInner::Rem(box self, box other).annotate(bitwidth)
}
pub fn xor(self, other: Self) -> UExpression<'ast, T> {
let bitwidth = self.bitwidth;
assert_eq!(bitwidth, other.bitwidth);
@ -96,6 +102,7 @@ pub enum UExpressionInner<'ast, T> {
Sub(Box<UExpression<'ast, T>>, Box<UExpression<'ast, T>>),
Mult(Box<UExpression<'ast, T>>, Box<UExpression<'ast, T>>),
Div(Box<UExpression<'ast, T>>, Box<UExpression<'ast, T>>),
Rem(Box<UExpression<'ast, T>>, Box<UExpression<'ast, T>>),
Xor(Box<UExpression<'ast, T>>, Box<UExpression<'ast, T>>),
And(Box<UExpression<'ast, T>>, Box<UExpression<'ast, T>>),
Or(Box<UExpression<'ast, T>>, Box<UExpression<'ast, T>>),

View file

@ -289,6 +289,12 @@ pub fn fold_uint_expression_inner<'ast, T: Field, F: Folder<'ast, T>>(
UExpressionInner::Div(box left, box right)
}
UExpressionInner::Rem(box left, box right) => {
let left = f.fold_uint_expression(left);
let right = f.fold_uint_expression(right);
UExpressionInner::Rem(box left, box right)
}
UExpressionInner::Xor(box left, box right) => {
let left = f.fold_uint_expression(left);
let right = f.fold_uint_expression(right);

View file

@ -488,6 +488,7 @@ impl<'ast, T: fmt::Display> fmt::Display for UExpression<'ast, T> {
UExpressionInner::Sub(ref lhs, ref rhs) => write!(f, "({} - {})", lhs, rhs),
UExpressionInner::Mult(ref lhs, ref rhs) => write!(f, "({} * {})", lhs, rhs),
UExpressionInner::Div(ref lhs, ref rhs) => write!(f, "({} * {})", lhs, rhs),
UExpressionInner::Rem(ref lhs, ref rhs) => write!(f, "({} % {})", lhs, rhs),
UExpressionInner::Xor(ref lhs, ref rhs) => write!(f, "({} ^ {})", lhs, rhs),
UExpressionInner::And(ref lhs, ref rhs) => write!(f, "({} & {})", lhs, rhs),
UExpressionInner::Or(ref lhs, ref rhs) => write!(f, "({} | {})", lhs, rhs),

View file

@ -28,6 +28,12 @@ impl<'ast, T: Field> UExpression<'ast, T> {
UExpressionInner::Div(box self, box other).annotate(bitwidth)
}
pub fn rem(self, other: Self) -> UExpression<'ast, T> {
let bitwidth = self.bitwidth;
assert_eq!(bitwidth, other.bitwidth);
UExpressionInner::Rem(box self, box other).annotate(bitwidth)
}
pub fn xor(self, other: Self) -> UExpression<'ast, T> {
let bitwidth = self.bitwidth;
assert_eq!(bitwidth, other.bitwidth);
@ -155,6 +161,7 @@ pub enum UExpressionInner<'ast, T> {
Sub(Box<UExpression<'ast, T>>, Box<UExpression<'ast, T>>),
Mult(Box<UExpression<'ast, T>>, Box<UExpression<'ast, T>>),
Div(Box<UExpression<'ast, T>>, Box<UExpression<'ast, T>>),
Rem(Box<UExpression<'ast, T>>, Box<UExpression<'ast, T>>),
Xor(Box<UExpression<'ast, T>>, Box<UExpression<'ast, T>>),
And(Box<UExpression<'ast, T>>, Box<UExpression<'ast, T>>),
Or(Box<UExpression<'ast, T>>, Box<UExpression<'ast, T>>),

View file

@ -0,0 +1,26 @@
{
"entry_point": "./tests/tests/uint/div_rem.zok",
"max_constraint_count": 43,
"tests": [
{
"input": {
"values": ["255", "1"]
},
"output": {
"Ok": {
"values": ["255", "0"]
}
}
},
{
"input": {
"values": ["42", "10"]
},
"output": {
"Ok": {
"values": ["4", "2"]
}
}
}
]
}

View file

@ -0,0 +1,2 @@
def main(u8 n, u8 d) -> (u8, u8):
return n / d, n % d

View file

@ -0,0 +1,26 @@
{
"entry_point": "./tests/tests/uint/rem.zok",
"max_constraint_count": 43,
"tests": [
{
"input": {
"values": ["255", "1"]
},
"output": {
"Ok": {
"values": ["0"]
}
}
},
{
"input": {
"values": ["42", "10"]
},
"output": {
"Ok": {
"values": ["2"]
}
}
}
]
}

View file

@ -0,0 +1,2 @@
def main(u8 x, u8 y) -> u8:
return x % y

View file

@ -114,11 +114,12 @@ op_add = {"+"}
op_sub = {"-"}
op_mul = {"*"}
op_div = {"/"}
op_rem = {"%"}
op_pow = @{"**"}
op_not = {"!"}
op_left_shift = @{"<<"}
op_right_shift = @{">>"}
op_binary = _ { op_pow | op_or | op_and | op_bit_xor | op_bit_and | op_bit_or | op_left_shift | op_right_shift | op_equal | op_not_equal | op_lte | op_lt | op_gte | op_gt | op_add | op_sub | op_mul | op_div }
op_binary = _ { op_pow | op_or | op_and | op_bit_xor | op_bit_and | op_bit_or | op_left_shift | op_right_shift | op_equal | op_not_equal | op_lte | op_lt | op_gte | op_gt | op_add | op_sub | op_mul | op_div | op_rem }
op_unary = { op_not }

View file

@ -49,7 +49,9 @@ mod ast {
Operator::new(Rule::op_left_shift, Assoc::Left)
| Operator::new(Rule::op_right_shift, Assoc::Left),
Operator::new(Rule::op_add, Assoc::Left) | Operator::new(Rule::op_sub, Assoc::Left),
Operator::new(Rule::op_mul, Assoc::Left) | Operator::new(Rule::op_div, Assoc::Left),
Operator::new(Rule::op_mul, Assoc::Left)
| Operator::new(Rule::op_div, Assoc::Left)
| Operator::new(Rule::op_rem, Assoc::Left),
Operator::new(Rule::op_pow, Assoc::Left),
])
}
@ -71,6 +73,7 @@ mod ast {
Rule::op_sub => Expression::binary(BinaryOperator::Sub, lhs, rhs, span),
Rule::op_mul => Expression::binary(BinaryOperator::Mul, lhs, rhs, span),
Rule::op_div => Expression::binary(BinaryOperator::Div, lhs, rhs, span),
Rule::op_rem => Expression::binary(BinaryOperator::Rem, lhs, rhs, span),
Rule::op_pow => Expression::binary(BinaryOperator::Pow, lhs, rhs, span),
Rule::op_equal => Expression::binary(BinaryOperator::Eq, lhs, rhs, span),
Rule::op_not_equal => Expression::binary(BinaryOperator::NotEq, lhs, rhs, span),
@ -418,6 +421,7 @@ mod ast {
Sub,
Mul,
Div,
Rem,
Eq,
NotEq,
Lt,