diff --git a/CHANGELOG.md b/CHANGELOG.md index 4188c1c3..9ea84c10 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,17 @@ All notable changes to this project will be documented in this file. ## [Unreleased] https://github.com/Zokrates/ZoKrates/compare/latest...develop +## [0.8.5] - 2023-03-28 + +### Release +- https://github.com/Zokrates/ZoKrates/releases/tag/0.8.5 + +### Changes +- Reduce memory usage and runtime by refactoring the reducer (ssa, propagation, unrolling and inlining) (#1283, @schaeff) +- Fix `radix-path` help message on `mpc init` subcommand (#1280, @dark64) +- Fix a potential crash in `zokrates-js` due to inefficient serialization of a setup keypair (#1277, @dark64) +- Show help when running `zokrates mpc` (#1275, @dark64) + ## [0.8.4] - 2023-01-31 ### Release diff --git a/Cargo.lock b/Cargo.lock index 1be2fe83..b7952272 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2906,7 +2906,7 @@ dependencies = [ [[package]] name = "zokrates_analysis" -version = "0.1.0" +version = "0.1.1" dependencies = [ "cfg-if 0.1.10", "csv", @@ -2956,9 +2956,10 @@ dependencies = [ [[package]] name = "zokrates_ast" -version = "0.1.4" +version = "0.1.5" dependencies = [ "ark-bls12-377", + "byteorder", "cfg-if 0.1.10", "derivative", "num-bigint 0.2.6", @@ -3003,7 +3004,7 @@ dependencies = [ [[package]] name = "zokrates_cli" -version = "0.8.4" +version = "0.8.5" dependencies = [ "assert_cli", "blake2 0.8.1", @@ -3063,7 +3064,7 @@ dependencies = [ [[package]] name = "zokrates_core" -version = "0.7.3" +version = "0.7.4" dependencies = [ "cfg-if 0.1.10", "csv", @@ -3146,7 +3147,7 @@ dependencies = [ [[package]] name = "zokrates_interpreter" -version = "0.1.2" +version = "0.1.3" dependencies = [ "ark-bls12-377", "num", @@ -3162,7 +3163,7 @@ dependencies = [ [[package]] name = "zokrates_js" -version = "1.1.5" +version = "1.1.6" dependencies = [ "console_error_panic_hook", "getrandom", diff --git a/changelogs/unreleased/1268-dark64 b/changelogs/unreleased/1268-dark64 new file mode 100644 index 00000000..7d2f33da --- /dev/null +++ b/changelogs/unreleased/1268-dark64 @@ -0,0 +1 @@ +Reduce compiled program size by deduplicating assembly solvers \ No newline at end of file diff --git a/changelogs/unreleased/1275-dark64 b/changelogs/unreleased/1275-dark64 deleted file mode 100644 index 506dce65..00000000 --- a/changelogs/unreleased/1275-dark64 +++ /dev/null @@ -1 +0,0 @@ -Show help when running `zokrates mpc` \ No newline at end of file diff --git a/changelogs/unreleased/1277-dark64 b/changelogs/unreleased/1277-dark64 deleted file mode 100644 index e94ff6a5..00000000 --- a/changelogs/unreleased/1277-dark64 +++ /dev/null @@ -1 +0,0 @@ -Fix a potential crash in `zokrates-js` due to inefficient serialization of a setup keypair diff --git a/changelogs/unreleased/1280-dark64 b/changelogs/unreleased/1280-dark64 deleted file mode 100644 index e6392c71..00000000 --- a/changelogs/unreleased/1280-dark64 +++ /dev/null @@ -1 +0,0 @@ -Fix `radix-path` help message on `mpc init` subcommand \ No newline at end of file diff --git a/changelogs/unreleased/1283-schaeff b/changelogs/unreleased/1283-schaeff deleted file mode 100644 index f4cf3909..00000000 --- a/changelogs/unreleased/1283-schaeff +++ /dev/null @@ -1 +0,0 @@ -Reduce memory usage and runtime by refactoring the reducer (ssa, propagation, unrolling and inlining) \ No newline at end of file diff --git a/zokrates_analysis/Cargo.toml b/zokrates_analysis/Cargo.toml index e347abcd..f93dc7d8 100644 --- a/zokrates_analysis/Cargo.toml +++ b/zokrates_analysis/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "zokrates_analysis" -version = "0.1.0" +version = "0.1.1" edition = "2021" [features] diff --git a/zokrates_analysis/src/flat_propagation.rs b/zokrates_analysis/src/flat_propagation.rs index 155d803c..eb2ca797 100644 --- a/zokrates_analysis/src/flat_propagation.rs +++ b/zokrates_analysis/src/flat_propagation.rs @@ -32,7 +32,7 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator { match e { FlatExpression::Number(n) => FlatExpression::Number(n), FlatExpression::Identifier(id) => match self.constants.get(&id) { - Some(c) => FlatExpression::Number(c.clone()), + Some(c) => FlatExpression::Number(*c), None => FlatExpression::Identifier(id), }, FlatExpression::Add(box e1, box e2) => { diff --git a/zokrates_analysis/src/flatten_complex_types.rs b/zokrates_analysis/src/flatten_complex_types.rs index 0b834ef8..578cd8ee 100644 --- a/zokrates_analysis/src/flatten_complex_types.rs +++ b/zokrates_analysis/src/flatten_complex_types.rs @@ -1,4 +1,4 @@ -use std::collections::HashMap; +use std::collections::BTreeMap; use std::convert::{TryFrom, TryInto}; use std::marker::PhantomData; use zokrates_ast::typed::types::{ConcreteArrayType, IntoType, UBitwidth}; @@ -481,7 +481,7 @@ impl<'ast, T: Field> Flattener { // } #[derive(Default)] pub struct ArgumentFinder<'ast, T> { - pub identifiers: HashMap, zir::Type>, + pub identifiers: BTreeMap, zir::Type>, _phantom: PhantomData, } diff --git a/zokrates_analysis/src/propagation.rs b/zokrates_analysis/src/propagation.rs index 7d77e86d..1385a243 100644 --- a/zokrates_analysis/src/propagation.rs +++ b/zokrates_analysis/src/propagation.rs @@ -508,7 +508,7 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for Propagator<'ast, T> { .unwrap() { FieldElementExpression::Number(num) => { - let mut acc = num.clone(); + let mut acc = num; let mut res = vec![]; for i in (0..bit_width as usize).rev() { diff --git a/zokrates_analysis/src/uint_optimizer.rs b/zokrates_analysis/src/uint_optimizer.rs index ac96dbfe..6a6b217e 100644 --- a/zokrates_analysis/src/uint_optimizer.rs +++ b/zokrates_analysis/src/uint_optimizer.rs @@ -170,7 +170,7 @@ impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> { left_max .checked_add(&range_max.clone()) .map(|max| (false, true, max)) - .unwrap_or_else(|| (true, true, range_max.clone() + range_max)) + .unwrap_or_else(|| (true, true, range_max + range_max)) }) }); @@ -223,7 +223,7 @@ impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> { left_max .checked_add(&target_offset) .map(|max| (false, true, max)) - .unwrap_or_else(|| (true, true, range_max.clone() + target_offset)) + .unwrap_or_else(|| (true, true, range_max + target_offset)) } else { left_max .checked_add(&offset) @@ -294,7 +294,7 @@ impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> { left_max .checked_mul(&range_max.clone()) .map(|max| (false, true, max)) - .unwrap_or_else(|| (true, true, range_max.clone() * range_max)) + .unwrap_or_else(|| (true, true, range_max * range_max)) }) }); diff --git a/zokrates_ark/src/gm17.rs b/zokrates_ark/src/gm17.rs index d1347c1a..b756913a 100644 --- a/zokrates_ark/src/gm17.rs +++ b/zokrates_ark/src/gm17.rs @@ -127,6 +127,7 @@ mod tests { arguments: vec![Parameter::public(Variable::new(0))], return_count: 1, statements: vec![Statement::constraint(Variable::new(0), Variable::public(0))], + solvers: vec![], }; let rng = &mut StdRng::from_entropy(); @@ -155,6 +156,7 @@ mod tests { arguments: vec![Parameter::public(Variable::new(0))], return_count: 1, statements: vec![Statement::constraint(Variable::new(0), Variable::public(0))], + solvers: vec![], }; let rng = &mut StdRng::from_entropy(); diff --git a/zokrates_ark/src/groth16.rs b/zokrates_ark/src/groth16.rs index f24d8209..3c7b4b80 100644 --- a/zokrates_ark/src/groth16.rs +++ b/zokrates_ark/src/groth16.rs @@ -126,6 +126,7 @@ mod tests { arguments: vec![Parameter::public(Variable::new(0))], return_count: 1, statements: vec![Statement::constraint(Variable::new(0), Variable::public(0))], + solvers: vec![], }; let rng = &mut StdRng::from_entropy(); @@ -154,6 +155,7 @@ mod tests { arguments: vec![Parameter::public(Variable::new(0))], return_count: 1, statements: vec![Statement::constraint(Variable::new(0), Variable::public(0))], + solvers: vec![], }; let rng = &mut StdRng::from_entropy(); diff --git a/zokrates_ark/src/lib.rs b/zokrates_ark/src/lib.rs index 586e09b2..e466dd1f 100644 --- a/zokrates_ark/src/lib.rs +++ b/zokrates_ark/src/lib.rs @@ -135,7 +135,7 @@ impl<'a, T: Field + ArkFieldExtensions, I: IntoIterator> self.program .public_inputs_values(self.witness.as_ref().unwrap()) .iter() - .map(|v| v.clone().into_ark()) + .map(|v| v.into_ark()) .collect() } } diff --git a/zokrates_ark/src/marlin.rs b/zokrates_ark/src/marlin.rs index d4eeb1c4..0750c2b8 100644 --- a/zokrates_ark/src/marlin.rs +++ b/zokrates_ark/src/marlin.rs @@ -410,6 +410,7 @@ mod tests { ), Statement::constraint(Variable::new(1), Variable::public(0)), ], + solvers: vec![], }; let rng = &mut StdRng::from_entropy(); @@ -448,6 +449,7 @@ mod tests { ), Statement::constraint(Variable::new(1), Variable::public(0)), ], + solvers: vec![], }; let rng = &mut StdRng::from_entropy(); diff --git a/zokrates_ast/Cargo.toml b/zokrates_ast/Cargo.toml index c19392aa..b135b06f 100644 --- a/zokrates_ast/Cargo.toml +++ b/zokrates_ast/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "zokrates_ast" -version = "0.1.4" +version = "0.1.5" edition = "2021" [features] @@ -9,6 +9,7 @@ bellman = ["zokrates_field/bellman", "pairing_ce", "zokrates_embed/bellman"] ark = ["ark-bls12-377", "zokrates_embed/ark"] [dependencies] +byteorder = "1.4.3" zokrates_pest_ast = { version = "0.3.0", path = "../zokrates_pest_ast" } cfg-if = "0.1" zokrates_field = { version = "0.5", path = "../zokrates_field", default-features = false } diff --git a/zokrates_ast/src/common/mod.rs b/zokrates_ast/src/common/mod.rs index 13d23bfd..c8ad0944 100644 --- a/zokrates_ast/src/common/mod.rs +++ b/zokrates_ast/src/common/mod.rs @@ -10,6 +10,6 @@ pub use self::embed::FlatEmbed; pub use self::error::RuntimeError; pub use self::metadata::SourceMetadata; pub use self::parameter::Parameter; -pub use self::solvers::Solver; +pub use self::solvers::{RefCall, Solver}; pub use self::variable::Variable; pub use format_string::FormatString; diff --git a/zokrates_ast/src/common/solvers.rs b/zokrates_ast/src/common/solvers.rs index 9b4f5c90..9c6e9bbc 100644 --- a/zokrates_ast/src/common/solvers.rs +++ b/zokrates_ast/src/common/solvers.rs @@ -2,6 +2,12 @@ use crate::zir::ZirFunction; use serde::{Deserialize, Serialize}; use std::fmt; +#[derive(Clone, PartialEq, Debug, Serialize, Deserialize, Hash, Eq)] +pub struct RefCall { + pub index: usize, + pub argument_count: usize, +} + #[derive(Clone, PartialEq, Debug, Serialize, Deserialize, Hash, Eq)] pub enum Solver<'ast, T> { ConditionEq, @@ -14,6 +20,7 @@ pub enum Solver<'ast, T> { EuclideanDiv, #[serde(borrow)] Zir(ZirFunction<'ast, T>), + Ref(RefCall), #[cfg(feature = "bellman")] Sha256Round, #[cfg(feature = "ark")] @@ -32,6 +39,7 @@ impl<'ast, T> fmt::Display for Solver<'ast, T> { Solver::ShaCh => write!(f, "ShaCh"), Solver::EuclideanDiv => write!(f, "EuclideanDiv"), Solver::Zir(_) => write!(f, "Zir(..)"), + Solver::Ref(call) => write!(f, "Ref@{}({})", call.index, call.argument_count), #[cfg(feature = "bellman")] Solver::Sha256Round => write!(f, "Sha256Round"), #[cfg(feature = "ark")] @@ -52,6 +60,7 @@ impl<'ast, T> Solver<'ast, T> { Solver::ShaCh => (3, 1), Solver::EuclideanDiv => (2, 2), Solver::Zir(f) => (f.arguments.len(), 1), + Solver::Ref(c) => (c.argument_count, 1), #[cfg(feature = "bellman")] Solver::Sha256Round => (768, 26935), #[cfg(feature = "ark")] diff --git a/zokrates_ast/src/flat/utils.rs b/zokrates_ast/src/flat/utils.rs index 03239687..c37d2f68 100644 --- a/zokrates_ast/src/flat/utils.rs +++ b/zokrates_ast/src/flat/utils.rs @@ -36,7 +36,7 @@ pub fn flat_expression_from_variable_summands(v: &[(T, usize)]) -> Fla match v.len() { 0 => FlatExpression::Number(T::zero()), 1 => { - let (val, var) = v[0].clone(); + let (val, var) = v[0]; FlatExpression::Mult( box FlatExpression::Number(val), box FlatExpression::Identifier(Variable::new(var)), diff --git a/zokrates_ast/src/ir/clean.rs b/zokrates_ast/src/ir/clean.rs index b4fb8f44..998e3298 100644 --- a/zokrates_ast/src/ir/clean.rs +++ b/zokrates_ast/src/ir/clean.rs @@ -14,6 +14,7 @@ impl<'ast, T: Field, I: IntoIterator>> ProgIterator<'a .statements .into_iter() .flat_map(|s| Cleaner::default().fold_statement(s)), + solvers: self.solvers, } } } diff --git a/zokrates_ast/src/ir/expression.rs b/zokrates_ast/src/ir/expression.rs index a32a1293..33e3856e 100644 --- a/zokrates_ast/src/ir/expression.rs +++ b/zokrates_ast/src/ir/expression.rs @@ -165,8 +165,8 @@ impl LinComb { match acc.entry(val) { Entry::Occupied(o) => { // if the new value is non zero, update, else remove the term entirely - if o.get().clone() + coeff.clone() != T::zero() { - *o.into_mut() = o.get().clone() + coeff; + if *o.get() + coeff != T::zero() { + *o.into_mut() = *o.get() + coeff; } else { o.remove(); } @@ -258,7 +258,7 @@ impl Mul<&T> for LinComb { LinComb( self.0 .into_iter() - .map(|(var, coeff)| (var, coeff * scalar.clone())) + .map(|(var, coeff)| (var, coeff * scalar)) .collect(), ) } diff --git a/zokrates_ast/src/ir/folder.rs b/zokrates_ast/src/ir/folder.rs index 6e67c15d..99466add 100644 --- a/zokrates_ast/src/ir/folder.rs +++ b/zokrates_ast/src/ir/folder.rs @@ -50,6 +50,7 @@ pub fn fold_program<'ast, T: Field, F: Folder<'ast, T>>( .flat_map(|s| f.fold_statement(s)) .collect(), return_count: p.return_count, + solvers: p.solvers, } } diff --git a/zokrates_ast/src/ir/from_flat.rs b/zokrates_ast/src/ir/from_flat.rs index fc961cd8..3404c6c4 100644 --- a/zokrates_ast/src/ir/from_flat.rs +++ b/zokrates_ast/src/ir/from_flat.rs @@ -24,6 +24,7 @@ pub fn from_flat<'ast, T: Field, I: IntoIterator>> statements: flat_prog_iterator.statements.into_iter().map(Into::into), arguments: flat_prog_iterator.arguments, return_count: flat_prog_iterator.return_count, + solvers: vec![], } } diff --git a/zokrates_ast/src/ir/mod.rs b/zokrates_ast/src/ir/mod.rs index 78b48f80..7de47396 100644 --- a/zokrates_ast/src/ir/mod.rs +++ b/zokrates_ast/src/ir/mod.rs @@ -14,12 +14,13 @@ pub mod folder; pub mod from_flat; mod serialize; pub mod smtlib2; +mod solver_indexer; pub mod visitor; mod witness; pub use self::expression::QuadComb; pub use self::expression::{CanonicalLinComb, LinComb}; -pub use self::serialize::ProgEnum; +pub use self::serialize::{ProgEnum, ProgHeader}; pub use crate::common::Parameter; pub use crate::common::RuntimeError; pub use crate::common::Solver; @@ -130,14 +131,22 @@ pub struct ProgIterator<'ast, T, I: IntoIterator>> { pub arguments: Vec, pub return_count: usize, pub statements: I, + #[serde(borrow)] + pub solvers: Vec>, } impl<'ast, T, I: IntoIterator>> ProgIterator<'ast, T, I> { - pub fn new(arguments: Vec, statements: I, return_count: usize) -> Self { + pub fn new( + arguments: Vec, + statements: I, + return_count: usize, + solvers: Vec>, + ) -> Self { Self { arguments, return_count, statements, + solvers, } } @@ -146,6 +155,7 @@ impl<'ast, T, I: IntoIterator>> ProgIterator<'ast, T, statements: self.statements.into_iter().collect::>(), arguments: self.arguments, return_count: self.return_count, + solvers: self.solvers, } } @@ -171,7 +181,7 @@ impl<'ast, T: Field, I: IntoIterator>> ProgIterator<'a self.arguments .iter() .filter(|p| !p.private) - .map(|p| witness.0.get(&p.id).unwrap().clone()) + .map(|p| *witness.0.get(&p.id).unwrap()) .chain(witness.return_values()) .collect() } @@ -192,6 +202,7 @@ impl<'ast, T> Prog<'ast, T> { statements: self.statements.into_iter(), arguments: self.arguments, return_count: self.return_count, + solvers: self.solvers, } } } diff --git a/zokrates_ast/src/ir/serialize.rs b/zokrates_ast/src/ir/serialize.rs index 09d00390..f3262429 100644 --- a/zokrates_ast/src/ir/serialize.rs +++ b/zokrates_ast/src/ir/serialize.rs @@ -1,14 +1,16 @@ -use crate::ir::check::UnconstrainedVariableDetector; +use crate::ir::{check::UnconstrainedVariableDetector, solver_indexer::SolverIndexer}; use super::{ProgIterator, Statement}; +use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; +use serde::Deserialize; use serde_cbor::{self, StreamDeserializer}; -use std::io::{Read, Write}; +use std::io::{Read, Seek, Write}; use zokrates_field::*; type DynamicError = Box; const ZOKRATES_MAGIC: &[u8; 4] = &[0x5a, 0x4f, 0x4b, 0]; -const ZOKRATES_VERSION_2: &[u8; 4] = &[0, 0, 0, 2]; +const FILE_VERSION: &[u8; 4] = &[3, 0, 0, 0]; #[derive(PartialEq, Eq, Debug)] pub enum ProgEnum< @@ -58,33 +60,189 @@ impl< } } +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +#[repr(u32)] +pub enum SectionType { + Parameters = 1, + Constraints = 2, + Solvers = 3, +} + +impl TryFrom for SectionType { + type Error = String; + + fn try_from(value: u32) -> Result { + match value { + 1 => Ok(SectionType::Parameters), + 2 => Ok(SectionType::Constraints), + 3 => Ok(SectionType::Solvers), + _ => Err("invalid section type".to_string()), + } + } +} + +#[derive(Debug, Clone)] +pub struct Section { + pub ty: SectionType, + pub offset: u64, + pub length: u64, +} + +impl Section { + pub fn new(ty: SectionType) -> Self { + Self { + ty, + offset: 0, + length: 0, + } + } + + pub fn set_offset(&mut self, offset: u64) { + self.offset = offset; + } + + pub fn set_length(&mut self, length: u64) { + self.length = length; + } +} + +#[derive(Debug, Clone)] +pub struct ProgHeader { + pub magic: [u8; 4], + pub version: [u8; 4], + pub curve_id: [u8; 4], + pub constraint_count: u32, + pub return_count: u32, + pub sections: [Section; 3], +} + +impl ProgHeader { + pub fn write(&self, mut w: W) -> std::io::Result<()> { + w.write_all(&self.magic)?; + w.write_all(&self.version)?; + w.write_all(&self.curve_id)?; + w.write_u32::(self.constraint_count)?; + w.write_u32::(self.return_count)?; + + for s in &self.sections { + w.write_u32::(s.ty as u32)?; + w.write_u64::(s.offset)?; + w.write_u64::(s.length)?; + } + + Ok(()) + } + + pub fn read(mut r: R) -> std::io::Result { + let mut magic = [0; 4]; + r.read_exact(&mut magic)?; + + let mut version = [0; 4]; + r.read_exact(&mut version)?; + + let mut curve_id = [0; 4]; + r.read_exact(&mut curve_id)?; + + let constraint_count = r.read_u32::()?; + let return_count = r.read_u32::()?; + + let parameters = Self::read_section(r.by_ref())?; + let constraints = Self::read_section(r.by_ref())?; + let solvers = Self::read_section(r.by_ref())?; + + Ok(ProgHeader { + magic, + version, + curve_id, + constraint_count, + return_count, + sections: [parameters, constraints, solvers], + }) + } + + fn read_section(mut r: R) -> std::io::Result
{ + let id = r.read_u32::()?; + let mut section = Section::new( + SectionType::try_from(id) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?, + ); + section.set_offset(r.read_u64::()?); + section.set_length(r.read_u64::()?); + Ok(section) + } +} + impl<'ast, T: Field, I: IntoIterator>> ProgIterator<'ast, T, I> { /// serialize a program iterator, returning the number of constraints serialized /// Note that we only return constraints, not other statements such as directives - pub fn serialize(self, mut w: W) -> Result { + pub fn serialize(self, mut w: W) -> Result { use super::folder::Folder; - w.write_all(ZOKRATES_MAGIC)?; - w.write_all(ZOKRATES_VERSION_2)?; - w.write_all(&T::id())?; + // reserve bytes for the header + w.write_all(&[0u8; std::mem::size_of::()])?; - serde_cbor::to_writer(&mut w, &self.arguments)?; - serde_cbor::to_writer(&mut w, &self.return_count)?; + // write parameters section + let parameters = { + let mut section = Section::new(SectionType::Parameters); + section.set_offset(w.stream_position()?); + serde_cbor::to_writer(&mut w, &self.arguments)?; + + section.set_length(w.stream_position()? - section.offset); + section + }; + + let mut solver_indexer: SolverIndexer<'ast, T> = SolverIndexer::default(); let mut unconstrained_variable_detector = UnconstrainedVariableDetector::new(&self); + let mut count: usize = 0; - let statements = self.statements.into_iter(); + // write constraints section + let constraints = { + let mut section = Section::new(SectionType::Constraints); + section.set_offset(w.stream_position()?); - let mut count = 0; - for s in statements { - if matches!(s, Statement::Constraint(..)) { - count += 1; + let statements = self.statements.into_iter(); + for s in statements { + if matches!(s, Statement::Constraint(..)) { + count += 1; + } + let s: Vec> = solver_indexer + .fold_statement(s) + .into_iter() + .flat_map(|s| unconstrained_variable_detector.fold_statement(s)) + .collect(); + for s in s { + serde_cbor::to_writer(&mut w, &s)?; + } } - let s = unconstrained_variable_detector.fold_statement(s); - for s in s { - serde_cbor::to_writer(&mut w, &s)?; - } - } + + section.set_length(w.stream_position()? - section.offset); + section + }; + + // write solvers section + let solvers = { + let mut section = Section::new(SectionType::Solvers); + section.set_offset(w.stream_position()?); + + serde_cbor::to_writer(&mut w, &solver_indexer.solvers)?; + + section.set_length(w.stream_position()? - section.offset); + section + }; + + let header = ProgHeader { + magic: *ZOKRATES_MAGIC, + version: *FILE_VERSION, + curve_id: T::id(), + constraint_count: count as u32, + return_count: self.return_count as u32, + sections: [parameters, constraints, solvers], + }; + + // rewind to write the header + w.rewind()?; + header.write(&mut w)?; unconstrained_variable_detector .finalize() @@ -103,11 +261,11 @@ impl<'de, R: serde_cbor::de::Read<'de>, T: serde::Deserialize<'de>> Iterator type Item = T; fn next(&mut self) -> Option { - self.s.next().transpose().unwrap() + self.s.next().and_then(|v| v.ok()) } } -impl<'de, R: Read> +impl<'de, R: Read + Seek> ProgEnum< 'de, UnwrappedStreamDeserializer<'de, serde_cbor::de::IoRead, Statement<'de, Bls12_381Field>>, @@ -116,125 +274,75 @@ impl<'de, R: Read> UnwrappedStreamDeserializer<'de, serde_cbor::de::IoRead, Statement<'de, Bw6_761Field>>, > { + fn read( + mut r: R, + header: &ProgHeader, + ) -> ProgIterator< + 'de, + T, + UnwrappedStreamDeserializer<'de, serde_cbor::de::IoRead, Statement<'de, T>>, + > { + let parameters = { + let section = &header.sections[0]; + r.seek(std::io::SeekFrom::Start(section.offset)).unwrap(); + + let mut p = serde_cbor::Deserializer::from_reader(r.by_ref()); + Vec::deserialize(&mut p) + .map_err(|_| String::from("Cannot read parameters")) + .unwrap() + }; + + let solvers = { + let section = &header.sections[2]; + r.seek(std::io::SeekFrom::Start(section.offset)).unwrap(); + + let mut p = serde_cbor::Deserializer::from_reader(r.by_ref()); + Vec::deserialize(&mut p) + .map_err(|_| String::from("Cannot read solvers")) + .unwrap() + }; + + let statements_deserializer = { + let section = &header.sections[1]; + r.seek(std::io::SeekFrom::Start(section.offset)).unwrap(); + + let p = serde_cbor::Deserializer::from_reader(r); + let s = p.into_iter::>(); + + UnwrappedStreamDeserializer { s } + }; + + ProgIterator::new( + parameters, + statements_deserializer, + header.return_count as usize, + solvers, + ) + } + pub fn deserialize(mut r: R) -> Result { + let header = ProgHeader::read(&mut r).map_err(|_| String::from("Invalid header"))?; + // Check the magic number, `ZOK` - let mut magic = [0; 4]; - r.read_exact(&mut magic) - .map_err(|_| String::from("Cannot read magic number"))?; + if &header.magic != ZOKRATES_MAGIC { + return Err("Invalid magic number".to_string()); + } - if &magic == ZOKRATES_MAGIC { - // Check the version, 2 - let mut version = [0; 4]; - r.read_exact(&mut version) - .map_err(|_| String::from("Cannot read version"))?; + // Check the file version + if &header.version != FILE_VERSION { + return Err("Invalid file version".to_string()); + } - if &version == ZOKRATES_VERSION_2 { - // Check the curve identifier, deserializing accordingly - let mut curve = [0; 4]; - r.read_exact(&mut curve) - .map_err(|_| String::from("Cannot read curve identifier"))?; - - use serde::de::Deserializer; - let mut p = serde_cbor::Deserializer::from_reader(r); - - struct ArgumentsVisitor; - - impl<'de> serde::de::Visitor<'de> for ArgumentsVisitor { - type Value = Vec; - fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { - formatter.write_str("seq of flat param") - } - - fn visit_seq(self, mut seq: A) -> Result - where - A: serde::de::SeqAccess<'de>, - { - let mut res = vec![]; - while let Some(e) = seq.next_element().unwrap() { - res.push(e); - } - Ok(res) - } - } - - let arguments = p.deserialize_seq(ArgumentsVisitor).unwrap(); - - struct ReturnCountVisitor; - - impl<'de> serde::de::Visitor<'de> for ReturnCountVisitor { - type Value = usize; - fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { - formatter.write_str("usize") - } - - fn visit_u32(self, v: u32) -> Result - where - E: serde::de::Error, - { - Ok(v as usize) - } - - fn visit_u8(self, v: u8) -> Result - where - E: serde::de::Error, - { - Ok(v as usize) - } - - fn visit_u16(self, v: u16) -> Result - where - E: serde::de::Error, - { - Ok(v as usize) - } - } - - let return_count = p.deserialize_u32(ReturnCountVisitor).unwrap(); - - match curve { - m if m == Bls12_381Field::id() => { - let s = p.into_iter::>(); - - Ok(ProgEnum::Bls12_381Program(ProgIterator::new( - arguments, - UnwrappedStreamDeserializer { s }, - return_count, - ))) - } - m if m == Bn128Field::id() => { - let s = p.into_iter::>(); - - Ok(ProgEnum::Bn128Program(ProgIterator::new( - arguments, - UnwrappedStreamDeserializer { s }, - return_count, - ))) - } - m if m == Bls12_377Field::id() => { - let s = p.into_iter::>(); - - Ok(ProgEnum::Bls12_377Program(ProgIterator::new( - arguments, - UnwrappedStreamDeserializer { s }, - return_count, - ))) - } - m if m == Bw6_761Field::id() => { - let s = p.into_iter::>(); - - Ok(ProgEnum::Bw6_761Program(ProgIterator::new( - arguments, - UnwrappedStreamDeserializer { s }, - return_count, - ))) - } - _ => Err(String::from("Unknown curve identifier")), - } - } else { - Err(String::from("Unknown version")) + match header.curve_id { + m if m == Bls12_381Field::id() => { + Ok(ProgEnum::Bls12_381Program(Self::read(r, &header))) } - } else { - Err(String::from("Wrong magic number")) + m if m == Bn128Field::id() => Ok(ProgEnum::Bn128Program(Self::read(r, &header))), + m if m == Bls12_377Field::id() => { + Ok(ProgEnum::Bls12_377Program(Self::read(r, &header))) + } + m if m == Bw6_761Field::id() => Ok(ProgEnum::Bw6_761Program(Self::read(r, &header))), + _ => Err(String::from("Unknown curve identifier")), } } } diff --git a/zokrates_ast/src/ir/solver_indexer.rs b/zokrates_ast/src/ir/solver_indexer.rs new file mode 100644 index 00000000..cbe1a151 --- /dev/null +++ b/zokrates_ast/src/ir/solver_indexer.rs @@ -0,0 +1,54 @@ +use crate::common::RefCall; +use crate::ir::folder::Folder; +use crate::ir::Directive; +use crate::ir::Solver; +use crate::zir::ZirFunction; +use std::collections::hash_map::DefaultHasher; +use std::collections::hash_map::Entry; +use std::collections::HashMap; +use zokrates_field::Field; + +type Hash = u64; + +fn hash(f: &ZirFunction) -> Hash { + use std::hash::Hash; + use std::hash::Hasher; + let mut hasher = DefaultHasher::new(); + f.hash(&mut hasher); + hasher.finish() +} + +#[derive(Debug, Default)] +pub struct SolverIndexer<'ast, T> { + pub solvers: Vec>, + pub index_map: HashMap, +} + +impl<'ast, T: Field> Folder<'ast, T> for SolverIndexer<'ast, T> { + fn fold_directive(&mut self, d: Directive<'ast, T>) -> Directive<'ast, T> { + match d.solver { + Solver::Zir(f) => { + let argument_count = f.arguments.len(); + let h = hash(&f); + let index = match self.index_map.entry(h) { + Entry::Occupied(v) => *v.get(), + Entry::Vacant(entry) => { + let index = self.solvers.len(); + entry.insert(index); + self.solvers.push(Solver::Zir(f)); + index + } + }; + Directive { + inputs: d.inputs, + outputs: d.outputs, + solver: Solver::Ref(RefCall { + index, + argument_count, + }), + } + } + _ => d, + } + } +} diff --git a/zokrates_ast/src/zir/canonicalizer.rs b/zokrates_ast/src/zir/canonicalizer.rs new file mode 100644 index 00000000..ca67d8b5 --- /dev/null +++ b/zokrates_ast/src/zir/canonicalizer.rs @@ -0,0 +1,94 @@ +use super::{Folder, Identifier, Parameter, Variable, ZirAssignee}; +use std::collections::HashMap; +use zokrates_field::Field; + +#[derive(Default)] +pub struct ZirCanonicalizer<'ast> { + identifier_map: HashMap, usize>, +} + +impl<'ast, T: Field> Folder<'ast, T> for ZirCanonicalizer<'ast> { + fn fold_parameter(&mut self, p: Parameter<'ast>) -> Parameter<'ast> { + let new_id = self.identifier_map.len(); + self.identifier_map.insert(p.id.id.clone(), new_id); + + Parameter { + id: Variable::with_id_and_type(Identifier::internal(new_id), p.id._type), + ..p + } + } + fn fold_assignee(&mut self, a: ZirAssignee<'ast>) -> ZirAssignee<'ast> { + let new_id = self.identifier_map.len(); + self.identifier_map.insert(a.id.clone(), new_id); + ZirAssignee::with_id_and_type(Identifier::internal(new_id), a._type) + } + fn fold_name(&mut self, n: Identifier<'ast>) -> Identifier<'ast> { + match self.identifier_map.get(&n) { + Some(v) => Identifier::internal(*v), + None => unreachable!(), + } + } +} + +#[cfg(test)] +mod tests { + use crate::zir::{ + FieldElementExpression, IdentifierExpression, Signature, Type, ZirAssignee, ZirFunction, + ZirStatement, + }; + + use super::*; + use zokrates_field::Bn128Field; + + #[test] + fn canonicalize() { + let func = ZirFunction:: { + arguments: vec![Parameter { + id: Variable::field_element("a"), + private: true, + }], + statements: vec![ + ZirStatement::Definition( + ZirAssignee::field_element("b"), + FieldElementExpression::Identifier(IdentifierExpression::new("a".into())) + .into(), + ), + ZirStatement::Return(vec![FieldElementExpression::Identifier( + IdentifierExpression::new("b".into()), + ) + .into()]), + ], + signature: Signature::new() + .inputs(vec![Type::FieldElement]) + .outputs(vec![Type::FieldElement]), + }; + + let mut canonicalizer = ZirCanonicalizer::default(); + let result = canonicalizer.fold_function(func); + + let expected = ZirFunction:: { + arguments: vec![Parameter { + id: Variable::field_element(Identifier::internal(0usize)), + private: true, + }], + statements: vec![ + ZirStatement::Definition( + ZirAssignee::field_element(Identifier::internal(1usize)), + FieldElementExpression::Identifier(IdentifierExpression::new( + Identifier::internal(0usize), + )) + .into(), + ), + ZirStatement::Return(vec![FieldElementExpression::Identifier( + IdentifierExpression::new(Identifier::internal(1usize)), + ) + .into()]), + ], + signature: Signature::new() + .inputs(vec![Type::FieldElement]) + .outputs(vec![Type::FieldElement]), + }; + + assert_eq!(result, expected); + } +} diff --git a/zokrates_ast/src/zir/identifier.rs b/zokrates_ast/src/zir/identifier.rs index 249b2630..f0c387df 100644 --- a/zokrates_ast/src/zir/identifier.rs +++ b/zokrates_ast/src/zir/identifier.rs @@ -4,13 +4,14 @@ use std::fmt; use crate::typed::Identifier as CoreIdentifier; -#[derive(Debug, PartialEq, Clone, Hash, Eq, Serialize, Deserialize)] +#[derive(Debug, PartialEq, Clone, Hash, Eq, PartialOrd, Ord, Serialize, Deserialize)] pub enum Identifier<'ast> { #[serde(borrow)] Source(SourceIdentifier<'ast>), + Internal(usize), } -#[derive(Debug, PartialEq, Clone, Hash, Eq, Serialize, Deserialize)] +#[derive(Debug, PartialEq, Clone, Hash, Eq, PartialOrd, Ord, Serialize, Deserialize)] pub enum SourceIdentifier<'ast> { #[serde(borrow)] Basic(CoreIdentifier<'ast>), @@ -30,10 +31,17 @@ impl<'ast> fmt::Display for SourceIdentifier<'ast> { } } +impl<'ast> Identifier<'ast> { + pub fn internal>(id: T) -> Self { + Identifier::Internal(id.into()) + } +} + impl<'ast> fmt::Display for Identifier<'ast> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { Identifier::Source(s) => write!(f, "{}", s), + Identifier::Internal(i) => write!(f, "i{}", i), } } } diff --git a/zokrates_ast/src/zir/lqc.rs b/zokrates_ast/src/zir/lqc.rs index 121b2a39..834f601f 100644 --- a/zokrates_ast/src/zir/lqc.rs +++ b/zokrates_ast/src/zir/lqc.rs @@ -43,7 +43,7 @@ impl<'ast, T: Field> std::ops::Sub for LinQuadComb<'ast, T> { linear: { let mut l = self.linear; other.linear.iter_mut().for_each(|(c, _)| { - *c = T::zero() - &*c; + *c = T::zero() - *c; }); l.append(&mut other.linear); l @@ -51,7 +51,7 @@ impl<'ast, T: Field> std::ops::Sub for LinQuadComb<'ast, T> { quadratic: { let mut q = self.quadratic; other.quadratic.iter_mut().for_each(|(c, _, _)| { - *c = T::zero() - &*c; + *c = T::zero() - *c; }); q.append(&mut other.quadratic); q @@ -68,18 +68,18 @@ impl<'ast, T: Field> LinQuadComb<'ast, T> { } Ok(Self { - constant: self.constant.clone() * rhs.constant.clone(), + constant: self.constant * rhs.constant, linear: { // lin0 * const1 + lin1 * const0 self.linear .clone() .into_iter() - .map(|(c, i)| (c * rhs.constant.clone(), i)) + .map(|(c, i)| (c * rhs.constant, i)) .chain( rhs.linear .clone() .into_iter() - .map(|(c, i)| (c * self.constant.clone(), i)), + .map(|(c, i)| (c * self.constant, i)), ) .collect() }, @@ -87,16 +87,16 @@ impl<'ast, T: Field> LinQuadComb<'ast, T> { // quad0 * const1 + quad1 * const0 + lin0 * lin1 self.quadratic .into_iter() - .map(|(c, i0, i1)| (c * rhs.constant.clone(), i0, i1)) + .map(|(c, i0, i1)| (c * rhs.constant, i0, i1)) .chain( rhs.quadratic .into_iter() - .map(|(c, i0, i1)| (c * self.constant.clone(), i0, i1)), + .map(|(c, i0, i1)| (c * self.constant, i0, i1)), ) .chain(self.linear.iter().flat_map(|(cl, l)| { rhs.linear .iter() - .map(|(cr, r)| (cl.clone() * cr.clone(), l.clone(), r.clone())) + .map(|(cr, r)| (*cl * *cr, l.clone(), r.clone())) })) .collect() }, diff --git a/zokrates_ast/src/zir/mod.rs b/zokrates_ast/src/zir/mod.rs index 60dc1467..08110966 100644 --- a/zokrates_ast/src/zir/mod.rs +++ b/zokrates_ast/src/zir/mod.rs @@ -1,3 +1,4 @@ +pub mod canonicalizer; pub mod folder; mod from_typed; mod identifier; diff --git a/zokrates_bellman/src/groth16.rs b/zokrates_bellman/src/groth16.rs index 457c7ed7..619d95e7 100644 --- a/zokrates_bellman/src/groth16.rs +++ b/zokrates_bellman/src/groth16.rs @@ -219,6 +219,7 @@ mod tests { arguments: vec![Parameter::public(Variable::new(0))], return_count: 1, statements: vec![Statement::constraint(Variable::new(0), Variable::public(0))], + solvers: vec![], }; let rng = &mut StdRng::from_entropy(); diff --git a/zokrates_bellman/src/lib.rs b/zokrates_bellman/src/lib.rs index 040683bf..ce1ce717 100644 --- a/zokrates_bellman/src/lib.rs +++ b/zokrates_bellman/src/lib.rs @@ -179,7 +179,7 @@ impl<'a, T: BellmanFieldExtensions + Field, I: IntoIterator", "Dennis Kuhnert ", "Thibaut Schaeffer "] repository = "https://github.com/Zokrates/ZoKrates.git" edition = "2018" diff --git a/zokrates_codegen/src/lib.rs b/zokrates_codegen/src/lib.rs index e149a968..a09eb61e 100644 --- a/zokrates_codegen/src/lib.rs +++ b/zokrates_codegen/src/lib.rs @@ -11,8 +11,8 @@ mod utils; use self::utils::flat_expression_from_bits; use zokrates_ast::zir::{ - ConditionalExpression, SelectExpression, ShouldReduce, UMetadata, ZirAssemblyStatement, - ZirExpressionList, + canonicalizer::ZirCanonicalizer, ConditionalExpression, Folder, SelectExpression, ShouldReduce, + UMetadata, ZirAssemblyStatement, ZirExpressionList, }; use zokrates_interpreter::Interpreter; @@ -1885,7 +1885,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { // constants do not require directives if let Some(FlatExpression::Number(ref x)) = e.field { - let bits: Vec<_> = Interpreter::execute_solver(&Solver::bits(to), &[x.clone()]) + let bits: Vec<_> = Interpreter::execute_solver(&Solver::bits(to), &[*x], &[]) .unwrap() .into_iter() .map(FlatExpression::Number) @@ -2237,10 +2237,15 @@ impl<'ast, T: Field> Flattener<'ast, T> { .cloned() .map(|p| self.layout.get(&p.id.id).cloned().unwrap().into()) .collect(); + let outputs: Vec = assignees .into_iter() .map(|assignee| self.use_variable(&assignee)) .collect(); + + let mut canonicalizer = ZirCanonicalizer::default(); + let function = canonicalizer.fold_function(function); + let directive = FlatDirective::new(outputs, Solver::Zir(function), inputs); statements_flattened.push_back(FlatStatement::Directive(directive)); } diff --git a/zokrates_core/Cargo.toml b/zokrates_core/Cargo.toml index aa1964c5..2b253004 100644 --- a/zokrates_core/Cargo.toml +++ b/zokrates_core/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "zokrates_core" -version = "0.7.3" +version = "0.7.4" edition = "2021" authors = ["Jacob Eberhardt ", "Dennis Kuhnert "] repository = "https://github.com/Zokrates/ZoKrates" diff --git a/zokrates_core/src/optimizer/duplicate.rs b/zokrates_core/src/optimizer/duplicate.rs index 664cfc2d..346695a2 100644 --- a/zokrates_core/src/optimizer/duplicate.rs +++ b/zokrates_core/src/optimizer/duplicate.rs @@ -39,14 +39,18 @@ impl<'ast, T: Field> Folder<'ast, T> for DuplicateOptimizer { } fn fold_statement(&mut self, s: Statement<'ast, T>) -> Vec> { - let hashed = hash(&s); - let result = match self.seen.get(&hashed) { - Some(_) => vec![], - None => vec![s], - }; - - self.seen.insert(hashed); - result + match s { + Statement::Block(s) => s.into_iter().flat_map(|s| self.fold_statement(s)).collect(), + s => { + let hashed = hash(&s); + let result = match self.seen.get(&hashed) { + Some(_) => vec![], + None => vec![s], + }; + self.seen.insert(hashed); + result + } + } } } @@ -77,6 +81,7 @@ mod tests { ], return_count: 0, arguments: vec![], + solvers: vec![], }; let expected = p.clone(); @@ -113,6 +118,7 @@ mod tests { ], return_count: 0, arguments: vec![], + solvers: vec![], }; let expected = Prog { @@ -128,6 +134,7 @@ mod tests { ], return_count: 0, arguments: vec![], + solvers: vec![], }; assert_eq!( diff --git a/zokrates_core/src/optimizer/mod.rs b/zokrates_core/src/optimizer/mod.rs index 1f94740a..95e9f25c 100644 --- a/zokrates_core/src/optimizer/mod.rs +++ b/zokrates_core/src/optimizer/mod.rs @@ -54,6 +54,7 @@ pub fn optimize<'ast, T: Field, I: IntoIterator>>( .flat_map(move |s| directive_optimizer.fold_statement(s)) .flat_map(move |s| duplicate_optimizer.fold_statement(s)), return_count: p.return_count, + solvers: p.solvers, }; log::debug!("Done"); diff --git a/zokrates_core/src/optimizer/redefinition.rs b/zokrates_core/src/optimizer/redefinition.rs index b0877fd0..c1f90aa5 100644 --- a/zokrates_core/src/optimizer/redefinition.rs +++ b/zokrates_core/src/optimizer/redefinition.rs @@ -146,7 +146,7 @@ impl RedefinitionOptimizer { // unwrap inputs to their constant value let inputs: Vec<_> = inputs.into_iter().map(|i| i.unwrap()).collect(); // run the solver - let outputs = Interpreter::execute_solver(&d.solver, &inputs).unwrap(); + let outputs = Interpreter::execute_solver(&d.solver, &inputs, &[]).unwrap(); assert_eq!(outputs.len(), d.outputs.len()); // insert the results in the substitution @@ -256,12 +256,14 @@ mod tests { Statement::definition(out, y), ], return_count: 1, + solvers: vec![], }; let optimized: Prog = Prog { arguments: vec![x], statements: vec![Statement::definition(out, x.id)], return_count: 1, + solvers: vec![], }; let mut optimizer = RedefinitionOptimizer::init(&p); @@ -280,6 +282,7 @@ mod tests { arguments: vec![x], statements: vec![Statement::definition(one, x.id)], return_count: 1, + solvers: vec![], }; let optimized = p.clone(); @@ -316,6 +319,7 @@ mod tests { Statement::definition(out, z), ], return_count: 1, + solvers: vec![], }; let optimized: Prog = Prog { @@ -325,6 +329,7 @@ mod tests { Statement::definition(out, x.id), ], return_count: 1, + solvers: vec![], }; let mut optimizer = RedefinitionOptimizer::init(&p); @@ -365,6 +370,7 @@ mod tests { Statement::definition(out_1, w), ], return_count: 2, + solvers: vec![], }; let optimized: Prog = Prog { @@ -374,6 +380,7 @@ mod tests { Statement::definition(out_1, Bn128Field::from(1)), ], return_count: 2, + solvers: vec![], }; let mut optimizer = RedefinitionOptimizer::init(&p); @@ -422,6 +429,7 @@ mod tests { Statement::definition(r, LinComb::from(a) + LinComb::from(b) + LinComb::from(c)), ], return_count: 1, + solvers: vec![], }; let expected: Prog = Prog { @@ -442,6 +450,7 @@ mod tests { ), ], return_count: 1, + solvers: vec![], }; let mut optimizer = RedefinitionOptimizer::init(&p); @@ -479,6 +488,7 @@ mod tests { Statement::definition(z, LinComb::from(x.id)), ], return_count: 0, + solvers: vec![], }; let optimized = p.clone(); @@ -507,6 +517,7 @@ mod tests { Statement::constraint(x.id, Bn128Field::from(2)), ], return_count: 1, + solvers: vec![], }; let optimized = p.clone(); diff --git a/zokrates_field/src/dummy_curve.rs b/zokrates_field/src/dummy_curve.rs index cb3af053..55460f03 100644 --- a/zokrates_field/src/dummy_curve.rs +++ b/zokrates_field/src/dummy_curve.rs @@ -10,7 +10,9 @@ use std::ops::{Add, Div, Mul, Sub}; const _PRIME: u8 = 7; -#[derive(Default, Debug, Hash, Clone, PartialOrd, Ord, Serialize, Deserialize, PartialEq, Eq)] +#[derive( + Default, Debug, Hash, Clone, Copy, PartialOrd, Ord, Serialize, Deserialize, PartialEq, Eq, +)] pub struct FieldPrime { v: u8, } diff --git a/zokrates_field/src/lib.rs b/zokrates_field/src/lib.rs index a33bf31a..a767f7d0 100644 --- a/zokrates_field/src/lib.rs +++ b/zokrates_field/src/lib.rs @@ -70,6 +70,7 @@ pub trait Field: + Zero + One + Clone + + Copy + PartialEq + Eq + Hash @@ -153,7 +154,7 @@ mod prime_field { type Fr = <$v as ark_ec::PairingEngine>::Fr; - #[derive(PartialEq, PartialOrd, Clone, Eq, Ord, Hash)] + #[derive(PartialEq, PartialOrd, Clone, Copy, Eq, Ord, Hash)] pub struct FieldPrime { v: Fr, } diff --git a/zokrates_interpreter/Cargo.toml b/zokrates_interpreter/Cargo.toml index 41cdfa35..27043668 100644 --- a/zokrates_interpreter/Cargo.toml +++ b/zokrates_interpreter/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "zokrates_interpreter" -version = "0.1.2" +version = "0.1.3" edition = "2021" [features] diff --git a/zokrates_interpreter/src/lib.rs b/zokrates_interpreter/src/lib.rs index 9f90e00c..db6ed4a5 100644 --- a/zokrates_interpreter/src/lib.rs +++ b/zokrates_interpreter/src/lib.rs @@ -50,7 +50,7 @@ impl Interpreter { witness.insert(Variable::one(), T::one()); for (arg, value) in program.arguments.iter().zip(inputs.iter()) { - witness.insert(arg.id, value.clone()); + witness.insert(arg.id, *value); } for statement in program.statements.into_iter() { @@ -83,12 +83,12 @@ impl Interpreter { inputs.pop().unwrap(), )) } - _ => Self::execute_solver(&d.solver, &inputs), + _ => Self::execute_solver(&d.solver, &inputs, &program.solvers), } .map_err(Error::Solver)?; for (i, o) in d.outputs.iter().enumerate() { - witness.insert(*o, res[i].clone()); + witness.insert(*o, res[i]); } } Statement::Log(l, expressions) => { @@ -164,7 +164,15 @@ impl Interpreter { pub fn execute_solver<'ast, T: Field>( solver: &Solver<'ast, T>, inputs: &[T], + solvers: &[Solver<'ast, T>], ) -> Result, String> { + let solver = match solver { + Solver::Ref(call) => solvers + .get(call.index) + .ok_or_else(|| format!("Could not get solver at index {}", call.index))?, + s => s, + }; + let (expected_input_count, expected_output_count) = solver.get_signature(); assert_eq!(inputs.len(), expected_input_count); @@ -180,7 +188,7 @@ impl Interpreter { .map(|(a, v)| match &a.id._type { zir::Type::FieldElement => Ok(( a.id.id.clone(), - zokrates_ast::zir::FieldElementExpression::Number(v.clone()).into(), + zokrates_ast::zir::FieldElementExpression::Number(*v).into(), )), zir::Type::Boolean => match v { v if *v == T::from(0) => Ok(( @@ -256,41 +264,38 @@ impl Interpreter { .collect() } Solver::Xor => { - let x = inputs[0].clone(); - let y = inputs[1].clone(); + let x = inputs[0]; + let y = inputs[1]; - vec![x.clone() + y.clone() - T::from(2) * x * y] + vec![x + y - T::from(2) * x * y] } Solver::Or => { - let x = inputs[0].clone(); - let y = inputs[1].clone(); + let x = inputs[0]; + let y = inputs[1]; - vec![x.clone() + y.clone() - x * y] + vec![x + y - x * y] } // res = b * c - (2b * c - b - c) * (a) Solver::ShaAndXorAndXorAnd => { - let a = inputs[0].clone(); - let b = inputs[1].clone(); - let c = inputs[2].clone(); - vec![b.clone() * c.clone() - (T::from(2) * b.clone() * c.clone() - b - c) * a] + let a = inputs[0]; + let b = inputs[1]; + let c = inputs[2]; + vec![b * c - (T::from(2) * b * c - b - c) * a] } // res = a(b - c) + c Solver::ShaCh => { - let a = inputs[0].clone(); - let b = inputs[1].clone(); - let c = inputs[2].clone(); - vec![a * (b - c.clone()) + c] + let a = inputs[0]; + let b = inputs[1]; + let c = inputs[2]; + vec![a * (b - c) + c] } - Solver::Div => vec![inputs[0] - .clone() - .checked_div(&inputs[1]) - .unwrap_or_else(T::one)], + Solver::Div => vec![inputs[0].checked_div(&inputs[1]).unwrap_or_else(T::one)], Solver::EuclideanDiv => { use num::CheckedDiv; - let n = inputs[0].clone().to_biguint(); - let d = inputs[1].clone().to_biguint(); + let n = inputs[0].to_biguint(); + let d = inputs[1].to_biguint(); let q = n.checked_div(&d).unwrap_or_else(|| 0u32.into()); let r = n - d * &q; @@ -334,6 +339,7 @@ impl Interpreter { &inputs[*n + 8usize..], ) } + _ => unreachable!("unexpected solver"), }; assert_eq!(res.len(), expected_output_count); @@ -354,14 +360,11 @@ pub enum Error { } fn evaluate_lin(w: &Witness, l: &LinComb) -> Result { - l.0.iter() - .map(|(var, mult)| { - w.0.get(var) - .map(|v| v.clone() * mult) - .ok_or(EvaluationError) - }) // get each term - .collect::, _>>() // fail if any term isn't found - .map(|v| v.iter().fold(T::from(0), |acc, t| acc + t)) // return the sum + l.0.iter().try_fold(T::from(0), |acc, (var, mult)| { + w.0.get(var) + .map(|v| acc + (*v * mult)) + .ok_or(EvaluationError) // fail if any term isn't found + }) } pub fn evaluate_quad(w: &Witness, q: &QuadComb) -> Result { @@ -434,6 +437,7 @@ mod tests { .iter() .map(|&i| Bn128Field::from(i)) .collect::>(), + &[], ) .unwrap(); let res: Vec = vec![0, 1].iter().map(|&i| Bn128Field::from(i)).collect(); @@ -450,6 +454,7 @@ mod tests { .iter() .map(|&i| Bn128Field::from(i)) .collect::>(), + &[], ) .unwrap(); let res: Vec = vec![1, 1].iter().map(|&i| Bn128Field::from(i)).collect(); @@ -460,9 +465,12 @@ mod tests { #[test] fn bits_of_one() { let inputs = vec![Bn128Field::from(1)]; - let res = - Interpreter::execute_solver(&Solver::Bits(Bn128Field::get_required_bits()), &inputs) - .unwrap(); + let res = Interpreter::execute_solver( + &Solver::Bits(Bn128Field::get_required_bits()), + &inputs, + &[], + ) + .unwrap(); assert_eq!(res[253], Bn128Field::from(1)); for r in &res[0..253] { assert_eq!(*r, Bn128Field::from(0)); @@ -472,9 +480,12 @@ mod tests { #[test] fn bits_of_42() { let inputs = vec![Bn128Field::from(42)]; - let res = - Interpreter::execute_solver(&Solver::Bits(Bn128Field::get_required_bits()), &inputs) - .unwrap(); + let res = Interpreter::execute_solver( + &Solver::Bits(Bn128Field::get_required_bits()), + &inputs, + &[], + ) + .unwrap(); assert_eq!(res[253], Bn128Field::from(0)); assert_eq!(res[252], Bn128Field::from(1)); assert_eq!(res[251], Bn128Field::from(0)); @@ -487,11 +498,53 @@ mod tests { #[test] fn five_hundred_bits_of_1() { let inputs = vec![Bn128Field::from(1)]; - let res = Interpreter::execute_solver(&Solver::Bits(500), &inputs).unwrap(); + let res = Interpreter::execute_solver(&Solver::Bits(500), &inputs, &[]).unwrap(); let mut expected = vec![Bn128Field::from(0); 500]; expected[499] = Bn128Field::from(1); assert_eq!(res, expected); } + + #[test] + fn solver_ref() { + use zir::{ + types::{Signature, Type}, + FieldElementExpression, Identifier, IdentifierExpression, Parameter, Variable, + ZirFunction, ZirStatement, + }; + use zokrates_ast::common::RefCall; + + let id = IdentifierExpression::new(Identifier::internal(0usize)); + + // (field i0) -> i0 * i0 + let solvers = vec![Solver::Zir(ZirFunction { + arguments: vec![Parameter { + id: Variable::field_element(id.id.clone()), + private: true, + }], + statements: vec![ZirStatement::Return(vec![FieldElementExpression::Mult( + Box::new(FieldElementExpression::Identifier(id.clone())), + Box::new(FieldElementExpression::Identifier(id.clone())), + ) + .into()])], + signature: Signature::new() + .inputs(vec![Type::FieldElement]) + .outputs(vec![Type::FieldElement]), + })]; + + let inputs = vec![Bn128Field::from(2)]; + let res = Interpreter::execute_solver( + &Solver::Ref(RefCall { + index: 0, + argument_count: 1, + }), + &inputs, + &solvers, + ) + .unwrap(); + + let expected = vec![Bn128Field::from(4)]; + assert_eq!(res, expected); + } } diff --git a/zokrates_js/Cargo.toml b/zokrates_js/Cargo.toml index 0c86329b..02374289 100644 --- a/zokrates_js/Cargo.toml +++ b/zokrates_js/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "zokrates_js" -version = "1.1.5" +version = "1.1.6" authors = ["Darko Macesic"] edition = "2018" diff --git a/zokrates_js/package.json b/zokrates_js/package.json index 7617f20f..02b6e919 100644 --- a/zokrates_js/package.json +++ b/zokrates_js/package.json @@ -1,6 +1,6 @@ { "name": "zokrates-js", - "version": "1.1.5", + "version": "1.1.6", "module": "index.js", "main": "index-node.js", "description": "JavaScript bindings for ZoKrates", diff --git a/zokrates_js/src/lib.rs b/zokrates_js/src/lib.rs index 2f3ac09e..ed3f2650 100644 --- a/zokrates_js/src/lib.rs +++ b/zokrates_js/src/lib.rs @@ -520,7 +520,8 @@ pub fn compute_witness( config: JsValue, log_callback: &js_sys::Function, ) -> Result { - let prog = ir::ProgEnum::deserialize(program) + let cursor = Cursor::new(program); + let prog = ir::ProgEnum::deserialize(cursor) .map_err(|err| JsValue::from_str(&err))? .collect(); match prog { @@ -582,7 +583,8 @@ pub fn setup(program: &[u8], entropy: JsValue, options: JsValue) -> Result Result