fold module id in folders
This commit is contained in:
parent
0bb9f5897a
commit
759f4dd553
5 changed files with 73 additions and 44 deletions
2
access.zok
Normal file
2
access.zok
Normal file
|
@ -0,0 +1,2 @@
|
||||||
|
def main(field[3][2] a, u32 index) -> field[2]:
|
||||||
|
return a[index]
|
|
@ -1,8 +1,8 @@
|
||||||
from "./bar" import main as bar
|
//from "./bar" import main as bar
|
||||||
from "./baz" import BAZ as baz
|
from "./baz" import BAZ as baz
|
||||||
import "./foo" as f
|
import "./foo" as f
|
||||||
|
|
||||||
def main() -> field:
|
def main() -> field:
|
||||||
field foo = f()
|
field foo = f()
|
||||||
assert(foo == bar() + baz)
|
assert(foo == 21 + baz)
|
||||||
return foo
|
return foo
|
|
@ -48,7 +48,7 @@ impl<'ast, 'a, T: Field> ConstantInliner<'ast, T> {
|
||||||
assert_eq!(id.version, 0);
|
assert_eq!(id.version, 0);
|
||||||
match id.id {
|
match id.id {
|
||||||
CoreIdentifier::Call(..) => {
|
CoreIdentifier::Call(..) => {
|
||||||
unreachable!("calls indentifiers are only available after call inlining")
|
unreachable!("calls identifiers are only available after call inlining")
|
||||||
}
|
}
|
||||||
CoreIdentifier::Source(id) => self
|
CoreIdentifier::Source(id) => self
|
||||||
.constants
|
.constants
|
||||||
|
@ -60,46 +60,34 @@ impl<'ast, 'a, T: Field> ConstantInliner<'ast, T> {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'ast, T: Field> Folder<'ast, T> for ConstantInliner<'ast, T> {
|
impl<'ast, T: Field> Folder<'ast, T> for ConstantInliner<'ast, T> {
|
||||||
fn fold_program(&mut self, p: TypedProgram<'ast, T>) -> TypedProgram<'ast, T> {
|
fn fold_module_id(&mut self, id: OwnedTypedModuleId) -> OwnedTypedModuleId {
|
||||||
TypedProgram {
|
// anytime we encounter a module id, visit the corresponding module if it hasn't been done yet
|
||||||
modules: p
|
if !self.treated(&id) {
|
||||||
.modules
|
let current_m_id = self.change_location(id.clone());
|
||||||
.into_iter()
|
self.constants.entry(self.location.clone()).or_default();
|
||||||
.map(|(m_id, m)| {
|
let m = self.fold_module(self.modules.get(&id).unwrap().clone());
|
||||||
if !self.treated(&m_id) {
|
|
||||||
self.change_location(m_id.clone());
|
self.modules.insert(id.clone(), m);
|
||||||
(m_id, self.fold_module(m))
|
|
||||||
} else {
|
self.change_location(current_m_id);
|
||||||
(m_id, m)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.collect(),
|
|
||||||
..p
|
|
||||||
}
|
}
|
||||||
|
id
|
||||||
}
|
}
|
||||||
|
|
||||||
fn fold_module(&mut self, m: TypedModule<'ast, T>) -> TypedModule<'ast, T> {
|
fn fold_module(&mut self, m: TypedModule<'ast, T>) -> TypedModule<'ast, T> {
|
||||||
assert!(self
|
// initialise a constant map for this module
|
||||||
.constants
|
self.constants.entry(self.location.clone()).or_default();
|
||||||
.insert(self.location.clone(), Default::default())
|
|
||||||
.is_none());
|
|
||||||
TypedModule {
|
TypedModule {
|
||||||
constants: m
|
constants: m
|
||||||
.constants
|
.constants
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|(id, tc)| {
|
.map(|(id, tc)| {
|
||||||
|
let id = self.fold_canonical_constant_identifier(id);
|
||||||
|
|
||||||
let constant = match tc {
|
let constant = match tc {
|
||||||
TypedConstantSymbol::There(imported_id) => {
|
TypedConstantSymbol::There(imported_id) => {
|
||||||
if !self.treated(&imported_id.module) {
|
let imported_id = self.fold_canonical_constant_identifier(imported_id);
|
||||||
let current_m_id = self.change_location(imported_id.module.clone());
|
|
||||||
let m = self.fold_module(
|
|
||||||
self.modules.get(&imported_id.module).unwrap().clone(),
|
|
||||||
);
|
|
||||||
|
|
||||||
self.modules.insert(imported_id.module.clone(), m);
|
|
||||||
|
|
||||||
self.change_location(current_m_id);
|
|
||||||
}
|
|
||||||
self.constants
|
self.constants
|
||||||
.get(&imported_id.module)
|
.get(&imported_id.module)
|
||||||
.unwrap()
|
.unwrap()
|
||||||
|
@ -115,12 +103,10 @@ impl<'ast, T: Field> Folder<'ast, T> for ConstantInliner<'ast, T> {
|
||||||
.fold_expression(constant)
|
.fold_expression(constant)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
assert!(self
|
self.constants
|
||||||
.constants
|
|
||||||
.entry(self.location.clone())
|
.entry(self.location.clone())
|
||||||
.or_default()
|
.or_default()
|
||||||
.insert(id.id.into(), constant.clone())
|
.insert(id.id.into(), constant.clone());
|
||||||
.is_none());
|
|
||||||
|
|
||||||
(
|
(
|
||||||
id,
|
id,
|
||||||
|
@ -176,7 +162,7 @@ impl<'ast, T: Field> Folder<'ast, T> for ConstantInliner<'ast, T> {
|
||||||
) -> FieldElementExpression<'ast, T> {
|
) -> FieldElementExpression<'ast, T> {
|
||||||
match e {
|
match e {
|
||||||
FieldElementExpression::Identifier(ref id) => match self.get_constant(id) {
|
FieldElementExpression::Identifier(ref id) => match self.get_constant(id) {
|
||||||
Some(c) => self.fold_expression(c).try_into().unwrap(),
|
Some(c) => c.try_into().unwrap(),
|
||||||
None => fold_field_expression(self, e),
|
None => fold_field_expression(self, e),
|
||||||
},
|
},
|
||||||
e => fold_field_expression(self, e),
|
e => fold_field_expression(self, e),
|
||||||
|
|
|
@ -215,6 +215,20 @@ pub trait Folder<'ast, T: Field>: Sized {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn fold_canonical_constant_identifier(
|
||||||
|
&mut self,
|
||||||
|
i: CanonicalConstantIdentifier<'ast>,
|
||||||
|
) -> CanonicalConstantIdentifier<'ast> {
|
||||||
|
CanonicalConstantIdentifier {
|
||||||
|
module: self.fold_module_id(i.module),
|
||||||
|
id: i.id,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn fold_module_id(&mut self, i: OwnedTypedModuleId) -> OwnedTypedModuleId {
|
||||||
|
i
|
||||||
|
}
|
||||||
|
|
||||||
fn fold_expression(&mut self, e: TypedExpression<'ast, T>) -> TypedExpression<'ast, T> {
|
fn fold_expression(&mut self, e: TypedExpression<'ast, T>) -> TypedExpression<'ast, T> {
|
||||||
match e {
|
match e {
|
||||||
TypedExpression::FieldElement(e) => self.fold_field_expression(e).into(),
|
TypedExpression::FieldElement(e) => self.fold_field_expression(e).into(),
|
||||||
|
@ -342,7 +356,12 @@ pub fn fold_module<'ast, T: Field, F: Folder<'ast, T>>(
|
||||||
constants: m
|
constants: m
|
||||||
.constants
|
.constants
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|(key, tc)| (key, f.fold_constant_symbol(tc)))
|
.map(|(id, tc)| {
|
||||||
|
(
|
||||||
|
f.fold_canonical_constant_identifier(id),
|
||||||
|
f.fold_constant_symbol(tc),
|
||||||
|
)
|
||||||
|
})
|
||||||
.collect(),
|
.collect(),
|
||||||
functions: m
|
functions: m
|
||||||
.functions
|
.functions
|
||||||
|
@ -1014,7 +1033,9 @@ pub fn fold_constant_symbol<'ast, T: Field, F: Folder<'ast, T>>(
|
||||||
) -> TypedConstantSymbol<'ast, T> {
|
) -> TypedConstantSymbol<'ast, T> {
|
||||||
match s {
|
match s {
|
||||||
TypedConstantSymbol::Here(tc) => TypedConstantSymbol::Here(f.fold_constant(tc)),
|
TypedConstantSymbol::Here(tc) => TypedConstantSymbol::Here(f.fold_constant(tc)),
|
||||||
there => there,
|
TypedConstantSymbol::There(id) => {
|
||||||
|
TypedConstantSymbol::There(f.fold_canonical_constant_identifier(id))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1052,8 +1073,8 @@ pub fn fold_program<'ast, T: Field, F: Folder<'ast, T>>(
|
||||||
modules: p
|
modules: p
|
||||||
.modules
|
.modules
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|(module_id, module)| (module_id, f.fold_module(module)))
|
.map(|(module_id, module)| (f.fold_module_id(module_id), f.fold_module(module)))
|
||||||
.collect(),
|
.collect(),
|
||||||
main: p.main,
|
main: f.fold_module_id(p.main),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -114,6 +114,20 @@ pub trait ResultFolder<'ast, T: Field>: Sized {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn fold_canonical_constant_identifier(
|
||||||
|
&mut self,
|
||||||
|
i: CanonicalConstantIdentifier<'ast>,
|
||||||
|
) -> Result<CanonicalConstantIdentifier<'ast>, Self::Error> {
|
||||||
|
Ok(CanonicalConstantIdentifier {
|
||||||
|
module: self.fold_module_id(i.module)?,
|
||||||
|
id: i.id,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn fold_module_id(&mut self, i: OwnedTypedModuleId) -> Result<OwnedTypedModuleId, Self::Error> {
|
||||||
|
Ok(i)
|
||||||
|
}
|
||||||
|
|
||||||
fn fold_name(&mut self, n: Identifier<'ast>) -> Result<Identifier<'ast>, Self::Error> {
|
fn fold_name(&mut self, n: Identifier<'ast>) -> Result<Identifier<'ast>, Self::Error> {
|
||||||
Ok(n)
|
Ok(n)
|
||||||
}
|
}
|
||||||
|
@ -923,6 +937,7 @@ pub fn fold_declaration_function_key<'ast, T: Field, F: ResultFolder<'ast, T>>(
|
||||||
key: DeclarationFunctionKey<'ast>,
|
key: DeclarationFunctionKey<'ast>,
|
||||||
) -> Result<DeclarationFunctionKey<'ast>, F::Error> {
|
) -> Result<DeclarationFunctionKey<'ast>, F::Error> {
|
||||||
Ok(DeclarationFunctionKey {
|
Ok(DeclarationFunctionKey {
|
||||||
|
module: f.fold_module_id(key.module)?,
|
||||||
signature: f.fold_signature(key.signature)?,
|
signature: f.fold_signature(key.signature)?,
|
||||||
..key
|
..key
|
||||||
})
|
})
|
||||||
|
@ -1067,7 +1082,9 @@ pub fn fold_constant_symbol<'ast, T: Field, F: ResultFolder<'ast, T>>(
|
||||||
) -> Result<TypedConstantSymbol<'ast, T>, F::Error> {
|
) -> Result<TypedConstantSymbol<'ast, T>, F::Error> {
|
||||||
match s {
|
match s {
|
||||||
TypedConstantSymbol::Here(tc) => Ok(TypedConstantSymbol::Here(f.fold_constant(tc)?)),
|
TypedConstantSymbol::Here(tc) => Ok(TypedConstantSymbol::Here(f.fold_constant(tc)?)),
|
||||||
there => Ok(there),
|
TypedConstantSymbol::There(id) => Ok(TypedConstantSymbol::There(
|
||||||
|
f.fold_canonical_constant_identifier(id)?,
|
||||||
|
)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1077,7 +1094,10 @@ pub fn fold_function_symbol<'ast, T: Field, F: ResultFolder<'ast, T>>(
|
||||||
) -> Result<TypedFunctionSymbol<'ast, T>, F::Error> {
|
) -> Result<TypedFunctionSymbol<'ast, T>, F::Error> {
|
||||||
match s {
|
match s {
|
||||||
TypedFunctionSymbol::Here(fun) => Ok(TypedFunctionSymbol::Here(f.fold_function(fun)?)),
|
TypedFunctionSymbol::Here(fun) => Ok(TypedFunctionSymbol::Here(f.fold_function(fun)?)),
|
||||||
there => Ok(there), // by default, do not fold modules recursively
|
TypedFunctionSymbol::There(key) => Ok(TypedFunctionSymbol::There(
|
||||||
|
f.fold_declaration_function_key(key)?,
|
||||||
|
)),
|
||||||
|
s => Ok(s),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1109,6 +1129,6 @@ pub fn fold_program<'ast, T: Field, F: ResultFolder<'ast, T>>(
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|(module_id, module)| f.fold_module(module).map(|m| (module_id, m)))
|
.map(|(module_id, module)| f.fold_module(module).map(|m| (module_id, m)))
|
||||||
.collect::<Result<_, _>>()?,
|
.collect::<Result<_, _>>()?,
|
||||||
main: p.main,
|
main: f.fold_module_id(p.main)?,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue