From e0b029959d8c34eca7eebfba0ca72ac7096fabe3 Mon Sep 17 00:00:00 2001 From: dark64 Date: Fri, 24 Mar 2023 14:39:08 +0100 Subject: [PATCH] rename substitutor to canonicalizer, add tests --- zokrates_ast/src/zir/canonicalizer.rs | 94 +++++++++++++++++++++++++++ zokrates_ast/src/zir/mod.rs | 2 +- zokrates_ast/src/zir/substitution.rs | 31 --------- zokrates_codegen/src/lib.rs | 6 +- zokrates_interpreter/src/lib.rs | 2 +- 5 files changed, 99 insertions(+), 36 deletions(-) create mode 100644 zokrates_ast/src/zir/canonicalizer.rs delete mode 100644 zokrates_ast/src/zir/substitution.rs diff --git a/zokrates_ast/src/zir/canonicalizer.rs b/zokrates_ast/src/zir/canonicalizer.rs new file mode 100644 index 00000000..ca67d8b5 --- /dev/null +++ b/zokrates_ast/src/zir/canonicalizer.rs @@ -0,0 +1,94 @@ +use super::{Folder, Identifier, Parameter, Variable, ZirAssignee}; +use std::collections::HashMap; +use zokrates_field::Field; + +#[derive(Default)] +pub struct ZirCanonicalizer<'ast> { + identifier_map: HashMap, usize>, +} + +impl<'ast, T: Field> Folder<'ast, T> for ZirCanonicalizer<'ast> { + fn fold_parameter(&mut self, p: Parameter<'ast>) -> Parameter<'ast> { + let new_id = self.identifier_map.len(); + self.identifier_map.insert(p.id.id.clone(), new_id); + + Parameter { + id: Variable::with_id_and_type(Identifier::internal(new_id), p.id._type), + ..p + } + } + fn fold_assignee(&mut self, a: ZirAssignee<'ast>) -> ZirAssignee<'ast> { + let new_id = self.identifier_map.len(); + self.identifier_map.insert(a.id.clone(), new_id); + ZirAssignee::with_id_and_type(Identifier::internal(new_id), a._type) + } + fn fold_name(&mut self, n: Identifier<'ast>) -> Identifier<'ast> { + match self.identifier_map.get(&n) { + Some(v) => Identifier::internal(*v), + None => unreachable!(), + } + } +} + +#[cfg(test)] +mod tests { + use crate::zir::{ + FieldElementExpression, IdentifierExpression, Signature, Type, ZirAssignee, ZirFunction, + ZirStatement, + }; + + use super::*; + use zokrates_field::Bn128Field; + + #[test] + fn canonicalize() { + let func = ZirFunction:: { + arguments: vec![Parameter { + id: Variable::field_element("a"), + private: true, + }], + statements: vec![ + ZirStatement::Definition( + ZirAssignee::field_element("b"), + FieldElementExpression::Identifier(IdentifierExpression::new("a".into())) + .into(), + ), + ZirStatement::Return(vec![FieldElementExpression::Identifier( + IdentifierExpression::new("b".into()), + ) + .into()]), + ], + signature: Signature::new() + .inputs(vec![Type::FieldElement]) + .outputs(vec![Type::FieldElement]), + }; + + let mut canonicalizer = ZirCanonicalizer::default(); + let result = canonicalizer.fold_function(func); + + let expected = ZirFunction:: { + arguments: vec![Parameter { + id: Variable::field_element(Identifier::internal(0usize)), + private: true, + }], + statements: vec![ + ZirStatement::Definition( + ZirAssignee::field_element(Identifier::internal(1usize)), + FieldElementExpression::Identifier(IdentifierExpression::new( + Identifier::internal(0usize), + )) + .into(), + ), + ZirStatement::Return(vec![FieldElementExpression::Identifier( + IdentifierExpression::new(Identifier::internal(1usize)), + ) + .into()]), + ], + signature: Signature::new() + .inputs(vec![Type::FieldElement]) + .outputs(vec![Type::FieldElement]), + }; + + assert_eq!(result, expected); + } +} diff --git a/zokrates_ast/src/zir/mod.rs b/zokrates_ast/src/zir/mod.rs index c2af87b0..08110966 100644 --- a/zokrates_ast/src/zir/mod.rs +++ b/zokrates_ast/src/zir/mod.rs @@ -1,10 +1,10 @@ +pub mod canonicalizer; pub mod folder; mod from_typed; mod identifier; pub mod lqc; mod parameter; pub mod result_folder; -pub mod substitution; pub mod types; mod uint; mod variable; diff --git a/zokrates_ast/src/zir/substitution.rs b/zokrates_ast/src/zir/substitution.rs deleted file mode 100644 index d5709c5d..00000000 --- a/zokrates_ast/src/zir/substitution.rs +++ /dev/null @@ -1,31 +0,0 @@ -use super::{Folder, Identifier, Parameter, Variable, ZirAssignee}; -use std::collections::HashMap; -use zokrates_field::Field; - -#[derive(Default)] -pub struct ZirSubstitutor<'ast> { - substitution: HashMap, usize>, -} - -impl<'ast, T: Field> Folder<'ast, T> for ZirSubstitutor<'ast> { - fn fold_parameter(&mut self, p: Parameter<'ast>) -> Parameter<'ast> { - let new_id = self.substitution.len(); - self.substitution.insert(p.id.id.clone(), new_id); - - Parameter { - id: Variable::with_id_and_type(Identifier::internal(new_id), p.id._type), - ..p - } - } - fn fold_assignee(&mut self, a: ZirAssignee<'ast>) -> ZirAssignee<'ast> { - let new_id = self.substitution.len(); - self.substitution.insert(a.id.clone(), new_id); - ZirAssignee::with_id_and_type(Identifier::internal(new_id), a._type) - } - fn fold_name(&mut self, n: Identifier<'ast>) -> Identifier<'ast> { - match self.substitution.get(&n) { - Some(v) => Identifier::internal(*v), - None => unreachable!(), - } - } -} diff --git a/zokrates_codegen/src/lib.rs b/zokrates_codegen/src/lib.rs index 9ec1f426..a09eb61e 100644 --- a/zokrates_codegen/src/lib.rs +++ b/zokrates_codegen/src/lib.rs @@ -11,7 +11,7 @@ mod utils; use self::utils::flat_expression_from_bits; use zokrates_ast::zir::{ - substitution::ZirSubstitutor, ConditionalExpression, Folder, SelectExpression, ShouldReduce, + canonicalizer::ZirCanonicalizer, ConditionalExpression, Folder, SelectExpression, ShouldReduce, UMetadata, ZirAssemblyStatement, ZirExpressionList, }; use zokrates_interpreter::Interpreter; @@ -2243,8 +2243,8 @@ impl<'ast, T: Field> Flattener<'ast, T> { .map(|assignee| self.use_variable(&assignee)) .collect(); - let mut substitutor = ZirSubstitutor::default(); - let function = substitutor.fold_function(function); + let mut canonicalizer = ZirCanonicalizer::default(); + let function = canonicalizer.fold_function(function); let directive = FlatDirective::new(outputs, Solver::Zir(function), inputs); statements_flattened.push_back(FlatStatement::Directive(directive)); diff --git a/zokrates_interpreter/src/lib.rs b/zokrates_interpreter/src/lib.rs index 73709add..db6ed4a5 100644 --- a/zokrates_interpreter/src/lib.rs +++ b/zokrates_interpreter/src/lib.rs @@ -520,7 +520,7 @@ mod tests { // (field i0) -> i0 * i0 let solvers = vec![Solver::Zir(ZirFunction { arguments: vec![Parameter { - id: Variable::with_id_and_type(id.id.clone(), Type::FieldElement), + id: Variable::field_element(id.id.clone()), private: true, }], statements: vec![ZirStatement::Return(vec![FieldElementExpression::Mult(