diff --git a/zokrates_core/src/semantics.rs b/zokrates_core/src/semantics.rs index 37c4baeb..b7133bd8 100644 --- a/zokrates_core/src/semantics.rs +++ b/zokrates_core/src/semantics.rs @@ -1122,7 +1122,8 @@ impl<'ast, T: Field> Checker<'ast, T> { fn check_generic_expression( &mut self, expr: ExpressionNode<'ast>, - constants_map: &HashMap<&'ast str, Type<'ast, T>>, + module_id: &ModuleId, + constants_map: &HashMap, Type<'ast, T>>, generics_map: &HashMap, usize>, ) -> Result, ErrorInner> { let pos = expr.pos(); @@ -1148,7 +1149,7 @@ impl<'ast, T: Field> Checker<'ast, T> { match (constants_map.get(name), generics_map.get(&name)) { (Some(ty), None) => { match ty { - Type::Uint(UBitwidth::B32) => Ok(DeclarationConstant::Constant(name)), + Type::Uint(UBitwidth::B32) => Ok(DeclarationConstant::Constant(CanonicalConstantIdentifier::new(name, module_id.into()))), _ => Err(ErrorInner { pos: Some(pos), message: format!( @@ -1192,6 +1193,7 @@ impl<'ast, T: Field> Checker<'ast, T> { UnresolvedType::Array(t, size) => { let checked_size = self.check_generic_expression( size.clone(), + module_id, state.constants.get(module_id).unwrap_or(&HashMap::new()), generics_map, )?; diff --git a/zokrates_core/src/static_analysis/constant_inliner.rs b/zokrates_core/src/static_analysis/constant_inliner.rs index 788c2841..67c8a101 100644 --- a/zokrates_core/src/static_analysis/constant_inliner.rs +++ b/zokrates_core/src/static_analysis/constant_inliner.rs @@ -1,13 +1,14 @@ +use crate::static_analysis::Propagator; use crate::typed_absy::folder::*; +use crate::typed_absy::result_folder::ResultFolder; use crate::typed_absy::types::DeclarationConstant; use crate::typed_absy::*; -use core::str; use std::collections::HashMap; use std::convert::TryInto; use zokrates_field::Field; type ModuleConstants<'ast, T> = - HashMap>>; + HashMap, TypedExpression<'ast, T>>>; pub struct ConstantInliner<'ast, T> { modules: TypedModules<'ast, T>, @@ -43,7 +44,7 @@ impl<'ast, 'a, T: Field> ConstantInliner<'ast, T> { prev } - fn get_constant(&mut self, id: &Identifier) -> Option> { + fn get_constant(&mut self, id: &Identifier) -> Option> { assert_eq!(id.version, 0); match id.id { CoreIdentifier::Call(..) => { @@ -52,7 +53,7 @@ impl<'ast, 'a, T: Field> ConstantInliner<'ast, T> { CoreIdentifier::Source(id) => self .constants .get(&self.location) - .and_then(|constants| constants.get(id)) + .and_then(|constants| constants.get(&id.into())) .cloned(), } } @@ -85,29 +86,43 @@ impl<'ast, T: Field> Folder<'ast, T> for ConstantInliner<'ast, T> { let constant = match tc { TypedConstantSymbol::There(imported_id) => { if !self.constants.contains_key(&imported_id.module) { - let current_m_id = self.change_location(id.module.clone()); - let _ = self - .fold_module(self.modules.get(&id.module).unwrap().clone()); + let current_m_id = + self.change_location(imported_id.module.clone()); + let _ = self.fold_module( + self.modules.get(&imported_id.module).unwrap().clone(), + ); self.change_location(current_m_id); } self.constants .get(&imported_id.module) .unwrap() - .get(&imported_id.id) + .get(&imported_id.id.into()) .cloned() .unwrap() } - TypedConstantSymbol::Here(c) => fold_constant(self, c), + TypedConstantSymbol::Here(c) => fold_constant(self, c).expression, }; + let constant = Propagator::with_constants( + self.constants.get_mut(&self.location).unwrap(), + ) + .fold_expression(constant) + .unwrap(); + assert!(self .constants .entry(self.location.clone()) .or_default() - .insert(id.id, constant.clone()) + .insert(id.id.into(), constant.clone()) .is_none()); - (id, TypedConstantSymbol::Here(constant)) + ( + id, + TypedConstantSymbol::Here(TypedConstant { + ty: constant.get_type().clone(), + expression: constant, + }), + ) }) .collect(), functions: m @@ -130,28 +145,20 @@ impl<'ast, T: Field> Folder<'ast, T> for ConstantInliner<'ast, T> { &mut self, c: DeclarationConstant<'ast>, ) -> DeclarationConstant<'ast> { - println!("id {}", c); - println!("constants {:#?}", self.constants); - println!("location {}", self.location.display()); - match c { DeclarationConstant::Constant(id) => DeclarationConstant::Concrete( match self .constants - .get(&self.location) + .get(&id.module) .unwrap() - .get(&id) + .get(&id.id.into()) .cloned() .unwrap() { - TypedConstant { - ty: Type::Uint(UBitwidth::B32), - expression: - TypedExpression::Uint(UExpression { - inner: UExpressionInner::Value(v), - .. - }), - } => v as u32, + TypedExpression::Uint(UExpression { + inner: UExpressionInner::Value(v), + .. + }) => v as u32, _ => unreachable!(), }, ), @@ -165,7 +172,7 @@ impl<'ast, T: Field> Folder<'ast, T> for ConstantInliner<'ast, T> { ) -> FieldElementExpression<'ast, T> { match e { FieldElementExpression::Identifier(ref id) => match self.get_constant(id) { - Some(c) => self.fold_constant(c).try_into().unwrap(), + Some(c) => self.fold_expression(c).try_into().unwrap(), None => fold_field_expression(self, e), }, e => fold_field_expression(self, e), @@ -178,7 +185,7 @@ impl<'ast, T: Field> Folder<'ast, T> for ConstantInliner<'ast, T> { ) -> BooleanExpression<'ast, T> { match e { BooleanExpression::Identifier(ref id) => match self.get_constant(id) { - Some(c) => self.fold_constant(c).try_into().unwrap(), + Some(c) => self.fold_expression(c).try_into().unwrap(), None => fold_boolean_expression(self, e), }, e => fold_boolean_expression(self, e), @@ -193,7 +200,7 @@ impl<'ast, T: Field> Folder<'ast, T> for ConstantInliner<'ast, T> { match e { UExpressionInner::Identifier(ref id) => match self.get_constant(id) { Some(c) => { - let e: UExpression<'ast, T> = self.fold_constant(c).try_into().unwrap(); + let e: UExpression<'ast, T> = self.fold_expression(c).try_into().unwrap(); e.into_inner() } None => fold_uint_expression_inner(self, size, e), @@ -210,7 +217,7 @@ impl<'ast, T: Field> Folder<'ast, T> for ConstantInliner<'ast, T> { match e { ArrayExpressionInner::Identifier(ref id) => match self.get_constant(id) { Some(c) => { - let e: ArrayExpression<'ast, T> = self.fold_constant(c).try_into().unwrap(); + let e: ArrayExpression<'ast, T> = self.fold_expression(c).try_into().unwrap(); e.into_inner() } None => fold_array_expression_inner(self, ty, e), @@ -227,7 +234,7 @@ impl<'ast, T: Field> Folder<'ast, T> for ConstantInliner<'ast, T> { match e { StructExpressionInner::Identifier(ref id) => match self.get_constant(id) { Some(c) => { - let e: StructExpression<'ast, T> = self.fold_constant(c).try_into().unwrap(); + let e: StructExpression<'ast, T> = self.fold_expression(c).try_into().unwrap(); e.into_inner() } None => fold_struct_expression_inner(self, ty, e), @@ -266,7 +273,7 @@ mod tests { }; let constants: TypedConstantSymbols<_> = vec![( - const_id, + CanonicalConstantIdentifier::new(const_id, "main".into()), TypedConstantSymbol::Here(TypedConstant::new( GType::FieldElement, TypedExpression::FieldElement(FieldElementExpression::Number(Bn128Field::from(1))), @@ -354,7 +361,7 @@ mod tests { }; let constants: TypedConstantSymbols<_> = vec![( - const_id, + CanonicalConstantIdentifier::new(const_id, "main".into()), TypedConstantSymbol::Here(TypedConstant::new( GType::Boolean, TypedExpression::Boolean(BooleanExpression::Value(true)), @@ -443,7 +450,7 @@ mod tests { }; let constants: TypedConstantSymbols<_> = vec![( - const_id, + CanonicalConstantIdentifier::new(const_id, "main".into()), TypedConstantSymbol::Here(TypedConstant::new( GType::Uint(UBitwidth::B32), UExpressionInner::Value(1u128) @@ -544,7 +551,7 @@ mod tests { }; let constants: TypedConstantSymbols<_> = vec![( - const_id, + CanonicalConstantIdentifier::new(const_id, "main".into()), TypedConstantSymbol::Here(TypedConstant::new( GType::FieldElement, TypedExpression::Array( @@ -683,7 +690,7 @@ mod tests { .collect(), constants: vec![ ( - const_a_id, + CanonicalConstantIdentifier::new(const_a_id, "main".into()), TypedConstantSymbol::Here(TypedConstant::new( GType::FieldElement, TypedExpression::FieldElement(FieldElementExpression::Number( @@ -692,7 +699,7 @@ mod tests { )), ), ( - const_b_id, + CanonicalConstantIdentifier::new(const_b_id, "main".into()), TypedConstantSymbol::Here(TypedConstant::new( GType::FieldElement, TypedExpression::FieldElement(FieldElementExpression::Add( @@ -741,7 +748,7 @@ mod tests { .collect(), constants: vec![ ( - const_a_id, + CanonicalConstantIdentifier::new(const_a_id, "main".into()), TypedConstantSymbol::Here(TypedConstant::new( GType::FieldElement, TypedExpression::FieldElement(FieldElementExpression::Number( @@ -750,7 +757,7 @@ mod tests { )), ), ( - const_b_id, + CanonicalConstantIdentifier::new(const_b_id, "main".into()), TypedConstantSymbol::Here(TypedConstant::new( GType::FieldElement, TypedExpression::FieldElement(FieldElementExpression::Number( @@ -802,7 +809,7 @@ mod tests { .into_iter() .collect(), constants: vec![( - foo_const_id, + CanonicalConstantIdentifier::new(foo_const_id, "foo".into()), TypedConstantSymbol::Here(TypedConstant::new( GType::FieldElement, TypedExpression::FieldElement(FieldElementExpression::Number( @@ -834,8 +841,11 @@ mod tests { .into_iter() .collect(), constants: vec![( - foo_const_id, - TypedConstantSymbol::There(OwnedTypedModuleId::from("foo"), foo_const_id), + CanonicalConstantIdentifier::new(foo_const_id, "main".into()), + TypedConstantSymbol::There(CanonicalConstantIdentifier::new( + foo_const_id, + "foo".into(), + )), )] .into_iter() .collect(), @@ -872,7 +882,7 @@ mod tests { .into_iter() .collect(), constants: vec![( - foo_const_id, + CanonicalConstantIdentifier::new(foo_const_id, "main".into()), TypedConstantSymbol::Here(TypedConstant::new( GType::FieldElement, TypedExpression::FieldElement(FieldElementExpression::Number( diff --git a/zokrates_core/src/static_analysis/mod.rs b/zokrates_core/src/static_analysis/mod.rs index 256f19fe..80670eca 100644 --- a/zokrates_core/src/static_analysis/mod.rs +++ b/zokrates_core/src/static_analysis/mod.rs @@ -78,13 +78,13 @@ impl<'ast, T: Field> TypedProgram<'ast, T> { pub fn analyse(self, config: &CompileConfig) -> Result<(ZirProgram<'ast, T>, Abi), Error> { // inline user-defined constants let r = ConstantInliner::inline(self); - println!("{}", r); // isolate branches let r = if config.isolate_branches { Isolator::isolate(r) } else { r }; + // reduce the program to a single function let r = reduce_program(r).map_err(Error::from)?; // generate abi diff --git a/zokrates_core/src/typed_absy/folder.rs b/zokrates_core/src/typed_absy/folder.rs index 2777bea0..22f3e229 100644 --- a/zokrates_core/src/typed_absy/folder.rs +++ b/zokrates_core/src/typed_absy/folder.rs @@ -1024,7 +1024,10 @@ pub fn fold_function_symbol<'ast, T: Field, F: Folder<'ast, T>>( ) -> TypedFunctionSymbol<'ast, T> { match s { TypedFunctionSymbol::Here(fun) => TypedFunctionSymbol::Here(f.fold_function(fun)), - there => there, // by default, do not fold modules recursively + TypedFunctionSymbol::There(key) => { + TypedFunctionSymbol::There(f.fold_declaration_function_key(key)) + } + s => s, } } diff --git a/zokrates_core/src/typed_absy/types.rs b/zokrates_core/src/typed_absy/types.rs index 8ad4c91a..94794fcb 100644 --- a/zokrates_core/src/typed_absy/types.rs +++ b/zokrates_core/src/typed_absy/types.rs @@ -119,7 +119,7 @@ impl<'ast> CanonicalConstantIdentifier<'ast> { pub enum DeclarationConstant<'ast> { Generic(GenericIdentifier<'ast>), Concrete(u32), - Constant(ConstantIdentifier<'ast>), + Constant(CanonicalConstantIdentifier<'ast>), } impl<'ast> From for DeclarationConstant<'ast> { @@ -145,7 +145,7 @@ impl<'ast> fmt::Display for DeclarationConstant<'ast> { match self { DeclarationConstant::Generic(i) => write!(f, "{}", i), DeclarationConstant::Concrete(v) => write!(f, "{}", v), - DeclarationConstant::Constant(v) => write!(f, "{}", v), + DeclarationConstant::Constant(v) => write!(f, "{}/{}", v.module.display(), v.id), } } } @@ -166,7 +166,7 @@ impl<'ast, T> From> for UExpression<'ast, T> { UExpressionInner::Value(v as u128).annotate(UBitwidth::B32) } DeclarationConstant::Constant(v) => { - UExpressionInner::Identifier(Identifier::from(v)).annotate(UBitwidth::B32) + UExpressionInner::Identifier(Identifier::from(v.id)).annotate(UBitwidth::B32) } } } diff --git a/zokrates_core_test/tests/tests/constants/import/origin.zok b/zokrates_core_test/tests/tests/constants/import/origin.zok index e48514af..297e1ac4 100644 --- a/zokrates_core_test/tests/tests/constants/import/origin.zok +++ b/zokrates_core_test/tests/tests/constants/import/origin.zok @@ -1,3 +1,3 @@ -const u32 N = 1 +const u32 N = 1 + 1 def foo(field[N] a) -> bool: return true \ No newline at end of file