Merge pull request #83 from Schaeff/remove-compiler-statements
Refactor Compiler statements to rust Directives
This commit is contained in:
commit
3c36637521
6 changed files with 59 additions and 107 deletions
57
src/absy.rs
57
src/absy.rs
|
@ -5,10 +5,7 @@
|
|||
//! @author Jacob Eberhardt <jacob.eberhardt@tu-berlin.de>
|
||||
//! @date 2017
|
||||
|
||||
const BINARY_SEPARATOR: &str = "_b";
|
||||
|
||||
use std::fmt;
|
||||
use std::collections::{BTreeMap};
|
||||
use substitution::Substitution;
|
||||
use field::Field;
|
||||
use imports::Import;
|
||||
|
@ -235,50 +232,6 @@ impl<T: Field> Expression<T> {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn solve(&self, inputs: &mut BTreeMap<String, T>) -> T {
|
||||
match *self {
|
||||
Expression::Number(ref x) => x.clone(),
|
||||
Expression::Identifier(ref var) => {
|
||||
if let None = inputs.get(var) {
|
||||
if var.contains(BINARY_SEPARATOR) {
|
||||
let var_name = var.split(BINARY_SEPARATOR).collect::<Vec<_>>()[0];
|
||||
let mut num = inputs[var_name].clone();
|
||||
let bits = T::get_required_bits();
|
||||
for i in (0..bits).rev() {
|
||||
if T::from(2).pow(i) <= num {
|
||||
num = num - T::from(2).pow(i);
|
||||
inputs.insert(format!("{}{}{}", &var_name, BINARY_SEPARATOR, i), T::one());
|
||||
} else {
|
||||
inputs.insert(format!("{}{}{}", &var_name, BINARY_SEPARATOR, i), T::zero());
|
||||
}
|
||||
}
|
||||
assert_eq!(num, T::zero());
|
||||
} else {
|
||||
panic!(
|
||||
"Variable {:?} is undeclared in inputs: {:?}",
|
||||
var,
|
||||
inputs
|
||||
);
|
||||
}
|
||||
}
|
||||
inputs[var].clone()
|
||||
}
|
||||
Expression::Add(ref x, ref y) => x.solve(inputs) + y.solve(inputs),
|
||||
Expression::Sub(ref x, ref y) => x.solve(inputs) - y.solve(inputs),
|
||||
Expression::Mult(ref x, ref y) => x.solve(inputs) * y.solve(inputs),
|
||||
Expression::Div(ref x, ref y) => x.solve(inputs) / y.solve(inputs),
|
||||
Expression::Pow(ref x, ref y) => x.solve(inputs).pow(y.solve(inputs)),
|
||||
Expression::IfElse(ref condition, ref consequent, ref alternative) => {
|
||||
if condition.solve(inputs) {
|
||||
consequent.solve(inputs)
|
||||
} else {
|
||||
alternative.solve(inputs)
|
||||
}
|
||||
}
|
||||
Expression::FunctionCall(_, _) => unimplemented!(), // should not happen, since never part of flattened functions
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_linear(&self) -> bool {
|
||||
match *self {
|
||||
Expression::Number(_) | Expression::Identifier(_) => true,
|
||||
|
@ -427,16 +380,6 @@ impl<T: Field> Condition<T> {
|
|||
),
|
||||
}
|
||||
}
|
||||
|
||||
fn solve(&self, inputs: &mut BTreeMap<String, T>) -> bool {
|
||||
match *self {
|
||||
Condition::Lt(ref lhs, ref rhs) => lhs.solve(inputs) < rhs.solve(inputs),
|
||||
Condition::Le(ref lhs, ref rhs) => lhs.solve(inputs) <= rhs.solve(inputs),
|
||||
Condition::Eq(ref lhs, ref rhs) => lhs.solve(inputs) == rhs.solve(inputs),
|
||||
Condition::Ge(ref lhs, ref rhs) => lhs.solve(inputs) >= rhs.solve(inputs),
|
||||
Condition::Gt(ref lhs, ref rhs) => lhs.solve(inputs) > rhs.solve(inputs),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Field> fmt::Display for Condition<T> {
|
||||
|
|
|
@ -7,7 +7,6 @@
|
|||
|
||||
const BINARY_SEPARATOR: &str = "_b";
|
||||
|
||||
use absy::Expression;
|
||||
use std::fmt;
|
||||
use std::collections::{BTreeMap};
|
||||
use field::Field;
|
||||
|
@ -101,10 +100,6 @@ impl<T: Field> FlatFunction<T> {
|
|||
let s = expr.solve(&mut witness);
|
||||
witness.insert(id.to_string(), s);
|
||||
},
|
||||
FlatStatement::Compiler(ref id, ref expr) => {
|
||||
let s = expr.solve(&mut witness);
|
||||
witness.insert(id.to_string(), s);
|
||||
},
|
||||
FlatStatement::Condition(ref lhs, ref rhs) => {
|
||||
if lhs.solve(&mut witness) != rhs.solve(&mut witness) {
|
||||
return Err(Error {
|
||||
|
@ -184,7 +179,6 @@ impl<T: Field> fmt::Debug for FlatFunction<T> {
|
|||
pub enum FlatStatement<T: Field> {
|
||||
Return(FlatExpressionList<T>),
|
||||
Condition(FlatExpression<T>, FlatExpression<T>),
|
||||
Compiler(String, Expression<T>),
|
||||
Definition(String, FlatExpression<T>),
|
||||
Directive(DirectiveStatement)
|
||||
}
|
||||
|
@ -195,7 +189,6 @@ impl<T: Field> fmt::Display for FlatStatement<T> {
|
|||
FlatStatement::Definition(ref lhs, ref rhs) => write!(f, "{} = {}", lhs, rhs),
|
||||
FlatStatement::Return(ref expr) => write!(f, "return {}", expr),
|
||||
FlatStatement::Condition(ref lhs, ref rhs) => write!(f, "{} == {}", lhs, rhs),
|
||||
FlatStatement::Compiler(ref lhs, ref rhs) => write!(f, "# {} = {}", lhs, rhs),
|
||||
FlatStatement::Directive(ref d) => write!(f, "{}", d),
|
||||
}
|
||||
}
|
||||
|
@ -207,7 +200,6 @@ impl<T: Field> fmt::Debug for FlatStatement<T> {
|
|||
FlatStatement::Definition(ref lhs, ref rhs) => write!(f, "{} = {}", lhs, rhs),
|
||||
FlatStatement::Return(ref expr) => write!(f, "FlatReturn({:?})", expr),
|
||||
FlatStatement::Condition(ref lhs, ref rhs) => write!(f, "FlatCondition({:?}, {:?})", lhs, rhs),
|
||||
FlatStatement::Compiler(ref lhs, ref rhs) => write!(f, "Compiler({:?}, {:?})", lhs, rhs),
|
||||
FlatStatement::Directive(ref d) => write!(f, "{:?}", d),
|
||||
}
|
||||
}
|
||||
|
@ -224,10 +216,6 @@ impl<T: Field> FlatStatement<T> {
|
|||
x.apply_substitution(substitution)
|
||||
),
|
||||
FlatStatement::Return(ref x) => FlatStatement::Return(x.apply_substitution(substitution)),
|
||||
FlatStatement::Compiler(ref lhs, ref rhs) => FlatStatement::Compiler(match substitution.get(lhs) {
|
||||
Some(z) => z.clone(),
|
||||
None => lhs.clone()
|
||||
}, rhs.clone().apply_substitution(substitution)),
|
||||
FlatStatement::Condition(ref x, ref y) => {
|
||||
FlatStatement::Condition(x.apply_substitution(substitution), y.apply_substitution(substitution))
|
||||
},
|
||||
|
|
|
@ -15,7 +15,7 @@ use flat_absy::*;
|
|||
use parameter::Parameter;
|
||||
use direct_substitution::DirectSubstitution;
|
||||
use substitution::Substitution;
|
||||
use helpers::{DirectiveStatement};
|
||||
use helpers::{DirectiveStatement, Helper, RustHelper};
|
||||
|
||||
/// Flattener, computes flattened program.
|
||||
pub struct Flattener {
|
||||
|
@ -170,22 +170,13 @@ impl Flattener {
|
|||
Expression::Sub(box lhs, box rhs),
|
||||
);
|
||||
statements_flattened.push(FlatStatement::Definition(name_x.to_string(), x));
|
||||
statements_flattened.push(FlatStatement::Compiler(
|
||||
name_y.to_string(),
|
||||
Expression::IfElse(
|
||||
box Condition::Eq(Expression::Identifier(name_x.to_string()), Expression::Number(T::zero())),
|
||||
box Expression::Number(T::zero()),
|
||||
box Expression::Number(T::one()),
|
||||
),
|
||||
));
|
||||
statements_flattened.push(FlatStatement::Compiler(
|
||||
name_m.to_string(),
|
||||
Expression::IfElse(
|
||||
box Condition::Eq(Expression::Identifier(name_x.to_string()), Expression::Number(T::zero())),
|
||||
box Expression::Number(T::one()),
|
||||
box Expression::Div(box Expression::Number(T::one()), box Expression::Identifier(name_x.to_string())),
|
||||
),
|
||||
));
|
||||
statements_flattened.push(
|
||||
FlatStatement::Directive(DirectiveStatement {
|
||||
outputs: vec![name_y.to_string(), name_m.to_string()],
|
||||
inputs: vec![name_x.to_string()],
|
||||
helper: Helper::Rust(RustHelper::ConditionEq)
|
||||
})
|
||||
);
|
||||
statements_flattened.push(FlatStatement::Condition(
|
||||
FlatExpression::Identifier(name_y.to_string()),
|
||||
FlatExpression::Mult(box FlatExpression::Identifier(name_x.to_string()), box FlatExpression::Identifier(name_m)),
|
||||
|
@ -281,12 +272,6 @@ impl Flattener {
|
|||
FlatStatement::Definition(new_var, new_rhs)
|
||||
);
|
||||
},
|
||||
FlatStatement::Compiler(var, rhs) => {
|
||||
let new_var: String = format!("{}{}", prefix, var.clone());
|
||||
replacement_map.insert(var, new_var.clone());
|
||||
let new_rhs = rhs.apply_substitution(&replacement_map);
|
||||
statements_flattened.push(FlatStatement::Compiler(new_var, new_rhs));
|
||||
},
|
||||
FlatStatement::Condition(lhs, rhs) => {
|
||||
let new_lhs = lhs.apply_substitution(&replacement_map);
|
||||
let new_rhs = rhs.apply_substitution(&replacement_map);
|
||||
|
|
|
@ -50,13 +50,15 @@ impl fmt::Display for LibsnarkGadgetHelper {
|
|||
|
||||
#[derive(Clone, PartialEq, Debug, Serialize, Deserialize)]
|
||||
pub enum RustHelper {
|
||||
Identity
|
||||
Identity,
|
||||
ConditionEq,
|
||||
}
|
||||
|
||||
impl fmt::Display for RustHelper {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
match *self {
|
||||
RustHelper::Identity => write!(f, "Identity"),
|
||||
RustHelper::ConditionEq => write!(f, "ConditionEq"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -106,6 +108,12 @@ impl<T: Field> Executable<T> for RustHelper {
|
|||
fn execute(&self, inputs: &Vec<T>) -> Result<Vec<T>, String> {
|
||||
match self {
|
||||
RustHelper::Identity => Ok(inputs.clone()),
|
||||
RustHelper::ConditionEq => {
|
||||
match inputs[0].is_zero() {
|
||||
true => Ok(vec![T::zero(), T::one()]),
|
||||
false => Ok(vec![T::one(), T::one() / inputs[0].clone()])
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -114,6 +122,7 @@ impl Signed for RustHelper {
|
|||
fn get_signature(&self) -> (usize, usize) {
|
||||
match self {
|
||||
RustHelper::Identity => (1, 1),
|
||||
RustHelper::ConditionEq => (1, 2),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -149,15 +158,47 @@ mod tests {
|
|||
use field::FieldPrime;
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn execute_sha() {
|
||||
let sha = LibsnarkGadgetHelper::Sha256Compress;
|
||||
// second vector here https://homes.esat.kuleuven.be/~nsmart/MPC/sha-256-test.txt
|
||||
let inputs = vec![0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,1,0,0,0,0,0,0,0,1,1,0,0,0,0,0,1,0,0,0,0,0,0,0,1,0,1,0,0,0,0,0,1,1,0,0,0,0,0,0,1,1,1,0,0,0,0,1,0,0,0,0,0,0,0,1,0,0,1,0,0,0,0,1,0,1,0,0,0,0,0,1,0,1,1,0,0,0,0,1,1,0,0,0,0,0,0,1,1,0,1,0,0,0,0,1,1,1,0,0,0,0,0,1,1,1,1,0,0,0,1,0,0,0,0,0,0,0,1,0,0,0,1,0,0,0,1,0,0,1,0,0,0,0,1,0,0,1,1,0,0,0,1,0,1,0,0,0,0,0,1,0,1,0,1,0,0,0,1,0,1,1,0,0,0,0,1,0,1,1,1,0,0,0,1,1,0,0,0,0,0,0,1,1,0,0,1,0,0,0,1,1,0,1,0,0,0,0,1,1,0,1,1,0,0,0,1,1,1,0,0,0,0,0,1,1,1,0,1,0,0,0,1,1,1,1,0,0,0,0,1,1,1,1,1,0,0,1,0,0,0,0,0,0,0,1,0,0,0,0,1,0,0,1,0,0,0,1,0,0,0,1,0,0,0,1,1,0,0,1,0,0,1,0,0,0,0,1,0,0,1,0,1,0,0,1,0,0,1,1,0,0,0,1,0,0,1,1,1,0,0,1,0,1,0,0,0,0,0,1,0,1,0,0,1,0,0,1,0,1,0,1,0,0,0,1,0,1,0,1,1,0,0,1,0,1,1,0,0,0,0,1,0,1,1,0,1,0,0,1,0,1,1,1,0,0,0,1,0,1,1,1,1,0,0,1,1,0,0,0,0,0,0,1,1,0,0,0,1,0,0,1,1,0,0,1,0,0,0,1,1,0,0,1,1,0,0,1,1,0,1,0,0,0,0,1,1,0,1,0,1,0,0,1,1,0,1,1,0,0,0,1,1,0,1,1,1,0,0,1,1,1,0,0,0,0,0,1,1,1,0,0,1,0,0,1,1,1,0,1,0,0,0,1,1,1,0,1,1,0,0,1,1,1,1,0,0,0,0,1,1,1,1,0,1,0,0,1,1,1,1,1,0,0,0,1,1,1,1,1,1];
|
||||
let r = sha.execute(&inputs.iter().map(|&i| FieldPrime::from(i)).collect()).unwrap();
|
||||
let r1 = &r[513..769]; // index of the result
|
||||
let res: Vec<FieldPrime> = vec![1,1,1,1,1,1,0,0,1,0,0,1,1,0,0,1,1,0,1,0,0,0,1,0,1,1,0,1,1,1,1,1,1,0,0,0,1,0,0,0,1,1,1,1,0,1,0,0,0,0,1,0,1,0,1,0,0,1,1,1,1,0,1,0,0,1,1,1,1,0,1,1,1,0,1,1,1,0,0,1,1,1,0,1,0,0,0,1,1,0,0,0,0,0,0,0,0,0,1,1,0,0,1,1,1,1,0,0,1,1,0,1,1,1,0,0,0,1,1,0,1,0,1,0,0,0,1,0,0,0,0,0,0,0,1,0,0,1,0,1,0,1,1,0,0,1,1,1,0,1,0,1,0,1,0,1,1,1,1,1,1,0,0,1,1,1,0,1,0,1,0,1,1,0,1,1,1,0,0,1,1,0,1,0,0,1,0,1,0,0,0,0,0,1,0,0,0,1,0,0,1,0,1,0,1,0,0,1,1,1,0,0,1,1,0,0,0,0,1,1,0,0,0,1,0,1,0,1,1,0,1,0,1,0,1,1,1,1,1,0,1,0,0,0,0,1,0,0,1,0,1,0,0,1,1,1].iter().map(|&i| FieldPrime::from(i)).collect();
|
||||
assert_eq!(r1, &res[..]);
|
||||
mod sha256libsnark {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn execute() {
|
||||
let sha = LibsnarkGadgetHelper::Sha256Compress;
|
||||
// second vector here https://homes.esat.kuleuven.be/~nsmart/MPC/sha-256-test.txt
|
||||
let inputs = vec![0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,1,0,0,0,0,0,0,0,1,1,0,0,0,0,0,1,0,0,0,0,0,0,0,1,0,1,0,0,0,0,0,1,1,0,0,0,0,0,0,1,1,1,0,0,0,0,1,0,0,0,0,0,0,0,1,0,0,1,0,0,0,0,1,0,1,0,0,0,0,0,1,0,1,1,0,0,0,0,1,1,0,0,0,0,0,0,1,1,0,1,0,0,0,0,1,1,1,0,0,0,0,0,1,1,1,1,0,0,0,1,0,0,0,0,0,0,0,1,0,0,0,1,0,0,0,1,0,0,1,0,0,0,0,1,0,0,1,1,0,0,0,1,0,1,0,0,0,0,0,1,0,1,0,1,0,0,0,1,0,1,1,0,0,0,0,1,0,1,1,1,0,0,0,1,1,0,0,0,0,0,0,1,1,0,0,1,0,0,0,1,1,0,1,0,0,0,0,1,1,0,1,1,0,0,0,1,1,1,0,0,0,0,0,1,1,1,0,1,0,0,0,1,1,1,1,0,0,0,0,1,1,1,1,1,0,0,1,0,0,0,0,0,0,0,1,0,0,0,0,1,0,0,1,0,0,0,1,0,0,0,1,0,0,0,1,1,0,0,1,0,0,1,0,0,0,0,1,0,0,1,0,1,0,0,1,0,0,1,1,0,0,0,1,0,0,1,1,1,0,0,1,0,1,0,0,0,0,0,1,0,1,0,0,1,0,0,1,0,1,0,1,0,0,0,1,0,1,0,1,1,0,0,1,0,1,1,0,0,0,0,1,0,1,1,0,1,0,0,1,0,1,1,1,0,0,0,1,0,1,1,1,1,0,0,1,1,0,0,0,0,0,0,1,1,0,0,0,1,0,0,1,1,0,0,1,0,0,0,1,1,0,0,1,1,0,0,1,1,0,1,0,0,0,0,1,1,0,1,0,1,0,0,1,1,0,1,1,0,0,0,1,1,0,1,1,1,0,0,1,1,1,0,0,0,0,0,1,1,1,0,0,1,0,0,1,1,1,0,1,0,0,0,1,1,1,0,1,1,0,0,1,1,1,1,0,0,0,0,1,1,1,1,0,1,0,0,1,1,1,1,1,0,0,0,1,1,1,1,1,1];
|
||||
let r = sha.execute(&inputs.iter().map(|&i| FieldPrime::from(i)).collect()).unwrap();
|
||||
let r1 = &r[513..769]; // index of the result
|
||||
let res: Vec<FieldPrime> = vec![1,1,1,1,1,1,0,0,1,0,0,1,1,0,0,1,1,0,1,0,0,0,1,0,1,1,0,1,1,1,1,1,1,0,0,0,1,0,0,0,1,1,1,1,0,1,0,0,0,0,1,0,1,0,1,0,0,1,1,1,1,0,1,0,0,1,1,1,1,0,1,1,1,0,1,1,1,0,0,1,1,1,0,1,0,0,0,1,1,0,0,0,0,0,0,0,0,0,1,1,0,0,1,1,1,1,0,0,1,1,0,1,1,1,0,0,0,1,1,0,1,0,1,0,0,0,1,0,0,0,0,0,0,0,1,0,0,1,0,1,0,1,1,0,0,1,1,1,0,1,0,1,0,1,0,1,1,1,1,1,1,0,0,1,1,1,0,1,0,1,0,1,1,0,1,1,1,0,0,1,1,0,1,0,0,1,0,1,0,0,0,0,0,1,0,0,0,1,0,0,1,0,1,0,1,0,0,1,1,1,0,0,1,1,0,0,0,0,1,1,0,0,0,1,0,1,0,1,1,0,1,0,1,0,1,1,1,1,1,0,1,0,0,0,0,1,0,0,1,0,1,0,0,1,1,1].iter().map(|&i| FieldPrime::from(i)).collect();
|
||||
assert_eq!(r1, &res[..]);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
mod eq_condition {
|
||||
|
||||
// Wanted: (Y = (X != 0) ? 1 : 0)
|
||||
// # Y = if X == 0 then 0 else 1 fi
|
||||
// # M = if X == 0 then 1 else 1/X fi
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn execute() {
|
||||
let cond_eq = RustHelper::ConditionEq;
|
||||
let inputs = vec![0];
|
||||
let r = cond_eq.execute(&inputs.iter().map(|&i| FieldPrime::from(i)).collect()).unwrap();
|
||||
let res: Vec<FieldPrime> = vec![0, 1].iter().map(|&i| FieldPrime::from(i)).collect();
|
||||
assert_eq!(r, &res[..]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn execute_non_eq() {
|
||||
let cond_eq = RustHelper::ConditionEq;
|
||||
let inputs = vec![1];
|
||||
let r = cond_eq.execute(&inputs.iter().map(|&i| FieldPrime::from(i)).collect()).unwrap();
|
||||
let res: Vec<FieldPrime> = vec![1, 1].iter().map(|&i| FieldPrime::from(i)).collect();
|
||||
assert_eq!(r, &res[..]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -83,10 +83,6 @@ impl Optimizer {
|
|||
FlatStatement::Definition(ref left, _) => {
|
||||
self.substitution.insert(left.clone(), format!("_{}", self.next_var_idx.increment()));
|
||||
},
|
||||
// Compiler statements introduce variables before they are defined, so add them to the substitution
|
||||
FlatStatement::Compiler(ref id, _) => {
|
||||
self.substitution.insert(id.clone(), format!("_{}", self.next_var_idx.increment()));
|
||||
},
|
||||
FlatStatement::Directive(ref d) => {
|
||||
for o in d.outputs.iter() {
|
||||
self.substitution.insert(o.clone(), format!("_{}", self.next_var_idx.increment()));
|
||||
|
|
|
@ -335,7 +335,6 @@ pub fn r1cs_program<T: Field>(
|
|||
b.push(b_row);
|
||||
c.push(c_row);
|
||||
},
|
||||
FlatStatement::Compiler(..) => continue,
|
||||
FlatStatement::Directive(..) => continue
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue