1
0
Fork 0
mirror of synced 2025-09-23 12:18:44 +00:00

fix shifting in zir propagation

This commit is contained in:
dark64 2022-11-23 20:31:49 +01:00
parent bc791e043a
commit c73c09e5c2

View file

@ -1,6 +1,8 @@
use num::traits::Pow;
use num_bigint::BigUint;
use std::collections::HashMap;
use std::fmt;
use std::ops::{BitAnd, BitOr, BitXor, Shl, Shr};
use std::ops::{BitAnd, BitOr, BitXor, Mul, Shr, Sub};
use zokrates_ast::zir::types::UBitwidth;
use zokrates_ast::zir::{
result_folder::*, Conditional, ConditionalExpression, ConditionalOrExpression, Expr, Id,
@ -312,15 +314,30 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for ZirPropagator<'ast, T> {
..
},
) if by == 0 => Ok(e),
(
_,
UExpression {
inner: UExpressionInner::Value(by),
..
},
) if by as usize >= T::get_required_bits() => {
Ok(FieldElementExpression::Number(T::from(0)))
}
(
FieldElementExpression::Number(n),
UExpression {
inner: UExpressionInner::Value(by),
..
},
) => Ok(FieldElementExpression::Number(
T::try_from(n.to_biguint().shl(by as usize)).unwrap(),
)),
) => {
let two = BigUint::from(2usize);
let mask: BigUint = two.pow(T::get_required_bits()).sub(1usize);
Ok(FieldElementExpression::Number(
T::try_from(n.to_biguint().mul(two.pow(by as usize)).bitand(mask))
.unwrap(),
))
}
(e, by) => Ok(FieldElementExpression::LeftShift(box e, box by)),
}
}
@ -335,6 +352,15 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for ZirPropagator<'ast, T> {
..
},
) if by == 0 => Ok(e),
(
_,
UExpression {
inner: UExpressionInner::Value(by),
..
},
) if by as usize >= T::get_required_bits() => {
Ok(FieldElementExpression::Number(T::from(0)))
}
(
FieldElementExpression::Number(n),
UExpression {
@ -948,6 +974,115 @@ mod tests {
);
}
#[test]
fn left_shift() {
let mut propagator = ZirPropagator::<Bn128Field>::default();
assert_eq!(
propagator.fold_field_expression(FieldElementExpression::LeftShift(
box FieldElementExpression::identifier("a".into()),
box UExpressionInner::Value(0).annotate(UBitwidth::B32),
)),
Ok(FieldElementExpression::identifier("a".into()))
);
assert_eq!(
propagator.fold_field_expression(FieldElementExpression::LeftShift(
box FieldElementExpression::Number(Bn128Field::from(2)),
box UExpressionInner::Value(2 as u128).annotate(UBitwidth::B32),
)),
Ok(FieldElementExpression::Number(Bn128Field::from(8)))
);
assert_eq!(
propagator.fold_field_expression(FieldElementExpression::LeftShift(
box FieldElementExpression::Number(Bn128Field::from(1)),
box UExpressionInner::Value((Bn128Field::get_required_bits() - 1) as u128).annotate(UBitwidth::B32),
)),
Ok(FieldElementExpression::Number(Bn128Field::try_from_dec_str("14474011154664524427946373126085988481658748083205070504932198000989141204992").unwrap()))
);
assert_eq!(
propagator.fold_field_expression(FieldElementExpression::LeftShift(
box FieldElementExpression::Number(Bn128Field::from(3)),
box UExpressionInner::Value((Bn128Field::get_required_bits() - 3) as u128).annotate(UBitwidth::B32),
)),
Ok(FieldElementExpression::Number(Bn128Field::try_from_dec_str("10855508365998393320959779844564491361244061062403802878699148500741855903744").unwrap()))
);
assert_eq!(
propagator.fold_field_expression(FieldElementExpression::LeftShift(
box FieldElementExpression::Number(Bn128Field::from(1)),
box UExpressionInner::Value((Bn128Field::get_required_bits()) as u128)
.annotate(UBitwidth::B32),
)),
Ok(FieldElementExpression::Number(Bn128Field::from(0)))
);
}
#[test]
fn right_shift() {
let mut propagator = ZirPropagator::<Bn128Field>::default();
assert_eq!(
propagator.fold_field_expression(FieldElementExpression::RightShift(
box FieldElementExpression::identifier("a".into()),
box UExpressionInner::Value(0).annotate(UBitwidth::B32),
)),
Ok(FieldElementExpression::identifier("a".into()))
);
assert_eq!(
propagator.fold_field_expression(FieldElementExpression::RightShift(
box FieldElementExpression::identifier("a".into()),
box UExpressionInner::Value(Bn128Field::get_required_bits() as u128)
.annotate(UBitwidth::B32),
)),
Ok(FieldElementExpression::Number(Bn128Field::from(0)))
);
assert_eq!(
propagator.fold_field_expression(FieldElementExpression::RightShift(
box FieldElementExpression::Number(Bn128Field::from(3)),
box UExpressionInner::Value(1 as u128).annotate(UBitwidth::B32),
)),
Ok(FieldElementExpression::Number(Bn128Field::from(1)))
);
assert_eq!(
propagator.fold_field_expression(FieldElementExpression::RightShift(
box FieldElementExpression::Number(Bn128Field::from(2)),
box UExpressionInner::Value(2 as u128).annotate(UBitwidth::B32),
)),
Ok(FieldElementExpression::Number(Bn128Field::from(0)))
);
assert_eq!(
propagator.fold_field_expression(FieldElementExpression::RightShift(
box FieldElementExpression::Number(Bn128Field::from(2)),
box UExpressionInner::Value(4 as u128).annotate(UBitwidth::B32),
)),
Ok(FieldElementExpression::Number(Bn128Field::from(0)))
);
assert_eq!(
propagator.fold_field_expression(FieldElementExpression::RightShift(
box FieldElementExpression::Number(Bn128Field::max_value()),
box UExpressionInner::Value((Bn128Field::get_required_bits() - 1) as u128)
.annotate(UBitwidth::B32),
)),
Ok(FieldElementExpression::Number(Bn128Field::from(1)))
);
assert_eq!(
propagator.fold_field_expression(FieldElementExpression::RightShift(
box FieldElementExpression::Number(Bn128Field::max_value()),
box UExpressionInner::Value(Bn128Field::get_required_bits() as u128)
.annotate(UBitwidth::B32),
)),
Ok(FieldElementExpression::Number(Bn128Field::from(0)))
);
}
#[test]
fn if_else() {
let mut propagator = ZirPropagator::default();