1
0
Fork 0
mirror of synced 2025-09-24 04:40:05 +00:00

remove wasm helpers, simplify, rename helper to solver

This commit is contained in:
schaeff 2019-11-25 19:02:15 +01:00
parent 6ad10fd339
commit 840c268a6a
17 changed files with 168 additions and 922 deletions

53
Cargo.lock generated
View file

@ -652,11 +652,6 @@ dependencies = [
"rustc_version 0.2.3 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "memory_units"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
[[package]]
name = "nodrop"
version = "0.1.13"
@ -767,22 +762,6 @@ dependencies = [
"rand 0.4.6 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "parity-wasm"
version = "0.31.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"byteorder 1.3.2 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "parity-wasm"
version = "0.35.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"byteorder 1.3.2 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "percent-encoding"
version = "1.0.1"
@ -1036,11 +1015,6 @@ name = "rustc-demangle"
version = "0.1.16"
source = "registry+https://github.com/rust-lang/crates.io-index"
[[package]]
name = "rustc-hex"
version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
[[package]]
name = "rustc-serialize"
version = "0.3.24"
@ -1377,24 +1351,6 @@ name = "wasi"
version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
[[package]]
name = "wasmi"
version = "0.4.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"memory_units 0.3.0 (registry+https://github.com/rust-lang/crates.io-index)",
"parity-wasm 0.31.3 (registry+https://github.com/rust-lang/crates.io-index)",
"wasmi-validation 0.1.0 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "wasmi-validation"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"parity-wasm 0.31.3 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "winapi"
version = "0.3.8"
@ -1480,17 +1436,14 @@ dependencies = [
"num 0.1.42 (registry+https://github.com/rust-lang/crates.io-index)",
"num-bigint 0.1.44 (registry+https://github.com/rust-lang/crates.io-index)",
"pairing_ce 0.18.0 (registry+https://github.com/rust-lang/crates.io-index)",
"parity-wasm 0.35.7 (registry+https://github.com/rust-lang/crates.io-index)",
"rand 0.4.6 (registry+https://github.com/rust-lang/crates.io-index)",
"reduce 0.1.2 (registry+https://github.com/rust-lang/crates.io-index)",
"regex 0.2.11 (registry+https://github.com/rust-lang/crates.io-index)",
"rustc-hex 1.0.0 (registry+https://github.com/rust-lang/crates.io-index)",
"serde 1.0.101 (registry+https://github.com/rust-lang/crates.io-index)",
"serde_bytes 0.10.5 (registry+https://github.com/rust-lang/crates.io-index)",
"serde_derive 1.0.101 (registry+https://github.com/rust-lang/crates.io-index)",
"serde_json 1.0.41 (registry+https://github.com/rust-lang/crates.io-index)",
"typed-arena 1.6.1 (registry+https://github.com/rust-lang/crates.io-index)",
"wasmi 0.4.5 (registry+https://github.com/rust-lang/crates.io-index)",
"zokrates_embed 0.1.0",
"zokrates_field 0.3.4",
"zokrates_pest_ast 0.1.3",
@ -1659,7 +1612,6 @@ dependencies = [
"checksum matches 0.1.8 (registry+https://github.com/rust-lang/crates.io-index)" = "7ffc5c5338469d4d3ea17d269fa8ea3512ad247247c30bd2df69e68309ed0a08"
"checksum memchr 2.2.1 (registry+https://github.com/rust-lang/crates.io-index)" = "88579771288728879b57485cc7d6b07d648c9f0141eb955f8ab7f9d45394468e"
"checksum memoffset 0.5.1 (registry+https://github.com/rust-lang/crates.io-index)" = "ce6075db033bbbb7ee5a0bbd3a3186bbae616f57fb001c485c7ff77955f8177f"
"checksum memory_units 0.3.0 (registry+https://github.com/rust-lang/crates.io-index)" = "71d96e3f3c0b6325d8ccd83c33b28acb183edcb6c67938ba104ec546854b0882"
"checksum nodrop 0.1.13 (registry+https://github.com/rust-lang/crates.io-index)" = "2f9667ddcc6cc8a43afc9b7917599d7216aa09c463919ea32c59ed6cac8bc945"
"checksum num 0.1.42 (registry+https://github.com/rust-lang/crates.io-index)" = "4703ad64153382334aa8db57c637364c322d3372e097840c72000dabdcf6156e"
"checksum num-bigint 0.1.44 (registry+https://github.com/rust-lang/crates.io-index)" = "e63899ad0da84ce718c14936262a41cee2c79c981fc0a0e7c7beb47d5a07e8c1"
@ -1673,8 +1625,6 @@ dependencies = [
"checksum openssl-probe 0.1.2 (registry+https://github.com/rust-lang/crates.io-index)" = "77af24da69f9d9341038eba93a073b1fdaaa1b788221b00a69bce9e762cb32de"
"checksum openssl-sys 0.9.50 (registry+https://github.com/rust-lang/crates.io-index)" = "2c42dcccb832556b5926bc9ae61e8775f2a61e725ab07ab3d1e7fcf8ae62c3b6"
"checksum pairing_ce 0.18.0 (registry+https://github.com/rust-lang/crates.io-index)" = "f075a9c570e2026111cb6dddf6a320e5163c42aa32500b315ec34acbcf7c9b36"
"checksum parity-wasm 0.31.3 (registry+https://github.com/rust-lang/crates.io-index)" = "511379a8194230c2395d2f5fa627a5a7e108a9f976656ce723ae68fca4097bfc"
"checksum parity-wasm 0.35.7 (registry+https://github.com/rust-lang/crates.io-index)" = "3e1e076c4e01399b6cd0793a8df42f90bba3ae424671ef421d1608a943155d93"
"checksum percent-encoding 1.0.1 (registry+https://github.com/rust-lang/crates.io-index)" = "31010dd2e1ac33d5b46a5b413495239882813e0369f8ed8a5e266f173602f831"
"checksum pest 2.1.2 (registry+https://github.com/rust-lang/crates.io-index)" = "7e4fb201c5c22a55d8b24fef95f78be52738e5e1361129be1b5e862ecdb6894a"
"checksum pest-ast 0.3.3 (registry+https://github.com/rust-lang/crates.io-index)" = "3fbf404899169771dd6a32c84248b83cd67a26cc7cc957aac87661490e1227e4"
@ -1706,7 +1656,6 @@ dependencies = [
"checksum rgb 0.8.14 (registry+https://github.com/rust-lang/crates.io-index)" = "2089e4031214d129e201f8c3c8c2fe97cd7322478a0d1cdf78e7029b0042efdb"
"checksum rust-crypto 0.2.36 (registry+https://github.com/rust-lang/crates.io-index)" = "f76d05d3993fd5f4af9434e8e436db163a12a9d40e1a58a726f27a01dfd12a2a"
"checksum rustc-demangle 0.1.16 (registry+https://github.com/rust-lang/crates.io-index)" = "4c691c0e608126e00913e33f0ccf3727d5fc84573623b8d65b2df340b5201783"
"checksum rustc-hex 1.0.0 (registry+https://github.com/rust-lang/crates.io-index)" = "0ceb8ce7a5e520de349e1fa172baeba4a9e8d5ef06c47471863530bc4972ee1e"
"checksum rustc-serialize 0.3.24 (registry+https://github.com/rust-lang/crates.io-index)" = "dcf128d1287d2ea9d80910b5f1120d0b8eede3fbf1abe91c40d39ea7d51e6fda"
"checksum rustc_version 0.2.3 (registry+https://github.com/rust-lang/crates.io-index)" = "138e3e0acb6c9fb258b19b67cb8abd63c00679d2851805ea151465464fe9030a"
"checksum ryu 1.0.0 (registry+https://github.com/rust-lang/crates.io-index)" = "c92464b447c0ee8c4fb3824ecc8383b81717b9f1e74ba2e72540aef7b9f82997"
@ -1751,8 +1700,6 @@ dependencies = [
"checksum void 1.0.2 (registry+https://github.com/rust-lang/crates.io-index)" = "6a02e4885ed3bc0f2de90ea6dd45ebcbb66dacffe03547fadbb0eeae2770887d"
"checksum walkdir 2.2.9 (registry+https://github.com/rust-lang/crates.io-index)" = "9658c94fa8b940eab2250bd5a457f9c48b748420d71293b165c8cdbe2f55f71e"
"checksum wasi 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)" = "b89c3ce4ce14bdc6fb6beaf9ec7928ca331de5df7e5ea278375642a2f478570d"
"checksum wasmi 0.4.5 (registry+https://github.com/rust-lang/crates.io-index)" = "aebbaef470840d157a5c47c8c49f024da7b1b80e90ff729ca982b2b80447e78b"
"checksum wasmi-validation 0.1.0 (registry+https://github.com/rust-lang/crates.io-index)" = "ab380192444b3e8522ae79c0a1976e42a82920916ccdfbce3def89f456ea33f3"
"checksum winapi 0.3.8 (registry+https://github.com/rust-lang/crates.io-index)" = "8093091eeb260906a183e6ae1abdba2ef5ef2257a21801128899c3fc699229c6"
"checksum winapi-i686-pc-windows-gnu 0.4.0 (registry+https://github.com/rust-lang/crates.io-index)" = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6"
"checksum winapi-util 0.1.2 (registry+https://github.com/rust-lang/crates.io-index)" = "7168bab6e1daee33b4557efd0e95d5ca70a03706d39fa5f3fe7a236f584b03c9"

View file

@ -8,7 +8,6 @@ edition = "2018"
[features]
default = []
libsnark = ["zokrates_core/libsnark"]
wasm = ["zokrates_core/wasm"]
[dependencies]
clap = "2.26.2"

View file

@ -9,7 +9,6 @@ build = "build.rs"
[features]
default = []
libsnark = ["cc", "cmake", "git2"]
wasm = ["wasmi", "parity-wasm", "rustc-hex"]
[dependencies]
libc = "0.2.0"
@ -32,9 +31,6 @@ zokrates_field = { version = "0.3.0", path = "../zokrates_field" }
zokrates_pest_ast = { version = "0.1.0", path = "../zokrates_pest_ast" }
zokrates_embed = { path = "../zokrates_embed" }
rand = "0.4"
wasmi = { version = "=0.4.5", optional = true }
parity-wasm = { version = "0.35.3", optional = true }
rustc-hex = { version = "1.0", optional = true }
csv = "1"
[dev-dependencies]

View file

@ -789,7 +789,7 @@ mod tests {
mod types {
use super::*;
/// Helper method to generate the ast for `def main(private {ty} a) -> (): return` which we use to check ty
/// solver method to generate the ast for `def main(private {ty} a) -> (): return` which we use to check ty
fn wrap(ty: absy::UnresolvedType) -> absy::Module<'static, FieldPrime> {
absy::Module {
symbols: vec![absy::SymbolDeclaration {

View file

@ -1,7 +1,8 @@
use crate::helpers::{DirectiveStatement, Helper, RustHelper};
use crate::solvers::Solver;
use bellman::pairing::ff::ScalarEngine;
use flat_absy::{
FlatExpression, FlatExpressionList, FlatFunction, FlatParameter, FlatStatement, FlatVariable,
FlatDirective, FlatExpression, FlatExpressionList, FlatFunction, FlatParameter, FlatStatement,
FlatVariable,
};
use reduce::Reduce;
use std::collections::HashMap;
@ -166,13 +167,13 @@ pub fn sha256_round<T: Field>() -> FlatFunction<T> {
.collect();
// insert a directive to set the witness based on the bellman gadget and inputs
let directive_statement = FlatStatement::Directive(DirectiveStatement {
let directive_statement = FlatStatement::Directive(FlatDirective {
outputs: cs_indices.map(|i| FlatVariable::new(i)).collect(),
inputs: input_argument_indices
.chain(current_hash_argument_indices)
.map(|i| FlatVariable::new(i).into())
.collect(),
helper: Helper::Rust(RustHelper::Sha256Round),
solver: Solver::Sha256Round,
});
// insert a statement to return the subset of the witness
@ -233,7 +234,7 @@ pub fn unpack<T: Field>() -> FlatFunction<T> {
.map(|index| use_variable(&mut layout, format!("o{}", index), &mut counter))
.collect();
let helper = Helper::bits();
let solver = Solver::bits();
let signature = Signature {
inputs: vec![Type::FieldElement],
@ -281,10 +282,10 @@ pub fn unpack<T: Field>() -> FlatFunction<T> {
statements.insert(
0,
FlatStatement::Directive(DirectiveStatement {
FlatStatement::Directive(FlatDirective {
inputs: directive_inputs,
outputs: directive_outputs,
helper: helper,
solver: solver,
}),
);
@ -322,11 +323,11 @@ mod tests {
); // 128 bit checks, 1 directive, 1 sum check, 1 return
assert_eq!(
unpack.statements[0],
FlatStatement::Directive(DirectiveStatement::new(
FlatStatement::Directive(FlatDirective::new(
(0..FieldPrime::get_required_bits())
.map(|i| FlatVariable::new(i + 1))
.collect(),
Helper::bits(),
Solver::bits(),
vec![FlatVariable::new(0)]
))
);

View file

@ -11,8 +11,8 @@ pub mod flat_variable;
pub use self::flat_parameter::FlatParameter;
pub use self::flat_variable::FlatVariable;
use crate::helpers::DirectiveStatement;
use crate::typed_absy::types::Signature;
use solvers::{Signed, Solver};
use std::collections::HashMap;
use std::fmt;
use zokrates_field::field::Field;
@ -95,7 +95,7 @@ pub enum FlatStatement<T: Field> {
Return(FlatExpressionList<T>),
Condition(FlatExpression<T>, FlatExpression<T>),
Definition(FlatVariable, FlatExpression<T>),
Directive(DirectiveStatement<T>),
Directive(FlatDirective<T>),
}
impl<T: Field> fmt::Display for FlatStatement<T> {
@ -149,7 +149,7 @@ impl<T: Field> FlatStatement<T> {
.map(|i| i.apply_substitution(substitution))
.collect();
FlatStatement::Directive(DirectiveStatement {
FlatStatement::Directive(FlatDirective {
outputs,
inputs,
..d
@ -159,6 +159,50 @@ impl<T: Field> FlatStatement<T> {
}
}
#[derive(Clone, PartialEq, Debug, Serialize, Deserialize)]
pub struct FlatDirective<T: Field> {
pub inputs: Vec<FlatExpression<T>>,
pub outputs: Vec<FlatVariable>,
pub solver: Solver,
}
impl<T: Field> FlatDirective<T> {
pub fn new<E: Into<FlatExpression<T>>>(
outputs: Vec<FlatVariable>,
solver: Solver,
inputs: Vec<E>,
) -> Self {
let (in_len, out_len) = solver.get_signature();
assert_eq!(in_len, inputs.len());
assert_eq!(out_len, outputs.len());
FlatDirective {
solver,
inputs: inputs.into_iter().map(|i| i.into()).collect(),
outputs,
}
}
}
impl<T: Field> fmt::Display for FlatDirective<T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(
f,
"# {} = {}({})",
self.outputs
.iter()
.map(|o| o.to_string())
.collect::<Vec<String>>()
.join(", "),
self.solver,
self.inputs
.iter()
.map(|i| i.to_string())
.collect::<Vec<String>>()
.join(", ")
)
}
}
#[derive(Clone, PartialEq, Serialize, Deserialize)]
pub enum FlatExpression<T: Field> {
Number(T),

View file

@ -6,7 +6,7 @@
//! @date 2017
use crate::flat_absy::*;
use crate::helpers::{DirectiveStatement, Helper, RustHelper};
use crate::solvers::Solver;
use crate::typed_absy::types::{FunctionIdentifier, FunctionKey, MemberId, Signature, Type};
use crate::typed_absy::*;
use std::collections::HashMap;
@ -590,9 +590,9 @@ impl<'ast, T: Field> Flattener<'ast, T> {
(0..bitwidth).map(|_| self.use_sym()).collect();
// add a directive to get the bits
statements_flattened.push(FlatStatement::Directive(DirectiveStatement::new(
statements_flattened.push(FlatStatement::Directive(FlatDirective::new(
lhs_bits_be.clone(),
Helper::bits(),
Solver::bits(),
vec![lhs_id],
)));
@ -637,9 +637,9 @@ impl<'ast, T: Field> Flattener<'ast, T> {
(0..bitwidth).map(|_| self.use_sym()).collect();
// add a directive to get the bits
statements_flattened.push(FlatStatement::Directive(DirectiveStatement::new(
statements_flattened.push(FlatStatement::Directive(FlatDirective::new(
rhs_bits_be.clone(),
Helper::bits(),
Solver::bits(),
vec![rhs_id],
)));
@ -690,9 +690,9 @@ impl<'ast, T: Field> Flattener<'ast, T> {
(0..bitwidth).map(|_| self.use_sym()).collect();
// add a directive to get the bits
statements_flattened.push(FlatStatement::Directive(DirectiveStatement::new(
statements_flattened.push(FlatStatement::Directive(FlatDirective::new(
sub_bits_be.clone(),
Helper::bits(),
Solver::bits(),
vec![subtraction_result.clone()],
)));
@ -776,9 +776,9 @@ impl<'ast, T: Field> Flattener<'ast, T> {
FieldElementExpression::Sub(box lhs, box rhs),
);
statements_flattened.push(FlatStatement::Directive(DirectiveStatement::new(
statements_flattened.push(FlatStatement::Directive(FlatDirective::new(
vec![name_y, name_m],
Helper::Rust(RustHelper::ConditionEq),
Solver::ConditionEq,
vec![x.clone()],
)));
statements_flattened.push(FlatStatement::Condition(
@ -963,9 +963,9 @@ impl<'ast, T: Field> Flattener<'ast, T> {
.into_iter()
.map(|i| i.apply_substitution(&replacement_map))
.collect();
FlatStatement::Directive(DirectiveStatement {
FlatStatement::Directive(FlatDirective {
outputs: new_outputs,
helper: d.helper,
solver: d.solver,
inputs: new_inputs,
})
}
@ -1138,9 +1138,9 @@ impl<'ast, T: Field> Flattener<'ast, T> {
let inverse = self.use_sym();
// # invb = 1/b
statements_flattened.push(FlatStatement::Directive(DirectiveStatement::new(
statements_flattened.push(FlatStatement::Directive(FlatDirective::new(
vec![invb],
Helper::Rust(RustHelper::Div),
Solver::Div,
vec![FlatExpression::Number(T::one()), new_right.clone()],
)));
@ -1151,9 +1151,9 @@ impl<'ast, T: Field> Flattener<'ast, T> {
));
// # c = a/b
statements_flattened.push(FlatStatement::Directive(DirectiveStatement::new(
statements_flattened.push(FlatStatement::Directive(FlatDirective::new(
vec![inverse],
Helper::Rust(RustHelper::Div),
Solver::Div,
vec![new_left.clone(), new_right.clone()],
)));
@ -2041,9 +2041,9 @@ mod tests {
FlatStatement::Definition(five, FlatExpression::Number(FieldPrime::from(5))),
FlatStatement::Definition(b0, b.into()),
// check div by 0
FlatStatement::Directive(DirectiveStatement::new(
FlatStatement::Directive(FlatDirective::new(
vec![invb0],
Helper::Rust(RustHelper::Div),
Solver::Div,
vec![FlatExpression::Number(FieldPrime::from(1)), b0.into()]
)),
FlatStatement::Condition(
@ -2051,9 +2051,9 @@ mod tests {
FlatExpression::Mult(box invb0.into(), box b0.into()),
),
// execute div
FlatStatement::Directive(DirectiveStatement::new(
FlatStatement::Directive(FlatDirective::new(
vec![sym_0],
Helper::Rust(RustHelper::Div),
Solver::Div,
vec![five, b0]
)),
FlatStatement::Condition(
@ -2064,9 +2064,9 @@ mod tests {
FlatStatement::Definition(sym_1, sym_0.into()),
FlatStatement::Definition(b1, b.into()),
// check div by 0
FlatStatement::Directive(DirectiveStatement::new(
FlatStatement::Directive(FlatDirective::new(
vec![invb1],
Helper::Rust(RustHelper::Div),
Solver::Div,
vec![FlatExpression::Number(FieldPrime::from(1)), b1.into()]
)),
FlatStatement::Condition(
@ -2074,9 +2074,9 @@ mod tests {
FlatExpression::Mult(box invb1.into(), box b1.into()),
),
// execute div
FlatStatement::Directive(DirectiveStatement::new(
FlatStatement::Directive(FlatDirective::new(
vec![sym_2],
Helper::Rust(RustHelper::Div),
Solver::Div,
vec![sym_1, b1]
)),
FlatStatement::Condition(

View file

@ -1,171 +0,0 @@
mod rust;
#[cfg(feature = "wasm")]
mod wasm;
pub use self::rust::RustHelper;
#[cfg(feature = "wasm")]
pub use self::wasm::WasmHelper;
use crate::flat_absy::{FlatExpression, FlatVariable};
use std::fmt;
use zokrates_field::field::Field;
#[derive(Clone, PartialEq, Debug, Serialize, Deserialize)]
pub struct DirectiveStatement<T: Field> {
pub inputs: Vec<FlatExpression<T>>,
pub outputs: Vec<FlatVariable>,
pub helper: Helper,
}
impl<T: Field> DirectiveStatement<T> {
pub fn new<E: Into<FlatExpression<T>>>(
outputs: Vec<FlatVariable>,
helper: Helper,
inputs: Vec<E>,
) -> Self {
let (in_len, out_len) = helper.get_signature();
assert_eq!(in_len, inputs.len());
assert_eq!(out_len, outputs.len());
DirectiveStatement {
helper,
inputs: inputs.into_iter().map(|i| i.into()).collect(),
outputs,
}
}
}
impl<T: Field> fmt::Display for DirectiveStatement<T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(
f,
"# {} = {}({})",
self.outputs
.iter()
.map(|o| o.to_string())
.collect::<Vec<String>>()
.join(", "),
self.helper,
self.inputs
.iter()
.map(|i| i.to_string())
.collect::<Vec<String>>()
.join(", ")
)
}
}
#[derive(Clone, PartialEq, Debug, Serialize, Deserialize, Hash, Eq)]
pub enum Helper {
Rust(RustHelper),
#[cfg(feature = "wasm")]
Wasm(WasmHelper),
}
#[cfg(feature = "wasm")]
impl Helper {
pub fn identity() -> Self {
Helper::Wasm(WasmHelper::from_hex(WasmHelper::IDENTITY_WASM))
}
pub fn bits() -> Self {
Helper::Wasm(WasmHelper::from(WasmHelper::BITS_WASM))
}
}
#[cfg(not(feature = "wasm"))]
impl Helper {
pub fn identity() -> Self {
Helper::Rust(RustHelper::Identity)
}
pub fn bits() -> Self {
Helper::Rust(RustHelper::Bits)
}
}
impl fmt::Display for Helper {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
Helper::Rust(ref h) => write!(f, "Rust::{}", h),
#[cfg(feature = "wasm")]
Helper::Wasm(ref h) => write!(f, "Wasm::{}", h),
}
}
}
pub trait Executable<T: Field>: Signed {
fn execute(&self, inputs: &Vec<T>) -> Result<Vec<T>, String>;
}
pub trait Signed {
fn get_signature(&self) -> (usize, usize);
}
impl<T: Field> Executable<T> for Helper {
fn execute(&self, inputs: &Vec<T>) -> Result<Vec<T>, String> {
let (expected_input_count, expected_output_count) = self.get_signature();
assert!(inputs.len() == expected_input_count);
let result = match self {
Helper::Rust(helper) => helper.execute(inputs),
#[cfg(feature = "wasm")]
Helper::Wasm(helper) => helper.execute(inputs),
};
match result {
Ok(ref r) if r.len() != expected_output_count => Err(format!(
"invalid witness size: is {} but should be {}",
r.len(),
expected_output_count
)
.to_string()),
r => r,
}
}
}
impl Signed for Helper {
fn get_signature(&self) -> (usize, usize) {
match self {
Helper::Rust(helper) => helper.get_signature(),
#[cfg(feature = "wasm")]
Helper::Wasm(helper) => helper.get_signature(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use zokrates_field::field::FieldPrime;
mod eq_condition {
// Wanted: (Y = (X != 0) ? 1 : 0)
// # Y = if X == 0 then 0 else 1 fi
// # M = if X == 0 then 1 else 1/X fi
use super::*;
#[test]
fn execute() {
let cond_eq = RustHelper::ConditionEq;
let inputs = vec![0];
let r = cond_eq
.execute(&inputs.iter().map(|&i| FieldPrime::from(i)).collect())
.unwrap();
let res: Vec<FieldPrime> = vec![0, 1].iter().map(|&i| FieldPrime::from(i)).collect();
assert_eq!(r, &res[..]);
}
#[test]
fn execute_non_eq() {
let cond_eq = RustHelper::ConditionEq;
let inputs = vec![1];
let r = cond_eq
.execute(&inputs.iter().map(|&i| FieldPrime::from(i)).collect())
.unwrap();
let res: Vec<FieldPrime> = vec![1, 1].iter().map(|&i| FieldPrime::from(i)).collect();
assert_eq!(r, &res[..]);
}
}
}

View file

@ -1,621 +0,0 @@
use helpers::{Executable, Signed};
use std::fmt;
use rustc_hex::FromHex;
use serde::{Deserialize, Deserializer};
use std::hash::{Hash, Hasher};
use std::rc::Rc;
use wasmi::{ImportsBuilder, ModuleInstance, ModuleRef, NopExternals};
use zokrates_field::field::Field;
#[derive(Clone, Debug, Serialize)]
pub struct WasmHelper(
#[serde(skip)] std::rc::Rc<ModuleRef>,
#[serde(serialize_with = "serde_bytes::serialize")] Vec<u8>,
);
impl WasmHelper {
// Hand-coded assembly for identity.
// Source available at https://gist.github.com/gballet/f14d11053d8f846bfbb3687581b0eecb#file-identity-wast
pub const IDENTITY_WASM: &'static str = "0061736d010000000105016000017f030302000005030100010615047f0041010b7f0041010b7f0041200b7f0141000b074b06066d656d6f727902000e6765745f696e707574735f6f6666000105736f6c766500000a6d696e5f696e7075747303000b6d696e5f6f75747075747303010a6669656c645f73697a6503020a300229000340412023036a410023036a280200360200230341016a240323032302470d000b4100240341200b040041000b0b4b020041000b20ffffffff000000000000000000000000ffffffff0000000000000000000000000041200b20deadbeef000000000000000000000000deadbeef000000000000000000000000";
// Generated from C code, normalized and cleaned up by hand.
// Source available at https://gist.github.com/gballet/f14d11053d8f846bfbb3687581b0eecb#file-bits_v2-c
pub const BITS_WASM: &'static [u8] = &[
0, 97, 115, 109, 1, 0, 0, 0, 1, 8, 2, 96, 0, 0, 96, 0, 1, 127, 3, 5, 4, 0, 1, 1, 1, 4, 5,
1, 112, 1, 1, 1, 5, 3, 1, 0, 2, 6, 38, 6, 127, 1, 65, 240, 199, 4, 11, 127, 0, 65, 240,
199, 4, 11, 127, 0, 65, 240, 199, 0, 11, 127, 0, 65, 32, 11, 127, 0, 65, 1, 11, 127, 0, 65,
254, 1, 11, 7, 109, 9, 6, 109, 101, 109, 111, 114, 121, 2, 0, 11, 95, 95, 104, 101, 97,
112, 95, 98, 97, 115, 101, 3, 1, 10, 95, 95, 100, 97, 116, 97, 95, 101, 110, 100, 3, 2, 14,
103, 101, 116, 95, 105, 110, 112, 117, 116, 115, 95, 111, 102, 102, 0, 1, 5, 115, 111, 108,
118, 101, 0, 2, 4, 109, 97, 105, 110, 0, 3, 10, 102, 105, 101, 108, 100, 95, 115, 105, 122,
101, 3, 3, 10, 109, 105, 110, 95, 105, 110, 112, 117, 116, 115, 3, 4, 11, 109, 105, 110,
95, 111, 117, 116, 112, 117, 116, 115, 3, 5, 9, 1, 0, 10, 85, 4, 3, 0, 1, 11, 5, 0, 65,
144, 8, 11, 68, 1, 2, 127, 65, 253, 1, 33, 0, 65, 176, 8, 33, 1, 3, 64, 32, 1, 65, 1, 32,
0, 65, 7, 113, 116, 32, 0, 65, 3, 118, 65, 144, 8, 106, 45, 0, 0, 113, 65, 0, 71, 58, 0, 0,
32, 1, 65, 32, 106, 33, 1, 32, 0, 65, 127, 106, 34, 0, 65, 127, 71, 13, 0, 11, 65, 176, 8,
11, 4, 0, 65, 0, 11, 11, 19, 1, 0, 65, 128, 8, 11, 12, 32, 0, 0, 0, 1, 0, 0, 0, 254, 0, 0,
0,
];
pub fn from_hex<U: Into<String>>(u: U) -> Self {
let code_hex = u.into();
let code = FromHex::from_hex(&code_hex[..])
.expect(format!("invalid bytecode: {}", code_hex).as_str());
WasmHelper::from(code)
}
}
impl<U: Into<Vec<u8>>> From<U> for WasmHelper {
fn from(code: U) -> Self {
let code_vec = code.into();
let module = wasmi::Module::from_buffer(code_vec.clone()).expect("Error decoding buffer");
let modinst = ModuleInstance::new(&module, &ImportsBuilder::default())
.expect("Failed to instantiate module")
.assert_no_start();
WasmHelper(Rc::new(modinst), code_vec)
}
}
impl<'de> Deserialize<'de> for WasmHelper {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
let hex: Vec<u8> = serde_bytes::deserialize(deserializer)?;
Ok(WasmHelper::from(hex))
}
}
impl PartialEq for WasmHelper {
fn eq(&self, other: &WasmHelper) -> bool {
self.1 == other.1
}
}
impl Eq for WasmHelper {}
impl Hash for WasmHelper {
fn hash<H: Hasher>(&self, state: &mut H) {
self.1.hash(state);
}
}
impl fmt::Display for WasmHelper {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "Hex(\"{:?}\")", &self.1[..])
}
}
fn get_export<T: wasmi::FromRuntimeValue>(varname: &str, modref: &ModuleRef) -> Result<T, String> {
modref
.export_by_name(varname)
.ok_or(&format!("Could not find exported symbol `{}` in module", varname)[..])?
.as_global()
.ok_or(format!(
"Error getting {} from the list of globals",
varname
))?
.get()
.try_into::<T>()
.ok_or(format!("Error converting `{}` to i32", varname))
}
impl Signed for WasmHelper {
fn get_signature(&self) -> (usize, usize) {
// Check that the module has the following exports:
// * min_inputs = the (minimum) number of inputs
// * min_outputs = the (minimum) number of outputs
let ni = get_export::<i32>("min_inputs", self.0.as_ref()).unwrap();
let no = get_export::<i32>("min_outputs", self.0.as_ref()).unwrap();
(ni as usize, no as usize)
}
}
impl<T: Field> Executable<T> for WasmHelper {
fn execute(&self, inputs: &Vec<T>) -> Result<Vec<T>, String> {
let field_size = get_export::<i32>("field_size", self.0.as_ref())? as usize;
let ninputs = get_export::<i32>("min_inputs", self.0.as_ref())? as usize;
if ninputs != inputs.len() {
return Err(format!(
"`solve` expected {} inputs, received {}",
ninputs,
inputs.len()
));
}
/* Prepare the inputs */
let input_offset = self
.0
.invoke_export("get_inputs_off", &[], &mut NopExternals)
.map_err(|e| format!("Error getting the input offset: {}", e.to_string()))?
.ok_or("`get_inputs_off` did not return any value")?
.try_into::<i32>()
.ok_or("`get_inputs_off` returned the wrong type")?;
let mem = self
.0
.as_ref()
.export_by_name("memory")
.ok_or("Module didn't export its memory section")?
.as_memory()
.unwrap()
.clone();
for (index, input) in inputs.iter().enumerate() {
// Get the field's bytes and check they correspond to
// the value that the module expects.
let mut bv = input.into_byte_vector();
if bv.len() > field_size {
return Err(format!(
"Input #{} is stored on {} bytes which is greater than the field size of {}",
index,
bv.len(),
field_size
));
} else {
bv.resize(field_size, 0);
}
let addr = (input_offset as u32) + (index as u32) * (field_size as u32);
mem.set(addr, &bv[..])
.map_err(|_e| format!("Could not write at memory address {}", addr))?;
}
let output_offset = self
.0
.as_ref()
.invoke_export("solve", &[], &mut NopExternals)
.map_err(|e| format!("Error solving the problem: {}", e.to_string()))?
.ok_or("`solve` did not return any value")?
.try_into::<i32>()
.ok_or("`solve returned the wrong type`")?;
// NOTE: The question regarding the way that an error code is
// returned is still open.
//
// The current model considers that 2GB is more than enough
// to store the output data.
//
// This being said this approach is tacky and I am considering
// others at this point:
//
// 1. Use a 64 bit return code, values greater than 32-bits
// are considered error codes.
// 2. Export an extra global called `errno` which contains
// the error code.
// 3. 32-bit alignment gives a 2-bit error field
// 4. Return a pointer to a structure that contains
// the error code just before the output data.
//
// Experimenting with other languages will help decide what
// is the better approach.
if output_offset > 0 {
let mut outputs = Vec::new();
let noutputs = get_export::<i32>("min_outputs", self.0.as_ref())?;
for i in 0..noutputs {
let index = i as u32;
let fs = field_size as u32;
let value = mem
.get(output_offset as u32 + fs * index, field_size)
.map_err(|e| format!("Could not retrieve the output offset: {}", e))?;
outputs.push(T::from_byte_vector(value));
}
Ok(outputs)
} else {
Err(format!("`solve` returned error code {}", output_offset))
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use parity_wasm::builder::*;
use parity_wasm::elements::{Instruction, Instructions, ValueType};
use std::panic;
use zokrates_field::field::FieldPrime;
fn remove_export(code: &str, symbol: &str) -> Vec<u8> {
let code = FromHex::from_hex(code).unwrap();
let mut idmod: parity_wasm::elements::Module = parity_wasm::deserialize_buffer(&code[..])
.expect("Could not deserialize Identity module");
idmod
.export_section_mut()
.expect("Could not get export section")
.entries_mut()
.retain(|ref export| export.field() != symbol);
parity_wasm::serialize(idmod).expect("Could not serialize buffer")
}
fn replace_function(
code: &str,
symbol: &str,
params: Vec<ValueType>,
ret: Option<ValueType>,
instr: Vec<Instruction>,
) -> Vec<u8> {
/* Deserialize to parity_wasm format */
let code = FromHex::from_hex(code).unwrap();
let mut pwmod: parity_wasm::elements::Module = parity_wasm::deserialize_buffer(&code[..])
.expect("Could not deserialize Identity module");
/* Remove export, if it exists */
pwmod
.export_section_mut()
.expect("Could not get export section")
.entries_mut()
.retain(|ref export| export.field() != symbol);
/* Add a new function and give it the export name */
let wmod: parity_wasm::elements::Module = from_module(pwmod)
.function()
.signature()
.with_params(params)
.with_return_type(ret)
.build()
.body()
.with_instructions(Instructions::new(instr))
.build()
.build()
.export()
.field(symbol)
.internal()
.func(2)
.build()
.build();
parity_wasm::serialize(wmod).expect("Could not serialize buffer")
}
fn replace_global(code: &str, symbol: &str, value: i32) -> Vec<u8> {
/* Deserialize to parity_wasm format */
let code = FromHex::from_hex(code).unwrap();
let mut pwmod: parity_wasm::elements::Module = parity_wasm::deserialize_buffer(&code[..])
.expect("Could not deserialize Identity module");
/* Remove export, if it exists */
pwmod
.export_section_mut()
.expect("Could not get export section")
.entries_mut()
.retain(|ref export| export.field() != symbol);
/* Add a new function and give it the export name */
let wmod: parity_wasm::elements::Module = from_module(pwmod)
.global()
.value_type()
.i32()
.init_expr(Instruction::I32Const(value))
.build()
.export()
.field(symbol)
.internal()
.global(4)
.build()
.build();
parity_wasm::serialize(wmod).expect("Could not serialize buffer")
}
fn replace_global_type(code: &str, symbol: &str) -> Vec<u8> {
/* Deserialize to parity_wasm format */
let code = FromHex::from_hex(code).unwrap();
let mut pwmod: parity_wasm::elements::Module = parity_wasm::deserialize_buffer(&code[..])
.expect("Could not deserialize Identity module");
/* Remove export, if it exists */
pwmod
.export_section_mut()
.expect("Could not get export section")
.entries_mut()
.retain(|ref export| export.field() != symbol);
/* Add a new function and give it the export name */
let wmod: parity_wasm::elements::Module = from_module(pwmod)
.global()
.value_type()
.f32()
.init_expr(Instruction::F32Const(0))
.build()
.export()
.field(symbol)
.internal()
.global(4)
.build()
.build();
parity_wasm::serialize(wmod).expect("Could not serialize buffer")
}
#[test]
fn check_signatures() {
let h1 = WasmHelper::from_hex(WasmHelper::IDENTITY_WASM);
assert_eq!(h1.get_signature(), (1, 1));
}
#[test]
#[should_panic(
expected = "invalid bytecode: invalid bytecode: Invalid character 'i' at position 0"
)]
fn check_invalid_bytecode_fails() {
WasmHelper::from_hex("invalid bytecode");
}
#[test]
#[should_panic(expected = "Error decoding buffer: Validation(\"I/O Error: UnexpectedEof\")")]
fn check_truncated_bytecode_fails() {
WasmHelper::from_hex(&WasmHelper::IDENTITY_WASM[..20]);
}
#[test]
fn validate_exports() {
/* Test identity without the `solve` export */
let id = WasmHelper::from(remove_export(WasmHelper::IDENTITY_WASM, "solve"));
let input = vec![FieldPrime::from(1)];
let outputs = id.execute(&input);
assert_eq!(
outputs,
Err(String::from(
"Error solving the problem: Function: Module doesn\'t have export solve",
))
);
/* Test identity, without the `get_inputs_off` export */
let id = WasmHelper::from(remove_export(WasmHelper::IDENTITY_WASM, "get_inputs_off"));
let input = vec![FieldPrime::from(1)];
let outputs = id.execute(&input);
assert_eq!(
outputs,
Err(String::from(
"Error getting the input offset: Function: Module doesn\'t have export get_inputs_off",
))
);
/* Test identity, without the `min_inputs` export */
let id = WasmHelper::from(remove_export(WasmHelper::IDENTITY_WASM, "min_inputs"));
let input = vec![FieldPrime::from(1)];
let outputs = id.execute(&input);
assert_eq!(
outputs,
Err(String::from(
"Could not find exported symbol `min_inputs` in module",
))
);
/* Test identity, without the `min_outputs` export */
let id = WasmHelper::from(remove_export(WasmHelper::IDENTITY_WASM, "min_outputs"));
let input = vec![FieldPrime::from(1)];
let outputs = id.execute(&input);
assert_eq!(
outputs,
Err(String::from(
"Could not find exported symbol `min_outputs` in module",
))
);
/* Test identity, without the `field_size` export */
let id = WasmHelper::from(remove_export(WasmHelper::IDENTITY_WASM, "field_size"));
let input = vec![FieldPrime::from(1)];
let outputs = id.execute(&input);
assert_eq!(
outputs,
Err(String::from(
"Could not find exported symbol `field_size` in module",
))
);
/* Test identity, without the `memory` export */
let id = WasmHelper::from(remove_export(WasmHelper::IDENTITY_WASM, "memory"));
let input = vec![FieldPrime::from(1)];
let outputs = id.execute(&input);
assert_eq!(
outputs,
Err(String::from("Module didn\'t export its memory section"))
);
}
#[test]
fn check_invalid_function_type() {
/* Test identity, with a different function return type */
let id = WasmHelper::from(replace_function(
WasmHelper::IDENTITY_WASM,
"get_inputs_off",
Vec::new(),
Some(ValueType::I64),
vec![Instruction::I64Const(0), Instruction::End],
));
let input = vec![FieldPrime::from(1)];
let outputs = id.execute(&input);
assert_eq!(
outputs,
Err(String::from("`get_inputs_off` returned the wrong type"))
);
/* Test identity, with no return type for function */
let id = WasmHelper::from(replace_function(
WasmHelper::IDENTITY_WASM,
"get_inputs_off",
Vec::new(),
None,
vec![Instruction::Nop, Instruction::End],
));
let input = vec![FieldPrime::from(1)];
let outputs = id.execute(&input);
assert_eq!(
outputs,
Err(String::from("`get_inputs_off` did not return any value"))
);
/* Test identity, with extra parameter for function */
let id = WasmHelper::from(replace_function(
WasmHelper::IDENTITY_WASM,
"get_inputs_off",
vec![ValueType::I64],
Some(ValueType::I32),
vec![Instruction::I32Const(0), Instruction::End],
));
let input = vec![FieldPrime::from(1)];
let outputs = id.execute(&input);
assert_eq!(
outputs,
Err(String::from(
"Error getting the input offset: Trap: Trap { kind: UnexpectedSignature }",
))
);
}
#[test]
fn check_invalid_field_size() {
/* Test identity, with 1-byte filed size */
let id = WasmHelper::from(replace_global(WasmHelper::IDENTITY_WASM, "field_size", 1));
let input = vec![FieldPrime::from(65536)];
let outputs = id.execute(&input);
assert_eq!(
outputs,
Err(String::from(
"Input #0 is stored on 3 bytes which is greater than the field size of 1",
))
);
/* Test identity, tweaked so that field_size is a f32 */
let id = WasmHelper::from(replace_global_type(WasmHelper::IDENTITY_WASM, "field_size"));
let input = vec![FieldPrime::from(65536)];
let outputs = id.execute(&input);
assert_eq!(
outputs,
Err(String::from("Error converting `field_size` to i32"))
);
}
#[test]
fn check_identity() {
let id = WasmHelper::from_hex(WasmHelper::IDENTITY_WASM);
let input = vec![FieldPrime::from(1)];
let outputs = id.execute(&input).expect("Identity call failed");
assert_eq!(outputs, input);
let id = WasmHelper::from_hex(WasmHelper::IDENTITY_WASM);
let input = vec![FieldPrime::from(0)];
let outputs = id.execute(&input).expect("Identity call failed");
assert_eq!(outputs, input);
}
#[test]
fn check_identity_3_bytes() {
let id = WasmHelper::from_hex(WasmHelper::IDENTITY_WASM);
let input = vec![FieldPrime::from(16777216)];
let outputs = id.execute(&input).expect("Identity call failed");
assert_eq!(outputs, input);
}
#[test]
fn check_identity_multiple_calls() {
let id = WasmHelper::from_hex(WasmHelper::IDENTITY_WASM);
let input = vec![FieldPrime::from(16777216)];
for _i in 0..10 {
let outputs = id.execute(&input).expect("Identity call failed");
assert_eq!(outputs, input);
}
}
#[test]
fn check_invalid_arg_number() {
let id = WasmHelper::from_hex(WasmHelper::IDENTITY_WASM);
let input = vec![FieldPrime::from(1)];
let outputs = id.execute(&input).expect("Identity call failed");
assert_eq!(outputs, input);
}
#[test]
fn check_memory_boundaries() {
// Check that input writes are boundary-checked: same as identity, but
// get_inputs_off returns an OOB offset.
let id = WasmHelper::from(replace_function(
WasmHelper::IDENTITY_WASM,
"get_inputs_off",
Vec::new(),
Some(ValueType::I32),
vec![Instruction::I32Const(65536), Instruction::End],
));
let input = vec![FieldPrime::from(65536)];
let outputs = id.execute(&input);
assert_eq!(
outputs,
Err(String::from("Could not write at memory address 65536"))
);
/* Check that output writes are boundary-checked */
// Check that input writes are boundary-checked: same as identity, but
// solve returns an OOB offset.
let id = WasmHelper::from(replace_function(
WasmHelper::IDENTITY_WASM,
"solve",
Vec::new(),
Some(ValueType::I32),
vec![Instruction::I32Const(65536), Instruction::End],
));
let input = vec![FieldPrime::from(65536)];
let outputs = id.execute(&input);
assert_eq!(
outputs,
Err(String::from(
"Could not retrieve the output offset: Memory: trying to access region [65536..65568] in memory [0..64]",
))
);
}
#[test]
fn check_negative_output_value() {
/* Same as identity, but `solve` returns -1 */
let id = WasmHelper::from(replace_function(
WasmHelper::IDENTITY_WASM,
"solve",
Vec::new(),
Some(ValueType::I32),
vec![Instruction::I32Const(-1), Instruction::End],
));
let input = vec![FieldPrime::from(1)];
let outputs = id.execute(&input);
assert_eq!(outputs, Err(String::from("`solve` returned error code -1")));
}
#[test]
fn check_bits() {
let bits = WasmHelper::from(WasmHelper::BITS_WASM);
let input = vec![FieldPrime::from(0xdeadbeef as u32)];
let outputs = bits.execute(&input).unwrap();
assert_eq!(254, outputs.len());
for i in 0..32 {
let bitval = (0xdeadbeef as i64 >> i) & 1;
assert_eq!(outputs[(253 - i) as usize], FieldPrime::from(bitval as i32));
}
for i in 32..254 {
assert_eq!(outputs[(253 - i) as usize], FieldPrime::from(0));
}
}
#[test]
fn check_bits_multiple_times() {
let bits = WasmHelper::from(WasmHelper::BITS_WASM);
let input = vec![FieldPrime::from(0xdeadbeef as u32)];
for _ in 0..10 {
let outputs = bits.execute(&input).unwrap();
assert_eq!(254, outputs.len());
for i in 0..32 {
let bitval = (0xdeadbeef as i64 >> i) & 1;
assert_eq!(outputs[(253 - i) as usize], FieldPrime::from(bitval as i32));
}
for i in 32..254 {
assert_eq!(outputs[(253 - i) as usize], FieldPrime::from(0));
}
}
}
}

View file

@ -1,5 +1,6 @@
use crate::flat_absy::{FlatExpression, FlatFunction, FlatProg, FlatStatement, FlatVariable};
use crate::helpers;
use crate::flat_absy::{
FlatDirective, FlatExpression, FlatFunction, FlatProg, FlatStatement, FlatVariable,
};
use crate::ir::{Directive, Function, LinComb, Prog, QuadComb, Statement};
use num::Zero;
use zokrates_field::field::Field;
@ -126,11 +127,11 @@ impl<T: Field> From<FlatStatement<T>> for Statement<T> {
}
}
impl<T: Field> From<helpers::DirectiveStatement<T>> for Directive<T> {
fn from(ds: helpers::DirectiveStatement<T>) -> Directive<T> {
impl<T: Field> From<FlatDirective<T>> for Directive<T> {
fn from(ds: FlatDirective<T>) -> Directive<T> {
Directive {
inputs: ds.inputs.into_iter().map(|i| i.into()).collect(),
helper: ds.helper,
solver: ds.solver,
outputs: ds.outputs,
}
}

View file

@ -1,6 +1,6 @@
use crate::flat_absy::flat_variable::FlatVariable;
use crate::helpers::Executable;
use crate::ir::{LinComb, Prog, QuadComb, Statement, Witness};
use crate::solvers::Executable;
use std::collections::BTreeMap;
use std::fmt;
use zokrates_field::field::Field;
@ -41,7 +41,7 @@ impl<T: Field> Prog<T> {
.iter()
.map(|i| i.evaluate(&witness).unwrap())
.collect();
match d.helper.execute(&input_values) {
match d.solver.execute(&input_values) {
Ok(res) => {
for (i, o) in d.outputs.iter().enumerate() {
witness.insert(o.clone(), res[i].clone());

View file

@ -1,6 +1,6 @@
use crate::flat_absy::flat_parameter::FlatParameter;
use crate::flat_absy::FlatVariable;
use crate::helpers::Helper;
use crate::solvers::Solver;
use std::fmt;
use typed_absy::types::signature::Signature;
use zokrates_field::field::Field;
@ -37,7 +37,7 @@ impl<T: Field> Statement<T> {
pub struct Directive<T: Field> {
pub inputs: Vec<LinComb<T>>,
pub outputs: Vec<FlatVariable>,
pub helper: Helper,
pub solver: Solver,
}
impl<T: Field> fmt::Display for Directive<T> {
@ -50,7 +50,7 @@ impl<T: Field> fmt::Display for Directive<T> {
.map(|o| format!("{}", o))
.collect::<Vec<_>>()
.join(", "),
self.helper,
self.solver,
self.inputs
.iter()
.map(|i| format!("{}", i))

View file

@ -28,11 +28,11 @@ extern crate zokrates_pest_ast;
mod embed;
mod flatten;
mod helpers;
mod imports;
mod optimizer;
mod parser;
mod semantics;
mod solvers;
mod static_analysis;
pub mod absy;

View file

@ -1948,7 +1948,7 @@ mod tests {
mod symbols {
use super::*;
/// Helper function to create (() -> (): return)
/// solver function to create (() -> (): return)
fn function0() -> FunctionNode<'static, FieldPrime> {
let statements: Vec<StatementNode<FieldPrime>> = vec![Statement::Return(
ExpressionList {
@ -1970,7 +1970,7 @@ mod tests {
.mock()
}
/// Helper function to create ((private field a) -> (): return)
/// solver function to create ((private field a) -> (): return)
fn function1() -> FunctionNode<'static, FieldPrime> {
let statements: Vec<StatementNode<FieldPrime>> = vec![Statement::Return(
ExpressionList {
@ -3304,7 +3304,7 @@ mod tests {
mod structs {
use super::*;
/// helper function to create a module at location "" with a single symbol `Foo { foo: field }`
/// solver function to create a module at location "" with a single symbol `Foo { foo: field }`
fn create_module_with_foo(
s: StructType<'static>,
) -> (Checker<'static>, State<'static, FieldPrime>) {

View file

@ -1,10 +1,9 @@
use crate::helpers::{Executable, Signed};
use std::fmt;
use zokrates_embed::generate_sha256_round_witness;
use zokrates_field::field::Field;
#[derive(Clone, PartialEq, Debug, Serialize, Deserialize, Hash, Eq)]
pub enum RustHelper {
pub enum Solver {
Identity,
ConditionEq,
Bits,
@ -12,33 +11,36 @@ pub enum RustHelper {
Sha256Round,
}
impl fmt::Display for RustHelper {
impl fmt::Display for Solver {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{:?}", self)
}
}
impl Signed for RustHelper {
impl Signed for Solver {
fn get_signature(&self) -> (usize, usize) {
match self {
RustHelper::Identity => (1, 1),
RustHelper::ConditionEq => (1, 2),
RustHelper::Bits => (1, 254),
RustHelper::Div => (2, 1),
RustHelper::Sha256Round => (768, 26935),
Solver::Identity => (1, 1),
Solver::ConditionEq => (1, 2),
Solver::Bits => (1, 254),
Solver::Div => (2, 1),
Solver::Sha256Round => (768, 26935),
}
}
}
impl<T: Field> Executable<T> for RustHelper {
impl<T: Field> Executable<T> for Solver {
fn execute(&self, inputs: &Vec<T>) -> Result<Vec<T>, String> {
let (expected_input_count, expected_output_count) = self.get_signature();
assert!(inputs.len() == expected_input_count);
match self {
RustHelper::Identity => Ok(inputs.clone()),
RustHelper::ConditionEq => match inputs[0].is_zero() {
Solver::Identity => Ok(inputs.clone()),
Solver::ConditionEq => match inputs[0].is_zero() {
true => Ok(vec![T::zero(), T::one()]),
false => Ok(vec![T::one(), T::one() / inputs[0].clone()]),
},
RustHelper::Bits => {
Solver::Bits => {
let mut num = inputs[0].clone();
let mut res = vec![];
let bits = 254;
@ -53,8 +55,8 @@ impl<T: Field> Executable<T> for RustHelper {
assert_eq!(num, T::zero());
Ok(res)
}
RustHelper::Div => Ok(vec![inputs[0].clone() / inputs[1].clone()]),
RustHelper::Sha256Round => {
Solver::Div => Ok(vec![inputs[0].clone() / inputs[1].clone()]),
Solver::Sha256Round => {
let i = &inputs[0..512];
let h = &inputs[512..];
let i: Vec<_> = i.iter().map(|x| x.clone().into_bellman()).collect();
@ -69,15 +71,64 @@ impl<T: Field> Executable<T> for RustHelper {
}
}
impl Solver {
pub fn identity() -> Self {
Solver::Identity
}
pub fn bits() -> Self {
Solver::Bits
}
}
pub trait Executable<T: Field>: Signed {
fn execute(&self, inputs: &Vec<T>) -> Result<Vec<T>, String>;
}
pub trait Signed {
fn get_signature(&self) -> (usize, usize);
}
#[cfg(test)]
mod tests {
use super::*;
use zokrates_field::field::FieldPrime;
mod eq_condition {
// Wanted: (Y = (X != 0) ? 1 : 0)
// # Y = if X == 0 then 0 else 1 fi
// # M = if X == 0 then 1 else 1/X fi
use super::*;
#[test]
fn execute() {
let cond_eq = Solver::ConditionEq;
let inputs = vec![0];
let r = cond_eq
.execute(&inputs.iter().map(|&i| FieldPrime::from(i)).collect())
.unwrap();
let res: Vec<FieldPrime> = vec![0, 1].iter().map(|&i| FieldPrime::from(i)).collect();
assert_eq!(r, &res[..]);
}
#[test]
fn execute_non_eq() {
let cond_eq = Solver::ConditionEq;
let inputs = vec![1];
let r = cond_eq
.execute(&inputs.iter().map(|&i| FieldPrime::from(i)).collect())
.unwrap();
let res: Vec<FieldPrime> = vec![1, 1].iter().map(|&i| FieldPrime::from(i)).collect();
assert_eq!(r, &res[..]);
}
}
#[test]
fn bits_of_one() {
let inputs = vec![FieldPrime::from(1)];
let res = RustHelper::Bits.execute(&inputs).unwrap();
let res = Solver::Bits.execute(&inputs).unwrap();
assert_eq!(res[253], FieldPrime::from(1));
for i in 0..252 {
assert_eq!(res[i], FieldPrime::from(0));
@ -87,7 +138,7 @@ mod tests {
#[test]
fn bits_of_42() {
let inputs = vec![FieldPrime::from(42)];
let res = RustHelper::Bits.execute(&inputs).unwrap();
let res = Solver::Bits.execute(&inputs).unwrap();
assert_eq!(res[253], FieldPrime::from(0));
assert_eq!(res[252], FieldPrime::from(1));
assert_eq!(res[251], FieldPrime::from(0));

View file

@ -1,6 +1,6 @@
use crate::flat_absy::{FlatExpression, FlatExpressionList, FlatFunction, FlatStatement};
use crate::flat_absy::{FlatParameter, FlatVariable};
use crate::helpers::{DirectiveStatement, Helper, RustHelper};
use crate::solvers::{FlatDirective, solver, Rustsolver};
use crate::types::{Signature, Type};
use bellman::pairing::ff::ScalarEngine;
use reduce::Reduce;
@ -118,13 +118,13 @@ pub fn sha_round<T: Field>() -> FlatFunction<T> {
.collect();
// insert a directive to set the witness based on the bellman gadget and inputs
let directive_statement = FlatStatement::Directive(DirectiveStatement {
let directive_statement = FlatStatement::Directive(FlatDirective {
outputs: cs_indices.map(|i| FlatVariable::new(i)).collect(),
inputs: input_argument_indices
.chain(current_hash_argument_indices)
.map(|i| FlatVariable::new(i).into())
.collect(),
helper: Helper::Rust(RustHelper::Sha256Round),
solver: Solver::Sha256Round,
});
// insert a statement to return the subset of the witness

View file

@ -5,7 +5,6 @@
//! @date 2018
use crate::flat_absy::*;
use crate::helpers::DirectiveStatement;
use std::collections::HashMap;
use zokrates_field::field::Field;
@ -74,7 +73,7 @@ impl<T: Field> FlatStatement<T> {
e1.propagate(constants),
e2.propagate(constants),
)),
FlatStatement::Directive(d) => Some(FlatStatement::Directive(DirectiveStatement {
FlatStatement::Directive(d) => Some(FlatStatement::Directive(FlatDirective {
inputs: d
.inputs
.into_iter()