implement blocks for other basic types outside multidef
This commit is contained in:
parent
2904568862
commit
ec3e8f6310
13 changed files with 366 additions and 139 deletions
25
test.zok
25
test.zok
|
@ -1,10 +1,21 @@
|
|||
def zero(field x) -> field:
|
||||
assert(x == 0)
|
||||
return 0
|
||||
// def zero(field x) -> field:
|
||||
// assert(x == 0)
|
||||
// return 0
|
||||
|
||||
def inverse(field x) -> field:
|
||||
assert(x != 0)
|
||||
// def inverse(field x) -> field:
|
||||
// assert(x != 0)
|
||||
// return 1/x
|
||||
|
||||
// def main(field x) -> field:
|
||||
// return if x == 0 then zero(x) else inverse(x) fi
|
||||
|
||||
def yes(bool x) -> bool:
|
||||
assert(x)
|
||||
return x
|
||||
|
||||
def main(field x) -> field:
|
||||
return if x == 0 then zero(x) else inverse(x) fi
|
||||
def no(bool x) -> bool:
|
||||
assert(!x)
|
||||
return !x
|
||||
|
||||
def main(bool x) -> bool:
|
||||
return if x then yes(x) else no(x) fi
|
|
@ -25,45 +25,6 @@ use zokrates_field::Field;
|
|||
|
||||
type FlatStatements<T> = Vec<FlatStatement<T>>;
|
||||
|
||||
struct Fallible<'ast, T, U> {
|
||||
pub inner: U,
|
||||
pub success: BooleanExpression<'ast, T>
|
||||
}
|
||||
|
||||
impl<'ast, U, T: Field> From<U> for Fallible<'ast, T, U> {
|
||||
fn from(e: U) -> Self {
|
||||
Fallible {
|
||||
inner: e,
|
||||
success: BooleanExpression::Value(true)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast, U, T> Fallible<'ast, T, U> {
|
||||
fn split(self) -> (U, BooleanExpression<'ast, T>) {
|
||||
(self.inner, self.success)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast, U, T: Field> Fallible<'ast, T, U> {
|
||||
fn unwrap(self) -> U {
|
||||
let (res, success) = self.split();
|
||||
assert_eq!(success, BooleanExpression::Value(true));
|
||||
return res
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast, T: Field> Fallible<'ast, T, FlatUExpression<T>> {
|
||||
fn get_field_unchecked(self) -> Fallible<'ast, T, FlatExpression<T>> {
|
||||
let (res, success) = self.split();
|
||||
Fallible {
|
||||
inner: res.get_field_unchecked(),
|
||||
success
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/// Flattener, computes flattened program.
|
||||
#[derive(Debug)]
|
||||
pub struct Flattener<'ast, T: Field> {
|
||||
|
@ -92,12 +53,6 @@ impl<T: Field> FlattenOutput<T> for FlatUExpression<T> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<'ast, T: Field, U: FlattenOutput<T> + Clone> FlattenOutput<T> for Fallible<'ast, T, U> {
|
||||
fn flat(&self) -> FlatExpression<T> {
|
||||
self.inner.clone().flat()
|
||||
}
|
||||
}
|
||||
|
||||
// We introduce a trait in order to make it possible to make flattening `e` generic over the type of `e`
|
||||
|
||||
trait Flatten<'ast, T: Field>: TryFrom<ZirExpression<'ast, T>, Error = ()> + IfElse<'ast, T> {
|
||||
|
@ -111,7 +66,7 @@ trait Flatten<'ast, T: Field>: TryFrom<ZirExpression<'ast, T>, Error = ()> + IfE
|
|||
}
|
||||
|
||||
impl<'ast, T: Field> Flatten<'ast, T> for FieldElementExpression<'ast, T> {
|
||||
type Output = Fallible<'ast, T, FlatExpression<T>>;
|
||||
type Output = FlatExpression<T>;
|
||||
|
||||
fn flatten(
|
||||
self,
|
||||
|
@ -481,7 +436,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
condition: BooleanExpression<'ast, T>,
|
||||
consequence: U,
|
||||
alternative: U,
|
||||
) -> Fallible<'ast, T, FlatUExpression<T>> {
|
||||
) -> FlatUExpression<T> {
|
||||
let condition = self.flatten_boolean_expression(statements_flattened, condition);
|
||||
|
||||
let mut consequence_statements = vec![];
|
||||
|
@ -504,15 +459,15 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
.join("\n")
|
||||
);
|
||||
|
||||
// let consequence_statements =
|
||||
// self.make_conditional(consequence_statements, condition_id.into());
|
||||
// let alternative_statements = self.make_conditional(
|
||||
// alternative_statements,
|
||||
// FlatExpression::Sub(
|
||||
// box FlatExpression::Number(T::one()),
|
||||
// box condition_id.into(),
|
||||
// ),
|
||||
// );
|
||||
let consequence_statements =
|
||||
self.make_conditional(consequence_statements, condition_id.into());
|
||||
let alternative_statements = self.make_conditional(
|
||||
alternative_statements,
|
||||
FlatExpression::Sub(
|
||||
box FlatExpression::Number(T::one()),
|
||||
box condition_id.into(),
|
||||
),
|
||||
);
|
||||
|
||||
println!(
|
||||
"AFTER\n {}\n",
|
||||
|
@ -568,7 +523,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
FlatUExpression {
|
||||
field: Some(FlatExpression::Identifier(res)),
|
||||
bits: None,
|
||||
}.into()
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute a strict check against a constant
|
||||
|
@ -692,6 +647,12 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
) -> FlatExpression<T> {
|
||||
// those will be booleans in the future
|
||||
match expression {
|
||||
BooleanExpression::Block(statements, box value) => {
|
||||
for s in statements {
|
||||
self.flatten_statement(statements_flattened, s);
|
||||
}
|
||||
self.flatten_boolean_expression(statements_flattened, value)
|
||||
}
|
||||
BooleanExpression::Identifier(x) => {
|
||||
FlatExpression::Identifier(*self.layout.get(&x).unwrap())
|
||||
}
|
||||
|
@ -702,8 +663,8 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
// We know from semantic checking that lhs and rhs have the same type
|
||||
// What the expression will flatten to depends on that type
|
||||
|
||||
let (lhs_flattened, lhs_success) = self.flatten_field_expression(statements_flattened, lhs).split();
|
||||
let (rhs_flattened, rhs_success) = self.flatten_field_expression(statements_flattened, rhs).split();
|
||||
let lhs_flattened = self.flatten_field_expression(statements_flattened, lhs);
|
||||
let rhs_flattened = self.flatten_field_expression(statements_flattened, rhs);
|
||||
|
||||
match (lhs_flattened, rhs_flattened) {
|
||||
(x, FlatExpression::Number(constant)) => {
|
||||
|
@ -918,8 +879,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
|
||||
let rhs = self.flatten_field_expression(statements_flattened, rhs);
|
||||
|
||||
unimplemented!()
|
||||
// self.eq_check(statements_flattened, lhs, rhs)
|
||||
self.eq_check(statements_flattened, lhs, rhs)
|
||||
}
|
||||
BooleanExpression::UintEq(box lhs, box rhs) => {
|
||||
// We reduce each side into range and apply the same approach as for field elements
|
||||
|
@ -1098,7 +1058,6 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
consequence,
|
||||
alternative,
|
||||
)
|
||||
.unwrap()
|
||||
.get_field_unchecked(),
|
||||
}
|
||||
}
|
||||
|
@ -1129,7 +1088,6 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
.into_iter()
|
||||
.map(|p| {
|
||||
self.flatten_expression(statements_flattened, p)
|
||||
.unwrap()
|
||||
.get_field_unchecked()
|
||||
})
|
||||
.collect();
|
||||
|
@ -1197,7 +1155,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
.into_iter()
|
||||
.map(|param_expr| self.flatten_expression(statements_flattened, param_expr))
|
||||
.into_iter()
|
||||
.map(|x| x.unwrap().get_field_unchecked())
|
||||
.map(|x| x.get_field_unchecked())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
for (concrete_argument, formal_argument) in
|
||||
|
@ -1280,22 +1238,15 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
&mut self,
|
||||
statements_flattened: &mut FlatStatements<T>,
|
||||
expr: ZirExpression<'ast, T>,
|
||||
) -> Fallible<'ast, T, FlatUExpression<T>> {
|
||||
) -> FlatUExpression<T> {
|
||||
match expr {
|
||||
ZirExpression::FieldElement(e) => {
|
||||
let (e, s) = self.flatten_field_expression(statements_flattened, e).split();
|
||||
|
||||
Fallible {
|
||||
inner: FlatUExpression::with_field(e),
|
||||
success: s
|
||||
}
|
||||
FlatUExpression::with_field(self.flatten_field_expression(statements_flattened, e))
|
||||
}
|
||||
ZirExpression::Boolean(e) => {
|
||||
let e = self.flatten_boolean_expression(statements_flattened, e);
|
||||
|
||||
FlatUExpression::with_field(e).into()
|
||||
},
|
||||
ZirExpression::Uint(e) => self.flatten_uint_expression(statements_flattened, e).into(),
|
||||
ZirExpression::Boolean(e) => FlatUExpression::with_field(
|
||||
self.flatten_boolean_expression(statements_flattened, e),
|
||||
),
|
||||
ZirExpression::Uint(e) => self.flatten_uint_expression(statements_flattened, e),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1477,6 +1428,12 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
let should_reduce = should_reduce.to_bool();
|
||||
|
||||
let res = match expr.into_inner() {
|
||||
UExpressionInner::Block(statements, box value) => {
|
||||
for s in statements {
|
||||
self.flatten_statement(statements_flattened, s);
|
||||
}
|
||||
self.flatten_uint_expression(statements_flattened, value)
|
||||
}
|
||||
UExpressionInner::Value(x) => {
|
||||
FlatUExpression::with_field(FlatExpression::Number(T::from(x as usize)))
|
||||
} // force to be a field element
|
||||
|
@ -1650,7 +1607,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
condition,
|
||||
consequence,
|
||||
alternative,
|
||||
).unwrap(),
|
||||
),
|
||||
UExpressionInner::Xor(box left, box right) => {
|
||||
let left_metadata = left.metadata.clone().unwrap();
|
||||
let right_metadata = right.metadata.clone().unwrap();
|
||||
|
@ -2028,16 +1985,21 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
&mut self,
|
||||
statements_flattened: &mut FlatStatements<T>,
|
||||
expr: FieldElementExpression<'ast, T>,
|
||||
) -> Fallible<'ast, T, FlatExpression<T>> {
|
||||
) -> FlatExpression<T> {
|
||||
match expr {
|
||||
FieldElementExpression::Block(statements, value) => todo!("flatten block with or without extracting panics depending on whether we're in a branch"),
|
||||
FieldElementExpression::Number(x) => FlatExpression::Number(x).into(), // force to be a field element
|
||||
FieldElementExpression::Block(statements, box value) => {
|
||||
for s in statements {
|
||||
self.flatten_statement(statements_flattened, s);
|
||||
}
|
||||
self.flatten_field_expression(statements_flattened, value)
|
||||
}
|
||||
FieldElementExpression::Number(x) => FlatExpression::Number(x), // force to be a field element
|
||||
FieldElementExpression::Identifier(x) => {
|
||||
FlatExpression::Identifier(*self.layout.get(&x).unwrap_or_else(|| panic!("{}", x))).into()
|
||||
FlatExpression::Identifier(*self.layout.get(&x).unwrap_or_else(|| panic!("{}", x)))
|
||||
}
|
||||
FieldElementExpression::Add(box left, box right) => {
|
||||
let (left_flattened, left_success) = self.flatten_field_expression(statements_flattened, left).split();
|
||||
let (right_flattened, right_success) = self.flatten_field_expression(statements_flattened, right).split();
|
||||
let left_flattened = self.flatten_field_expression(statements_flattened, left);
|
||||
let right_flattened = self.flatten_field_expression(statements_flattened, right);
|
||||
let new_left = if left_flattened.is_linear() {
|
||||
left_flattened
|
||||
} else {
|
||||
|
@ -2052,11 +2014,11 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
statements_flattened.push(FlatStatement::Definition(id, right_flattened));
|
||||
FlatExpression::Identifier(id)
|
||||
};
|
||||
FlatExpression::Add(box new_left, box new_right).into()
|
||||
FlatExpression::Add(box new_left, box new_right)
|
||||
}
|
||||
FieldElementExpression::Sub(box left, box right) => {
|
||||
let (left_flattened, left_success) = self.flatten_field_expression(statements_flattened, left).split();
|
||||
let (right_flattened, right_success) = self.flatten_field_expression(statements_flattened, right).split();
|
||||
let left_flattened = self.flatten_field_expression(statements_flattened, left);
|
||||
let right_flattened = self.flatten_field_expression(statements_flattened, right);
|
||||
|
||||
let new_left = if left_flattened.is_linear() {
|
||||
left_flattened
|
||||
|
@ -2073,11 +2035,11 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
FlatExpression::Identifier(id)
|
||||
};
|
||||
|
||||
FlatExpression::Sub(box new_left, box new_right).into()
|
||||
FlatExpression::Sub(box new_left, box new_right)
|
||||
}
|
||||
FieldElementExpression::Mult(box left, box right) => {
|
||||
let (left_flattened, left_success) = self.flatten_field_expression(statements_flattened, left).split();
|
||||
let (right_flattened, right_success) = self.flatten_field_expression(statements_flattened, right).split();
|
||||
let left_flattened = self.flatten_field_expression(statements_flattened, left);
|
||||
let right_flattened = self.flatten_field_expression(statements_flattened, right);
|
||||
let new_left = if left_flattened.is_linear() {
|
||||
left_flattened
|
||||
} else {
|
||||
|
@ -2092,11 +2054,11 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
statements_flattened.push(FlatStatement::Definition(id, right_flattened));
|
||||
FlatExpression::Identifier(id)
|
||||
};
|
||||
FlatExpression::Mult(box new_left, box new_right).into()
|
||||
FlatExpression::Mult(box new_left, box new_right)
|
||||
}
|
||||
FieldElementExpression::Div(box left, box right) => {
|
||||
let (left_flattened, left_success) = self.flatten_field_expression(statements_flattened, left).split();
|
||||
let (right_flattened, right_success) = self.flatten_field_expression(statements_flattened, right).split();
|
||||
let left_flattened = self.flatten_field_expression(statements_flattened, left);
|
||||
let right_flattened = self.flatten_field_expression(statements_flattened, right);
|
||||
let new_left: FlatExpression<T> = {
|
||||
let id = self.use_sym();
|
||||
statements_flattened.push(FlatStatement::Definition(id, left_flattened));
|
||||
|
@ -2137,13 +2099,14 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
FlatExpression::Mult(box new_right, box inverse.into()),
|
||||
));
|
||||
|
||||
FlatExpression::from(inverse).into()
|
||||
inverse.into()
|
||||
}
|
||||
FieldElementExpression::Pow(box base, box exponent) => {
|
||||
match exponent.into_inner() {
|
||||
UExpressionInner::Value(ref e) => {
|
||||
// flatten the base expression
|
||||
let (base_flattened, base_success) = self.flatten_field_expression(statements_flattened, base).split();
|
||||
let base_flattened =
|
||||
self.flatten_field_expression(statements_flattened, base.clone());
|
||||
|
||||
// we require from the base to be linear
|
||||
// TODO change that
|
||||
|
@ -2207,7 +2170,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
}
|
||||
false => acc, // this bit is false, keep the previous result
|
||||
},
|
||||
).into()
|
||||
)
|
||||
}
|
||||
_ => panic!("Expected number as pow exponent"),
|
||||
}
|
||||
|
@ -2239,7 +2202,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
let flat_expressions = exprs
|
||||
.into_iter()
|
||||
.map(|expr| self.flatten_expression(statements_flattened, expr))
|
||||
.map(|x| x.unwrap().get_field_unchecked())
|
||||
.map(|x| x.get_field_unchecked())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
statements_flattened.push(FlatStatement::Return(FlatExpressionList {
|
||||
|
@ -2253,7 +2216,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
// define n variables with n the number of primitive types for v_type
|
||||
// assign them to the n primitive types for expr
|
||||
|
||||
let (rhs, rhs_success) = self.flatten_expression(statements_flattened, expr).split();
|
||||
let rhs = self.flatten_expression(statements_flattened, expr);
|
||||
|
||||
let bits = rhs.bits.clone();
|
||||
|
||||
|
|
|
@ -299,6 +299,7 @@ pub fn fold_array_expression_inner<'ast, T: Field>(
|
|||
array: typed_absy::ArrayExpressionInner<'ast, T>,
|
||||
) -> Vec<zir::ZirExpression<'ast, T>> {
|
||||
match array {
|
||||
typed_absy::ArrayExpressionInner::Block(statements, box value) => unimplemented!(),
|
||||
typed_absy::ArrayExpressionInner::Identifier(id) => {
|
||||
let variables = flatten_identifier_rec(
|
||||
f.fold_name(id),
|
||||
|
@ -427,6 +428,7 @@ pub fn fold_struct_expression_inner<'ast, T: Field>(
|
|||
struc: typed_absy::StructExpressionInner<'ast, T>,
|
||||
) -> Vec<zir::ZirExpression<'ast, T>> {
|
||||
match struc {
|
||||
typed_absy::StructExpressionInner::Block(statements, box value) => unimplemented!(),
|
||||
typed_absy::StructExpressionInner::Identifier(id) => {
|
||||
let variables = flatten_identifier_rec(
|
||||
f.fold_name(id),
|
||||
|
@ -623,6 +625,15 @@ pub fn fold_boolean_expression<'ast, T: Field>(
|
|||
e: typed_absy::BooleanExpression<'ast, T>,
|
||||
) -> zir::BooleanExpression<'ast, T> {
|
||||
match e {
|
||||
typed_absy::BooleanExpression::Block(statements, box value) => {
|
||||
zir::BooleanExpression::Block(
|
||||
statements
|
||||
.into_iter()
|
||||
.flat_map(|s| f.fold_statement(s))
|
||||
.collect(),
|
||||
box f.fold_boolean_expression(value),
|
||||
)
|
||||
}
|
||||
typed_absy::BooleanExpression::Value(v) => zir::BooleanExpression::Value(v),
|
||||
typed_absy::BooleanExpression::Identifier(id) => zir::BooleanExpression::Identifier(
|
||||
flatten_identifier_rec(f.fold_name(id), &typed_absy::types::ConcreteType::Boolean)[0]
|
||||
|
@ -805,6 +816,13 @@ pub fn fold_uint_expression_inner<'ast, T: Field>(
|
|||
e: typed_absy::UExpressionInner<'ast, T>,
|
||||
) -> zir::UExpressionInner<'ast, T> {
|
||||
match e {
|
||||
typed_absy::UExpressionInner::Block(statements, box value) => zir::UExpressionInner::Block(
|
||||
statements
|
||||
.into_iter()
|
||||
.flat_map(|s| f.fold_statement(s))
|
||||
.collect(),
|
||||
box f.fold_uint_expression(value),
|
||||
),
|
||||
typed_absy::UExpressionInner::Value(v) => zir::UExpressionInner::Value(v),
|
||||
typed_absy::UExpressionInner::Identifier(id) => zir::UExpressionInner::Identifier(
|
||||
flatten_identifier_rec(
|
||||
|
|
|
@ -74,8 +74,10 @@ fn get_canonical_function<'ast, T: Field>(
|
|||
}
|
||||
}
|
||||
|
||||
type InlineResult<'ast, T> =
|
||||
Result<Output<Vec<TypedExpression<'ast, T>>, Vec<Versions<'ast>>>, InlineError<'ast, T>>;
|
||||
type InlineResult<'ast, T> = Result<
|
||||
Output<(Vec<TypedStatement<'ast, T>>, Vec<TypedExpression<'ast, T>>), Vec<Versions<'ast>>>,
|
||||
InlineError<'ast, T>,
|
||||
>;
|
||||
|
||||
pub fn inline_call<'a, 'ast, T: Field>(
|
||||
k: DeclarationFunctionKey<'ast>,
|
||||
|
@ -227,14 +229,7 @@ pub fn inline_call<'a, 'ast, T: Field>(
|
|||
.chain(std::iter::once(pop_log))
|
||||
.collect();
|
||||
|
||||
let e = match expressions[0].clone() {
|
||||
TypedExpression::FieldElement(e) => e,
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
let res = crate::typed_absy::FieldElementExpression::Block(statements, box e);
|
||||
|
||||
Ok(incomplete_data
|
||||
.map(|d| Output::Incomplete(vec![res.clone().into()], d))
|
||||
.unwrap_or_else(|| Output::Complete(vec![res.into()])))
|
||||
.map(|d| Output::Incomplete((statements.clone(), expressions.clone()), d))
|
||||
.unwrap_or_else(|| Output::Complete((statements, expressions))))
|
||||
}
|
||||
|
|
|
@ -22,7 +22,7 @@ use crate::typed_absy::Folder;
|
|||
use std::collections::HashMap;
|
||||
|
||||
use crate::typed_absy::{
|
||||
ArrayExpression, ArrayExpressionInner, ArrayType, BooleanExpression, CoreIdentifier,
|
||||
ArrayExpression, ArrayExpressionInner, ArrayType, Block, BooleanExpression, CoreIdentifier,
|
||||
DeclarationFunctionKey, FieldElementExpression, FunctionCall, Identifier, StructExpression,
|
||||
StructExpressionInner, Type, Typed, TypedExpression, TypedExpressionList, TypedFunction,
|
||||
TypedFunctionSymbol, TypedModule, TypedProgram, TypedStatement, UExpression, UExpressionInner,
|
||||
|
@ -197,10 +197,13 @@ impl<'ast, 'a, T: Field> Reducer<'ast, 'a, T> {
|
|||
key: DeclarationFunctionKey<'ast>,
|
||||
generics: Vec<Option<UExpression<'ast, T>>>,
|
||||
arguments: Vec<TypedExpression<'ast, T>>,
|
||||
output_types: Vec<Type<'ast, T>>,
|
||||
output_type: Type<'ast, T>,
|
||||
) -> Result<E, Error>
|
||||
where
|
||||
E: FunctionCall<'ast, T> + TryFrom<TypedExpression<'ast, T>, Error = ()> + std::fmt::Debug,
|
||||
E: Block<'ast, T>
|
||||
+ FunctionCall<'ast, T>
|
||||
+ TryFrom<TypedExpression<'ast, T>, Error = ()>
|
||||
+ std::fmt::Debug,
|
||||
{
|
||||
let generics = generics
|
||||
.into_iter()
|
||||
|
@ -216,18 +219,23 @@ impl<'ast, 'a, T: Field> Reducer<'ast, 'a, T> {
|
|||
key.clone(),
|
||||
generics,
|
||||
arguments,
|
||||
output_types,
|
||||
vec![output_type.clone()],
|
||||
&self.program,
|
||||
&mut self.versions,
|
||||
);
|
||||
|
||||
match res {
|
||||
Ok(Output::Complete(expressions)) => {
|
||||
Ok(Output::Complete((statements, mut expressions))) => {
|
||||
self.complete &= true;
|
||||
Ok(expressions[0].clone().try_into().unwrap())
|
||||
Ok(E::block(
|
||||
statements,
|
||||
expressions.pop().unwrap().try_into().unwrap(),
|
||||
output_type,
|
||||
))
|
||||
}
|
||||
Ok(Output::Incomplete(expressions, delta_for_loop_versions)) => {
|
||||
Ok(Output::Incomplete((statements, expressions), delta_for_loop_versions)) => {
|
||||
self.complete = false;
|
||||
self.statement_buffer.extend(statements);
|
||||
self.for_loop_versions_after.extend(delta_for_loop_versions);
|
||||
Ok(expressions[0].clone().try_into().unwrap())
|
||||
}
|
||||
|
@ -246,7 +254,7 @@ impl<'ast, 'a, T: Field> Reducer<'ast, 'a, T> {
|
|||
output_types.pop().unwrap(),
|
||||
))
|
||||
}
|
||||
Err(InlineError::Flat(embed, generics, arguments, output_types)) => {
|
||||
Err(InlineError::Flat(embed, generics, arguments, mut output_types)) => {
|
||||
let identifier = Identifier::from(CoreIdentifier::Call(0)).version(
|
||||
*self
|
||||
.versions
|
||||
|
@ -254,7 +262,7 @@ impl<'ast, 'a, T: Field> Reducer<'ast, 'a, T> {
|
|||
.and_modify(|e| *e += 1) // if it was already declared, we increment
|
||||
.or_insert(0),
|
||||
);
|
||||
let var = Variable::with_id_and_type(identifier, output_types[0].clone());
|
||||
let var = Variable::with_id_and_type(identifier, output_types.pop().unwrap());
|
||||
|
||||
let v = vec![var.clone().into()];
|
||||
|
||||
|
@ -291,6 +299,8 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> {
|
|||
.map(|a| self.fold_expression(a))
|
||||
.collect::<Result<_, _>>()?;
|
||||
|
||||
unimplemented!("multi def needs to be put in blocks");
|
||||
|
||||
match inline_call(
|
||||
key,
|
||||
generics,
|
||||
|
@ -299,25 +309,33 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> {
|
|||
&self.program,
|
||||
&mut self.versions,
|
||||
) {
|
||||
Ok(Output::Complete(expressions)) => {
|
||||
Ok(Output::Complete((statements, expressions))) => {
|
||||
assert_eq!(v.len(), expressions.len());
|
||||
|
||||
self.complete &= true;
|
||||
|
||||
Ok(v.into_iter()
|
||||
.zip(expressions)
|
||||
.map(|(v, e)| TypedStatement::Definition(v, e))
|
||||
Ok(statements
|
||||
.into_iter()
|
||||
.chain(
|
||||
v.into_iter()
|
||||
.zip(expressions)
|
||||
.map(|(v, e)| TypedStatement::Definition(v, e)),
|
||||
)
|
||||
.collect())
|
||||
}
|
||||
Ok(Output::Incomplete(expressions, delta_for_loop_versions)) => {
|
||||
Ok(Output::Incomplete((statements, expressions), delta_for_loop_versions)) => {
|
||||
assert_eq!(v.len(), expressions.len());
|
||||
|
||||
self.complete = false;
|
||||
self.for_loop_versions_after.extend(delta_for_loop_versions);
|
||||
|
||||
Ok(v.into_iter()
|
||||
.zip(expressions)
|
||||
.map(|(v, e)| TypedStatement::Definition(v, e))
|
||||
Ok(statements
|
||||
.into_iter()
|
||||
.chain(
|
||||
v.into_iter()
|
||||
.zip(expressions)
|
||||
.map(|(v, e)| TypedStatement::Definition(v, e)),
|
||||
)
|
||||
.collect())
|
||||
}
|
||||
Err(InlineError::Generic(decl, conc)) => Err(Error::Incompatible(format!(
|
||||
|
@ -425,7 +443,7 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> {
|
|||
) -> Result<BooleanExpression<'ast, T>, Self::Error> {
|
||||
match e {
|
||||
BooleanExpression::FunctionCall(key, generics, arguments) => {
|
||||
self.fold_function_call(key, generics, arguments, vec![Type::Boolean])
|
||||
self.fold_function_call(key, generics, arguments, Type::Boolean)
|
||||
}
|
||||
e => fold_boolean_expression(self, e),
|
||||
}
|
||||
|
@ -440,7 +458,7 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> {
|
|||
key.clone(),
|
||||
generics.clone(),
|
||||
arguments.clone(),
|
||||
vec![e.get_type()],
|
||||
e.get_type(),
|
||||
),
|
||||
_ => fold_uint_expression(self, e),
|
||||
}
|
||||
|
@ -452,7 +470,7 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> {
|
|||
) -> Result<FieldElementExpression<'ast, T>, Self::Error> {
|
||||
match e {
|
||||
FieldElementExpression::FunctionCall(key, generic, arguments) => {
|
||||
self.fold_function_call(key, generic, arguments, vec![Type::FieldElement])
|
||||
self.fold_function_call(key, generic, arguments, Type::FieldElement)
|
||||
}
|
||||
e => fold_field_expression(self, e),
|
||||
}
|
||||
|
@ -469,7 +487,7 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> {
|
|||
key.clone(),
|
||||
generics,
|
||||
arguments.clone(),
|
||||
vec![Type::array(ty.clone())],
|
||||
Type::array(ty.clone()),
|
||||
)
|
||||
.map(|e| e.into_inner()),
|
||||
ArrayExpressionInner::Slice(box array, box from, box to) => {
|
||||
|
@ -501,7 +519,7 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> {
|
|||
key.clone(),
|
||||
generics.clone(),
|
||||
arguments.clone(),
|
||||
vec![e.get_type()],
|
||||
e.get_type(),
|
||||
),
|
||||
_ => fold_struct_expression(self, e),
|
||||
}
|
||||
|
|
|
@ -126,6 +126,9 @@ impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> {
|
|||
use self::UExpressionInner::*;
|
||||
|
||||
let res = match inner {
|
||||
b @ Block(..) => force_no_reduce(
|
||||
fold_uint_expression_inner(self, e.bitwidth, b).annotate(e.bitwidth),
|
||||
),
|
||||
Value(v) => Value(v).annotate(range).with_max(v),
|
||||
Identifier(id) => Identifier(id.clone()).annotate(range).metadata(
|
||||
self.ids
|
||||
|
|
|
@ -257,6 +257,13 @@ pub fn fold_array_expression_inner<'ast, T: Field, F: Folder<'ast, T>>(
|
|||
e: ArrayExpressionInner<'ast, T>,
|
||||
) -> ArrayExpressionInner<'ast, T> {
|
||||
match e {
|
||||
ArrayExpressionInner::Block(statements, box value) => ArrayExpressionInner::Block(
|
||||
statements
|
||||
.into_iter()
|
||||
.flat_map(|s| f.fold_statement(s))
|
||||
.collect(),
|
||||
box f.fold_array_expression(value),
|
||||
),
|
||||
ArrayExpressionInner::Identifier(id) => ArrayExpressionInner::Identifier(f.fold_name(id)),
|
||||
ArrayExpressionInner::Value(exprs) => ArrayExpressionInner::Value(
|
||||
exprs
|
||||
|
@ -308,6 +315,13 @@ pub fn fold_struct_expression_inner<'ast, T: Field, F: Folder<'ast, T>>(
|
|||
e: StructExpressionInner<'ast, T>,
|
||||
) -> StructExpressionInner<'ast, T> {
|
||||
match e {
|
||||
StructExpressionInner::Block(statements, box value) => StructExpressionInner::Block(
|
||||
statements
|
||||
.into_iter()
|
||||
.flat_map(|s| f.fold_statement(s))
|
||||
.collect(),
|
||||
box f.fold_struct_expression(value),
|
||||
),
|
||||
StructExpressionInner::Identifier(id) => StructExpressionInner::Identifier(f.fold_name(id)),
|
||||
StructExpressionInner::Value(exprs) => {
|
||||
StructExpressionInner::Value(exprs.into_iter().map(|e| f.fold_expression(e)).collect())
|
||||
|
@ -428,6 +442,13 @@ pub fn fold_boolean_expression<'ast, T: Field, F: Folder<'ast, T>>(
|
|||
e: BooleanExpression<'ast, T>,
|
||||
) -> BooleanExpression<'ast, T> {
|
||||
match e {
|
||||
BooleanExpression::Block(statements, box value) => BooleanExpression::Block(
|
||||
statements
|
||||
.into_iter()
|
||||
.flat_map(|s| f.fold_statement(s))
|
||||
.collect(),
|
||||
box f.fold_boolean_expression(value),
|
||||
),
|
||||
BooleanExpression::Value(v) => BooleanExpression::Value(v),
|
||||
BooleanExpression::Identifier(id) => BooleanExpression::Identifier(f.fold_name(id)),
|
||||
BooleanExpression::FieldEq(box e1, box e2) => {
|
||||
|
@ -551,6 +572,13 @@ pub fn fold_uint_expression_inner<'ast, T: Field, F: Folder<'ast, T>>(
|
|||
e: UExpressionInner<'ast, T>,
|
||||
) -> UExpressionInner<'ast, T> {
|
||||
match e {
|
||||
UExpressionInner::Block(statements, box value) => UExpressionInner::Block(
|
||||
statements
|
||||
.into_iter()
|
||||
.flat_map(|s| f.fold_statement(s))
|
||||
.collect(),
|
||||
box f.fold_uint_expression(value),
|
||||
),
|
||||
UExpressionInner::Value(v) => UExpressionInner::Value(v),
|
||||
UExpressionInner::Identifier(id) => UExpressionInner::Identifier(f.fold_name(id)),
|
||||
UExpressionInner::Add(box left, box right) => {
|
||||
|
|
|
@ -591,6 +591,16 @@ impl<'ast, T: fmt::Display> fmt::Display for ArrayExpression<'ast, T> {
|
|||
impl<'ast, T: fmt::Display> fmt::Display for StructExpression<'ast, T> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
match self.inner {
|
||||
StructExpressionInner::Block(ref statements, ref value) => write!(
|
||||
f,
|
||||
"{{{}}}",
|
||||
statements
|
||||
.iter()
|
||||
.map(|s| s.to_string())
|
||||
.chain(std::iter::once(value.to_string()))
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n")
|
||||
),
|
||||
StructExpressionInner::Identifier(ref var) => write!(f, "{}", var),
|
||||
StructExpressionInner::Value(ref values) => write!(
|
||||
f,
|
||||
|
@ -805,6 +815,10 @@ impl<'ast, T> From<T> for FieldElementExpression<'ast, T> {
|
|||
/// An expression of type `bool`
|
||||
#[derive(Clone, PartialEq, Debug, Hash, Eq)]
|
||||
pub enum BooleanExpression<'ast, T> {
|
||||
Block(
|
||||
Vec<TypedStatement<'ast, T>>,
|
||||
Box<BooleanExpression<'ast, T>>,
|
||||
),
|
||||
Identifier(Identifier<'ast>),
|
||||
Value(bool),
|
||||
FieldLt(
|
||||
|
@ -961,6 +975,7 @@ impl<'ast, T> std::iter::FromIterator<TypedExpressionOrSpread<'ast, T>> for Arra
|
|||
|
||||
#[derive(Clone, PartialEq, Debug, Hash, Eq)]
|
||||
pub enum ArrayExpressionInner<'ast, T> {
|
||||
Block(Vec<TypedStatement<'ast, T>>, Box<ArrayExpression<'ast, T>>),
|
||||
Identifier(Identifier<'ast>),
|
||||
Value(ArrayValue<'ast, T>),
|
||||
FunctionCall(
|
||||
|
@ -1069,6 +1084,7 @@ impl<'ast, T> StructExpression<'ast, T> {
|
|||
|
||||
#[derive(Clone, PartialEq, Debug, Hash, Eq)]
|
||||
pub enum StructExpressionInner<'ast, T> {
|
||||
Block(Vec<TypedStatement<'ast, T>>, Box<StructExpression<'ast, T>>),
|
||||
Identifier(Identifier<'ast>),
|
||||
Value(Vec<TypedExpression<'ast, T>>),
|
||||
FunctionCall(
|
||||
|
@ -1276,6 +1292,16 @@ impl<'ast, T: fmt::Display> fmt::Display for FieldElementExpression<'ast, T> {
|
|||
impl<'ast, T: fmt::Display> fmt::Display for UExpression<'ast, T> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
match self.inner {
|
||||
UExpressionInner::Block(ref statements, ref value) => write!(
|
||||
f,
|
||||
"{{{}}}",
|
||||
statements
|
||||
.iter()
|
||||
.map(|s| s.to_string())
|
||||
.chain(std::iter::once(value.to_string()))
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n")
|
||||
),
|
||||
UExpressionInner::Value(ref v) => write!(f, "{}", v),
|
||||
UExpressionInner::Identifier(ref var) => write!(f, "{}", var),
|
||||
UExpressionInner::Add(ref lhs, ref rhs) => write!(f, "({} + {})", lhs, rhs),
|
||||
|
@ -1333,6 +1359,16 @@ impl<'ast, T: fmt::Display> fmt::Display for UExpression<'ast, T> {
|
|||
impl<'ast, T: fmt::Display> fmt::Display for BooleanExpression<'ast, T> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
match *self {
|
||||
BooleanExpression::Block(ref statements, ref value) => write!(
|
||||
f,
|
||||
"{{{}}}",
|
||||
statements
|
||||
.iter()
|
||||
.map(|s| s.to_string())
|
||||
.chain(std::iter::once(value.to_string()))
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n")
|
||||
),
|
||||
BooleanExpression::Identifier(ref var) => write!(f, "{}", var),
|
||||
BooleanExpression::FieldLt(ref lhs, ref rhs) => write!(f, "{} < {}", lhs, rhs),
|
||||
BooleanExpression::FieldLe(ref lhs, ref rhs) => write!(f, "{} <= {}", lhs, rhs),
|
||||
|
@ -1390,6 +1426,16 @@ impl<'ast, T: fmt::Display> fmt::Display for BooleanExpression<'ast, T> {
|
|||
impl<'ast, T: fmt::Display> fmt::Display for ArrayExpressionInner<'ast, T> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
match *self {
|
||||
ArrayExpressionInner::Block(ref statements, ref value) => write!(
|
||||
f,
|
||||
"{{{}}}",
|
||||
statements
|
||||
.iter()
|
||||
.map(|s| s.to_string())
|
||||
.chain(std::iter::once(value.to_string()))
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n")
|
||||
),
|
||||
ArrayExpressionInner::Identifier(ref var) => write!(f, "{}", var),
|
||||
ArrayExpressionInner::Value(ref values) => write!(
|
||||
f,
|
||||
|
@ -1809,3 +1855,76 @@ impl<'ast, T: Field> FunctionCall<'ast, T> for StructExpression<'ast, T> {
|
|||
StructExpressionInner::FunctionCall(key, generics, arguments).annotate(struct_ty)
|
||||
}
|
||||
}
|
||||
|
||||
pub trait Block<'ast, T> {
|
||||
fn block(
|
||||
statements: Vec<TypedStatement<'ast, T>>,
|
||||
value: Self,
|
||||
output_type: Type<'ast, T>,
|
||||
) -> Self;
|
||||
}
|
||||
|
||||
impl<'ast, T: Field> Block<'ast, T> for FieldElementExpression<'ast, T> {
|
||||
fn block(
|
||||
statements: Vec<TypedStatement<'ast, T>>,
|
||||
value: Self,
|
||||
output_type: Type<'ast, T>,
|
||||
) -> Self {
|
||||
assert_eq!(output_type, Type::FieldElement);
|
||||
FieldElementExpression::Block(statements, box value)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast, T: Field> Block<'ast, T> for BooleanExpression<'ast, T> {
|
||||
fn block(
|
||||
statements: Vec<TypedStatement<'ast, T>>,
|
||||
value: Self,
|
||||
output_type: Type<'ast, T>,
|
||||
) -> Self {
|
||||
assert_eq!(output_type, Type::Boolean);
|
||||
BooleanExpression::Block(statements, box value)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast, T: Field> Block<'ast, T> for UExpression<'ast, T> {
|
||||
fn block(
|
||||
statements: Vec<TypedStatement<'ast, T>>,
|
||||
value: Self,
|
||||
output_type: Type<'ast, T>,
|
||||
) -> Self {
|
||||
let bitwidth = match output_type {
|
||||
Type::Uint(bitwidth) => bitwidth,
|
||||
_ => unreachable!(),
|
||||
};
|
||||
UExpressionInner::Block(statements, box value).annotate(bitwidth)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast, T: Field> Block<'ast, T> for ArrayExpression<'ast, T> {
|
||||
fn block(
|
||||
statements: Vec<TypedStatement<'ast, T>>,
|
||||
value: Self,
|
||||
output_type: Type<'ast, T>,
|
||||
) -> Self {
|
||||
let array_ty = match output_type {
|
||||
Type::Array(array_ty) => array_ty,
|
||||
_ => unreachable!(),
|
||||
};
|
||||
ArrayExpressionInner::Block(statements, box value).annotate(*array_ty.ty, array_ty.size)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast, T: Field> Block<'ast, T> for StructExpression<'ast, T> {
|
||||
fn block(
|
||||
statements: Vec<TypedStatement<'ast, T>>,
|
||||
value: Self,
|
||||
output_type: Type<'ast, T>,
|
||||
) -> Self {
|
||||
let struct_ty = match output_type {
|
||||
Type::Struct(struct_ty) => struct_ty,
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
StructExpressionInner::Block(statements, box value).annotate(struct_ty)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -302,6 +302,14 @@ pub fn fold_array_expression_inner<'ast, T: Field, F: ResultFolder<'ast, T>>(
|
|||
e: ArrayExpressionInner<'ast, T>,
|
||||
) -> Result<ArrayExpressionInner<'ast, T>, F::Error> {
|
||||
let e = match e {
|
||||
ArrayExpressionInner::Block(statements, box value) => ArrayExpressionInner::Block(
|
||||
statements
|
||||
.into_iter()
|
||||
.map(|s| f.fold_statement(s))
|
||||
.collect::<Result<Vec<_>, _>>()
|
||||
.map(|r| r.into_iter().flatten().collect())?,
|
||||
box f.fold_array_expression(value)?,
|
||||
),
|
||||
ArrayExpressionInner::Identifier(id) => ArrayExpressionInner::Identifier(f.fold_name(id)?),
|
||||
ArrayExpressionInner::Value(exprs) => ArrayExpressionInner::Value(
|
||||
exprs
|
||||
|
@ -357,6 +365,14 @@ pub fn fold_struct_expression_inner<'ast, T: Field, F: ResultFolder<'ast, T>>(
|
|||
e: StructExpressionInner<'ast, T>,
|
||||
) -> Result<StructExpressionInner<'ast, T>, F::Error> {
|
||||
let e = match e {
|
||||
StructExpressionInner::Block(statements, box value) => StructExpressionInner::Block(
|
||||
statements
|
||||
.into_iter()
|
||||
.map(|s| f.fold_statement(s))
|
||||
.collect::<Result<Vec<_>, _>>()
|
||||
.map(|r| r.into_iter().flatten().collect())?,
|
||||
box f.fold_struct_expression(value)?,
|
||||
),
|
||||
StructExpressionInner::Identifier(id) => {
|
||||
StructExpressionInner::Identifier(f.fold_name(id)?)
|
||||
}
|
||||
|
@ -491,6 +507,14 @@ pub fn fold_boolean_expression<'ast, T: Field, F: ResultFolder<'ast, T>>(
|
|||
e: BooleanExpression<'ast, T>,
|
||||
) -> Result<BooleanExpression<'ast, T>, F::Error> {
|
||||
let e = match e {
|
||||
BooleanExpression::Block(statements, box value) => BooleanExpression::Block(
|
||||
statements
|
||||
.into_iter()
|
||||
.map(|s| f.fold_statement(s))
|
||||
.collect::<Result<Vec<_>, _>>()
|
||||
.map(|r| r.into_iter().flatten().collect())?,
|
||||
box f.fold_boolean_expression(value)?,
|
||||
),
|
||||
BooleanExpression::Value(v) => BooleanExpression::Value(v),
|
||||
BooleanExpression::Identifier(id) => BooleanExpression::Identifier(f.fold_name(id)?),
|
||||
BooleanExpression::FieldEq(box e1, box e2) => {
|
||||
|
@ -618,6 +642,14 @@ pub fn fold_uint_expression_inner<'ast, T: Field, F: ResultFolder<'ast, T>>(
|
|||
e: UExpressionInner<'ast, T>,
|
||||
) -> Result<UExpressionInner<'ast, T>, F::Error> {
|
||||
let e = match e {
|
||||
UExpressionInner::Block(statements, box value) => UExpressionInner::Block(
|
||||
statements
|
||||
.into_iter()
|
||||
.map(|s| f.fold_statement(s))
|
||||
.collect::<Result<Vec<_>, _>>()
|
||||
.map(|r| r.into_iter().flatten().collect())?,
|
||||
box f.fold_uint_expression(value)?,
|
||||
),
|
||||
UExpressionInner::Value(v) => UExpressionInner::Value(v),
|
||||
UExpressionInner::Identifier(id) => UExpressionInner::Identifier(f.fold_name(id)?),
|
||||
UExpressionInner::Add(box left, box right) => {
|
||||
|
|
|
@ -175,6 +175,7 @@ impl<'ast, T> PartialEq<usize> for UExpression<'ast, T> {
|
|||
|
||||
#[derive(Clone, PartialEq, Eq, Hash, Debug)]
|
||||
pub enum UExpressionInner<'ast, T> {
|
||||
Block(Vec<TypedStatement<'ast, T>>, Box<UExpression<'ast, T>>),
|
||||
Identifier(Identifier<'ast>),
|
||||
Value(u128),
|
||||
Add(Box<UExpression<'ast, T>>, Box<UExpression<'ast, T>>),
|
||||
|
|
|
@ -169,6 +169,13 @@ pub fn fold_boolean_expression<'ast, T: Field, F: Folder<'ast, T>>(
|
|||
e: BooleanExpression<'ast, T>,
|
||||
) -> BooleanExpression<'ast, T> {
|
||||
match e {
|
||||
BooleanExpression::Block(statements, box value) => BooleanExpression::Block(
|
||||
statements
|
||||
.into_iter()
|
||||
.flat_map(|s| f.fold_statement(s))
|
||||
.collect(),
|
||||
box f.fold_boolean_expression(value),
|
||||
),
|
||||
BooleanExpression::Value(v) => BooleanExpression::Value(v),
|
||||
BooleanExpression::Identifier(id) => BooleanExpression::Identifier(f.fold_name(id)),
|
||||
BooleanExpression::FieldEq(box e1, box e2) => {
|
||||
|
@ -265,6 +272,13 @@ pub fn fold_uint_expression_inner<'ast, T: Field, F: Folder<'ast, T>>(
|
|||
e: UExpressionInner<'ast, T>,
|
||||
) -> UExpressionInner<'ast, T> {
|
||||
match e {
|
||||
UExpressionInner::Block(statements, box value) => UExpressionInner::Block(
|
||||
statements
|
||||
.into_iter()
|
||||
.flat_map(|s| f.fold_statement(s))
|
||||
.collect(),
|
||||
box f.fold_uint_expression(value),
|
||||
),
|
||||
UExpressionInner::Value(v) => UExpressionInner::Value(v),
|
||||
UExpressionInner::Identifier(id) => UExpressionInner::Identifier(f.fold_name(id)),
|
||||
UExpressionInner::Add(box left, box right) => {
|
||||
|
|
|
@ -275,6 +275,7 @@ pub enum FieldElementExpression<'ast, T> {
|
|||
/// An expression of type `bool`
|
||||
#[derive(Clone, PartialEq, Hash, Eq)]
|
||||
pub enum BooleanExpression<'ast, T> {
|
||||
Block(Vec<ZirStatement<'ast, T>>, Box<BooleanExpression<'ast, T>>),
|
||||
Identifier(Identifier<'ast>),
|
||||
Value(bool),
|
||||
FieldLt(
|
||||
|
@ -422,6 +423,16 @@ impl<'ast, T: fmt::Display> fmt::Display for FieldElementExpression<'ast, T> {
|
|||
impl<'ast, T: fmt::Display> fmt::Display for UExpression<'ast, T> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
match self.inner {
|
||||
UExpressionInner::Block(ref statements, ref value) => write!(
|
||||
f,
|
||||
"{{{}}}",
|
||||
statements
|
||||
.iter()
|
||||
.map(|s| s.to_string())
|
||||
.chain(std::iter::once(value.to_string()))
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n")
|
||||
),
|
||||
UExpressionInner::Value(ref v) => write!(f, "{}", v),
|
||||
UExpressionInner::Identifier(ref var) => write!(f, "{}", var),
|
||||
UExpressionInner::Add(ref lhs, ref rhs) => write!(f, "({} + {})", lhs, rhs),
|
||||
|
@ -447,6 +458,16 @@ impl<'ast, T: fmt::Display> fmt::Display for UExpression<'ast, T> {
|
|||
impl<'ast, T: fmt::Display> fmt::Display for BooleanExpression<'ast, T> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
match *self {
|
||||
BooleanExpression::Block(ref statements, ref value) => write!(
|
||||
f,
|
||||
"{{{}}}",
|
||||
statements
|
||||
.iter()
|
||||
.map(|s| s.to_string())
|
||||
.chain(std::iter::once(value.to_string()))
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n")
|
||||
),
|
||||
BooleanExpression::Identifier(ref var) => write!(f, "{}", var),
|
||||
BooleanExpression::Value(b) => write!(f, "{}", b),
|
||||
BooleanExpression::FieldLt(ref lhs, ref rhs) => write!(f, "{} < {}", lhs, rhs),
|
||||
|
@ -475,6 +496,9 @@ impl<'ast, T: fmt::Display> fmt::Display for BooleanExpression<'ast, T> {
|
|||
impl<'ast, T: fmt::Debug> fmt::Debug for BooleanExpression<'ast, T> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
match *self {
|
||||
BooleanExpression::Block(ref statements, ref value) => {
|
||||
write!(f, "Block({:?}, {:?})", statements, value)
|
||||
}
|
||||
BooleanExpression::Identifier(ref var) => write!(f, "Ide({:?})", var),
|
||||
BooleanExpression::Value(b) => write!(f, "Value({})", b),
|
||||
BooleanExpression::FieldLt(ref lhs, ref rhs) => {
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
use crate::zir::identifier::Identifier;
|
||||
use crate::zir::types::UBitwidth;
|
||||
use crate::zir::BooleanExpression;
|
||||
use crate::zir::{BooleanExpression, ZirStatement};
|
||||
use zokrates_field::Field;
|
||||
|
||||
impl<'ast, T: Field> UExpression<'ast, T> {
|
||||
|
@ -158,6 +158,7 @@ pub struct UExpression<'ast, T> {
|
|||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
||||
pub enum UExpressionInner<'ast, T> {
|
||||
Block(Vec<ZirStatement<'ast, T>>, Box<UExpression<'ast, T>>),
|
||||
Identifier(Identifier<'ast>),
|
||||
Value(u128),
|
||||
Add(Box<UExpression<'ast, T>>, Box<UExpression<'ast, T>>),
|
||||
|
|
Loading…
Reference in a new issue