fix shifting in zir propagation
This commit is contained in:
parent
bc791e043a
commit
c73c09e5c2
1 changed files with 139 additions and 4 deletions
|
@ -1,6 +1,8 @@
|
||||||
|
use num::traits::Pow;
|
||||||
|
use num_bigint::BigUint;
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::fmt;
|
use std::fmt;
|
||||||
use std::ops::{BitAnd, BitOr, BitXor, Shl, Shr};
|
use std::ops::{BitAnd, BitOr, BitXor, Mul, Shr, Sub};
|
||||||
use zokrates_ast::zir::types::UBitwidth;
|
use zokrates_ast::zir::types::UBitwidth;
|
||||||
use zokrates_ast::zir::{
|
use zokrates_ast::zir::{
|
||||||
result_folder::*, Conditional, ConditionalExpression, ConditionalOrExpression, Expr, Id,
|
result_folder::*, Conditional, ConditionalExpression, ConditionalOrExpression, Expr, Id,
|
||||||
|
@ -312,15 +314,30 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for ZirPropagator<'ast, T> {
|
||||||
..
|
..
|
||||||
},
|
},
|
||||||
) if by == 0 => Ok(e),
|
) if by == 0 => Ok(e),
|
||||||
|
(
|
||||||
|
_,
|
||||||
|
UExpression {
|
||||||
|
inner: UExpressionInner::Value(by),
|
||||||
|
..
|
||||||
|
},
|
||||||
|
) if by as usize >= T::get_required_bits() => {
|
||||||
|
Ok(FieldElementExpression::Number(T::from(0)))
|
||||||
|
}
|
||||||
(
|
(
|
||||||
FieldElementExpression::Number(n),
|
FieldElementExpression::Number(n),
|
||||||
UExpression {
|
UExpression {
|
||||||
inner: UExpressionInner::Value(by),
|
inner: UExpressionInner::Value(by),
|
||||||
..
|
..
|
||||||
},
|
},
|
||||||
) => Ok(FieldElementExpression::Number(
|
) => {
|
||||||
T::try_from(n.to_biguint().shl(by as usize)).unwrap(),
|
let two = BigUint::from(2usize);
|
||||||
)),
|
let mask: BigUint = two.pow(T::get_required_bits()).sub(1usize);
|
||||||
|
|
||||||
|
Ok(FieldElementExpression::Number(
|
||||||
|
T::try_from(n.to_biguint().mul(two.pow(by as usize)).bitand(mask))
|
||||||
|
.unwrap(),
|
||||||
|
))
|
||||||
|
}
|
||||||
(e, by) => Ok(FieldElementExpression::LeftShift(box e, box by)),
|
(e, by) => Ok(FieldElementExpression::LeftShift(box e, box by)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -335,6 +352,15 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for ZirPropagator<'ast, T> {
|
||||||
..
|
..
|
||||||
},
|
},
|
||||||
) if by == 0 => Ok(e),
|
) if by == 0 => Ok(e),
|
||||||
|
(
|
||||||
|
_,
|
||||||
|
UExpression {
|
||||||
|
inner: UExpressionInner::Value(by),
|
||||||
|
..
|
||||||
|
},
|
||||||
|
) if by as usize >= T::get_required_bits() => {
|
||||||
|
Ok(FieldElementExpression::Number(T::from(0)))
|
||||||
|
}
|
||||||
(
|
(
|
||||||
FieldElementExpression::Number(n),
|
FieldElementExpression::Number(n),
|
||||||
UExpression {
|
UExpression {
|
||||||
|
@ -948,6 +974,115 @@ mod tests {
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn left_shift() {
|
||||||
|
let mut propagator = ZirPropagator::<Bn128Field>::default();
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
propagator.fold_field_expression(FieldElementExpression::LeftShift(
|
||||||
|
box FieldElementExpression::identifier("a".into()),
|
||||||
|
box UExpressionInner::Value(0).annotate(UBitwidth::B32),
|
||||||
|
)),
|
||||||
|
Ok(FieldElementExpression::identifier("a".into()))
|
||||||
|
);
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
propagator.fold_field_expression(FieldElementExpression::LeftShift(
|
||||||
|
box FieldElementExpression::Number(Bn128Field::from(2)),
|
||||||
|
box UExpressionInner::Value(2 as u128).annotate(UBitwidth::B32),
|
||||||
|
)),
|
||||||
|
Ok(FieldElementExpression::Number(Bn128Field::from(8)))
|
||||||
|
);
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
propagator.fold_field_expression(FieldElementExpression::LeftShift(
|
||||||
|
box FieldElementExpression::Number(Bn128Field::from(1)),
|
||||||
|
box UExpressionInner::Value((Bn128Field::get_required_bits() - 1) as u128).annotate(UBitwidth::B32),
|
||||||
|
)),
|
||||||
|
Ok(FieldElementExpression::Number(Bn128Field::try_from_dec_str("14474011154664524427946373126085988481658748083205070504932198000989141204992").unwrap()))
|
||||||
|
);
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
propagator.fold_field_expression(FieldElementExpression::LeftShift(
|
||||||
|
box FieldElementExpression::Number(Bn128Field::from(3)),
|
||||||
|
box UExpressionInner::Value((Bn128Field::get_required_bits() - 3) as u128).annotate(UBitwidth::B32),
|
||||||
|
)),
|
||||||
|
Ok(FieldElementExpression::Number(Bn128Field::try_from_dec_str("10855508365998393320959779844564491361244061062403802878699148500741855903744").unwrap()))
|
||||||
|
);
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
propagator.fold_field_expression(FieldElementExpression::LeftShift(
|
||||||
|
box FieldElementExpression::Number(Bn128Field::from(1)),
|
||||||
|
box UExpressionInner::Value((Bn128Field::get_required_bits()) as u128)
|
||||||
|
.annotate(UBitwidth::B32),
|
||||||
|
)),
|
||||||
|
Ok(FieldElementExpression::Number(Bn128Field::from(0)))
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn right_shift() {
|
||||||
|
let mut propagator = ZirPropagator::<Bn128Field>::default();
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
propagator.fold_field_expression(FieldElementExpression::RightShift(
|
||||||
|
box FieldElementExpression::identifier("a".into()),
|
||||||
|
box UExpressionInner::Value(0).annotate(UBitwidth::B32),
|
||||||
|
)),
|
||||||
|
Ok(FieldElementExpression::identifier("a".into()))
|
||||||
|
);
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
propagator.fold_field_expression(FieldElementExpression::RightShift(
|
||||||
|
box FieldElementExpression::identifier("a".into()),
|
||||||
|
box UExpressionInner::Value(Bn128Field::get_required_bits() as u128)
|
||||||
|
.annotate(UBitwidth::B32),
|
||||||
|
)),
|
||||||
|
Ok(FieldElementExpression::Number(Bn128Field::from(0)))
|
||||||
|
);
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
propagator.fold_field_expression(FieldElementExpression::RightShift(
|
||||||
|
box FieldElementExpression::Number(Bn128Field::from(3)),
|
||||||
|
box UExpressionInner::Value(1 as u128).annotate(UBitwidth::B32),
|
||||||
|
)),
|
||||||
|
Ok(FieldElementExpression::Number(Bn128Field::from(1)))
|
||||||
|
);
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
propagator.fold_field_expression(FieldElementExpression::RightShift(
|
||||||
|
box FieldElementExpression::Number(Bn128Field::from(2)),
|
||||||
|
box UExpressionInner::Value(2 as u128).annotate(UBitwidth::B32),
|
||||||
|
)),
|
||||||
|
Ok(FieldElementExpression::Number(Bn128Field::from(0)))
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
propagator.fold_field_expression(FieldElementExpression::RightShift(
|
||||||
|
box FieldElementExpression::Number(Bn128Field::from(2)),
|
||||||
|
box UExpressionInner::Value(4 as u128).annotate(UBitwidth::B32),
|
||||||
|
)),
|
||||||
|
Ok(FieldElementExpression::Number(Bn128Field::from(0)))
|
||||||
|
);
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
propagator.fold_field_expression(FieldElementExpression::RightShift(
|
||||||
|
box FieldElementExpression::Number(Bn128Field::max_value()),
|
||||||
|
box UExpressionInner::Value((Bn128Field::get_required_bits() - 1) as u128)
|
||||||
|
.annotate(UBitwidth::B32),
|
||||||
|
)),
|
||||||
|
Ok(FieldElementExpression::Number(Bn128Field::from(1)))
|
||||||
|
);
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
propagator.fold_field_expression(FieldElementExpression::RightShift(
|
||||||
|
box FieldElementExpression::Number(Bn128Field::max_value()),
|
||||||
|
box UExpressionInner::Value(Bn128Field::get_required_bits() as u128)
|
||||||
|
.annotate(UBitwidth::B32),
|
||||||
|
)),
|
||||||
|
Ok(FieldElementExpression::Number(Bn128Field::from(0)))
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn if_else() {
|
fn if_else() {
|
||||||
let mut propagator = ZirPropagator::default();
|
let mut propagator = ZirPropagator::default();
|
||||||
|
|
Loading…
Reference in a new issue