1
0
Fork 0
mirror of synced 2025-09-24 04:40:05 +00:00

Merge pull request #1283 from Zokrates/greedy-reducer

Refactor reducer to reduce memory usage and runtime
This commit is contained in:
Thibaut Schaeffer 2023-02-28 12:02:17 +01:00 committed by GitHub
commit 8ca79372df
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
17 changed files with 775 additions and 1150 deletions

View file

@ -0,0 +1 @@
Reduce memory usage and runtime by refactoring the reducer (ssa, propagation, unrolling and inlining)

View file

@ -629,8 +629,6 @@ fn fold_statement<'ast, T: Field>(
})
.collect(),
)],
typed::TypedStatement::PushCallLog(..) => vec![],
typed::TypedStatement::PopCallLog => vec![],
typed::TypedStatement::For(..) => unreachable!(),
};

View file

@ -161,10 +161,6 @@ pub fn analyse<'ast, T: Field>(
let r = reduce_program(r).map_err(Error::from)?;
log::trace!("\n{}", r);
log::debug!("Static analyser: Propagate");
let r = Propagator::propagate(r)?;
log::trace!("\n{}", r);
log::debug!("Static analyser: Concretize structs");
let r = StructConcretizer::concretize(r);
log::trace!("\n{}", r);

View file

@ -44,25 +44,16 @@ impl fmt::Display for Error {
}
}
#[derive(Debug)]
pub struct Propagator<'ast, 'a, T: Field> {
#[derive(Debug, Default)]
pub struct Propagator<'ast, T> {
// constants keeps track of constant expressions
// we currently do not support partially constant expressions: `field [x, 1][1]` is not considered constant, `field [0, 1][1]` is
constants: &'a mut Constants<'ast, T>,
constants: Constants<'ast, T>,
}
impl<'ast, 'a, T: Field> Propagator<'ast, 'a, T> {
pub fn with_constants(constants: &'a mut Constants<'ast, T>) -> Self {
Propagator { constants }
}
impl<'ast, T: Field> Propagator<'ast, T> {
pub fn propagate(p: TypedProgram<'ast, T>) -> Result<TypedProgram<'ast, T>, Error> {
let mut constants = Constants::new();
Propagator {
constants: &mut constants,
}
.fold_program(p)
Propagator::default().fold_program(p)
}
// get a mutable reference to the constant corresponding to a given assignee if any, otherwise
@ -141,7 +132,7 @@ impl<'ast, 'a, T: Field> Propagator<'ast, 'a, T> {
}
}
impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Propagator<'ast, 'a, T> {
impl<'ast, T: Field> ResultFolder<'ast, T> for Propagator<'ast, T> {
type Error = Error;
fn fold_program(&mut self, p: TypedProgram<'ast, T>) -> Result<TypedProgram<'ast, T>, Error> {
@ -629,8 +620,6 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Propagator<'ast, 'a, T> {
_ => Ok(vec![TypedStatement::Assertion(expr, err)]),
}
}
s @ TypedStatement::PushCallLog(..) => Ok(vec![s]),
s @ TypedStatement::PopCallLog => Ok(vec![s]),
s => fold_statement(self, s),
}
}
@ -1502,7 +1491,7 @@ mod tests {
);
assert_eq!(
Propagator::with_constants(&mut Constants::new()).fold_field_expression(e),
Propagator::default().fold_field_expression(e),
Ok(FieldElementExpression::Number(Bn128Field::from(5)))
);
}
@ -1515,7 +1504,7 @@ mod tests {
);
assert_eq!(
Propagator::with_constants(&mut Constants::new()).fold_field_expression(e),
Propagator::default().fold_field_expression(e),
Ok(FieldElementExpression::Number(Bn128Field::from(1)))
);
}
@ -1528,7 +1517,7 @@ mod tests {
);
assert_eq!(
Propagator::with_constants(&mut Constants::new()).fold_field_expression(e),
Propagator::default().fold_field_expression(e),
Ok(FieldElementExpression::Number(Bn128Field::from(6)))
);
}
@ -1541,7 +1530,7 @@ mod tests {
);
assert_eq!(
Propagator::with_constants(&mut Constants::new()).fold_field_expression(e),
Propagator::default().fold_field_expression(e),
Ok(FieldElementExpression::Number(Bn128Field::from(3)))
);
}
@ -1554,15 +1543,14 @@ mod tests {
);
assert_eq!(
Propagator::with_constants(&mut Constants::new()).fold_field_expression(e),
Propagator::default().fold_field_expression(e),
Ok(FieldElementExpression::Number(Bn128Field::from(8)))
);
}
#[test]
fn left_shift() {
let mut constants = Constants::new();
let mut propagator = Propagator::with_constants(&mut constants);
let mut propagator = Propagator::default();
assert_eq!(
propagator.fold_field_expression(FieldElementExpression::LeftShift(
@ -1607,8 +1595,7 @@ mod tests {
#[test]
fn right_shift() {
let mut constants = Constants::new();
let mut propagator = Propagator::with_constants(&mut constants);
let mut propagator = Propagator::default();
assert_eq!(
propagator.fold_field_expression(FieldElementExpression::RightShift(
@ -1676,7 +1663,7 @@ mod tests {
);
assert_eq!(
Propagator::with_constants(&mut Constants::new()).fold_field_expression(e),
Propagator::default().fold_field_expression(e),
Ok(FieldElementExpression::Number(Bn128Field::from(2)))
);
}
@ -1691,7 +1678,7 @@ mod tests {
);
assert_eq!(
Propagator::with_constants(&mut Constants::new()).fold_field_expression(e),
Propagator::default().fold_field_expression(e),
Ok(FieldElementExpression::Number(Bn128Field::from(3)))
);
}
@ -1713,7 +1700,7 @@ mod tests {
);
assert_eq!(
Propagator::with_constants(&mut Constants::new()).fold_field_expression(e),
Propagator::default().fold_field_expression(e),
Ok(FieldElementExpression::Number(Bn128Field::from(3)))
);
}
@ -1735,18 +1722,15 @@ mod tests {
BooleanExpression::Not(box BooleanExpression::identifier("a".into()));
assert_eq!(
Propagator::with_constants(&mut Constants::new())
.fold_boolean_expression(e_true),
Propagator::default().fold_boolean_expression(e_true),
Ok(BooleanExpression::Value(true))
);
assert_eq!(
Propagator::with_constants(&mut Constants::new())
.fold_boolean_expression(e_false),
Propagator::default().fold_boolean_expression(e_false),
Ok(BooleanExpression::Value(false))
);
assert_eq!(
Propagator::with_constants(&mut Constants::new())
.fold_boolean_expression(e_default.clone()),
Propagator::default().fold_boolean_expression(e_default.clone()),
Ok(e_default)
);
}
@ -1776,23 +1760,19 @@ mod tests {
));
assert_eq!(
Propagator::with_constants(&mut Constants::new())
.fold_boolean_expression(e_constant_true),
Propagator::default().fold_boolean_expression(e_constant_true),
Ok(BooleanExpression::Value(true))
);
assert_eq!(
Propagator::with_constants(&mut Constants::new())
.fold_boolean_expression(e_constant_false),
Propagator::default().fold_boolean_expression(e_constant_false),
Ok(BooleanExpression::Value(false))
);
assert_eq!(
Propagator::with_constants(&mut Constants::new())
.fold_boolean_expression(e_identifier_true),
Propagator::default().fold_boolean_expression(e_identifier_true),
Ok(BooleanExpression::Value(true))
);
assert_eq!(
Propagator::with_constants(&mut Constants::new())
.fold_boolean_expression(e_identifier_unchanged.clone()),
Propagator::default().fold_boolean_expression(e_identifier_unchanged.clone()),
Ok(e_identifier_unchanged)
);
}
@ -1800,38 +1780,42 @@ mod tests {
#[test]
fn bool_eq() {
assert_eq!(
Propagator::<Bn128Field>::with_constants(&mut Constants::new())
.fold_boolean_expression(BooleanExpression::BoolEq(EqExpression::new(
Propagator::<Bn128Field>::default().fold_boolean_expression(
BooleanExpression::BoolEq(EqExpression::new(
BooleanExpression::Value(false),
BooleanExpression::Value(false)
))),
))
),
Ok(BooleanExpression::Value(true))
);
assert_eq!(
Propagator::<Bn128Field>::with_constants(&mut Constants::new())
.fold_boolean_expression(BooleanExpression::BoolEq(EqExpression::new(
Propagator::<Bn128Field>::default().fold_boolean_expression(
BooleanExpression::BoolEq(EqExpression::new(
BooleanExpression::Value(true),
BooleanExpression::Value(true)
))),
))
),
Ok(BooleanExpression::Value(true))
);
assert_eq!(
Propagator::<Bn128Field>::with_constants(&mut Constants::new())
.fold_boolean_expression(BooleanExpression::BoolEq(EqExpression::new(
Propagator::<Bn128Field>::default().fold_boolean_expression(
BooleanExpression::BoolEq(EqExpression::new(
BooleanExpression::Value(true),
BooleanExpression::Value(false)
))),
))
),
Ok(BooleanExpression::Value(false))
);
assert_eq!(
Propagator::<Bn128Field>::with_constants(&mut Constants::new())
.fold_boolean_expression(BooleanExpression::BoolEq(EqExpression::new(
Propagator::<Bn128Field>::default().fold_boolean_expression(
BooleanExpression::BoolEq(EqExpression::new(
BooleanExpression::Value(false),
BooleanExpression::Value(true)
))),
))
),
Ok(BooleanExpression::Value(false))
);
}
@ -1933,33 +1917,27 @@ mod tests {
));
assert_eq!(
Propagator::with_constants(&mut Constants::new())
.fold_boolean_expression(e_constant_true),
Propagator::default().fold_boolean_expression(e_constant_true),
Ok(BooleanExpression::Value(true))
);
assert_eq!(
Propagator::with_constants(&mut Constants::new())
.fold_boolean_expression(e_constant_false),
Propagator::default().fold_boolean_expression(e_constant_false),
Ok(BooleanExpression::Value(false))
);
assert_eq!(
Propagator::with_constants(&mut Constants::new())
.fold_boolean_expression(e_identifier_true),
Propagator::default().fold_boolean_expression(e_identifier_true),
Ok(BooleanExpression::Value(true))
);
assert_eq!(
Propagator::with_constants(&mut Constants::new())
.fold_boolean_expression(e_identifier_unchanged.clone()),
Propagator::default().fold_boolean_expression(e_identifier_unchanged.clone()),
Ok(e_identifier_unchanged)
);
assert_eq!(
Propagator::with_constants(&mut Constants::new())
.fold_boolean_expression(e_non_canonical_true),
Propagator::default().fold_boolean_expression(e_non_canonical_true),
Ok(BooleanExpression::Value(true))
);
assert_eq!(
Propagator::with_constants(&mut Constants::new())
.fold_boolean_expression(e_non_canonical_false),
Propagator::default().fold_boolean_expression(e_non_canonical_false),
Ok(BooleanExpression::Value(false))
);
}
@ -1977,13 +1955,11 @@ mod tests {
);
assert_eq!(
Propagator::with_constants(&mut Constants::new())
.fold_boolean_expression(e_true),
Propagator::default().fold_boolean_expression(e_true),
Ok(BooleanExpression::Value(true))
);
assert_eq!(
Propagator::with_constants(&mut Constants::new())
.fold_boolean_expression(e_false),
Propagator::default().fold_boolean_expression(e_false),
Ok(BooleanExpression::Value(false))
);
}
@ -2001,13 +1977,11 @@ mod tests {
);
assert_eq!(
Propagator::with_constants(&mut Constants::new())
.fold_boolean_expression(e_true),
Propagator::default().fold_boolean_expression(e_true),
Ok(BooleanExpression::Value(true))
);
assert_eq!(
Propagator::with_constants(&mut Constants::new())
.fold_boolean_expression(e_false),
Propagator::default().fold_boolean_expression(e_false),
Ok(BooleanExpression::Value(false))
);
}
@ -2025,13 +1999,11 @@ mod tests {
);
assert_eq!(
Propagator::with_constants(&mut Constants::new())
.fold_boolean_expression(e_true),
Propagator::default().fold_boolean_expression(e_true),
Ok(BooleanExpression::Value(true))
);
assert_eq!(
Propagator::with_constants(&mut Constants::new())
.fold_boolean_expression(e_false),
Propagator::default().fold_boolean_expression(e_false),
Ok(BooleanExpression::Value(false))
);
}
@ -2049,13 +2021,11 @@ mod tests {
);
assert_eq!(
Propagator::with_constants(&mut Constants::new())
.fold_boolean_expression(e_true),
Propagator::default().fold_boolean_expression(e_true),
Ok(BooleanExpression::Value(true))
);
assert_eq!(
Propagator::with_constants(&mut Constants::new())
.fold_boolean_expression(e_false),
Propagator::default().fold_boolean_expression(e_false),
Ok(BooleanExpression::Value(false))
);
}
@ -2065,67 +2035,75 @@ mod tests {
let a_bool: Identifier = "a".into();
assert_eq!(
Propagator::<Bn128Field>::with_constants(&mut Constants::new())
.fold_boolean_expression(BooleanExpression::And(
Propagator::<Bn128Field>::default().fold_boolean_expression(
BooleanExpression::And(
box BooleanExpression::Value(true),
box BooleanExpression::identifier(a_bool.clone())
)),
)
),
Ok(BooleanExpression::identifier(a_bool.clone()))
);
assert_eq!(
Propagator::<Bn128Field>::with_constants(&mut Constants::new())
.fold_boolean_expression(BooleanExpression::And(
Propagator::<Bn128Field>::default().fold_boolean_expression(
BooleanExpression::And(
box BooleanExpression::identifier(a_bool.clone()),
box BooleanExpression::Value(true),
)),
)
),
Ok(BooleanExpression::identifier(a_bool.clone()))
);
assert_eq!(
Propagator::<Bn128Field>::with_constants(&mut Constants::new())
.fold_boolean_expression(BooleanExpression::And(
Propagator::<Bn128Field>::default().fold_boolean_expression(
BooleanExpression::And(
box BooleanExpression::Value(false),
box BooleanExpression::identifier(a_bool.clone())
)),
)
),
Ok(BooleanExpression::Value(false))
);
assert_eq!(
Propagator::<Bn128Field>::with_constants(&mut Constants::new())
.fold_boolean_expression(BooleanExpression::And(
Propagator::<Bn128Field>::default().fold_boolean_expression(
BooleanExpression::And(
box BooleanExpression::identifier(a_bool.clone()),
box BooleanExpression::Value(false),
)),
)
),
Ok(BooleanExpression::Value(false))
);
assert_eq!(
Propagator::<Bn128Field>::with_constants(&mut Constants::new())
.fold_boolean_expression(BooleanExpression::And(
Propagator::<Bn128Field>::default().fold_boolean_expression(
BooleanExpression::And(
box BooleanExpression::Value(true),
box BooleanExpression::Value(false),
)),
)
),
Ok(BooleanExpression::Value(false))
);
assert_eq!(
Propagator::<Bn128Field>::with_constants(&mut Constants::new())
.fold_boolean_expression(BooleanExpression::And(
Propagator::<Bn128Field>::default().fold_boolean_expression(
BooleanExpression::And(
box BooleanExpression::Value(false),
box BooleanExpression::Value(true),
)),
)
),
Ok(BooleanExpression::Value(false))
);
assert_eq!(
Propagator::<Bn128Field>::with_constants(&mut Constants::new())
.fold_boolean_expression(BooleanExpression::And(
Propagator::<Bn128Field>::default().fold_boolean_expression(
BooleanExpression::And(
box BooleanExpression::Value(true),
box BooleanExpression::Value(true),
)),
)
),
Ok(BooleanExpression::Value(true))
);
assert_eq!(
Propagator::<Bn128Field>::with_constants(&mut Constants::new())
.fold_boolean_expression(BooleanExpression::And(
Propagator::<Bn128Field>::default().fold_boolean_expression(
BooleanExpression::And(
box BooleanExpression::Value(false),
box BooleanExpression::Value(false),
)),
)
),
Ok(BooleanExpression::Value(false))
);
}
@ -2135,67 +2113,75 @@ mod tests {
let a_bool: Identifier = "a".into();
assert_eq!(
Propagator::<Bn128Field>::with_constants(&mut Constants::new())
.fold_boolean_expression(BooleanExpression::Or(
Propagator::<Bn128Field>::default().fold_boolean_expression(
BooleanExpression::Or(
box BooleanExpression::Value(true),
box BooleanExpression::identifier(a_bool.clone())
)),
)
),
Ok(BooleanExpression::Value(true))
);
assert_eq!(
Propagator::<Bn128Field>::with_constants(&mut Constants::new())
.fold_boolean_expression(BooleanExpression::Or(
Propagator::<Bn128Field>::default().fold_boolean_expression(
BooleanExpression::Or(
box BooleanExpression::identifier(a_bool.clone()),
box BooleanExpression::Value(true),
)),
)
),
Ok(BooleanExpression::Value(true))
);
assert_eq!(
Propagator::<Bn128Field>::with_constants(&mut Constants::new())
.fold_boolean_expression(BooleanExpression::Or(
Propagator::<Bn128Field>::default().fold_boolean_expression(
BooleanExpression::Or(
box BooleanExpression::Value(false),
box BooleanExpression::identifier(a_bool.clone())
)),
)
),
Ok(BooleanExpression::identifier(a_bool.clone()))
);
assert_eq!(
Propagator::<Bn128Field>::with_constants(&mut Constants::new())
.fold_boolean_expression(BooleanExpression::Or(
Propagator::<Bn128Field>::default().fold_boolean_expression(
BooleanExpression::Or(
box BooleanExpression::identifier(a_bool.clone()),
box BooleanExpression::Value(false),
)),
)
),
Ok(BooleanExpression::identifier(a_bool.clone()))
);
assert_eq!(
Propagator::<Bn128Field>::with_constants(&mut Constants::new())
.fold_boolean_expression(BooleanExpression::Or(
Propagator::<Bn128Field>::default().fold_boolean_expression(
BooleanExpression::Or(
box BooleanExpression::Value(true),
box BooleanExpression::Value(false),
)),
)
),
Ok(BooleanExpression::Value(true))
);
assert_eq!(
Propagator::<Bn128Field>::with_constants(&mut Constants::new())
.fold_boolean_expression(BooleanExpression::Or(
Propagator::<Bn128Field>::default().fold_boolean_expression(
BooleanExpression::Or(
box BooleanExpression::Value(false),
box BooleanExpression::Value(true),
)),
)
),
Ok(BooleanExpression::Value(true))
);
assert_eq!(
Propagator::<Bn128Field>::with_constants(&mut Constants::new())
.fold_boolean_expression(BooleanExpression::Or(
Propagator::<Bn128Field>::default().fold_boolean_expression(
BooleanExpression::Or(
box BooleanExpression::Value(true),
box BooleanExpression::Value(true),
)),
)
),
Ok(BooleanExpression::Value(true))
);
assert_eq!(
Propagator::<Bn128Field>::with_constants(&mut Constants::new())
.fold_boolean_expression(BooleanExpression::Or(
Propagator::<Bn128Field>::default().fold_boolean_expression(
BooleanExpression::Or(
box BooleanExpression::Value(false),
box BooleanExpression::Value(false),
)),
)
),
Ok(BooleanExpression::Value(false))
);
}

