Fix multiple bugs in function flattening; add examples
This commit is contained in:
parent
45cff4fb01
commit
fb8442ff6a
8 changed files with 148 additions and 58 deletions
|
@ -26,9 +26,9 @@ RUN cd libsnark-$libsnarkcommit \
|
|||
&& make install lib PREFIX=/usr/local \
|
||||
NO_PROCPS=1 NO_GTEST=1 NO_DOCS=1 CURVE=ALT_BN128 FEATUREFLAGS="-DBINARY_OUTPUT=1 -DMONTGOMERY_OUTPUT=1 -DNO_PT_COMPRESSION=1"
|
||||
|
||||
COPY . /root/VerifiableStatementCompiler
|
||||
|
||||
ENV LD_LIBRARY_PATH $LD_LIBRARY_PATH:/usr/local/lib
|
||||
|
||||
RUN cd VerifiableStatementCompiler \
|
||||
&& cargo build
|
||||
#COPY . /root/VerifiableStatementCompiler
|
||||
|
||||
#RUN cd VerifiableStatementCompiler \
|
||||
# && cargo build --release
|
||||
|
|
12
examples/multi_functions.code
Normal file
12
examples/multi_functions.code
Normal file
|
@ -0,0 +1,12 @@
|
|||
def add(a,b):
|
||||
v = a + b
|
||||
return v + a
|
||||
|
||||
def main(a,b,c,d):
|
||||
g = a + b
|
||||
x = add(a,b)
|
||||
y = add(c,d)
|
||||
g = add(x, g)
|
||||
g = add(x, g)
|
||||
f = c + d + a
|
||||
return x + y + g + f
|
17
examples/n_choose_k.code
Normal file
17
examples/n_choose_k.code
Normal file
|
@ -0,0 +1,17 @@
|
|||
// working with 9988 choose 14
|
||||
//
|
||||
def fac(x):
|
||||
f = 1
|
||||
counter = 0
|
||||
for i in 1..20000 do
|
||||
f = if counter == x then f else f * i fi
|
||||
counter = if counter == x then counter else counter + 1 fi
|
||||
endfor
|
||||
return f
|
||||
def main(n, k):
|
||||
top = fac(n)
|
||||
bot1 = fac(k)
|
||||
sub = n - k
|
||||
bot2 = fac(sub)
|
||||
bot = bot1 * bot2
|
||||
return top / bot
|
12
src/absy.rs
12
src/absy.rs
|
@ -29,7 +29,7 @@ impl<T: Field> Prog<T> {
|
|||
|
||||
impl<T: Field> fmt::Display for Prog<T> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
write!(f, "prog({}\t)", self.functions.iter().map(|x| format!("\t{}", x)).collect::<Vec<_>>().join("\n"))
|
||||
write!(f, "{}", self.functions.iter().map(|x| format!("{}", x)).collect::<Vec<_>>().join("\n"))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -99,17 +99,15 @@ pub enum Statement<T: Field> {
|
|||
}
|
||||
|
||||
impl<T: Field> Statement<T> {
|
||||
|
||||
pub fn is_flattened(&self) -> bool {
|
||||
match *self {
|
||||
Statement::Return(ref x) |
|
||||
Statement::Definition(_,ref x) |
|
||||
Statement::Compiler(_,ref x) => x.is_flattened(),
|
||||
Statement::Condition(ref x,ref y) => x.is_flattened() && y.is_flattened(),
|
||||
Statement::For(_, _, _, _) => unimplemented!(), // should not be required, can be implemented later
|
||||
Statement::Definition(_,ref x) => x.is_flattened(),
|
||||
Statement::Compiler(..) => true,
|
||||
Statement::Condition(ref x,ref y) => (x.is_linear() && y.is_flattened()) || (x.is_flattened() && y.is_linear()),
|
||||
Statement::For(..) => unimplemented!(), // should not be required, can be implemented later
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
|
138
src/flatten.rs
138
src/flatten.rs
|
@ -323,26 +323,104 @@ impl Flattener {
|
|||
FunctionCall(ref id, ref params) => {
|
||||
for funct in functions_flattened {
|
||||
if funct.id == *id && funct.arguments.len() == (*params).len() {
|
||||
// add all flattened statements except return statement
|
||||
for stat in funct.statements.clone() {
|
||||
assert!(stat.is_flattened());
|
||||
match stat {
|
||||
Statement::Return(x) =>{
|
||||
// set return statements right side as expression result
|
||||
return x
|
||||
},
|
||||
_ => statements_flattened.push(stat),
|
||||
let mut old_sub: HashMap<String, String> = HashMap::new();
|
||||
for i in 0..funct.arguments.len() {
|
||||
let input_param = match self.substitution.get(&funct.arguments[i].id) {
|
||||
Some(val) => val.to_string(),
|
||||
None => funct.arguments[i].id.to_string(),
|
||||
};
|
||||
if params[i].id != input_param {
|
||||
match self.substitution.get(¶ms[i].id) {
|
||||
Some(val) => {
|
||||
old_sub.insert(val.to_string(), params[i].id.to_string());
|
||||
},
|
||||
None => {},
|
||||
}
|
||||
// self.variables.insert(params[i].id.to_string());
|
||||
self.substitution.insert(input_param, params[i].id.to_string());
|
||||
}
|
||||
}
|
||||
// add all flattened statements except return statement
|
||||
for stat in funct.statements.clone() {
|
||||
assert!(stat.is_flattened(), format!("Not flattened: {}", &stat));
|
||||
match stat {
|
||||
// set return statements right side as expression result
|
||||
Statement::Return(x) => {
|
||||
let result = x.apply_substitution(&self.substitution);
|
||||
// restore previous substitution
|
||||
for param in &funct.arguments {
|
||||
match old_sub.get(¶m.id) {
|
||||
Some(val) => {
|
||||
self.substitution.insert(param.id.to_string(), val.to_string());
|
||||
},
|
||||
None => {
|
||||
self.substitution.remove(¶m.id);
|
||||
},
|
||||
}
|
||||
}
|
||||
return result;
|
||||
},
|
||||
Statement::Definition(var, rhs) => {
|
||||
let new_rhs = rhs.apply_substitution(&self.substitution);
|
||||
statements_flattened.push(Statement::Definition(self.use_variable(&var), new_rhs));
|
||||
},
|
||||
Statement::Compiler(var, rhs) => {
|
||||
let new_rhs = rhs.apply_substitution(&self.substitution);
|
||||
statements_flattened.push(Statement::Compiler(self.use_variable(&var), new_rhs));
|
||||
},
|
||||
Statement::Condition(lhs, rhs) => {
|
||||
let new_lhs = lhs.apply_substitution(&self.substitution);
|
||||
let new_rhs = rhs.apply_substitution(&self.substitution);
|
||||
statements_flattened.push(Statement::Condition(new_lhs, new_rhs));
|
||||
},
|
||||
Statement::For(..) => panic!("Not flattened!"),
|
||||
}
|
||||
}
|
||||
} else {
|
||||
panic!("Function definition for function {} with {} argument(s) not found.",funct.id , funct.arguments.len());
|
||||
}
|
||||
}
|
||||
panic!("Should never happen.")
|
||||
panic!("Function definition for function {} with {:?} argument(s) not found.",id , params);
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
pub fn flatten_statement<T: Field>(&mut self, functions_flattened: &mut Vec<Function<T>>, statements_flattened: &mut Vec<Statement<T>>, stat: &Statement<T>) {
|
||||
match *stat {
|
||||
Statement::Return(ref expr) => {
|
||||
let expr_subbed = expr.apply_substitution(&self.substitution);
|
||||
let rhs = self.flatten_expression(functions_flattened, statements_flattened, expr_subbed);
|
||||
statements_flattened.push(Statement::Return(rhs));
|
||||
},
|
||||
Statement::Definition(ref id, ref expr) => {
|
||||
let expr_subbed = expr.apply_substitution(&self.substitution);
|
||||
let rhs = self.flatten_expression(functions_flattened, statements_flattened, expr_subbed);
|
||||
statements_flattened.push(Statement::Definition(self.use_variable(&id), rhs));
|
||||
},
|
||||
Statement::Condition(ref expr1, ref expr2) => {
|
||||
let expr1_subbed = expr1.apply_substitution(&self.substitution);
|
||||
let expr2_subbed = expr2.apply_substitution(&self.substitution);
|
||||
let (lhs, rhs) = if expr1_subbed.is_linear() {
|
||||
(expr1_subbed, self.flatten_expression(functions_flattened, statements_flattened, expr2_subbed))
|
||||
} else if expr2_subbed.is_linear() {
|
||||
(expr2_subbed, self.flatten_expression(functions_flattened, statements_flattened, expr1_subbed))
|
||||
} else {
|
||||
unimplemented!()
|
||||
};
|
||||
statements_flattened.push(Statement::Condition(lhs, rhs));
|
||||
},
|
||||
Statement::For(ref var, ref start, ref end, ref statements) => {
|
||||
let mut current = start.clone();
|
||||
while ¤t < end {
|
||||
statements_flattened.push(Statement::Definition(self.use_variable(&var), Expression::Number(current.clone())));
|
||||
for s in statements {
|
||||
self.flatten_statement(functions_flattened, statements_flattened, s);
|
||||
}
|
||||
current = T::one() + ¤t;
|
||||
}
|
||||
},
|
||||
ref s @ Statement::Compiler(..) => statements_flattened.push(s.clone()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns a flattened `Function` based on the given `funct`.
|
||||
///
|
||||
/// # Arguments
|
||||
|
@ -350,34 +428,9 @@ impl Flattener {
|
|||
/// * `functions_flattened` - Vector where new flattened statements can be added.
|
||||
/// * `funct` - `Function` that will be flattened.
|
||||
pub fn flatten_function<T: Field>(&mut self, functions_flattened: &mut Vec<Function<T>>, funct: Function<T>) -> Function<T> {
|
||||
let mut statements_flattened = Vec::new();
|
||||
let mut statements_flattened: Vec<Statement<T>> = Vec::new();
|
||||
for stat in funct.statements {
|
||||
match stat {
|
||||
Statement::Return(expr) => {
|
||||
let expr_subbed = expr.apply_substitution(&self.substitution);
|
||||
let rhs = self.flatten_expression(&functions_flattened, &mut statements_flattened, expr_subbed);
|
||||
statements_flattened.push(Statement::Return(rhs));
|
||||
},
|
||||
Statement::Definition(id, expr) => {
|
||||
let expr_subbed = expr.apply_substitution(&self.substitution);
|
||||
let rhs = self.flatten_expression(&functions_flattened, &mut statements_flattened, expr_subbed);
|
||||
statements_flattened.push(Statement::Definition(self.use_variable(id), rhs));
|
||||
},
|
||||
Statement::Condition(expr1, expr2) => {
|
||||
let expr1_subbed = expr1.apply_substitution(&self.substitution);
|
||||
let expr2_subbed = expr2.apply_substitution(&self.substitution);
|
||||
let (lhs, rhs) = if expr1_subbed.is_linear() {
|
||||
(expr1_subbed, self.flatten_expression(&functions_flattened, &mut statements_flattened, expr2_subbed))
|
||||
} else if expr2_subbed.is_linear() {
|
||||
(expr2_subbed, self.flatten_expression(&functions_flattened, &mut statements_flattened, expr1_subbed))
|
||||
} else {
|
||||
unimplemented!()
|
||||
};
|
||||
statements_flattened.push(Statement::Condition(lhs, rhs));
|
||||
},
|
||||
Statement::For(..) => unimplemented!(),
|
||||
s @ Statement::Compiler(..) => statements_flattened.push(s),
|
||||
}
|
||||
self.flatten_statement(functions_flattened, &mut statements_flattened, &stat);
|
||||
}
|
||||
Function { id: funct.id, arguments: funct.arguments, statements: statements_flattened }
|
||||
}
|
||||
|
@ -393,6 +446,9 @@ impl Flattener {
|
|||
self.substitution = HashMap::new();
|
||||
self.next_var_idx = 0;
|
||||
for func in prog.functions{
|
||||
self.variables = HashSet::new();
|
||||
self.substitution = HashMap::new();
|
||||
self.next_var_idx = 0;
|
||||
let flattened_func = self.flatten_function(&mut functions_flattened, func);
|
||||
functions_flattened.push(flattened_func);
|
||||
}
|
||||
|
@ -405,7 +461,7 @@ impl Flattener {
|
|||
/// # Arguments
|
||||
///
|
||||
/// * `name` - A String that holds the name of the variable
|
||||
fn use_variable(&mut self, name: String) -> String {
|
||||
fn use_variable(&mut self, name: &String) -> String {
|
||||
let mut i = 0;
|
||||
let mut new_name = name.to_string();
|
||||
loop {
|
||||
|
@ -415,7 +471,7 @@ impl Flattener {
|
|||
} else {
|
||||
self.variables.insert(new_name.to_string());
|
||||
if i == 1 {
|
||||
self.substitution.insert(name, new_name.to_string());
|
||||
self.substitution.insert(name.to_string(), new_name.to_string());
|
||||
} else if i > 1 {
|
||||
self.substitution.insert(format!("{}_{}", name, i - 2), new_name.to_string());
|
||||
}
|
||||
|
|
|
@ -46,11 +46,11 @@ pub fn run_libsnark<T: Field>(variables: Vec<String>, a: Vec<Vec<(usize, T)>>, b
|
|||
}
|
||||
|
||||
//debugging output
|
||||
println!("debugging output:");
|
||||
println!("a_arr {:?}", a_arr);
|
||||
println!("b_arr {:?}", b_arr);
|
||||
println!("c_arr {:?}", c_arr);
|
||||
println!("w_arr {:?}", w_arr);
|
||||
//println!("debugging output:");
|
||||
//println!("a_arr {:?}", a_arr);
|
||||
//println!("b_arr {:?}", b_arr);
|
||||
//println!("c_arr {:?}", c_arr);
|
||||
//println!("w_arr {:?}", w_arr);
|
||||
|
||||
unsafe {
|
||||
_run_libsnark(a_arr[0].as_ptr(),b_arr[0].as_ptr(), c_arr[0].as_ptr(), w_arr[0].as_ptr(), num_constraints as i32, num_variables as i32, num_inputs as i32)
|
||||
|
|
|
@ -75,6 +75,10 @@ fn main() {
|
|||
// generate wittness
|
||||
let witness_map = program_flattened.get_witness(inputs);
|
||||
println!("witness_map {:?}", witness_map);
|
||||
match witness_map.get("~out") {
|
||||
Some(out) => println!("~out: {}", out),
|
||||
None => println!("~out not found")
|
||||
}
|
||||
let witness: Vec<_> = variables.iter().map(|x| witness_map[x].clone()).collect();
|
||||
println!("witness {:?}", witness);
|
||||
|
||||
|
|
|
@ -464,7 +464,10 @@ fn parse_statement1<T: Field>(ide: String, input: String, pos: Position) -> Resu
|
|||
assert_eq!(s3, "");
|
||||
Ok((Statement::Definition(ide, e2), s2, p2))
|
||||
},
|
||||
(t3, _, p3) => Err(Error { expected: vec![Token::Add, Token::Sub, Token::Pow, Token::Mult, Token::Div, Token::Unknown("".to_string())], got: t3 , pos: p3 }),
|
||||
(t3, _, p3) => {
|
||||
println!("here {}", input);
|
||||
Err(Error { expected: vec![Token::Add, Token::Sub, Token::Pow, Token::Mult, Token::Div, Token::Unknown("".to_string())], got: t3 , pos: p3 })
|
||||
},
|
||||
},
|
||||
Err(err) => Err(err),
|
||||
},
|
||||
|
|
Loading…
Reference in a new issue