1
0
Fork 0
mirror of synced 2025-09-23 12:18:44 +00:00

implement blocks for all types, add breaking example for edge case

This commit is contained in:
schaeff 2021-05-16 23:03:05 +02:00
parent 7b46a33c5b
commit 09b1e52608
9 changed files with 111 additions and 131 deletions

View file

@ -0,0 +1,10 @@
def throwing_bound<N>(u32 x) -> u32:
assert(x == N)
return 1
// this should compile: the conditional, even though it can throw, has a constant compile-time value `1`
// the value of the blocks should be propagated out, so that `if x == 0 then 1 else 1 fi` can be determined to be `1`
def main(u32 x):
for u32 i in 0..if x == 0 then throwing_bound::<0>(x) else throwing_bound::<1>(x) fi do
endfor
return

View file

@ -1,5 +1,5 @@
def bound(field x) -> u32:
return 41 + x
return 41 + 1
def main(field a) -> field:
field x = 7

View file

@ -206,7 +206,6 @@ impl<'ast, T: Field> Flattener<T> {
fn fold_expression_list(
&mut self,
statements_buffer: &mut Vec<zir::ZirStatement<'ast, T>>,
es: typed_absy::TypedExpressionList<'ast, T>,
) -> zir::ZirExpressionList<'ast, T> {
match es {
@ -227,7 +226,6 @@ impl<'ast, T: Field> Flattener<T> {
fn fold_field_expression(
&mut self,
statements_buffer: &mut Vec<zir::ZirStatement<'ast, T>>,
e: typed_absy::FieldElementExpression<'ast, T>,
) -> zir::FieldElementExpression<'ast, T> {
fold_field_expression(self, statements_buffer, e)
@ -328,11 +326,12 @@ pub fn fold_array_expression_inner<'ast, T: Field>(
array: typed_absy::ArrayExpressionInner<'ast, T>,
) -> Vec<zir::ZirExpression<'ast, T>> {
match array {
typed_absy::ArrayExpressionInner::Block(statements, box value) => {
statements
typed_absy::ArrayExpressionInner::Block(block) => {
block
.statements
.into_iter()
.for_each(|s| f.fold_statement(statements_buffer, s));
f.fold_array_expression(statements_buffer, value)
f.fold_array_expression(statements_buffer, *block.value)
}
typed_absy::ArrayExpressionInner::Identifier(id) => {
let variables = flatten_identifier_rec(
@ -472,11 +471,12 @@ pub fn fold_struct_expression_inner<'ast, T: Field>(
struc: typed_absy::StructExpressionInner<'ast, T>,
) -> Vec<zir::ZirExpression<'ast, T>> {
match struc {
typed_absy::StructExpressionInner::Block(statements, box value) => {
statements
typed_absy::StructExpressionInner::Block(block) => {
block
.statements
.into_iter()
.for_each(|s| f.fold_statement(statements_buffer, s));
f.fold_struct_expression(statements_buffer, value)
f.fold_struct_expression(statements_buffer, *block.value)
}
typed_absy::StructExpressionInner::Identifier(id) => {
let variables = flatten_identifier_rec(
@ -699,11 +699,12 @@ pub fn fold_boolean_expression<'ast, T: Field>(
e: typed_absy::BooleanExpression<'ast, T>,
) -> zir::BooleanExpression<'ast, T> {
match e {
typed_absy::BooleanExpression::Block(statements, box value) => {
statements
typed_absy::BooleanExpression::Block(block) => {
block
.statements
.into_iter()
.for_each(|s| f.fold_statement(statements_buffer, s));
f.fold_boolean_expression(statements_buffer, value)
f.fold_boolean_expression(statements_buffer, *block.value)
}
typed_absy::BooleanExpression::Value(v) => zir::BooleanExpression::Value(v),
typed_absy::BooleanExpression::Identifier(id) => zir::BooleanExpression::Identifier(
@ -899,11 +900,12 @@ pub fn fold_uint_expression_inner<'ast, T: Field>(
e: typed_absy::UExpressionInner<'ast, T>,
) -> zir::UExpressionInner<'ast, T> {
match e {
typed_absy::UExpressionInner::Block(statements, box value) => {
statements
typed_absy::UExpressionInner::Block(block) => {
block
.statements
.into_iter()
.for_each(|s| f.fold_statement(statements_buffer, s));
f.fold_uint_expression(statements_buffer, value)
f.fold_uint_expression(statements_buffer, *block.value)
.into_inner()
}
typed_absy::UExpressionInner::Value(v) => zir::UExpressionInner::Value(v),

View file

@ -150,8 +150,8 @@ fn is_constant<T: Field>(e: &TypedExpression<T>) -> bool {
TypedExpression::Uint(a) => {
matches!(a.as_inner(), UExpressionInner::Value(..))
|| match a.as_inner() {
UExpressionInner::Block(_, e) => {
is_constant(&TypedExpression::from(*e.clone()))
UExpressionInner::Block(block) => {
is_constant(&TypedExpression::from(*block.value.clone()))
}
_ => false,
}

View file

@ -14,6 +14,30 @@ impl<'ast, T: Field> Fold<'ast, T> for FieldElementExpression<'ast, T> {
}
}
impl<'ast, T: Field> Fold<'ast, T> for BooleanExpression<'ast, T> {
fn fold<F: Folder<'ast, T>>(self, f: &mut F) -> Self {
f.fold_boolean_expression(self)
}
}
impl<'ast, T: Field> Fold<'ast, T> for UExpression<'ast, T> {
fn fold<F: Folder<'ast, T>>(self, f: &mut F) -> Self {
f.fold_uint_expression(self)
}
}
impl<'ast, T: Field> Fold<'ast, T> for StructExpression<'ast, T> {
fn fold<F: Folder<'ast, T>>(self, f: &mut F) -> Self {
f.fold_struct_expression(self)
}
}
impl<'ast, T: Field> Fold<'ast, T> for ArrayExpression<'ast, T> {
fn fold<F: Folder<'ast, T>>(self, f: &mut F) -> Self {
f.fold_array_expression(self)
}
}
pub trait Folder<'ast, T: Field>: Sized {
fn fold_program(&mut self, p: TypedProgram<'ast, T>) -> TypedProgram<'ast, T> {
fold_program(self, p)
@ -274,13 +298,9 @@ pub fn fold_array_expression_inner<'ast, T: Field, F: Folder<'ast, T>>(
e: ArrayExpressionInner<'ast, T>,
) -> ArrayExpressionInner<'ast, T> {
match e {
ArrayExpressionInner::Block(statements, box value) => ArrayExpressionInner::Block(
statements
.into_iter()
.flat_map(|s| f.fold_statement(s))
.collect(),
box f.fold_array_expression(value),
),
ArrayExpressionInner::Block(block) => {
ArrayExpressionInner::Block(f.fold_block_expression(block))
}
ArrayExpressionInner::Identifier(id) => ArrayExpressionInner::Identifier(f.fold_name(id)),
ArrayExpressionInner::Value(exprs) => ArrayExpressionInner::Value(
exprs
@ -332,13 +352,9 @@ pub fn fold_struct_expression_inner<'ast, T: Field, F: Folder<'ast, T>>(
e: StructExpressionInner<'ast, T>,
) -> StructExpressionInner<'ast, T> {
match e {
StructExpressionInner::Block(statements, box value) => StructExpressionInner::Block(
statements
.into_iter()
.flat_map(|s| f.fold_statement(s))
.collect(),
box f.fold_struct_expression(value),
),
StructExpressionInner::Block(block) => {
StructExpressionInner::Block(f.fold_block_expression(block))
}
StructExpressionInner::Identifier(id) => StructExpressionInner::Identifier(f.fold_name(id)),
StructExpressionInner::Value(exprs) => {
StructExpressionInner::Value(exprs.into_iter().map(|e| f.fold_expression(e)).collect())
@ -455,13 +471,7 @@ pub fn fold_boolean_expression<'ast, T: Field, F: Folder<'ast, T>>(
e: BooleanExpression<'ast, T>,
) -> BooleanExpression<'ast, T> {
match e {
BooleanExpression::Block(statements, box value) => BooleanExpression::Block(
statements
.into_iter()
.flat_map(|s| f.fold_statement(s))
.collect(),
box f.fold_boolean_expression(value),
),
BooleanExpression::Block(block) => BooleanExpression::Block(f.fold_block_expression(block)),
BooleanExpression::Value(v) => BooleanExpression::Value(v),
BooleanExpression::Identifier(id) => BooleanExpression::Identifier(f.fold_name(id)),
BooleanExpression::FieldEq(box e1, box e2) => {
@ -585,13 +595,7 @@ pub fn fold_uint_expression_inner<'ast, T: Field, F: Folder<'ast, T>>(
e: UExpressionInner<'ast, T>,
) -> UExpressionInner<'ast, T> {
match e {
UExpressionInner::Block(statements, box value) => UExpressionInner::Block(
statements
.into_iter()
.flat_map(|s| f.fold_statement(s))
.collect(),
box f.fold_uint_expression(value),
),
UExpressionInner::Block(block) => UExpressionInner::Block(f.fold_block_expression(block)),
UExpressionInner::Value(v) => UExpressionInner::Value(v),
UExpressionInner::Identifier(id) => UExpressionInner::Identifier(f.fold_name(id)),
UExpressionInner::Add(box left, box right) => {

View file

@ -585,16 +585,7 @@ impl<'ast, T: fmt::Display> fmt::Display for ArrayExpression<'ast, T> {
impl<'ast, T: fmt::Display> fmt::Display for StructExpression<'ast, T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self.inner {
StructExpressionInner::Block(ref statements, ref value) => write!(
f,
"{{{}}}",
statements
.iter()
.map(|s| s.to_string())
.chain(std::iter::once(value.to_string()))
.collect::<Vec<_>>()
.join("\n")
),
StructExpressionInner::Block(ref block) => write!(f, "{}", block),
StructExpressionInner::Identifier(ref var) => write!(f, "{}", var),
StructExpressionInner::Value(ref values) => write!(
f,
@ -822,10 +813,7 @@ impl<'ast, T> From<T> for FieldElementExpression<'ast, T> {
/// An expression of type `bool`
#[derive(Clone, PartialEq, Debug, Hash, Eq)]
pub enum BooleanExpression<'ast, T> {
Block(
Vec<TypedStatement<'ast, T>>,
Box<BooleanExpression<'ast, T>>,
),
Block(BlockExpression<'ast, T, Self>),
Identifier(Identifier<'ast>),
Value(bool),
FieldLt(
@ -982,7 +970,7 @@ impl<'ast, T> std::iter::FromIterator<TypedExpressionOrSpread<'ast, T>> for Arra
#[derive(Clone, PartialEq, Debug, Hash, Eq)]
pub enum ArrayExpressionInner<'ast, T> {
Block(Vec<TypedStatement<'ast, T>>, Box<ArrayExpression<'ast, T>>),
Block(BlockExpression<'ast, T, ArrayExpression<'ast, T>>),
Identifier(Identifier<'ast>),
Value(ArrayValue<'ast, T>),
FunctionCall(
@ -1091,7 +1079,7 @@ impl<'ast, T> StructExpression<'ast, T> {
#[derive(Clone, PartialEq, Debug, Hash, Eq)]
pub enum StructExpressionInner<'ast, T> {
Block(Vec<TypedStatement<'ast, T>>, Box<StructExpression<'ast, T>>),
Block(BlockExpression<'ast, T, StructExpression<'ast, T>>),
Identifier(Identifier<'ast>),
Value(Vec<TypedExpression<'ast, T>>),
FunctionCall(
@ -1305,16 +1293,7 @@ impl<'ast, T: fmt::Display> fmt::Display for FieldElementExpression<'ast, T> {
impl<'ast, T: fmt::Display> fmt::Display for UExpression<'ast, T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self.inner {
UExpressionInner::Block(ref statements, ref value) => write!(
f,
"{{{}}}",
statements
.iter()
.map(|s| s.to_string())
.chain(std::iter::once(value.to_string()))
.collect::<Vec<_>>()
.join("\n")
),
UExpressionInner::Block(ref block) => write!(f, "{}", block,),
UExpressionInner::Value(ref v) => write!(f, "{}", v),
UExpressionInner::Identifier(ref var) => write!(f, "{}", var),
UExpressionInner::Add(ref lhs, ref rhs) => write!(f, "({} + {})", lhs, rhs),
@ -1372,16 +1351,7 @@ impl<'ast, T: fmt::Display> fmt::Display for UExpression<'ast, T> {
impl<'ast, T: fmt::Display> fmt::Display for BooleanExpression<'ast, T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
BooleanExpression::Block(ref statements, ref value) => write!(
f,
"{{{}}}",
statements
.iter()
.map(|s| s.to_string())
.chain(std::iter::once(value.to_string()))
.collect::<Vec<_>>()
.join("\n")
),
BooleanExpression::Block(ref block) => write!(f, "{}", block,),
BooleanExpression::Identifier(ref var) => write!(f, "{}", var),
BooleanExpression::FieldLt(ref lhs, ref rhs) => write!(f, "{} < {}", lhs, rhs),
BooleanExpression::FieldLe(ref lhs, ref rhs) => write!(f, "{} <= {}", lhs, rhs),
@ -1439,16 +1409,7 @@ impl<'ast, T: fmt::Display> fmt::Display for BooleanExpression<'ast, T> {
impl<'ast, T: fmt::Display> fmt::Display for ArrayExpressionInner<'ast, T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
ArrayExpressionInner::Block(ref statements, ref value) => write!(
f,
"{{{}}}",
statements
.iter()
.map(|s| s.to_string())
.chain(std::iter::once(value.to_string()))
.collect::<Vec<_>>()
.join("\n")
),
ArrayExpressionInner::Block(ref block) => write!(f, "{}", block,),
ArrayExpressionInner::Identifier(ref var) => write!(f, "{}", var),
ArrayExpressionInner::Value(ref values) => write!(
f,
@ -1881,21 +1842,22 @@ impl<'ast, T: Field> Block<'ast, T> for FieldElementExpression<'ast, T> {
impl<'ast, T: Field> Block<'ast, T> for BooleanExpression<'ast, T> {
fn block(statements: Vec<TypedStatement<'ast, T>>, value: Self) -> Self {
BooleanExpression::Block(statements, box value)
BooleanExpression::Block(BlockExpression::new(statements, value))
}
}
impl<'ast, T: Field> Block<'ast, T> for UExpression<'ast, T> {
fn block(statements: Vec<TypedStatement<'ast, T>>, value: Self) -> Self {
let bitwidth = value.bitwidth();
UExpressionInner::Block(statements, box value).annotate(bitwidth)
UExpressionInner::Block(BlockExpression::new(statements, value)).annotate(bitwidth)
}
}
impl<'ast, T: Field> Block<'ast, T> for ArrayExpression<'ast, T> {
fn block(statements: Vec<TypedStatement<'ast, T>>, value: Self) -> Self {
let array_ty = value.ty();
ArrayExpressionInner::Block(statements, box value).annotate(*array_ty.ty, array_ty.size)
ArrayExpressionInner::Block(BlockExpression::new(statements, value))
.annotate(*array_ty.ty, array_ty.size)
}
}
@ -1903,6 +1865,6 @@ impl<'ast, T: Field> Block<'ast, T> for StructExpression<'ast, T> {
fn block(statements: Vec<TypedStatement<'ast, T>>, value: Self) -> Self {
let struct_ty = value.ty().clone();
StructExpressionInner::Block(statements, box value).annotate(struct_ty)
StructExpressionInner::Block(BlockExpression::new(statements, value)).annotate(struct_ty)
}
}

View file

@ -14,6 +14,30 @@ impl<'ast, T: Field> ResultFold<'ast, T> for FieldElementExpression<'ast, T> {
}
}
impl<'ast, T: Field> ResultFold<'ast, T> for BooleanExpression<'ast, T> {
fn fold<F: ResultFolder<'ast, T>>(self, f: &mut F) -> Result<Self, F::Error> {
f.fold_boolean_expression(self)
}
}
impl<'ast, T: Field> ResultFold<'ast, T> for UExpression<'ast, T> {
fn fold<F: ResultFolder<'ast, T>>(self, f: &mut F) -> Result<Self, F::Error> {
f.fold_uint_expression(self)
}
}
impl<'ast, T: Field> ResultFold<'ast, T> for ArrayExpression<'ast, T> {
fn fold<F: ResultFolder<'ast, T>>(self, f: &mut F) -> Result<Self, F::Error> {
f.fold_array_expression(self)
}
}
impl<'ast, T: Field> ResultFold<'ast, T> for StructExpression<'ast, T> {
fn fold<F: ResultFolder<'ast, T>>(self, f: &mut F) -> Result<Self, F::Error> {
f.fold_struct_expression(self)
}
}
pub trait ResultFolder<'ast, T: Field>: Sized {
type Error;
@ -319,14 +343,9 @@ pub fn fold_array_expression_inner<'ast, T: Field, F: ResultFolder<'ast, T>>(
e: ArrayExpressionInner<'ast, T>,
) -> Result<ArrayExpressionInner<'ast, T>, F::Error> {
let e = match e {
ArrayExpressionInner::Block(statements, box value) => ArrayExpressionInner::Block(
statements
.into_iter()
.map(|s| f.fold_statement(s))
.collect::<Result<Vec<_>, _>>()
.map(|r| r.into_iter().flatten().collect())?,
box f.fold_array_expression(value)?,
),
ArrayExpressionInner::Block(block) => {
ArrayExpressionInner::Block(f.fold_block_expression(block)?)
}
ArrayExpressionInner::Identifier(id) => ArrayExpressionInner::Identifier(f.fold_name(id)?),
ArrayExpressionInner::Value(exprs) => ArrayExpressionInner::Value(
exprs
@ -382,14 +401,9 @@ pub fn fold_struct_expression_inner<'ast, T: Field, F: ResultFolder<'ast, T>>(
e: StructExpressionInner<'ast, T>,
) -> Result<StructExpressionInner<'ast, T>, F::Error> {
let e = match e {
StructExpressionInner::Block(statements, box value) => StructExpressionInner::Block(
statements
.into_iter()
.map(|s| f.fold_statement(s))
.collect::<Result<Vec<_>, _>>()
.map(|r| r.into_iter().flatten().collect())?,
box f.fold_struct_expression(value)?,
),
StructExpressionInner::Block(block) => {
StructExpressionInner::Block(f.fold_block_expression(block)?)
}
StructExpressionInner::Identifier(id) => {
StructExpressionInner::Identifier(f.fold_name(id)?)
}
@ -536,14 +550,9 @@ pub fn fold_boolean_expression<'ast, T: Field, F: ResultFolder<'ast, T>>(
e: BooleanExpression<'ast, T>,
) -> Result<BooleanExpression<'ast, T>, F::Error> {
let e = match e {
BooleanExpression::Block(statements, box value) => BooleanExpression::Block(
statements
.into_iter()
.map(|s| f.fold_statement(s))
.collect::<Result<Vec<_>, _>>()
.map(|r| r.into_iter().flatten().collect())?,
box f.fold_boolean_expression(value)?,
),
BooleanExpression::Block(block) => {
BooleanExpression::Block(f.fold_block_expression(block)?)
}
BooleanExpression::Value(v) => BooleanExpression::Value(v),
BooleanExpression::Identifier(id) => BooleanExpression::Identifier(f.fold_name(id)?),
BooleanExpression::FieldEq(box e1, box e2) => {
@ -671,14 +680,7 @@ pub fn fold_uint_expression_inner<'ast, T: Field, F: ResultFolder<'ast, T>>(
e: UExpressionInner<'ast, T>,
) -> Result<UExpressionInner<'ast, T>, F::Error> {
let e = match e {
UExpressionInner::Block(statements, box value) => UExpressionInner::Block(
statements
.into_iter()
.map(|s| f.fold_statement(s))
.collect::<Result<Vec<_>, _>>()
.map(|r| r.into_iter().flatten().collect())?,
box f.fold_uint_expression(value)?,
),
UExpressionInner::Block(block) => UExpressionInner::Block(f.fold_block_expression(block)?),
UExpressionInner::Value(v) => UExpressionInner::Value(v),
UExpressionInner::Identifier(id) => UExpressionInner::Identifier(f.fold_name(id)?),
UExpressionInner::Add(box left, box right) => {

View file

@ -175,7 +175,7 @@ impl<'ast, T> PartialEq<usize> for UExpression<'ast, T> {
#[derive(Clone, PartialEq, Eq, Hash, Debug)]
pub enum UExpressionInner<'ast, T> {
Block(Vec<TypedStatement<'ast, T>>, Box<UExpression<'ast, T>>),
Block(BlockExpression<'ast, T, UExpression<'ast, T>>),
Identifier(Identifier<'ast>),
Value(u128),
Add(Box<UExpression<'ast, T>>, Box<UExpression<'ast, T>>),

View file

@ -1,9 +1,9 @@
def throwing_bound(u32 x) -> u32:
assert(x == 1)
def throwing_bound<N>(u32 x) -> u32:
assert(x == N)
return 1
// Even if the bound is constant at compile time, it can throw at runtime
def main(u32 x):
for u32 i in 0..throwing_bound(x) do
for u32 i in 0..throwing_bound::<1>(x) do
endfor
return