1
0
Fork 0
mirror of synced 2025-09-23 20:28:36 +00:00

introduce block into ast and implement isolation on that

This commit is contained in:
schaeff 2021-05-16 19:16:25 +02:00
parent 43b7457ec9
commit a1a65378a7
20 changed files with 448 additions and 207 deletions

View file

@ -1,32 +0,0 @@
def zero(field x) -> field:
assert(x == 0)
return 0
def inverse(field x) -> field:
assert(x != 0)
return 1/x
def main(field x) -> field:
return if x == 0 then zero(x) else inverse(x) fi
// def yes(bool x) -> bool:
// assert(x)
// return x
// def no(bool x) -> bool:
// assert(!x)
// return !x
// def main(bool x) -> bool:
// return if x then yes(x) else no(x) fi
// def ones(field[2] a) -> field[2]:
// assert(a == [1, 1])
// return a
// def twos(field[2] a) -> field[2]:
// assert(a == [2, 2])
// return a
// def main(bool condition, field[2] a, field[2] b) -> field[2]:
// return if condition then ones(a) else twos(b) fi

View file

@ -1,10 +1,10 @@
def bound(field x) -> u32:
return 41 + 1
return 41 + x
def main(field a) -> field:
field x = 7
x = x + 1
for u32 i in 0..bound(x) do
for u32 i in 0..bound(x) + bound(x + 1) do
// x = x + a
x = x + a
endfor

View file

