diff --git a/Cargo.lock b/Cargo.lock index 3372d1e1..d3b8ac94 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -652,11 +652,6 @@ dependencies = [ "rustc_version 0.2.3 (registry+https://github.com/rust-lang/crates.io-index)", ] -[[package]] -name = "memory_units" -version = "0.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" - [[package]] name = "nodrop" version = "0.1.13" @@ -767,22 +762,6 @@ dependencies = [ "rand 0.4.6 (registry+https://github.com/rust-lang/crates.io-index)", ] -[[package]] -name = "parity-wasm" -version = "0.31.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -dependencies = [ - "byteorder 1.3.2 (registry+https://github.com/rust-lang/crates.io-index)", -] - -[[package]] -name = "parity-wasm" -version = "0.35.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -dependencies = [ - "byteorder 1.3.2 (registry+https://github.com/rust-lang/crates.io-index)", -] - [[package]] name = "percent-encoding" version = "1.0.1" @@ -1036,11 +1015,6 @@ name = "rustc-demangle" version = "0.1.16" source = "registry+https://github.com/rust-lang/crates.io-index" -[[package]] -name = "rustc-hex" -version = "1.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" - [[package]] name = "rustc-serialize" version = "0.3.24" @@ -1377,24 +1351,6 @@ name = "wasi" version = "0.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -[[package]] -name = "wasmi" -version = "0.4.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -dependencies = [ - "memory_units 0.3.0 (registry+https://github.com/rust-lang/crates.io-index)", - "parity-wasm 0.31.3 (registry+https://github.com/rust-lang/crates.io-index)", - "wasmi-validation 0.1.0 (registry+https://github.com/rust-lang/crates.io-index)", -] - -[[package]] -name = "wasmi-validation" -version = "0.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -dependencies = [ - "parity-wasm 0.31.3 (registry+https://github.com/rust-lang/crates.io-index)", -] - [[package]] name = "winapi" version = "0.3.8" @@ -1480,17 +1436,14 @@ dependencies = [ "num 0.1.42 (registry+https://github.com/rust-lang/crates.io-index)", "num-bigint 0.1.44 (registry+https://github.com/rust-lang/crates.io-index)", "pairing_ce 0.18.0 (registry+https://github.com/rust-lang/crates.io-index)", - "parity-wasm 0.35.7 (registry+https://github.com/rust-lang/crates.io-index)", "rand 0.4.6 (registry+https://github.com/rust-lang/crates.io-index)", "reduce 0.1.2 (registry+https://github.com/rust-lang/crates.io-index)", "regex 0.2.11 (registry+https://github.com/rust-lang/crates.io-index)", - "rustc-hex 1.0.0 (registry+https://github.com/rust-lang/crates.io-index)", "serde 1.0.101 (registry+https://github.com/rust-lang/crates.io-index)", "serde_bytes 0.10.5 (registry+https://github.com/rust-lang/crates.io-index)", "serde_derive 1.0.101 (registry+https://github.com/rust-lang/crates.io-index)", "serde_json 1.0.41 (registry+https://github.com/rust-lang/crates.io-index)", "typed-arena 1.6.1 (registry+https://github.com/rust-lang/crates.io-index)", - "wasmi 0.4.5 (registry+https://github.com/rust-lang/crates.io-index)", "zokrates_embed 0.1.0", "zokrates_field 0.3.4", "zokrates_pest_ast 0.1.3", @@ -1659,7 +1612,6 @@ dependencies = [ "checksum matches 0.1.8 (registry+https://github.com/rust-lang/crates.io-index)" = "7ffc5c5338469d4d3ea17d269fa8ea3512ad247247c30bd2df69e68309ed0a08" "checksum memchr 2.2.1 (registry+https://github.com/rust-lang/crates.io-index)" = "88579771288728879b57485cc7d6b07d648c9f0141eb955f8ab7f9d45394468e" "checksum memoffset 0.5.1 (registry+https://github.com/rust-lang/crates.io-index)" = "ce6075db033bbbb7ee5a0bbd3a3186bbae616f57fb001c485c7ff77955f8177f" -"checksum memory_units 0.3.0 (registry+https://github.com/rust-lang/crates.io-index)" = "71d96e3f3c0b6325d8ccd83c33b28acb183edcb6c67938ba104ec546854b0882" "checksum nodrop 0.1.13 (registry+https://github.com/rust-lang/crates.io-index)" = "2f9667ddcc6cc8a43afc9b7917599d7216aa09c463919ea32c59ed6cac8bc945" "checksum num 0.1.42 (registry+https://github.com/rust-lang/crates.io-index)" = "4703ad64153382334aa8db57c637364c322d3372e097840c72000dabdcf6156e" "checksum num-bigint 0.1.44 (registry+https://github.com/rust-lang/crates.io-index)" = "e63899ad0da84ce718c14936262a41cee2c79c981fc0a0e7c7beb47d5a07e8c1" @@ -1673,8 +1625,6 @@ dependencies = [ "checksum openssl-probe 0.1.2 (registry+https://github.com/rust-lang/crates.io-index)" = "77af24da69f9d9341038eba93a073b1fdaaa1b788221b00a69bce9e762cb32de" "checksum openssl-sys 0.9.50 (registry+https://github.com/rust-lang/crates.io-index)" = "2c42dcccb832556b5926bc9ae61e8775f2a61e725ab07ab3d1e7fcf8ae62c3b6" "checksum pairing_ce 0.18.0 (registry+https://github.com/rust-lang/crates.io-index)" = "f075a9c570e2026111cb6dddf6a320e5163c42aa32500b315ec34acbcf7c9b36" -"checksum parity-wasm 0.31.3 (registry+https://github.com/rust-lang/crates.io-index)" = "511379a8194230c2395d2f5fa627a5a7e108a9f976656ce723ae68fca4097bfc" -"checksum parity-wasm 0.35.7 (registry+https://github.com/rust-lang/crates.io-index)" = "3e1e076c4e01399b6cd0793a8df42f90bba3ae424671ef421d1608a943155d93" "checksum percent-encoding 1.0.1 (registry+https://github.com/rust-lang/crates.io-index)" = "31010dd2e1ac33d5b46a5b413495239882813e0369f8ed8a5e266f173602f831" "checksum pest 2.1.2 (registry+https://github.com/rust-lang/crates.io-index)" = "7e4fb201c5c22a55d8b24fef95f78be52738e5e1361129be1b5e862ecdb6894a" "checksum pest-ast 0.3.3 (registry+https://github.com/rust-lang/crates.io-index)" = "3fbf404899169771dd6a32c84248b83cd67a26cc7cc957aac87661490e1227e4" @@ -1706,7 +1656,6 @@ dependencies = [ "checksum rgb 0.8.14 (registry+https://github.com/rust-lang/crates.io-index)" = "2089e4031214d129e201f8c3c8c2fe97cd7322478a0d1cdf78e7029b0042efdb" "checksum rust-crypto 0.2.36 (registry+https://github.com/rust-lang/crates.io-index)" = "f76d05d3993fd5f4af9434e8e436db163a12a9d40e1a58a726f27a01dfd12a2a" "checksum rustc-demangle 0.1.16 (registry+https://github.com/rust-lang/crates.io-index)" = "4c691c0e608126e00913e33f0ccf3727d5fc84573623b8d65b2df340b5201783" -"checksum rustc-hex 1.0.0 (registry+https://github.com/rust-lang/crates.io-index)" = "0ceb8ce7a5e520de349e1fa172baeba4a9e8d5ef06c47471863530bc4972ee1e" "checksum rustc-serialize 0.3.24 (registry+https://github.com/rust-lang/crates.io-index)" = "dcf128d1287d2ea9d80910b5f1120d0b8eede3fbf1abe91c40d39ea7d51e6fda" "checksum rustc_version 0.2.3 (registry+https://github.com/rust-lang/crates.io-index)" = "138e3e0acb6c9fb258b19b67cb8abd63c00679d2851805ea151465464fe9030a" "checksum ryu 1.0.0 (registry+https://github.com/rust-lang/crates.io-index)" = "c92464b447c0ee8c4fb3824ecc8383b81717b9f1e74ba2e72540aef7b9f82997" @@ -1751,8 +1700,6 @@ dependencies = [ "checksum void 1.0.2 (registry+https://github.com/rust-lang/crates.io-index)" = "6a02e4885ed3bc0f2de90ea6dd45ebcbb66dacffe03547fadbb0eeae2770887d" "checksum walkdir 2.2.9 (registry+https://github.com/rust-lang/crates.io-index)" = "9658c94fa8b940eab2250bd5a457f9c48b748420d71293b165c8cdbe2f55f71e" "checksum wasi 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)" = "b89c3ce4ce14bdc6fb6beaf9ec7928ca331de5df7e5ea278375642a2f478570d" -"checksum wasmi 0.4.5 (registry+https://github.com/rust-lang/crates.io-index)" = "aebbaef470840d157a5c47c8c49f024da7b1b80e90ff729ca982b2b80447e78b" -"checksum wasmi-validation 0.1.0 (registry+https://github.com/rust-lang/crates.io-index)" = "ab380192444b3e8522ae79c0a1976e42a82920916ccdfbce3def89f456ea33f3" "checksum winapi 0.3.8 (registry+https://github.com/rust-lang/crates.io-index)" = "8093091eeb260906a183e6ae1abdba2ef5ef2257a21801128899c3fc699229c6" "checksum winapi-i686-pc-windows-gnu 0.4.0 (registry+https://github.com/rust-lang/crates.io-index)" = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" "checksum winapi-util 0.1.2 (registry+https://github.com/rust-lang/crates.io-index)" = "7168bab6e1daee33b4557efd0e95d5ca70a03706d39fa5f3fe7a236f584b03c9" diff --git a/zokrates_cli/Cargo.toml b/zokrates_cli/Cargo.toml index 60dcabca..01969525 100644 --- a/zokrates_cli/Cargo.toml +++ b/zokrates_cli/Cargo.toml @@ -8,7 +8,6 @@ edition = "2018" [features] default = [] libsnark = ["zokrates_core/libsnark"] -wasm = ["zokrates_core/wasm"] [dependencies] clap = "2.26.2" diff --git a/zokrates_core/Cargo.toml b/zokrates_core/Cargo.toml index 7ad8d3ad..026243be 100644 --- a/zokrates_core/Cargo.toml +++ b/zokrates_core/Cargo.toml @@ -9,7 +9,6 @@ build = "build.rs" [features] default = [] libsnark = ["cc", "cmake", "git2"] -wasm = ["wasmi", "parity-wasm", "rustc-hex"] [dependencies] libc = "0.2.0" @@ -32,9 +31,6 @@ zokrates_field = { version = "0.3.0", path = "../zokrates_field" } zokrates_pest_ast = { version = "0.1.0", path = "../zokrates_pest_ast" } zokrates_embed = { path = "../zokrates_embed" } rand = "0.4" -wasmi = { version = "=0.4.5", optional = true } -parity-wasm = { version = "0.35.3", optional = true } -rustc-hex = { version = "1.0", optional = true } csv = "1" [dev-dependencies] diff --git a/zokrates_core/src/absy/from_ast.rs b/zokrates_core/src/absy/from_ast.rs index 44532609..81d209f6 100644 --- a/zokrates_core/src/absy/from_ast.rs +++ b/zokrates_core/src/absy/from_ast.rs @@ -789,7 +789,7 @@ mod tests { mod types { use super::*; - /// Helper method to generate the ast for `def main(private {ty} a) -> (): return` which we use to check ty + /// solver method to generate the ast for `def main(private {ty} a) -> (): return` which we use to check ty fn wrap(ty: absy::UnresolvedType) -> absy::Module<'static, FieldPrime> { absy::Module { symbols: vec![absy::SymbolDeclaration { diff --git a/zokrates_core/src/embed.rs b/zokrates_core/src/embed.rs index 6ced72be..814bc349 100644 --- a/zokrates_core/src/embed.rs +++ b/zokrates_core/src/embed.rs @@ -1,7 +1,8 @@ -use crate::helpers::{DirectiveStatement, Helper, RustHelper}; +use crate::solvers::Solver; use bellman::pairing::ff::ScalarEngine; use flat_absy::{ - FlatExpression, FlatExpressionList, FlatFunction, FlatParameter, FlatStatement, FlatVariable, + FlatDirective, FlatExpression, FlatExpressionList, FlatFunction, FlatParameter, FlatStatement, + FlatVariable, }; use reduce::Reduce; use std::collections::HashMap; @@ -166,13 +167,13 @@ pub fn sha256_round() -> FlatFunction { .collect(); // insert a directive to set the witness based on the bellman gadget and inputs - let directive_statement = FlatStatement::Directive(DirectiveStatement { + let directive_statement = FlatStatement::Directive(FlatDirective { outputs: cs_indices.map(|i| FlatVariable::new(i)).collect(), inputs: input_argument_indices .chain(current_hash_argument_indices) .map(|i| FlatVariable::new(i).into()) .collect(), - helper: Helper::Rust(RustHelper::Sha256Round), + solver: Solver::Sha256Round, }); // insert a statement to return the subset of the witness @@ -233,7 +234,7 @@ pub fn unpack() -> FlatFunction { .map(|index| use_variable(&mut layout, format!("o{}", index), &mut counter)) .collect(); - let helper = Helper::bits(); + let solver = Solver::bits(); let signature = Signature { inputs: vec![Type::FieldElement], @@ -281,10 +282,10 @@ pub fn unpack() -> FlatFunction { statements.insert( 0, - FlatStatement::Directive(DirectiveStatement { + FlatStatement::Directive(FlatDirective { inputs: directive_inputs, outputs: directive_outputs, - helper: helper, + solver: solver, }), ); @@ -322,11 +323,11 @@ mod tests { ); // 128 bit checks, 1 directive, 1 sum check, 1 return assert_eq!( unpack.statements[0], - FlatStatement::Directive(DirectiveStatement::new( + FlatStatement::Directive(FlatDirective::new( (0..FieldPrime::get_required_bits()) .map(|i| FlatVariable::new(i + 1)) .collect(), - Helper::bits(), + Solver::bits(), vec![FlatVariable::new(0)] )) ); diff --git a/zokrates_core/src/flat_absy/mod.rs b/zokrates_core/src/flat_absy/mod.rs index 0762476e..0426e60c 100644 --- a/zokrates_core/src/flat_absy/mod.rs +++ b/zokrates_core/src/flat_absy/mod.rs @@ -11,8 +11,8 @@ pub mod flat_variable; pub use self::flat_parameter::FlatParameter; pub use self::flat_variable::FlatVariable; -use crate::helpers::DirectiveStatement; use crate::typed_absy::types::Signature; +use solvers::{Signed, Solver}; use std::collections::HashMap; use std::fmt; use zokrates_field::field::Field; @@ -95,7 +95,7 @@ pub enum FlatStatement { Return(FlatExpressionList), Condition(FlatExpression, FlatExpression), Definition(FlatVariable, FlatExpression), - Directive(DirectiveStatement), + Directive(FlatDirective), } impl fmt::Display for FlatStatement { @@ -149,7 +149,7 @@ impl FlatStatement { .map(|i| i.apply_substitution(substitution)) .collect(); - FlatStatement::Directive(DirectiveStatement { + FlatStatement::Directive(FlatDirective { outputs, inputs, ..d @@ -159,6 +159,50 @@ impl FlatStatement { } } +#[derive(Clone, PartialEq, Debug, Serialize, Deserialize)] +pub struct FlatDirective { + pub inputs: Vec>, + pub outputs: Vec, + pub solver: Solver, +} + +impl FlatDirective { + pub fn new>>( + outputs: Vec, + solver: Solver, + inputs: Vec, + ) -> Self { + let (in_len, out_len) = solver.get_signature(); + assert_eq!(in_len, inputs.len()); + assert_eq!(out_len, outputs.len()); + FlatDirective { + solver, + inputs: inputs.into_iter().map(|i| i.into()).collect(), + outputs, + } + } +} + +impl fmt::Display for FlatDirective { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!( + f, + "# {} = {}({})", + self.outputs + .iter() + .map(|o| o.to_string()) + .collect::>() + .join(", "), + self.solver, + self.inputs + .iter() + .map(|i| i.to_string()) + .collect::>() + .join(", ") + ) + } +} + #[derive(Clone, PartialEq, Serialize, Deserialize)] pub enum FlatExpression { Number(T), diff --git a/zokrates_core/src/flatten/mod.rs b/zokrates_core/src/flatten/mod.rs index 08ebb777..32c59f1b 100644 --- a/zokrates_core/src/flatten/mod.rs +++ b/zokrates_core/src/flatten/mod.rs @@ -6,7 +6,7 @@ //! @date 2017 use crate::flat_absy::*; -use crate::helpers::{DirectiveStatement, Helper, RustHelper}; +use crate::solvers::Solver; use crate::typed_absy::types::{FunctionIdentifier, FunctionKey, MemberId, Signature, Type}; use crate::typed_absy::*; use std::collections::HashMap; @@ -590,9 +590,9 @@ impl<'ast, T: Field> Flattener<'ast, T> { (0..bitwidth).map(|_| self.use_sym()).collect(); // add a directive to get the bits - statements_flattened.push(FlatStatement::Directive(DirectiveStatement::new( + statements_flattened.push(FlatStatement::Directive(FlatDirective::new( lhs_bits_be.clone(), - Helper::bits(), + Solver::bits(), vec![lhs_id], ))); @@ -637,9 +637,9 @@ impl<'ast, T: Field> Flattener<'ast, T> { (0..bitwidth).map(|_| self.use_sym()).collect(); // add a directive to get the bits - statements_flattened.push(FlatStatement::Directive(DirectiveStatement::new( + statements_flattened.push(FlatStatement::Directive(FlatDirective::new( rhs_bits_be.clone(), - Helper::bits(), + Solver::bits(), vec![rhs_id], ))); @@ -690,9 +690,9 @@ impl<'ast, T: Field> Flattener<'ast, T> { (0..bitwidth).map(|_| self.use_sym()).collect(); // add a directive to get the bits - statements_flattened.push(FlatStatement::Directive(DirectiveStatement::new( + statements_flattened.push(FlatStatement::Directive(FlatDirective::new( sub_bits_be.clone(), - Helper::bits(), + Solver::bits(), vec![subtraction_result.clone()], ))); @@ -776,9 +776,9 @@ impl<'ast, T: Field> Flattener<'ast, T> { FieldElementExpression::Sub(box lhs, box rhs), ); - statements_flattened.push(FlatStatement::Directive(DirectiveStatement::new( + statements_flattened.push(FlatStatement::Directive(FlatDirective::new( vec![name_y, name_m], - Helper::Rust(RustHelper::ConditionEq), + Solver::ConditionEq, vec![x.clone()], ))); statements_flattened.push(FlatStatement::Condition( @@ -963,9 +963,9 @@ impl<'ast, T: Field> Flattener<'ast, T> { .into_iter() .map(|i| i.apply_substitution(&replacement_map)) .collect(); - FlatStatement::Directive(DirectiveStatement { + FlatStatement::Directive(FlatDirective { outputs: new_outputs, - helper: d.helper, + solver: d.solver, inputs: new_inputs, }) } @@ -1138,9 +1138,9 @@ impl<'ast, T: Field> Flattener<'ast, T> { let inverse = self.use_sym(); // # invb = 1/b - statements_flattened.push(FlatStatement::Directive(DirectiveStatement::new( + statements_flattened.push(FlatStatement::Directive(FlatDirective::new( vec![invb], - Helper::Rust(RustHelper::Div), + Solver::Div, vec![FlatExpression::Number(T::one()), new_right.clone()], ))); @@ -1151,9 +1151,9 @@ impl<'ast, T: Field> Flattener<'ast, T> { )); // # c = a/b - statements_flattened.push(FlatStatement::Directive(DirectiveStatement::new( + statements_flattened.push(FlatStatement::Directive(FlatDirective::new( vec![inverse], - Helper::Rust(RustHelper::Div), + Solver::Div, vec![new_left.clone(), new_right.clone()], ))); @@ -2041,9 +2041,9 @@ mod tests { FlatStatement::Definition(five, FlatExpression::Number(FieldPrime::from(5))), FlatStatement::Definition(b0, b.into()), // check div by 0 - FlatStatement::Directive(DirectiveStatement::new( + FlatStatement::Directive(FlatDirective::new( vec![invb0], - Helper::Rust(RustHelper::Div), + Solver::Div, vec![FlatExpression::Number(FieldPrime::from(1)), b0.into()] )), FlatStatement::Condition( @@ -2051,9 +2051,9 @@ mod tests { FlatExpression::Mult(box invb0.into(), box b0.into()), ), // execute div - FlatStatement::Directive(DirectiveStatement::new( + FlatStatement::Directive(FlatDirective::new( vec![sym_0], - Helper::Rust(RustHelper::Div), + Solver::Div, vec![five, b0] )), FlatStatement::Condition( @@ -2064,9 +2064,9 @@ mod tests { FlatStatement::Definition(sym_1, sym_0.into()), FlatStatement::Definition(b1, b.into()), // check div by 0 - FlatStatement::Directive(DirectiveStatement::new( + FlatStatement::Directive(FlatDirective::new( vec![invb1], - Helper::Rust(RustHelper::Div), + Solver::Div, vec![FlatExpression::Number(FieldPrime::from(1)), b1.into()] )), FlatStatement::Condition( @@ -2074,9 +2074,9 @@ mod tests { FlatExpression::Mult(box invb1.into(), box b1.into()), ), // execute div - FlatStatement::Directive(DirectiveStatement::new( + FlatStatement::Directive(FlatDirective::new( vec![sym_2], - Helper::Rust(RustHelper::Div), + Solver::Div, vec![sym_1, b1] )), FlatStatement::Condition( diff --git a/zokrates_core/src/helpers/mod.rs b/zokrates_core/src/helpers/mod.rs deleted file mode 100644 index 366a2116..00000000 --- a/zokrates_core/src/helpers/mod.rs +++ /dev/null @@ -1,171 +0,0 @@ -mod rust; -#[cfg(feature = "wasm")] -mod wasm; - -pub use self::rust::RustHelper; -#[cfg(feature = "wasm")] -pub use self::wasm::WasmHelper; -use crate::flat_absy::{FlatExpression, FlatVariable}; -use std::fmt; -use zokrates_field::field::Field; - -#[derive(Clone, PartialEq, Debug, Serialize, Deserialize)] -pub struct DirectiveStatement { - pub inputs: Vec>, - pub outputs: Vec, - pub helper: Helper, -} - -impl DirectiveStatement { - pub fn new>>( - outputs: Vec, - helper: Helper, - inputs: Vec, - ) -> Self { - let (in_len, out_len) = helper.get_signature(); - assert_eq!(in_len, inputs.len()); - assert_eq!(out_len, outputs.len()); - DirectiveStatement { - helper, - inputs: inputs.into_iter().map(|i| i.into()).collect(), - outputs, - } - } -} - -impl fmt::Display for DirectiveStatement { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!( - f, - "# {} = {}({})", - self.outputs - .iter() - .map(|o| o.to_string()) - .collect::>() - .join(", "), - self.helper, - self.inputs - .iter() - .map(|i| i.to_string()) - .collect::>() - .join(", ") - ) - } -} - -#[derive(Clone, PartialEq, Debug, Serialize, Deserialize, Hash, Eq)] -pub enum Helper { - Rust(RustHelper), - #[cfg(feature = "wasm")] - Wasm(WasmHelper), -} - -#[cfg(feature = "wasm")] -impl Helper { - pub fn identity() -> Self { - Helper::Wasm(WasmHelper::from_hex(WasmHelper::IDENTITY_WASM)) - } - - pub fn bits() -> Self { - Helper::Wasm(WasmHelper::from(WasmHelper::BITS_WASM)) - } -} - -#[cfg(not(feature = "wasm"))] -impl Helper { - pub fn identity() -> Self { - Helper::Rust(RustHelper::Identity) - } - - pub fn bits() -> Self { - Helper::Rust(RustHelper::Bits) - } -} - -impl fmt::Display for Helper { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match *self { - Helper::Rust(ref h) => write!(f, "Rust::{}", h), - #[cfg(feature = "wasm")] - Helper::Wasm(ref h) => write!(f, "Wasm::{}", h), - } - } -} - -pub trait Executable: Signed { - fn execute(&self, inputs: &Vec) -> Result, String>; -} - -pub trait Signed { - fn get_signature(&self) -> (usize, usize); -} - -impl Executable for Helper { - fn execute(&self, inputs: &Vec) -> Result, String> { - let (expected_input_count, expected_output_count) = self.get_signature(); - assert!(inputs.len() == expected_input_count); - - let result = match self { - Helper::Rust(helper) => helper.execute(inputs), - #[cfg(feature = "wasm")] - Helper::Wasm(helper) => helper.execute(inputs), - }; - - match result { - Ok(ref r) if r.len() != expected_output_count => Err(format!( - "invalid witness size: is {} but should be {}", - r.len(), - expected_output_count - ) - .to_string()), - r => r, - } - } -} - -impl Signed for Helper { - fn get_signature(&self) -> (usize, usize) { - match self { - Helper::Rust(helper) => helper.get_signature(), - #[cfg(feature = "wasm")] - Helper::Wasm(helper) => helper.get_signature(), - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - use zokrates_field::field::FieldPrime; - - mod eq_condition { - - // Wanted: (Y = (X != 0) ? 1 : 0) - // # Y = if X == 0 then 0 else 1 fi - // # M = if X == 0 then 1 else 1/X fi - - use super::*; - - #[test] - fn execute() { - let cond_eq = RustHelper::ConditionEq; - let inputs = vec![0]; - let r = cond_eq - .execute(&inputs.iter().map(|&i| FieldPrime::from(i)).collect()) - .unwrap(); - let res: Vec = vec![0, 1].iter().map(|&i| FieldPrime::from(i)).collect(); - assert_eq!(r, &res[..]); - } - - #[test] - fn execute_non_eq() { - let cond_eq = RustHelper::ConditionEq; - let inputs = vec![1]; - let r = cond_eq - .execute(&inputs.iter().map(|&i| FieldPrime::from(i)).collect()) - .unwrap(); - let res: Vec = vec![1, 1].iter().map(|&i| FieldPrime::from(i)).collect(); - assert_eq!(r, &res[..]); - } - } -} diff --git a/zokrates_core/src/helpers/wasm.rs b/zokrates_core/src/helpers/wasm.rs deleted file mode 100644 index 89ba1a6f..00000000 --- a/zokrates_core/src/helpers/wasm.rs +++ /dev/null @@ -1,621 +0,0 @@ -use helpers::{Executable, Signed}; -use std::fmt; - -use rustc_hex::FromHex; -use serde::{Deserialize, Deserializer}; -use std::hash::{Hash, Hasher}; -use std::rc::Rc; -use wasmi::{ImportsBuilder, ModuleInstance, ModuleRef, NopExternals}; -use zokrates_field::field::Field; - -#[derive(Clone, Debug, Serialize)] -pub struct WasmHelper( - #[serde(skip)] std::rc::Rc, - #[serde(serialize_with = "serde_bytes::serialize")] Vec, -); - -impl WasmHelper { - // Hand-coded assembly for identity. - // Source available at https://gist.github.com/gballet/f14d11053d8f846bfbb3687581b0eecb#file-identity-wast - pub const IDENTITY_WASM: &'static str = "0061736d010000000105016000017f030302000005030100010615047f0041010b7f0041010b7f0041200b7f0141000b074b06066d656d6f727902000e6765745f696e707574735f6f6666000105736f6c766500000a6d696e5f696e7075747303000b6d696e5f6f75747075747303010a6669656c645f73697a6503020a300229000340412023036a410023036a280200360200230341016a240323032302470d000b4100240341200b040041000b0b4b020041000b20ffffffff000000000000000000000000ffffffff0000000000000000000000000041200b20deadbeef000000000000000000000000deadbeef000000000000000000000000"; - // Generated from C code, normalized and cleaned up by hand. - // Source available at https://gist.github.com/gballet/f14d11053d8f846bfbb3687581b0eecb#file-bits_v2-c - pub const BITS_WASM: &'static [u8] = &[ - 0, 97, 115, 109, 1, 0, 0, 0, 1, 8, 2, 96, 0, 0, 96, 0, 1, 127, 3, 5, 4, 0, 1, 1, 1, 4, 5, - 1, 112, 1, 1, 1, 5, 3, 1, 0, 2, 6, 38, 6, 127, 1, 65, 240, 199, 4, 11, 127, 0, 65, 240, - 199, 4, 11, 127, 0, 65, 240, 199, 0, 11, 127, 0, 65, 32, 11, 127, 0, 65, 1, 11, 127, 0, 65, - 254, 1, 11, 7, 109, 9, 6, 109, 101, 109, 111, 114, 121, 2, 0, 11, 95, 95, 104, 101, 97, - 112, 95, 98, 97, 115, 101, 3, 1, 10, 95, 95, 100, 97, 116, 97, 95, 101, 110, 100, 3, 2, 14, - 103, 101, 116, 95, 105, 110, 112, 117, 116, 115, 95, 111, 102, 102, 0, 1, 5, 115, 111, 108, - 118, 101, 0, 2, 4, 109, 97, 105, 110, 0, 3, 10, 102, 105, 101, 108, 100, 95, 115, 105, 122, - 101, 3, 3, 10, 109, 105, 110, 95, 105, 110, 112, 117, 116, 115, 3, 4, 11, 109, 105, 110, - 95, 111, 117, 116, 112, 117, 116, 115, 3, 5, 9, 1, 0, 10, 85, 4, 3, 0, 1, 11, 5, 0, 65, - 144, 8, 11, 68, 1, 2, 127, 65, 253, 1, 33, 0, 65, 176, 8, 33, 1, 3, 64, 32, 1, 65, 1, 32, - 0, 65, 7, 113, 116, 32, 0, 65, 3, 118, 65, 144, 8, 106, 45, 0, 0, 113, 65, 0, 71, 58, 0, 0, - 32, 1, 65, 32, 106, 33, 1, 32, 0, 65, 127, 106, 34, 0, 65, 127, 71, 13, 0, 11, 65, 176, 8, - 11, 4, 0, 65, 0, 11, 11, 19, 1, 0, 65, 128, 8, 11, 12, 32, 0, 0, 0, 1, 0, 0, 0, 254, 0, 0, - 0, - ]; - - pub fn from_hex>(u: U) -> Self { - let code_hex = u.into(); - let code = FromHex::from_hex(&code_hex[..]) - .expect(format!("invalid bytecode: {}", code_hex).as_str()); - WasmHelper::from(code) - } -} - -impl>> From for WasmHelper { - fn from(code: U) -> Self { - let code_vec = code.into(); - let module = wasmi::Module::from_buffer(code_vec.clone()).expect("Error decoding buffer"); - let modinst = ModuleInstance::new(&module, &ImportsBuilder::default()) - .expect("Failed to instantiate module") - .assert_no_start(); - WasmHelper(Rc::new(modinst), code_vec) - } -} - -impl<'de> Deserialize<'de> for WasmHelper { - fn deserialize(deserializer: D) -> Result - where - D: Deserializer<'de>, - { - let hex: Vec = serde_bytes::deserialize(deserializer)?; - Ok(WasmHelper::from(hex)) - } -} - -impl PartialEq for WasmHelper { - fn eq(&self, other: &WasmHelper) -> bool { - self.1 == other.1 - } -} - -impl Eq for WasmHelper {} - -impl Hash for WasmHelper { - fn hash(&self, state: &mut H) { - self.1.hash(state); - } -} - -impl fmt::Display for WasmHelper { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "Hex(\"{:?}\")", &self.1[..]) - } -} - -fn get_export(varname: &str, modref: &ModuleRef) -> Result { - modref - .export_by_name(varname) - .ok_or(&format!("Could not find exported symbol `{}` in module", varname)[..])? - .as_global() - .ok_or(format!( - "Error getting {} from the list of globals", - varname - ))? - .get() - .try_into::() - .ok_or(format!("Error converting `{}` to i32", varname)) -} - -impl Signed for WasmHelper { - fn get_signature(&self) -> (usize, usize) { - // Check that the module has the following exports: - // * min_inputs = the (minimum) number of inputs - // * min_outputs = the (minimum) number of outputs - let ni = get_export::("min_inputs", self.0.as_ref()).unwrap(); - let no = get_export::("min_outputs", self.0.as_ref()).unwrap(); - - (ni as usize, no as usize) - } -} - -impl Executable for WasmHelper { - fn execute(&self, inputs: &Vec) -> Result, String> { - let field_size = get_export::("field_size", self.0.as_ref())? as usize; - let ninputs = get_export::("min_inputs", self.0.as_ref())? as usize; - - if ninputs != inputs.len() { - return Err(format!( - "`solve` expected {} inputs, received {}", - ninputs, - inputs.len() - )); - } - - /* Prepare the inputs */ - let input_offset = self - .0 - .invoke_export("get_inputs_off", &[], &mut NopExternals) - .map_err(|e| format!("Error getting the input offset: {}", e.to_string()))? - .ok_or("`get_inputs_off` did not return any value")? - .try_into::() - .ok_or("`get_inputs_off` returned the wrong type")?; - - let mem = self - .0 - .as_ref() - .export_by_name("memory") - .ok_or("Module didn't export its memory section")? - .as_memory() - .unwrap() - .clone(); - - for (index, input) in inputs.iter().enumerate() { - // Get the field's bytes and check they correspond to - // the value that the module expects. - let mut bv = input.into_byte_vector(); - if bv.len() > field_size { - return Err(format!( - "Input #{} is stored on {} bytes which is greater than the field size of {}", - index, - bv.len(), - field_size - )); - } else { - bv.resize(field_size, 0); - } - - let addr = (input_offset as u32) + (index as u32) * (field_size as u32); - mem.set(addr, &bv[..]) - .map_err(|_e| format!("Could not write at memory address {}", addr))?; - } - - let output_offset = self - .0 - .as_ref() - .invoke_export("solve", &[], &mut NopExternals) - .map_err(|e| format!("Error solving the problem: {}", e.to_string()))? - .ok_or("`solve` did not return any value")? - .try_into::() - .ok_or("`solve returned the wrong type`")?; - - // NOTE: The question regarding the way that an error code is - // returned is still open. - // - // The current model considers that 2GB is more than enough - // to store the output data. - // - // This being said this approach is tacky and I am considering - // others at this point: - // - // 1. Use a 64 bit return code, values greater than 32-bits - // are considered error codes. - // 2. Export an extra global called `errno` which contains - // the error code. - // 3. 32-bit alignment gives a 2-bit error field - // 4. Return a pointer to a structure that contains - // the error code just before the output data. - // - // Experimenting with other languages will help decide what - // is the better approach. - if output_offset > 0 { - let mut outputs = Vec::new(); - let noutputs = get_export::("min_outputs", self.0.as_ref())?; - for i in 0..noutputs { - let index = i as u32; - let fs = field_size as u32; - let value = mem - .get(output_offset as u32 + fs * index, field_size) - .map_err(|e| format!("Could not retrieve the output offset: {}", e))?; - - outputs.push(T::from_byte_vector(value)); - } - - Ok(outputs) - } else { - Err(format!("`solve` returned error code {}", output_offset)) - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - use parity_wasm::builder::*; - use parity_wasm::elements::{Instruction, Instructions, ValueType}; - use std::panic; - use zokrates_field::field::FieldPrime; - - fn remove_export(code: &str, symbol: &str) -> Vec { - let code = FromHex::from_hex(code).unwrap(); - let mut idmod: parity_wasm::elements::Module = parity_wasm::deserialize_buffer(&code[..]) - .expect("Could not deserialize Identity module"); - idmod - .export_section_mut() - .expect("Could not get export section") - .entries_mut() - .retain(|ref export| export.field() != symbol); - parity_wasm::serialize(idmod).expect("Could not serialize buffer") - } - - fn replace_function( - code: &str, - symbol: &str, - params: Vec, - ret: Option, - instr: Vec, - ) -> Vec { - /* Deserialize to parity_wasm format */ - let code = FromHex::from_hex(code).unwrap(); - let mut pwmod: parity_wasm::elements::Module = parity_wasm::deserialize_buffer(&code[..]) - .expect("Could not deserialize Identity module"); - - /* Remove export, if it exists */ - pwmod - .export_section_mut() - .expect("Could not get export section") - .entries_mut() - .retain(|ref export| export.field() != symbol); - - /* Add a new function and give it the export name */ - let wmod: parity_wasm::elements::Module = from_module(pwmod) - .function() - .signature() - .with_params(params) - .with_return_type(ret) - .build() - .body() - .with_instructions(Instructions::new(instr)) - .build() - .build() - .export() - .field(symbol) - .internal() - .func(2) - .build() - .build(); - parity_wasm::serialize(wmod).expect("Could not serialize buffer") - } - - fn replace_global(code: &str, symbol: &str, value: i32) -> Vec { - /* Deserialize to parity_wasm format */ - let code = FromHex::from_hex(code).unwrap(); - let mut pwmod: parity_wasm::elements::Module = parity_wasm::deserialize_buffer(&code[..]) - .expect("Could not deserialize Identity module"); - - /* Remove export, if it exists */ - pwmod - .export_section_mut() - .expect("Could not get export section") - .entries_mut() - .retain(|ref export| export.field() != symbol); - - /* Add a new function and give it the export name */ - let wmod: parity_wasm::elements::Module = from_module(pwmod) - .global() - .value_type() - .i32() - .init_expr(Instruction::I32Const(value)) - .build() - .export() - .field(symbol) - .internal() - .global(4) - .build() - .build(); - parity_wasm::serialize(wmod).expect("Could not serialize buffer") - } - - fn replace_global_type(code: &str, symbol: &str) -> Vec { - /* Deserialize to parity_wasm format */ - let code = FromHex::from_hex(code).unwrap(); - let mut pwmod: parity_wasm::elements::Module = parity_wasm::deserialize_buffer(&code[..]) - .expect("Could not deserialize Identity module"); - - /* Remove export, if it exists */ - pwmod - .export_section_mut() - .expect("Could not get export section") - .entries_mut() - .retain(|ref export| export.field() != symbol); - - /* Add a new function and give it the export name */ - let wmod: parity_wasm::elements::Module = from_module(pwmod) - .global() - .value_type() - .f32() - .init_expr(Instruction::F32Const(0)) - .build() - .export() - .field(symbol) - .internal() - .global(4) - .build() - .build(); - parity_wasm::serialize(wmod).expect("Could not serialize buffer") - } - - #[test] - fn check_signatures() { - let h1 = WasmHelper::from_hex(WasmHelper::IDENTITY_WASM); - assert_eq!(h1.get_signature(), (1, 1)); - } - - #[test] - #[should_panic( - expected = "invalid bytecode: invalid bytecode: Invalid character 'i' at position 0" - )] - fn check_invalid_bytecode_fails() { - WasmHelper::from_hex("invalid bytecode"); - } - - #[test] - #[should_panic(expected = "Error decoding buffer: Validation(\"I/O Error: UnexpectedEof\")")] - fn check_truncated_bytecode_fails() { - WasmHelper::from_hex(&WasmHelper::IDENTITY_WASM[..20]); - } - - #[test] - fn validate_exports() { - /* Test identity without the `solve` export */ - let id = WasmHelper::from(remove_export(WasmHelper::IDENTITY_WASM, "solve")); - let input = vec![FieldPrime::from(1)]; - let outputs = id.execute(&input); - assert_eq!( - outputs, - Err(String::from( - "Error solving the problem: Function: Module doesn\'t have export solve", - )) - ); - - /* Test identity, without the `get_inputs_off` export */ - let id = WasmHelper::from(remove_export(WasmHelper::IDENTITY_WASM, "get_inputs_off")); - let input = vec![FieldPrime::from(1)]; - let outputs = id.execute(&input); - assert_eq!( - outputs, - Err(String::from( - "Error getting the input offset: Function: Module doesn\'t have export get_inputs_off", - )) - ); - - /* Test identity, without the `min_inputs` export */ - let id = WasmHelper::from(remove_export(WasmHelper::IDENTITY_WASM, "min_inputs")); - let input = vec![FieldPrime::from(1)]; - let outputs = id.execute(&input); - assert_eq!( - outputs, - Err(String::from( - "Could not find exported symbol `min_inputs` in module", - )) - ); - - /* Test identity, without the `min_outputs` export */ - let id = WasmHelper::from(remove_export(WasmHelper::IDENTITY_WASM, "min_outputs")); - let input = vec![FieldPrime::from(1)]; - let outputs = id.execute(&input); - assert_eq!( - outputs, - Err(String::from( - "Could not find exported symbol `min_outputs` in module", - )) - ); - - /* Test identity, without the `field_size` export */ - let id = WasmHelper::from(remove_export(WasmHelper::IDENTITY_WASM, "field_size")); - let input = vec![FieldPrime::from(1)]; - let outputs = id.execute(&input); - assert_eq!( - outputs, - Err(String::from( - "Could not find exported symbol `field_size` in module", - )) - ); - - /* Test identity, without the `memory` export */ - let id = WasmHelper::from(remove_export(WasmHelper::IDENTITY_WASM, "memory")); - let input = vec![FieldPrime::from(1)]; - let outputs = id.execute(&input); - assert_eq!( - outputs, - Err(String::from("Module didn\'t export its memory section")) - ); - } - - #[test] - fn check_invalid_function_type() { - /* Test identity, with a different function return type */ - let id = WasmHelper::from(replace_function( - WasmHelper::IDENTITY_WASM, - "get_inputs_off", - Vec::new(), - Some(ValueType::I64), - vec![Instruction::I64Const(0), Instruction::End], - )); - let input = vec![FieldPrime::from(1)]; - let outputs = id.execute(&input); - assert_eq!( - outputs, - Err(String::from("`get_inputs_off` returned the wrong type")) - ); - - /* Test identity, with no return type for function */ - let id = WasmHelper::from(replace_function( - WasmHelper::IDENTITY_WASM, - "get_inputs_off", - Vec::new(), - None, - vec![Instruction::Nop, Instruction::End], - )); - let input = vec![FieldPrime::from(1)]; - let outputs = id.execute(&input); - assert_eq!( - outputs, - Err(String::from("`get_inputs_off` did not return any value")) - ); - - /* Test identity, with extra parameter for function */ - let id = WasmHelper::from(replace_function( - WasmHelper::IDENTITY_WASM, - "get_inputs_off", - vec![ValueType::I64], - Some(ValueType::I32), - vec![Instruction::I32Const(0), Instruction::End], - )); - let input = vec![FieldPrime::from(1)]; - let outputs = id.execute(&input); - assert_eq!( - outputs, - Err(String::from( - "Error getting the input offset: Trap: Trap { kind: UnexpectedSignature }", - )) - ); - } - - #[test] - fn check_invalid_field_size() { - /* Test identity, with 1-byte filed size */ - let id = WasmHelper::from(replace_global(WasmHelper::IDENTITY_WASM, "field_size", 1)); - let input = vec![FieldPrime::from(65536)]; - let outputs = id.execute(&input); - assert_eq!( - outputs, - Err(String::from( - "Input #0 is stored on 3 bytes which is greater than the field size of 1", - )) - ); - - /* Test identity, tweaked so that field_size is a f32 */ - let id = WasmHelper::from(replace_global_type(WasmHelper::IDENTITY_WASM, "field_size")); - let input = vec![FieldPrime::from(65536)]; - let outputs = id.execute(&input); - assert_eq!( - outputs, - Err(String::from("Error converting `field_size` to i32")) - ); - } - - #[test] - fn check_identity() { - let id = WasmHelper::from_hex(WasmHelper::IDENTITY_WASM); - let input = vec![FieldPrime::from(1)]; - let outputs = id.execute(&input).expect("Identity call failed"); - assert_eq!(outputs, input); - - let id = WasmHelper::from_hex(WasmHelper::IDENTITY_WASM); - let input = vec![FieldPrime::from(0)]; - let outputs = id.execute(&input).expect("Identity call failed"); - assert_eq!(outputs, input); - } - - #[test] - fn check_identity_3_bytes() { - let id = WasmHelper::from_hex(WasmHelper::IDENTITY_WASM); - let input = vec![FieldPrime::from(16777216)]; - let outputs = id.execute(&input).expect("Identity call failed"); - assert_eq!(outputs, input); - } - - #[test] - fn check_identity_multiple_calls() { - let id = WasmHelper::from_hex(WasmHelper::IDENTITY_WASM); - let input = vec![FieldPrime::from(16777216)]; - for _i in 0..10 { - let outputs = id.execute(&input).expect("Identity call failed"); - assert_eq!(outputs, input); - } - } - - #[test] - fn check_invalid_arg_number() { - let id = WasmHelper::from_hex(WasmHelper::IDENTITY_WASM); - let input = vec![FieldPrime::from(1)]; - let outputs = id.execute(&input).expect("Identity call failed"); - assert_eq!(outputs, input); - } - - #[test] - fn check_memory_boundaries() { - // Check that input writes are boundary-checked: same as identity, but - // get_inputs_off returns an OOB offset. - let id = WasmHelper::from(replace_function( - WasmHelper::IDENTITY_WASM, - "get_inputs_off", - Vec::new(), - Some(ValueType::I32), - vec![Instruction::I32Const(65536), Instruction::End], - )); - let input = vec![FieldPrime::from(65536)]; - let outputs = id.execute(&input); - assert_eq!( - outputs, - Err(String::from("Could not write at memory address 65536")) - ); - - /* Check that output writes are boundary-checked */ - // Check that input writes are boundary-checked: same as identity, but - // solve returns an OOB offset. - let id = WasmHelper::from(replace_function( - WasmHelper::IDENTITY_WASM, - "solve", - Vec::new(), - Some(ValueType::I32), - vec![Instruction::I32Const(65536), Instruction::End], - )); - let input = vec![FieldPrime::from(65536)]; - let outputs = id.execute(&input); - assert_eq!( - outputs, - Err(String::from( - "Could not retrieve the output offset: Memory: trying to access region [65536..65568] in memory [0..64]", - )) - ); - } - - #[test] - fn check_negative_output_value() { - /* Same as identity, but `solve` returns -1 */ - let id = WasmHelper::from(replace_function( - WasmHelper::IDENTITY_WASM, - "solve", - Vec::new(), - Some(ValueType::I32), - vec![Instruction::I32Const(-1), Instruction::End], - )); - let input = vec![FieldPrime::from(1)]; - let outputs = id.execute(&input); - assert_eq!(outputs, Err(String::from("`solve` returned error code -1"))); - } - - #[test] - fn check_bits() { - let bits = WasmHelper::from(WasmHelper::BITS_WASM); - let input = vec![FieldPrime::from(0xdeadbeef as u32)]; - let outputs = bits.execute(&input).unwrap(); - - assert_eq!(254, outputs.len()); - - for i in 0..32 { - let bitval = (0xdeadbeef as i64 >> i) & 1; - assert_eq!(outputs[(253 - i) as usize], FieldPrime::from(bitval as i32)); - } - - for i in 32..254 { - assert_eq!(outputs[(253 - i) as usize], FieldPrime::from(0)); - } - } - - #[test] - fn check_bits_multiple_times() { - let bits = WasmHelper::from(WasmHelper::BITS_WASM); - let input = vec![FieldPrime::from(0xdeadbeef as u32)]; - - for _ in 0..10 { - let outputs = bits.execute(&input).unwrap(); - - assert_eq!(254, outputs.len()); - - for i in 0..32 { - let bitval = (0xdeadbeef as i64 >> i) & 1; - assert_eq!(outputs[(253 - i) as usize], FieldPrime::from(bitval as i32)); - } - - for i in 32..254 { - assert_eq!(outputs[(253 - i) as usize], FieldPrime::from(0)); - } - } - } -} diff --git a/zokrates_core/src/ir/from_flat.rs b/zokrates_core/src/ir/from_flat.rs index 13268a02..7e89ba51 100644 --- a/zokrates_core/src/ir/from_flat.rs +++ b/zokrates_core/src/ir/from_flat.rs @@ -1,5 +1,6 @@ -use crate::flat_absy::{FlatExpression, FlatFunction, FlatProg, FlatStatement, FlatVariable}; -use crate::helpers; +use crate::flat_absy::{ + FlatDirective, FlatExpression, FlatFunction, FlatProg, FlatStatement, FlatVariable, +}; use crate::ir::{Directive, Function, LinComb, Prog, QuadComb, Statement}; use num::Zero; use zokrates_field::field::Field; @@ -126,11 +127,11 @@ impl From> for Statement { } } -impl From> for Directive { - fn from(ds: helpers::DirectiveStatement) -> Directive { +impl From> for Directive { + fn from(ds: FlatDirective) -> Directive { Directive { inputs: ds.inputs.into_iter().map(|i| i.into()).collect(), - helper: ds.helper, + solver: ds.solver, outputs: ds.outputs, } } diff --git a/zokrates_core/src/ir/interpreter.rs b/zokrates_core/src/ir/interpreter.rs index 65495aea..f900ab28 100644 --- a/zokrates_core/src/ir/interpreter.rs +++ b/zokrates_core/src/ir/interpreter.rs @@ -1,6 +1,6 @@ use crate::flat_absy::flat_variable::FlatVariable; -use crate::helpers::Executable; use crate::ir::{LinComb, Prog, QuadComb, Statement, Witness}; +use crate::solvers::Executable; use std::collections::BTreeMap; use std::fmt; use zokrates_field::field::Field; @@ -41,7 +41,7 @@ impl Prog { .iter() .map(|i| i.evaluate(&witness).unwrap()) .collect(); - match d.helper.execute(&input_values) { + match d.solver.execute(&input_values) { Ok(res) => { for (i, o) in d.outputs.iter().enumerate() { witness.insert(o.clone(), res[i].clone()); diff --git a/zokrates_core/src/ir/mod.rs b/zokrates_core/src/ir/mod.rs index c483c646..2633698e 100644 --- a/zokrates_core/src/ir/mod.rs +++ b/zokrates_core/src/ir/mod.rs @@ -1,6 +1,6 @@ use crate::flat_absy::flat_parameter::FlatParameter; use crate::flat_absy::FlatVariable; -use crate::helpers::Helper; +use crate::solvers::Solver; use std::fmt; use typed_absy::types::signature::Signature; use zokrates_field::field::Field; @@ -37,7 +37,7 @@ impl Statement { pub struct Directive { pub inputs: Vec>, pub outputs: Vec, - pub helper: Helper, + pub solver: Solver, } impl fmt::Display for Directive { @@ -50,7 +50,7 @@ impl fmt::Display for Directive { .map(|o| format!("{}", o)) .collect::>() .join(", "), - self.helper, + self.solver, self.inputs .iter() .map(|i| format!("{}", i)) diff --git a/zokrates_core/src/lib.rs b/zokrates_core/src/lib.rs index a387e875..93f39344 100644 --- a/zokrates_core/src/lib.rs +++ b/zokrates_core/src/lib.rs @@ -28,11 +28,11 @@ extern crate zokrates_pest_ast; mod embed; mod flatten; -mod helpers; mod imports; mod optimizer; mod parser; mod semantics; +mod solvers; mod static_analysis; pub mod absy; diff --git a/zokrates_core/src/semantics.rs b/zokrates_core/src/semantics.rs index 30830c48..31bf1aa3 100644 --- a/zokrates_core/src/semantics.rs +++ b/zokrates_core/src/semantics.rs @@ -1948,7 +1948,7 @@ mod tests { mod symbols { use super::*; - /// Helper function to create (() -> (): return) + /// solver function to create (() -> (): return) fn function0() -> FunctionNode<'static, FieldPrime> { let statements: Vec> = vec![Statement::Return( ExpressionList { @@ -1970,7 +1970,7 @@ mod tests { .mock() } - /// Helper function to create ((private field a) -> (): return) + /// solver function to create ((private field a) -> (): return) fn function1() -> FunctionNode<'static, FieldPrime> { let statements: Vec> = vec![Statement::Return( ExpressionList { @@ -3304,7 +3304,7 @@ mod tests { mod structs { use super::*; - /// helper function to create a module at location "" with a single symbol `Foo { foo: field }` + /// solver function to create a module at location "" with a single symbol `Foo { foo: field }` fn create_module_with_foo( s: StructType<'static>, ) -> (Checker<'static>, State<'static, FieldPrime>) { diff --git a/zokrates_core/src/helpers/rust.rs b/zokrates_core/src/solvers/mod.rs similarity index 53% rename from zokrates_core/src/helpers/rust.rs rename to zokrates_core/src/solvers/mod.rs index f1bb3f4f..b8599121 100644 --- a/zokrates_core/src/helpers/rust.rs +++ b/zokrates_core/src/solvers/mod.rs @@ -1,10 +1,9 @@ -use crate::helpers::{Executable, Signed}; use std::fmt; use zokrates_embed::generate_sha256_round_witness; use zokrates_field::field::Field; #[derive(Clone, PartialEq, Debug, Serialize, Deserialize, Hash, Eq)] -pub enum RustHelper { +pub enum Solver { Identity, ConditionEq, Bits, @@ -12,33 +11,36 @@ pub enum RustHelper { Sha256Round, } -impl fmt::Display for RustHelper { +impl fmt::Display for Solver { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "{:?}", self) } } -impl Signed for RustHelper { +impl Signed for Solver { fn get_signature(&self) -> (usize, usize) { match self { - RustHelper::Identity => (1, 1), - RustHelper::ConditionEq => (1, 2), - RustHelper::Bits => (1, 254), - RustHelper::Div => (2, 1), - RustHelper::Sha256Round => (768, 26935), + Solver::Identity => (1, 1), + Solver::ConditionEq => (1, 2), + Solver::Bits => (1, 254), + Solver::Div => (2, 1), + Solver::Sha256Round => (768, 26935), } } } -impl Executable for RustHelper { +impl Executable for Solver { fn execute(&self, inputs: &Vec) -> Result, String> { + let (expected_input_count, expected_output_count) = self.get_signature(); + assert!(inputs.len() == expected_input_count); + match self { - RustHelper::Identity => Ok(inputs.clone()), - RustHelper::ConditionEq => match inputs[0].is_zero() { + Solver::Identity => Ok(inputs.clone()), + Solver::ConditionEq => match inputs[0].is_zero() { true => Ok(vec![T::zero(), T::one()]), false => Ok(vec![T::one(), T::one() / inputs[0].clone()]), }, - RustHelper::Bits => { + Solver::Bits => { let mut num = inputs[0].clone(); let mut res = vec![]; let bits = 254; @@ -53,8 +55,8 @@ impl Executable for RustHelper { assert_eq!(num, T::zero()); Ok(res) } - RustHelper::Div => Ok(vec![inputs[0].clone() / inputs[1].clone()]), - RustHelper::Sha256Round => { + Solver::Div => Ok(vec![inputs[0].clone() / inputs[1].clone()]), + Solver::Sha256Round => { let i = &inputs[0..512]; let h = &inputs[512..]; let i: Vec<_> = i.iter().map(|x| x.clone().into_bellman()).collect(); @@ -69,15 +71,64 @@ impl Executable for RustHelper { } } +impl Solver { + pub fn identity() -> Self { + Solver::Identity + } + + pub fn bits() -> Self { + Solver::Bits + } +} + +pub trait Executable: Signed { + fn execute(&self, inputs: &Vec) -> Result, String>; +} + +pub trait Signed { + fn get_signature(&self) -> (usize, usize); +} + #[cfg(test)] mod tests { use super::*; use zokrates_field::field::FieldPrime; + mod eq_condition { + + // Wanted: (Y = (X != 0) ? 1 : 0) + // # Y = if X == 0 then 0 else 1 fi + // # M = if X == 0 then 1 else 1/X fi + + use super::*; + + #[test] + fn execute() { + let cond_eq = Solver::ConditionEq; + let inputs = vec![0]; + let r = cond_eq + .execute(&inputs.iter().map(|&i| FieldPrime::from(i)).collect()) + .unwrap(); + let res: Vec = vec![0, 1].iter().map(|&i| FieldPrime::from(i)).collect(); + assert_eq!(r, &res[..]); + } + + #[test] + fn execute_non_eq() { + let cond_eq = Solver::ConditionEq; + let inputs = vec![1]; + let r = cond_eq + .execute(&inputs.iter().map(|&i| FieldPrime::from(i)).collect()) + .unwrap(); + let res: Vec = vec![1, 1].iter().map(|&i| FieldPrime::from(i)).collect(); + assert_eq!(r, &res[..]); + } + } + #[test] fn bits_of_one() { let inputs = vec![FieldPrime::from(1)]; - let res = RustHelper::Bits.execute(&inputs).unwrap(); + let res = Solver::Bits.execute(&inputs).unwrap(); assert_eq!(res[253], FieldPrime::from(1)); for i in 0..252 { assert_eq!(res[i], FieldPrime::from(0)); @@ -87,7 +138,7 @@ mod tests { #[test] fn bits_of_42() { let inputs = vec![FieldPrime::from(42)]; - let res = RustHelper::Bits.execute(&inputs).unwrap(); + let res = Solver::Bits.execute(&inputs).unwrap(); assert_eq!(res[253], FieldPrime::from(0)); assert_eq!(res[252], FieldPrime::from(1)); assert_eq!(res[251], FieldPrime::from(0)); diff --git a/zokrates_core/src/standard.rs b/zokrates_core/src/standard.rs index 0584a195..d4aa4b84 100644 --- a/zokrates_core/src/standard.rs +++ b/zokrates_core/src/standard.rs @@ -1,6 +1,6 @@ use crate::flat_absy::{FlatExpression, FlatExpressionList, FlatFunction, FlatStatement}; use crate::flat_absy::{FlatParameter, FlatVariable}; -use crate::helpers::{DirectiveStatement, Helper, RustHelper}; +use crate::solvers::{FlatDirective, solver, Rustsolver}; use crate::types::{Signature, Type}; use bellman::pairing::ff::ScalarEngine; use reduce::Reduce; @@ -118,13 +118,13 @@ pub fn sha_round() -> FlatFunction { .collect(); // insert a directive to set the witness based on the bellman gadget and inputs - let directive_statement = FlatStatement::Directive(DirectiveStatement { + let directive_statement = FlatStatement::Directive(FlatDirective { outputs: cs_indices.map(|i| FlatVariable::new(i)).collect(), inputs: input_argument_indices .chain(current_hash_argument_indices) .map(|i| FlatVariable::new(i).into()) .collect(), - helper: Helper::Rust(RustHelper::Sha256Round), + solver: Solver::Sha256Round, }); // insert a statement to return the subset of the witness diff --git a/zokrates_core/src/static_analysis/flat_propagation.rs b/zokrates_core/src/static_analysis/flat_propagation.rs index 1b26d5fa..a759f8af 100644 --- a/zokrates_core/src/static_analysis/flat_propagation.rs +++ b/zokrates_core/src/static_analysis/flat_propagation.rs @@ -5,7 +5,6 @@ //! @date 2018 use crate::flat_absy::*; -use crate::helpers::DirectiveStatement; use std::collections::HashMap; use zokrates_field::field::Field; @@ -74,7 +73,7 @@ impl FlatStatement { e1.propagate(constants), e2.propagate(constants), )), - FlatStatement::Directive(d) => Some(FlatStatement::Directive(DirectiveStatement { + FlatStatement::Directive(d) => Some(FlatStatement::Directive(FlatDirective { inputs: d .inputs .into_iter()