fix shifting in zir propagation
This commit is contained in:
parent
bc791e043a
commit
c73c09e5c2
1 changed files with 139 additions and 4 deletions
|
@ -1,6 +1,8 @@
|
|||
use num::traits::Pow;
|
||||
use num_bigint::BigUint;
|
||||
use std::collections::HashMap;
|
||||
use std::fmt;
|
||||
use std::ops::{BitAnd, BitOr, BitXor, Shl, Shr};
|
||||
use std::ops::{BitAnd, BitOr, BitXor, Mul, Shr, Sub};
|
||||
use zokrates_ast::zir::types::UBitwidth;
|
||||
use zokrates_ast::zir::{
|
||||
result_folder::*, Conditional, ConditionalExpression, ConditionalOrExpression, Expr, Id,
|
||||
|
@ -312,15 +314,30 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for ZirPropagator<'ast, T> {
|
|||
..
|
||||
},
|
||||
) if by == 0 => Ok(e),
|
||||
(
|
||||
_,
|
||||
UExpression {
|
||||
inner: UExpressionInner::Value(by),
|
||||
..
|
||||
},
|
||||
) if by as usize >= T::get_required_bits() => {
|
||||
Ok(FieldElementExpression::Number(T::from(0)))
|
||||
}
|
||||
(
|
||||
FieldElementExpression::Number(n),
|
||||
UExpression {
|
||||
inner: UExpressionInner::Value(by),
|
||||
..
|
||||
},
|
||||
) => Ok(FieldElementExpression::Number(
|
||||
T::try_from(n.to_biguint().shl(by as usize)).unwrap(),
|
||||
)),
|
||||
) => {
|
||||
let two = BigUint::from(2usize);
|
||||
let mask: BigUint = two.pow(T::get_required_bits()).sub(1usize);
|
||||
|
||||
Ok(FieldElementExpression::Number(
|
||||
T::try_from(n.to_biguint().mul(two.pow(by as usize)).bitand(mask))
|
||||
.unwrap(),
|
||||
))
|
||||
}
|
||||
(e, by) => Ok(FieldElementExpression::LeftShift(box e, box by)),
|
||||
}
|
||||
}
|
||||
|
@ -335,6 +352,15 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for ZirPropagator<'ast, T> {
|
|||
..
|
||||
},
|
||||
) if by == 0 => Ok(e),
|
||||
(
|
||||
_,
|
||||
UExpression {
|
||||
inner: UExpressionInner::Value(by),
|
||||
..
|
||||
},
|
||||
) if by as usize >= T::get_required_bits() => {
|
||||
Ok(FieldElementExpression::Number(T::from(0)))
|
||||
}
|
||||
(
|
||||
FieldElementExpression::Number(n),
|
||||
UExpression {
|
||||
|
@ -948,6 +974,115 @@ mod tests {
|
|||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn left_shift() {
|
||||
let mut propagator = ZirPropagator::<Bn128Field>::default();
|
||||
|
||||
assert_eq!(
|
||||
propagator.fold_field_expression(FieldElementExpression::LeftShift(
|
||||
box FieldElementExpression::identifier("a".into()),
|
||||
box UExpressionInner::Value(0).annotate(UBitwidth::B32),
|
||||
)),
|
||||
Ok(FieldElementExpression::identifier("a".into()))
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
propagator.fold_field_expression(FieldElementExpression::LeftShift(
|
||||
box FieldElementExpression::Number(Bn128Field::from(2)),
|
||||
box UExpressionInner::Value(2 as u128).annotate(UBitwidth::B32),
|
||||
)),
|
||||
Ok(FieldElementExpression::Number(Bn128Field::from(8)))
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
propagator.fold_field_expression(FieldElementExpression::LeftShift(
|
||||
box FieldElementExpression::Number(Bn128Field::from(1)),
|
||||
box UExpressionInner::Value((Bn128Field::get_required_bits() - 1) as u128).annotate(UBitwidth::B32),
|
||||
)),
|
||||
Ok(FieldElementExpression::Number(Bn128Field::try_from_dec_str("14474011154664524427946373126085988481658748083205070504932198000989141204992").unwrap()))
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
propagator.fold_field_expression(FieldElementExpression::LeftShift(
|
||||
box FieldElementExpression::Number(Bn128Field::from(3)),
|
||||
box UExpressionInner::Value((Bn128Field::get_required_bits() - 3) as u128).annotate(UBitwidth::B32),
|
||||
)),
|
||||
Ok(FieldElementExpression::Number(Bn128Field::try_from_dec_str("10855508365998393320959779844564491361244061062403802878699148500741855903744").unwrap()))
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
propagator.fold_field_expression(FieldElementExpression::LeftShift(
|
||||
box FieldElementExpression::Number(Bn128Field::from(1)),
|
||||
box UExpressionInner::Value((Bn128Field::get_required_bits()) as u128)
|
||||
.annotate(UBitwidth::B32),
|
||||
)),
|
||||
Ok(FieldElementExpression::Number(Bn128Field::from(0)))
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn right_shift() {
|
||||
let mut propagator = ZirPropagator::<Bn128Field>::default();
|
||||
|
||||
assert_eq!(
|
||||
propagator.fold_field_expression(FieldElementExpression::RightShift(
|
||||
box FieldElementExpression::identifier("a".into()),
|
||||
box UExpressionInner::Value(0).annotate(UBitwidth::B32),
|
||||
)),
|
||||
Ok(FieldElementExpression::identifier("a".into()))
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
propagator.fold_field_expression(FieldElementExpression::RightShift(
|
||||
box FieldElementExpression::identifier("a".into()),
|
||||
box UExpressionInner::Value(Bn128Field::get_required_bits() as u128)
|
||||
.annotate(UBitwidth::B32),
|
||||
)),
|
||||
Ok(FieldElementExpression::Number(Bn128Field::from(0)))
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
propagator.fold_field_expression(FieldElementExpression::RightShift(
|
||||
box FieldElementExpression::Number(Bn128Field::from(3)),
|
||||
box UExpressionInner::Value(1 as u128).annotate(UBitwidth::B32),
|
||||
)),
|
||||
Ok(FieldElementExpression::Number(Bn128Field::from(1)))
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
propagator.fold_field_expression(FieldElementExpression::RightShift(
|
||||
box FieldElementExpression::Number(Bn128Field::from(2)),
|
||||
box UExpressionInner::Value(2 as u128).annotate(UBitwidth::B32),
|
||||
)),
|
||||
Ok(FieldElementExpression::Number(Bn128Field::from(0)))
|
||||
);
|
||||
assert_eq!(
|
||||
propagator.fold_field_expression(FieldElementExpression::RightShift(
|
||||
box FieldElementExpression::Number(Bn128Field::from(2)),
|
||||
box UExpressionInner::Value(4 as u128).annotate(UBitwidth::B32),
|
||||
)),
|
||||
Ok(FieldElementExpression::Number(Bn128Field::from(0)))
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
propagator.fold_field_expression(FieldElementExpression::RightShift(
|
||||
box FieldElementExpression::Number(Bn128Field::max_value()),
|
||||
box UExpressionInner::Value((Bn128Field::get_required_bits() - 1) as u128)
|
||||
.annotate(UBitwidth::B32),
|
||||
)),
|
||||
Ok(FieldElementExpression::Number(Bn128Field::from(1)))
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
propagator.fold_field_expression(FieldElementExpression::RightShift(
|
||||
box FieldElementExpression::Number(Bn128Field::max_value()),
|
||||
box UExpressionInner::Value(Bn128Field::get_required_bits() as u128)
|
||||
.annotate(UBitwidth::B32),
|
||||
)),
|
||||
Ok(FieldElementExpression::Number(Bn128Field::from(0)))
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn if_else() {
|
||||
let mut propagator = ZirPropagator::default();
|
||||
|
|
Loading…
Reference in a new issue