From a946299c71035282e7c0909b4d3ce8bf6c00696e Mon Sep 17 00:00:00 2001 From: schaeff Date: Tue, 28 May 2019 14:51:50 +0200 Subject: [PATCH] adapt flatten tests --- zokrates_core/src/flat_absy/mod.rs | 2 +- zokrates_core/src/flatten/mod.rs | 1567 ++++++++++------------ zokrates_core/src/static_analysis/mod.rs | 2 +- 3 files changed, 674 insertions(+), 897 deletions(-) diff --git a/zokrates_core/src/flat_absy/mod.rs b/zokrates_core/src/flat_absy/mod.rs index bac310a0..29024e39 100644 --- a/zokrates_core/src/flat_absy/mod.rs +++ b/zokrates_core/src/flat_absy/mod.rs @@ -17,7 +17,7 @@ use std::collections::HashMap; use std::fmt; use zokrates_field::field::Field; -#[derive(Clone)] +#[derive(Clone, PartialEq)] pub struct FlatProg { /// FlatFunctions of the program pub main: FlatFunction, diff --git a/zokrates_core/src/flatten/mod.rs b/zokrates_core/src/flatten/mod.rs index 2b1d973a..7d9bb967 100644 --- a/zokrates_core/src/flatten/mod.rs +++ b/zokrates_core/src/flatten/mod.rs @@ -1244,912 +1244,689 @@ mod tests { use zokrates_field::field::FieldPrime; #[test] - fn multiple_definition() { - // def foo() - // return 1, 2 - // def main() - // a, b = foo() + fn powers() { + // def main(): + // field a = 7 + // field b = a**4 + // return b - let mut flattener = Flattener::new(); - let foo = TypedFunction { + // def main(): + // _0 = 7 + // _1 = (_0 * _0) + // _2 = (_1 * _0) + // _3 = (_2 * _0) + // return _3 + + let main_fun = TypedFunction { arguments: vec![], - statements: vec![TypedStatement::Return(vec![ - FieldElementExpression::Number(FieldPrime::from(1)).into(), - FieldElementExpression::Number(FieldPrime::from(2)).into(), - ])], - signature: Signature::new() - .inputs(vec![]) - .outputs(vec![Type::FieldElement, Type::FieldElement]), - }; - let mut statements_flattened = vec![]; - let statement = TypedStatement::MultipleDefinition( - vec![ - Variable::field_element("a".to_string()), - Variable::field_element("b".to_string()), - ], - TypedExpressionList::FunctionCall( - FunctionKey::with_id("foo").signature( - Signature::new().outputs(vec![Type::FieldElement, Type::FieldElement]), + statements: vec![ + TypedStatement::Definition( + TypedAssignee::Identifier(Variable::field_element("a")), + FieldElementExpression::Number(FieldPrime::from(7)).into(), ), - vec![], - vec![Type::FieldElement, Type::FieldElement], + TypedStatement::Definition( + TypedAssignee::Identifier(Variable::field_element("b")), + FieldElementExpression::Pow( + box FieldElementExpression::Identifier(String::from("a")), + box FieldElementExpression::Number(FieldPrime::from(4)), + ) + .into(), + ), + TypedStatement::Return(vec![ + FieldElementExpression::Identifier(String::from("b")).into() + ]), + ], + signature: Signature { + inputs: vec![], + outputs: vec![Type::FieldElement], + }, + }; + + let program = TypedProgram { + modules: vec![( + String::from("main"), + TypedModule { + functions: vec![( + FunctionKey::with_id("main") + .signature(Signature::new().outputs(vec![Type::FieldElement])), + TypedFunctionSymbol::Here(main_fun), + )] + .into_iter() + .collect(), + imports: vec![], + }, + )] + .into_iter() + .collect(), + main: String::from("main"), + }; + + let flat_prog = Flattener::flatten(program); + + let expected = FlatProg { + main: FlatFunction { + arguments: vec![], + statements: vec![ + FlatStatement::Definition( + FlatVariable::new(0), + FlatExpression::Number(FieldPrime::from(7)), + ), + FlatStatement::Definition( + FlatVariable::new(1), + FlatExpression::Mult( + box FlatExpression::Identifier(FlatVariable::new(0)), + box FlatExpression::Identifier(FlatVariable::new(0)), + ), + ), + FlatStatement::Definition( + FlatVariable::new(2), + FlatExpression::Mult( + box FlatExpression::Identifier(FlatVariable::new(1)), + box FlatExpression::Identifier(FlatVariable::new(0)), + ), + ), + FlatStatement::Definition( + FlatVariable::new(3), + FlatExpression::Mult( + box FlatExpression::Identifier(FlatVariable::new(2)), + box FlatExpression::Identifier(FlatVariable::new(0)), + ), + ), + FlatStatement::Return(FlatExpressionList { + expressions: vec![FlatExpression::Identifier(FlatVariable::new(3))], + }), + ], + signature: Signature::new().outputs(vec![Type::FieldElement]), + }, + }; + + assert_eq!(flat_prog, expected); + } + + #[test] + fn if_else() { + let program: TypedProgram = TypedProgram { + modules: vec![( + String::from("main"), + TypedModule { + functions: HashMap::new(), + imports: vec![], + }, + )] + .into_iter() + .collect(), + main: String::from("main"), + }; + + use static_analysis::CoreLibInjector; + let program = CoreLibInjector::inject(program); + + let expression = FieldElementExpression::IfElse( + box BooleanExpression::Eq( + box FieldElementExpression::Number(FieldPrime::from(32)), + box FieldElementExpression::Number(FieldPrime::from(4)), ), + box FieldElementExpression::Number(FieldPrime::from(12)), + box FieldElementExpression::Number(FieldPrime::from(51)), ); - let symbols = vec![( - FunctionKey::with_id("foo") - .signature(Signature::new().outputs(vec![Type::FieldElement, Type::FieldElement])), - TypedFunctionSymbol::Here(foo), - )] - .into_iter() - .collect(); + let mut statements_flattened = vec![]; - flattener.flatten_statement(&symbols, &mut statements_flattened, statement); + let functions = program.modules.get("main").unwrap().functions.clone(); - let a = FlatVariable::new(0); + let mut flattener = Flattener::with_modules(program.modules); + + flattener.flatten_field_expression(&functions, &mut statements_flattened, expression); + } + + #[test] + fn geq_leq() { + let program: TypedProgram = TypedProgram { + modules: vec![( + String::from("main"), + TypedModule { + functions: HashMap::new(), + imports: vec![], + }, + )] + .into_iter() + .collect(), + main: String::from("main"), + }; + + use static_analysis::CoreLibInjector; + let program = CoreLibInjector::inject(program); + + let functions = program.modules.get("main").unwrap().functions.clone(); + + let mut flattener = Flattener::with_modules(program.modules); + let expression_le = BooleanExpression::Le( + box FieldElementExpression::Number(FieldPrime::from(32)), + box FieldElementExpression::Number(FieldPrime::from(4)), + ); + + let expression_ge = BooleanExpression::Ge( + box FieldElementExpression::Number(FieldPrime::from(32)), + box FieldElementExpression::Number(FieldPrime::from(4)), + ); + + flattener.flatten_boolean_expression(&functions, &mut vec![], expression_le); + + flattener.flatten_boolean_expression(&functions, &mut vec![], expression_ge); + } + + #[test] + fn bool_and() { + let program: TypedProgram = TypedProgram { + modules: vec![( + String::from("main"), + TypedModule { + functions: HashMap::new(), + imports: vec![], + }, + )] + .into_iter() + .collect(), + main: String::from("main"), + }; + + use static_analysis::CoreLibInjector; + let program = CoreLibInjector::inject(program); + + let mut statements_flattened = vec![]; + + let functions = program.modules.get("main").unwrap().functions.clone(); + + let mut flattener = Flattener::with_modules(program.modules); + + let expression = FieldElementExpression::IfElse( + box BooleanExpression::And( + box BooleanExpression::Eq( + box FieldElementExpression::Number(FieldPrime::from(4)), + box FieldElementExpression::Number(FieldPrime::from(4)), + ), + box BooleanExpression::Lt( + box FieldElementExpression::Number(FieldPrime::from(4)), + box FieldElementExpression::Number(FieldPrime::from(20)), + ), + ), + box FieldElementExpression::Number(FieldPrime::from(12)), + box FieldElementExpression::Number(FieldPrime::from(51)), + ); + + flattener.flatten_field_expression(&functions, &mut statements_flattened, expression); + } + + #[test] + fn div() { + // a = 5 / b / b + let program: TypedProgram = TypedProgram { + modules: vec![( + String::from("main"), + TypedModule { + functions: HashMap::new(), + imports: vec![], + }, + )] + .into_iter() + .collect(), + main: String::from("main"), + }; + + let mut statements_flattened = vec![]; + + let functions = program.modules.get("main").unwrap().functions.clone(); + + let mut flattener = Flattener::with_modules(program.modules); + + let definition = TypedStatement::Definition( + TypedAssignee::Identifier(Variable::field_element("b")), + FieldElementExpression::Number(FieldPrime::from(42)).into(), + ); + + let statement = TypedStatement::Definition( + TypedAssignee::Identifier(Variable::field_element("a")), + FieldElementExpression::Div( + box FieldElementExpression::Div( + box FieldElementExpression::Number(FieldPrime::from(5)), + box FieldElementExpression::Identifier(String::from("b")), + ), + box FieldElementExpression::Identifier(String::from("b")), + ) + .into(), + ); + + flattener.flatten_statement(&functions, &mut statements_flattened, definition); + + flattener.flatten_statement(&functions, &mut statements_flattened, statement); + + // define b + let b = FlatVariable::new(0); + // define new wires for members of Div + let five = FlatVariable::new(1); + let b0 = FlatVariable::new(2); + // Define inverse of denominator to prevent div by 0 + let invb0 = FlatVariable::new(3); + // Define inverse + let sym_0 = FlatVariable::new(4); + // Define result, which is first member to next Div + let sym_1 = FlatVariable::new(5); + // Define second member + let b1 = FlatVariable::new(6); + // Define inverse of denominator to prevent div by 0 + let invb1 = FlatVariable::new(7); + // Define inverse + let sym_2 = FlatVariable::new(8); + // Define left hand side + let a = FlatVariable::new(9); assert_eq!( - statements_flattened[0], - FlatStatement::Definition(a, FlatExpression::Number(FieldPrime::from(1))) + statements_flattened, + vec![ + FlatStatement::Definition(b, FlatExpression::Number(FieldPrime::from(42))), + // inputs to first div (5/b) + FlatStatement::Definition(five, FlatExpression::Number(FieldPrime::from(5))), + FlatStatement::Definition(b0, b.into()), + // check div by 0 + FlatStatement::Directive(DirectiveStatement::new( + vec![invb0], + Helper::Rust(RustHelper::Div), + vec![FlatExpression::Number(FieldPrime::from(1)), b0.into()] + )), + FlatStatement::Condition( + FlatExpression::Number(FieldPrime::from(1)), + FlatExpression::Mult(box invb0.into(), box b0.into()), + ), + // execute div + FlatStatement::Directive(DirectiveStatement::new( + vec![sym_0], + Helper::Rust(RustHelper::Div), + vec![five, b0] + )), + FlatStatement::Condition( + five.into(), + FlatExpression::Mult(box b0.into(), box sym_0.into()), + ), + // inputs to second div (res/b) + FlatStatement::Definition(sym_1, sym_0.into()), + FlatStatement::Definition(b1, b.into()), + // check div by 0 + FlatStatement::Directive(DirectiveStatement::new( + vec![invb1], + Helper::Rust(RustHelper::Div), + vec![FlatExpression::Number(FieldPrime::from(1)), b1.into()] + )), + FlatStatement::Condition( + FlatExpression::Number(FieldPrime::from(1)), + FlatExpression::Mult(box invb1.into(), box b1.into()), + ), + // execute div + FlatStatement::Directive(DirectiveStatement::new( + vec![sym_2], + Helper::Rust(RustHelper::Div), + vec![sym_1, b1] + )), + FlatStatement::Condition( + sym_1.into(), + FlatExpression::Mult(box b1.into(), box sym_2.into()), + ), + // result + FlatStatement::Definition(a, sym_2.into()), + ] ); } - // #[test] - // fn multiple_definition2() { - // // def dup(x) - // // return x, x - // // def main() - // // a, b = dup(2) - - // let a = FlatVariable::new(0); - - // let mut flattener = Flattener::new(); - // let mut functions_flattened = vec![FlatFunction { - // id: "dup".to_string(), - // arguments: vec![FlatParameter { - // id: a, - // private: true, - // }], - // statements: vec![FlatStatement::Return(FlatExpressionList { - // expressions: vec![FlatExpression::Identifier(a), FlatExpression::Identifier(a)], - // })], - // signature: Signature::new() - // .inputs(vec![Type::FieldElement]) - // .outputs(vec![Type::FieldElement, Type::FieldElement]), - // }]; - // let statement = TypedStatement::MultipleDefinition( - // vec![ - // Variable::field_element("a".to_string()), - // Variable::field_element("b".to_string()), - // ], - // TypedExpressionList::FunctionCall( - // "dup".to_string(), - // vec![TypedExpression::FieldElement( - // FieldElementExpression::Number(FieldPrime::from(2)), - // )], - // vec![Type::FieldElement, Type::FieldElement], - // ), - // ); - - // let fun = TypedFunction { - // id: String::from("main"), - // arguments: vec![], - // statements: vec![statement], - // signature: Signature { - // inputs: vec![], - // outputs: vec![], - // }, - // }; - - // let f = flattener.flatten_function(&mut functions_flattened, fun); - - // let a = FlatVariable::new(0); - - // assert_eq!( - // f.statements[0], - // FlatStatement::Definition(a, FlatExpression::Number(FieldPrime::from(2))) - // ); - // } - - // #[test] - // fn simple_definition() { - // // def foo() - // // return 1 - // // def main() - // // a = foo() - - // let mut flattener = Flattener::new(); - // let mut functions_flattened = vec![FlatFunction { - // id: "foo".to_string(), - // arguments: vec![], - // statements: vec![FlatStatement::Return(FlatExpressionList { - // expressions: vec![FlatExpression::Number(FieldPrime::from(1))], - // })], - // signature: Signature::new() - // .inputs(vec![]) - // .outputs(vec![Type::FieldElement]), - // }]; - // let mut statements_flattened = vec![]; - // let statement = TypedStatement::Definition( - // TypedAssignee::Identifier(Variable::field_element("a")), - // TypedExpression::FieldElement(FieldElementExpression::FunctionCall( - // "foo".to_string(), - // vec![], - // )), - // ); - - // flattener.flatten_statement( - // &mut functions_flattened, - // &mut statements_flattened, - // statement, - // ); - - // let a = FlatVariable::new(0); - - // assert_eq!( - // statements_flattened[0], - // FlatStatement::Definition(a, FlatExpression::Number(FieldPrime::from(1))) - // ); - // } - - // #[test] - // fn redefine_argument() { - // // def foo(a) - // // a = a + 1 - // // return 1 - - // // should flatten to no redefinition - // // def foo(a) - // // a_0 = a + 1 - // // return 1 - - // let mut flattener = Flattener::new(); - // let mut functions_flattened = vec![]; - - // let funct = TypedFunction { - // id: "foo".to_string(), - // signature: Signature::new() - // .inputs(vec![Type::FieldElement]) - // .outputs(vec![Type::FieldElement]), - // arguments: vec![Parameter { - // id: Variable::field_element("a"), - // private: true, - // }], - // statements: vec![ - // TypedStatement::Definition( - // TypedAssignee::Identifier(Variable::field_element("a")), - // FieldElementExpression::Add( - // box FieldElementExpression::Identifier("a".to_string()), - // box FieldElementExpression::Number(FieldPrime::from(1)), - // ) - // .into(), - // ), - // TypedStatement::Return(vec![ - // FieldElementExpression::Number(FieldPrime::from(1)).into() - // ]), - // ], - // }; - - // let flat_funct = flattener.flatten_function(&mut functions_flattened, funct); - - // let a = FlatVariable::new(0); - // let a_0 = FlatVariable::new(1); - - // assert_eq!( - // flat_funct.statements[0], - // FlatStatement::Definition( - // a_0, - // FlatExpression::Add( - // box FlatExpression::Identifier(a), - // box FlatExpression::Number(FieldPrime::from(1)) - // ) - // ) - // ); - // } - - // #[test] - // fn call_with_def() { - // // def foo(): - // // a = 3 - // // return a - - // // def main(): - // // return foo() - - // let foo = TypedFunction { - // id: String::from("foo"), - // arguments: vec![], - // statements: vec![ - // TypedStatement::Definition( - // TypedAssignee::Identifier(Variable::field_element("a")), - // FieldElementExpression::Number(FieldPrime::from(3)).into(), - // ), - // TypedStatement::Return(vec![ - // FieldElementExpression::Identifier(String::from("a")).into() - // ]), - // ], - // signature: Signature { - // inputs: vec![], - // outputs: vec![Type::FieldElement], - // }, - // }; - - // let main = TypedFunction { - // id: String::from("main"), - // arguments: vec![], - // statements: vec![TypedStatement::Return(vec![ - // FieldElementExpression::FunctionCall(String::from("foo"), vec![]).into(), - // ])], - // signature: Signature { - // inputs: vec![], - // outputs: vec![Type::FieldElement], - // }, - // }; - - // let mut flattener = Flattener::new(); - - // let foo_flattened = flattener.flatten_function(&mut vec![], foo); - - // let expected = FlatFunction { - // id: String::from("main"), - // arguments: vec![], - // statements: vec![ - // FlatStatement::Definition( - // FlatVariable::new(0), - // FlatExpression::Number(FieldPrime::from(3)), - // ), - // FlatStatement::Return(FlatExpressionList { - // expressions: vec![FlatExpression::Identifier(FlatVariable::new(0))], - // }), - // ], - // signature: Signature::new().outputs(vec![Type::FieldElement]), - // }; - - // let main_flattened = flattener.flatten_function(&mut vec![foo_flattened], main); - - // assert_eq!(main_flattened, expected); - // } - - // #[test] - // fn powers() { - // // def main(): - // // field a = 7 - // // field b = a**4 - // // return b - - // // def main(): - // // _0 = 7 - // // _1 = (_0 * _0) - // // _2 = (_1 * _0) - // // _3 = (_2 * _0) - // // return _3 - - // let function = TypedFunction { - // id: String::from("main"), - // arguments: vec![], - // statements: vec![ - // TypedStatement::Definition( - // TypedAssignee::Identifier(Variable::field_element("a")), - // FieldElementExpression::Number(FieldPrime::from(7)).into(), - // ), - // TypedStatement::Definition( - // TypedAssignee::Identifier(Variable::field_element("b")), - // FieldElementExpression::Pow( - // box FieldElementExpression::Identifier(String::from("a")), - // box FieldElementExpression::Number(FieldPrime::from(4)), - // ) - // .into(), - // ), - // TypedStatement::Return(vec![ - // FieldElementExpression::Identifier(String::from("b")).into() - // ]), - // ], - // signature: Signature { - // inputs: vec![], - // outputs: vec![Type::FieldElement], - // }, - // }; - - // let mut flattener = Flattener::new(); - - // let expected = FlatFunction { - // id: String::from("main"), - // arguments: vec![], - // statements: vec![ - // FlatStatement::Definition( - // FlatVariable::new(0), - // FlatExpression::Number(FieldPrime::from(7)), - // ), - // FlatStatement::Definition( - // FlatVariable::new(1), - // FlatExpression::Mult( - // box FlatExpression::Identifier(FlatVariable::new(0)), - // box FlatExpression::Identifier(FlatVariable::new(0)), - // ), - // ), - // FlatStatement::Definition( - // FlatVariable::new(2), - // FlatExpression::Mult( - // box FlatExpression::Identifier(FlatVariable::new(1)), - // box FlatExpression::Identifier(FlatVariable::new(0)), - // ), - // ), - // FlatStatement::Definition( - // FlatVariable::new(3), - // FlatExpression::Mult( - // box FlatExpression::Identifier(FlatVariable::new(2)), - // box FlatExpression::Identifier(FlatVariable::new(0)), - // ), - // ), - // FlatStatement::Return(FlatExpressionList { - // expressions: vec![FlatExpression::Identifier(FlatVariable::new(3))], - // }), - // ], - // signature: Signature::new().outputs(vec![Type::FieldElement]), - // }; - - // let flattened = flattener.flatten_function(&mut vec![], function); - - // assert_eq!(flattened, expected); - // } - - // #[test] - // fn overload() { - // // def foo() - // // return 1 - // // def foo() - // // return 1, 2 - // // def main() - // // a = foo() - // // b, c = foo() - // // return 1 - // // - // // should not panic - // // - - // let mut flattener = Flattener::new(); - // let functions = vec![ - // TypedFunction { - // id: "foo".to_string(), - // arguments: vec![], - // statements: vec![TypedStatement::Return(vec![TypedExpression::FieldElement( - // FieldElementExpression::Number(FieldPrime::from(1)), - // )])], - // signature: Signature::new() - // .inputs(vec![]) - // .outputs(vec![Type::FieldElement]), - // }, - // TypedFunction { - // id: "foo".to_string(), - // arguments: vec![], - // statements: vec![TypedStatement::Return(vec![ - // TypedExpression::FieldElement(FieldElementExpression::Number( - // FieldPrime::from(1), - // )), - // TypedExpression::FieldElement(FieldElementExpression::Number( - // FieldPrime::from(2), - // )), - // ])], - // signature: Signature::new() - // .inputs(vec![]) - // .outputs(vec![Type::FieldElement, Type::FieldElement]), - // }, - // TypedFunction { - // id: "main".to_string(), - // arguments: vec![], - // statements: vec![ - // TypedStatement::Definition( - // TypedAssignee::Identifier(Variable::field_element("a")), - // TypedExpression::FieldElement(FieldElementExpression::FunctionCall( - // "foo".to_string(), - // vec![], - // )), - // ), - // TypedStatement::MultipleDefinition( - // vec![ - // Variable::field_element("b".to_string()), - // Variable::field_element("c".to_string()), - // ], - // TypedExpressionList::FunctionCall( - // "foo".to_string(), - // vec![], - // vec![Type::FieldElement, Type::FieldElement], - // ), - // ), - // TypedStatement::Return(vec![TypedExpression::FieldElement( - // FieldElementExpression::Number(FieldPrime::from(1)), - // )]), - // ], - // signature: Signature::new() - // .inputs(vec![]) - // .outputs(vec![Type::FieldElement]), - // }, - // ]; - - // flattener.flatten_program(TypedModule { - // functions: functions - // .into_iter() - // .map(|f| { - // ( - // FunctionKey { - // id: f.id.clone(), - // signature: f.signature.clone(), - // }, - // f, - // ) - // }) - // .collect(), - // imported_functions: vec![], - // imports: vec![], - // }); - - // // shouldn't panic - // } - - // #[test] - // fn if_else() { - // let mut flattener = Flattener::new(); - // let expression = FieldElementExpression::IfElse( - // box BooleanExpression::Eq( - // box FieldElementExpression::Number(FieldPrime::from(32)), - // box FieldElementExpression::Number(FieldPrime::from(4)), - // ), - // box FieldElementExpression::Number(FieldPrime::from(12)), - // box FieldElementExpression::Number(FieldPrime::from(51)), - // ); - - // let mut functions_flattened = vec![]; - // flattener.load_corelib(&mut functions_flattened); - - // flattener.flatten_field_expression(&functions_flattened, &mut vec![], expression); - // } - - // #[test] - // fn geq_leq() { - // let mut flattener = Flattener::new(); - // let expression_le = BooleanExpression::Le( - // box FieldElementExpression::Number(FieldPrime::from(32)), - // box FieldElementExpression::Number(FieldPrime::from(4)), - // ); - - // let expression_ge = BooleanExpression::Ge( - // box FieldElementExpression::Number(FieldPrime::from(32)), - // box FieldElementExpression::Number(FieldPrime::from(4)), - // ); - - // flattener.flatten_boolean_expression(&mut vec![], &mut vec![], expression_le); - - // flattener.flatten_boolean_expression(&mut vec![], &mut vec![], expression_ge); - // } - - // #[test] - // fn bool_and() { - // let expression = FieldElementExpression::IfElse( - // box BooleanExpression::And( - // box BooleanExpression::Eq( - // box FieldElementExpression::Number(FieldPrime::from(4)), - // box FieldElementExpression::Number(FieldPrime::from(4)), - // ), - // box BooleanExpression::Lt( - // box FieldElementExpression::Number(FieldPrime::from(4)), - // box FieldElementExpression::Number(FieldPrime::from(20)), - // ), - // ), - // box FieldElementExpression::Number(FieldPrime::from(12)), - // box FieldElementExpression::Number(FieldPrime::from(51)), - // ); - - // let mut flattener = Flattener::new(); - // let mut functions_flattened = vec![]; - // flattener.load_corelib(&mut functions_flattened); - // flattener.flatten_field_expression(&functions_flattened, &mut vec![], expression); - // } - - // #[test] - // fn div() { - // // a = 5 / b / b - // let mut flattener = Flattener::new(); - // let mut functions_flattened = vec![]; - // let mut statements_flattened = vec![]; - - // let definition = TypedStatement::Definition( - // TypedAssignee::Identifier(Variable::field_element("b")), - // FieldElementExpression::Number(FieldPrime::from(42)).into(), - // ); - - // let statement = TypedStatement::Definition( - // TypedAssignee::Identifier(Variable::field_element("a")), - // FieldElementExpression::Div( - // box FieldElementExpression::Div( - // box FieldElementExpression::Number(FieldPrime::from(5)), - // box FieldElementExpression::Identifier(String::from("b")), - // ), - // box FieldElementExpression::Identifier(String::from("b")), - // ) - // .into(), - // ); - - // flattener.flatten_statement( - // &mut functions_flattened, - // &mut statements_flattened, - // definition, - // ); - - // flattener.flatten_statement( - // &mut functions_flattened, - // &mut statements_flattened, - // statement, - // ); - - // // define b - // let b = FlatVariable::new(0); - // // define new wires for members of Div - // let five = FlatVariable::new(1); - // let b0 = FlatVariable::new(2); - // // Define inverse of denominator to prevent div by 0 - // let invb0 = FlatVariable::new(3); - // // Define inverse - // let sym_0 = FlatVariable::new(4); - // // Define result, which is first member to next Div - // let sym_1 = FlatVariable::new(5); - // // Define second member - // let b1 = FlatVariable::new(6); - // // Define inverse of denominator to prevent div by 0 - // let invb1 = FlatVariable::new(7); - // // Define inverse - // let sym_2 = FlatVariable::new(8); - // // Define left hand side - // let a = FlatVariable::new(9); - - // assert_eq!( - // statements_flattened, - // vec![ - // FlatStatement::Definition(b, FlatExpression::Number(FieldPrime::from(42))), - // // inputs to first div (5/b) - // FlatStatement::Definition(five, FlatExpression::Number(FieldPrime::from(5))), - // FlatStatement::Definition(b0, b.into()), - // // check div by 0 - // FlatStatement::Directive(DirectiveStatement::new( - // vec![invb0], - // Helper::Rust(RustHelper::Div), - // vec![FlatExpression::Number(FieldPrime::from(1)), b0.into()] - // )), - // FlatStatement::Condition( - // FlatExpression::Number(FieldPrime::from(1)), - // FlatExpression::Mult(box invb0.into(), box b0.into()), - // ), - // // execute div - // FlatStatement::Directive(DirectiveStatement::new( - // vec![sym_0], - // Helper::Rust(RustHelper::Div), - // vec![five, b0] - // )), - // FlatStatement::Condition( - // five.into(), - // FlatExpression::Mult(box b0.into(), box sym_0.into()), - // ), - // // inputs to second div (res/b) - // FlatStatement::Definition(sym_1, sym_0.into()), - // FlatStatement::Definition(b1, b.into()), - // // check div by 0 - // FlatStatement::Directive(DirectiveStatement::new( - // vec![invb1], - // Helper::Rust(RustHelper::Div), - // vec![FlatExpression::Number(FieldPrime::from(1)), b1.into()] - // )), - // FlatStatement::Condition( - // FlatExpression::Number(FieldPrime::from(1)), - // FlatExpression::Mult(box invb1.into(), box b1.into()), - // ), - // // execute div - // FlatStatement::Directive(DirectiveStatement::new( - // vec![sym_2], - // Helper::Rust(RustHelper::Div), - // vec![sym_1, b1] - // )), - // FlatStatement::Condition( - // sym_1.into(), - // FlatExpression::Mult(box b1.into(), box sym_2.into()), - // ), - // // result - // FlatStatement::Definition(a, sym_2.into()), - // ] - // ); - // } - - // #[test] - // fn field_array() { - // // foo = [ , , ] - - // let mut flattener = Flattener::new(); - // let mut functions_flattened = vec![]; - // let mut statements_flattened = vec![]; - // let statement = TypedStatement::Definition( - // TypedAssignee::Identifier(Variable::field_array("foo", 3)), - // FieldElementArrayExpression::Value( - // 3, - // vec![ - // FieldElementExpression::Number(FieldPrime::from(1)), - // FieldElementExpression::Number(FieldPrime::from(2)), - // FieldElementExpression::Number(FieldPrime::from(3)), - // ], - // ) - // .into(), - // ); - // let expression = FieldElementArrayExpression::Identifier(3, String::from("foo")); - - // flattener.flatten_statement( - // &mut functions_flattened, - // &mut statements_flattened, - // statement, - // ); - - // let expressions = flattener.flatten_field_array_expression( - // &mut functions_flattened, - // &mut statements_flattened, - // expression, - // ); - - // assert_eq!( - // expressions, - // vec![ - // FlatExpression::Identifier(FlatVariable::new(0)), - // FlatExpression::Identifier(FlatVariable::new(1)), - // FlatExpression::Identifier(FlatVariable::new(2)), - // ] - // ); - // } - - // #[test] - // fn array_definition() { - // // field[3] foo = [1, 2, 3] - - // let mut flattener = Flattener::new(); - // let mut functions_flattened = vec![]; - // let mut statements_flattened = vec![]; - // let statement = TypedStatement::Definition( - // TypedAssignee::Identifier(Variable::field_array("foo", 3)), - // FieldElementArrayExpression::Value( - // 3, - // vec![ - // FieldElementExpression::Number(FieldPrime::from(1)), - // FieldElementExpression::Number(FieldPrime::from(2)), - // FieldElementExpression::Number(FieldPrime::from(3)), - // ], - // ) - // .into(), - // ); - - // flattener.flatten_statement( - // &mut functions_flattened, - // &mut statements_flattened, - // statement, - // ); - - // assert_eq!( - // statements_flattened, - // vec![ - // FlatStatement::Definition( - // FlatVariable::new(0), - // FlatExpression::Number(FieldPrime::from(1)) - // ), - // FlatStatement::Definition( - // FlatVariable::new(1), - // FlatExpression::Number(FieldPrime::from(2)) - // ), - // FlatStatement::Definition( - // FlatVariable::new(2), - // FlatExpression::Number(FieldPrime::from(3)) - // ), - // ] - // ); - // } - - // #[test] - // fn array_selection() { - // // field[3] foo = [1, 2, 3] - // // foo[1] - - // let mut flattener = Flattener::new(); - // let mut functions_flattened = vec![]; - // let mut statements_flattened = vec![]; - // let statement = TypedStatement::Definition( - // TypedAssignee::Identifier(Variable::field_array("foo", 3)), - // FieldElementArrayExpression::Value( - // 3, - // vec![ - // FieldElementExpression::Number(FieldPrime::from(1)), - // FieldElementExpression::Number(FieldPrime::from(2)), - // FieldElementExpression::Number(FieldPrime::from(3)), - // ], - // ) - // .into(), - // ); - - // let expression = FieldElementExpression::Select( - // box FieldElementArrayExpression::Identifier(3, String::from("foo")), - // box FieldElementExpression::Number(FieldPrime::from(1)), - // ); - - // flattener.flatten_statement::( - // &mut functions_flattened, - // &mut statements_flattened, - // statement, - // ); - - // let flat_expression = flattener.flatten_field_expression::( - // &mut functions_flattened, - // &mut statements_flattened, - // expression, - // ); - - // assert_eq!( - // flat_expression, - // FlatExpression::Identifier(FlatVariable::new(1)), - // ); - // } - - // #[test] - // fn array_sum() { - // // field[3] foo = [1, 2, 3] - // // bar = foo[0] + foo[1] + foo[2] - // // we don't optimise detecting constants, this will be done in an optimiser pass - - // let mut flattener = Flattener::new(); - // let mut functions_flattened = vec![]; - // let mut statements_flattened = vec![]; - // let def = TypedStatement::Definition( - // TypedAssignee::Identifier(Variable::field_array("foo", 3)), - // FieldElementArrayExpression::Value( - // 3, - // vec![ - // FieldElementExpression::Number(FieldPrime::from(1)), - // FieldElementExpression::Number(FieldPrime::from(2)), - // FieldElementExpression::Number(FieldPrime::from(3)), - // ], - // ) - // .into(), - // ); - - // let sum = TypedStatement::Definition( - // TypedAssignee::Identifier(Variable::field_element("bar")), - // FieldElementExpression::Add( - // box FieldElementExpression::Add( - // box FieldElementExpression::Select( - // box FieldElementArrayExpression::Identifier(3, String::from("foo")), - // box FieldElementExpression::Number(FieldPrime::from(0)), - // ), - // box FieldElementExpression::Select( - // box FieldElementArrayExpression::Identifier(3, String::from("foo")), - // box FieldElementExpression::Number(FieldPrime::from(1)), - // ), - // ), - // box FieldElementExpression::Select( - // box FieldElementArrayExpression::Identifier(3, String::from("foo")), - // box FieldElementExpression::Number(FieldPrime::from(2)), - // ), - // ) - // .into(), - // ); - - // flattener.flatten_statement::( - // &mut functions_flattened, - // &mut statements_flattened, - // def, - // ); - - // flattener.flatten_statement::( - // &mut functions_flattened, - // &mut statements_flattened, - // sum, - // ); - - // assert_eq!( - // statements_flattened[3], - // FlatStatement::Definition( - // FlatVariable::new(3), - // FlatExpression::Add( - // box FlatExpression::Add( - // box FlatExpression::Identifier(FlatVariable::new(0)), - // box FlatExpression::Identifier(FlatVariable::new(1)), - // ), - // box FlatExpression::Identifier(FlatVariable::new(2)), - // ) - // ) - // ); - // } - - // #[test] - // fn array_if() { - // // if 1 == 1 then [1] else [3] fi - - // let with_arrays = { - // let mut flattener = Flattener::new(); - // let mut functions_flattened = vec![]; - // flattener.load_corelib(&mut functions_flattened); - // let mut statements_flattened = vec![]; - - // let e = FieldElementArrayExpression::IfElse( - // box BooleanExpression::Eq( - // box FieldElementExpression::Number(FieldPrime::from(1)), - // box FieldElementExpression::Number(FieldPrime::from(1)), - // ), - // box FieldElementArrayExpression::Value( - // 1, - // vec![FieldElementExpression::Number(FieldPrime::from(1))], - // ), - // box FieldElementArrayExpression::Value( - // 1, - // vec![FieldElementExpression::Number(FieldPrime::from(3))], - // ), - // ); - - // ( - // flattener.flatten_field_array_expression( - // &mut functions_flattened, - // &mut statements_flattened, - // e, - // )[0] - // .clone(), - // statements_flattened, - // ) - // }; - - // let without_arrays = { - // let mut flattener = Flattener::new(); - // let mut functions_flattened = vec![]; - // flattener.load_corelib(&mut functions_flattened); - // let mut statements_flattened = vec![]; - - // // if 1 == 1 then 1 else 3 fi - // let e = FieldElementExpression::IfElse( - // box BooleanExpression::Eq( - // box FieldElementExpression::Number(FieldPrime::from(1)), - // box FieldElementExpression::Number(FieldPrime::from(1)), - // ), - // box FieldElementExpression::Number(FieldPrime::from(1)), - // box FieldElementExpression::Number(FieldPrime::from(3)), - // ); - - // ( - // flattener.flatten_field_expression( - // &mut functions_flattened, - // &mut statements_flattened, - // e, - // ), - // statements_flattened, - // ) - // }; - - // assert_eq!(with_arrays, without_arrays); - // } - - // #[test] - // fn next_variable() { - // let mut flattener = Flattener::new(); - // assert_eq!( - // FlatVariable::new(0), - // flattener.use_variable(&String::from("a")) - // ); - // assert_eq!( - // flattener.get_latest_var_substitution(&String::from("a")), - // FlatVariable::new(0) - // ); - // assert_eq!( - // FlatVariable::new(1), - // flattener.use_variable(&String::from("a")) - // ); - // assert_eq!( - // flattener.get_latest_var_substitution(&String::from("a")), - // FlatVariable::new(1) - // ); - // assert_eq!( - // FlatVariable::new(2), - // flattener.use_variable(&String::from("a")) - // ); - // assert_eq!( - // flattener.get_latest_var_substitution(&String::from("a")), - // FlatVariable::new(2) - // ); - // } + #[test] + fn field_array() { + // foo = [ , , ] + + let program: TypedProgram = TypedProgram { + modules: vec![( + String::from("main"), + TypedModule { + functions: HashMap::new(), + imports: vec![], + }, + )] + .into_iter() + .collect(), + main: String::from("main"), + }; + + let functions = program.modules.get("main").unwrap().functions.clone(); + + let mut flattener = Flattener::with_modules(program.modules); + let mut statements_flattened = vec![]; + let statement = TypedStatement::Definition( + TypedAssignee::Identifier(Variable::field_array("foo", 3)), + FieldElementArrayExpression::Value( + 3, + vec![ + FieldElementExpression::Number(FieldPrime::from(1)), + FieldElementExpression::Number(FieldPrime::from(2)), + FieldElementExpression::Number(FieldPrime::from(3)), + ], + ) + .into(), + ); + let expression = FieldElementArrayExpression::Identifier(3, String::from("foo")); + + flattener.flatten_statement(&functions, &mut statements_flattened, statement); + + let expressions = flattener.flatten_field_array_expression( + &functions, + &mut statements_flattened, + expression, + ); + + assert_eq!( + expressions, + vec![ + FlatExpression::Identifier(FlatVariable::new(0)), + FlatExpression::Identifier(FlatVariable::new(1)), + FlatExpression::Identifier(FlatVariable::new(2)), + ] + ); + } + + #[test] + fn array_definition() { + // field[3] foo = [1, 2, 3] + + let program: TypedProgram = TypedProgram { + modules: vec![( + String::from("main"), + TypedModule { + functions: HashMap::new(), + imports: vec![], + }, + )] + .into_iter() + .collect(), + main: String::from("main"), + }; + + let functions = program.modules.get("main").unwrap().functions.clone(); + + let mut flattener = Flattener::with_modules(program.modules); + let mut statements_flattened = vec![]; + let statement = TypedStatement::Definition( + TypedAssignee::Identifier(Variable::field_array("foo", 3)), + FieldElementArrayExpression::Value( + 3, + vec![ + FieldElementExpression::Number(FieldPrime::from(1)), + FieldElementExpression::Number(FieldPrime::from(2)), + FieldElementExpression::Number(FieldPrime::from(3)), + ], + ) + .into(), + ); + + flattener.flatten_statement(&functions, &mut statements_flattened, statement); + + assert_eq!( + statements_flattened, + vec![ + FlatStatement::Definition( + FlatVariable::new(0), + FlatExpression::Number(FieldPrime::from(1)) + ), + FlatStatement::Definition( + FlatVariable::new(1), + FlatExpression::Number(FieldPrime::from(2)) + ), + FlatStatement::Definition( + FlatVariable::new(2), + FlatExpression::Number(FieldPrime::from(3)) + ), + ] + ); + } + + #[test] + fn array_selection() { + // field[3] foo = [1, 2, 3] + // foo[1] + + let program: TypedProgram = TypedProgram { + modules: vec![( + String::from("main"), + TypedModule { + functions: HashMap::new(), + imports: vec![], + }, + )] + .into_iter() + .collect(), + main: String::from("main"), + }; + + let functions = program.modules.get("main").unwrap().functions.clone(); + + let mut flattener = Flattener::with_modules(program.modules); + let mut statements_flattened = vec![]; + let statement = TypedStatement::Definition( + TypedAssignee::Identifier(Variable::field_array("foo", 3)), + FieldElementArrayExpression::Value( + 3, + vec![ + FieldElementExpression::Number(FieldPrime::from(1)), + FieldElementExpression::Number(FieldPrime::from(2)), + FieldElementExpression::Number(FieldPrime::from(3)), + ], + ) + .into(), + ); + + let expression = FieldElementExpression::Select( + box FieldElementArrayExpression::Identifier(3, String::from("foo")), + box FieldElementExpression::Number(FieldPrime::from(1)), + ); + + flattener.flatten_statement(&functions, &mut statements_flattened, statement); + + let flat_expression = + flattener.flatten_field_expression(&functions, &mut statements_flattened, expression); + + assert_eq!( + flat_expression, + FlatExpression::Identifier(FlatVariable::new(1)), + ); + } + + #[test] + fn array_sum() { + // field[3] foo = [1, 2, 3] + // bar = foo[0] + foo[1] + foo[2] + // we don't optimise detecting constants, this will be done in an optimiser pass + + let program: TypedProgram = TypedProgram { + modules: vec![( + String::from("main"), + TypedModule { + functions: HashMap::new(), + imports: vec![], + }, + )] + .into_iter() + .collect(), + main: String::from("main"), + }; + + let functions = program.modules.get("main").unwrap().functions.clone(); + + let mut flattener = Flattener::with_modules(program.modules); + + let mut statements_flattened = vec![]; + let def = TypedStatement::Definition( + TypedAssignee::Identifier(Variable::field_array("foo", 3)), + FieldElementArrayExpression::Value( + 3, + vec![ + FieldElementExpression::Number(FieldPrime::from(1)), + FieldElementExpression::Number(FieldPrime::from(2)), + FieldElementExpression::Number(FieldPrime::from(3)), + ], + ) + .into(), + ); + + let sum = TypedStatement::Definition( + TypedAssignee::Identifier(Variable::field_element("bar")), + FieldElementExpression::Add( + box FieldElementExpression::Add( + box FieldElementExpression::Select( + box FieldElementArrayExpression::Identifier(3, String::from("foo")), + box FieldElementExpression::Number(FieldPrime::from(0)), + ), + box FieldElementExpression::Select( + box FieldElementArrayExpression::Identifier(3, String::from("foo")), + box FieldElementExpression::Number(FieldPrime::from(1)), + ), + ), + box FieldElementExpression::Select( + box FieldElementArrayExpression::Identifier(3, String::from("foo")), + box FieldElementExpression::Number(FieldPrime::from(2)), + ), + ) + .into(), + ); + + flattener.flatten_statement(&functions, &mut statements_flattened, def); + + flattener.flatten_statement(&functions, &mut statements_flattened, sum); + + assert_eq!( + statements_flattened[3], + FlatStatement::Definition( + FlatVariable::new(3), + FlatExpression::Add( + box FlatExpression::Add( + box FlatExpression::Identifier(FlatVariable::new(0)), + box FlatExpression::Identifier(FlatVariable::new(1)), + ), + box FlatExpression::Identifier(FlatVariable::new(2)), + ) + ) + ); + } + + #[test] + fn array_if() { + // if 1 == 1 then [1] else [3] fi + + let with_arrays = { + let program: TypedProgram = TypedProgram { + modules: vec![( + String::from("main"), + TypedModule { + functions: HashMap::new(), + imports: vec![], + }, + )] + .into_iter() + .collect(), + main: String::from("main"), + }; + + use static_analysis::CoreLibInjector; + let program = CoreLibInjector::inject(program); + + let functions = program.modules.get("main").unwrap().functions.clone(); + + let mut flattener = Flattener::with_modules(program.modules); + let mut statements_flattened = vec![]; + + let e = FieldElementArrayExpression::IfElse( + box BooleanExpression::Eq( + box FieldElementExpression::Number(FieldPrime::from(1)), + box FieldElementExpression::Number(FieldPrime::from(1)), + ), + box FieldElementArrayExpression::Value( + 1, + vec![FieldElementExpression::Number(FieldPrime::from(1))], + ), + box FieldElementArrayExpression::Value( + 1, + vec![FieldElementExpression::Number(FieldPrime::from(3))], + ), + ); + + ( + flattener.flatten_field_array_expression(&functions, &mut statements_flattened, e) + [0] + .clone(), + statements_flattened, + ) + }; + + let without_arrays = { + let program: TypedProgram = TypedProgram { + modules: vec![( + String::from("main"), + TypedModule { + functions: HashMap::new(), + imports: vec![], + }, + )] + .into_iter() + .collect(), + main: String::from("main"), + }; + + use static_analysis::CoreLibInjector; + let program = CoreLibInjector::inject(program); + let mut statements_flattened = vec![]; + + let functions = program.modules.get("main").unwrap().functions.clone(); + + let mut flattener = Flattener::with_modules(program.modules); + // if 1 == 1 then 1 else 3 fi + let e = FieldElementExpression::IfElse( + box BooleanExpression::Eq( + box FieldElementExpression::Number(FieldPrime::from(1)), + box FieldElementExpression::Number(FieldPrime::from(1)), + ), + box FieldElementExpression::Number(FieldPrime::from(1)), + box FieldElementExpression::Number(FieldPrime::from(3)), + ); + + ( + flattener.flatten_field_expression(&functions, &mut statements_flattened, e), + statements_flattened, + ) + }; + + assert_eq!(with_arrays, without_arrays); + } + + #[test] + fn next_variable() { + let mut flattener: Flattener = Flattener::new(); + assert_eq!( + FlatVariable::new(0), + flattener.use_variable(&String::from("a")) + ); + assert_eq!( + flattener.get_latest_var_substitution(&String::from("a")), + FlatVariable::new(0) + ); + assert_eq!( + FlatVariable::new(1), + flattener.use_variable(&String::from("a")) + ); + assert_eq!( + flattener.get_latest_var_substitution(&String::from("a")), + FlatVariable::new(1) + ); + assert_eq!( + FlatVariable::new(2), + flattener.use_variable(&String::from("a")) + ); + assert_eq!( + flattener.get_latest_var_substitution(&String::from("a")), + FlatVariable::new(2) + ); + } } diff --git a/zokrates_core/src/static_analysis/mod.rs b/zokrates_core/src/static_analysis/mod.rs index f25d25eb..b732269d 100644 --- a/zokrates_core/src/static_analysis/mod.rs +++ b/zokrates_core/src/static_analysis/mod.rs @@ -11,7 +11,7 @@ mod power_check; mod propagation; mod unroll; -use self::core_lib_injector::CoreLibInjector; +pub use self::core_lib_injector::CoreLibInjector; use self::inline::Inliner; use self::power_check::PowerChecker; use self::propagation::Propagator;