1
0
Fork 0
mirror of synced 2025-09-23 12:18:44 +00:00

fix conflicts

This commit is contained in:
schaeff 2018-12-07 18:01:17 +01:00
commit 9160aa4164
23 changed files with 857 additions and 715 deletions

View file

@ -30,6 +30,9 @@ jobs:
- run: - run:
name: Run integration tests name: Run integration tests
command: WITH_LIBSNARK=1 LIBSNARK_SOURCE_PATH=$HOME/libsnark RUSTFLAGS="-D warnings" cargo test --release -- --ignored command: WITH_LIBSNARK=1 LIBSNARK_SOURCE_PATH=$HOME/libsnark RUSTFLAGS="-D warnings" cargo test --release -- --ignored
- run:
name: Generate code coverage report
command: ./scripts/cov.sh
- save_cache: - save_cache:
paths: paths:
- /usr/local/cargo/registry - /usr/local/cargo/registry

1
.codecov.yml Normal file
View file

@ -0,0 +1 @@
comment: off

16
scripts/cov.sh Executable file
View file

@ -0,0 +1,16 @@
#!/bin/bash
# Exit if any subcommand fails
set -e
apt-get update
apt-get install -qq curl zlib1g-dev build-essential python
apt-get install -qq cmake g++ pkg-config jq
apt-get install -qq libcurl4-openssl-dev libelf-dev libdw-dev binutils-dev libiberty-dev
cargo install cargo-kcov
cargo kcov --print-install-kcov-sh | sh
cd zokrates_fs_resolver && WITH_LIBSNARK=1 LIBSNARK_SOURCE_PATH=$HOME/libsnark cargo kcov && cd ..
cd zokrates_core && WITH_LIBSNARK=1 LIBSNARK_SOURCE_PATH=$HOME/libsnark cargo kcov && cd ..
cd zokrates_cli && WITH_LIBSNARK=1 LIBSNARK_SOURCE_PATH=$HOME/libsnark cargo kcov && cd ..
bash <(curl -s https://codecov.io/bash)
echo "Uploaded code coverage"

View file

@ -20,6 +20,16 @@ def main() -> (field):
return 1 return 1
``` ```
### If expressions
An if expression allows you to branch your code depending on conditions.
```zokrates
def main(field x) -> (field):
field y = if x + 2 == 3 then 1 else 5 fi
return y
```
### For loops ### For loops
For loops are available with the following syntax: For loops are available with the following syntax:

View file

@ -21,11 +21,9 @@ use std::path::{Path, PathBuf};
use std::string::String; use std::string::String;
use zokrates_core::compile::compile; use zokrates_core::compile::compile;
use zokrates_core::field::{Field, FieldPrime}; use zokrates_core::field::{Field, FieldPrime};
use zokrates_core::flat_absy::FlatProg; use zokrates_core::ir;
#[cfg(feature = "libsnark")] #[cfg(feature = "libsnark")]
use zokrates_core::proof_system::{ProofSystem, GM17, PGHR13}; use zokrates_core::proof_system::{ProofSystem, GM17, PGHR13};
#[cfg(feature = "libsnark")]
use zokrates_core::r1cs::r1cs_program;
use zokrates_fs_resolver::resolve as fs_resolve; use zokrates_fs_resolver::resolve as fs_resolve;
fn main() { fn main() {
@ -245,20 +243,14 @@ fn main() {
let mut reader = BufReader::new(file); let mut reader = BufReader::new(file);
let program_flattened: FlatProg<FieldPrime> = let program_flattened: ir::Prog<FieldPrime> =
match compile(&mut reader, Some(location), Some(fs_resolve)) { match compile(&mut reader, Some(location), Some(fs_resolve)) {
Ok(p) => p, Ok(p) => p,
Err(why) => panic!("Compilation failed: {}", why), Err(why) => panic!("Compilation failed: {}", why),
}; };
// number of constraints the flattened program will translate to. // number of constraints the flattened program will translate to.
let num_constraints = &program_flattened let num_constraints = program_flattened.constraint_count();
.functions
.iter()
.find(|x| x.id == "main")
.unwrap()
.statements
.len();
// serialize flattened program and write to binary file // serialize flattened program and write to binary file
let mut bin_output_file = match File::create(&bin_output_path) { let mut bin_output_file = match File::create(&bin_output_path) {
@ -305,7 +297,7 @@ fn main() {
Err(why) => panic!("couldn't open {}: {}", path.display(), why), Err(why) => panic!("couldn't open {}: {}", path.display(), why),
}; };
let program_ast: FlatProg<FieldPrime> = match deserialize_from(&mut file, Infinite) { let program_ast: ir::Prog<FieldPrime> = match deserialize_from(&mut file, Infinite) {
Ok(x) => x, Ok(x) => x,
Err(why) => { Err(why) => {
println!("{:?}", why); println!("{:?}", why);
@ -313,14 +305,8 @@ fn main() {
} }
}; };
let main_flattened = program_ast
.functions
.iter()
.find(|x| x.id == "main")
.unwrap();
// print deserialized flattened program // print deserialized flattened program
println!("{}", main_flattened); println!("{}", program_ast);
// validate #arguments // validate #arguments
let mut cli_arguments: Vec<FieldPrime> = Vec::new(); let mut cli_arguments: Vec<FieldPrime> = Vec::new();
@ -339,11 +325,11 @@ fn main() {
let is_interactive = sub_matches.occurrences_of("interactive") > 0; let is_interactive = sub_matches.occurrences_of("interactive") > 0;
// in interactive mode, only public inputs are expected // in interactive mode, only public inputs are expected
let expected_cli_args_count = main_flattened let expected_cli_args_count = if is_interactive {
.arguments program_ast.public_arguments_count()
.iter() } else {
.filter(|x| !(x.private && is_interactive)) program_ast.public_arguments_count() + program_ast.private_arguments_count()
.count(); };
if cli_arguments.len() != expected_cli_args_count { if cli_arguments.len() != expected_cli_args_count {
println!( println!(
@ -355,10 +341,9 @@ fn main() {
} }
let mut cli_arguments_iter = cli_arguments.into_iter(); let mut cli_arguments_iter = cli_arguments.into_iter();
let arguments = main_flattened let arguments: Vec<FieldPrime> = program_ast
.arguments .parameters()
.clone() .iter()
.into_iter()
.map(|x| { .map(|x| {
match x.private && is_interactive { match x.private && is_interactive {
// private inputs are passed interactively when the flag is present // private inputs are passed interactively when the flag is present
@ -383,7 +368,9 @@ fn main() {
}) })
.collect(); .collect();
let witness_map = main_flattened.get_witness(arguments).unwrap(); let witness_map = program_ast
.execute(arguments)
.unwrap_or_else(|e| panic!(format!("Execution failed: {}", e)));
println!( println!(
"\nWitness: \n\n{}", "\nWitness: \n\n{}",
@ -431,7 +418,7 @@ fn main() {
Err(why) => panic!("couldn't open {}: {}", path.display(), why), Err(why) => panic!("couldn't open {}: {}", path.display(), why),
}; };
let program_ast: FlatProg<FieldPrime> = match deserialize_from(&mut file, Infinite) { let program: ir::Prog<FieldPrime> = match deserialize_from(&mut file, Infinite) {
Ok(x) => x, Ok(x) => x,
Err(why) => { Err(why) => {
println!("{:?}", why); println!("{:?}", why);
@ -439,17 +426,11 @@ fn main() {
} }
}; };
let main_flattened = program_ast
.functions
.iter()
.find(|x| x.id == "main")
.unwrap();
// print deserialized flattened program // print deserialized flattened program
println!("{}", main_flattened); println!("{}", program);
// transform to R1CS // transform to R1CS
let (variables, public_variables_count, a, b, c) = r1cs_program(&program_ast); let (variables, public_variables_count, a, b, c) = r1cs_program(program);
// write variables meta information to file // write variables meta information to file
let var_inf_path = Path::new(sub_matches.value_of("meta-information").unwrap()); let var_inf_path = Path::new(sub_matches.value_of("meta-information").unwrap());
@ -629,6 +610,7 @@ mod tests {
extern crate glob; extern crate glob;
use self::glob::glob; use self::glob::glob;
use super::*; use super::*;
use zokrates_core::ir::r1cs_program;
#[test] #[test]
fn examples() { fn examples() {
@ -655,10 +637,10 @@ mod tests {
.into_string() .into_string()
.unwrap(); .unwrap();
let program_flattened: FlatProg<FieldPrime> = let program_flattened: ir::Prog<FieldPrime> =
compile(&mut reader, Some(location), Some(fs_resolve)).unwrap(); compile(&mut reader, Some(location), Some(fs_resolve)).unwrap();
let (..) = r1cs_program(&program_flattened); let (..) = r1cs_program(program_flattened);
} }
} }
@ -684,12 +666,12 @@ mod tests {
let mut reader = BufReader::new(file); let mut reader = BufReader::new(file);
let program_flattened: FlatProg<FieldPrime> = let program_flattened: ir::Prog<FieldPrime> =
compile(&mut reader, Some(location), Some(fs_resolve)).unwrap(); compile(&mut reader, Some(location), Some(fs_resolve)).unwrap();
let (..) = r1cs_program(&program_flattened); let (..) = r1cs_program(program_flattened.clone());
let _ = program_flattened let _ = program_flattened
.get_witness(vec![FieldPrime::from(0)]) .execute(vec![FieldPrime::from(0)])
.unwrap(); .unwrap();
} }
} }
@ -716,14 +698,14 @@ mod tests {
let mut reader = BufReader::new(file); let mut reader = BufReader::new(file);
let program_flattened: FlatProg<FieldPrime> = let program_flattened: ir::Prog<FieldPrime> =
compile(&mut reader, Some(location), Some(fs_resolve)).unwrap(); compile(&mut reader, Some(location), Some(fs_resolve)).unwrap();
let (..) = r1cs_program(&program_flattened); let (..) = r1cs_program(program_flattened.clone());
let result = std::panic::catch_unwind(|| { let result = std::panic::catch_unwind(|| {
let _ = program_flattened let _ = program_flattened
.get_witness(vec![FieldPrime::from(0)]) .execute(vec![FieldPrime::from(0)])
.unwrap(); .unwrap();
}); });
assert!(result.is_err()); assert!(result.is_err());

View file

@ -17,7 +17,7 @@ use flat_absy::*;
use imports::Import; use imports::Import;
use std::fmt; use std::fmt;
#[derive(Serialize, Deserialize, Clone, PartialEq)] #[derive(Clone, PartialEq)]
pub struct Prog<T: Field> { pub struct Prog<T: Field> {
/// Functions of the program /// Functions of the program
pub functions: Vec<Function<T>>, pub functions: Vec<Function<T>>,
@ -74,7 +74,7 @@ impl<T: Field> fmt::Debug for Prog<T> {
} }
} }
#[derive(Serialize, Deserialize, Clone, PartialEq)] #[derive(Clone, PartialEq)]
pub struct Function<T: Field> { pub struct Function<T: Field> {
/// Name of the program /// Name of the program
pub id: String, pub id: String,
@ -122,7 +122,7 @@ impl<T: Field> fmt::Debug for Function<T> {
} }
} }
#[derive(Clone, Serialize, Deserialize, PartialEq)] #[derive(Clone, PartialEq)]
pub enum Assignee<T: Field> { pub enum Assignee<T: Field> {
Identifier(String), Identifier(String),
ArrayElement(Box<Assignee<T>>, Box<Expression<T>>), ArrayElement(Box<Assignee<T>>, Box<Expression<T>>),
@ -154,7 +154,7 @@ impl<T: Field> From<Expression<T>> for Assignee<T> {
} }
} }
#[derive(Clone, Serialize, Deserialize, PartialEq)] #[derive(Clone, PartialEq)]
pub enum Statement<T: Field> { pub enum Statement<T: Field> {
Return(ExpressionList<T>), Return(ExpressionList<T>),
Declaration(Variable), Declaration(Variable),

View file

@ -8,6 +8,7 @@ use field::Field;
use flat_absy::FlatProg; use flat_absy::FlatProg;
use flatten::Flattener; use flatten::Flattener;
use imports::{self, Importer}; use imports::{self, Importer};
use ir;
use optimizer::Optimizer; use optimizer::Optimizer;
use parser::{self, parse_program}; use parser::{self, parse_program};
use semantics::{self, Checker}; use semantics::{self, Checker};
@ -64,9 +65,9 @@ pub fn compile<T: Field, R: BufRead, S: BufRead, E: Into<imports::Error>>(
reader: &mut R, reader: &mut R,
location: Option<String>, location: Option<String>,
resolve_option: Option<fn(&Option<String>, &String) -> Result<(S, String, String), E>>, resolve_option: Option<fn(&Option<String>, &String) -> Result<(S, String, String), E>>,
) -> Result<FlatProg<T>, CompileError<T>> { ) -> Result<ir::Prog<T>, CompileError<T>> {
let compiled = compile_aux(reader, location, resolve_option)?; let compiled = compile_aux(reader, location, resolve_option)?;
Ok(Optimizer::new().optimize_program(compiled)) Ok(ir::Prog::from(Optimizer::new().optimize_program(compiled)))
} }
pub fn compile_aux<T: Field, R: BufRead, S: BufRead, E: Into<imports::Error>>( pub fn compile_aux<T: Field, R: BufRead, S: BufRead, E: Into<imports::Error>>(
@ -113,7 +114,7 @@ mod test {
"# "#
.as_bytes(), .as_bytes(),
); );
let res: Result<FlatProg<FieldPrime>, CompileError<FieldPrime>> = compile( let res: Result<ir::Prog<FieldPrime>, CompileError<FieldPrime>> = compile(
&mut r, &mut r,
Some(String::from("./path/to/file")), Some(String::from("./path/to/file")),
None::< None::<
@ -138,7 +139,7 @@ mod test {
"# "#
.as_bytes(), .as_bytes(),
); );
let res: Result<FlatProg<FieldPrime>, CompileError<FieldPrime>> = compile( let res: Result<ir::Prog<FieldPrime>, CompileError<FieldPrime>> = compile(
&mut r, &mut r,
Some(String::from("./path/to/file")), Some(String::from("./path/to/file")),
None::< None::<

View file

@ -37,7 +37,7 @@ impl fmt::Display for FlatVariable {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self.id { match self.id {
0 => write!(f, "~one"), 0 => write!(f, "~one"),
i if i > 0 => write!(f, "_{}", i + 1), i if i > 0 => write!(f, "_{}", i - 1),
i => write!(f, "~out_{}", -(i + 1)), i => write!(f, "~out_{}", -(i + 1)),
} }
} }
@ -72,18 +72,18 @@ mod tests {
#[test] #[test]
fn one() { fn one() {
assert_eq!(FlatVariable::one().id, 0); assert_eq!(format!("{}", FlatVariable::one()), "~one");
} }
#[test] #[test]
fn public() { fn public() {
assert_eq!(FlatVariable::public(0).id, -1); assert_eq!(format!("{}", FlatVariable::public(0)), "~out_0");
assert_eq!(FlatVariable::public(42).id, -43); assert_eq!(format!("{}", FlatVariable::public(42)), "~out_42");
} }
#[test] #[test]
fn private() { fn private() {
assert_eq!(FlatVariable::new(0).id, 1); assert_eq!(format!("{}", FlatVariable::new(0)), "_0");
assert_eq!(FlatVariable::new(42).id, 43); assert_eq!(format!("{}", FlatVariable::new(42)), "_42");
} }
} }

View file

@ -19,7 +19,7 @@ use std::collections::{BTreeMap, HashMap};
use std::fmt; use std::fmt;
use types::Signature; use types::Signature;
#[derive(Serialize, Deserialize, Clone)] #[derive(Clone)]
pub struct FlatProg<T: Field> { pub struct FlatProg<T: Field> {
/// FlatFunctions of the program /// FlatFunctions of the program
pub functions: Vec<FlatFunction<T>>, pub functions: Vec<FlatFunction<T>>,
@ -71,7 +71,7 @@ impl<T: Field> From<standard::DirectiveR1CS> for FlatProg<T> {
} }
} }
#[derive(Serialize, Deserialize, Clone, PartialEq)] #[derive(Clone, PartialEq)]
pub struct FlatFunction<T: Field> { pub struct FlatFunction<T: Field> {
/// Name of the program /// Name of the program
pub id: String, pub id: String,
@ -180,7 +180,7 @@ impl<T: Field> fmt::Debug for FlatFunction<T> {
/// ///
/// * r1cs - R1CS in standard JSON data format /// * r1cs - R1CS in standard JSON data format
#[derive(Clone, Serialize, Deserialize, PartialEq)] #[derive(Clone, PartialEq)]
pub enum FlatStatement<T: Field> { pub enum FlatStatement<T: Field> {
Return(FlatExpressionList<T>), Return(FlatExpressionList<T>),
Condition(FlatExpression<T>, FlatExpression<T>), Condition(FlatExpression<T>, FlatExpression<T>),

View file

@ -492,7 +492,7 @@ impl Flattener {
.into_iter() .into_iter()
.map(|x| x.apply_direct_substitution(&replacement_map)) .map(|x| x.apply_direct_substitution(&replacement_map))
.collect(), .collect(),
} };
} }
FlatStatement::Definition(var, rhs) => { FlatStatement::Definition(var, rhs) => {
let new_var = self.issue_new_variable(); let new_var = self.issue_new_variable();

View file

@ -179,7 +179,7 @@ impl Importer {
return Err(CompileError::ImportError(Error::new(format!( return Err(CompileError::ImportError(Error::new(format!(
"Gadget {} not found", "Gadget {} not found",
s s
)))) ))));
} }
} }
} }
@ -207,7 +207,7 @@ impl Importer {
return Err(CompileError::ImportError(Error::new(format!( return Err(CompileError::ImportError(Error::new(format!(
"Packing helper {} not found", "Packing helper {} not found",
s s
)))) ))));
} }
} }
} else { } else {
@ -226,7 +226,7 @@ impl Importer {
Err(err) => return Err(CompileError::ImportError(err.into())), Err(err) => return Err(CompileError::ImportError(err.into())),
}, },
None => { None => {
return Err(Error::new("Can't resolve import without a resolver").into()) return Err(Error::new("Can't resolve import without a resolver").into());
} }
} }
} }

View file

@ -0,0 +1,174 @@
use field::Field;
use flat_absy::FlatVariable;
use num::Zero;
use std::collections::BTreeMap;
use std::fmt;
use std::ops::{Add, Sub};
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct QuadComb<T: Field> {
pub left: LinComb<T>,
pub right: LinComb<T>,
}
impl<T: Field> QuadComb<T> {
pub fn from_linear_combinations(left: LinComb<T>, right: LinComb<T>) -> Self {
QuadComb { left, right }
}
}
impl<T: Field> From<FlatVariable> for QuadComb<T> {
fn from(v: FlatVariable) -> QuadComb<T> {
LinComb::from(v).into()
}
}
impl<T: Field> fmt::Display for QuadComb<T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "({}) * ({})", self.left, self.right,)
}
}
impl<T: Field> From<LinComb<T>> for QuadComb<T> {
fn from(lc: LinComb<T>) -> QuadComb<T> {
QuadComb::from_linear_combinations(LinComb::one(), lc)
}
}
#[derive(PartialEq, PartialOrd, Clone, Eq, Ord, Hash, Debug, Serialize, Deserialize)]
pub struct LinComb<T: Field>(pub BTreeMap<FlatVariable, T>);
impl<T: Field> LinComb<T> {
pub fn summand<U: Into<T>>(mult: U, var: FlatVariable) -> LinComb<T> {
let mut res = BTreeMap::new();
res.insert(var, mult.into());
LinComb(res)
}
pub fn one() -> LinComb<T> {
Self::summand(1, FlatVariable::one())
}
}
impl<T: Field> fmt::Display for LinComb<T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(
f,
"{}",
self.0
.iter()
.map(|(k, v)| format!("{} * {}", v, k))
.collect::<Vec<_>>()
.join(" + ")
)
}
}
impl<T: Field> From<FlatVariable> for LinComb<T> {
fn from(v: FlatVariable) -> LinComb<T> {
let mut r = BTreeMap::new();
r.insert(v, T::one());
LinComb(r)
}
}
impl<T: Field> Add<LinComb<T>> for LinComb<T> {
type Output = LinComb<T>;
fn add(self, other: LinComb<T>) -> LinComb<T> {
let mut res = self.0.clone();
for (k, v) in other.0 {
let new_val = v + res.get(&k).unwrap_or(&T::zero());
if new_val == T::zero() {
res.remove(&k)
} else {
res.insert(k, new_val)
};
}
LinComb(res)
}
}
impl<T: Field> Sub<LinComb<T>> for LinComb<T> {
type Output = LinComb<T>;
fn sub(self, other: LinComb<T>) -> LinComb<T> {
let mut res = self.0.clone();
for (k, v) in other.0 {
let new_val = T::zero() - v + res.get(&k).unwrap_or(&T::zero());
if new_val == T::zero() {
res.remove(&k)
} else {
res.insert(k, new_val)
};
}
LinComb(res)
}
}
impl<T: Field> Zero for LinComb<T> {
fn zero() -> LinComb<T> {
LinComb(BTreeMap::new())
}
fn is_zero(&self) -> bool {
self.0.len() == 0
}
}
#[cfg(test)]
mod tests {
use super::*;
use field::FieldPrime;
mod linear {
use super::*;
#[test]
fn add_zero() {
let a: LinComb<FieldPrime> = LinComb::zero();
let b: LinComb<FieldPrime> = FlatVariable::new(42).into();
let c = a + b.clone();
assert_eq!(c, b);
}
#[test]
fn add() {
let a: LinComb<FieldPrime> = FlatVariable::new(42).into();
let b: LinComb<FieldPrime> = FlatVariable::new(42).into();
let c = a + b.clone();
let mut expected_map = BTreeMap::new();
expected_map.insert(FlatVariable::new(42), FieldPrime::from(2));
assert_eq!(c, LinComb(expected_map));
}
#[test]
fn sub() {
let a: LinComb<FieldPrime> = FlatVariable::new(42).into();
let b: LinComb<FieldPrime> = FlatVariable::new(42).into();
let c = a - b.clone();
assert_eq!(c, LinComb::zero());
}
}
mod quadratic {
use super::*;
#[test]
fn from_linear() {
let a: LinComb<FieldPrime> = LinComb::summand(3, FlatVariable::new(42))
+ LinComb::summand(4, FlatVariable::new(33));
let expected = QuadComb {
left: LinComb::one(),
right: a.clone(),
};
assert_eq!(QuadComb::from(a), expected);
}
#[test]
fn zero() {
let a: LinComb<FieldPrime> = LinComb::zero();
let expected: QuadComb<FieldPrime> = QuadComb {
left: LinComb::one(),
right: LinComb::zero(),
};
assert_eq!(QuadComb::from(a), expected);
}
}
}

View file

@ -0,0 +1,214 @@
use field::Field;
use flat_absy::{FlatExpression, FlatFunction, FlatProg, FlatStatement, FlatVariable};
use helpers;
use ir::{DirectiveStatement, Function, LinComb, Prog, QuadComb, Statement};
use num::Zero;
impl<T: Field> From<FlatFunction<T>> for Function<T> {
fn from(flat_function: FlatFunction<T>) -> Function<T> {
let return_expressions: Vec<FlatExpression<T>> = flat_function
.statements
.iter()
.filter_map(|s| match s {
FlatStatement::Return(el) => Some(el.expressions.clone()),
_ => None,
})
.next()
.unwrap();
Function {
id: flat_function.id,
arguments: flat_function.arguments.into_iter().map(|p| p.id).collect(),
returns: return_expressions.into_iter().map(|e| e.into()).collect(),
statements: flat_function
.statements
.into_iter()
.filter_map(|s| match s {
FlatStatement::Return(..) => None,
s => Some(s.into()),
})
.collect(),
}
}
}
impl<T: Field> From<FlatProg<T>> for Prog<T> {
fn from(flat_prog: FlatProg<T>) -> Prog<T> {
// get the main function as all calls have been resolved
let main = flat_prog
.functions
.into_iter()
.find(|f| f.id == "main")
.unwrap();
// get the interface of the program, ie which inputs are private and public
let private = main.arguments.iter().map(|p| p.private).collect();
// convert the main function to this IR for functions
let main: Function<T> = main.into();
// contrary to other functions, we need to make sure that return values are identifiers, so we define new (public) variables
let definitions =
main.returns.iter().enumerate().map(|(index, e)| {
Statement::Constraint(e.clone(), FlatVariable::public(index).into())
});
// update the main function with the extra definition statements and replace the return values
let main = Function {
returns: (0..main.returns.len())
.map(|i| FlatVariable::public(i).into())
.collect(),
statements: main.statements.into_iter().chain(definitions).collect(),
..main
};
let main = Function::from(main);
Prog { private, main }
}
}
impl<T: Field> From<FlatExpression<T>> for QuadComb<T> {
fn from(flat_expression: FlatExpression<T>) -> QuadComb<T> {
match flat_expression.is_linear() {
true => LinComb::from(flat_expression).into(),
false => match flat_expression {
FlatExpression::Mult(box e1, box e2) => {
QuadComb::from_linear_combinations(e1.into(), e2.into())
}
e => unimplemented!("{}", e),
},
}
}
}
impl<T: Field> From<FlatExpression<T>> for LinComb<T> {
fn from(flat_expression: FlatExpression<T>) -> LinComb<T> {
assert!(flat_expression.is_linear());
match flat_expression {
FlatExpression::Number(ref n) if *n == T::from(0) => LinComb::zero(),
FlatExpression::Number(n) => LinComb::summand(n, FlatVariable::one()),
FlatExpression::Identifier(id) => LinComb::from(id),
FlatExpression::Add(box e1, box e2) => LinComb::from(e1) + LinComb::from(e2),
FlatExpression::Sub(box e1, box e2) => LinComb::from(e1) - LinComb::from(e2),
FlatExpression::Mult(
box FlatExpression::Number(n1),
box FlatExpression::Identifier(v1),
)
| FlatExpression::Mult(
box FlatExpression::Identifier(v1),
box FlatExpression::Number(n1),
) => LinComb::summand(n1, v1),
e => unimplemented!("{}", e),
}
}
}
impl<T: Field> From<FlatStatement<T>> for Statement<T> {
fn from(flat_statement: FlatStatement<T>) -> Statement<T> {
match flat_statement {
FlatStatement::Condition(linear, quadratic) => match quadratic {
FlatExpression::Mult(box lhs, box rhs) => Statement::Constraint(
QuadComb::from_linear_combinations(lhs.into(), rhs.into()),
linear.into(),
),
e => Statement::Constraint(LinComb::from(e).into(), linear.into()),
},
FlatStatement::Definition(var, quadratic) => match quadratic {
FlatExpression::Mult(box lhs, box rhs) => Statement::Constraint(
QuadComb::from_linear_combinations(lhs.into(), rhs.into()),
var.into(),
),
e => Statement::Constraint(LinComb::from(e).into(), var.into()),
},
FlatStatement::Directive(ds) => Statement::Directive(ds.into()),
_ => panic!("return should be handled at the function level"),
}
}
}
impl<T: Field> From<helpers::DirectiveStatement<T>> for DirectiveStatement<T> {
fn from(ds: helpers::DirectiveStatement<T>) -> DirectiveStatement<T> {
DirectiveStatement {
inputs: ds.inputs.into_iter().map(|i| i.into()).collect(),
helper: ds.helper,
outputs: ds.outputs,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use field::FieldPrime;
#[test]
fn zero() {
// 0
let zero = FlatExpression::Number(FieldPrime::from(0));
let expected: LinComb<FieldPrime> = LinComb::zero();
assert_eq!(LinComb::from(zero), expected);
}
#[test]
fn one() {
// 1
let one = FlatExpression::Number(FieldPrime::from(1));
let expected: LinComb<FieldPrime> = FlatVariable::one().into();
assert_eq!(LinComb::from(one), expected);
}
#[test]
fn forty_two() {
// 42
let one = FlatExpression::Number(FieldPrime::from(42));
let expected: LinComb<FieldPrime> = LinComb::summand(42, FlatVariable::one());
assert_eq!(LinComb::from(one), expected);
}
#[test]
fn add() {
// x + y
let add = FlatExpression::Add(
box FlatExpression::Identifier(FlatVariable::new(42)),
box FlatExpression::Identifier(FlatVariable::new(21)),
);
let expected: LinComb<FieldPrime> =
LinComb::summand(1, FlatVariable::new(42)) + LinComb::summand(1, FlatVariable::new(21));
assert_eq!(LinComb::from(add), expected);
}
#[test]
fn linear_combination() {
// 42*x + 21*y
let add = FlatExpression::Add(
box FlatExpression::Mult(
box FlatExpression::Number(FieldPrime::from(42)),
box FlatExpression::Identifier(FlatVariable::new(42)),
),
box FlatExpression::Mult(
box FlatExpression::Number(FieldPrime::from(21)),
box FlatExpression::Identifier(FlatVariable::new(21)),
),
);
let expected: LinComb<FieldPrime> = LinComb::summand(42, FlatVariable::new(42))
+ LinComb::summand(21, FlatVariable::new(21));
assert_eq!(LinComb::from(add), expected);
}
#[test]
fn linear_combination_inverted() {
// x*42 + y*21
let add = FlatExpression::Add(
box FlatExpression::Mult(
box FlatExpression::Identifier(FlatVariable::new(42)),
box FlatExpression::Number(FieldPrime::from(42)),
),
box FlatExpression::Mult(
box FlatExpression::Identifier(FlatVariable::new(21)),
box FlatExpression::Number(FieldPrime::from(21)),
),
);
let expected: LinComb<FieldPrime> = LinComb::summand(42, FlatVariable::new(42))
+ LinComb::summand(21, FlatVariable::new(21));
assert_eq!(LinComb::from(add), expected);
}
}

View file

@ -0,0 +1,89 @@
use field::Field;
use helpers::Executable;
use ir::*;
use std::collections::BTreeMap;
impl<T: Field> Prog<T> {
pub fn execute(self, inputs: Vec<T>) -> Result<BTreeMap<FlatVariable, T>, Error<T>> {
let main = self.main;
assert_eq!(main.arguments.len(), inputs.len());
let mut witness = BTreeMap::new();
witness.insert(FlatVariable::one(), T::one());
for (arg, value) in main.arguments.iter().zip(inputs.iter()) {
witness.insert(arg.clone(), value.clone());
}
for statement in main.statements {
match statement {
Statement::Constraint(quad, lin) => match lin.is_assignee(&witness) {
true => {
let val = quad.evaluate(&witness);
witness.insert(lin.0.iter().next().unwrap().0.clone(), val);
}
false => {
let lhs_value = quad.evaluate(&witness);
let rhs_value = lin.evaluate(&witness);
if lhs_value != rhs_value {
return Err(Error::Constraint(quad, lin, lhs_value, rhs_value));
}
}
},
Statement::Directive(ref d) => {
let input_values: Vec<T> =
d.inputs.iter().map(|i| i.evaluate(&witness)).collect();
match d.helper.execute(&input_values) {
Ok(res) => {
for (i, o) in d.outputs.iter().enumerate() {
witness.insert(o.clone(), res[i].clone());
}
continue;
}
Err(_) => return Err(Error::Solver),
};
}
}
}
Ok(witness)
}
}
impl<T: Field> LinComb<T> {
fn evaluate(&self, witness: &BTreeMap<FlatVariable, T>) -> T {
self.0
.iter()
.map(|(var, val)| witness.get(var).unwrap().clone() * val)
.fold(T::from(0), |acc, t| acc + t)
}
fn is_assignee(&self, witness: &BTreeMap<FlatVariable, T>) -> bool {
self.0.iter().count() == 1
&& self.0.iter().next().unwrap().1 == &T::from(1)
&& !witness.contains_key(self.0.iter().next().unwrap().0)
}
}
impl<T: Field> QuadComb<T> {
fn evaluate(&self, witness: &BTreeMap<FlatVariable, T>) -> T {
self.left.evaluate(&witness) * self.right.evaluate(&witness)
}
}
#[derive(PartialEq, Debug)]
pub enum Error<T: Field> {
Constraint(QuadComb<T>, LinComb<T>, T, T),
Solver,
}
impl<T: Field> fmt::Display for Error<T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
Error::Constraint(ref quad, ref lin, ref left_value, ref right_value) => write!(
f,
"Expected {} to equal {}, but {} != {}",
quad, lin, left_value, right_value
),
Error::Solver => write!(f, ""),
}
}
}

