fold module id in folders
This commit is contained in:
parent
0bb9f5897a
commit
759f4dd553
5 changed files with 73 additions and 44 deletions
2
access.zok
Normal file
2
access.zok
Normal file
|
@ -0,0 +1,2 @@
|
|||
def main(field[3][2] a, u32 index) -> field[2]:
|
||||
return a[index]
|
|
@ -1,8 +1,8 @@
|
|||
from "./bar" import main as bar
|
||||
//from "./bar" import main as bar
|
||||
from "./baz" import BAZ as baz
|
||||
import "./foo" as f
|
||||
|
||||
def main() -> field:
|
||||
field foo = f()
|
||||
assert(foo == bar() + baz)
|
||||
assert(foo == 21 + baz)
|
||||
return foo
|
|
@ -48,7 +48,7 @@ impl<'ast, 'a, T: Field> ConstantInliner<'ast, T> {
|
|||
assert_eq!(id.version, 0);
|
||||
match id.id {
|
||||
CoreIdentifier::Call(..) => {
|
||||
unreachable!("calls indentifiers are only available after call inlining")
|
||||
unreachable!("calls identifiers are only available after call inlining")
|
||||
}
|
||||
CoreIdentifier::Source(id) => self
|
||||
.constants
|
||||
|
@ -60,46 +60,34 @@ impl<'ast, 'a, T: Field> ConstantInliner<'ast, T> {
|
|||
}
|
||||
|
||||
impl<'ast, T: Field> Folder<'ast, T> for ConstantInliner<'ast, T> {
|
||||
fn fold_program(&mut self, p: TypedProgram<'ast, T>) -> TypedProgram<'ast, T> {
|
||||
TypedProgram {
|
||||
modules: p
|
||||
.modules
|
||||
.into_iter()
|
||||
.map(|(m_id, m)| {
|
||||
if !self.treated(&m_id) {
|
||||
self.change_location(m_id.clone());
|
||||
(m_id, self.fold_module(m))
|
||||
} else {
|
||||
(m_id, m)
|
||||
}
|
||||
})
|
||||
.collect(),
|
||||
..p
|
||||
fn fold_module_id(&mut self, id: OwnedTypedModuleId) -> OwnedTypedModuleId {
|
||||
// anytime we encounter a module id, visit the corresponding module if it hasn't been done yet
|
||||
if !self.treated(&id) {
|
||||
let current_m_id = self.change_location(id.clone());
|
||||
self.constants.entry(self.location.clone()).or_default();
|
||||
let m = self.fold_module(self.modules.get(&id).unwrap().clone());
|
||||
|
||||
self.modules.insert(id.clone(), m);
|
||||
|
||||
self.change_location(current_m_id);
|
||||
}
|
||||
id
|
||||
}
|
||||
|
||||
fn fold_module(&mut self, m: TypedModule<'ast, T>) -> TypedModule<'ast, T> {
|
||||
assert!(self
|
||||
.constants
|
||||
.insert(self.location.clone(), Default::default())
|
||||
.is_none());
|
||||
// initialise a constant map for this module
|
||||
self.constants.entry(self.location.clone()).or_default();
|
||||
|
||||
TypedModule {
|
||||
constants: m
|
||||
.constants
|
||||
.into_iter()
|
||||
.map(|(id, tc)| {
|
||||
let id = self.fold_canonical_constant_identifier(id);
|
||||
|
||||
let constant = match tc {
|
||||
TypedConstantSymbol::There(imported_id) => {
|
||||
if !self.treated(&imported_id.module) {
|
||||
let current_m_id = self.change_location(imported_id.module.clone());
|
||||
let m = self.fold_module(
|
||||
self.modules.get(&imported_id.module).unwrap().clone(),
|
||||
);
|
||||
|
||||
self.modules.insert(imported_id.module.clone(), m);
|
||||
|
||||
self.change_location(current_m_id);
|
||||
}
|
||||
let imported_id = self.fold_canonical_constant_identifier(imported_id);
|
||||
self.constants
|
||||
.get(&imported_id.module)
|
||||
.unwrap()
|
||||
|
@ -115,12 +103,10 @@ impl<'ast, T: Field> Folder<'ast, T> for ConstantInliner<'ast, T> {
|
|||
.fold_expression(constant)
|
||||
.unwrap();
|
||||
|
||||
assert!(self
|
||||
.constants
|
||||
self.constants
|
||||
.entry(self.location.clone())
|
||||
.or_default()
|
||||
.insert(id.id.into(), constant.clone())
|
||||
.is_none());
|
||||
.insert(id.id.into(), constant.clone());
|
||||
|
||||
(
|
||||
id,
|
||||
|
@ -176,7 +162,7 @@ impl<'ast, T: Field> Folder<'ast, T> for ConstantInliner<'ast, T> {
|
|||
) -> FieldElementExpression<'ast, T> {
|
||||
match e {
|
||||
FieldElementExpression::Identifier(ref id) => match self.get_constant(id) {
|
||||
Some(c) => self.fold_expression(c).try_into().unwrap(),
|
||||
Some(c) => c.try_into().unwrap(),
|
||||
None => fold_field_expression(self, e),
|
||||
},
|
||||
e => fold_field_expression(self, e),
|
||||
|
|
|
@ -215,6 +215,20 @@ pub trait Folder<'ast, T: Field>: Sized {
|
|||
}
|
||||
}
|
||||
|
||||
fn fold_canonical_constant_identifier(
|
||||
&mut self,
|
||||
i: CanonicalConstantIdentifier<'ast>,
|
||||
) -> CanonicalConstantIdentifier<'ast> {
|
||||
CanonicalConstantIdentifier {
|
||||
module: self.fold_module_id(i.module),
|
||||
id: i.id,
|
||||
}
|
||||
}
|
||||
|
||||
fn fold_module_id(&mut self, i: OwnedTypedModuleId) -> OwnedTypedModuleId {
|
||||
i
|
||||
}
|
||||
|
||||
fn fold_expression(&mut self, e: TypedExpression<'ast, T>) -> TypedExpression<'ast, T> {
|
||||
match e {
|
||||
TypedExpression::FieldElement(e) => self.fold_field_expression(e).into(),
|
||||
|
@ -342,7 +356,12 @@ pub fn fold_module<'ast, T: Field, F: Folder<'ast, T>>(
|
|||
constants: m
|
||||
.constants
|
||||
.into_iter()
|
||||
.map(|(key, tc)| (key, f.fold_constant_symbol(tc)))
|
||||
.map(|(id, tc)| {
|
||||
(
|
||||
f.fold_canonical_constant_identifier(id),
|
||||
f.fold_constant_symbol(tc),
|
||||
)
|
||||
})
|
||||
.collect(),
|
||||
functions: m
|
||||
.functions
|
||||
|
@ -1014,7 +1033,9 @@ pub fn fold_constant_symbol<'ast, T: Field, F: Folder<'ast, T>>(
|
|||
) -> TypedConstantSymbol<'ast, T> {
|
||||
match s {
|
||||
TypedConstantSymbol::Here(tc) => TypedConstantSymbol::Here(f.fold_constant(tc)),
|
||||
there => there,
|
||||
TypedConstantSymbol::There(id) => {
|
||||
TypedConstantSymbol::There(f.fold_canonical_constant_identifier(id))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1052,8 +1073,8 @@ pub fn fold_program<'ast, T: Field, F: Folder<'ast, T>>(
|
|||
modules: p
|
||||
.modules
|
||||
.into_iter()
|
||||
.map(|(module_id, module)| (module_id, f.fold_module(module)))
|
||||
.map(|(module_id, module)| (f.fold_module_id(module_id), f.fold_module(module)))
|
||||
.collect(),
|
||||
main: p.main,
|
||||
main: f.fold_module_id(p.main),
|
||||
}
|
||||
}
|
||||
|
|
|
@ -114,6 +114,20 @@ pub trait ResultFolder<'ast, T: Field>: Sized {
|
|||
})
|
||||
}
|
||||
|
||||
fn fold_canonical_constant_identifier(
|
||||
&mut self,
|
||||
i: CanonicalConstantIdentifier<'ast>,
|
||||
) -> Result<CanonicalConstantIdentifier<'ast>, Self::Error> {
|
||||
Ok(CanonicalConstantIdentifier {
|
||||
module: self.fold_module_id(i.module)?,
|
||||
id: i.id,
|
||||
})
|
||||
}
|
||||
|
||||
fn fold_module_id(&mut self, i: OwnedTypedModuleId) -> Result<OwnedTypedModuleId, Self::Error> {
|
||||
Ok(i)
|
||||
}
|
||||
|
||||
fn fold_name(&mut self, n: Identifier<'ast>) -> Result<Identifier<'ast>, Self::Error> {
|
||||
Ok(n)
|
||||
}
|
||||
|
@ -923,6 +937,7 @@ pub fn fold_declaration_function_key<'ast, T: Field, F: ResultFolder<'ast, T>>(
|
|||
key: DeclarationFunctionKey<'ast>,
|
||||
) -> Result<DeclarationFunctionKey<'ast>, F::Error> {
|
||||
Ok(DeclarationFunctionKey {
|
||||
module: f.fold_module_id(key.module)?,
|
||||
signature: f.fold_signature(key.signature)?,
|
||||
..key
|
||||
})
|
||||
|
@ -1067,7 +1082,9 @@ pub fn fold_constant_symbol<'ast, T: Field, F: ResultFolder<'ast, T>>(
|
|||
) -> Result<TypedConstantSymbol<'ast, T>, F::Error> {
|
||||
match s {
|
||||
TypedConstantSymbol::Here(tc) => Ok(TypedConstantSymbol::Here(f.fold_constant(tc)?)),
|
||||
there => Ok(there),
|
||||
TypedConstantSymbol::There(id) => Ok(TypedConstantSymbol::There(
|
||||
f.fold_canonical_constant_identifier(id)?,
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1077,7 +1094,10 @@ pub fn fold_function_symbol<'ast, T: Field, F: ResultFolder<'ast, T>>(
|
|||
) -> Result<TypedFunctionSymbol<'ast, T>, F::Error> {
|
||||
match s {
|
||||
TypedFunctionSymbol::Here(fun) => Ok(TypedFunctionSymbol::Here(f.fold_function(fun)?)),
|
||||
there => Ok(there), // by default, do not fold modules recursively
|
||||
TypedFunctionSymbol::There(key) => Ok(TypedFunctionSymbol::There(
|
||||
f.fold_declaration_function_key(key)?,
|
||||
)),
|
||||
s => Ok(s),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1109,6 +1129,6 @@ pub fn fold_program<'ast, T: Field, F: ResultFolder<'ast, T>>(
|
|||
.into_iter()
|
||||
.map(|(module_id, module)| f.fold_module(module).map(|m| (module_id, m)))
|
||||
.collect::<Result<_, _>>()?,
|
||||
main: p.main,
|
||||
main: f.fold_module_id(p.main)?,
|
||||
})
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue