From fb8442ff6ac12e62aba528ec9a66c126944de661 Mon Sep 17 00:00:00 2001 From: Dennis Kuhnert Date: Thu, 6 Jul 2017 13:28:11 +0200 Subject: [PATCH] Fix multiple bugs in function flattening; add examples --- Dockerfile | 8 +- examples/multi_functions.code | 12 +++ examples/n_choose_k.code | 17 +++++ src/absy.rs | 12 ++- src/flatten.rs | 138 ++++++++++++++++++++++++---------- src/libsnark.rs | 10 +-- src/main.rs | 4 + src/parser.rs | 5 +- 8 files changed, 148 insertions(+), 58 deletions(-) create mode 100644 examples/multi_functions.code create mode 100644 examples/n_choose_k.code diff --git a/Dockerfile b/Dockerfile index 0db2d5a4..b3419c4b 100644 --- a/Dockerfile +++ b/Dockerfile @@ -26,9 +26,9 @@ RUN cd libsnark-$libsnarkcommit \ && make install lib PREFIX=/usr/local \ NO_PROCPS=1 NO_GTEST=1 NO_DOCS=1 CURVE=ALT_BN128 FEATUREFLAGS="-DBINARY_OUTPUT=1 -DMONTGOMERY_OUTPUT=1 -DNO_PT_COMPRESSION=1" -COPY . /root/VerifiableStatementCompiler - ENV LD_LIBRARY_PATH $LD_LIBRARY_PATH:/usr/local/lib -RUN cd VerifiableStatementCompiler \ - && cargo build +#COPY . /root/VerifiableStatementCompiler + +#RUN cd VerifiableStatementCompiler \ +# && cargo build --release diff --git a/examples/multi_functions.code b/examples/multi_functions.code new file mode 100644 index 00000000..8e265d93 --- /dev/null +++ b/examples/multi_functions.code @@ -0,0 +1,12 @@ +def add(a,b): + v = a + b + return v + a + +def main(a,b,c,d): + g = a + b + x = add(a,b) + y = add(c,d) + g = add(x, g) + g = add(x, g) + f = c + d + a + return x + y + g + f diff --git a/examples/n_choose_k.code b/examples/n_choose_k.code new file mode 100644 index 00000000..aecb8726 --- /dev/null +++ b/examples/n_choose_k.code @@ -0,0 +1,17 @@ +// working with 9988 choose 14 +// +def fac(x): + f = 1 + counter = 0 + for i in 1..20000 do + f = if counter == x then f else f * i fi + counter = if counter == x then counter else counter + 1 fi + endfor + return f +def main(n, k): + top = fac(n) + bot1 = fac(k) + sub = n - k + bot2 = fac(sub) + bot = bot1 * bot2 + return top / bot diff --git a/src/absy.rs b/src/absy.rs index 87efb4d9..e79e05d1 100644 --- a/src/absy.rs +++ b/src/absy.rs @@ -29,7 +29,7 @@ impl Prog { impl fmt::Display for Prog { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "prog({}\t)", self.functions.iter().map(|x| format!("\t{}", x)).collect::>().join("\n")) + write!(f, "{}", self.functions.iter().map(|x| format!("{}", x)).collect::>().join("\n")) } } @@ -99,17 +99,15 @@ pub enum Statement { } impl Statement { - pub fn is_flattened(&self) -> bool { match *self { Statement::Return(ref x) | - Statement::Definition(_,ref x) | - Statement::Compiler(_,ref x) => x.is_flattened(), - Statement::Condition(ref x,ref y) => x.is_flattened() && y.is_flattened(), - Statement::For(_, _, _, _) => unimplemented!(), // should not be required, can be implemented later + Statement::Definition(_,ref x) => x.is_flattened(), + Statement::Compiler(..) => true, + Statement::Condition(ref x,ref y) => (x.is_linear() && y.is_flattened()) || (x.is_flattened() && y.is_linear()), + Statement::For(..) => unimplemented!(), // should not be required, can be implemented later } } - } diff --git a/src/flatten.rs b/src/flatten.rs index 46ab084f..3833dc7a 100644 --- a/src/flatten.rs +++ b/src/flatten.rs @@ -323,26 +323,104 @@ impl Flattener { FunctionCall(ref id, ref params) => { for funct in functions_flattened { if funct.id == *id && funct.arguments.len() == (*params).len() { - // add all flattened statements except return statement - for stat in funct.statements.clone() { - assert!(stat.is_flattened()); - match stat { - Statement::Return(x) =>{ - // set return statements right side as expression result - return x - }, - _ => statements_flattened.push(stat), + let mut old_sub: HashMap = HashMap::new(); + for i in 0..funct.arguments.len() { + let input_param = match self.substitution.get(&funct.arguments[i].id) { + Some(val) => val.to_string(), + None => funct.arguments[i].id.to_string(), + }; + if params[i].id != input_param { + match self.substitution.get(¶ms[i].id) { + Some(val) => { + old_sub.insert(val.to_string(), params[i].id.to_string()); + }, + None => {}, + } + // self.variables.insert(params[i].id.to_string()); + self.substitution.insert(input_param, params[i].id.to_string()); + } + } + // add all flattened statements except return statement + for stat in funct.statements.clone() { + assert!(stat.is_flattened(), format!("Not flattened: {}", &stat)); + match stat { + // set return statements right side as expression result + Statement::Return(x) => { + let result = x.apply_substitution(&self.substitution); + // restore previous substitution + for param in &funct.arguments { + match old_sub.get(¶m.id) { + Some(val) => { + self.substitution.insert(param.id.to_string(), val.to_string()); + }, + None => { + self.substitution.remove(¶m.id); + }, + } + } + return result; + }, + Statement::Definition(var, rhs) => { + let new_rhs = rhs.apply_substitution(&self.substitution); + statements_flattened.push(Statement::Definition(self.use_variable(&var), new_rhs)); + }, + Statement::Compiler(var, rhs) => { + let new_rhs = rhs.apply_substitution(&self.substitution); + statements_flattened.push(Statement::Compiler(self.use_variable(&var), new_rhs)); + }, + Statement::Condition(lhs, rhs) => { + let new_lhs = lhs.apply_substitution(&self.substitution); + let new_rhs = rhs.apply_substitution(&self.substitution); + statements_flattened.push(Statement::Condition(new_lhs, new_rhs)); + }, + Statement::For(..) => panic!("Not flattened!"), } } - } else { - panic!("Function definition for function {} with {} argument(s) not found.",funct.id , funct.arguments.len()); } } - panic!("Should never happen.") + panic!("Function definition for function {} with {:?} argument(s) not found.",id , params); }, } } + pub fn flatten_statement(&mut self, functions_flattened: &mut Vec>, statements_flattened: &mut Vec>, stat: &Statement) { + match *stat { + Statement::Return(ref expr) => { + let expr_subbed = expr.apply_substitution(&self.substitution); + let rhs = self.flatten_expression(functions_flattened, statements_flattened, expr_subbed); + statements_flattened.push(Statement::Return(rhs)); + }, + Statement::Definition(ref id, ref expr) => { + let expr_subbed = expr.apply_substitution(&self.substitution); + let rhs = self.flatten_expression(functions_flattened, statements_flattened, expr_subbed); + statements_flattened.push(Statement::Definition(self.use_variable(&id), rhs)); + }, + Statement::Condition(ref expr1, ref expr2) => { + let expr1_subbed = expr1.apply_substitution(&self.substitution); + let expr2_subbed = expr2.apply_substitution(&self.substitution); + let (lhs, rhs) = if expr1_subbed.is_linear() { + (expr1_subbed, self.flatten_expression(functions_flattened, statements_flattened, expr2_subbed)) + } else if expr2_subbed.is_linear() { + (expr2_subbed, self.flatten_expression(functions_flattened, statements_flattened, expr1_subbed)) + } else { + unimplemented!() + }; + statements_flattened.push(Statement::Condition(lhs, rhs)); + }, + Statement::For(ref var, ref start, ref end, ref statements) => { + let mut current = start.clone(); + while ¤t < end { + statements_flattened.push(Statement::Definition(self.use_variable(&var), Expression::Number(current.clone()))); + for s in statements { + self.flatten_statement(functions_flattened, statements_flattened, s); + } + current = T::one() + ¤t; + } + }, + ref s @ Statement::Compiler(..) => statements_flattened.push(s.clone()), + } + } + /// Returns a flattened `Function` based on the given `funct`. /// /// # Arguments @@ -350,34 +428,9 @@ impl Flattener { /// * `functions_flattened` - Vector where new flattened statements can be added. /// * `funct` - `Function` that will be flattened. pub fn flatten_function(&mut self, functions_flattened: &mut Vec>, funct: Function) -> Function { - let mut statements_flattened = Vec::new(); + let mut statements_flattened: Vec> = Vec::new(); for stat in funct.statements { - match stat { - Statement::Return(expr) => { - let expr_subbed = expr.apply_substitution(&self.substitution); - let rhs = self.flatten_expression(&functions_flattened, &mut statements_flattened, expr_subbed); - statements_flattened.push(Statement::Return(rhs)); - }, - Statement::Definition(id, expr) => { - let expr_subbed = expr.apply_substitution(&self.substitution); - let rhs = self.flatten_expression(&functions_flattened, &mut statements_flattened, expr_subbed); - statements_flattened.push(Statement::Definition(self.use_variable(id), rhs)); - }, - Statement::Condition(expr1, expr2) => { - let expr1_subbed = expr1.apply_substitution(&self.substitution); - let expr2_subbed = expr2.apply_substitution(&self.substitution); - let (lhs, rhs) = if expr1_subbed.is_linear() { - (expr1_subbed, self.flatten_expression(&functions_flattened, &mut statements_flattened, expr2_subbed)) - } else if expr2_subbed.is_linear() { - (expr2_subbed, self.flatten_expression(&functions_flattened, &mut statements_flattened, expr1_subbed)) - } else { - unimplemented!() - }; - statements_flattened.push(Statement::Condition(lhs, rhs)); - }, - Statement::For(..) => unimplemented!(), - s @ Statement::Compiler(..) => statements_flattened.push(s), - } + self.flatten_statement(functions_flattened, &mut statements_flattened, &stat); } Function { id: funct.id, arguments: funct.arguments, statements: statements_flattened } } @@ -393,6 +446,9 @@ impl Flattener { self.substitution = HashMap::new(); self.next_var_idx = 0; for func in prog.functions{ + self.variables = HashSet::new(); + self.substitution = HashMap::new(); + self.next_var_idx = 0; let flattened_func = self.flatten_function(&mut functions_flattened, func); functions_flattened.push(flattened_func); } @@ -405,7 +461,7 @@ impl Flattener { /// # Arguments /// /// * `name` - A String that holds the name of the variable - fn use_variable(&mut self, name: String) -> String { + fn use_variable(&mut self, name: &String) -> String { let mut i = 0; let mut new_name = name.to_string(); loop { @@ -415,7 +471,7 @@ impl Flattener { } else { self.variables.insert(new_name.to_string()); if i == 1 { - self.substitution.insert(name, new_name.to_string()); + self.substitution.insert(name.to_string(), new_name.to_string()); } else if i > 1 { self.substitution.insert(format!("{}_{}", name, i - 2), new_name.to_string()); } diff --git a/src/libsnark.rs b/src/libsnark.rs index db0e2e15..914a50f6 100644 --- a/src/libsnark.rs +++ b/src/libsnark.rs @@ -46,11 +46,11 @@ pub fn run_libsnark(variables: Vec, a: Vec>, b } //debugging output - println!("debugging output:"); - println!("a_arr {:?}", a_arr); - println!("b_arr {:?}", b_arr); - println!("c_arr {:?}", c_arr); - println!("w_arr {:?}", w_arr); + //println!("debugging output:"); + //println!("a_arr {:?}", a_arr); + //println!("b_arr {:?}", b_arr); + //println!("c_arr {:?}", c_arr); + //println!("w_arr {:?}", w_arr); unsafe { _run_libsnark(a_arr[0].as_ptr(),b_arr[0].as_ptr(), c_arr[0].as_ptr(), w_arr[0].as_ptr(), num_constraints as i32, num_variables as i32, num_inputs as i32) diff --git a/src/main.rs b/src/main.rs index 62d31dba..71badf56 100644 --- a/src/main.rs +++ b/src/main.rs @@ -75,6 +75,10 @@ fn main() { // generate wittness let witness_map = program_flattened.get_witness(inputs); println!("witness_map {:?}", witness_map); + match witness_map.get("~out") { + Some(out) => println!("~out: {}", out), + None => println!("~out not found") + } let witness: Vec<_> = variables.iter().map(|x| witness_map[x].clone()).collect(); println!("witness {:?}", witness); diff --git a/src/parser.rs b/src/parser.rs index 7f0853bd..bc34fecf 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -464,7 +464,10 @@ fn parse_statement1(ide: String, input: String, pos: Position) -> Resu assert_eq!(s3, ""); Ok((Statement::Definition(ide, e2), s2, p2)) }, - (t3, _, p3) => Err(Error { expected: vec![Token::Add, Token::Sub, Token::Pow, Token::Mult, Token::Div, Token::Unknown("".to_string())], got: t3 , pos: p3 }), + (t3, _, p3) => { + println!("here {}", input); + Err(Error { expected: vec![Token::Add, Token::Sub, Token::Pow, Token::Mult, Token::Div, Token::Unknown("".to_string())], got: t3 , pos: p3 }) + }, }, Err(err) => Err(err), },