diff --git a/.circleci/config.yml b/.circleci/config.yml index 9d4ca794..f1655ac2 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -30,6 +30,9 @@ jobs: - run: name: Run integration tests command: WITH_LIBSNARK=1 LIBSNARK_SOURCE_PATH=$HOME/libsnark RUSTFLAGS="-D warnings" cargo test --release -- --ignored + - run: + name: Generate code coverage report + command: ./scripts/cov.sh - save_cache: paths: - /usr/local/cargo/registry diff --git a/.codecov.yml b/.codecov.yml new file mode 100644 index 00000000..db247200 --- /dev/null +++ b/.codecov.yml @@ -0,0 +1 @@ +comment: off diff --git a/scripts/cov.sh b/scripts/cov.sh new file mode 100755 index 00000000..dfc79144 --- /dev/null +++ b/scripts/cov.sh @@ -0,0 +1,16 @@ +#!/bin/bash + +# Exit if any subcommand fails +set -e + +apt-get update +apt-get install -qq curl zlib1g-dev build-essential python +apt-get install -qq cmake g++ pkg-config jq +apt-get install -qq libcurl4-openssl-dev libelf-dev libdw-dev binutils-dev libiberty-dev +cargo install cargo-kcov +cargo kcov --print-install-kcov-sh | sh +cd zokrates_fs_resolver && WITH_LIBSNARK=1 LIBSNARK_SOURCE_PATH=$HOME/libsnark cargo kcov && cd .. +cd zokrates_core && WITH_LIBSNARK=1 LIBSNARK_SOURCE_PATH=$HOME/libsnark cargo kcov && cd .. +cd zokrates_cli && WITH_LIBSNARK=1 LIBSNARK_SOURCE_PATH=$HOME/libsnark cargo kcov && cd .. +bash <(curl -s https://codecov.io/bash) +echo "Uploaded code coverage" diff --git a/zokrates_book/src/concepts/control_flow.md b/zokrates_book/src/concepts/control_flow.md index a6cf3f58..14f94868 100644 --- a/zokrates_book/src/concepts/control_flow.md +++ b/zokrates_book/src/concepts/control_flow.md @@ -20,6 +20,16 @@ def main() -> (field): return 1 ``` +### If expressions + +An if expression allows you to branch your code depending on conditions. + +```zokrates +def main(field x) -> (field): + field y = if x + 2 == 3 then 1 else 5 fi + return y +``` + ### For loops For loops are available with the following syntax: diff --git a/zokrates_cli/src/bin.rs b/zokrates_cli/src/bin.rs index a0f77c95..e4ba730f 100644 --- a/zokrates_cli/src/bin.rs +++ b/zokrates_cli/src/bin.rs @@ -21,11 +21,9 @@ use std::path::{Path, PathBuf}; use std::string::String; use zokrates_core::compile::compile; use zokrates_core::field::{Field, FieldPrime}; -use zokrates_core::flat_absy::FlatProg; +use zokrates_core::ir; #[cfg(feature = "libsnark")] use zokrates_core::proof_system::{ProofSystem, GM17, PGHR13}; -#[cfg(feature = "libsnark")] -use zokrates_core::r1cs::r1cs_program; use zokrates_fs_resolver::resolve as fs_resolve; fn main() { @@ -245,20 +243,14 @@ fn main() { let mut reader = BufReader::new(file); - let program_flattened: FlatProg = + let program_flattened: ir::Prog = match compile(&mut reader, Some(location), Some(fs_resolve)) { Ok(p) => p, Err(why) => panic!("Compilation failed: {}", why), }; // number of constraints the flattened program will translate to. - let num_constraints = &program_flattened - .functions - .iter() - .find(|x| x.id == "main") - .unwrap() - .statements - .len(); + let num_constraints = program_flattened.constraint_count(); // serialize flattened program and write to binary file let mut bin_output_file = match File::create(&bin_output_path) { @@ -305,7 +297,7 @@ fn main() { Err(why) => panic!("couldn't open {}: {}", path.display(), why), }; - let program_ast: FlatProg = match deserialize_from(&mut file, Infinite) { + let program_ast: ir::Prog = match deserialize_from(&mut file, Infinite) { Ok(x) => x, Err(why) => { println!("{:?}", why); @@ -313,14 +305,8 @@ fn main() { } }; - let main_flattened = program_ast - .functions - .iter() - .find(|x| x.id == "main") - .unwrap(); - // print deserialized flattened program - println!("{}", main_flattened); + println!("{}", program_ast); // validate #arguments let mut cli_arguments: Vec = Vec::new(); @@ -339,11 +325,11 @@ fn main() { let is_interactive = sub_matches.occurrences_of("interactive") > 0; // in interactive mode, only public inputs are expected - let expected_cli_args_count = main_flattened - .arguments - .iter() - .filter(|x| !(x.private && is_interactive)) - .count(); + let expected_cli_args_count = if is_interactive { + program_ast.public_arguments_count() + } else { + program_ast.public_arguments_count() + program_ast.private_arguments_count() + }; if cli_arguments.len() != expected_cli_args_count { println!( @@ -355,10 +341,9 @@ fn main() { } let mut cli_arguments_iter = cli_arguments.into_iter(); - let arguments = main_flattened - .arguments - .clone() - .into_iter() + let arguments: Vec = program_ast + .parameters() + .iter() .map(|x| { match x.private && is_interactive { // private inputs are passed interactively when the flag is present @@ -383,7 +368,9 @@ fn main() { }) .collect(); - let witness_map = main_flattened.get_witness(arguments).unwrap(); + let witness_map = program_ast + .execute(arguments) + .unwrap_or_else(|e| panic!(format!("Execution failed: {}", e))); println!( "\nWitness: \n\n{}", @@ -431,7 +418,7 @@ fn main() { Err(why) => panic!("couldn't open {}: {}", path.display(), why), }; - let program_ast: FlatProg = match deserialize_from(&mut file, Infinite) { + let program: ir::Prog = match deserialize_from(&mut file, Infinite) { Ok(x) => x, Err(why) => { println!("{:?}", why); @@ -439,17 +426,11 @@ fn main() { } }; - let main_flattened = program_ast - .functions - .iter() - .find(|x| x.id == "main") - .unwrap(); - // print deserialized flattened program - println!("{}", main_flattened); + println!("{}", program); // transform to R1CS - let (variables, public_variables_count, a, b, c) = r1cs_program(&program_ast); + let (variables, public_variables_count, a, b, c) = r1cs_program(program); // write variables meta information to file let var_inf_path = Path::new(sub_matches.value_of("meta-information").unwrap()); @@ -629,6 +610,7 @@ mod tests { extern crate glob; use self::glob::glob; use super::*; + use zokrates_core::ir::r1cs_program; #[test] fn examples() { @@ -655,10 +637,10 @@ mod tests { .into_string() .unwrap(); - let program_flattened: FlatProg = + let program_flattened: ir::Prog = compile(&mut reader, Some(location), Some(fs_resolve)).unwrap(); - let (..) = r1cs_program(&program_flattened); + let (..) = r1cs_program(program_flattened); } } @@ -684,12 +666,12 @@ mod tests { let mut reader = BufReader::new(file); - let program_flattened: FlatProg = + let program_flattened: ir::Prog = compile(&mut reader, Some(location), Some(fs_resolve)).unwrap(); - let (..) = r1cs_program(&program_flattened); + let (..) = r1cs_program(program_flattened.clone()); let _ = program_flattened - .get_witness(vec![FieldPrime::from(0)]) + .execute(vec![FieldPrime::from(0)]) .unwrap(); } } @@ -716,14 +698,14 @@ mod tests { let mut reader = BufReader::new(file); - let program_flattened: FlatProg = + let program_flattened: ir::Prog = compile(&mut reader, Some(location), Some(fs_resolve)).unwrap(); - let (..) = r1cs_program(&program_flattened); + let (..) = r1cs_program(program_flattened.clone()); let result = std::panic::catch_unwind(|| { let _ = program_flattened - .get_witness(vec![FieldPrime::from(0)]) + .execute(vec![FieldPrime::from(0)]) .unwrap(); }); assert!(result.is_err()); diff --git a/zokrates_core/src/absy/mod.rs b/zokrates_core/src/absy/mod.rs index 84c9ca73..38c4d647 100644 --- a/zokrates_core/src/absy/mod.rs +++ b/zokrates_core/src/absy/mod.rs @@ -17,7 +17,7 @@ use flat_absy::*; use imports::Import; use std::fmt; -#[derive(Serialize, Deserialize, Clone, PartialEq)] +#[derive(Clone, PartialEq)] pub struct Prog { /// Functions of the program pub functions: Vec>, @@ -74,7 +74,7 @@ impl fmt::Debug for Prog { } } -#[derive(Serialize, Deserialize, Clone, PartialEq)] +#[derive(Clone, PartialEq)] pub struct Function { /// Name of the program pub id: String, @@ -122,7 +122,7 @@ impl fmt::Debug for Function { } } -#[derive(Clone, Serialize, Deserialize, PartialEq)] +#[derive(Clone, PartialEq)] pub enum Assignee { Identifier(String), ArrayElement(Box>, Box>), @@ -154,7 +154,7 @@ impl From> for Assignee { } } -#[derive(Clone, Serialize, Deserialize, PartialEq)] +#[derive(Clone, PartialEq)] pub enum Statement { Return(ExpressionList), Declaration(Variable), diff --git a/zokrates_core/src/compile.rs b/zokrates_core/src/compile.rs index 56597ea5..f2f0ed95 100644 --- a/zokrates_core/src/compile.rs +++ b/zokrates_core/src/compile.rs @@ -8,6 +8,7 @@ use field::Field; use flat_absy::FlatProg; use flatten::Flattener; use imports::{self, Importer}; +use ir; use optimizer::Optimizer; use parser::{self, parse_program}; use semantics::{self, Checker}; @@ -64,9 +65,9 @@ pub fn compile>( reader: &mut R, location: Option, resolve_option: Option, &String) -> Result<(S, String, String), E>>, -) -> Result, CompileError> { +) -> Result, CompileError> { let compiled = compile_aux(reader, location, resolve_option)?; - Ok(Optimizer::new().optimize_program(compiled)) + Ok(ir::Prog::from(Optimizer::new().optimize_program(compiled))) } pub fn compile_aux>( @@ -113,7 +114,7 @@ mod test { "# .as_bytes(), ); - let res: Result, CompileError> = compile( + let res: Result, CompileError> = compile( &mut r, Some(String::from("./path/to/file")), None::< @@ -138,7 +139,7 @@ mod test { "# .as_bytes(), ); - let res: Result, CompileError> = compile( + let res: Result, CompileError> = compile( &mut r, Some(String::from("./path/to/file")), None::< diff --git a/zokrates_core/src/flat_absy/flat_variable.rs b/zokrates_core/src/flat_absy/flat_variable.rs index f575f0d9..5f273536 100644 --- a/zokrates_core/src/flat_absy/flat_variable.rs +++ b/zokrates_core/src/flat_absy/flat_variable.rs @@ -37,7 +37,7 @@ impl fmt::Display for FlatVariable { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self.id { 0 => write!(f, "~one"), - i if i > 0 => write!(f, "_{}", i + 1), + i if i > 0 => write!(f, "_{}", i - 1), i => write!(f, "~out_{}", -(i + 1)), } } @@ -72,18 +72,18 @@ mod tests { #[test] fn one() { - assert_eq!(FlatVariable::one().id, 0); + assert_eq!(format!("{}", FlatVariable::one()), "~one"); } #[test] fn public() { - assert_eq!(FlatVariable::public(0).id, -1); - assert_eq!(FlatVariable::public(42).id, -43); + assert_eq!(format!("{}", FlatVariable::public(0)), "~out_0"); + assert_eq!(format!("{}", FlatVariable::public(42)), "~out_42"); } #[test] fn private() { - assert_eq!(FlatVariable::new(0).id, 1); - assert_eq!(FlatVariable::new(42).id, 43); + assert_eq!(format!("{}", FlatVariable::new(0)), "_0"); + assert_eq!(format!("{}", FlatVariable::new(42)), "_42"); } } diff --git a/zokrates_core/src/flat_absy/mod.rs b/zokrates_core/src/flat_absy/mod.rs index b163434d..7dc4831c 100644 --- a/zokrates_core/src/flat_absy/mod.rs +++ b/zokrates_core/src/flat_absy/mod.rs @@ -19,7 +19,7 @@ use std::collections::{BTreeMap, HashMap}; use std::fmt; use types::Signature; -#[derive(Serialize, Deserialize, Clone)] +#[derive(Clone)] pub struct FlatProg { /// FlatFunctions of the program pub functions: Vec>, @@ -71,7 +71,7 @@ impl From for FlatProg { } } -#[derive(Serialize, Deserialize, Clone, PartialEq)] +#[derive(Clone, PartialEq)] pub struct FlatFunction { /// Name of the program pub id: String, @@ -180,7 +180,7 @@ impl fmt::Debug for FlatFunction { /// /// * r1cs - R1CS in standard JSON data format -#[derive(Clone, Serialize, Deserialize, PartialEq)] +#[derive(Clone, PartialEq)] pub enum FlatStatement { Return(FlatExpressionList), Condition(FlatExpression, FlatExpression), diff --git a/zokrates_core/src/flatten/mod.rs b/zokrates_core/src/flatten/mod.rs index b542f516..62f7a9fd 100644 --- a/zokrates_core/src/flatten/mod.rs +++ b/zokrates_core/src/flatten/mod.rs @@ -492,7 +492,7 @@ impl Flattener { .into_iter() .map(|x| x.apply_direct_substitution(&replacement_map)) .collect(), - } + }; } FlatStatement::Definition(var, rhs) => { let new_var = self.issue_new_variable(); diff --git a/zokrates_core/src/imports.rs b/zokrates_core/src/imports.rs index fbf9861c..c1c97216 100644 --- a/zokrates_core/src/imports.rs +++ b/zokrates_core/src/imports.rs @@ -179,7 +179,7 @@ impl Importer { return Err(CompileError::ImportError(Error::new(format!( "Gadget {} not found", s - )))) + )))); } } } @@ -207,7 +207,7 @@ impl Importer { return Err(CompileError::ImportError(Error::new(format!( "Packing helper {} not found", s - )))) + )))); } } } else { @@ -226,7 +226,7 @@ impl Importer { Err(err) => return Err(CompileError::ImportError(err.into())), }, None => { - return Err(Error::new("Can't resolve import without a resolver").into()) + return Err(Error::new("Can't resolve import without a resolver").into()); } } } diff --git a/zokrates_core/src/ir/expression.rs b/zokrates_core/src/ir/expression.rs new file mode 100644 index 00000000..7a29a78b --- /dev/null +++ b/zokrates_core/src/ir/expression.rs @@ -0,0 +1,174 @@ +use field::Field; +use flat_absy::FlatVariable; +use num::Zero; +use std::collections::BTreeMap; +use std::fmt; +use std::ops::{Add, Sub}; + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct QuadComb { + pub left: LinComb, + pub right: LinComb, +} + +impl QuadComb { + pub fn from_linear_combinations(left: LinComb, right: LinComb) -> Self { + QuadComb { left, right } + } +} + +impl From for QuadComb { + fn from(v: FlatVariable) -> QuadComb { + LinComb::from(v).into() + } +} + +impl fmt::Display for QuadComb { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "({}) * ({})", self.left, self.right,) + } +} + +impl From> for QuadComb { + fn from(lc: LinComb) -> QuadComb { + QuadComb::from_linear_combinations(LinComb::one(), lc) + } +} + +#[derive(PartialEq, PartialOrd, Clone, Eq, Ord, Hash, Debug, Serialize, Deserialize)] +pub struct LinComb(pub BTreeMap); + +impl LinComb { + pub fn summand>(mult: U, var: FlatVariable) -> LinComb { + let mut res = BTreeMap::new(); + res.insert(var, mult.into()); + LinComb(res) + } + + pub fn one() -> LinComb { + Self::summand(1, FlatVariable::one()) + } +} + +impl fmt::Display for LinComb { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!( + f, + "{}", + self.0 + .iter() + .map(|(k, v)| format!("{} * {}", v, k)) + .collect::>() + .join(" + ") + ) + } +} + +impl From for LinComb { + fn from(v: FlatVariable) -> LinComb { + let mut r = BTreeMap::new(); + r.insert(v, T::one()); + LinComb(r) + } +} + +impl Add> for LinComb { + type Output = LinComb; + + fn add(self, other: LinComb) -> LinComb { + let mut res = self.0.clone(); + for (k, v) in other.0 { + let new_val = v + res.get(&k).unwrap_or(&T::zero()); + if new_val == T::zero() { + res.remove(&k) + } else { + res.insert(k, new_val) + }; + } + LinComb(res) + } +} + +impl Sub> for LinComb { + type Output = LinComb; + + fn sub(self, other: LinComb) -> LinComb { + let mut res = self.0.clone(); + for (k, v) in other.0 { + let new_val = T::zero() - v + res.get(&k).unwrap_or(&T::zero()); + if new_val == T::zero() { + res.remove(&k) + } else { + res.insert(k, new_val) + }; + } + LinComb(res) + } +} + +impl Zero for LinComb { + fn zero() -> LinComb { + LinComb(BTreeMap::new()) + } + fn is_zero(&self) -> bool { + self.0.len() == 0 + } +} + +#[cfg(test)] +mod tests { + use super::*; + use field::FieldPrime; + + mod linear { + + use super::*; + #[test] + fn add_zero() { + let a: LinComb = LinComb::zero(); + let b: LinComb = FlatVariable::new(42).into(); + let c = a + b.clone(); + assert_eq!(c, b); + } + #[test] + fn add() { + let a: LinComb = FlatVariable::new(42).into(); + let b: LinComb = FlatVariable::new(42).into(); + let c = a + b.clone(); + let mut expected_map = BTreeMap::new(); + expected_map.insert(FlatVariable::new(42), FieldPrime::from(2)); + assert_eq!(c, LinComb(expected_map)); + } + #[test] + fn sub() { + let a: LinComb = FlatVariable::new(42).into(); + let b: LinComb = FlatVariable::new(42).into(); + let c = a - b.clone(); + assert_eq!(c, LinComb::zero()); + } + } + + mod quadratic { + use super::*; + #[test] + fn from_linear() { + let a: LinComb = LinComb::summand(3, FlatVariable::new(42)) + + LinComb::summand(4, FlatVariable::new(33)); + let expected = QuadComb { + left: LinComb::one(), + right: a.clone(), + }; + assert_eq!(QuadComb::from(a), expected); + } + + #[test] + fn zero() { + let a: LinComb = LinComb::zero(); + let expected: QuadComb = QuadComb { + left: LinComb::one(), + right: LinComb::zero(), + }; + assert_eq!(QuadComb::from(a), expected); + } + } +} diff --git a/zokrates_core/src/ir/from_flat.rs b/zokrates_core/src/ir/from_flat.rs new file mode 100644 index 00000000..8fa1773f --- /dev/null +++ b/zokrates_core/src/ir/from_flat.rs @@ -0,0 +1,214 @@ +use field::Field; +use flat_absy::{FlatExpression, FlatFunction, FlatProg, FlatStatement, FlatVariable}; +use helpers; +use ir::{DirectiveStatement, Function, LinComb, Prog, QuadComb, Statement}; +use num::Zero; + +impl From> for Function { + fn from(flat_function: FlatFunction) -> Function { + let return_expressions: Vec> = flat_function + .statements + .iter() + .filter_map(|s| match s { + FlatStatement::Return(el) => Some(el.expressions.clone()), + _ => None, + }) + .next() + .unwrap(); + Function { + id: flat_function.id, + arguments: flat_function.arguments.into_iter().map(|p| p.id).collect(), + returns: return_expressions.into_iter().map(|e| e.into()).collect(), + statements: flat_function + .statements + .into_iter() + .filter_map(|s| match s { + FlatStatement::Return(..) => None, + s => Some(s.into()), + }) + .collect(), + } + } +} + +impl From> for Prog { + fn from(flat_prog: FlatProg) -> Prog { + // get the main function as all calls have been resolved + let main = flat_prog + .functions + .into_iter() + .find(|f| f.id == "main") + .unwrap(); + + // get the interface of the program, ie which inputs are private and public + let private = main.arguments.iter().map(|p| p.private).collect(); + + // convert the main function to this IR for functions + let main: Function = main.into(); + + // contrary to other functions, we need to make sure that return values are identifiers, so we define new (public) variables + let definitions = + main.returns.iter().enumerate().map(|(index, e)| { + Statement::Constraint(e.clone(), FlatVariable::public(index).into()) + }); + + // update the main function with the extra definition statements and replace the return values + let main = Function { + returns: (0..main.returns.len()) + .map(|i| FlatVariable::public(i).into()) + .collect(), + statements: main.statements.into_iter().chain(definitions).collect(), + ..main + }; + + let main = Function::from(main); + Prog { private, main } + } +} + +impl From> for QuadComb { + fn from(flat_expression: FlatExpression) -> QuadComb { + match flat_expression.is_linear() { + true => LinComb::from(flat_expression).into(), + false => match flat_expression { + FlatExpression::Mult(box e1, box e2) => { + QuadComb::from_linear_combinations(e1.into(), e2.into()) + } + e => unimplemented!("{}", e), + }, + } + } +} + +impl From> for LinComb { + fn from(flat_expression: FlatExpression) -> LinComb { + assert!(flat_expression.is_linear()); + match flat_expression { + FlatExpression::Number(ref n) if *n == T::from(0) => LinComb::zero(), + FlatExpression::Number(n) => LinComb::summand(n, FlatVariable::one()), + FlatExpression::Identifier(id) => LinComb::from(id), + FlatExpression::Add(box e1, box e2) => LinComb::from(e1) + LinComb::from(e2), + FlatExpression::Sub(box e1, box e2) => LinComb::from(e1) - LinComb::from(e2), + FlatExpression::Mult( + box FlatExpression::Number(n1), + box FlatExpression::Identifier(v1), + ) + | FlatExpression::Mult( + box FlatExpression::Identifier(v1), + box FlatExpression::Number(n1), + ) => LinComb::summand(n1, v1), + e => unimplemented!("{}", e), + } + } +} + +impl From> for Statement { + fn from(flat_statement: FlatStatement) -> Statement { + match flat_statement { + FlatStatement::Condition(linear, quadratic) => match quadratic { + FlatExpression::Mult(box lhs, box rhs) => Statement::Constraint( + QuadComb::from_linear_combinations(lhs.into(), rhs.into()), + linear.into(), + ), + e => Statement::Constraint(LinComb::from(e).into(), linear.into()), + }, + FlatStatement::Definition(var, quadratic) => match quadratic { + FlatExpression::Mult(box lhs, box rhs) => Statement::Constraint( + QuadComb::from_linear_combinations(lhs.into(), rhs.into()), + var.into(), + ), + e => Statement::Constraint(LinComb::from(e).into(), var.into()), + }, + FlatStatement::Directive(ds) => Statement::Directive(ds.into()), + _ => panic!("return should be handled at the function level"), + } + } +} + +impl From> for DirectiveStatement { + fn from(ds: helpers::DirectiveStatement) -> DirectiveStatement { + DirectiveStatement { + inputs: ds.inputs.into_iter().map(|i| i.into()).collect(), + helper: ds.helper, + outputs: ds.outputs, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use field::FieldPrime; + + #[test] + fn zero() { + // 0 + let zero = FlatExpression::Number(FieldPrime::from(0)); + let expected: LinComb = LinComb::zero(); + assert_eq!(LinComb::from(zero), expected); + } + + #[test] + fn one() { + // 1 + let one = FlatExpression::Number(FieldPrime::from(1)); + let expected: LinComb = FlatVariable::one().into(); + assert_eq!(LinComb::from(one), expected); + } + + #[test] + fn forty_two() { + // 42 + let one = FlatExpression::Number(FieldPrime::from(42)); + let expected: LinComb = LinComb::summand(42, FlatVariable::one()); + assert_eq!(LinComb::from(one), expected); + } + + #[test] + fn add() { + // x + y + let add = FlatExpression::Add( + box FlatExpression::Identifier(FlatVariable::new(42)), + box FlatExpression::Identifier(FlatVariable::new(21)), + ); + let expected: LinComb = + LinComb::summand(1, FlatVariable::new(42)) + LinComb::summand(1, FlatVariable::new(21)); + assert_eq!(LinComb::from(add), expected); + } + + #[test] + fn linear_combination() { + // 42*x + 21*y + let add = FlatExpression::Add( + box FlatExpression::Mult( + box FlatExpression::Number(FieldPrime::from(42)), + box FlatExpression::Identifier(FlatVariable::new(42)), + ), + box FlatExpression::Mult( + box FlatExpression::Number(FieldPrime::from(21)), + box FlatExpression::Identifier(FlatVariable::new(21)), + ), + ); + let expected: LinComb = LinComb::summand(42, FlatVariable::new(42)) + + LinComb::summand(21, FlatVariable::new(21)); + assert_eq!(LinComb::from(add), expected); + } + + #[test] + fn linear_combination_inverted() { + // x*42 + y*21 + let add = FlatExpression::Add( + box FlatExpression::Mult( + box FlatExpression::Identifier(FlatVariable::new(42)), + box FlatExpression::Number(FieldPrime::from(42)), + ), + box FlatExpression::Mult( + box FlatExpression::Identifier(FlatVariable::new(21)), + box FlatExpression::Number(FieldPrime::from(21)), + ), + ); + let expected: LinComb = LinComb::summand(42, FlatVariable::new(42)) + + LinComb::summand(21, FlatVariable::new(21)); + assert_eq!(LinComb::from(add), expected); + } +} diff --git a/zokrates_core/src/ir/interpreter.rs b/zokrates_core/src/ir/interpreter.rs new file mode 100644 index 00000000..913f1d7d --- /dev/null +++ b/zokrates_core/src/ir/interpreter.rs @@ -0,0 +1,89 @@ +use field::Field; +use helpers::Executable; +use ir::*; +use std::collections::BTreeMap; + +impl Prog { + pub fn execute(self, inputs: Vec) -> Result, Error> { + let main = self.main; + assert_eq!(main.arguments.len(), inputs.len()); + let mut witness = BTreeMap::new(); + witness.insert(FlatVariable::one(), T::one()); + for (arg, value) in main.arguments.iter().zip(inputs.iter()) { + witness.insert(arg.clone(), value.clone()); + } + + for statement in main.statements { + match statement { + Statement::Constraint(quad, lin) => match lin.is_assignee(&witness) { + true => { + let val = quad.evaluate(&witness); + witness.insert(lin.0.iter().next().unwrap().0.clone(), val); + } + false => { + let lhs_value = quad.evaluate(&witness); + let rhs_value = lin.evaluate(&witness); + if lhs_value != rhs_value { + return Err(Error::Constraint(quad, lin, lhs_value, rhs_value)); + } + } + }, + Statement::Directive(ref d) => { + let input_values: Vec = + d.inputs.iter().map(|i| i.evaluate(&witness)).collect(); + match d.helper.execute(&input_values) { + Ok(res) => { + for (i, o) in d.outputs.iter().enumerate() { + witness.insert(o.clone(), res[i].clone()); + } + continue; + } + Err(_) => return Err(Error::Solver), + }; + } + } + } + + Ok(witness) + } +} + +impl LinComb { + fn evaluate(&self, witness: &BTreeMap) -> T { + self.0 + .iter() + .map(|(var, val)| witness.get(var).unwrap().clone() * val) + .fold(T::from(0), |acc, t| acc + t) + } + + fn is_assignee(&self, witness: &BTreeMap) -> bool { + self.0.iter().count() == 1 + && self.0.iter().next().unwrap().1 == &T::from(1) + && !witness.contains_key(self.0.iter().next().unwrap().0) + } +} + +impl QuadComb { + fn evaluate(&self, witness: &BTreeMap) -> T { + self.left.evaluate(&witness) * self.right.evaluate(&witness) + } +} + +#[derive(PartialEq, Debug)] +pub enum Error { + Constraint(QuadComb, LinComb, T, T), + Solver, +} + +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match *self { + Error::Constraint(ref quad, ref lin, ref left_value, ref right_value) => write!( + f, + "Expected {} to equal {}, but {} != {}", + quad, lin, left_value, right_value + ), + Error::Solver => write!(f, ""), + } + } +} diff --git a/zokrates_core/src/ir/mod.rs b/zokrates_core/src/ir/mod.rs new file mode 100644 index 00000000..c073e46a --- /dev/null +++ b/zokrates_core/src/ir/mod.rs @@ -0,0 +1,270 @@ +use field::Field; +use flat_absy::flat_parameter::FlatParameter; +use flat_absy::FlatVariable; +use helpers::Helper; +use std::collections::HashMap; +use std::fmt; +use std::mem; + +mod expression; +mod from_flat; +mod interpreter; + +use self::expression::LinComb; +use self::expression::QuadComb; + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub enum Statement { + Constraint(QuadComb, LinComb), + Directive(DirectiveStatement), +} + +#[derive(Clone, PartialEq, Debug, Serialize, Deserialize)] +pub struct DirectiveStatement { + pub inputs: Vec>, + pub outputs: Vec, + pub helper: Helper, +} + +impl fmt::Display for DirectiveStatement { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!( + f, + "# {} = {}({})", + self.outputs + .iter() + .map(|o| format!("{}", o)) + .collect::>() + .join(", "), + self.helper, + self.inputs + .iter() + .map(|i| format!("{}", i)) + .collect::>() + .join(", ") + ) + } +} + +impl fmt::Display for Statement { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match *self { + Statement::Constraint(ref quad, ref lin) => write!(f, "{} == {}", quad, lin), + Statement::Directive(ref s) => write!(f, "{}", s), + } + } +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct Function { + pub id: String, + pub statements: Vec>, + pub arguments: Vec, + pub returns: Vec>, +} + +impl fmt::Display for Function { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!( + f, + "def {}({}) -> ({}):\n{}\n\t return {}", + self.id, + self.arguments + .iter() + .map(|v| format!("{}", v)) + .collect::>() + .join(", "), + self.returns.len(), + self.statements + .iter() + .map(|s| format!("\t{}", s)) + .collect::>() + .join("\n"), + self.returns + .iter() + .map(|e| format!("{}", e)) + .collect::>() + .join(", ") + ) + } +} + +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct Prog { + pub main: Function, + pub private: Vec, +} + +impl Prog { + pub fn constraint_count(&self) -> usize { + self.main + .statements + .iter() + .filter(|s| match s { + Statement::Constraint(..) => true, + _ => false, + }) + .count() + } + + pub fn public_arguments_count(&self) -> usize { + self.private.iter().filter(|b| !**b).count() + } + + pub fn private_arguments_count(&self) -> usize { + self.private.iter().filter(|b| **b).count() + } + + pub fn parameters(&self) -> Vec { + self.main + .arguments + .iter() + .zip(self.private.iter()) + .map(|(id, private)| FlatParameter { + private: *private, + id: *id, + }) + .collect() + } +} + +impl fmt::Display for Prog { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}", self.main) + } +} + +/// Returns the index of `var` in `variables`, adding `var` with incremented index if it not yet exists. +/// +/// # Arguments +/// +/// * `variables` - A mutual map that maps all existing variables to their index. +/// * `var` - Variable to be searched for. +pub fn provide_variable_idx( + variables: &mut HashMap, + var: &FlatVariable, +) -> usize { + let index = variables.len(); + *variables.entry(*var).or_insert(index) +} + +/// Calculates one R1CS row representation of a program and returns (V, A, B, C) so that: +/// * `V` contains all used variables and the index in the vector represents the used number in `A`, `B`, `C` +/// * `* = ` for a witness `x` +/// +/// # Arguments +/// +/// * `prog` - The program the representation is calculated for. +pub fn r1cs_program( + prog: Prog, +) -> ( + Vec, + usize, + Vec>, + Vec>, + Vec>, +) { + let mut variables: HashMap = HashMap::new(); + provide_variable_idx(&mut variables, &FlatVariable::one()); + + for x in prog + .main + .arguments + .iter() + .enumerate() + .filter(|(index, _)| !prog.private[*index]) + { + provide_variable_idx(&mut variables, &x.1); + } + + //Only the main function is relevant in this step, since all calls to other functions were resolved during flattening + let main = prog.main; + + //~out are added after main's arguments as we want variables (columns) + //in the r1cs to be aligned like "public inputs | private inputs" + let main_return_count = main.returns.len(); + + for i in 0..main_return_count { + provide_variable_idx(&mut variables, &FlatVariable::public(i)); + } + + // position where private part of witness starts + let private_inputs_offset = variables.len(); + + // first pass through statements to populate `variables` + for (quad, lin) in main.statements.iter().filter_map(|s| match s { + Statement::Constraint(quad, lin) => Some((quad, lin)), + Statement::Directive(..) => None, + }) { + for (k, _) in &quad.left.0 { + provide_variable_idx(&mut variables, &k); + } + for (k, _) in &quad.right.0 { + provide_variable_idx(&mut variables, &k); + } + for (k, _) in &lin.0 { + provide_variable_idx(&mut variables, &k); + } + } + + let mut a = vec![]; + let mut b = vec![]; + let mut c = vec![]; + + // second pass to convert program to raw sparse vectors + for (quad, lin) in main.statements.into_iter().filter_map(|s| match s { + Statement::Constraint(quad, lin) => Some((quad, lin)), + Statement::Directive(..) => None, + }) { + a.push( + quad.left + .0 + .into_iter() + .map(|(k, v)| (variables.get(&k).unwrap().clone(), v)) + .collect(), + ); + b.push( + quad.right + .0 + .into_iter() + .map(|(k, v)| (variables.get(&k).unwrap().clone(), v)) + .collect(), + ); + c.push( + lin.0 + .into_iter() + .map(|(k, v)| (variables.get(&k).unwrap().clone(), v)) + .collect(), + ); + } + + // Convert map back into list ordered by index + let mut variables_list = vec![FlatVariable::new(0); variables.len()]; + for (k, v) in variables.drain() { + assert_eq!(variables_list[v], FlatVariable::new(0)); + mem::replace(&mut variables_list[v], k); + } + (variables_list, private_inputs_offset, a, b, c) +} + +#[cfg(test)] +mod tests { + use super::*; + use field::FieldPrime; + + mod statement { + use super::*; + + #[test] + fn print_constraint() { + let c: Statement = Statement::Constraint( + QuadComb::from_linear_combinations( + FlatVariable::new(42).into(), + FlatVariable::new(42).into(), + ), + FlatVariable::new(42).into(), + ); + assert_eq!(format!("{}", c), "(1 * _42) * (1 * _42) == 1 * _42") + } + } +} diff --git a/zokrates_core/src/lib.rs b/zokrates_core/src/lib.rs index 015a9274..c966616c 100644 --- a/zokrates_core/src/lib.rs +++ b/zokrates_core/src/lib.rs @@ -29,8 +29,8 @@ pub mod absy; pub mod compile; pub mod field; pub mod flat_absy; +pub mod ir; #[cfg(feature = "libsnark")] pub mod libsnark; #[cfg(feature = "libsnark")] pub mod proof_system; -pub mod r1cs; diff --git a/zokrates_core/src/parser/parse/expression.rs b/zokrates_core/src/parser/parse/expression.rs index 2df658c3..4ead5452 100644 --- a/zokrates_core/src/parser/parse/expression.rs +++ b/zokrates_core/src/parser/parse/expression.rs @@ -362,14 +362,14 @@ pub fn parse_function_call( p = p2; } (Token::Close, s2, p2) => { - return parse_term1(Expression::FunctionCall(ide, args), s2, p2) + return parse_term1(Expression::FunctionCall(ide, args), s2, p2); } (t2, _, p2) => { return Err(Error { expected: vec![Token::Comma, Token::Close], got: t2, pos: p2, - }) + }); } } } @@ -417,7 +417,7 @@ pub fn parse_inline_array( expected: vec![Token::Comma, Token::RightBracket], got: t2, pos: p2, - }) + }); } } } diff --git a/zokrates_core/src/parser/parse/function.rs b/zokrates_core/src/parser/parse/function.rs index a8ea072a..64d030da 100644 --- a/zokrates_core/src/parser/parse/function.rs +++ b/zokrates_core/src/parser/parse/function.rs @@ -42,7 +42,7 @@ fn parse_function_header( expected: vec![Token::Close], got: t3, pos: p3, - }) + }); } }, Err(e) => return Err(e), @@ -52,7 +52,7 @@ fn parse_function_header( expected: vec![Token::Open], got: t1, pos: p1, - }) + }); } }?; @@ -67,7 +67,7 @@ fn parse_function_header( expected: vec![Token::Close], got: t3, pos: p3, - }) + }); } }, Err(e) => return Err(e), @@ -77,7 +77,7 @@ fn parse_function_header( expected: vec![Token::Open], got: t1, pos: p1, - }) + }); } }, (t0, _, p0) => { @@ -85,7 +85,7 @@ fn parse_function_header( expected: vec![Token::Arrow], got: t0, pos: p0, - }) + }); } }?; @@ -103,7 +103,7 @@ fn parse_function_header( expected: vec![Token::Unknown("".to_string())], got: t6, pos: p6, - }) + }); } }, (t5, _, p5) => { @@ -111,7 +111,7 @@ fn parse_function_header( expected: vec![Token::Colon], got: t5, pos: p5, - }) + }); } } } @@ -167,7 +167,7 @@ fn parse_function_arguments( expected: vec![Token::Comma, Token::Close], got: t3, pos: p3, - }) + }); } } } @@ -188,7 +188,7 @@ fn parse_function_arguments( expected: vec![Token::Comma, Token::Close], got: t3, pos: p3, - }) + }); } } } @@ -202,7 +202,7 @@ fn parse_function_arguments( ], got: t4, pos: p4, - }) + }); } } } @@ -231,7 +231,7 @@ fn parse_function_return_types( expected: vec![Token::Comma, Token::Close], got: t3, pos: p3, - }) + }); } } } @@ -245,7 +245,7 @@ fn parse_function_return_types( ], got: t4, pos: p4, - }) + }); } } } diff --git a/zokrates_core/src/parser/parse/import.rs b/zokrates_core/src/parser/parse/import.rs index 9d1462cc..dc6a9142 100644 --- a/zokrates_core/src/parser/parse/import.rs +++ b/zokrates_core/src/parser/parse/import.rs @@ -16,14 +16,14 @@ pub fn parse_import( (Token::Path(code_path), s2, p2) => match next_token::(&s2, &p2) { (Token::As, s3, p3) => match next_token(&s3, &p3) { (Token::Ide(id), _, p4) => { - return Ok((Import::new_with_alias(code_path, &id), p4)) + return Ok((Import::new_with_alias(code_path, &id), p4)); } (t4, _, p4) => { return Err(Error { expected: vec![Token::Ide("ide".to_string())], got: t4, pos: p4, - }) + }); } }, (Token::Unknown(_), _, p3) => return Ok((Import::new(code_path), p3)), @@ -35,7 +35,7 @@ pub fn parse_import( ], got: t3, pos: p3, - }) + }); } }, (t2, _, p2) => Err(Error { diff --git a/zokrates_core/src/parser/parse/program.rs b/zokrates_core/src/parser/parse/program.rs index 96399e6e..adb26a0e 100644 --- a/zokrates_core/src/parser/parse/program.rs +++ b/zokrates_core/src/parser/parse/program.rs @@ -47,7 +47,7 @@ pub fn parse_program(reader: &mut R) -> Result, Er expected: vec![Token::Def], got: t1, pos: p1, - }) + }); } }, None => break, diff --git a/zokrates_core/src/r1cs.rs b/zokrates_core/src/r1cs.rs deleted file mode 100644 index c4bcc298..00000000 --- a/zokrates_core/src/r1cs.rs +++ /dev/null @@ -1,618 +0,0 @@ -//! Module containing necessary functions to convert a flattened program or expression to r1cs. -//! -//! @file r1cs.rs -//! @author Dennis Kuhnert -//! @author Jacob Eberhardt [a, 2*b, c-d] -fn get_summands(expr: &FlatExpression) -> Vec<&FlatExpression> { - let mut trace = Vec::new(); - let mut add = Vec::new(); - trace.push(expr); - loop { - if let Some(e) = trace.pop() { - match *e { - ref e @ Number(_) | ref e @ Identifier(_) | ref e @ Mult(..) | ref e @ Sub(..) - if e.is_linear() => - { - add.push(e) - } - Add(ref l, ref r) => { - trace.push(l); - trace.push(r); - } - ref e => panic!("Not covered: {}", e), - } - } else { - return add; - } - } -} - -/// Returns a `HashMap` containing variables and the number of occurrences -/// -/// # Arguments -/// -/// * `expr` - FlatExpression only containing Numbers, Variables, Add and Mult -/// -/// # Example -/// -/// `7 * x + 4 * y + x` -> { x => 8, y = 4 } -fn count_variables_add(expr: &FlatExpression) -> HashMap { - let summands = get_summands(expr); - let mut count = HashMap::new(); - for s in summands { - match *s { - Number(ref x) => { - let num = count.entry(FlatVariable::one()).or_insert(T::zero()); - *num = num.clone() + x; - } - Identifier(ref v) => { - let num = count.entry(*v).or_insert(T::zero()); - *num = num.clone() + T::one(); - } - Mult(box Number(ref x1), box Number(ref x2)) => { - let num = count.entry(FlatVariable::one()).or_insert(T::zero()); - *num = num.clone() + x1 + x2; - } - Mult(box Number(ref x), box Identifier(ref v)) - | Mult(box Identifier(ref v), box Number(ref x)) => { - let num = count.entry(*v).or_insert(T::zero()); - *num = num.clone() + x; - } - ref e => panic!("Not covered: {}", e), - } - } - count -} - -/// Returns an equation equivalent to `lhs == rhs` only using `Add` and `Mult` -/// -/// # Arguments -/// -/// * `lhs` - Left hand side of the equation -/// * `rhs` - Right hand side of the equation -fn swap_sub( - lhs: &FlatExpression, - rhs: &FlatExpression, -) -> (FlatExpression, FlatExpression) { - let mut left = get_summands(lhs); - let mut right = get_summands(rhs); - let mut run = true; - while run { - run = false; - for i in 0..left.len() { - match *left[i] { - ref e @ Number(_) | ref e @ Identifier(_) | ref e @ Mult(..) if e.is_linear() => {} - Sub(ref l, ref r) => { - run = true; - left.swap_remove(i); - left.extend(get_summands(l)); - right.extend(get_summands(r)); - } - ref e => panic!("Unexpected: {}", e), - } - } - for i in 0..right.len() { - match *right[i] { - ref e @ Number(_) | ref e @ Identifier(_) | ref e @ Mult(..) if e.is_linear() => {} - Sub(ref l, ref r) => { - run = true; - right.swap_remove(i); - right.extend(get_summands(l)); - left.extend(get_summands(r)); - } - ref e => panic!("Unexpected: {}", e), - } - } - } - if let Some(left_init) = left.pop() { - if let Some(right_init) = right.pop() { - return ( - left.iter() - .fold(left_init.clone(), |acc, &x| Add(box acc, box x.clone())), - right - .iter() - .fold(right_init.clone(), |acc, &x| Add(box acc, box x.clone())), - ); - } - } - panic!("Unexpected"); -} - -/// Calculates one R1CS row representation for `linear_expr` = `expr`. -/// ( = *) -/// -/// # Arguments -/// -/// * `linear_expr` - Left hand side of the equation, has to be linear -/// * `expr` - Right hand side of the equation -/// * `variables` - a mutual vector that contains all existing variables. Not found variables will be added. -/// * `a_row` - Result row of matrix a -/// * `b_row` - Result row of matrix B -/// * `c_row` - Result row of matrix C -fn r1cs_expression( - linear_expr: FlatExpression, - expr: FlatExpression, - variables: &mut HashMap, - a_row: &mut Vec<(usize, T)>, - b_row: &mut Vec<(usize, T)>, - c_row: &mut Vec<(usize, T)>, -) { - assert!(linear_expr.is_linear()); - - match expr { - e @ Add(..) | e @ Sub(..) => { - let (lhs, rhs) = swap_sub(&linear_expr, &e); - for (key, value) in count_variables_add(&rhs) { - a_row.push((provide_variable_idx(variables, &key), value)); - } - b_row.push((0, T::one())); - for (key, value) in count_variables_add(&lhs) { - c_row.push((provide_variable_idx(variables, &key), value)); - } - } - Mult(lhs, rhs) => { - match lhs { - box Number(x) => a_row.push((0, x)), - box Identifier(x) => a_row.push((provide_variable_idx(variables, &x), T::one())), - box e @ Add(..) => { - for (key, value) in count_variables_add(&e) { - a_row.push((provide_variable_idx(variables, &key), value)); - } - } - e @ _ => panic!("Not flattened: {}", e), - }; - match rhs { - box Number(x) => b_row.push((0, x)), - box Identifier(x) => b_row.push((provide_variable_idx(variables, &x), T::one())), - box e @ Add(..) => { - for (key, value) in count_variables_add(&e) { - b_row.push((provide_variable_idx(variables, &key), value)); - } - } - e @ _ => panic!("Not flattened: {}", e), - }; - for (key, value) in count_variables_add(&linear_expr) { - c_row.push((provide_variable_idx(variables, &key), value)); - } - } - Identifier(var) => { - a_row.push((provide_variable_idx(variables, &var), T::one())); - b_row.push((0, T::one())); - for (key, value) in count_variables_add(&linear_expr) { - c_row.push((provide_variable_idx(variables, &key), value)); - } - } - Number(x) => { - a_row.push((0, x)); - b_row.push((0, T::one())); - for (key, value) in count_variables_add(&linear_expr) { - c_row.push((provide_variable_idx(variables, &key), value)); - } - } - } -} - -/// Returns the index of `var` in `variables`, adding `var` with incremented index if it not yet exists. -/// -/// # Arguments -/// -/// * `variables` - A mutual map that maps all existing variables to their index. -/// * `var` - Variable to be searched for. -fn provide_variable_idx(variables: &mut HashMap, var: &FlatVariable) -> usize { - let index = variables.len(); - *variables.entry(*var).or_insert(index) -} - -/// Calculates one R1CS row representation of a program and returns (V, A, B, C) so that: -/// * `V` contains all used variables and the index in the vector represents the used number in `A`, `B`, `C` -/// * `* = ` for a witness `x` -/// -/// # Arguments -/// -/// * `prog` - The program the representation is calculated for. -pub fn r1cs_program( - prog: &FlatProg, -) -> ( - Vec, - usize, - Vec>, - Vec>, - Vec>, -) { - let mut variables: HashMap = HashMap::new(); - provide_variable_idx(&mut variables, &FlatVariable::one()); - let mut a: Vec> = Vec::new(); - let mut b: Vec> = Vec::new(); - let mut c: Vec> = Vec::new(); - - //Only the main function is relevant in this step, since all calls to other functions were resolved during flattening - let main = prog - .clone() - .functions - .into_iter() - .find(|x: &FlatFunction| x.id == "main".to_string()) - .unwrap(); - - for x in main.arguments.iter().filter(|x| !x.private) { - provide_variable_idx(&mut variables, &x.id); - } - - // ~out is added after main's arguments as we want variables (columns) - // in the r1cs to be aligned like "public inputs | private inputs" - let main_return_count = main - .signature - .outputs - .iter() - .map(|t| t.get_primitive_count()) - .fold(0, |acc, x| acc + x); - - for i in 0..main_return_count { - provide_variable_idx(&mut variables, &FlatVariable::public(i)); - } - - // position where private part of witness starts - let private_inputs_offset = variables.len(); - - for def in &main.statements { - match *def { - FlatStatement::Return(ref list) => { - for (i, val) in list.expressions.iter().enumerate() { - let mut a_row = Vec::new(); - let mut b_row = Vec::new(); - let mut c_row = Vec::new(); - r1cs_expression( - Identifier(FlatVariable::public(i)), - val.clone(), - &mut variables, - &mut a_row, - &mut b_row, - &mut c_row, - ); - a.push(a_row); - b.push(b_row); - c.push(c_row); - } - } - FlatStatement::Definition(ref id, ref rhs) => { - let mut a_row = Vec::new(); - let mut b_row = Vec::new(); - let mut c_row = Vec::new(); - r1cs_expression( - FlatExpression::Identifier(*id), - rhs.clone(), - &mut variables, - &mut a_row, - &mut b_row, - &mut c_row, - ); - a.push(a_row); - b.push(b_row); - c.push(c_row); - } - FlatStatement::Condition(ref expr1, ref expr2) => { - let mut a_row = Vec::new(); - let mut b_row = Vec::new(); - let mut c_row = Vec::new(); - r1cs_expression( - expr1.clone(), - expr2.clone(), - &mut variables, - &mut a_row, - &mut b_row, - &mut c_row, - ); - a.push(a_row); - b.push(b_row); - c.push(c_row); - } - FlatStatement::Directive(..) => continue, - } - } - - // Convert map back into list ordered by index - let mut variables_list = vec![FlatVariable::new(0); variables.len()]; - for (k, v) in variables.drain() { - assert_eq!(variables_list[v], FlatVariable::new(0)); - mem::replace(&mut variables_list[v], k); - } - (variables_list, private_inputs_offset, a, b, c) -} - -#[cfg(test)] -mod tests { - use super::*; - use field::FieldPrime; - use std::cmp::Ordering; - - /// Sort function for tuples `(x, y)` which sorts based on `x` first. - /// If `x` is equal, `y` is used for comparison. - fn sort_tup(a: &(A, B), b: &(A, B)) -> Ordering { - if a.0 == b.0 { - a.1.cmp(&b.1) - } else { - a.0.cmp(&b.0) - } - } - - #[cfg(test)] - mod r1cs_expression { - use super::*; - - #[test] - fn add() { - // x = y + 5 - - let one = FlatVariable::one(); - let x = FlatVariable::new(0); - let y = FlatVariable::new(1); - - let lhs = Identifier(x); - let rhs = Add(box Identifier(y), box Number(FieldPrime::from(5))); - - let mut variables: HashMap = HashMap::new(); - variables.insert(one, 0); - variables.insert(x, 1); - variables.insert(y, 2); - let mut a_row: Vec<(usize, FieldPrime)> = Vec::new(); - let mut b_row: Vec<(usize, FieldPrime)> = Vec::new(); - let mut c_row: Vec<(usize, FieldPrime)> = Vec::new(); - - r1cs_expression(lhs, rhs, &mut variables, &mut a_row, &mut b_row, &mut c_row); - a_row.sort_by(sort_tup); - b_row.sort_by(sort_tup); - c_row.sort_by(sort_tup); - assert_eq!( - vec![(0, FieldPrime::from(5)), (2, FieldPrime::from(1))], - a_row - ); - assert_eq!(vec![(0, FieldPrime::from(1))], b_row); - assert_eq!(vec![(1, FieldPrime::from(1))], c_row); - } - - #[test] - fn add_sub_mix() { - // (x + y) - ((z + 3*x) - y) == (x - y) + ((2*x - 4*y) + (4*y - 2*z)) - // --> (x + y) + y + 4y + 2z + y == x + 2x + 4y + (z + 3x) - // <=> x + 7*y + 2*z == 6*x + 4y + z - - let one = FlatVariable::one(); - let x = FlatVariable::new(0); - let y = FlatVariable::new(1); - let z = FlatVariable::new(2); - - let lhs = Sub( - box Add(box Identifier(x), box Identifier(y)), - box Sub( - box Add( - box Identifier(z), - box Mult(box Number(FieldPrime::from(3)), box Identifier(x)), - ), - box Identifier(y), - ), - ); - let rhs = Add( - box Sub(box Identifier(x), box Identifier(y)), - box Add( - box Sub( - box Mult(box Number(FieldPrime::from(2)), box Identifier(x)), - box Mult(box Number(FieldPrime::from(4)), box Identifier(y)), - ), - box Sub( - box Mult(box Number(FieldPrime::from(4)), box Identifier(y)), - box Mult(box Number(FieldPrime::from(2)), box Identifier(z)), - ), - ), - ); - - let mut variables: HashMap = HashMap::new(); - variables.insert(one, 0); - variables.insert(x, 1); - variables.insert(y, 2); - variables.insert(z, 3); - - let mut a_row: Vec<(usize, FieldPrime)> = Vec::new(); - let mut b_row: Vec<(usize, FieldPrime)> = Vec::new(); - let mut c_row: Vec<(usize, FieldPrime)> = Vec::new(); - - r1cs_expression(lhs, rhs, &mut variables, &mut a_row, &mut b_row, &mut c_row); - a_row.sort_by(sort_tup); - b_row.sort_by(sort_tup); - c_row.sort_by(sort_tup); - assert_eq!( - vec![ - (1, FieldPrime::from(6)), - (2, FieldPrime::from(4)), - (3, FieldPrime::from(1)), - ], - a_row - ); - assert_eq!(vec![(0, FieldPrime::from(1))], b_row); - assert_eq!( - vec![ - (1, FieldPrime::from(1)), - (2, FieldPrime::from(7)), - (3, FieldPrime::from(2)), - ], - c_row - ); - } - - #[test] - fn sub() { - // 7 * x + y == 3 * y - z * 6 - - let one = FlatVariable::one(); - let x = FlatVariable::new(0); - let y = FlatVariable::new(1); - let z = FlatVariable::new(2); - - let lhs = Add( - box Mult(box Number(FieldPrime::from(7)), box Identifier(x)), - box Identifier(y), - ); - let rhs = Sub( - box Mult(box Number(FieldPrime::from(3)), box Identifier(y)), - box Mult(box Identifier(z), box Number(FieldPrime::from(6))), - ); - - let mut variables: HashMap = HashMap::new(); - variables.insert(one, 0); - variables.insert(x, 1); - variables.insert(y, 2); - variables.insert(z, 3); - - let mut a_row: Vec<(usize, FieldPrime)> = Vec::new(); - let mut b_row: Vec<(usize, FieldPrime)> = Vec::new(); - let mut c_row: Vec<(usize, FieldPrime)> = Vec::new(); - - r1cs_expression(lhs, rhs, &mut variables, &mut a_row, &mut b_row, &mut c_row); - a_row.sort_by(sort_tup); - b_row.sort_by(sort_tup); - c_row.sort_by(sort_tup); - assert_eq!(vec![(2, FieldPrime::from(3))], a_row); // 3 * y - assert_eq!(vec![(0, FieldPrime::from(1))], b_row); // 1 - assert_eq!( - vec![ - (1, FieldPrime::from(7)), - (2, FieldPrime::from(1)), - (3, FieldPrime::from(6)), - ], - c_row - ); // (7 * x + y) + z * 6 - } - - #[test] - fn sub_multiple() { - // (((3 * y) - (z * 2)) - (x * 12)) == (a - x) - // --> 3*y + x == a + 12*x + 2*z - - let one = FlatVariable::one(); - let x = FlatVariable::new(0); - let y = FlatVariable::new(1); - let z = FlatVariable::new(2); - let a = FlatVariable::new(3); - - let lhs = Sub( - box Sub( - box Mult(box Number(FieldPrime::from(3)), box Identifier(y)), - box Mult(box Identifier(z), box Number(FieldPrime::from(2))), - ), - box Mult(box Identifier(x), box Number(FieldPrime::from(12))), - ); - let rhs = Sub(box Identifier(a), box Identifier(x)); - - let mut variables: HashMap = HashMap::new(); - variables.insert(one, 0); - variables.insert(x, 1); - variables.insert(y, 2); - variables.insert(z, 3); - variables.insert(a, 4); - - let mut a_row: Vec<(usize, FieldPrime)> = Vec::new(); - let mut b_row: Vec<(usize, FieldPrime)> = Vec::new(); - let mut c_row: Vec<(usize, FieldPrime)> = Vec::new(); - - r1cs_expression(lhs, rhs, &mut variables, &mut a_row, &mut b_row, &mut c_row); - a_row.sort_by(sort_tup); - b_row.sort_by(sort_tup); - c_row.sort_by(sort_tup); - assert_eq!( - vec![ - (1, FieldPrime::from(12)), - (3, FieldPrime::from(2)), - (4, FieldPrime::from(1)), - ], - a_row - ); // a + 12*x + 2*z - assert_eq!(vec![(0, FieldPrime::from(1))], b_row); // 1 - assert_eq!( - vec![(1, FieldPrime::from(1)), (2, FieldPrime::from(3))], - c_row - ); // 3*y + x - } - - #[test] - fn add_mult() { - // 4 * y + 3 * x + 3 * z == (3 * x + 6 * y + 4 * z) * (31 * x + 4 * z) - - let one = FlatVariable::one(); - let x = FlatVariable::new(0); - let y = FlatVariable::new(1); - let z = FlatVariable::new(2); - - let lhs = Add( - box Add( - box Mult(box Number(FieldPrime::from(4)), box Identifier(y)), - box Mult(box Number(FieldPrime::from(3)), box Identifier(x)), - ), - box Mult(box Number(FieldPrime::from(3)), box Identifier(z)), - ); - let rhs = Mult( - box Add( - box Add( - box Mult(box Number(FieldPrime::from(3)), box Identifier(x)), - box Mult(box Number(FieldPrime::from(6)), box Identifier(y)), - ), - box Mult(box Number(FieldPrime::from(4)), box Identifier(z)), - ), - box Add( - box Mult(box Number(FieldPrime::from(31)), box Identifier(x)), - box Mult(box Number(FieldPrime::from(4)), box Identifier(z)), - ), - ); - - let mut variables: HashMap = HashMap::new(); - variables.insert(one, 0); - variables.insert(x, 1); - variables.insert(y, 2); - variables.insert(z, 3); - - let mut a_row: Vec<(usize, FieldPrime)> = Vec::new(); - let mut b_row: Vec<(usize, FieldPrime)> = Vec::new(); - let mut c_row: Vec<(usize, FieldPrime)> = Vec::new(); - - r1cs_expression(lhs, rhs, &mut variables, &mut a_row, &mut b_row, &mut c_row); - a_row.sort_by(sort_tup); - b_row.sort_by(sort_tup); - c_row.sort_by(sort_tup); - assert_eq!( - vec![ - (1, FieldPrime::from(3)), - (2, FieldPrime::from(6)), - (3, FieldPrime::from(4)), - ], - a_row - ); - assert_eq!( - vec![(1, FieldPrime::from(31)), (3, FieldPrime::from(4))], - b_row - ); - assert_eq!( - vec![ - (1, FieldPrime::from(3)), - (2, FieldPrime::from(4)), - (3, FieldPrime::from(3)), - ], - c_row - ); - } - } -} diff --git a/zokrates_core/src/semantics.rs b/zokrates_core/src/semantics.rs index c52958e2..3c4d41fd 100644 --- a/zokrates_core/src/semantics.rs +++ b/zokrates_core/src/semantics.rs @@ -201,7 +201,7 @@ impl Checker { "Duplicate definition for function {} with signature {}", funct.id, funct.signature ), - }) + }); } 0 => {} _ => panic!("duplicate function declaration should have been caught"), diff --git a/zokrates_core/src/typed_absy/mod.rs b/zokrates_core/src/typed_absy/mod.rs index d8e0ecd8..6e851baa 100644 --- a/zokrates_core/src/typed_absy/mod.rs +++ b/zokrates_core/src/typed_absy/mod.rs @@ -19,7 +19,7 @@ use types::Type; pub use self::folder::Folder; -#[derive(Serialize, Deserialize, Clone, PartialEq)] +#[derive(Clone, PartialEq)] pub struct TypedProg { /// Functions of the program pub functions: Vec>, @@ -76,7 +76,7 @@ impl fmt::Debug for TypedProg { } } -#[derive(Serialize, Deserialize, Clone, PartialEq)] +#[derive(Clone, PartialEq)] pub struct TypedFunction { /// Name of the program pub id: String, @@ -130,7 +130,7 @@ impl fmt::Debug for TypedFunction { } } -#[derive(Clone, Serialize, Deserialize, PartialEq, Hash, Eq)] +#[derive(Clone, PartialEq, Hash, Eq)] pub enum TypedAssignee { Identifier(Variable), ArrayElement(Box>, Box>), @@ -166,7 +166,7 @@ impl fmt::Display for TypedAssignee { } } -#[derive(Clone, Serialize, Deserialize, PartialEq)] +#[derive(Clone, PartialEq)] pub enum TypedStatement { Return(Vec>), Definition(TypedAssignee, TypedExpression), @@ -250,7 +250,7 @@ pub trait Typed { fn get_type(&self) -> Type; } -#[derive(Clone, PartialEq, Serialize, Deserialize, Hash, Eq)] +#[derive(Clone, PartialEq, Hash, Eq)] pub enum TypedExpression { Boolean(BooleanExpression), FieldElement(FieldElementExpression), @@ -319,7 +319,7 @@ pub trait MultiTyped { fn get_types(&self) -> &Vec; } -#[derive(Clone, PartialEq, Serialize, Deserialize)] +#[derive(Clone, PartialEq)] pub enum TypedExpressionList { FunctionCall(String, Vec>, Vec), } @@ -332,7 +332,7 @@ impl MultiTyped for TypedExpressionList { } } -#[derive(Clone, PartialEq, Serialize, Deserialize, Hash, Eq)] +#[derive(Clone, PartialEq, Hash, Eq)] pub enum FieldElementExpression { Number(T), Identifier(String), @@ -368,7 +368,7 @@ pub enum FieldElementExpression { ), } -#[derive(Clone, PartialEq, Serialize, Deserialize, Hash, Eq)] +#[derive(Clone, PartialEq, Hash, Eq)] pub enum BooleanExpression { Identifier(String), Value(bool), @@ -398,7 +398,7 @@ pub enum BooleanExpression { } // for now we store the array size in the variants -#[derive(Clone, PartialEq, Serialize, Deserialize, Hash, Eq)] +#[derive(Clone, PartialEq, Hash, Eq)] pub enum FieldElementArrayExpression { Identifier(usize, String), Value(usize, Vec>),