Function flattening rewritten
This commit is contained in:
parent
607be16d00
commit
c213330ed0
4 changed files with 72 additions and 37 deletions
|
@ -2,7 +2,7 @@ def add(a,b):
|
|||
v = a + b
|
||||
return v
|
||||
|
||||
// Expected for inputs 1,1: c=4, d=4, e=4
|
||||
// Expected for inputs 1,1: c=4, d=7, e=10
|
||||
def main(a,b):
|
||||
c = add(a*2+3*b-a,b-1)
|
||||
d = add(a*b+2, a*b*c)
|
||||
|
|
|
@ -2,7 +2,6 @@ def add(a,b):
|
|||
v = a + b
|
||||
return v
|
||||
|
||||
// is the direct return a problem?
|
||||
def sub(a,b):
|
||||
return a-b
|
||||
|
||||
|
|
|
@ -57,17 +57,17 @@ counter = 0 // globally counts duplicate entries in boxes, rows and columns
|
|||
|
||||
// check row correctnes
|
||||
|
||||
counter = counter + checkEquality(a11,a21,b11,b21)
|
||||
counter = counter + checkEquality(a11,a12,b11,b12)
|
||||
counter = counter + checkEquality(a21,a22,b21,b22)
|
||||
counter = counter + checkEquality(c11,c21,d11,d12)
|
||||
counter = counter + checkEquality(c21,c22,d21,d22)
|
||||
|
||||
// check column correctnes
|
||||
|
||||
counter = counter + checkEquality(a11,a21,b11,b21)
|
||||
counter = counter + checkEquality(a21,a22,b21,b22)
|
||||
counter = counter + checkEquality(c11,c21,d11,d12)
|
||||
counter = counter + checkEquality(c21,c22,d21,d22)
|
||||
counter = counter + checkEquality(a11,a21,c11,c21)
|
||||
counter = counter + checkEquality(a12,a22,c12,c22)
|
||||
counter = counter + checkEquality(b11,b21,d11,d21)
|
||||
counter = counter + checkEquality(b12,b22,d12,d22)
|
||||
|
||||
// assert counter is 0
|
||||
counter == 0
|
||||
|
|
|
@ -448,8 +448,16 @@ impl Flattener {
|
|||
|
||||
for funct in functions_flattened {
|
||||
if funct.id == *id && funct.arguments.len() == (*param_expressions).len() {
|
||||
// add temporary substitution for the parameters
|
||||
let mut temp_substitution: HashMap<String, String> = HashMap::new(); // substitutions for parameters are only valid during the function call's processing
|
||||
// funct is now the called function
|
||||
|
||||
// add all variables of caller to ensure no conflicting variable can be used in funct
|
||||
// update with new variables created during processing
|
||||
let mut used_vars: HashSet<String> = self.variables.clone();
|
||||
// if conflicting variable is found, a replacement variable needs to be created
|
||||
// and the substitution needs to be added to replacement map.
|
||||
let mut replacement_map: HashMap<String, String> = HashMap::new();
|
||||
|
||||
println!("used variables: {:?}", used_vars);
|
||||
|
||||
println!("Called Function's Arguments: {:?}", funct.arguments);
|
||||
println!("Calling Function's Arguments: {:?}", params_flattened);
|
||||
|
@ -459,10 +467,10 @@ impl Flattener {
|
|||
let identifier_called: String =
|
||||
funct.arguments.get(i).unwrap().id.clone();
|
||||
if identifier_called != identifier_call{
|
||||
temp_substitution.insert(identifier_called, identifier_call);
|
||||
replacement_map.insert(identifier_called, identifier_call);
|
||||
}
|
||||
}
|
||||
println!("Param substitutions: {:?}", temp_substitution);
|
||||
println!("Param substitutions: {:?}", replacement_map);
|
||||
|
||||
|
||||
// add all flattened statements, adapt return statement
|
||||
|
@ -471,36 +479,43 @@ impl Flattener {
|
|||
match stat {
|
||||
// set return statements right side as expression result
|
||||
Statement::Return(x) => {
|
||||
let result = x.apply_substitution(&temp_substitution)
|
||||
.apply_substitution(&self.substitution);
|
||||
println!("function return substitutions:\n {:?}", &self.substitution);
|
||||
let result = x.apply_substitution(&replacement_map);
|
||||
// add back variables and substitutions to calling function
|
||||
for v in used_vars{
|
||||
self.variables.insert(v);
|
||||
}
|
||||
for (k,v) in replacement_map{
|
||||
self.substitution.insert(k,v);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
Statement::Definition(var, rhs) => {
|
||||
let new_rhs = rhs.apply_substitution(&temp_substitution)
|
||||
.apply_substitution(&self.substitution);
|
||||
let subsvar: String = self.use_variable(&var);
|
||||
if var != subsvar {
|
||||
temp_substitution.insert(var.clone(), subsvar.clone());
|
||||
let new_rhs = rhs.apply_substitution(&replacement_map);
|
||||
let mut new_var: String = var.clone();
|
||||
if used_vars.contains(&var){
|
||||
new_var = self.get_new_var(&var, &mut used_vars);
|
||||
replacement_map.insert(var, new_var.clone());
|
||||
} else{
|
||||
used_vars.insert(new_var.clone()); // variables must not be used again
|
||||
}
|
||||
statements_flattened.push(
|
||||
Statement::Definition(subsvar, new_rhs)
|
||||
Statement::Definition(new_var, new_rhs)
|
||||
);
|
||||
}
|
||||
Statement::Compiler(var, rhs) => {
|
||||
let new_rhs = rhs.apply_substitution(&temp_substitution)
|
||||
.apply_substitution(&self.substitution);
|
||||
let subsvar = self.use_variable(&var);
|
||||
if var != subsvar {
|
||||
temp_substitution.insert(var.clone(), subsvar.clone());
|
||||
let new_rhs = rhs.apply_substitution(&replacement_map);
|
||||
let mut new_var: String = var.clone();
|
||||
if used_vars.contains(&var){
|
||||
new_var = self.get_new_var(&var, &mut used_vars);
|
||||
replacement_map.insert(var, new_var.clone());
|
||||
} else{
|
||||
used_vars.insert(new_var.clone()); // variables must not be used again
|
||||
}
|
||||
statements_flattened.push(Statement::Compiler(subsvar, new_rhs));
|
||||
statements_flattened.push(Statement::Compiler(new_var, new_rhs));
|
||||
}
|
||||
Statement::Condition(lhs, rhs) => {
|
||||
let new_lhs = lhs.apply_substitution(&temp_substitution)
|
||||
.apply_substitution(&self.substitution);
|
||||
let new_rhs = rhs.apply_substitution(&temp_substitution)
|
||||
.apply_substitution(&self.substitution);
|
||||
let new_lhs = lhs.apply_substitution(&replacement_map);
|
||||
let new_rhs = rhs.apply_substitution(&replacement_map);
|
||||
statements_flattened
|
||||
.push(Statement::Condition(new_lhs, new_rhs));
|
||||
}
|
||||
|
@ -534,7 +549,6 @@ impl Flattener {
|
|||
statements_flattened,
|
||||
expr_subbed,
|
||||
);
|
||||
println!("substitution at return: {:?}", &self.substitution);
|
||||
statements_flattened.push(Statement::Return(rhs));
|
||||
}
|
||||
Statement::Definition(ref id, ref expr) => {
|
||||
|
@ -545,7 +559,13 @@ impl Flattener {
|
|||
statements_flattened,
|
||||
expr_subbed,
|
||||
);
|
||||
statements_flattened.push(Statement::Definition(self.use_variable(&id), rhs));
|
||||
let var = self.use_variable(&id);
|
||||
// handle return of function call
|
||||
if !(var == *id) && self.variables.contains(id) && !self.substitution.contains_key(id){
|
||||
self.substitution.insert(id.clone().to_string(),var.clone());
|
||||
}
|
||||
println!("substitution: {:?}",self.substitution);
|
||||
statements_flattened.push(Statement::Definition(var, rhs));
|
||||
}
|
||||
Statement::Condition(ref expr1, ref expr2) => {
|
||||
let expr1_subbed = expr1.apply_substitution(&self.substitution);
|
||||
|
@ -608,12 +628,15 @@ impl Flattener {
|
|||
functions_flattened: &mut Vec<Function<T>>,
|
||||
funct: Function<T>,
|
||||
) -> Function<T> {
|
||||
self.variables = HashSet::new();
|
||||
self.substitution = HashMap::new();
|
||||
self.next_var_idx = 0;
|
||||
let mut arguments_flattened: Vec<Parameter> = Vec::new();
|
||||
let mut statements_flattened: Vec<Statement<T>> = Vec::new();
|
||||
// flatten parameters (substitute name to guarantee global uniqueness)
|
||||
// push parameters
|
||||
for arg in funct.arguments {
|
||||
arguments_flattened.push(Parameter {
|
||||
id: self.use_variable(&arg.id),
|
||||
id: arg.id.to_string(),
|
||||
});
|
||||
}
|
||||
// flatten statements in functions and apply substitution
|
||||
|
@ -639,9 +662,6 @@ impl Flattener {
|
|||
/// * `prog` - `Prog`ram that will be flattened.
|
||||
pub fn flatten_program<T: Field>(&mut self, prog: Prog<T>) -> Prog<T> {
|
||||
let mut functions_flattened = Vec::new();
|
||||
self.variables = HashSet::new();
|
||||
self.substitution = HashMap::new();
|
||||
self.next_var_idx = 0;
|
||||
for func in prog.functions {
|
||||
let flattened_func = self.flatten_function(&mut functions_flattened, func);
|
||||
functions_flattened.push(flattened_func);
|
||||
|
@ -676,4 +696,20 @@ impl Flattener {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
// used for function call flattening
|
||||
fn get_new_var(&mut self, name: &String, used_vars: &mut HashSet<String>) -> String{
|
||||
let mut i = 0;
|
||||
let mut new_name = name.to_string();
|
||||
loop {
|
||||
if used_vars.contains(&new_name) {
|
||||
new_name = format!("{}_{}", &name, i);
|
||||
i += 1;
|
||||
} else {
|
||||
used_vars.insert(new_name.to_string());
|
||||
return new_name;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue