diff --git a/zokrates_core/src/semantics.rs b/zokrates_core/src/semantics.rs index b7133bd8..dfa75dfc 100644 --- a/zokrates_core/src/semantics.rs +++ b/zokrates_core/src/semantics.rs @@ -445,8 +445,8 @@ impl<'ast, T: Field> Checker<'ast, T> { declaration: SymbolDeclarationNode<'ast>, module_id: &ModuleId, state: &mut State<'ast, T>, - functions: &mut HashMap, TypedFunctionSymbol<'ast, T>>, - constants: &mut HashMap, TypedConstantSymbol<'ast, T>>, + functions: &mut TypedFunctionSymbols<'ast, T>, + constants: &mut TypedConstantSymbols<'ast, T>, symbol_unifier: &mut SymbolUnifier<'ast>, ) -> Result<(), Vec> { let mut errors: Vec = vec![]; @@ -506,13 +506,13 @@ impl<'ast, T: Field> Checker<'ast, T> { .in_file(module_id), ), true => { - constants.insert( + constants.push(( CanonicalConstantIdentifier::new( declaration.id, module_id.into(), ), TypedConstantSymbol::Here(c.clone()), - ); + )); self.insert_into_scope(Variable::with_id_and_type( declaration.id, c.get_type(), @@ -663,7 +663,7 @@ impl<'ast, T: Field> Checker<'ast, T> { let imported_id = CanonicalConstantIdentifier::new(import.symbol_id, import.module_id); let id = CanonicalConstantIdentifier::new(declaration.id, module_id.into()); - constants.insert(id.clone(), TypedConstantSymbol::There(imported_id)); + constants.push((id.clone(), TypedConstantSymbol::There(imported_id))); self.insert_into_scope(Variable::with_id_and_type(declaration.id, ty.clone())); state @@ -760,8 +760,8 @@ impl<'ast, T: Field> Checker<'ast, T> { module_id: &ModuleId, state: &mut State<'ast, T>, ) -> Result<(), Vec> { - let mut checked_functions = HashMap::new(); - let mut checked_constants = HashMap::new(); + let mut checked_functions = TypedFunctionSymbols::new(); + let mut checked_constants = TypedConstantSymbols::new(); // check if the module was already removed from the untyped ones let to_insert = match state.modules.remove(module_id) { diff --git a/zokrates_core/src/static_analysis/constant_inliner.rs b/zokrates_core/src/static_analysis/constant_inliner.rs index 67c8a101..28ea448a 100644 --- a/zokrates_core/src/static_analysis/constant_inliner.rs +++ b/zokrates_core/src/static_analysis/constant_inliner.rs @@ -34,16 +34,16 @@ impl<'ast, 'a, T: Field> ConstantInliner<'ast, T> { inliner.fold_program(p) } - // fn module(&self) -> &TypedModule<'ast, T> { - // self.modules.get(&self.location).unwrap() - // } - fn change_location(&mut self, location: OwnedTypedModuleId) -> OwnedTypedModuleId { let prev = self.location.clone(); self.location = location; prev } + fn treated(&self, id: &OwnedTypedModuleId) -> bool { + self.constants.contains_key(id) + } + fn get_constant(&mut self, id: &Identifier) -> Option> { assert_eq!(id.version, 0); match id.id { @@ -66,8 +66,12 @@ impl<'ast, T: Field> Folder<'ast, T> for ConstantInliner<'ast, T> { .modules .into_iter() .map(|(m_id, m)| { - self.change_location(m_id.clone()); - (m_id, self.fold_module(m)) + if !self.treated(&m_id) { + self.change_location(m_id.clone()); + (m_id, self.fold_module(m)) + } else { + (m_id, m) + } }) .collect(), ..p @@ -75,69 +79,68 @@ impl<'ast, T: Field> Folder<'ast, T> for ConstantInliner<'ast, T> { } fn fold_module(&mut self, m: TypedModule<'ast, T>) -> TypedModule<'ast, T> { - // only treat this module if its constants are not in the map yet - if !self.constants.contains_key(&self.location) { - self.constants.entry(self.location.clone()).or_default(); - TypedModule { - constants: m - .constants - .into_iter() - .map(|(id, tc)| { - let constant = match tc { - TypedConstantSymbol::There(imported_id) => { - if !self.constants.contains_key(&imported_id.module) { - let current_m_id = - self.change_location(imported_id.module.clone()); - let _ = self.fold_module( - self.modules.get(&imported_id.module).unwrap().clone(), - ); - self.change_location(current_m_id); - } - self.constants - .get(&imported_id.module) - .unwrap() - .get(&imported_id.id.into()) - .cloned() - .unwrap() + assert!(self + .constants + .insert(self.location.clone(), Default::default()) + .is_none()); + TypedModule { + constants: m + .constants + .into_iter() + .map(|(id, tc)| { + let constant = match tc { + TypedConstantSymbol::There(imported_id) => { + if !self.treated(&imported_id.module) { + let current_m_id = self.change_location(imported_id.module.clone()); + let m = self.fold_module( + self.modules.get(&imported_id.module).unwrap().clone(), + ); + + self.modules.insert(imported_id.module.clone(), m); + + self.change_location(current_m_id); } - TypedConstantSymbol::Here(c) => fold_constant(self, c).expression, - }; + self.constants + .get(&imported_id.module) + .unwrap() + .get(&imported_id.id.into()) + .cloned() + .unwrap() + } + TypedConstantSymbol::Here(c) => fold_constant(self, c).expression, + }; - let constant = Propagator::with_constants( - self.constants.get_mut(&self.location).unwrap(), - ) - .fold_expression(constant) - .unwrap(); + let constant = + Propagator::with_constants(self.constants.get_mut(&self.location).unwrap()) + .fold_expression(constant) + .unwrap(); - assert!(self - .constants - .entry(self.location.clone()) - .or_default() - .insert(id.id.into(), constant.clone()) - .is_none()); + assert!(self + .constants + .entry(self.location.clone()) + .or_default() + .insert(id.id.into(), constant.clone()) + .is_none()); - ( - id, - TypedConstantSymbol::Here(TypedConstant { - ty: constant.get_type().clone(), - expression: constant, - }), - ) - }) - .collect(), - functions: m - .functions - .into_iter() - .map(|(key, fun)| { - ( - self.fold_declaration_function_key(key), - self.fold_function_symbol(fun), - ) - }) - .collect(), - } - } else { - m + ( + id, + TypedConstantSymbol::Here(TypedConstant { + ty: constant.get_type().clone(), + expression: constant, + }), + ) + }) + .collect(), + functions: m + .functions + .into_iter() + .map(|(key, fun)| { + ( + self.fold_declaration_function_key(key), + self.fold_function_symbol(fun), + ) + }) + .collect(), } } @@ -159,7 +162,7 @@ impl<'ast, T: Field> Folder<'ast, T> for ConstantInliner<'ast, T> { inner: UExpressionInner::Value(v), .. }) => v as u32, - _ => unreachable!(), + _ => unreachable!("all constants should be reduceable to u32 literals"), }, ), c => c, diff --git a/zokrates_core/src/typed_absy/mod.rs b/zokrates_core/src/typed_absy/mod.rs index af0f17d0..0e5e5486 100644 --- a/zokrates_core/src/typed_absy/mod.rs +++ b/zokrates_core/src/typed_absy/mod.rs @@ -70,8 +70,11 @@ pub enum TypedConstantSymbol<'ast, T> { } /// A collection of `TypedConstantSymbol`s -pub type TypedConstantSymbols<'ast, T> = - HashMap, TypedConstantSymbol<'ast, T>>; +/// It is still ordered, as we inline the constants in the order they are declared +pub type TypedConstantSymbols<'ast, T> = Vec<( + CanonicalConstantIdentifier<'ast>, + TypedConstantSymbol<'ast, T>, +)>; /// A typed program as a collection of modules, one of them being the main #[derive(PartialEq, Debug, Clone)]