View file

@ -2,10 +2,11 @@
use crate::reducer::ConstantDefinitions;
use zokrates_ast::typed::{
folder::*, ArrayExpression, ArrayExpressionInner, ArrayType, BooleanExpression, CoreIdentifier,
DeclarationConstant, Expr, FieldElementExpression, Id, Identifier, IdentifierExpression,
StructExpression, StructExpressionInner, StructType, TupleExpression, TupleExpressionInner,
TupleType, TypedProgram, TypedSymbolDeclaration, UBitwidth, UExpression, UExpressionInner,
folder::*, identifier::FrameIdentifier, ArrayExpression, ArrayExpressionInner, ArrayType,
BooleanExpression, CoreIdentifier, DeclarationConstant, Expr, FieldElementExpression, Id,
Identifier, IdentifierExpression, StructExpression, StructExpressionInner, StructType,
TupleExpression, TupleExpressionInner, TupleType, TypedProgram, TypedSymbolDeclaration,
UBitwidth, UExpression, UExpressionInner,
};
use zokrates_field::Field;
@ -61,7 +62,11 @@ impl<'a, 'ast, T: Field> Folder<'ast, T> for ConstantsReader<'a, 'ast, T> {
FieldElementExpression::Identifier(IdentifierExpression {
id:
Identifier {
id: CoreIdentifier::Constant(c),
id:
FrameIdentifier {
id: CoreIdentifier::Constant(c),
frame: _,
},
version,
},
..
@ -86,7 +91,11 @@ impl<'a, 'ast, T: Field> Folder<'ast, T> for ConstantsReader<'a, 'ast, T> {
BooleanExpression::Identifier(IdentifierExpression {
id:
Identifier {
id: CoreIdentifier::Constant(c),
id:
FrameIdentifier {
id: CoreIdentifier::Constant(c),
frame: _,
},
version,
},
..
@ -112,7 +121,11 @@ impl<'a, 'ast, T: Field> Folder<'ast, T> for ConstantsReader<'a, 'ast, T> {
UExpressionInner::Identifier(IdentifierExpression {
id:
Identifier {
id: CoreIdentifier::Constant(c),
id:
FrameIdentifier {
id: CoreIdentifier::Constant(c),
frame: _,
},
version,
},
..
@ -136,7 +149,11 @@ impl<'a, 'ast, T: Field> Folder<'ast, T> for ConstantsReader<'a, 'ast, T> {
ArrayExpressionInner::Identifier(IdentifierExpression {
id:
Identifier {
id: CoreIdentifier::Constant(c),
id:
FrameIdentifier {
id: CoreIdentifier::Constant(c),
frame: _,
},
version,
},
..
@ -162,7 +179,11 @@ impl<'a, 'ast, T: Field> Folder<'ast, T> for ConstantsReader<'a, 'ast, T> {
TupleExpressionInner::Identifier(IdentifierExpression {
id:
Identifier {
id: CoreIdentifier::Constant(c),
id:
FrameIdentifier {
id: CoreIdentifier::Constant(c),
frame: _,
},
version,
},
..
@ -188,7 +209,11 @@ impl<'a, 'ast, T: Field> Folder<'ast, T> for ConstantsReader<'a, 'ast, T> {
StructExpressionInner::Identifier(IdentifierExpression {
id:
Identifier {
id: CoreIdentifier::Constant(c),
id:
FrameIdentifier {
id: CoreIdentifier::Constant(c),
frame: _,
},
version,
},
..

View file

@ -5,9 +5,9 @@ use crate::reducer::{
};
use std::collections::{BTreeMap, HashSet};
use zokrates_ast::typed::{
result_folder::*, types::ConcreteGenericsAssignment, Constant, OwnedTypedModuleId, Typed,
TypedConstant, TypedConstantSymbol, TypedConstantSymbolDeclaration, TypedModuleId,
TypedProgram, TypedSymbolDeclaration, UExpression,
result_folder::*, Constant, OwnedTypedModuleId, Typed, TypedConstant, TypedConstantSymbol,
TypedConstantSymbolDeclaration, TypedModuleId, TypedProgram, TypedSymbolDeclaration,
UExpression,
};
use zokrates_field::Field;
@ -118,11 +118,7 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for ConstantsWriter<'ast, T> {
signature: DeclarationSignature::new().output(c.ty.clone()),
};
let mut inlined_wrapper = reduce_function(
wrapper,
ConcreteGenericsAssignment::default(),
&self.program,
)?;
let mut inlined_wrapper = reduce_function(wrapper, &self.program)?;
if let TypedStatement::Return(expression) =
inlined_wrapper.statements.pop().unwrap()

View file

@ -15,25 +15,24 @@
// ```
//
// Becomes
// ```
// # Call foo::<42> with a_0 := x
// n_0 = 42
// a_1 = a_0
// n_1 = n_0
// # Pop call with #CALL_RETURN_AT_INDEX_0_0 := a_1
// inputs: [a]
// arguments: [x]
// generics_bindings: [n = 42]
// statements:
// n = 42
// a = a
// n = n
// return_expression: a
// Notes:
// - The body of the function is in SSA form
// - The return value(s) are assigned to internal variables
use crate::reducer::Output;
use crate::reducer::ShallowTransformer;
use crate::reducer::Versions;
// - The body of the function is *not* in SSA form
use zokrates_ast::common::FlatEmbed;
use zokrates_ast::typed::types::{ConcreteGenericsAssignment, IntoType};
use zokrates_ast::typed::CoreIdentifier;
use zokrates_ast::typed::Identifier;
use zokrates_ast::typed::TypedAssignee;
use zokrates_ast::typed::UBitwidth;
use zokrates_ast::typed::{
ConcreteFunctionKey, ConcreteSignature, ConcreteVariable, DeclarationFunctionKey, Expr,
Signature, Type, TypedExpression, TypedFunctionSymbol, TypedFunctionSymbolDeclaration,
@ -43,22 +42,12 @@ use zokrates_field::Field;
pub enum InlineError<'ast, T> {
Generic(DeclarationFunctionKey<'ast, T>, ConcreteFunctionKey<'ast>),
Flat(
FlatEmbed,
Vec<u32>,
Vec<TypedExpression<'ast, T>>,
Type<'ast, T>,
),
NonConstant(
DeclarationFunctionKey<'ast, T>,
Vec<Option<UExpression<'ast, T>>>,
Vec<TypedExpression<'ast, T>>,
Type<'ast, T>,
),
Flat(FlatEmbed, Vec<u32>, Type<'ast, T>),
NonConstant,
}
fn get_canonical_function<'ast, T: Field>(
function_key: DeclarationFunctionKey<'ast, T>,
function_key: &DeclarationFunctionKey<'ast, T>,
program: &TypedProgram<'ast, T>,
) -> TypedFunctionSymbolDeclaration<'ast, T> {
let s = program
@ -66,30 +55,35 @@ fn get_canonical_function<'ast, T: Field>(
.get(&function_key.module)
.unwrap()
.functions_iter()
.find(|d| d.key == function_key)
.find(|d| d.key == *function_key)
.unwrap();
match &s.symbol {
TypedFunctionSymbol::There(key) => get_canonical_function(key.clone(), program),
TypedFunctionSymbol::There(key) => get_canonical_function(key, program),
_ => s.clone(),
}
}
type InlineResult<'ast, T> = Result<
Output<(Vec<TypedStatement<'ast, T>>, TypedExpression<'ast, T>), Vec<Versions<'ast>>>,
InlineError<'ast, T>,
>;
pub struct InlineValue<'ast, T> {
/// the pre-SSA input variables to assign the arguments to
pub input_variables: Vec<Variable<'ast, T>>,
/// the pre-SSA statements for this call, including definition of the generic parameters
pub statements: Vec<TypedStatement<'ast, T>>,
/// the pre-SSA return value for this call
pub return_value: TypedExpression<'ast, T>,
}
type InlineResult<'ast, T> = Result<InlineValue<'ast, T>, InlineError<'ast, T>>;
pub fn inline_call<'a, 'ast, T: Field, E: Expr<'ast, T>>(
k: DeclarationFunctionKey<'ast, T>,
generics: Vec<Option<UExpression<'ast, T>>>,
arguments: Vec<TypedExpression<'ast, T>>,
output: &E::Ty,
k: &DeclarationFunctionKey<'ast, T>,
generics: &[Option<UExpression<'ast, T>>],
arguments: &[TypedExpression<'ast, T>],
output_ty: &E::Ty,
program: &TypedProgram<'ast, T>,
versions: &'a mut Versions<'ast>,
) -> InlineResult<'ast, T> {
use zokrates_ast::typed::Typed;
let output_type = output.clone().into_type();
let output_type = output_ty.clone().into_type();
// we try to get concrete values for explicit generics
let generics_values: Vec<Option<u32>> = generics
@ -103,36 +97,23 @@ pub fn inline_call<'a, 'ast, T: Field, E: Expr<'ast, T>>(
.transpose()
})
.collect::<Result<_, _>>()
.map_err(|_| {
InlineError::NonConstant(
k.clone(),
generics.clone(),
arguments.clone(),
output_type.clone(),
)
})?;
.map_err(|_| InlineError::NonConstant)?;
// we infer a signature based on inputs and outputs
// this is where we could handle explicit annotations
let inferred_signature = Signature::new()
.generics(generics.clone())
.generics(generics.to_vec().clone())
.inputs(arguments.iter().map(|a| a.get_type()).collect())
.output(output_type.clone());
// we try to get concrete values for the whole signature. if this fails we should propagate again
// we try to get concrete values for the whole signature
let inferred_signature = match ConcreteSignature::try_from(inferred_signature) {
Ok(s) => s,
Err(_) => {
return Err(InlineError::NonConstant(
k,
generics,
arguments,
output_type,
));
return Err(InlineError::NonConstant);
}
};
let decl = get_canonical_function(k.clone(), program);
let decl = get_canonical_function(k, program);
// get an assignment of generics for this call site
let assignment: ConcreteGenericsAssignment<'ast> = k
@ -154,7 +135,6 @@ pub fn inline_call<'a, 'ast, T: Field, E: Expr<'ast, T>>(
TypedFunctionSymbol::Flat(e) => Err(InlineError::Flat(
e,
e.generics::<T>(&assignment),
arguments.clone(),
output_type,
)),
_ => unreachable!(),
@ -162,59 +142,38 @@ pub fn inline_call<'a, 'ast, T: Field, E: Expr<'ast, T>>(
assert_eq!(f.arguments.len(), arguments.len());
let (ssa_f, incomplete_data) = match ShallowTransformer::transform(f, &assignment, versions) {
Output::Complete(v) => (v, None),
Output::Incomplete(statements, for_loop_versions) => (statements, Some(for_loop_versions)),
};
let generic_bindings = assignment.0.into_iter().map(|(identifier, value)| {
TypedStatement::Definition(
TypedAssignee::Identifier(Variable::uint(
CoreIdentifier::from(identifier),
UBitwidth::B32,
)),
TypedExpression::from(UExpression::from(value)).into(),
)
});
let call_log = TypedStatement::PushCallLog(decl.key.clone(), assignment.clone());
let input_bindings: Vec<TypedStatement<'ast, T>> = ssa_f
let input_variables: Vec<Variable<'ast, T>> = f
.arguments
.into_iter()
.zip(inferred_signature.inputs.clone())
.map(|(p, t)| ConcreteVariable::new(p.id.id, t, false))
.zip(arguments.clone())
.map(|(v, a)| TypedStatement::definition(Variable::from(v).into(), a))
.map(Variable::from)
.collect();
let (statements, mut returns): (Vec<_>, Vec<_>) = ssa_f
.statements
.into_iter()
let (statements, mut returns): (Vec<_>, Vec<_>) = generic_bindings
.chain(f.statements)
.partition(|s| !matches!(s, TypedStatement::Return(..)));
assert_eq!(returns.len(), 1);
let return_expression = match returns.pop().unwrap() {
let return_value = match returns.pop().unwrap() {
TypedStatement::Return(e) => e,
_ => unreachable!(),
};
let v: ConcreteVariable<'ast> = ConcreteVariable::new(
Identifier::from(CoreIdentifier::Call(0)).version(
*versions
.entry(CoreIdentifier::Call(0))
.and_modify(|e| *e += 1) // if it was already declared, we increment
.or_insert(0),
),
*inferred_signature.output.clone(),
false,
);
let expression = TypedExpression::from(Variable::from(v.clone()));
let output_binding = TypedStatement::definition(Variable::from(v).into(), return_expression);
let pop_log = TypedStatement::PopCallLog;
let statements: Vec<_> = std::iter::once(call_log)
.chain(input_bindings)
.chain(statements)
.chain(std::iter::once(output_binding))
.chain(std::iter::once(pop_log))
.collect();
Ok(incomplete_data
.map(|d| Output::Incomplete((statements.clone(), expression.clone()), d))
.unwrap_or_else(|| Output::Complete((statements, expression))))
Ok(InlineValue {
input_variables,
statements,
return_value,
})
}

View file

@ -3,40 +3,42 @@
// - free of function calls (except for low level calls) thanks to inlining
// - free of for-loops thanks to unrolling
// The process happens in two steps
// 1. Shallow SSA for the `main` function
// We turn the `main` function into SSA form, but ignoring function calls and for loops
// 2. Unroll and inline
// We go through the shallow-SSA program and
// - unroll loops
// - inline function calls. This includes applying shallow-ssa on the target function
// The process happens in a greedy way, starting from the main function
// For each statement:
// * put it in ssa form
// * propagate it
// * inline it (calling this process recursively)
// * propagate again
// if at any time a generic parameter or loop bound is not constant, error out, because it should have been propagated to a constant by the greedy approach
mod constants_reader;
mod constants_writer;
mod inline;
mod shallow_ssa;
use self::inline::InlineValue;
use self::inline::{inline_call, InlineError};
use std::collections::HashMap;
use zokrates_ast::typed::result_folder::*;
use zokrates_ast::typed::types::ConcreteGenericsAssignment;
use zokrates_ast::typed::types::GGenericsAssignment;
use zokrates_ast::typed::DeclarationParameter;
use zokrates_ast::typed::Folder;
use zokrates_ast::typed::{CanonicalConstantIdentifier, EmbedCall, Variable};
use zokrates_ast::typed::TypedAssignee;
use zokrates_ast::typed::{
ArrayExpressionInner, ArrayType, BlockExpression, CoreIdentifier, Expr, FunctionCall,
FunctionCallExpression, FunctionCallOrExpression, Id, Identifier, OwnedTypedModuleId,
TypedExpression, TypedFunction, TypedFunctionSymbol, TypedFunctionSymbolDeclaration,
TypedModule, TypedProgram, TypedStatement, UExpression, UExpressionInner,
FunctionCallExpression, FunctionCallOrExpression, Id, OwnedTypedModuleId, TypedExpression,
TypedFunction, TypedFunctionSymbol, TypedFunctionSymbolDeclaration, TypedModule, TypedProgram,
TypedStatement, UExpression, UExpressionInner,
};
use zokrates_ast::typed::{CanonicalConstantIdentifier, EmbedCall, Variable};
use zokrates_field::Field;
use self::constants_writer::ConstantsWriter;
use self::shallow_ssa::ShallowTransformer;
use crate::propagation::{Constants, Propagator};
use crate::propagation;
use crate::propagation::Propagator;
use std::fmt;
@ -46,25 +48,15 @@ const MAX_FOR_LOOP_SIZE: u128 = 2u128.pow(20);
pub type ConstantDefinitions<'ast, T> =
HashMap<CanonicalConstantIdentifier<'ast>, TypedExpression<'ast, T>>;
// An SSA version map, giving access to the latest version number for each identifier
pub type Versions<'ast> = HashMap<CoreIdentifier<'ast>, usize>;
// A container to represent whether more treatment must be applied to the function
#[derive(Debug, PartialEq, Eq)]
pub enum Output<U, V> {
Complete(U),
Incomplete(U, V),
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum Error {
Incompatible(String),
GenericsInMain,
// TODO: give more details about what's blocking the progress
NoProgress,
LoopTooLarge(u128),
ConstantReduction(String, OwnedTypedModuleId),
NonConstant(String),
Type(String),
Propagation(propagation::Error),
}
impl fmt::Display for Error {
@ -76,133 +68,36 @@ impl fmt::Display for Error {
s
),
Error::GenericsInMain => write!(f, "Cannot generate code for generic function"),
Error::NoProgress => write!(f, "Failed to unroll or inline program. Check that main function arguments aren't used as array size or for-loop bounds"),
Error::LoopTooLarge(size) => write!(f, "Found a loop of size {}, which is larger than the maximum allowed of {}. Check the loop bounds, especially for underflows", size, MAX_FOR_LOOP_SIZE),
Error::ConstantReduction(name, module) => write!(f, "Failed to reduce constant `{}` in module `{}` to a literal, try simplifying its declaration", name, module.display()),
Error::Type(message) => write!(f, "{}", message),
Error::NonConstant(s) => write!(f, "{}", s),
Error::Type(s) => write!(f, "{}", s),
Error::Propagation(e) => write!(f, "{}", e),
}
}
}
#[derive(Debug, Default)]
struct Substitutions<'ast>(HashMap<CoreIdentifier<'ast>, HashMap<usize, usize>>);
impl<'ast> Substitutions<'ast> {
// create an equivalent substitution map where all paths
// are of length 1
fn canonicalize(self) -> Self {
Substitutions(
self.0
.into_iter()
.map(|(id, sub)| (id, Self::canonicalize_sub(sub)))
.collect(),
)
}
// canonicalize substitutions for a given id
fn canonicalize_sub(sub: HashMap<usize, usize>) -> HashMap<usize, usize> {
fn add_to_cache(
sub: &HashMap<usize, usize>,
cache: HashMap<usize, usize>,
k: usize,
) -> HashMap<usize, usize> {
match cache.contains_key(&k) {
// `k` is already in the cache, no changes to the cache
true => cache,
_ => match sub.get(&k) {
// `k` does not point to anything, no changes to the cache
None => cache,
// `k` points to some `v
Some(v) => {
// add `v` to the cache
let cache = add_to_cache(sub, cache, *v);
// `k` points to what `v` points to, or to `v`
let v = cache.get(v).cloned().unwrap_or(*v);
let mut cache = cache;
cache.insert(k, v);
cache
}
},
}
}
sub.keys()
.fold(HashMap::new(), |cache, k| add_to_cache(&sub, cache, *k))
}
}
struct Sub<'a, 'ast> {
substitutions: &'a Substitutions<'ast>,
}
impl<'a, 'ast> Sub<'a, 'ast> {
fn new(substitutions: &'a Substitutions<'ast>) -> Self {
Self { substitutions }
}
}
impl<'a, 'ast, T: Field> Folder<'ast, T> for Sub<'a, 'ast> {
fn fold_name(&mut self, id: Identifier<'ast>) -> Identifier<'ast> {
let version = self
.substitutions
.0
.get(&id.id)
.map(|sub| sub.get(&id.version).cloned().unwrap_or(id.version))
.unwrap_or(id.version);
id.version(version)
}
}
fn register<'ast>(
substitutions: &mut Substitutions<'ast>,
substitute: &Versions<'ast>,
with: &Versions<'ast>,
) {
for (id, key, value) in substitute
.iter()
.filter_map(|(id, version)| with.get(id).map(|to| (id, version, to)))
.filter(|(_, key, value)| key != value)
{
let sub = substitutions.0.entry(id.clone()).or_default();
// redirect `k` to `v`, unless `v` is already redirected to `v0`, in which case we redirect to `v0`
sub.insert(*key, *sub.get(value).unwrap_or(value));
impl From<propagation::Error> for Error {
fn from(e: propagation::Error) -> Self {
Self::Propagation(e)
}
}
#[derive(Debug)]
struct Reducer<'ast, 'a, T> {
statement_buffer: Vec<TypedStatement<'ast, T>>,
for_loop_versions: Vec<Versions<'ast>>,
for_loop_versions_after: Vec<Versions<'ast>>,
program: &'a TypedProgram<'ast, T>,
versions: &'a mut Versions<'ast>,
substitutions: &'a mut Substitutions<'ast>,
complete: bool,
propagator: Propagator<'ast, T>,
ssa: ShallowTransformer<'ast>,
statement_buffer: Vec<TypedStatement<'ast, T>>,
}
impl<'ast, 'a, T: Field> Reducer<'ast, 'a, T> {
fn new(
program: &'a TypedProgram<'ast, T>,
versions: &'a mut Versions<'ast>,
substitutions: &'a mut Substitutions<'ast>,
for_loop_versions: Vec<Versions<'ast>>,
) -> Self {
// we reverse the vector as it's cheaper to `pop` than to take from
// the head
let mut for_loop_versions = for_loop_versions;
for_loop_versions.reverse();
fn new(program: &'a TypedProgram<'ast, T>) -> Self {
Reducer {
propagator: Propagator::default(),
ssa: ShallowTransformer::default(),
statement_buffer: vec![],
for_loop_versions_after: vec![],
for_loop_versions,
substitutions,
program,
versions,
complete: true,
}
}
}
@ -210,6 +105,13 @@ impl<'ast, 'a, T: Field> Reducer<'ast, 'a, T> {
impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> {
type Error = Error;
fn fold_parameter(
&mut self,
p: DeclarationParameter<'ast, T>,
) -> Result<DeclarationParameter<'ast, T>, Self::Error> {
Ok(self.ssa.fold_parameter(p))
}
fn fold_function_call_expression<
E: Id<'ast, T> + From<TypedExpression<'ast, T>> + Expr<'ast, T> + FunctionCall<'ast, T>,
>(
@ -217,65 +119,98 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> {
ty: &E::Ty,
e: FunctionCallExpression<'ast, T, E>,
) -> Result<FunctionCallOrExpression<'ast, T, E>, Self::Error> {
let generics = e
// generics are already in ssa form
let generics: Vec<_> = e
.generics
.into_iter()
.map(|g| g.map(|g| self.fold_uint_expression(g)).transpose())
.map(|g| {
g.map(|g| {
let g = self.propagator.fold_uint_expression(g)?;
let g = self.fold_uint_expression(g)?;
self.propagator
.fold_uint_expression(g)
.map_err(Self::Error::from)
})
.transpose()
})
.collect::<Result<_, _>>()?;
let arguments = e
// arguments are already in ssa form
let arguments: Vec<_> = e
.arguments
.into_iter()
.map(|e| self.fold_expression(e))
.map(|e| {
let e = self.propagator.fold_expression(e)?;
let e = self.fold_expression(e)?;
self.propagator
.fold_expression(e)
.map_err(Self::Error::from)
})
.collect::<Result<_, _>>()?;
let res = inline_call::<_, E>(
e.function_key.clone(),
generics,
arguments,
ty,
self.program,
self.versions,
);
self.ssa.push_call_frame();
let res = inline_call::<_, E>(&e.function_key, &generics, &arguments, ty, self.program);
let res = match res {
Ok(InlineValue {
input_variables,
statements,
return_value,
}) => {
// the lhs is from the inner call frame, the rhs is from the outer one, so only fold the lhs
let input_bindings: Vec<_> = input_variables
.into_iter()
.zip(arguments)
.map(|(v, a)| TypedStatement::definition(self.ssa.fold_assignee(v.into()), a))
.collect();
let input_bindings = input_bindings
.into_iter()
.map(|s| self.propagator.fold_statement(s))
.collect::<Result<Vec<_>, _>>()?
.into_iter()
.flatten();
self.statement_buffer.extend(input_bindings);
let statements = statements
.into_iter()
.map(|s| self.fold_statement(s))
.collect::<Result<Vec<_>, _>>()?
.into_iter()
.flatten();
match res {
Ok(Output::Complete((statements, expression))) => {
self.complete &= true;
self.statement_buffer.extend(statements);
let return_value = self.ssa.fold_expression(return_value);
let return_value = self.propagator.fold_expression(return_value)?;
let return_value = self.fold_expression(return_value)?;
Ok(FunctionCallOrExpression::Expression(
E::from(expression).into_inner(),
))
}
Ok(Output::Incomplete((statements, expression), delta_for_loop_versions)) => {
self.complete = false;
self.statement_buffer.extend(statements);
self.for_loop_versions_after.extend(delta_for_loop_versions);
Ok(FunctionCallOrExpression::Expression(
E::from(expression.clone()).into_inner(),
E::from(return_value).into_inner(),
))
}
Err(InlineError::Generic(decl, conc)) => Err(Error::Incompatible(format!(
"Call site `{}` incompatible with declaration `{}`",
conc, decl
))),
Err(InlineError::NonConstant(key, generics, arguments, _)) => {
self.complete = false;
Ok(FunctionCallOrExpression::Expression(E::function_call(
key, generics, arguments,
)))
}
Err(InlineError::Flat(embed, generics, arguments, output_type)) => {
let identifier = Identifier::from(CoreIdentifier::Call(0)).version(
*self
.versions
.entry(CoreIdentifier::Call(0))
.and_modify(|e| *e += 1) // if it was already declared, we increment
.or_insert(0),
);
Err(InlineError::NonConstant) => Err(Error::NonConstant(format!(
"Generic parameters must be compile-time constants, found {}",
FunctionCallExpression::<_, E>::new(e.function_key, generics, arguments)
))),
Err(InlineError::Flat(embed, generics, output_type)) => {
let identifier = self.ssa.issue_next_identifier(CoreIdentifier::Call(0));
let var = Variable::immutable(identifier.clone(), output_type);
let v = var.clone().into();
let v: TypedAssignee<'ast, T> = var.clone().into();
self.statement_buffer
.push(TypedStatement::embed_call_definition(
@ -286,7 +221,11 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> {
identifier,
)))
}
}
};
self.ssa.pop_call_frame();
res
}
fn fold_block_expression<E: ResultFold<'ast, T>>(
@ -325,74 +264,70 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> {
) -> Result<Vec<TypedStatement<'ast, T>>, Self::Error> {
let res = match s {
TypedStatement::For(v, from, to, statements) => {
let versions_before = self.for_loop_versions.pop().unwrap();
let from = self.ssa.fold_uint_expression(from);
let from = self.propagator.fold_uint_expression(from)?;
let from = self.fold_uint_expression(from)?;
let from = self.propagator.fold_uint_expression(from)?;
let to = self.ssa.fold_uint_expression(to);
let to = self.propagator.fold_uint_expression(to)?;
let to = self.fold_uint_expression(to)?;
let to = self.propagator.fold_uint_expression(to)?;
match (from.as_inner(), to.as_inner()) {
(UExpressionInner::Value(from), UExpressionInner::Value(to)) => {
let mut out_statements = vec![];
// get a fresh set of versions for all variables to use as a starting point inside the loop
self.versions.values_mut().for_each(|v| *v += 1);
// add this set of versions to the substitution, pointing to the versions before the loop
register(self.substitutions, self.versions, &versions_before);
// the versions after the loop are found by applying an offset of 1 to the versions before the loop
let versions_after = versions_before
.clone()
.into_iter()
.map(|(k, v)| (k, v + 1))
.collect();
let mut transformer = ShallowTransformer::with_versions(self.versions);
if to - from > MAX_FOR_LOOP_SIZE {
return Err(Error::LoopTooLarge(to.saturating_sub(*from)));
}
for index in *from..*to {
let statements: Vec<TypedStatement<_>> =
std::iter::once(TypedStatement::definition(
v.clone().into(),
UExpression::from(index as u32).into(),
))
.chain(statements.clone().into_iter())
.flat_map(|s| transformer.fold_statement(s))
.collect();
out_statements.extend(statements);
}
let backups = transformer.for_loop_backups;
let blocked = transformer.blocked;
// we know the final versions of the variables after full unrolling of the loop
// the versions after the loop need to point to these, so we add to the substitutions
register(self.substitutions, &versions_after, self.versions);
// we may have found new for loops when unrolling this one, which means new backed up versions
// we insert these in our backup list and update our cursor
self.for_loop_versions_after.extend(backups);
// if the ssa transform got blocked, the reduction is not complete
self.complete &= !blocked;
Ok(out_statements)
(UExpressionInner::Value(from), UExpressionInner::Value(to))
if to - from > MAX_FOR_LOOP_SIZE =>
{
Err(Error::LoopTooLarge(to.saturating_sub(*from)))
}
_ => {
let from = self.fold_uint_expression(from)?;
let to = self.fold_uint_expression(to)?;
self.complete = false;
self.for_loop_versions_after.push(versions_before);
Ok(vec![TypedStatement::For(v, from, to, statements)])
}
}
(UExpressionInner::Value(from), UExpressionInner::Value(to)) => Ok((*from
..*to)
.flat_map(|index| {
std::iter::once(TypedStatement::definition(
v.clone().into(),
UExpression::from(index as u32).into(),
))
.chain(statements.clone())
.map(|s| self.fold_statement(s))
.collect::<Vec<_>>()
})
.collect::<Result<Vec<_>, _>>()?
.into_iter()
.flatten()
.collect::<Vec<_>>()),
_ => Err(Error::NonConstant(format!(
"Expected loop bounds to be constant, found {}..{}",
from, to
))),
}?
}
s => {
let statements = self.ssa.fold_statement(s);
let statements = statements
.into_iter()
.map(|s| self.propagator.fold_statement(s))
.collect::<Result<Vec<_>, _>>()?
.into_iter()
.flatten();
let statements = statements
.map(|s| fold_statement(self, s))
.collect::<Result<Vec<_>, _>>()?
.into_iter()
.flatten();
let statements = statements
.map(|s| self.propagator.fold_statement(s))
.collect::<Result<Vec<_>, _>>()?
.into_iter()
.flatten();
statements.collect()
}
s => fold_statement(self, s),
};
res.map(|res| self.statement_buffer.drain(..).chain(res).collect())
Ok(self.statement_buffer.drain(..).chain(res).collect())
}
fn fold_array_expression_inner(
@ -402,18 +337,29 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> {
) -> Result<ArrayExpressionInner<'ast, T>, Self::Error> {
match e {
ArrayExpressionInner::Slice(box array, box from, box to) => {
let array = self.ssa.fold_array_expression(array);
let array = self.propagator.fold_array_expression(array)?;
let array = self.fold_array_expression(array)?;
let array = self.propagator.fold_array_expression(array)?;
let from = self.ssa.fold_uint_expression(from);
let from = self.propagator.fold_uint_expression(from)?;
let from = self.fold_uint_expression(from)?;
let from = self.propagator.fold_uint_expression(from)?;
let to = self.ssa.fold_uint_expression(to);
let to = self.propagator.fold_uint_expression(to)?;
let to = self.fold_uint_expression(to)?;
let to = self.propagator.fold_uint_expression(to)?;
match (from.as_inner(), to.as_inner()) {
(UExpressionInner::Value(..), UExpressionInner::Value(..)) => {
Ok(ArrayExpressionInner::Slice(box array, box from, box to))
}
_ => {
self.complete = false;
Ok(ArrayExpressionInner::Slice(box array, box from, box to))
}
_ => Err(Error::NonConstant(format!(
"Slice bounds must be compile time constants, found {}",
ArrayExpressionInner::Slice(box array, box from, box to)
))),
}
}
_ => fold_array_expression_inner(self, array_ty, e),
@ -443,7 +389,7 @@ pub fn reduce_program<T: Field>(p: TypedProgram<T>) -> Result<TypedProgram<T>, E
match main_function.signature.generics.len() {
0 => {
let main_function = reduce_function(main_function, GGenericsAssignment::default(), &p)?;
let main_function = Reducer::new(&p).fold_function(main_function)?;
Ok(TypedProgram {
main: p.main.clone(),
@ -467,91 +413,11 @@ pub fn reduce_program<T: Field>(p: TypedProgram<T>) -> Result<TypedProgram<T>, E
fn reduce_function<'ast, T: Field>(
f: TypedFunction<'ast, T>,
generics: ConcreteGenericsAssignment<'ast>,
program: &TypedProgram<'ast, T>,
) -> Result<TypedFunction<'ast, T>, Error> {
let mut versions = Versions::default();
assert!(f.signature.generics.is_empty());
let mut constants = Constants::default();
let f = match ShallowTransformer::transform(f, &generics, &mut versions) {
Output::Complete(f) => Ok(f),
Output::Incomplete(new_f, new_for_loop_versions) => {
let mut for_loop_versions = new_for_loop_versions;
let mut f = new_f;
let mut substitutions = Substitutions::default();
let mut hash = None;
loop {
let mut reducer = Reducer::new(
program,
&mut versions,
&mut substitutions,
for_loop_versions,
);
let new_f = TypedFunction {
statements: f
.statements
.into_iter()
.map(|s| reducer.fold_statement(s))
.collect::<Result<Vec<_>, _>>()?
.into_iter()
.flatten()
.collect(),
..f
};
assert!(reducer.for_loop_versions.is_empty());
match reducer.complete {
true => {
substitutions = substitutions.canonicalize();
let new_f = Sub::new(&substitutions).fold_function(new_f);
let new_f = Propagator::with_constants(&mut constants)
.fold_function(new_f)
.map_err(|e| Error::Incompatible(format!("{}", e)))?;
break Ok(new_f);
}
false => {
for_loop_versions = reducer.for_loop_versions_after;
let new_f = Sub::new(&substitutions).fold_function(new_f);
f = Propagator::with_constants(&mut constants)
.fold_function(new_f)
.map_err(|e| Error::Incompatible(format!("{}", e)))?;
let new_hash = Some(compute_hash(&f));
if new_hash == hash {
break Err(Error::NoProgress);
} else {
hash = new_hash
}
}
}
}
}
}?;
Propagator::with_constants(&mut constants)
.fold_function(f)
.map_err(|e| Error::Incompatible(format!("{}", e)))
}
fn compute_hash<T: Field>(f: &TypedFunction<T>) -> u64 {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut s = DefaultHasher::new();
f.hash(&mut s);
s.finish()
Reducer::new(program).fold_function(f)
}
#[cfg(test)]
@ -588,14 +454,11 @@ mod tests {
// }
// expected:
// def main(field a_0) -> field {
// a_1 = a_0;
// # PUSH CALL to foo
// a_3 := a_1; // input binding
// #RETURN_AT_INDEX_0_0 := a_3;
// # POP CALL
// a_2 = #RETURN_AT_INDEX_0_0;
// return a_2;
// def main(field a_f0_v0) -> field {
// a_f0_v1 = a_f0_v0; // redef
// a_f1_v0 = a_f0_v1; // input binding
// a_f0_v2 = a_f1_v0; // output binding
// return a_f0_v2;
// }
let foo: TypedFunction<Bn128Field> = TypedFunction {
@ -691,30 +554,13 @@ mod tests {
Variable::field_element(Identifier::from("a").version(1)).into(),
FieldElementExpression::identifier("a".into()).into(),
),
TypedStatement::PushCallLog(
DeclarationFunctionKey::with_location("main", "foo").signature(
DeclarationSignature::new()
.inputs(vec![DeclarationType::FieldElement])
.output(DeclarationType::FieldElement),
),
GGenericsAssignment::default(),
),
TypedStatement::definition(
Variable::field_element(Identifier::from("a").version(3)).into(),
Variable::field_element(Identifier::from("a").in_frame(1)).into(),
FieldElementExpression::identifier(Identifier::from("a").version(1)).into(),
),
TypedStatement::definition(
Variable::field_element(Identifier::from(CoreIdentifier::Call(0)).version(0))
.into(),
FieldElementExpression::identifier(Identifier::from("a").version(3)).into(),
),
TypedStatement::PopCallLog,
TypedStatement::definition(
Variable::field_element(Identifier::from("a").version(2)).into(),
FieldElementExpression::identifier(
Identifier::from(CoreIdentifier::Call(0)).version(0),
)
.into(),
FieldElementExpression::identifier(Identifier::from("a").in_frame(1)).into(),
),
TypedStatement::Return(
FieldElementExpression::identifier(Identifier::from("a").version(2)).into(),
@ -763,14 +609,11 @@ mod tests {
// }
// expected:
// def main(field a_0) -> field {
// field[1] b_0 = [42];
// # PUSH CALL to foo::<1>
// a_0 = b_0;
// #RETURN_AT_INDEX_0_0 := a_0;
// # POP CALL
// b_1 = #RETURN_AT_INDEX_0_0;
// return a_2 + b_1[0];
// def main(field a_f0_v0) -> field {
// field[1] b_f0_v0 = [a_f0_v0];
// a_f1_v0 = b_f0_v0;
// b_f0_v1 = a_f1_v0;
// return a_f0_v0 + b_f0_v1[0];
// }
let foo_signature = DeclarationSignature::new()
@ -897,42 +740,19 @@ mod tests {
.annotate(Type::FieldElement, 1u32)
.into(),
),
TypedStatement::PushCallLog(
DeclarationFunctionKey::with_location("main", "foo")
.signature(foo_signature.clone()),
GGenericsAssignment(
vec![(GenericIdentifier::with_name("K").with_index(0), 1)]
.into_iter()
.collect(),
),
),
TypedStatement::definition(
Variable::array(Identifier::from("a").version(1), Type::FieldElement, 1u32)
Variable::array(Identifier::from("a").in_frame(1), Type::FieldElement, 1u32)
.into(),
ArrayExpression::identifier("b".into())
.annotate(Type::FieldElement, 1u32)
.into(),
),
TypedStatement::definition(
Variable::array(
Identifier::from(CoreIdentifier::Call(0)).version(0),
Type::FieldElement,
1u32,
)
.into(),
ArrayExpression::identifier(Identifier::from("a").version(1))
.annotate(Type::FieldElement, 1u32)
.into(),
),
TypedStatement::PopCallLog,
TypedStatement::definition(
Variable::array(Identifier::from("b").version(1), Type::FieldElement, 1u32)
.into(),
ArrayExpression::identifier(
Identifier::from(CoreIdentifier::Call(0)).version(0),
)
.annotate(Type::FieldElement, 1u32)
.into(),
ArrayExpression::identifier(Identifier::from("a").in_frame(1))
.annotate(Type::FieldElement, 1u32)
.into(),
),
TypedStatement::Return(
(FieldElementExpression::identifier("a".into())
@ -987,14 +807,11 @@ mod tests {
// }
// expected:
// def main(field a_0) -> field {
// field[1] b_0 = [42];
// # PUSH CALL to foo::<1>
// a_0 = b_0;
// #RETURN_AT_INDEX_0_0 := a_0;
// # POP CALL
// b_1 = #RETURN_AT_INDEX_0_0;
// return a_2 + b_1[0];
// def main(field a) -> field {
// field[1] b = [a];
// a_f1 = b;
// b_1 = a_f1;
// return a + b_1[0];
// }
let foo_signature = DeclarationSignature::new()
@ -1125,47 +942,25 @@ mod tests {
TypedStatement::definition(
Variable::array("b", Type::FieldElement, 1u32).into(),
ArrayExpressionInner::Value(
vec![FieldElementExpression::identifier("a".into()).into()].into(),
vec![FieldElementExpression::identifier(Identifier::from("a")).into()]
.into(),
)
.annotate(Type::FieldElement, 1u32)
.into(),
),
TypedStatement::PushCallLog(
DeclarationFunctionKey::with_location("main", "foo")
.signature(foo_signature.clone()),
GGenericsAssignment(
vec![(GenericIdentifier::with_name("K").with_index(0), 1)]
.into_iter()
.collect(),
),
),
TypedStatement::definition(
Variable::array(Identifier::from("a").version(1), Type::FieldElement, 1u32)
Variable::array(Identifier::from("a").in_frame(1), Type::FieldElement, 1u32)
.into(),
ArrayExpression::identifier("b".into())
.annotate(Type::FieldElement, 1u32)
.into(),
),
TypedStatement::definition(
Variable::array(
Identifier::from(CoreIdentifier::Call(0)).version(0),
Type::FieldElement,
1u32,
)
.into(),
ArrayExpression::identifier(Identifier::from("a").version(1))
.annotate(Type::FieldElement, 1u32)
.into(),
),
TypedStatement::PopCallLog,
TypedStatement::definition(
Variable::array(Identifier::from("b").version(1), Type::FieldElement, 1u32)
.into(),
ArrayExpression::identifier(
Identifier::from(CoreIdentifier::Call(0)).version(0),
)
.annotate(Type::FieldElement, 1u32)
.into(),
ArrayExpression::identifier(Identifier::from("a").in_frame(1))
.annotate(Type::FieldElement, 1u32)
.into(),
),
TypedStatement::Return(
(FieldElementExpression::identifier("a".into())
@ -1391,33 +1186,11 @@ mod tests {
let expected_main = TypedFunction {
arguments: vec![],
statements: vec![
TypedStatement::PushCallLog(
DeclarationFunctionKey::with_location("main", "foo")
.signature(foo_signature.clone()),
GGenericsAssignment(
vec![(GenericIdentifier::with_name("K").with_index(0), 1)]
.into_iter()
.collect(),
),
),
TypedStatement::PushCallLog(
DeclarationFunctionKey::with_location("main", "bar")
.signature(foo_signature.clone()),
GGenericsAssignment(
vec![(GenericIdentifier::with_name("K").with_index(0), 2)]
.into_iter()
.collect(),
),
),
TypedStatement::PopCallLog,
TypedStatement::PopCallLog,
TypedStatement::Return(
TupleExpressionInner::Value(vec![])
.annotate(TupleType::new(vec![]))
.into(),
),
],
statements: vec![TypedStatement::Return(
TupleExpressionInner::Value(vec![])
.annotate(TupleType::new(vec![]))
.into(),
)],
signature: DeclarationSignature::new(),
};

View file

@ -1,7 +1,6 @@
// The SSA transformation leaves gaps in the indices when it hits a for-loop, so that the body of the for-loop can
// modify the variables in scope. The state of the indices before all for-loops is returned to account for that possibility.
// Function calls are also left unvisited
// Saving the indices is not required for function calls, as they cannot modify their environment
// The SSA transformation
// * introduces new versions if and only if we are assigning to an identifier
// * does not visit the statements of loops
// Example:
// def main(field a) -> field {
@ -19,178 +18,167 @@
// u32 n_0 = 42;
// a_1 = a_0 + 1;
// field b_0 = foo(a_1); // we keep the function call as is
// # versions: {n: 0, a: 1, b: 0}
// for u32 i_0 in 0..n_0 {
// <body> // we keep the loop body as is
// }
// return b_3; // we leave versions b_1 and b_2 to make b accessible and modifiable inside the for-loop
// }
use std::collections::HashMap;
use zokrates_ast::typed::folder::*;
use zokrates_ast::typed::types::ConcreteGenericsAssignment;
use zokrates_ast::typed::types::Type;
use zokrates_ast::typed::*;
use zokrates_field::Field;
use super::{Output, Versions};
pub struct ShallowTransformer<'ast, 'a> {
// version index for any variable name
pub versions: &'a mut Versions<'ast>,
// A backup of the versions before each for-loop
pub for_loop_backups: Vec<Versions<'ast>>,
// whether all statements could be unrolled so far. Loops with variable bounds cannot.
pub blocked: bool,
// An SSA version map, giving access to the latest version number for each identifier
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Versions<'ast> {
map: HashMap<usize, HashMap<CoreIdentifier<'ast>, usize>>,
}
impl<'ast, 'a> ShallowTransformer<'ast, 'a> {
pub fn with_versions(versions: &'a mut Versions<'ast>) -> Self {
ShallowTransformer {
versions,
for_loop_backups: Vec::default(),
blocked: false,
impl<'ast> Default for Versions<'ast> {
fn default() -> Self {
// create a call frame at index 0
Self {
map: vec![(0, Default::default())].into_iter().collect(),
}
}
}
// increase all versions by 1 and return the old versions
fn create_version_gap(&mut self) -> Versions<'ast> {
let ret = self.versions.clone();
self.versions.values_mut().for_each(|v| *v += 1);
ret
}
#[derive(Debug, Default)]
pub struct ShallowTransformer<'ast> {
// version index for any variable name
pub versions: Versions<'ast>,
pub frames: Vec<usize>,
pub latest_frame: usize,
}
fn issue_next_identifier(&mut self, c_id: CoreIdentifier<'ast>) -> Identifier<'ast> {
let version = *self
.versions
impl<'ast> ShallowTransformer<'ast> {
pub fn issue_next_identifier(&mut self, c_id: CoreIdentifier<'ast>) -> Identifier<'ast> {
let frame = self.frame();
let frame_versions = self.versions.map.entry(frame).or_default();
let version = frame_versions
.entry(c_id.clone())
.and_modify(|e| *e += 1) // if it was already declared, we increment
.or_insert(0); // otherwise, we start from this version
.or_default(); // otherwise, we start from this version
Identifier::from(c_id).version(version)
Identifier::from(c_id.in_frame(frame)).version(*version)
}
fn issue_next_ssa_variable<T: Field>(&mut self, v: Variable<'ast, T>) -> Variable<'ast, T> {
assert_eq!(v.id.version, 0);
Variable {
id: self.issue_next_identifier(v.id.id),
id: self.issue_next_identifier(v.id.id.id),
..v
}
}
pub fn transform<T: Field>(
f: TypedFunction<'ast, T>,
generics: &ConcreteGenericsAssignment<'ast>,
versions: &'a mut Versions<'ast>,
) -> Output<TypedFunction<'ast, T>, Vec<Versions<'ast>>> {
let mut unroller = ShallowTransformer::with_versions(versions);
let f = unroller.fold_function(f, generics);
match unroller.blocked {
false => Output::Complete(f),
true => Output::Incomplete(f, unroller.for_loop_backups),
}
fn frame(&self) -> usize {
*self.frames.last().unwrap_or(&0)
}
fn fold_function<T: Field>(
pub fn push_call_frame(&mut self) {
self.latest_frame += 1;
self.frames.push(self.latest_frame);
self.versions
.map
.insert(self.latest_frame, Default::default());
}
pub fn pop_call_frame(&mut self) {
let frame = self.frames.pop().unwrap();
self.versions.map.remove(&frame);
}
// fold an assignee replacing by the latest version. This is necessary because the trait implementation increases the ssa version for identifiers,
// but this should not be applied recursively to complex assignees
fn fold_assignee_no_ssa_increase<T: Field>(
&mut self,
f: TypedFunction<'ast, T>,
generics: &ConcreteGenericsAssignment<'ast>,
) -> TypedFunction<'ast, T> {
let mut f = f;
a: TypedAssignee<'ast, T>,
) -> TypedAssignee<'ast, T> {
match a {
TypedAssignee::Identifier(v) => TypedAssignee::Identifier(self.fold_variable(v)),
TypedAssignee::Select(box a, box index) => TypedAssignee::Select(
box self.fold_assignee_no_ssa_increase(a),
box self.fold_uint_expression(index),
),
TypedAssignee::Member(box s, m) => {
TypedAssignee::Member(box self.fold_assignee_no_ssa_increase(s), m)
}
TypedAssignee::Element(box s, index) => {
TypedAssignee::Element(box self.fold_assignee_no_ssa_increase(s), index)
}
}
}
}
f.statements = generics
.0
.clone()
.into_iter()
.map(|(g, v)| {
TypedStatement::definition(
Variable::new(CoreIdentifier::from(g), Type::Uint(UBitwidth::B32), false)
.into(),
UExpression::from(v as u32).into(),
)
})
.chain(f.statements)
.collect();
for arg in &f.arguments {
let _ = self.issue_next_identifier(arg.id.id.id.clone());
impl<'ast, T: Field> Folder<'ast, T> for ShallowTransformer<'ast> {
fn fold_function(&mut self, f: TypedFunction<'ast, T>) -> TypedFunction<'ast, T> {
for g in &f.signature.generics {
let generic_parameter = match g.as_ref().unwrap() {
DeclarationConstant::Generic(g) => g,
_ => unreachable!(),
};
let _ = self.issue_next_identifier(CoreIdentifier::from(generic_parameter.clone()));
}
fold_function(self, f)
}
fn fold_assignee<T: Field>(&mut self, a: TypedAssignee<'ast, T>) -> TypedAssignee<'ast, T> {
fn fold_parameter(
&mut self,
p: DeclarationParameter<'ast, T>,
) -> DeclarationParameter<'ast, T> {
DeclarationParameter {
id: DeclarationVariable {
id: self.issue_next_identifier(p.id.id.id.id),
..p.id
},
..p
}
}
fn fold_assignee(&mut self, a: TypedAssignee<'ast, T>) -> TypedAssignee<'ast, T> {
match a {
// create a new version for assignments to identifiers
TypedAssignee::Identifier(v) => {
let v = self.issue_next_ssa_variable(v);
TypedAssignee::Identifier(self.fold_variable(v))
}
a => fold_assignee(self, a),
// otherwise, simply replace by the current version
a => self.fold_assignee_no_ssa_increase(a),
}
}
}
impl<'ast, 'a, T: Field> Folder<'ast, T> for ShallowTransformer<'ast, 'a> {
fn fold_assembly_statement(
&mut self,
s: TypedAssemblyStatement<'ast, T>,
) -> Vec<TypedAssemblyStatement<'ast, T>> {
match s {
TypedAssemblyStatement::Assignment(a, e) => {
let e = self.fold_expression(e);
let a = self.fold_assignee(a);
vec![TypedAssemblyStatement::Assignment(a, e)]
}
s => fold_assembly_statement(self, s),
}
}
fn fold_statement(&mut self, s: TypedStatement<'ast, T>) -> Vec<TypedStatement<'ast, T>> {
match s {
TypedStatement::Definition(a, DefinitionRhs::Expression(e)) => {
let e = self.fold_expression(e);
let a = self.fold_assignee(a);
vec![TypedStatement::definition(a, e)]
}
TypedStatement::Definition(assignee, DefinitionRhs::EmbedCall(embed_call)) => {
let embed_call = self.fold_embed_call(embed_call);
let assignee = self.fold_assignee(assignee);
vec![TypedStatement::embed_call_definition(assignee, embed_call)]
}
// only fold bounds of for loop statements
TypedStatement::For(v, from, to, stats) => {
let from = self.fold_uint_expression(from);
let to = self.fold_uint_expression(to);
self.blocked = true;
let versions_before_loop = self.create_version_gap();
self.for_loop_backups.push(versions_before_loop);
vec![TypedStatement::For(v, from, to, stats)]
}
s => fold_statement(self, s),
}
}
// retrieve the latest version
fn fold_name(&mut self, n: Identifier<'ast>) -> Identifier<'ast> {
let res = Identifier {
version: *self.versions.get(&(n.id)).unwrap_or(&0),
..n
};
res
}
let version = self
.versions
.map
.get(&self.frame())
.unwrap()
.get(&n.id.id)
.cloned()
.unwrap_or(0);
fn fold_function_call_expression<
E: Id<'ast, T> + From<TypedExpression<'ast, T>> + Expr<'ast, T> + FunctionCall<'ast, T>,
>(
&mut self,
ty: &E::Ty,
c: FunctionCallExpression<'ast, T, E>,
) -> FunctionCallOrExpression<'ast, T, E> {
if !c.function_key.id.starts_with('_') {
self.blocked = true;
}
fold_function_call_expression(self, ty, c)
n.in_frame(self.frame()).version(version)
}
}
@ -203,36 +191,57 @@ mod tests {
use super::*;
#[test]
fn detect_non_constant_bound() {
let loops: Vec<TypedStatement<Bn128Field>> = vec![TypedStatement::For(
Variable::new("i", Type::Uint(UBitwidth::B32), false),
UExpression::identifier("i".into()).annotate(UBitwidth::B32),
2u32.into(),
vec![],
)];
fn ignore_loop_content() {
// field foo = 0
// u32 i = 4;
// for u32 i in i..2 {
// foo = 5;
// }
let statements = loops;
// should be left unchanged, as we do not visit the loop content nor the index variable
let f = TypedFunction {
arguments: vec![],
signature: DeclarationSignature::new(),
statements,
statements: vec![
TypedStatement::definition(
TypedAssignee::Identifier(Variable::field_element(Identifier::from("foo"))),
FieldElementExpression::Number(Bn128Field::from(4)).into(),
),
TypedStatement::definition(
TypedAssignee::Identifier(Variable::uint(
Identifier::from("i"),
UBitwidth::B32,
)),
UExpression::from(0u32).into(),
),
TypedStatement::For(
Variable::new("i", Type::Uint(UBitwidth::B32), false),
UExpression::identifier("i".into()).annotate(UBitwidth::B32),
2u32.into(),
vec![TypedStatement::definition(
TypedAssignee::Identifier(Variable::field_element(Identifier::from(
"foo",
))),
FieldElementExpression::Number(Bn128Field::from(5)).into(),
)],
),
TypedStatement::Return(
TupleExpressionInner::Value(vec![])
.annotate(TupleType::new(vec![]))
.into(),
),
],
signature: DeclarationSignature::default(),
};
match ShallowTransformer::transform(
f,
&ConcreteGenericsAssignment::default(),
&mut Versions::default(),
) {
Output::Incomplete(..) => {}
_ => unreachable!(),
};
let mut ssa = ShallowTransformer::default();
assert_eq!(ssa.fold_function(f.clone()), f);
}
#[test]
fn definition() {
// field a
// a = 5
// field a = 5
// a = 6
// a
@ -241,9 +250,7 @@ mod tests {
// a_1 = 6
// a_1
let mut versions = Versions::new();
let mut u = ShallowTransformer::with_versions(&mut versions);
let mut u = ShallowTransformer::default();
let s = TypedStatement::definition(
TypedAssignee::Identifier(Variable::field_element("a")),
@ -283,17 +290,14 @@ mod tests {
#[test]
fn incremental_definition() {
// field a
// a = 5
// field a = 5
// a = a + 1
// should be turned into
// a_0 = 5
// a_1 = a_0 + 1
let mut versions = Versions::new();
let mut u = ShallowTransformer::with_versions(&mut versions);
let mut u = ShallowTransformer::default();
let s = TypedStatement::definition(
TypedAssignee::Identifier(Variable::field_element("a")),
@ -342,9 +346,7 @@ mod tests {
// a_0 = 2
// a_1 = foo(a_0)
let mut versions = Versions::new();
let mut u = ShallowTransformer::with_versions(&mut versions);
let mut u = ShallowTransformer::default();
let s = TypedStatement::definition(
TypedAssignee::Identifier(Variable::field_element("a")),
@ -403,9 +405,7 @@ mod tests {
// a_0 = [1, 1]
// a_0[1] = 2
let mut versions = Versions::new();
let mut u = ShallowTransformer::with_versions(&mut versions);
let mut u = ShallowTransformer::default();
let s = TypedStatement::definition(
TypedAssignee::Identifier(Variable::array("a", Type::FieldElement, 2u32)),
@ -460,9 +460,7 @@ mod tests {
// a_0 = [[0, 1], [2, 3]]
// a_0 = [4, 5]
let mut versions = Versions::new();
let mut u = ShallowTransformer::with_versions(&mut versions);
let mut u = ShallowTransformer::default();
let array_of_array_ty = Type::array((Type::array((Type::FieldElement, 2u32)), 2u32));
@ -557,10 +555,10 @@ mod tests {
mod for_loop {
use super::*;
use zokrates_ast::typed::types::GGenericsAssignment;
#[test]
fn treat_loop() {
// def main<K>(field a) -> field {
// def main(field a) -> field {
// u32 n = 42;
// n = n;
// a = a;
@ -575,24 +573,21 @@ mod tests {
// return a;
// }
// When called with K := 1, expected:
// expected:
// def main(field a_0) -> field {
// u32 K = 1;
// u32 n_0 = 42;
// n_1 = n_0;
// a_1 = a_0;
// # versions: {n: 1, a: 1, K: 0}
// for u32 i_0 in n_1..n_1*n_1 {
// a_0 = a_0;
// }
// a_2 = a_1;
// for u32 i_0 in n_1..n_1*n_1 {
// a_0 = a_0;
// }
// a_3 = a_2;
// # versions: {n: 2, a: 3, K: 1}
// for u32 i_0 in n_2..n_2*n_2 {
// a_0 = a_0;
// }
// a_5 = a_4;
// return a_5;
// } # versions: {n: 3, a: 5, K: 2}
// return a_3;
// }
let f: TypedFunction<Bn128Field> = TypedFunction {
arguments: vec![DeclarationVariable::field_element("a").into()],
@ -642,32 +637,15 @@ mod tests {
TypedStatement::Return(FieldElementExpression::identifier("a".into()).into()),
],
signature: DeclarationSignature::new()
.generics(vec![Some(
GenericIdentifier::with_name("K").with_index(0).into(),
)])
.inputs(vec![DeclarationType::FieldElement])
.output(DeclarationType::FieldElement),
};
let mut versions = Versions::default();
let ssa = ShallowTransformer::transform(
f,
&GGenericsAssignment(
vec![(GenericIdentifier::with_name("K").with_index(0), 1)]
.into_iter()
.collect(),
),
&mut versions,
);
let mut ssa = ShallowTransformer::default();
let expected = TypedFunction {
arguments: vec![DeclarationVariable::field_element("a").into()],
statements: vec![
TypedStatement::definition(
Variable::uint("K", UBitwidth::B32).into(),
TypedExpression::Uint(1u32.into()),
),
TypedStatement::definition(
Variable::uint("n", UBitwidth::B32).into(),
TypedExpression::Uint(42u32.into()),
@ -696,16 +674,16 @@ mod tests {
)],
),
TypedStatement::definition(
Variable::field_element(Identifier::from("a").version(3)).into(),
FieldElementExpression::identifier(Identifier::from("a").version(2)).into(),
Variable::field_element(Identifier::from("a").version(2)).into(),
FieldElementExpression::identifier(Identifier::from("a").version(1)).into(),
),
TypedStatement::For(
Variable::uint("i", UBitwidth::B32),
UExpression::identifier(Identifier::from("n").version(2))
UExpression::identifier(Identifier::from("n").version(1))
.annotate(UBitwidth::B32),
UExpression::identifier(Identifier::from("n").version(2))
UExpression::identifier(Identifier::from("n").version(1))
.annotate(UBitwidth::B32)
* UExpression::identifier(Identifier::from("n").version(2))
* UExpression::identifier(Identifier::from("n").version(1))
.annotate(UBitwidth::B32),
vec![TypedStatement::definition(
Variable::field_element("a").into(),
@ -713,47 +691,35 @@ mod tests {
)],
),
TypedStatement::definition(
Variable::field_element(Identifier::from("a").version(5)).into(),
FieldElementExpression::identifier(Identifier::from("a").version(4)).into(),
Variable::field_element(Identifier::from("a").version(3)).into(),
FieldElementExpression::identifier(Identifier::from("a").version(2)).into(),
),
TypedStatement::Return(
FieldElementExpression::identifier(Identifier::from("a").version(5)).into(),
FieldElementExpression::identifier(Identifier::from("a").version(3)).into(),
),
],
signature: DeclarationSignature::new()
.generics(vec![Some(
GenericIdentifier::with_name("K").with_index(0).into(),
)])
.inputs(vec![DeclarationType::FieldElement])
.output(DeclarationType::FieldElement),
};
let res = ssa.fold_function(f);
assert_eq!(
versions,
vec![("n".into(), 3), ("a".into(), 5), ("K".into(), 2)]
.into_iter()
.collect::<Versions>()
ssa.versions.map,
vec![(
0,
vec![("n".into(), 1), ("a".into(), 3)].into_iter().collect()
)]
.into_iter()
.collect()
);
let expected = Output::Incomplete(
expected,
vec![
vec![("n".into(), 1), ("a".into(), 1), ("K".into(), 0)]
.into_iter()
.collect::<Versions>(),
vec![("n".into(), 2), ("a".into(), 3), ("K".into(), 1)]
.into_iter()
.collect::<Versions>(),
],
);
assert_eq!(ssa, expected);
assert_eq!(res, expected);
}
}
mod shadowing {
use zokrates_ast::typed::types::GGenericsAssignment;
use super::*;
#[test]
@ -764,11 +730,11 @@ mod tests {
// return;
// }
// should become
// should become (only the field variable is affected as shadowing is taken care of in semantics already)
// def main(field a_0) {
// field a_1 = 42;
// bool a_2 = true;
// def main(field a_s0_v0) {
// field a_s0_v1 = 42;
// bool a_s1_v0 = true
// return;
// }
@ -780,7 +746,11 @@ mod tests {
TypedExpression::Uint(42u32.into()),
),
TypedStatement::definition(
Variable::boolean("a").into(),
Variable::boolean(CoreIdentifier::from(ShadowedIdentifier::shadow(
"a".into(),
1,
)))
.into(),
BooleanExpression::Value(true).into(),
),
TypedStatement::Return(
@ -789,9 +759,7 @@ mod tests {
.into(),
),
],
signature: DeclarationSignature::new()
.generics(vec![])
.inputs(vec![DeclarationType::FieldElement]),
signature: DeclarationSignature::new().inputs(vec![DeclarationType::FieldElement]),
};
let expected: TypedFunction<Bn128Field> = TypedFunction {
@ -802,7 +770,11 @@ mod tests {
TypedExpression::Uint(42u32.into()),
),
TypedStatement::definition(
Variable::boolean(Identifier::from("a").version(2)).into(),
Variable::boolean(CoreIdentifier::from(ShadowedIdentifier::shadow(
"a".into(),
1,
)))
.into(),
BooleanExpression::Value(true).into(),
),
TypedStatement::Return(
@ -811,121 +783,17 @@ mod tests {
.into(),
),
],
signature: DeclarationSignature::new()
.generics(vec![])
.inputs(vec![DeclarationType::FieldElement]),
signature: DeclarationSignature::new().inputs(vec![DeclarationType::FieldElement]),
};
let mut versions = Versions::default();
let ssa = ShallowTransformer::default().fold_function(f);
let ssa =
ShallowTransformer::transform(f, &GGenericsAssignment::default(), &mut versions);
assert_eq!(ssa, Output::Complete(expected));
}
#[test]
fn next_scope() {
// def main(field a) {
// for u32 i in 0..1 {
// a = a + 1
// field a = 42
// }
// return a
// }
// should become
// def main(field a_0) {
// # versions: {a: 0}
// for u32 i in 0..1 {
// a_0 = a_0
// field a_0 = 42
// }
// return a_1
// }
let f: TypedFunction<Bn128Field> = TypedFunction {
arguments: vec![DeclarationVariable::field_element("a").into()],
statements: vec![
TypedStatement::For(
Variable::uint("i", UBitwidth::B32),
0u32.into(),
1u32.into(),
vec![
TypedStatement::definition(
Variable::field_element(Identifier::from("a")).into(),
FieldElementExpression::identifier("a".into()).into(),
),
TypedStatement::definition(
Variable::field_element(Identifier::from("a")).into(),
FieldElementExpression::Number(42usize.into()).into(),
),
],
),
TypedStatement::Return(
TupleExpressionInner::Value(vec![FieldElementExpression::identifier(
"a".into(),
)
.into()])
.annotate(TupleType::new(vec![Type::FieldElement]))
.into(),
),
],
signature: DeclarationSignature::new()
.generics(vec![])
.inputs(vec![DeclarationType::FieldElement])
.output(DeclarationType::FieldElement),
};
let expected: TypedFunction<Bn128Field> = TypedFunction {
arguments: vec![DeclarationVariable::field_element("a").into()],
statements: vec![
TypedStatement::For(
Variable::uint("i", UBitwidth::B32),
0u32.into(),
1u32.into(),
vec![
TypedStatement::definition(
Variable::field_element(Identifier::from("a")).into(),
FieldElementExpression::identifier(Identifier::from("a")).into(),
),
TypedStatement::definition(
Variable::field_element(Identifier::from("a")).into(),
FieldElementExpression::Number(42usize.into()).into(),
),
],
),
TypedStatement::Return(
TupleExpressionInner::Value(vec![FieldElementExpression::identifier(
Identifier::from("a").version(1),
)
.into()])
.annotate(TupleType::new(vec![Type::FieldElement]))
.into(),
),
],
signature: DeclarationSignature::new()
.generics(vec![])
.inputs(vec![DeclarationType::FieldElement])
.output(DeclarationType::FieldElement),
};
let mut versions = Versions::default();
let ssa =
ShallowTransformer::transform(f, &GGenericsAssignment::default(), &mut versions);
assert_eq!(
ssa,
Output::Incomplete(expected, vec![vec![("a".into(), 0)].into_iter().collect()])
);
assert_eq!(ssa, expected);
}
}
mod function_call {
use super::*;
use zokrates_ast::typed::types::GGenericsAssignment;
// test that function calls are left in
#[test]
fn treat_calls() {
@ -939,17 +807,12 @@ mod tests {
// return a;
// }
// When called with K := 1, expected:
// def main(field a_0) -> field {
// K = 1;
// u32 n_0 = 42;
// n_1 = n_0;
// a_1 = a_0;
// a_2 = foo::<n_1>(a_1);
// n_2 = n_1;
// a_3 = a_2 * foo::<n_2>(a_2);
// a_2 = foo::<42>(a_1);
// a_3 = a_2 * foo::<42>(a_2);
// return a_3;
// } # versions: {n: 2, a: 3}
// }
let f: TypedFunction<Bn128Field> = TypedFunction {
arguments: vec![DeclarationVariable::field_element("a").into()],
@ -1007,25 +870,9 @@ mod tests {
.output(DeclarationType::FieldElement),
};
let mut versions = Versions::default();
let ssa = ShallowTransformer::transform(
f,
&GGenericsAssignment(
vec![(GenericIdentifier::with_name("K").with_index(0), 1)]
.into_iter()
.collect(),
),
&mut versions,
);
let expected = TypedFunction {
arguments: vec![DeclarationVariable::field_element("a").into()],
statements: vec![
TypedStatement::definition(
Variable::uint("K", UBitwidth::B32).into(),
TypedExpression::Uint(1u32.into()),
),
TypedStatement::definition(
Variable::uint("n", UBitwidth::B32).into(),
TypedExpression::Uint(42u32.into()),
@ -1089,14 +936,23 @@ mod tests {
.output(DeclarationType::FieldElement),
};
let mut ssa = ShallowTransformer::default();
let res = ssa.fold_function(f);
assert_eq!(
versions,
vec![("n".into(), 2), ("a".into(), 3), ("K".into(), 0)]
.into_iter()
.collect::<Versions>()
ssa.versions.map,
vec![(
0,
vec![("n".into(), 2), ("a".into(), 3), ("K".into(), 0)]
.into_iter()
.collect()
)]
.into_iter()
.collect()
);
assert_eq!(ssa, Output::Incomplete(expected, vec![],));
assert_eq!(res, expected);
}
}
}

View file

@ -4,6 +4,8 @@ use crate::typed::types::*;
use crate::typed::*;
use zokrates_field::Field;
use super::identifier::FrameIdentifier;
pub trait Fold<'ast, T: Field>: Sized {
fn fold<F: Folder<'ast, T>>(self, f: &mut F) -> Self;
}
@ -128,11 +130,12 @@ pub trait Folder<'ast, T: Field>: Sized {
}
fn fold_name(&mut self, n: Identifier<'ast>) -> Identifier<'ast> {
let id = match n.id {
CoreIdentifier::Constant(c) => {
CoreIdentifier::Constant(self.fold_canonical_constant_identifier(c))
}
id => id,
let id = match n.id.id.clone() {
CoreIdentifier::Constant(c) => FrameIdentifier {
id: CoreIdentifier::Constant(self.fold_canonical_constant_identifier(c)),
frame: 0,
},
_ => n.id,
};
Identifier { id, ..n }
@ -528,10 +531,8 @@ pub fn fold_assembly_statement<'ast, T: Field, F: Folder<'ast, T>>(
) -> Vec<TypedAssemblyStatement<'ast, T>> {
match s {
TypedAssemblyStatement::Assignment(a, e) => {
vec![TypedAssemblyStatement::Assignment(
f.fold_assignee(a),
f.fold_expression(e),
)]
let e = f.fold_expression(e);
vec![TypedAssemblyStatement::Assignment(f.fold_assignee(a), e)]
}
TypedAssemblyStatement::Constraint(lhs, rhs, metadata) => {
vec![TypedAssemblyStatement::Constraint(
@ -549,8 +550,9 @@ pub fn fold_statement<'ast, T: Field, F: Folder<'ast, T>>(
) -> Vec<TypedStatement<'ast, T>> {
let res = match s {
TypedStatement::Return(e) => TypedStatement::Return(f.fold_expression(e)),
TypedStatement::Definition(a, e) => {
TypedStatement::Definition(f.fold_assignee(a), f.fold_definition_rhs(e))
TypedStatement::Definition(a, rhs) => {
let rhs = f.fold_definition_rhs(rhs);
TypedStatement::Definition(f.fold_assignee(a), rhs)
}
TypedStatement::Assertion(e, error) => {
TypedStatement::Assertion(f.fold_boolean_expression(e), error)
@ -573,7 +575,6 @@ pub fn fold_statement<'ast, T: Field, F: Folder<'ast, T>>(
.flat_map(|s| f.fold_assembly_statement(s))
.collect(),
),
s => s,
};
vec![res]
}

View file

@ -24,18 +24,49 @@ impl<'ast> fmt::Display for CoreIdentifier<'ast> {
}
}
impl<'ast> From<CanonicalConstantIdentifier<'ast>> for CoreIdentifier<'ast> {
fn from(s: CanonicalConstantIdentifier<'ast>) -> CoreIdentifier<'ast> {
CoreIdentifier::Constant(s)
impl<'ast> FrameIdentifier<'ast> {
pub fn in_frame(self, frame: usize) -> FrameIdentifier<'ast> {
FrameIdentifier { frame, ..self }
}
}
impl<'ast> Identifier<'ast> {
pub fn in_frame(self, frame: usize) -> Identifier<'ast> {
Identifier {
id: self.id.in_frame(frame),
..self
}
}
}
impl<'ast> CoreIdentifier<'ast> {
pub fn in_frame(self, frame: usize) -> FrameIdentifier<'ast> {
FrameIdentifier { id: self, frame }
}
}
impl<'ast> From<CanonicalConstantIdentifier<'ast>> for FrameIdentifier<'ast> {
fn from(s: CanonicalConstantIdentifier<'ast>) -> FrameIdentifier<'ast> {
FrameIdentifier::from(CoreIdentifier::Constant(s))
}
}
/// A identifier for a variable in a given call frame
#[derive(Debug, PartialEq, Clone, Hash, Eq, PartialOrd, Ord, Serialize, Deserialize)]
pub struct FrameIdentifier<'ast> {
/// the id of the variable
#[serde(borrow)]
pub id: CoreIdentifier<'ast>,
/// the frame of the variable
pub frame: usize,
}
/// A identifier for a variable
#[derive(Debug, PartialEq, Clone, Hash, Eq, PartialOrd, Ord, Serialize, Deserialize)]
pub struct Identifier<'ast> {
/// the id of the variable
#[serde(borrow)]
pub id: CoreIdentifier<'ast>,
pub id: FrameIdentifier<'ast>,
/// the version of the variable, used after SSA transformation
pub version: usize,
}
@ -58,7 +89,7 @@ impl<'ast> fmt::Display for ShadowedIdentifier<'ast> {
if self.shadow == 0 {
write!(f, "{}", self.id)
} else {
write!(f, "{}_{}", self.id, self.shadow)
write!(f, "{}_s{}", self.id, self.shadow)
}
}
}
@ -68,20 +99,45 @@ impl<'ast> fmt::Display for Identifier<'ast> {
if self.version == 0 {
write!(f, "{}", self.id)
} else {
write!(f, "{}_{}", self.id, self.version)
write!(f, "{}_v{}", self.id, self.version)
}
}
}
impl<'ast> fmt::Display for FrameIdentifier<'ast> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
if self.frame == 0 {
write!(f, "{}", self.id)
} else {
write!(f, "{}_f{}", self.id, self.frame)
}
}
}
impl<'ast> From<CanonicalConstantIdentifier<'ast>> for Identifier<'ast> {
fn from(id: CanonicalConstantIdentifier<'ast>) -> Identifier<'ast> {
Identifier::from(CoreIdentifier::Constant(id))
Identifier::from(FrameIdentifier::from(CoreIdentifier::Constant(id)))
}
}
impl<'ast> From<FrameIdentifier<'ast>> for Identifier<'ast> {
fn from(id: FrameIdentifier<'ast>) -> Identifier<'ast> {
Identifier { id, version: 0 }
}
}
impl<'ast> From<CoreIdentifier<'ast>> for Identifier<'ast> {
fn from(id: CoreIdentifier<'ast>) -> Identifier<'ast> {
Identifier { id, version: 0 }
Identifier {
id: FrameIdentifier::from(id),
version: 0,
}
}
}
impl<'ast> From<CoreIdentifier<'ast>> for FrameIdentifier<'ast> {
fn from(id: CoreIdentifier<'ast>) -> FrameIdentifier<'ast> {
FrameIdentifier { id, frame: 0 }
}
}
@ -107,6 +163,6 @@ impl<'ast> From<&'ast str> for CoreIdentifier<'ast> {
impl<'ast> From<&'ast str> for Identifier<'ast> {
fn from(id: &'ast str) -> Identifier<'ast> {
Identifier::from(CoreIdentifier::from(id))
Identifier::from(FrameIdentifier::from(CoreIdentifier::from(id)))
}
}

View file

@ -27,7 +27,7 @@ pub use self::types::{
UBitwidth,
};
use self::types::{ConcreteArrayType, ConcreteStructType};
use crate::typed::types::{ConcreteGenericsAssignment, IntoType};
use crate::typed::types::IntoType;
pub use self::variable::{ConcreteVariable, DeclarationVariable, GVariable, Variable};
use std::marker::PhantomData;
@ -353,19 +353,8 @@ impl<'ast, T: fmt::Display> fmt::Display for TypedFunction<'ast, T> {
writeln!(f)?;
let mut tab = 0;
for s in &self.statements {
if let TypedStatement::PopCallLog = s {
tab -= 1;
};
s.fmt_indented(f, 1 + tab)?;
writeln!(f)?;
if let TypedStatement::PushCallLog(..) = s {
tab += 1;
};
writeln!(f, "{}", s)?;
}
writeln!(f, "}}")?;
@ -695,12 +684,6 @@ pub enum TypedStatement<'ast, T> {
Vec<TypedStatement<'ast, T>>,
),
Log(FormatString, Vec<TypedExpression<'ast, T>>),
// Aux
PushCallLog(
DeclarationFunctionKey<'ast, T>,
ConcreteGenericsAssignment<'ast>,
),
PopCallLog,
Assembly(Vec<TypedAssemblyStatement<'ast, T>>),
}
@ -714,31 +697,6 @@ impl<'ast, T> TypedStatement<'ast, T> {
}
}
impl<'ast, T: fmt::Display> TypedStatement<'ast, T> {
fn fmt_indented(&self, f: &mut fmt::Formatter, depth: usize) -> fmt::Result {
match self {
TypedStatement::For(variable, from, to, statements) => {
write!(f, "{}", "\t".repeat(depth))?;
writeln!(f, "for {} in {}..{} {{", variable, from, to)?;
for s in statements {
s.fmt_indented(f, depth + 1)?;
writeln!(f)?;
}
write!(f, "{}}}", "\t".repeat(depth))
}
TypedStatement::Assembly(statements) => {
write!(f, "{}", "\t".repeat(depth))?;
writeln!(f, "asm {{")?;
for s in statements {
writeln!(f, "{}{}", "\t".repeat(depth + 1), s)?;
}
write!(f, "{}}}", "\t".repeat(depth))
}
s => write!(f, "{}{}", "\t".repeat(depth), s),
}
}
}
impl<'ast, T: fmt::Display> fmt::Display for TypedStatement<'ast, T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
@ -773,14 +731,6 @@ impl<'ast, T: fmt::Display> fmt::Display for TypedStatement<'ast, T> {
.collect::<Vec<_>>()
.join(", ")
),
TypedStatement::PushCallLog(ref key, ref generics) => write!(
f,
"// PUSH CALL TO {}/{}::<{}>",
key.module.display(),
key.id,
generics,
),
TypedStatement::PopCallLog => write!(f, "// POP CALL",),
TypedStatement::Assembly(ref statements) => {
writeln!(f, "asm {{")?;
for s in statements {

View file

@ -4,6 +4,8 @@ use crate::typed::types::*;
use crate::typed::*;
use zokrates_field::Field;
use super::identifier::FrameIdentifier;
pub trait ResultFold<'ast, T: Field>: Sized {
fn fold<F: ResultFolder<'ast, T>>(self, f: &mut F) -> Result<Self, F::Error>;
}
@ -156,11 +158,12 @@ pub trait ResultFolder<'ast, T: Field>: Sized {
}
fn fold_name(&mut self, n: Identifier<'ast>) -> Result<Identifier<'ast>, Self::Error> {
let id = match n.id {
CoreIdentifier::Constant(c) => {
CoreIdentifier::Constant(self.fold_canonical_constant_identifier(c)?)
}
id => id,
let id = match n.id.id.clone() {
CoreIdentifier::Constant(c) => FrameIdentifier {
id: CoreIdentifier::Constant(self.fold_canonical_constant_identifier(c)?),
frame: 0,
},
_ => n.id,
};
Ok(Identifier { id, ..n })
@ -529,10 +532,8 @@ pub fn fold_assembly_statement<'ast, T: Field, F: ResultFolder<'ast, T>>(
) -> Result<Vec<TypedAssemblyStatement<'ast, T>>, F::Error> {
Ok(match s {
TypedAssemblyStatement::Assignment(a, e) => {
vec![TypedAssemblyStatement::Assignment(
f.fold_assignee(a)?,
f.fold_expression(e)?,
)]
let e = f.fold_expression(e)?;
vec![TypedAssemblyStatement::Assignment(f.fold_assignee(a)?, e)]
}
TypedAssemblyStatement::Constraint(lhs, rhs, metadata) => {
vec![TypedAssemblyStatement::Constraint(
@ -551,7 +552,8 @@ pub fn fold_statement<'ast, T: Field, F: ResultFolder<'ast, T>>(
let res = match s {
TypedStatement::Return(e) => TypedStatement::Return(f.fold_expression(e)?),
TypedStatement::Definition(a, e) => {
TypedStatement::Definition(f.fold_assignee(a)?, f.fold_definition_rhs(e)?)
let rhs = f.fold_definition_rhs(e)?;
TypedStatement::Definition(f.fold_assignee(a)?, rhs)
}
TypedStatement::Assertion(e, error) => {
TypedStatement::Assertion(f.fold_boolean_expression(e)?, error)
@ -583,7 +585,6 @@ pub fn fold_statement<'ast, T: Field, F: ResultFolder<'ast, T>>(
.flatten()
.collect(),
),
s => s,
};
Ok(vec![res])
}

View file

@ -241,13 +241,14 @@ impl<'ast, T: Field> From<DeclarationConstant<'ast, T>> for UExpression<'ast, T>
fn from(c: DeclarationConstant<'ast, T>) -> Self {
match c {
DeclarationConstant::Generic(g) => {
UExpression::identifier(CoreIdentifier::from(g).into()).annotate(UBitwidth::B32)
UExpression::identifier(Identifier::from(CoreIdentifier::from(g)))
.annotate(UBitwidth::B32)
}
DeclarationConstant::Concrete(v) => {
UExpressionInner::Value(v as u128).annotate(UBitwidth::B32)
}
DeclarationConstant::Constant(v) => {
UExpression::identifier(CoreIdentifier::from(v).into()).annotate(UBitwidth::B32)
UExpression::identifier(FrameIdentifier::from(v).into()).annotate(UBitwidth::B32)
}
DeclarationConstant::Expression(e) => e.try_into().unwrap(),
}
@ -1144,8 +1145,7 @@ pub fn check_type<'ast, T, S: Clone + PartialEq + PartialEq<u32>>(
impl<'ast, T: Field> From<CanonicalConstantIdentifier<'ast>> for UExpression<'ast, T> {
fn from(c: CanonicalConstantIdentifier<'ast>) -> Self {
UExpression::identifier(Identifier::from(CoreIdentifier::Constant(c)))
.annotate(UBitwidth::B32)
UExpression::identifier(Identifier::from(FrameIdentifier::from(c))).annotate(UBitwidth::B32)
}
}
@ -1230,6 +1230,7 @@ pub use self::signature::{
try_from_g_signature, ConcreteSignature, DeclarationSignature, GSignature, Signature,
};
use super::identifier::FrameIdentifier;
use super::{Id, ShadowedIdentifier};
pub mod signature {

View file

@ -1170,7 +1170,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
let id = arg.id.value.id;
let info = IdentifierInfo {
id: decl_v.id.id.clone(),
id: decl_v.id.id.id.clone(),
ty,
is_mutable,
};

View file

@ -0,0 +1,15 @@
{
"max_constraint_count": 1,
"tests": [
{
"input": {
"values": ["0"]
},
"output": {
"Ok": {
"value": "4"
}
}
}
]
}

View file

@ -0,0 +1,11 @@
// main should be x -> x + 4
def foo(field mut a) -> field {
a = a + 1;
return a + 1;
}
def main(field mut a) -> field {
a = foo(a + 1);
return a + 1;
}