1
0
Fork 0
mirror of synced 2025-09-23 12:18:44 +00:00

fix shifting in zir propagation

This commit is contained in:
dark64 2022-11-23 20:31:49 +01:00
parent bc791e043a
commit c73c09e5c2

View file

@ -1,6 +1,8 @@
use num::traits::Pow;
use num_bigint::BigUint;
use std::collections::HashMap; use std::collections::HashMap;
use std::fmt; use std::fmt;
use std::ops::{BitAnd, BitOr, BitXor, Shl, Shr}; use std::ops::{BitAnd, BitOr, BitXor, Mul, Shr, Sub};
use zokrates_ast::zir::types::UBitwidth; use zokrates_ast::zir::types::UBitwidth;
use zokrates_ast::zir::{ use zokrates_ast::zir::{
result_folder::*, Conditional, ConditionalExpression, ConditionalOrExpression, Expr, Id, result_folder::*, Conditional, ConditionalExpression, ConditionalOrExpression, Expr, Id,
@ -312,15 +314,30 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for ZirPropagator<'ast, T> {
.. ..
}, },
) if by == 0 => Ok(e), ) if by == 0 => Ok(e),
(
_,
UExpression {
inner: UExpressionInner::Value(by),
..
},
) if by as usize >= T::get_required_bits() => {
Ok(FieldElementExpression::Number(T::from(0)))
}
( (
FieldElementExpression::Number(n), FieldElementExpression::Number(n),
UExpression { UExpression {
inner: UExpressionInner::Value(by), inner: UExpressionInner::Value(by),
.. ..
}, },
) => Ok(FieldElementExpression::Number( ) => {
T::try_from(n.to_biguint().shl(by as usize)).unwrap(), let two = BigUint::from(2usize);
)), let mask: BigUint = two.pow(T::get_required_bits()).sub(1usize);
Ok(FieldElementExpression::Number(
T::try_from(n.to_biguint().mul(two.pow(by as usize)).bitand(mask))
.unwrap(),
))
}
(e, by) => Ok(FieldElementExpression::LeftShift(box e, box by)), (e, by) => Ok(FieldElementExpression::LeftShift(box e, box by)),
} }
} }
@ -335,6 +352,15 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for ZirPropagator<'ast, T> {
.. ..
}, },
) if by == 0 => Ok(e), ) if by == 0 => Ok(e),
(
_,
UExpression {
inner: UExpressionInner::Value(by),
..
},
) if by as usize >= T::get_required_bits() => {
Ok(FieldElementExpression::Number(T::from(0)))
}
( (
FieldElementExpression::Number(n), FieldElementExpression::Number(n),
UExpression { UExpression {
@ -948,6 +974,115 @@ mod tests {
); );
} }
#[test]
fn left_shift() {
let mut propagator = ZirPropagator::<Bn128Field>::default();
assert_eq!(
propagator.fold_field_expression(FieldElementExpression::LeftShift(
box FieldElementExpression::identifier("a".into()),
box UExpressionInner::Value(0).annotate(UBitwidth::B32),
)),
Ok(FieldElementExpression::identifier("a".into()))
);
assert_eq!(
propagator.fold_field_expression(FieldElementExpression::LeftShift(
box FieldElementExpression::Number(Bn128Field::from(2)),
box UExpressionInner::Value(2 as u128).annotate(UBitwidth::B32),
)),
Ok(FieldElementExpression::Number(Bn128Field::from(8)))
);
assert_eq!(
propagator.fold_field_expression(FieldElementExpression::LeftShift(
box FieldElementExpression::Number(Bn128Field::from(1)),
box UExpressionInner::Value((Bn128Field::get_required_bits() - 1) as u128).annotate(UBitwidth::B32),
)),
Ok(FieldElementExpression::Number(Bn128Field::try_from_dec_str("14474011154664524427946373126085988481658748083205070504932198000989141204992").unwrap()))
);
assert_eq!(
propagator.fold_field_expression(FieldElementExpression::LeftShift(
box FieldElementExpression::Number(Bn128Field::from(3)),
box UExpressionInner::Value((Bn128Field::get_required_bits() - 3) as u128).annotate(UBitwidth::B32),
)),
Ok(FieldElementExpression::Number(Bn128Field::try_from_dec_str("10855508365998393320959779844564491361244061062403802878699148500741855903744").unwrap()))
);
assert_eq!(
propagator.fold_field_expression(FieldElementExpression::LeftShift(
box FieldElementExpression::Number(Bn128Field::from(1)),
box UExpressionInner::Value((Bn128Field::get_required_bits()) as u128)
.annotate(UBitwidth::B32),
)),
Ok(FieldElementExpression::Number(Bn128Field::from(0)))
);
}
#[test]
fn right_shift() {
let mut propagator = ZirPropagator::<Bn128Field>::default();
assert_eq!(
propagator.fold_field_expression(FieldElementExpression::RightShift(
box FieldElementExpression::identifier("a".into()),
box UExpressionInner::Value(0).annotate(UBitwidth::B32),
)),
Ok(FieldElementExpression::identifier("a".into()))
);
assert_eq!(
propagator.fold_field_expression(FieldElementExpression::RightShift(
box FieldElementExpression::identifier("a".into()),
box UExpressionInner::Value(Bn128Field::get_required_bits() as u128)
.annotate(UBitwidth::B32),
)),
Ok(FieldElementExpression::Number(Bn128Field::from(0)))
);
assert_eq!(
propagator.fold_field_expression(FieldElementExpression::RightShift(
box FieldElementExpression::Number(Bn128Field::from(3)),
box UExpressionInner::Value(1 as u128).annotate(UBitwidth::B32),
)),
Ok(FieldElementExpression::Number(Bn128Field::from(1)))
);
assert_eq!(
propagator.fold_field_expression(FieldElementExpression::RightShift(
box FieldElementExpression::Number(Bn128Field::from(2)),
box UExpressionInner::Value(2 as u128).annotate(UBitwidth::B32),
)),
Ok(FieldElementExpression::Number(Bn128Field::from(0)))
);
assert_eq!(
propagator.fold_field_expression(FieldElementExpression::RightShift(
box FieldElementExpression::Number(Bn128Field::from(2)),
box UExpressionInner::Value(4 as u128).annotate(UBitwidth::B32),
)),
Ok(FieldElementExpression::Number(Bn128Field::from(0)))
);
assert_eq!(
propagator.fold_field_expression(FieldElementExpression::RightShift(
box FieldElementExpression::Number(Bn128Field::max_value()),
box UExpressionInner::Value((Bn128Field::get_required_bits() - 1) as u128)
.annotate(UBitwidth::B32),
)),
Ok(FieldElementExpression::Number(Bn128Field::from(1)))
);
assert_eq!(
propagator.fold_field_expression(FieldElementExpression::RightShift(
box FieldElementExpression::Number(Bn128Field::max_value()),
box UExpressionInner::Value(Bn128Field::get_required_bits() as u128)
.annotate(UBitwidth::B32),
)),
Ok(FieldElementExpression::Number(Bn128Field::from(0)))
);
}
#[test] #[test]
fn if_else() { fn if_else() {
let mut propagator = ZirPropagator::default(); let mut propagator = ZirPropagator::default();