1
0
Fork 0
mirror of synced 2025-09-24 04:40:05 +00:00

implement blocks for other basic types outside multidef

This commit is contained in:
schaeff 2021-05-12 20:51:39 +02:00
parent 2904568862
commit ec3e8f6310
13 changed files with 366 additions and 139 deletions

View file

@ -1,10 +1,21 @@
def zero(field x) -> field:
assert(x == 0)
return 0
// def zero(field x) -> field:
// assert(x == 0)
// return 0
def inverse(field x) -> field:
assert(x != 0)
// def inverse(field x) -> field:
// assert(x != 0)
// return 1/x
// def main(field x) -> field:
// return if x == 0 then zero(x) else inverse(x) fi
def yes(bool x) -> bool:
assert(x)
return x
def main(field x) -> field:
return if x == 0 then zero(x) else inverse(x) fi
def no(bool x) -> bool:
assert(!x)
return !x
def main(bool x) -> bool:
return if x then yes(x) else no(x) fi

View file

@ -25,45 +25,6 @@ use zokrates_field::Field;
type FlatStatements<T> = Vec<FlatStatement<T>>;
struct Fallible<'ast, T, U> {
pub inner: U,
pub success: BooleanExpression<'ast, T>
}
impl<'ast, U, T: Field> From<U> for Fallible<'ast, T, U> {
fn from(e: U) -> Self {
Fallible {
inner: e,
success: BooleanExpression::Value(true)
}
}
}
impl<'ast, U, T> Fallible<'ast, T, U> {
fn split(self) -> (U, BooleanExpression<'ast, T>) {
(self.inner, self.success)
}
}
impl<'ast, U, T: Field> Fallible<'ast, T, U> {
fn unwrap(self) -> U {
let (res, success) = self.split();
assert_eq!(success, BooleanExpression::Value(true));
return res
}
}
impl<'ast, T: Field> Fallible<'ast, T, FlatUExpression<T>> {
fn get_field_unchecked(self) -> Fallible<'ast, T, FlatExpression<T>> {
let (res, success) = self.split();
Fallible {
inner: res.get_field_unchecked(),
success
}
}
}
/// Flattener, computes flattened program.
#[derive(Debug)]
pub struct Flattener<'ast, T: Field> {
@ -92,12 +53,6 @@ impl<T: Field> FlattenOutput<T> for FlatUExpression<T> {
}
}
impl<'ast, T: Field, U: FlattenOutput<T> + Clone> FlattenOutput<T> for Fallible<'ast, T, U> {
fn flat(&self) -> FlatExpression<T> {
self.inner.clone().flat()
}
}
// We introduce a trait in order to make it possible to make flattening `e` generic over the type of `e`
trait Flatten<'ast, T: Field>: TryFrom<ZirExpression<'ast, T>, Error = ()> + IfElse<'ast, T> {
@ -111,7 +66,7 @@ trait Flatten<'ast, T: Field>: TryFrom<ZirExpression<'ast, T>, Error = ()> + IfE
}
impl<'ast, T: Field> Flatten<'ast, T> for FieldElementExpression<'ast, T> {
type Output = Fallible<'ast, T, FlatExpression<T>>;
type Output = FlatExpression<T>;
fn flatten(
self,
@ -481,7 +436,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
condition: BooleanExpression<'ast, T>,
consequence: U,
alternative: U,
) -> Fallible<'ast, T, FlatUExpression<T>> {
) -> FlatUExpression<T> {
let condition = self.flatten_boolean_expression(statements_flattened, condition);
let mut consequence_statements = vec![];
@ -504,15 +459,15 @@ impl<'ast, T: Field> Flattener<'ast, T> {
.join("\n")
);
// let consequence_statements =
// self.make_conditional(consequence_statements, condition_id.into());
// let alternative_statements = self.make_conditional(
// alternative_statements,
// FlatExpression::Sub(
// box FlatExpression::Number(T::one()),
// box condition_id.into(),
// ),
// );
let consequence_statements =
self.make_conditional(consequence_statements, condition_id.into());
let alternative_statements = self.make_conditional(
alternative_statements,
FlatExpression::Sub(
box FlatExpression::Number(T::one()),
box condition_id.into(),
),
);
println!(
"AFTER\n {}\n",
@ -568,7 +523,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
FlatUExpression {
field: Some(FlatExpression::Identifier(res)),
bits: None,
}.into()
}
}
/// Compute a strict check against a constant
@ -692,6 +647,12 @@ impl<'ast, T: Field> Flattener<'ast, T> {
) -> FlatExpression<T> {
// those will be booleans in the future
match expression {
BooleanExpression::Block(statements, box value) => {
for s in statements {
self.flatten_statement(statements_flattened, s);
}
self.flatten_boolean_expression(statements_flattened, value)
}
BooleanExpression::Identifier(x) => {
FlatExpression::Identifier(*self.layout.get(&x).unwrap())
}
@ -702,8 +663,8 @@ impl<'ast, T: Field> Flattener<'ast, T> {
// We know from semantic checking that lhs and rhs have the same type
// What the expression will flatten to depends on that type
let (lhs_flattened, lhs_success) = self.flatten_field_expression(statements_flattened, lhs).split();
let (rhs_flattened, rhs_success) = self.flatten_field_expression(statements_flattened, rhs).split();
let lhs_flattened = self.flatten_field_expression(statements_flattened, lhs);
let rhs_flattened = self.flatten_field_expression(statements_flattened, rhs);
match (lhs_flattened, rhs_flattened) {
(x, FlatExpression::Number(constant)) => {
@ -918,8 +879,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
let rhs = self.flatten_field_expression(statements_flattened, rhs);
unimplemented!()
// self.eq_check(statements_flattened, lhs, rhs)
self.eq_check(statements_flattened, lhs, rhs)
}
BooleanExpression::UintEq(box lhs, box rhs) => {
// We reduce each side into range and apply the same approach as for field elements
@ -1098,7 +1058,6 @@ impl<'ast, T: Field> Flattener<'ast, T> {
consequence,
alternative,
)
.unwrap()
.get_field_unchecked(),
}
}
@ -1129,7 +1088,6 @@ impl<'ast, T: Field> Flattener<'ast, T> {
.into_iter()
.map(|p| {
self.flatten_expression(statements_flattened, p)
.unwrap()
.get_field_unchecked()
})
.collect();
@ -1197,7 +1155,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
.into_iter()
.map(|param_expr| self.flatten_expression(statements_flattened, param_expr))
.into_iter()
.map(|x| x.unwrap().get_field_unchecked())
.map(|x| x.get_field_unchecked())
.collect::<Vec<_>>();
for (concrete_argument, formal_argument) in
@ -1280,22 +1238,15 @@ impl<'ast, T: Field> Flattener<'ast, T> {
&mut self,
statements_flattened: &mut FlatStatements<T>,
expr: ZirExpression<'ast, T>,
) -> Fallible<'ast, T, FlatUExpression<T>> {
) -> FlatUExpression<T> {
match expr {
ZirExpression::FieldElement(e) => {
let (e, s) = self.flatten_field_expression(statements_flattened, e).split();
Fallible {
inner: FlatUExpression::with_field(e),
success: s
}
FlatUExpression::with_field(self.flatten_field_expression(statements_flattened, e))
}
ZirExpression::Boolean(e) => {
let e = self.flatten_boolean_expression(statements_flattened, e);
FlatUExpression::with_field(e).into()
},
ZirExpression::Uint(e) => self.flatten_uint_expression(statements_flattened, e).into(),
ZirExpression::Boolean(e) => FlatUExpression::with_field(
self.flatten_boolean_expression(statements_flattened, e),
),
ZirExpression::Uint(e) => self.flatten_uint_expression(statements_flattened, e),
}
}
@ -1477,6 +1428,12 @@ impl<'ast, T: Field> Flattener<'ast, T> {
let should_reduce = should_reduce.to_bool();
let res = match expr.into_inner() {
UExpressionInner::Block(statements, box value) => {
for s in statements {
self.flatten_statement(statements_flattened, s);
}
self.flatten_uint_expression(statements_flattened, value)
}
UExpressionInner::Value(x) => {
FlatUExpression::with_field(FlatExpression::Number(T::from(x as usize)))
} // force to be a field element
@ -1650,7 +1607,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
condition,
consequence,
alternative,
).unwrap(),
),
UExpressionInner::Xor(box left, box right) => {
let left_metadata = left.metadata.clone().unwrap();
let right_metadata = right.metadata.clone().unwrap();
@ -2028,16 +1985,21 @@ impl<'ast, T: Field> Flattener<'ast, T> {
&mut self,
statements_flattened: &mut FlatStatements<T>,
expr: FieldElementExpression<'ast, T>,
) -> Fallible<'ast, T, FlatExpression<T>> {
) -> FlatExpression<T> {
match expr {
FieldElementExpression::Block(statements, value) => todo!("flatten block with or without extracting panics depending on whether we're in a branch"),
FieldElementExpression::Number(x) => FlatExpression::Number(x).into(), // force to be a field element
FieldElementExpression::Block(statements, box value) => {
for s in statements {
self.flatten_statement(statements_flattened, s);
}
self.flatten_field_expression(statements_flattened, value)
}
FieldElementExpression::Number(x) => FlatExpression::Number(x), // force to be a field element
FieldElementExpression::Identifier(x) => {
FlatExpression::Identifier(*self.layout.get(&x).unwrap_or_else(|| panic!("{}", x))).into()
FlatExpression::Identifier(*self.layout.get(&x).unwrap_or_else(|| panic!("{}", x)))
}
FieldElementExpression::Add(box left, box right) => {
let (left_flattened, left_success) = self.flatten_field_expression(statements_flattened, left).split();
let (right_flattened, right_success) = self.flatten_field_expression(statements_flattened, right).split();
let left_flattened = self.flatten_field_expression(statements_flattened, left);
let right_flattened = self.flatten_field_expression(statements_flattened, right);
let new_left = if left_flattened.is_linear() {
left_flattened
} else {
@ -2052,11 +2014,11 @@ impl<'ast, T: Field> Flattener<'ast, T> {
statements_flattened.push(FlatStatement::Definition(id, right_flattened));
FlatExpression::Identifier(id)
};
FlatExpression::Add(box new_left, box new_right).into()
FlatExpression::Add(box new_left, box new_right)
}
FieldElementExpression::Sub(box left, box right) => {
let (left_flattened, left_success) = self.flatten_field_expression(statements_flattened, left).split();
let (right_flattened, right_success) = self.flatten_field_expression(statements_flattened, right).split();
let left_flattened = self.flatten_field_expression(statements_flattened, left);
let right_flattened = self.flatten_field_expression(statements_flattened, right);
let new_left = if left_flattened.is_linear() {
left_flattened
@ -2073,11 +2035,11 @@ impl<'ast, T: Field> Flattener<'ast, T> {
FlatExpression::Identifier(id)
};
FlatExpression::Sub(box new_left, box new_right).into()
FlatExpression::Sub(box new_left, box new_right)
}
FieldElementExpression::Mult(box left, box right) => {
let (left_flattened, left_success) = self.flatten_field_expression(statements_flattened, left).split();
let (right_flattened, right_success) = self.flatten_field_expression(statements_flattened, right).split();
let left_flattened = self.flatten_field_expression(statements_flattened, left);
let right_flattened = self.flatten_field_expression(statements_flattened, right);
let new_left = if left_flattened.is_linear() {
left_flattened
} else {
@ -2092,11 +2054,11 @@ impl<'ast, T: Field> Flattener<'ast, T> {
statements_flattened.push(FlatStatement::Definition(id, right_flattened));
FlatExpression::Identifier(id)
};
FlatExpression::Mult(box new_left, box new_right).into()
FlatExpression::Mult(box new_left, box new_right)
}
FieldElementExpression::Div(box left, box right) => {
let (left_flattened, left_success) = self.flatten_field_expression(statements_flattened, left).split();
let (right_flattened, right_success) = self.flatten_field_expression(statements_flattened, right).split();
let left_flattened = self.flatten_field_expression(statements_flattened, left);
let right_flattened = self.flatten_field_expression(statements_flattened, right);
let new_left: FlatExpression<T> = {
let id = self.use_sym();
statements_flattened.push(FlatStatement::Definition(id, left_flattened));
@ -2137,13 +2099,14 @@ impl<'ast, T: Field> Flattener<'ast, T> {
FlatExpression::Mult(box new_right, box inverse.into()),
));
FlatExpression::from(inverse).into()
inverse.into()
}
FieldElementExpression::Pow(box base, box exponent) => {
match exponent.into_inner() {
UExpressionInner::Value(ref e) => {
// flatten the base expression
let (base_flattened, base_success) = self.flatten_field_expression(statements_flattened, base).split();
let base_flattened =
self.flatten_field_expression(statements_flattened, base.clone());
// we require from the base to be linear
// TODO change that
@ -2207,7 +2170,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
}
false => acc, // this bit is false, keep the previous result
},
).into()
)
}
_ => panic!("Expected number as pow exponent"),
}
@ -2239,7 +2202,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
let flat_expressions = exprs
.into_iter()
.map(|expr| self.flatten_expression(statements_flattened, expr))
.map(|x| x.unwrap().get_field_unchecked())
.map(|x| x.get_field_unchecked())
.collect::<Vec<_>>();
statements_flattened.push(FlatStatement::Return(FlatExpressionList {
@ -2253,7 +2216,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
// define n variables with n the number of primitive types for v_type
// assign them to the n primitive types for expr
let (rhs, rhs_success) = self.flatten_expression(statements_flattened, expr).split();
let rhs = self.flatten_expression(statements_flattened, expr);
let bits = rhs.bits.clone();

View file

@ -299,6 +299,7 @@ pub fn fold_array_expression_inner<'ast, T: Field>(
array: typed_absy::ArrayExpressionInner<'ast, T>,
) -> Vec<zir::ZirExpression<'ast, T>> {
match array {
typed_absy::ArrayExpressionInner::Block(statements, box value) => unimplemented!(),
typed_absy::ArrayExpressionInner::Identifier(id) => {
let variables = flatten_identifier_rec(
f.fold_name(id),
@ -427,6 +428,7 @@ pub fn fold_struct_expression_inner<'ast, T: Field>(
struc: typed_absy::StructExpressionInner<'ast, T>,
) -> Vec<zir::ZirExpression<'ast, T>> {
match struc {
typed_absy::StructExpressionInner::Block(statements, box value) => unimplemented!(),
typed_absy::StructExpressionInner::Identifier(id) => {
let variables = flatten_identifier_rec(
f.fold_name(id),
@ -623,6 +625,15 @@ pub fn fold_boolean_expression<'ast, T: Field>(
e: typed_absy::BooleanExpression<'ast, T>,
) -> zir::BooleanExpression<'ast, T> {
match e {
typed_absy::BooleanExpression::Block(statements, box value) => {
zir::BooleanExpression::Block(
statements
.into_iter()
.flat_map(|s| f.fold_statement(s))
.collect(),
box f.fold_boolean_expression(value),
)
}
typed_absy::BooleanExpression::Value(v) => zir::BooleanExpression::Value(v),
typed_absy::BooleanExpression::Identifier(id) => zir::BooleanExpression::Identifier(
flatten_identifier_rec(f.fold_name(id), &typed_absy::types::ConcreteType::Boolean)[0]
@ -805,6 +816,13 @@ pub fn fold_uint_expression_inner<'ast, T: Field>(
e: typed_absy::UExpressionInner<'ast, T>,
) -> zir::UExpressionInner<'ast, T> {
match e {
typed_absy::UExpressionInner::Block(statements, box value) => zir::UExpressionInner::Block(
statements
.into_iter()
.flat_map(|s| f.fold_statement(s))
.collect(),
box f.fold_uint_expression(value),
),
typed_absy::UExpressionInner::Value(v) => zir::UExpressionInner::Value(v),
typed_absy::UExpressionInner::Identifier(id) => zir::UExpressionInner::Identifier(
flatten_identifier_rec(

View file

@ -74,8 +74,10 @@ fn get_canonical_function<'ast, T: Field>(
}
}
type InlineResult<'ast, T> =
Result<Output<Vec<TypedExpression<'ast, T>>, Vec<Versions<'ast>>>, InlineError<'ast, T>>;
type InlineResult<'ast, T> = Result<
Output<(Vec<TypedStatement<'ast, T>>, Vec<TypedExpression<'ast, T>>), Vec<Versions<'ast>>>,
InlineError<'ast, T>,
>;
pub fn inline_call<'a, 'ast, T: Field>(
k: DeclarationFunctionKey<'ast>,
@ -227,14 +229,7 @@ pub fn inline_call<'a, 'ast, T: Field>(
.chain(std::iter::once(pop_log))
.collect();
let e = match expressions[0].clone() {
TypedExpression::FieldElement(e) => e,
_ => unreachable!(),
};
let res = crate::typed_absy::FieldElementExpression::Block(statements, box e);
Ok(incomplete_data
.map(|d| Output::Incomplete(vec![res.clone().into()], d))
.unwrap_or_else(|| Output::Complete(vec![res.into()])))
.map(|d| Output::Incomplete((statements.clone(), expressions.clone()), d))
.unwrap_or_else(|| Output::Complete((statements, expressions))))
}

View file

@ -22,7 +22,7 @@ use crate::typed_absy::Folder;
use std::collections::HashMap;
use crate::typed_absy::{
ArrayExpression, ArrayExpressionInner, ArrayType, BooleanExpression, CoreIdentifier,
ArrayExpression, ArrayExpressionInner, ArrayType, Block, BooleanExpression, CoreIdentifier,
DeclarationFunctionKey, FieldElementExpression, FunctionCall, Identifier, StructExpression,
StructExpressionInner, Type, Typed, TypedExpression, TypedExpressionList, TypedFunction,
TypedFunctionSymbol, TypedModule, TypedProgram, TypedStatement, UExpression, UExpressionInner,
@ -197,10 +197,13 @@ impl<'ast, 'a, T: Field> Reducer<'ast, 'a, T> {
key: DeclarationFunctionKey<'ast>,
generics: Vec<Option<UExpression<'ast, T>>>,
arguments: Vec<TypedExpression<'ast, T>>,
output_types: Vec<Type<'ast, T>>,
output_type: Type<'ast, T>,
) -> Result<E, Error>
where
E: FunctionCall<'ast, T> + TryFrom<TypedExpression<'ast, T>, Error = ()> + std::fmt::Debug,
E: Block<'ast, T>
+ FunctionCall<'ast, T>
+ TryFrom<TypedExpression<'ast, T>, Error = ()>
+ std::fmt::Debug,
{
let generics = generics
.into_iter()
@ -216,18 +219,23 @@ impl<'ast, 'a, T: Field> Reducer<'ast, 'a, T> {
key.clone(),
generics,
arguments,
output_types,
vec![output_type.clone()],
&self.program,
&mut self.versions,
);
match res {
Ok(Output::Complete(expressions)) => {
Ok(Output::Complete((statements, mut expressions))) => {
self.complete &= true;
Ok(expressions[0].clone().try_into().unwrap())
Ok(E::block(
statements,
expressions.pop().unwrap().try_into().unwrap(),
output_type,
))
}
Ok(Output::Incomplete(expressions, delta_for_loop_versions)) => {
Ok(Output::Incomplete((statements, expressions), delta_for_loop_versions)) => {
self.complete = false;
self.statement_buffer.extend(statements);
self.for_loop_versions_after.extend(delta_for_loop_versions);
Ok(expressions[0].clone().try_into().unwrap())
}
@ -246,7 +254,7 @@ impl<'ast, 'a, T: Field> Reducer<'ast, 'a, T> {
output_types.pop().unwrap(),
))
}
Err(InlineError::Flat(embed, generics, arguments, output_types)) => {
Err(InlineError::Flat(embed, generics, arguments, mut output_types)) => {
let identifier = Identifier::from(CoreIdentifier::Call(0)).version(
*self
.versions
@ -254,7 +262,7 @@ impl<'ast, 'a, T: Field> Reducer<'ast, 'a, T> {
.and_modify(|e| *e += 1) // if it was already declared, we increment
.or_insert(0),
);
let var = Variable::with_id_and_type(identifier, output_types[0].clone());
let var = Variable::with_id_and_type(identifier, output_types.pop().unwrap());
let v = vec![var.clone().into()];
@ -291,6 +299,8 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> {
.map(|a| self.fold_expression(a))
.collect::<Result<_, _>>()?;
unimplemented!("multi def needs to be put in blocks");
match inline_call(
key,
generics,
@ -299,25 +309,33 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> {
&self.program,
&mut self.versions,
) {
Ok(Output::Complete(expressions)) => {
Ok(Output::Complete((statements, expressions))) => {
assert_eq!(v.len(), expressions.len());
self.complete &= true;
Ok(v.into_iter()
.zip(expressions)
.map(|(v, e)| TypedStatement::Definition(v, e))
Ok(statements
.into_iter()
.chain(
v.into_iter()
.zip(expressions)
.map(|(v, e)| TypedStatement::Definition(v, e)),
)
.collect())
}
Ok(Output::Incomplete(expressions, delta_for_loop_versions)) => {
Ok(Output::Incomplete((statements, expressions), delta_for_loop_versions)) => {
assert_eq!(v.len(), expressions.len());
self.complete = false;
self.for_loop_versions_after.extend(delta_for_loop_versions);
Ok(v.into_iter()
.zip(expressions)
.map(|(v, e)| TypedStatement::Definition(v, e))
Ok(statements
.into_iter()
.chain(
v.into_iter()
.zip(expressions)
.map(|(v, e)| TypedStatement::Definition(v, e)),
)
.collect())
}
Err(InlineError::Generic(decl, conc)) => Err(Error::Incompatible(format!(
@ -425,7 +443,7 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> {
) -> Result<BooleanExpression<'ast, T>, Self::Error> {
match e {
BooleanExpression::FunctionCall(key, generics, arguments) => {
self.fold_function_call(key, generics, arguments, vec![Type::Boolean])
self.fold_function_call(key, generics, arguments, Type::Boolean)
}
e => fold_boolean_expression(self, e),
}
@ -440,7 +458,7 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> {
key.clone(),
generics.clone(),
arguments.clone(),
vec![e.get_type()],
e.get_type(),
),
_ => fold_uint_expression(self, e),
}
@ -452,7 +470,7 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> {
) -> Result<FieldElementExpression<'ast, T>, Self::Error> {
match e {
FieldElementExpression::FunctionCall(key, generic, arguments) => {
self.fold_function_call(key, generic, arguments, vec![Type::FieldElement])
self.fold_function_call(key, generic, arguments, Type::FieldElement)
}
e => fold_field_expression(self, e),
}
@ -469,7 +487,7 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> {
key.clone(),
generics,
arguments.clone(),
vec![Type::array(ty.clone())],
Type::array(ty.clone()),
)
.map(|e| e.into_inner()),
ArrayExpressionInner::Slice(box array, box from, box to) => {
@ -501,7 +519,7 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> {
key.clone(),
generics.clone(),
arguments.clone(),
vec![e.get_type()],
e.get_type(),
),
_ => fold_struct_expression(self, e),
}

View file

@ -126,6 +126,9 @@ impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> {
use self::UExpressionInner::*;
let res = match inner {
b @ Block(..) => force_no_reduce(
fold_uint_expression_inner(self, e.bitwidth, b).annotate(e.bitwidth),
),
Value(v) => Value(v).annotate(range).with_max(v),
Identifier(id) => Identifier(id.clone()).annotate(range).metadata(
self.ids

View file

@ -257,6 +257,13 @@ pub fn fold_array_expression_inner<'ast, T: Field, F: Folder<'ast, T>>(
e: ArrayExpressionInner<'ast, T>,
) -> ArrayExpressionInner<'ast, T> {
match e {
ArrayExpressionInner::Block(statements, box value) => ArrayExpressionInner::Block(
statements
.into_iter()
.flat_map(|s| f.fold_statement(s))
.collect(),
box f.fold_array_expression(value),
),
ArrayExpressionInner::Identifier(id) => ArrayExpressionInner::Identifier(f.fold_name(id)),
ArrayExpressionInner::Value(exprs) => ArrayExpressionInner::Value(
exprs
@ -308,6 +315,13 @@ pub fn fold_struct_expression_inner<'ast, T: Field, F: Folder<'ast, T>>(
e: StructExpressionInner<'ast, T>,
) -> StructExpressionInner<'ast, T> {
match e {
StructExpressionInner::Block(statements, box value) => StructExpressionInner::Block(
statements
.into_iter()
.flat_map(|s| f.fold_statement(s))
.collect(),
box f.fold_struct_expression(value),
),
StructExpressionInner::Identifier(id) => StructExpressionInner::Identifier(f.fold_name(id)),
StructExpressionInner::Value(exprs) => {
StructExpressionInner::Value(exprs.into_iter().map(|e| f.fold_expression(e)).collect())
@ -428,6 +442,13 @@ pub fn fold_boolean_expression<'ast, T: Field, F: Folder<'ast, T>>(
e: BooleanExpression<'ast, T>,
) -> BooleanExpression<'ast, T> {
match e {
BooleanExpression::Block(statements, box value) => BooleanExpression::Block(
statements
.into_iter()
.flat_map(|s| f.fold_statement(s))
.collect(),
box f.fold_boolean_expression(value),
),
BooleanExpression::Value(v) => BooleanExpression::Value(v),
BooleanExpression::Identifier(id) => BooleanExpression::Identifier(f.fold_name(id)),
BooleanExpression::FieldEq(box e1, box e2) => {
@ -551,6 +572,13 @@ pub fn fold_uint_expression_inner<'ast, T: Field, F: Folder<'ast, T>>(
e: UExpressionInner<'ast, T>,
) -> UExpressionInner<'ast, T> {
match e {
UExpressionInner::Block(statements, box value) => UExpressionInner::Block(
statements
.into_iter()
.flat_map(|s| f.fold_statement(s))
.collect(),
box f.fold_uint_expression(value),
),
UExpressionInner::Value(v) => UExpressionInner::Value(v),
UExpressionInner::Identifier(id) => UExpressionInner::Identifier(f.fold_name(id)),
UExpressionInner::Add(box left, box right) => {

View file

@ -591,6 +591,16 @@ impl<'ast, T: fmt::Display> fmt::Display for ArrayExpression<'ast, T> {
impl<'ast, T: fmt::Display> fmt::Display for StructExpression<'ast, T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self.inner {
StructExpressionInner::Block(ref statements, ref value) => write!(
f,
"{{{}}}",
statements
.iter()
.map(|s| s.to_string())
.chain(std::iter::once(value.to_string()))
.collect::<Vec<_>>()
.join("\n")
),
StructExpressionInner::Identifier(ref var) => write!(f, "{}", var),
StructExpressionInner::Value(ref values) => write!(
f,
@ -805,6 +815,10 @@ impl<'ast, T> From<T> for FieldElementExpression<'ast, T> {
/// An expression of type `bool`
#[derive(Clone, PartialEq, Debug, Hash, Eq)]
pub enum BooleanExpression<'ast, T> {
Block(
Vec<TypedStatement<'ast, T>>,
Box<BooleanExpression<'ast, T>>,
),
Identifier(Identifier<'ast>),
Value(bool),
FieldLt(
@ -961,6 +975,7 @@ impl<'ast, T> std::iter::FromIterator<TypedExpressionOrSpread<'ast, T>> for Arra
#[derive(Clone, PartialEq, Debug, Hash, Eq)]
pub enum ArrayExpressionInner<'ast, T> {
Block(Vec<TypedStatement<'ast, T>>, Box<ArrayExpression<'ast, T>>),
Identifier(Identifier<'ast>),
Value(ArrayValue<'ast, T>),
FunctionCall(
@ -1069,6 +1084,7 @@ impl<'ast, T> StructExpression<'ast, T> {
#[derive(Clone, PartialEq, Debug, Hash, Eq)]
pub enum StructExpressionInner<'ast, T> {
Block(Vec<TypedStatement<'ast, T>>, Box<StructExpression<'ast, T>>),
Identifier(Identifier<'ast>),
Value(Vec<TypedExpression<'ast, T>>),
FunctionCall(
@ -1276,6 +1292,16 @@ impl<'ast, T: fmt::Display> fmt::Display for FieldElementExpression<'ast, T> {
impl<'ast, T: fmt::Display> fmt::Display for UExpression<'ast, T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self.inner {
UExpressionInner::Block(ref statements, ref value) => write!(
f,
"{{{}}}",
statements
.iter()
.map(|s| s.to_string())
.chain(std::iter::once(value.to_string()))
.collect::<Vec<_>>()
.join("\n")
),
UExpressionInner::Value(ref v) => write!(f, "{}", v),
UExpressionInner::Identifier(ref var) => write!(f, "{}", var),
UExpressionInner::Add(ref lhs, ref rhs) => write!(f, "({} + {})", lhs, rhs),
@ -1333,6 +1359,16 @@ impl<'ast, T: fmt::Display> fmt::Display for UExpression<'ast, T> {
impl<'ast, T: fmt::Display> fmt::Display for BooleanExpression<'ast, T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
BooleanExpression::Block(ref statements, ref value) => write!(
f,
"{{{}}}",
statements
.iter()
.map(|s| s.to_string())
.chain(std::iter::once(value.to_string()))
.collect::<Vec<_>>()
.join("\n")
),
BooleanExpression::Identifier(ref var) => write!(f, "{}", var),
BooleanExpression::FieldLt(ref lhs, ref rhs) => write!(f, "{} < {}", lhs, rhs),
BooleanExpression::FieldLe(ref lhs, ref rhs) => write!(f, "{} <= {}", lhs, rhs),
@ -1390,6 +1426,16 @@ impl<'ast, T: fmt::Display> fmt::Display for BooleanExpression<'ast, T> {
impl<'ast, T: fmt::Display> fmt::Display for ArrayExpressionInner<'ast, T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
ArrayExpressionInner::Block(ref statements, ref value) => write!(
f,
"{{{}}}",
statements
.iter()
.map(|s| s.to_string())
.chain(std::iter::once(value.to_string()))
.collect::<Vec<_>>()
.join("\n")
),
ArrayExpressionInner::Identifier(ref var) => write!(f, "{}", var),
ArrayExpressionInner::Value(ref values) => write!(
f,
@ -1809,3 +1855,76 @@ impl<'ast, T: Field> FunctionCall<'ast, T> for StructExpression<'ast, T> {
StructExpressionInner::FunctionCall(key, generics, arguments).annotate(struct_ty)
}
}
pub trait Block<'ast, T> {
fn block(
statements: Vec<TypedStatement<'ast, T>>,
value: Self,
output_type: Type<'ast, T>,
) -> Self;
}
impl<'ast, T: Field> Block<'ast, T> for FieldElementExpression<'ast, T> {
fn block(
statements: Vec<TypedStatement<'ast, T>>,
value: Self,
output_type: Type<'ast, T>,
) -> Self {
assert_eq!(output_type, Type::FieldElement);
FieldElementExpression::Block(statements, box value)
}
}
impl<'ast, T: Field> Block<'ast, T> for BooleanExpression<'ast, T> {
fn block(
statements: Vec<TypedStatement<'ast, T>>,
value: Self,
output_type: Type<'ast, T>,
) -> Self {
assert_eq!(output_type, Type::Boolean);
BooleanExpression::Block(statements, box value)
}
}
impl<'ast, T: Field> Block<'ast, T> for UExpression<'ast, T> {
fn block(
statements: Vec<TypedStatement<'ast, T>>,
value: Self,
output_type: Type<'ast, T>,
) -> Self {
let bitwidth = match output_type {
Type::Uint(bitwidth) => bitwidth,
_ => unreachable!(),
};
UExpressionInner::Block(statements, box value).annotate(bitwidth)
}
}
impl<'ast, T: Field> Block<'ast, T> for ArrayExpression<'ast, T> {
fn block(
statements: Vec<TypedStatement<'ast, T>>,
value: Self,
output_type: Type<'ast, T>,
) -> Self {
let array_ty = match output_type {
Type::Array(array_ty) => array_ty,
_ => unreachable!(),
};
ArrayExpressionInner::Block(statements, box value).annotate(*array_ty.ty, array_ty.size)
}
}
impl<'ast, T: Field> Block<'ast, T> for StructExpression<'ast, T> {
fn block(
statements: Vec<TypedStatement<'ast, T>>,
value: Self,
output_type: Type<'ast, T>,
) -> Self {
let struct_ty = match output_type {
Type::Struct(struct_ty) => struct_ty,
_ => unreachable!(),
};
StructExpressionInner::Block(statements, box value).annotate(struct_ty)
}
}

View file

@ -302,6 +302,14 @@ pub fn fold_array_expression_inner<'ast, T: Field, F: ResultFolder<'ast, T>>(
e: ArrayExpressionInner<'ast, T>,
) -> Result<ArrayExpressionInner<'ast, T>, F::Error> {
let e = match e {
ArrayExpressionInner::Block(statements, box value) => ArrayExpressionInner::Block(
statements
.into_iter()
.map(|s| f.fold_statement(s))
.collect::<Result<Vec<_>, _>>()
.map(|r| r.into_iter().flatten().collect())?,
box f.fold_array_expression(value)?,
),
ArrayExpressionInner::Identifier(id) => ArrayExpressionInner::Identifier(f.fold_name(id)?),
ArrayExpressionInner::Value(exprs) => ArrayExpressionInner::Value(
exprs
@ -357,6 +365,14 @@ pub fn fold_struct_expression_inner<'ast, T: Field, F: ResultFolder<'ast, T>>(
e: StructExpressionInner<'ast, T>,
) -> Result<StructExpressionInner<'ast, T>, F::Error> {
let e = match e {
StructExpressionInner::Block(statements, box value) => StructExpressionInner::Block(
statements
.into_iter()
.map(|s| f.fold_statement(s))
.collect::<Result<Vec<_>, _>>()
.map(|r| r.into_iter().flatten().collect())?,
box f.fold_struct_expression(value)?,
),
StructExpressionInner::Identifier(id) => {
StructExpressionInner::Identifier(f.fold_name(id)?)
}
@ -491,6 +507,14 @@ pub fn fold_boolean_expression<'ast, T: Field, F: ResultFolder<'ast, T>>(
e: BooleanExpression<'ast, T>,
) -> Result<BooleanExpression<'ast, T>, F::Error> {
let e = match e {
BooleanExpression::Block(statements, box value) => BooleanExpression::Block(
statements
.into_iter()
.map(|s| f.fold_statement(s))
.collect::<Result<Vec<_>, _>>()
.map(|r| r.into_iter().flatten().collect())?,
box f.fold_boolean_expression(value)?,
),
BooleanExpression::Value(v) => BooleanExpression::Value(v),
BooleanExpression::Identifier(id) => BooleanExpression::Identifier(f.fold_name(id)?),
BooleanExpression::FieldEq(box e1, box e2) => {
@ -618,6 +642,14 @@ pub fn fold_uint_expression_inner<'ast, T: Field, F: ResultFolder<'ast, T>>(
e: UExpressionInner<'ast, T>,
) -> Result<UExpressionInner<'ast, T>, F::Error> {
let e = match e {
UExpressionInner::Block(statements, box value) => UExpressionInner::Block(
statements
.into_iter()
.map(|s| f.fold_statement(s))
.collect::<Result<Vec<_>, _>>()
.map(|r| r.into_iter().flatten().collect())?,
box f.fold_uint_expression(value)?,
),
UExpressionInner::Value(v) => UExpressionInner::Value(v),
UExpressionInner::Identifier(id) => UExpressionInner::Identifier(f.fold_name(id)?),
UExpressionInner::Add(box left, box right) => {

View file

@ -175,6 +175,7 @@ impl<'ast, T> PartialEq<usize> for UExpression<'ast, T> {
#[derive(Clone, PartialEq, Eq, Hash, Debug)]
pub enum UExpressionInner<'ast, T> {
Block(Vec<TypedStatement<'ast, T>>, Box<UExpression<'ast, T>>),
Identifier(Identifier<'ast>),
Value(u128),
Add(Box<UExpression<'ast, T>>, Box<UExpression<'ast, T>>),

View file

@ -169,6 +169,13 @@ pub fn fold_boolean_expression<'ast, T: Field, F: Folder<'ast, T>>(
e: BooleanExpression<'ast, T>,
) -> BooleanExpression<'ast, T> {
match e {
BooleanExpression::Block(statements, box value) => BooleanExpression::Block(
statements
.into_iter()
.flat_map(|s| f.fold_statement(s))
.collect(),
box f.fold_boolean_expression(value),
),
BooleanExpression::Value(v) => BooleanExpression::Value(v),
BooleanExpression::Identifier(id) => BooleanExpression::Identifier(f.fold_name(id)),
BooleanExpression::FieldEq(box e1, box e2) => {
@ -265,6 +272,13 @@ pub fn fold_uint_expression_inner<'ast, T: Field, F: Folder<'ast, T>>(
e: UExpressionInner<'ast, T>,
) -> UExpressionInner<'ast, T> {
match e {
UExpressionInner::Block(statements, box value) => UExpressionInner::Block(
statements
.into_iter()
.flat_map(|s| f.fold_statement(s))
.collect(),
box f.fold_uint_expression(value),
),
UExpressionInner::Value(v) => UExpressionInner::Value(v),
UExpressionInner::Identifier(id) => UExpressionInner::Identifier(f.fold_name(id)),
UExpressionInner::Add(box left, box right) => {

View file

@ -275,6 +275,7 @@ pub enum FieldElementExpression<'ast, T> {
/// An expression of type `bool`
#[derive(Clone, PartialEq, Hash, Eq)]
pub enum BooleanExpression<'ast, T> {
Block(Vec<ZirStatement<'ast, T>>, Box<BooleanExpression<'ast, T>>),
Identifier(Identifier<'ast>),
Value(bool),
FieldLt(
@ -422,6 +423,16 @@ impl<'ast, T: fmt::Display> fmt::Display for FieldElementExpression<'ast, T> {
impl<'ast, T: fmt::Display> fmt::Display for UExpression<'ast, T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self.inner {
UExpressionInner::Block(ref statements, ref value) => write!(
f,
"{{{}}}",
statements
.iter()
.map(|s| s.to_string())
.chain(std::iter::once(value.to_string()))
.collect::<Vec<_>>()
.join("\n")
),
UExpressionInner::Value(ref v) => write!(f, "{}", v),
UExpressionInner::Identifier(ref var) => write!(f, "{}", var),
UExpressionInner::Add(ref lhs, ref rhs) => write!(f, "({} + {})", lhs, rhs),
@ -447,6 +458,16 @@ impl<'ast, T: fmt::Display> fmt::Display for UExpression<'ast, T> {
impl<'ast, T: fmt::Display> fmt::Display for BooleanExpression<'ast, T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
BooleanExpression::Block(ref statements, ref value) => write!(
f,
"{{{}}}",
statements
.iter()
.map(|s| s.to_string())
.chain(std::iter::once(value.to_string()))
.collect::<Vec<_>>()
.join("\n")
),
BooleanExpression::Identifier(ref var) => write!(f, "{}", var),
BooleanExpression::Value(b) => write!(f, "{}", b),
BooleanExpression::FieldLt(ref lhs, ref rhs) => write!(f, "{} < {}", lhs, rhs),
@ -475,6 +496,9 @@ impl<'ast, T: fmt::Display> fmt::Display for BooleanExpression<'ast, T> {
impl<'ast, T: fmt::Debug> fmt::Debug for BooleanExpression<'ast, T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
BooleanExpression::Block(ref statements, ref value) => {
write!(f, "Block({:?}, {:?})", statements, value)
}
BooleanExpression::Identifier(ref var) => write!(f, "Ide({:?})", var),
BooleanExpression::Value(b) => write!(f, "Value({})", b),
BooleanExpression::FieldLt(ref lhs, ref rhs) => {

View file

@ -1,6 +1,6 @@
use crate::zir::identifier::Identifier;
use crate::zir::types::UBitwidth;
use crate::zir::BooleanExpression;
use crate::zir::{BooleanExpression, ZirStatement};
use zokrates_field::Field;
impl<'ast, T: Field> UExpression<'ast, T> {
@ -158,6 +158,7 @@ pub struct UExpression<'ast, T> {
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum UExpressionInner<'ast, T> {
Block(Vec<ZirStatement<'ast, T>>, Box<UExpression<'ast, T>>),
Identifier(Identifier<'ast>),
Value(u128),
Add(Box<UExpression<'ast, T>>, Box<UExpression<'ast, T>>),