From e12a4b46d3c1d3a0dc95249af5b32a82f7763367 Mon Sep 17 00:00:00 2001 From: dark64 Date: Wed, 25 Jan 2023 14:03:39 +0100 Subject: [PATCH 01/15] wip --- zokrates_ast/src/common/mod.rs | 2 +- zokrates_ast/src/common/solvers.rs | 34 +++++++++++++++++++++++++++--- zokrates_ast/src/ir/clean.rs | 1 + zokrates_ast/src/ir/folder.rs | 1 + zokrates_ast/src/ir/from_flat.rs | 3 +++ zokrates_ast/src/ir/mod.rs | 17 ++++++++++++--- zokrates_codegen/src/lib.rs | 7 ++++-- 7 files changed, 56 insertions(+), 9 deletions(-) diff --git a/zokrates_ast/src/common/mod.rs b/zokrates_ast/src/common/mod.rs index 13d23bfd..e0337dec 100644 --- a/zokrates_ast/src/common/mod.rs +++ b/zokrates_ast/src/common/mod.rs @@ -10,6 +10,6 @@ pub use self::embed::FlatEmbed; pub use self::error::RuntimeError; pub use self::metadata::SourceMetadata; pub use self::parameter::Parameter; -pub use self::solvers::Solver; +pub use self::solvers::{Solver, ZirSolver}; pub use self::variable::Variable; pub use format_string::FormatString; diff --git a/zokrates_ast/src/common/solvers.rs b/zokrates_ast/src/common/solvers.rs index 9b4f5c90..42318342 100644 --- a/zokrates_ast/src/common/solvers.rs +++ b/zokrates_ast/src/common/solvers.rs @@ -2,6 +2,22 @@ use crate::zir::ZirFunction; use serde::{Deserialize, Serialize}; use std::fmt; +#[derive(Clone, PartialEq, Debug, Serialize, Deserialize, Hash, Eq)] +pub enum ZirSolver<'ast, T> { + #[serde(borrow)] + Function(ZirFunction<'ast, T>), + Indexed(usize, usize), +} + +impl<'ast, T> fmt::Display for ZirSolver<'ast, T> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + ZirSolver::Function(_) => write!(f, "Zir(..)"), + ZirSolver::Indexed(index, n) => write!(f, "Zir@{}({})", index, n), + } + } +} + #[derive(Clone, PartialEq, Debug, Serialize, Deserialize, Hash, Eq)] pub enum Solver<'ast, T> { ConditionEq, @@ -13,13 +29,22 @@ pub enum Solver<'ast, T> { ShaCh, EuclideanDiv, #[serde(borrow)] - Zir(ZirFunction<'ast, T>), + Zir(ZirSolver<'ast, T>), #[cfg(feature = "bellman")] Sha256Round, #[cfg(feature = "ark")] SnarkVerifyBls12377(usize), } +impl<'ast, T> Solver<'ast, T> { + pub fn zir_function(function: ZirFunction<'ast, T>) -> Self { + Solver::Zir(ZirSolver::Function(function)) + } + pub fn zir_indexed(index: usize, argument_count: usize) -> Self { + Solver::Zir(ZirSolver::Indexed(index, argument_count)) + } +} + impl<'ast, T> fmt::Display for Solver<'ast, T> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { @@ -31,7 +56,7 @@ impl<'ast, T> fmt::Display for Solver<'ast, T> { Solver::ShaAndXorAndXorAnd => write!(f, "ShaAndXorAndXorAnd"), Solver::ShaCh => write!(f, "ShaCh"), Solver::EuclideanDiv => write!(f, "EuclideanDiv"), - Solver::Zir(_) => write!(f, "Zir(..)"), + Solver::Zir(s) => write!(f, "{}", s), #[cfg(feature = "bellman")] Solver::Sha256Round => write!(f, "Sha256Round"), #[cfg(feature = "ark")] @@ -51,7 +76,10 @@ impl<'ast, T> Solver<'ast, T> { Solver::ShaAndXorAndXorAnd => (3, 1), Solver::ShaCh => (3, 1), Solver::EuclideanDiv => (2, 2), - Solver::Zir(f) => (f.arguments.len(), 1), + Solver::Zir(s) => match s { + ZirSolver::Function(f) => (f.arguments.len(), 1), + ZirSolver::Indexed(_, n) => (*n, 1), + }, #[cfg(feature = "bellman")] Solver::Sha256Round => (768, 26935), #[cfg(feature = "ark")] diff --git a/zokrates_ast/src/ir/clean.rs b/zokrates_ast/src/ir/clean.rs index b4fb8f44..998e3298 100644 --- a/zokrates_ast/src/ir/clean.rs +++ b/zokrates_ast/src/ir/clean.rs @@ -14,6 +14,7 @@ impl<'ast, T: Field, I: IntoIterator>> ProgIterator<'a .statements .into_iter() .flat_map(|s| Cleaner::default().fold_statement(s)), + solvers: self.solvers, } } } diff --git a/zokrates_ast/src/ir/folder.rs b/zokrates_ast/src/ir/folder.rs index 6e67c15d..99466add 100644 --- a/zokrates_ast/src/ir/folder.rs +++ b/zokrates_ast/src/ir/folder.rs @@ -50,6 +50,7 @@ pub fn fold_program<'ast, T: Field, F: Folder<'ast, T>>( .flat_map(|s| f.fold_statement(s)) .collect(), return_count: p.return_count, + solvers: p.solvers, } } diff --git a/zokrates_ast/src/ir/from_flat.rs b/zokrates_ast/src/ir/from_flat.rs index fc961cd8..7932f74b 100644 --- a/zokrates_ast/src/ir/from_flat.rs +++ b/zokrates_ast/src/ir/from_flat.rs @@ -2,6 +2,8 @@ use crate::flat::{FlatDirective, FlatExpression, FlatProgIterator, FlatStatement use crate::ir::{Directive, LinComb, ProgIterator, QuadComb, Statement}; use zokrates_field::Field; +use super::SolverMap; + impl QuadComb { fn from_flat_expression>>(flat_expression: U) -> QuadComb { let flat_expression = flat_expression.into(); @@ -24,6 +26,7 @@ pub fn from_flat<'ast, T: Field, I: IntoIterator>> statements: flat_prog_iterator.statements.into_iter().map(Into::into), arguments: flat_prog_iterator.arguments, return_count: flat_prog_iterator.return_count, + solvers: SolverMap::default(), } } diff --git a/zokrates_ast/src/ir/mod.rs b/zokrates_ast/src/ir/mod.rs index 78b48f80..de73b683 100644 --- a/zokrates_ast/src/ir/mod.rs +++ b/zokrates_ast/src/ir/mod.rs @@ -2,7 +2,7 @@ use crate::common::FormatString; use crate::typed::ConcreteType; use derivative::Derivative; use serde::{Deserialize, Serialize}; -use std::collections::BTreeSet; +use std::collections::{BTreeSet, HashMap}; use std::fmt; use std::hash::Hash; use zokrates_field::Field; @@ -22,8 +22,8 @@ pub use self::expression::{CanonicalLinComb, LinComb}; pub use self::serialize::ProgEnum; pub use crate::common::Parameter; pub use crate::common::RuntimeError; -pub use crate::common::Solver; pub use crate::common::Variable; +pub use crate::common::{Solver, ZirSolver}; pub use self::witness::Witness; @@ -124,20 +124,29 @@ impl<'ast, T: Field> fmt::Display for Statement<'ast, T> { } pub type Prog<'ast, T> = ProgIterator<'ast, T, Vec>>; +pub type SolverMap<'ast, T> = HashMap>; #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq, Default)] pub struct ProgIterator<'ast, T, I: IntoIterator>> { pub arguments: Vec, pub return_count: usize, pub statements: I, + #[serde(borrow)] + pub solvers: SolverMap<'ast, T>, } impl<'ast, T, I: IntoIterator>> ProgIterator<'ast, T, I> { - pub fn new(arguments: Vec, statements: I, return_count: usize) -> Self { + pub fn new( + arguments: Vec, + statements: I, + return_count: usize, + solvers: SolverMap<'ast, T>, + ) -> Self { Self { arguments, return_count, statements, + solvers, } } @@ -146,6 +155,7 @@ impl<'ast, T, I: IntoIterator>> ProgIterator<'ast, T, statements: self.statements.into_iter().collect::>(), arguments: self.arguments, return_count: self.return_count, + solvers: self.solvers, } } @@ -192,6 +202,7 @@ impl<'ast, T> Prog<'ast, T> { statements: self.statements.into_iter(), arguments: self.arguments, return_count: self.return_count, + solvers: self.solvers, } } } diff --git a/zokrates_codegen/src/lib.rs b/zokrates_codegen/src/lib.rs index e149a968..795fa00a 100644 --- a/zokrates_codegen/src/lib.rs +++ b/zokrates_codegen/src/lib.rs @@ -24,7 +24,7 @@ use zokrates_ast::common::embed::*; use zokrates_ast::common::FlatEmbed; use zokrates_ast::common::{RuntimeError, Variable}; use zokrates_ast::flat::*; -use zokrates_ast::ir::Solver; +use zokrates_ast::ir::{Solver, ZirSolver}; use zokrates_ast::zir::types::{Type, UBitwidth}; use zokrates_ast::zir::{ BooleanExpression, Conditional, FieldElementExpression, Identifier, Parameter as ZirParameter, @@ -2241,7 +2241,10 @@ impl<'ast, T: Field> Flattener<'ast, T> { .into_iter() .map(|assignee| self.use_variable(&assignee)) .collect(); - let directive = FlatDirective::new(outputs, Solver::Zir(function), inputs); + + let solver = Solver::Zir(ZirSolver::Function(function)); + let directive = FlatDirective::new(outputs, solver, inputs); + statements_flattened.push_back(FlatStatement::Directive(directive)); } ZirAssemblyStatement::Constraint(lhs, rhs, metadata) => { From 49889111836b9b6f7c3d65ebda2d058af8a8552f Mon Sep 17 00:00:00 2001 From: dark64 Date: Thu, 26 Jan 2023 15:18:31 +0100 Subject: [PATCH 02/15] optimize zir solvers by indexing --- .../src/flatten_complex_types.rs | 4 +- zokrates_ast/src/common/mod.rs | 2 +- zokrates_ast/src/common/solvers.rs | 37 +---- zokrates_ast/src/ir/from_flat.rs | 4 +- zokrates_ast/src/ir/mod.rs | 10 +- zokrates_ast/src/ir/serialize.rs | 142 ++++++++++-------- zokrates_ast/src/ir/solver_indexer.rs | 50 ++++++ zokrates_ast/src/zir/identifier.rs | 12 +- zokrates_ast/src/zir/mod.rs | 1 + zokrates_ast/src/zir/substitution.rs | 26 ++++ zokrates_codegen/src/lib.rs | 32 +++- zokrates_core/src/optimizer/duplicate.rs | 20 ++- zokrates_core/src/optimizer/mod.rs | 1 + zokrates_core/src/optimizer/redefinition.rs | 2 +- zokrates_interpreter/src/lib.rs | 11 +- 15 files changed, 235 insertions(+), 119 deletions(-) create mode 100644 zokrates_ast/src/ir/solver_indexer.rs create mode 100644 zokrates_ast/src/zir/substitution.rs diff --git a/zokrates_analysis/src/flatten_complex_types.rs b/zokrates_analysis/src/flatten_complex_types.rs index f4b81d8e..4fa123bc 100644 --- a/zokrates_analysis/src/flatten_complex_types.rs +++ b/zokrates_analysis/src/flatten_complex_types.rs @@ -1,4 +1,4 @@ -use std::collections::HashMap; +use std::collections::BTreeMap; use std::convert::{TryFrom, TryInto}; use std::marker::PhantomData; use zokrates_ast::typed::types::{ConcreteArrayType, IntoType, UBitwidth}; @@ -481,7 +481,7 @@ impl<'ast, T: Field> Flattener { // } #[derive(Default)] pub struct ArgumentFinder<'ast, T> { - pub identifiers: HashMap, zir::Type>, + pub identifiers: BTreeMap, zir::Type>, _phantom: PhantomData, } diff --git a/zokrates_ast/src/common/mod.rs b/zokrates_ast/src/common/mod.rs index e0337dec..13d23bfd 100644 --- a/zokrates_ast/src/common/mod.rs +++ b/zokrates_ast/src/common/mod.rs @@ -10,6 +10,6 @@ pub use self::embed::FlatEmbed; pub use self::error::RuntimeError; pub use self::metadata::SourceMetadata; pub use self::parameter::Parameter; -pub use self::solvers::{Solver, ZirSolver}; +pub use self::solvers::Solver; pub use self::variable::Variable; pub use format_string::FormatString; diff --git a/zokrates_ast/src/common/solvers.rs b/zokrates_ast/src/common/solvers.rs index 42318342..0551bb6d 100644 --- a/zokrates_ast/src/common/solvers.rs +++ b/zokrates_ast/src/common/solvers.rs @@ -2,22 +2,6 @@ use crate::zir::ZirFunction; use serde::{Deserialize, Serialize}; use std::fmt; -#[derive(Clone, PartialEq, Debug, Serialize, Deserialize, Hash, Eq)] -pub enum ZirSolver<'ast, T> { - #[serde(borrow)] - Function(ZirFunction<'ast, T>), - Indexed(usize, usize), -} - -impl<'ast, T> fmt::Display for ZirSolver<'ast, T> { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self { - ZirSolver::Function(_) => write!(f, "Zir(..)"), - ZirSolver::Indexed(index, n) => write!(f, "Zir@{}({})", index, n), - } - } -} - #[derive(Clone, PartialEq, Debug, Serialize, Deserialize, Hash, Eq)] pub enum Solver<'ast, T> { ConditionEq, @@ -29,22 +13,14 @@ pub enum Solver<'ast, T> { ShaCh, EuclideanDiv, #[serde(borrow)] - Zir(ZirSolver<'ast, T>), + Zir(ZirFunction<'ast, T>), + IndexedCall(usize, usize), #[cfg(feature = "bellman")] Sha256Round, #[cfg(feature = "ark")] SnarkVerifyBls12377(usize), } -impl<'ast, T> Solver<'ast, T> { - pub fn zir_function(function: ZirFunction<'ast, T>) -> Self { - Solver::Zir(ZirSolver::Function(function)) - } - pub fn zir_indexed(index: usize, argument_count: usize) -> Self { - Solver::Zir(ZirSolver::Indexed(index, argument_count)) - } -} - impl<'ast, T> fmt::Display for Solver<'ast, T> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { @@ -56,7 +32,8 @@ impl<'ast, T> fmt::Display for Solver<'ast, T> { Solver::ShaAndXorAndXorAnd => write!(f, "ShaAndXorAndXorAnd"), Solver::ShaCh => write!(f, "ShaCh"), Solver::EuclideanDiv => write!(f, "EuclideanDiv"), - Solver::Zir(s) => write!(f, "{}", s), + Solver::Zir(_) => write!(f, "Zir(..)"), + Solver::IndexedCall(index, argc) => write!(f, "IndexedCall@{}({})", index, argc), #[cfg(feature = "bellman")] Solver::Sha256Round => write!(f, "Sha256Round"), #[cfg(feature = "ark")] @@ -76,10 +53,8 @@ impl<'ast, T> Solver<'ast, T> { Solver::ShaAndXorAndXorAnd => (3, 1), Solver::ShaCh => (3, 1), Solver::EuclideanDiv => (2, 2), - Solver::Zir(s) => match s { - ZirSolver::Function(f) => (f.arguments.len(), 1), - ZirSolver::Indexed(_, n) => (*n, 1), - }, + Solver::Zir(f) => (f.arguments.len(), 1), + Solver::IndexedCall(_, n) => (*n, 1), #[cfg(feature = "bellman")] Solver::Sha256Round => (768, 26935), #[cfg(feature = "ark")] diff --git a/zokrates_ast/src/ir/from_flat.rs b/zokrates_ast/src/ir/from_flat.rs index 7932f74b..3404c6c4 100644 --- a/zokrates_ast/src/ir/from_flat.rs +++ b/zokrates_ast/src/ir/from_flat.rs @@ -2,8 +2,6 @@ use crate::flat::{FlatDirective, FlatExpression, FlatProgIterator, FlatStatement use crate::ir::{Directive, LinComb, ProgIterator, QuadComb, Statement}; use zokrates_field::Field; -use super::SolverMap; - impl QuadComb { fn from_flat_expression>>(flat_expression: U) -> QuadComb { let flat_expression = flat_expression.into(); @@ -26,7 +24,7 @@ pub fn from_flat<'ast, T: Field, I: IntoIterator>> statements: flat_prog_iterator.statements.into_iter().map(Into::into), arguments: flat_prog_iterator.arguments, return_count: flat_prog_iterator.return_count, - solvers: SolverMap::default(), + solvers: vec![], } } diff --git a/zokrates_ast/src/ir/mod.rs b/zokrates_ast/src/ir/mod.rs index de73b683..b401b757 100644 --- a/zokrates_ast/src/ir/mod.rs +++ b/zokrates_ast/src/ir/mod.rs @@ -2,7 +2,7 @@ use crate::common::FormatString; use crate::typed::ConcreteType; use derivative::Derivative; use serde::{Deserialize, Serialize}; -use std::collections::{BTreeSet, HashMap}; +use std::collections::BTreeSet; use std::fmt; use std::hash::Hash; use zokrates_field::Field; @@ -14,6 +14,7 @@ pub mod folder; pub mod from_flat; mod serialize; pub mod smtlib2; +mod solver_indexer; pub mod visitor; mod witness; @@ -22,8 +23,8 @@ pub use self::expression::{CanonicalLinComb, LinComb}; pub use self::serialize::ProgEnum; pub use crate::common::Parameter; pub use crate::common::RuntimeError; +pub use crate::common::Solver; pub use crate::common::Variable; -pub use crate::common::{Solver, ZirSolver}; pub use self::witness::Witness; @@ -124,7 +125,6 @@ impl<'ast, T: Field> fmt::Display for Statement<'ast, T> { } pub type Prog<'ast, T> = ProgIterator<'ast, T, Vec>>; -pub type SolverMap<'ast, T> = HashMap>; #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq, Default)] pub struct ProgIterator<'ast, T, I: IntoIterator>> { @@ -132,7 +132,7 @@ pub struct ProgIterator<'ast, T, I: IntoIterator>> { pub return_count: usize, pub statements: I, #[serde(borrow)] - pub solvers: SolverMap<'ast, T>, + pub solvers: Vec>, } impl<'ast, T, I: IntoIterator>> ProgIterator<'ast, T, I> { @@ -140,7 +140,7 @@ impl<'ast, T, I: IntoIterator>> ProgIterator<'ast, T, arguments: Vec, statements: I, return_count: usize, - solvers: SolverMap<'ast, T>, + solvers: Vec>, ) -> Self { Self { arguments, diff --git a/zokrates_ast/src/ir/serialize.rs b/zokrates_ast/src/ir/serialize.rs index 09d00390..3bd40366 100644 --- a/zokrates_ast/src/ir/serialize.rs +++ b/zokrates_ast/src/ir/serialize.rs @@ -1,14 +1,18 @@ -use crate::ir::check::UnconstrainedVariableDetector; +use crate::{ + ir::{check::UnconstrainedVariableDetector, solver_indexer::SolverIndexer}, + Solver, +}; use super::{ProgIterator, Statement}; +use serde::Deserialize; use serde_cbor::{self, StreamDeserializer}; -use std::io::{Read, Write}; +use std::io::{Read, Seek, Write}; use zokrates_field::*; type DynamicError = Box; const ZOKRATES_MAGIC: &[u8; 4] = &[0x5a, 0x4f, 0x4b, 0]; -const ZOKRATES_VERSION_2: &[u8; 4] = &[0, 0, 0, 2]; +const ZOKRATES_VERSION_3: &[u8; 4] = &[0, 0, 0, 3]; #[derive(PartialEq, Eq, Debug)] pub enum ProgEnum< @@ -61,17 +65,21 @@ impl< impl<'ast, T: Field, I: IntoIterator>> ProgIterator<'ast, T, I> { /// serialize a program iterator, returning the number of constraints serialized /// Note that we only return constraints, not other statements such as directives - pub fn serialize(self, mut w: W) -> Result { + pub fn serialize(self, mut w: W) -> Result { use super::folder::Folder; w.write_all(ZOKRATES_MAGIC)?; - w.write_all(ZOKRATES_VERSION_2)?; + w.write_all(ZOKRATES_VERSION_3)?; w.write_all(&T::id())?; + let solver_list_ptr_offset = w.stream_position()?; + w.write_all(&[0u8; std::mem::size_of::()])?; // reserve 8 bytes + serde_cbor::to_writer(&mut w, &self.arguments)?; serde_cbor::to_writer(&mut w, &self.return_count)?; let mut unconstrained_variable_detector = UnconstrainedVariableDetector::new(&self); + let mut solver_indexer: SolverIndexer<'ast, T> = SolverIndexer::default(); let statements = self.statements.into_iter(); @@ -80,12 +88,22 @@ impl<'ast, T: Field, I: IntoIterator>> ProgIterator<'a if matches!(s, Statement::Constraint(..)) { count += 1; } - let s = unconstrained_variable_detector.fold_statement(s); + let s: Vec> = solver_indexer + .fold_statement(s) + .into_iter() + .flat_map(|s| unconstrained_variable_detector.fold_statement(s)) + .collect(); for s in s { serde_cbor::to_writer(&mut w, &s)?; } } + let solver_list_offset = w.stream_position()?; + serde_cbor::to_writer(&mut w, &solver_indexer.solvers)?; + + w.seek(std::io::SeekFrom::Start(solver_list_ptr_offset))?; + w.write_all(&solver_list_offset.to_le_bytes())?; + unconstrained_variable_detector .finalize() .map(|_| count) @@ -103,11 +121,11 @@ impl<'de, R: serde_cbor::de::Read<'de>, T: serde::Deserialize<'de>> Iterator type Item = T; fn next(&mut self) -> Option { - self.s.next().transpose().unwrap() + self.s.next().and_then(|v| v.ok()) } } -impl<'de, R: Read> +impl<'de, R: Read + Seek> ProgEnum< 'de, UnwrappedStreamDeserializer<'de, serde_cbor::de::IoRead, Statement<'de, Bls12_381Field>>, @@ -128,104 +146,108 @@ impl<'de, R: Read> r.read_exact(&mut version) .map_err(|_| String::from("Cannot read version"))?; - if &version == ZOKRATES_VERSION_2 { + if &version == ZOKRATES_VERSION_3 { // Check the curve identifier, deserializing accordingly let mut curve = [0; 4]; r.read_exact(&mut curve) .map_err(|_| String::from("Cannot read curve identifier"))?; - use serde::de::Deserializer; - let mut p = serde_cbor::Deserializer::from_reader(r); + let mut buffer = [0u8; std::mem::size_of::()]; + r.read_exact(&mut buffer) + .map_err(|_| String::from("Cannot read solver list pointer"))?; - struct ArgumentsVisitor; + let solver_list_offset = u64::from_le_bytes(buffer); - impl<'de> serde::de::Visitor<'de> for ArgumentsVisitor { - type Value = Vec; - fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { - formatter.write_str("seq of flat param") - } + let (arguments, return_count) = { + let mut p = serde_cbor::Deserializer::from_reader(r.by_ref()); - fn visit_seq(self, mut seq: A) -> Result - where - A: serde::de::SeqAccess<'de>, - { - let mut res = vec![]; - while let Some(e) = seq.next_element().unwrap() { - res.push(e); - } - Ok(res) - } - } + let arguments: Vec = Vec::deserialize(&mut p) + .map_err(|_| String::from("Cannot read parameters"))?; - let arguments = p.deserialize_seq(ArgumentsVisitor).unwrap(); + let return_count = usize::deserialize(&mut p) + .map_err(|_| String::from("Cannot read return count"))?; - struct ReturnCountVisitor; + (arguments, return_count) + }; - impl<'de> serde::de::Visitor<'de> for ReturnCountVisitor { - type Value = usize; - fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { - formatter.write_str("usize") - } - - fn visit_u32(self, v: u32) -> Result - where - E: serde::de::Error, - { - Ok(v as usize) - } - - fn visit_u8(self, v: u8) -> Result - where - E: serde::de::Error, - { - Ok(v as usize) - } - - fn visit_u16(self, v: u16) -> Result - where - E: serde::de::Error, - { - Ok(v as usize) - } - } - - let return_count = p.deserialize_u32(ReturnCountVisitor).unwrap(); + let statement_offset = r.stream_position().unwrap(); + r.seek(std::io::SeekFrom::Start(solver_list_offset)) + .unwrap(); match curve { m if m == Bls12_381Field::id() => { - let s = p.into_iter::>(); + let solvers: Vec> = { + let mut p = serde_cbor::Deserializer::from_reader(r.by_ref()); + Vec::deserialize(&mut p) + .map_err(|_| String::from("Cannot read solver map"))? + }; + r.seek(std::io::SeekFrom::Start(statement_offset)).unwrap(); + + let p = serde_cbor::Deserializer::from_reader(r); + let s = p.into_iter::>(); Ok(ProgEnum::Bls12_381Program(ProgIterator::new( arguments, UnwrappedStreamDeserializer { s }, return_count, + solvers, ))) } m if m == Bn128Field::id() => { + let solvers: Vec> = { + let mut p = serde_cbor::Deserializer::from_reader(r.by_ref()); + Vec::deserialize(&mut p) + .map_err(|_| String::from("Cannot read solver map"))? + }; + + r.seek(std::io::SeekFrom::Start(statement_offset)).unwrap(); + + let p = serde_cbor::Deserializer::from_reader(r); let s = p.into_iter::>(); Ok(ProgEnum::Bn128Program(ProgIterator::new( arguments, UnwrappedStreamDeserializer { s }, return_count, + solvers, ))) } m if m == Bls12_377Field::id() => { + let solvers: Vec> = { + let mut p = serde_cbor::Deserializer::from_reader(r.by_ref()); + Vec::deserialize(&mut p) + .map_err(|_| String::from("Cannot read solver map"))? + }; + + r.seek(std::io::SeekFrom::Start(statement_offset)).unwrap(); + + let p = serde_cbor::Deserializer::from_reader(r); let s = p.into_iter::>(); Ok(ProgEnum::Bls12_377Program(ProgIterator::new( arguments, UnwrappedStreamDeserializer { s }, return_count, + solvers, ))) } m if m == Bw6_761Field::id() => { + let solvers: Vec> = { + let mut p = serde_cbor::Deserializer::from_reader(r.by_ref()); + Vec::deserialize(&mut p) + .map_err(|_| String::from("Cannot read solver map"))? + }; + + r.seek(std::io::SeekFrom::Start(statement_offset)).unwrap(); + + let p = serde_cbor::Deserializer::from_reader(r); let s = p.into_iter::>(); Ok(ProgEnum::Bw6_761Program(ProgIterator::new( arguments, UnwrappedStreamDeserializer { s }, return_count, + solvers, ))) } _ => Err(String::from("Unknown curve identifier")), diff --git a/zokrates_ast/src/ir/solver_indexer.rs b/zokrates_ast/src/ir/solver_indexer.rs new file mode 100644 index 00000000..59393d25 --- /dev/null +++ b/zokrates_ast/src/ir/solver_indexer.rs @@ -0,0 +1,50 @@ +use crate::ir::folder::Folder; +use crate::ir::Directive; +use crate::ir::Solver; +use crate::zir::ZirFunction; +use std::collections::hash_map::DefaultHasher; +use std::collections::hash_map::Entry; +use std::collections::HashMap; +use zokrates_field::Field; + +type Hash = u64; + +fn hash(f: &ZirFunction) -> Hash { + use std::hash::Hash; + use std::hash::Hasher; + let mut hasher = DefaultHasher::new(); + f.hash(&mut hasher); + hasher.finish() +} + +#[derive(Debug, Default)] +pub struct SolverIndexer<'ast, T> { + pub solvers: Vec>, + pub index_map: HashMap, +} + +impl<'ast, T: Field> Folder<'ast, T> for SolverIndexer<'ast, T> { + fn fold_directive(&mut self, d: Directive<'ast, T>) -> Directive<'ast, T> { + match d.solver { + Solver::Zir(f) => { + let argc = f.arguments.len(); + let h = hash(&f); + let index = match self.index_map.entry(h) { + Entry::Occupied(v) => *v.get(), + Entry::Vacant(entry) => { + let index = self.solvers.len(); + entry.insert(index); + self.solvers.push(Solver::Zir(f)); + index + } + }; + Directive { + inputs: d.inputs, + outputs: d.outputs, + solver: Solver::IndexedCall(index, argc), + } + } + _ => d, + } + } +} diff --git a/zokrates_ast/src/zir/identifier.rs b/zokrates_ast/src/zir/identifier.rs index 249b2630..bc60f566 100644 --- a/zokrates_ast/src/zir/identifier.rs +++ b/zokrates_ast/src/zir/identifier.rs @@ -4,13 +4,14 @@ use std::fmt; use crate::typed::Identifier as CoreIdentifier; -#[derive(Debug, PartialEq, Clone, Hash, Eq, Serialize, Deserialize)] +#[derive(Debug, PartialEq, Clone, Hash, Eq, PartialOrd, Ord, Serialize, Deserialize)] pub enum Identifier<'ast> { #[serde(borrow)] Source(SourceIdentifier<'ast>), + Internal(String), } -#[derive(Debug, PartialEq, Clone, Hash, Eq, Serialize, Deserialize)] +#[derive(Debug, PartialEq, Clone, Hash, Eq, PartialOrd, Ord, Serialize, Deserialize)] pub enum SourceIdentifier<'ast> { #[serde(borrow)] Basic(CoreIdentifier<'ast>), @@ -34,10 +35,17 @@ impl<'ast> fmt::Display for Identifier<'ast> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { Identifier::Source(s) => write!(f, "{}", s), + Identifier::Internal(s) => write!(f, "{}", s), } } } +impl<'ast> From for Identifier<'ast> { + fn from(id: String) -> Identifier<'ast> { + Identifier::Internal(id) + } +} + // this is only used in tests but somehow cfg(test) does not work impl<'ast> From<&'ast str> for Identifier<'ast> { fn from(id: &'ast str) -> Identifier<'ast> { diff --git a/zokrates_ast/src/zir/mod.rs b/zokrates_ast/src/zir/mod.rs index 60dc1467..c2af87b0 100644 --- a/zokrates_ast/src/zir/mod.rs +++ b/zokrates_ast/src/zir/mod.rs @@ -4,6 +4,7 @@ mod identifier; pub mod lqc; mod parameter; pub mod result_folder; +pub mod substitution; pub mod types; mod uint; mod variable; diff --git a/zokrates_ast/src/zir/substitution.rs b/zokrates_ast/src/zir/substitution.rs new file mode 100644 index 00000000..89a8f82d --- /dev/null +++ b/zokrates_ast/src/zir/substitution.rs @@ -0,0 +1,26 @@ +use super::{Folder, Identifier}; +use std::{collections::HashMap, marker::PhantomData}; +use zokrates_field::Field; + +pub struct ZirSubstitutor<'a, 'ast, T> { + substitution: &'a HashMap, Identifier<'ast>>, + _phantom: PhantomData, +} + +impl<'a, 'ast, T: Field> ZirSubstitutor<'a, 'ast, T> { + pub fn new(substitution: &'a HashMap, Identifier<'ast>>) -> Self { + ZirSubstitutor { + substitution, + _phantom: PhantomData::default(), + } + } +} + +impl<'a, 'ast, T: Field> Folder<'ast, T> for ZirSubstitutor<'a, 'ast, T> { + fn fold_name(&mut self, n: Identifier<'ast>) -> Identifier<'ast> { + match self.substitution.get(&n) { + Some(v) => v.clone(), + None => n, + } + } +} diff --git a/zokrates_codegen/src/lib.rs b/zokrates_codegen/src/lib.rs index 795fa00a..ea954b89 100644 --- a/zokrates_codegen/src/lib.rs +++ b/zokrates_codegen/src/lib.rs @@ -11,8 +11,8 @@ mod utils; use self::utils::flat_expression_from_bits; use zokrates_ast::zir::{ - ConditionalExpression, SelectExpression, ShouldReduce, UMetadata, ZirAssemblyStatement, - ZirExpressionList, + substitution::ZirSubstitutor, ConditionalExpression, Folder, SelectExpression, ShouldReduce, + UMetadata, ZirAssemblyStatement, ZirExpressionList, }; use zokrates_interpreter::Interpreter; @@ -24,7 +24,7 @@ use zokrates_ast::common::embed::*; use zokrates_ast::common::FlatEmbed; use zokrates_ast::common::{RuntimeError, Variable}; use zokrates_ast::flat::*; -use zokrates_ast::ir::{Solver, ZirSolver}; +use zokrates_ast::ir::Solver; use zokrates_ast::zir::types::{Type, UBitwidth}; use zokrates_ast::zir::{ BooleanExpression, Conditional, FieldElementExpression, Identifier, Parameter as ZirParameter, @@ -1885,7 +1885,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { // constants do not require directives if let Some(FlatExpression::Number(ref x)) = e.field { - let bits: Vec<_> = Interpreter::execute_solver(&Solver::bits(to), &[x.clone()]) + let bits: Vec<_> = Interpreter::execute_solver(&Solver::bits(to), &[x.clone()], &[]) .unwrap() .into_iter() .map(FlatExpression::Number) @@ -2237,12 +2237,34 @@ impl<'ast, T: Field> Flattener<'ast, T> { .cloned() .map(|p| self.layout.get(&p.id.id).cloned().unwrap().into()) .collect(); + let outputs: Vec = assignees .into_iter() .map(|assignee| self.use_variable(&assignee)) .collect(); - let solver = Solver::Zir(ZirSolver::Function(function)); + let mut substitution_map = HashMap::default(); + for (index, p) in function.arguments.iter().enumerate() { + let new_id = format!("i{}", index).into(); + substitution_map.insert(p.id.id.clone(), new_id); + } + + let mut substitutor = ZirSubstitutor::new(&substitution_map); + let function = ZirFunction { + arguments: function + .arguments + .into_iter() + .map(|p| substitutor.fold_parameter(p)) + .collect(), + statements: function + .statements + .into_iter() + .flat_map(|s| substitutor.fold_statement(s)) + .collect(), + signature: function.signature, + }; + + let solver = Solver::Zir(function); let directive = FlatDirective::new(outputs, solver, inputs); statements_flattened.push_back(FlatStatement::Directive(directive)); diff --git a/zokrates_core/src/optimizer/duplicate.rs b/zokrates_core/src/optimizer/duplicate.rs index 664cfc2d..a08a3f73 100644 --- a/zokrates_core/src/optimizer/duplicate.rs +++ b/zokrates_core/src/optimizer/duplicate.rs @@ -39,14 +39,18 @@ impl<'ast, T: Field> Folder<'ast, T> for DuplicateOptimizer { } fn fold_statement(&mut self, s: Statement<'ast, T>) -> Vec> { - let hashed = hash(&s); - let result = match self.seen.get(&hashed) { - Some(_) => vec![], - None => vec![s], - }; - - self.seen.insert(hashed); - result + match s { + Statement::Block(s) => s.into_iter().flat_map(|s| self.fold_statement(s)).collect(), + s => { + let hashed = hash(&s); + let result = match self.seen.get(&hashed) { + Some(_) => vec![], + None => vec![s], + }; + self.seen.insert(hashed); + result + } + } } } diff --git a/zokrates_core/src/optimizer/mod.rs b/zokrates_core/src/optimizer/mod.rs index 1f94740a..95e9f25c 100644 --- a/zokrates_core/src/optimizer/mod.rs +++ b/zokrates_core/src/optimizer/mod.rs @@ -54,6 +54,7 @@ pub fn optimize<'ast, T: Field, I: IntoIterator>>( .flat_map(move |s| directive_optimizer.fold_statement(s)) .flat_map(move |s| duplicate_optimizer.fold_statement(s)), return_count: p.return_count, + solvers: p.solvers, }; log::debug!("Done"); diff --git a/zokrates_core/src/optimizer/redefinition.rs b/zokrates_core/src/optimizer/redefinition.rs index b0877fd0..bc7fc279 100644 --- a/zokrates_core/src/optimizer/redefinition.rs +++ b/zokrates_core/src/optimizer/redefinition.rs @@ -146,7 +146,7 @@ impl RedefinitionOptimizer { // unwrap inputs to their constant value let inputs: Vec<_> = inputs.into_iter().map(|i| i.unwrap()).collect(); // run the solver - let outputs = Interpreter::execute_solver(&d.solver, &inputs).unwrap(); + let outputs = Interpreter::execute_solver(&d.solver, &inputs, &[]).unwrap(); assert_eq!(outputs.len(), d.outputs.len()); // insert the results in the substitution diff --git a/zokrates_interpreter/src/lib.rs b/zokrates_interpreter/src/lib.rs index 9f90e00c..c8182b20 100644 --- a/zokrates_interpreter/src/lib.rs +++ b/zokrates_interpreter/src/lib.rs @@ -83,7 +83,7 @@ impl Interpreter { inputs.pop().unwrap(), )) } - _ => Self::execute_solver(&d.solver, &inputs), + _ => Self::execute_solver(&d.solver, &inputs, &program.solvers), } .map_err(Error::Solver)?; @@ -164,7 +164,15 @@ impl Interpreter { pub fn execute_solver<'ast, T: Field>( solver: &Solver<'ast, T>, inputs: &[T], + solvers: &[Solver<'ast, T>], ) -> Result, String> { + let solver = match solver { + Solver::IndexedCall(index, _) => solvers + .get(*index) + .ok_or_else(|| format!("Could not resolve solver at index {}", index))?, + s => s, + }; + let (expected_input_count, expected_output_count) = solver.get_signature(); assert_eq!(inputs.len(), expected_input_count); @@ -334,6 +342,7 @@ impl Interpreter { &inputs[*n + 8usize..], ) } + _ => unreachable!("unexpected solver"), }; assert_eq!(res.len(), expected_output_count); From 6a16198bed54afc7dbe7ceb1afdee32ec15044c0 Mon Sep 17 00:00:00 2001 From: dark64 Date: Thu, 26 Jan 2023 15:21:00 +0100 Subject: [PATCH 03/15] fix message --- zokrates_ast/src/ir/serialize.rs | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/zokrates_ast/src/ir/serialize.rs b/zokrates_ast/src/ir/serialize.rs index 3bd40366..246cdbb5 100644 --- a/zokrates_ast/src/ir/serialize.rs +++ b/zokrates_ast/src/ir/serialize.rs @@ -154,7 +154,7 @@ impl<'de, R: Read + Seek> let mut buffer = [0u8; std::mem::size_of::()]; r.read_exact(&mut buffer) - .map_err(|_| String::from("Cannot read solver list pointer"))?; + .map_err(|_| String::from("Cannot read solver list offset"))?; let solver_list_offset = u64::from_le_bytes(buffer); @@ -179,7 +179,7 @@ impl<'de, R: Read + Seek> let solvers: Vec> = { let mut p = serde_cbor::Deserializer::from_reader(r.by_ref()); Vec::deserialize(&mut p) - .map_err(|_| String::from("Cannot read solver map"))? + .map_err(|_| String::from("Cannot read solver list"))? }; r.seek(std::io::SeekFrom::Start(statement_offset)).unwrap(); @@ -197,7 +197,7 @@ impl<'de, R: Read + Seek> let solvers: Vec> = { let mut p = serde_cbor::Deserializer::from_reader(r.by_ref()); Vec::deserialize(&mut p) - .map_err(|_| String::from("Cannot read solver map"))? + .map_err(|_| String::from("Cannot read solver list"))? }; r.seek(std::io::SeekFrom::Start(statement_offset)).unwrap(); @@ -216,7 +216,7 @@ impl<'de, R: Read + Seek> let solvers: Vec> = { let mut p = serde_cbor::Deserializer::from_reader(r.by_ref()); Vec::deserialize(&mut p) - .map_err(|_| String::from("Cannot read solver map"))? + .map_err(|_| String::from("Cannot read solver list"))? }; r.seek(std::io::SeekFrom::Start(statement_offset)).unwrap(); @@ -235,7 +235,7 @@ impl<'de, R: Read + Seek> let solvers: Vec> = { let mut p = serde_cbor::Deserializer::from_reader(r.by_ref()); Vec::deserialize(&mut p) - .map_err(|_| String::from("Cannot read solver map"))? + .map_err(|_| String::from("Cannot read solver list"))? }; r.seek(std::io::SeekFrom::Start(statement_offset)).unwrap(); From fdd441c37447abafbe13ed740871cc2d50b3006d Mon Sep 17 00:00:00 2001 From: dark64 Date: Thu, 26 Jan 2023 15:24:24 +0100 Subject: [PATCH 04/15] use cursor in zokrates_js --- zokrates_js/src/lib.rs | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/zokrates_js/src/lib.rs b/zokrates_js/src/lib.rs index 07082c3e..c776f34b 100644 --- a/zokrates_js/src/lib.rs +++ b/zokrates_js/src/lib.rs @@ -471,7 +471,8 @@ pub fn compute_witness( config: JsValue, log_callback: &js_sys::Function, ) -> Result { - let prog = ir::ProgEnum::deserialize(program) + let cursor = Cursor::new(program); + let prog = ir::ProgEnum::deserialize(cursor) .map_err(|err| JsValue::from_str(&err))? .collect(); match prog { @@ -533,7 +534,8 @@ pub fn setup(program: &[u8], options: JsValue) -> Result { ) .map_err(|e| JsValue::from_str(&e))?; - let prog = ir::ProgEnum::deserialize(program) + let cursor = Cursor::new(program); + let prog = ir::ProgEnum::deserialize(cursor) .map_err(|err| JsValue::from_str(&err))? .collect(); @@ -572,7 +574,8 @@ pub fn setup_with_srs(srs: &[u8], program: &[u8], options: JsValue) -> Result Date: Tue, 28 Feb 2023 02:04:02 +0100 Subject: [PATCH 05/15] refactor, fix tests --- Cargo.lock | 1 + zokrates_ark/src/gm17.rs | 2 + zokrates_ark/src/groth16.rs | 2 + zokrates_ark/src/marlin.rs | 2 + zokrates_ast/Cargo.toml | 1 + zokrates_ast/src/ir/serialize.rs | 403 +++++++++++++------- zokrates_bellman/src/groth16.rs | 1 + zokrates_bellman/src/lib.rs | 6 + zokrates_circom/src/lib.rs | 1 + zokrates_circom/src/r1cs.rs | 2 + zokrates_core/src/optimizer/duplicate.rs | 3 + zokrates_core/src/optimizer/redefinition.rs | 11 + zokrates_interpreter/src/lib.rs | 22 +- zokrates_test/tests/wasm.rs | 1 + 14 files changed, 314 insertions(+), 144 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index e2191772..dce729b9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2959,6 +2959,7 @@ name = "zokrates_ast" version = "0.1.4" dependencies = [ "ark-bls12-377", + "byteorder", "cfg-if 0.1.10", "csv", "derivative", diff --git a/zokrates_ark/src/gm17.rs b/zokrates_ark/src/gm17.rs index a8941ed0..337770b7 100644 --- a/zokrates_ark/src/gm17.rs +++ b/zokrates_ark/src/gm17.rs @@ -123,6 +123,7 @@ mod tests { arguments: vec![Parameter::public(Variable::new(0))], return_count: 1, statements: vec![Statement::constraint(Variable::new(0), Variable::public(0))], + solvers: vec![], }; let rng = &mut StdRng::from_entropy(); @@ -148,6 +149,7 @@ mod tests { arguments: vec![Parameter::public(Variable::new(0))], return_count: 1, statements: vec![Statement::constraint(Variable::new(0), Variable::public(0))], + solvers: vec![], }; let rng = &mut StdRng::from_entropy(); diff --git a/zokrates_ark/src/groth16.rs b/zokrates_ark/src/groth16.rs index e9281210..1a832f1e 100644 --- a/zokrates_ark/src/groth16.rs +++ b/zokrates_ark/src/groth16.rs @@ -120,6 +120,7 @@ mod tests { arguments: vec![Parameter::public(Variable::new(0))], return_count: 1, statements: vec![Statement::constraint(Variable::new(0), Variable::public(0))], + solvers: vec![], }; let rng = &mut StdRng::from_entropy(); @@ -145,6 +146,7 @@ mod tests { arguments: vec![Parameter::public(Variable::new(0))], return_count: 1, statements: vec![Statement::constraint(Variable::new(0), Variable::public(0))], + solvers: vec![], }; let rng = &mut StdRng::from_entropy(); diff --git a/zokrates_ark/src/marlin.rs b/zokrates_ark/src/marlin.rs index b0e0f75e..233f9ac5 100644 --- a/zokrates_ark/src/marlin.rs +++ b/zokrates_ark/src/marlin.rs @@ -404,6 +404,7 @@ mod tests { ), Statement::constraint(Variable::new(1), Variable::public(0)), ], + solvers: vec![], }; let rng = &mut StdRng::from_entropy(); @@ -439,6 +440,7 @@ mod tests { ), Statement::constraint(Variable::new(1), Variable::public(0)), ], + solvers: vec![], }; let rng = &mut StdRng::from_entropy(); diff --git a/zokrates_ast/Cargo.toml b/zokrates_ast/Cargo.toml index 6d9b4324..62a40343 100644 --- a/zokrates_ast/Cargo.toml +++ b/zokrates_ast/Cargo.toml @@ -9,6 +9,7 @@ bellman = ["zokrates_field/bellman", "pairing_ce", "zokrates_embed/bellman"] ark = ["ark-bls12-377", "zokrates_embed/ark"] [dependencies] +byteorder = "1.4.3" zokrates_pest_ast = { version = "0.3.0", path = "../zokrates_pest_ast" } cfg-if = "0.1" zokrates_field = { version = "0.5", path = "../zokrates_field", default-features = false } diff --git a/zokrates_ast/src/ir/serialize.rs b/zokrates_ast/src/ir/serialize.rs index 246cdbb5..6266662d 100644 --- a/zokrates_ast/src/ir/serialize.rs +++ b/zokrates_ast/src/ir/serialize.rs @@ -4,6 +4,7 @@ use crate::{ }; use super::{ProgIterator, Statement}; +use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; use serde::Deserialize; use serde_cbor::{self, StreamDeserializer}; use std::io::{Read, Seek, Write}; @@ -12,7 +13,7 @@ use zokrates_field::*; type DynamicError = Box; const ZOKRATES_MAGIC: &[u8; 4] = &[0x5a, 0x4f, 0x4b, 0]; -const ZOKRATES_VERSION_3: &[u8; 4] = &[0, 0, 0, 3]; +const FILE_VERSION: &[u8; 4] = &[3, 0, 0, 0]; #[derive(PartialEq, Eq, Debug)] pub enum ProgEnum< @@ -62,47 +63,195 @@ impl< } } +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +pub enum SectionType { + Arguments = 1, + Constraints = 2, + Solvers = 3, +} + +impl TryFrom for SectionType { + type Error = String; + + fn try_from(value: u32) -> Result { + match value { + 1 => Ok(SectionType::Arguments), + 2 => Ok(SectionType::Constraints), + 3 => Ok(SectionType::Solvers), + _ => Err("invalid section type".to_string()), + } + } +} + +#[derive(Debug, Clone)] +pub struct Section { + pub ty: SectionType, + pub offset: u64, + pub length: u64, +} + +impl Section { + pub fn new(ty: SectionType) -> Self { + Self { + ty, + offset: 0, + length: 0, + } + } + + pub fn set_offset(&mut self, offset: u64) { + self.offset = offset; + } + + pub fn set_length(&mut self, length: u64) { + self.length = length; + } +} + +#[derive(Debug, Clone)] +pub struct ProgHeader { + pub magic: [u8; 4], + pub version: [u8; 4], + pub curve_id: [u8; 4], + pub constraint_count: u32, + pub return_count: u32, + pub sections: Vec
, +} + +impl ProgHeader { + pub fn write(&self, mut w: W) -> std::io::Result<()> { + w.write_all(&self.magic)?; + w.write_all(&self.version)?; + w.write_all(&self.curve_id)?; + w.write_u32::(self.constraint_count)?; + w.write_u32::(self.return_count)?; + + w.write_u32::(self.sections.len() as u32)?; + for s in &self.sections { + w.write_u32::(s.ty as u32)?; + w.write_u64::(s.offset)?; + w.write_u64::(s.length)?; + } + + Ok(()) + } + + pub fn read(mut r: R) -> std::io::Result { + let mut magic = [0; 4]; + r.read_exact(&mut magic)?; + + let mut version = [0; 4]; + r.read_exact(&mut version)?; + + let mut curve_id = [0; 4]; + r.read_exact(&mut curve_id)?; + + let constraint_count = r.read_u32::()?; + let return_count = r.read_u32::()?; + + let section_count = r.read_u32::()?; + let mut sections = vec![]; + + for _ in 0..section_count { + let id = r.read_u32::()?; + let mut section = Section::new( + SectionType::try_from(id) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?, + ); + section.set_offset(r.read_u64::()?); + section.set_length(r.read_u64::()?); + sections.push(section); + } + + Ok(ProgHeader { + magic, + version, + curve_id, + constraint_count, + return_count, + sections, + }) + } + + fn get_section(&self, ty: SectionType) -> Option<&Section> { + self.sections.iter().find(|s| s.ty == ty) + } +} + impl<'ast, T: Field, I: IntoIterator>> ProgIterator<'ast, T, I> { /// serialize a program iterator, returning the number of constraints serialized /// Note that we only return constraints, not other statements such as directives pub fn serialize(self, mut w: W) -> Result { use super::folder::Folder; - w.write_all(ZOKRATES_MAGIC)?; - w.write_all(ZOKRATES_VERSION_3)?; - w.write_all(&T::id())?; + const SECTION_COUNT: usize = 3; + const HEADER_SIZE: usize = 24 + SECTION_COUNT * 20; - let solver_list_ptr_offset = w.stream_position()?; - w.write_all(&[0u8; std::mem::size_of::()])?; // reserve 8 bytes + let mut header = ProgHeader { + magic: *ZOKRATES_MAGIC, + version: *FILE_VERSION, + curve_id: T::id(), + constraint_count: 0, + return_count: self.return_count as u32, + sections: Vec::with_capacity(SECTION_COUNT), + }; - serde_cbor::to_writer(&mut w, &self.arguments)?; - serde_cbor::to_writer(&mut w, &self.return_count)?; + w.write_all(&[0u8; HEADER_SIZE])?; // reserve bytes for the header + + let arguments = { + let mut section = Section::new(SectionType::Arguments); + section.set_offset(w.stream_position()?); + + serde_cbor::to_writer(&mut w, &self.arguments)?; + + section.set_length(w.stream_position()? - section.offset); + section + }; - let mut unconstrained_variable_detector = UnconstrainedVariableDetector::new(&self); let mut solver_indexer: SolverIndexer<'ast, T> = SolverIndexer::default(); + let mut unconstrained_variable_detector = UnconstrainedVariableDetector::new(&self); + let mut count: usize = 0; - let statements = self.statements.into_iter(); + let constraints = { + let mut section = Section::new(SectionType::Constraints); + section.set_offset(w.stream_position()?); - let mut count = 0; - for s in statements { - if matches!(s, Statement::Constraint(..)) { - count += 1; + let statements = self.statements.into_iter(); + for s in statements { + if matches!(s, Statement::Constraint(..)) { + count += 1; + } + let s: Vec> = solver_indexer + .fold_statement(s) + .into_iter() + .flat_map(|s| unconstrained_variable_detector.fold_statement(s)) + .collect(); + for s in s { + serde_cbor::to_writer(&mut w, &s)?; + } } - let s: Vec> = solver_indexer - .fold_statement(s) - .into_iter() - .flat_map(|s| unconstrained_variable_detector.fold_statement(s)) - .collect(); - for s in s { - serde_cbor::to_writer(&mut w, &s)?; - } - } - let solver_list_offset = w.stream_position()?; - serde_cbor::to_writer(&mut w, &solver_indexer.solvers)?; + section.set_length(w.stream_position()? - section.offset); + section + }; - w.seek(std::io::SeekFrom::Start(solver_list_ptr_offset))?; - w.write_all(&solver_list_offset.to_le_bytes())?; + let solvers = { + let mut section = Section::new(SectionType::Solvers); + section.set_offset(w.stream_position()?); + + serde_cbor::to_writer(&mut w, &solver_indexer.solvers)?; + + section.set_length(w.stream_position()? - section.offset); + section + }; + + header.constraint_count = count as u32; + header + .sections + .extend_from_slice(&[arguments, constraints, solvers]); + + w.rewind()?; + header.write(&mut w)?; unconstrained_variable_detector .finalize() @@ -135,128 +284,108 @@ impl<'de, R: Read + Seek> > { pub fn deserialize(mut r: R) -> Result { + let header = ProgHeader::read(&mut r).map_err(|_| String::from("Invalid header"))?; + // Check the magic number, `ZOK` - let mut magic = [0; 4]; - r.read_exact(&mut magic) - .map_err(|_| String::from("Cannot read magic number"))?; + if &header.magic != ZOKRATES_MAGIC { + return Err("Invalid magic number".to_string()); + } - if &magic == ZOKRATES_MAGIC { - // Check the version, 2 - let mut version = [0; 4]; - r.read_exact(&mut version) - .map_err(|_| String::from("Cannot read version"))?; + // Check the file version + if &header.version != FILE_VERSION { + return Err("Invalid file version".to_string()); + } - if &version == ZOKRATES_VERSION_3 { - // Check the curve identifier, deserializing accordingly - let mut curve = [0; 4]; - r.read_exact(&mut curve) - .map_err(|_| String::from("Cannot read curve identifier"))?; + let arguments = { + let section = header.get_section(SectionType::Arguments).unwrap(); + r.seek(std::io::SeekFrom::Start(section.offset)).unwrap(); - let mut buffer = [0u8; std::mem::size_of::()]; - r.read_exact(&mut buffer) - .map_err(|_| String::from("Cannot read solver list offset"))?; + let mut p = serde_cbor::Deserializer::from_reader(r.by_ref()); + Vec::deserialize(&mut p).map_err(|_| String::from("Cannot read parameters"))? + }; - let solver_list_offset = u64::from_le_bytes(buffer); + let solvers_section = header.get_section(SectionType::Solvers).unwrap(); + r.seek(std::io::SeekFrom::Start(solvers_section.offset)) + .unwrap(); - let (arguments, return_count) = { + match header.curve_id { + m if m == Bls12_381Field::id() => { + let solvers: Vec> = { let mut p = serde_cbor::Deserializer::from_reader(r.by_ref()); - - let arguments: Vec = Vec::deserialize(&mut p) - .map_err(|_| String::from("Cannot read parameters"))?; - - let return_count = usize::deserialize(&mut p) - .map_err(|_| String::from("Cannot read return count"))?; - - (arguments, return_count) + Vec::deserialize(&mut p).map_err(|_| String::from("Cannot read solvers"))? }; - let statement_offset = r.stream_position().unwrap(); - r.seek(std::io::SeekFrom::Start(solver_list_offset)) - .unwrap(); + let section = header.get_section(SectionType::Constraints).unwrap(); + r.seek(std::io::SeekFrom::Start(section.offset)).unwrap(); - match curve { - m if m == Bls12_381Field::id() => { - let solvers: Vec> = { - let mut p = serde_cbor::Deserializer::from_reader(r.by_ref()); - Vec::deserialize(&mut p) - .map_err(|_| String::from("Cannot read solver list"))? - }; + let p = serde_cbor::Deserializer::from_reader(r); + let s = p.into_iter::>(); - r.seek(std::io::SeekFrom::Start(statement_offset)).unwrap(); - - let p = serde_cbor::Deserializer::from_reader(r); - let s = p.into_iter::>(); - Ok(ProgEnum::Bls12_381Program(ProgIterator::new( - arguments, - UnwrappedStreamDeserializer { s }, - return_count, - solvers, - ))) - } - m if m == Bn128Field::id() => { - let solvers: Vec> = { - let mut p = serde_cbor::Deserializer::from_reader(r.by_ref()); - Vec::deserialize(&mut p) - .map_err(|_| String::from("Cannot read solver list"))? - }; - - r.seek(std::io::SeekFrom::Start(statement_offset)).unwrap(); - - let p = serde_cbor::Deserializer::from_reader(r); - let s = p.into_iter::>(); - - Ok(ProgEnum::Bn128Program(ProgIterator::new( - arguments, - UnwrappedStreamDeserializer { s }, - return_count, - solvers, - ))) - } - m if m == Bls12_377Field::id() => { - let solvers: Vec> = { - let mut p = serde_cbor::Deserializer::from_reader(r.by_ref()); - Vec::deserialize(&mut p) - .map_err(|_| String::from("Cannot read solver list"))? - }; - - r.seek(std::io::SeekFrom::Start(statement_offset)).unwrap(); - - let p = serde_cbor::Deserializer::from_reader(r); - let s = p.into_iter::>(); - - Ok(ProgEnum::Bls12_377Program(ProgIterator::new( - arguments, - UnwrappedStreamDeserializer { s }, - return_count, - solvers, - ))) - } - m if m == Bw6_761Field::id() => { - let solvers: Vec> = { - let mut p = serde_cbor::Deserializer::from_reader(r.by_ref()); - Vec::deserialize(&mut p) - .map_err(|_| String::from("Cannot read solver list"))? - }; - - r.seek(std::io::SeekFrom::Start(statement_offset)).unwrap(); - - let p = serde_cbor::Deserializer::from_reader(r); - let s = p.into_iter::>(); - - Ok(ProgEnum::Bw6_761Program(ProgIterator::new( - arguments, - UnwrappedStreamDeserializer { s }, - return_count, - solvers, - ))) - } - _ => Err(String::from("Unknown curve identifier")), - } - } else { - Err(String::from("Unknown version")) + Ok(ProgEnum::Bls12_381Program(ProgIterator::new( + arguments, + UnwrappedStreamDeserializer { s }, + header.return_count as usize, + solvers, + ))) } - } else { - Err(String::from("Wrong magic number")) + m if m == Bn128Field::id() => { + let solvers: Vec> = { + let mut p = serde_cbor::Deserializer::from_reader(r.by_ref()); + Vec::deserialize(&mut p).map_err(|_| String::from("Cannot read solvers"))? + }; + + let section = header.get_section(SectionType::Constraints).unwrap(); + r.seek(std::io::SeekFrom::Start(section.offset)).unwrap(); + + let p = serde_cbor::Deserializer::from_reader(r); + let s = p.into_iter::>(); + + Ok(ProgEnum::Bn128Program(ProgIterator::new( + arguments, + UnwrappedStreamDeserializer { s }, + header.return_count as usize, + solvers, + ))) + } + m if m == Bls12_377Field::id() => { + let solvers: Vec> = { + let mut p = serde_cbor::Deserializer::from_reader(r.by_ref()); + Vec::deserialize(&mut p).map_err(|_| String::from("Cannot read solvers"))? + }; + + let section = header.get_section(SectionType::Constraints).unwrap(); + r.seek(std::io::SeekFrom::Start(section.offset)).unwrap(); + + let p = serde_cbor::Deserializer::from_reader(r); + let s = p.into_iter::>(); + + Ok(ProgEnum::Bls12_377Program(ProgIterator::new( + arguments, + UnwrappedStreamDeserializer { s }, + header.return_count as usize, + solvers, + ))) + } + m if m == Bw6_761Field::id() => { + let solvers: Vec> = { + let mut p = serde_cbor::Deserializer::from_reader(r.by_ref()); + Vec::deserialize(&mut p).map_err(|_| String::from("Cannot read solvers"))? + }; + + let section = header.get_section(SectionType::Constraints).unwrap(); + r.seek(std::io::SeekFrom::Start(section.offset)).unwrap(); + + let p = serde_cbor::Deserializer::from_reader(r); + let s = p.into_iter::>(); + + Ok(ProgEnum::Bw6_761Program(ProgIterator::new( + arguments, + UnwrappedStreamDeserializer { s }, + header.return_count as usize, + solvers, + ))) + } + _ => Err(String::from("Unknown curve identifier")), } } } diff --git a/zokrates_bellman/src/groth16.rs b/zokrates_bellman/src/groth16.rs index d0e6c417..1eae3b46 100644 --- a/zokrates_bellman/src/groth16.rs +++ b/zokrates_bellman/src/groth16.rs @@ -214,6 +214,7 @@ mod tests { arguments: vec![Parameter::public(Variable::new(0))], return_count: 1, statements: vec![Statement::constraint(Variable::new(0), Variable::public(0))], + solvers: vec![], }; let rng = &mut StdRng::from_entropy(); diff --git a/zokrates_bellman/src/lib.rs b/zokrates_bellman/src/lib.rs index 9828b974..821e3ea2 100644 --- a/zokrates_bellman/src/lib.rs +++ b/zokrates_bellman/src/lib.rs @@ -276,6 +276,7 @@ mod tests { arguments: vec![Parameter::private(Variable::new(0))], return_count: 1, statements: vec![Statement::constraint(Variable::new(0), Variable::public(0))], + solvers: vec![], }; let interpreter = Interpreter::default(); @@ -297,6 +298,7 @@ mod tests { arguments: vec![Parameter::public(Variable::new(0))], return_count: 1, statements: vec![Statement::constraint(Variable::new(0), Variable::public(0))], + solvers: vec![], }; let interpreter = Interpreter::default(); @@ -318,6 +320,7 @@ mod tests { arguments: vec![], return_count: 1, statements: vec![Statement::constraint(Variable::one(), Variable::public(0))], + solvers: vec![], }; let interpreter = Interpreter::default(); @@ -350,6 +353,7 @@ mod tests { Variable::public(1), ), ], + solvers: vec![], }; let interpreter = Interpreter::default(); @@ -373,6 +377,7 @@ mod tests { LinComb::from(Variable::new(42)) + LinComb::one(), Variable::public(0), )], + solvers: vec![], }; let interpreter = Interpreter::default(); @@ -400,6 +405,7 @@ mod tests { LinComb::from(Variable::new(42)) + LinComb::from(Variable::new(51)), Variable::public(0), )], + solvers: vec![], }; let interpreter = Interpreter::default(); diff --git a/zokrates_circom/src/lib.rs b/zokrates_circom/src/lib.rs index 9b16742f..e87c0af3 100644 --- a/zokrates_circom/src/lib.rs +++ b/zokrates_circom/src/lib.rs @@ -44,6 +44,7 @@ mod tests { None, ), ], + solvers: vec![], }; let mut r1cs = vec![]; diff --git a/zokrates_circom/src/r1cs.rs b/zokrates_circom/src/r1cs.rs index 854bc0ea..b2fc8db2 100644 --- a/zokrates_circom/src/r1cs.rs +++ b/zokrates_circom/src/r1cs.rs @@ -296,6 +296,7 @@ mod tests { Variable::public(0).into(), None, )], + solvers: vec![], }; let mut buf = Vec::new(); @@ -365,6 +366,7 @@ mod tests { None, ), ], + solvers: vec![], }; let mut buf = Vec::new(); diff --git a/zokrates_core/src/optimizer/duplicate.rs b/zokrates_core/src/optimizer/duplicate.rs index a08a3f73..346695a2 100644 --- a/zokrates_core/src/optimizer/duplicate.rs +++ b/zokrates_core/src/optimizer/duplicate.rs @@ -81,6 +81,7 @@ mod tests { ], return_count: 0, arguments: vec![], + solvers: vec![], }; let expected = p.clone(); @@ -117,6 +118,7 @@ mod tests { ], return_count: 0, arguments: vec![], + solvers: vec![], }; let expected = Prog { @@ -132,6 +134,7 @@ mod tests { ], return_count: 0, arguments: vec![], + solvers: vec![], }; assert_eq!( diff --git a/zokrates_core/src/optimizer/redefinition.rs b/zokrates_core/src/optimizer/redefinition.rs index bc7fc279..c1f90aa5 100644 --- a/zokrates_core/src/optimizer/redefinition.rs +++ b/zokrates_core/src/optimizer/redefinition.rs @@ -256,12 +256,14 @@ mod tests { Statement::definition(out, y), ], return_count: 1, + solvers: vec![], }; let optimized: Prog = Prog { arguments: vec![x], statements: vec![Statement::definition(out, x.id)], return_count: 1, + solvers: vec![], }; let mut optimizer = RedefinitionOptimizer::init(&p); @@ -280,6 +282,7 @@ mod tests { arguments: vec![x], statements: vec![Statement::definition(one, x.id)], return_count: 1, + solvers: vec![], }; let optimized = p.clone(); @@ -316,6 +319,7 @@ mod tests { Statement::definition(out, z), ], return_count: 1, + solvers: vec![], }; let optimized: Prog = Prog { @@ -325,6 +329,7 @@ mod tests { Statement::definition(out, x.id), ], return_count: 1, + solvers: vec![], }; let mut optimizer = RedefinitionOptimizer::init(&p); @@ -365,6 +370,7 @@ mod tests { Statement::definition(out_1, w), ], return_count: 2, + solvers: vec![], }; let optimized: Prog = Prog { @@ -374,6 +380,7 @@ mod tests { Statement::definition(out_1, Bn128Field::from(1)), ], return_count: 2, + solvers: vec![], }; let mut optimizer = RedefinitionOptimizer::init(&p); @@ -422,6 +429,7 @@ mod tests { Statement::definition(r, LinComb::from(a) + LinComb::from(b) + LinComb::from(c)), ], return_count: 1, + solvers: vec![], }; let expected: Prog = Prog { @@ -442,6 +450,7 @@ mod tests { ), ], return_count: 1, + solvers: vec![], }; let mut optimizer = RedefinitionOptimizer::init(&p); @@ -479,6 +488,7 @@ mod tests { Statement::definition(z, LinComb::from(x.id)), ], return_count: 0, + solvers: vec![], }; let optimized = p.clone(); @@ -507,6 +517,7 @@ mod tests { Statement::constraint(x.id, Bn128Field::from(2)), ], return_count: 1, + solvers: vec![], }; let optimized = p.clone(); diff --git a/zokrates_interpreter/src/lib.rs b/zokrates_interpreter/src/lib.rs index c8182b20..14fbade2 100644 --- a/zokrates_interpreter/src/lib.rs +++ b/zokrates_interpreter/src/lib.rs @@ -443,6 +443,7 @@ mod tests { .iter() .map(|&i| Bn128Field::from(i)) .collect::>(), + &[], ) .unwrap(); let res: Vec = vec![0, 1].iter().map(|&i| Bn128Field::from(i)).collect(); @@ -459,6 +460,7 @@ mod tests { .iter() .map(|&i| Bn128Field::from(i)) .collect::>(), + &[], ) .unwrap(); let res: Vec = vec![1, 1].iter().map(|&i| Bn128Field::from(i)).collect(); @@ -469,9 +471,12 @@ mod tests { #[test] fn bits_of_one() { let inputs = vec![Bn128Field::from(1)]; - let res = - Interpreter::execute_solver(&Solver::Bits(Bn128Field::get_required_bits()), &inputs) - .unwrap(); + let res = Interpreter::execute_solver( + &Solver::Bits(Bn128Field::get_required_bits()), + &inputs, + &[], + ) + .unwrap(); assert_eq!(res[253], Bn128Field::from(1)); for r in &res[0..253] { assert_eq!(*r, Bn128Field::from(0)); @@ -481,9 +486,12 @@ mod tests { #[test] fn bits_of_42() { let inputs = vec![Bn128Field::from(42)]; - let res = - Interpreter::execute_solver(&Solver::Bits(Bn128Field::get_required_bits()), &inputs) - .unwrap(); + let res = Interpreter::execute_solver( + &Solver::Bits(Bn128Field::get_required_bits()), + &inputs, + &[], + ) + .unwrap(); assert_eq!(res[253], Bn128Field::from(0)); assert_eq!(res[252], Bn128Field::from(1)); assert_eq!(res[251], Bn128Field::from(0)); @@ -496,7 +504,7 @@ mod tests { #[test] fn five_hundred_bits_of_1() { let inputs = vec![Bn128Field::from(1)]; - let res = Interpreter::execute_solver(&Solver::Bits(500), &inputs).unwrap(); + let res = Interpreter::execute_solver(&Solver::Bits(500), &inputs, &[]).unwrap(); let mut expected = vec![Bn128Field::from(0); 500]; expected[499] = Bn128Field::from(1); diff --git a/zokrates_test/tests/wasm.rs b/zokrates_test/tests/wasm.rs index 8b81f3bf..659f5a23 100644 --- a/zokrates_test/tests/wasm.rs +++ b/zokrates_test/tests/wasm.rs @@ -20,6 +20,7 @@ fn generate_proof() { arguments: vec![Parameter::public(Variable::new(0))], return_count: 1, statements: vec![Statement::constraint(Variable::new(0), Variable::new(0))], + solvers: vec![], }; let interpreter = Interpreter::default(); From 779bb2bc203ac955028743086dc472616aa51bc6 Mon Sep 17 00:00:00 2001 From: dark64 Date: Tue, 28 Feb 2023 02:06:01 +0100 Subject: [PATCH 06/15] add changelog --- changelogs/unreleased/1268-dark64 | 1 + 1 file changed, 1 insertion(+) create mode 100644 changelogs/unreleased/1268-dark64 diff --git a/changelogs/unreleased/1268-dark64 b/changelogs/unreleased/1268-dark64 new file mode 100644 index 00000000..8b060b16 --- /dev/null +++ b/changelogs/unreleased/1268-dark64 @@ -0,0 +1 @@ +Optimize assembly solver \ No newline at end of file From 3ce57d93a4bd70456bc3b83821a99a55dc4583e3 Mon Sep 17 00:00:00 2001 From: dark64 Date: Fri, 3 Mar 2023 18:19:24 +0100 Subject: [PATCH 07/15] optimize evaluation of lincomb --- zokrates_ast/src/ir/expression.rs | 7 ++--- zokrates_field/src/dummy_curve.rs | 4 ++- zokrates_field/src/lib.rs | 3 +- zokrates_interpreter/src/lib.rs | 52 ++++++++++++++----------------- 4 files changed, 31 insertions(+), 35 deletions(-) diff --git a/zokrates_ast/src/ir/expression.rs b/zokrates_ast/src/ir/expression.rs index a32a1293..28ac7993 100644 --- a/zokrates_ast/src/ir/expression.rs +++ b/zokrates_ast/src/ir/expression.rs @@ -121,9 +121,8 @@ impl LinComb { } pub fn is_assignee(&self, witness: &Witness) -> bool { - self.0.len() == 1 - && self.0.get(0).unwrap().1 == T::from(1) - && !witness.0.contains_key(&self.0.get(0).unwrap().0) + let (var, val) = self.0.get(0).unwrap(); + self.0.len() == 1 && val == &T::from(1) && !witness.0.contains_key(var) } pub fn try_summand(self) -> Result<(Variable, T), Self> { @@ -258,7 +257,7 @@ impl Mul<&T> for LinComb { LinComb( self.0 .into_iter() - .map(|(var, coeff)| (var, coeff * scalar.clone())) + .map(|(var, coeff)| (var, coeff * scalar)) .collect(), ) } diff --git a/zokrates_field/src/dummy_curve.rs b/zokrates_field/src/dummy_curve.rs index 5d3aed4a..b1326592 100644 --- a/zokrates_field/src/dummy_curve.rs +++ b/zokrates_field/src/dummy_curve.rs @@ -10,7 +10,9 @@ use std::ops::{Add, Div, Mul, Sub}; const _PRIME: u8 = 7; -#[derive(Default, Debug, Hash, Clone, PartialOrd, Ord, Serialize, Deserialize, PartialEq, Eq)] +#[derive( + Default, Debug, Hash, Clone, Copy, PartialOrd, Ord, Serialize, Deserialize, PartialEq, Eq, +)] pub struct FieldPrime { v: u8, } diff --git a/zokrates_field/src/lib.rs b/zokrates_field/src/lib.rs index 38f76905..81e9915f 100644 --- a/zokrates_field/src/lib.rs +++ b/zokrates_field/src/lib.rs @@ -70,6 +70,7 @@ pub trait Field: + Zero + One + Clone + + Copy + PartialEq + Eq + Hash @@ -148,7 +149,7 @@ mod prime_field { type Fr = <$v as ark_ec::PairingEngine>::Fr; - #[derive(PartialEq, PartialOrd, Clone, Eq, Ord, Hash)] + #[derive(PartialEq, PartialOrd, Clone, Copy, Eq, Ord, Hash)] pub struct FieldPrime { v: Fr, } diff --git a/zokrates_interpreter/src/lib.rs b/zokrates_interpreter/src/lib.rs index 14fbade2..275624fe 100644 --- a/zokrates_interpreter/src/lib.rs +++ b/zokrates_interpreter/src/lib.rs @@ -88,7 +88,7 @@ impl Interpreter { .map_err(Error::Solver)?; for (i, o) in d.outputs.iter().enumerate() { - witness.insert(*o, res[i].clone()); + witness.insert(*o, res[i]); } } Statement::Log(l, expressions) => { @@ -264,41 +264,38 @@ impl Interpreter { .collect() } Solver::Xor => { - let x = inputs[0].clone(); - let y = inputs[1].clone(); + let x = inputs[0]; + let y = inputs[1]; - vec![x.clone() + y.clone() - T::from(2) * x * y] + vec![x + y - T::from(2) * x * y] } Solver::Or => { - let x = inputs[0].clone(); - let y = inputs[1].clone(); + let x = inputs[0]; + let y = inputs[1]; - vec![x.clone() + y.clone() - x * y] + vec![x + y - x * y] } // res = b * c - (2b * c - b - c) * (a) Solver::ShaAndXorAndXorAnd => { - let a = inputs[0].clone(); - let b = inputs[1].clone(); - let c = inputs[2].clone(); - vec![b.clone() * c.clone() - (T::from(2) * b.clone() * c.clone() - b - c) * a] + let a = inputs[0]; + let b = inputs[1]; + let c = inputs[2]; + vec![b * c - (T::from(2) * b * c - b - c) * a] } // res = a(b - c) + c Solver::ShaCh => { - let a = inputs[0].clone(); - let b = inputs[1].clone(); - let c = inputs[2].clone(); - vec![a * (b - c.clone()) + c] + let a = inputs[0]; + let b = inputs[1]; + let c = inputs[2]; + vec![a * (b - c) + c] } - Solver::Div => vec![inputs[0] - .clone() - .checked_div(&inputs[1]) - .unwrap_or_else(T::one)], + Solver::Div => vec![inputs[0].checked_div(&inputs[1]).unwrap_or_else(T::one)], Solver::EuclideanDiv => { use num::CheckedDiv; - let n = inputs[0].clone().to_biguint(); - let d = inputs[1].clone().to_biguint(); + let n = inputs[0].to_biguint(); + let d = inputs[1].to_biguint(); let q = n.checked_div(&d).unwrap_or_else(|| 0u32.into()); let r = n - d * &q; @@ -363,14 +360,11 @@ pub enum Error { } fn evaluate_lin(w: &Witness, l: &LinComb) -> Result { - l.0.iter() - .map(|(var, mult)| { - w.0.get(var) - .map(|v| v.clone() * mult) - .ok_or(EvaluationError) - }) // get each term - .collect::, _>>() // fail if any term isn't found - .map(|v| v.iter().fold(T::from(0), |acc, t| acc + t)) // return the sum + l.0.iter().try_fold(T::from(0), |acc, (var, mult)| { + w.0.get(var) + .map(|v| acc + (*v * mult)) + .ok_or(EvaluationError) // fail if any term isn't found + }) } pub fn evaluate_quad(w: &Witness, q: &QuadComb) -> Result { From e8abfb51ea3b12a8196ddfabb6fed202c8437401 Mon Sep 17 00:00:00 2001 From: dark64 Date: Mon, 6 Mar 2023 13:37:12 +0400 Subject: [PATCH 08/15] clippy --- zokrates_analysis/src/flat_propagation.rs | 2 +- zokrates_analysis/src/propagation.rs | 2 +- zokrates_analysis/src/uint_optimizer.rs | 6 +++--- zokrates_ark/src/lib.rs | 2 +- zokrates_ast/src/flat/utils.rs | 2 +- zokrates_ast/src/ir/expression.rs | 4 ++-- zokrates_ast/src/ir/mod.rs | 2 +- zokrates_ast/src/zir/lqc.rs | 16 ++++++++-------- zokrates_bellman/src/lib.rs | 2 +- zokrates_codegen/src/lib.rs | 2 +- zokrates_interpreter/src/lib.rs | 4 ++-- 11 files changed, 22 insertions(+), 22 deletions(-) diff --git a/zokrates_analysis/src/flat_propagation.rs b/zokrates_analysis/src/flat_propagation.rs index 155d803c..eb2ca797 100644 --- a/zokrates_analysis/src/flat_propagation.rs +++ b/zokrates_analysis/src/flat_propagation.rs @@ -32,7 +32,7 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator { match e { FlatExpression::Number(n) => FlatExpression::Number(n), FlatExpression::Identifier(id) => match self.constants.get(&id) { - Some(c) => FlatExpression::Number(c.clone()), + Some(c) => FlatExpression::Number(*c), None => FlatExpression::Identifier(id), }, FlatExpression::Add(box e1, box e2) => { diff --git a/zokrates_analysis/src/propagation.rs b/zokrates_analysis/src/propagation.rs index b7e5c0a1..708ebabb 100644 --- a/zokrates_analysis/src/propagation.rs +++ b/zokrates_analysis/src/propagation.rs @@ -517,7 +517,7 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Propagator<'ast, 'a, T> { .unwrap() { FieldElementExpression::Number(num) => { - let mut acc = num.clone(); + let mut acc = num; let mut res = vec![]; for i in (0..bit_width as usize).rev() { diff --git a/zokrates_analysis/src/uint_optimizer.rs b/zokrates_analysis/src/uint_optimizer.rs index ac96dbfe..6a6b217e 100644 --- a/zokrates_analysis/src/uint_optimizer.rs +++ b/zokrates_analysis/src/uint_optimizer.rs @@ -170,7 +170,7 @@ impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> { left_max .checked_add(&range_max.clone()) .map(|max| (false, true, max)) - .unwrap_or_else(|| (true, true, range_max.clone() + range_max)) + .unwrap_or_else(|| (true, true, range_max + range_max)) }) }); @@ -223,7 +223,7 @@ impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> { left_max .checked_add(&target_offset) .map(|max| (false, true, max)) - .unwrap_or_else(|| (true, true, range_max.clone() + target_offset)) + .unwrap_or_else(|| (true, true, range_max + target_offset)) } else { left_max .checked_add(&offset) @@ -294,7 +294,7 @@ impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> { left_max .checked_mul(&range_max.clone()) .map(|max| (false, true, max)) - .unwrap_or_else(|| (true, true, range_max.clone() * range_max)) + .unwrap_or_else(|| (true, true, range_max * range_max)) }) }); diff --git a/zokrates_ark/src/lib.rs b/zokrates_ark/src/lib.rs index 425be3a8..322785c3 100644 --- a/zokrates_ark/src/lib.rs +++ b/zokrates_ark/src/lib.rs @@ -150,7 +150,7 @@ impl<'a, T: Field + ArkFieldExtensions, I: IntoIterator> self.program .public_inputs_values(self.witness.as_ref().unwrap()) .iter() - .map(|v| v.clone().into_ark()) + .map(|v| v.into_ark()) .collect() } } diff --git a/zokrates_ast/src/flat/utils.rs b/zokrates_ast/src/flat/utils.rs index 03239687..c37d2f68 100644 --- a/zokrates_ast/src/flat/utils.rs +++ b/zokrates_ast/src/flat/utils.rs @@ -36,7 +36,7 @@ pub fn flat_expression_from_variable_summands(v: &[(T, usize)]) -> Fla match v.len() { 0 => FlatExpression::Number(T::zero()), 1 => { - let (val, var) = v[0].clone(); + let (val, var) = v[0]; FlatExpression::Mult( box FlatExpression::Number(val), box FlatExpression::Identifier(Variable::new(var)), diff --git a/zokrates_ast/src/ir/expression.rs b/zokrates_ast/src/ir/expression.rs index 28ac7993..ccebfa9a 100644 --- a/zokrates_ast/src/ir/expression.rs +++ b/zokrates_ast/src/ir/expression.rs @@ -164,8 +164,8 @@ impl LinComb { match acc.entry(val) { Entry::Occupied(o) => { // if the new value is non zero, update, else remove the term entirely - if o.get().clone() + coeff.clone() != T::zero() { - *o.into_mut() = o.get().clone() + coeff; + if *o.get() + coeff != T::zero() { + *o.into_mut() = *o.get() + coeff; } else { o.remove(); } diff --git a/zokrates_ast/src/ir/mod.rs b/zokrates_ast/src/ir/mod.rs index b401b757..66535d3c 100644 --- a/zokrates_ast/src/ir/mod.rs +++ b/zokrates_ast/src/ir/mod.rs @@ -181,7 +181,7 @@ impl<'ast, T: Field, I: IntoIterator>> ProgIterator<'a self.arguments .iter() .filter(|p| !p.private) - .map(|p| witness.0.get(&p.id).unwrap().clone()) + .map(|p| *witness.0.get(&p.id).unwrap()) .chain(witness.return_values()) .collect() } diff --git a/zokrates_ast/src/zir/lqc.rs b/zokrates_ast/src/zir/lqc.rs index 121b2a39..834f601f 100644 --- a/zokrates_ast/src/zir/lqc.rs +++ b/zokrates_ast/src/zir/lqc.rs @@ -43,7 +43,7 @@ impl<'ast, T: Field> std::ops::Sub for LinQuadComb<'ast, T> { linear: { let mut l = self.linear; other.linear.iter_mut().for_each(|(c, _)| { - *c = T::zero() - &*c; + *c = T::zero() - *c; }); l.append(&mut other.linear); l @@ -51,7 +51,7 @@ impl<'ast, T: Field> std::ops::Sub for LinQuadComb<'ast, T> { quadratic: { let mut q = self.quadratic; other.quadratic.iter_mut().for_each(|(c, _, _)| { - *c = T::zero() - &*c; + *c = T::zero() - *c; }); q.append(&mut other.quadratic); q @@ -68,18 +68,18 @@ impl<'ast, T: Field> LinQuadComb<'ast, T> { } Ok(Self { - constant: self.constant.clone() * rhs.constant.clone(), + constant: self.constant * rhs.constant, linear: { // lin0 * const1 + lin1 * const0 self.linear .clone() .into_iter() - .map(|(c, i)| (c * rhs.constant.clone(), i)) + .map(|(c, i)| (c * rhs.constant, i)) .chain( rhs.linear .clone() .into_iter() - .map(|(c, i)| (c * self.constant.clone(), i)), + .map(|(c, i)| (c * self.constant, i)), ) .collect() }, @@ -87,16 +87,16 @@ impl<'ast, T: Field> LinQuadComb<'ast, T> { // quad0 * const1 + quad1 * const0 + lin0 * lin1 self.quadratic .into_iter() - .map(|(c, i0, i1)| (c * rhs.constant.clone(), i0, i1)) + .map(|(c, i0, i1)| (c * rhs.constant, i0, i1)) .chain( rhs.quadratic .into_iter() - .map(|(c, i0, i1)| (c * self.constant.clone(), i0, i1)), + .map(|(c, i0, i1)| (c * self.constant, i0, i1)), ) .chain(self.linear.iter().flat_map(|(cl, l)| { rhs.linear .iter() - .map(|(cr, r)| (cl.clone() * cr.clone(), l.clone(), r.clone())) + .map(|(cr, r)| (*cl * *cr, l.clone(), r.clone())) })) .collect() }, diff --git a/zokrates_bellman/src/lib.rs b/zokrates_bellman/src/lib.rs index 821e3ea2..263cfc21 100644 --- a/zokrates_bellman/src/lib.rs +++ b/zokrates_bellman/src/lib.rs @@ -189,7 +189,7 @@ impl<'a, T: BellmanFieldExtensions + Field, I: IntoIterator Flattener<'ast, T> { // constants do not require directives if let Some(FlatExpression::Number(ref x)) = e.field { - let bits: Vec<_> = Interpreter::execute_solver(&Solver::bits(to), &[x.clone()], &[]) + let bits: Vec<_> = Interpreter::execute_solver(&Solver::bits(to), &[*x], &[]) .unwrap() .into_iter() .map(FlatExpression::Number) diff --git a/zokrates_interpreter/src/lib.rs b/zokrates_interpreter/src/lib.rs index 275624fe..8a7f8854 100644 --- a/zokrates_interpreter/src/lib.rs +++ b/zokrates_interpreter/src/lib.rs @@ -50,7 +50,7 @@ impl Interpreter { witness.insert(Variable::one(), T::one()); for (arg, value) in program.arguments.iter().zip(inputs.iter()) { - witness.insert(arg.id, value.clone()); + witness.insert(arg.id, *value); } for statement in program.statements.into_iter() { @@ -188,7 +188,7 @@ impl Interpreter { .map(|(a, v)| match &a.id._type { zir::Type::FieldElement => Ok(( a.id.id.clone(), - zokrates_ast::zir::FieldElementExpression::Number(v.clone()).into(), + zokrates_ast::zir::FieldElementExpression::Number(*v).into(), )), zir::Type::Boolean => match v { v if *v == T::from(0) => Ok(( From cfe43e672dd05f8953c3772cd65cd9ced275b200 Mon Sep 17 00:00:00 2001 From: dark64 Date: Mon, 6 Mar 2023 13:53:03 +0400 Subject: [PATCH 09/15] revert is_assignee --- zokrates_ast/src/ir/expression.rs | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/zokrates_ast/src/ir/expression.rs b/zokrates_ast/src/ir/expression.rs index ccebfa9a..33e3856e 100644 --- a/zokrates_ast/src/ir/expression.rs +++ b/zokrates_ast/src/ir/expression.rs @@ -121,8 +121,9 @@ impl LinComb { } pub fn is_assignee(&self, witness: &Witness) -> bool { - let (var, val) = self.0.get(0).unwrap(); - self.0.len() == 1 && val == &T::from(1) && !witness.0.contains_key(var) + self.0.len() == 1 + && self.0.get(0).unwrap().1 == T::from(1) + && !witness.0.contains_key(&self.0.get(0).unwrap().0) } pub fn try_summand(self) -> Result<(Variable, T), Self> { From 0f31a1b42ea03c419a7b5600d3ee3a05b7db6fbd Mon Sep 17 00:00:00 2001 From: dark64 Date: Tue, 14 Mar 2023 19:20:02 +0100 Subject: [PATCH 10/15] apply suggestions (part 1) --- changelogs/unreleased/1268-dark64 | 2 +- zokrates_ast/src/common/mod.rs | 2 +- zokrates_ast/src/common/solvers.rs | 12 +++++++++--- zokrates_ast/src/ir/solver_indexer.rs | 10 +++++++--- zokrates_ast/src/zir/identifier.rs | 12 ++++++------ zokrates_ast/src/zir/substitution.rs | 19 ++++++++----------- zokrates_codegen/src/lib.rs | 18 +++--------------- zokrates_interpreter/src/lib.rs | 6 +++--- 8 files changed, 38 insertions(+), 43 deletions(-) diff --git a/changelogs/unreleased/1268-dark64 b/changelogs/unreleased/1268-dark64 index 8b060b16..7d2f33da 100644 --- a/changelogs/unreleased/1268-dark64 +++ b/changelogs/unreleased/1268-dark64 @@ -1 +1 @@ -Optimize assembly solver \ No newline at end of file +Reduce compiled program size by deduplicating assembly solvers \ No newline at end of file diff --git a/zokrates_ast/src/common/mod.rs b/zokrates_ast/src/common/mod.rs index 13d23bfd..c8ad0944 100644 --- a/zokrates_ast/src/common/mod.rs +++ b/zokrates_ast/src/common/mod.rs @@ -10,6 +10,6 @@ pub use self::embed::FlatEmbed; pub use self::error::RuntimeError; pub use self::metadata::SourceMetadata; pub use self::parameter::Parameter; -pub use self::solvers::Solver; +pub use self::solvers::{RefCall, Solver}; pub use self::variable::Variable; pub use format_string::FormatString; diff --git a/zokrates_ast/src/common/solvers.rs b/zokrates_ast/src/common/solvers.rs index 0551bb6d..9c6e9bbc 100644 --- a/zokrates_ast/src/common/solvers.rs +++ b/zokrates_ast/src/common/solvers.rs @@ -2,6 +2,12 @@ use crate::zir::ZirFunction; use serde::{Deserialize, Serialize}; use std::fmt; +#[derive(Clone, PartialEq, Debug, Serialize, Deserialize, Hash, Eq)] +pub struct RefCall { + pub index: usize, + pub argument_count: usize, +} + #[derive(Clone, PartialEq, Debug, Serialize, Deserialize, Hash, Eq)] pub enum Solver<'ast, T> { ConditionEq, @@ -14,7 +20,7 @@ pub enum Solver<'ast, T> { EuclideanDiv, #[serde(borrow)] Zir(ZirFunction<'ast, T>), - IndexedCall(usize, usize), + Ref(RefCall), #[cfg(feature = "bellman")] Sha256Round, #[cfg(feature = "ark")] @@ -33,7 +39,7 @@ impl<'ast, T> fmt::Display for Solver<'ast, T> { Solver::ShaCh => write!(f, "ShaCh"), Solver::EuclideanDiv => write!(f, "EuclideanDiv"), Solver::Zir(_) => write!(f, "Zir(..)"), - Solver::IndexedCall(index, argc) => write!(f, "IndexedCall@{}({})", index, argc), + Solver::Ref(call) => write!(f, "Ref@{}({})", call.index, call.argument_count), #[cfg(feature = "bellman")] Solver::Sha256Round => write!(f, "Sha256Round"), #[cfg(feature = "ark")] @@ -54,7 +60,7 @@ impl<'ast, T> Solver<'ast, T> { Solver::ShaCh => (3, 1), Solver::EuclideanDiv => (2, 2), Solver::Zir(f) => (f.arguments.len(), 1), - Solver::IndexedCall(_, n) => (*n, 1), + Solver::Ref(c) => (c.argument_count, 1), #[cfg(feature = "bellman")] Solver::Sha256Round => (768, 26935), #[cfg(feature = "ark")] diff --git a/zokrates_ast/src/ir/solver_indexer.rs b/zokrates_ast/src/ir/solver_indexer.rs index 59393d25..cbe1a151 100644 --- a/zokrates_ast/src/ir/solver_indexer.rs +++ b/zokrates_ast/src/ir/solver_indexer.rs @@ -1,3 +1,4 @@ +use crate::common::RefCall; use crate::ir::folder::Folder; use crate::ir::Directive; use crate::ir::Solver; @@ -20,14 +21,14 @@ fn hash(f: &ZirFunction) -> Hash { #[derive(Debug, Default)] pub struct SolverIndexer<'ast, T> { pub solvers: Vec>, - pub index_map: HashMap, + pub index_map: HashMap, } impl<'ast, T: Field> Folder<'ast, T> for SolverIndexer<'ast, T> { fn fold_directive(&mut self, d: Directive<'ast, T>) -> Directive<'ast, T> { match d.solver { Solver::Zir(f) => { - let argc = f.arguments.len(); + let argument_count = f.arguments.len(); let h = hash(&f); let index = match self.index_map.entry(h) { Entry::Occupied(v) => *v.get(), @@ -41,7 +42,10 @@ impl<'ast, T: Field> Folder<'ast, T> for SolverIndexer<'ast, T> { Directive { inputs: d.inputs, outputs: d.outputs, - solver: Solver::IndexedCall(index, argc), + solver: Solver::Ref(RefCall { + index, + argument_count, + }), } } _ => d, diff --git a/zokrates_ast/src/zir/identifier.rs b/zokrates_ast/src/zir/identifier.rs index bc60f566..b036e42a 100644 --- a/zokrates_ast/src/zir/identifier.rs +++ b/zokrates_ast/src/zir/identifier.rs @@ -31,6 +31,12 @@ impl<'ast> fmt::Display for SourceIdentifier<'ast> { } } +impl<'ast> Identifier<'ast> { + pub fn internal>(name: S) -> Self { + Identifier::Internal(name.into()) + } +} + impl<'ast> fmt::Display for Identifier<'ast> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { @@ -40,12 +46,6 @@ impl<'ast> fmt::Display for Identifier<'ast> { } } -impl<'ast> From for Identifier<'ast> { - fn from(id: String) -> Identifier<'ast> { - Identifier::Internal(id) - } -} - // this is only used in tests but somehow cfg(test) does not work impl<'ast> From<&'ast str> for Identifier<'ast> { fn from(id: &'ast str) -> Identifier<'ast> { diff --git a/zokrates_ast/src/zir/substitution.rs b/zokrates_ast/src/zir/substitution.rs index 89a8f82d..d559ccff 100644 --- a/zokrates_ast/src/zir/substitution.rs +++ b/zokrates_ast/src/zir/substitution.rs @@ -1,22 +1,19 @@ use super::{Folder, Identifier}; -use std::{collections::HashMap, marker::PhantomData}; +use std::collections::HashMap; use zokrates_field::Field; -pub struct ZirSubstitutor<'a, 'ast, T> { - substitution: &'a HashMap, Identifier<'ast>>, - _phantom: PhantomData, +#[derive(Default)] +pub struct ZirSubstitutor<'ast> { + substitution: HashMap, Identifier<'ast>>, } -impl<'a, 'ast, T: Field> ZirSubstitutor<'a, 'ast, T> { - pub fn new(substitution: &'a HashMap, Identifier<'ast>>) -> Self { - ZirSubstitutor { - substitution, - _phantom: PhantomData::default(), - } +impl<'ast> ZirSubstitutor<'ast> { + pub fn new(substitution: HashMap, Identifier<'ast>>) -> Self { + Self { substitution } } } -impl<'a, 'ast, T: Field> Folder<'ast, T> for ZirSubstitutor<'a, 'ast, T> { +impl<'ast, T: Field> Folder<'ast, T> for ZirSubstitutor<'ast> { fn fold_name(&mut self, n: Identifier<'ast>) -> Identifier<'ast> { match self.substitution.get(&n) { Some(v) => v.clone(), diff --git a/zokrates_codegen/src/lib.rs b/zokrates_codegen/src/lib.rs index 77a3944f..57aa7fcd 100644 --- a/zokrates_codegen/src/lib.rs +++ b/zokrates_codegen/src/lib.rs @@ -2245,24 +2245,12 @@ impl<'ast, T: Field> Flattener<'ast, T> { let mut substitution_map = HashMap::default(); for (index, p) in function.arguments.iter().enumerate() { - let new_id = format!("i{}", index).into(); + let new_id = Identifier::internal(format!("i{}", index)); substitution_map.insert(p.id.id.clone(), new_id); } - let mut substitutor = ZirSubstitutor::new(&substitution_map); - let function = ZirFunction { - arguments: function - .arguments - .into_iter() - .map(|p| substitutor.fold_parameter(p)) - .collect(), - statements: function - .statements - .into_iter() - .flat_map(|s| substitutor.fold_statement(s)) - .collect(), - signature: function.signature, - }; + let mut substitutor = ZirSubstitutor::new(substitution_map); + let function = substitutor.fold_function(function); let solver = Solver::Zir(function); let directive = FlatDirective::new(outputs, solver, inputs); diff --git a/zokrates_interpreter/src/lib.rs b/zokrates_interpreter/src/lib.rs index 8a7f8854..19a3ef39 100644 --- a/zokrates_interpreter/src/lib.rs +++ b/zokrates_interpreter/src/lib.rs @@ -167,9 +167,9 @@ impl Interpreter { solvers: &[Solver<'ast, T>], ) -> Result, String> { let solver = match solver { - Solver::IndexedCall(index, _) => solvers - .get(*index) - .ok_or_else(|| format!("Could not resolve solver at index {}", index))?, + Solver::Ref(call) => solvers + .get(call.index) + .ok_or_else(|| format!("Could not get solver at index {}", call.index))?, s => s, }; From 475744bf6e0469b714d08f49dd42119fef38a4d9 Mon Sep 17 00:00:00 2001 From: dark64 Date: Mon, 20 Mar 2023 20:21:23 +0100 Subject: [PATCH 11/15] refactor --- zokrates_ast/src/ir/mod.rs | 2 +- zokrates_ast/src/ir/serialize.rs | 168 ++++++++++++------------------- 2 files changed, 67 insertions(+), 103 deletions(-) diff --git a/zokrates_ast/src/ir/mod.rs b/zokrates_ast/src/ir/mod.rs index 66535d3c..7de47396 100644 --- a/zokrates_ast/src/ir/mod.rs +++ b/zokrates_ast/src/ir/mod.rs @@ -20,7 +20,7 @@ mod witness; pub use self::expression::QuadComb; pub use self::expression::{CanonicalLinComb, LinComb}; -pub use self::serialize::ProgEnum; +pub use self::serialize::{ProgEnum, ProgHeader}; pub use crate::common::Parameter; pub use crate::common::RuntimeError; pub use crate::common::Solver; diff --git a/zokrates_ast/src/ir/serialize.rs b/zokrates_ast/src/ir/serialize.rs index 6266662d..37b28f7d 100644 --- a/zokrates_ast/src/ir/serialize.rs +++ b/zokrates_ast/src/ir/serialize.rs @@ -1,7 +1,4 @@ -use crate::{ - ir::{check::UnconstrainedVariableDetector, solver_indexer::SolverIndexer}, - Solver, -}; +use crate::ir::{check::UnconstrainedVariableDetector, solver_indexer::SolverIndexer}; use super::{ProgIterator, Statement}; use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; @@ -65,7 +62,7 @@ impl< #[derive(Debug, PartialEq, Eq, Clone, Copy)] pub enum SectionType { - Arguments = 1, + Parameters = 1, Constraints = 2, Solvers = 3, } @@ -75,7 +72,7 @@ impl TryFrom for SectionType { fn try_from(value: u32) -> Result { match value { - 1 => Ok(SectionType::Arguments), + 1 => Ok(SectionType::Parameters), 2 => Ok(SectionType::Constraints), 3 => Ok(SectionType::Solvers), _ => Err("invalid section type".to_string()), @@ -198,21 +195,23 @@ impl<'ast, T: Field, I: IntoIterator>> ProgIterator<'a w.write_all(&[0u8; HEADER_SIZE])?; // reserve bytes for the header - let arguments = { - let mut section = Section::new(SectionType::Arguments); + // write parameters + if self.arguments.len() > 0 { + let mut section = Section::new(SectionType::Parameters); section.set_offset(w.stream_position()?); serde_cbor::to_writer(&mut w, &self.arguments)?; section.set_length(w.stream_position()? - section.offset); - section - }; + header.sections.push(section); + } let mut solver_indexer: SolverIndexer<'ast, T> = SolverIndexer::default(); let mut unconstrained_variable_detector = UnconstrainedVariableDetector::new(&self); let mut count: usize = 0; - let constraints = { + // write constraints + { let mut section = Section::new(SectionType::Constraints); section.set_offset(w.stream_position()?); @@ -232,24 +231,23 @@ impl<'ast, T: Field, I: IntoIterator>> ProgIterator<'a } section.set_length(w.stream_position()? - section.offset); - section + header.sections.push(section); }; - let solvers = { + // write solvers + if solver_indexer.solvers.len() > 0 { let mut section = Section::new(SectionType::Solvers); section.set_offset(w.stream_position()?); serde_cbor::to_writer(&mut w, &solver_indexer.solvers)?; section.set_length(w.stream_position()? - section.offset); - section - }; + header.sections.push(section); + } header.constraint_count = count as u32; - header - .sections - .extend_from_slice(&[arguments, constraints, solvers]); + // rewind to write the header w.rewind()?; header.write(&mut w)?; @@ -283,6 +281,52 @@ impl<'de, R: Read + Seek> UnwrappedStreamDeserializer<'de, serde_cbor::de::IoRead, Statement<'de, Bw6_761Field>>, > { + fn read( + mut r: R, + header: &ProgHeader, + ) -> ProgIterator< + 'de, + T, + UnwrappedStreamDeserializer<'de, serde_cbor::de::IoRead, Statement<'de, T>>, + > { + let parameters = match header.get_section(SectionType::Parameters) { + Some(section) => { + r.seek(std::io::SeekFrom::Start(section.offset)).unwrap(); + + let mut p = serde_cbor::Deserializer::from_reader(r.by_ref()); + Vec::deserialize(&mut p) + .map_err(|_| String::from("Cannot read parameters")) + .unwrap() + } + None => vec![], + }; + + let solvers = match header.get_section(SectionType::Solvers) { + Some(section) => { + r.seek(std::io::SeekFrom::Start(section.offset)).unwrap(); + + let mut p = serde_cbor::Deserializer::from_reader(r.by_ref()); + Vec::deserialize(&mut p) + .map_err(|_| String::from("Cannot read solvers")) + .unwrap() + } + None => vec![], + }; + + let section = header.get_section(SectionType::Constraints).unwrap(); + r.seek(std::io::SeekFrom::Start(section.offset)).unwrap(); + + let p = serde_cbor::Deserializer::from_reader(r); + let s = p.into_iter::>(); + + ProgIterator::new( + parameters, + UnwrappedStreamDeserializer { s }, + header.return_count as usize, + solvers, + ) + } + pub fn deserialize(mut r: R) -> Result { let header = ProgHeader::read(&mut r).map_err(|_| String::from("Invalid header"))?; @@ -296,95 +340,15 @@ impl<'de, R: Read + Seek> return Err("Invalid file version".to_string()); } - let arguments = { - let section = header.get_section(SectionType::Arguments).unwrap(); - r.seek(std::io::SeekFrom::Start(section.offset)).unwrap(); - - let mut p = serde_cbor::Deserializer::from_reader(r.by_ref()); - Vec::deserialize(&mut p).map_err(|_| String::from("Cannot read parameters"))? - }; - - let solvers_section = header.get_section(SectionType::Solvers).unwrap(); - r.seek(std::io::SeekFrom::Start(solvers_section.offset)) - .unwrap(); - match header.curve_id { m if m == Bls12_381Field::id() => { - let solvers: Vec> = { - let mut p = serde_cbor::Deserializer::from_reader(r.by_ref()); - Vec::deserialize(&mut p).map_err(|_| String::from("Cannot read solvers"))? - }; - - let section = header.get_section(SectionType::Constraints).unwrap(); - r.seek(std::io::SeekFrom::Start(section.offset)).unwrap(); - - let p = serde_cbor::Deserializer::from_reader(r); - let s = p.into_iter::>(); - - Ok(ProgEnum::Bls12_381Program(ProgIterator::new( - arguments, - UnwrappedStreamDeserializer { s }, - header.return_count as usize, - solvers, - ))) - } - m if m == Bn128Field::id() => { - let solvers: Vec> = { - let mut p = serde_cbor::Deserializer::from_reader(r.by_ref()); - Vec::deserialize(&mut p).map_err(|_| String::from("Cannot read solvers"))? - }; - - let section = header.get_section(SectionType::Constraints).unwrap(); - r.seek(std::io::SeekFrom::Start(section.offset)).unwrap(); - - let p = serde_cbor::Deserializer::from_reader(r); - let s = p.into_iter::>(); - - Ok(ProgEnum::Bn128Program(ProgIterator::new( - arguments, - UnwrappedStreamDeserializer { s }, - header.return_count as usize, - solvers, - ))) + Ok(ProgEnum::Bls12_381Program(Self::read(r, &header))) } + m if m == Bn128Field::id() => Ok(ProgEnum::Bn128Program(Self::read(r, &header))), m if m == Bls12_377Field::id() => { - let solvers: Vec> = { - let mut p = serde_cbor::Deserializer::from_reader(r.by_ref()); - Vec::deserialize(&mut p).map_err(|_| String::from("Cannot read solvers"))? - }; - - let section = header.get_section(SectionType::Constraints).unwrap(); - r.seek(std::io::SeekFrom::Start(section.offset)).unwrap(); - - let p = serde_cbor::Deserializer::from_reader(r); - let s = p.into_iter::>(); - - Ok(ProgEnum::Bls12_377Program(ProgIterator::new( - arguments, - UnwrappedStreamDeserializer { s }, - header.return_count as usize, - solvers, - ))) - } - m if m == Bw6_761Field::id() => { - let solvers: Vec> = { - let mut p = serde_cbor::Deserializer::from_reader(r.by_ref()); - Vec::deserialize(&mut p).map_err(|_| String::from("Cannot read solvers"))? - }; - - let section = header.get_section(SectionType::Constraints).unwrap(); - r.seek(std::io::SeekFrom::Start(section.offset)).unwrap(); - - let p = serde_cbor::Deserializer::from_reader(r); - let s = p.into_iter::>(); - - Ok(ProgEnum::Bw6_761Program(ProgIterator::new( - arguments, - UnwrappedStreamDeserializer { s }, - header.return_count as usize, - solvers, - ))) + Ok(ProgEnum::Bls12_377Program(Self::read(r, &header))) } + m if m == Bw6_761Field::id() => Ok(ProgEnum::Bw6_761Program(Self::read(r, &header))), _ => Err(String::from("Unknown curve identifier")), } } From 87f356a7c62f118ee5000a6fd27e6e3c923274b5 Mon Sep 17 00:00:00 2001 From: dark64 Date: Tue, 21 Mar 2023 00:09:37 +0100 Subject: [PATCH 12/15] add test --- zokrates_ast/src/ir/serialize.rs | 4 +-- zokrates_interpreter/src/lib.rs | 42 ++++++++++++++++++++++++++++++++ 2 files changed, 44 insertions(+), 2 deletions(-) diff --git a/zokrates_ast/src/ir/serialize.rs b/zokrates_ast/src/ir/serialize.rs index 37b28f7d..2882171f 100644 --- a/zokrates_ast/src/ir/serialize.rs +++ b/zokrates_ast/src/ir/serialize.rs @@ -196,7 +196,7 @@ impl<'ast, T: Field, I: IntoIterator>> ProgIterator<'a w.write_all(&[0u8; HEADER_SIZE])?; // reserve bytes for the header // write parameters - if self.arguments.len() > 0 { + if !self.arguments.is_empty() { let mut section = Section::new(SectionType::Parameters); section.set_offset(w.stream_position()?); @@ -235,7 +235,7 @@ impl<'ast, T: Field, I: IntoIterator>> ProgIterator<'a }; // write solvers - if solver_indexer.solvers.len() > 0 { + if !solver_indexer.solvers.is_empty() { let mut section = Section::new(SectionType::Solvers); section.set_offset(w.stream_position()?); diff --git a/zokrates_interpreter/src/lib.rs b/zokrates_interpreter/src/lib.rs index 19a3ef39..ccfff818 100644 --- a/zokrates_interpreter/src/lib.rs +++ b/zokrates_interpreter/src/lib.rs @@ -505,4 +505,46 @@ mod tests { assert_eq!(res, expected); } + + #[test] + fn solver_ref() { + use zir::{ + types::{Signature, Type}, + FieldElementExpression, Identifier, IdentifierExpression, Parameter, Variable, + ZirFunction, ZirStatement, + }; + use zokrates_ast::common::RefCall; + + let id = IdentifierExpression::new(Identifier::internal("i0")); + + // (field i0) -> i0 * i0 + let solvers = vec![Solver::Zir(ZirFunction { + arguments: vec![Parameter { + id: Variable::with_id_and_type(id.id.clone(), Type::FieldElement), + private: true, + }], + statements: vec![ZirStatement::Return(vec![FieldElementExpression::Mult( + Box::new(FieldElementExpression::Identifier(id.clone())), + Box::new(FieldElementExpression::Identifier(id.clone())), + ) + .into()])], + signature: Signature::new() + .inputs(vec![Type::FieldElement]) + .outputs(vec![Type::FieldElement]), + })]; + + let inputs = vec![Bn128Field::from(2)]; + let res = Interpreter::execute_solver( + &Solver::Ref(RefCall { + index: 0, + argument_count: 1, + }), + &inputs, + &solvers, + ) + .unwrap(); + + let expected = vec![Bn128Field::from(4)]; + assert_eq!(res, expected); + } } From 95bec7be646ca57e2cfc5ae8a4ef46d4732ec531 Mon Sep 17 00:00:00 2001 From: dark64 Date: Tue, 21 Mar 2023 13:16:07 +0100 Subject: [PATCH 13/15] fix zir substitutor --- zokrates_ast/src/ir/serialize.rs | 125 +++++++++++++-------------- zokrates_ast/src/zir/identifier.rs | 8 +- zokrates_ast/src/zir/substitution.rs | 28 +++--- zokrates_codegen/src/lib.rs | 12 +-- zokrates_interpreter/src/lib.rs | 2 +- 5 files changed, 84 insertions(+), 91 deletions(-) diff --git a/zokrates_ast/src/ir/serialize.rs b/zokrates_ast/src/ir/serialize.rs index 2882171f..f3262429 100644 --- a/zokrates_ast/src/ir/serialize.rs +++ b/zokrates_ast/src/ir/serialize.rs @@ -61,6 +61,7 @@ impl< } #[derive(Debug, PartialEq, Eq, Clone, Copy)] +#[repr(u32)] pub enum SectionType { Parameters = 1, Constraints = 2, @@ -112,7 +113,7 @@ pub struct ProgHeader { pub curve_id: [u8; 4], pub constraint_count: u32, pub return_count: u32, - pub sections: Vec
, + pub sections: [Section; 3], } impl ProgHeader { @@ -123,7 +124,6 @@ impl ProgHeader { w.write_u32::(self.constraint_count)?; w.write_u32::(self.return_count)?; - w.write_u32::(self.sections.len() as u32)?; for s in &self.sections { w.write_u32::(s.ty as u32)?; w.write_u64::(s.offset)?; @@ -146,19 +146,9 @@ impl ProgHeader { let constraint_count = r.read_u32::()?; let return_count = r.read_u32::()?; - let section_count = r.read_u32::()?; - let mut sections = vec![]; - - for _ in 0..section_count { - let id = r.read_u32::()?; - let mut section = Section::new( - SectionType::try_from(id) - .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?, - ); - section.set_offset(r.read_u64::()?); - section.set_length(r.read_u64::()?); - sections.push(section); - } + let parameters = Self::read_section(r.by_ref())?; + let constraints = Self::read_section(r.by_ref())?; + let solvers = Self::read_section(r.by_ref())?; Ok(ProgHeader { magic, @@ -166,12 +156,19 @@ impl ProgHeader { curve_id, constraint_count, return_count, - sections, + sections: [parameters, constraints, solvers], }) } - fn get_section(&self, ty: SectionType) -> Option<&Section> { - self.sections.iter().find(|s| s.ty == ty) + fn read_section(mut r: R) -> std::io::Result
{ + let id = r.read_u32::()?; + let mut section = Section::new( + SectionType::try_from(id) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?, + ); + section.set_offset(r.read_u64::()?); + section.set_length(r.read_u64::()?); + Ok(section) } } @@ -181,37 +178,26 @@ impl<'ast, T: Field, I: IntoIterator>> ProgIterator<'a pub fn serialize(self, mut w: W) -> Result { use super::folder::Folder; - const SECTION_COUNT: usize = 3; - const HEADER_SIZE: usize = 24 + SECTION_COUNT * 20; + // reserve bytes for the header + w.write_all(&[0u8; std::mem::size_of::()])?; - let mut header = ProgHeader { - magic: *ZOKRATES_MAGIC, - version: *FILE_VERSION, - curve_id: T::id(), - constraint_count: 0, - return_count: self.return_count as u32, - sections: Vec::with_capacity(SECTION_COUNT), - }; - - w.write_all(&[0u8; HEADER_SIZE])?; // reserve bytes for the header - - // write parameters - if !self.arguments.is_empty() { + // write parameters section + let parameters = { let mut section = Section::new(SectionType::Parameters); section.set_offset(w.stream_position()?); serde_cbor::to_writer(&mut w, &self.arguments)?; section.set_length(w.stream_position()? - section.offset); - header.sections.push(section); - } + section + }; let mut solver_indexer: SolverIndexer<'ast, T> = SolverIndexer::default(); let mut unconstrained_variable_detector = UnconstrainedVariableDetector::new(&self); let mut count: usize = 0; - // write constraints - { + // write constraints section + let constraints = { let mut section = Section::new(SectionType::Constraints); section.set_offset(w.stream_position()?); @@ -231,21 +217,28 @@ impl<'ast, T: Field, I: IntoIterator>> ProgIterator<'a } section.set_length(w.stream_position()? - section.offset); - header.sections.push(section); + section }; - // write solvers - if !solver_indexer.solvers.is_empty() { + // write solvers section + let solvers = { let mut section = Section::new(SectionType::Solvers); section.set_offset(w.stream_position()?); serde_cbor::to_writer(&mut w, &solver_indexer.solvers)?; section.set_length(w.stream_position()? - section.offset); - header.sections.push(section); - } + section + }; - header.constraint_count = count as u32; + let header = ProgHeader { + magic: *ZOKRATES_MAGIC, + version: *FILE_VERSION, + curve_id: T::id(), + constraint_count: count as u32, + return_count: self.return_count as u32, + sections: [parameters, constraints, solvers], + }; // rewind to write the header w.rewind()?; @@ -289,39 +282,39 @@ impl<'de, R: Read + Seek> T, UnwrappedStreamDeserializer<'de, serde_cbor::de::IoRead, Statement<'de, T>>, > { - let parameters = match header.get_section(SectionType::Parameters) { - Some(section) => { - r.seek(std::io::SeekFrom::Start(section.offset)).unwrap(); + let parameters = { + let section = &header.sections[0]; + r.seek(std::io::SeekFrom::Start(section.offset)).unwrap(); - let mut p = serde_cbor::Deserializer::from_reader(r.by_ref()); - Vec::deserialize(&mut p) - .map_err(|_| String::from("Cannot read parameters")) - .unwrap() - } - None => vec![], + let mut p = serde_cbor::Deserializer::from_reader(r.by_ref()); + Vec::deserialize(&mut p) + .map_err(|_| String::from("Cannot read parameters")) + .unwrap() }; - let solvers = match header.get_section(SectionType::Solvers) { - Some(section) => { - r.seek(std::io::SeekFrom::Start(section.offset)).unwrap(); + let solvers = { + let section = &header.sections[2]; + r.seek(std::io::SeekFrom::Start(section.offset)).unwrap(); - let mut p = serde_cbor::Deserializer::from_reader(r.by_ref()); - Vec::deserialize(&mut p) - .map_err(|_| String::from("Cannot read solvers")) - .unwrap() - } - None => vec![], + let mut p = serde_cbor::Deserializer::from_reader(r.by_ref()); + Vec::deserialize(&mut p) + .map_err(|_| String::from("Cannot read solvers")) + .unwrap() }; - let section = header.get_section(SectionType::Constraints).unwrap(); - r.seek(std::io::SeekFrom::Start(section.offset)).unwrap(); + let statements_deserializer = { + let section = &header.sections[1]; + r.seek(std::io::SeekFrom::Start(section.offset)).unwrap(); - let p = serde_cbor::Deserializer::from_reader(r); - let s = p.into_iter::>(); + let p = serde_cbor::Deserializer::from_reader(r); + let s = p.into_iter::>(); + + UnwrappedStreamDeserializer { s } + }; ProgIterator::new( parameters, - UnwrappedStreamDeserializer { s }, + statements_deserializer, header.return_count as usize, solvers, ) diff --git a/zokrates_ast/src/zir/identifier.rs b/zokrates_ast/src/zir/identifier.rs index b036e42a..f0c387df 100644 --- a/zokrates_ast/src/zir/identifier.rs +++ b/zokrates_ast/src/zir/identifier.rs @@ -8,7 +8,7 @@ use crate::typed::Identifier as CoreIdentifier; pub enum Identifier<'ast> { #[serde(borrow)] Source(SourceIdentifier<'ast>), - Internal(String), + Internal(usize), } #[derive(Debug, PartialEq, Clone, Hash, Eq, PartialOrd, Ord, Serialize, Deserialize)] @@ -32,8 +32,8 @@ impl<'ast> fmt::Display for SourceIdentifier<'ast> { } impl<'ast> Identifier<'ast> { - pub fn internal>(name: S) -> Self { - Identifier::Internal(name.into()) + pub fn internal>(id: T) -> Self { + Identifier::Internal(id.into()) } } @@ -41,7 +41,7 @@ impl<'ast> fmt::Display for Identifier<'ast> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { Identifier::Source(s) => write!(f, "{}", s), - Identifier::Internal(s) => write!(f, "{}", s), + Identifier::Internal(i) => write!(f, "i{}", i), } } } diff --git a/zokrates_ast/src/zir/substitution.rs b/zokrates_ast/src/zir/substitution.rs index d559ccff..d5709c5d 100644 --- a/zokrates_ast/src/zir/substitution.rs +++ b/zokrates_ast/src/zir/substitution.rs @@ -1,23 +1,31 @@ -use super::{Folder, Identifier}; +use super::{Folder, Identifier, Parameter, Variable, ZirAssignee}; use std::collections::HashMap; use zokrates_field::Field; #[derive(Default)] pub struct ZirSubstitutor<'ast> { - substitution: HashMap, Identifier<'ast>>, -} - -impl<'ast> ZirSubstitutor<'ast> { - pub fn new(substitution: HashMap, Identifier<'ast>>) -> Self { - Self { substitution } - } + substitution: HashMap, usize>, } impl<'ast, T: Field> Folder<'ast, T> for ZirSubstitutor<'ast> { + fn fold_parameter(&mut self, p: Parameter<'ast>) -> Parameter<'ast> { + let new_id = self.substitution.len(); + self.substitution.insert(p.id.id.clone(), new_id); + + Parameter { + id: Variable::with_id_and_type(Identifier::internal(new_id), p.id._type), + ..p + } + } + fn fold_assignee(&mut self, a: ZirAssignee<'ast>) -> ZirAssignee<'ast> { + let new_id = self.substitution.len(); + self.substitution.insert(a.id.clone(), new_id); + ZirAssignee::with_id_and_type(Identifier::internal(new_id), a._type) + } fn fold_name(&mut self, n: Identifier<'ast>) -> Identifier<'ast> { match self.substitution.get(&n) { - Some(v) => v.clone(), - None => n, + Some(v) => Identifier::internal(*v), + None => unreachable!(), } } } diff --git a/zokrates_codegen/src/lib.rs b/zokrates_codegen/src/lib.rs index 57aa7fcd..9ec1f426 100644 --- a/zokrates_codegen/src/lib.rs +++ b/zokrates_codegen/src/lib.rs @@ -2243,18 +2243,10 @@ impl<'ast, T: Field> Flattener<'ast, T> { .map(|assignee| self.use_variable(&assignee)) .collect(); - let mut substitution_map = HashMap::default(); - for (index, p) in function.arguments.iter().enumerate() { - let new_id = Identifier::internal(format!("i{}", index)); - substitution_map.insert(p.id.id.clone(), new_id); - } - - let mut substitutor = ZirSubstitutor::new(substitution_map); + let mut substitutor = ZirSubstitutor::default(); let function = substitutor.fold_function(function); - let solver = Solver::Zir(function); - let directive = FlatDirective::new(outputs, solver, inputs); - + let directive = FlatDirective::new(outputs, Solver::Zir(function), inputs); statements_flattened.push_back(FlatStatement::Directive(directive)); } ZirAssemblyStatement::Constraint(lhs, rhs, metadata) => { diff --git a/zokrates_interpreter/src/lib.rs b/zokrates_interpreter/src/lib.rs index ccfff818..73709add 100644 --- a/zokrates_interpreter/src/lib.rs +++ b/zokrates_interpreter/src/lib.rs @@ -515,7 +515,7 @@ mod tests { }; use zokrates_ast::common::RefCall; - let id = IdentifierExpression::new(Identifier::internal("i0")); + let id = IdentifierExpression::new(Identifier::internal(0usize)); // (field i0) -> i0 * i0 let solvers = vec![Solver::Zir(ZirFunction { From e0b029959d8c34eca7eebfba0ca72ac7096fabe3 Mon Sep 17 00:00:00 2001 From: dark64 Date: Fri, 24 Mar 2023 14:39:08 +0100 Subject: [PATCH 14/15] rename substitutor to canonicalizer, add tests --- zokrates_ast/src/zir/canonicalizer.rs | 94 +++++++++++++++++++++++++++ zokrates_ast/src/zir/mod.rs | 2 +- zokrates_ast/src/zir/substitution.rs | 31 --------- zokrates_codegen/src/lib.rs | 6 +- zokrates_interpreter/src/lib.rs | 2 +- 5 files changed, 99 insertions(+), 36 deletions(-) create mode 100644 zokrates_ast/src/zir/canonicalizer.rs delete mode 100644 zokrates_ast/src/zir/substitution.rs diff --git a/zokrates_ast/src/zir/canonicalizer.rs b/zokrates_ast/src/zir/canonicalizer.rs new file mode 100644 index 00000000..ca67d8b5 --- /dev/null +++ b/zokrates_ast/src/zir/canonicalizer.rs @@ -0,0 +1,94 @@ +use super::{Folder, Identifier, Parameter, Variable, ZirAssignee}; +use std::collections::HashMap; +use zokrates_field::Field; + +#[derive(Default)] +pub struct ZirCanonicalizer<'ast> { + identifier_map: HashMap, usize>, +} + +impl<'ast, T: Field> Folder<'ast, T> for ZirCanonicalizer<'ast> { + fn fold_parameter(&mut self, p: Parameter<'ast>) -> Parameter<'ast> { + let new_id = self.identifier_map.len(); + self.identifier_map.insert(p.id.id.clone(), new_id); + + Parameter { + id: Variable::with_id_and_type(Identifier::internal(new_id), p.id._type), + ..p + } + } + fn fold_assignee(&mut self, a: ZirAssignee<'ast>) -> ZirAssignee<'ast> { + let new_id = self.identifier_map.len(); + self.identifier_map.insert(a.id.clone(), new_id); + ZirAssignee::with_id_and_type(Identifier::internal(new_id), a._type) + } + fn fold_name(&mut self, n: Identifier<'ast>) -> Identifier<'ast> { + match self.identifier_map.get(&n) { + Some(v) => Identifier::internal(*v), + None => unreachable!(), + } + } +} + +#[cfg(test)] +mod tests { + use crate::zir::{ + FieldElementExpression, IdentifierExpression, Signature, Type, ZirAssignee, ZirFunction, + ZirStatement, + }; + + use super::*; + use zokrates_field::Bn128Field; + + #[test] + fn canonicalize() { + let func = ZirFunction:: { + arguments: vec![Parameter { + id: Variable::field_element("a"), + private: true, + }], + statements: vec![ + ZirStatement::Definition( + ZirAssignee::field_element("b"), + FieldElementExpression::Identifier(IdentifierExpression::new("a".into())) + .into(), + ), + ZirStatement::Return(vec![FieldElementExpression::Identifier( + IdentifierExpression::new("b".into()), + ) + .into()]), + ], + signature: Signature::new() + .inputs(vec![Type::FieldElement]) + .outputs(vec![Type::FieldElement]), + }; + + let mut canonicalizer = ZirCanonicalizer::default(); + let result = canonicalizer.fold_function(func); + + let expected = ZirFunction:: { + arguments: vec![Parameter { + id: Variable::field_element(Identifier::internal(0usize)), + private: true, + }], + statements: vec![ + ZirStatement::Definition( + ZirAssignee::field_element(Identifier::internal(1usize)), + FieldElementExpression::Identifier(IdentifierExpression::new( + Identifier::internal(0usize), + )) + .into(), + ), + ZirStatement::Return(vec![FieldElementExpression::Identifier( + IdentifierExpression::new(Identifier::internal(1usize)), + ) + .into()]), + ], + signature: Signature::new() + .inputs(vec![Type::FieldElement]) + .outputs(vec![Type::FieldElement]), + }; + + assert_eq!(result, expected); + } +} diff --git a/zokrates_ast/src/zir/mod.rs b/zokrates_ast/src/zir/mod.rs index c2af87b0..08110966 100644 --- a/zokrates_ast/src/zir/mod.rs +++ b/zokrates_ast/src/zir/mod.rs @@ -1,10 +1,10 @@ +pub mod canonicalizer; pub mod folder; mod from_typed; mod identifier; pub mod lqc; mod parameter; pub mod result_folder; -pub mod substitution; pub mod types; mod uint; mod variable; diff --git a/zokrates_ast/src/zir/substitution.rs b/zokrates_ast/src/zir/substitution.rs deleted file mode 100644 index d5709c5d..00000000 --- a/zokrates_ast/src/zir/substitution.rs +++ /dev/null @@ -1,31 +0,0 @@ -use super::{Folder, Identifier, Parameter, Variable, ZirAssignee}; -use std::collections::HashMap; -use zokrates_field::Field; - -#[derive(Default)] -pub struct ZirSubstitutor<'ast> { - substitution: HashMap, usize>, -} - -impl<'ast, T: Field> Folder<'ast, T> for ZirSubstitutor<'ast> { - fn fold_parameter(&mut self, p: Parameter<'ast>) -> Parameter<'ast> { - let new_id = self.substitution.len(); - self.substitution.insert(p.id.id.clone(), new_id); - - Parameter { - id: Variable::with_id_and_type(Identifier::internal(new_id), p.id._type), - ..p - } - } - fn fold_assignee(&mut self, a: ZirAssignee<'ast>) -> ZirAssignee<'ast> { - let new_id = self.substitution.len(); - self.substitution.insert(a.id.clone(), new_id); - ZirAssignee::with_id_and_type(Identifier::internal(new_id), a._type) - } - fn fold_name(&mut self, n: Identifier<'ast>) -> Identifier<'ast> { - match self.substitution.get(&n) { - Some(v) => Identifier::internal(*v), - None => unreachable!(), - } - } -} diff --git a/zokrates_codegen/src/lib.rs b/zokrates_codegen/src/lib.rs index 9ec1f426..a09eb61e 100644 --- a/zokrates_codegen/src/lib.rs +++ b/zokrates_codegen/src/lib.rs @@ -11,7 +11,7 @@ mod utils; use self::utils::flat_expression_from_bits; use zokrates_ast::zir::{ - substitution::ZirSubstitutor, ConditionalExpression, Folder, SelectExpression, ShouldReduce, + canonicalizer::ZirCanonicalizer, ConditionalExpression, Folder, SelectExpression, ShouldReduce, UMetadata, ZirAssemblyStatement, ZirExpressionList, }; use zokrates_interpreter::Interpreter; @@ -2243,8 +2243,8 @@ impl<'ast, T: Field> Flattener<'ast, T> { .map(|assignee| self.use_variable(&assignee)) .collect(); - let mut substitutor = ZirSubstitutor::default(); - let function = substitutor.fold_function(function); + let mut canonicalizer = ZirCanonicalizer::default(); + let function = canonicalizer.fold_function(function); let directive = FlatDirective::new(outputs, Solver::Zir(function), inputs); statements_flattened.push_back(FlatStatement::Directive(directive)); diff --git a/zokrates_interpreter/src/lib.rs b/zokrates_interpreter/src/lib.rs index 73709add..db6ed4a5 100644 --- a/zokrates_interpreter/src/lib.rs +++ b/zokrates_interpreter/src/lib.rs @@ -520,7 +520,7 @@ mod tests { // (field i0) -> i0 * i0 let solvers = vec![Solver::Zir(ZirFunction { arguments: vec![Parameter { - id: Variable::with_id_and_type(id.id.clone(), Type::FieldElement), + id: Variable::field_element(id.id.clone()), private: true, }], statements: vec![ZirStatement::Return(vec![FieldElementExpression::Mult( From 16ef50fa9c51c1558d4dcd59e064b84b2d70e443 Mon Sep 17 00:00:00 2001 From: schaeff Date: Tue, 28 Mar 2023 10:45:53 +0200 Subject: [PATCH 15/15] bump versions, generate changelog --- CHANGELOG.md | 11 +++++++++++ Cargo.lock | 12 ++++++------ changelogs/unreleased/1275-dark64 | 1 - changelogs/unreleased/1277-dark64 | 1 - changelogs/unreleased/1280-dark64 | 1 - changelogs/unreleased/1283-schaeff | 1 - zokrates_analysis/Cargo.toml | 2 +- zokrates_ast/Cargo.toml | 2 +- zokrates_cli/Cargo.toml | 2 +- zokrates_core/Cargo.toml | 2 +- zokrates_interpreter/Cargo.toml | 2 +- zokrates_js/Cargo.toml | 2 +- zokrates_js/package.json | 2 +- 13 files changed, 24 insertions(+), 17 deletions(-) delete mode 100644 changelogs/unreleased/1275-dark64 delete mode 100644 changelogs/unreleased/1277-dark64 delete mode 100644 changelogs/unreleased/1280-dark64 delete mode 100644 changelogs/unreleased/1283-schaeff diff --git a/CHANGELOG.md b/CHANGELOG.md index 4188c1c3..9ea84c10 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,17 @@ All notable changes to this project will be documented in this file. ## [Unreleased] https://github.com/Zokrates/ZoKrates/compare/latest...develop +## [0.8.5] - 2023-03-28 + +### Release +- https://github.com/Zokrates/ZoKrates/releases/tag/0.8.5 + +### Changes +- Reduce memory usage and runtime by refactoring the reducer (ssa, propagation, unrolling and inlining) (#1283, @schaeff) +- Fix `radix-path` help message on `mpc init` subcommand (#1280, @dark64) +- Fix a potential crash in `zokrates-js` due to inefficient serialization of a setup keypair (#1277, @dark64) +- Show help when running `zokrates mpc` (#1275, @dark64) + ## [0.8.4] - 2023-01-31 ### Release diff --git a/Cargo.lock b/Cargo.lock index e2191772..e780f027 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2906,7 +2906,7 @@ dependencies = [ [[package]] name = "zokrates_analysis" -version = "0.1.0" +version = "0.1.1" dependencies = [ "cfg-if 0.1.10", "csv", @@ -2956,7 +2956,7 @@ dependencies = [ [[package]] name = "zokrates_ast" -version = "0.1.4" +version = "0.1.5" dependencies = [ "ark-bls12-377", "cfg-if 0.1.10", @@ -3004,7 +3004,7 @@ dependencies = [ [[package]] name = "zokrates_cli" -version = "0.8.4" +version = "0.8.5" dependencies = [ "assert_cli", "blake2 0.8.1", @@ -3064,7 +3064,7 @@ dependencies = [ [[package]] name = "zokrates_core" -version = "0.7.3" +version = "0.7.4" dependencies = [ "cfg-if 0.1.10", "csv", @@ -3147,7 +3147,7 @@ dependencies = [ [[package]] name = "zokrates_interpreter" -version = "0.1.2" +version = "0.1.3" dependencies = [ "ark-bls12-377", "num", @@ -3163,7 +3163,7 @@ dependencies = [ [[package]] name = "zokrates_js" -version = "1.1.5" +version = "1.1.6" dependencies = [ "console_error_panic_hook", "getrandom", diff --git a/changelogs/unreleased/1275-dark64 b/changelogs/unreleased/1275-dark64 deleted file mode 100644 index 506dce65..00000000 --- a/changelogs/unreleased/1275-dark64 +++ /dev/null @@ -1 +0,0 @@ -Show help when running `zokrates mpc` \ No newline at end of file diff --git a/changelogs/unreleased/1277-dark64 b/changelogs/unreleased/1277-dark64 deleted file mode 100644 index e94ff6a5..00000000 --- a/changelogs/unreleased/1277-dark64 +++ /dev/null @@ -1 +0,0 @@ -Fix a potential crash in `zokrates-js` due to inefficient serialization of a setup keypair diff --git a/changelogs/unreleased/1280-dark64 b/changelogs/unreleased/1280-dark64 deleted file mode 100644 index e6392c71..00000000 --- a/changelogs/unreleased/1280-dark64 +++ /dev/null @@ -1 +0,0 @@ -Fix `radix-path` help message on `mpc init` subcommand \ No newline at end of file diff --git a/changelogs/unreleased/1283-schaeff b/changelogs/unreleased/1283-schaeff deleted file mode 100644 index f4cf3909..00000000 --- a/changelogs/unreleased/1283-schaeff +++ /dev/null @@ -1 +0,0 @@ -Reduce memory usage and runtime by refactoring the reducer (ssa, propagation, unrolling and inlining) \ No newline at end of file diff --git a/zokrates_analysis/Cargo.toml b/zokrates_analysis/Cargo.toml index e347abcd..f93dc7d8 100644 --- a/zokrates_analysis/Cargo.toml +++ b/zokrates_analysis/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "zokrates_analysis" -version = "0.1.0" +version = "0.1.1" edition = "2021" [features] diff --git a/zokrates_ast/Cargo.toml b/zokrates_ast/Cargo.toml index 6d9b4324..60eb498c 100644 --- a/zokrates_ast/Cargo.toml +++ b/zokrates_ast/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "zokrates_ast" -version = "0.1.4" +version = "0.1.5" edition = "2021" [features] diff --git a/zokrates_cli/Cargo.toml b/zokrates_cli/Cargo.toml index 3fa5b5fe..ccf98794 100644 --- a/zokrates_cli/Cargo.toml +++ b/zokrates_cli/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "zokrates_cli" -version = "0.8.4" +version = "0.8.5" authors = ["Jacob Eberhardt ", "Dennis Kuhnert ", "Thibaut Schaeffer "] repository = "https://github.com/Zokrates/ZoKrates.git" edition = "2018" diff --git a/zokrates_core/Cargo.toml b/zokrates_core/Cargo.toml index aa1964c5..2b253004 100644 --- a/zokrates_core/Cargo.toml +++ b/zokrates_core/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "zokrates_core" -version = "0.7.3" +version = "0.7.4" edition = "2021" authors = ["Jacob Eberhardt ", "Dennis Kuhnert "] repository = "https://github.com/Zokrates/ZoKrates" diff --git a/zokrates_interpreter/Cargo.toml b/zokrates_interpreter/Cargo.toml index 41cdfa35..27043668 100644 --- a/zokrates_interpreter/Cargo.toml +++ b/zokrates_interpreter/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "zokrates_interpreter" -version = "0.1.2" +version = "0.1.3" edition = "2021" [features] diff --git a/zokrates_js/Cargo.toml b/zokrates_js/Cargo.toml index 0c86329b..02374289 100644 --- a/zokrates_js/Cargo.toml +++ b/zokrates_js/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "zokrates_js" -version = "1.1.5" +version = "1.1.6" authors = ["Darko Macesic"] edition = "2018" diff --git a/zokrates_js/package.json b/zokrates_js/package.json index 7617f20f..02b6e919 100644 --- a/zokrates_js/package.json +++ b/zokrates_js/package.json @@ -1,6 +1,6 @@ { "name": "zokrates-js", - "version": "1.1.5", + "version": "1.1.6", "module": "index.js", "main": "index-node.js", "description": "JavaScript bindings for ZoKrates",