1
0
Fork 0
mirror of synced 2025-09-23 12:18:44 +00:00

keep constants ordered for inlining, clean

This commit is contained in:
schaeff 2021-06-11 12:04:38 +02:00
parent 33f0db4d47
commit 09c86e1735
3 changed files with 81 additions and 75 deletions

View file

@ -445,8 +445,8 @@ impl<'ast, T: Field> Checker<'ast, T> {
declaration: SymbolDeclarationNode<'ast>,
module_id: &ModuleId,
state: &mut State<'ast, T>,
functions: &mut HashMap<DeclarationFunctionKey<'ast>, TypedFunctionSymbol<'ast, T>>,
constants: &mut HashMap<CanonicalConstantIdentifier<'ast>, TypedConstantSymbol<'ast, T>>,
functions: &mut TypedFunctionSymbols<'ast, T>,
constants: &mut TypedConstantSymbols<'ast, T>,
symbol_unifier: &mut SymbolUnifier<'ast>,
) -> Result<(), Vec<Error>> {
let mut errors: Vec<Error> = vec![];
@ -506,13 +506,13 @@ impl<'ast, T: Field> Checker<'ast, T> {
.in_file(module_id),
),
true => {
constants.insert(
constants.push((
CanonicalConstantIdentifier::new(
declaration.id,
module_id.into(),
),
TypedConstantSymbol::Here(c.clone()),
);
));
self.insert_into_scope(Variable::with_id_and_type(
declaration.id,
c.get_type(),
@ -663,7 +663,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
let imported_id = CanonicalConstantIdentifier::new(import.symbol_id, import.module_id);
let id = CanonicalConstantIdentifier::new(declaration.id, module_id.into());
constants.insert(id.clone(), TypedConstantSymbol::There(imported_id));
constants.push((id.clone(), TypedConstantSymbol::There(imported_id)));
self.insert_into_scope(Variable::with_id_and_type(declaration.id, ty.clone()));
state
@ -760,8 +760,8 @@ impl<'ast, T: Field> Checker<'ast, T> {
module_id: &ModuleId,
state: &mut State<'ast, T>,
) -> Result<(), Vec<Error>> {
let mut checked_functions = HashMap::new();
let mut checked_constants = HashMap::new();
let mut checked_functions = TypedFunctionSymbols::new();
let mut checked_constants = TypedConstantSymbols::new();
// check if the module was already removed from the untyped ones
let to_insert = match state.modules.remove(module_id) {

View file

@ -34,16 +34,16 @@ impl<'ast, 'a, T: Field> ConstantInliner<'ast, T> {
inliner.fold_program(p)
}
// fn module(&self) -> &TypedModule<'ast, T> {
// self.modules.get(&self.location).unwrap()
// }
fn change_location(&mut self, location: OwnedTypedModuleId) -> OwnedTypedModuleId {
let prev = self.location.clone();
self.location = location;
prev
}
fn treated(&self, id: &OwnedTypedModuleId) -> bool {
self.constants.contains_key(id)
}
fn get_constant(&mut self, id: &Identifier) -> Option<TypedExpression<'ast, T>> {
assert_eq!(id.version, 0);
match id.id {
@ -66,8 +66,12 @@ impl<'ast, T: Field> Folder<'ast, T> for ConstantInliner<'ast, T> {
.modules
.into_iter()
.map(|(m_id, m)| {
self.change_location(m_id.clone());
(m_id, self.fold_module(m))
if !self.treated(&m_id) {
self.change_location(m_id.clone());
(m_id, self.fold_module(m))
} else {
(m_id, m)
}
})
.collect(),
..p
@ -75,69 +79,68 @@ impl<'ast, T: Field> Folder<'ast, T> for ConstantInliner<'ast, T> {
}
fn fold_module(&mut self, m: TypedModule<'ast, T>) -> TypedModule<'ast, T> {
// only treat this module if its constants are not in the map yet
if !self.constants.contains_key(&self.location) {
self.constants.entry(self.location.clone()).or_default();
TypedModule {
constants: m
.constants
.into_iter()
.map(|(id, tc)| {
let constant = match tc {
TypedConstantSymbol::There(imported_id) => {
if !self.constants.contains_key(&imported_id.module) {
let current_m_id =
self.change_location(imported_id.module.clone());
let _ = self.fold_module(
self.modules.get(&imported_id.module).unwrap().clone(),
);
self.change_location(current_m_id);
}
self.constants
.get(&imported_id.module)
.unwrap()
.get(&imported_id.id.into())
.cloned()
.unwrap()
assert!(self
.constants
.insert(self.location.clone(), Default::default())
.is_none());
TypedModule {
constants: m
.constants
.into_iter()
.map(|(id, tc)| {
let constant = match tc {
TypedConstantSymbol::There(imported_id) => {
if !self.treated(&imported_id.module) {
let current_m_id = self.change_location(imported_id.module.clone());
let m = self.fold_module(
self.modules.get(&imported_id.module).unwrap().clone(),
);
self.modules.insert(imported_id.module.clone(), m);
self.change_location(current_m_id);
}
TypedConstantSymbol::Here(c) => fold_constant(self, c).expression,
};
self.constants
.get(&imported_id.module)
.unwrap()
.get(&imported_id.id.into())
.cloned()
.unwrap()
}
TypedConstantSymbol::Here(c) => fold_constant(self, c).expression,
};
let constant = Propagator::with_constants(
self.constants.get_mut(&self.location).unwrap(),
)
.fold_expression(constant)
.unwrap();
let constant =
Propagator::with_constants(self.constants.get_mut(&self.location).unwrap())
.fold_expression(constant)
.unwrap();
assert!(self
.constants
.entry(self.location.clone())
.or_default()
.insert(id.id.into(), constant.clone())
.is_none());
assert!(self
.constants
.entry(self.location.clone())
.or_default()
.insert(id.id.into(), constant.clone())
.is_none());
(
id,
TypedConstantSymbol::Here(TypedConstant {
ty: constant.get_type().clone(),
expression: constant,
}),
)
})
.collect(),
functions: m
.functions
.into_iter()
.map(|(key, fun)| {
(
self.fold_declaration_function_key(key),
self.fold_function_symbol(fun),
)
})
.collect(),
}
} else {
m
(
id,
TypedConstantSymbol::Here(TypedConstant {
ty: constant.get_type().clone(),
expression: constant,
}),
)
})
.collect(),
functions: m
.functions
.into_iter()
.map(|(key, fun)| {
(
self.fold_declaration_function_key(key),
self.fold_function_symbol(fun),
)
})
.collect(),
}
}
@ -159,7 +162,7 @@ impl<'ast, T: Field> Folder<'ast, T> for ConstantInliner<'ast, T> {
inner: UExpressionInner::Value(v),
..
}) => v as u32,
_ => unreachable!(),
_ => unreachable!("all constants should be reduceable to u32 literals"),
},
),
c => c,

View file

@ -70,8 +70,11 @@ pub enum TypedConstantSymbol<'ast, T> {
}
/// A collection of `TypedConstantSymbol`s
pub type TypedConstantSymbols<'ast, T> =
HashMap<CanonicalConstantIdentifier<'ast>, TypedConstantSymbol<'ast, T>>;
/// It is still ordered, as we inline the constants in the order they are declared
pub type TypedConstantSymbols<'ast, T> = Vec<(
CanonicalConstantIdentifier<'ast>,
TypedConstantSymbol<'ast, T>,
)>;
/// A typed program as a collection of modules, one of them being the main
#[derive(PartialEq, Debug, Clone)]