1
0
Fork 0
mirror of synced 2025-09-23 12:18:44 +00:00

fold module id in folders

This commit is contained in:
schaeff 2021-06-15 16:41:38 +02:00
parent 0bb9f5897a
commit 759f4dd553
5 changed files with 73 additions and 44 deletions

2
access.zok Normal file
View file

@ -0,0 +1,2 @@
def main(field[3][2] a, u32 index) -> field[2]:
return a[index]

View file

@ -1,8 +1,8 @@
from "./bar" import main as bar
//from "./bar" import main as bar
from "./baz" import BAZ as baz
import "./foo" as f
def main() -> field:
field foo = f()
assert(foo == bar() + baz)
assert(foo == 21 + baz)
return foo

View file

@ -48,7 +48,7 @@ impl<'ast, 'a, T: Field> ConstantInliner<'ast, T> {
assert_eq!(id.version, 0);
match id.id {
CoreIdentifier::Call(..) => {
unreachable!("calls indentifiers are only available after call inlining")
unreachable!("calls identifiers are only available after call inlining")
}
CoreIdentifier::Source(id) => self
.constants
@ -60,46 +60,34 @@ impl<'ast, 'a, T: Field> ConstantInliner<'ast, T> {
}
impl<'ast, T: Field> Folder<'ast, T> for ConstantInliner<'ast, T> {
fn fold_program(&mut self, p: TypedProgram<'ast, T>) -> TypedProgram<'ast, T> {
TypedProgram {
modules: p
.modules
.into_iter()
.map(|(m_id, m)| {
if !self.treated(&m_id) {
self.change_location(m_id.clone());
(m_id, self.fold_module(m))
} else {
(m_id, m)
}
})
.collect(),
..p
fn fold_module_id(&mut self, id: OwnedTypedModuleId) -> OwnedTypedModuleId {
// anytime we encounter a module id, visit the corresponding module if it hasn't been done yet
if !self.treated(&id) {
let current_m_id = self.change_location(id.clone());
self.constants.entry(self.location.clone()).or_default();
let m = self.fold_module(self.modules.get(&id).unwrap().clone());
self.modules.insert(id.clone(), m);
self.change_location(current_m_id);
}
id
}
fn fold_module(&mut self, m: TypedModule<'ast, T>) -> TypedModule<'ast, T> {
assert!(self
.constants
.insert(self.location.clone(), Default::default())
.is_none());
// initialise a constant map for this module
self.constants.entry(self.location.clone()).or_default();
TypedModule {
constants: m
.constants
.into_iter()
.map(|(id, tc)| {
let id = self.fold_canonical_constant_identifier(id);
let constant = match tc {
TypedConstantSymbol::There(imported_id) => {
if !self.treated(&imported_id.module) {
let current_m_id = self.change_location(imported_id.module.clone());
let m = self.fold_module(
self.modules.get(&imported_id.module).unwrap().clone(),
);
self.modules.insert(imported_id.module.clone(), m);
self.change_location(current_m_id);
}
let imported_id = self.fold_canonical_constant_identifier(imported_id);
self.constants
.get(&imported_id.module)
.unwrap()
@ -115,12 +103,10 @@ impl<'ast, T: Field> Folder<'ast, T> for ConstantInliner<'ast, T> {
.fold_expression(constant)
.unwrap();
assert!(self
.constants
self.constants
.entry(self.location.clone())
.or_default()
.insert(id.id.into(), constant.clone())
.is_none());
.insert(id.id.into(), constant.clone());
(
id,
@ -176,7 +162,7 @@ impl<'ast, T: Field> Folder<'ast, T> for ConstantInliner<'ast, T> {
) -> FieldElementExpression<'ast, T> {
match e {
FieldElementExpression::Identifier(ref id) => match self.get_constant(id) {
Some(c) => self.fold_expression(c).try_into().unwrap(),
Some(c) => c.try_into().unwrap(),
None => fold_field_expression(self, e),
},
e => fold_field_expression(self, e),

View file

@ -215,6 +215,20 @@ pub trait Folder<'ast, T: Field>: Sized {
}
}
fn fold_canonical_constant_identifier(
&mut self,
i: CanonicalConstantIdentifier<'ast>,
) -> CanonicalConstantIdentifier<'ast> {
CanonicalConstantIdentifier {
module: self.fold_module_id(i.module),
id: i.id,
}
}
fn fold_module_id(&mut self, i: OwnedTypedModuleId) -> OwnedTypedModuleId {
i
}
fn fold_expression(&mut self, e: TypedExpression<'ast, T>) -> TypedExpression<'ast, T> {
match e {
TypedExpression::FieldElement(e) => self.fold_field_expression(e).into(),
@ -342,7 +356,12 @@ pub fn fold_module<'ast, T: Field, F: Folder<'ast, T>>(
constants: m
.constants
.into_iter()
.map(|(key, tc)| (key, f.fold_constant_symbol(tc)))
.map(|(id, tc)| {
(
f.fold_canonical_constant_identifier(id),
f.fold_constant_symbol(tc),
)
})
.collect(),
functions: m
.functions
@ -1014,7 +1033,9 @@ pub fn fold_constant_symbol<'ast, T: Field, F: Folder<'ast, T>>(
) -> TypedConstantSymbol<'ast, T> {
match s {
TypedConstantSymbol::Here(tc) => TypedConstantSymbol::Here(f.fold_constant(tc)),
there => there,
TypedConstantSymbol::There(id) => {
TypedConstantSymbol::There(f.fold_canonical_constant_identifier(id))
}
}
}
@ -1052,8 +1073,8 @@ pub fn fold_program<'ast, T: Field, F: Folder<'ast, T>>(
modules: p
.modules
.into_iter()
.map(|(module_id, module)| (module_id, f.fold_module(module)))
.map(|(module_id, module)| (f.fold_module_id(module_id), f.fold_module(module)))
.collect(),
main: p.main,
main: f.fold_module_id(p.main),
}
}

View file

@ -114,6 +114,20 @@ pub trait ResultFolder<'ast, T: Field>: Sized {
})
}
fn fold_canonical_constant_identifier(
&mut self,
i: CanonicalConstantIdentifier<'ast>,
) -> Result<CanonicalConstantIdentifier<'ast>, Self::Error> {
Ok(CanonicalConstantIdentifier {
module: self.fold_module_id(i.module)?,
id: i.id,
})
}
fn fold_module_id(&mut self, i: OwnedTypedModuleId) -> Result<OwnedTypedModuleId, Self::Error> {
Ok(i)
}
fn fold_name(&mut self, n: Identifier<'ast>) -> Result<Identifier<'ast>, Self::Error> {
Ok(n)
}
@ -923,6 +937,7 @@ pub fn fold_declaration_function_key<'ast, T: Field, F: ResultFolder<'ast, T>>(
key: DeclarationFunctionKey<'ast>,
) -> Result<DeclarationFunctionKey<'ast>, F::Error> {
Ok(DeclarationFunctionKey {
module: f.fold_module_id(key.module)?,
signature: f.fold_signature(key.signature)?,
..key
})
@ -1067,7 +1082,9 @@ pub fn fold_constant_symbol<'ast, T: Field, F: ResultFolder<'ast, T>>(
) -> Result<TypedConstantSymbol<'ast, T>, F::Error> {
match s {
TypedConstantSymbol::Here(tc) => Ok(TypedConstantSymbol::Here(f.fold_constant(tc)?)),
there => Ok(there),
TypedConstantSymbol::There(id) => Ok(TypedConstantSymbol::There(
f.fold_canonical_constant_identifier(id)?,
)),
}
}
@ -1077,7 +1094,10 @@ pub fn fold_function_symbol<'ast, T: Field, F: ResultFolder<'ast, T>>(
) -> Result<TypedFunctionSymbol<'ast, T>, F::Error> {
match s {
TypedFunctionSymbol::Here(fun) => Ok(TypedFunctionSymbol::Here(f.fold_function(fun)?)),
there => Ok(there), // by default, do not fold modules recursively
TypedFunctionSymbol::There(key) => Ok(TypedFunctionSymbol::There(
f.fold_declaration_function_key(key)?,
)),
s => Ok(s),
}
}
@ -1109,6 +1129,6 @@ pub fn fold_program<'ast, T: Field, F: ResultFolder<'ast, T>>(
.into_iter()
.map(|(module_id, module)| f.fold_module(module).map(|m| (module_id, m)))
.collect::<Result<_, _>>()?,
main: p.main,
main: f.fold_module_id(p.main)?,
})
}