@ -173,13 +173,9 @@ pub fn compile<T: Field, E: Into<imports::Error>>(
let (typed_ast, abi) = check_with_arena(source, location, resolver, &arena)?;
println!("{}", typed_ast);
// flatten input program
let program_flattened = Flattener::flatten(typed_ast, config);
println!("{}", program_flattened);
// analyse (constant propagation after call resolution)
let program_flattened = program_flattened.analyse();

View file

@ -450,15 +450,6 @@ impl<'ast, T: Field> Flattener<'ast, T> {
let condition_id = self.use_sym();
statements_flattened.push(FlatStatement::Definition(condition_id, condition));
println!(
"BEFORE\n {}\n",
alternative_statements
.iter()
.map(|s| s.to_string())
.collect::<Vec<_>>()
.join("\n")
);
let consequence_statements =
self.make_conditional(consequence_statements, condition_id.into());
let alternative_statements = self.make_conditional(
@ -469,15 +460,6 @@ impl<'ast, T: Field> Flattener<'ast, T> {
),
);
println!(
"AFTER\n {}\n",
alternative_statements
.iter()
.map(|s| s.to_string())
.collect::<Vec<_>>()
.join("\n")
);
statements_flattened.extend(consequence_statements);
statements_flattened.extend(alternative_statements);

View file

@ -1963,7 +1963,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
TypedExpression::Boolean(condition) => {
match (consequence_checked, alternative_checked) {
(TypedExpression::FieldElement(consequence), TypedExpression::FieldElement(alternative)) => {
Ok(FieldElementExpression::IfElse(box condition, box FieldElementExpression::Block(vec![], box consequence), box FieldElementExpression::Block(vec![], box alternative)).into())
Ok(FieldElementExpression::IfElse(box condition, box consequence, box alternative).into())
},
(TypedExpression::Boolean(consequence), TypedExpression::Boolean(alternative)) => {
Ok(BooleanExpression::IfElse(box condition, box consequence, box alternative).into())

View file

@ -19,7 +19,7 @@ impl BoundsChecker {
let array = self.fold_array_expression(array)?;
let index = self.fold_uint_expression(index)?;
match (array.get_array_type().size.as_inner(), index.as_inner()) {
match (array.ty().size.as_inner(), index.as_inner()) {
(UExpressionInner::Value(size), UExpressionInner::Value(index)) => {
if index >= size {
return Err(format!(
@ -53,7 +53,7 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for BoundsChecker {
let to = self.fold_uint_expression(to)?;
match (
array.get_array_type().size.as_inner(),
array.ty().size.as_inner(),
from.as_inner(),
to.as_inner(),
) {

View file

@ -0,0 +1,102 @@
// Isolate branches means making sure that any branch is enclosed in a block.
// This is important, because we want any statement resulting from inlining any branch to be isolated from the coller, so that its panics can be conditional to the branch being logically run
// `if c then a else b fi` becomes `if c then { a } else { b } fi`, and down the line any statements resulting from trating `a` and `b` can be safely kept inside the respective blocks.
use crate::typed_absy::folder::*;
use crate::typed_absy::*;
use zokrates_field::Field;
pub struct Isolator;
impl Isolator {
pub fn isolate<T: Field>(p: TypedProgram<T>) -> TypedProgram<T> {
let mut isolator = Isolator;
isolator.fold_program(p)
}
}
impl<'ast, T: Field> Folder<'ast, T> for Isolator {
fn fold_field_expression(
&mut self,
e: FieldElementExpression<'ast, T>,
) -> FieldElementExpression<'ast, T> {
match e {
FieldElementExpression::IfElse(box condition, box consequence, box alternative) => {
FieldElementExpression::IfElse(
box self.fold_boolean_expression(condition),
box FieldElementExpression::block(vec![],consequence),
box FieldElementExpression::block(vec![],alternative),
)
}
e => fold_field_expression(self, e),
}
}
fn fold_boolean_expression(
&mut self,
e: BooleanExpression<'ast, T>,
) -> BooleanExpression<'ast, T> {
match e {
BooleanExpression::IfElse(box condition, box consequence, box alternative) => {
BooleanExpression::IfElse(
box self.fold_boolean_expression(condition),
box BooleanExpression::block(vec![],consequence),
box BooleanExpression::block(vec![],alternative),
)
}
e => fold_boolean_expression(self, e),
}
}
fn fold_uint_expression_inner(
&mut self,
bitwidth: UBitwidth,
e: UExpressionInner<'ast, T>,
) -> UExpressionInner<'ast, T> {
match e {
UExpressionInner::IfElse(box condition, box consequence, box alternative) => {
UExpressionInner::IfElse(
box self.fold_boolean_expression(condition),
box UExpression::block(vec![],consequence),
box UExpression::block(vec![],alternative),
)
}
e => fold_uint_expression_inner(self, bitwidth, e),
}
}
fn fold_array_expression_inner(
&mut self,
array_ty: &ArrayType<'ast, T>,
e: ArrayExpressionInner<'ast, T>,
) -> ArrayExpressionInner<'ast, T> {
match e {
ArrayExpressionInner::IfElse(box condition, box consequence, box alternative) => {
ArrayExpressionInner::IfElse(
box self.fold_boolean_expression(condition),
box ArrayExpression::block(vec![],consequence),
box ArrayExpression::block(vec![],alternative),
)
}
e => fold_array_expression_inner(self, array_ty, e),
}
}
fn fold_struct_expression_inner(
&mut self,
struct_ty: &StructType<'ast, T>,
e: StructExpressionInner<'ast, T>,
) -> StructExpressionInner<'ast, T> {
match e {
StructExpressionInner::IfElse(box condition, box consequence, box alternative) => {
StructExpressionInner::IfElse(
box self.fold_boolean_expression(condition),
box StructExpression::block(vec![],consequence),
box StructExpression::block(vec![],alternative),
)
}
e => fold_struct_expression_inner(self, struct_ty, e),
}
}
}

View file

@ -683,11 +683,12 @@ pub fn fold_field_expression<'ast, T: Field>(
_ => unreachable!(""),
}
}
typed_absy::FieldElementExpression::Block(statements, box value) => {
statements
typed_absy::FieldElementExpression::Block(block) => {
block
.statements
.into_iter()
.for_each(|s| f.fold_statement(statements_buffer, s));
f.fold_field_expression(statements_buffer, value)
f.fold_field_expression(statements_buffer, *block.value)
}
}
}

View file

@ -5,6 +5,7 @@
//! @date 2018
mod bounds_checker;
mod branch_isolator;
mod constant_inliner;
mod flat_propagation;
mod flatten_complex_types;
@ -17,6 +18,7 @@ mod variable_read_remover;
mod variable_write_remover;
use self::bounds_checker::BoundsChecker;
use self::branch_isolator::Isolator;
use self::flatten_complex_types::Flattener;
use self::propagation::Propagator;
use self::reducer::reduce_program;
@ -75,6 +77,8 @@ impl<'ast, T: Field> TypedProgram<'ast, T> {
pub fn analyse(self) -> Result<(ZirProgram<'ast, T>, Abi), Error> {
// inline user-defined constants
let r = ConstantInliner::inline(self);
// isolate branches
let r = Isolator::isolate(r);
// reduce the program to a single function
let r = reduce_program(r).map_err(Error::from)?;
// generate abi

View file

@ -147,7 +147,15 @@ fn is_constant<T: Field>(e: &TypedExpression<T>) -> bool {
StructExpressionInner::Value(v) => v.iter().all(|e| is_constant(e)),
_ => false,
},
TypedExpression::Uint(a) => matches!(a.as_inner(), UExpressionInner::Value(..)),
TypedExpression::Uint(a) => {
matches!(a.as_inner(), UExpressionInner::Value(..))
|| match a.as_inner() {
UExpressionInner::Block(_, e) => {
is_constant(&TypedExpression::from(*e.clone()))
}
_ => false,
}
}
_ => false,
}
}
@ -167,7 +175,7 @@ fn remove_spreads<T: Field>(e: TypedExpression<T>) -> TypedExpression<T> {
match e {
TypedExpression::Array(a) => {
let array_ty = a.get_array_type();
let array_ty = a.ty();
match a.into_inner() {
ArrayExpressionInner::Value(v) => ArrayExpressionInner::Value(
@ -353,8 +361,6 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Propagator<'ast, 'a, T> {
let expression_list = self.fold_expression_list(expression_list)?;
match expression_list {
l @ TypedExpressionList::Block(..) => fold_expression_list(self, l)
.map(|l| vec![TypedStatement::MultipleDefinition(assignees, l)]),
TypedExpressionList::EmbedCall(embed, generics, arguments, types) => {
let arguments: Vec<_> = arguments
.into_iter()

View file

@ -22,11 +22,11 @@ use crate::typed_absy::Folder;
use std::collections::HashMap;
use crate::typed_absy::{
ArrayExpression, ArrayExpressionInner, ArrayType, Block, BooleanExpression, CoreIdentifier,
DeclarationFunctionKey, FieldElementExpression, FunctionCall, Identifier, StructExpression,
StructExpressionInner, Type, Typed, TypedExpression, TypedExpressionList, TypedFunction,
TypedFunctionSymbol, TypedModule, TypedProgram, TypedStatement, UExpression, UExpressionInner,
Variable,
ArrayExpression, ArrayExpressionInner, ArrayType, BlockExpression, BooleanExpression,
CoreIdentifier, DeclarationFunctionKey, FieldElementExpression, FunctionCall, Identifier,
StructExpression, StructExpressionInner, StructType, Type, TypedExpression,
TypedExpressionList, TypedFunction, TypedFunctionSymbol, TypedModule, TypedProgram,
TypedStatement, UBitwidth, UExpression, UExpressionInner, Variable,
};
use std::convert::{TryFrom, TryInto};
@ -200,10 +200,7 @@ impl<'ast, 'a, T: Field> Reducer<'ast, 'a, T> {
output_type: Type<'ast, T>,
) -> Result<E, Error>
where
E: Block<'ast, T>
+ FunctionCall<'ast, T>
+ TryFrom<TypedExpression<'ast, T>, Error = ()>
+ std::fmt::Debug,
E: FunctionCall<'ast, T> + TryFrom<TypedExpression<'ast, T>, Error = ()> + std::fmt::Debug,
{
let generics = generics
.into_iter()
@ -227,11 +224,8 @@ impl<'ast, 'a, T: Field> Reducer<'ast, 'a, T> {
match res {
Ok(Output::Complete((statements, mut expressions))) => {
self.complete &= true;
Ok(E::block(
statements,
expressions.pop().unwrap().try_into().unwrap(),
output_type,
))
self.statement_buffer.extend(statements);
Ok(expressions.pop().unwrap().try_into().unwrap())
}
Ok(Output::Incomplete((statements, expressions), delta_for_loop_versions)) => {
self.complete = false;
@ -280,6 +274,29 @@ impl<'ast, 'a, T: Field> Reducer<'ast, 'a, T> {
impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> {
type Error = Error;
fn fold_block_expression<E: ResultFold<'ast, T>>(
&mut self,
b: BlockExpression<'ast, T, E>,
) -> Result<BlockExpression<'ast, T, E>, Self::Error> {
// backup the statements and continue with a fresh state
let statement_buffer = std::mem::take(&mut self.statement_buffer);
let block = fold_block_expression(self, b)?;
// put the original statements back and extract the statements created by visiting the block
let extra_statements = std::mem::replace(&mut self.statement_buffer, statement_buffer);
// return the visited block, augmented with the statements created while visiting it
Ok(BlockExpression {
statements: block
.statements
.into_iter()
.chain(extra_statements)
.collect(),
..block
})
}
fn fold_statement(
&mut self,
s: TypedStatement<'ast, T>,
@ -433,7 +450,8 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> {
s => fold_statement(self, s),
};
res.map(|res| self.statement_buffer.drain(..).chain(res).collect())
//res.map(|res| self.statement_buffer.drain(..).chain(res).collect())
res
}
fn fold_boolean_expression(
@ -448,18 +466,21 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> {
}
}
fn fold_uint_expression(
fn fold_uint_expression_inner(
&mut self,
e: UExpression<'ast, T>,
) -> Result<UExpression<'ast, T>, Self::Error> {
match e.as_inner() {
UExpressionInner::FunctionCall(key, generics, arguments) => self.fold_function_call(
key.clone(),
generics.clone(),
arguments.clone(),
e.get_type(),
),
_ => fold_uint_expression(self, e),
bitwidth: UBitwidth,
e: UExpressionInner<'ast, T>,
) -> Result<UExpressionInner<'ast, T>, Self::Error> {
match e {
UExpressionInner::FunctionCall(key, generics, arguments) => self
.fold_function_call::<UExpression<'ast, T>>(
key,
generics,
arguments,
Type::Uint(bitwidth),
)
.map(|e| e.into_inner()),
e => fold_uint_expression_inner(self, bitwidth, e),
}
}
@ -477,7 +498,7 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> {
fn fold_array_expression_inner(
&mut self,
ty: &ArrayType<'ast, T>,
array_ty: &ArrayType<'ast, T>,
e: ArrayExpressionInner<'ast, T>,
) -> Result<ArrayExpressionInner<'ast, T>, Self::Error> {
match e {
@ -486,7 +507,7 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> {
key.clone(),
generics,
arguments.clone(),
Type::array(ty.clone()),
Type::array(array_ty.clone()),
)
.map(|e| e.into_inner()),
ArrayExpressionInner::Slice(box array, box from, box to) => {
@ -504,23 +525,25 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> {
}
}
}
_ => fold_array_expression_inner(self, &ty, e),
_ => fold_array_expression_inner(self, &array_ty, e),
}
}
fn fold_struct_expression(
fn fold_struct_expression_inner(
&mut self,
e: StructExpression<'ast, T>,
) -> Result<StructExpression<'ast, T>, Self::Error> {
match e.as_inner() {
struct_ty: &StructType<'ast, T>,
e: StructExpressionInner<'ast, T>,
) -> Result<StructExpressionInner<'ast, T>, Self::Error> {
match e {
StructExpressionInner::FunctionCall(key, generics, arguments) => self
.fold_function_call(
key.clone(),
generics.clone(),
arguments.clone(),
e.get_type(),
),
_ => fold_struct_expression(self, e),
.fold_function_call::<StructExpression<'ast, T>>(
key,
generics,
arguments,
Type::Struct(struct_ty.clone()),
)
.map(|e| e.into_inner()),
_ => fold_struct_expression_inner(self, struct_ty, e),
}
}
}
@ -597,7 +620,14 @@ fn reduce_function<'ast, T: Field>(
statements: f
.statements
.into_iter()
.map(|s| reducer.fold_statement(s))
.map(|s| {
let res = reducer.fold_statement(s)?;
Ok(reducer
.statement_buffer
.drain(..)
.chain(res)
.collect::<Vec<_>>())
})
.collect::<Result<Vec<_>, _>>()?
.into_iter()
.flatten()

View file

@ -4,6 +4,16 @@ use crate::typed_absy::types::{ArrayType, StructMember, StructType};
use crate::typed_absy::*;
use zokrates_field::Field;
pub trait Fold<'ast, T: Field>: Sized {
fn fold<F: Folder<'ast, T>>(self, f: &mut F) -> Self;
}
impl<'ast, T: Field> Fold<'ast, T> for FieldElementExpression<'ast, T> {
fn fold<F: Folder<'ast, T>>(self, f: &mut F) -> Self {
f.fold_field_expression(self)
}
}
pub trait Folder<'ast, T: Field>: Sized {
fn fold_program(&mut self, p: TypedProgram<'ast, T>) -> TypedProgram<'ast, T> {
fold_program(self, p)
@ -137,6 +147,13 @@ pub trait Folder<'ast, T: Field>: Sized {
}
}
fn fold_block_expression<E: Fold<'ast, T>>(
&mut self,
block: BlockExpression<'ast, T, E>,
) -> BlockExpression<'ast, T, E> {
fold_block_expression(self, block)
}
fn fold_array_expression(&mut self, e: ArrayExpression<'ast, T>) -> ArrayExpression<'ast, T> {
fold_array_expression(self, e)
}
@ -358,13 +375,9 @@ pub fn fold_field_expression<'ast, T: Field, F: Folder<'ast, T>>(
e: FieldElementExpression<'ast, T>,
) -> FieldElementExpression<'ast, T> {
match e {
FieldElementExpression::Block(statements, box value) => FieldElementExpression::Block(
statements
.into_iter()
.flat_map(|s| f.fold_statement(s))
.collect(),
box f.fold_field_expression(value),
),
FieldElementExpression::Block(block) => {
FieldElementExpression::Block(f.fold_block_expression(block))
}
FieldElementExpression::Number(n) => FieldElementExpression::Number(n),
FieldElementExpression::Identifier(id) => {
FieldElementExpression::Identifier(f.fold_name(id))
@ -688,6 +701,20 @@ pub fn fold_uint_expression_inner<'ast, T: Field, F: Folder<'ast, T>>(
}
}
pub fn fold_block_expression<'ast, T: Field, E: Fold<'ast, T>, F: Folder<'ast, T>>(
f: &mut F,
block: BlockExpression<'ast, T, E>,
) -> BlockExpression<'ast, T, E> {
BlockExpression {
statements: block
.statements
.into_iter()
.flat_map(|s| f.fold_statement(s))
.collect(),
value: box block.value.fold(f),
}
}
pub fn fold_function<'ast, T: Field, F: Folder<'ast, T>>(
f: &mut F,
fun: TypedFunction<'ast, T>,
@ -749,13 +776,6 @@ pub fn fold_expression_list<'ast, T: Field, F: Folder<'ast, T>>(
types.into_iter().map(|t| f.fold_type(t)).collect(),
)
}
TypedExpressionList::Block(statements, values) => TypedExpressionList::Block(
statements
.into_iter()
.flat_map(|s| f.fold_statement(s))
.collect(),
values.into_iter().map(|v| f.fold_expression(v)).collect(),
),
}
}

View file

@ -481,7 +481,7 @@ impl<'ast, T: Field> ArrayExpression<'ast, T> {
array: Self,
target_inner_ty: Type<'ast, T>,
) -> Result<Self, TypedExpression<'ast, T>> {
let array_ty = array.get_array_type();
let array_ty = array.ty();
// elements must fit in the target type
match array.into_inner() {

View file

@ -489,13 +489,7 @@ impl<'ast, T: fmt::Display> fmt::Display for TypedStatement<'ast, T> {
TypedStatement::Declaration(ref var) => write!(f, "{}", var),
TypedStatement::Definition(ref lhs, ref rhs) => write!(f, "{} = {}", lhs, rhs),
TypedStatement::Assertion(ref e) => write!(f, "assert({})", e),
TypedStatement::For(ref var, ref start, ref stop, ref list) => {
writeln!(f, "for {} in {}..{} do", var, start, stop)?;
for l in list {
writeln!(f, "\t\t{}", l)?;
}
write!(f, "\tendfor")
}
TypedStatement::For(..) => unreachable!("fmt_indented should be called instead"),
TypedStatement::MultipleDefinition(ref ids, ref rhs) => {
for (i, id) in ids.iter().enumerate() {
write!(f, "{}", id)?;
@ -713,7 +707,6 @@ pub enum TypedExpressionList<'ast, T> {
Vec<TypedExpression<'ast, T>>,
Vec<Type<'ast, T>>,
),
Block(Vec<TypedStatement<'ast, T>>, Vec<TypedExpression<'ast, T>>),
}
impl<'ast, T: Field> MultiTyped<'ast, T> for TypedExpressionList<'ast, T> {
@ -721,9 +714,22 @@ impl<'ast, T: Field> MultiTyped<'ast, T> for TypedExpressionList<'ast, T> {
match *self {
TypedExpressionList::FunctionCall(_, _, _, ref types) => types.clone(),
TypedExpressionList::EmbedCall(_, _, _, ref types) => types.clone(),
TypedExpressionList::Block(_, ref values) => {
values.iter().map(|v| v.get_type()).collect()
}
}
}
}
#[derive(Clone, PartialEq, Debug, Hash, Eq)]
// a block expression which returns an `E`
pub struct BlockExpression<'ast, T, E> {
pub statements: Vec<TypedStatement<'ast, T>>,
pub value: Box<E>,
}
impl<'ast, T, E> BlockExpression<'ast, T, E> {
pub fn new(statements: Vec<TypedStatement<'ast, T>>, value: E) -> Self {
BlockExpression {
statements,
value: box value,
}
}
}
@ -731,10 +737,7 @@ impl<'ast, T: Field> MultiTyped<'ast, T> for TypedExpressionList<'ast, T> {
/// An expression of type `field`
#[derive(Clone, PartialEq, Debug, Hash, Eq)]
pub enum FieldElementExpression<'ast, T> {
Block(
Vec<TypedStatement<'ast, T>>,
Box<FieldElementExpression<'ast, T>>,
),
Block(BlockExpression<'ast, T, Self>),
Number(T),
Identifier(Identifier<'ast>),
Add(
@ -925,7 +928,7 @@ impl<'ast, T: Clone> ArrayValue<'ast, T> {
TypedExpressionOrSpread::Expression(e) => vec![Some(e.clone())],
TypedExpressionOrSpread::Spread(s) => match s.array.size().into_inner() {
UExpressionInner::Value(size) => {
let array_ty = s.array.get_array_type().clone();
let array_ty = s.array.ty().clone();
match s.array.into_inner() {
ArrayExpressionInner::Value(v) => v
@ -1036,7 +1039,7 @@ impl<'ast, T: Clone> ArrayExpression<'ast, T> {
self.inner
}
pub fn get_array_type(&self) -> ArrayType<'ast, T> {
pub fn ty(&self) -> ArrayType<'ast, T> {
ArrayType {
size: self.size(),
ty: box self.inner_type().clone(),
@ -1233,19 +1236,25 @@ impl<'ast, T> TryFrom<TypedConstant<'ast, T>> for IntExpression<'ast, T> {
}
}
impl<'ast, T: fmt::Display, E: fmt::Display> fmt::Display for BlockExpression<'ast, T, E> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(
f,
"{{\n{}\n}}",
self.statements
.iter()
.map(|s| s.to_string())
.chain(std::iter::once(self.value.to_string()))
.collect::<Vec<_>>()
.join("\n")
)
}
}
impl<'ast, T: fmt::Display> fmt::Display for FieldElementExpression<'ast, T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
FieldElementExpression::Block(ref statements, ref value) => write!(
f,
"{{{}}}",
statements
.iter()
.map(|s| s.to_string())
.chain(std::iter::once(value.to_string()))
.collect::<Vec<_>>()
.join("\n")
),
FieldElementExpression::Block(ref block) => write!(f, "{}", block),
FieldElementExpression::Number(ref i) => write!(f, "{}f", i),
FieldElementExpression::Identifier(ref var) => write!(f, "{}", var),
FieldElementExpression::Add(ref lhs, ref rhs) => write!(f, "({} + {})", lhs, rhs),
@ -1542,22 +1551,6 @@ impl<'ast, T: fmt::Display> fmt::Display for TypedExpressionList<'ast, T> {
}
write!(f, ")")
}
TypedExpressionList::Block(ref statements, ref values) => write!(
f,
"{{{}}}",
statements
.iter()
.map(|s| s.to_string())
.chain(std::iter::once(
values
.iter()
.map(|v| v.to_string())
.collect::<Vec<_>>()
.join(", ")
))
.collect::<Vec<_>>()
.join("\n")
),
}
}
}
@ -1676,7 +1669,7 @@ impl<'ast, T> Select<'ast, T> for BooleanExpression<'ast, T> {
impl<'ast, T: Clone> Select<'ast, T> for TypedExpression<'ast, T> {
fn select<I: Into<UExpression<'ast, T>>>(array: ArrayExpression<'ast, T>, index: I) -> Self {
match *array.get_array_type().ty {
match *array.ty().ty {
Type::Array(..) => ArrayExpression::select(array, index).into(),
Type::Struct(..) => StructExpression::select(array, index).into(),
Type::FieldElement => FieldElementExpression::select(array, index).into(),
@ -1880,7 +1873,6 @@ pub trait Block<'ast, T> {
fn block(
statements: Vec<TypedStatement<'ast, T>>,
value: Self,
output_type: Type<'ast, T>,
) -> Self;
}
@ -1888,10 +1880,8 @@ impl<'ast, T: Field> Block<'ast, T> for FieldElementExpression<'ast, T> {
fn block(
statements: Vec<TypedStatement<'ast, T>>,
value: Self,
output_type: Type<'ast, T>,
) -> Self {
assert_eq!(output_type, Type::FieldElement);
FieldElementExpression::Block(statements, box value)
FieldElementExpression::Block(BlockExpression::new(statements, value))
}
}
@ -1899,9 +1889,7 @@ impl<'ast, T: Field> Block<'ast, T> for BooleanExpression<'ast, T> {
fn block(
statements: Vec<TypedStatement<'ast, T>>,
value: Self,
output_type: Type<'ast, T>,
) -> Self {
assert_eq!(output_type, Type::Boolean);
BooleanExpression::Block(statements, box value)
}
}
@ -1910,12 +1898,8 @@ impl<'ast, T: Field> Block<'ast, T> for UExpression<'ast, T> {
fn block(
statements: Vec<TypedStatement<'ast, T>>,
value: Self,
output_type: Type<'ast, T>,
) -> Self {
let bitwidth = match output_type {
Type::Uint(bitwidth) => bitwidth,
_ => unreachable!(),
};
let bitwidth = value.bitwidth();
UExpressionInner::Block(statements, box value).annotate(bitwidth)
}
}
@ -1924,12 +1908,8 @@ impl<'ast, T: Field> Block<'ast, T> for ArrayExpression<'ast, T> {
fn block(
statements: Vec<TypedStatement<'ast, T>>,
value: Self,
output_type: Type<'ast, T>,
) -> Self {
let array_ty = match output_type {
Type::Array(array_ty) => array_ty,
_ => unreachable!(),
};
let array_ty = value.ty();
ArrayExpressionInner::Block(statements, box value).annotate(*array_ty.ty, array_ty.size)
}
}
@ -1938,12 +1918,8 @@ impl<'ast, T: Field> Block<'ast, T> for StructExpression<'ast, T> {
fn block(
statements: Vec<TypedStatement<'ast, T>>,
value: Self,
output_type: Type<'ast, T>,
) -> Self {
let struct_ty = match output_type {
Type::Struct(struct_ty) => struct_ty,
_ => unreachable!(),
};
let struct_ty = value.ty().clone();
StructExpressionInner::Block(statements, box value).annotate(struct_ty)
}

View file

@ -4,6 +4,16 @@ use crate::typed_absy::types::{ArrayType, StructMember, StructType};
use crate::typed_absy::*;
use zokrates_field::Field;
pub trait ResultFold<'ast, T: Field>: Sized {
fn fold<F: ResultFolder<'ast, T>>(self, f: &mut F) -> Result<Self, F::Error>;
}
impl<'ast, T: Field> ResultFold<'ast, T> for FieldElementExpression<'ast, T> {
fn fold<F: ResultFolder<'ast, T>>(self, f: &mut F) -> Result<Self, F::Error> {
f.fold_field_expression(self)
}
}
pub trait ResultFolder<'ast, T: Field>: Sized {
type Error;
@ -90,6 +100,13 @@ pub trait ResultFolder<'ast, T: Field>: Sized {
}
}
fn fold_block_expression<E: ResultFold<'ast, T>>(
&mut self,
block: BlockExpression<'ast, T, E>,
) -> Result<BlockExpression<'ast, T, E>, Self::Error> {
fold_block_expression(self, block)
}
fn fold_array_type(
&mut self,
t: ArrayType<'ast, T>,
@ -418,14 +435,9 @@ pub fn fold_field_expression<'ast, T: Field, F: ResultFolder<'ast, T>>(
e: FieldElementExpression<'ast, T>,
) -> Result<FieldElementExpression<'ast, T>, F::Error> {
let e = match e {
FieldElementExpression::Block(statements, box value) => FieldElementExpression::Block(
statements
.into_iter()
.map(|s| f.fold_statement(s))
.collect::<Result<Vec<_>, _>>()
.map(|r| r.into_iter().flatten().collect())?,
box f.fold_field_expression(value)?,
),
FieldElementExpression::Block(block) => {
FieldElementExpression::Block(f.fold_block_expression(block)?)
}
FieldElementExpression::Number(n) => FieldElementExpression::Number(n),
FieldElementExpression::Identifier(id) => {
FieldElementExpression::Identifier(f.fold_name(id)?)
@ -502,6 +514,23 @@ pub fn fold_int_expression<'ast, T: Field, F: ResultFolder<'ast, T>>(
unreachable!()
}
pub fn fold_block_expression<'ast, T: Field, E: ResultFold<'ast, T>, F: ResultFolder<'ast, T>>(
f: &mut F,
block: BlockExpression<'ast, T, E>,
) -> Result<BlockExpression<'ast, T, E>, F::Error> {
Ok(BlockExpression {
statements: block
.statements
.into_iter()
.map(|s| f.fold_statement(s))
.collect::<Result<Vec<_>, _>>()?
.into_iter()
.flatten()
.collect(),
value: box block.value.fold(f)?,
})
}
pub fn fold_boolean_expression<'ast, T: Field, F: ResultFolder<'ast, T>>(
f: &mut F,
e: BooleanExpression<'ast, T>,
@ -834,17 +863,6 @@ pub fn fold_expression_list<'ast, T: Field, F: ResultFolder<'ast, T>>(
.collect::<Result<_, _>>()?,
))
}
TypedExpressionList::Block(statements, values) => Ok(TypedExpressionList::Block(
statements
.into_iter()
.map(|s| f.fold_statement(s))
.collect::<Result<Vec<_>, _>>()
.map(|v| v.into_iter().flatten().collect())?,
values
.into_iter()
.map(|v| f.fold_expression(v))
.collect::<Result<_, _>>()?,
)),
}
}

View file

@ -0,0 +1,33 @@
{
"entry_point": "./tests/tests/panics/loop_bound.zok",
"curves": ["Bn128", "Bls12_381", "Bls12_377", "Bw6_761"],
"tests": [
{
"input": {
"values": [
"0"
]
},
"output": {
"Err": {
"UnsatisfiedConstraint": {
"left": "0",
"right": "1"
}
}
}
},
{
"input": {
"values": [
"1"
]
},
"output": {
"Ok": {
"values": []
}
}
}
]
}

View file

@ -0,0 +1,9 @@
def throwing_bound(u32 x) -> u32:
assert(x == 1)
return 1
// Even if the bound is constant at compile time, it can throw at runtime
def main(u32 x):
for u32 i in 0..throwing_bound(x) do
endfor
return

View file

@ -0,0 +1,64 @@
{
"entry_point": "./tests/tests/panics/panic_isolation.zok",
"curves": ["Bn128"],
"tests": [
{
"input": {
"values": [
"1",
"42",
"42",
"0"
]
},
"output": {
"Err": {
"UnsatisfiedConstraint": {
"left": "1",
"right": "21888242871839275222246405745257275088548364400416034343698204186575808495577"
}
}
}
},
{
"input": {
"values": [
"1",
"1",
"1",
"1"
]
},
"output": {
"Ok": {
"values": [
"1",
"1",
"1",
"1"
]
}
}
},
{
"input": {
"values": [
"0",
"2",
"2",
"0"
]
},
"output": {
"Ok": {
"values": [
"0",
"2",
"2",
"0"
]
}
}
}
]
}

