1
0
Fork 0
mirror of synced 2025-09-23 12:18:44 +00:00

upgrade flattening for multi return

This commit is contained in:
schaeff 2018-01-22 15:07:08 +01:00
parent 6ec373ae4d
commit 8d76c39ff6
5 changed files with 312 additions and 58 deletions

View file

@ -1,6 +1,7 @@
def main(x):
x, y = dup(x)
return 1
def foo(a):
b = 12*a
return a, 2*a, 5*b, a*b
def dup(x):
return x, x
def main(i):
x, y, z, t = foo(i)
return 1

View file

@ -80,8 +80,13 @@ impl<T: Field> Function<T> {
for statement in &self.statements {
match *statement {
Statement::Return(ref expr) => {
let s = expr.solve(&mut witness);
witness.insert("~out".to_string(), s);
match expr.clone() {
Expression::List(values) => {
let s = values[0].solve(&mut witness);
witness.insert("~out".to_string(), s);
},
_ => panic!("should return a list")
}
}
Statement::Compiler(ref id, ref expr) | Statement::Definition(ref id, ref expr) => {
let s = expr.solve(&mut witness);
@ -91,7 +96,7 @@ impl<T: Field> Function<T> {
Statement::Condition(ref lhs, ref rhs) => {
assert_eq!(lhs.solve(&mut witness), rhs.solve(&mut witness))
},
Statement::MultipleDefinition(ref ids, ref expr) => unimplemented!()
Statement::MultipleDefinition(..) => panic!("No MultipleDefinition allowed in flattened code"),
}
}
witness
@ -134,7 +139,7 @@ impl<T: Field> fmt::Debug for Function<T> {
}
}
#[derive(Clone, Serialize, Deserialize)]
#[derive(Clone, Serialize, Deserialize, PartialEq)]
pub enum Statement<T: Field> {
Return(Expression<T>),
Definition(String, Expression<T>),
@ -147,7 +152,8 @@ pub enum Statement<T: Field> {
impl<T: Field> Statement<T> {
pub fn is_flattened(&self) -> bool {
match *self {
Statement::Return(ref x) | Statement::Definition(_, ref x) | Statement::MultipleDefinition(_, ref x) => x.is_flattened(),
Statement::Definition(_, ref x) | Statement::MultipleDefinition(_, ref x) => x.is_flattened(),
Statement::Return(ref x) => x.is_flattened(),
Statement::Compiler(..) => true,
Statement::Condition(ref x, ref y) => {
(x.is_linear() && y.is_flattened()) || (x.is_flattened() && y.is_linear())
@ -230,7 +236,7 @@ pub enum Expression<T: Field> {
Pow(Box<Expression<T>>, Box<Expression<T>>),
IfElse(Box<Condition<T>>, Box<Expression<T>>, Box<Expression<T>>),
FunctionCall(String, Vec<Expression<T>>),
Destructure(Vec<Expression<T>>),
List(Vec<Expression<T>>),
}
impl<T: Field> Expression<T> {
@ -276,11 +282,10 @@ impl<T: Field> Expression<T> {
param.apply_substitution(substitution);
}
Expression::FunctionCall(i.clone(), p.clone())
},
Expression::List(ref exprs) => {
Expression::List(exprs.iter().map(|e| e.apply_substitution(substitution)).collect::<Vec<_>>())
}
Expression::Destructure(ref ids) => unimplemented!()
// Expression::Destructure(
// ids.iter().map(|id| Expression::Identifier(id).apply_substitution(substitution)).collect::<Vec<_>>()
// )
}
}
@ -333,7 +338,7 @@ impl<T: Field> Expression<T> {
}
}
Expression::FunctionCall(_, _) => unimplemented!(), // should not happen, since never part of flattened functions
Expression::Destructure(_) => unimplemented!() // same
Expression::List(_) => unimplemented!() // same
}
}
@ -366,7 +371,10 @@ impl<T: Field> Expression<T> {
(box Expression::Sub(..), _) | (_, box Expression::Sub(..)) => false,
(box x, box y) => x.is_linear() && y.is_linear(),
}
}
},
Expression::List(ref exprs) => {
exprs.into_iter().fold(true, |acc, x| acc && x.is_flattened())
},
_ => false,
}
}
@ -394,12 +402,20 @@ impl<T: Field> fmt::Display for Expression<T> {
for (i, param) in p.iter().enumerate() {
try!(write!(f, "{}", param));
if i < p.len() - 1 {
try!(write!(f, ","));
try!(write!(f, ", "));
}
}
write!(f, ")")
}
Expression::Destructure(..) => unimplemented!()
Expression::List(ref exprs) => {
for (i, param) in exprs.iter().enumerate() {
try!(write!(f, "{}", param));
if i < exprs.len() - 1 {
try!(write!(f, ", "));
}
}
write!(f, "")
},
}
}
}
@ -426,7 +442,7 @@ impl<T: Field> fmt::Debug for Expression<T> {
try!(f.debug_list().entries(p.iter()).finish());
write!(f, ")")
}
Expression::Destructure(ref ids) => write!(f, "Destructure({:?})", ids),
Expression::List(ref exprs) => write!(f, "List({:?})", exprs),
}
}
}

View file

@ -219,6 +219,9 @@ impl Flattener {
{
x.clone()
}
List(ref exprs) if exprs.iter().fold(true, |acc, x| acc && x.is_flattened()) => {
List(exprs.clone())
},
Add(box left, box right) => {
let left_flattened = self.flatten_expression(
functions_flattened,
@ -423,6 +426,7 @@ impl Flattener {
FunctionCall(ref id, ref param_expressions) => {
for funct in functions_flattened {
if funct.id == *id && funct.arguments.len() == (*param_expressions).len() {
// funct is now the called function
// Idea: variables are given a prefix.
@ -476,8 +480,13 @@ impl Flattener {
match stat {
// set return statements right side as expression result
Statement::Return(x) => {
let result = x.apply_substitution(&replacement_map);
return result;
match x {
List(values) => {
let new_values = values.into_iter().map(|x| x.apply_substitution(&replacement_map)).collect::<Vec<_>>();
return List(new_values)
},
_ => panic!("should return a List")
}
},
Statement::Definition(var, rhs) => {
let new_rhs = rhs.apply_substitution(&replacement_map);
@ -500,7 +509,32 @@ impl Flattener {
.push(Statement::Condition(new_lhs, new_rhs));
},
Statement::For(..) => panic!("Not flattened!"),
Statement::MultipleDefinition(..) => unimplemented!(),
// Statement::MultipleDefinition(e1, e2) => {
// let new_rhs = e2.apply_substitution(&replacement_map);
// match new_rhs {
// Expression::List(rhslist) => {
// match e1 {
// Expression::List(exprs) => {
// for (i, e) in exprs.into_iter().enumerate() {
// match e {
// Expression::Identifier(var) => {
// let new_var: String = format!("{}{}", prefix, var.clone());
// replacement_map.insert(var, new_var.clone());
// statements_flattened.push(
// Statement::Definition(new_var, rhslist[i].clone())
// );
// },
// _ => panic!("")
// }
// }
// },
// _ => panic!("")
// }
// },
// _ => panic!("")
// }
// },
_ => unimplemented!()
}
}
}
@ -523,15 +557,21 @@ impl Flattener {
stat: &Statement<T>,
) {
match *stat {
Statement::Return(ref expr) => {
let expr_subbed = expr.apply_substitution(&self.substitution);
Statement::Return(ref exprs) => {
let exprs_subbed = exprs.clone().apply_substitution(&self.substitution);
let rhs = self.flatten_expression(
functions_flattened,
arguments_flattened,
statements_flattened,
expr_subbed,
exprs_subbed,
);
statements_flattened.push(Statement::Return(rhs));
match rhs.clone() {
List(_) => {
statements_flattened.push(Statement::Return(rhs));
},
_ => panic!("")
}
}
Statement::Definition(ref id, ref expr) => {
let expr_subbed = expr.apply_substitution(&self.substitution);
@ -596,7 +636,46 @@ impl Flattener {
}
}
ref s @ Statement::Compiler(..) => statements_flattened.push(s.clone()),
Statement::MultipleDefinition(..) => unimplemented!(),
Statement::MultipleDefinition(ref e1, ref e2) => {
match *e1 {
Expression::List(ref exprs) => {
match *e2 {
FunctionCall(..) => {
let expr_subbed = e2.apply_substitution(&self.substitution);
let rhs = self.flatten_expression(
functions_flattened,
arguments_flattened,
statements_flattened,
expr_subbed,
);
match rhs {
Expression::List(rhslis) => {
let rhslist = rhslis.clone();
for (i, exp) in exprs.into_iter().enumerate() {
match *exp {
Expression::Identifier(ref id) => {
let var = self.use_variable(&id);
// handle return of function call
let var_to_replace = self.get_latest_var_substitution(&id);
if !(var == var_to_replace) && self.variables.contains(&var_to_replace) && !self.substitution.contains_key(&var_to_replace){
self.substitution.insert(var_to_replace.clone().to_string(),var.clone());
}
statements_flattened.push(Statement::Definition(var, rhslist[i].clone()));
},
_ => panic!("Only identifiers can be on the left side of a definition")
}
}
},
_ => panic!("")
}
},
_ => panic!("")
}
},
_ => panic!("")
}
},
}
}
@ -635,7 +714,7 @@ impl Flattener {
id: funct.id,
arguments: arguments_flattened,
statements: statements_flattened,
return_count: 1
return_count: funct.return_count
}
}
@ -690,5 +769,142 @@ impl Flattener {
}
}
}
}
#[cfg(test)]
mod multiple_definition {
use super::*;
use field::FieldPrime;
#[test]
fn multiple_definition() {
// def foo()
// return 1, 2
// def main()
// a, b = foo()
let mut flattener = Flattener::new(FieldPrime::get_required_bits());
let mut functions_flattened = vec![
Function {
id: "foo".to_string(),
arguments: vec![],
statements: vec![Statement::Return(Expression::List(vec![
Expression::Number(FieldPrime::from(1)),
Expression::Number(FieldPrime::from(2))
]))],
return_count: 2,
}
];
let arguments_flattened = vec![];
let mut statements_flattened = vec![];
let statement = Statement::MultipleDefinition(
Expression::List(
vec![
Expression::Identifier("a".to_string()),
Expression::Identifier("b".to_string())
]
),
Expression::FunctionCall("foo".to_string(), vec![])
);
flattener.flatten_statement(
&mut functions_flattened,
&arguments_flattened,
&mut statements_flattened,
&statement,
);
assert_eq!(
statements_flattened[0]
,
Statement::Definition("a".to_string(), Expression::Number(FieldPrime::from(1)))
);
}
#[test]
fn multiple_definition2() {
// def dup(x)
// return x, x
// def main()
// a, b = dup(2)
let mut flattener = Flattener::new(FieldPrime::get_required_bits());
let mut functions_flattened = vec![
Function {
id: "dup".to_string(),
arguments: vec![Parameter { id: "x".to_string() }],
statements: vec![Statement::Return(Expression::List(vec![
Expression::Identifier("x".to_string()),
Expression::Identifier("x".to_string()),
]))],
return_count: 2,
}
];
let arguments_flattened = vec![];
let mut statements_flattened = vec![];
let statement = Statement::MultipleDefinition(
Expression::List(
vec![
Expression::Identifier("a".to_string()),
Expression::Identifier("b".to_string())
]
),
Expression::FunctionCall("dup".to_string(), vec![Expression::Number(FieldPrime::from(2))])
);
flattener.flatten_statement(
&mut functions_flattened,
&arguments_flattened,
&mut statements_flattened,
&statement,
);
assert_eq!(
statements_flattened[0]
,
Statement::Definition("a".to_string(), Expression::Number(FieldPrime::from(2)))
);
}
#[test]
fn simple_definition() {
// def foo()
// return 1
// def main()
// a = foo()
let mut flattener = Flattener::new(FieldPrime::get_required_bits());
let mut functions_flattened = vec![
Function {
id: "foo".to_string(),
arguments: vec![],
statements: vec![Statement::Return(Expression::List(vec![
Expression::Number(FieldPrime::from(1))
]))],
return_count: 1,
}
];
let arguments_flattened = vec![];
let mut statements_flattened = vec![];
let statement = Statement::Definition(
"a".to_string(),
Expression::FunctionCall("foo".to_string(), vec![])
);
flattener.flatten_statement(
&mut functions_flattened,
&arguments_flattened,
&mut statements_flattened,
&statement,
);
assert_eq!(
statements_flattened[0]
,
Statement::Definition("a".to_string(), Expression::Number(FieldPrime::from(1)))
);
}
}

View file

@ -790,7 +790,7 @@ fn parse_expression_list1<T: Field>(
let mut res = Vec::new();
res.push(Expression::Identifier(head));
match parse_comma_separated_list_rec(input, pos, &mut res) {
Ok((list, s1, p1)) => Ok((Expression::Destructure(list), s1, p1)),
Ok((list, s1, p1)) => Ok((Expression::List(list), s1, p1)),
Err(err) => Err(err)
}
}
@ -802,7 +802,7 @@ fn parse_expression_list<T: Field>(
) -> Result<(Expression<T>, String, Position), Error<T>> {
let mut res = Vec::new();
match parse_comma_separated_list_rec(input, pos, &mut res) {
Ok((list, s1, p1)) => Ok((Expression::Destructure(list), s1, p1)),
Ok((list, s1, p1)) => Ok((Expression::List(list), s1, p1)),
Err(err) => Err(err)
}
}
@ -1111,15 +1111,15 @@ fn parse_statement<T: Field>(
pos: p2,
}),
},
(Token::Return, s1, p1) => match parse_expr(&s1, &p1) {
Ok((expr, s2, p2)) => match next_token(&s2, &p2) {
(Token::Return, s1, p1) => match parse_expression_list(s1, p1) {
Ok((Expression::List(xs), s2, p2)) => match next_token(&s2, &p2) {
(Token::InlineComment(_), ref s3, _) => {
assert_eq!(s3, "");
Ok((Statement::Return(expr), s2, p2))
Ok((Statement::Return(Expression::List(xs)), s2, p2))
}
(Token::Unknown(ref t3), ref s3, _) if t3 == "" => {
assert_eq!(s3, "");
Ok((Statement::Return(expr), s2, p2))
Ok((Statement::Return(Expression::List(xs)), s2, p2))
}
(t4, _, p4) => Err(Error {
expected: vec![
@ -1134,6 +1134,7 @@ fn parse_statement<T: Field>(
pos: p4,
}),
},
Ok(..) => unimplemented!(),
Err(err) => Err(err),
},
(Token::Def, _, p1) => Err(Error {
@ -1266,7 +1267,7 @@ fn parse_function<T: Field>(
// parse function body
let mut stats = Vec::new();
let mut return_count;
let return_count;
loop {
match lines.next() {
Some(Ok(ref x)) if x.trim().starts_with("//") || x.trim() == "" => {} // skip
@ -1278,9 +1279,9 @@ fn parse_function<T: Field>(
col: 1,
},
) {
Ok((statement @ Statement::Return(_), ..)) => {
return_count = 1;
stats.push(statement);
Ok((Statement::Return(Expression::List(exprs)), ..)) => {
return_count = exprs.len();
stats.push(Statement::Return(Expression::List(exprs)));
break;
}
Ok((statement, _, pos)) => {
@ -1387,14 +1388,6 @@ pub fn parse_program<T: Field>(file: File) -> Result<Prog<T>, Error<T>> {
Ok(Prog { functions })
}
fn parse_comma_separated_list<T: Field>(
input: String,
pos: Position
) -> Result<(Vec<Expression<T>>, String, Position), Error<T>> {
let mut res = Vec::new();
parse_comma_separated_list_rec(input, pos, &mut res)
}
fn parse_comma_separated_list_rec<T: Field>(
input: String,
pos: Position,
@ -1599,7 +1592,7 @@ mod tests {
fn destructure1() {
let pos = Position { line: 45, col: 121 };
let string = String::from("b, c");
let expr = Expression::Destructure::<FieldPrime>(vec![Expression::Identifier(String::from("a")),Expression::Identifier(String::from("b")),Expression::Identifier(String::from("c"))]);
let expr = Expression::List::<FieldPrime>(vec![Expression::Identifier(String::from("a")),Expression::Identifier(String::from("b")),Expression::Identifier(String::from("c"))]);
assert_eq!(
Ok((expr, String::from(""), pos.col(string.len() as isize))),
parse_expression_list1(String::from("a"), string, pos)
@ -1610,7 +1603,7 @@ mod tests {
fn destructure() {
let pos = Position { line: 45, col: 121 };
let string = String::from("b, c");
let expr = Expression::Destructure::<FieldPrime>(vec![Expression::Identifier(String::from("b")),Expression::Identifier(String::from("c"))]);
let expr = Expression::List::<FieldPrime>(vec![Expression::Identifier(String::from("b")),Expression::Identifier(String::from("c"))]);
assert_eq!(
Ok((expr, String::from(""), pos.col(string.len() as isize))),
parse_expression_list(string, pos)
@ -1653,6 +1646,29 @@ mod tests {
// test impossible to run forever?
}
// #[cfg(test)]
// mod parse_program {
// use super::*;
// #[test]
// fn single_output() {
// let pos = Position { line: 45, col: 121 };
// let string = "
// def foo():
// return 1
// ";
// let fun = Function {
// id: "foo".to_string(),
// arguments: vec![],
// statements: vec![Expression::Return(Expression::Number(FieldPrime::from(1)))],
// return_count: 1
// };
// assert_eq!(
// Ok((fun, String::from(""), pos.col(string.len() as isize))),
// parse_function(string, pos)
// )
// }
// }
// parse function
// parse_term1
// parse_term

View file

@ -304,14 +304,19 @@ pub fn r1cs_program<T: Field>(
let mut b_row: Vec<(usize, T)> = Vec::new();
let mut c_row: Vec<(usize, T)> = Vec::new();
match *def {
Statement::Return(ref expr) => r1cs_expression(
Identifier("~out".to_string()),
expr.clone(),
&mut variables,
&mut a_row,
&mut b_row,
&mut c_row,
),
Statement::Return(ref expr) => {
match expr.clone() {
Expression::List(values) => r1cs_expression(
Identifier("~out".to_string()),
values[0].clone(),
&mut variables,
&mut a_row,
&mut b_row,
&mut c_row,
),
_ => panic!("should return a List")
}
},
Statement::Definition(ref id, ref expr) => r1cs_expression(
Identifier(id.to_string()),
expr.clone(),