1
0
Fork 0
mirror of synced 2025-09-23 12:18:44 +00:00

refactor zir to introduce conditional and select expressions

This commit is contained in:
schaeff 2022-08-22 12:01:23 +02:00
parent 53b62f568b
commit e4fbc6d35a
8 changed files with 634 additions and 654 deletions

View file

@ -4,6 +4,27 @@ use crate::zir::types::UBitwidth;
use crate::zir::*;
use zokrates_field::Field;
pub trait Fold<'ast, T: Field>: Sized {
fn fold<F: Folder<'ast, T>>(self, f: &mut F) -> Self;
}
impl<'ast, T: Field> Fold<'ast, T> for FieldElementExpression<'ast, T> {
fn fold<F: Folder<'ast, T>>(self, f: &mut F) -> Self {
f.fold_field_expression(self)
}
}
impl<'ast, T: Field> Fold<'ast, T> for BooleanExpression<'ast, T> {
fn fold<F: Folder<'ast, T>>(self, f: &mut F) -> Self {
f.fold_boolean_expression(self)
}
}
impl<'ast, T: Field> Fold<'ast, T> for UExpression<'ast, T> {
fn fold<F: Folder<'ast, T>>(self, f: &mut F) -> Self {
f.fold_uint_expression(self)
}
}
pub trait Folder<'ast, T: Field>: Sized {
fn fold_program(&mut self, p: ZirProgram<'ast, T>) -> ZirProgram<'ast, T> {
fold_program(self, p)
@ -39,6 +60,22 @@ pub trait Folder<'ast, T: Field>: Sized {
fold_statement(self, s)
}
fn fold_conditional_expression<E: Expr<'ast, T> + Fold<'ast, T> + Conditional<'ast, T>>(
&mut self,
ty: &E::Ty,
e: ConditionalExpression<'ast, T, E>,
) -> ConditionalOrExpression<'ast, T, E> {
fold_conditional_expression(self, ty, e)
}
fn fold_select_expression<E: Expr<'ast, T> + Fold<'ast, T> + Select<'ast, T>>(
&mut self,
ty: &E::Ty,
e: SelectExpression<'ast, T, E>,
) -> SelectOrExpression<'ast, T, E> {
fold_select_expression(self, ty, e)
}
fn fold_expression(&mut self, e: ZirExpression<'ast, T>) -> ZirExpression<'ast, T> {
match e {
ZirExpression::FieldElement(e) => self.fold_field_expression(e).into(),
@ -141,10 +178,12 @@ pub fn fold_field_expression<'ast, T: Field, F: Folder<'ast, T>>(
FieldElementExpression::Identifier(id) => {
FieldElementExpression::Identifier(f.fold_name(id))
}
FieldElementExpression::Select(a, box i) => FieldElementExpression::Select(
a.into_iter().map(|a| f.fold_field_expression(a)).collect(),
box f.fold_uint_expression(i),
),
FieldElementExpression::Select(e) => {
match f.fold_select_expression(&Type::FieldElement, e) {
SelectOrExpression::Select(s) => FieldElementExpression::Select(s),
SelectOrExpression::Expression(u) => u,
}
}
FieldElementExpression::Add(box e1, box e2) => {
let e1 = f.fold_field_expression(e1);
let e2 = f.fold_field_expression(e2);
@ -170,11 +209,11 @@ pub fn fold_field_expression<'ast, T: Field, F: Folder<'ast, T>>(
let e2 = f.fold_uint_expression(e2);
FieldElementExpression::Pow(box e1, box e2)
}
FieldElementExpression::Conditional(box cond, box cons, box alt) => {
let cond = f.fold_boolean_expression(cond);
let cons = f.fold_field_expression(cons);
let alt = f.fold_field_expression(alt);
FieldElementExpression::Conditional(box cond, box cons, box alt)
FieldElementExpression::Conditional(c) => {
match f.fold_conditional_expression(&Type::FieldElement, c) {
ConditionalOrExpression::Conditional(s) => FieldElementExpression::Conditional(s),
ConditionalOrExpression::Expression(u) => u,
}
}
}
}
@ -186,12 +225,10 @@ pub fn fold_boolean_expression<'ast, T: Field, F: Folder<'ast, T>>(
match e {
BooleanExpression::Value(v) => BooleanExpression::Value(v),
BooleanExpression::Identifier(id) => BooleanExpression::Identifier(f.fold_name(id)),
BooleanExpression::Select(a, box i) => BooleanExpression::Select(
a.into_iter()
.map(|a| f.fold_boolean_expression(a))
.collect(),
box f.fold_uint_expression(i),
),
BooleanExpression::Select(e) => match f.fold_select_expression(&Type::Boolean, e) {
SelectOrExpression::Select(s) => BooleanExpression::Select(s),
SelectOrExpression::Expression(u) => u,
},
BooleanExpression::FieldEq(box e1, box e2) => {
let e1 = f.fold_field_expression(e1);
let e2 = f.fold_field_expression(e2);
@ -212,41 +249,21 @@ pub fn fold_boolean_expression<'ast, T: Field, F: Folder<'ast, T>>(
let e2 = f.fold_field_expression(e2);
BooleanExpression::FieldLt(box e1, box e2)
}
BooleanExpression::FieldLe(box e1, box e2) => {
let e1 = f.fold_field_expression(e1);
let e2 = f.fold_field_expression(e2);
BooleanExpression::FieldLe(box e1, box e2)
}
BooleanExpression::FieldGt(box e1, box e2) => {
let e1 = f.fold_field_expression(e1);
let e2 = f.fold_field_expression(e2);
BooleanExpression::FieldGt(box e1, box e2)
}
BooleanExpression::FieldGe(box e1, box e2) => {
let e1 = f.fold_field_expression(e1);
let e2 = f.fold_field_expression(e2);
BooleanExpression::FieldGe(box e1, box e2)
}
BooleanExpression::UintLt(box e1, box e2) => {
let e1 = f.fold_uint_expression(e1);
let e2 = f.fold_uint_expression(e2);
BooleanExpression::UintLt(box e1, box e2)
}
BooleanExpression::FieldLe(box e1, box e2) => {
let e1 = f.fold_field_expression(e1);
let e2 = f.fold_field_expression(e2);
BooleanExpression::FieldLe(box e1, box e2)
}
BooleanExpression::UintLe(box e1, box e2) => {
let e1 = f.fold_uint_expression(e1);
let e2 = f.fold_uint_expression(e2);
BooleanExpression::UintLe(box e1, box e2)
}
BooleanExpression::UintGt(box e1, box e2) => {
let e1 = f.fold_uint_expression(e1);
let e2 = f.fold_uint_expression(e2);
BooleanExpression::UintGt(box e1, box e2)
}
BooleanExpression::UintGe(box e1, box e2) => {
let e1 = f.fold_uint_expression(e1);
let e2 = f.fold_uint_expression(e2);
BooleanExpression::UintGe(box e1, box e2)
}
BooleanExpression::Or(box e1, box e2) => {
let e1 = f.fold_boolean_expression(e1);
let e2 = f.fold_boolean_expression(e2);
@ -261,12 +278,11 @@ pub fn fold_boolean_expression<'ast, T: Field, F: Folder<'ast, T>>(
let e = f.fold_boolean_expression(e);
BooleanExpression::Not(box e)
}
BooleanExpression::Conditional(box cond, box cons, box alt) => {
let cond = f.fold_boolean_expression(cond);
let cons = f.fold_boolean_expression(cons);
let alt = f.fold_boolean_expression(alt);
BooleanExpression::Conditional(box cond, box cons, box alt)
}
BooleanExpression::Conditional(c) => match f.fold_conditional_expression(&Type::Boolean, c)
{
ConditionalOrExpression::Conditional(s) => BooleanExpression::Conditional(s),
ConditionalOrExpression::Expression(u) => u,
},
}
}
@ -282,16 +298,16 @@ pub fn fold_uint_expression<'ast, T: Field, F: Folder<'ast, T>>(
pub fn fold_uint_expression_inner<'ast, T: Field, F: Folder<'ast, T>>(
f: &mut F,
_: UBitwidth,
ty: UBitwidth,
e: UExpressionInner<'ast, T>,
) -> UExpressionInner<'ast, T> {
match e {
UExpressionInner::Value(v) => UExpressionInner::Value(v),
UExpressionInner::Identifier(id) => UExpressionInner::Identifier(f.fold_name(id)),
UExpressionInner::Select(a, box i) => UExpressionInner::Select(
a.into_iter().map(|a| f.fold_uint_expression(a)).collect(),
box f.fold_uint_expression(i),
),
UExpressionInner::Select(e) => match f.fold_select_expression(&ty, e) {
SelectOrExpression::Select(s) => UExpressionInner::Select(s),
SelectOrExpression::Expression(u) => u,
},
UExpressionInner::Add(box left, box right) => {
let left = f.fold_uint_expression(left);
let right = f.fold_uint_expression(right);
@ -355,12 +371,10 @@ pub fn fold_uint_expression_inner<'ast, T: Field, F: Folder<'ast, T>>(
UExpressionInner::Not(box e)
}
UExpressionInner::Conditional(box cond, box cons, box alt) => {
let cond = f.fold_boolean_expression(cond);
let cons = f.fold_uint_expression(cons);
let alt = f.fold_uint_expression(alt);
UExpressionInner::Conditional(box cond, box cons, box alt)
}
UExpressionInner::Conditional(c) => match f.fold_conditional_expression(&ty, c) {
ConditionalOrExpression::Conditional(s) => UExpressionInner::Conditional(s),
ConditionalOrExpression::Expression(u) => u,
},
}
}
@ -391,3 +405,36 @@ pub fn fold_program<'ast, T: Field, F: Folder<'ast, T>>(
main: f.fold_function(p.main),
}
}
pub fn fold_conditional_expression<
'ast,
T: Field,
E: Expr<'ast, T> + Fold<'ast, T> + Conditional<'ast, T>,
F: Folder<'ast, T>,
>(
f: &mut F,
_: &E::Ty,
e: ConditionalExpression<'ast, T, E>,
) -> ConditionalOrExpression<'ast, T, E> {
ConditionalOrExpression::Conditional(ConditionalExpression::new(
f.fold_boolean_expression(*e.condition),
e.consequence.fold(f),
e.alternative.fold(f),
))
}
pub fn fold_select_expression<
'ast,
T: Field,
E: Expr<'ast, T> + Fold<'ast, T> + Select<'ast, T>,
F: Folder<'ast, T>,
>(
f: &mut F,
_: &E::Ty,
e: SelectExpression<'ast, T, E>,
) -> SelectOrExpression<'ast, T, E> {
SelectOrExpression::Select(SelectExpression::new(
e.array.into_iter().map(|e| e.fold(f)).collect(),
e.index.fold(f),
))
}

View file

@ -8,7 +8,7 @@ mod uint;
mod variable;
pub use self::parameter::Parameter;
pub use self::types::Type;
pub use self::types::{Type, UBitwidth};
pub use self::variable::Variable;
use crate::common::{FlatEmbed, FormatString};
use crate::typed::ConcreteType;
@ -23,7 +23,7 @@ pub use self::folder::Folder;
pub use self::identifier::{Identifier, SourceIdentifier};
/// A typed program as a collection of modules, one of them being the main
#[derive(PartialEq, Eq, Debug)]
#[derive(PartialEq, Eq, Debug, Clone)]
pub struct ZirProgram<'ast, T> {
pub main: ZirFunction<'ast, T>,
}
@ -138,14 +138,15 @@ impl<'ast, T: fmt::Display> ZirStatement<'ast, T> {
write!(f, "{}", "\t".repeat(depth))?;
match self {
ZirStatement::Return(ref exprs) => {
write!(f, "return ")?;
for (i, expr) in exprs.iter().enumerate() {
write!(f, "{}", expr)?;
if i < exprs.len() - 1 {
write!(f, ", ")?;
}
}
write!(f, ";")
write!(
f,
"return {};",
exprs
.iter()
.map(|e| e.to_string())
.collect::<Vec<_>>()
.join(", ")
)
}
ZirStatement::Definition(ref lhs, ref rhs) => {
write!(f, "{} = {};", lhs, rhs)
@ -181,7 +182,7 @@ impl<'ast, T: fmt::Display> ZirStatement<'ast, T> {
}
ZirStatement::Log(ref l, ref expressions) => write!(
f,
"log(\"{}\"), {})",
"log(\"{}\"), {});",
l,
expressions
.iter()
@ -203,6 +204,63 @@ pub trait Typed {
fn get_type(&self) -> Type;
}
#[derive(Debug, Clone, PartialEq, Hash, Eq)]
pub struct ConditionalExpression<'ast, T, E> {
pub condition: Box<BooleanExpression<'ast, T>>,
pub consequence: Box<E>,
pub alternative: Box<E>,
}
impl<'ast, T, E> ConditionalExpression<'ast, T, E> {
pub fn new(condition: BooleanExpression<'ast, T>, consequence: E, alternative: E) -> Self {
ConditionalExpression {
condition: box condition,
consequence: box consequence,
alternative: box alternative,
}
}
}
impl<'ast, T: fmt::Display, E: fmt::Display> fmt::Display for ConditionalExpression<'ast, T, E> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(
f,
"{} ? {} : {}",
self.condition, self.consequence, self.alternative
)
}
}
#[derive(Debug, Clone, PartialEq, Hash, Eq)]
pub struct SelectExpression<'ast, T, E> {
pub array: Vec<E>,
pub index: Box<UExpression<'ast, T>>,
}
impl<'ast, T, E> SelectExpression<'ast, T, E> {
pub fn new(array: Vec<E>, index: UExpression<'ast, T>) -> Self {
SelectExpression {
array,
index: box index,
}
}
}
impl<'ast, T: fmt::Display, E: fmt::Display> fmt::Display for SelectExpression<'ast, T, E> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(
f,
"{}[{}]",
self.array
.iter()
.map(|a| a.to_string())
.collect::<Vec<_>>()
.join(", "),
self.index
)
}
}
/// A typed expression
#[derive(Clone, PartialEq, Hash, Eq)]
pub enum ZirExpression<'ast, T> {
@ -291,7 +349,7 @@ pub enum ZirExpressionList<'ast, T> {
pub enum FieldElementExpression<'ast, T> {
Number(T),
Identifier(Identifier<'ast>),
Select(Vec<Self>, Box<UExpression<'ast, T>>),
Select(SelectExpression<'ast, T, Self>),
Add(
Box<FieldElementExpression<'ast, T>>,
Box<FieldElementExpression<'ast, T>>,
@ -312,11 +370,7 @@ pub enum FieldElementExpression<'ast, T> {
Box<FieldElementExpression<'ast, T>>,
Box<UExpression<'ast, T>>,
),
Conditional(
Box<BooleanExpression<'ast, T>>,
Box<FieldElementExpression<'ast, T>>,
Box<FieldElementExpression<'ast, T>>,
),
Conditional(ConditionalExpression<'ast, T, FieldElementExpression<'ast, T>>),
}
/// An expression of type `bool`
@ -324,7 +378,7 @@ pub enum FieldElementExpression<'ast, T> {
pub enum BooleanExpression<'ast, T> {
Value(bool),
Identifier(Identifier<'ast>),
Select(Vec<Self>, Box<UExpression<'ast, T>>),
Select(SelectExpression<'ast, T, Self>),
FieldLt(
Box<FieldElementExpression<'ast, T>>,
Box<FieldElementExpression<'ast, T>>,
@ -333,22 +387,12 @@ pub enum BooleanExpression<'ast, T> {
Box<FieldElementExpression<'ast, T>>,
Box<FieldElementExpression<'ast, T>>,
),
FieldGe(
Box<FieldElementExpression<'ast, T>>,
Box<FieldElementExpression<'ast, T>>,
),
FieldGt(
Box<FieldElementExpression<'ast, T>>,
Box<FieldElementExpression<'ast, T>>,
),
FieldEq(
Box<FieldElementExpression<'ast, T>>,
Box<FieldElementExpression<'ast, T>>,
),
UintLt(Box<UExpression<'ast, T>>, Box<UExpression<'ast, T>>),
UintLe(Box<UExpression<'ast, T>>, Box<UExpression<'ast, T>>),
UintGe(Box<UExpression<'ast, T>>, Box<UExpression<'ast, T>>),
UintGt(Box<UExpression<'ast, T>>, Box<UExpression<'ast, T>>),
UintEq(Box<UExpression<'ast, T>>, Box<UExpression<'ast, T>>),
BoolEq(
Box<BooleanExpression<'ast, T>>,
@ -363,11 +407,7 @@ pub enum BooleanExpression<'ast, T> {
Box<BooleanExpression<'ast, T>>,
),
Not(Box<BooleanExpression<'ast, T>>),
Conditional(
Box<BooleanExpression<'ast, T>>,
Box<BooleanExpression<'ast, T>>,
Box<BooleanExpression<'ast, T>>,
),
Conditional(ConditionalExpression<'ast, T, BooleanExpression<'ast, T>>),
}
pub struct ConjunctionIterator<T> {
@ -438,26 +478,14 @@ impl<'ast, T: fmt::Display> fmt::Display for FieldElementExpression<'ast, T> {
match *self {
FieldElementExpression::Number(ref i) => write!(f, "{}", i),
FieldElementExpression::Identifier(ref var) => write!(f, "{}", var),
FieldElementExpression::Select(ref a, ref i) => write!(
f,
"[{}][{}]",
a.iter()
.map(|a| a.to_string())
.collect::<Vec<_>>()
.join(", "),
i
),
FieldElementExpression::Select(ref e) => write!(f, "{}", e),
FieldElementExpression::Add(ref lhs, ref rhs) => write!(f, "({} + {})", lhs, rhs),
FieldElementExpression::Sub(ref lhs, ref rhs) => write!(f, "({} - {})", lhs, rhs),
FieldElementExpression::Mult(ref lhs, ref rhs) => write!(f, "({} * {})", lhs, rhs),
FieldElementExpression::Div(ref lhs, ref rhs) => write!(f, "({} / {})", lhs, rhs),
FieldElementExpression::Pow(ref lhs, ref rhs) => write!(f, "{}**{}", lhs, rhs),
FieldElementExpression::Conditional(ref condition, ref consequent, ref alternative) => {
write!(
f,
"if {} {{ {} }} else {{ {} }}",
condition, consequent, alternative
)
FieldElementExpression::Conditional(ref c) => {
write!(f, "{}", c)
}
}
}
@ -468,15 +496,7 @@ impl<'ast, T: fmt::Display> fmt::Display for UExpression<'ast, T> {
match self.inner {
UExpressionInner::Value(ref v) => write!(f, "{}", v),
UExpressionInner::Identifier(ref var) => write!(f, "{}", var),
UExpressionInner::Select(ref a, ref i) => write!(
f,
"[{}][{}]",
a.iter()
.map(|a| a.to_string())
.collect::<Vec<_>>()
.join(", "),
i
),
UExpressionInner::Select(ref e) => write!(f, "{}", e),
UExpressionInner::Add(ref lhs, ref rhs) => write!(f, "({} + {})", lhs, rhs),
UExpressionInner::Sub(ref lhs, ref rhs) => write!(f, "({} - {})", lhs, rhs),
UExpressionInner::Mult(ref lhs, ref rhs) => write!(f, "({} * {})", lhs, rhs),
@ -488,12 +508,8 @@ impl<'ast, T: fmt::Display> fmt::Display for UExpression<'ast, T> {
UExpressionInner::LeftShift(ref e, ref by) => write!(f, "({} << {})", e, by),
UExpressionInner::RightShift(ref e, ref by) => write!(f, "({} >> {})", e, by),
UExpressionInner::Not(ref e) => write!(f, "!{}", e),
UExpressionInner::Conditional(ref condition, ref consequent, ref alternative) => {
write!(
f,
"if {} {{ {} }} else {{ {} }}",
condition, consequent, alternative
)
UExpressionInner::Conditional(ref c) => {
write!(f, "{}", c)
}
}
}
@ -504,35 +520,19 @@ impl<'ast, T: fmt::Display> fmt::Display for BooleanExpression<'ast, T> {
match *self {
BooleanExpression::Identifier(ref var) => write!(f, "{}", var),
BooleanExpression::Value(b) => write!(f, "{}", b),
BooleanExpression::Select(ref a, ref i) => write!(
f,
"[{}][{}]",
a.iter()
.map(|a| a.to_string())
.collect::<Vec<_>>()
.join(", "),
i
),
BooleanExpression::FieldLt(ref lhs, ref rhs) => write!(f, "{} < {}", lhs, rhs),
BooleanExpression::FieldLe(ref lhs, ref rhs) => write!(f, "{} <= {}", lhs, rhs),
BooleanExpression::FieldGe(ref lhs, ref rhs) => write!(f, "{} >= {}", lhs, rhs),
BooleanExpression::FieldGt(ref lhs, ref rhs) => write!(f, "{} > {}", lhs, rhs),
BooleanExpression::UintLt(ref lhs, ref rhs) => write!(f, "{} < {}", lhs, rhs),
BooleanExpression::UintLe(ref lhs, ref rhs) => write!(f, "{} <= {}", lhs, rhs),
BooleanExpression::UintGe(ref lhs, ref rhs) => write!(f, "{} >= {}", lhs, rhs),
BooleanExpression::UintGt(ref lhs, ref rhs) => write!(f, "{} > {}", lhs, rhs),
BooleanExpression::FieldEq(ref lhs, ref rhs) => write!(f, "{} == {}", lhs, rhs),
BooleanExpression::BoolEq(ref lhs, ref rhs) => write!(f, "{} == {}", lhs, rhs),
BooleanExpression::UintEq(ref lhs, ref rhs) => write!(f, "{} == {}", lhs, rhs),
BooleanExpression::Or(ref lhs, ref rhs) => write!(f, "{} || {}", lhs, rhs),
BooleanExpression::And(ref lhs, ref rhs) => write!(f, "{} && {}", lhs, rhs),
BooleanExpression::Select(ref e) => write!(f, "{}", e),
BooleanExpression::FieldLt(ref lhs, ref rhs) => write!(f, "({} < {})", lhs, rhs),
BooleanExpression::UintLt(ref lhs, ref rhs) => write!(f, "({} < {})", lhs, rhs),
BooleanExpression::FieldLe(ref lhs, ref rhs) => write!(f, "({} <= {})", lhs, rhs),
BooleanExpression::UintLe(ref lhs, ref rhs) => write!(f, "({} <= {})", lhs, rhs),
BooleanExpression::FieldEq(ref lhs, ref rhs) => write!(f, "({} == {})", lhs, rhs),
BooleanExpression::BoolEq(ref lhs, ref rhs) => write!(f, "({} == {})", lhs, rhs),
BooleanExpression::UintEq(ref lhs, ref rhs) => write!(f, "({} == {})", lhs, rhs),
BooleanExpression::Or(ref lhs, ref rhs) => write!(f, "({} || {})", lhs, rhs),
BooleanExpression::And(ref lhs, ref rhs) => write!(f, "({} && {})", lhs, rhs),
BooleanExpression::Not(ref exp) => write!(f, "!{}", exp),
BooleanExpression::Conditional(ref condition, ref consequent, ref alternative) => {
write!(
f,
"if {} {{ {} }} else {{ {} }}",
condition, consequent, alternative
)
BooleanExpression::Conditional(ref c) => {
write!(f, "{}", c)
}
}
}
@ -584,7 +584,81 @@ impl<'ast, T: fmt::Debug> fmt::Debug for ZirExpressionList<'ast, T> {
}
// Common behaviour accross expressions
pub trait Expr<'ast, T>: fmt::Display + PartialEq {
type Inner;
type Ty: Clone + IntoType;
fn ty(&self) -> &Self::Ty;
fn into_inner(self) -> Self::Inner;
fn as_inner(&self) -> &Self::Inner;
fn as_inner_mut(&mut self) -> &mut Self::Inner;
}
impl<'ast, T: Field> Expr<'ast, T> for FieldElementExpression<'ast, T> {
type Inner = Self;
type Ty = Type;
fn ty(&self) -> &Self::Ty {
&Type::FieldElement
}
fn into_inner(self) -> Self::Inner {
self
}
fn as_inner(&self) -> &Self::Inner {
self
}
fn as_inner_mut(&mut self) -> &mut Self::Inner {
self
}
}
impl<'ast, T: Field> Expr<'ast, T> for BooleanExpression<'ast, T> {
type Inner = Self;
type Ty = Type;
fn ty(&self) -> &Self::Ty {
&Type::Boolean
}
fn into_inner(self) -> Self::Inner {
self
}
fn as_inner(&self) -> &Self::Inner {
self
}
fn as_inner_mut(&mut self) -> &mut Self::Inner {
self
}
}
impl<'ast, T: Field> Expr<'ast, T> for UExpression<'ast, T> {
type Inner = UExpressionInner<'ast, T>;
type Ty = UBitwidth;
fn ty(&self) -> &Self::Ty {
&self.bitwidth
}
fn into_inner(self) -> Self::Inner {
self.inner
}
fn as_inner(&self) -> &Self::Inner {
&self.inner
}
fn as_inner_mut(&mut self) -> &mut Self::Inner {
&mut self.inner
}
}
pub trait Conditional<'ast, T> {
fn conditional(
condition: BooleanExpression<'ast, T>,
@ -593,13 +667,22 @@ pub trait Conditional<'ast, T> {
) -> Self;
}
pub enum ConditionalOrExpression<'ast, T, E: Expr<'ast, T>> {
Conditional(ConditionalExpression<'ast, T, E>),
Expression(E::Inner),
}
impl<'ast, T> Conditional<'ast, T> for FieldElementExpression<'ast, T> {
fn conditional(
condition: BooleanExpression<'ast, T>,
consequence: Self,
alternative: Self,
) -> Self {
FieldElementExpression::Conditional(box condition, box consequence, box alternative)
FieldElementExpression::Conditional(ConditionalExpression::new(
condition,
consequence,
alternative,
))
}
}
@ -609,7 +692,11 @@ impl<'ast, T> Conditional<'ast, T> for BooleanExpression<'ast, T> {
consequence: Self,
alternative: Self,
) -> Self {
BooleanExpression::Conditional(box condition, box consequence, box alternative)
BooleanExpression::Conditional(ConditionalExpression::new(
condition,
consequence,
alternative,
))
}
}
@ -621,7 +708,56 @@ impl<'ast, T> Conditional<'ast, T> for UExpression<'ast, T> {
) -> Self {
let bitwidth = consequence.bitwidth;
UExpressionInner::Conditional(box condition, box consequence, box alternative)
.annotate(bitwidth)
UExpressionInner::Conditional(ConditionalExpression::new(
condition,
consequence,
alternative,
))
.annotate(bitwidth)
}
}
pub trait Select<'ast, T>: Sized {
fn select(array: Vec<Self>, index: UExpression<'ast, T>) -> Self;
}
pub enum SelectOrExpression<'ast, T, E: Expr<'ast, T>> {
Select(SelectExpression<'ast, T, E>),
Expression(E::Inner),
}
impl<'ast, T> Select<'ast, T> for FieldElementExpression<'ast, T> {
fn select(array: Vec<Self>, index: UExpression<'ast, T>) -> Self {
FieldElementExpression::Select(SelectExpression::new(array, index))
}
}
impl<'ast, T> Select<'ast, T> for BooleanExpression<'ast, T> {
fn select(array: Vec<Self>, index: UExpression<'ast, T>) -> Self {
BooleanExpression::Select(SelectExpression::new(array, index))
}
}
impl<'ast, T> Select<'ast, T> for UExpression<'ast, T> {
fn select(array: Vec<Self>, index: UExpression<'ast, T>) -> Self {
let bitwidth = array[0].bitwidth;
UExpressionInner::Select(SelectExpression::new(array, index)).annotate(bitwidth)
}
}
pub trait IntoType {
fn into_type(self) -> Type;
}
impl IntoType for Type {
fn into_type(self) -> Type {
self
}
}
impl IntoType for UBitwidth {
fn into_type(self) -> Type {
Type::Uint(self)
}
}