270
zokrates_core/src/ir/mod.rs Normal file
View file

@ -0,0 +1,270 @@
use field::Field;
use flat_absy::flat_parameter::FlatParameter;
use flat_absy::FlatVariable;
use helpers::Helper;
use std::collections::HashMap;
use std::fmt;
use std::mem;
mod expression;
mod from_flat;
mod interpreter;
use self::expression::LinComb;
use self::expression::QuadComb;
#[derive(Debug, Serialize, Deserialize, Clone)]
pub enum Statement<T: Field> {
Constraint(QuadComb<T>, LinComb<T>),
Directive(DirectiveStatement<T>),
}
#[derive(Clone, PartialEq, Debug, Serialize, Deserialize)]
pub struct DirectiveStatement<T: Field> {
pub inputs: Vec<LinComb<T>>,
pub outputs: Vec<FlatVariable>,
pub helper: Helper,
}
impl<T: Field> fmt::Display for DirectiveStatement<T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(
f,
"# {} = {}({})",
self.outputs
.iter()
.map(|o| format!("{}", o))
.collect::<Vec<_>>()
.join(", "),
self.helper,
self.inputs
.iter()
.map(|i| format!("{}", i))
.collect::<Vec<_>>()
.join(", ")
)
}
}
impl<T: Field> fmt::Display for Statement<T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
Statement::Constraint(ref quad, ref lin) => write!(f, "{} == {}", quad, lin),
Statement::Directive(ref s) => write!(f, "{}", s),
}
}
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct Function<T: Field> {
pub id: String,
pub statements: Vec<Statement<T>>,
pub arguments: Vec<FlatVariable>,
pub returns: Vec<QuadComb<T>>,
}
impl<T: Field> fmt::Display for Function<T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(
f,
"def {}({}) -> ({}):\n{}\n\t return {}",
self.id,
self.arguments
.iter()
.map(|v| format!("{}", v))
.collect::<Vec<_>>()
.join(", "),
self.returns.len(),
self.statements
.iter()
.map(|s| format!("\t{}", s))
.collect::<Vec<_>>()
.join("\n"),
self.returns
.iter()
.map(|e| format!("{}", e))
.collect::<Vec<_>>()
.join(", ")
)
}
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct Prog<T: Field> {
pub main: Function<T>,
pub private: Vec<bool>,
}
impl<T: Field> Prog<T> {
pub fn constraint_count(&self) -> usize {
self.main
.statements
.iter()
.filter(|s| match s {
Statement::Constraint(..) => true,
_ => false,
})
.count()
}
pub fn public_arguments_count(&self) -> usize {
self.private.iter().filter(|b| !**b).count()
}
pub fn private_arguments_count(&self) -> usize {
self.private.iter().filter(|b| **b).count()
}
pub fn parameters(&self) -> Vec<FlatParameter> {
self.main
.arguments
.iter()
.zip(self.private.iter())
.map(|(id, private)| FlatParameter {
private: *private,
id: *id,
})
.collect()
}
}
impl<T: Field> fmt::Display for Prog<T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}", self.main)
}
}
/// Returns the index of `var` in `variables`, adding `var` with incremented index if it not yet exists.
///
/// # Arguments
///
/// * `variables` - A mutual map that maps all existing variables to their index.
/// * `var` - Variable to be searched for.
pub fn provide_variable_idx(
variables: &mut HashMap<FlatVariable, usize>,
var: &FlatVariable,
) -> usize {
let index = variables.len();
*variables.entry(*var).or_insert(index)
}
/// Calculates one R1CS row representation of a program and returns (V, A, B, C) so that:
/// * `V` contains all used variables and the index in the vector represents the used number in `A`, `B`, `C`
/// * `<A,x>*<B,x> = <C,x>` for a witness `x`
///
/// # Arguments
///
/// * `prog` - The program the representation is calculated for.
pub fn r1cs_program<T: Field>(
prog: Prog<T>,
) -> (
Vec<FlatVariable>,
usize,
Vec<Vec<(usize, T)>>,
Vec<Vec<(usize, T)>>,
Vec<Vec<(usize, T)>>,
) {
let mut variables: HashMap<FlatVariable, usize> = HashMap::new();
provide_variable_idx(&mut variables, &FlatVariable::one());
for x in prog
.main
.arguments
.iter()
.enumerate()
.filter(|(index, _)| !prog.private[*index])
{
provide_variable_idx(&mut variables, &x.1);
}
//Only the main function is relevant in this step, since all calls to other functions were resolved during flattening
let main = prog.main;
//~out are added after main's arguments as we want variables (columns)
//in the r1cs to be aligned like "public inputs | private inputs"
let main_return_count = main.returns.len();
for i in 0..main_return_count {
provide_variable_idx(&mut variables, &FlatVariable::public(i));
}
// position where private part of witness starts
let private_inputs_offset = variables.len();
// first pass through statements to populate `variables`
for (quad, lin) in main.statements.iter().filter_map(|s| match s {
Statement::Constraint(quad, lin) => Some((quad, lin)),
Statement::Directive(..) => None,
}) {
for (k, _) in &quad.left.0 {
provide_variable_idx(&mut variables, &k);
}
for (k, _) in &quad.right.0 {
provide_variable_idx(&mut variables, &k);
}
for (k, _) in &lin.0 {
provide_variable_idx(&mut variables, &k);
}
}
let mut a = vec![];
let mut b = vec![];
let mut c = vec![];
// second pass to convert program to raw sparse vectors
for (quad, lin) in main.statements.into_iter().filter_map(|s| match s {
Statement::Constraint(quad, lin) => Some((quad, lin)),
Statement::Directive(..) => None,
}) {
a.push(
quad.left
.0
.into_iter()
.map(|(k, v)| (variables.get(&k).unwrap().clone(), v))
.collect(),
);
b.push(
quad.right
.0
.into_iter()
.map(|(k, v)| (variables.get(&k).unwrap().clone(), v))
.collect(),
);
c.push(
lin.0
.into_iter()
.map(|(k, v)| (variables.get(&k).unwrap().clone(), v))
.collect(),
);
}
// Convert map back into list ordered by index
let mut variables_list = vec![FlatVariable::new(0); variables.len()];
for (k, v) in variables.drain() {
assert_eq!(variables_list[v], FlatVariable::new(0));
mem::replace(&mut variables_list[v], k);
}
(variables_list, private_inputs_offset, a, b, c)
}
#[cfg(test)]
mod tests {
use super::*;
use field::FieldPrime;
mod statement {
use super::*;
#[test]
fn print_constraint() {
let c: Statement<FieldPrime> = Statement::Constraint(
QuadComb::from_linear_combinations(
FlatVariable::new(42).into(),
FlatVariable::new(42).into(),
),
FlatVariable::new(42).into(),
);
assert_eq!(format!("{}", c), "(1 * _42) * (1 * _42) == 1 * _42")
}
}
}

View file

@ -29,8 +29,8 @@ pub mod absy;
pub mod compile; pub mod compile;
pub mod field; pub mod field;
pub mod flat_absy; pub mod flat_absy;
pub mod ir;
#[cfg(feature = "libsnark")] #[cfg(feature = "libsnark")]
pub mod libsnark; pub mod libsnark;
#[cfg(feature = "libsnark")] #[cfg(feature = "libsnark")]
pub mod proof_system; pub mod proof_system;
pub mod r1cs;

View file

@ -362,14 +362,14 @@ pub fn parse_function_call<T: Field>(
p = p2; p = p2;
} }
(Token::Close, s2, p2) => { (Token::Close, s2, p2) => {
return parse_term1(Expression::FunctionCall(ide, args), s2, p2) return parse_term1(Expression::FunctionCall(ide, args), s2, p2);
} }
(t2, _, p2) => { (t2, _, p2) => {
return Err(Error { return Err(Error {
expected: vec![Token::Comma, Token::Close], expected: vec![Token::Comma, Token::Close],
got: t2, got: t2,
pos: p2, pos: p2,
}) });
} }
} }
} }
@ -417,7 +417,7 @@ pub fn parse_inline_array<T: Field>(
expected: vec![Token::Comma, Token::RightBracket], expected: vec![Token::Comma, Token::RightBracket],
got: t2, got: t2,
pos: p2, pos: p2,
}) });
} }
} }
} }

View file

@ -42,7 +42,7 @@ fn parse_function_header<T: Field>(
expected: vec![Token::Close], expected: vec![Token::Close],
got: t3, got: t3,
pos: p3, pos: p3,
}) });
} }
}, },
Err(e) => return Err(e), Err(e) => return Err(e),
@ -52,7 +52,7 @@ fn parse_function_header<T: Field>(
expected: vec![Token::Open], expected: vec![Token::Open],
got: t1, got: t1,
pos: p1, pos: p1,
}) });
} }
}?; }?;
@ -67,7 +67,7 @@ fn parse_function_header<T: Field>(
expected: vec![Token::Close], expected: vec![Token::Close],
got: t3, got: t3,
pos: p3, pos: p3,
}) });
} }
}, },
Err(e) => return Err(e), Err(e) => return Err(e),
@ -77,7 +77,7 @@ fn parse_function_header<T: Field>(
expected: vec![Token::Open], expected: vec![Token::Open],
got: t1, got: t1,
pos: p1, pos: p1,
}) });
} }
}, },
(t0, _, p0) => { (t0, _, p0) => {
@ -85,7 +85,7 @@ fn parse_function_header<T: Field>(
expected: vec![Token::Arrow], expected: vec![Token::Arrow],
got: t0, got: t0,
pos: p0, pos: p0,
}) });
} }
}?; }?;
@ -103,7 +103,7 @@ fn parse_function_header<T: Field>(
expected: vec![Token::Unknown("".to_string())], expected: vec![Token::Unknown("".to_string())],
got: t6, got: t6,
pos: p6, pos: p6,
}) });
} }
}, },
(t5, _, p5) => { (t5, _, p5) => {
@ -111,7 +111,7 @@ fn parse_function_header<T: Field>(
expected: vec![Token::Colon], expected: vec![Token::Colon],
got: t5, got: t5,
pos: p5, pos: p5,
}) });
} }
} }
} }
@ -167,7 +167,7 @@ fn parse_function_arguments<T: Field>(
expected: vec![Token::Comma, Token::Close], expected: vec![Token::Comma, Token::Close],
got: t3, got: t3,
pos: p3, pos: p3,
}) });
} }
} }
} }
@ -188,7 +188,7 @@ fn parse_function_arguments<T: Field>(
expected: vec![Token::Comma, Token::Close], expected: vec![Token::Comma, Token::Close],
got: t3, got: t3,
pos: p3, pos: p3,
}) });
} }
} }
} }
@ -202,7 +202,7 @@ fn parse_function_arguments<T: Field>(
], ],
got: t4, got: t4,
pos: p4, pos: p4,
}) });
} }
} }
} }
@ -231,7 +231,7 @@ fn parse_function_return_types<T: Field>(
expected: vec![Token::Comma, Token::Close], expected: vec![Token::Comma, Token::Close],
got: t3, got: t3,
pos: p3, pos: p3,
}) });
} }
} }
} }
@ -245,7 +245,7 @@ fn parse_function_return_types<T: Field>(
], ],
got: t4, got: t4,
pos: p4, pos: p4,
}) });
} }
} }
} }

View file

@ -16,14 +16,14 @@ pub fn parse_import<T: Field>(
(Token::Path(code_path), s2, p2) => match next_token::<T>(&s2, &p2) { (Token::Path(code_path), s2, p2) => match next_token::<T>(&s2, &p2) {
(Token::As, s3, p3) => match next_token(&s3, &p3) { (Token::As, s3, p3) => match next_token(&s3, &p3) {
(Token::Ide(id), _, p4) => { (Token::Ide(id), _, p4) => {
return Ok((Import::new_with_alias(code_path, &id), p4)) return Ok((Import::new_with_alias(code_path, &id), p4));
} }
(t4, _, p4) => { (t4, _, p4) => {
return Err(Error { return Err(Error {
expected: vec![Token::Ide("ide".to_string())], expected: vec![Token::Ide("ide".to_string())],
got: t4, got: t4,
pos: p4, pos: p4,
}) });
} }
}, },
(Token::Unknown(_), _, p3) => return Ok((Import::new(code_path), p3)), (Token::Unknown(_), _, p3) => return Ok((Import::new(code_path), p3)),
@ -35,7 +35,7 @@ pub fn parse_import<T: Field>(
], ],
got: t3, got: t3,
pos: p3, pos: p3,
}) });
} }
}, },
(t2, _, p2) => Err(Error { (t2, _, p2) => Err(Error {

View file

@ -47,7 +47,7 @@ pub fn parse_program<T: Field, R: BufRead>(reader: &mut R) -> Result<Prog<T>, Er
expected: vec![Token::Def], expected: vec![Token::Def],
got: t1, got: t1,
pos: p1, pos: p1,
}) });
} }
}, },
None => break, None => break,

View file

@ -1,618 +0,0 @@
//! Module containing necessary functions to convert a flattened program or expression to r1cs.
//!
//! @file r1cs.rs
//! @author Dennis Kuhnert <dennis.kuhnert@campus.tu-berlin.de>
//! @author Jacob Eberhardt <jacob.eberhardt@tu-berlin.de
//! @date 2017
use field::Field;
use flat_absy::flat_variable::FlatVariable;
use flat_absy::FlatExpression::*;
use flat_absy::*;
use std::collections::HashMap;
use std::mem;
/// Returns a vector of summands of the given `FlatExpression`.
///
/// # Arguments
///
/// * `expr` - `FlatExpression` to be split to summands.
///
/// # Example
///
/// a + 2*b + (c - d) -> [a, 2*b, c-d]
fn get_summands<T: Field>(expr: &FlatExpression<T>) -> Vec<&FlatExpression<T>> {
let mut trace = Vec::new();
let mut add = Vec::new();
trace.push(expr);
loop {
if let Some(e) = trace.pop() {
match *e {
ref e @ Number(_) | ref e @ Identifier(_) | ref e @ Mult(..) | ref e @ Sub(..)
if e.is_linear() =>
{
add.push(e)
}
Add(ref l, ref r) => {
trace.push(l);
trace.push(r);
}
ref e => panic!("Not covered: {}", e),
}
} else {
return add;
}
}
}
/// Returns a `HashMap` containing variables and the number of occurrences
///
/// # Arguments
///
/// * `expr` - FlatExpression only containing Numbers, Variables, Add and Mult
///
/// # Example
///
/// `7 * x + 4 * y + x` -> { x => 8, y = 4 }
fn count_variables_add<T: Field>(expr: &FlatExpression<T>) -> HashMap<FlatVariable, T> {
let summands = get_summands(expr);
let mut count = HashMap::new();
for s in summands {
match *s {
Number(ref x) => {
let num = count.entry(FlatVariable::one()).or_insert(T::zero());
*num = num.clone() + x;
}
Identifier(ref v) => {
let num = count.entry(*v).or_insert(T::zero());
*num = num.clone() + T::one();
}
Mult(box Number(ref x1), box Number(ref x2)) => {
let num = count.entry(FlatVariable::one()).or_insert(T::zero());
*num = num.clone() + x1 + x2;
}
Mult(box Number(ref x), box Identifier(ref v))
| Mult(box Identifier(ref v), box Number(ref x)) => {
let num = count.entry(*v).or_insert(T::zero());
*num = num.clone() + x;
}
ref e => panic!("Not covered: {}", e),
}
}
count
}
/// Returns an equation equivalent to `lhs == rhs` only using `Add` and `Mult`
///
/// # Arguments
///
/// * `lhs` - Left hand side of the equation
/// * `rhs` - Right hand side of the equation
fn swap_sub<T: Field>(
lhs: &FlatExpression<T>,
rhs: &FlatExpression<T>,
) -> (FlatExpression<T>, FlatExpression<T>) {
let mut left = get_summands(lhs);
let mut right = get_summands(rhs);
let mut run = true;
while run {
run = false;
for i in 0..left.len() {
match *left[i] {
ref e @ Number(_) | ref e @ Identifier(_) | ref e @ Mult(..) if e.is_linear() => {}
Sub(ref l, ref r) => {
run = true;
left.swap_remove(i);
left.extend(get_summands(l));
right.extend(get_summands(r));
}
ref e => panic!("Unexpected: {}", e),
}
}
for i in 0..right.len() {
match *right[i] {
ref e @ Number(_) | ref e @ Identifier(_) | ref e @ Mult(..) if e.is_linear() => {}
Sub(ref l, ref r) => {
run = true;
right.swap_remove(i);
right.extend(get_summands(l));
left.extend(get_summands(r));
}
ref e => panic!("Unexpected: {}", e),
}
}
}
if let Some(left_init) = left.pop() {
if let Some(right_init) = right.pop() {
return (
left.iter()
.fold(left_init.clone(), |acc, &x| Add(box acc, box x.clone())),
right
.iter()
.fold(right_init.clone(), |acc, &x| Add(box acc, box x.clone())),
);
}
}
panic!("Unexpected");
}
/// Calculates one R1CS row representation for `linear_expr` = `expr`.
/// (<C,x> = <a,x>*<B,x>)
///
/// # Arguments
///
/// * `linear_expr` - Left hand side of the equation, has to be linear
/// * `expr` - Right hand side of the equation
/// * `variables` - a mutual vector that contains all existing variables. Not found variables will be added.
/// * `a_row` - Result row of matrix a
/// * `b_row` - Result row of matrix B
/// * `c_row` - Result row of matrix C
fn r1cs_expression<T: Field>(
linear_expr: FlatExpression<T>,
expr: FlatExpression<T>,
variables: &mut HashMap<FlatVariable, usize>,
a_row: &mut Vec<(usize, T)>,
b_row: &mut Vec<(usize, T)>,
c_row: &mut Vec<(usize, T)>,
) {
assert!(linear_expr.is_linear());
match expr {
e @ Add(..) | e @ Sub(..) => {
let (lhs, rhs) = swap_sub(&linear_expr, &e);
for (key, value) in count_variables_add(&rhs) {
a_row.push((provide_variable_idx(variables, &key), value));
}
b_row.push((0, T::one()));
for (key, value) in count_variables_add(&lhs) {
c_row.push((provide_variable_idx(variables, &key), value));
}
}
Mult(lhs, rhs) => {
match lhs {
box Number(x) => a_row.push((0, x)),
box Identifier(x) => a_row.push((provide_variable_idx(variables, &x), T::one())),
box e @ Add(..) => {
for (key, value) in count_variables_add(&e) {
a_row.push((provide_variable_idx(variables, &key), value));
}
}
e @ _ => panic!("Not flattened: {}", e),
};
match rhs {
box Number(x) => b_row.push((0, x)),
box Identifier(x) => b_row.push((provide_variable_idx(variables, &x), T::one())),
box e @ Add(..) => {
for (key, value) in count_variables_add(&e) {
b_row.push((provide_variable_idx(variables, &key), value));
}
}
e @ _ => panic!("Not flattened: {}", e),
};
for (key, value) in count_variables_add(&linear_expr) {
c_row.push((provide_variable_idx(variables, &key), value));
}
}
Identifier(var) => {
a_row.push((provide_variable_idx(variables, &var), T::one()));
b_row.push((0, T::one()));
for (key, value) in count_variables_add(&linear_expr) {
c_row.push((provide_variable_idx(variables, &key), value));
}
}
Number(x) => {
a_row.push((0, x));
b_row.push((0, T::one()));
for (key, value) in count_variables_add(&linear_expr) {
c_row.push((provide_variable_idx(variables, &key), value));
}
}
}
}
/// Returns the index of `var` in `variables`, adding `var` with incremented index if it not yet exists.
///
/// # Arguments
///
/// * `variables` - A mutual map that maps all existing variables to their index.
/// * `var` - Variable to be searched for.
fn provide_variable_idx(variables: &mut HashMap<FlatVariable, usize>, var: &FlatVariable) -> usize {
let index = variables.len();
*variables.entry(*var).or_insert(index)
}
/// Calculates one R1CS row representation of a program and returns (V, A, B, C) so that:
/// * `V` contains all used variables and the index in the vector represents the used number in `A`, `B`, `C`
/// * `<A,x>*<B,x> = <C,x>` for a witness `x`
///
/// # Arguments
///
/// * `prog` - The program the representation is calculated for.
pub fn r1cs_program<T: Field>(
prog: &FlatProg<T>,
) -> (
Vec<FlatVariable>,
usize,
Vec<Vec<(usize, T)>>,
Vec<Vec<(usize, T)>>,
Vec<Vec<(usize, T)>>,
) {
let mut variables: HashMap<FlatVariable, usize> = HashMap::new();
provide_variable_idx(&mut variables, &FlatVariable::one());
let mut a: Vec<Vec<(usize, T)>> = Vec::new();
let mut b: Vec<Vec<(usize, T)>> = Vec::new();
let mut c: Vec<Vec<(usize, T)>> = Vec::new();
//Only the main function is relevant in this step, since all calls to other functions were resolved during flattening
let main = prog
.clone()
.functions
.into_iter()
.find(|x: &FlatFunction<T>| x.id == "main".to_string())
.unwrap();
for x in main.arguments.iter().filter(|x| !x.private) {
provide_variable_idx(&mut variables, &x.id);
}
// ~out is added after main's arguments as we want variables (columns)
// in the r1cs to be aligned like "public inputs | private inputs"
let main_return_count = main
.signature
.outputs
.iter()
.map(|t| t.get_primitive_count())
.fold(0, |acc, x| acc + x);
for i in 0..main_return_count {
provide_variable_idx(&mut variables, &FlatVariable::public(i));
}
// position where private part of witness starts
let private_inputs_offset = variables.len();
for def in &main.statements {
match *def {
FlatStatement::Return(ref list) => {
for (i, val) in list.expressions.iter().enumerate() {
let mut a_row = Vec::new();
let mut b_row = Vec::new();
let mut c_row = Vec::new();
r1cs_expression(
Identifier(FlatVariable::public(i)),
val.clone(),
&mut variables,
&mut a_row,
&mut b_row,
&mut c_row,
);
a.push(a_row);
b.push(b_row);
c.push(c_row);
}
}
FlatStatement::Definition(ref id, ref rhs) => {
let mut a_row = Vec::new();
let mut b_row = Vec::new();
let mut c_row = Vec::new();
r1cs_expression(
FlatExpression::Identifier(*id),
rhs.clone(),
&mut variables,
&mut a_row,
&mut b_row,
&mut c_row,
);
a.push(a_row);
b.push(b_row);
c.push(c_row);
}
FlatStatement::Condition(ref expr1, ref expr2) => {
let mut a_row = Vec::new();
let mut b_row = Vec::new();
let mut c_row = Vec::new();
r1cs_expression(
expr1.clone(),
expr2.clone(),
&mut variables,
&mut a_row,
&mut b_row,
&mut c_row,
);
a.push(a_row);
b.push(b_row);
c.push(c_row);
}
FlatStatement::Directive(..) => continue,
}
}
// Convert map back into list ordered by index
let mut variables_list = vec![FlatVariable::new(0); variables.len()];
for (k, v) in variables.drain() {
assert_eq!(variables_list[v], FlatVariable::new(0));
mem::replace(&mut variables_list[v], k);
}
(variables_list, private_inputs_offset, a, b, c)
}
#[cfg(test)]
mod tests {
use super::*;
use field::FieldPrime;
use std::cmp::Ordering;
/// Sort function for tuples `(x, y)` which sorts based on `x` first.
/// If `x` is equal, `y` is used for comparison.
fn sort_tup<A: Ord, B: Ord>(a: &(A, B), b: &(A, B)) -> Ordering {
if a.0 == b.0 {
a.1.cmp(&b.1)
} else {
a.0.cmp(&b.0)
}
}
#[cfg(test)]
mod r1cs_expression {
use super::*;
#[test]
fn add() {
// x = y + 5
let one = FlatVariable::one();
let x = FlatVariable::new(0);
let y = FlatVariable::new(1);
let lhs = Identifier(x);
let rhs = Add(box Identifier(y), box Number(FieldPrime::from(5)));
let mut variables: HashMap<FlatVariable, usize> = HashMap::new();
variables.insert(one, 0);
variables.insert(x, 1);
variables.insert(y, 2);
let mut a_row: Vec<(usize, FieldPrime)> = Vec::new();
let mut b_row: Vec<(usize, FieldPrime)> = Vec::new();
let mut c_row: Vec<(usize, FieldPrime)> = Vec::new();
r1cs_expression(lhs, rhs, &mut variables, &mut a_row, &mut b_row, &mut c_row);
a_row.sort_by(sort_tup);
b_row.sort_by(sort_tup);
c_row.sort_by(sort_tup);
assert_eq!(
vec![(0, FieldPrime::from(5)), (2, FieldPrime::from(1))],
a_row
);
assert_eq!(vec![(0, FieldPrime::from(1))], b_row);
assert_eq!(vec![(1, FieldPrime::from(1))], c_row);
}
#[test]
fn add_sub_mix() {
// (x + y) - ((z + 3*x) - y) == (x - y) + ((2*x - 4*y) + (4*y - 2*z))
// --> (x + y) + y + 4y + 2z + y == x + 2x + 4y + (z + 3x)
// <=> x + 7*y + 2*z == 6*x + 4y + z
let one = FlatVariable::one();
let x = FlatVariable::new(0);
let y = FlatVariable::new(1);
let z = FlatVariable::new(2);
let lhs = Sub(
box Add(box Identifier(x), box Identifier(y)),
box Sub(
box Add(
box Identifier(z),
box Mult(box Number(FieldPrime::from(3)), box Identifier(x)),
),
box Identifier(y),
),
);
let rhs = Add(
box Sub(box Identifier(x), box Identifier(y)),
box Add(
box Sub(
box Mult(box Number(FieldPrime::from(2)), box Identifier(x)),
box Mult(box Number(FieldPrime::from(4)), box Identifier(y)),
),
box Sub(
box Mult(box Number(FieldPrime::from(4)), box Identifier(y)),
box Mult(box Number(FieldPrime::from(2)), box Identifier(z)),
),
),
);
let mut variables: HashMap<FlatVariable, usize> = HashMap::new();
variables.insert(one, 0);
variables.insert(x, 1);
variables.insert(y, 2);
variables.insert(z, 3);
let mut a_row: Vec<(usize, FieldPrime)> = Vec::new();
let mut b_row: Vec<(usize, FieldPrime)> = Vec::new();
let mut c_row: Vec<(usize, FieldPrime)> = Vec::new();
r1cs_expression(lhs, rhs, &mut variables, &mut a_row, &mut b_row, &mut c_row);
a_row.sort_by(sort_tup);
b_row.sort_by(sort_tup);
c_row.sort_by(sort_tup);
assert_eq!(
vec![
(1, FieldPrime::from(6)),
(2, FieldPrime::from(4)),
(3, FieldPrime::from(1)),
],
a_row
);
assert_eq!(vec![(0, FieldPrime::from(1))], b_row);
assert_eq!(
vec![
(1, FieldPrime::from(1)),
(2, FieldPrime::from(7)),
(3, FieldPrime::from(2)),
],
c_row
);
}
#[test]
fn sub() {
// 7 * x + y == 3 * y - z * 6
let one = FlatVariable::one();
let x = FlatVariable::new(0);
let y = FlatVariable::new(1);
let z = FlatVariable::new(2);
let lhs = Add(
box Mult(box Number(FieldPrime::from(7)), box Identifier(x)),
box Identifier(y),
);
let rhs = Sub(
box Mult(box Number(FieldPrime::from(3)), box Identifier(y)),
box Mult(box Identifier(z), box Number(FieldPrime::from(6))),
);
let mut variables: HashMap<FlatVariable, usize> = HashMap::new();
variables.insert(one, 0);
variables.insert(x, 1);
variables.insert(y, 2);
variables.insert(z, 3);
let mut a_row: Vec<(usize, FieldPrime)> = Vec::new();
let mut b_row: Vec<(usize, FieldPrime)> = Vec::new();
let mut c_row: Vec<(usize, FieldPrime)> = Vec::new();
r1cs_expression(lhs, rhs, &mut variables, &mut a_row, &mut b_row, &mut c_row);
a_row.sort_by(sort_tup);
b_row.sort_by(sort_tup);
c_row.sort_by(sort_tup);
assert_eq!(vec![(2, FieldPrime::from(3))], a_row); // 3 * y
assert_eq!(vec![(0, FieldPrime::from(1))], b_row); // 1
assert_eq!(
vec![
(1, FieldPrime::from(7)),
(2, FieldPrime::from(1)),
(3, FieldPrime::from(6)),
],
c_row
); // (7 * x + y) + z * 6
}
#[test]
fn sub_multiple() {
// (((3 * y) - (z * 2)) - (x * 12)) == (a - x)
// --> 3*y + x == a + 12*x + 2*z
let one = FlatVariable::one();
let x = FlatVariable::new(0);
let y = FlatVariable::new(1);
let z = FlatVariable::new(2);
let a = FlatVariable::new(3);
let lhs = Sub(
box Sub(
box Mult(box Number(FieldPrime::from(3)), box Identifier(y)),
box Mult(box Identifier(z), box Number(FieldPrime::from(2))),
),
box Mult(box Identifier(x), box Number(FieldPrime::from(12))),
);
let rhs = Sub(box Identifier(a), box Identifier(x));
let mut variables: HashMap<FlatVariable, usize> = HashMap::new();
variables.insert(one, 0);
variables.insert(x, 1);
variables.insert(y, 2);
variables.insert(z, 3);
variables.insert(a, 4);
let mut a_row: Vec<(usize, FieldPrime)> = Vec::new();
let mut b_row: Vec<(usize, FieldPrime)> = Vec::new();
let mut c_row: Vec<(usize, FieldPrime)> = Vec::new();
r1cs_expression(lhs, rhs, &mut variables, &mut a_row, &mut b_row, &mut c_row);
a_row.sort_by(sort_tup);
b_row.sort_by(sort_tup);
c_row.sort_by(sort_tup);
assert_eq!(
vec![
(1, FieldPrime::from(12)),
(3, FieldPrime::from(2)),
(4, FieldPrime::from(1)),
],
a_row
); // a + 12*x + 2*z
assert_eq!(vec![(0, FieldPrime::from(1))], b_row); // 1
assert_eq!(
vec![(1, FieldPrime::from(1)), (2, FieldPrime::from(3))],
c_row
); // 3*y + x
}
#[test]
fn add_mult() {
// 4 * y + 3 * x + 3 * z == (3 * x + 6 * y + 4 * z) * (31 * x + 4 * z)
let one = FlatVariable::one();
let x = FlatVariable::new(0);
let y = FlatVariable::new(1);
let z = FlatVariable::new(2);
let lhs = Add(
box Add(
box Mult(box Number(FieldPrime::from(4)), box Identifier(y)),
box Mult(box Number(FieldPrime::from(3)), box Identifier(x)),
),
box Mult(box Number(FieldPrime::from(3)), box Identifier(z)),
);
let rhs = Mult(
box Add(
box Add(
box Mult(box Number(FieldPrime::from(3)), box Identifier(x)),
box Mult(box Number(FieldPrime::from(6)), box Identifier(y)),
),
box Mult(box Number(FieldPrime::from(4)), box Identifier(z)),
),
box Add(
box Mult(box Number(FieldPrime::from(31)), box Identifier(x)),
box Mult(box Number(FieldPrime::from(4)), box Identifier(z)),
),
);
let mut variables: HashMap<FlatVariable, usize> = HashMap::new();
variables.insert(one, 0);
variables.insert(x, 1);
variables.insert(y, 2);
variables.insert(z, 3);
let mut a_row: Vec<(usize, FieldPrime)> = Vec::new();
let mut b_row: Vec<(usize, FieldPrime)> = Vec::new();
let mut c_row: Vec<(usize, FieldPrime)> = Vec::new();
r1cs_expression(lhs, rhs, &mut variables, &mut a_row, &mut b_row, &mut c_row);
a_row.sort_by(sort_tup);
b_row.sort_by(sort_tup);
c_row.sort_by(sort_tup);
assert_eq!(
vec![
(1, FieldPrime::from(3)),
(2, FieldPrime::from(6)),
(3, FieldPrime::from(4)),
],
a_row
);
assert_eq!(
vec![(1, FieldPrime::from(31)), (3, FieldPrime::from(4))],
b_row
);
assert_eq!(
vec![
(1, FieldPrime::from(3)),
(2, FieldPrime::from(4)),
(3, FieldPrime::from(3)),
],
c_row
);
}
}
}

