From 6bcf021124412039c2905227e8268d9c1c351d2b Mon Sep 17 00:00:00 2001 From: Dennis Kuhnert Date: Wed, 1 Feb 2017 11:56:48 +0100 Subject: [PATCH] Add new flattening for Add & Mult; Add new examples --- examples/add.code | 6 ++ examples/flatten.code | 6 ++ examples/sub.code | 6 ++ src/ast.rs | 4 +- src/parser.rs | 155 ++++++++++++++++++++++-------------------- src/r1cs.rs | 130 +++++++++++++++++++++++++++-------- 6 files changed, 203 insertions(+), 104 deletions(-) create mode 100644 examples/add.code create mode 100644 examples/flatten.code create mode 100644 examples/sub.code diff --git a/examples/add.code b/examples/add.code new file mode 100644 index 00000000..9a5528c2 --- /dev/null +++ b/examples/add.code @@ -0,0 +1,6 @@ +// only using add, no need to flatten +def qeval(a): + b = a + 5 + c = a + b + a + 4 + d = a + c + a + b + return b + c + d diff --git a/examples/flatten.code b/examples/flatten.code new file mode 100644 index 00000000..6277c9af --- /dev/null +++ b/examples/flatten.code @@ -0,0 +1,6 @@ +// this code needs flattening +def qeval(a): + b = a + 5 + a * a + c = b + a + a * b * b + d = a * b + c * c + return b + c + d diff --git a/examples/sub.code b/examples/sub.code new file mode 100644 index 00000000..206f9f08 --- /dev/null +++ b/examples/sub.code @@ -0,0 +1,6 @@ +// only using sub, no need to flatten +def qeval(a): + b = a - 5 + c = b + a + b + d = b - a - 3 - a + return d + c diff --git a/src/ast.rs b/src/ast.rs index 4ce27147..f0e6ca51 100644 --- a/src/ast.rs +++ b/src/ast.rs @@ -86,7 +86,7 @@ impl Expression { } } - fn is_linear(&self) -> bool { + pub fn is_linear(&self) -> bool { match *self { Expression::NumberLiteral(_) | Expression::VariableReference(_) => true, @@ -108,7 +108,7 @@ impl Expression { Expression::NumberLiteral(_) | Expression::VariableReference(_) => true, Expression::Add(ref x, ref y) | - Expression::Sub(ref x, ref y) => x.is_flattened() && y.is_flattened(), + Expression::Sub(ref x, ref y) => x.is_linear() && y.is_linear(), Expression::Mult(ref x, ref y) | Expression::Div(ref x, ref y) => x.is_linear() && y.is_linear(), _ => false, diff --git a/src/parser.rs b/src/parser.rs index a44ce8a5..38e0e9f9 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -121,87 +121,96 @@ fn flatten_expression(defs_flattened: &mut Vec, num_variables: &mut match expr { x @ Expression::NumberLiteral(_) | x @ Expression::VariableReference(_) => x, - // ref x @ Expression::Add(..) | - // ref x @ Expression::Sub(..) | - // ref x @ Expression::Mult(..) | - // ref x @ Expression::Div(..) if x.is_flattened() => x.clone(), - Expression::Add(left, right) => { - // TODO currently assuming that left is always Number or Variable - let new_right = match right { - box Expression::NumberLiteral(x) => Expression::NumberLiteral(x), - box Expression::VariableReference(ref x) => Expression::VariableReference(x.to_string()), - box expr => { - let tmp_expression = flatten_expression( - defs_flattened, - num_variables, - expr - ); - let new_name = format!("sym_{}", num_variables); - *num_variables += 1; - defs_flattened.push(Definition::Definition(new_name.to_string(), tmp_expression)); - Expression::VariableReference(new_name) - }, + ref x @ Expression::Add(..) | + ref x @ Expression::Sub(..) | + ref x @ Expression::Mult(..) | + ref x @ Expression::Div(..) if x.is_flattened() => x.clone(), + Expression::Add(box left, box right) => { + let left_flattened = flatten_expression(defs_flattened, num_variables, left); + let right_flattened = flatten_expression(defs_flattened, num_variables, right); + let new_left = if left_flattened.is_linear() { + left_flattened + } else { + let new_name = format!("sym_{}", num_variables); + *num_variables += 1; + defs_flattened.push(Definition::Definition(new_name.to_string(), left_flattened)); + Expression::VariableReference(new_name) }; - Expression::Add(left, box new_right) + let new_right = if right_flattened.is_linear() { + right_flattened + } else { + let new_name = format!("sym_{}", num_variables); + *num_variables += 1; + defs_flattened.push(Definition::Definition(new_name.to_string(), right_flattened)); + Expression::VariableReference(new_name) + }; + Expression::Add(box new_left, box new_right) }, - Expression::Sub(left, right) => { - // TODO currently assuming that left is always Number or Variable - let new_right = match right { - box Expression::NumberLiteral(x) => Expression::NumberLiteral(x), - box Expression::VariableReference(ref x) => Expression::VariableReference(x.to_string()), - box expr => { - let tmp_expression = flatten_expression( - defs_flattened, - num_variables, - expr - ); - let new_name = format!("sym_{}", num_variables); - *num_variables += 1; - defs_flattened.push(Definition::Definition(new_name.to_string(), tmp_expression)); - Expression::VariableReference(new_name) - }, + Expression::Sub(box left, box right) => { + let left_flattened = flatten_expression(defs_flattened, num_variables, left); + let right_flattened = flatten_expression(defs_flattened, num_variables, right); + let new_left = if left_flattened.is_linear() { + left_flattened + } else { + let new_name = format!("sym_{}", num_variables); + *num_variables += 1; + defs_flattened.push(Definition::Definition(new_name.to_string(), left_flattened)); + Expression::VariableReference(new_name) }; - Expression::Sub(left, box new_right) + let new_right = if right_flattened.is_linear() { + right_flattened + } else { + let new_name = format!("sym_{}", num_variables); + *num_variables += 1; + defs_flattened.push(Definition::Definition(new_name.to_string(), right_flattened)); + Expression::VariableReference(new_name) + }; + Expression::Sub(box new_left, box new_right) }, - Expression::Mult(left, right) => { - // TODO currently assuming that left is always Number or Variable - let new_right = match right { - box Expression::NumberLiteral(x) => Expression::NumberLiteral(x), - box Expression::VariableReference(ref x) => Expression::VariableReference(x.to_string()), - box expr => { - let tmp_expression = flatten_expression( - defs_flattened, - num_variables, - expr - ); - let new_name = format!("sym_{}", num_variables); - *num_variables += 1; - defs_flattened.push(Definition::Definition(new_name.to_string(), tmp_expression)); - Expression::VariableReference(new_name) - }, + Expression::Mult(box left, box right) => { + let left_flattened = flatten_expression(defs_flattened, num_variables, left); + let right_flattened = flatten_expression(defs_flattened, num_variables, right); + let new_left = if left_flattened.is_linear() { + left_flattened + } else { + let new_name = format!("sym_{}", num_variables); + *num_variables += 1; + defs_flattened.push(Definition::Definition(new_name.to_string(), left_flattened)); + Expression::VariableReference(new_name) }; - Expression::Mult(left, box new_right) + let new_right = if right_flattened.is_linear() { + right_flattened + } else { + let new_name = format!("sym_{}", num_variables); + *num_variables += 1; + defs_flattened.push(Definition::Definition(new_name.to_string(), right_flattened)); + Expression::VariableReference(new_name) + }; + Expression::Mult(box new_left, box new_right) }, - Expression::Div(left, right) => { - // TODO currently assuming that left is always Number or Variable - let new_right = match right { - box Expression::NumberLiteral(x) => Expression::NumberLiteral(x), - box Expression::VariableReference(ref x) => Expression::VariableReference(x.to_string()), - box expr => { - let tmp_expression = flatten_expression( - defs_flattened, - num_variables, - expr - ); - let new_name = format!("sym_{}", num_variables); - *num_variables += 1; - defs_flattened.push(Definition::Definition(new_name.to_string(), tmp_expression)); - Expression::VariableReference(new_name) - }, + Expression::Div(box left, box right) => { + let left_flattened = flatten_expression(defs_flattened, num_variables, left); + let right_flattened = flatten_expression(defs_flattened, num_variables, right); + let new_left = if left_flattened.is_linear() { + left_flattened + } else { + let new_name = format!("sym_{}", num_variables); + *num_variables += 1; + defs_flattened.push(Definition::Definition(new_name.to_string(), left_flattened)); + Expression::VariableReference(new_name) }; - Expression::Div(left, box new_right) + let new_right = if right_flattened.is_linear() { + right_flattened + } else { + let new_name = format!("sym_{}", num_variables); + *num_variables += 1; + defs_flattened.push(Definition::Definition(new_name.to_string(), right_flattened)); + Expression::VariableReference(new_name) + }; + Expression::Div(box new_left, box new_right) }, Expression::Pow(base, exponent) => { + // TODO currently assuming that base is number or variable match exponent { box Expression::NumberLiteral(x) if x > 1 => { match base { @@ -234,7 +243,7 @@ fn flatten_expression(defs_flattened: &mut Vec, num_variables: &mut _ => panic!("Only variables and numbers allowed in pow base") } } - _ => panic!("Expected number as pow exponent"), + _ => panic!("Expected number > 1 as pow exponent"), } }, Expression::IfElse(box condition, consequent, alternative) => { diff --git a/src/r1cs.rs b/src/r1cs.rs index 35adabcb..3fa64585 100644 --- a/src/r1cs.rs +++ b/src/r1cs.rs @@ -1,52 +1,124 @@ use ast::*; +use std::collections::HashMap; + +fn count_variables_add(expr: Expression) -> HashMap { + let mut count = HashMap::new(); + match expr { + Expression::Add(box lhs, box rhs) => { + match (lhs, rhs) { + (Expression::NumberLiteral(x), Expression::NumberLiteral(y)) => { + let num = count.entry("~one".to_string()).or_insert(0); + *num += x + y; + }, + (Expression::VariableReference(v), Expression::NumberLiteral(x)) | + (Expression::NumberLiteral(x), Expression::VariableReference(v)) => { + { + let num = count.entry("~one".to_string()).or_insert(0); + *num += x; + } + let var = count.entry(v).or_insert(0); + *var += 1; + }, + (Expression::VariableReference(v1), Expression::VariableReference(v2)) => { + { + let var1 = count.entry(v1).or_insert(0); + *var1 += 1; + } + let var2 = count.entry(v2).or_insert(0); + *var2 += 1; + }, + (Expression::NumberLiteral(x), e @ Expression::Add(..)) | + (e @ Expression::Add(..), Expression::NumberLiteral(x)) => { + { + let num = count.entry("~one".to_string()).or_insert(0); + *num += x; + } + let vars = count_variables_add(e); + for (key, value) in &vars { + let val = count.entry(key.to_string()).or_insert(0); + *val += *value; + } + }, + (Expression::VariableReference(v), e @ Expression::Add(..)) | + (e @ Expression::Add(..), Expression::VariableReference(v)) => { + { + let var = count.entry(v).or_insert(0); + *var += 1; + } + let vars = count_variables_add(e); + for (key, value) in &vars { + let val = count.entry(key.to_string()).or_insert(0); + *val += *value; + } + }, + (Expression::NumberLiteral(x), Expression::Mult(box Expression::NumberLiteral(n), box Expression::VariableReference(v))) | + (Expression::NumberLiteral(x), Expression::Mult(box Expression::VariableReference(v), box Expression::NumberLiteral(n))) | + (Expression::Mult(box Expression::NumberLiteral(n), box Expression::VariableReference(v)), Expression::NumberLiteral(x)) | + (Expression::Mult(box Expression::VariableReference(v), box Expression::NumberLiteral(n)), Expression::NumberLiteral(x)) => { + { + let num = count.entry("~one".to_string()).or_insert(0); + *num += x; + } + let var = count.entry(v).or_insert(0); + *var += n; + }, + e @ _ => panic!("Error: Add({}, {})", e.0, e.1), + } + }, + e @ _ => panic!("Definition::Add expected, got: {}", e), + } + count +} pub fn r1cs_expression(idx: usize, expr: Expression, variables: &mut Vec, a_row: &mut Vec<(usize, i32)>, b_row: &mut Vec<(usize, i32)>, c_row: &mut Vec<(usize, i32)>) { match expr { - Expression::Add(lhs, rhs) => { - c_row.push((idx, 1)); - match (lhs, rhs) { - (box Expression::VariableReference(ref x1), box Expression::VariableReference(ref x2)) if x1 == x2 => { - a_row.push((variables.iter().position(|r| r == x1).unwrap(), 2)); - b_row.push((0, 1)); - }, - (box Expression::VariableReference(ref x1), box Expression::VariableReference(ref x2)) /*if x1 != x2*/ => { - a_row.push((variables.iter().position(|r| r == x1).unwrap(), 1)); - a_row.push((variables.iter().position(|r| r == x2).unwrap(), 1)); - b_row.push((0, 1)); - }, - (box Expression::NumberLiteral(num), box Expression::VariableReference(ref x)) | - (box Expression::VariableReference(ref x), box Expression::NumberLiteral(num)) => { - a_row.push((0, num)); - a_row.push((variables.iter().position(|r| r == x).unwrap(), 1)); - b_row.push((0, 1)); - } - _ => panic!("Not flattened!"), + e @ Expression::Add(..) => { + for (key, value) in count_variables_add(e) { + a_row.push((variables.iter().position(|r| r == &key).unwrap(), value)); } + b_row.push((0, 1)); + c_row.push((idx, 1)); }, - Expression::Sub(lhs, rhs) => { // a - b = c --> c + b = a - a_row.push((idx, 1)); + Expression::Sub(lhs, rhs) => { // a - b = c --> b + c = a + for (key, value) in count_variables_add(Expression::Add(rhs, box Expression::VariableReference(variables[idx].to_string()))) { + a_row.push((variables.iter().position(|r| r == &key).unwrap(), value)); + } + b_row.push((0, 1)); match lhs { box Expression::NumberLiteral(x) => c_row.push((0, x)), box Expression::VariableReference(x) => c_row.push((variables.iter().position(|r| r == &x).unwrap(), 1)), - _ => panic!("Not flattened!"), - }; - match rhs { - box Expression::NumberLiteral(x) => b_row.push((0, x)), - box Expression::VariableReference(x) => b_row.push((variables.iter().position(|r| r == &x).unwrap(), 1)), - _ => panic!("Not flattened!"), + e @ _ => panic!("unimplemented: {}", e), }; + + // a_row.push((idx, 1)); + // match lhs { + // box Expression::NumberLiteral(x) => c_row.push((0, x)), + // box Expression::VariableReference(x) => c_row.push((variables.iter().position(|r| r == &x).unwrap(), 1)), + // _ => panic!("Not flattened!"), + // }; + // match rhs { + // box Expression::NumberLiteral(x) => b_row.push((0, x)), + // box Expression::VariableReference(x) => b_row.push((variables.iter().position(|r| r == &x).unwrap(), 1)), + // _ => panic!("Not flattened!"), + // }; }, Expression::Mult(lhs, rhs) => { c_row.push((idx, 1)); match lhs { box Expression::NumberLiteral(x) => a_row.push((0, x)), box Expression::VariableReference(x) => a_row.push((variables.iter().position(|r| r == &x).unwrap(), 1)), - _ => panic!("Not flattened!"), + box e @ Expression::Add(..) => for (key, value) in count_variables_add(e) { + a_row.push((variables.iter().position(|r| r == &key).unwrap(), value)); + }, + e @ _ => panic!("Not flattened: {}", e), }; match rhs { box Expression::NumberLiteral(x) => b_row.push((0, x)), box Expression::VariableReference(x) => b_row.push((variables.iter().position(|r| r == &x).unwrap(), 1)), - _ => panic!("Not flattened!"), + box e @ Expression::Add(..) => for (key, value) in count_variables_add(e) { + b_row.push((variables.iter().position(|r| r == &key).unwrap(), value)); + }, + e @ _ => panic!("Not flattened: {}", e), }; }, Expression::Div(lhs, rhs) => { // a / b = c --> c * b = a