refactor zir to introduce conditional and select expressions
This commit is contained in:
parent
53b62f568b
commit
e4fbc6d35a
8 changed files with 634 additions and 654 deletions
|
@ -4,6 +4,27 @@ use crate::zir::types::UBitwidth;
|
|||
use crate::zir::*;
|
||||
use zokrates_field::Field;
|
||||
|
||||
pub trait Fold<'ast, T: Field>: Sized {
|
||||
fn fold<F: Folder<'ast, T>>(self, f: &mut F) -> Self;
|
||||
}
|
||||
|
||||
impl<'ast, T: Field> Fold<'ast, T> for FieldElementExpression<'ast, T> {
|
||||
fn fold<F: Folder<'ast, T>>(self, f: &mut F) -> Self {
|
||||
f.fold_field_expression(self)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast, T: Field> Fold<'ast, T> for BooleanExpression<'ast, T> {
|
||||
fn fold<F: Folder<'ast, T>>(self, f: &mut F) -> Self {
|
||||
f.fold_boolean_expression(self)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast, T: Field> Fold<'ast, T> for UExpression<'ast, T> {
|
||||
fn fold<F: Folder<'ast, T>>(self, f: &mut F) -> Self {
|
||||
f.fold_uint_expression(self)
|
||||
}
|
||||
}
|
||||
pub trait Folder<'ast, T: Field>: Sized {
|
||||
fn fold_program(&mut self, p: ZirProgram<'ast, T>) -> ZirProgram<'ast, T> {
|
||||
fold_program(self, p)
|
||||
|
@ -39,6 +60,22 @@ pub trait Folder<'ast, T: Field>: Sized {
|
|||
fold_statement(self, s)
|
||||
}
|
||||
|
||||
fn fold_conditional_expression<E: Expr<'ast, T> + Fold<'ast, T> + Conditional<'ast, T>>(
|
||||
&mut self,
|
||||
ty: &E::Ty,
|
||||
e: ConditionalExpression<'ast, T, E>,
|
||||
) -> ConditionalOrExpression<'ast, T, E> {
|
||||
fold_conditional_expression(self, ty, e)
|
||||
}
|
||||
|
||||
fn fold_select_expression<E: Expr<'ast, T> + Fold<'ast, T> + Select<'ast, T>>(
|
||||
&mut self,
|
||||
ty: &E::Ty,
|
||||
e: SelectExpression<'ast, T, E>,
|
||||
) -> SelectOrExpression<'ast, T, E> {
|
||||
fold_select_expression(self, ty, e)
|
||||
}
|
||||
|
||||
fn fold_expression(&mut self, e: ZirExpression<'ast, T>) -> ZirExpression<'ast, T> {
|
||||
match e {
|
||||
ZirExpression::FieldElement(e) => self.fold_field_expression(e).into(),
|
||||
|
@ -141,10 +178,12 @@ pub fn fold_field_expression<'ast, T: Field, F: Folder<'ast, T>>(
|
|||
FieldElementExpression::Identifier(id) => {
|
||||
FieldElementExpression::Identifier(f.fold_name(id))
|
||||
}
|
||||
FieldElementExpression::Select(a, box i) => FieldElementExpression::Select(
|
||||
a.into_iter().map(|a| f.fold_field_expression(a)).collect(),
|
||||
box f.fold_uint_expression(i),
|
||||
),
|
||||
FieldElementExpression::Select(e) => {
|
||||
match f.fold_select_expression(&Type::FieldElement, e) {
|
||||
SelectOrExpression::Select(s) => FieldElementExpression::Select(s),
|
||||
SelectOrExpression::Expression(u) => u,
|
||||
}
|
||||
}
|
||||
FieldElementExpression::Add(box e1, box e2) => {
|
||||
let e1 = f.fold_field_expression(e1);
|
||||
let e2 = f.fold_field_expression(e2);
|
||||
|
@ -170,11 +209,11 @@ pub fn fold_field_expression<'ast, T: Field, F: Folder<'ast, T>>(
|
|||
let e2 = f.fold_uint_expression(e2);
|
||||
FieldElementExpression::Pow(box e1, box e2)
|
||||
}
|
||||
FieldElementExpression::Conditional(box cond, box cons, box alt) => {
|
||||
let cond = f.fold_boolean_expression(cond);
|
||||
let cons = f.fold_field_expression(cons);
|
||||
let alt = f.fold_field_expression(alt);
|
||||
FieldElementExpression::Conditional(box cond, box cons, box alt)
|
||||
FieldElementExpression::Conditional(c) => {
|
||||
match f.fold_conditional_expression(&Type::FieldElement, c) {
|
||||
ConditionalOrExpression::Conditional(s) => FieldElementExpression::Conditional(s),
|
||||
ConditionalOrExpression::Expression(u) => u,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -186,12 +225,10 @@ pub fn fold_boolean_expression<'ast, T: Field, F: Folder<'ast, T>>(
|
|||
match e {
|
||||
BooleanExpression::Value(v) => BooleanExpression::Value(v),
|
||||
BooleanExpression::Identifier(id) => BooleanExpression::Identifier(f.fold_name(id)),
|
||||
BooleanExpression::Select(a, box i) => BooleanExpression::Select(
|
||||
a.into_iter()
|
||||
.map(|a| f.fold_boolean_expression(a))
|
||||
.collect(),
|
||||
box f.fold_uint_expression(i),
|
||||
),
|
||||
BooleanExpression::Select(e) => match f.fold_select_expression(&Type::Boolean, e) {
|
||||
SelectOrExpression::Select(s) => BooleanExpression::Select(s),
|
||||
SelectOrExpression::Expression(u) => u,
|
||||
},
|
||||
BooleanExpression::FieldEq(box e1, box e2) => {
|
||||
let e1 = f.fold_field_expression(e1);
|
||||
let e2 = f.fold_field_expression(e2);
|
||||
|
@ -212,41 +249,21 @@ pub fn fold_boolean_expression<'ast, T: Field, F: Folder<'ast, T>>(
|
|||
let e2 = f.fold_field_expression(e2);
|
||||
BooleanExpression::FieldLt(box e1, box e2)
|
||||
}
|
||||
BooleanExpression::FieldLe(box e1, box e2) => {
|
||||
let e1 = f.fold_field_expression(e1);
|
||||
let e2 = f.fold_field_expression(e2);
|
||||
BooleanExpression::FieldLe(box e1, box e2)
|
||||
}
|
||||
BooleanExpression::FieldGt(box e1, box e2) => {
|
||||
let e1 = f.fold_field_expression(e1);
|
||||
let e2 = f.fold_field_expression(e2);
|
||||
BooleanExpression::FieldGt(box e1, box e2)
|
||||
}
|
||||
BooleanExpression::FieldGe(box e1, box e2) => {
|
||||
let e1 = f.fold_field_expression(e1);
|
||||
let e2 = f.fold_field_expression(e2);
|
||||
BooleanExpression::FieldGe(box e1, box e2)
|
||||
}
|
||||
BooleanExpression::UintLt(box e1, box e2) => {
|
||||
let e1 = f.fold_uint_expression(e1);
|
||||
let e2 = f.fold_uint_expression(e2);
|
||||
BooleanExpression::UintLt(box e1, box e2)
|
||||
}
|
||||
BooleanExpression::FieldLe(box e1, box e2) => {
|
||||
let e1 = f.fold_field_expression(e1);
|
||||
let e2 = f.fold_field_expression(e2);
|
||||
BooleanExpression::FieldLe(box e1, box e2)
|
||||
}
|
||||
BooleanExpression::UintLe(box e1, box e2) => {
|
||||
let e1 = f.fold_uint_expression(e1);
|
||||
let e2 = f.fold_uint_expression(e2);
|
||||
BooleanExpression::UintLe(box e1, box e2)
|
||||
}
|
||||
BooleanExpression::UintGt(box e1, box e2) => {
|
||||
let e1 = f.fold_uint_expression(e1);
|
||||
let e2 = f.fold_uint_expression(e2);
|
||||
BooleanExpression::UintGt(box e1, box e2)
|
||||
}
|
||||
BooleanExpression::UintGe(box e1, box e2) => {
|
||||
let e1 = f.fold_uint_expression(e1);
|
||||
let e2 = f.fold_uint_expression(e2);
|
||||
BooleanExpression::UintGe(box e1, box e2)
|
||||
}
|
||||
BooleanExpression::Or(box e1, box e2) => {
|
||||
let e1 = f.fold_boolean_expression(e1);
|
||||
let e2 = f.fold_boolean_expression(e2);
|
||||
|
@ -261,12 +278,11 @@ pub fn fold_boolean_expression<'ast, T: Field, F: Folder<'ast, T>>(
|
|||
let e = f.fold_boolean_expression(e);
|
||||
BooleanExpression::Not(box e)
|
||||
}
|
||||
BooleanExpression::Conditional(box cond, box cons, box alt) => {
|
||||
let cond = f.fold_boolean_expression(cond);
|
||||
let cons = f.fold_boolean_expression(cons);
|
||||
let alt = f.fold_boolean_expression(alt);
|
||||
BooleanExpression::Conditional(box cond, box cons, box alt)
|
||||
}
|
||||
BooleanExpression::Conditional(c) => match f.fold_conditional_expression(&Type::Boolean, c)
|
||||
{
|
||||
ConditionalOrExpression::Conditional(s) => BooleanExpression::Conditional(s),
|
||||
ConditionalOrExpression::Expression(u) => u,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -282,16 +298,16 @@ pub fn fold_uint_expression<'ast, T: Field, F: Folder<'ast, T>>(
|
|||
|
||||
pub fn fold_uint_expression_inner<'ast, T: Field, F: Folder<'ast, T>>(
|
||||
f: &mut F,
|
||||
_: UBitwidth,
|
||||
ty: UBitwidth,
|
||||
e: UExpressionInner<'ast, T>,
|
||||
) -> UExpressionInner<'ast, T> {
|
||||
match e {
|
||||
UExpressionInner::Value(v) => UExpressionInner::Value(v),
|
||||
UExpressionInner::Identifier(id) => UExpressionInner::Identifier(f.fold_name(id)),
|
||||
UExpressionInner::Select(a, box i) => UExpressionInner::Select(
|
||||
a.into_iter().map(|a| f.fold_uint_expression(a)).collect(),
|
||||
box f.fold_uint_expression(i),
|
||||
),
|
||||
UExpressionInner::Select(e) => match f.fold_select_expression(&ty, e) {
|
||||
SelectOrExpression::Select(s) => UExpressionInner::Select(s),
|
||||
SelectOrExpression::Expression(u) => u,
|
||||
},
|
||||
UExpressionInner::Add(box left, box right) => {
|
||||
let left = f.fold_uint_expression(left);
|
||||
let right = f.fold_uint_expression(right);
|
||||
|
@ -355,12 +371,10 @@ pub fn fold_uint_expression_inner<'ast, T: Field, F: Folder<'ast, T>>(
|
|||
|
||||
UExpressionInner::Not(box e)
|
||||
}
|
||||
UExpressionInner::Conditional(box cond, box cons, box alt) => {
|
||||
let cond = f.fold_boolean_expression(cond);
|
||||
let cons = f.fold_uint_expression(cons);
|
||||
let alt = f.fold_uint_expression(alt);
|
||||
UExpressionInner::Conditional(box cond, box cons, box alt)
|
||||
}
|
||||
UExpressionInner::Conditional(c) => match f.fold_conditional_expression(&ty, c) {
|
||||
ConditionalOrExpression::Conditional(s) => UExpressionInner::Conditional(s),
|
||||
ConditionalOrExpression::Expression(u) => u,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -391,3 +405,36 @@ pub fn fold_program<'ast, T: Field, F: Folder<'ast, T>>(
|
|||
main: f.fold_function(p.main),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn fold_conditional_expression<
|
||||
'ast,
|
||||
T: Field,
|
||||
E: Expr<'ast, T> + Fold<'ast, T> + Conditional<'ast, T>,
|
||||
F: Folder<'ast, T>,
|
||||
>(
|
||||
f: &mut F,
|
||||
_: &E::Ty,
|
||||
e: ConditionalExpression<'ast, T, E>,
|
||||
) -> ConditionalOrExpression<'ast, T, E> {
|
||||
ConditionalOrExpression::Conditional(ConditionalExpression::new(
|
||||
f.fold_boolean_expression(*e.condition),
|
||||
e.consequence.fold(f),
|
||||
e.alternative.fold(f),
|
||||
))
|
||||
}
|
||||
|
||||
pub fn fold_select_expression<
|
||||
'ast,
|
||||
T: Field,
|
||||
E: Expr<'ast, T> + Fold<'ast, T> + Select<'ast, T>,
|
||||
F: Folder<'ast, T>,
|
||||
>(
|
||||
f: &mut F,
|
||||
_: &E::Ty,
|
||||
e: SelectExpression<'ast, T, E>,
|
||||
) -> SelectOrExpression<'ast, T, E> {
|
||||
SelectOrExpression::Select(SelectExpression::new(
|
||||
e.array.into_iter().map(|e| e.fold(f)).collect(),
|
||||
e.index.fold(f),
|
||||
))
|
||||
}
|
||||
|
|
|
@ -8,7 +8,7 @@ mod uint;
|
|||
mod variable;
|
||||
|
||||
pub use self::parameter::Parameter;
|
||||
pub use self::types::Type;
|
||||
pub use self::types::{Type, UBitwidth};
|
||||
pub use self::variable::Variable;
|
||||
use crate::common::{FlatEmbed, FormatString};
|
||||
use crate::typed::ConcreteType;
|
||||
|
@ -23,7 +23,7 @@ pub use self::folder::Folder;
|
|||
pub use self::identifier::{Identifier, SourceIdentifier};
|
||||
|
||||
/// A typed program as a collection of modules, one of them being the main
|
||||
#[derive(PartialEq, Eq, Debug)]
|
||||
#[derive(PartialEq, Eq, Debug, Clone)]
|
||||
pub struct ZirProgram<'ast, T> {
|
||||
pub main: ZirFunction<'ast, T>,
|
||||
}
|
||||
|
@ -138,14 +138,15 @@ impl<'ast, T: fmt::Display> ZirStatement<'ast, T> {
|
|||
write!(f, "{}", "\t".repeat(depth))?;
|
||||
match self {
|
||||
ZirStatement::Return(ref exprs) => {
|
||||
write!(f, "return ")?;
|
||||
for (i, expr) in exprs.iter().enumerate() {
|
||||
write!(f, "{}", expr)?;
|
||||
if i < exprs.len() - 1 {
|
||||
write!(f, ", ")?;
|
||||
}
|
||||
}
|
||||
write!(f, ";")
|
||||
write!(
|
||||
f,
|
||||
"return {};",
|
||||
exprs
|
||||
.iter()
|
||||
.map(|e| e.to_string())
|
||||
.collect::<Vec<_>>()
|
||||
.join(", ")
|
||||
)
|
||||
}
|
||||
ZirStatement::Definition(ref lhs, ref rhs) => {
|
||||
write!(f, "{} = {};", lhs, rhs)
|
||||
|
@ -181,7 +182,7 @@ impl<'ast, T: fmt::Display> ZirStatement<'ast, T> {
|
|||
}
|
||||
ZirStatement::Log(ref l, ref expressions) => write!(
|
||||
f,
|
||||
"log(\"{}\"), {})",
|
||||
"log(\"{}\"), {});",
|
||||
l,
|
||||
expressions
|
||||
.iter()
|
||||
|
@ -203,6 +204,63 @@ pub trait Typed {
|
|||
fn get_type(&self) -> Type;
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Hash, Eq)]
|
||||
pub struct ConditionalExpression<'ast, T, E> {
|
||||
pub condition: Box<BooleanExpression<'ast, T>>,
|
||||
pub consequence: Box<E>,
|
||||
pub alternative: Box<E>,
|
||||
}
|
||||
|
||||
impl<'ast, T, E> ConditionalExpression<'ast, T, E> {
|
||||
pub fn new(condition: BooleanExpression<'ast, T>, consequence: E, alternative: E) -> Self {
|
||||
ConditionalExpression {
|
||||
condition: box condition,
|
||||
consequence: box consequence,
|
||||
alternative: box alternative,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast, T: fmt::Display, E: fmt::Display> fmt::Display for ConditionalExpression<'ast, T, E> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
"{} ? {} : {}",
|
||||
self.condition, self.consequence, self.alternative
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Hash, Eq)]
|
||||
pub struct SelectExpression<'ast, T, E> {
|
||||
pub array: Vec<E>,
|
||||
pub index: Box<UExpression<'ast, T>>,
|
||||
}
|
||||
|
||||
impl<'ast, T, E> SelectExpression<'ast, T, E> {
|
||||
pub fn new(array: Vec<E>, index: UExpression<'ast, T>) -> Self {
|
||||
SelectExpression {
|
||||
array,
|
||||
index: box index,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast, T: fmt::Display, E: fmt::Display> fmt::Display for SelectExpression<'ast, T, E> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
"{}[{}]",
|
||||
self.array
|
||||
.iter()
|
||||
.map(|a| a.to_string())
|
||||
.collect::<Vec<_>>()
|
||||
.join(", "),
|
||||
self.index
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/// A typed expression
|
||||
#[derive(Clone, PartialEq, Hash, Eq)]
|
||||
pub enum ZirExpression<'ast, T> {
|
||||
|
@ -291,7 +349,7 @@ pub enum ZirExpressionList<'ast, T> {
|
|||
pub enum FieldElementExpression<'ast, T> {
|
||||
Number(T),
|
||||
Identifier(Identifier<'ast>),
|
||||
Select(Vec<Self>, Box<UExpression<'ast, T>>),
|
||||
Select(SelectExpression<'ast, T, Self>),
|
||||
Add(
|
||||
Box<FieldElementExpression<'ast, T>>,
|
||||
Box<FieldElementExpression<'ast, T>>,
|
||||
|
@ -312,11 +370,7 @@ pub enum FieldElementExpression<'ast, T> {
|
|||
Box<FieldElementExpression<'ast, T>>,
|
||||
Box<UExpression<'ast, T>>,
|
||||
),
|
||||
Conditional(
|
||||
Box<BooleanExpression<'ast, T>>,
|
||||
Box<FieldElementExpression<'ast, T>>,
|
||||
Box<FieldElementExpression<'ast, T>>,
|
||||
),
|
||||
Conditional(ConditionalExpression<'ast, T, FieldElementExpression<'ast, T>>),
|
||||
}
|
||||
|
||||
/// An expression of type `bool`
|
||||
|
@ -324,7 +378,7 @@ pub enum FieldElementExpression<'ast, T> {
|
|||
pub enum BooleanExpression<'ast, T> {
|
||||
Value(bool),
|
||||
Identifier(Identifier<'ast>),
|
||||
Select(Vec<Self>, Box<UExpression<'ast, T>>),
|
||||
Select(SelectExpression<'ast, T, Self>),
|
||||
FieldLt(
|
||||
Box<FieldElementExpression<'ast, T>>,
|
||||
Box<FieldElementExpression<'ast, T>>,
|
||||
|
@ -333,22 +387,12 @@ pub enum BooleanExpression<'ast, T> {
|
|||
Box<FieldElementExpression<'ast, T>>,
|
||||
Box<FieldElementExpression<'ast, T>>,
|
||||
),
|
||||
FieldGe(
|
||||
Box<FieldElementExpression<'ast, T>>,
|
||||
Box<FieldElementExpression<'ast, T>>,
|
||||
),
|
||||
FieldGt(
|
||||
Box<FieldElementExpression<'ast, T>>,
|
||||
Box<FieldElementExpression<'ast, T>>,
|
||||
),
|
||||
FieldEq(
|
||||
Box<FieldElementExpression<'ast, T>>,
|
||||
Box<FieldElementExpression<'ast, T>>,
|
||||
),
|
||||
UintLt(Box<UExpression<'ast, T>>, Box<UExpression<'ast, T>>),
|
||||
UintLe(Box<UExpression<'ast, T>>, Box<UExpression<'ast, T>>),
|
||||
UintGe(Box<UExpression<'ast, T>>, Box<UExpression<'ast, T>>),
|
||||
UintGt(Box<UExpression<'ast, T>>, Box<UExpression<'ast, T>>),
|
||||
UintEq(Box<UExpression<'ast, T>>, Box<UExpression<'ast, T>>),
|
||||
BoolEq(
|
||||
Box<BooleanExpression<'ast, T>>,
|
||||
|
@ -363,11 +407,7 @@ pub enum BooleanExpression<'ast, T> {
|
|||
Box<BooleanExpression<'ast, T>>,
|
||||
),
|
||||
Not(Box<BooleanExpression<'ast, T>>),
|
||||
Conditional(
|
||||
Box<BooleanExpression<'ast, T>>,
|
||||
Box<BooleanExpression<'ast, T>>,
|
||||
Box<BooleanExpression<'ast, T>>,
|
||||
),
|
||||
Conditional(ConditionalExpression<'ast, T, BooleanExpression<'ast, T>>),
|
||||
}
|
||||
|
||||
pub struct ConjunctionIterator<T> {
|
||||
|
@ -438,26 +478,14 @@ impl<'ast, T: fmt::Display> fmt::Display for FieldElementExpression<'ast, T> {
|
|||
match *self {
|
||||
FieldElementExpression::Number(ref i) => write!(f, "{}", i),
|
||||
FieldElementExpression::Identifier(ref var) => write!(f, "{}", var),
|
||||
FieldElementExpression::Select(ref a, ref i) => write!(
|
||||
f,
|
||||
"[{}][{}]",
|
||||
a.iter()
|
||||
.map(|a| a.to_string())
|
||||
.collect::<Vec<_>>()
|
||||
.join(", "),
|
||||
i
|
||||
),
|
||||
FieldElementExpression::Select(ref e) => write!(f, "{}", e),
|
||||
FieldElementExpression::Add(ref lhs, ref rhs) => write!(f, "({} + {})", lhs, rhs),
|
||||
FieldElementExpression::Sub(ref lhs, ref rhs) => write!(f, "({} - {})", lhs, rhs),
|
||||
FieldElementExpression::Mult(ref lhs, ref rhs) => write!(f, "({} * {})", lhs, rhs),
|
||||
FieldElementExpression::Div(ref lhs, ref rhs) => write!(f, "({} / {})", lhs, rhs),
|
||||
FieldElementExpression::Pow(ref lhs, ref rhs) => write!(f, "{}**{}", lhs, rhs),
|
||||
FieldElementExpression::Conditional(ref condition, ref consequent, ref alternative) => {
|
||||
write!(
|
||||
f,
|
||||
"if {} {{ {} }} else {{ {} }}",
|
||||
condition, consequent, alternative
|
||||
)
|
||||
FieldElementExpression::Conditional(ref c) => {
|
||||
write!(f, "{}", c)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -468,15 +496,7 @@ impl<'ast, T: fmt::Display> fmt::Display for UExpression<'ast, T> {
|
|||
match self.inner {
|
||||
UExpressionInner::Value(ref v) => write!(f, "{}", v),
|
||||
UExpressionInner::Identifier(ref var) => write!(f, "{}", var),
|
||||
UExpressionInner::Select(ref a, ref i) => write!(
|
||||
f,
|
||||
"[{}][{}]",
|
||||
a.iter()
|
||||
.map(|a| a.to_string())
|
||||
.collect::<Vec<_>>()
|
||||
.join(", "),
|
||||
i
|
||||
),
|
||||
UExpressionInner::Select(ref e) => write!(f, "{}", e),
|
||||
UExpressionInner::Add(ref lhs, ref rhs) => write!(f, "({} + {})", lhs, rhs),
|
||||
UExpressionInner::Sub(ref lhs, ref rhs) => write!(f, "({} - {})", lhs, rhs),
|
||||
UExpressionInner::Mult(ref lhs, ref rhs) => write!(f, "({} * {})", lhs, rhs),
|
||||
|
@ -488,12 +508,8 @@ impl<'ast, T: fmt::Display> fmt::Display for UExpression<'ast, T> {
|
|||
UExpressionInner::LeftShift(ref e, ref by) => write!(f, "({} << {})", e, by),
|
||||
UExpressionInner::RightShift(ref e, ref by) => write!(f, "({} >> {})", e, by),
|
||||
UExpressionInner::Not(ref e) => write!(f, "!{}", e),
|
||||
UExpressionInner::Conditional(ref condition, ref consequent, ref alternative) => {
|
||||
write!(
|
||||
f,
|
||||
"if {} {{ {} }} else {{ {} }}",
|
||||
condition, consequent, alternative
|
||||
)
|
||||
UExpressionInner::Conditional(ref c) => {
|
||||
write!(f, "{}", c)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -504,35 +520,19 @@ impl<'ast, T: fmt::Display> fmt::Display for BooleanExpression<'ast, T> {
|
|||
match *self {
|
||||
BooleanExpression::Identifier(ref var) => write!(f, "{}", var),
|
||||
BooleanExpression::Value(b) => write!(f, "{}", b),
|
||||
BooleanExpression::Select(ref a, ref i) => write!(
|
||||
f,
|
||||
"[{}][{}]",
|
||||
a.iter()
|
||||
.map(|a| a.to_string())
|
||||
.collect::<Vec<_>>()
|
||||
.join(", "),
|
||||
i
|
||||
),
|
||||
BooleanExpression::FieldLt(ref lhs, ref rhs) => write!(f, "{} < {}", lhs, rhs),
|
||||
BooleanExpression::FieldLe(ref lhs, ref rhs) => write!(f, "{} <= {}", lhs, rhs),
|
||||
BooleanExpression::FieldGe(ref lhs, ref rhs) => write!(f, "{} >= {}", lhs, rhs),
|
||||
BooleanExpression::FieldGt(ref lhs, ref rhs) => write!(f, "{} > {}", lhs, rhs),
|
||||
BooleanExpression::UintLt(ref lhs, ref rhs) => write!(f, "{} < {}", lhs, rhs),
|
||||
BooleanExpression::UintLe(ref lhs, ref rhs) => write!(f, "{} <= {}", lhs, rhs),
|
||||
BooleanExpression::UintGe(ref lhs, ref rhs) => write!(f, "{} >= {}", lhs, rhs),
|
||||
BooleanExpression::UintGt(ref lhs, ref rhs) => write!(f, "{} > {}", lhs, rhs),
|
||||
BooleanExpression::FieldEq(ref lhs, ref rhs) => write!(f, "{} == {}", lhs, rhs),
|
||||
BooleanExpression::BoolEq(ref lhs, ref rhs) => write!(f, "{} == {}", lhs, rhs),
|
||||
BooleanExpression::UintEq(ref lhs, ref rhs) => write!(f, "{} == {}", lhs, rhs),
|
||||
BooleanExpression::Or(ref lhs, ref rhs) => write!(f, "{} || {}", lhs, rhs),
|
||||
BooleanExpression::And(ref lhs, ref rhs) => write!(f, "{} && {}", lhs, rhs),
|
||||
BooleanExpression::Select(ref e) => write!(f, "{}", e),
|
||||
BooleanExpression::FieldLt(ref lhs, ref rhs) => write!(f, "({} < {})", lhs, rhs),
|
||||
BooleanExpression::UintLt(ref lhs, ref rhs) => write!(f, "({} < {})", lhs, rhs),
|
||||
BooleanExpression::FieldLe(ref lhs, ref rhs) => write!(f, "({} <= {})", lhs, rhs),
|
||||
BooleanExpression::UintLe(ref lhs, ref rhs) => write!(f, "({} <= {})", lhs, rhs),
|
||||
BooleanExpression::FieldEq(ref lhs, ref rhs) => write!(f, "({} == {})", lhs, rhs),
|
||||
BooleanExpression::BoolEq(ref lhs, ref rhs) => write!(f, "({} == {})", lhs, rhs),
|
||||
BooleanExpression::UintEq(ref lhs, ref rhs) => write!(f, "({} == {})", lhs, rhs),
|
||||
BooleanExpression::Or(ref lhs, ref rhs) => write!(f, "({} || {})", lhs, rhs),
|
||||
BooleanExpression::And(ref lhs, ref rhs) => write!(f, "({} && {})", lhs, rhs),
|
||||
BooleanExpression::Not(ref exp) => write!(f, "!{}", exp),
|
||||
BooleanExpression::Conditional(ref condition, ref consequent, ref alternative) => {
|
||||
write!(
|
||||
f,
|
||||
"if {} {{ {} }} else {{ {} }}",
|
||||
condition, consequent, alternative
|
||||
)
|
||||
BooleanExpression::Conditional(ref c) => {
|
||||
write!(f, "{}", c)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -584,7 +584,81 @@ impl<'ast, T: fmt::Debug> fmt::Debug for ZirExpressionList<'ast, T> {
|
|||
}
|
||||
|
||||
// Common behaviour accross expressions
|
||||
pub trait Expr<'ast, T>: fmt::Display + PartialEq {
|
||||
type Inner;
|
||||
type Ty: Clone + IntoType;
|
||||
|
||||
fn ty(&self) -> &Self::Ty;
|
||||
|
||||
fn into_inner(self) -> Self::Inner;
|
||||
|
||||
fn as_inner(&self) -> &Self::Inner;
|
||||
|
||||
fn as_inner_mut(&mut self) -> &mut Self::Inner;
|
||||
}
|
||||
|
||||
impl<'ast, T: Field> Expr<'ast, T> for FieldElementExpression<'ast, T> {
|
||||
type Inner = Self;
|
||||
type Ty = Type;
|
||||
|
||||
fn ty(&self) -> &Self::Ty {
|
||||
&Type::FieldElement
|
||||
}
|
||||
|
||||
fn into_inner(self) -> Self::Inner {
|
||||
self
|
||||
}
|
||||
|
||||
fn as_inner(&self) -> &Self::Inner {
|
||||
self
|
||||
}
|
||||
|
||||
fn as_inner_mut(&mut self) -> &mut Self::Inner {
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast, T: Field> Expr<'ast, T> for BooleanExpression<'ast, T> {
|
||||
type Inner = Self;
|
||||
type Ty = Type;
|
||||
|
||||
fn ty(&self) -> &Self::Ty {
|
||||
&Type::Boolean
|
||||
}
|
||||
|
||||
fn into_inner(self) -> Self::Inner {
|
||||
self
|
||||
}
|
||||
|
||||
fn as_inner(&self) -> &Self::Inner {
|
||||
self
|
||||
}
|
||||
|
||||
fn as_inner_mut(&mut self) -> &mut Self::Inner {
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast, T: Field> Expr<'ast, T> for UExpression<'ast, T> {
|
||||
type Inner = UExpressionInner<'ast, T>;
|
||||
type Ty = UBitwidth;
|
||||
|
||||
fn ty(&self) -> &Self::Ty {
|
||||
&self.bitwidth
|
||||
}
|
||||
|
||||
fn into_inner(self) -> Self::Inner {
|
||||
self.inner
|
||||
}
|
||||
|
||||
fn as_inner(&self) -> &Self::Inner {
|
||||
&self.inner
|
||||
}
|
||||
|
||||
fn as_inner_mut(&mut self) -> &mut Self::Inner {
|
||||
&mut self.inner
|
||||
}
|
||||
}
|
||||
pub trait Conditional<'ast, T> {
|
||||
fn conditional(
|
||||
condition: BooleanExpression<'ast, T>,
|
||||
|
@ -593,13 +667,22 @@ pub trait Conditional<'ast, T> {
|
|||
) -> Self;
|
||||
}
|
||||
|
||||
pub enum ConditionalOrExpression<'ast, T, E: Expr<'ast, T>> {
|
||||
Conditional(ConditionalExpression<'ast, T, E>),
|
||||
Expression(E::Inner),
|
||||
}
|
||||
|
||||
impl<'ast, T> Conditional<'ast, T> for FieldElementExpression<'ast, T> {
|
||||
fn conditional(
|
||||
condition: BooleanExpression<'ast, T>,
|
||||
consequence: Self,
|
||||
alternative: Self,
|
||||
) -> Self {
|
||||
FieldElementExpression::Conditional(box condition, box consequence, box alternative)
|
||||
FieldElementExpression::Conditional(ConditionalExpression::new(
|
||||
condition,
|
||||
consequence,
|
||||
alternative,
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -609,7 +692,11 @@ impl<'ast, T> Conditional<'ast, T> for BooleanExpression<'ast, T> {
|
|||
consequence: Self,
|
||||
alternative: Self,
|
||||
) -> Self {
|
||||
BooleanExpression::Conditional(box condition, box consequence, box alternative)
|
||||
BooleanExpression::Conditional(ConditionalExpression::new(
|
||||
condition,
|
||||
consequence,
|
||||
alternative,
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -621,7 +708,56 @@ impl<'ast, T> Conditional<'ast, T> for UExpression<'ast, T> {
|
|||
) -> Self {
|
||||
let bitwidth = consequence.bitwidth;
|
||||
|
||||
UExpressionInner::Conditional(box condition, box consequence, box alternative)
|
||||
.annotate(bitwidth)
|
||||
UExpressionInner::Conditional(ConditionalExpression::new(
|
||||
condition,
|
||||
consequence,
|
||||
alternative,
|
||||
))
|
||||
.annotate(bitwidth)
|
||||
}
|
||||
}
|
||||
|
||||
pub trait Select<'ast, T>: Sized {
|
||||
fn select(array: Vec<Self>, index: UExpression<'ast, T>) -> Self;
|
||||
}
|
||||
|
||||
pub enum SelectOrExpression<'ast, T, E: Expr<'ast, T>> {
|
||||
Select(SelectExpression<'ast, T, E>),
|
||||
Expression(E::Inner),
|
||||
}
|
||||
|
||||
impl<'ast, T> Select<'ast, T> for FieldElementExpression<'ast, T> {
|
||||
fn select(array: Vec<Self>, index: UExpression<'ast, T>) -> Self {
|
||||
FieldElementExpression::Select(SelectExpression::new(array, index))
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast, T> Select<'ast, T> for BooleanExpression<'ast, T> {
|
||||
fn select(array: Vec<Self>, index: UExpression<'ast, T>) -> Self {
|
||||
BooleanExpression::Select(SelectExpression::new(array, index))
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast, T> Select<'ast, T> for UExpression<'ast, T> {
|
||||
fn select(array: Vec<Self>, index: UExpression<'ast, T>) -> Self {
|
||||
let bitwidth = array[0].bitwidth;
|
||||
|
||||
UExpressionInner::Select(SelectExpression::new(array, index)).annotate(bitwidth)
|
||||
}
|
||||
}
|
||||
|
||||
pub trait IntoType {
|
||||
fn into_type(self) -> Type;
|
||||
}
|
||||
|
||||
impl IntoType for Type {
|
||||
fn into_type(self) -> Type {
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl IntoType for UBitwidth {
|
||||
fn into_type(self) -> Type {
|
||||
Type::Uint(self)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -4,6 +4,27 @@ use crate::zir::types::UBitwidth;
|
|||
use crate::zir::*;
|
||||
use zokrates_field::Field;
|
||||
|
||||
pub trait ResultFold<'ast, T: Field>: Sized {
|
||||
fn fold<F: ResultFolder<'ast, T>>(self, f: &mut F) -> Result<Self, F::Error>;
|
||||
}
|
||||
|
||||
impl<'ast, T: Field> ResultFold<'ast, T> for FieldElementExpression<'ast, T> {
|
||||
fn fold<F: ResultFolder<'ast, T>>(self, f: &mut F) -> Result<Self, F::Error> {
|
||||
f.fold_field_expression(self)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast, T: Field> ResultFold<'ast, T> for BooleanExpression<'ast, T> {
|
||||
fn fold<F: ResultFolder<'ast, T>>(self, f: &mut F) -> Result<Self, F::Error> {
|
||||
f.fold_boolean_expression(self)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast, T: Field> ResultFold<'ast, T> for UExpression<'ast, T> {
|
||||
fn fold<F: ResultFolder<'ast, T>>(self, f: &mut F) -> Result<Self, F::Error> {
|
||||
f.fold_uint_expression(self)
|
||||
}
|
||||
}
|
||||
pub trait ResultFolder<'ast, T: Field>: Sized {
|
||||
type Error;
|
||||
|
||||
|
@ -76,6 +97,24 @@ pub trait ResultFolder<'ast, T: Field>: Sized {
|
|||
}
|
||||
}
|
||||
|
||||
fn fold_conditional_expression<
|
||||
E: Expr<'ast, T> + ResultFold<'ast, T> + Conditional<'ast, T>,
|
||||
>(
|
||||
&mut self,
|
||||
ty: &E::Ty,
|
||||
e: ConditionalExpression<'ast, T, E>,
|
||||
) -> Result<ConditionalOrExpression<'ast, T, E>, Self::Error> {
|
||||
fold_conditional_expression(self, ty, e)
|
||||
}
|
||||
|
||||
fn fold_select_expression<E: Clone + Expr<'ast, T> + ResultFold<'ast, T> + Select<'ast, T>>(
|
||||
&mut self,
|
||||
ty: &E::Ty,
|
||||
e: SelectExpression<'ast, T, E>,
|
||||
) -> Result<SelectOrExpression<'ast, T, E>, Self::Error> {
|
||||
fold_select_expression(self, ty, e)
|
||||
}
|
||||
|
||||
fn fold_field_expression(
|
||||
&mut self,
|
||||
e: FieldElementExpression<'ast, T>,
|
||||
|
@ -173,12 +212,12 @@ pub fn fold_field_expression<'ast, T: Field, F: ResultFolder<'ast, T>>(
|
|||
FieldElementExpression::Identifier(id) => {
|
||||
FieldElementExpression::Identifier(f.fold_name(id)?)
|
||||
}
|
||||
FieldElementExpression::Select(a, box i) => FieldElementExpression::Select(
|
||||
a.into_iter()
|
||||
.map(|a| f.fold_field_expression(a))
|
||||
.collect::<Result<_, _>>()?,
|
||||
box f.fold_uint_expression(i)?,
|
||||
),
|
||||
FieldElementExpression::Select(e) => {
|
||||
match f.fold_select_expression(&Type::FieldElement, e)? {
|
||||
SelectOrExpression::Select(s) => FieldElementExpression::Select(s),
|
||||
SelectOrExpression::Expression(u) => u,
|
||||
}
|
||||
}
|
||||
FieldElementExpression::Add(box e1, box e2) => {
|
||||
let e1 = f.fold_field_expression(e1)?;
|
||||
let e2 = f.fold_field_expression(e2)?;
|
||||
|
@ -204,11 +243,11 @@ pub fn fold_field_expression<'ast, T: Field, F: ResultFolder<'ast, T>>(
|
|||
let e2 = f.fold_uint_expression(e2)?;
|
||||
FieldElementExpression::Pow(box e1, box e2)
|
||||
}
|
||||
FieldElementExpression::Conditional(box cond, box cons, box alt) => {
|
||||
let cond = f.fold_boolean_expression(cond)?;
|
||||
let cons = f.fold_field_expression(cons)?;
|
||||
let alt = f.fold_field_expression(alt)?;
|
||||
FieldElementExpression::Conditional(box cond, box cons, box alt)
|
||||
FieldElementExpression::Conditional(c) => {
|
||||
match f.fold_conditional_expression(&Type::FieldElement, c)? {
|
||||
ConditionalOrExpression::Conditional(s) => FieldElementExpression::Conditional(s),
|
||||
ConditionalOrExpression::Expression(u) => u,
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
@ -220,12 +259,10 @@ pub fn fold_boolean_expression<'ast, T: Field, F: ResultFolder<'ast, T>>(
|
|||
Ok(match e {
|
||||
BooleanExpression::Value(v) => BooleanExpression::Value(v),
|
||||
BooleanExpression::Identifier(id) => BooleanExpression::Identifier(f.fold_name(id)?),
|
||||
BooleanExpression::Select(a, box i) => BooleanExpression::Select(
|
||||
a.into_iter()
|
||||
.map(|a| f.fold_boolean_expression(a))
|
||||
.collect::<Result<_, _>>()?,
|
||||
box f.fold_uint_expression(i)?,
|
||||
),
|
||||
BooleanExpression::Select(e) => match f.fold_select_expression(&Type::Boolean, e)? {
|
||||
SelectOrExpression::Select(s) => BooleanExpression::Select(s),
|
||||
SelectOrExpression::Expression(u) => u,
|
||||
},
|
||||
BooleanExpression::FieldEq(box e1, box e2) => {
|
||||
let e1 = f.fold_field_expression(e1)?;
|
||||
let e2 = f.fold_field_expression(e2)?;
|
||||
|
@ -246,41 +283,21 @@ pub fn fold_boolean_expression<'ast, T: Field, F: ResultFolder<'ast, T>>(
|
|||
let e2 = f.fold_field_expression(e2)?;
|
||||
BooleanExpression::FieldLt(box e1, box e2)
|
||||
}
|
||||
BooleanExpression::FieldLe(box e1, box e2) => {
|
||||
let e1 = f.fold_field_expression(e1)?;
|
||||
let e2 = f.fold_field_expression(e2)?;
|
||||
BooleanExpression::FieldLe(box e1, box e2)
|
||||
}
|
||||
BooleanExpression::FieldGt(box e1, box e2) => {
|
||||
let e1 = f.fold_field_expression(e1)?;
|
||||
let e2 = f.fold_field_expression(e2)?;
|
||||
BooleanExpression::FieldGt(box e1, box e2)
|
||||
}
|
||||
BooleanExpression::FieldGe(box e1, box e2) => {
|
||||
let e1 = f.fold_field_expression(e1)?;
|
||||
let e2 = f.fold_field_expression(e2)?;
|
||||
BooleanExpression::FieldGe(box e1, box e2)
|
||||
}
|
||||
BooleanExpression::UintLt(box e1, box e2) => {
|
||||
let e1 = f.fold_uint_expression(e1)?;
|
||||
let e2 = f.fold_uint_expression(e2)?;
|
||||
BooleanExpression::UintLt(box e1, box e2)
|
||||
}
|
||||
BooleanExpression::FieldLe(box e1, box e2) => {
|
||||
let e1 = f.fold_field_expression(e1)?;
|
||||
let e2 = f.fold_field_expression(e2)?;
|
||||
BooleanExpression::FieldLe(box e1, box e2)
|
||||
}
|
||||
BooleanExpression::UintLe(box e1, box e2) => {
|
||||
let e1 = f.fold_uint_expression(e1)?;
|
||||
let e2 = f.fold_uint_expression(e2)?;
|
||||
BooleanExpression::UintLe(box e1, box e2)
|
||||
}
|
||||
BooleanExpression::UintGt(box e1, box e2) => {
|
||||
let e1 = f.fold_uint_expression(e1)?;
|
||||
let e2 = f.fold_uint_expression(e2)?;
|
||||
BooleanExpression::UintGt(box e1, box e2)
|
||||
}
|
||||
BooleanExpression::UintGe(box e1, box e2) => {
|
||||
let e1 = f.fold_uint_expression(e1)?;
|
||||
let e2 = f.fold_uint_expression(e2)?;
|
||||
BooleanExpression::UintGe(box e1, box e2)
|
||||
}
|
||||
BooleanExpression::Or(box e1, box e2) => {
|
||||
let e1 = f.fold_boolean_expression(e1)?;
|
||||
let e2 = f.fold_boolean_expression(e2)?;
|
||||
|
@ -295,11 +312,11 @@ pub fn fold_boolean_expression<'ast, T: Field, F: ResultFolder<'ast, T>>(
|
|||
let e = f.fold_boolean_expression(e)?;
|
||||
BooleanExpression::Not(box e)
|
||||
}
|
||||
BooleanExpression::Conditional(box cond, box cons, box alt) => {
|
||||
let cond = f.fold_boolean_expression(cond)?;
|
||||
let cons = f.fold_boolean_expression(cons)?;
|
||||
let alt = f.fold_boolean_expression(alt)?;
|
||||
BooleanExpression::Conditional(box cond, box cons, box alt)
|
||||
BooleanExpression::Conditional(c) => {
|
||||
match f.fold_conditional_expression(&Type::Boolean, c)? {
|
||||
ConditionalOrExpression::Conditional(s) => BooleanExpression::Conditional(s),
|
||||
ConditionalOrExpression::Expression(u) => u,
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
@ -316,18 +333,16 @@ pub fn fold_uint_expression<'ast, T: Field, F: ResultFolder<'ast, T>>(
|
|||
|
||||
pub fn fold_uint_expression_inner<'ast, T: Field, F: ResultFolder<'ast, T>>(
|
||||
f: &mut F,
|
||||
_: UBitwidth,
|
||||
ty: UBitwidth,
|
||||
e: UExpressionInner<'ast, T>,
|
||||
) -> Result<UExpressionInner<'ast, T>, F::Error> {
|
||||
Ok(match e {
|
||||
UExpressionInner::Value(v) => UExpressionInner::Value(v),
|
||||
UExpressionInner::Identifier(id) => UExpressionInner::Identifier(f.fold_name(id)?),
|
||||
UExpressionInner::Select(a, box i) => UExpressionInner::Select(
|
||||
a.into_iter()
|
||||
.map(|a| f.fold_uint_expression(a))
|
||||
.collect::<Result<_, _>>()?,
|
||||
box f.fold_uint_expression(i)?,
|
||||
),
|
||||
UExpressionInner::Select(e) => match f.fold_select_expression(&ty, e)? {
|
||||
SelectOrExpression::Select(s) => UExpressionInner::Select(s),
|
||||
SelectOrExpression::Expression(u) => u,
|
||||
},
|
||||
UExpressionInner::Add(box left, box right) => {
|
||||
let left = f.fold_uint_expression(left)?;
|
||||
let right = f.fold_uint_expression(right)?;
|
||||
|
@ -391,13 +406,10 @@ pub fn fold_uint_expression_inner<'ast, T: Field, F: ResultFolder<'ast, T>>(
|
|||
|
||||
UExpressionInner::Not(box e)
|
||||
}
|
||||
UExpressionInner::Conditional(box cond, box cons, box alt) => {
|
||||
let cond = f.fold_boolean_expression(cond)?;
|
||||
let cons = f.fold_uint_expression(cons)?;
|
||||
let alt = f.fold_uint_expression(alt)?;
|
||||
|
||||
UExpressionInner::Conditional(box cond, box cons, box alt)
|
||||
}
|
||||
UExpressionInner::Conditional(c) => match f.fold_conditional_expression(&ty, c)? {
|
||||
ConditionalOrExpression::Conditional(s) => UExpressionInner::Conditional(s),
|
||||
ConditionalOrExpression::Expression(u) => u,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -431,3 +443,41 @@ pub fn fold_program<'ast, T: Field, F: ResultFolder<'ast, T>>(
|
|||
main: f.fold_function(p.main)?,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn fold_conditional_expression<
|
||||
'ast,
|
||||
T: Field,
|
||||
E: Expr<'ast, T> + ResultFold<'ast, T> + Conditional<'ast, T>,
|
||||
F: ResultFolder<'ast, T>,
|
||||
>(
|
||||
f: &mut F,
|
||||
_: &E::Ty,
|
||||
e: ConditionalExpression<'ast, T, E>,
|
||||
) -> Result<ConditionalOrExpression<'ast, T, E>, F::Error> {
|
||||
Ok(ConditionalOrExpression::Conditional(
|
||||
ConditionalExpression::new(
|
||||
f.fold_boolean_expression(*e.condition)?,
|
||||
e.consequence.fold(f)?,
|
||||
e.alternative.fold(f)?,
|
||||
),
|
||||
))
|
||||
}
|
||||
|
||||
pub fn fold_select_expression<
|
||||
'ast,
|
||||
T: Field,
|
||||
E: Expr<'ast, T> + ResultFold<'ast, T> + Select<'ast, T>,
|
||||
F: ResultFolder<'ast, T>,
|
||||
>(
|
||||
f: &mut F,
|
||||
_: &E::Ty,
|
||||
e: SelectExpression<'ast, T, E>,
|
||||
) -> Result<SelectOrExpression<'ast, T, E>, F::Error> {
|
||||
Ok(SelectOrExpression::Select(SelectExpression::new(
|
||||
e.array
|
||||
.into_iter()
|
||||
.map(|e| e.fold(f))
|
||||
.collect::<Result<Vec<_>, _>>()?,
|
||||
e.index.fold(f)?,
|
||||
)))
|
||||
}
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
use crate::zir::identifier::Identifier;
|
||||
use crate::zir::types::UBitwidth;
|
||||
use crate::zir::BooleanExpression;
|
||||
use zokrates_field::Field;
|
||||
|
||||
use super::{ConditionalExpression, SelectExpression};
|
||||
|
||||
impl<'ast, T: Field> UExpression<'ast, T> {
|
||||
#[allow(clippy::should_implement_trait)]
|
||||
pub fn add(self, other: Self) -> UExpression<'ast, T> {
|
||||
|
@ -20,7 +21,7 @@ impl<'ast, T: Field> UExpression<'ast, T> {
|
|||
|
||||
pub fn select(values: Vec<Self>, index: Self) -> UExpression<'ast, T> {
|
||||
let bitwidth = values[0].bitwidth;
|
||||
UExpressionInner::Select(values, box index).annotate(bitwidth)
|
||||
UExpressionInner::Select(SelectExpression::new(values, index)).annotate(bitwidth)
|
||||
}
|
||||
|
||||
pub fn mult(self, other: Self) -> UExpression<'ast, T> {
|
||||
|
@ -178,7 +179,7 @@ pub struct UExpression<'ast, T> {
|
|||
pub enum UExpressionInner<'ast, T> {
|
||||
Value(u128),
|
||||
Identifier(Identifier<'ast>),
|
||||
Select(Vec<UExpression<'ast, T>>, Box<UExpression<'ast, T>>),
|
||||
Select(SelectExpression<'ast, T, UExpression<'ast, T>>),
|
||||
Add(Box<UExpression<'ast, T>>, Box<UExpression<'ast, T>>),
|
||||
Sub(Box<UExpression<'ast, T>>, Box<UExpression<'ast, T>>),
|
||||
Mult(Box<UExpression<'ast, T>>, Box<UExpression<'ast, T>>),
|
||||
|
@ -190,11 +191,7 @@ pub enum UExpressionInner<'ast, T> {
|
|||
LeftShift(Box<UExpression<'ast, T>>, u32),
|
||||
RightShift(Box<UExpression<'ast, T>>, u32),
|
||||
Not(Box<UExpression<'ast, T>>),
|
||||
Conditional(
|
||||
Box<BooleanExpression<'ast, T>>,
|
||||
Box<UExpression<'ast, T>>,
|
||||
Box<UExpression<'ast, T>>,
|
||||
),
|
||||
Conditional(ConditionalExpression<'ast, T, UExpression<'ast, T>>),
|
||||
}
|
||||
|
||||
impl<'ast, T> UExpressionInner<'ast, T> {
|
||||
|
|
|
@ -8,7 +8,9 @@
|
|||
mod utils;
|
||||
|
||||
use self::utils::flat_expression_from_bits;
|
||||
use zokrates_ast::zir::{ShouldReduce, UMetadata, ZirExpressionList};
|
||||
use zokrates_ast::zir::{
|
||||
ConditionalExpression, SelectExpression, ShouldReduce, UMetadata, ZirExpressionList,
|
||||
};
|
||||
use zokrates_interpreter::Interpreter;
|
||||
|
||||
use crate::compile::CompileConfig;
|
||||
|
@ -558,13 +560,15 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
/// * `alternative` - the alternative of type U.
|
||||
/// # Remarks
|
||||
/// * U is the type of the expression
|
||||
fn flatten_if_else_expression<U: Flatten<'ast, T>>(
|
||||
fn flatten_conditional_expression<U: Flatten<'ast, T>>(
|
||||
&mut self,
|
||||
statements_flattened: &mut FlatStatements<T>,
|
||||
condition: BooleanExpression<'ast, T>,
|
||||
consequence: U,
|
||||
alternative: U,
|
||||
e: ConditionalExpression<'ast, T, U>,
|
||||
) -> FlatUExpression<T> {
|
||||
let condition = *e.condition;
|
||||
let consequence = *e.consequence;
|
||||
let alternative = *e.alternative;
|
||||
|
||||
let condition_flat =
|
||||
self.flatten_boolean_expression(statements_flattened, condition.clone());
|
||||
|
||||
|
@ -859,8 +863,8 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
BooleanExpression::Identifier(x) => {
|
||||
FlatExpression::Identifier(*self.layout.get(&x).unwrap())
|
||||
}
|
||||
BooleanExpression::Select(a, box index) => self
|
||||
.flatten_select_expression(statements_flattened, a, index)
|
||||
BooleanExpression::Select(e) => self
|
||||
.flatten_select_expression(statements_flattened, e)
|
||||
.get_field_unchecked(),
|
||||
BooleanExpression::FieldLt(box lhs, box rhs) => {
|
||||
// Get the bit width to know the size of the binary decompositions for this Field
|
||||
|
@ -949,14 +953,6 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
);
|
||||
FlatExpression::Add(box eq, box lt)
|
||||
}
|
||||
BooleanExpression::FieldGt(lhs, rhs) => self.flatten_boolean_expression(
|
||||
statements_flattened,
|
||||
BooleanExpression::FieldLt(rhs, lhs),
|
||||
),
|
||||
BooleanExpression::FieldGe(lhs, rhs) => self.flatten_boolean_expression(
|
||||
statements_flattened,
|
||||
BooleanExpression::FieldLe(rhs, lhs),
|
||||
),
|
||||
BooleanExpression::UintLt(box lhs, box rhs) => {
|
||||
let bit_width = lhs.bitwidth.to_usize();
|
||||
assert!(lhs.metadata.as_ref().unwrap().should_reduce.to_bool());
|
||||
|
@ -987,14 +983,6 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
);
|
||||
FlatExpression::Add(box eq, box lt)
|
||||
}
|
||||
BooleanExpression::UintGt(lhs, rhs) => self.flatten_boolean_expression(
|
||||
statements_flattened,
|
||||
BooleanExpression::UintLt(rhs, lhs),
|
||||
),
|
||||
BooleanExpression::UintGe(lhs, rhs) => self.flatten_boolean_expression(
|
||||
statements_flattened,
|
||||
BooleanExpression::UintLe(rhs, lhs),
|
||||
),
|
||||
BooleanExpression::Or(box lhs, box rhs) => {
|
||||
let x = self.flatten_boolean_expression(statements_flattened, lhs);
|
||||
let y = self.flatten_boolean_expression(statements_flattened, rhs);
|
||||
|
@ -1036,13 +1024,8 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
true => T::from(1),
|
||||
false => T::from(0),
|
||||
}),
|
||||
BooleanExpression::Conditional(box condition, box consequence, box alternative) => self
|
||||
.flatten_if_else_expression(
|
||||
statements_flattened,
|
||||
condition,
|
||||
consequence,
|
||||
alternative,
|
||||
)
|
||||
BooleanExpression::Conditional(e) => self
|
||||
.flatten_conditional_expression(statements_flattened, e)
|
||||
.get_field_unchecked(),
|
||||
}
|
||||
}
|
||||
|
@ -1473,9 +1456,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
});
|
||||
FlatUExpression::with_field(field).bits(bits)
|
||||
}
|
||||
UExpressionInner::Select(a, box index) => {
|
||||
self.flatten_select_expression(statements_flattened, a, index)
|
||||
}
|
||||
UExpressionInner::Select(e) => self.flatten_select_expression(statements_flattened, e),
|
||||
UExpressionInner::Not(box e) => {
|
||||
let e = self.flatten_uint_expression(statements_flattened, e);
|
||||
|
||||
|
@ -1633,13 +1614,9 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
|
||||
FlatUExpression::with_field(r)
|
||||
}
|
||||
UExpressionInner::Conditional(box condition, box consequence, box alternative) => self
|
||||
.flatten_if_else_expression(
|
||||
statements_flattened,
|
||||
condition,
|
||||
consequence,
|
||||
alternative,
|
||||
),
|
||||
UExpressionInner::Conditional(e) => {
|
||||
self.flatten_conditional_expression(statements_flattened, e)
|
||||
}
|
||||
UExpressionInner::Xor(box left, box right) => {
|
||||
let left_metadata = left.metadata.clone().unwrap();
|
||||
let right_metadata = right.metadata.clone().unwrap();
|
||||
|
@ -2039,10 +2016,12 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
fn flatten_select_expression<U: Flatten<'ast, T>>(
|
||||
&mut self,
|
||||
statements_flattened: &mut FlatStatements<T>,
|
||||
a: Vec<U>,
|
||||
index: UExpression<'ast, T>,
|
||||
e: SelectExpression<'ast, T, U>,
|
||||
) -> FlatUExpression<T> {
|
||||
let (range_check, result) = a
|
||||
let array = e.array;
|
||||
let index = *e.index;
|
||||
|
||||
let (range_check, result) = array
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
.map(|(i, e)| {
|
||||
|
@ -2108,8 +2087,8 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
FieldElementExpression::Identifier(x) => {
|
||||
FlatExpression::Identifier(*self.layout.get(&x).unwrap_or_else(|| panic!("{}", x)))
|
||||
}
|
||||
FieldElementExpression::Select(a, box index) => self
|
||||
.flatten_select_expression(statements_flattened, a, index)
|
||||
FieldElementExpression::Select(e) => self
|
||||
.flatten_select_expression(statements_flattened, e)
|
||||
.get_field_unchecked(),
|
||||
FieldElementExpression::Add(box left, box right) => {
|
||||
let left_flattened = self.flatten_field_expression(statements_flattened, left);
|
||||
|
@ -2294,17 +2273,8 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
_ => panic!("Expected number as pow exponent"),
|
||||
}
|
||||
}
|
||||
FieldElementExpression::Conditional(
|
||||
box condition,
|
||||
box consequence,
|
||||
box alternative,
|
||||
) => self
|
||||
.flatten_if_else_expression(
|
||||
statements_flattened,
|
||||
condition,
|
||||
consequence,
|
||||
alternative,
|
||||
)
|
||||
FieldElementExpression::Conditional(e) => self
|
||||
.flatten_conditional_expression(statements_flattened, e)
|
||||
.get_field_unchecked(),
|
||||
}
|
||||
}
|
||||
|
@ -2428,8 +2398,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
error.into(),
|
||||
)
|
||||
}
|
||||
BooleanExpression::FieldLt(box lhs, box rhs)
|
||||
| BooleanExpression::FieldGt(box rhs, box lhs) => {
|
||||
BooleanExpression::FieldLt(box lhs, box rhs) => {
|
||||
let lhs = self.flatten_field_expression(statements_flattened, lhs);
|
||||
let rhs = self.flatten_field_expression(statements_flattened, rhs);
|
||||
|
||||
|
@ -2459,8 +2428,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
}
|
||||
}
|
||||
}
|
||||
BooleanExpression::FieldLe(box lhs, box rhs)
|
||||
| BooleanExpression::FieldGe(box rhs, box lhs) => {
|
||||
BooleanExpression::FieldLe(box lhs, box rhs) => {
|
||||
let lhs = self.flatten_field_expression(statements_flattened, lhs);
|
||||
let rhs = self.flatten_field_expression(statements_flattened, rhs);
|
||||
|
||||
|
@ -3531,13 +3499,13 @@ mod tests {
|
|||
#[test]
|
||||
fn if_else() {
|
||||
let config = CompileConfig::default();
|
||||
let expression = FieldElementExpression::Conditional(
|
||||
box BooleanExpression::FieldEq(
|
||||
let expression = FieldElementExpression::conditional(
|
||||
BooleanExpression::FieldEq(
|
||||
box FieldElementExpression::Number(Bn128Field::from(32)),
|
||||
box FieldElementExpression::Number(Bn128Field::from(4)),
|
||||
),
|
||||
box FieldElementExpression::Number(Bn128Field::from(12)),
|
||||
box FieldElementExpression::Number(Bn128Field::from(51)),
|
||||
FieldElementExpression::Number(Bn128Field::from(12)),
|
||||
FieldElementExpression::Number(Bn128Field::from(51)),
|
||||
);
|
||||
|
||||
let mut flattener = Flattener::new(config);
|
||||
|
@ -3554,13 +3522,6 @@ mod tests {
|
|||
box FieldElementExpression::Number(Bn128Field::from(4)),
|
||||
);
|
||||
flattener.flatten_boolean_expression(&mut FlatStatements::new(), expression_le);
|
||||
|
||||
let mut flattener = Flattener::new(config);
|
||||
let expression_ge = BooleanExpression::FieldGe(
|
||||
box FieldElementExpression::Number(Bn128Field::from(32)),
|
||||
box FieldElementExpression::Number(Bn128Field::from(4)),
|
||||
);
|
||||
flattener.flatten_boolean_expression(&mut FlatStatements::new(), expression_ge);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
@ -3568,8 +3529,8 @@ mod tests {
|
|||
let config = CompileConfig::default();
|
||||
let mut flattener = Flattener::new(config);
|
||||
|
||||
let expression = FieldElementExpression::Conditional(
|
||||
box BooleanExpression::And(
|
||||
let expression = FieldElementExpression::conditional(
|
||||
BooleanExpression::And(
|
||||
box BooleanExpression::FieldEq(
|
||||
box FieldElementExpression::Number(Bn128Field::from(4)),
|
||||
box FieldElementExpression::Number(Bn128Field::from(4)),
|
||||
|
@ -3579,8 +3540,8 @@ mod tests {
|
|||
box FieldElementExpression::Number(Bn128Field::from(20)),
|
||||
),
|
||||
),
|
||||
box FieldElementExpression::Number(Bn128Field::from(12)),
|
||||
box FieldElementExpression::Number(Bn128Field::from(51)),
|
||||
FieldElementExpression::Number(Bn128Field::from(12)),
|
||||
FieldElementExpression::Number(Bn128Field::from(51)),
|
||||
);
|
||||
|
||||
flattener.flatten_field_expression(&mut FlatStatements::new(), expression);
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
use std::marker::PhantomData;
|
||||
use zokrates_ast::typed::types::UBitwidth;
|
||||
use zokrates_ast::typed::{self, Expr, Typed};
|
||||
use zokrates_ast::zir;
|
||||
use zokrates_ast::zir::{self, Select};
|
||||
use zokrates_field::Field;
|
||||
|
||||
use std::convert::{TryFrom, TryInto};
|
||||
|
@ -746,36 +746,35 @@ fn fold_select_expression<'ast, T: Field, E>(
|
|||
let ty = a[0].get_type();
|
||||
|
||||
match ty {
|
||||
zir::Type::Boolean => zir::BooleanExpression::Select(
|
||||
zir::Type::Boolean => zir::BooleanExpression::select(
|
||||
a.into_iter()
|
||||
.map(|e| match e {
|
||||
zir::ZirExpression::Boolean(e) => e.clone(),
|
||||
_ => unreachable!(),
|
||||
})
|
||||
.collect(),
|
||||
box index.clone(),
|
||||
index.clone(),
|
||||
)
|
||||
.into(),
|
||||
zir::Type::FieldElement => zir::FieldElementExpression::Select(
|
||||
zir::Type::FieldElement => zir::FieldElementExpression::select(
|
||||
a.into_iter()
|
||||
.map(|e| match e {
|
||||
zir::ZirExpression::FieldElement(e) => e.clone(),
|
||||
_ => unreachable!(),
|
||||
})
|
||||
.collect(),
|
||||
box index.clone(),
|
||||
index.clone(),
|
||||
)
|
||||
.into(),
|
||||
zir::Type::Uint(bitwidth) => zir::UExpressionInner::Select(
|
||||
zir::Type::Uint(_) => zir::UExpression::select(
|
||||
a.into_iter()
|
||||
.map(|e| match e {
|
||||
zir::ZirExpression::Uint(e) => e.clone(),
|
||||
_ => unreachable!(),
|
||||
})
|
||||
.collect(),
|
||||
box index.clone(),
|
||||
index.clone(),
|
||||
)
|
||||
.annotate(bitwidth)
|
||||
.into(),
|
||||
}
|
||||
})
|
||||
|
@ -987,12 +986,12 @@ fn fold_boolean_expression<'ast, T: Field>(
|
|||
typed::BooleanExpression::FieldGt(box e1, box e2) => {
|
||||
let e1 = f.fold_field_expression(statements_buffer, e1);
|
||||
let e2 = f.fold_field_expression(statements_buffer, e2);
|
||||
zir::BooleanExpression::FieldGt(box e1, box e2)
|
||||
zir::BooleanExpression::FieldLt(box e2, box e1)
|
||||
}
|
||||
typed::BooleanExpression::FieldGe(box e1, box e2) => {
|
||||
let e1 = f.fold_field_expression(statements_buffer, e1);
|
||||
let e2 = f.fold_field_expression(statements_buffer, e2);
|
||||
zir::BooleanExpression::FieldGe(box e1, box e2)
|
||||
zir::BooleanExpression::FieldLe(box e2, box e1)
|
||||
}
|
||||
typed::BooleanExpression::UintLt(box e1, box e2) => {
|
||||
let e1 = f.fold_uint_expression(statements_buffer, e1);
|
||||
|
@ -1007,12 +1006,12 @@ fn fold_boolean_expression<'ast, T: Field>(
|
|||
typed::BooleanExpression::UintGt(box e1, box e2) => {
|
||||
let e1 = f.fold_uint_expression(statements_buffer, e1);
|
||||
let e2 = f.fold_uint_expression(statements_buffer, e2);
|
||||
zir::BooleanExpression::UintGt(box e1, box e2)
|
||||
zir::BooleanExpression::UintLt(box e2, box e1)
|
||||
}
|
||||
typed::BooleanExpression::UintGe(box e1, box e2) => {
|
||||
let e1 = f.fold_uint_expression(statements_buffer, e1);
|
||||
let e2 = f.fold_uint_expression(statements_buffer, e2);
|
||||
zir::BooleanExpression::UintGe(box e1, box e2)
|
||||
zir::BooleanExpression::UintLe(box e2, box e1)
|
||||
}
|
||||
typed::BooleanExpression::Or(box e1, box e2) => {
|
||||
let e1 = f.fold_boolean_expression(statements_buffer, e1);
|
||||
|
|
|
@ -55,22 +55,15 @@ fn force_no_reduce<T: Field>(e: UExpression<T>) -> UExpression<T> {
|
|||
}
|
||||
|
||||
impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> {
|
||||
fn fold_field_expression(
|
||||
fn fold_select_expression<E: Expr<'ast, T> + Fold<'ast, T> + Select<'ast, T>>(
|
||||
&mut self,
|
||||
e: FieldElementExpression<'ast, T>,
|
||||
) -> FieldElementExpression<'ast, T> {
|
||||
match e {
|
||||
FieldElementExpression::Select(a, box i) => {
|
||||
let a = a
|
||||
.into_iter()
|
||||
.map(|e| self.fold_field_expression(e))
|
||||
.collect();
|
||||
let i = self.fold_uint_expression(i);
|
||||
_: &E::Ty,
|
||||
e: SelectExpression<'ast, T, E>,
|
||||
) -> SelectOrExpression<'ast, T, E> {
|
||||
let array = e.array.into_iter().map(|e| e.fold(self)).collect();
|
||||
let index = e.index.fold(self);
|
||||
|
||||
FieldElementExpression::Select(a, box force_reduce(i))
|
||||
}
|
||||
_ => fold_field_expression(self, e),
|
||||
}
|
||||
SelectOrExpression::Select(SelectExpression::new(array, force_reduce(index)))
|
||||
}
|
||||
|
||||
fn fold_boolean_expression(
|
||||
|
@ -78,15 +71,6 @@ impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> {
|
|||
e: BooleanExpression<'ast, T>,
|
||||
) -> BooleanExpression<'ast, T> {
|
||||
match e {
|
||||
BooleanExpression::Select(a, box i) => {
|
||||
let a = a
|
||||
.into_iter()
|
||||
.map(|e| self.fold_boolean_expression(e))
|
||||
.collect();
|
||||
let i = self.fold_uint_expression(i);
|
||||
|
||||
BooleanExpression::Select(a, box force_reduce(i))
|
||||
}
|
||||
BooleanExpression::UintEq(box left, box right) => {
|
||||
let left = self.fold_uint_expression(left);
|
||||
let right = self.fold_uint_expression(right);
|
||||
|
@ -114,24 +98,6 @@ impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> {
|
|||
|
||||
BooleanExpression::UintLe(box left, box right)
|
||||
}
|
||||
BooleanExpression::UintGt(box left, box right) => {
|
||||
let left = self.fold_uint_expression(left);
|
||||
let right = self.fold_uint_expression(right);
|
||||
|
||||
let left = force_reduce(left);
|
||||
let right = force_reduce(right);
|
||||
|
||||
BooleanExpression::UintGt(box left, box right)
|
||||
}
|
||||
BooleanExpression::UintGe(box left, box right) => {
|
||||
let left = self.fold_uint_expression(left);
|
||||
let right = self.fold_uint_expression(right);
|
||||
|
||||
let left = force_reduce(left);
|
||||
let right = force_reduce(right);
|
||||
|
||||
BooleanExpression::UintGe(box left, box right)
|
||||
}
|
||||
e => fold_boolean_expression(self, e),
|
||||
}
|
||||
}
|
||||
|
@ -161,12 +127,15 @@ impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> {
|
|||
.cloned()
|
||||
.unwrap_or_else(|| panic!("identifier should have been defined: {}", id)),
|
||||
),
|
||||
Select(values, box index) => {
|
||||
Select(e) => {
|
||||
let index = *e.index;
|
||||
let array = e.array;
|
||||
|
||||
let index = self.fold_uint_expression(index);
|
||||
|
||||
let index = force_reduce(index);
|
||||
|
||||
let values: Vec<_> = values
|
||||
let values: Vec<_> = array
|
||||
.into_iter()
|
||||
.map(|v| force_no_reduce(self.fold_uint_expression(v)))
|
||||
.collect();
|
||||
|
@ -389,10 +358,10 @@ impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> {
|
|||
|
||||
UExpression::right_shift(force_reduce(e), by).with_max(max)
|
||||
}
|
||||
Conditional(box condition, box consequence, box alternative) => {
|
||||
let condition = self.fold_boolean_expression(condition);
|
||||
let consequence = self.fold_uint_expression(consequence);
|
||||
let alternative = self.fold_uint_expression(alternative);
|
||||
Conditional(e) => {
|
||||
let condition = self.fold_boolean_expression(*e.condition);
|
||||
let consequence = e.consequence.fold(self);
|
||||
let alternative = e.alternative.fold(self);
|
||||
|
||||
let consequence_max = consequence.metadata.clone().unwrap().max;
|
||||
let alternative_max = alternative.metadata.clone().unwrap().max;
|
||||
|
|
|
@ -1,8 +1,17 @@
|
|||
use std::collections::HashMap;
|
||||
use std::fmt;
|
||||
use zokrates_ast::zir::result_folder::fold_boolean_expression;
|
||||
use zokrates_ast::zir::result_folder::fold_field_expression;
|
||||
use zokrates_ast::zir::result_folder::fold_statement;
|
||||
use zokrates_ast::zir::result_folder::ResultFold;
|
||||
use zokrates_ast::zir::result_folder::ResultFolder;
|
||||
use zokrates_ast::zir::types::UBitwidth;
|
||||
use zokrates_ast::zir::Conditional;
|
||||
use zokrates_ast::zir::ConditionalExpression;
|
||||
use zokrates_ast::zir::ConditionalOrExpression;
|
||||
use zokrates_ast::zir::Expr;
|
||||
use zokrates_ast::zir::SelectExpression;
|
||||
use zokrates_ast::zir::SelectOrExpression;
|
||||
use zokrates_ast::zir::{
|
||||
BooleanExpression, FieldElementExpression, Identifier, RuntimeError, UExpression,
|
||||
UExpressionInner, ZirExpression, ZirProgram, ZirStatement,
|
||||
|
@ -136,24 +145,6 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for ZirPropagator<'ast, T> {
|
|||
}
|
||||
_ => Ok(FieldElementExpression::Identifier(id)),
|
||||
},
|
||||
FieldElementExpression::Select(e, box index) => {
|
||||
let index = self.fold_uint_expression(index)?;
|
||||
let e: Vec<FieldElementExpression<'ast, T>> = e
|
||||
.into_iter()
|
||||
.map(|e| self.fold_field_expression(e))
|
||||
.collect::<Result<_, _>>()?;
|
||||
|
||||
match index.into_inner() {
|
||||
UExpressionInner::Value(v) => e
|
||||
.get(v as usize)
|
||||
.cloned()
|
||||
.ok_or(Error::OutOfBounds(v as usize, e.len())),
|
||||
i => Ok(FieldElementExpression::Select(
|
||||
e,
|
||||
box i.annotate(UBitwidth::B32),
|
||||
)),
|
||||
}
|
||||
}
|
||||
FieldElementExpression::Add(box e1, box e2) => {
|
||||
match (
|
||||
self.fold_field_expression(e1)?,
|
||||
|
@ -237,28 +228,7 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for ZirPropagator<'ast, T> {
|
|||
)),
|
||||
}
|
||||
}
|
||||
FieldElementExpression::Conditional(
|
||||
box condition,
|
||||
box consequence,
|
||||
box alternative,
|
||||
) => {
|
||||
let condition = self.fold_boolean_expression(condition)?;
|
||||
let consequence = self.fold_field_expression(consequence)?;
|
||||
let alternative = self.fold_field_expression(alternative)?;
|
||||
|
||||
match (condition, consequence, alternative) {
|
||||
(_, consequence, alternative) if consequence == alternative => Ok(consequence),
|
||||
(BooleanExpression::Value(true), consequence, _) => Ok(consequence),
|
||||
(BooleanExpression::Value(false), _, alternative) => Ok(alternative),
|
||||
(condition, consequence, alternative) => {
|
||||
Ok(FieldElementExpression::Conditional(
|
||||
box condition,
|
||||
box consequence,
|
||||
box alternative,
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
e => fold_field_expression(self, e),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -274,21 +244,6 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for ZirPropagator<'ast, T> {
|
|||
}
|
||||
_ => Ok(BooleanExpression::Identifier(id)),
|
||||
},
|
||||
BooleanExpression::Select(e, box index) => {
|
||||
let index = self.fold_uint_expression(index)?;
|
||||
let e: Vec<BooleanExpression<'ast, T>> = e
|
||||
.into_iter()
|
||||
.map(|e| self.fold_boolean_expression(e))
|
||||
.collect::<Result<_, _>>()?;
|
||||
|
||||
match index.as_inner() {
|
||||
UExpressionInner::Value(v) => e
|
||||
.get(*v as usize)
|
||||
.cloned()
|
||||
.ok_or(Error::OutOfBounds(*v as usize, e.len())),
|
||||
_ => Ok(BooleanExpression::Select(e, box index)),
|
||||
}
|
||||
}
|
||||
BooleanExpression::FieldLt(box e1, box e2) => {
|
||||
match (
|
||||
self.fold_field_expression(e1)?,
|
||||
|
@ -317,34 +272,6 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for ZirPropagator<'ast, T> {
|
|||
(e1, e2) => Ok(BooleanExpression::FieldLe(box e1, box e2)),
|
||||
}
|
||||
}
|
||||
BooleanExpression::FieldGe(box e1, box e2) => {
|
||||
match (
|
||||
self.fold_field_expression(e1)?,
|
||||
self.fold_field_expression(e2)?,
|
||||
) {
|
||||
(FieldElementExpression::Number(n1), FieldElementExpression::Number(n2)) => {
|
||||
Ok(BooleanExpression::Value(n1 >= n2))
|
||||
}
|
||||
(e1, e2) => Ok(BooleanExpression::FieldGe(box e1, box e2)),
|
||||
}
|
||||
}
|
||||
BooleanExpression::FieldGt(box e1, box e2) => {
|
||||
match (
|
||||
self.fold_field_expression(e1)?,
|
||||
self.fold_field_expression(e2)?,
|
||||
) {
|
||||
(FieldElementExpression::Number(n1), FieldElementExpression::Number(n2)) => {
|
||||
Ok(BooleanExpression::Value(n1 > n2))
|
||||
}
|
||||
(_, FieldElementExpression::Number(c)) if c == T::max_value() => {
|
||||
Ok(BooleanExpression::Value(false))
|
||||
}
|
||||
(FieldElementExpression::Number(c), _) if c == T::zero() => {
|
||||
Ok(BooleanExpression::Value(false))
|
||||
}
|
||||
(e1, e2) => Ok(BooleanExpression::FieldGt(box e1, box e2)),
|
||||
}
|
||||
}
|
||||
BooleanExpression::FieldEq(box e1, box e2) => {
|
||||
match (
|
||||
self.fold_field_expression(e1)?,
|
||||
|
@ -384,28 +311,6 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for ZirPropagator<'ast, T> {
|
|||
_ => Ok(BooleanExpression::UintLe(box e1, box e2)),
|
||||
}
|
||||
}
|
||||
BooleanExpression::UintGe(box e1, box e2) => {
|
||||
let e1 = self.fold_uint_expression(e1)?;
|
||||
let e2 = self.fold_uint_expression(e2)?;
|
||||
|
||||
match (e1.as_inner(), e2.as_inner()) {
|
||||
(UExpressionInner::Value(v1), UExpressionInner::Value(v2)) => {
|
||||
Ok(BooleanExpression::Value(v1 >= v2))
|
||||
}
|
||||
_ => Ok(BooleanExpression::UintGe(box e1, box e2)),
|
||||
}
|
||||
}
|
||||
BooleanExpression::UintGt(box e1, box e2) => {
|
||||
let e1 = self.fold_uint_expression(e1)?;
|
||||
let e2 = self.fold_uint_expression(e2)?;
|
||||
|
||||
match (e1.as_inner(), e2.as_inner()) {
|
||||
(UExpressionInner::Value(v1), UExpressionInner::Value(v2)) => {
|
||||
Ok(BooleanExpression::Value(v1 > v2))
|
||||
}
|
||||
_ => Ok(BooleanExpression::UintGt(box e1, box e2)),
|
||||
}
|
||||
}
|
||||
BooleanExpression::UintEq(box e1, box e2) => {
|
||||
let e1 = self.fold_uint_expression(e1)?;
|
||||
let e2 = self.fold_uint_expression(e2)?;
|
||||
|
@ -475,22 +380,33 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for ZirPropagator<'ast, T> {
|
|||
BooleanExpression::Value(v) => Ok(BooleanExpression::Value(!v)),
|
||||
e => Ok(BooleanExpression::Not(box e)),
|
||||
},
|
||||
BooleanExpression::Conditional(box condition, box consequence, box alternative) => {
|
||||
let condition = self.fold_boolean_expression(condition)?;
|
||||
let consequence = self.fold_boolean_expression(consequence)?;
|
||||
let alternative = self.fold_boolean_expression(alternative)?;
|
||||
e => fold_boolean_expression(self, e),
|
||||
}
|
||||
}
|
||||
|
||||
match (condition, consequence, alternative) {
|
||||
(_, consequence, alternative) if consequence == alternative => Ok(consequence),
|
||||
(BooleanExpression::Value(true), consequence, _) => Ok(consequence),
|
||||
(BooleanExpression::Value(false), _, alternative) => Ok(alternative),
|
||||
(condition, consequence, alternative) => Ok(BooleanExpression::Conditional(
|
||||
box condition,
|
||||
box consequence,
|
||||
box alternative,
|
||||
)),
|
||||
}
|
||||
}
|
||||
fn fold_select_expression<
|
||||
E: Clone + Expr<'ast, T> + ResultFold<'ast, T> + zokrates_ast::zir::Select<'ast, T>,
|
||||
>(
|
||||
&mut self,
|
||||
_: &E::Ty,
|
||||
e: SelectExpression<'ast, T, E>,
|
||||
) -> Result<zokrates_ast::zir::SelectOrExpression<'ast, T, E>, Self::Error> {
|
||||
let index = self.fold_uint_expression(*e.index)?;
|
||||
let array = e
|
||||
.array
|
||||
.into_iter()
|
||||
.map(|e| e.fold(self))
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
|
||||
match index.as_inner() {
|
||||
UExpressionInner::Value(v) => array
|
||||
.get(*v as usize)
|
||||
.cloned()
|
||||
.ok_or(Error::OutOfBounds(*v as usize, array.len()))
|
||||
.map(|e| SelectOrExpression::Expression(e.into_inner())),
|
||||
_ => Ok(SelectOrExpression::Expression(
|
||||
E::select(array, index).into_inner(),
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -505,22 +421,6 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for ZirPropagator<'ast, T> {
|
|||
Some(ZirExpression::Uint(e)) => Ok(e.as_inner().clone()),
|
||||
_ => Ok(UExpressionInner::Identifier(id)),
|
||||
},
|
||||
UExpressionInner::Select(e, box index) => {
|
||||
let index = self.fold_uint_expression(index)?;
|
||||
let e: Vec<UExpression<'ast, T>> = e
|
||||
.into_iter()
|
||||
.map(|e| self.fold_uint_expression(e))
|
||||
.collect::<Result<_, _>>()?;
|
||||
|
||||
match index.into_inner() {
|
||||
UExpressionInner::Value(v) => e
|
||||
.get(v as usize)
|
||||
.cloned()
|
||||
.ok_or(Error::OutOfBounds(v as usize, e.len()))
|
||||
.map(|e| e.into_inner()),
|
||||
i => Ok(UExpressionInner::Select(e, box i.annotate(UBitwidth::B32))),
|
||||
}
|
||||
}
|
||||
UExpressionInner::Add(box e1, box e2) => {
|
||||
let e1 = self.fold_uint_expression(e1)?;
|
||||
let e2 = self.fold_uint_expression(e2)?;
|
||||
|
@ -687,22 +587,34 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for ZirPropagator<'ast, T> {
|
|||
e => Ok(UExpressionInner::Not(box e.annotate(bitwidth))),
|
||||
}
|
||||
}
|
||||
UExpressionInner::Conditional(box condition, box consequence, box alternative) => {
|
||||
let condition = self.fold_boolean_expression(condition)?;
|
||||
let consequence = self.fold_uint_expression(consequence)?.into_inner();
|
||||
let alternative = self.fold_uint_expression(alternative)?.into_inner();
|
||||
e => self.fold_uint_expression_inner(bitwidth, e),
|
||||
}
|
||||
}
|
||||
|
||||
match (condition, consequence, alternative) {
|
||||
(_, consequence, alternative) if consequence == alternative => Ok(consequence),
|
||||
(BooleanExpression::Value(true), consequence, _) => Ok(consequence),
|
||||
(BooleanExpression::Value(false), _, alternative) => Ok(alternative),
|
||||
(condition, consequence, alternative) => Ok(UExpressionInner::Conditional(
|
||||
box condition,
|
||||
box consequence.annotate(bitwidth),
|
||||
box alternative.annotate(bitwidth),
|
||||
)),
|
||||
}
|
||||
}
|
||||
fn fold_conditional_expression<
|
||||
E: Expr<'ast, T> + ResultFold<'ast, T> + Conditional<'ast, T>,
|
||||
>(
|
||||
&mut self,
|
||||
_: &E::Ty,
|
||||
e: ConditionalExpression<'ast, T, E>,
|
||||
) -> Result<ConditionalOrExpression<'ast, T, E>, Self::Error> {
|
||||
let condition = self.fold_boolean_expression(*e.condition)?;
|
||||
let consequence = e.consequence.fold(self)?;
|
||||
let alternative = e.alternative.fold(self)?;
|
||||
|
||||
match (condition, consequence, alternative) {
|
||||
(_, consequence, alternative) if consequence == alternative => Ok(
|
||||
ConditionalOrExpression::Expression(consequence.into_inner()),
|
||||
),
|
||||
(BooleanExpression::Value(true), consequence, _) => Ok(
|
||||
ConditionalOrExpression::Expression(consequence.into_inner()),
|
||||
),
|
||||
(BooleanExpression::Value(false), _, alternative) => Ok(
|
||||
ConditionalOrExpression::Expression(alternative.into_inner()),
|
||||
),
|
||||
(condition, consequence, alternative) => Ok(ConditionalOrExpression::Conditional(
|
||||
ConditionalExpression::new(condition, consequence, alternative),
|
||||
)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -754,6 +666,8 @@ mod tests {
|
|||
|
||||
#[cfg(test)]
|
||||
mod field {
|
||||
use zokrates_ast::zir::Select;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
|
@ -761,23 +675,23 @@ mod tests {
|
|||
let mut propagator = ZirPropagator::default();
|
||||
|
||||
assert_eq!(
|
||||
propagator.fold_field_expression(FieldElementExpression::Select(
|
||||
propagator.fold_field_expression(FieldElementExpression::select(
|
||||
vec![
|
||||
FieldElementExpression::Number(Bn128Field::from(1)),
|
||||
FieldElementExpression::Number(Bn128Field::from(2)),
|
||||
],
|
||||
box UExpressionInner::Value(1).annotate(UBitwidth::B32),
|
||||
UExpressionInner::Value(1).annotate(UBitwidth::B32),
|
||||
)),
|
||||
Ok(FieldElementExpression::Number(Bn128Field::from(2)))
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
propagator.fold_field_expression(FieldElementExpression::Select(
|
||||
propagator.fold_field_expression(FieldElementExpression::select(
|
||||
vec![
|
||||
FieldElementExpression::Number(Bn128Field::from(1)),
|
||||
FieldElementExpression::Number(Bn128Field::from(2)),
|
||||
],
|
||||
box UExpressionInner::Value(3).annotate(UBitwidth::B32),
|
||||
UExpressionInner::Value(3).annotate(UBitwidth::B32),
|
||||
)),
|
||||
Err(Error::OutOfBounds(3, 2))
|
||||
);
|
||||
|
@ -923,28 +837,28 @@ mod tests {
|
|||
let mut propagator = ZirPropagator::default();
|
||||
|
||||
assert_eq!(
|
||||
propagator.fold_field_expression(FieldElementExpression::Conditional(
|
||||
box BooleanExpression::Value(true),
|
||||
box FieldElementExpression::Number(Bn128Field::from(1)),
|
||||
box FieldElementExpression::Number(Bn128Field::from(2)),
|
||||
propagator.fold_field_expression(FieldElementExpression::conditional(
|
||||
BooleanExpression::Value(true),
|
||||
FieldElementExpression::Number(Bn128Field::from(1)),
|
||||
FieldElementExpression::Number(Bn128Field::from(2)),
|
||||
)),
|
||||
Ok(FieldElementExpression::Number(Bn128Field::from(1)))
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
propagator.fold_field_expression(FieldElementExpression::Conditional(
|
||||
box BooleanExpression::Value(false),
|
||||
box FieldElementExpression::Number(Bn128Field::from(1)),
|
||||
box FieldElementExpression::Number(Bn128Field::from(2)),
|
||||
propagator.fold_field_expression(FieldElementExpression::conditional(
|
||||
BooleanExpression::Value(false),
|
||||
FieldElementExpression::Number(Bn128Field::from(1)),
|
||||
FieldElementExpression::Number(Bn128Field::from(2)),
|
||||
)),
|
||||
Ok(FieldElementExpression::Number(Bn128Field::from(2)))
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
propagator.fold_field_expression(FieldElementExpression::Conditional(
|
||||
box BooleanExpression::Identifier("a".into()),
|
||||
box FieldElementExpression::Number(Bn128Field::from(2)),
|
||||
box FieldElementExpression::Number(Bn128Field::from(2)),
|
||||
propagator.fold_field_expression(FieldElementExpression::conditional(
|
||||
BooleanExpression::Identifier("a".into()),
|
||||
FieldElementExpression::Number(Bn128Field::from(2)),
|
||||
FieldElementExpression::Number(Bn128Field::from(2)),
|
||||
)),
|
||||
Ok(FieldElementExpression::Number(Bn128Field::from(2)))
|
||||
);
|
||||
|
@ -953,6 +867,8 @@ mod tests {
|
|||
|
||||
#[cfg(test)]
|
||||
mod bool {
|
||||
use zokrates_ast::zir::Select;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
|
@ -960,23 +876,23 @@ mod tests {
|
|||
let mut propagator = ZirPropagator::<Bn128Field>::default();
|
||||
|
||||
assert_eq!(
|
||||
propagator.fold_boolean_expression(BooleanExpression::Select(
|
||||
propagator.fold_boolean_expression(BooleanExpression::select(
|
||||
vec![
|
||||
BooleanExpression::Value(false),
|
||||
BooleanExpression::Value(true),
|
||||
],
|
||||
box UExpressionInner::Value(1).annotate(UBitwidth::B32),
|
||||
UExpressionInner::Value(1).annotate(UBitwidth::B32),
|
||||
)),
|
||||
Ok(BooleanExpression::Value(true))
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
propagator.fold_boolean_expression(BooleanExpression::Select(
|
||||
propagator.fold_boolean_expression(BooleanExpression::select(
|
||||
vec![
|
||||
BooleanExpression::Value(false),
|
||||
BooleanExpression::Value(true),
|
||||
],
|
||||
box UExpressionInner::Value(3).annotate(UBitwidth::B32),
|
||||
UExpressionInner::Value(3).annotate(UBitwidth::B32),
|
||||
)),
|
||||
Err(Error::OutOfBounds(3, 2))
|
||||
);
|
||||
|
@ -1040,64 +956,6 @@ mod tests {
|
|||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn field_ge() {
|
||||
let mut propagator = ZirPropagator::default();
|
||||
|
||||
assert_eq!(
|
||||
propagator.fold_boolean_expression(BooleanExpression::FieldGe(
|
||||
box FieldElementExpression::Number(Bn128Field::from(3)),
|
||||
box FieldElementExpression::Number(Bn128Field::from(2)),
|
||||
)),
|
||||
Ok(BooleanExpression::Value(true))
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
propagator.fold_boolean_expression(BooleanExpression::FieldGe(
|
||||
box FieldElementExpression::Number(Bn128Field::from(3)),
|
||||
box FieldElementExpression::Number(Bn128Field::from(3)),
|
||||
)),
|
||||
Ok(BooleanExpression::Value(true))
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn field_gt() {
|
||||
let mut propagator = ZirPropagator::default();
|
||||
|
||||
assert_eq!(
|
||||
propagator.fold_boolean_expression(BooleanExpression::FieldGt(
|
||||
box FieldElementExpression::Number(Bn128Field::from(3)),
|
||||
box FieldElementExpression::Number(Bn128Field::from(2)),
|
||||
)),
|
||||
Ok(BooleanExpression::Value(true))
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
propagator.fold_boolean_expression(BooleanExpression::FieldGt(
|
||||
box FieldElementExpression::Number(Bn128Field::from(3)),
|
||||
box FieldElementExpression::Number(Bn128Field::from(3)),
|
||||
)),
|
||||
Ok(BooleanExpression::Value(false))
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
propagator.fold_boolean_expression(BooleanExpression::FieldGt(
|
||||
box FieldElementExpression::Number(Bn128Field::from(0)),
|
||||
box FieldElementExpression::Identifier("a".into()),
|
||||
)),
|
||||
Ok(BooleanExpression::Value(false))
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
propagator.fold_boolean_expression(BooleanExpression::FieldGt(
|
||||
box FieldElementExpression::Identifier("a".into()),
|
||||
box FieldElementExpression::Number(Bn128Field::max_value()),
|
||||
)),
|
||||
Ok(BooleanExpression::Value(false))
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn field_eq() {
|
||||
let mut propagator = ZirPropagator::default();
|
||||
|
@ -1161,48 +1019,6 @@ mod tests {
|
|||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn uint_ge() {
|
||||
let mut propagator = ZirPropagator::<Bn128Field>::default();
|
||||
|
||||
assert_eq!(
|
||||
propagator.fold_boolean_expression(BooleanExpression::UintGe(
|
||||
box UExpressionInner::Value(3).annotate(UBitwidth::B32),
|
||||
box UExpressionInner::Value(2).annotate(UBitwidth::B32),
|
||||
)),
|
||||
Ok(BooleanExpression::Value(true))
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
propagator.fold_boolean_expression(BooleanExpression::UintGe(
|
||||
box UExpressionInner::Value(3).annotate(UBitwidth::B32),
|
||||
box UExpressionInner::Value(3).annotate(UBitwidth::B32),
|
||||
)),
|
||||
Ok(BooleanExpression::Value(true))
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn uint_gt() {
|
||||
let mut propagator = ZirPropagator::<Bn128Field>::default();
|
||||
|
||||
assert_eq!(
|
||||
propagator.fold_boolean_expression(BooleanExpression::UintGt(
|
||||
box UExpressionInner::Value(3).annotate(UBitwidth::B32),
|
||||
box UExpressionInner::Value(2).annotate(UBitwidth::B32),
|
||||
)),
|
||||
Ok(BooleanExpression::Value(true))
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
propagator.fold_boolean_expression(BooleanExpression::UintGt(
|
||||
box UExpressionInner::Value(3).annotate(UBitwidth::B32),
|
||||
box UExpressionInner::Value(3).annotate(UBitwidth::B32),
|
||||
)),
|
||||
Ok(BooleanExpression::Value(false))
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn uint_eq() {
|
||||
let mut propagator = ZirPropagator::<Bn128Field>::default();
|
||||
|
@ -1327,28 +1143,28 @@ mod tests {
|
|||
let mut propagator = ZirPropagator::<Bn128Field>::default();
|
||||
|
||||
assert_eq!(
|
||||
propagator.fold_boolean_expression(BooleanExpression::Conditional(
|
||||
box BooleanExpression::Value(true),
|
||||
box BooleanExpression::Value(true),
|
||||
box BooleanExpression::Value(false)
|
||||
propagator.fold_boolean_expression(BooleanExpression::conditional(
|
||||
BooleanExpression::Value(true),
|
||||
BooleanExpression::Value(true),
|
||||
BooleanExpression::Value(false)
|
||||
)),
|
||||
Ok(BooleanExpression::Value(true))
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
propagator.fold_boolean_expression(BooleanExpression::Conditional(
|
||||
box BooleanExpression::Value(false),
|
||||
box BooleanExpression::Value(true),
|
||||
box BooleanExpression::Value(false)
|
||||
propagator.fold_boolean_expression(BooleanExpression::conditional(
|
||||
BooleanExpression::Value(false),
|
||||
BooleanExpression::Value(true),
|
||||
BooleanExpression::Value(false)
|
||||
)),
|
||||
Ok(BooleanExpression::Value(false))
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
propagator.fold_boolean_expression(BooleanExpression::Conditional(
|
||||
box BooleanExpression::Identifier("a".into()),
|
||||
box BooleanExpression::Value(true),
|
||||
box BooleanExpression::Value(true)
|
||||
propagator.fold_boolean_expression(BooleanExpression::conditional(
|
||||
BooleanExpression::Identifier("a".into()),
|
||||
BooleanExpression::Value(true),
|
||||
BooleanExpression::Value(true)
|
||||
)),
|
||||
Ok(BooleanExpression::Value(true))
|
||||
);
|
||||
|
@ -1366,13 +1182,14 @@ mod tests {
|
|||
assert_eq!(
|
||||
propagator.fold_uint_expression_inner(
|
||||
UBitwidth::B32,
|
||||
UExpressionInner::Select(
|
||||
UExpression::select(
|
||||
vec![
|
||||
UExpressionInner::Value(1).annotate(UBitwidth::B32),
|
||||
UExpressionInner::Value(2).annotate(UBitwidth::B32),
|
||||
],
|
||||
box UExpressionInner::Value(1).annotate(UBitwidth::B32),
|
||||
UExpressionInner::Value(1).annotate(UBitwidth::B32),
|
||||
)
|
||||
.into_inner()
|
||||
),
|
||||
Ok(UExpressionInner::Value(2))
|
||||
);
|
||||
|
@ -1380,13 +1197,14 @@ mod tests {
|
|||
assert_eq!(
|
||||
propagator.fold_uint_expression_inner(
|
||||
UBitwidth::B32,
|
||||
UExpressionInner::Select(
|
||||
UExpression::select(
|
||||
vec![
|
||||
UExpressionInner::Value(1).annotate(UBitwidth::B32),
|
||||
UExpressionInner::Value(2).annotate(UBitwidth::B32),
|
||||
],
|
||||
box UExpressionInner::Value(3).annotate(UBitwidth::B32),
|
||||
UExpressionInner::Value(3).annotate(UBitwidth::B32),
|
||||
)
|
||||
.into_inner()
|
||||
),
|
||||
Err(Error::OutOfBounds(3, 2))
|
||||
);
|
||||
|
@ -1752,11 +1570,12 @@ mod tests {
|
|||
assert_eq!(
|
||||
propagator.fold_uint_expression_inner(
|
||||
UBitwidth::B32,
|
||||
UExpressionInner::Conditional(
|
||||
box BooleanExpression::Value(true),
|
||||
box UExpressionInner::Value(1).annotate(UBitwidth::B32),
|
||||
box UExpressionInner::Value(2).annotate(UBitwidth::B32),
|
||||
UExpression::conditional(
|
||||
BooleanExpression::Value(true),
|
||||
UExpressionInner::Value(1).annotate(UBitwidth::B32),
|
||||
UExpressionInner::Value(2).annotate(UBitwidth::B32),
|
||||
)
|
||||
.into_inner()
|
||||
),
|
||||
Ok(UExpressionInner::Value(1))
|
||||
);
|
||||
|
@ -1764,11 +1583,12 @@ mod tests {
|
|||
assert_eq!(
|
||||
propagator.fold_uint_expression_inner(
|
||||
UBitwidth::B32,
|
||||
UExpressionInner::Conditional(
|
||||
box BooleanExpression::Value(false),
|
||||
box UExpressionInner::Value(1).annotate(UBitwidth::B32),
|
||||
box UExpressionInner::Value(2).annotate(UBitwidth::B32),
|
||||
UExpression::conditional(
|
||||
BooleanExpression::Value(false),
|
||||
UExpressionInner::Value(1).annotate(UBitwidth::B32),
|
||||
UExpressionInner::Value(2).annotate(UBitwidth::B32),
|
||||
)
|
||||
.into_inner()
|
||||
),
|
||||
Ok(UExpressionInner::Value(2))
|
||||
);
|
||||
|
@ -1776,11 +1596,12 @@ mod tests {
|
|||
assert_eq!(
|
||||
propagator.fold_uint_expression_inner(
|
||||
UBitwidth::B32,
|
||||
UExpressionInner::Conditional(
|
||||
box BooleanExpression::Identifier("a".into()),
|
||||
box UExpressionInner::Value(2).annotate(UBitwidth::B32),
|
||||
box UExpressionInner::Value(2).annotate(UBitwidth::B32),
|
||||
UExpression::conditional(
|
||||
BooleanExpression::Identifier("a".into()),
|
||||
UExpressionInner::Value(2).annotate(UBitwidth::B32),
|
||||
UExpressionInner::Value(2).annotate(UBitwidth::B32),
|
||||
)
|
||||
.into_inner()
|
||||
),
|
||||
Ok(UExpressionInner::Value(2))
|
||||
);
|
||||
|
|
Loading…
Reference in a new issue