diff --git a/access.zok b/access.zok deleted file mode 100644 index 7d01f505..00000000 --- a/access.zok +++ /dev/null @@ -1,2 +0,0 @@ -def main(field[3][2] a, u32 index) -> field[2]: - return a[index] \ No newline at end of file diff --git a/zokrates_core/src/static_analysis/constant_inliner.rs b/zokrates_core/src/static_analysis/constant_inliner.rs index e92bd259..a0408729 100644 --- a/zokrates_core/src/static_analysis/constant_inliner.rs +++ b/zokrates_core/src/static_analysis/constant_inliner.rs @@ -7,20 +7,20 @@ use std::collections::HashMap; use std::convert::TryInto; use zokrates_field::Field; -type ModuleConstants<'ast, T> = +type ProgramConstants<'ast, T> = HashMap, TypedExpression<'ast, T>>>; pub struct ConstantInliner<'ast, T> { modules: TypedModules<'ast, T>, location: OwnedTypedModuleId, - constants: ModuleConstants<'ast, T>, + constants: ProgramConstants<'ast, T>, } impl<'ast, 'a, T: Field> ConstantInliner<'ast, T> { pub fn new( modules: TypedModules<'ast, T>, location: OwnedTypedModuleId, - constants: ModuleConstants<'ast, T>, + constants: ProgramConstants<'ast, T>, ) -> Self { ConstantInliner { modules, @@ -29,7 +29,7 @@ impl<'ast, 'a, T: Field> ConstantInliner<'ast, T> { } } pub fn inline(p: TypedProgram<'ast, T>) -> TypedProgram<'ast, T> { - let constants = HashMap::new(); + let constants = ProgramConstants::new(); let mut inliner = ConstantInliner::new(p.modules.clone(), p.main.clone(), constants); inliner.fold_program(p) } @@ -37,6 +37,7 @@ impl<'ast, 'a, T: Field> ConstantInliner<'ast, T> { fn change_location(&mut self, location: OwnedTypedModuleId) -> OwnedTypedModuleId { let prev = self.location.clone(); self.location = location; + self.constants.entry(self.location.clone()).or_default(); prev } @@ -44,18 +45,24 @@ impl<'ast, 'a, T: Field> ConstantInliner<'ast, T> { self.constants.contains_key(id) } - fn get_constant(&mut self, id: &Identifier) -> Option> { - assert_eq!(id.version, 0); - match id.id { - CoreIdentifier::Call(..) => { - unreachable!("calls identifiers are only available after call inlining") - } - CoreIdentifier::Source(id) => self - .constants - .get(&self.location) - .and_then(|constants| constants.get(&id.into())) - .cloned(), - } + fn get_constant( + &self, + id: &CanonicalConstantIdentifier<'ast>, + ) -> Option> { + self.constants + .get(&id.module) + .and_then(|constants| constants.get(&id.id.into())) + .cloned() + } + + fn get_constant_for_identifier( + &self, + id: &Identifier<'ast>, + ) -> Option> { + self.constants + .get(&self.location) + .and_then(|constants| constants.get(&id)) + .cloned() } } @@ -64,48 +71,42 @@ impl<'ast, T: Field> Folder<'ast, T> for ConstantInliner<'ast, T> { // anytime we encounter a module id, visit the corresponding module if it hasn't been done yet if !self.treated(&id) { let current_m_id = self.change_location(id.clone()); - self.constants.entry(self.location.clone()).or_default(); - let m = self.fold_module(self.modules.get(&id).unwrap().clone()); - + let m = self.modules.remove(&id).unwrap(); + let m = self.fold_module(m); self.modules.insert(id.clone(), m); - self.change_location(current_m_id); } id } fn fold_module(&mut self, m: TypedModule<'ast, T>) -> TypedModule<'ast, T> { - // initialise a constant map for this module - self.constants.entry(self.location.clone()).or_default(); - TypedModule { constants: m .constants .into_iter() .map(|(id, tc)| { - let id = self.fold_canonical_constant_identifier(id); - let constant = match tc { TypedConstantSymbol::There(imported_id) => { + // visit the imported symbol. This triggers visiting the corresponding module if needed let imported_id = self.fold_canonical_constant_identifier(imported_id); - self.constants - .get(&imported_id.module) - .unwrap() - .get(&imported_id.id.into()) - .cloned() + // after that, the constant must have been defined defined in the global map. It is already reduced + // to a literal, so running propagation isn't required + self.get_constant(&imported_id).unwrap() + } + TypedConstantSymbol::Here(c) => { + let non_propagated_constant = fold_constant(self, c).expression; + // folding the constant above only reduces it to an expression containing only literals, not to a single literal. + // propagating with an empty map of constants reduces it to a single literal + Propagator::with_constants(&mut HashMap::default()) + .fold_expression(non_propagated_constant) .unwrap() } - TypedConstantSymbol::Here(c) => fold_constant(self, c).expression, }; - let constant = - Propagator::with_constants(self.constants.get_mut(&self.location).unwrap()) - .fold_expression(constant) - .unwrap(); - + // add to the constant map. The value added is always a single litteral self.constants - .entry(self.location.clone()) - .or_default() + .get_mut(&self.location) + .unwrap() .insert(id.id.into(), constant.clone()); ( @@ -136,22 +137,15 @@ impl<'ast, T: Field> Folder<'ast, T> for ConstantInliner<'ast, T> { ) -> DeclarationConstant<'ast> { match c { // replace constants by their concrete value in declaration types - DeclarationConstant::Constant(id) => DeclarationConstant::Concrete( - match self - .constants - .get(&id.module) - .unwrap() - .get(&id.id.into()) - .cloned() - .unwrap() - { + DeclarationConstant::Constant(id) => { + DeclarationConstant::Concrete(match self.get_constant(&id).unwrap() { TypedExpression::Uint(UExpression { inner: UExpressionInner::Value(v), .. }) => v as u32, _ => unreachable!("all constants should be reduceable to u32 literals"), - }, - ), + }) + } c => c, } } @@ -161,10 +155,12 @@ impl<'ast, T: Field> Folder<'ast, T> for ConstantInliner<'ast, T> { e: FieldElementExpression<'ast, T>, ) -> FieldElementExpression<'ast, T> { match e { - FieldElementExpression::Identifier(ref id) => match self.get_constant(id) { - Some(c) => c.try_into().unwrap(), - None => fold_field_expression(self, e), - }, + FieldElementExpression::Identifier(ref id) => { + match self.get_constant_for_identifier(id) { + Some(c) => c.try_into().unwrap(), + None => fold_field_expression(self, e), + } + } e => fold_field_expression(self, e), } } @@ -174,8 +170,8 @@ impl<'ast, T: Field> Folder<'ast, T> for ConstantInliner<'ast, T> { e: BooleanExpression<'ast, T>, ) -> BooleanExpression<'ast, T> { match e { - BooleanExpression::Identifier(ref id) => match self.get_constant(id) { - Some(c) => self.fold_expression(c).try_into().unwrap(), + BooleanExpression::Identifier(ref id) => match self.get_constant_for_identifier(id) { + Some(c) => c.try_into().unwrap(), None => fold_boolean_expression(self, e), }, e => fold_boolean_expression(self, e), @@ -188,9 +184,9 @@ impl<'ast, T: Field> Folder<'ast, T> for ConstantInliner<'ast, T> { e: UExpressionInner<'ast, T>, ) -> UExpressionInner<'ast, T> { match e { - UExpressionInner::Identifier(ref id) => match self.get_constant(id) { + UExpressionInner::Identifier(ref id) => match self.get_constant_for_identifier(id) { Some(c) => { - let e: UExpression<'ast, T> = self.fold_expression(c).try_into().unwrap(); + let e: UExpression<'ast, T> = c.try_into().unwrap(); e.into_inner() } None => fold_uint_expression_inner(self, size, e), @@ -205,13 +201,15 @@ impl<'ast, T: Field> Folder<'ast, T> for ConstantInliner<'ast, T> { e: ArrayExpressionInner<'ast, T>, ) -> ArrayExpressionInner<'ast, T> { match e { - ArrayExpressionInner::Identifier(ref id) => match self.get_constant(id) { - Some(c) => { - let e: ArrayExpression<'ast, T> = self.fold_expression(c).try_into().unwrap(); - e.into_inner() + ArrayExpressionInner::Identifier(ref id) => { + match self.get_constant_for_identifier(id) { + Some(c) => { + let e: ArrayExpression<'ast, T> = c.try_into().unwrap(); + e.into_inner() + } + None => fold_array_expression_inner(self, ty, e), } - None => fold_array_expression_inner(self, ty, e), - }, + } e => fold_array_expression_inner(self, ty, e), } } @@ -222,9 +220,10 @@ impl<'ast, T: Field> Folder<'ast, T> for ConstantInliner<'ast, T> { e: StructExpressionInner<'ast, T>, ) -> StructExpressionInner<'ast, T> { match e { - StructExpressionInner::Identifier(ref id) => match self.get_constant(id) { + StructExpressionInner::Identifier(ref id) => match self.get_constant_for_identifier(id) + { Some(c) => { - let e: StructExpression<'ast, T> = self.fold_expression(c).try_into().unwrap(); + let e: StructExpression<'ast, T> = c.try_into().unwrap(); e.into_inner() } None => fold_struct_expression_inner(self, ty, e), diff --git a/zokrates_core/src/typed_absy/mod.rs b/zokrates_core/src/typed_absy/mod.rs index 0e5e5486..61012739 100644 --- a/zokrates_core/src/typed_absy/mod.rs +++ b/zokrates_core/src/typed_absy/mod.rs @@ -298,6 +298,7 @@ impl<'ast, T: fmt::Display> fmt::Display for TypedFunction<'ast, T> { #[derive(Clone, PartialEq, Debug)] pub struct TypedConstant<'ast, T> { + // the type is already stored in the TypedExpression, but we want to avoid awkward trait bounds in `fmt::Display` pub ty: Type<'ast, T>, pub expression: TypedExpression<'ast, T>, } @@ -310,6 +311,7 @@ impl<'ast, T> TypedConstant<'ast, T> { impl<'ast, T: fmt::Display> fmt::Display for TypedConstant<'ast, T> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + // using `self.expression.get_type()` would be better here but ends up requiring stronger trait bounds write!(f, "const {}({})", self.ty, self.expression) } }