clean
This commit is contained in:
parent
59023ca677
commit
6b4a80d891
3 changed files with 66 additions and 67 deletions
|
@ -1,2 +0,0 @@
|
|||
def main(field[3][2] a, u32 index) -> field[2]:
|
||||
return a[index]
|
|
@ -7,20 +7,20 @@ use std::collections::HashMap;
|
|||
use std::convert::TryInto;
|
||||
use zokrates_field::Field;
|
||||
|
||||
type ModuleConstants<'ast, T> =
|
||||
type ProgramConstants<'ast, T> =
|
||||
HashMap<OwnedTypedModuleId, HashMap<Identifier<'ast>, TypedExpression<'ast, T>>>;
|
||||
|
||||
pub struct ConstantInliner<'ast, T> {
|
||||
modules: TypedModules<'ast, T>,
|
||||
location: OwnedTypedModuleId,
|
||||
constants: ModuleConstants<'ast, T>,
|
||||
constants: ProgramConstants<'ast, T>,
|
||||
}
|
||||
|
||||
impl<'ast, 'a, T: Field> ConstantInliner<'ast, T> {
|
||||
pub fn new(
|
||||
modules: TypedModules<'ast, T>,
|
||||
location: OwnedTypedModuleId,
|
||||
constants: ModuleConstants<'ast, T>,
|
||||
constants: ProgramConstants<'ast, T>,
|
||||
) -> Self {
|
||||
ConstantInliner {
|
||||
modules,
|
||||
|
@ -29,7 +29,7 @@ impl<'ast, 'a, T: Field> ConstantInliner<'ast, T> {
|
|||
}
|
||||
}
|
||||
pub fn inline(p: TypedProgram<'ast, T>) -> TypedProgram<'ast, T> {
|
||||
let constants = HashMap::new();
|
||||
let constants = ProgramConstants::new();
|
||||
let mut inliner = ConstantInliner::new(p.modules.clone(), p.main.clone(), constants);
|
||||
inliner.fold_program(p)
|
||||
}
|
||||
|
@ -37,6 +37,7 @@ impl<'ast, 'a, T: Field> ConstantInliner<'ast, T> {
|
|||
fn change_location(&mut self, location: OwnedTypedModuleId) -> OwnedTypedModuleId {
|
||||
let prev = self.location.clone();
|
||||
self.location = location;
|
||||
self.constants.entry(self.location.clone()).or_default();
|
||||
prev
|
||||
}
|
||||
|
||||
|
@ -44,18 +45,24 @@ impl<'ast, 'a, T: Field> ConstantInliner<'ast, T> {
|
|||
self.constants.contains_key(id)
|
||||
}
|
||||
|
||||
fn get_constant(&mut self, id: &Identifier) -> Option<TypedExpression<'ast, T>> {
|
||||
assert_eq!(id.version, 0);
|
||||
match id.id {
|
||||
CoreIdentifier::Call(..) => {
|
||||
unreachable!("calls identifiers are only available after call inlining")
|
||||
}
|
||||
CoreIdentifier::Source(id) => self
|
||||
.constants
|
||||
.get(&self.location)
|
||||
.and_then(|constants| constants.get(&id.into()))
|
||||
.cloned(),
|
||||
}
|
||||
fn get_constant(
|
||||
&self,
|
||||
id: &CanonicalConstantIdentifier<'ast>,
|
||||
) -> Option<TypedExpression<'ast, T>> {
|
||||
self.constants
|
||||
.get(&id.module)
|
||||
.and_then(|constants| constants.get(&id.id.into()))
|
||||
.cloned()
|
||||
}
|
||||
|
||||
fn get_constant_for_identifier(
|
||||
&self,
|
||||
id: &Identifier<'ast>,
|
||||
) -> Option<TypedExpression<'ast, T>> {
|
||||
self.constants
|
||||
.get(&self.location)
|
||||
.and_then(|constants| constants.get(&id))
|
||||
.cloned()
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -64,48 +71,42 @@ impl<'ast, T: Field> Folder<'ast, T> for ConstantInliner<'ast, T> {
|
|||
// anytime we encounter a module id, visit the corresponding module if it hasn't been done yet
|
||||
if !self.treated(&id) {
|
||||
let current_m_id = self.change_location(id.clone());
|
||||
self.constants.entry(self.location.clone()).or_default();
|
||||
let m = self.fold_module(self.modules.get(&id).unwrap().clone());
|
||||
|
||||
let m = self.modules.remove(&id).unwrap();
|
||||
let m = self.fold_module(m);
|
||||
self.modules.insert(id.clone(), m);
|
||||
|
||||
self.change_location(current_m_id);
|
||||
}
|
||||
id
|
||||
}
|
||||
|
||||
fn fold_module(&mut self, m: TypedModule<'ast, T>) -> TypedModule<'ast, T> {
|
||||
// initialise a constant map for this module
|
||||
self.constants.entry(self.location.clone()).or_default();
|
||||
|
||||
TypedModule {
|
||||
constants: m
|
||||
.constants
|
||||
.into_iter()
|
||||
.map(|(id, tc)| {
|
||||
let id = self.fold_canonical_constant_identifier(id);
|
||||
|
||||
let constant = match tc {
|
||||
TypedConstantSymbol::There(imported_id) => {
|
||||
// visit the imported symbol. This triggers visiting the corresponding module if needed
|
||||
let imported_id = self.fold_canonical_constant_identifier(imported_id);
|
||||
self.constants
|
||||
.get(&imported_id.module)
|
||||
.unwrap()
|
||||
.get(&imported_id.id.into())
|
||||
.cloned()
|
||||
// after that, the constant must have been defined defined in the global map. It is already reduced
|
||||
// to a literal, so running propagation isn't required
|
||||
self.get_constant(&imported_id).unwrap()
|
||||
}
|
||||
TypedConstantSymbol::Here(c) => {
|
||||
let non_propagated_constant = fold_constant(self, c).expression;
|
||||
// folding the constant above only reduces it to an expression containing only literals, not to a single literal.
|
||||
// propagating with an empty map of constants reduces it to a single literal
|
||||
Propagator::with_constants(&mut HashMap::default())
|
||||
.fold_expression(non_propagated_constant)
|
||||
.unwrap()
|
||||
}
|
||||
TypedConstantSymbol::Here(c) => fold_constant(self, c).expression,
|
||||
};
|
||||
|
||||
let constant =
|
||||
Propagator::with_constants(self.constants.get_mut(&self.location).unwrap())
|
||||
.fold_expression(constant)
|
||||
.unwrap();
|
||||
|
||||
// add to the constant map. The value added is always a single litteral
|
||||
self.constants
|
||||
.entry(self.location.clone())
|
||||
.or_default()
|
||||
.get_mut(&self.location)
|
||||
.unwrap()
|
||||
.insert(id.id.into(), constant.clone());
|
||||
|
||||
(
|
||||
|
@ -136,22 +137,15 @@ impl<'ast, T: Field> Folder<'ast, T> for ConstantInliner<'ast, T> {
|
|||
) -> DeclarationConstant<'ast> {
|
||||
match c {
|
||||
// replace constants by their concrete value in declaration types
|
||||
DeclarationConstant::Constant(id) => DeclarationConstant::Concrete(
|
||||
match self
|
||||
.constants
|
||||
.get(&id.module)
|
||||
.unwrap()
|
||||
.get(&id.id.into())
|
||||
.cloned()
|
||||
.unwrap()
|
||||
{
|
||||
DeclarationConstant::Constant(id) => {
|
||||
DeclarationConstant::Concrete(match self.get_constant(&id).unwrap() {
|
||||
TypedExpression::Uint(UExpression {
|
||||
inner: UExpressionInner::Value(v),
|
||||
..
|
||||
}) => v as u32,
|
||||
_ => unreachable!("all constants should be reduceable to u32 literals"),
|
||||
},
|
||||
),
|
||||
})
|
||||
}
|
||||
c => c,
|
||||
}
|
||||
}
|
||||
|
@ -161,10 +155,12 @@ impl<'ast, T: Field> Folder<'ast, T> for ConstantInliner<'ast, T> {
|
|||
e: FieldElementExpression<'ast, T>,
|
||||
) -> FieldElementExpression<'ast, T> {
|
||||
match e {
|
||||
FieldElementExpression::Identifier(ref id) => match self.get_constant(id) {
|
||||
Some(c) => c.try_into().unwrap(),
|
||||
None => fold_field_expression(self, e),
|
||||
},
|
||||
FieldElementExpression::Identifier(ref id) => {
|
||||
match self.get_constant_for_identifier(id) {
|
||||
Some(c) => c.try_into().unwrap(),
|
||||
None => fold_field_expression(self, e),
|
||||
}
|
||||
}
|
||||
e => fold_field_expression(self, e),
|
||||
}
|
||||
}
|
||||
|
@ -174,8 +170,8 @@ impl<'ast, T: Field> Folder<'ast, T> for ConstantInliner<'ast, T> {
|
|||
e: BooleanExpression<'ast, T>,
|
||||
) -> BooleanExpression<'ast, T> {
|
||||
match e {
|
||||
BooleanExpression::Identifier(ref id) => match self.get_constant(id) {
|
||||
Some(c) => self.fold_expression(c).try_into().unwrap(),
|
||||
BooleanExpression::Identifier(ref id) => match self.get_constant_for_identifier(id) {
|
||||
Some(c) => c.try_into().unwrap(),
|
||||
None => fold_boolean_expression(self, e),
|
||||
},
|
||||
e => fold_boolean_expression(self, e),
|
||||
|
@ -188,9 +184,9 @@ impl<'ast, T: Field> Folder<'ast, T> for ConstantInliner<'ast, T> {
|
|||
e: UExpressionInner<'ast, T>,
|
||||
) -> UExpressionInner<'ast, T> {
|
||||
match e {
|
||||
UExpressionInner::Identifier(ref id) => match self.get_constant(id) {
|
||||
UExpressionInner::Identifier(ref id) => match self.get_constant_for_identifier(id) {
|
||||
Some(c) => {
|
||||
let e: UExpression<'ast, T> = self.fold_expression(c).try_into().unwrap();
|
||||
let e: UExpression<'ast, T> = c.try_into().unwrap();
|
||||
e.into_inner()
|
||||
}
|
||||
None => fold_uint_expression_inner(self, size, e),
|
||||
|
@ -205,13 +201,15 @@ impl<'ast, T: Field> Folder<'ast, T> for ConstantInliner<'ast, T> {
|
|||
e: ArrayExpressionInner<'ast, T>,
|
||||
) -> ArrayExpressionInner<'ast, T> {
|
||||
match e {
|
||||
ArrayExpressionInner::Identifier(ref id) => match self.get_constant(id) {
|
||||
Some(c) => {
|
||||
let e: ArrayExpression<'ast, T> = self.fold_expression(c).try_into().unwrap();
|
||||
e.into_inner()
|
||||
ArrayExpressionInner::Identifier(ref id) => {
|
||||
match self.get_constant_for_identifier(id) {
|
||||
Some(c) => {
|
||||
let e: ArrayExpression<'ast, T> = c.try_into().unwrap();
|
||||
e.into_inner()
|
||||
}
|
||||
None => fold_array_expression_inner(self, ty, e),
|
||||
}
|
||||
None => fold_array_expression_inner(self, ty, e),
|
||||
},
|
||||
}
|
||||
e => fold_array_expression_inner(self, ty, e),
|
||||
}
|
||||
}
|
||||
|
@ -222,9 +220,10 @@ impl<'ast, T: Field> Folder<'ast, T> for ConstantInliner<'ast, T> {
|
|||
e: StructExpressionInner<'ast, T>,
|
||||
) -> StructExpressionInner<'ast, T> {
|
||||
match e {
|
||||
StructExpressionInner::Identifier(ref id) => match self.get_constant(id) {
|
||||
StructExpressionInner::Identifier(ref id) => match self.get_constant_for_identifier(id)
|
||||
{
|
||||
Some(c) => {
|
||||
let e: StructExpression<'ast, T> = self.fold_expression(c).try_into().unwrap();
|
||||
let e: StructExpression<'ast, T> = c.try_into().unwrap();
|
||||
e.into_inner()
|
||||
}
|
||||
None => fold_struct_expression_inner(self, ty, e),
|
||||
|
|
|
@ -298,6 +298,7 @@ impl<'ast, T: fmt::Display> fmt::Display for TypedFunction<'ast, T> {
|
|||
|
||||
#[derive(Clone, PartialEq, Debug)]
|
||||
pub struct TypedConstant<'ast, T> {
|
||||
// the type is already stored in the TypedExpression, but we want to avoid awkward trait bounds in `fmt::Display`
|
||||
pub ty: Type<'ast, T>,
|
||||
pub expression: TypedExpression<'ast, T>,
|
||||
}
|
||||
|
@ -310,6 +311,7 @@ impl<'ast, T> TypedConstant<'ast, T> {
|
|||
|
||||
impl<'ast, T: fmt::Display> fmt::Display for TypedConstant<'ast, T> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
// using `self.expression.get_type()` would be better here but ends up requiring stronger trait bounds
|
||||
write!(f, "const {}({})", self.ty, self.expression)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue