wip
This commit is contained in:
parent
53b62f568b
commit
8caa6b4720
19 changed files with 660 additions and 400 deletions
|
@ -37,6 +37,7 @@ impl From<crate::zir::RuntimeError> for RuntimeError {
|
|||
match error {
|
||||
crate::zir::RuntimeError::SourceAssertion(s) => RuntimeError::SourceAssertion(s),
|
||||
crate::zir::RuntimeError::SelectRangeCheck => RuntimeError::SelectRangeCheck,
|
||||
crate::zir::RuntimeError::DivisionByZero => RuntimeError::Inverse,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -264,6 +264,10 @@ pub trait Folder<'ast, T: Field>: Sized {
|
|||
fold_statement(self, s)
|
||||
}
|
||||
|
||||
fn fold_definition_rhs(&mut self, rhs: DefinitionRhs<'ast, T>) -> DefinitionRhs<'ast, T> {
|
||||
fold_definition_rhs(self, rhs)
|
||||
}
|
||||
|
||||
fn fold_embed_call(&mut self, e: EmbedCall<'ast, T>) -> EmbedCall<'ast, T> {
|
||||
fold_embed_call(self, e)
|
||||
}
|
||||
|
@ -491,6 +495,16 @@ pub fn fold_constant_symbol_declaration<'ast, T: Field, F: Folder<'ast, T>>(
|
|||
}
|
||||
}
|
||||
|
||||
pub fn fold_definition_rhs<'ast, T: Field, F: Folder<'ast, T>>(
|
||||
f: &mut F,
|
||||
rhs: DefinitionRhs<'ast, T>,
|
||||
) -> DefinitionRhs<'ast, T> {
|
||||
match rhs {
|
||||
DefinitionRhs::EmbedCall(c) => DefinitionRhs::EmbedCall(f.fold_embed_call(c)),
|
||||
DefinitionRhs::Expression(e) => DefinitionRhs::Expression(f.fold_expression(e)),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn fold_statement<'ast, T: Field, F: Folder<'ast, T>>(
|
||||
f: &mut F,
|
||||
s: TypedStatement<'ast, T>,
|
||||
|
@ -498,7 +512,7 @@ pub fn fold_statement<'ast, T: Field, F: Folder<'ast, T>>(
|
|||
let res = match s {
|
||||
TypedStatement::Return(e) => TypedStatement::Return(f.fold_expression(e)),
|
||||
TypedStatement::Definition(a, e) => {
|
||||
TypedStatement::Definition(f.fold_assignee(a), f.fold_expression(e))
|
||||
TypedStatement::Definition(f.fold_assignee(a), f.fold_definition_rhs(e))
|
||||
}
|
||||
TypedStatement::Assertion(e, error) => {
|
||||
TypedStatement::Assertion(f.fold_boolean_expression(e), error)
|
||||
|
@ -515,12 +529,6 @@ pub fn fold_statement<'ast, T: Field, F: Folder<'ast, T>>(
|
|||
TypedStatement::Log(s, e) => {
|
||||
TypedStatement::Log(s, e.into_iter().map(|e| f.fold_expression(e)).collect())
|
||||
}
|
||||
TypedStatement::EmbedCallDefinition(assignee, embed_call) => {
|
||||
TypedStatement::EmbedCallDefinition(
|
||||
f.fold_assignee(assignee),
|
||||
f.fold_embed_call(embed_call),
|
||||
)
|
||||
}
|
||||
s => s,
|
||||
};
|
||||
vec![res]
|
||||
|
|
|
@ -588,6 +588,7 @@ impl fmt::Display for AssertionMetadata {
|
|||
pub enum RuntimeError {
|
||||
SourceAssertion(AssertionMetadata),
|
||||
SelectRangeCheck,
|
||||
DivisionByZero,
|
||||
}
|
||||
|
||||
impl fmt::Display for RuntimeError {
|
||||
|
@ -595,6 +596,7 @@ impl fmt::Display for RuntimeError {
|
|||
match self {
|
||||
RuntimeError::SourceAssertion(metadata) => write!(f, "{}", metadata),
|
||||
RuntimeError::SelectRangeCheck => write!(f, "Range check on array access"),
|
||||
RuntimeError::DivisionByZero => write!(f, "Division by zero"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -646,12 +648,39 @@ impl<'ast, T: fmt::Display> fmt::Display for EmbedCall<'ast, T> {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, PartialEq, Debug, Hash, Eq, PartialOrd, Ord)]
|
||||
pub enum DefinitionRhs<'ast, T> {
|
||||
Expression(TypedExpression<'ast, T>),
|
||||
EmbedCall(EmbedCall<'ast, T>),
|
||||
}
|
||||
|
||||
impl<'ast, T> From<TypedExpression<'ast, T>> for DefinitionRhs<'ast, T> {
|
||||
fn from(e: TypedExpression<'ast, T>) -> Self {
|
||||
Self::Expression(e)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast, T> From<EmbedCall<'ast, T>> for DefinitionRhs<'ast, T> {
|
||||
fn from(c: EmbedCall<'ast, T>) -> Self {
|
||||
Self::EmbedCall(c)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast, T: fmt::Display> fmt::Display for DefinitionRhs<'ast, T> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
match self {
|
||||
DefinitionRhs::EmbedCall(c) => write!(f, "{}", c),
|
||||
DefinitionRhs::Expression(e) => write!(f, "{}", e),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A statement in a `TypedFunction`
|
||||
#[allow(clippy::large_enum_variant)]
|
||||
#[derive(Clone, PartialEq, Debug, Hash, Eq, PartialOrd, Ord)]
|
||||
pub enum TypedStatement<'ast, T> {
|
||||
Return(TypedExpression<'ast, T>),
|
||||
Definition(TypedAssignee<'ast, T>, TypedExpression<'ast, T>),
|
||||
Definition(TypedAssignee<'ast, T>, DefinitionRhs<'ast, T>),
|
||||
Assertion(BooleanExpression<'ast, T>, RuntimeError),
|
||||
For(
|
||||
Variable<'ast, T>,
|
||||
|
@ -660,7 +689,6 @@ pub enum TypedStatement<'ast, T> {
|
|||
Vec<TypedStatement<'ast, T>>,
|
||||
),
|
||||
Log(FormatString, Vec<TypedExpression<'ast, T>>),
|
||||
EmbedCallDefinition(TypedAssignee<'ast, T>, EmbedCall<'ast, T>),
|
||||
// Aux
|
||||
PushCallLog(
|
||||
DeclarationFunctionKey<'ast, T>,
|
||||
|
@ -669,6 +697,16 @@ pub enum TypedStatement<'ast, T> {
|
|||
PopCallLog,
|
||||
}
|
||||
|
||||
impl<'ast, T> TypedStatement<'ast, T> {
|
||||
pub fn definition(a: TypedAssignee<'ast, T>, e: TypedExpression<'ast, T>) -> Self {
|
||||
Self::Definition(a, e.into())
|
||||
}
|
||||
|
||||
pub fn embed_call_definition(a: TypedAssignee<'ast, T>, c: EmbedCall<'ast, T>) -> Self {
|
||||
Self::Definition(a, c.into())
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast, T: fmt::Display> TypedStatement<'ast, T> {
|
||||
fn fmt_indented(&self, f: &mut fmt::Formatter, depth: usize) -> fmt::Result {
|
||||
match self {
|
||||
|
@ -710,9 +748,6 @@ impl<'ast, T: fmt::Display> fmt::Display for TypedStatement<'ast, T> {
|
|||
}
|
||||
write!(f, "\t}}")
|
||||
}
|
||||
TypedStatement::EmbedCallDefinition(ref lhs, ref rhs) => {
|
||||
write!(f, "{} = {};", lhs, rhs)
|
||||
}
|
||||
TypedStatement::Log(ref l, ref expressions) => write!(
|
||||
f,
|
||||
"log({}, {})",
|
||||
|
@ -1642,7 +1677,7 @@ impl<'ast, T: fmt::Display> fmt::Display for UExpression<'ast, T> {
|
|||
UExpressionInner::Rem(ref lhs, ref rhs) => write!(f, "({} % {})", lhs, rhs),
|
||||
UExpressionInner::RightShift(ref e, ref by) => write!(f, "({} >> {})", e, by),
|
||||
UExpressionInner::LeftShift(ref e, ref by) => write!(f, "({} << {})", e, by),
|
||||
UExpressionInner::Not(ref e) => write!(f, "!{}", e),
|
||||
UExpressionInner::Not(ref e) => write!(f, "!({})", e),
|
||||
UExpressionInner::Neg(ref e) => write!(f, "(-{})", e),
|
||||
UExpressionInner::Pos(ref e) => write!(f, "(+{})", e),
|
||||
UExpressionInner::Select(ref select) => write!(f, "{}", select),
|
||||
|
@ -1675,7 +1710,7 @@ impl<'ast, T: fmt::Display> fmt::Display for BooleanExpression<'ast, T> {
|
|||
BooleanExpression::UintEq(ref e) => write!(f, "{}", e),
|
||||
BooleanExpression::Or(ref lhs, ref rhs) => write!(f, "{} || {}", lhs, rhs),
|
||||
BooleanExpression::And(ref lhs, ref rhs) => write!(f, "{} && {}", lhs, rhs),
|
||||
BooleanExpression::Not(ref exp) => write!(f, "!{}", exp),
|
||||
BooleanExpression::Not(ref exp) => write!(f, "!({})", exp),
|
||||
BooleanExpression::Value(b) => write!(f, "{}", b),
|
||||
BooleanExpression::FunctionCall(ref function_call) => write!(f, "{}", function_call),
|
||||
BooleanExpression::Conditional(ref c) => write!(f, "{}", c),
|
||||
|
|
|
@ -385,6 +385,13 @@ pub trait ResultFolder<'ast, T: Field>: Sized {
|
|||
fold_statement(self, s)
|
||||
}
|
||||
|
||||
fn fold_definition_rhs(
|
||||
&mut self,
|
||||
rhs: DefinitionRhs<'ast, T>,
|
||||
) -> Result<DefinitionRhs<'ast, T>, Self::Error> {
|
||||
fold_definition_rhs(self, rhs)
|
||||
}
|
||||
|
||||
fn fold_embed_call(
|
||||
&mut self,
|
||||
e: EmbedCall<'ast, T>,
|
||||
|
@ -508,7 +515,7 @@ pub fn fold_statement<'ast, T: Field, F: ResultFolder<'ast, T>>(
|
|||
let res = match s {
|
||||
TypedStatement::Return(e) => TypedStatement::Return(f.fold_expression(e)?),
|
||||
TypedStatement::Definition(a, e) => {
|
||||
TypedStatement::Definition(f.fold_assignee(a)?, f.fold_expression(e)?)
|
||||
TypedStatement::Definition(f.fold_assignee(a)?, f.fold_definition_rhs(e)?)
|
||||
}
|
||||
TypedStatement::Assertion(e, error) => {
|
||||
TypedStatement::Assertion(f.fold_boolean_expression(e)?, error)
|
||||
|
@ -531,17 +538,21 @@ pub fn fold_statement<'ast, T: Field, F: ResultFolder<'ast, T>>(
|
|||
.map(|e| f.fold_expression(e))
|
||||
.collect::<Result<Vec<_>, _>>()?,
|
||||
),
|
||||
TypedStatement::EmbedCallDefinition(assignee, embed_call) => {
|
||||
TypedStatement::EmbedCallDefinition(
|
||||
f.fold_assignee(assignee)?,
|
||||
f.fold_embed_call(embed_call)?,
|
||||
)
|
||||
}
|
||||
s => s,
|
||||
};
|
||||
Ok(vec![res])
|
||||
}
|
||||
|
||||
pub fn fold_definition_rhs<'ast, T: Field, F: ResultFolder<'ast, T>>(
|
||||
f: &mut F,
|
||||
rhs: DefinitionRhs<'ast, T>,
|
||||
) -> Result<DefinitionRhs<'ast, T>, F::Error> {
|
||||
Ok(match rhs {
|
||||
DefinitionRhs::EmbedCall(c) => DefinitionRhs::EmbedCall(f.fold_embed_call(c)?),
|
||||
DefinitionRhs::Expression(e) => DefinitionRhs::Expression(f.fold_expression(e)?),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn fold_embed_call<'ast, T: Field, F: ResultFolder<'ast, T>>(
|
||||
f: &mut F,
|
||||
e: EmbedCall<'ast, T>,
|
||||
|
|
|
@ -92,6 +92,7 @@ pub type ZirAssignee<'ast> = Variable<'ast>;
|
|||
pub enum RuntimeError {
|
||||
SourceAssertion(String),
|
||||
SelectRangeCheck,
|
||||
DivisionByZero,
|
||||
}
|
||||
|
||||
impl fmt::Display for RuntimeError {
|
||||
|
@ -99,6 +100,7 @@ impl fmt::Display for RuntimeError {
|
|||
match self {
|
||||
RuntimeError::SourceAssertion(message) => write!(f, "{}", message),
|
||||
RuntimeError::SelectRangeCheck => write!(f, "Range check on array access"),
|
||||
RuntimeError::DivisionByZero => write!(f, "Division by zero"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -4,9 +4,10 @@ from "field" import FIELD_SIZE_IN_BITS;
|
|||
// It should not work for the maxvalue = 2^(pbits - 2) - 1 augmented by one
|
||||
// /!\ should be called with a = 0
|
||||
|
||||
def main(field a) -> bool {
|
||||
def main(field a) {
|
||||
u32 pbits = FIELD_SIZE_IN_BITS;
|
||||
// we added a = 0 to prevent the condition to be evaluated at compile time
|
||||
field maxvalue = a + (2**(pbits - 2) - 1);
|
||||
return a < maxvalue + 1;
|
||||
bool c = a < maxvalue + 1;
|
||||
return
|
||||
}
|
||||
|
|
|
@ -4,9 +4,10 @@ from "field" import FIELD_SIZE_IN_BITS;
|
|||
// It should not work for the maxvalue = 2^(pbits - 2) - 1 augmented by one
|
||||
// /!\ should be called with a = 0
|
||||
|
||||
def main(field a) -> bool {
|
||||
def main(field a) {
|
||||
u32 pbits = FIELD_SIZE_IN_BITS;
|
||||
// we added a = 0 to prevent the condition to be evaluated at compile time
|
||||
field maxvalue = a + (2**(pbits - 2) - 1);
|
||||
return maxvalue + 1 < a;
|
||||
bool c = maxvalue + 1 < a;
|
||||
return
|
||||
}
|
||||
|
|
|
@ -1739,7 +1739,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
|
|||
types,
|
||||
)
|
||||
.map_err(|e| vec![e])?;
|
||||
Ok(TypedStatement::Definition(assignee, e))
|
||||
Ok(TypedStatement::definition(assignee, e))
|
||||
}
|
||||
_ => {
|
||||
// check the expression to be assigned
|
||||
|
@ -1780,7 +1780,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
|
|||
assignee_type
|
||||
),
|
||||
})
|
||||
.map(|rhs| TypedStatement::Definition(assignee, rhs))
|
||||
.map(|rhs| TypedStatement::definition(assignee, rhs))
|
||||
.map_err(|e| vec![e])
|
||||
}
|
||||
}
|
||||
|
@ -4430,7 +4430,7 @@ mod tests {
|
|||
let mut checker: Checker<Bn128Field> = new_with_args(scope, 1, HashSet::new());
|
||||
assert_eq!(
|
||||
checker.check_statement(statement, &*MODULE_ID, &TypeMap::new()),
|
||||
Ok(TypedStatement::Definition(
|
||||
Ok(TypedStatement::definition(
|
||||
TypedAssignee::Identifier(typed::Variable::field_element("a")),
|
||||
FieldElementExpression::Identifier("b".into()).into()
|
||||
))
|
||||
|
@ -4659,7 +4659,7 @@ mod tests {
|
|||
Statement::Return(None).mock(),
|
||||
];
|
||||
|
||||
let for_statements_checked = vec![TypedStatement::Definition(
|
||||
let for_statements_checked = vec![TypedStatement::definition(
|
||||
TypedAssignee::Identifier(typed::Variable::uint("a", UBitwidth::B32)),
|
||||
UExpressionInner::Identifier("i".into())
|
||||
.annotate(UBitwidth::B32)
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
use zokrates_ast::typed::{
|
||||
folder::*, BlockExpression, BooleanExpression, Conditional, ConditionalExpression,
|
||||
ConditionalOrExpression, CoreIdentifier, Expr, Identifier, Type, TypedProgram, TypedStatement,
|
||||
Variable,
|
||||
ConditionalOrExpression, CoreIdentifier, Expr, Identifier, Type, TypedExpression, TypedProgram,
|
||||
TypedStatement, Variable,
|
||||
};
|
||||
use zokrates_field::Field;
|
||||
|
||||
|
@ -65,9 +65,9 @@ impl<'ast, T: Field> Folder<'ast, T> for ConditionRedefiner<'ast, T> {
|
|||
| condition @ BooleanExpression::Identifier(_) => condition,
|
||||
condition => {
|
||||
let condition_id = Identifier::from(CoreIdentifier::Condition(self.index));
|
||||
self.buffer.push(TypedStatement::Definition(
|
||||
self.buffer.push(TypedStatement::definition(
|
||||
Variable::immutable(condition_id.clone(), Type::Boolean).into(),
|
||||
condition.into(),
|
||||
TypedExpression::from(condition).into(),
|
||||
));
|
||||
self.index += 1;
|
||||
BooleanExpression::Identifier(condition_id)
|
||||
|
@ -99,7 +99,7 @@ mod tests {
|
|||
// field foo = if true { 1 } else { 2 };
|
||||
// should be left unchanged
|
||||
|
||||
let s = TypedStatement::Definition(
|
||||
let s = TypedStatement::definition(
|
||||
Variable::field_element("foo").into(),
|
||||
FieldElementExpression::conditional(
|
||||
BooleanExpression::Value(true),
|
||||
|
@ -120,7 +120,7 @@ mod tests {
|
|||
// field foo = if c { 1 } else { 2 };
|
||||
// should be left unchanged
|
||||
|
||||
let s = TypedStatement::Definition(
|
||||
let s = TypedStatement::definition(
|
||||
Variable::field_element("foo").into(),
|
||||
FieldElementExpression::conditional(
|
||||
BooleanExpression::Identifier("c".into()),
|
||||
|
@ -148,7 +148,7 @@ mod tests {
|
|||
box BooleanExpression::Identifier("d".into()),
|
||||
);
|
||||
|
||||
let s = TypedStatement::Definition(
|
||||
let s = TypedStatement::definition(
|
||||
Variable::field_element("foo").into(),
|
||||
FieldElementExpression::conditional(
|
||||
condition.clone(),
|
||||
|
@ -163,12 +163,12 @@ mod tests {
|
|||
|
||||
let expected = vec![
|
||||
// define condition
|
||||
TypedStatement::Definition(
|
||||
TypedStatement::definition(
|
||||
Variable::immutable(CoreIdentifier::Condition(0), Type::Boolean).into(),
|
||||
condition.into(),
|
||||
),
|
||||
// rewrite statement
|
||||
TypedStatement::Definition(
|
||||
TypedStatement::definition(
|
||||
Variable::field_element("foo").into(),
|
||||
FieldElementExpression::conditional(
|
||||
BooleanExpression::Identifier(CoreIdentifier::Condition(0).into()),
|
||||
|
@ -212,7 +212,7 @@ mod tests {
|
|||
box BooleanExpression::Identifier("f".into()),
|
||||
);
|
||||
|
||||
let s = TypedStatement::Definition(
|
||||
let s = TypedStatement::definition(
|
||||
Variable::field_element("foo").into(),
|
||||
FieldElementExpression::conditional(
|
||||
condition_0.clone(),
|
||||
|
@ -232,16 +232,16 @@ mod tests {
|
|||
|
||||
let expected = vec![
|
||||
// define conditions
|
||||
TypedStatement::Definition(
|
||||
TypedStatement::definition(
|
||||
Variable::immutable(CoreIdentifier::Condition(0), Type::Boolean).into(),
|
||||
condition_0.into(),
|
||||
),
|
||||
TypedStatement::Definition(
|
||||
TypedStatement::definition(
|
||||
Variable::immutable(CoreIdentifier::Condition(1), Type::Boolean).into(),
|
||||
condition_1.into(),
|
||||
),
|
||||
// rewrite statement
|
||||
TypedStatement::Definition(
|
||||
TypedStatement::definition(
|
||||
Variable::field_element("foo").into(),
|
||||
FieldElementExpression::conditional(
|
||||
BooleanExpression::Identifier(CoreIdentifier::Condition(0).into()),
|
||||
|
@ -303,12 +303,12 @@ mod tests {
|
|||
let condition_id_1 = BooleanExpression::Identifier(CoreIdentifier::Condition(1).into());
|
||||
let condition_id_2 = BooleanExpression::Identifier(CoreIdentifier::Condition(2).into());
|
||||
|
||||
let s = TypedStatement::Definition(
|
||||
let s = TypedStatement::definition(
|
||||
Variable::field_element("foo").into(),
|
||||
FieldElementExpression::conditional(
|
||||
condition_0.clone(),
|
||||
FieldElementExpression::block(
|
||||
vec![TypedStatement::Definition(
|
||||
vec![TypedStatement::definition(
|
||||
Variable::field_element("a").into(),
|
||||
FieldElementExpression::Number(Bn128Field::from(1)).into(),
|
||||
)],
|
||||
|
@ -320,7 +320,7 @@ mod tests {
|
|||
),
|
||||
),
|
||||
FieldElementExpression::block(
|
||||
vec![TypedStatement::Definition(
|
||||
vec![TypedStatement::definition(
|
||||
Variable::field_element("b").into(),
|
||||
FieldElementExpression::Number(Bn128Field::from(2)).into(),
|
||||
)],
|
||||
|
@ -340,22 +340,22 @@ mod tests {
|
|||
|
||||
let expected = vec![
|
||||
// define conditions
|
||||
TypedStatement::Definition(
|
||||
TypedStatement::definition(
|
||||
Variable::immutable(CoreIdentifier::Condition(0), Type::Boolean).into(),
|
||||
condition_0.into(),
|
||||
),
|
||||
// rewrite statement
|
||||
TypedStatement::Definition(
|
||||
TypedStatement::definition(
|
||||
Variable::field_element("foo").into(),
|
||||
FieldElementExpression::conditional(
|
||||
condition_id_0.clone(),
|
||||
FieldElementExpression::block(
|
||||
vec![
|
||||
TypedStatement::Definition(
|
||||
TypedStatement::definition(
|
||||
Variable::field_element("a").into(),
|
||||
FieldElementExpression::Number(Bn128Field::from(1)).into(),
|
||||
),
|
||||
TypedStatement::Definition(
|
||||
TypedStatement::definition(
|
||||
Variable::immutable(CoreIdentifier::Condition(1), Type::Boolean)
|
||||
.into(),
|
||||
condition_1.into(),
|
||||
|
@ -370,11 +370,11 @@ mod tests {
|
|||
),
|
||||
FieldElementExpression::block(
|
||||
vec![
|
||||
TypedStatement::Definition(
|
||||
TypedStatement::definition(
|
||||
Variable::field_element("b").into(),
|
||||
FieldElementExpression::Number(Bn128Field::from(2)).into(),
|
||||
),
|
||||
TypedStatement::Definition(
|
||||
TypedStatement::definition(
|
||||
Variable::immutable(CoreIdentifier::Condition(2), Type::Boolean)
|
||||
.into(),
|
||||
condition_2.into(),
|
||||
|
|
|
@ -1,11 +1,11 @@
|
|||
use std::fmt;
|
||||
use zokrates_ast::common::FlatEmbed;
|
||||
use zokrates_ast::typed::TypedProgram;
|
||||
use zokrates_ast::typed::{
|
||||
result_folder::ResultFolder,
|
||||
result_folder::{fold_statement, fold_uint_expression_inner},
|
||||
Constant, EmbedCall, TypedStatement, UBitwidth, UExpressionInner,
|
||||
};
|
||||
use zokrates_ast::typed::{DefinitionRhs, TypedProgram};
|
||||
use zokrates_field::Field;
|
||||
|
||||
pub struct ConstantArgumentChecker;
|
||||
|
@ -33,37 +33,41 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for ConstantArgumentChecker {
|
|||
s: TypedStatement<'ast, T>,
|
||||
) -> Result<Vec<TypedStatement<'ast, T>>, Self::Error> {
|
||||
match s {
|
||||
TypedStatement::EmbedCallDefinition(assignee, embed_call) => match embed_call {
|
||||
EmbedCall {
|
||||
embed: FlatEmbed::BitArrayLe,
|
||||
..
|
||||
} => {
|
||||
let arguments = embed_call
|
||||
.arguments
|
||||
.into_iter()
|
||||
.map(|a| self.fold_expression(a))
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
TypedStatement::Definition(assignee, DefinitionRhs::EmbedCall(embed_call)) => {
|
||||
match embed_call {
|
||||
EmbedCall {
|
||||
embed: FlatEmbed::BitArrayLe,
|
||||
..
|
||||
} => {
|
||||
let arguments = embed_call
|
||||
.arguments
|
||||
.into_iter()
|
||||
.map(|a| self.fold_expression(a))
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
|
||||
if arguments[1].is_constant() {
|
||||
Ok(vec![TypedStatement::EmbedCallDefinition(
|
||||
assignee,
|
||||
EmbedCall {
|
||||
embed: FlatEmbed::BitArrayLe,
|
||||
generics: embed_call.generics,
|
||||
arguments,
|
||||
},
|
||||
)])
|
||||
} else {
|
||||
Err(Error(format!(
|
||||
"Cannot compare to a variable value, found `{}`",
|
||||
arguments[1]
|
||||
)))
|
||||
if arguments[1].is_constant() {
|
||||
Ok(vec![TypedStatement::Definition(
|
||||
assignee,
|
||||
EmbedCall {
|
||||
embed: FlatEmbed::BitArrayLe,
|
||||
generics: embed_call.generics,
|
||||
arguments,
|
||||
}
|
||||
.into(),
|
||||
)])
|
||||
} else {
|
||||
Err(Error(format!(
|
||||
"Cannot compare to a variable value, found `{}`",
|
||||
arguments[1]
|
||||
)))
|
||||
}
|
||||
}
|
||||
embed_call => Ok(vec![TypedStatement::Definition(
|
||||
assignee,
|
||||
embed_call.into(),
|
||||
)]),
|
||||
}
|
||||
embed_call => Ok(vec![TypedStatement::EmbedCallDefinition(
|
||||
assignee, embed_call,
|
||||
)]),
|
||||
},
|
||||
}
|
||||
s => fold_statement(self, s),
|
||||
}
|
||||
}
|
||||
|
|
79
zokrates_core/src/static_analysis/dead_code.rs
Normal file
79
zokrates_core/src/static_analysis/dead_code.rs
Normal file
|
@ -0,0 +1,79 @@
|
|||
use std::collections::HashSet;
|
||||
use zokrates_ast::typed::{
|
||||
folder::*, BlockExpression, Identifier, TypedAssignee, TypedFunction, TypedProgram,
|
||||
TypedStatement,
|
||||
};
|
||||
use zokrates_field::Field;
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct DeadCodeEliminator<'ast> {
|
||||
used: HashSet<Identifier<'ast>>,
|
||||
in_block: usize,
|
||||
}
|
||||
|
||||
impl<'ast> DeadCodeEliminator<'ast> {
|
||||
pub fn eliminate<T: Field>(p: TypedProgram<'ast, T>) -> TypedProgram<'ast, T> {
|
||||
Self::default().fold_program(p)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast, T: Field> Folder<'ast, T> for DeadCodeEliminator<'ast> {
|
||||
fn fold_function(&mut self, f: TypedFunction<'ast, T>) -> TypedFunction<'ast, T> {
|
||||
// iterate on the statements starting from the end, as we want to see usage before definition
|
||||
let mut statements: Vec<_> = f
|
||||
.statements
|
||||
.into_iter()
|
||||
.rev()
|
||||
.flat_map(|s| self.fold_statement(s))
|
||||
.collect();
|
||||
statements.reverse();
|
||||
TypedFunction { statements, ..f }
|
||||
}
|
||||
|
||||
fn fold_statement(&mut self, s: TypedStatement<'ast, T>) -> Vec<TypedStatement<'ast, T>> {
|
||||
match s {
|
||||
TypedStatement::Definition(a, e) => match a {
|
||||
TypedAssignee::Identifier(ref id) => {
|
||||
// if the lhs is used later in the program and we're in a block
|
||||
if self.used.remove(&id.id) {
|
||||
// include this statement
|
||||
fold_statement(self, TypedStatement::Definition(a, e))
|
||||
} else {
|
||||
// otherwise remove it
|
||||
vec![]
|
||||
}
|
||||
}
|
||||
_ => fold_statement(self, TypedStatement::Definition(a, e)),
|
||||
},
|
||||
TypedStatement::For(..) => {
|
||||
unreachable!("for loops should be removed before dead code elimination is run")
|
||||
}
|
||||
s => fold_statement(self, s),
|
||||
}
|
||||
}
|
||||
|
||||
fn fold_block_expression<E: Fold<'ast, T>>(
|
||||
&mut self,
|
||||
block: BlockExpression<'ast, T, E>,
|
||||
) -> BlockExpression<'ast, T, E> {
|
||||
self.in_block += 1;
|
||||
|
||||
let value = box block.value.fold(self);
|
||||
let mut statements: Vec<_> = block
|
||||
.statements
|
||||
.into_iter()
|
||||
.rev()
|
||||
.flat_map(|s| self.fold_statement(s))
|
||||
.collect();
|
||||
statements.reverse();
|
||||
|
||||
let block = BlockExpression { value, statements };
|
||||
self.in_block -= 1;
|
||||
block
|
||||
}
|
||||
|
||||
fn fold_name(&mut self, n: Identifier<'ast>) -> Identifier<'ast> {
|
||||
self.used.insert(n.clone());
|
||||
n
|
||||
}
|
||||
}
|
|
@ -402,7 +402,7 @@ fn fold_statement<'ast, T: Field>(
|
|||
typed::TypedStatement::Return(expression) => vec![zir::ZirStatement::Return(
|
||||
f.fold_expression(statements_buffer, expression),
|
||||
)],
|
||||
typed::TypedStatement::Definition(a, e) => {
|
||||
typed::TypedStatement::Definition(a, typed::DefinitionRhs::Expression(e)) => {
|
||||
let a = f.fold_assignee(a);
|
||||
let e = f.fold_expression(statements_buffer, e);
|
||||
assert_eq!(a.len(), e.len());
|
||||
|
@ -418,10 +418,14 @@ fn fold_statement<'ast, T: Field>(
|
|||
zir::RuntimeError::SourceAssertion(metadata.to_string())
|
||||
}
|
||||
typed::RuntimeError::SelectRangeCheck => zir::RuntimeError::SelectRangeCheck,
|
||||
typed::RuntimeError::DivisionByZero => zir::RuntimeError::DivisionByZero,
|
||||
};
|
||||
vec![zir::ZirStatement::Assertion(e, error)]
|
||||
}
|
||||
typed::TypedStatement::EmbedCallDefinition(assignee, embed_call) => {
|
||||
typed::TypedStatement::Definition(
|
||||
assignee,
|
||||
typed::DefinitionRhs::EmbedCall(embed_call),
|
||||
) => {
|
||||
vec![zir::ZirStatement::MultipleDefinition(
|
||||
f.fold_assignee(assignee),
|
||||
zir::ZirExpressionList::EmbedCall(
|
||||
|
|
|
@ -8,10 +8,12 @@ mod branch_isolator;
|
|||
mod condition_redefiner;
|
||||
mod constant_argument_checker;
|
||||
mod constant_resolver;
|
||||
mod dead_code;
|
||||
mod flat_propagation;
|
||||
mod flatten_complex_types;
|
||||
mod log_ignorer;
|
||||
mod out_of_bounds;
|
||||
mod panic_extractor;
|
||||
mod propagation;
|
||||
mod reducer;
|
||||
mod struct_concretizer;
|
||||
|
@ -32,6 +34,8 @@ use self::uint_optimizer::UintOptimizer;
|
|||
use self::variable_write_remover::VariableWriteRemover;
|
||||
use crate::compile::CompileConfig;
|
||||
use crate::static_analysis::constant_resolver::ConstantResolver;
|
||||
use crate::static_analysis::dead_code::DeadCodeEliminator;
|
||||
use crate::static_analysis::panic_extractor::PanicExtractor;
|
||||
use crate::static_analysis::zir_propagation::ZirPropagator;
|
||||
use std::fmt;
|
||||
use zokrates_ast::typed::{abi::Abi, TypedProgram};
|
||||
|
@ -162,6 +166,14 @@ pub fn analyse<'ast, T: Field>(
|
|||
let r = ConditionRedefiner::redefine(r);
|
||||
log::trace!("\n{}", r);
|
||||
|
||||
log::debug!("Static analyser: Extract panics");
|
||||
let r = PanicExtractor::extract(r);
|
||||
log::trace!("\n{}", r);
|
||||
|
||||
log::debug!("Static analyser: Remove dead code");
|
||||
let r = DeadCodeEliminator::eliminate(r);
|
||||
log::trace!("\n{}", r);
|
||||
|
||||
// convert to zir, removing complex types
|
||||
log::debug!("Static analyser: Convert to zir");
|
||||
let zir = Flattener::flatten(r);
|
||||
|
|
72
zokrates_core/src/static_analysis/panic_extractor.rs
Normal file
72
zokrates_core/src/static_analysis/panic_extractor.rs
Normal file
|
@ -0,0 +1,72 @@
|
|||
use zokrates_ast::typed::{
|
||||
folder::*, BooleanExpression, EqExpression, FieldElementExpression, RuntimeError, TypedProgram,
|
||||
TypedStatement, UBitwidth, UExpressionInner,
|
||||
};
|
||||
use zokrates_field::Field;
|
||||
|
||||
// a static analyser pass to extract the failure modes into separate assert statements, so that a statement can panic iff it's an assertion
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct PanicExtractor<'ast, T> {
|
||||
panic_buffer: Vec<(BooleanExpression<'ast, T>, RuntimeError)>,
|
||||
}
|
||||
|
||||
impl<'ast, T: Field> PanicExtractor<'ast, T> {
|
||||
pub fn extract(p: TypedProgram<'ast, T>) -> TypedProgram<'ast, T> {
|
||||
Self::default().fold_program(p)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast, T: Field> Folder<'ast, T> for PanicExtractor<'ast, T> {
|
||||
fn fold_statement(&mut self, s: TypedStatement<'ast, T>) -> Vec<TypedStatement<'ast, T>> {
|
||||
let s = fold_statement(self, s);
|
||||
self.panic_buffer
|
||||
.drain(..)
|
||||
.map(|(b, e)| TypedStatement::Assertion(b, e))
|
||||
.chain(s)
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn fold_field_expression(
|
||||
&mut self,
|
||||
e: FieldElementExpression<'ast, T>,
|
||||
) -> FieldElementExpression<'ast, T> {
|
||||
match e {
|
||||
FieldElementExpression::Div(box n, box d) => {
|
||||
let n = self.fold_field_expression(n);
|
||||
let d = self.fold_field_expression(d);
|
||||
self.panic_buffer.push((
|
||||
BooleanExpression::Not(box BooleanExpression::FieldEq(EqExpression::new(
|
||||
d.clone(),
|
||||
T::zero().into(),
|
||||
))),
|
||||
RuntimeError::DivisionByZero,
|
||||
));
|
||||
FieldElementExpression::Div(box n, box d)
|
||||
}
|
||||
e => fold_field_expression(self, e),
|
||||
}
|
||||
}
|
||||
|
||||
fn fold_uint_expression_inner(
|
||||
&mut self,
|
||||
b: UBitwidth,
|
||||
e: UExpressionInner<'ast, T>,
|
||||
) -> UExpressionInner<'ast, T> {
|
||||
match e {
|
||||
UExpressionInner::Div(box n, box d) => {
|
||||
let n = self.fold_uint_expression(n);
|
||||
let d = self.fold_uint_expression(d);
|
||||
self.panic_buffer.push((
|
||||
BooleanExpression::Not(box BooleanExpression::UintEq(EqExpression::new(
|
||||
d.clone(),
|
||||
UExpressionInner::Value(0).annotate(b),
|
||||
))),
|
||||
RuntimeError::DivisionByZero,
|
||||
));
|
||||
UExpressionInner::Div(box n, box d)
|
||||
}
|
||||
e => fold_uint_expression_inner(self, b, e),
|
||||
}
|
||||
}
|
||||
}
|
|
@ -222,83 +222,81 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Propagator<'ast, 'a, T> {
|
|||
match s {
|
||||
// propagation to the defined variable if rhs is a constant
|
||||
TypedStatement::Definition(assignee, expr) => {
|
||||
let expr = self.fold_expression(expr)?;
|
||||
let assignee = self.fold_assignee(assignee)?;
|
||||
|
||||
if let (Ok(a), Ok(e)) = (
|
||||
ConcreteType::try_from(assignee.get_type()),
|
||||
ConcreteType::try_from(expr.get_type()),
|
||||
) {
|
||||
if a != e {
|
||||
return Err(Error::Type(format!(
|
||||
"Cannot assign {} of type {} to {} of type {}",
|
||||
expr, e, assignee, a
|
||||
)));
|
||||
}
|
||||
};
|
||||
match expr {
|
||||
DefinitionRhs::Expression(expr) => {
|
||||
let expr = self.fold_expression(expr)?;
|
||||
|
||||
if expr.is_constant() {
|
||||
match assignee {
|
||||
TypedAssignee::Identifier(var) => {
|
||||
let expr = expr.into_canonical_constant();
|
||||
|
||||
assert!(self.constants.insert(var.id, expr).is_none());
|
||||
|
||||
Ok(vec![])
|
||||
}
|
||||
assignee => match self.try_get_constant_mut(&assignee) {
|
||||
Ok((_, c)) => {
|
||||
*c = expr.into_canonical_constant();
|
||||
Ok(vec![])
|
||||
if let (Ok(a), Ok(e)) = (
|
||||
ConcreteType::try_from(assignee.get_type()),
|
||||
ConcreteType::try_from(expr.get_type()),
|
||||
) {
|
||||
if a != e {
|
||||
return Err(Error::Type(format!(
|
||||
"Cannot assign {} of type {} to {} of type {}",
|
||||
expr, e, assignee, a
|
||||
)));
|
||||
}
|
||||
Err(v) => match self.constants.remove(&v.id) {
|
||||
// invalidate the cache for this identifier, and define the latest
|
||||
// version of the constant in the program, if any
|
||||
};
|
||||
|
||||
if expr.is_constant() {
|
||||
match assignee {
|
||||
TypedAssignee::Identifier(var) => {
|
||||
let expr = expr.into_canonical_constant();
|
||||
|
||||
assert!(self.constants.insert(var.id, expr).is_none());
|
||||
|
||||
Ok(vec![])
|
||||
}
|
||||
assignee => match self.try_get_constant_mut(&assignee) {
|
||||
Ok((_, c)) => {
|
||||
*c = expr.into_canonical_constant();
|
||||
Ok(vec![])
|
||||
}
|
||||
Err(v) => match self.constants.remove(&v.id) {
|
||||
// invalidate the cache for this identifier, and define the latest
|
||||
// version of the constant in the program, if any
|
||||
Some(c) => Ok(vec![
|
||||
TypedStatement::Definition(v.clone().into(), c.into()),
|
||||
TypedStatement::Definition(assignee, expr.into()),
|
||||
]),
|
||||
None => Ok(vec![TypedStatement::Definition(
|
||||
assignee,
|
||||
expr.into(),
|
||||
)]),
|
||||
},
|
||||
},
|
||||
}
|
||||
} else {
|
||||
// the expression being assigned is not constant, invalidate the cache
|
||||
let v = self
|
||||
.try_get_constant_mut(&assignee)
|
||||
.map(|(v, _)| v)
|
||||
.unwrap_or_else(|v| v);
|
||||
|
||||
match self.constants.remove(&v.id) {
|
||||
Some(c) => Ok(vec![
|
||||
TypedStatement::Definition(v.clone().into(), c),
|
||||
TypedStatement::Definition(assignee, expr),
|
||||
TypedStatement::Definition(v.clone().into(), c.into()),
|
||||
TypedStatement::Definition(assignee, expr.into()),
|
||||
]),
|
||||
None => Ok(vec![TypedStatement::Definition(assignee, expr)]),
|
||||
},
|
||||
},
|
||||
None => Ok(vec![TypedStatement::Definition(assignee, expr.into())]),
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// the expression being assigned is not constant, invalidate the cache
|
||||
let v = self
|
||||
.try_get_constant_mut(&assignee)
|
||||
.map(|(v, _)| v)
|
||||
.unwrap_or_else(|v| v);
|
||||
DefinitionRhs::EmbedCall(embed_call) => {
|
||||
let embed_call = self.fold_embed_call(embed_call)?;
|
||||
|
||||
match self.constants.remove(&v.id) {
|
||||
Some(c) => Ok(vec![
|
||||
TypedStatement::Definition(v.clone().into(), c),
|
||||
TypedStatement::Definition(assignee, expr),
|
||||
]),
|
||||
None => Ok(vec![TypedStatement::Definition(assignee, expr)]),
|
||||
}
|
||||
}
|
||||
}
|
||||
// we do not visit the for-loop statements
|
||||
TypedStatement::For(v, from, to, statements) => {
|
||||
let from = self.fold_uint_expression(from)?;
|
||||
let to = self.fold_uint_expression(to)?;
|
||||
fn process_u_from_bits<'ast, T: Field>(
|
||||
arguments: &[TypedExpression<'ast, T>],
|
||||
bitwidth: UBitwidth,
|
||||
) -> TypedExpression<'ast, T> {
|
||||
assert_eq!(arguments.len(), 1);
|
||||
|
||||
Ok(vec![TypedStatement::For(v, from, to, statements)])
|
||||
}
|
||||
TypedStatement::EmbedCallDefinition(assignee, embed_call) => {
|
||||
let assignee = self.fold_assignee(assignee)?;
|
||||
let embed_call = self.fold_embed_call(embed_call)?;
|
||||
let argument = arguments.last().cloned().unwrap();
|
||||
let argument = argument.into_canonical_constant();
|
||||
|
||||
fn process_u_from_bits<'ast, T: Field>(
|
||||
arguments: &[TypedExpression<'ast, T>],
|
||||
bitwidth: UBitwidth,
|
||||
) -> TypedExpression<'ast, T> {
|
||||
assert_eq!(arguments.len(), 1);
|
||||
|
||||
let argument = arguments.last().cloned().unwrap();
|
||||
let argument = argument.into_canonical_constant();
|
||||
|
||||
match ArrayExpression::try_from(argument)
|
||||
match ArrayExpression::try_from(argument)
|
||||
.unwrap()
|
||||
.into_inner()
|
||||
{
|
||||
|
@ -330,196 +328,228 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Propagator<'ast, 'a, T> {
|
|||
.into(),
|
||||
_ => unreachable!("should be an array value"),
|
||||
}
|
||||
}
|
||||
|
||||
fn process_u_to_bits<'ast, T: Field>(
|
||||
arguments: &[TypedExpression<'ast, T>],
|
||||
bitwidth: UBitwidth,
|
||||
) -> TypedExpression<'ast, T> {
|
||||
assert_eq!(arguments.len(), 1);
|
||||
|
||||
match UExpression::try_from(arguments[0].clone())
|
||||
.unwrap()
|
||||
.into_inner()
|
||||
{
|
||||
UExpressionInner::Value(v) => {
|
||||
let mut num = v;
|
||||
let mut res = vec![];
|
||||
|
||||
for i in (0..bitwidth as u32).rev() {
|
||||
if 2u128.pow(i) <= num {
|
||||
num -= 2u128.pow(i);
|
||||
res.push(true);
|
||||
} else {
|
||||
res.push(false);
|
||||
}
|
||||
}
|
||||
assert_eq!(num, 0);
|
||||
|
||||
ArrayExpressionInner::Value(
|
||||
res.into_iter()
|
||||
.map(|v| BooleanExpression::Value(v).into())
|
||||
.collect::<Vec<_>>()
|
||||
.into(),
|
||||
)
|
||||
.annotate(Type::Boolean, bitwidth.to_usize() as u32)
|
||||
.into()
|
||||
}
|
||||
_ => unreachable!("should be a uint value"),
|
||||
}
|
||||
}
|
||||
|
||||
match embed_call.arguments.iter().all(|a| a.is_constant()) {
|
||||
true => {
|
||||
let r: Option<TypedExpression<'ast, T>> = match embed_call.embed {
|
||||
FlatEmbed::BitArrayLe => Ok(None), // todo
|
||||
FlatEmbed::U64FromBits => Ok(Some(process_u_from_bits(
|
||||
&embed_call.arguments,
|
||||
UBitwidth::B64,
|
||||
))),
|
||||
FlatEmbed::U32FromBits => Ok(Some(process_u_from_bits(
|
||||
&embed_call.arguments,
|
||||
UBitwidth::B32,
|
||||
))),
|
||||
FlatEmbed::U16FromBits => Ok(Some(process_u_from_bits(
|
||||
&embed_call.arguments,
|
||||
UBitwidth::B16,
|
||||
))),
|
||||
FlatEmbed::U8FromBits => Ok(Some(process_u_from_bits(
|
||||
&embed_call.arguments,
|
||||
UBitwidth::B8,
|
||||
))),
|
||||
FlatEmbed::U64ToBits => Ok(Some(process_u_to_bits(
|
||||
&embed_call.arguments,
|
||||
UBitwidth::B64,
|
||||
))),
|
||||
FlatEmbed::U32ToBits => Ok(Some(process_u_to_bits(
|
||||
&embed_call.arguments,
|
||||
UBitwidth::B32,
|
||||
))),
|
||||
FlatEmbed::U16ToBits => Ok(Some(process_u_to_bits(
|
||||
&embed_call.arguments,
|
||||
UBitwidth::B16,
|
||||
))),
|
||||
FlatEmbed::U8ToBits => Ok(Some(process_u_to_bits(
|
||||
&embed_call.arguments,
|
||||
UBitwidth::B8,
|
||||
))),
|
||||
FlatEmbed::Unpack => {
|
||||
assert_eq!(embed_call.arguments.len(), 1);
|
||||
assert_eq!(embed_call.generics.len(), 1);
|
||||
fn process_u_to_bits<'ast, T: Field>(
|
||||
arguments: &[TypedExpression<'ast, T>],
|
||||
bitwidth: UBitwidth,
|
||||
) -> TypedExpression<'ast, T> {
|
||||
assert_eq!(arguments.len(), 1);
|
||||
|
||||
let bit_width = embed_call.generics[0];
|
||||
|
||||
match FieldElementExpression::<T>::try_from(
|
||||
embed_call.arguments[0].clone(),
|
||||
)
|
||||
match UExpression::try_from(arguments[0].clone())
|
||||
.unwrap()
|
||||
{
|
||||
FieldElementExpression::Number(num) => {
|
||||
let mut acc = num.clone();
|
||||
let mut res = vec![];
|
||||
.into_inner()
|
||||
{
|
||||
UExpressionInner::Value(v) => {
|
||||
let mut num = v;
|
||||
let mut res = vec![];
|
||||
|
||||
for i in (0..bit_width as usize).rev() {
|
||||
if T::from(2).pow(i) <= acc {
|
||||
acc = acc - T::from(2).pow(i);
|
||||
res.push(true);
|
||||
} else {
|
||||
res.push(false);
|
||||
}
|
||||
for i in (0..bitwidth as u32).rev() {
|
||||
if 2u128.pow(i) <= num {
|
||||
num -= 2u128.pow(i);
|
||||
res.push(true);
|
||||
} else {
|
||||
res.push(false);
|
||||
}
|
||||
}
|
||||
assert_eq!(num, 0);
|
||||
|
||||
if acc != T::zero() {
|
||||
Err(Error::ValueTooLarge(format!(
|
||||
ArrayExpressionInner::Value(
|
||||
res.into_iter()
|
||||
.map(|v| BooleanExpression::Value(v).into())
|
||||
.collect::<Vec<_>>()
|
||||
.into(),
|
||||
)
|
||||
.annotate(Type::Boolean, bitwidth.to_usize() as u32)
|
||||
.into()
|
||||
}
|
||||
_ => unreachable!("should be a uint value"),
|
||||
}
|
||||
}
|
||||
|
||||
match embed_call.arguments.iter().all(|a| a.is_constant()) {
|
||||
true => {
|
||||
let r: Option<TypedExpression<'ast, T>> = match embed_call.embed {
|
||||
FlatEmbed::BitArrayLe => Ok(None), // todo
|
||||
FlatEmbed::U64FromBits => Ok(Some(process_u_from_bits(
|
||||
&embed_call.arguments,
|
||||
UBitwidth::B64,
|
||||
))),
|
||||
FlatEmbed::U32FromBits => Ok(Some(process_u_from_bits(
|
||||
&embed_call.arguments,
|
||||
UBitwidth::B32,
|
||||
))),
|
||||
FlatEmbed::U16FromBits => Ok(Some(process_u_from_bits(
|
||||
&embed_call.arguments,
|
||||
UBitwidth::B16,
|
||||
))),
|
||||
FlatEmbed::U8FromBits => Ok(Some(process_u_from_bits(
|
||||
&embed_call.arguments,
|
||||
UBitwidth::B8,
|
||||
))),
|
||||
FlatEmbed::U64ToBits => Ok(Some(process_u_to_bits(
|
||||
&embed_call.arguments,
|
||||
UBitwidth::B64,
|
||||
))),
|
||||
FlatEmbed::U32ToBits => Ok(Some(process_u_to_bits(
|
||||
&embed_call.arguments,
|
||||
UBitwidth::B32,
|
||||
))),
|
||||
FlatEmbed::U16ToBits => Ok(Some(process_u_to_bits(
|
||||
&embed_call.arguments,
|
||||
UBitwidth::B16,
|
||||
))),
|
||||
FlatEmbed::U8ToBits => Ok(Some(process_u_to_bits(
|
||||
&embed_call.arguments,
|
||||
UBitwidth::B8,
|
||||
))),
|
||||
FlatEmbed::Unpack => {
|
||||
assert_eq!(embed_call.arguments.len(), 1);
|
||||
assert_eq!(embed_call.generics.len(), 1);
|
||||
|
||||
let bit_width = embed_call.generics[0];
|
||||
|
||||
match FieldElementExpression::<T>::try_from(
|
||||
embed_call.arguments[0].clone(),
|
||||
)
|
||||
.unwrap()
|
||||
{
|
||||
FieldElementExpression::Number(num) => {
|
||||
let mut acc = num.clone();
|
||||
let mut res = vec![];
|
||||
|
||||
for i in (0..bit_width as usize).rev() {
|
||||
if T::from(2).pow(i) <= acc {
|
||||
acc = acc - T::from(2).pow(i);
|
||||
res.push(true);
|
||||
} else {
|
||||
res.push(false);
|
||||
}
|
||||
}
|
||||
|
||||
if acc != T::zero() {
|
||||
Err(Error::ValueTooLarge(format!(
|
||||
"Cannot unpack `{}` to `{}`: value is too large",
|
||||
num,
|
||||
assignee.get_type()
|
||||
)))
|
||||
} else {
|
||||
Ok(Some(
|
||||
ArrayExpressionInner::Value(
|
||||
res.into_iter()
|
||||
.map(|v| BooleanExpression::Value(v).into())
|
||||
.collect::<Vec<_>>()
|
||||
} else {
|
||||
Ok(Some(
|
||||
ArrayExpressionInner::Value(
|
||||
res.into_iter()
|
||||
.map(|v| {
|
||||
BooleanExpression::Value(v)
|
||||
.into()
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.into(),
|
||||
)
|
||||
.annotate(Type::Boolean, bit_width)
|
||||
.into(),
|
||||
)
|
||||
.annotate(Type::Boolean, bit_width)
|
||||
.into(),
|
||||
))
|
||||
))
|
||||
}
|
||||
}
|
||||
_ => unreachable!("should be a field value"),
|
||||
}
|
||||
}
|
||||
_ => unreachable!("should be a field value"),
|
||||
}
|
||||
}
|
||||
#[cfg(feature = "bellman")]
|
||||
FlatEmbed::Sha256Round => Ok(None),
|
||||
#[cfg(feature = "ark")]
|
||||
FlatEmbed::SnarkVerifyBls12377 => Ok(None),
|
||||
}?;
|
||||
#[cfg(feature = "bellman")]
|
||||
FlatEmbed::Sha256Round => Ok(None),
|
||||
#[cfg(feature = "ark")]
|
||||
FlatEmbed::SnarkVerifyBls12377 => Ok(None),
|
||||
}?;
|
||||
|
||||
Ok(match r {
|
||||
// if the function call returns a constant
|
||||
Some(expr) => match assignee {
|
||||
TypedAssignee::Identifier(var) => {
|
||||
self.constants.insert(var.id, expr);
|
||||
vec![]
|
||||
}
|
||||
assignee => match self.try_get_constant_mut(&assignee) {
|
||||
Ok((_, c)) => {
|
||||
*c = expr;
|
||||
vec![]
|
||||
}
|
||||
Err(v) => match self.constants.remove(&v.id) {
|
||||
Some(c) => vec![
|
||||
TypedStatement::Definition(v.clone().into(), c),
|
||||
TypedStatement::Definition(assignee, expr),
|
||||
],
|
||||
None => {
|
||||
vec![TypedStatement::Definition(assignee, expr)]
|
||||
Ok(match r {
|
||||
// if the function call returns a constant
|
||||
Some(expr) => match assignee {
|
||||
TypedAssignee::Identifier(var) => {
|
||||
self.constants.insert(var.id, expr);
|
||||
vec![]
|
||||
}
|
||||
assignee => match self.try_get_constant_mut(&assignee) {
|
||||
Ok((_, c)) => {
|
||||
*c = expr;
|
||||
vec![]
|
||||
}
|
||||
Err(v) => match self.constants.remove(&v.id) {
|
||||
Some(c) => vec![
|
||||
TypedStatement::Definition(
|
||||
v.clone().into(),
|
||||
c.into(),
|
||||
),
|
||||
TypedStatement::Definition(
|
||||
assignee,
|
||||
expr.into(),
|
||||
),
|
||||
],
|
||||
None => {
|
||||
vec![TypedStatement::Definition(
|
||||
assignee,
|
||||
expr.into(),
|
||||
)]
|
||||
}
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
None => {
|
||||
// if the function call does not return a constant, invalidate the cache
|
||||
// this happens because we only propagate certain calls here
|
||||
None => {
|
||||
// if the function call does not return a constant, invalidate the cache
|
||||
// this happens because we only propagate certain calls here
|
||||
|
||||
let v = self
|
||||
.try_get_constant_mut(&assignee)
|
||||
.map(|(v, _)| v)
|
||||
.unwrap_or_else(|v| v);
|
||||
|
||||
match self.constants.remove(&v.id) {
|
||||
Some(c) => vec![
|
||||
TypedStatement::Definition(
|
||||
v.clone().into(),
|
||||
c.into(),
|
||||
),
|
||||
TypedStatement::Definition(
|
||||
assignee,
|
||||
embed_call.into(),
|
||||
),
|
||||
],
|
||||
None => vec![TypedStatement::Definition(
|
||||
assignee,
|
||||
embed_call.into(),
|
||||
)],
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
false => {
|
||||
// if the function arguments are not constant, invalidate the cache
|
||||
// for the return assignees
|
||||
let def =
|
||||
TypedStatement::Definition(assignee.clone(), embed_call.into());
|
||||
|
||||
let v = self
|
||||
.try_get_constant_mut(&assignee)
|
||||
.map(|(v, _)| v)
|
||||
.unwrap_or_else(|v| v);
|
||||
|
||||
match self.constants.remove(&v.id) {
|
||||
Some(c) => vec![
|
||||
TypedStatement::Definition(v.clone().into(), c),
|
||||
TypedStatement::EmbedCallDefinition(assignee, embed_call),
|
||||
],
|
||||
None => vec![TypedStatement::EmbedCallDefinition(
|
||||
assignee, embed_call,
|
||||
)],
|
||||
}
|
||||
Ok(match self.constants.remove(&v.id) {
|
||||
Some(c) => {
|
||||
vec![
|
||||
TypedStatement::Definition(v.clone().into(), c.into()),
|
||||
def,
|
||||
]
|
||||
}
|
||||
None => vec![def],
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
false => {
|
||||
// if the function arguments are not constant, invalidate the cache
|
||||
// for the return assignees
|
||||
let def = TypedStatement::EmbedCallDefinition(assignee.clone(), embed_call);
|
||||
|
||||
let v = self
|
||||
.try_get_constant_mut(&assignee)
|
||||
.map(|(v, _)| v)
|
||||
.unwrap_or_else(|v| v);
|
||||
|
||||
Ok(match self.constants.remove(&v.id) {
|
||||
Some(c) => {
|
||||
vec![TypedStatement::Definition(v.clone().into(), c), def]
|
||||
}
|
||||
None => vec![def],
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// we do not visit the for-loop statements
|
||||
TypedStatement::For(v, from, to, statements) => {
|
||||
let from = self.fold_uint_expression(from)?;
|
||||
let to = self.fold_uint_expression(to)?;
|
||||
|
||||
Ok(vec![TypedStatement::For(v, from, to, statements)])
|
||||
}
|
||||
TypedStatement::Assertion(e, ty) => {
|
||||
let e_str = e.to_string();
|
||||
let expr = self.fold_boolean_expression(e)?;
|
||||
|
|
|
@ -178,7 +178,7 @@ pub fn inline_call<'a, 'ast, T: Field, E: Expr<'ast, T>>(
|
|||
.zip(inferred_signature.inputs.clone())
|
||||
.map(|(p, t)| ConcreteVariable::new(p.id.id, t, false))
|
||||
.zip(arguments.clone())
|
||||
.map(|(v, a)| TypedStatement::Definition(TypedAssignee::Identifier(v.into()), a))
|
||||
.map(|(v, a)| TypedStatement::definition(TypedAssignee::Identifier(v.into()), a))
|
||||
.collect();
|
||||
|
||||
let (statements, mut returns): (Vec<_>, Vec<_>) = ssa_f
|
||||
|
@ -207,7 +207,7 @@ pub fn inline_call<'a, 'ast, T: Field, E: Expr<'ast, T>>(
|
|||
let expression = TypedExpression::from(Variable::from(v.clone()));
|
||||
|
||||
let output_binding =
|
||||
TypedStatement::Definition(TypedAssignee::Identifier(v.into()), return_expression);
|
||||
TypedStatement::definition(TypedAssignee::Identifier(v.into()), return_expression);
|
||||
|
||||
let pop_log = TypedStatement::PopCallLog;
|
||||
|
||||
|
|
|
@ -278,7 +278,7 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> {
|
|||
let v = var.clone().into();
|
||||
|
||||
self.statement_buffer
|
||||
.push(TypedStatement::EmbedCallDefinition(
|
||||
.push(TypedStatement::embed_call_definition(
|
||||
v,
|
||||
EmbedCall::new(embed, generics, arguments),
|
||||
));
|
||||
|
@ -352,7 +352,7 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> {
|
|||
|
||||
for index in *from..*to {
|
||||
let statements: Vec<TypedStatement<_>> =
|
||||
std::iter::once(TypedStatement::Definition(
|
||||
std::iter::once(TypedStatement::definition(
|
||||
v.clone().into(),
|
||||
UExpression::from(index as u32).into(),
|
||||
))
|
||||
|
@ -611,21 +611,21 @@ mod tests {
|
|||
let main: TypedFunction<Bn128Field> = TypedFunction {
|
||||
arguments: vec![DeclarationVariable::field_element("a").into()],
|
||||
statements: vec![
|
||||
TypedStatement::Definition(
|
||||
TypedStatement::definition(
|
||||
Variable::uint("n", UBitwidth::B32).into(),
|
||||
TypedExpression::Uint(42u32.into()),
|
||||
),
|
||||
TypedStatement::Definition(
|
||||
TypedStatement::definition(
|
||||
Variable::uint("n", UBitwidth::B32).into(),
|
||||
UExpressionInner::Identifier("n".into())
|
||||
.annotate(UBitwidth::B32)
|
||||
.into(),
|
||||
),
|
||||
TypedStatement::Definition(
|
||||
TypedStatement::definition(
|
||||
Variable::field_element("a").into(),
|
||||
FieldElementExpression::Identifier("a".into()).into(),
|
||||
),
|
||||
TypedStatement::Definition(
|
||||
TypedStatement::definition(
|
||||
Variable::field_element("a").into(),
|
||||
FieldElementExpression::function_call(
|
||||
DeclarationFunctionKey::with_location("main", "foo").signature(
|
||||
|
@ -638,7 +638,7 @@ mod tests {
|
|||
)
|
||||
.into(),
|
||||
),
|
||||
TypedStatement::Definition(
|
||||
TypedStatement::definition(
|
||||
Variable::uint("n", UBitwidth::B32).into(),
|
||||
UExpressionInner::Identifier("n".into())
|
||||
.annotate(UBitwidth::B32)
|
||||
|
@ -687,7 +687,7 @@ mod tests {
|
|||
let expected_main = TypedFunction {
|
||||
arguments: vec![DeclarationVariable::field_element("a").into()],
|
||||
statements: vec![
|
||||
TypedStatement::Definition(
|
||||
TypedStatement::definition(
|
||||
Variable::field_element(Identifier::from("a").version(1)).into(),
|
||||
FieldElementExpression::Identifier("a".into()).into(),
|
||||
),
|
||||
|
@ -699,17 +699,17 @@ mod tests {
|
|||
),
|
||||
GGenericsAssignment::default(),
|
||||
),
|
||||
TypedStatement::Definition(
|
||||
TypedStatement::definition(
|
||||
Variable::field_element(Identifier::from("a").version(3)).into(),
|
||||
FieldElementExpression::Identifier(Identifier::from("a").version(1)).into(),
|
||||
),
|
||||
TypedStatement::Definition(
|
||||
TypedStatement::definition(
|
||||
Variable::field_element(Identifier::from(CoreIdentifier::Call(0)).version(0))
|
||||
.into(),
|
||||
FieldElementExpression::Identifier(Identifier::from("a").version(3)).into(),
|
||||
),
|
||||
TypedStatement::PopCallLog,
|
||||
TypedStatement::Definition(
|
||||
TypedStatement::definition(
|
||||
Variable::field_element(Identifier::from("a").version(2)).into(),
|
||||
FieldElementExpression::Identifier(
|
||||
Identifier::from(CoreIdentifier::Call(0)).version(0),
|
||||
|
@ -804,17 +804,17 @@ mod tests {
|
|||
let main: TypedFunction<Bn128Field> = TypedFunction {
|
||||
arguments: vec![DeclarationVariable::field_element("a").into()],
|
||||
statements: vec![
|
||||
TypedStatement::Definition(
|
||||
TypedStatement::definition(
|
||||
Variable::uint("n", UBitwidth::B32).into(),
|
||||
TypedExpression::Uint(42u32.into()),
|
||||
),
|
||||
TypedStatement::Definition(
|
||||
TypedStatement::definition(
|
||||
Variable::uint("n", UBitwidth::B32).into(),
|
||||
UExpressionInner::Identifier("n".into())
|
||||
.annotate(UBitwidth::B32)
|
||||
.into(),
|
||||
),
|
||||
TypedStatement::Definition(
|
||||
TypedStatement::definition(
|
||||
Variable::array("b", Type::FieldElement, 1u32).into(),
|
||||
ArrayExpressionInner::Value(
|
||||
vec![FieldElementExpression::Identifier("a".into()).into()].into(),
|
||||
|
@ -822,7 +822,7 @@ mod tests {
|
|||
.annotate(Type::FieldElement, 1u32)
|
||||
.into(),
|
||||
),
|
||||
TypedStatement::Definition(
|
||||
TypedStatement::definition(
|
||||
Variable::array("b", Type::FieldElement, 1u32).into(),
|
||||
ArrayExpression::function_call(
|
||||
DeclarationFunctionKey::with_location("main", "foo")
|
||||
|
@ -835,7 +835,7 @@ mod tests {
|
|||
.annotate(Type::FieldElement, 1u32)
|
||||
.into(),
|
||||
),
|
||||
TypedStatement::Definition(
|
||||
TypedStatement::definition(
|
||||
Variable::uint("n", UBitwidth::B32).into(),
|
||||
UExpressionInner::Identifier("n".into())
|
||||
.annotate(UBitwidth::B32)
|
||||
|
@ -889,7 +889,7 @@ mod tests {
|
|||
let expected_main = TypedFunction {
|
||||
arguments: vec![DeclarationVariable::field_element("a").into()],
|
||||
statements: vec![
|
||||
TypedStatement::Definition(
|
||||
TypedStatement::definition(
|
||||
Variable::array("b", Type::FieldElement, 1u32).into(),
|
||||
ArrayExpressionInner::Value(
|
||||
vec![FieldElementExpression::Identifier("a".into()).into()].into(),
|
||||
|
@ -906,14 +906,14 @@ mod tests {
|
|||
.collect(),
|
||||
),
|
||||
),
|
||||
TypedStatement::Definition(
|
||||
TypedStatement::definition(
|
||||
Variable::array(Identifier::from("a").version(1), Type::FieldElement, 1u32)
|
||||
.into(),
|
||||
ArrayExpressionInner::Identifier("b".into())
|
||||
.annotate(Type::FieldElement, 1u32)
|
||||
.into(),
|
||||
),
|
||||
TypedStatement::Definition(
|
||||
TypedStatement::definition(
|
||||
Variable::array(
|
||||
Identifier::from(CoreIdentifier::Call(0)).version(0),
|
||||
Type::FieldElement,
|
||||
|
@ -925,7 +925,7 @@ mod tests {
|
|||
.into(),
|
||||
),
|
||||
TypedStatement::PopCallLog,
|
||||
TypedStatement::Definition(
|
||||
TypedStatement::definition(
|
||||
Variable::array(Identifier::from("b").version(1), Type::FieldElement, 1u32)
|
||||
.into(),
|
||||
ArrayExpressionInner::Identifier(
|
||||
|
@ -1028,17 +1028,17 @@ mod tests {
|
|||
let main: TypedFunction<Bn128Field> = TypedFunction {
|
||||
arguments: vec![DeclarationVariable::field_element("a").into()],
|
||||
statements: vec![
|
||||
TypedStatement::Definition(
|
||||
TypedStatement::definition(
|
||||
Variable::uint("n", UBitwidth::B32).into(),
|
||||
TypedExpression::Uint(2u32.into()),
|
||||
),
|
||||
TypedStatement::Definition(
|
||||
TypedStatement::definition(
|
||||
Variable::uint("n", UBitwidth::B32).into(),
|
||||
UExpressionInner::Identifier("n".into())
|
||||
.annotate(UBitwidth::B32)
|
||||
.into(),
|
||||
),
|
||||
TypedStatement::Definition(
|
||||
TypedStatement::definition(
|
||||
Variable::array(
|
||||
"b",
|
||||
Type::FieldElement,
|
||||
|
@ -1055,7 +1055,7 @@ mod tests {
|
|||
.annotate(Type::FieldElement, 1u32)
|
||||
.into(),
|
||||
),
|
||||
TypedStatement::Definition(
|
||||
TypedStatement::definition(
|
||||
Variable::array("b", Type::FieldElement, 1u32).into(),
|
||||
ArrayExpression::function_call(
|
||||
DeclarationFunctionKey::with_location("main", "foo")
|
||||
|
@ -1068,7 +1068,7 @@ mod tests {
|
|||
.annotate(Type::FieldElement, 1u32)
|
||||
.into(),
|
||||
),
|
||||
TypedStatement::Definition(
|
||||
TypedStatement::definition(
|
||||
Variable::uint("n", UBitwidth::B32).into(),
|
||||
UExpressionInner::Identifier("n".into())
|
||||
.annotate(UBitwidth::B32)
|
||||
|
@ -1122,7 +1122,7 @@ mod tests {
|
|||
let expected_main = TypedFunction {
|
||||
arguments: vec![DeclarationVariable::field_element("a").into()],
|
||||
statements: vec![
|
||||
TypedStatement::Definition(
|
||||
TypedStatement::definition(
|
||||
Variable::array("b", Type::FieldElement, 1u32).into(),
|
||||
ArrayExpressionInner::Value(
|
||||
vec![FieldElementExpression::Identifier("a".into()).into()].into(),
|
||||
|
@ -1139,14 +1139,14 @@ mod tests {
|
|||
.collect(),
|
||||
),
|
||||
),
|
||||
TypedStatement::Definition(
|
||||
TypedStatement::definition(
|
||||
Variable::array(Identifier::from("a").version(1), Type::FieldElement, 1u32)
|
||||
.into(),
|
||||
ArrayExpressionInner::Identifier("b".into())
|
||||
.annotate(Type::FieldElement, 1u32)
|
||||
.into(),
|
||||
),
|
||||
TypedStatement::Definition(
|
||||
TypedStatement::definition(
|
||||
Variable::array(
|
||||
Identifier::from(CoreIdentifier::Call(0)).version(0),
|
||||
Type::FieldElement,
|
||||
|
@ -1158,7 +1158,7 @@ mod tests {
|
|||
.into(),
|
||||
),
|
||||
TypedStatement::PopCallLog,
|
||||
TypedStatement::Definition(
|
||||
TypedStatement::definition(
|
||||
Variable::array(Identifier::from("b").version(1), Type::FieldElement, 1u32)
|
||||
.into(),
|
||||
ArrayExpressionInner::Identifier(
|
||||
|
@ -1254,7 +1254,7 @@ mod tests {
|
|||
)
|
||||
.into()],
|
||||
statements: vec![
|
||||
TypedStatement::Definition(
|
||||
TypedStatement::definition(
|
||||
Variable::array(
|
||||
"ret",
|
||||
Type::FieldElement,
|
||||
|
@ -1333,7 +1333,7 @@ mod tests {
|
|||
let main: TypedFunction<Bn128Field> = TypedFunction {
|
||||
arguments: vec![],
|
||||
statements: vec![
|
||||
TypedStatement::Definition(
|
||||
TypedStatement::definition(
|
||||
Variable::array("b", Type::FieldElement, 1u32).into(),
|
||||
ArrayExpression::function_call(
|
||||
DeclarationFunctionKey::with_location("main", "foo")
|
||||
|
@ -1485,7 +1485,7 @@ mod tests {
|
|||
let main: TypedFunction<Bn128Field> = TypedFunction {
|
||||
arguments: vec![],
|
||||
statements: vec![
|
||||
TypedStatement::Definition(
|
||||
TypedStatement::definition(
|
||||
Variable::array("b", Type::FieldElement, 1u32).into(),
|
||||
ArrayExpression::function_call(
|
||||
DeclarationFunctionKey::with_location("main", "foo")
|
||||
|
|
|
@ -105,7 +105,7 @@ impl<'ast, 'a> ShallowTransformer<'ast, 'a> {
|
|||
.0
|
||||
.iter()
|
||||
.map(|(g, v)| {
|
||||
TypedStatement::Definition(
|
||||
TypedStatement::definition(
|
||||
TypedAssignee::Identifier(Variable::new(
|
||||
g.name(),
|
||||
Type::Uint(UBitwidth::B32),
|
||||
|
@ -128,7 +128,7 @@ impl<'ast, 'a> ShallowTransformer<'ast, 'a> {
|
|||
impl<'ast, 'a, T: Field> Folder<'ast, T> for ShallowTransformer<'ast, 'a> {
|
||||
fn fold_statement(&mut self, s: TypedStatement<'ast, T>) -> Vec<TypedStatement<'ast, T>> {
|
||||
match s {
|
||||
TypedStatement::Definition(a, e) => {
|
||||
TypedStatement::Definition(a, DefinitionRhs::Expression(e)) => {
|
||||
let e = self.fold_expression(e);
|
||||
|
||||
let a = match a {
|
||||
|
@ -139,9 +139,9 @@ impl<'ast, 'a, T: Field> Folder<'ast, T> for ShallowTransformer<'ast, 'a> {
|
|||
a => fold_assignee(self, a),
|
||||
};
|
||||
|
||||
vec![TypedStatement::Definition(a, e)]
|
||||
vec![TypedStatement::definition(a, e)]
|
||||
}
|
||||
TypedStatement::EmbedCallDefinition(assignee, embed_call) => {
|
||||
TypedStatement::Definition(assignee, DefinitionRhs::EmbedCall(embed_call)) => {
|
||||
let assignee = match assignee {
|
||||
TypedAssignee::Identifier(v) => {
|
||||
let v = self.issue_next_ssa_variable(v);
|
||||
|
@ -150,7 +150,7 @@ impl<'ast, 'a, T: Field> Folder<'ast, T> for ShallowTransformer<'ast, 'a> {
|
|||
a => fold_assignee(self, a),
|
||||
};
|
||||
let embed_call = self.fold_embed_call(embed_call);
|
||||
vec![TypedStatement::EmbedCallDefinition(assignee, embed_call)]
|
||||
vec![TypedStatement::embed_call_definition(assignee, embed_call)]
|
||||
}
|
||||
TypedStatement::For(v, from, to, stats) => {
|
||||
let from = self.fold_uint_expression(from);
|
||||
|
@ -238,13 +238,13 @@ mod tests {
|
|||
|
||||
let mut u = ShallowTransformer::with_versions(&mut versions);
|
||||
|
||||
let s = TypedStatement::Definition(
|
||||
let s = TypedStatement::definition(
|
||||
TypedAssignee::Identifier(Variable::field_element("a")),
|
||||
FieldElementExpression::Number(Bn128Field::from(5)).into(),
|
||||
);
|
||||
assert_eq!(
|
||||
u.fold_statement(s),
|
||||
vec![TypedStatement::Definition(
|
||||
vec![TypedStatement::definition(
|
||||
TypedAssignee::Identifier(Variable::field_element(
|
||||
Identifier::from("a").version(0)
|
||||
)),
|
||||
|
@ -252,13 +252,13 @@ mod tests {
|
|||
)]
|
||||
);
|
||||
|
||||
let s = TypedStatement::Definition(
|
||||
let s = TypedStatement::definition(
|
||||
TypedAssignee::Identifier(Variable::field_element("a")),
|
||||
FieldElementExpression::Number(Bn128Field::from(6)).into(),
|
||||
);
|
||||
assert_eq!(
|
||||
u.fold_statement(s),
|
||||
vec![TypedStatement::Definition(
|
||||
vec![TypedStatement::definition(
|
||||
TypedAssignee::Identifier(Variable::field_element(
|
||||
Identifier::from("a").version(1)
|
||||
)),
|
||||
|
@ -288,13 +288,13 @@ mod tests {
|
|||
|
||||
let mut u = ShallowTransformer::with_versions(&mut versions);
|
||||
|
||||
let s = TypedStatement::Definition(
|
||||
let s = TypedStatement::definition(
|
||||
TypedAssignee::Identifier(Variable::field_element("a")),
|
||||
FieldElementExpression::Number(Bn128Field::from(5)).into(),
|
||||
);
|
||||
assert_eq!(
|
||||
u.fold_statement(s),
|
||||
vec![TypedStatement::Definition(
|
||||
vec![TypedStatement::definition(
|
||||
TypedAssignee::Identifier(Variable::field_element(
|
||||
Identifier::from("a").version(0)
|
||||
)),
|
||||
|
@ -302,7 +302,7 @@ mod tests {
|
|||
)]
|
||||
);
|
||||
|
||||
let s = TypedStatement::Definition(
|
||||
let s = TypedStatement::definition(
|
||||
TypedAssignee::Identifier(Variable::field_element("a")),
|
||||
FieldElementExpression::Add(
|
||||
box FieldElementExpression::Identifier("a".into()),
|
||||
|
@ -312,7 +312,7 @@ mod tests {
|
|||
);
|
||||
assert_eq!(
|
||||
u.fold_statement(s),
|
||||
vec![TypedStatement::Definition(
|
||||
vec![TypedStatement::definition(
|
||||
TypedAssignee::Identifier(Variable::field_element(
|
||||
Identifier::from("a").version(1)
|
||||
)),
|
||||
|
@ -339,13 +339,13 @@ mod tests {
|
|||
|
||||
let mut u = ShallowTransformer::with_versions(&mut versions);
|
||||
|
||||
let s = TypedStatement::Definition(
|
||||
let s = TypedStatement::definition(
|
||||
TypedAssignee::Identifier(Variable::field_element("a")),
|
||||
FieldElementExpression::Number(Bn128Field::from(2)).into(),
|
||||
);
|
||||
assert_eq!(
|
||||
u.fold_statement(s),
|
||||
vec![TypedStatement::Definition(
|
||||
vec![TypedStatement::definition(
|
||||
TypedAssignee::Identifier(Variable::field_element(
|
||||
Identifier::from("a").version(0)
|
||||
)),
|
||||
|
@ -353,7 +353,7 @@ mod tests {
|
|||
)]
|
||||
);
|
||||
|
||||
let s: TypedStatement<Bn128Field> = TypedStatement::Definition(
|
||||
let s: TypedStatement<Bn128Field> = TypedStatement::definition(
|
||||
Variable::field_element("a").into(),
|
||||
FieldElementExpression::function_call(
|
||||
DeclarationFunctionKey::with_location("main", "foo").signature(
|
||||
|
@ -368,7 +368,7 @@ mod tests {
|
|||
);
|
||||
assert_eq!(
|
||||
u.fold_statement(s),
|
||||
vec![TypedStatement::Definition(
|
||||
vec![TypedStatement::definition(
|
||||
Variable::field_element(Identifier::from("a").version(1)).into(),
|
||||
FieldElementExpression::function_call(
|
||||
DeclarationFunctionKey::with_location("main", "foo").signature(
|
||||
|
@ -400,7 +400,7 @@ mod tests {
|
|||
|
||||
let mut u = ShallowTransformer::with_versions(&mut versions);
|
||||
|
||||
let s = TypedStatement::Definition(
|
||||
let s = TypedStatement::definition(
|
||||
TypedAssignee::Identifier(Variable::array("a", Type::FieldElement, 2u32)),
|
||||
ArrayExpressionInner::Value(
|
||||
vec![
|
||||
|
@ -415,7 +415,7 @@ mod tests {
|
|||
|
||||
assert_eq!(
|
||||
u.fold_statement(s),
|
||||
vec![TypedStatement::Definition(
|
||||
vec![TypedStatement::definition(
|
||||
TypedAssignee::Identifier(Variable::array(
|
||||
Identifier::from("a").version(0),
|
||||
Type::FieldElement,
|
||||
|
@ -433,7 +433,7 @@ mod tests {
|
|||
)]
|
||||
);
|
||||
|
||||
let s: TypedStatement<Bn128Field> = TypedStatement::Definition(
|
||||
let s: TypedStatement<Bn128Field> = TypedStatement::definition(
|
||||
TypedAssignee::Select(
|
||||
box TypedAssignee::Identifier(Variable::array("a", Type::FieldElement, 2u32)),
|
||||
box UExpression::from(1u32),
|
||||
|
@ -459,7 +459,7 @@ mod tests {
|
|||
|
||||
let array_of_array_ty = Type::array((Type::array((Type::FieldElement, 2u32)), 2u32));
|
||||
|
||||
let s = TypedStatement::Definition(
|
||||
let s = TypedStatement::definition(
|
||||
TypedAssignee::Identifier(Variable::new("a", array_of_array_ty.clone(), true)),
|
||||
ArrayExpressionInner::Value(
|
||||
vec![
|
||||
|
@ -490,7 +490,7 @@ mod tests {
|
|||
|
||||
assert_eq!(
|
||||
u.fold_statement(s),
|
||||
vec![TypedStatement::Definition(
|
||||
vec![TypedStatement::definition(
|
||||
TypedAssignee::Identifier(Variable::new(
|
||||
Identifier::from("a").version(0),
|
||||
array_of_array_ty.clone(),
|
||||
|
@ -524,7 +524,7 @@ mod tests {
|
|||
)]
|
||||
);
|
||||
|
||||
let s: TypedStatement<Bn128Field> = TypedStatement::Definition(
|
||||
let s: TypedStatement<Bn128Field> = TypedStatement::definition(
|
||||
TypedAssignee::Select(
|
||||
box TypedAssignee::Identifier(Variable::new(
|
||||
"a",
|
||||
|
@ -590,17 +590,17 @@ mod tests {
|
|||
let f: TypedFunction<Bn128Field> = TypedFunction {
|
||||
arguments: vec![DeclarationVariable::field_element("a").into()],
|
||||
statements: vec![
|
||||
TypedStatement::Definition(
|
||||
TypedStatement::definition(
|
||||
Variable::uint("n", UBitwidth::B32).into(),
|
||||
TypedExpression::Uint(42u32.into()),
|
||||
),
|
||||
TypedStatement::Definition(
|
||||
TypedStatement::definition(
|
||||
Variable::uint("n", UBitwidth::B32).into(),
|
||||
UExpressionInner::Identifier("n".into())
|
||||
.annotate(UBitwidth::B32)
|
||||
.into(),
|
||||
),
|
||||
TypedStatement::Definition(
|
||||
TypedStatement::definition(
|
||||
Variable::field_element("a").into(),
|
||||
FieldElementExpression::Identifier("a".into()).into(),
|
||||
),
|
||||
|
@ -609,12 +609,12 @@ mod tests {
|
|||
UExpressionInner::Identifier("n".into()).annotate(UBitwidth::B32),
|
||||
UExpressionInner::Identifier("n".into()).annotate(UBitwidth::B32)
|
||||
* UExpressionInner::Identifier("n".into()).annotate(UBitwidth::B32),
|
||||
vec![TypedStatement::Definition(
|
||||
vec![TypedStatement::definition(
|
||||
Variable::field_element("a").into(),
|
||||
FieldElementExpression::Identifier("a".into()).into(),
|
||||
)],
|
||||
),
|
||||
TypedStatement::Definition(
|
||||
TypedStatement::definition(
|
||||
Variable::field_element("a").into(),
|
||||
FieldElementExpression::Identifier("a".into()).into(),
|
||||
),
|
||||
|
@ -623,12 +623,12 @@ mod tests {
|
|||
UExpressionInner::Identifier("n".into()).annotate(UBitwidth::B32),
|
||||
UExpressionInner::Identifier("n".into()).annotate(UBitwidth::B32)
|
||||
* UExpressionInner::Identifier("n".into()).annotate(UBitwidth::B32),
|
||||
vec![TypedStatement::Definition(
|
||||
vec![TypedStatement::definition(
|
||||
Variable::field_element("a").into(),
|
||||
FieldElementExpression::Identifier("a".into()).into(),
|
||||
)],
|
||||
),
|
||||
TypedStatement::Definition(
|
||||
TypedStatement::definition(
|
||||
Variable::field_element("a").into(),
|
||||
FieldElementExpression::Identifier("a".into()).into(),
|
||||
),
|
||||
|
@ -657,21 +657,21 @@ mod tests {
|
|||
let expected = TypedFunction {
|
||||
arguments: vec![DeclarationVariable::field_element("a").into()],
|
||||
statements: vec![
|
||||
TypedStatement::Definition(
|
||||
TypedStatement::definition(
|
||||
Variable::uint("K", UBitwidth::B32).into(),
|
||||
TypedExpression::Uint(1u32.into()),
|
||||
),
|
||||
TypedStatement::Definition(
|
||||
TypedStatement::definition(
|
||||
Variable::uint("n", UBitwidth::B32).into(),
|
||||
TypedExpression::Uint(42u32.into()),
|
||||
),
|
||||
TypedStatement::Definition(
|
||||
TypedStatement::definition(
|
||||
Variable::uint(Identifier::from("n").version(1), UBitwidth::B32).into(),
|
||||
UExpressionInner::Identifier("n".into())
|
||||
.annotate(UBitwidth::B32)
|
||||
.into(),
|
||||
),
|
||||
TypedStatement::Definition(
|
||||
TypedStatement::definition(
|
||||
Variable::field_element(Identifier::from("a").version(1)).into(),
|
||||
FieldElementExpression::Identifier("a".into()).into(),
|
||||
),
|
||||
|
@ -683,12 +683,12 @@ mod tests {
|
|||
.annotate(UBitwidth::B32)
|
||||
* UExpressionInner::Identifier(Identifier::from("n").version(1))
|
||||
.annotate(UBitwidth::B32),
|
||||
vec![TypedStatement::Definition(
|
||||
vec![TypedStatement::definition(
|
||||
Variable::field_element("a").into(),
|
||||
FieldElementExpression::Identifier("a".into()).into(),
|
||||
)],
|
||||
),
|
||||
TypedStatement::Definition(
|
||||
TypedStatement::definition(
|
||||
Variable::field_element(Identifier::from("a").version(3)).into(),
|
||||
FieldElementExpression::Identifier(Identifier::from("a").version(2)).into(),
|
||||
),
|
||||
|
@ -700,12 +700,12 @@ mod tests {
|
|||
.annotate(UBitwidth::B32)
|
||||
* UExpressionInner::Identifier(Identifier::from("n").version(2))
|
||||
.annotate(UBitwidth::B32),
|
||||
vec![TypedStatement::Definition(
|
||||
vec![TypedStatement::definition(
|
||||
Variable::field_element("a").into(),
|
||||
FieldElementExpression::Identifier("a".into()).into(),
|
||||
)],
|
||||
),
|
||||
TypedStatement::Definition(
|
||||
TypedStatement::definition(
|
||||
Variable::field_element(Identifier::from("a").version(5)).into(),
|
||||
FieldElementExpression::Identifier(Identifier::from("a").version(4)).into(),
|
||||
),
|
||||
|
@ -775,21 +775,21 @@ mod tests {
|
|||
let f: TypedFunction<Bn128Field> = TypedFunction {
|
||||
arguments: vec![DeclarationVariable::field_element("a").into()],
|
||||
statements: vec![
|
||||
TypedStatement::Definition(
|
||||
TypedStatement::definition(
|
||||
Variable::uint("n", UBitwidth::B32).into(),
|
||||
TypedExpression::Uint(42u32.into()),
|
||||
),
|
||||
TypedStatement::Definition(
|
||||
TypedStatement::definition(
|
||||
Variable::uint("n", UBitwidth::B32).into(),
|
||||
UExpressionInner::Identifier("n".into())
|
||||
.annotate(UBitwidth::B32)
|
||||
.into(),
|
||||
),
|
||||
TypedStatement::Definition(
|
||||
TypedStatement::definition(
|
||||
Variable::field_element("a").into(),
|
||||
FieldElementExpression::Identifier("a".into()).into(),
|
||||
),
|
||||
TypedStatement::Definition(
|
||||
TypedStatement::definition(
|
||||
Variable::field_element("a").into(),
|
||||
FieldElementExpression::function_call(
|
||||
DeclarationFunctionKey::with_location("main", "foo"),
|
||||
|
@ -800,13 +800,13 @@ mod tests {
|
|||
)
|
||||
.into(),
|
||||
),
|
||||
TypedStatement::Definition(
|
||||
TypedStatement::definition(
|
||||
Variable::uint("n", UBitwidth::B32).into(),
|
||||
UExpressionInner::Identifier("n".into())
|
||||
.annotate(UBitwidth::B32)
|
||||
.into(),
|
||||
),
|
||||
TypedStatement::Definition(
|
||||
TypedStatement::definition(
|
||||
Variable::field_element("a").into(),
|
||||
(FieldElementExpression::Identifier("a".into())
|
||||
* FieldElementExpression::function_call(
|
||||
|
@ -844,25 +844,25 @@ mod tests {
|
|||
let expected = TypedFunction {
|
||||
arguments: vec![DeclarationVariable::field_element("a").into()],
|
||||
statements: vec![
|
||||
TypedStatement::Definition(
|
||||
TypedStatement::definition(
|
||||
Variable::uint("K", UBitwidth::B32).into(),
|
||||
TypedExpression::Uint(1u32.into()),
|
||||
),
|
||||
TypedStatement::Definition(
|
||||
TypedStatement::definition(
|
||||
Variable::uint("n", UBitwidth::B32).into(),
|
||||
TypedExpression::Uint(42u32.into()),
|
||||
),
|
||||
TypedStatement::Definition(
|
||||
TypedStatement::definition(
|
||||
Variable::uint(Identifier::from("n").version(1), UBitwidth::B32).into(),
|
||||
UExpressionInner::Identifier("n".into())
|
||||
.annotate(UBitwidth::B32)
|
||||
.into(),
|
||||
),
|
||||
TypedStatement::Definition(
|
||||
TypedStatement::definition(
|
||||
Variable::field_element(Identifier::from("a").version(1)).into(),
|
||||
FieldElementExpression::Identifier("a".into()).into(),
|
||||
),
|
||||
TypedStatement::Definition(
|
||||
TypedStatement::definition(
|
||||
Variable::field_element(Identifier::from("a").version(2)).into(),
|
||||
FieldElementExpression::function_call(
|
||||
DeclarationFunctionKey::with_location("main", "foo"),
|
||||
|
@ -877,13 +877,13 @@ mod tests {
|
|||
)
|
||||
.into(),
|
||||
),
|
||||
TypedStatement::Definition(
|
||||
TypedStatement::definition(
|
||||
Variable::uint(Identifier::from("n").version(2), UBitwidth::B32).into(),
|
||||
UExpressionInner::Identifier(Identifier::from("n").version(1))
|
||||
.annotate(UBitwidth::B32)
|
||||
.into(),
|
||||
),
|
||||
TypedStatement::Definition(
|
||||
TypedStatement::definition(
|
||||
Variable::field_element(Identifier::from("a").version(3)).into(),
|
||||
(FieldElementExpression::Identifier(Identifier::from("a").version(2))
|
||||
* FieldElementExpression::function_call(
|
||||
|
|
|
@ -457,11 +457,11 @@ fn is_constant<T>(assignee: &TypedAssignee<T>) -> bool {
|
|||
impl<'ast, T: Field> Folder<'ast, T> for VariableWriteRemover {
|
||||
fn fold_statement(&mut self, s: TypedStatement<'ast, T>) -> Vec<TypedStatement<'ast, T>> {
|
||||
match s {
|
||||
TypedStatement::Definition(assignee, expr) => {
|
||||
TypedStatement::Definition(assignee, DefinitionRhs::Expression(expr)) => {
|
||||
let expr = self.fold_expression(expr);
|
||||
|
||||
if is_constant(&assignee) {
|
||||
vec![TypedStatement::Definition(assignee, expr)]
|
||||
vec![TypedStatement::definition(assignee, expr)]
|
||||
} else {
|
||||
// Note: here we redefine the whole object, ideally we would only redefine some of it
|
||||
// Example: `a[0][i] = 42` we redefine `a` but we could redefine just `a[0]`
|
||||
|
@ -511,7 +511,7 @@ impl<'ast, T: Field> Folder<'ast, T> for VariableWriteRemover {
|
|||
|
||||
range_checks
|
||||
.into_iter()
|
||||
.chain(std::iter::once(TypedStatement::Definition(
|
||||
.chain(std::iter::once(TypedStatement::definition(
|
||||
TypedAssignee::Identifier(variable),
|
||||
e,
|
||||
)))
|
||||
|
|
Loading…
Reference in a new issue