1
0
Fork 0
mirror of synced 2025-09-23 12:18:44 +00:00

fold module id in folders

This commit is contained in:
schaeff 2021-06-15 16:41:38 +02:00
parent 0bb9f5897a
commit 759f4dd553
5 changed files with 73 additions and 44 deletions

2
access.zok Normal file
View file

@ -0,0 +1,2 @@
def main(field[3][2] a, u32 index) -> field[2]:
return a[index]

View file

@ -1,8 +1,8 @@
from "./bar" import main as bar //from "./bar" import main as bar
from "./baz" import BAZ as baz from "./baz" import BAZ as baz
import "./foo" as f import "./foo" as f
def main() -> field: def main() -> field:
field foo = f() field foo = f()
assert(foo == bar() + baz) assert(foo == 21 + baz)
return foo return foo

View file

@ -48,7 +48,7 @@ impl<'ast, 'a, T: Field> ConstantInliner<'ast, T> {
assert_eq!(id.version, 0); assert_eq!(id.version, 0);
match id.id { match id.id {
CoreIdentifier::Call(..) => { CoreIdentifier::Call(..) => {
unreachable!("calls indentifiers are only available after call inlining") unreachable!("calls identifiers are only available after call inlining")
} }
CoreIdentifier::Source(id) => self CoreIdentifier::Source(id) => self
.constants .constants
@ -60,46 +60,34 @@ impl<'ast, 'a, T: Field> ConstantInliner<'ast, T> {
} }
impl<'ast, T: Field> Folder<'ast, T> for ConstantInliner<'ast, T> { impl<'ast, T: Field> Folder<'ast, T> for ConstantInliner<'ast, T> {
fn fold_program(&mut self, p: TypedProgram<'ast, T>) -> TypedProgram<'ast, T> { fn fold_module_id(&mut self, id: OwnedTypedModuleId) -> OwnedTypedModuleId {
TypedProgram { // anytime we encounter a module id, visit the corresponding module if it hasn't been done yet
modules: p if !self.treated(&id) {
.modules let current_m_id = self.change_location(id.clone());
.into_iter() self.constants.entry(self.location.clone()).or_default();
.map(|(m_id, m)| { let m = self.fold_module(self.modules.get(&id).unwrap().clone());
if !self.treated(&m_id) {
self.change_location(m_id.clone()); self.modules.insert(id.clone(), m);
(m_id, self.fold_module(m))
} else { self.change_location(current_m_id);
(m_id, m)
}
})
.collect(),
..p
} }
id
} }
fn fold_module(&mut self, m: TypedModule<'ast, T>) -> TypedModule<'ast, T> { fn fold_module(&mut self, m: TypedModule<'ast, T>) -> TypedModule<'ast, T> {
assert!(self // initialise a constant map for this module
.constants self.constants.entry(self.location.clone()).or_default();
.insert(self.location.clone(), Default::default())
.is_none());
TypedModule { TypedModule {
constants: m constants: m
.constants .constants
.into_iter() .into_iter()
.map(|(id, tc)| { .map(|(id, tc)| {
let id = self.fold_canonical_constant_identifier(id);
let constant = match tc { let constant = match tc {
TypedConstantSymbol::There(imported_id) => { TypedConstantSymbol::There(imported_id) => {
if !self.treated(&imported_id.module) { let imported_id = self.fold_canonical_constant_identifier(imported_id);
let current_m_id = self.change_location(imported_id.module.clone());
let m = self.fold_module(
self.modules.get(&imported_id.module).unwrap().clone(),
);
self.modules.insert(imported_id.module.clone(), m);
self.change_location(current_m_id);
}
self.constants self.constants
.get(&imported_id.module) .get(&imported_id.module)
.unwrap() .unwrap()
@ -115,12 +103,10 @@ impl<'ast, T: Field> Folder<'ast, T> for ConstantInliner<'ast, T> {
.fold_expression(constant) .fold_expression(constant)
.unwrap(); .unwrap();
assert!(self self.constants
.constants
.entry(self.location.clone()) .entry(self.location.clone())
.or_default() .or_default()
.insert(id.id.into(), constant.clone()) .insert(id.id.into(), constant.clone());
.is_none());
( (
id, id,
@ -176,7 +162,7 @@ impl<'ast, T: Field> Folder<'ast, T> for ConstantInliner<'ast, T> {
) -> FieldElementExpression<'ast, T> { ) -> FieldElementExpression<'ast, T> {
match e { match e {
FieldElementExpression::Identifier(ref id) => match self.get_constant(id) { FieldElementExpression::Identifier(ref id) => match self.get_constant(id) {
Some(c) => self.fold_expression(c).try_into().unwrap(), Some(c) => c.try_into().unwrap(),
None => fold_field_expression(self, e), None => fold_field_expression(self, e),
}, },
e => fold_field_expression(self, e), e => fold_field_expression(self, e),

View file

@ -215,6 +215,20 @@ pub trait Folder<'ast, T: Field>: Sized {
} }
} }
fn fold_canonical_constant_identifier(
&mut self,
i: CanonicalConstantIdentifier<'ast>,
) -> CanonicalConstantIdentifier<'ast> {
CanonicalConstantIdentifier {
module: self.fold_module_id(i.module),
id: i.id,
}
}
fn fold_module_id(&mut self, i: OwnedTypedModuleId) -> OwnedTypedModuleId {
i
}
fn fold_expression(&mut self, e: TypedExpression<'ast, T>) -> TypedExpression<'ast, T> { fn fold_expression(&mut self, e: TypedExpression<'ast, T>) -> TypedExpression<'ast, T> {
match e { match e {
TypedExpression::FieldElement(e) => self.fold_field_expression(e).into(), TypedExpression::FieldElement(e) => self.fold_field_expression(e).into(),
@ -342,7 +356,12 @@ pub fn fold_module<'ast, T: Field, F: Folder<'ast, T>>(
constants: m constants: m
.constants .constants
.into_iter() .into_iter()
.map(|(key, tc)| (key, f.fold_constant_symbol(tc))) .map(|(id, tc)| {
(
f.fold_canonical_constant_identifier(id),
f.fold_constant_symbol(tc),
)
})
.collect(), .collect(),
functions: m functions: m
.functions .functions
@ -1014,7 +1033,9 @@ pub fn fold_constant_symbol<'ast, T: Field, F: Folder<'ast, T>>(
) -> TypedConstantSymbol<'ast, T> { ) -> TypedConstantSymbol<'ast, T> {
match s { match s {
TypedConstantSymbol::Here(tc) => TypedConstantSymbol::Here(f.fold_constant(tc)), TypedConstantSymbol::Here(tc) => TypedConstantSymbol::Here(f.fold_constant(tc)),
there => there, TypedConstantSymbol::There(id) => {
TypedConstantSymbol::There(f.fold_canonical_constant_identifier(id))
}
} }
} }
@ -1052,8 +1073,8 @@ pub fn fold_program<'ast, T: Field, F: Folder<'ast, T>>(
modules: p modules: p
.modules .modules
.into_iter() .into_iter()
.map(|(module_id, module)| (module_id, f.fold_module(module))) .map(|(module_id, module)| (f.fold_module_id(module_id), f.fold_module(module)))
.collect(), .collect(),
main: p.main, main: f.fold_module_id(p.main),
} }
} }

View file

@ -114,6 +114,20 @@ pub trait ResultFolder<'ast, T: Field>: Sized {
}) })
} }
fn fold_canonical_constant_identifier(
&mut self,
i: CanonicalConstantIdentifier<'ast>,
) -> Result<CanonicalConstantIdentifier<'ast>, Self::Error> {
Ok(CanonicalConstantIdentifier {
module: self.fold_module_id(i.module)?,
id: i.id,
})
}
fn fold_module_id(&mut self, i: OwnedTypedModuleId) -> Result<OwnedTypedModuleId, Self::Error> {
Ok(i)
}
fn fold_name(&mut self, n: Identifier<'ast>) -> Result<Identifier<'ast>, Self::Error> { fn fold_name(&mut self, n: Identifier<'ast>) -> Result<Identifier<'ast>, Self::Error> {
Ok(n) Ok(n)
} }
@ -923,6 +937,7 @@ pub fn fold_declaration_function_key<'ast, T: Field, F: ResultFolder<'ast, T>>(
key: DeclarationFunctionKey<'ast>, key: DeclarationFunctionKey<'ast>,
) -> Result<DeclarationFunctionKey<'ast>, F::Error> { ) -> Result<DeclarationFunctionKey<'ast>, F::Error> {
Ok(DeclarationFunctionKey { Ok(DeclarationFunctionKey {
module: f.fold_module_id(key.module)?,
signature: f.fold_signature(key.signature)?, signature: f.fold_signature(key.signature)?,
..key ..key
}) })
@ -1067,7 +1082,9 @@ pub fn fold_constant_symbol<'ast, T: Field, F: ResultFolder<'ast, T>>(
) -> Result<TypedConstantSymbol<'ast, T>, F::Error> { ) -> Result<TypedConstantSymbol<'ast, T>, F::Error> {
match s { match s {
TypedConstantSymbol::Here(tc) => Ok(TypedConstantSymbol::Here(f.fold_constant(tc)?)), TypedConstantSymbol::Here(tc) => Ok(TypedConstantSymbol::Here(f.fold_constant(tc)?)),
there => Ok(there), TypedConstantSymbol::There(id) => Ok(TypedConstantSymbol::There(
f.fold_canonical_constant_identifier(id)?,
)),
} }
} }
@ -1077,7 +1094,10 @@ pub fn fold_function_symbol<'ast, T: Field, F: ResultFolder<'ast, T>>(
) -> Result<TypedFunctionSymbol<'ast, T>, F::Error> { ) -> Result<TypedFunctionSymbol<'ast, T>, F::Error> {
match s { match s {
TypedFunctionSymbol::Here(fun) => Ok(TypedFunctionSymbol::Here(f.fold_function(fun)?)), TypedFunctionSymbol::Here(fun) => Ok(TypedFunctionSymbol::Here(f.fold_function(fun)?)),
there => Ok(there), // by default, do not fold modules recursively TypedFunctionSymbol::There(key) => Ok(TypedFunctionSymbol::There(
f.fold_declaration_function_key(key)?,
)),
s => Ok(s),
} }
} }
@ -1109,6 +1129,6 @@ pub fn fold_program<'ast, T: Field, F: ResultFolder<'ast, T>>(
.into_iter() .into_iter()
.map(|(module_id, module)| f.fold_module(module).map(|m| (module_id, m))) .map(|(module_id, module)| f.fold_module(module).map(|m| (module_id, m)))
.collect::<Result<_, _>>()?, .collect::<Result<_, _>>()?,
main: p.main, main: f.fold_module_id(p.main)?,
}) })
} }