1
0
Fork 0
mirror of synced 2025-09-23 12:18:44 +00:00
This commit is contained in:
schaeff 2022-08-05 16:23:18 +02:00
parent 53b62f568b
commit 8caa6b4720
19 changed files with 660 additions and 400 deletions

View file

@ -37,6 +37,7 @@ impl From<crate::zir::RuntimeError> for RuntimeError {
match error {
crate::zir::RuntimeError::SourceAssertion(s) => RuntimeError::SourceAssertion(s),
crate::zir::RuntimeError::SelectRangeCheck => RuntimeError::SelectRangeCheck,
crate::zir::RuntimeError::DivisionByZero => RuntimeError::Inverse,
}
}
}

View file

@ -264,6 +264,10 @@ pub trait Folder<'ast, T: Field>: Sized {
fold_statement(self, s)
}
fn fold_definition_rhs(&mut self, rhs: DefinitionRhs<'ast, T>) -> DefinitionRhs<'ast, T> {
fold_definition_rhs(self, rhs)
}
fn fold_embed_call(&mut self, e: EmbedCall<'ast, T>) -> EmbedCall<'ast, T> {
fold_embed_call(self, e)
}
@ -491,6 +495,16 @@ pub fn fold_constant_symbol_declaration<'ast, T: Field, F: Folder<'ast, T>>(
}
}
pub fn fold_definition_rhs<'ast, T: Field, F: Folder<'ast, T>>(
f: &mut F,
rhs: DefinitionRhs<'ast, T>,
) -> DefinitionRhs<'ast, T> {
match rhs {
DefinitionRhs::EmbedCall(c) => DefinitionRhs::EmbedCall(f.fold_embed_call(c)),
DefinitionRhs::Expression(e) => DefinitionRhs::Expression(f.fold_expression(e)),
}
}
pub fn fold_statement<'ast, T: Field, F: Folder<'ast, T>>(
f: &mut F,
s: TypedStatement<'ast, T>,
@ -498,7 +512,7 @@ pub fn fold_statement<'ast, T: Field, F: Folder<'ast, T>>(
let res = match s {
TypedStatement::Return(e) => TypedStatement::Return(f.fold_expression(e)),
TypedStatement::Definition(a, e) => {
TypedStatement::Definition(f.fold_assignee(a), f.fold_expression(e))
TypedStatement::Definition(f.fold_assignee(a), f.fold_definition_rhs(e))
}
TypedStatement::Assertion(e, error) => {
TypedStatement::Assertion(f.fold_boolean_expression(e), error)
@ -515,12 +529,6 @@ pub fn fold_statement<'ast, T: Field, F: Folder<'ast, T>>(
TypedStatement::Log(s, e) => {
TypedStatement::Log(s, e.into_iter().map(|e| f.fold_expression(e)).collect())
}
TypedStatement::EmbedCallDefinition(assignee, embed_call) => {
TypedStatement::EmbedCallDefinition(
f.fold_assignee(assignee),
f.fold_embed_call(embed_call),
)
}
s => s,
};
vec![res]

View file

@ -588,6 +588,7 @@ impl fmt::Display for AssertionMetadata {
pub enum RuntimeError {
SourceAssertion(AssertionMetadata),
SelectRangeCheck,
DivisionByZero,
}
impl fmt::Display for RuntimeError {
@ -595,6 +596,7 @@ impl fmt::Display for RuntimeError {
match self {
RuntimeError::SourceAssertion(metadata) => write!(f, "{}", metadata),
RuntimeError::SelectRangeCheck => write!(f, "Range check on array access"),
RuntimeError::DivisionByZero => write!(f, "Division by zero"),
}
}
}
@ -646,12 +648,39 @@ impl<'ast, T: fmt::Display> fmt::Display for EmbedCall<'ast, T> {
}
}
#[derive(Clone, PartialEq, Debug, Hash, Eq, PartialOrd, Ord)]
pub enum DefinitionRhs<'ast, T> {
Expression(TypedExpression<'ast, T>),
EmbedCall(EmbedCall<'ast, T>),
}
impl<'ast, T> From<TypedExpression<'ast, T>> for DefinitionRhs<'ast, T> {
fn from(e: TypedExpression<'ast, T>) -> Self {
Self::Expression(e)
}
}
impl<'ast, T> From<EmbedCall<'ast, T>> for DefinitionRhs<'ast, T> {
fn from(c: EmbedCall<'ast, T>) -> Self {
Self::EmbedCall(c)
}
}
impl<'ast, T: fmt::Display> fmt::Display for DefinitionRhs<'ast, T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
DefinitionRhs::EmbedCall(c) => write!(f, "{}", c),
DefinitionRhs::Expression(e) => write!(f, "{}", e),
}
}
}
/// A statement in a `TypedFunction`
#[allow(clippy::large_enum_variant)]
#[derive(Clone, PartialEq, Debug, Hash, Eq, PartialOrd, Ord)]
pub enum TypedStatement<'ast, T> {
Return(TypedExpression<'ast, T>),
Definition(TypedAssignee<'ast, T>, TypedExpression<'ast, T>),
Definition(TypedAssignee<'ast, T>, DefinitionRhs<'ast, T>),
Assertion(BooleanExpression<'ast, T>, RuntimeError),
For(
Variable<'ast, T>,
@ -660,7 +689,6 @@ pub enum TypedStatement<'ast, T> {
Vec<TypedStatement<'ast, T>>,
),
Log(FormatString, Vec<TypedExpression<'ast, T>>),
EmbedCallDefinition(TypedAssignee<'ast, T>, EmbedCall<'ast, T>),
// Aux
PushCallLog(
DeclarationFunctionKey<'ast, T>,
@ -669,6 +697,16 @@ pub enum TypedStatement<'ast, T> {
PopCallLog,
}
impl<'ast, T> TypedStatement<'ast, T> {
pub fn definition(a: TypedAssignee<'ast, T>, e: TypedExpression<'ast, T>) -> Self {
Self::Definition(a, e.into())
}
pub fn embed_call_definition(a: TypedAssignee<'ast, T>, c: EmbedCall<'ast, T>) -> Self {
Self::Definition(a, c.into())
}
}
impl<'ast, T: fmt::Display> TypedStatement<'ast, T> {
fn fmt_indented(&self, f: &mut fmt::Formatter, depth: usize) -> fmt::Result {
match self {
@ -710,9 +748,6 @@ impl<'ast, T: fmt::Display> fmt::Display for TypedStatement<'ast, T> {
}
write!(f, "\t}}")
}
TypedStatement::EmbedCallDefinition(ref lhs, ref rhs) => {
write!(f, "{} = {};", lhs, rhs)
}
TypedStatement::Log(ref l, ref expressions) => write!(
f,
"log({}, {})",
@ -1642,7 +1677,7 @@ impl<'ast, T: fmt::Display> fmt::Display for UExpression<'ast, T> {
UExpressionInner::Rem(ref lhs, ref rhs) => write!(f, "({} % {})", lhs, rhs),
UExpressionInner::RightShift(ref e, ref by) => write!(f, "({} >> {})", e, by),
UExpressionInner::LeftShift(ref e, ref by) => write!(f, "({} << {})", e, by),
UExpressionInner::Not(ref e) => write!(f, "!{}", e),
UExpressionInner::Not(ref e) => write!(f, "!({})", e),
UExpressionInner::Neg(ref e) => write!(f, "(-{})", e),
UExpressionInner::Pos(ref e) => write!(f, "(+{})", e),
UExpressionInner::Select(ref select) => write!(f, "{}", select),
@ -1675,7 +1710,7 @@ impl<'ast, T: fmt::Display> fmt::Display for BooleanExpression<'ast, T> {
BooleanExpression::UintEq(ref e) => write!(f, "{}", e),
BooleanExpression::Or(ref lhs, ref rhs) => write!(f, "{} || {}", lhs, rhs),
BooleanExpression::And(ref lhs, ref rhs) => write!(f, "{} && {}", lhs, rhs),
BooleanExpression::Not(ref exp) => write!(f, "!{}", exp),
BooleanExpression::Not(ref exp) => write!(f, "!({})", exp),
BooleanExpression::Value(b) => write!(f, "{}", b),
BooleanExpression::FunctionCall(ref function_call) => write!(f, "{}", function_call),
BooleanExpression::Conditional(ref c) => write!(f, "{}", c),

View file

@ -385,6 +385,13 @@ pub trait ResultFolder<'ast, T: Field>: Sized {
fold_statement(self, s)
}
fn fold_definition_rhs(
&mut self,
rhs: DefinitionRhs<'ast, T>,
) -> Result<DefinitionRhs<'ast, T>, Self::Error> {
fold_definition_rhs(self, rhs)
}
fn fold_embed_call(
&mut self,
e: EmbedCall<'ast, T>,
@ -508,7 +515,7 @@ pub fn fold_statement<'ast, T: Field, F: ResultFolder<'ast, T>>(
let res = match s {
TypedStatement::Return(e) => TypedStatement::Return(f.fold_expression(e)?),
TypedStatement::Definition(a, e) => {
TypedStatement::Definition(f.fold_assignee(a)?, f.fold_expression(e)?)
TypedStatement::Definition(f.fold_assignee(a)?, f.fold_definition_rhs(e)?)
}
TypedStatement::Assertion(e, error) => {
TypedStatement::Assertion(f.fold_boolean_expression(e)?, error)
@ -531,17 +538,21 @@ pub fn fold_statement<'ast, T: Field, F: ResultFolder<'ast, T>>(
.map(|e| f.fold_expression(e))
.collect::<Result<Vec<_>, _>>()?,
),
TypedStatement::EmbedCallDefinition(assignee, embed_call) => {
TypedStatement::EmbedCallDefinition(
f.fold_assignee(assignee)?,
f.fold_embed_call(embed_call)?,
)
}
s => s,
};
Ok(vec![res])
}
pub fn fold_definition_rhs<'ast, T: Field, F: ResultFolder<'ast, T>>(
f: &mut F,
rhs: DefinitionRhs<'ast, T>,
) -> Result<DefinitionRhs<'ast, T>, F::Error> {
Ok(match rhs {
DefinitionRhs::EmbedCall(c) => DefinitionRhs::EmbedCall(f.fold_embed_call(c)?),
DefinitionRhs::Expression(e) => DefinitionRhs::Expression(f.fold_expression(e)?),
})
}
pub fn fold_embed_call<'ast, T: Field, F: ResultFolder<'ast, T>>(
f: &mut F,
e: EmbedCall<'ast, T>,

View file

@ -92,6 +92,7 @@ pub type ZirAssignee<'ast> = Variable<'ast>;
pub enum RuntimeError {
SourceAssertion(String),
SelectRangeCheck,
DivisionByZero,
}
impl fmt::Display for RuntimeError {
@ -99,6 +100,7 @@ impl fmt::Display for RuntimeError {
match self {
RuntimeError::SourceAssertion(message) => write!(f, "{}", message),
RuntimeError::SelectRangeCheck => write!(f, "Range check on array access"),
RuntimeError::DivisionByZero => write!(f, "Division by zero"),
}
}
}

View file

@ -4,9 +4,10 @@ from "field" import FIELD_SIZE_IN_BITS;
// It should not work for the maxvalue = 2^(pbits - 2) - 1 augmented by one
// /!\ should be called with a = 0
def main(field a) -> bool {
def main(field a) {
u32 pbits = FIELD_SIZE_IN_BITS;
// we added a = 0 to prevent the condition to be evaluated at compile time
field maxvalue = a + (2**(pbits - 2) - 1);
return a < maxvalue + 1;
bool c = a < maxvalue + 1;
return
}

View file

@ -4,9 +4,10 @@ from "field" import FIELD_SIZE_IN_BITS;
// It should not work for the maxvalue = 2^(pbits - 2) - 1 augmented by one
// /!\ should be called with a = 0
def main(field a) -> bool {
def main(field a) {
u32 pbits = FIELD_SIZE_IN_BITS;
// we added a = 0 to prevent the condition to be evaluated at compile time
field maxvalue = a + (2**(pbits - 2) - 1);
return maxvalue + 1 < a;
bool c = maxvalue + 1 < a;
return
}

View file

@ -1739,7 +1739,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
types,
)
.map_err(|e| vec![e])?;
Ok(TypedStatement::Definition(assignee, e))
Ok(TypedStatement::definition(assignee, e))
}
_ => {
// check the expression to be assigned
@ -1780,7 +1780,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
assignee_type
),
})
.map(|rhs| TypedStatement::Definition(assignee, rhs))
.map(|rhs| TypedStatement::definition(assignee, rhs))
.map_err(|e| vec![e])
}
}
@ -4430,7 +4430,7 @@ mod tests {
let mut checker: Checker<Bn128Field> = new_with_args(scope, 1, HashSet::new());
assert_eq!(
checker.check_statement(statement, &*MODULE_ID, &TypeMap::new()),
Ok(TypedStatement::Definition(
Ok(TypedStatement::definition(
TypedAssignee::Identifier(typed::Variable::field_element("a")),
FieldElementExpression::Identifier("b".into()).into()
))
@ -4659,7 +4659,7 @@ mod tests {
Statement::Return(None).mock(),
];
let for_statements_checked = vec![TypedStatement::Definition(
let for_statements_checked = vec![TypedStatement::definition(
TypedAssignee::Identifier(typed::Variable::uint("a", UBitwidth::B32)),
UExpressionInner::Identifier("i".into())
.annotate(UBitwidth::B32)

View file

@ -1,7 +1,7 @@
use zokrates_ast::typed::{
folder::*, BlockExpression, BooleanExpression, Conditional, ConditionalExpression,
ConditionalOrExpression, CoreIdentifier, Expr, Identifier, Type, TypedProgram, TypedStatement,
Variable,
ConditionalOrExpression, CoreIdentifier, Expr, Identifier, Type, TypedExpression, TypedProgram,
TypedStatement, Variable,
};
use zokrates_field::Field;
@ -65,9 +65,9 @@ impl<'ast, T: Field> Folder<'ast, T> for ConditionRedefiner<'ast, T> {
| condition @ BooleanExpression::Identifier(_) => condition,
condition => {
let condition_id = Identifier::from(CoreIdentifier::Condition(self.index));
self.buffer.push(TypedStatement::Definition(
self.buffer.push(TypedStatement::definition(
Variable::immutable(condition_id.clone(), Type::Boolean).into(),
condition.into(),
TypedExpression::from(condition).into(),
));
self.index += 1;
BooleanExpression::Identifier(condition_id)
@ -99,7 +99,7 @@ mod tests {
// field foo = if true { 1 } else { 2 };
// should be left unchanged
let s = TypedStatement::Definition(
let s = TypedStatement::definition(
Variable::field_element("foo").into(),
FieldElementExpression::conditional(
BooleanExpression::Value(true),
@ -120,7 +120,7 @@ mod tests {
// field foo = if c { 1 } else { 2 };
// should be left unchanged
let s = TypedStatement::Definition(
let s = TypedStatement::definition(
Variable::field_element("foo").into(),
FieldElementExpression::conditional(
BooleanExpression::Identifier("c".into()),
@ -148,7 +148,7 @@ mod tests {
box BooleanExpression::Identifier("d".into()),
);
let s = TypedStatement::Definition(
let s = TypedStatement::definition(
Variable::field_element("foo").into(),
FieldElementExpression::conditional(
condition.clone(),
@ -163,12 +163,12 @@ mod tests {
let expected = vec![
// define condition
TypedStatement::Definition(
TypedStatement::definition(
Variable::immutable(CoreIdentifier::Condition(0), Type::Boolean).into(),
condition.into(),
),
// rewrite statement
TypedStatement::Definition(
TypedStatement::definition(
Variable::field_element("foo").into(),
FieldElementExpression::conditional(
BooleanExpression::Identifier(CoreIdentifier::Condition(0).into()),
@ -212,7 +212,7 @@ mod tests {
box BooleanExpression::Identifier("f".into()),
);
let s = TypedStatement::Definition(
let s = TypedStatement::definition(
Variable::field_element("foo").into(),
FieldElementExpression::conditional(
condition_0.clone(),
@ -232,16 +232,16 @@ mod tests {
let expected = vec![
// define conditions
TypedStatement::Definition(
TypedStatement::definition(
Variable::immutable(CoreIdentifier::Condition(0), Type::Boolean).into(),
condition_0.into(),
),
TypedStatement::Definition(
TypedStatement::definition(
Variable::immutable(CoreIdentifier::Condition(1), Type::Boolean).into(),
condition_1.into(),
),
// rewrite statement
TypedStatement::Definition(
TypedStatement::definition(
Variable::field_element("foo").into(),
FieldElementExpression::conditional(
BooleanExpression::Identifier(CoreIdentifier::Condition(0).into()),
@ -303,12 +303,12 @@ mod tests {
let condition_id_1 = BooleanExpression::Identifier(CoreIdentifier::Condition(1).into());
let condition_id_2 = BooleanExpression::Identifier(CoreIdentifier::Condition(2).into());
let s = TypedStatement::Definition(
let s = TypedStatement::definition(
Variable::field_element("foo").into(),
FieldElementExpression::conditional(
condition_0.clone(),
FieldElementExpression::block(
vec![TypedStatement::Definition(
vec![TypedStatement::definition(
Variable::field_element("a").into(),
FieldElementExpression::Number(Bn128Field::from(1)).into(),
)],
@ -320,7 +320,7 @@ mod tests {
),
),
FieldElementExpression::block(
vec![TypedStatement::Definition(
vec![TypedStatement::definition(
Variable::field_element("b").into(),
FieldElementExpression::Number(Bn128Field::from(2)).into(),
)],
@ -340,22 +340,22 @@ mod tests {
let expected = vec![
// define conditions
TypedStatement::Definition(
TypedStatement::definition(
Variable::immutable(CoreIdentifier::Condition(0), Type::Boolean).into(),
condition_0.into(),
),
// rewrite statement
TypedStatement::Definition(
TypedStatement::definition(
Variable::field_element("foo").into(),
FieldElementExpression::conditional(
condition_id_0.clone(),
FieldElementExpression::block(
vec![
TypedStatement::Definition(
TypedStatement::definition(
Variable::field_element("a").into(),
FieldElementExpression::Number(Bn128Field::from(1)).into(),
),
TypedStatement::Definition(
TypedStatement::definition(
Variable::immutable(CoreIdentifier::Condition(1), Type::Boolean)
.into(),
condition_1.into(),
@ -370,11 +370,11 @@ mod tests {
),
FieldElementExpression::block(
vec![
TypedStatement::Definition(
TypedStatement::definition(
Variable::field_element("b").into(),
FieldElementExpression::Number(Bn128Field::from(2)).into(),
),
TypedStatement::Definition(
TypedStatement::definition(
Variable::immutable(CoreIdentifier::Condition(2), Type::Boolean)
.into(),
condition_2.into(),

View file

@ -1,11 +1,11 @@
use std::fmt;
use zokrates_ast::common::FlatEmbed;
use zokrates_ast::typed::TypedProgram;
use zokrates_ast::typed::{
result_folder::ResultFolder,
result_folder::{fold_statement, fold_uint_expression_inner},
Constant, EmbedCall, TypedStatement, UBitwidth, UExpressionInner,
};
use zokrates_ast::typed::{DefinitionRhs, TypedProgram};
use zokrates_field::Field;
pub struct ConstantArgumentChecker;
@ -33,37 +33,41 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for ConstantArgumentChecker {
s: TypedStatement<'ast, T>,
) -> Result<Vec<TypedStatement<'ast, T>>, Self::Error> {
match s {
TypedStatement::EmbedCallDefinition(assignee, embed_call) => match embed_call {
EmbedCall {
embed: FlatEmbed::BitArrayLe,
..
} => {
let arguments = embed_call
.arguments
.into_iter()
.map(|a| self.fold_expression(a))
.collect::<Result<Vec<_>, _>>()?;
TypedStatement::Definition(assignee, DefinitionRhs::EmbedCall(embed_call)) => {
match embed_call {
EmbedCall {
embed: FlatEmbed::BitArrayLe,
..
} => {
let arguments = embed_call
.arguments
.into_iter()
.map(|a| self.fold_expression(a))
.collect::<Result<Vec<_>, _>>()?;
if arguments[1].is_constant() {
Ok(vec![TypedStatement::EmbedCallDefinition(
assignee,
EmbedCall {
embed: FlatEmbed::BitArrayLe,
generics: embed_call.generics,
arguments,
},
)])
} else {
Err(Error(format!(
"Cannot compare to a variable value, found `{}`",
arguments[1]
)))
if arguments[1].is_constant() {
Ok(vec![TypedStatement::Definition(
assignee,
EmbedCall {
embed: FlatEmbed::BitArrayLe,
generics: embed_call.generics,
arguments,
}
.into(),
)])
} else {
Err(Error(format!(
"Cannot compare to a variable value, found `{}`",
arguments[1]
)))
}
}
embed_call => Ok(vec![TypedStatement::Definition(
assignee,
embed_call.into(),
)]),
}
embed_call => Ok(vec![TypedStatement::EmbedCallDefinition(
assignee, embed_call,
)]),
},
}
s => fold_statement(self, s),
}
}

View file

@ -0,0 +1,79 @@
use std::collections::HashSet;
use zokrates_ast::typed::{
folder::*, BlockExpression, Identifier, TypedAssignee, TypedFunction, TypedProgram,
TypedStatement,
};
use zokrates_field::Field;
#[derive(Default)]
pub struct DeadCodeEliminator<'ast> {
used: HashSet<Identifier<'ast>>,
in_block: usize,
}
impl<'ast> DeadCodeEliminator<'ast> {
pub fn eliminate<T: Field>(p: TypedProgram<'ast, T>) -> TypedProgram<'ast, T> {
Self::default().fold_program(p)
}
}
impl<'ast, T: Field> Folder<'ast, T> for DeadCodeEliminator<'ast> {
fn fold_function(&mut self, f: TypedFunction<'ast, T>) -> TypedFunction<'ast, T> {
// iterate on the statements starting from the end, as we want to see usage before definition
let mut statements: Vec<_> = f
.statements
.into_iter()
.rev()
.flat_map(|s| self.fold_statement(s))
.collect();
statements.reverse();
TypedFunction { statements, ..f }
}
fn fold_statement(&mut self, s: TypedStatement<'ast, T>) -> Vec<TypedStatement<'ast, T>> {
match s {
TypedStatement::Definition(a, e) => match a {
TypedAssignee::Identifier(ref id) => {
// if the lhs is used later in the program and we're in a block
if self.used.remove(&id.id) {
// include this statement
fold_statement(self, TypedStatement::Definition(a, e))
} else {
// otherwise remove it
vec![]
}
}
_ => fold_statement(self, TypedStatement::Definition(a, e)),
},
TypedStatement::For(..) => {
unreachable!("for loops should be removed before dead code elimination is run")
}
s => fold_statement(self, s),
}
}
fn fold_block_expression<E: Fold<'ast, T>>(
&mut self,
block: BlockExpression<'ast, T, E>,
) -> BlockExpression<'ast, T, E> {
self.in_block += 1;
let value = box block.value.fold(self);
let mut statements: Vec<_> = block
.statements
.into_iter()
.rev()
.flat_map(|s| self.fold_statement(s))
.collect();
statements.reverse();
let block = BlockExpression { value, statements };
self.in_block -= 1;
block
}
fn fold_name(&mut self, n: Identifier<'ast>) -> Identifier<'ast> {
self.used.insert(n.clone());
n
}
}

View file

@ -402,7 +402,7 @@ fn fold_statement<'ast, T: Field>(
typed::TypedStatement::Return(expression) => vec![zir::ZirStatement::Return(
f.fold_expression(statements_buffer, expression),
)],
typed::TypedStatement::Definition(a, e) => {
typed::TypedStatement::Definition(a, typed::DefinitionRhs::Expression(e)) => {
let a = f.fold_assignee(a);
let e = f.fold_expression(statements_buffer, e);
assert_eq!(a.len(), e.len());
@ -418,10 +418,14 @@ fn fold_statement<'ast, T: Field>(
zir::RuntimeError::SourceAssertion(metadata.to_string())
}
typed::RuntimeError::SelectRangeCheck => zir::RuntimeError::SelectRangeCheck,
typed::RuntimeError::DivisionByZero => zir::RuntimeError::DivisionByZero,
};
vec![zir::ZirStatement::Assertion(e, error)]
}
typed::TypedStatement::EmbedCallDefinition(assignee, embed_call) => {
typed::TypedStatement::Definition(
assignee,
typed::DefinitionRhs::EmbedCall(embed_call),
) => {
vec![zir::ZirStatement::MultipleDefinition(
f.fold_assignee(assignee),
zir::ZirExpressionList::EmbedCall(

View file

@ -8,10 +8,12 @@ mod branch_isolator;
mod condition_redefiner;
mod constant_argument_checker;
mod constant_resolver;
mod dead_code;
mod flat_propagation;
mod flatten_complex_types;
mod log_ignorer;
mod out_of_bounds;
mod panic_extractor;
mod propagation;
mod reducer;
mod struct_concretizer;
@ -32,6 +34,8 @@ use self::uint_optimizer::UintOptimizer;
use self::variable_write_remover::VariableWriteRemover;
use crate::compile::CompileConfig;
use crate::static_analysis::constant_resolver::ConstantResolver;
use crate::static_analysis::dead_code::DeadCodeEliminator;
use crate::static_analysis::panic_extractor::PanicExtractor;
use crate::static_analysis::zir_propagation::ZirPropagator;
use std::fmt;
use zokrates_ast::typed::{abi::Abi, TypedProgram};
@ -162,6 +166,14 @@ pub fn analyse<'ast, T: Field>(
let r = ConditionRedefiner::redefine(r);
log::trace!("\n{}", r);
log::debug!("Static analyser: Extract panics");
let r = PanicExtractor::extract(r);
log::trace!("\n{}", r);
log::debug!("Static analyser: Remove dead code");
let r = DeadCodeEliminator::eliminate(r);
log::trace!("\n{}", r);
// convert to zir, removing complex types
log::debug!("Static analyser: Convert to zir");
let zir = Flattener::flatten(r);

View file

@ -0,0 +1,72 @@
use zokrates_ast::typed::{
folder::*, BooleanExpression, EqExpression, FieldElementExpression, RuntimeError, TypedProgram,
TypedStatement, UBitwidth, UExpressionInner,
};
use zokrates_field::Field;
// a static analyser pass to extract the failure modes into separate assert statements, so that a statement can panic iff it's an assertion
#[derive(Default)]
pub struct PanicExtractor<'ast, T> {
panic_buffer: Vec<(BooleanExpression<'ast, T>, RuntimeError)>,
}
impl<'ast, T: Field> PanicExtractor<'ast, T> {
pub fn extract(p: TypedProgram<'ast, T>) -> TypedProgram<'ast, T> {
Self::default().fold_program(p)
}
}
impl<'ast, T: Field> Folder<'ast, T> for PanicExtractor<'ast, T> {
fn fold_statement(&mut self, s: TypedStatement<'ast, T>) -> Vec<TypedStatement<'ast, T>> {
let s = fold_statement(self, s);
self.panic_buffer
.drain(..)
.map(|(b, e)| TypedStatement::Assertion(b, e))
.chain(s)
.collect()
}
fn fold_field_expression(
&mut self,
e: FieldElementExpression<'ast, T>,
) -> FieldElementExpression<'ast, T> {
match e {
FieldElementExpression::Div(box n, box d) => {
let n = self.fold_field_expression(n);
let d = self.fold_field_expression(d);
self.panic_buffer.push((
BooleanExpression::Not(box BooleanExpression::FieldEq(EqExpression::new(
d.clone(),
T::zero().into(),
))),
RuntimeError::DivisionByZero,
));
FieldElementExpression::Div(box n, box d)
}
e => fold_field_expression(self, e),
}
}
fn fold_uint_expression_inner(
&mut self,
b: UBitwidth,
e: UExpressionInner<'ast, T>,
) -> UExpressionInner<'ast, T> {
match e {
UExpressionInner::Div(box n, box d) => {
let n = self.fold_uint_expression(n);
let d = self.fold_uint_expression(d);
self.panic_buffer.push((
BooleanExpression::Not(box BooleanExpression::UintEq(EqExpression::new(
d.clone(),
UExpressionInner::Value(0).annotate(b),
))),
RuntimeError::DivisionByZero,
));
UExpressionInner::Div(box n, box d)
}
e => fold_uint_expression_inner(self, b, e),
}
}
}

View file

@ -222,83 +222,81 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Propagator<'ast, 'a, T> {
match s {
// propagation to the defined variable if rhs is a constant
TypedStatement::Definition(assignee, expr) => {
let expr = self.fold_expression(expr)?;
let assignee = self.fold_assignee(assignee)?;
if let (Ok(a), Ok(e)) = (
ConcreteType::try_from(assignee.get_type()),
ConcreteType::try_from(expr.get_type()),
) {
if a != e {
return Err(Error::Type(format!(
"Cannot assign {} of type {} to {} of type {}",
expr, e, assignee, a
)));
}
};
match expr {
DefinitionRhs::Expression(expr) => {
let expr = self.fold_expression(expr)?;
if expr.is_constant() {
match assignee {
TypedAssignee::Identifier(var) => {
let expr = expr.into_canonical_constant();
assert!(self.constants.insert(var.id, expr).is_none());
Ok(vec![])
}
assignee => match self.try_get_constant_mut(&assignee) {
Ok((_, c)) => {
*c = expr.into_canonical_constant();
Ok(vec![])
if let (Ok(a), Ok(e)) = (
ConcreteType::try_from(assignee.get_type()),
ConcreteType::try_from(expr.get_type()),
) {
if a != e {
return Err(Error::Type(format!(
"Cannot assign {} of type {} to {} of type {}",
expr, e, assignee, a
)));
}
Err(v) => match self.constants.remove(&v.id) {
// invalidate the cache for this identifier, and define the latest
// version of the constant in the program, if any
};
if expr.is_constant() {
match assignee {
TypedAssignee::Identifier(var) => {
let expr = expr.into_canonical_constant();
assert!(self.constants.insert(var.id, expr).is_none());
Ok(vec![])
}
assignee => match self.try_get_constant_mut(&assignee) {
Ok((_, c)) => {
*c = expr.into_canonical_constant();
Ok(vec![])
}
Err(v) => match self.constants.remove(&v.id) {
// invalidate the cache for this identifier, and define the latest
// version of the constant in the program, if any
Some(c) => Ok(vec![
TypedStatement::Definition(v.clone().into(), c.into()),
TypedStatement::Definition(assignee, expr.into()),
]),
None => Ok(vec![TypedStatement::Definition(
assignee,
expr.into(),
)]),
},
},
}
} else {
// the expression being assigned is not constant, invalidate the cache
let v = self
.try_get_constant_mut(&assignee)
.map(|(v, _)| v)
.unwrap_or_else(|v| v);
match self.constants.remove(&v.id) {
Some(c) => Ok(vec![
TypedStatement::Definition(v.clone().into(), c),
TypedStatement::Definition(assignee, expr),
TypedStatement::Definition(v.clone().into(), c.into()),
TypedStatement::Definition(assignee, expr.into()),
]),
None => Ok(vec![TypedStatement::Definition(assignee, expr)]),
},
},
None => Ok(vec![TypedStatement::Definition(assignee, expr.into())]),
}
}
}
} else {
// the expression being assigned is not constant, invalidate the cache
let v = self
.try_get_constant_mut(&assignee)
.map(|(v, _)| v)
.unwrap_or_else(|v| v);
DefinitionRhs::EmbedCall(embed_call) => {
let embed_call = self.fold_embed_call(embed_call)?;
match self.constants.remove(&v.id) {
Some(c) => Ok(vec![
TypedStatement::Definition(v.clone().into(), c),
TypedStatement::Definition(assignee, expr),
]),
None => Ok(vec![TypedStatement::Definition(assignee, expr)]),
}
}
}
// we do not visit the for-loop statements
TypedStatement::For(v, from, to, statements) => {
let from = self.fold_uint_expression(from)?;
let to = self.fold_uint_expression(to)?;
fn process_u_from_bits<'ast, T: Field>(
arguments: &[TypedExpression<'ast, T>],
bitwidth: UBitwidth,
) -> TypedExpression<'ast, T> {
assert_eq!(arguments.len(), 1);
Ok(vec![TypedStatement::For(v, from, to, statements)])
}
TypedStatement::EmbedCallDefinition(assignee, embed_call) => {
let assignee = self.fold_assignee(assignee)?;
let embed_call = self.fold_embed_call(embed_call)?;
let argument = arguments.last().cloned().unwrap();
let argument = argument.into_canonical_constant();
fn process_u_from_bits<'ast, T: Field>(
arguments: &[TypedExpression<'ast, T>],
bitwidth: UBitwidth,
) -> TypedExpression<'ast, T> {
assert_eq!(arguments.len(), 1);
let argument = arguments.last().cloned().unwrap();
let argument = argument.into_canonical_constant();
match ArrayExpression::try_from(argument)
match ArrayExpression::try_from(argument)
.unwrap()
.into_inner()
{
@ -330,196 +328,228 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Propagator<'ast, 'a, T> {
.into(),
_ => unreachable!("should be an array value"),
}
}
fn process_u_to_bits<'ast, T: Field>(
arguments: &[TypedExpression<'ast, T>],
bitwidth: UBitwidth,
) -> TypedExpression<'ast, T> {
assert_eq!(arguments.len(), 1);
match UExpression::try_from(arguments[0].clone())
.unwrap()
.into_inner()
{
UExpressionInner::Value(v) => {
let mut num = v;
let mut res = vec![];
for i in (0..bitwidth as u32).rev() {
if 2u128.pow(i) <= num {
num -= 2u128.pow(i);
res.push(true);
} else {
res.push(false);
}
}
assert_eq!(num, 0);
ArrayExpressionInner::Value(
res.into_iter()
.map(|v| BooleanExpression::Value(v).into())
.collect::<Vec<_>>()
.into(),
)
.annotate(Type::Boolean, bitwidth.to_usize() as u32)
.into()
}
_ => unreachable!("should be a uint value"),
}
}
match embed_call.arguments.iter().all(|a| a.is_constant()) {
true => {
let r: Option<TypedExpression<'ast, T>> = match embed_call.embed {
FlatEmbed::BitArrayLe => Ok(None), // todo
FlatEmbed::U64FromBits => Ok(Some(process_u_from_bits(
&embed_call.arguments,
UBitwidth::B64,
))),
FlatEmbed::U32FromBits => Ok(Some(process_u_from_bits(
&embed_call.arguments,
UBitwidth::B32,
))),
FlatEmbed::U16FromBits => Ok(Some(process_u_from_bits(
&embed_call.arguments,
UBitwidth::B16,
))),
FlatEmbed::U8FromBits => Ok(Some(process_u_from_bits(
&embed_call.arguments,
UBitwidth::B8,
))),
FlatEmbed::U64ToBits => Ok(Some(process_u_to_bits(
&embed_call.arguments,
UBitwidth::B64,
))),
FlatEmbed::U32ToBits => Ok(Some(process_u_to_bits(
&embed_call.arguments,
UBitwidth::B32,
))),
FlatEmbed::U16ToBits => Ok(Some(process_u_to_bits(
&embed_call.arguments,
UBitwidth::B16,
))),
FlatEmbed::U8ToBits => Ok(Some(process_u_to_bits(
&embed_call.arguments,
UBitwidth::B8,
))),
FlatEmbed::Unpack => {
assert_eq!(embed_call.arguments.len(), 1);
assert_eq!(embed_call.generics.len(), 1);
fn process_u_to_bits<'ast, T: Field>(
arguments: &[TypedExpression<'ast, T>],
bitwidth: UBitwidth,
) -> TypedExpression<'ast, T> {
assert_eq!(arguments.len(), 1);
let bit_width = embed_call.generics[0];
match FieldElementExpression::<T>::try_from(
embed_call.arguments[0].clone(),
)
match UExpression::try_from(arguments[0].clone())
.unwrap()
{
FieldElementExpression::Number(num) => {
let mut acc = num.clone();
let mut res = vec![];
.into_inner()
{
UExpressionInner::Value(v) => {
let mut num = v;
let mut res = vec![];
for i in (0..bit_width as usize).rev() {
if T::from(2).pow(i) <= acc {
acc = acc - T::from(2).pow(i);
res.push(true);
} else {
res.push(false);
}
for i in (0..bitwidth as u32).rev() {
if 2u128.pow(i) <= num {
num -= 2u128.pow(i);
res.push(true);
} else {
res.push(false);
}
}
assert_eq!(num, 0);
if acc != T::zero() {
Err(Error::ValueTooLarge(format!(
ArrayExpressionInner::Value(
res.into_iter()
.map(|v| BooleanExpression::Value(v).into())
.collect::<Vec<_>>()
.into(),
)
.annotate(Type::Boolean, bitwidth.to_usize() as u32)
.into()
}
_ => unreachable!("should be a uint value"),
}
}
match embed_call.arguments.iter().all(|a| a.is_constant()) {
true => {
let r: Option<TypedExpression<'ast, T>> = match embed_call.embed {
FlatEmbed::BitArrayLe => Ok(None), // todo
FlatEmbed::U64FromBits => Ok(Some(process_u_from_bits(
&embed_call.arguments,
UBitwidth::B64,
))),
FlatEmbed::U32FromBits => Ok(Some(process_u_from_bits(
&embed_call.arguments,
UBitwidth::B32,
))),
FlatEmbed::U16FromBits => Ok(Some(process_u_from_bits(
&embed_call.arguments,
UBitwidth::B16,
))),
FlatEmbed::U8FromBits => Ok(Some(process_u_from_bits(
&embed_call.arguments,
UBitwidth::B8,
))),
FlatEmbed::U64ToBits => Ok(Some(process_u_to_bits(
&embed_call.arguments,
UBitwidth::B64,
))),
FlatEmbed::U32ToBits => Ok(Some(process_u_to_bits(
&embed_call.arguments,
UBitwidth::B32,
))),
FlatEmbed::U16ToBits => Ok(Some(process_u_to_bits(
&embed_call.arguments,
UBitwidth::B16,
))),
FlatEmbed::U8ToBits => Ok(Some(process_u_to_bits(
&embed_call.arguments,
UBitwidth::B8,
))),
FlatEmbed::Unpack => {
assert_eq!(embed_call.arguments.len(), 1);
assert_eq!(embed_call.generics.len(), 1);
let bit_width = embed_call.generics[0];
match FieldElementExpression::<T>::try_from(
embed_call.arguments[0].clone(),
)
.unwrap()
{
FieldElementExpression::Number(num) => {
let mut acc = num.clone();
let mut res = vec![];
for i in (0..bit_width as usize).rev() {
if T::from(2).pow(i) <= acc {
acc = acc - T::from(2).pow(i);
res.push(true);
} else {
res.push(false);
}
}
if acc != T::zero() {
Err(Error::ValueTooLarge(format!(
"Cannot unpack `{}` to `{}`: value is too large",
num,
assignee.get_type()
)))
} else {
Ok(Some(
ArrayExpressionInner::Value(
res.into_iter()
.map(|v| BooleanExpression::Value(v).into())
.collect::<Vec<_>>()
} else {
Ok(Some(
ArrayExpressionInner::Value(
res.into_iter()
.map(|v| {
BooleanExpression::Value(v)
.into()
})
.collect::<Vec<_>>()
.into(),
)
.annotate(Type::Boolean, bit_width)
.into(),
)
.annotate(Type::Boolean, bit_width)
.into(),
))
))
}
}
_ => unreachable!("should be a field value"),
}
}
_ => unreachable!("should be a field value"),
}
}
#[cfg(feature = "bellman")]
FlatEmbed::Sha256Round => Ok(None),
#[cfg(feature = "ark")]
FlatEmbed::SnarkVerifyBls12377 => Ok(None),
}?;
#[cfg(feature = "bellman")]
FlatEmbed::Sha256Round => Ok(None),
#[cfg(feature = "ark")]
FlatEmbed::SnarkVerifyBls12377 => Ok(None),
}?;
Ok(match r {
// if the function call returns a constant
Some(expr) => match assignee {
TypedAssignee::Identifier(var) => {
self.constants.insert(var.id, expr);
vec![]
}
assignee => match self.try_get_constant_mut(&assignee) {
Ok((_, c)) => {
*c = expr;
vec![]
}
Err(v) => match self.constants.remove(&v.id) {
Some(c) => vec![
TypedStatement::Definition(v.clone().into(), c),
TypedStatement::Definition(assignee, expr),
],
None => {
vec![TypedStatement::Definition(assignee, expr)]
Ok(match r {
// if the function call returns a constant
Some(expr) => match assignee {
TypedAssignee::Identifier(var) => {
self.constants.insert(var.id, expr);
vec![]
}
assignee => match self.try_get_constant_mut(&assignee) {
Ok((_, c)) => {
*c = expr;
vec![]
}
Err(v) => match self.constants.remove(&v.id) {
Some(c) => vec![
TypedStatement::Definition(
v.clone().into(),
c.into(),
),
TypedStatement::Definition(
assignee,
expr.into(),
),
],
None => {
vec![TypedStatement::Definition(
assignee,
expr.into(),
)]
}
},
},
},
},
},
None => {
// if the function call does not return a constant, invalidate the cache
// this happens because we only propagate certain calls here
None => {
// if the function call does not return a constant, invalidate the cache
// this happens because we only propagate certain calls here
let v = self
.try_get_constant_mut(&assignee)
.map(|(v, _)| v)
.unwrap_or_else(|v| v);
match self.constants.remove(&v.id) {
Some(c) => vec![
TypedStatement::Definition(
v.clone().into(),
c.into(),
),
TypedStatement::Definition(
assignee,
embed_call.into(),
),
],
None => vec![TypedStatement::Definition(
assignee,
embed_call.into(),
)],
}
}
})
}
false => {
// if the function arguments are not constant, invalidate the cache
// for the return assignees
let def =
TypedStatement::Definition(assignee.clone(), embed_call.into());
let v = self
.try_get_constant_mut(&assignee)
.map(|(v, _)| v)
.unwrap_or_else(|v| v);
match self.constants.remove(&v.id) {
Some(c) => vec![
TypedStatement::Definition(v.clone().into(), c),
TypedStatement::EmbedCallDefinition(assignee, embed_call),
],
None => vec![TypedStatement::EmbedCallDefinition(
assignee, embed_call,
)],
}
Ok(match self.constants.remove(&v.id) {
Some(c) => {
vec![
TypedStatement::Definition(v.clone().into(), c.into()),
def,
]
}
None => vec![def],
})
}
})
}
false => {
// if the function arguments are not constant, invalidate the cache
// for the return assignees
let def = TypedStatement::EmbedCallDefinition(assignee.clone(), embed_call);
let v = self
.try_get_constant_mut(&assignee)
.map(|(v, _)| v)
.unwrap_or_else(|v| v);
Ok(match self.constants.remove(&v.id) {
Some(c) => {
vec![TypedStatement::Definition(v.clone().into(), c), def]
}
None => vec![def],
})
}
}
}
}
// we do not visit the for-loop statements
TypedStatement::For(v, from, to, statements) => {
let from = self.fold_uint_expression(from)?;
let to = self.fold_uint_expression(to)?;
Ok(vec![TypedStatement::For(v, from, to, statements)])
}
TypedStatement::Assertion(e, ty) => {
let e_str = e.to_string();
let expr = self.fold_boolean_expression(e)?;

View file

@ -178,7 +178,7 @@ pub fn inline_call<'a, 'ast, T: Field, E: Expr<'ast, T>>(
.zip(inferred_signature.inputs.clone())
.map(|(p, t)| ConcreteVariable::new(p.id.id, t, false))
.zip(arguments.clone())
.map(|(v, a)| TypedStatement::Definition(TypedAssignee::Identifier(v.into()), a))
.map(|(v, a)| TypedStatement::definition(TypedAssignee::Identifier(v.into()), a))
.collect();
let (statements, mut returns): (Vec<_>, Vec<_>) = ssa_f
@ -207,7 +207,7 @@ pub fn inline_call<'a, 'ast, T: Field, E: Expr<'ast, T>>(
let expression = TypedExpression::from(Variable::from(v.clone()));
let output_binding =
TypedStatement::Definition(TypedAssignee::Identifier(v.into()), return_expression);
TypedStatement::definition(TypedAssignee::Identifier(v.into()), return_expression);
let pop_log = TypedStatement::PopCallLog;

View file

@ -278,7 +278,7 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> {
let v = var.clone().into();
self.statement_buffer
.push(TypedStatement::EmbedCallDefinition(
.push(TypedStatement::embed_call_definition(
v,
EmbedCall::new(embed, generics, arguments),
));
@ -352,7 +352,7 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> {
for index in *from..*to {
let statements: Vec<TypedStatement<_>> =
std::iter::once(TypedStatement::Definition(
std::iter::once(TypedStatement::definition(
v.clone().into(),
UExpression::from(index as u32).into(),
))
@ -611,21 +611,21 @@ mod tests {
let main: TypedFunction<Bn128Field> = TypedFunction {
arguments: vec![DeclarationVariable::field_element("a").into()],
statements: vec![
TypedStatement::Definition(
TypedStatement::definition(
Variable::uint("n", UBitwidth::B32).into(),
TypedExpression::Uint(42u32.into()),
),
TypedStatement::Definition(
TypedStatement::definition(
Variable::uint("n", UBitwidth::B32).into(),
UExpressionInner::Identifier("n".into())
.annotate(UBitwidth::B32)
.into(),
),
TypedStatement::Definition(
TypedStatement::definition(
Variable::field_element("a").into(),
FieldElementExpression::Identifier("a".into()).into(),
),
TypedStatement::Definition(
TypedStatement::definition(
Variable::field_element("a").into(),
FieldElementExpression::function_call(
DeclarationFunctionKey::with_location("main", "foo").signature(
@ -638,7 +638,7 @@ mod tests {
)
.into(),
),
TypedStatement::Definition(
TypedStatement::definition(
Variable::uint("n", UBitwidth::B32).into(),
UExpressionInner::Identifier("n".into())
.annotate(UBitwidth::B32)
@ -687,7 +687,7 @@ mod tests {
let expected_main = TypedFunction {
arguments: vec![DeclarationVariable::field_element("a").into()],
statements: vec![
TypedStatement::Definition(
TypedStatement::definition(
Variable::field_element(Identifier::from("a").version(1)).into(),
FieldElementExpression::Identifier("a".into()).into(),
),
@ -699,17 +699,17 @@ mod tests {
),
GGenericsAssignment::default(),
),
TypedStatement::Definition(
TypedStatement::definition(
Variable::field_element(Identifier::from("a").version(3)).into(),
FieldElementExpression::Identifier(Identifier::from("a").version(1)).into(),
),
TypedStatement::Definition(
TypedStatement::definition(
Variable::field_element(Identifier::from(CoreIdentifier::Call(0)).version(0))
.into(),
FieldElementExpression::Identifier(Identifier::from("a").version(3)).into(),
),
TypedStatement::PopCallLog,
TypedStatement::Definition(
TypedStatement::definition(
Variable::field_element(Identifier::from("a").version(2)).into(),
FieldElementExpression::Identifier(
Identifier::from(CoreIdentifier::Call(0)).version(0),
@ -804,17 +804,17 @@ mod tests {
let main: TypedFunction<Bn128Field> = TypedFunction {
arguments: vec![DeclarationVariable::field_element("a").into()],
statements: vec![
TypedStatement::Definition(
TypedStatement::definition(
Variable::uint("n", UBitwidth::B32).into(),
TypedExpression::Uint(42u32.into()),
),
TypedStatement::Definition(
TypedStatement::definition(
Variable::uint("n", UBitwidth::B32).into(),
UExpressionInner::Identifier("n".into())
.annotate(UBitwidth::B32)
.into(),
),
TypedStatement::Definition(
TypedStatement::definition(
Variable::array("b", Type::FieldElement, 1u32).into(),
ArrayExpressionInner::Value(
vec![FieldElementExpression::Identifier("a".into()).into()].into(),
@ -822,7 +822,7 @@ mod tests {
.annotate(Type::FieldElement, 1u32)
.into(),
),
TypedStatement::Definition(
TypedStatement::definition(
Variable::array("b", Type::FieldElement, 1u32).into(),
ArrayExpression::function_call(
DeclarationFunctionKey::with_location("main", "foo")
@ -835,7 +835,7 @@ mod tests {
.annotate(Type::FieldElement, 1u32)
.into(),
),
TypedStatement::Definition(
TypedStatement::definition(
Variable::uint("n", UBitwidth::B32).into(),
UExpressionInner::Identifier("n".into())
.annotate(UBitwidth::B32)
@ -889,7 +889,7 @@ mod tests {
let expected_main = TypedFunction {
arguments: vec![DeclarationVariable::field_element("a").into()],
statements: vec![
TypedStatement::Definition(
TypedStatement::definition(
Variable::array("b", Type::FieldElement, 1u32).into(),
ArrayExpressionInner::Value(
vec![FieldElementExpression::Identifier("a".into()).into()].into(),
@ -906,14 +906,14 @@ mod tests {
.collect(),
),
),
TypedStatement::Definition(
TypedStatement::definition(
Variable::array(Identifier::from("a").version(1), Type::FieldElement, 1u32)
.into(),
ArrayExpressionInner::Identifier("b".into())
.annotate(Type::FieldElement, 1u32)
.into(),
),
TypedStatement::Definition(
TypedStatement::definition(
Variable::array(
Identifier::from(CoreIdentifier::Call(0)).version(0),
Type::FieldElement,
@ -925,7 +925,7 @@ mod tests {
.into(),
),
TypedStatement::PopCallLog,
TypedStatement::Definition(
TypedStatement::definition(
Variable::array(Identifier::from("b").version(1), Type::FieldElement, 1u32)
.into(),
ArrayExpressionInner::Identifier(
@ -1028,17 +1028,17 @@ mod tests {
let main: TypedFunction<Bn128Field> = TypedFunction {
arguments: vec![DeclarationVariable::field_element("a").into()],
statements: vec![
TypedStatement::Definition(
TypedStatement::definition(
Variable::uint("n", UBitwidth::B32).into(),
TypedExpression::Uint(2u32.into()),
),
TypedStatement::Definition(
TypedStatement::definition(
Variable::uint("n", UBitwidth::B32).into(),
UExpressionInner::Identifier("n".into())
.annotate(UBitwidth::B32)
.into(),
),
TypedStatement::Definition(
TypedStatement::definition(
Variable::array(
"b",
Type::FieldElement,
@ -1055,7 +1055,7 @@ mod tests {
.annotate(Type::FieldElement, 1u32)
.into(),
),
TypedStatement::Definition(
TypedStatement::definition(
Variable::array("b", Type::FieldElement, 1u32).into(),
ArrayExpression::function_call(
DeclarationFunctionKey::with_location("main", "foo")
@ -1068,7 +1068,7 @@ mod tests {
.annotate(Type::FieldElement, 1u32)
.into(),
),
TypedStatement::Definition(
TypedStatement::definition(
Variable::uint("n", UBitwidth::B32).into(),
UExpressionInner::Identifier("n".into())
.annotate(UBitwidth::B32)
@ -1122,7 +1122,7 @@ mod tests {
let expected_main = TypedFunction {
arguments: vec![DeclarationVariable::field_element("a").into()],
statements: vec![
TypedStatement::Definition(
TypedStatement::definition(
Variable::array("b", Type::FieldElement, 1u32).into(),
ArrayExpressionInner::Value(
vec![FieldElementExpression::Identifier("a".into()).into()].into(),
@ -1139,14 +1139,14 @@ mod tests {
.collect(),
),
),
TypedStatement::Definition(
TypedStatement::definition(
Variable::array(Identifier::from("a").version(1), Type::FieldElement, 1u32)
.into(),
ArrayExpressionInner::Identifier("b".into())
.annotate(Type::FieldElement, 1u32)
.into(),
),
TypedStatement::Definition(
TypedStatement::definition(
Variable::array(
Identifier::from(CoreIdentifier::Call(0)).version(0),
Type::FieldElement,
@ -1158,7 +1158,7 @@ mod tests {
.into(),
),
TypedStatement::PopCallLog,
TypedStatement::Definition(
TypedStatement::definition(
Variable::array(Identifier::from("b").version(1), Type::FieldElement, 1u32)
.into(),
ArrayExpressionInner::Identifier(
@ -1254,7 +1254,7 @@ mod tests {
)
.into()],
statements: vec![
TypedStatement::Definition(
TypedStatement::definition(
Variable::array(
"ret",
Type::FieldElement,
@ -1333,7 +1333,7 @@ mod tests {
let main: TypedFunction<Bn128Field> = TypedFunction {
arguments: vec![],
statements: vec![
TypedStatement::Definition(
TypedStatement::definition(
Variable::array("b", Type::FieldElement, 1u32).into(),
ArrayExpression::function_call(
DeclarationFunctionKey::with_location("main", "foo")
@ -1485,7 +1485,7 @@ mod tests {
let main: TypedFunction<Bn128Field> = TypedFunction {
arguments: vec![],
statements: vec![
TypedStatement::Definition(
TypedStatement::definition(
Variable::array("b", Type::FieldElement, 1u32).into(),
ArrayExpression::function_call(
DeclarationFunctionKey::with_location("main", "foo")

View file

@ -105,7 +105,7 @@ impl<'ast, 'a> ShallowTransformer<'ast, 'a> {
.0
.iter()
.map(|(g, v)| {
TypedStatement::Definition(
TypedStatement::definition(
TypedAssignee::Identifier(Variable::new(
g.name(),
Type::Uint(UBitwidth::B32),
@ -128,7 +128,7 @@ impl<'ast, 'a> ShallowTransformer<'ast, 'a> {
impl<'ast, 'a, T: Field> Folder<'ast, T> for ShallowTransformer<'ast, 'a> {
fn fold_statement(&mut self, s: TypedStatement<'ast, T>) -> Vec<TypedStatement<'ast, T>> {
match s {
TypedStatement::Definition(a, e) => {
TypedStatement::Definition(a, DefinitionRhs::Expression(e)) => {
let e = self.fold_expression(e);
let a = match a {
@ -139,9 +139,9 @@ impl<'ast, 'a, T: Field> Folder<'ast, T> for ShallowTransformer<'ast, 'a> {
a => fold_assignee(self, a),
};
vec![TypedStatement::Definition(a, e)]
vec![TypedStatement::definition(a, e)]
}
TypedStatement::EmbedCallDefinition(assignee, embed_call) => {
TypedStatement::Definition(assignee, DefinitionRhs::EmbedCall(embed_call)) => {
let assignee = match assignee {
TypedAssignee::Identifier(v) => {
let v = self.issue_next_ssa_variable(v);
@ -150,7 +150,7 @@ impl<'ast, 'a, T: Field> Folder<'ast, T> for ShallowTransformer<'ast, 'a> {
a => fold_assignee(self, a),
};
let embed_call = self.fold_embed_call(embed_call);
vec![TypedStatement::EmbedCallDefinition(assignee, embed_call)]
vec![TypedStatement::embed_call_definition(assignee, embed_call)]
}
TypedStatement::For(v, from, to, stats) => {
let from = self.fold_uint_expression(from);
@ -238,13 +238,13 @@ mod tests {
let mut u = ShallowTransformer::with_versions(&mut versions);
let s = TypedStatement::Definition(
let s = TypedStatement::definition(
TypedAssignee::Identifier(Variable::field_element("a")),
FieldElementExpression::Number(Bn128Field::from(5)).into(),
);
assert_eq!(
u.fold_statement(s),
vec![TypedStatement::Definition(
vec![TypedStatement::definition(
TypedAssignee::Identifier(Variable::field_element(
Identifier::from("a").version(0)
)),
@ -252,13 +252,13 @@ mod tests {
)]
);
let s = TypedStatement::Definition(
let s = TypedStatement::definition(
TypedAssignee::Identifier(Variable::field_element("a")),
FieldElementExpression::Number(Bn128Field::from(6)).into(),
);
assert_eq!(
u.fold_statement(s),
vec![TypedStatement::Definition(
vec![TypedStatement::definition(
TypedAssignee::Identifier(Variable::field_element(
Identifier::from("a").version(1)
)),
@ -288,13 +288,13 @@ mod tests {
let mut u = ShallowTransformer::with_versions(&mut versions);
let s = TypedStatement::Definition(
let s = TypedStatement::definition(
TypedAssignee::Identifier(Variable::field_element("a")),
FieldElementExpression::Number(Bn128Field::from(5)).into(),
);
assert_eq!(
u.fold_statement(s),
vec![TypedStatement::Definition(
vec![TypedStatement::definition(
TypedAssignee::Identifier(Variable::field_element(
Identifier::from("a").version(0)
)),
@ -302,7 +302,7 @@ mod tests {
)]
);
let s = TypedStatement::Definition(
let s = TypedStatement::definition(
TypedAssignee::Identifier(Variable::field_element("a")),
FieldElementExpression::Add(
box FieldElementExpression::Identifier("a".into()),
@ -312,7 +312,7 @@ mod tests {
);
assert_eq!(
u.fold_statement(s),
vec![TypedStatement::Definition(
vec![TypedStatement::definition(
TypedAssignee::Identifier(Variable::field_element(
Identifier::from("a").version(1)
)),
@ -339,13 +339,13 @@ mod tests {
let mut u = ShallowTransformer::with_versions(&mut versions);
let s = TypedStatement::Definition(
let s = TypedStatement::definition(
TypedAssignee::Identifier(Variable::field_element("a")),
FieldElementExpression::Number(Bn128Field::from(2)).into(),
);
assert_eq!(
u.fold_statement(s),
vec![TypedStatement::Definition(
vec![TypedStatement::definition(
TypedAssignee::Identifier(Variable::field_element(
Identifier::from("a").version(0)
)),
@ -353,7 +353,7 @@ mod tests {
)]
);
let s: TypedStatement<Bn128Field> = TypedStatement::Definition(
let s: TypedStatement<Bn128Field> = TypedStatement::definition(
Variable::field_element("a").into(),
FieldElementExpression::function_call(
DeclarationFunctionKey::with_location("main", "foo").signature(
@ -368,7 +368,7 @@ mod tests {
);
assert_eq!(
u.fold_statement(s),
vec![TypedStatement::Definition(
vec![TypedStatement::definition(
Variable::field_element(Identifier::from("a").version(1)).into(),
FieldElementExpression::function_call(
DeclarationFunctionKey::with_location("main", "foo").signature(
@ -400,7 +400,7 @@ mod tests {
let mut u = ShallowTransformer::with_versions(&mut versions);
let s = TypedStatement::Definition(
let s = TypedStatement::definition(
TypedAssignee::Identifier(Variable::array("a", Type::FieldElement, 2u32)),
ArrayExpressionInner::Value(
vec![
@ -415,7 +415,7 @@ mod tests {
assert_eq!(
u.fold_statement(s),
vec![TypedStatement::Definition(
vec![TypedStatement::definition(
TypedAssignee::Identifier(Variable::array(
Identifier::from("a").version(0),
Type::FieldElement,
@ -433,7 +433,7 @@ mod tests {
)]
);
let s: TypedStatement<Bn128Field> = TypedStatement::Definition(
let s: TypedStatement<Bn128Field> = TypedStatement::definition(
TypedAssignee::Select(
box TypedAssignee::Identifier(Variable::array("a", Type::FieldElement, 2u32)),
box UExpression::from(1u32),
@ -459,7 +459,7 @@ mod tests {
let array_of_array_ty = Type::array((Type::array((Type::FieldElement, 2u32)), 2u32));
let s = TypedStatement::Definition(
let s = TypedStatement::definition(
TypedAssignee::Identifier(Variable::new("a", array_of_array_ty.clone(), true)),
ArrayExpressionInner::Value(
vec![
@ -490,7 +490,7 @@ mod tests {
assert_eq!(
u.fold_statement(s),
vec![TypedStatement::Definition(
vec![TypedStatement::definition(
TypedAssignee::Identifier(Variable::new(
Identifier::from("a").version(0),
array_of_array_ty.clone(),
@ -524,7 +524,7 @@ mod tests {
)]
);
let s: TypedStatement<Bn128Field> = TypedStatement::Definition(
let s: TypedStatement<Bn128Field> = TypedStatement::definition(
TypedAssignee::Select(
box TypedAssignee::Identifier(Variable::new(
"a",
@ -590,17 +590,17 @@ mod tests {
let f: TypedFunction<Bn128Field> = TypedFunction {
arguments: vec![DeclarationVariable::field_element("a").into()],
statements: vec![
TypedStatement::Definition(
TypedStatement::definition(
Variable::uint("n", UBitwidth::B32).into(),
TypedExpression::Uint(42u32.into()),
),
TypedStatement::Definition(
TypedStatement::definition(
Variable::uint("n", UBitwidth::B32).into(),
UExpressionInner::Identifier("n".into())
.annotate(UBitwidth::B32)
.into(),
),
TypedStatement::Definition(
TypedStatement::definition(
Variable::field_element("a").into(),
FieldElementExpression::Identifier("a".into()).into(),
),
@ -609,12 +609,12 @@ mod tests {
UExpressionInner::Identifier("n".into()).annotate(UBitwidth::B32),
UExpressionInner::Identifier("n".into()).annotate(UBitwidth::B32)
* UExpressionInner::Identifier("n".into()).annotate(UBitwidth::B32),
vec![TypedStatement::Definition(
vec![TypedStatement::definition(
Variable::field_element("a").into(),
FieldElementExpression::Identifier("a".into()).into(),
)],
),
TypedStatement::Definition(
TypedStatement::definition(
Variable::field_element("a").into(),
FieldElementExpression::Identifier("a".into()).into(),
),
@ -623,12 +623,12 @@ mod tests {
UExpressionInner::Identifier("n".into()).annotate(UBitwidth::B32),
UExpressionInner::Identifier("n".into()).annotate(UBitwidth::B32)
* UExpressionInner::Identifier("n".into()).annotate(UBitwidth::B32),
vec![TypedStatement::Definition(
vec![TypedStatement::definition(
Variable::field_element("a").into(),
FieldElementExpression::Identifier("a".into()).into(),
)],
),
TypedStatement::Definition(
TypedStatement::definition(
Variable::field_element("a").into(),
FieldElementExpression::Identifier("a".into()).into(),
),
@ -657,21 +657,21 @@ mod tests {
let expected = TypedFunction {
arguments: vec![DeclarationVariable::field_element("a").into()],
statements: vec![
TypedStatement::Definition(
TypedStatement::definition(
Variable::uint("K", UBitwidth::B32).into(),
TypedExpression::Uint(1u32.into()),
),
TypedStatement::Definition(
TypedStatement::definition(
Variable::uint("n", UBitwidth::B32).into(),
TypedExpression::Uint(42u32.into()),
),
TypedStatement::Definition(
TypedStatement::definition(
Variable::uint(Identifier::from("n").version(1), UBitwidth::B32).into(),
UExpressionInner::Identifier("n".into())
.annotate(UBitwidth::B32)
.into(),
),
TypedStatement::Definition(
TypedStatement::definition(
Variable::field_element(Identifier::from("a").version(1)).into(),
FieldElementExpression::Identifier("a".into()).into(),
),
@ -683,12 +683,12 @@ mod tests {
.annotate(UBitwidth::B32)
* UExpressionInner::Identifier(Identifier::from("n").version(1))
.annotate(UBitwidth::B32),
vec![TypedStatement::Definition(
vec![TypedStatement::definition(
Variable::field_element("a").into(),
FieldElementExpression::Identifier("a".into()).into(),
)],
),
TypedStatement::Definition(
TypedStatement::definition(
Variable::field_element(Identifier::from("a").version(3)).into(),
FieldElementExpression::Identifier(Identifier::from("a").version(2)).into(),
),
@ -700,12 +700,12 @@ mod tests {
.annotate(UBitwidth::B32)
* UExpressionInner::Identifier(Identifier::from("n").version(2))
.annotate(UBitwidth::B32),
vec![TypedStatement::Definition(
vec![TypedStatement::definition(
Variable::field_element("a").into(),
FieldElementExpression::Identifier("a".into()).into(),
)],
),
TypedStatement::Definition(
TypedStatement::definition(
Variable::field_element(Identifier::from("a").version(5)).into(),
FieldElementExpression::Identifier(Identifier::from("a").version(4)).into(),
),
@ -775,21 +775,21 @@ mod tests {
let f: TypedFunction<Bn128Field> = TypedFunction {
arguments: vec![DeclarationVariable::field_element("a").into()],
statements: vec![
TypedStatement::Definition(
TypedStatement::definition(
Variable::uint("n", UBitwidth::B32).into(),
TypedExpression::Uint(42u32.into()),
),
TypedStatement::Definition(
TypedStatement::definition(
Variable::uint("n", UBitwidth::B32).into(),
UExpressionInner::Identifier("n".into())
.annotate(UBitwidth::B32)
.into(),
),
TypedStatement::Definition(
TypedStatement::definition(
Variable::field_element("a").into(),
FieldElementExpression::Identifier("a".into()).into(),
),
TypedStatement::Definition(
TypedStatement::definition(
Variable::field_element("a").into(),
FieldElementExpression::function_call(
DeclarationFunctionKey::with_location("main", "foo"),
@ -800,13 +800,13 @@ mod tests {
)
.into(),
),
TypedStatement::Definition(
TypedStatement::definition(
Variable::uint("n", UBitwidth::B32).into(),
UExpressionInner::Identifier("n".into())
.annotate(UBitwidth::B32)
.into(),
),
TypedStatement::Definition(
TypedStatement::definition(
Variable::field_element("a").into(),
(FieldElementExpression::Identifier("a".into())
* FieldElementExpression::function_call(
@ -844,25 +844,25 @@ mod tests {
let expected = TypedFunction {
arguments: vec![DeclarationVariable::field_element("a").into()],
statements: vec![
TypedStatement::Definition(
TypedStatement::definition(
Variable::uint("K", UBitwidth::B32).into(),
TypedExpression::Uint(1u32.into()),
),
TypedStatement::Definition(
TypedStatement::definition(
Variable::uint("n", UBitwidth::B32).into(),
TypedExpression::Uint(42u32.into()),
),
TypedStatement::Definition(
TypedStatement::definition(
Variable::uint(Identifier::from("n").version(1), UBitwidth::B32).into(),
UExpressionInner::Identifier("n".into())
.annotate(UBitwidth::B32)
.into(),
),
TypedStatement::Definition(
TypedStatement::definition(
Variable::field_element(Identifier::from("a").version(1)).into(),
FieldElementExpression::Identifier("a".into()).into(),
),
TypedStatement::Definition(
TypedStatement::definition(
Variable::field_element(Identifier::from("a").version(2)).into(),
FieldElementExpression::function_call(
DeclarationFunctionKey::with_location("main", "foo"),
@ -877,13 +877,13 @@ mod tests {
)
.into(),
),
TypedStatement::Definition(
TypedStatement::definition(
Variable::uint(Identifier::from("n").version(2), UBitwidth::B32).into(),
UExpressionInner::Identifier(Identifier::from("n").version(1))
.annotate(UBitwidth::B32)
.into(),
),
TypedStatement::Definition(
TypedStatement::definition(
Variable::field_element(Identifier::from("a").version(3)).into(),
(FieldElementExpression::Identifier(Identifier::from("a").version(2))
* FieldElementExpression::function_call(

View file

@ -457,11 +457,11 @@ fn is_constant<T>(assignee: &TypedAssignee<T>) -> bool {
impl<'ast, T: Field> Folder<'ast, T> for VariableWriteRemover {
fn fold_statement(&mut self, s: TypedStatement<'ast, T>) -> Vec<TypedStatement<'ast, T>> {
match s {
TypedStatement::Definition(assignee, expr) => {
TypedStatement::Definition(assignee, DefinitionRhs::Expression(expr)) => {
let expr = self.fold_expression(expr);
if is_constant(&assignee) {
vec![TypedStatement::Definition(assignee, expr)]
vec![TypedStatement::definition(assignee, expr)]
} else {
// Note: here we redefine the whole object, ideally we would only redefine some of it
// Example: `a[0][i] = 42` we redefine `a` but we could redefine just `a[0]`
@ -511,7 +511,7 @@ impl<'ast, T: Field> Folder<'ast, T> for VariableWriteRemover {
range_checks
.into_iter()
.chain(std::iter::once(TypedStatement::Definition(
.chain(std::iter::once(TypedStatement::definition(
TypedAssignee::Identifier(variable),
e,
)))