View file

@ -4,6 +4,27 @@ use crate::zir::types::UBitwidth;
use crate::zir::*;
use zokrates_field::Field;
pub trait ResultFold<'ast, T: Field>: Sized {
fn fold<F: ResultFolder<'ast, T>>(self, f: &mut F) -> Result<Self, F::Error>;
}
impl<'ast, T: Field> ResultFold<'ast, T> for FieldElementExpression<'ast, T> {
fn fold<F: ResultFolder<'ast, T>>(self, f: &mut F) -> Result<Self, F::Error> {
f.fold_field_expression(self)
}
}
impl<'ast, T: Field> ResultFold<'ast, T> for BooleanExpression<'ast, T> {
fn fold<F: ResultFolder<'ast, T>>(self, f: &mut F) -> Result<Self, F::Error> {
f.fold_boolean_expression(self)
}
}
impl<'ast, T: Field> ResultFold<'ast, T> for UExpression<'ast, T> {
fn fold<F: ResultFolder<'ast, T>>(self, f: &mut F) -> Result<Self, F::Error> {
f.fold_uint_expression(self)
}
}
pub trait ResultFolder<'ast, T: Field>: Sized {
type Error;
@ -76,6 +97,24 @@ pub trait ResultFolder<'ast, T: Field>: Sized {
}
}
fn fold_conditional_expression<
E: Expr<'ast, T> + ResultFold<'ast, T> + Conditional<'ast, T>,
>(
&mut self,
ty: &E::Ty,
e: ConditionalExpression<'ast, T, E>,
) -> Result<ConditionalOrExpression<'ast, T, E>, Self::Error> {
fold_conditional_expression(self, ty, e)
}
fn fold_select_expression<E: Clone + Expr<'ast, T> + ResultFold<'ast, T> + Select<'ast, T>>(
&mut self,
ty: &E::Ty,
e: SelectExpression<'ast, T, E>,
) -> Result<SelectOrExpression<'ast, T, E>, Self::Error> {
fold_select_expression(self, ty, e)
}
fn fold_field_expression(
&mut self,
e: FieldElementExpression<'ast, T>,
@ -173,12 +212,12 @@ pub fn fold_field_expression<'ast, T: Field, F: ResultFolder<'ast, T>>(
FieldElementExpression::Identifier(id) => {
FieldElementExpression::Identifier(f.fold_name(id)?)
}
FieldElementExpression::Select(a, box i) => FieldElementExpression::Select(
a.into_iter()
.map(|a| f.fold_field_expression(a))
.collect::<Result<_, _>>()?,
box f.fold_uint_expression(i)?,
),
FieldElementExpression::Select(e) => {
match f.fold_select_expression(&Type::FieldElement, e)? {
SelectOrExpression::Select(s) => FieldElementExpression::Select(s),
SelectOrExpression::Expression(u) => u,
}
}
FieldElementExpression::Add(box e1, box e2) => {
let e1 = f.fold_field_expression(e1)?;
let e2 = f.fold_field_expression(e2)?;
@ -204,11 +243,11 @@ pub fn fold_field_expression<'ast, T: Field, F: ResultFolder<'ast, T>>(
let e2 = f.fold_uint_expression(e2)?;
FieldElementExpression::Pow(box e1, box e2)
}
FieldElementExpression::Conditional(box cond, box cons, box alt) => {
let cond = f.fold_boolean_expression(cond)?;
let cons = f.fold_field_expression(cons)?;
let alt = f.fold_field_expression(alt)?;
FieldElementExpression::Conditional(box cond, box cons, box alt)
FieldElementExpression::Conditional(c) => {
match f.fold_conditional_expression(&Type::FieldElement, c)? {
ConditionalOrExpression::Conditional(s) => FieldElementExpression::Conditional(s),
ConditionalOrExpression::Expression(u) => u,
}
}
})
}
@ -220,12 +259,10 @@ pub fn fold_boolean_expression<'ast, T: Field, F: ResultFolder<'ast, T>>(
Ok(match e {
BooleanExpression::Value(v) => BooleanExpression::Value(v),
BooleanExpression::Identifier(id) => BooleanExpression::Identifier(f.fold_name(id)?),
BooleanExpression::Select(a, box i) => BooleanExpression::Select(
a.into_iter()
.map(|a| f.fold_boolean_expression(a))
.collect::<Result<_, _>>()?,
box f.fold_uint_expression(i)?,
),
BooleanExpression::Select(e) => match f.fold_select_expression(&Type::Boolean, e)? {
SelectOrExpression::Select(s) => BooleanExpression::Select(s),
SelectOrExpression::Expression(u) => u,
},
BooleanExpression::FieldEq(box e1, box e2) => {
let e1 = f.fold_field_expression(e1)?;
let e2 = f.fold_field_expression(e2)?;
@ -246,41 +283,21 @@ pub fn fold_boolean_expression<'ast, T: Field, F: ResultFolder<'ast, T>>(
let e2 = f.fold_field_expression(e2)?;
BooleanExpression::FieldLt(box e1, box e2)
}
BooleanExpression::FieldLe(box e1, box e2) => {
let e1 = f.fold_field_expression(e1)?;
let e2 = f.fold_field_expression(e2)?;
BooleanExpression::FieldLe(box e1, box e2)
}
BooleanExpression::FieldGt(box e1, box e2) => {
let e1 = f.fold_field_expression(e1)?;
let e2 = f.fold_field_expression(e2)?;
BooleanExpression::FieldGt(box e1, box e2)
}
BooleanExpression::FieldGe(box e1, box e2) => {
let e1 = f.fold_field_expression(e1)?;
let e2 = f.fold_field_expression(e2)?;
BooleanExpression::FieldGe(box e1, box e2)
}
BooleanExpression::UintLt(box e1, box e2) => {
let e1 = f.fold_uint_expression(e1)?;
let e2 = f.fold_uint_expression(e2)?;
BooleanExpression::UintLt(box e1, box e2)
}
BooleanExpression::FieldLe(box e1, box e2) => {
let e1 = f.fold_field_expression(e1)?;
let e2 = f.fold_field_expression(e2)?;
BooleanExpression::FieldLe(box e1, box e2)
}
BooleanExpression::UintLe(box e1, box e2) => {
let e1 = f.fold_uint_expression(e1)?;
let e2 = f.fold_uint_expression(e2)?;
BooleanExpression::UintLe(box e1, box e2)
}
BooleanExpression::UintGt(box e1, box e2) => {
let e1 = f.fold_uint_expression(e1)?;
let e2 = f.fold_uint_expression(e2)?;
BooleanExpression::UintGt(box e1, box e2)
}
BooleanExpression::UintGe(box e1, box e2) => {
let e1 = f.fold_uint_expression(e1)?;
let e2 = f.fold_uint_expression(e2)?;
BooleanExpression::UintGe(box e1, box e2)
}
BooleanExpression::Or(box e1, box e2) => {
let e1 = f.fold_boolean_expression(e1)?;
let e2 = f.fold_boolean_expression(e2)?;
@ -295,11 +312,11 @@ pub fn fold_boolean_expression<'ast, T: Field, F: ResultFolder<'ast, T>>(
let e = f.fold_boolean_expression(e)?;
BooleanExpression::Not(box e)
}
BooleanExpression::Conditional(box cond, box cons, box alt) => {
let cond = f.fold_boolean_expression(cond)?;
let cons = f.fold_boolean_expression(cons)?;
let alt = f.fold_boolean_expression(alt)?;
BooleanExpression::Conditional(box cond, box cons, box alt)
BooleanExpression::Conditional(c) => {
match f.fold_conditional_expression(&Type::Boolean, c)? {
ConditionalOrExpression::Conditional(s) => BooleanExpression::Conditional(s),
ConditionalOrExpression::Expression(u) => u,
}
}
})
}
@ -316,18 +333,16 @@ pub fn fold_uint_expression<'ast, T: Field, F: ResultFolder<'ast, T>>(
pub fn fold_uint_expression_inner<'ast, T: Field, F: ResultFolder<'ast, T>>(
f: &mut F,
_: UBitwidth,
ty: UBitwidth,
e: UExpressionInner<'ast, T>,
) -> Result<UExpressionInner<'ast, T>, F::Error> {
Ok(match e {
UExpressionInner::Value(v) => UExpressionInner::Value(v),
UExpressionInner::Identifier(id) => UExpressionInner::Identifier(f.fold_name(id)?),
UExpressionInner::Select(a, box i) => UExpressionInner::Select(
a.into_iter()
.map(|a| f.fold_uint_expression(a))
.collect::<Result<_, _>>()?,
box f.fold_uint_expression(i)?,
),
UExpressionInner::Select(e) => match f.fold_select_expression(&ty, e)? {
SelectOrExpression::Select(s) => UExpressionInner::Select(s),
SelectOrExpression::Expression(u) => u,
},
UExpressionInner::Add(box left, box right) => {
let left = f.fold_uint_expression(left)?;
let right = f.fold_uint_expression(right)?;
@ -391,13 +406,10 @@ pub fn fold_uint_expression_inner<'ast, T: Field, F: ResultFolder<'ast, T>>(
UExpressionInner::Not(box e)
}
UExpressionInner::Conditional(box cond, box cons, box alt) => {
let cond = f.fold_boolean_expression(cond)?;
let cons = f.fold_uint_expression(cons)?;
let alt = f.fold_uint_expression(alt)?;
UExpressionInner::Conditional(box cond, box cons, box alt)
}
UExpressionInner::Conditional(c) => match f.fold_conditional_expression(&ty, c)? {
ConditionalOrExpression::Conditional(s) => UExpressionInner::Conditional(s),
ConditionalOrExpression::Expression(u) => u,
},
})
}
@ -431,3 +443,41 @@ pub fn fold_program<'ast, T: Field, F: ResultFolder<'ast, T>>(
main: f.fold_function(p.main)?,
})
}
pub fn fold_conditional_expression<
'ast,
T: Field,
E: Expr<'ast, T> + ResultFold<'ast, T> + Conditional<'ast, T>,
F: ResultFolder<'ast, T>,
>(
f: &mut F,
_: &E::Ty,
e: ConditionalExpression<'ast, T, E>,
) -> Result<ConditionalOrExpression<'ast, T, E>, F::Error> {
Ok(ConditionalOrExpression::Conditional(
ConditionalExpression::new(
f.fold_boolean_expression(*e.condition)?,
e.consequence.fold(f)?,
e.alternative.fold(f)?,
),
))
}
pub fn fold_select_expression<
'ast,
T: Field,
E: Expr<'ast, T> + ResultFold<'ast, T> + Select<'ast, T>,
F: ResultFolder<'ast, T>,
>(
f: &mut F,
_: &E::Ty,
e: SelectExpression<'ast, T, E>,
) -> Result<SelectOrExpression<'ast, T, E>, F::Error> {
Ok(SelectOrExpression::Select(SelectExpression::new(
e.array
.into_iter()
.map(|e| e.fold(f))
.collect::<Result<Vec<_>, _>>()?,
e.index.fold(f)?,
)))
}

View file

@ -1,8 +1,9 @@
use crate::zir::identifier::Identifier;
use crate::zir::types::UBitwidth;
use crate::zir::BooleanExpression;
use zokrates_field::Field;
use super::{ConditionalExpression, SelectExpression};
impl<'ast, T: Field> UExpression<'ast, T> {
#[allow(clippy::should_implement_trait)]
pub fn add(self, other: Self) -> UExpression<'ast, T> {
@ -20,7 +21,7 @@ impl<'ast, T: Field> UExpression<'ast, T> {
pub fn select(values: Vec<Self>, index: Self) -> UExpression<'ast, T> {
let bitwidth = values[0].bitwidth;
UExpressionInner::Select(values, box index).annotate(bitwidth)
UExpressionInner::Select(SelectExpression::new(values, index)).annotate(bitwidth)
}
pub fn mult(self, other: Self) -> UExpression<'ast, T> {
@ -178,7 +179,7 @@ pub struct UExpression<'ast, T> {
pub enum UExpressionInner<'ast, T> {
Value(u128),
Identifier(Identifier<'ast>),
Select(Vec<UExpression<'ast, T>>, Box<UExpression<'ast, T>>),
Select(SelectExpression<'ast, T, UExpression<'ast, T>>),
Add(Box<UExpression<'ast, T>>, Box<UExpression<'ast, T>>),
Sub(Box<UExpression<'ast, T>>, Box<UExpression<'ast, T>>),
Mult(Box<UExpression<'ast, T>>, Box<UExpression<'ast, T>>),
@ -190,11 +191,7 @@ pub enum UExpressionInner<'ast, T> {
LeftShift(Box<UExpression<'ast, T>>, u32),
RightShift(Box<UExpression<'ast, T>>, u32),
Not(Box<UExpression<'ast, T>>),
Conditional(
Box<BooleanExpression<'ast, T>>,
Box<UExpression<'ast, T>>,
Box<UExpression<'ast, T>>,
),
Conditional(ConditionalExpression<'ast, T, UExpression<'ast, T>>),
}
impl<'ast, T> UExpressionInner<'ast, T> {

View file

@ -8,7 +8,9 @@
mod utils;
use self::utils::flat_expression_from_bits;
use zokrates_ast::zir::{ShouldReduce, UMetadata, ZirExpressionList};
use zokrates_ast::zir::{
ConditionalExpression, SelectExpression, ShouldReduce, UMetadata, ZirExpressionList,
};
use zokrates_interpreter::Interpreter;
use crate::compile::CompileConfig;
@ -558,13 +560,15 @@ impl<'ast, T: Field> Flattener<'ast, T> {
/// * `alternative` - the alternative of type U.
/// # Remarks
/// * U is the type of the expression
fn flatten_if_else_expression<U: Flatten<'ast, T>>(
fn flatten_conditional_expression<U: Flatten<'ast, T>>(
&mut self,
statements_flattened: &mut FlatStatements<T>,
condition: BooleanExpression<'ast, T>,
consequence: U,
alternative: U,
e: ConditionalExpression<'ast, T, U>,
) -> FlatUExpression<T> {
let condition = *e.condition;
let consequence = *e.consequence;
let alternative = *e.alternative;
let condition_flat =
self.flatten_boolean_expression(statements_flattened, condition.clone());
@ -859,8 +863,8 @@ impl<'ast, T: Field> Flattener<'ast, T> {
BooleanExpression::Identifier(x) => {
FlatExpression::Identifier(*self.layout.get(&x).unwrap())
}
BooleanExpression::Select(a, box index) => self
.flatten_select_expression(statements_flattened, a, index)
BooleanExpression::Select(e) => self
.flatten_select_expression(statements_flattened, e)
.get_field_unchecked(),
BooleanExpression::FieldLt(box lhs, box rhs) => {
// Get the bit width to know the size of the binary decompositions for this Field
@ -949,14 +953,6 @@ impl<'ast, T: Field> Flattener<'ast, T> {
);
FlatExpression::Add(box eq, box lt)
}
BooleanExpression::FieldGt(lhs, rhs) => self.flatten_boolean_expression(
statements_flattened,
BooleanExpression::FieldLt(rhs, lhs),
),
BooleanExpression::FieldGe(lhs, rhs) => self.flatten_boolean_expression(
statements_flattened,
BooleanExpression::FieldLe(rhs, lhs),
),
BooleanExpression::UintLt(box lhs, box rhs) => {
let bit_width = lhs.bitwidth.to_usize();
assert!(lhs.metadata.as_ref().unwrap().should_reduce.to_bool());
@ -987,14 +983,6 @@ impl<'ast, T: Field> Flattener<'ast, T> {
);
FlatExpression::Add(box eq, box lt)
}
BooleanExpression::UintGt(lhs, rhs) => self.flatten_boolean_expression(
statements_flattened,
BooleanExpression::UintLt(rhs, lhs),
),
BooleanExpression::UintGe(lhs, rhs) => self.flatten_boolean_expression(
statements_flattened,
BooleanExpression::UintLe(rhs, lhs),
),
BooleanExpression::Or(box lhs, box rhs) => {
let x = self.flatten_boolean_expression(statements_flattened, lhs);
let y = self.flatten_boolean_expression(statements_flattened, rhs);
@ -1036,13 +1024,8 @@ impl<'ast, T: Field> Flattener<'ast, T> {
true => T::from(1),
false => T::from(0),
}),
BooleanExpression::Conditional(box condition, box consequence, box alternative) => self
.flatten_if_else_expression(
statements_flattened,
condition,
consequence,
alternative,
)
BooleanExpression::Conditional(e) => self
.flatten_conditional_expression(statements_flattened, e)
.get_field_unchecked(),
}
}
@ -1473,9 +1456,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
});
FlatUExpression::with_field(field).bits(bits)
}
UExpressionInner::Select(a, box index) => {
self.flatten_select_expression(statements_flattened, a, index)
}
UExpressionInner::Select(e) => self.flatten_select_expression(statements_flattened, e),
UExpressionInner::Not(box e) => {
let e = self.flatten_uint_expression(statements_flattened, e);
@ -1633,13 +1614,9 @@ impl<'ast, T: Field> Flattener<'ast, T> {
FlatUExpression::with_field(r)
}
UExpressionInner::Conditional(box condition, box consequence, box alternative) => self
.flatten_if_else_expression(
statements_flattened,
condition,
consequence,
alternative,
),
UExpressionInner::Conditional(e) => {
self.flatten_conditional_expression(statements_flattened, e)
}
UExpressionInner::Xor(box left, box right) => {
let left_metadata = left.metadata.clone().unwrap();
let right_metadata = right.metadata.clone().unwrap();
@ -2039,10 +2016,12 @@ impl<'ast, T: Field> Flattener<'ast, T> {
fn flatten_select_expression<U: Flatten<'ast, T>>(
&mut self,
statements_flattened: &mut FlatStatements<T>,
a: Vec<U>,
index: UExpression<'ast, T>,
e: SelectExpression<'ast, T, U>,
) -> FlatUExpression<T> {
let (range_check, result) = a
let array = e.array;
let index = *e.index;
let (range_check, result) = array
.into_iter()
.enumerate()
.map(|(i, e)| {
@ -2108,8 +2087,8 @@ impl<'ast, T: Field> Flattener<'ast, T> {
FieldElementExpression::Identifier(x) => {
FlatExpression::Identifier(*self.layout.get(&x).unwrap_or_else(|| panic!("{}", x)))
}
FieldElementExpression::Select(a, box index) => self
.flatten_select_expression(statements_flattened, a, index)
FieldElementExpression::Select(e) => self
.flatten_select_expression(statements_flattened, e)
.get_field_unchecked(),
FieldElementExpression::Add(box left, box right) => {
let left_flattened = self.flatten_field_expression(statements_flattened, left);
@ -2294,17 +2273,8 @@ impl<'ast, T: Field> Flattener<'ast, T> {
_ => panic!("Expected number as pow exponent"),
}
}
FieldElementExpression::Conditional(
box condition,
box consequence,
box alternative,
) => self
.flatten_if_else_expression(
statements_flattened,
condition,
consequence,
alternative,
)
FieldElementExpression::Conditional(e) => self
.flatten_conditional_expression(statements_flattened, e)
.get_field_unchecked(),
}
}
@ -2428,8 +2398,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
error.into(),
)
}
BooleanExpression::FieldLt(box lhs, box rhs)
| BooleanExpression::FieldGt(box rhs, box lhs) => {
BooleanExpression::FieldLt(box lhs, box rhs) => {
let lhs = self.flatten_field_expression(statements_flattened, lhs);
let rhs = self.flatten_field_expression(statements_flattened, rhs);
@ -2459,8 +2428,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
}
}
}
BooleanExpression::FieldLe(box lhs, box rhs)
| BooleanExpression::FieldGe(box rhs, box lhs) => {
BooleanExpression::FieldLe(box lhs, box rhs) => {
let lhs = self.flatten_field_expression(statements_flattened, lhs);
let rhs = self.flatten_field_expression(statements_flattened, rhs);
@ -3531,13 +3499,13 @@ mod tests {
#[test]
fn if_else() {
let config = CompileConfig::default();
let expression = FieldElementExpression::Conditional(
box BooleanExpression::FieldEq(
let expression = FieldElementExpression::conditional(
BooleanExpression::FieldEq(
box FieldElementExpression::Number(Bn128Field::from(32)),
box FieldElementExpression::Number(Bn128Field::from(4)),
),
box FieldElementExpression::Number(Bn128Field::from(12)),
box FieldElementExpression::Number(Bn128Field::from(51)),
FieldElementExpression::Number(Bn128Field::from(12)),
FieldElementExpression::Number(Bn128Field::from(51)),
);
let mut flattener = Flattener::new(config);
@ -3554,13 +3522,6 @@ mod tests {
box FieldElementExpression::Number(Bn128Field::from(4)),
);
flattener.flatten_boolean_expression(&mut FlatStatements::new(), expression_le);
let mut flattener = Flattener::new(config);
let expression_ge = BooleanExpression::FieldGe(
box FieldElementExpression::Number(Bn128Field::from(32)),
box FieldElementExpression::Number(Bn128Field::from(4)),
);
flattener.flatten_boolean_expression(&mut FlatStatements::new(), expression_ge);
}
#[test]
@ -3568,8 +3529,8 @@ mod tests {
let config = CompileConfig::default();
let mut flattener = Flattener::new(config);
let expression = FieldElementExpression::Conditional(
box BooleanExpression::And(
let expression = FieldElementExpression::conditional(
BooleanExpression::And(
box BooleanExpression::FieldEq(
box FieldElementExpression::Number(Bn128Field::from(4)),
box FieldElementExpression::Number(Bn128Field::from(4)),
@ -3579,8 +3540,8 @@ mod tests {
box FieldElementExpression::Number(Bn128Field::from(20)),
),
),
box FieldElementExpression::Number(Bn128Field::from(12)),
box FieldElementExpression::Number(Bn128Field::from(51)),
FieldElementExpression::Number(Bn128Field::from(12)),
FieldElementExpression::Number(Bn128Field::from(51)),
);
flattener.flatten_field_expression(&mut FlatStatements::new(), expression);

View file

@ -1,7 +1,7 @@
use std::marker::PhantomData;
use zokrates_ast::typed::types::UBitwidth;
use zokrates_ast::typed::{self, Expr, Typed};
use zokrates_ast::zir;
use zokrates_ast::zir::{self, Select};
use zokrates_field::Field;
use std::convert::{TryFrom, TryInto};
@ -746,36 +746,35 @@ fn fold_select_expression<'ast, T: Field, E>(
let ty = a[0].get_type();
match ty {
zir::Type::Boolean => zir::BooleanExpression::Select(
zir::Type::Boolean => zir::BooleanExpression::select(
a.into_iter()
.map(|e| match e {
zir::ZirExpression::Boolean(e) => e.clone(),
_ => unreachable!(),
})
.collect(),
box index.clone(),
index.clone(),
)
.into(),
zir::Type::FieldElement => zir::FieldElementExpression::Select(
zir::Type::FieldElement => zir::FieldElementExpression::select(
a.into_iter()
.map(|e| match e {
zir::ZirExpression::FieldElement(e) => e.clone(),
_ => unreachable!(),
})
.collect(),
box index.clone(),
index.clone(),
)
.into(),
zir::Type::Uint(bitwidth) => zir::UExpressionInner::Select(
zir::Type::Uint(_) => zir::UExpression::select(
a.into_iter()
.map(|e| match e {
zir::ZirExpression::Uint(e) => e.clone(),
_ => unreachable!(),
})
.collect(),
box index.clone(),
index.clone(),
)
.annotate(bitwidth)
.into(),
}
})
@ -987,12 +986,12 @@ fn fold_boolean_expression<'ast, T: Field>(
typed::BooleanExpression::FieldGt(box e1, box e2) => {
let e1 = f.fold_field_expression(statements_buffer, e1);
let e2 = f.fold_field_expression(statements_buffer, e2);
zir::BooleanExpression::FieldGt(box e1, box e2)
zir::BooleanExpression::FieldLt(box e2, box e1)
}
typed::BooleanExpression::FieldGe(box e1, box e2) => {
let e1 = f.fold_field_expression(statements_buffer, e1);
let e2 = f.fold_field_expression(statements_buffer, e2);
zir::BooleanExpression::FieldGe(box e1, box e2)
zir::BooleanExpression::FieldLe(box e2, box e1)
}
typed::BooleanExpression::UintLt(box e1, box e2) => {
let e1 = f.fold_uint_expression(statements_buffer, e1);
@ -1007,12 +1006,12 @@ fn fold_boolean_expression<'ast, T: Field>(
typed::BooleanExpression::UintGt(box e1, box e2) => {
let e1 = f.fold_uint_expression(statements_buffer, e1);
let e2 = f.fold_uint_expression(statements_buffer, e2);
zir::BooleanExpression::UintGt(box e1, box e2)
zir::BooleanExpression::UintLt(box e2, box e1)
}
typed::BooleanExpression::UintGe(box e1, box e2) => {
let e1 = f.fold_uint_expression(statements_buffer, e1);
let e2 = f.fold_uint_expression(statements_buffer, e2);
zir::BooleanExpression::UintGe(box e1, box e2)
zir::BooleanExpression::UintLe(box e2, box e1)
}
typed::BooleanExpression::Or(box e1, box e2) => {
let e1 = f.fold_boolean_expression(statements_buffer, e1);

View file

@ -55,22 +55,15 @@ fn force_no_reduce<T: Field>(e: UExpression<T>) -> UExpression<T> {
}
impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> {
fn fold_field_expression(
fn fold_select_expression<E: Expr<'ast, T> + Fold<'ast, T> + Select<'ast, T>>(
&mut self,
e: FieldElementExpression<'ast, T>,
) -> FieldElementExpression<'ast, T> {
match e {
FieldElementExpression::Select(a, box i) => {
let a = a
.into_iter()
.map(|e| self.fold_field_expression(e))
.collect();
let i = self.fold_uint_expression(i);
_: &E::Ty,
e: SelectExpression<'ast, T, E>,
) -> SelectOrExpression<'ast, T, E> {
let array = e.array.into_iter().map(|e| e.fold(self)).collect();
let index = e.index.fold(self);
FieldElementExpression::Select(a, box force_reduce(i))
}
_ => fold_field_expression(self, e),
}
SelectOrExpression::Select(SelectExpression::new(array, force_reduce(index)))
}
fn fold_boolean_expression(
@ -78,15 +71,6 @@ impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> {
e: BooleanExpression<'ast, T>,
) -> BooleanExpression<'ast, T> {
match e {
BooleanExpression::Select(a, box i) => {
let a = a
.into_iter()
.map(|e| self.fold_boolean_expression(e))
.collect();
let i = self.fold_uint_expression(i);
BooleanExpression::Select(a, box force_reduce(i))
}
BooleanExpression::UintEq(box left, box right) => {
let left = self.fold_uint_expression(left);
let right = self.fold_uint_expression(right);
@ -114,24 +98,6 @@ impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> {
BooleanExpression::UintLe(box left, box right)
}
BooleanExpression::UintGt(box left, box right) => {
let left = self.fold_uint_expression(left);
let right = self.fold_uint_expression(right);
let left = force_reduce(left);
let right = force_reduce(right);
BooleanExpression::UintGt(box left, box right)
}
BooleanExpression::UintGe(box left, box right) => {
let left = self.fold_uint_expression(left);
let right = self.fold_uint_expression(right);
let left = force_reduce(left);
let right = force_reduce(right);
BooleanExpression::UintGe(box left, box right)
}
e => fold_boolean_expression(self, e),
}
}
@ -161,12 +127,15 @@ impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> {
.cloned()
.unwrap_or_else(|| panic!("identifier should have been defined: {}", id)),
),
Select(values, box index) => {
Select(e) => {
let index = *e.index;
let array = e.array;
let index = self.fold_uint_expression(index);
let index = force_reduce(index);
let values: Vec<_> = values
let values: Vec<_> = array
.into_iter()
.map(|v| force_no_reduce(self.fold_uint_expression(v)))
.collect();
@ -389,10 +358,10 @@ impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> {
UExpression::right_shift(force_reduce(e), by).with_max(max)
}
Conditional(box condition, box consequence, box alternative) => {
let condition = self.fold_boolean_expression(condition);
let consequence = self.fold_uint_expression(consequence);
let alternative = self.fold_uint_expression(alternative);
Conditional(e) => {
let condition = self.fold_boolean_expression(*e.condition);
let consequence = e.consequence.fold(self);
let alternative = e.alternative.fold(self);
let consequence_max = consequence.metadata.clone().unwrap().max;
let alternative_max = alternative.metadata.clone().unwrap().max;

View file

@ -1,8 +1,17 @@
use std::collections::HashMap;
use std::fmt;
use zokrates_ast::zir::result_folder::fold_boolean_expression;
use zokrates_ast::zir::result_folder::fold_field_expression;
use zokrates_ast::zir::result_folder::fold_statement;
use zokrates_ast::zir::result_folder::ResultFold;
use zokrates_ast::zir::result_folder::ResultFolder;
use zokrates_ast::zir::types::UBitwidth;
use zokrates_ast::zir::Conditional;
use zokrates_ast::zir::ConditionalExpression;
use zokrates_ast::zir::ConditionalOrExpression;
use zokrates_ast::zir::Expr;
use zokrates_ast::zir::SelectExpression;
use zokrates_ast::zir::SelectOrExpression;
use zokrates_ast::zir::{
BooleanExpression, FieldElementExpression, Identifier, RuntimeError, UExpression,
UExpressionInner, ZirExpression, ZirProgram, ZirStatement,
@ -136,24 +145,6 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for ZirPropagator<'ast, T> {
}
_ => Ok(FieldElementExpression::Identifier(id)),
},
FieldElementExpression::Select(e, box index) => {
let index = self.fold_uint_expression(index)?;
let e: Vec<FieldElementExpression<'ast, T>> = e
.into_iter()
.map(|e| self.fold_field_expression(e))
.collect::<Result<_, _>>()?;
match index.into_inner() {
UExpressionInner::Value(v) => e
.get(v as usize)
.cloned()
.ok_or(Error::OutOfBounds(v as usize, e.len())),
i => Ok(FieldElementExpression::Select(
e,
box i.annotate(UBitwidth::B32),
)),
}
}
FieldElementExpression::Add(box e1, box e2) => {
match (
self.fold_field_expression(e1)?,
@ -237,28 +228,7 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for ZirPropagator<'ast, T> {
)),
}
}
FieldElementExpression::Conditional(
box condition,
box consequence,
box alternative,
) => {
let condition = self.fold_boolean_expression(condition)?;
let consequence = self.fold_field_expression(consequence)?;
let alternative = self.fold_field_expression(alternative)?;
match (condition, consequence, alternative) {
(_, consequence, alternative) if consequence == alternative => Ok(consequence),
(BooleanExpression::Value(true), consequence, _) => Ok(consequence),
(BooleanExpression::Value(false), _, alternative) => Ok(alternative),
(condition, consequence, alternative) => {
Ok(FieldElementExpression::Conditional(
box condition,
box consequence,
box alternative,
))
}
}
}
e => fold_field_expression(self, e),
}
}
@ -274,21 +244,6 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for ZirPropagator<'ast, T> {
}
_ => Ok(BooleanExpression::Identifier(id)),
},
BooleanExpression::Select(e, box index) => {
let index = self.fold_uint_expression(index)?;
let e: Vec<BooleanExpression<'ast, T>> = e
.into_iter()
.map(|e| self.fold_boolean_expression(e))
.collect::<Result<_, _>>()?;
match index.as_inner() {
UExpressionInner::Value(v) => e
.get(*v as usize)
.cloned()
.ok_or(Error::OutOfBounds(*v as usize, e.len())),
_ => Ok(BooleanExpression::Select(e, box index)),
}
}
BooleanExpression::FieldLt(box e1, box e2) => {
match (
self.fold_field_expression(e1)?,
@ -317,34 +272,6 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for ZirPropagator<'ast, T> {
(e1, e2) => Ok(BooleanExpression::FieldLe(box e1, box e2)),
}
}
BooleanExpression::FieldGe(box e1, box e2) => {
match (
self.fold_field_expression(e1)?,
self.fold_field_expression(e2)?,
) {
(FieldElementExpression::Number(n1), FieldElementExpression::Number(n2)) => {
Ok(BooleanExpression::Value(n1 >= n2))
}
(e1, e2) => Ok(BooleanExpression::FieldGe(box e1, box e2)),
}
}
BooleanExpression::FieldGt(box e1, box e2) => {
match (
self.fold_field_expression(e1)?,
self.fold_field_expression(e2)?,
) {
(FieldElementExpression::Number(n1), FieldElementExpression::Number(n2)) => {
Ok(BooleanExpression::Value(n1 > n2))
}
(_, FieldElementExpression::Number(c)) if c == T::max_value() => {
Ok(BooleanExpression::Value(false))
}
(FieldElementExpression::Number(c), _) if c == T::zero() => {
Ok(BooleanExpression::Value(false))
}
(e1, e2) => Ok(BooleanExpression::FieldGt(box e1, box e2)),
}
}
BooleanExpression::FieldEq(box e1, box e2) => {
match (
self.fold_field_expression(e1)?,
@ -384,28 +311,6 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for ZirPropagator<'ast, T> {
_ => Ok(BooleanExpression::UintLe(box e1, box e2)),
}
}
BooleanExpression::UintGe(box e1, box e2) => {
let e1 = self.fold_uint_expression(e1)?;
let e2 = self.fold_uint_expression(e2)?;
match (e1.as_inner(), e2.as_inner()) {
(UExpressionInner::Value(v1), UExpressionInner::Value(v2)) => {
Ok(BooleanExpression::Value(v1 >= v2))
}
_ => Ok(BooleanExpression::UintGe(box e1, box e2)),
}
}
BooleanExpression::UintGt(box e1, box e2) => {
let e1 = self.fold_uint_expression(e1)?;
let e2 = self.fold_uint_expression(e2)?;
match (e1.as_inner(), e2.as_inner()) {
(UExpressionInner::Value(v1), UExpressionInner::Value(v2)) => {
Ok(BooleanExpression::Value(v1 > v2))
}
_ => Ok(BooleanExpression::UintGt(box e1, box e2)),
}
}
BooleanExpression::UintEq(box e1, box e2) => {
let e1 = self.fold_uint_expression(e1)?;
let e2 = self.fold_uint_expression(e2)?;
@ -475,22 +380,33 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for ZirPropagator<'ast, T> {
BooleanExpression::Value(v) => Ok(BooleanExpression::Value(!v)),
e => Ok(BooleanExpression::Not(box e)),
},
BooleanExpression::Conditional(box condition, box consequence, box alternative) => {
let condition = self.fold_boolean_expression(condition)?;
let consequence = self.fold_boolean_expression(consequence)?;
let alternative = self.fold_boolean_expression(alternative)?;
e => fold_boolean_expression(self, e),
}
}
match (condition, consequence, alternative) {
(_, consequence, alternative) if consequence == alternative => Ok(consequence),
(BooleanExpression::Value(true), consequence, _) => Ok(consequence),
(BooleanExpression::Value(false), _, alternative) => Ok(alternative),
(condition, consequence, alternative) => Ok(BooleanExpression::Conditional(
box condition,
box consequence,
box alternative,
)),
}
}
fn fold_select_expression<
E: Clone + Expr<'ast, T> + ResultFold<'ast, T> + zokrates_ast::zir::Select<'ast, T>,
>(
&mut self,
_: &E::Ty,
e: SelectExpression<'ast, T, E>,
) -> Result<zokrates_ast::zir::SelectOrExpression<'ast, T, E>, Self::Error> {
let index = self.fold_uint_expression(*e.index)?;
let array = e
.array
.into_iter()
.map(|e| e.fold(self))
.collect::<Result<Vec<_>, _>>()?;
match index.as_inner() {
UExpressionInner::Value(v) => array
.get(*v as usize)
.cloned()
.ok_or(Error::OutOfBounds(*v as usize, array.len()))
.map(|e| SelectOrExpression::Expression(e.into_inner())),
_ => Ok(SelectOrExpression::Expression(
E::select(array, index).into_inner(),
)),
}
}
@ -505,22 +421,6 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for ZirPropagator<'ast, T> {
Some(ZirExpression::Uint(e)) => Ok(e.as_inner().clone()),
_ => Ok(UExpressionInner::Identifier(id)),
},
UExpressionInner::Select(e, box index) => {
let index = self.fold_uint_expression(index)?;
let e: Vec<UExpression<'ast, T>> = e
.into_iter()
.map(|e| self.fold_uint_expression(e))
.collect::<Result<_, _>>()?;
match index.into_inner() {
UExpressionInner::Value(v) => e
.get(v as usize)
.cloned()
.ok_or(Error::OutOfBounds(v as usize, e.len()))
.map(|e| e.into_inner()),
i => Ok(UExpressionInner::Select(e, box i.annotate(UBitwidth::B32))),
}
}
UExpressionInner::Add(box e1, box e2) => {
let e1 = self.fold_uint_expression(e1)?;
let e2 = self.fold_uint_expression(e2)?;
@ -687,22 +587,34 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for ZirPropagator<'ast, T> {
e => Ok(UExpressionInner::Not(box e.annotate(bitwidth))),
}
}
UExpressionInner::Conditional(box condition, box consequence, box alternative) => {
let condition = self.fold_boolean_expression(condition)?;
let consequence = self.fold_uint_expression(consequence)?.into_inner();
let alternative = self.fold_uint_expression(alternative)?.into_inner();
e => self.fold_uint_expression_inner(bitwidth, e),
}
}
match (condition, consequence, alternative) {
(_, consequence, alternative) if consequence == alternative => Ok(consequence),
(BooleanExpression::Value(true), consequence, _) => Ok(consequence),
(BooleanExpression::Value(false), _, alternative) => Ok(alternative),
(condition, consequence, alternative) => Ok(UExpressionInner::Conditional(
box condition,
box consequence.annotate(bitwidth),
box alternative.annotate(bitwidth),
)),
}
}
fn fold_conditional_expression<
E: Expr<'ast, T> + ResultFold<'ast, T> + Conditional<'ast, T>,
>(
&mut self,
_: &E::Ty,
e: ConditionalExpression<'ast, T, E>,
) -> Result<ConditionalOrExpression<'ast, T, E>, Self::Error> {
let condition = self.fold_boolean_expression(*e.condition)?;
let consequence = e.consequence.fold(self)?;
let alternative = e.alternative.fold(self)?;
match (condition, consequence, alternative) {
(_, consequence, alternative) if consequence == alternative => Ok(
ConditionalOrExpression::Expression(consequence.into_inner()),
),
(BooleanExpression::Value(true), consequence, _) => Ok(
ConditionalOrExpression::Expression(consequence.into_inner()),
),
(BooleanExpression::Value(false), _, alternative) => Ok(
ConditionalOrExpression::Expression(alternative.into_inner()),
),
(condition, consequence, alternative) => Ok(ConditionalOrExpression::Conditional(
ConditionalExpression::new(condition, consequence, alternative),
)),
}
}
}
@ -754,6 +666,8 @@ mod tests {
#[cfg(test)]
mod field {
use zokrates_ast::zir::Select;
use super::*;
#[test]
@ -761,23 +675,23 @@ mod tests {
let mut propagator = ZirPropagator::default();
assert_eq!(
propagator.fold_field_expression(FieldElementExpression::Select(
propagator.fold_field_expression(FieldElementExpression::select(
vec![
FieldElementExpression::Number(Bn128Field::from(1)),
FieldElementExpression::Number(Bn128Field::from(2)),
],
box UExpressionInner::Value(1).annotate(UBitwidth::B32),
UExpressionInner::Value(1).annotate(UBitwidth::B32),
)),
Ok(FieldElementExpression::Number(Bn128Field::from(2)))
);
assert_eq!(
propagator.fold_field_expression(FieldElementExpression::Select(
propagator.fold_field_expression(FieldElementExpression::select(
vec![
FieldElementExpression::Number(Bn128Field::from(1)),
FieldElementExpression::Number(Bn128Field::from(2)),
],
box UExpressionInner::Value(3).annotate(UBitwidth::B32),
UExpressionInner::Value(3).annotate(UBitwidth::B32),
)),
Err(Error::OutOfBounds(3, 2))
);
@ -923,28 +837,28 @@ mod tests {
let mut propagator = ZirPropagator::default();
assert_eq!(
propagator.fold_field_expression(FieldElementExpression::Conditional(
box BooleanExpression::Value(true),
box FieldElementExpression::Number(Bn128Field::from(1)),
box FieldElementExpression::Number(Bn128Field::from(2)),
propagator.fold_field_expression(FieldElementExpression::conditional(
BooleanExpression::Value(true),
FieldElementExpression::Number(Bn128Field::from(1)),
FieldElementExpression::Number(Bn128Field::from(2)),
)),
Ok(FieldElementExpression::Number(Bn128Field::from(1)))
);
assert_eq!(
propagator.fold_field_expression(FieldElementExpression::Conditional(
box BooleanExpression::Value(false),
box FieldElementExpression::Number(Bn128Field::from(1)),
box FieldElementExpression::Number(Bn128Field::from(2)),
propagator.fold_field_expression(FieldElementExpression::conditional(
BooleanExpression::Value(false),
FieldElementExpression::Number(Bn128Field::from(1)),
FieldElementExpression::Number(Bn128Field::from(2)),
)),
Ok(FieldElementExpression::Number(Bn128Field::from(2)))
);
assert_eq!(
propagator.fold_field_expression(FieldElementExpression::Conditional(
box BooleanExpression::Identifier("a".into()),
box FieldElementExpression::Number(Bn128Field::from(2)),
box FieldElementExpression::Number(Bn128Field::from(2)),
propagator.fold_field_expression(FieldElementExpression::conditional(
BooleanExpression::Identifier("a".into()),
FieldElementExpression::Number(Bn128Field::from(2)),
FieldElementExpression::Number(Bn128Field::from(2)),
)),
Ok(FieldElementExpression::Number(Bn128Field::from(2)))
);
@ -953,6 +867,8 @@ mod tests {
#[cfg(test)]
mod bool {
use zokrates_ast::zir::Select;
use super::*;
#[test]
@ -960,23 +876,23 @@ mod tests {
let mut propagator = ZirPropagator::<Bn128Field>::default();
assert_eq!(
propagator.fold_boolean_expression(BooleanExpression::Select(
propagator.fold_boolean_expression(BooleanExpression::select(
vec![
BooleanExpression::Value(false),
BooleanExpression::Value(true),
],
box UExpressionInner::Value(1).annotate(UBitwidth::B32),
UExpressionInner::Value(1).annotate(UBitwidth::B32),
)),
Ok(BooleanExpression::Value(true))
);
assert_eq!(
propagator.fold_boolean_expression(BooleanExpression::Select(
propagator.fold_boolean_expression(BooleanExpression::select(
vec![
BooleanExpression::Value(false),
BooleanExpression::Value(true),
],
box UExpressionInner::Value(3).annotate(UBitwidth::B32),
UExpressionInner::Value(3).annotate(UBitwidth::B32),
)),
Err(Error::OutOfBounds(3, 2))
);
@ -1040,64 +956,6 @@ mod tests {
);
}
#[test]
fn field_ge() {
let mut propagator = ZirPropagator::default();
assert_eq!(
propagator.fold_boolean_expression(BooleanExpression::FieldGe(
box FieldElementExpression::Number(Bn128Field::from(3)),
box FieldElementExpression::Number(Bn128Field::from(2)),
)),
Ok(BooleanExpression::Value(true))
);
assert_eq!(
propagator.fold_boolean_expression(BooleanExpression::FieldGe(
box FieldElementExpression::Number(Bn128Field::from(3)),
box FieldElementExpression::Number(Bn128Field::from(3)),
)),
Ok(BooleanExpression::Value(true))
);
}
#[test]
fn field_gt() {
let mut propagator = ZirPropagator::default();
assert_eq!(
propagator.fold_boolean_expression(BooleanExpression::FieldGt(
box FieldElementExpression::Number(Bn128Field::from(3)),
box FieldElementExpression::Number(Bn128Field::from(2)),
)),
Ok(BooleanExpression::Value(true))
);
assert_eq!(
propagator.fold_boolean_expression(BooleanExpression::FieldGt(
box FieldElementExpression::Number(Bn128Field::from(3)),
box FieldElementExpression::Number(Bn128Field::from(3)),
)),
Ok(BooleanExpression::Value(false))
);
assert_eq!(
propagator.fold_boolean_expression(BooleanExpression::FieldGt(
box FieldElementExpression::Number(Bn128Field::from(0)),
box FieldElementExpression::Identifier("a".into()),
)),
Ok(BooleanExpression::Value(false))
);
assert_eq!(
propagator.fold_boolean_expression(BooleanExpression::FieldGt(
box FieldElementExpression::Identifier("a".into()),
box FieldElementExpression::Number(Bn128Field::max_value()),
)),
Ok(BooleanExpression::Value(false))
);
}
#[test]
fn field_eq() {
let mut propagator = ZirPropagator::default();
@ -1161,48 +1019,6 @@ mod tests {
);
}
#[test]
fn uint_ge() {
let mut propagator = ZirPropagator::<Bn128Field>::default();
assert_eq!(
propagator.fold_boolean_expression(BooleanExpression::UintGe(
box UExpressionInner::Value(3).annotate(UBitwidth::B32),
box UExpressionInner::Value(2).annotate(UBitwidth::B32),
)),
Ok(BooleanExpression::Value(true))
);
assert_eq!(
propagator.fold_boolean_expression(BooleanExpression::UintGe(
box UExpressionInner::Value(3).annotate(UBitwidth::B32),
box UExpressionInner::Value(3).annotate(UBitwidth::B32),
)),
Ok(BooleanExpression::Value(true))
);
}
#[test]
fn uint_gt() {
let mut propagator = ZirPropagator::<Bn128Field>::default();
assert_eq!(
propagator.fold_boolean_expression(BooleanExpression::UintGt(
box UExpressionInner::Value(3).annotate(UBitwidth::B32),
box UExpressionInner::Value(2).annotate(UBitwidth::B32),
)),
Ok(BooleanExpression::Value(true))
);
assert_eq!(
propagator.fold_boolean_expression(BooleanExpression::UintGt(
box UExpressionInner::Value(3).annotate(UBitwidth::B32),
box UExpressionInner::Value(3).annotate(UBitwidth::B32),
)),
Ok(BooleanExpression::Value(false))
);
}
#[test]
fn uint_eq() {
let mut propagator = ZirPropagator::<Bn128Field>::default();
@ -1327,28 +1143,28 @@ mod tests {
let mut propagator = ZirPropagator::<Bn128Field>::default();
assert_eq!(
propagator.fold_boolean_expression(BooleanExpression::Conditional(
box BooleanExpression::Value(true),
box BooleanExpression::Value(true),
box BooleanExpression::Value(false)
propagator.fold_boolean_expression(BooleanExpression::conditional(
BooleanExpression::Value(true),
BooleanExpression::Value(true),
BooleanExpression::Value(false)
)),
Ok(BooleanExpression::Value(true))
);
assert_eq!(
propagator.fold_boolean_expression(BooleanExpression::Conditional(
box BooleanExpression::Value(false),
box BooleanExpression::Value(true),
box BooleanExpression::Value(false)
propagator.fold_boolean_expression(BooleanExpression::conditional(
BooleanExpression::Value(false),
BooleanExpression::Value(true),
BooleanExpression::Value(false)
)),
Ok(BooleanExpression::Value(false))
);
assert_eq!(
propagator.fold_boolean_expression(BooleanExpression::Conditional(
box BooleanExpression::Identifier("a".into()),
box BooleanExpression::Value(true),
box BooleanExpression::Value(true)
propagator.fold_boolean_expression(BooleanExpression::conditional(
BooleanExpression::Identifier("a".into()),
BooleanExpression::Value(true),
BooleanExpression::Value(true)
)),
Ok(BooleanExpression::Value(true))
);
@ -1366,13 +1182,14 @@ mod tests {
assert_eq!(
propagator.fold_uint_expression_inner(
UBitwidth::B32,
UExpressionInner::Select(
UExpression::select(
vec![
UExpressionInner::Value(1).annotate(UBitwidth::B32),
UExpressionInner::Value(2).annotate(UBitwidth::B32),
],
box UExpressionInner::Value(1).annotate(UBitwidth::B32),
UExpressionInner::Value(1).annotate(UBitwidth::B32),
)
.into_inner()
),
Ok(UExpressionInner::Value(2))
);
@ -1380,13 +1197,14 @@ mod tests {
assert_eq!(
propagator.fold_uint_expression_inner(
UBitwidth::B32,
UExpressionInner::Select(
UExpression::select(
vec![
UExpressionInner::Value(1).annotate(UBitwidth::B32),
UExpressionInner::Value(2).annotate(UBitwidth::B32),
],
box UExpressionInner::Value(3).annotate(UBitwidth::B32),
UExpressionInner::Value(3).annotate(UBitwidth::B32),
)
.into_inner()
),
Err(Error::OutOfBounds(3, 2))
);
@ -1752,11 +1570,12 @@ mod tests {
assert_eq!(
propagator.fold_uint_expression_inner(
UBitwidth::B32,
UExpressionInner::Conditional(
box BooleanExpression::Value(true),
box UExpressionInner::Value(1).annotate(UBitwidth::B32),
box UExpressionInner::Value(2).annotate(UBitwidth::B32),
UExpression::conditional(
BooleanExpression::Value(true),
UExpressionInner::Value(1).annotate(UBitwidth::B32),
UExpressionInner::Value(2).annotate(UBitwidth::B32),
)
.into_inner()
),
Ok(UExpressionInner::Value(1))
);
@ -1764,11 +1583,12 @@ mod tests {
assert_eq!(
propagator.fold_uint_expression_inner(
UBitwidth::B32,
UExpressionInner::Conditional(
box BooleanExpression::Value(false),
box UExpressionInner::Value(1).annotate(UBitwidth::B32),
box UExpressionInner::Value(2).annotate(UBitwidth::B32),
UExpression::conditional(
BooleanExpression::Value(false),
UExpressionInner::Value(1).annotate(UBitwidth::B32),
UExpressionInner::Value(2).annotate(UBitwidth::B32),
)
.into_inner()
),
Ok(UExpressionInner::Value(2))
);
@ -1776,11 +1596,12 @@ mod tests {
assert_eq!(
propagator.fold_uint_expression_inner(
UBitwidth::B32,
UExpressionInner::Conditional(
box BooleanExpression::Identifier("a".into()),
box UExpressionInner::Value(2).annotate(UBitwidth::B32),
box UExpressionInner::Value(2).annotate(UBitwidth::B32),
UExpression::conditional(
BooleanExpression::Identifier("a".into()),
UExpressionInner::Value(2).annotate(UBitwidth::B32),
UExpressionInner::Value(2).annotate(UBitwidth::B32),
)
.into_inner()
),
Ok(UExpressionInner::Value(2))
);