View file

@ -0,0 +1,31 @@
def zero(field x) -> field:
assert(x == 0)
return 0
def inverse(field x) -> field:
assert(x != 0)
return 1/x
def yes(bool x) -> bool:
assert(x)
return x
def no(bool x) -> bool:
assert(!x)
return x
def ones(field[2] a) -> field[2]:
assert(a == [1, 1])
return a
def twos(field[2] a) -> field[2]:
assert(a == [2, 2])
return a
def main(bool condition, field[2] a, field x) -> (bool, field[2], field):
// first branch asserts that `condition` is true, second branch asserts that `condition` is false. This should never throw.
// first branch asserts that all elements in `a` are 1, 2 in the second branch. This should throw only if `a` is neither ones or zeroes
// first branch asserts that `x` is zero and returns it, second branch asserts that `x` isn't 0 and returns its inverse (which internally generates a failing assert if x is 0). This should never throw
return if condition then yes(condition) else no(condition) fi,\
if condition then ones(a) else twos(a) fi,\
if x == 0 then zero(x) else inverse(x) fi

View file

@ -163,8 +163,9 @@ fn compile_and_run<T: Field>(t: Tests) {
let mut s = String::new();
code.read_to_string(&mut s).unwrap();
let context = format!(
"\n{}\nCalled with input ({})\n",
"\n{}\nCalled on curve {} with input ({})\n",
s,
T::name(),
input
.iter()
.map(|i| i.to_string())