View file

@ -201,7 +201,7 @@ impl Checker {
"Duplicate definition for function {} with signature {}", "Duplicate definition for function {} with signature {}",
funct.id, funct.signature funct.id, funct.signature
), ),
}) });
} }
0 => {} 0 => {}
_ => panic!("duplicate function declaration should have been caught"), _ => panic!("duplicate function declaration should have been caught"),

View file

@ -19,7 +19,7 @@ use types::Type;
pub use self::folder::Folder; pub use self::folder::Folder;
#[derive(Serialize, Deserialize, Clone, PartialEq)] #[derive(Clone, PartialEq)]
pub struct TypedProg<T: Field> { pub struct TypedProg<T: Field> {
/// Functions of the program /// Functions of the program
pub functions: Vec<TypedFunction<T>>, pub functions: Vec<TypedFunction<T>>,
@ -76,7 +76,7 @@ impl<T: Field> fmt::Debug for TypedProg<T> {
} }
} }
#[derive(Serialize, Deserialize, Clone, PartialEq)] #[derive(Clone, PartialEq)]
pub struct TypedFunction<T: Field> { pub struct TypedFunction<T: Field> {
/// Name of the program /// Name of the program
pub id: String, pub id: String,
@ -130,7 +130,7 @@ impl<T: Field> fmt::Debug for TypedFunction<T> {
} }
} }
#[derive(Clone, Serialize, Deserialize, PartialEq, Hash, Eq)] #[derive(Clone, PartialEq, Hash, Eq)]
pub enum TypedAssignee<T: Field> { pub enum TypedAssignee<T: Field> {
Identifier(Variable), Identifier(Variable),
ArrayElement(Box<TypedAssignee<T>>, Box<FieldElementExpression<T>>), ArrayElement(Box<TypedAssignee<T>>, Box<FieldElementExpression<T>>),
@ -166,7 +166,7 @@ impl<T: Field> fmt::Display for TypedAssignee<T> {
} }
} }
#[derive(Clone, Serialize, Deserialize, PartialEq)] #[derive(Clone, PartialEq)]
pub enum TypedStatement<T: Field> { pub enum TypedStatement<T: Field> {
Return(Vec<TypedExpression<T>>), Return(Vec<TypedExpression<T>>),
Definition(TypedAssignee<T>, TypedExpression<T>), Definition(TypedAssignee<T>, TypedExpression<T>),
@ -250,7 +250,7 @@ pub trait Typed {
fn get_type(&self) -> Type; fn get_type(&self) -> Type;
} }
#[derive(Clone, PartialEq, Serialize, Deserialize, Hash, Eq)] #[derive(Clone, PartialEq, Hash, Eq)]
pub enum TypedExpression<T: Field> { pub enum TypedExpression<T: Field> {
Boolean(BooleanExpression<T>), Boolean(BooleanExpression<T>),
FieldElement(FieldElementExpression<T>), FieldElement(FieldElementExpression<T>),
@ -319,7 +319,7 @@ pub trait MultiTyped {
fn get_types(&self) -> &Vec<Type>; fn get_types(&self) -> &Vec<Type>;
} }
#[derive(Clone, PartialEq, Serialize, Deserialize)] #[derive(Clone, PartialEq)]
pub enum TypedExpressionList<T: Field> { pub enum TypedExpressionList<T: Field> {
FunctionCall(String, Vec<TypedExpression<T>>, Vec<Type>), FunctionCall(String, Vec<TypedExpression<T>>, Vec<Type>),
} }
@ -332,7 +332,7 @@ impl<T: Field> MultiTyped for TypedExpressionList<T> {
} }
} }
#[derive(Clone, PartialEq, Serialize, Deserialize, Hash, Eq)] #[derive(Clone, PartialEq, Hash, Eq)]
pub enum FieldElementExpression<T: Field> { pub enum FieldElementExpression<T: Field> {
Number(T), Number(T),
Identifier(String), Identifier(String),
@ -368,7 +368,7 @@ pub enum FieldElementExpression<T: Field> {
), ),
} }
#[derive(Clone, PartialEq, Serialize, Deserialize, Hash, Eq)] #[derive(Clone, PartialEq, Hash, Eq)]
pub enum BooleanExpression<T: Field> { pub enum BooleanExpression<T: Field> {
Identifier(String), Identifier(String),
Value(bool), Value(bool),
@ -398,7 +398,7 @@ pub enum BooleanExpression<T: Field> {
} }
// for now we store the array size in the variants // for now we store the array size in the variants
#[derive(Clone, PartialEq, Serialize, Deserialize, Hash, Eq)] #[derive(Clone, PartialEq, Hash, Eq)]
pub enum FieldElementArrayExpression<T: Field> { pub enum FieldElementArrayExpression<T: Field> {
Identifier(usize, String), Identifier(usize, String),
Value(usize, Vec<FieldElementExpression<T>>), Value(usize, Vec<FieldElementExpression<T>>),