From 5c528535f2d7bee1fbfdd5d6e6b0aaa72c839717 Mon Sep 17 00:00:00 2001 From: dark64 Date: Thu, 13 May 2021 15:09:06 +0200 Subject: [PATCH 1/6] support constants in declaration types --- .../constant_array_size_type_mismatch.zok | 4 + .../generics/conflicting_constant.zok | 7 + zokrates_core/src/semantics.rs | 295 ++++++++++-------- .../src/static_analysis/constant_inliner.rs | 74 ++++- zokrates_core/src/typed_absy/folder.rs | 6 +- zokrates_core/src/typed_absy/mod.rs | 4 +- zokrates_core/src/typed_absy/result_folder.rs | 9 +- zokrates_core/src/typed_absy/types.rs | 9 +- .../tests/tests/constants/array_size.json | 16 + .../tests/tests/constants/array_size.zok | 5 + .../tests/tests/constants/propagate.json | 16 + .../tests/tests/constants/propagate.zok | 5 + .../tests/tests/constants/struct.zok | 10 +- 13 files changed, 310 insertions(+), 150 deletions(-) create mode 100644 zokrates_cli/examples/compile_errors/constant_array_size_type_mismatch.zok create mode 100644 zokrates_cli/examples/compile_errors/generics/conflicting_constant.zok create mode 100644 zokrates_core_test/tests/tests/constants/array_size.json create mode 100644 zokrates_core_test/tests/tests/constants/array_size.zok create mode 100644 zokrates_core_test/tests/tests/constants/propagate.json create mode 100644 zokrates_core_test/tests/tests/constants/propagate.zok diff --git a/zokrates_cli/examples/compile_errors/constant_array_size_type_mismatch.zok b/zokrates_cli/examples/compile_errors/constant_array_size_type_mismatch.zok new file mode 100644 index 00000000..8c34a04c --- /dev/null +++ b/zokrates_cli/examples/compile_errors/constant_array_size_type_mismatch.zok @@ -0,0 +1,4 @@ +const field SIZE = 2 + +def main(field[SIZE] n): + return \ No newline at end of file diff --git a/zokrates_cli/examples/compile_errors/generics/conflicting_constant.zok b/zokrates_cli/examples/compile_errors/generics/conflicting_constant.zok new file mode 100644 index 00000000..dcc9e6dc --- /dev/null +++ b/zokrates_cli/examples/compile_errors/generics/conflicting_constant.zok @@ -0,0 +1,7 @@ +const u32 N = 42 + +def foo(field[N] a) -> bool: + return true + +def main(): + return \ No newline at end of file diff --git a/zokrates_core/src/semantics.rs b/zokrates_core/src/semantics.rs index 0796a6c3..3847a808 100644 --- a/zokrates_core/src/semantics.rs +++ b/zokrates_core/src/semantics.rs @@ -353,11 +353,11 @@ impl<'ast, T: Field> Checker<'ast, T> { id: &'ast str, c: ConstantDefinitionNode<'ast>, module_id: &ModuleId, - types: &TypeMap<'ast>, + state: &State<'ast, T>, ) -> Result, ErrorInner> { let pos = c.pos(); - let ty = self.check_type(c.value.ty.clone(), module_id, &types)?; - let checked_expr = self.check_expression(c.value.expression.clone(), module_id, types)?; + let ty = self.check_type(c.value.ty.clone(), module_id, state)?; + let checked_expr = self.check_expression(c.value.expression.clone(), module_id, state)?; match ty { Type::FieldElement => { @@ -397,7 +397,7 @@ impl<'ast, T: Field> Checker<'ast, T> { id: String, s: StructDefinitionNode<'ast>, module_id: &ModuleId, - types: &TypeMap<'ast>, + state: &State<'ast, T>, ) -> Result, Vec> { let pos = s.pos(); let s = s.value; @@ -409,7 +409,7 @@ impl<'ast, T: Field> Checker<'ast, T> { for field in s.fields { let member_id = field.value.id.to_string(); match self - .check_declaration_type(field.value.ty, module_id, &types, &HashMap::new()) + .check_declaration_type(field.value.ty, module_id, state, &HashMap::new()) .map(|t| (member_id, t)) { Ok(f) => match fields_set.insert(f.0.clone()) { @@ -460,7 +460,7 @@ impl<'ast, T: Field> Checker<'ast, T> { declaration_id.to_string(), t.clone(), module_id, - &state.types, + state, ) { Ok(ty) => { match symbol_unifier.insert_type(declaration_id) { @@ -492,7 +492,7 @@ impl<'ast, T: Field> Checker<'ast, T> { } } Symbol::Here(SymbolDefinition::Constant(c)) => { - match self.check_constant_definition(declaration_id, c, module_id, &state.types) { + match self.check_constant_definition(declaration_id, c, module_id, state) { Ok(c) => { match symbol_unifier.insert_constant(declaration_id) { false => errors.push( @@ -527,7 +527,7 @@ impl<'ast, T: Field> Checker<'ast, T> { } } Symbol::Here(SymbolDefinition::Function(f)) => { - match self.check_function(f, module_id, &state.types) { + match self.check_function(f, module_id, state) { Ok(funct) => { match symbol_unifier .insert_function(declaration_id, funct.signature.clone()) @@ -831,7 +831,7 @@ impl<'ast, T: Field> Checker<'ast, T> { &mut self, funct_node: FunctionNode<'ast>, module_id: &ModuleId, - types: &TypeMap<'ast>, + state: &State<'ast, T>, ) -> Result, Vec> { assert!(self.return_types.is_none()); @@ -849,7 +849,7 @@ impl<'ast, T: Field> Checker<'ast, T> { let mut statements_checked = vec![]; - match self.check_signature(funct.signature, module_id, types) { + match self.check_signature(funct.signature, module_id, state) { Ok(s) => { // define variables for the constants for generic in &s.generics { @@ -905,7 +905,7 @@ impl<'ast, T: Field> Checker<'ast, T> { found_return = true; } - match self.check_statement(stat, module_id, types) { + match self.check_statement(stat, module_id, state) { Ok(statement) => { if let TypedStatement::Return(e) = &statement { match e.iter().map(|e| e.get_type()).collect::>() @@ -971,7 +971,7 @@ impl<'ast, T: Field> Checker<'ast, T> { &mut self, signature: UnresolvedSignature<'ast>, module_id: &ModuleId, - types: &TypeMap<'ast>, + state: &State<'ast, T>, ) -> Result, Vec> { let mut errors = vec![]; let mut inputs = vec![]; @@ -981,24 +981,39 @@ impl<'ast, T: Field> Checker<'ast, T> { let mut generics_map = HashMap::new(); for (index, g) in signature.generics.iter().enumerate() { - match generics_map.insert(g.value, index).is_none() { - true => { - generics.push(Some(Constant::Generic(GenericIdentifier { - name: g.value, - index, - }))); - } - false => { - errors.push(ErrorInner { - pos: Some(g.pos()), - message: format!("Generic parameter {} is already declared", g.value), - }); + if let Some((key, _)) = state + .constants + .get(module_id) + .unwrap() + .get_key_value(g.value) + { + errors.push(ErrorInner { + pos: Some(g.pos()), + message: format!( + "Generic parameter {} conflicts with constant symbol {}", + g.value, key + ), + }); + } else { + match generics_map.insert(g.value, index).is_none() { + true => { + generics.push(Some(Constant::Generic(GenericIdentifier { + name: g.value, + index, + }))); + } + false => { + errors.push(ErrorInner { + pos: Some(g.pos()), + message: format!("Generic parameter {} is already declared", g.value), + }); + } } } } for t in signature.inputs { - match self.check_declaration_type(t, module_id, types, &generics_map) { + match self.check_declaration_type(t, module_id, state, &generics_map) { Ok(t) => { inputs.push(t); } @@ -1009,7 +1024,7 @@ impl<'ast, T: Field> Checker<'ast, T> { } for t in signature.outputs { - match self.check_declaration_type(t, module_id, types, &generics_map) { + match self.check_declaration_type(t, module_id, state, &generics_map) { Ok(t) => { outputs.push(t); } @@ -1036,7 +1051,7 @@ impl<'ast, T: Field> Checker<'ast, T> { &mut self, ty: UnresolvedTypeNode<'ast>, module_id: &ModuleId, - types: &TypeMap<'ast>, + state: &State<'ast, T>, ) -> Result, ErrorInner> { let pos = ty.pos(); let ty = ty.value; @@ -1046,7 +1061,7 @@ impl<'ast, T: Field> Checker<'ast, T> { UnresolvedType::Boolean => Ok(Type::Boolean), UnresolvedType::Uint(bitwidth) => Ok(Type::uint(bitwidth)), UnresolvedType::Array(t, size) => { - let size = self.check_expression(size, module_id, types)?; + let size = self.check_expression(size, module_id, state)?; let ty = size.get_type(); @@ -1079,11 +1094,12 @@ impl<'ast, T: Field> Checker<'ast, T> { }?; Ok(Type::Array(ArrayType::new( - self.check_type(*t, module_id, types)?, + self.check_type(*t, module_id, state)?, size, ))) } - UnresolvedType::User(id) => types + UnresolvedType::User(id) => state + .types .get(module_id) .unwrap() .get(&id) @@ -1099,6 +1115,7 @@ impl<'ast, T: Field> Checker<'ast, T> { fn check_generic_expression( &mut self, expr: ExpressionNode<'ast>, + constants_map: &HashMap, Type<'ast, T>>, generics_map: &HashMap, usize>, ) -> Result, ErrorInner> { let pos = expr.pos(); @@ -1121,13 +1138,24 @@ impl<'ast, T: Field> Checker<'ast, T> { } } Expression::Identifier(name) => { - // check that this generic parameter is defined - match generics_map.get(&name) { - Some(index) => Ok(Constant::Generic(GenericIdentifier {name, index: *index})), - None => Err(ErrorInner { - pos: Some(pos), - message: format!("Undeclared generic parameter in function definition: `{}` isn\'t declared as a generic constant", name) - }) + match (constants_map.get(name), generics_map.get(&name)) { + (Some(c), None) => { + match c { + Type::Uint(bitwidth) => Ok(Constant::Identifier(name, bitwidth.to_usize())), + _ => Err(ErrorInner { + pos: Some(pos), + message: format!( + "Expected array dimension to be a u32 constant or an identifier, found {} of type {}", + name, c + ), + }) + } + } + (None, Some(index)) => Ok(Constant::Generic(GenericIdentifier { name, index: *index })), + _ => Err(ErrorInner { + pos: Some(pos), + message: format!("Undeclared generic parameter in function definition: `{}` isn\'t declared as a generic constant", name) + }) } } e => Err(ErrorInner { @@ -1144,7 +1172,7 @@ impl<'ast, T: Field> Checker<'ast, T> { &mut self, ty: UnresolvedTypeNode<'ast>, module_id: &ModuleId, - types: &TypeMap<'ast>, + state: &State<'ast, T>, generics_map: &HashMap, usize>, ) -> Result, ErrorInner> { let pos = ty.pos(); @@ -1155,24 +1183,27 @@ impl<'ast, T: Field> Checker<'ast, T> { UnresolvedType::Boolean => Ok(DeclarationType::Boolean), UnresolvedType::Uint(bitwidth) => Ok(DeclarationType::uint(bitwidth)), UnresolvedType::Array(t, size) => { - let checked_size = self.check_generic_expression(size.clone(), &generics_map)?; + let checked_size = self.check_generic_expression( + size.clone(), + state.constants.get(module_id).unwrap(), + generics_map, + )?; Ok(DeclarationType::Array(DeclarationArrayType::new( - self.check_declaration_type(*t, module_id, types, generics_map)?, + self.check_declaration_type(*t, module_id, state, generics_map)?, checked_size, ))) } - UnresolvedType::User(id) => { - types - .get(module_id) - .unwrap() - .get(&id) - .cloned() - .ok_or_else(|| ErrorInner { - pos: Some(pos), - message: format!("Undefined type {}", id), - }) - } + UnresolvedType::User(id) => state + .types + .get(module_id) + .unwrap() + .get(&id) + .cloned() + .ok_or_else(|| ErrorInner { + pos: Some(pos), + message: format!("Undefined type {}", id), + }), } } @@ -1180,11 +1211,11 @@ impl<'ast, T: Field> Checker<'ast, T> { &mut self, v: crate::absy::VariableNode<'ast>, module_id: &ModuleId, - types: &TypeMap<'ast>, + state: &State<'ast, T>, ) -> Result, Vec> { Ok(Variable::with_id_and_type( v.value.id, - self.check_type(v.value._type, module_id, types) + self.check_type(v.value._type, module_id, state) .map_err(|e| vec![e])?, )) } @@ -1196,17 +1227,17 @@ impl<'ast, T: Field> Checker<'ast, T> { statements: Vec>, pos: (Position, Position), module_id: &ModuleId, - types: &TypeMap<'ast>, + state: &State<'ast, T>, ) -> Result, Vec> { self.check_for_var(&var).map_err(|e| vec![e])?; - let var = self.check_variable(var, module_id, types).unwrap(); + let var = self.check_variable(var, module_id, state).unwrap(); let from = self - .check_expression(range.0, module_id, &types) + .check_expression(range.0, module_id, state) .map_err(|e| vec![e])?; let to = self - .check_expression(range.1, module_id, &types) + .check_expression(range.1, module_id, state) .map_err(|e| vec![e])?; let from = match from { @@ -1274,7 +1305,7 @@ impl<'ast, T: Field> Checker<'ast, T> { let mut checked_statements = vec![]; for stat in statements { - let checked_stat = self.check_statement(stat, module_id, types)?; + let checked_stat = self.check_statement(stat, module_id, state)?; checked_statements.push(checked_stat); } @@ -1285,7 +1316,7 @@ impl<'ast, T: Field> Checker<'ast, T> { &mut self, stat: StatementNode<'ast>, module_id: &ModuleId, - types: &TypeMap<'ast>, + state: &State<'ast, T>, ) -> Result, Vec> { let pos = stat.pos(); @@ -1299,7 +1330,7 @@ impl<'ast, T: Field> Checker<'ast, T> { for e in e.value.expressions.into_iter() { let e_checked = self - .check_expression(e, module_id, &types) + .check_expression(e, module_id, state) .map_err(|e| vec![e])?; expression_list_checked.push(e_checked); } @@ -1367,7 +1398,7 @@ impl<'ast, T: Field> Checker<'ast, T> { Ok(res) } Statement::Declaration(var) => { - let var = self.check_variable(var, module_id, types)?; + let var = self.check_variable(var, module_id, state)?; match self.insert_into_scope(var.clone()) { true => Ok(TypedStatement::Declaration(var)), false => Err(ErrorInner { @@ -1386,12 +1417,12 @@ impl<'ast, T: Field> Checker<'ast, T> { // check the expression to be assigned let checked_expr = self - .check_expression(expr, module_id, &types) + .check_expression(expr, module_id, state) .map_err(|e| vec![e])?; // check that the assignee is declared and is well formed let var = self - .check_assignee(assignee, module_id, &types) + .check_assignee(assignee, module_id, state) .map_err(|e| vec![e])?; let var_type = var.get_type(); @@ -1430,7 +1461,7 @@ impl<'ast, T: Field> Checker<'ast, T> { } Statement::Assertion(e) => { let e = self - .check_expression(e, module_id, &types) + .check_expression(e, module_id, state) .map_err(|e| vec![e])?; match e { @@ -1449,7 +1480,7 @@ impl<'ast, T: Field> Checker<'ast, T> { Statement::For(var, from, to, statements) => { self.enter_scope(); - let res = self.check_for_loop(var, (from, to), statements, pos, module_id, types); + let res = self.check_for_loop(var, (from, to), statements, pos, module_id, state); self.exit_scope(); @@ -1465,7 +1496,7 @@ impl<'ast, T: Field> Checker<'ast, T> { generics.into_iter().map(|g| g.map(|g| { let pos = g.pos(); - self.check_expression(g, module_id, &types).and_then(|g| { + self.check_expression(g, module_id, state).and_then(|g| { UExpression::try_from_typed(g, UBitwidth::B32).map_err( |e| ErrorInner { pos: Some(pos), @@ -1484,7 +1515,7 @@ impl<'ast, T: Field> Checker<'ast, T> { ).transpose().map_err(|e| vec![e])?; // check lhs assignees are defined - let (assignees, errors): (Vec<_>, Vec<_>) = assignees.into_iter().map(|a| self.check_assignee(a, module_id, types)).partition(|r| r.is_ok()); + let (assignees, errors): (Vec<_>, Vec<_>) = assignees.into_iter().map(|a| self.check_assignee(a, module_id, state)).partition(|r| r.is_ok()); if !errors.is_empty() { return Err(errors.into_iter().map(|e| e.unwrap_err()).collect()); @@ -1497,7 +1528,7 @@ impl<'ast, T: Field> Checker<'ast, T> { // find argument types let mut arguments_checked = vec![]; for arg in arguments { - let arg_checked = self.check_expression(arg, module_id, &types).map_err(|e| vec![e])?; + let arg_checked = self.check_expression(arg, module_id, state).map_err(|e| vec![e])?; arguments_checked.push(arg_checked); } @@ -1545,7 +1576,7 @@ impl<'ast, T: Field> Checker<'ast, T> { &mut self, assignee: AssigneeNode<'ast>, module_id: &ModuleId, - types: &TypeMap<'ast>, + state: &State<'ast, T>, ) -> Result, ErrorInner> { let pos = assignee.pos(); // check that the assignee is declared @@ -1567,14 +1598,14 @@ impl<'ast, T: Field> Checker<'ast, T> { }), }, Assignee::Select(box assignee, box index) => { - let checked_assignee = self.check_assignee(assignee, module_id, &types)?; + let checked_assignee = self.check_assignee(assignee, module_id, state)?; let ty = checked_assignee.get_type(); match ty { Type::Array(..) => { let checked_index = match index { RangeOrExpression::Expression(e) => { - self.check_expression(e, module_id, &types)? + self.check_expression(e, module_id, state)? } r => unimplemented!( "Using slices in assignments is not supported yet, found {}", @@ -1609,7 +1640,7 @@ impl<'ast, T: Field> Checker<'ast, T> { } } Assignee::Member(box assignee, box member) => { - let checked_assignee = self.check_assignee(assignee, module_id, &types)?; + let checked_assignee = self.check_assignee(assignee, module_id, state)?; let ty = checked_assignee.get_type(); match &ty { @@ -1646,14 +1677,14 @@ impl<'ast, T: Field> Checker<'ast, T> { &mut self, spread_or_expression: SpreadOrExpression<'ast>, module_id: &ModuleId, - types: &TypeMap<'ast>, + state: &State<'ast, T>, ) -> Result, ErrorInner> { match spread_or_expression { SpreadOrExpression::Spread(s) => { let pos = s.pos(); let checked_expression = - self.check_expression(s.value.expression, module_id, &types)?; + self.check_expression(s.value.expression, module_id, state)?; match checked_expression { TypedExpression::Array(a) => Ok(TypedExpressionOrSpread::Spread(a.into())), @@ -1666,9 +1697,9 @@ impl<'ast, T: Field> Checker<'ast, T> { }), } } - SpreadOrExpression::Expression(e) => self - .check_expression(e, module_id, &types) - .map(|r| r.into()), + SpreadOrExpression::Expression(e) => { + self.check_expression(e, module_id, state).map(|r| r.into()) + } } } @@ -1676,7 +1707,7 @@ impl<'ast, T: Field> Checker<'ast, T> { &mut self, expr: ExpressionNode<'ast>, module_id: &ModuleId, - types: &TypeMap<'ast>, + state: &State<'ast, T>, ) -> Result, ErrorInner> { let pos = expr.pos(); @@ -1711,8 +1742,8 @@ impl<'ast, T: Field> Checker<'ast, T> { } } Expression::Add(box e1, box e2) => { - let e1_checked = self.check_expression(e1, module_id, &types)?; - let e2_checked = self.check_expression(e2, module_id, &types)?; + let e1_checked = self.check_expression(e1, module_id, state)?; + let e2_checked = self.check_expression(e2, module_id, state)?; use self::TypedExpression::*; @@ -1746,8 +1777,8 @@ impl<'ast, T: Field> Checker<'ast, T> { } } Expression::Sub(box e1, box e2) => { - let e1_checked = self.check_expression(e1, module_id, &types)?; - let e2_checked = self.check_expression(e2, module_id, &types)?; + let e1_checked = self.check_expression(e1, module_id, state)?; + let e2_checked = self.check_expression(e2, module_id, state)?; use self::TypedExpression::*; @@ -1777,8 +1808,8 @@ impl<'ast, T: Field> Checker<'ast, T> { } } Expression::Mult(box e1, box e2) => { - let e1_checked = self.check_expression(e1, module_id, &types)?; - let e2_checked = self.check_expression(e2, module_id, &types)?; + let e1_checked = self.check_expression(e1, module_id, state)?; + let e2_checked = self.check_expression(e2, module_id, state)?; use self::TypedExpression::*; @@ -1812,8 +1843,8 @@ impl<'ast, T: Field> Checker<'ast, T> { } } Expression::Div(box e1, box e2) => { - let e1_checked = self.check_expression(e1, module_id, &types)?; - let e2_checked = self.check_expression(e2, module_id, &types)?; + let e1_checked = self.check_expression(e1, module_id, state)?; + let e2_checked = self.check_expression(e2, module_id, state)?; use self::TypedExpression::*; @@ -1847,8 +1878,8 @@ impl<'ast, T: Field> Checker<'ast, T> { } } Expression::Rem(box e1, box e2) => { - let e1_checked = self.check_expression(e1, module_id, &types)?; - let e2_checked = self.check_expression(e2, module_id, &types)?; + let e1_checked = self.check_expression(e1, module_id, state)?; + let e2_checked = self.check_expression(e2, module_id, state)?; let (e1_checked, e2_checked) = TypedExpression::align_without_integers( e1_checked, e2_checked, @@ -1876,8 +1907,8 @@ impl<'ast, T: Field> Checker<'ast, T> { } } Expression::Pow(box e1, box e2) => { - let e1_checked = self.check_expression(e1, module_id, &types)?; - let e2_checked = self.check_expression(e2, module_id, &types)?; + let e1_checked = self.check_expression(e1, module_id, state)?; + let e2_checked = self.check_expression(e2, module_id, state)?; let e1_checked = match FieldElementExpression::try_from_typed(e1_checked) { Ok(e) => e.into(), @@ -1904,7 +1935,7 @@ impl<'ast, T: Field> Checker<'ast, T> { } } Expression::Neg(box e) => { - let e = self.check_expression(e, module_id, &types)?; + let e = self.check_expression(e, module_id, state)?; match e { TypedExpression::Int(e) => Ok(IntExpression::Neg(box e).into()), @@ -1923,7 +1954,7 @@ impl<'ast, T: Field> Checker<'ast, T> { } } Expression::Pos(box e) => { - let e = self.check_expression(e, module_id, &types)?; + let e = self.check_expression(e, module_id, state)?; match e { TypedExpression::Int(e) => Ok(IntExpression::Pos(box e).into()), @@ -1942,9 +1973,9 @@ impl<'ast, T: Field> Checker<'ast, T> { } } Expression::IfElse(box condition, box consequence, box alternative) => { - let condition_checked = self.check_expression(condition, module_id, &types)?; - let consequence_checked = self.check_expression(consequence, module_id, &types)?; - let alternative_checked = self.check_expression(alternative, module_id, &types)?; + let condition_checked = self.check_expression(condition, module_id, state)?; + let consequence_checked = self.check_expression(consequence, module_id, state)?; + let alternative_checked = self.check_expression(alternative, module_id, state)?; let (consequence_checked, alternative_checked) = TypedExpression::align_without_integers( @@ -2020,7 +2051,7 @@ impl<'ast, T: Field> Checker<'ast, T> { .map(|g| { g.map(|g| { let pos = g.pos(); - self.check_expression(g, module_id, &types).and_then(|g| { + self.check_expression(g, module_id, state).and_then(|g| { UExpression::try_from_typed(g, UBitwidth::B32).map_err( |e| ErrorInner { pos: Some(pos), @@ -2042,7 +2073,7 @@ impl<'ast, T: Field> Checker<'ast, T> { // check the arguments let mut arguments_checked = vec![]; for arg in arguments { - let arg_checked = self.check_expression(arg, module_id, &types)?; + let arg_checked = self.check_expression(arg, module_id, state)?; arguments_checked.push(arg_checked); } @@ -2168,8 +2199,8 @@ impl<'ast, T: Field> Checker<'ast, T> { } } Expression::Lt(box e1, box e2) => { - let e1_checked = self.check_expression(e1, module_id, &types)?; - let e2_checked = self.check_expression(e2, module_id, &types)?; + let e1_checked = self.check_expression(e1, module_id, state)?; + let e2_checked = self.check_expression(e2, module_id, state)?; let (e1_checked, e2_checked) = TypedExpression::align_without_integers( e1_checked, e2_checked, @@ -2218,8 +2249,8 @@ impl<'ast, T: Field> Checker<'ast, T> { } } Expression::Le(box e1, box e2) => { - let e1_checked = self.check_expression(e1, module_id, &types)?; - let e2_checked = self.check_expression(e2, module_id, &types)?; + let e1_checked = self.check_expression(e1, module_id, state)?; + let e2_checked = self.check_expression(e2, module_id, state)?; let (e1_checked, e2_checked) = TypedExpression::align_without_integers( e1_checked, e2_checked, @@ -2268,8 +2299,8 @@ impl<'ast, T: Field> Checker<'ast, T> { } } Expression::Eq(box e1, box e2) => { - let e1_checked = self.check_expression(e1, module_id, &types)?; - let e2_checked = self.check_expression(e2, module_id, &types)?; + let e1_checked = self.check_expression(e1, module_id, state)?; + let e2_checked = self.check_expression(e2, module_id, state)?; let (e1_checked, e2_checked) = TypedExpression::align_without_integers( e1_checked, e2_checked, @@ -2318,8 +2349,8 @@ impl<'ast, T: Field> Checker<'ast, T> { } } Expression::Ge(box e1, box e2) => { - let e1_checked = self.check_expression(e1, module_id, &types)?; - let e2_checked = self.check_expression(e2, module_id, &types)?; + let e1_checked = self.check_expression(e1, module_id, state)?; + let e2_checked = self.check_expression(e2, module_id, state)?; let (e1_checked, e2_checked) = TypedExpression::align_without_integers( e1_checked, e2_checked, @@ -2368,8 +2399,8 @@ impl<'ast, T: Field> Checker<'ast, T> { } } Expression::Gt(box e1, box e2) => { - let e1_checked = self.check_expression(e1, module_id, &types)?; - let e2_checked = self.check_expression(e2, module_id, &types)?; + let e1_checked = self.check_expression(e1, module_id, state)?; + let e2_checked = self.check_expression(e2, module_id, state)?; let (e1_checked, e2_checked) = TypedExpression::align_without_integers( e1_checked, e2_checked, @@ -2418,7 +2449,7 @@ impl<'ast, T: Field> Checker<'ast, T> { } } Expression::Select(box array, box index) => { - let array = self.check_expression(array, module_id, &types)?; + let array = self.check_expression(array, module_id, state)?; match index { RangeOrExpression::Range(r) => { @@ -2432,13 +2463,13 @@ impl<'ast, T: Field> Checker<'ast, T> { let from = r .value .from - .map(|e| self.check_expression(e, module_id, &types)) + .map(|e| self.check_expression(e, module_id, state)) .unwrap_or_else(|| Ok(UExpression::from(0u32).into()))?; let to = r .value .to - .map(|e| self.check_expression(e, module_id, &types)) + .map(|e| self.check_expression(e, module_id, state)) .unwrap_or_else(|| Ok(array_size.clone().into()))?; let from = UExpression::try_from_typed(from, UBitwidth::B32).map_err(|e| ErrorInner { @@ -2478,7 +2509,7 @@ impl<'ast, T: Field> Checker<'ast, T> { } } RangeOrExpression::Expression(index) => { - let index = self.check_expression(index, module_id, &types)?; + let index = self.check_expression(index, module_id, state)?; let index = UExpression::try_from_typed(index, UBitwidth::B32).map_err(|e| { @@ -2519,7 +2550,7 @@ impl<'ast, T: Field> Checker<'ast, T> { } } Expression::Member(box e, box id) => { - let e = self.check_expression(e, module_id, &types)?; + let e = self.check_expression(e, module_id, state)?; match e { TypedExpression::Struct(s) => { @@ -2575,7 +2606,7 @@ impl<'ast, T: Field> Checker<'ast, T> { // check each expression, getting its type let mut expressions_or_spreads_checked = vec![]; for e in expressions_or_spreads { - let e_checked = self.check_spread_or_expression(e, module_id, &types)?; + let e_checked = self.check_spread_or_expression(e, module_id, state)?; expressions_or_spreads_checked.push(e_checked); } @@ -2642,10 +2673,10 @@ impl<'ast, T: Field> Checker<'ast, T> { ) } Expression::ArrayInitializer(box e, box count) => { - let e = self.check_expression(e, module_id, &types)?; + let e = self.check_expression(e, module_id, state)?; let ty = e.get_type(); - let count = self.check_expression(count, module_id, &types)?; + let count = self.check_expression(count, module_id, state)?; let count = UExpression::try_from_typed(count, UBitwidth::B32).map_err(|e| ErrorInner { @@ -2665,7 +2696,7 @@ impl<'ast, T: Field> Checker<'ast, T> { let ty = self.check_type( UnresolvedType::User(id.clone()).at(42, 42, 42), module_id, - &types, + state, )?; let struct_type = match ty { Type::Struct(struct_type) => struct_type, @@ -2705,7 +2736,7 @@ impl<'ast, T: Field> Checker<'ast, T> { match inline_members_map.remove(member.id.as_str()) { Some(value) => { let expression_checked = - self.check_expression(value, module_id, &types)?; + self.check_expression(value, module_id, state)?; let expression_checked = TypedExpression::align_to_type( expression_checked, @@ -2750,8 +2781,8 @@ impl<'ast, T: Field> Checker<'ast, T> { .into()) } Expression::And(box e1, box e2) => { - let e1_checked = self.check_expression(e1, module_id, &types)?; - let e2_checked = self.check_expression(e2, module_id, &types)?; + let e1_checked = self.check_expression(e1, module_id, state)?; + let e2_checked = self.check_expression(e2, module_id, state)?; let (e1_checked, e2_checked) = TypedExpression::align_without_integers( e1_checked, e2_checked, @@ -2784,8 +2815,8 @@ impl<'ast, T: Field> Checker<'ast, T> { } } Expression::Or(box e1, box e2) => { - let e1_checked = self.check_expression(e1, module_id, &types)?; - let e2_checked = self.check_expression(e2, module_id, &types)?; + let e1_checked = self.check_expression(e1, module_id, state)?; + let e2_checked = self.check_expression(e2, module_id, state)?; match (e1_checked, e2_checked) { (TypedExpression::Boolean(e1), TypedExpression::Boolean(e2)) => { Ok(BooleanExpression::Or(box e1, box e2).into()) @@ -2801,8 +2832,8 @@ impl<'ast, T: Field> Checker<'ast, T> { } } Expression::LeftShift(box e1, box e2) => { - let e1 = self.check_expression(e1, module_id, &types)?; - let e2 = self.check_expression(e2, module_id, &types)?; + let e1 = self.check_expression(e1, module_id, state)?; + let e2 = self.check_expression(e2, module_id, state)?; let e2 = UExpression::try_from_typed(e2, UBitwidth::B32).map_err(|e| ErrorInner { @@ -2828,8 +2859,8 @@ impl<'ast, T: Field> Checker<'ast, T> { } } Expression::RightShift(box e1, box e2) => { - let e1 = self.check_expression(e1, module_id, &types)?; - let e2 = self.check_expression(e2, module_id, &types)?; + let e1 = self.check_expression(e1, module_id, state)?; + let e2 = self.check_expression(e2, module_id, state)?; let e2 = UExpression::try_from_typed(e2, UBitwidth::B32).map_err(|e| ErrorInner { @@ -2857,8 +2888,8 @@ impl<'ast, T: Field> Checker<'ast, T> { } } Expression::BitOr(box e1, box e2) => { - let e1_checked = self.check_expression(e1, module_id, &types)?; - let e2_checked = self.check_expression(e2, module_id, &types)?; + let e1_checked = self.check_expression(e1, module_id, state)?; + let e2_checked = self.check_expression(e2, module_id, state)?; let (e1_checked, e2_checked) = TypedExpression::align_without_integers( e1_checked, e2_checked, @@ -2889,8 +2920,8 @@ impl<'ast, T: Field> Checker<'ast, T> { } } Expression::BitAnd(box e1, box e2) => { - let e1_checked = self.check_expression(e1, module_id, &types)?; - let e2_checked = self.check_expression(e2, module_id, &types)?; + let e1_checked = self.check_expression(e1, module_id, state)?; + let e2_checked = self.check_expression(e2, module_id, state)?; let (e1_checked, e2_checked) = TypedExpression::align_without_integers( e1_checked, e2_checked, @@ -2921,8 +2952,8 @@ impl<'ast, T: Field> Checker<'ast, T> { } } Expression::BitXor(box e1, box e2) => { - let e1_checked = self.check_expression(e1, module_id, &types)?; - let e2_checked = self.check_expression(e2, module_id, &types)?; + let e1_checked = self.check_expression(e1, module_id, state)?; + let e2_checked = self.check_expression(e2, module_id, state)?; let (e1_checked, e2_checked) = TypedExpression::align_without_integers( e1_checked, e2_checked, @@ -2953,7 +2984,7 @@ impl<'ast, T: Field> Checker<'ast, T> { } } Expression::Not(box e) => { - let e_checked = self.check_expression(e, module_id, &types)?; + let e_checked = self.check_expression(e, module_id, state)?; match e_checked { TypedExpression::Int(e) => Ok(IntExpression::Not(box e).into()), TypedExpression::Boolean(e) => Ok(BooleanExpression::Not(box e).into()), diff --git a/zokrates_core/src/static_analysis/constant_inliner.rs b/zokrates_core/src/static_analysis/constant_inliner.rs index 360927d7..9afecd4a 100644 --- a/zokrates_core/src/static_analysis/constant_inliner.rs +++ b/zokrates_core/src/static_analysis/constant_inliner.rs @@ -1,20 +1,26 @@ +use crate::static_analysis::propagation::Propagator; use crate::typed_absy::folder::*; +use crate::typed_absy::result_folder::ResultFolder; +use crate::typed_absy::types::{Constant, DeclarationStructType, GStructMember}; use crate::typed_absy::*; +use std::collections::HashMap; use std::convert::TryInto; use zokrates_field::Field; -pub struct ConstantInliner<'ast, T: Field> { +pub struct ConstantInliner<'ast, 'a, T: Field> { modules: TypedModules<'ast, T>, location: OwnedTypedModuleId, + propagator: Propagator<'ast, 'a, T>, } -impl<'ast, T: Field> ConstantInliner<'ast, T> { - pub fn new(modules: TypedModules<'ast, T>, location: OwnedTypedModuleId) -> Self { - ConstantInliner { modules, location } - } - +impl<'ast, 'a, T: Field> ConstantInliner<'ast, 'a, T> { pub fn inline(p: TypedProgram<'ast, T>) -> TypedProgram<'ast, T> { - let mut inliner = ConstantInliner::new(p.modules.clone(), p.main.clone()); + let mut constants = HashMap::new(); + let mut inliner = ConstantInliner { + modules: p.modules.clone(), + location: p.main.clone(), + propagator: Propagator::with_constants(&mut constants), + }; inliner.fold_program(p) } @@ -51,12 +57,18 @@ impl<'ast, T: Field> ConstantInliner<'ast, T> { let _ = self.change_location(location); symbol } - TypedConstantSymbol::Here(tc) => self.fold_constant(tc), + TypedConstantSymbol::Here(tc) => { + let tc: TypedConstant = self.fold_constant(tc); + TypedConstant { + expression: self.propagator.fold_expression(tc.expression).unwrap(), + ..tc + } + } } } } -impl<'ast, T: Field> Folder<'ast, T> for ConstantInliner<'ast, T> { +impl<'ast, 'a, T: Field> Folder<'ast, T> for ConstantInliner<'ast, 'a, T> { fn fold_program(&mut self, p: TypedProgram<'ast, T>) -> TypedProgram<'ast, T> { TypedProgram { modules: p @@ -71,6 +83,50 @@ impl<'ast, T: Field> Folder<'ast, T> for ConstantInliner<'ast, T> { } } + fn fold_signature(&mut self, s: DeclarationSignature<'ast>) -> DeclarationSignature<'ast> { + DeclarationSignature { + generics: s.generics, + inputs: s + .inputs + .into_iter() + .map(|ty| self.fold_declaration_type(ty)) + .collect(), + outputs: s + .outputs + .into_iter() + .map(|ty| self.fold_declaration_type(ty)) + .collect(), + } + } + + fn fold_declaration_type(&mut self, t: DeclarationType<'ast>) -> DeclarationType<'ast> { + match t { + DeclarationType::Array(ref array_ty) => match array_ty.size { + Constant::Identifier(name, _) => { + let tc = self.get_constant(&name.into()).unwrap(); + let expression: UExpression<'ast, T> = tc.expression.try_into().unwrap(); + match expression.inner { + UExpressionInner::Value(v) => DeclarationType::array(( + *array_ty.ty.clone(), + Constant::Concrete(v as u32), + )), + _ => unreachable!("expected u32 value"), + } + } + _ => t, + }, + DeclarationType::Struct(struct_ty) => DeclarationType::struc(DeclarationStructType { + members: struct_ty + .members + .into_iter() + .map(|m| GStructMember::new(m.id, self.fold_declaration_type(*m.ty))) + .collect(), + ..struct_ty + }), + _ => t, + } + } + fn fold_constant_symbol( &mut self, s: TypedConstantSymbol<'ast, T>, diff --git a/zokrates_core/src/typed_absy/folder.rs b/zokrates_core/src/typed_absy/folder.rs index 36c1453f..0284fb3e 100644 --- a/zokrates_core/src/typed_absy/folder.rs +++ b/zokrates_core/src/typed_absy/folder.rs @@ -35,6 +35,10 @@ pub trait Folder<'ast, T: Field>: Sized { fold_function(self, f) } + fn fold_signature(&mut self, s: DeclarationSignature<'ast>) -> DeclarationSignature<'ast> { + s + } + fn fold_parameter(&mut self, p: DeclarationParameter<'ast>) -> DeclarationParameter<'ast> { DeclarationParameter { id: self.fold_declaration_variable(p.id), @@ -668,7 +672,7 @@ pub fn fold_function<'ast, T: Field, F: Folder<'ast, T>>( .into_iter() .flat_map(|s| f.fold_statement(s)) .collect(), - ..fun + signature: f.fold_signature(fun.signature), } } diff --git a/zokrates_core/src/typed_absy/mod.rs b/zokrates_core/src/typed_absy/mod.rs index e00c56fa..4dca7ded 100644 --- a/zokrates_core/src/typed_absy/mod.rs +++ b/zokrates_core/src/typed_absy/mod.rs @@ -290,8 +290,8 @@ impl<'ast, T: fmt::Display> fmt::Display for TypedFunction<'ast, T> { #[derive(Clone, PartialEq, Debug)] pub struct TypedConstant<'ast, T> { - ty: Type<'ast, T>, - expression: TypedExpression<'ast, T>, + pub ty: Type<'ast, T>, + pub expression: TypedExpression<'ast, T>, } impl<'ast, T> TypedConstant<'ast, T> { diff --git a/zokrates_core/src/typed_absy/result_folder.rs b/zokrates_core/src/typed_absy/result_folder.rs index 245ab2a5..8a961e7f 100644 --- a/zokrates_core/src/typed_absy/result_folder.rs +++ b/zokrates_core/src/typed_absy/result_folder.rs @@ -49,6 +49,13 @@ pub trait ResultFolder<'ast, T: Field>: Sized { fold_function(self, f) } + fn fold_signature( + &mut self, + s: DeclarationSignature<'ast>, + ) -> Result, Self::Error> { + Ok(s) + } + fn fold_parameter( &mut self, p: DeclarationParameter<'ast>, @@ -741,7 +748,7 @@ pub fn fold_function<'ast, T: Field, F: ResultFolder<'ast, T>>( .into_iter() .flatten() .collect(), - ..fun + signature: f.fold_signature(fun.signature)?, }) } diff --git a/zokrates_core/src/typed_absy/types.rs b/zokrates_core/src/typed_absy/types.rs index 96ef84e5..cb02133d 100644 --- a/zokrates_core/src/typed_absy/types.rs +++ b/zokrates_core/src/typed_absy/types.rs @@ -1,4 +1,4 @@ -use crate::typed_absy::{OwnedTypedModuleId, UExpression, UExpressionInner}; +use crate::typed_absy::{Identifier, OwnedTypedModuleId, UExpression, UExpressionInner}; use crate::typed_absy::{TryFrom, TryInto}; use serde::{de::Error, ser::SerializeMap, Deserialize, Deserializer, Serialize, Serializer}; use std::collections::BTreeMap; @@ -54,6 +54,7 @@ pub struct SpecializationError; pub enum Constant<'ast> { Generic(GenericIdentifier<'ast>), Concrete(u32), + Identifier(&'ast str, usize), } impl<'ast> From for Constant<'ast> { @@ -79,6 +80,7 @@ impl<'ast> fmt::Display for Constant<'ast> { match self { Constant::Generic(i) => write!(f, "{}", i), Constant::Concrete(v) => write!(f, "{}", v), + Constant::Identifier(v, _) => write!(f, "{}", v), } } } @@ -96,6 +98,9 @@ impl<'ast, T> From> for UExpression<'ast, T> { UExpressionInner::Identifier(i.name.into()).annotate(UBitwidth::B32) } Constant::Concrete(v) => UExpressionInner::Value(v as u128).annotate(UBitwidth::B32), + Constant::Identifier(v, size) => { + UExpressionInner::Identifier(Identifier::from(v)).annotate(UBitwidth::from(size)) + } } } } @@ -920,6 +925,7 @@ pub mod signature { } }, Constant::Concrete(s0) => s1 == *s0 as usize, + Constant::Identifier(_, s0) => s1 == *s0, } } (DeclarationType::FieldElement, GType::FieldElement) @@ -945,6 +951,7 @@ pub mod signature { let size = match t0.size { Constant::Generic(s) => constants.0.get(&s).cloned().ok_or(s), Constant::Concrete(s) => Ok(s.into()), + Constant::Identifier(_, s) => Ok((s as u32).into()), }?; GType::Array(GArrayType { size, ty }) diff --git a/zokrates_core_test/tests/tests/constants/array_size.json b/zokrates_core_test/tests/tests/constants/array_size.json new file mode 100644 index 00000000..8f4bc52a --- /dev/null +++ b/zokrates_core_test/tests/tests/constants/array_size.json @@ -0,0 +1,16 @@ +{ + "entry_point": "./tests/tests/constants/array_size.zok", + "max_constraint_count": 2, + "tests": [ + { + "input": { + "values": ["42", "42"] + }, + "output": { + "Ok": { + "values": ["42", "42"] + } + } + } + ] +} diff --git a/zokrates_core_test/tests/tests/constants/array_size.zok b/zokrates_core_test/tests/tests/constants/array_size.zok new file mode 100644 index 00000000..5853da74 --- /dev/null +++ b/zokrates_core_test/tests/tests/constants/array_size.zok @@ -0,0 +1,5 @@ +const u32 SIZE = 2 + +def main(field[SIZE] a) -> field[SIZE]: + field[SIZE] b = a + return b \ No newline at end of file diff --git a/zokrates_core_test/tests/tests/constants/propagate.json b/zokrates_core_test/tests/tests/constants/propagate.json new file mode 100644 index 00000000..3d34ff24 --- /dev/null +++ b/zokrates_core_test/tests/tests/constants/propagate.json @@ -0,0 +1,16 @@ +{ + "entry_point": "./tests/tests/constants/propagate.zok", + "max_constraint_count": 4, + "tests": [ + { + "input": { + "values": [] + }, + "output": { + "Ok": { + "values": ["42", "42", "42", "42"] + } + } + } + ] +} diff --git a/zokrates_core_test/tests/tests/constants/propagate.zok b/zokrates_core_test/tests/tests/constants/propagate.zok new file mode 100644 index 00000000..151d0206 --- /dev/null +++ b/zokrates_core_test/tests/tests/constants/propagate.zok @@ -0,0 +1,5 @@ +const u32 TWO = 2 +const u32 FOUR = TWO * TWO + +def main() -> field[FOUR]: + return [42; FOUR] \ No newline at end of file diff --git a/zokrates_core_test/tests/tests/constants/struct.zok b/zokrates_core_test/tests/tests/constants/struct.zok index 92e705ca..545bfef5 100644 --- a/zokrates_core_test/tests/tests/constants/struct.zok +++ b/zokrates_core_test/tests/tests/constants/struct.zok @@ -1,9 +1,11 @@ -struct Foo { - field a +const u32 A_SIZE = 2 + +struct State { + field[A_SIZE] a field b } -const Foo FOO = Foo { a: 2, b: 2 } +const State STATE = State { a: [1, 1], b: 2 } def main() -> field: - return FOO.a + FOO.b \ No newline at end of file + return STATE.a[0] + STATE.a[1] + STATE.b \ No newline at end of file From b2201a38ff0d58eb623df9af81a57acbb8023379 Mon Sep 17 00:00:00 2001 From: dark64 Date: Thu, 13 May 2021 15:12:20 +0200 Subject: [PATCH 2/6] add changelog --- changelogs/unreleased/864-dark64 | 1 + 1 file changed, 1 insertion(+) create mode 100644 changelogs/unreleased/864-dark64 diff --git a/changelogs/unreleased/864-dark64 b/changelogs/unreleased/864-dark64 new file mode 100644 index 00000000..a7322d31 --- /dev/null +++ b/changelogs/unreleased/864-dark64 @@ -0,0 +1 @@ +Support the use of constants in declaration types \ No newline at end of file From 53a851fb1ff2a8846609b028b8ddfaf9dbe1e563 Mon Sep 17 00:00:00 2001 From: dark64 Date: Thu, 13 May 2021 15:26:15 +0200 Subject: [PATCH 3/6] fix possible unwrapping panic in semantics --- zokrates_core/src/semantics.rs | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/zokrates_core/src/semantics.rs b/zokrates_core/src/semantics.rs index 3847a808..8314c6ed 100644 --- a/zokrates_core/src/semantics.rs +++ b/zokrates_core/src/semantics.rs @@ -397,7 +397,7 @@ impl<'ast, T: Field> Checker<'ast, T> { id: String, s: StructDefinitionNode<'ast>, module_id: &ModuleId, - state: &State<'ast, T>, + state: &mut State<'ast, T>, ) -> Result, Vec> { let pos = s.pos(); let s = s.value; @@ -831,7 +831,7 @@ impl<'ast, T: Field> Checker<'ast, T> { &mut self, funct_node: FunctionNode<'ast>, module_id: &ModuleId, - state: &State<'ast, T>, + state: &mut State<'ast, T>, ) -> Result, Vec> { assert!(self.return_types.is_none()); @@ -971,7 +971,7 @@ impl<'ast, T: Field> Checker<'ast, T> { &mut self, signature: UnresolvedSignature<'ast>, module_id: &ModuleId, - state: &State<'ast, T>, + state: &mut State<'ast, T>, ) -> Result, Vec> { let mut errors = vec![]; let mut inputs = vec![]; @@ -1154,7 +1154,7 @@ impl<'ast, T: Field> Checker<'ast, T> { (None, Some(index)) => Ok(Constant::Generic(GenericIdentifier { name, index: *index })), _ => Err(ErrorInner { pos: Some(pos), - message: format!("Undeclared generic parameter in function definition: `{}` isn\'t declared as a generic constant", name) + message: format!("Undeclared symbol `{}` in function definition", name) }) } } @@ -1172,7 +1172,7 @@ impl<'ast, T: Field> Checker<'ast, T> { &mut self, ty: UnresolvedTypeNode<'ast>, module_id: &ModuleId, - state: &State<'ast, T>, + state: &mut State<'ast, T>, generics_map: &HashMap, usize>, ) -> Result, ErrorInner> { let pos = ty.pos(); @@ -1185,7 +1185,7 @@ impl<'ast, T: Field> Checker<'ast, T> { UnresolvedType::Array(t, size) => { let checked_size = self.check_generic_expression( size.clone(), - state.constants.get(module_id).unwrap(), + state.constants.entry(module_id.to_path_buf()).or_default(), generics_map, )?; From 31949e5bc40efab2da5a908379d836ba94b7420e Mon Sep 17 00:00:00 2001 From: dark64 Date: Thu, 13 May 2021 17:23:30 +0200 Subject: [PATCH 4/6] fix tests --- zokrates_core/src/semantics.rs | 196 ++++++++++-------- .../src/static_analysis/constant_inliner.rs | 15 +- 2 files changed, 119 insertions(+), 92 deletions(-) diff --git a/zokrates_core/src/semantics.rs b/zokrates_core/src/semantics.rs index 8314c6ed..a9ad7ef0 100644 --- a/zokrates_core/src/semantics.rs +++ b/zokrates_core/src/semantics.rs @@ -983,8 +983,8 @@ impl<'ast, T: Field> Checker<'ast, T> { for (index, g) in signature.generics.iter().enumerate() { if let Some((key, _)) = state .constants - .get(module_id) - .unwrap() + .entry(module_id.to_path_buf()) + .or_default() .get_key_value(g.value) { errors.push(ErrorInner { @@ -3053,11 +3053,12 @@ mod tests { fn field_in_range() { // The value of `P - 1` is a valid field literal - let types = HashMap::new(); + let modules = Modules::new(); + let state = State::new(modules); let expr = Expression::FieldConstant(Bn128Field::max_value().to_biguint()).mock(); assert!(Checker::::new() - .check_expression(expr, &*MODULE_ID, &types) + .check_expression(expr, &*MODULE_ID, &state) .is_ok()); } @@ -3065,13 +3066,14 @@ mod tests { fn field_overflow() { // the value of `P` is an invalid field literal - let types = HashMap::new(); + let modules = Modules::new(); + let state = State::new(modules); let value = Bn128Field::max_value().to_biguint().add(1u32); let expr = Expression::FieldConstant(value).mock(); assert!(Checker::::new() - .check_expression(expr, &*MODULE_ID, &types) + .check_expression(expr, &*MODULE_ID, &state) .is_err()); } } @@ -3086,7 +3088,9 @@ mod tests { // in the case of arrays, lengths do *not* have to match, as at this point they can be // generic, so we cannot tell yet - let types = HashMap::new(); + let modules = Modules::new(); + let state = State::new(modules); + // [3, true] let a = Expression::InlineArray(vec![ Expression::IntConstant(3usize.into()).mock().into(), @@ -3094,7 +3098,7 @@ mod tests { ]) .mock(); assert!(Checker::::new() - .check_expression(a, &*MODULE_ID, &types) + .check_expression(a, &*MODULE_ID, &state) .is_err()); // [[0f], [0f, 0f]] @@ -3114,7 +3118,7 @@ mod tests { ]) .mock(); assert!(Checker::::new() - .check_expression(a, &*MODULE_ID, &types) + .check_expression(a, &*MODULE_ID, &state) .is_ok()); // [[0f], true] @@ -3130,7 +3134,7 @@ mod tests { ]) .mock(); assert!(Checker::::new() - .check_expression(a, &*MODULE_ID, &types) + .check_expression(a, &*MODULE_ID, &state) .is_err()); } } @@ -3513,12 +3517,10 @@ mod tests { let mut checker: Checker = Checker::new(); assert_eq!( - checker - .check_module(&*MODULE_ID, &mut state) - .unwrap_err()[0] + checker.check_module(&*MODULE_ID, &mut state).unwrap_err()[0] .inner .message, - "Undeclared generic parameter in function definition: `P` isn\'t declared as a generic constant" + "Undeclared symbol `P` in function definition" ); } } @@ -3766,20 +3768,28 @@ mod tests { #[test] fn undeclared_generic() { + let modules = Modules::new(); + let mut state = State::new(modules); + let signature = UnresolvedSignature::new().inputs(vec![UnresolvedType::Array( box UnresolvedType::FieldElement.mock(), Expression::Identifier("K").mock(), ) .mock()]); - assert_eq!(Checker::::new().check_signature(signature, &*MODULE_ID, &TypeMap::default()), Err(vec![ErrorInner { - pos: Some((Position::mock(), Position::mock())), - message: "Undeclared generic parameter in function definition: `K` isn\'t declared as a generic constant".to_string() - }])); + assert_eq!( + Checker::::new().check_signature(signature, &*MODULE_ID, &mut state), + Err(vec![ErrorInner { + pos: Some((Position::mock(), Position::mock())), + message: "Undeclared symbol `K` in function definition".to_string() + }]) + ); } #[test] fn success() { // (field[L][K]) -> field[L][K] + let modules = Modules::new(); + let mut state = State::new(modules); let signature = UnresolvedSignature::new() .generics(vec!["K".mock(), "L".mock(), "M".mock()]) @@ -3802,11 +3812,7 @@ mod tests { ) .mock()]); assert_eq!( - Checker::::new().check_signature( - signature, - &*MODULE_ID, - &TypeMap::default() - ), + Checker::::new().check_signature(signature, &*MODULE_ID, &mut state), Ok(DeclarationSignature::new() .inputs(vec![DeclarationType::array(( DeclarationType::array(( @@ -3836,13 +3842,14 @@ mod tests { ) .mock(); - let types = HashMap::new(); + let modules = Modules::new(); + let state = State::new(modules); let mut checker: Checker = Checker::new(); checker.enter_scope(); assert_eq!( - checker.check_statement(statement, &*MODULE_ID, &types), + checker.check_statement(statement, &*MODULE_ID, &state), Err(vec![ErrorInner { pos: Some((Position::mock(), Position::mock())), message: "Identifier \"b\" is undefined".into() @@ -3860,7 +3867,8 @@ mod tests { ) .mock(); - let types = HashMap::new(); + let modules = Modules::new(); + let state = State::new(modules); let mut scope = HashSet::new(); scope.insert(ScopedVariable { @@ -3874,7 +3882,7 @@ mod tests { let mut checker: Checker = new_with_args(scope, 1, HashSet::new()); assert_eq!( - checker.check_statement(statement, &*MODULE_ID, &types), + checker.check_statement(statement, &*MODULE_ID, &state), Ok(TypedStatement::Definition( TypedAssignee::Identifier(typed_absy::Variable::field_element("a")), FieldElementExpression::Identifier("b".into()).into() @@ -4104,11 +4112,12 @@ mod tests { } .mock(); - let types = HashMap::new(); + let modules = Modules::new(); + let mut state = State::new(modules); let mut checker: Checker = Checker::new(); assert_eq!( - checker.check_function(foo, &*MODULE_ID, &types), + checker.check_function(foo, &*MODULE_ID, &mut state), Err(vec![ErrorInner { pos: Some((Position::mock(), Position::mock())), message: "Identifier \"i\" is undefined".into() @@ -4187,11 +4196,12 @@ mod tests { signature: DeclarationSignature::default(), }; - let types = HashMap::new(); + let modules = Modules::new(); + let mut state = State::new(modules); let mut checker: Checker = Checker::new(); assert_eq!( - checker.check_function(foo, &*MODULE_ID, &types), + checker.check_function(foo, &*MODULE_ID, &mut state), Ok(foo_checked) ); } @@ -4240,11 +4250,12 @@ mod tests { } .mock(); - let types = HashMap::new(); + let modules = Modules::new(); + let mut state = State::new(modules); let mut checker: Checker = new_with_args(HashSet::new(), 0, functions); assert_eq!( - checker.check_function(bar, &*MODULE_ID, &types), + checker.check_function(bar, &*MODULE_ID, &mut state), Err(vec![ErrorInner { pos: Some((Position::mock(), Position::mock())), message: @@ -4298,11 +4309,12 @@ mod tests { } .mock(); - let types = HashMap::new(); + let modules = Modules::new(); + let mut state = State::new(modules); let mut checker: Checker = new_with_args(HashSet::new(), 0, functions); assert_eq!( - checker.check_function(bar, &*MODULE_ID, &types), + checker.check_function(bar, &*MODULE_ID, &mut state), Err(vec![ErrorInner { pos: Some((Position::mock(), Position::mock())), message: "Function definition for function foo with signature () -> _ not found." @@ -4343,11 +4355,12 @@ mod tests { } .mock(); - let types = HashMap::new(); + let modules = Modules::new(); + let mut state = State::new(modules); let mut checker: Checker = new_with_args(HashSet::new(), 0, HashSet::new()); assert_eq!( - checker.check_function(bar, &*MODULE_ID, &types), + checker.check_function(bar, &*MODULE_ID, &mut state), Err(vec![ErrorInner { pos: Some((Position::mock(), Position::mock())), @@ -4685,11 +4698,12 @@ mod tests { } .mock(); - let types = HashMap::new(); + let modules = Modules::new(); + let mut state = State::new(modules); let mut checker: Checker = new_with_args(HashSet::new(), 0, HashSet::new()); assert_eq!( - checker.check_function(bar, &*MODULE_ID, &types), + checker.check_function(bar, &*MODULE_ID, &mut state), Err(vec![ErrorInner { pos: Some((Position::mock(), Position::mock())), @@ -4725,11 +4739,12 @@ mod tests { } .mock(); - let types = HashMap::new(); + let modules = Modules::new(); + let mut state = State::new(modules); let mut checker: Checker = new_with_args(HashSet::new(), 0, HashSet::new()); assert_eq!( - checker.check_function(bar, &*MODULE_ID, &types), + checker.check_function(bar, &*MODULE_ID, &mut state), Err(vec![ErrorInner { pos: Some((Position::mock(), Position::mock())), message: "Identifier \"a\" is undefined".into() @@ -4832,11 +4847,12 @@ mod tests { .outputs(vec![DeclarationType::FieldElement]), }; - let types = HashMap::new(); + let modules = Modules::new(); + let mut state = State::new(modules); let mut checker: Checker = new_with_args(HashSet::new(), 0, functions); assert_eq!( - checker.check_function(bar, &*MODULE_ID, &types), + checker.check_function(bar, &*MODULE_ID, &mut state), Ok(bar_checked) ); } @@ -4864,10 +4880,13 @@ mod tests { UnresolvedType::Boolean.mock(), ]); + let modules = Modules::new(); + let mut state = State::new(modules); + let mut checker: Checker = new_with_args(HashSet::new(), 0, HashSet::new()); assert_eq!( checker - .check_function(f, &*MODULE_ID, &HashMap::new()) + .check_function(f, &*MODULE_ID, &mut state) .unwrap_err()[0] .message, "Duplicate name in function definition: `a` was previously declared as an argument or a generic constant" @@ -4966,7 +4985,9 @@ mod tests { // // should fail - let types = HashMap::new(); + let modules = Modules::new(); + let state = State::new(modules); + let mut checker: Checker = Checker::new(); let _: Result, Vec> = checker.check_statement( Statement::Declaration( @@ -4974,7 +4995,7 @@ mod tests { ) .mock(), &*MODULE_ID, - &types, + &state, ); let s2_checked: Result, Vec> = checker .check_statement( @@ -4983,7 +5004,7 @@ mod tests { ) .mock(), &*MODULE_ID, - &types, + &state, ); assert_eq!( s2_checked, @@ -5001,7 +5022,8 @@ mod tests { // // should fail - let types = HashMap::new(); + let modules = Modules::new(); + let state = State::new(modules); let mut checker: Checker = Checker::new(); let _: Result, Vec> = checker.check_statement( @@ -5010,7 +5032,7 @@ mod tests { ) .mock(), &*MODULE_ID, - &types, + &state, ); let s2_checked: Result, Vec> = checker .check_statement( @@ -5019,7 +5041,7 @@ mod tests { ) .mock(), &*MODULE_ID, - &types, + &state, ); assert_eq!( s2_checked, @@ -5064,7 +5086,9 @@ mod tests { #[test] fn empty_def() { // an empty struct should be allowed to be defined - let types = HashMap::new(); + let modules = Modules::new(); + let mut state = State::new(modules); + let declaration: StructDefinitionNode = StructDefinition { fields: vec![] }.mock(); let expected_type = DeclarationType::Struct(DeclarationStructType::new( @@ -5078,7 +5102,7 @@ mod tests { "Foo".into(), declaration, &*MODULE_ID, - &types + &mut state ), Ok(expected_type) ); @@ -5087,7 +5111,9 @@ mod tests { #[test] fn valid_def() { // a valid struct should be allowed to be defined - let types = HashMap::new(); + let modules = Modules::new(); + let mut state = State::new(modules); + let declaration: StructDefinitionNode = StructDefinition { fields: vec![ StructDefinitionField { @@ -5118,7 +5144,7 @@ mod tests { "Foo".into(), declaration, &*MODULE_ID, - &types + &mut state ), Ok(expected_type) ); @@ -5127,7 +5153,8 @@ mod tests { #[test] fn duplicate_member_def() { // definition of a struct with a duplicate member should be rejected - let types = HashMap::new(); + let modules = Modules::new(); + let mut state = State::new(modules); let declaration: StructDefinitionNode = StructDefinition { fields: vec![ @@ -5151,7 +5178,7 @@ mod tests { "Foo".into(), declaration, &*MODULE_ID, - &types + &mut state ) .unwrap_err()[0] .message, @@ -5356,7 +5383,7 @@ mod tests { // an undefined type cannot be checked // Bar - let (mut checker, state) = create_module_with_foo(StructDefinition { + let (mut checker, mut state) = create_module_with_foo(StructDefinition { fields: vec![StructDefinitionField { id: "foo", ty: UnresolvedType::FieldElement.mock(), @@ -5368,7 +5395,7 @@ mod tests { checker.check_type( UnresolvedType::User("Foo".into()).mock(), &*MODULE_ID, - &state.types + &mut state ), Ok(Type::Struct(StructType::new( "".into(), @@ -5382,7 +5409,7 @@ mod tests { .check_type( UnresolvedType::User("Bar".into()).mock(), &*MODULE_ID, - &state.types + &mut state ) .unwrap_err() .message, @@ -5402,7 +5429,7 @@ mod tests { // struct Foo = { foo: field } // Foo { foo: 42 }.foo - let (mut checker, state) = create_module_with_foo(StructDefinition { + let (mut checker, mut state) = create_module_with_foo(StructDefinition { fields: vec![StructDefinitionField { id: "foo", ty: UnresolvedType::FieldElement.mock(), @@ -5422,7 +5449,7 @@ mod tests { ) .mock(), &*MODULE_ID, - &state.types + &mut state ), Ok(FieldElementExpression::Member( box StructExpressionInner::Value(vec![FieldElementExpression::Number( @@ -5447,7 +5474,7 @@ mod tests { // struct Foo = { foo: field } // Foo { foo: 42 }.bar - let (mut checker, state) = create_module_with_foo(StructDefinition { + let (mut checker, mut state) = create_module_with_foo(StructDefinition { fields: vec![StructDefinitionField { id: "foo", ty: UnresolvedType::FieldElement.mock(), @@ -5468,7 +5495,7 @@ mod tests { ) .mock(), &*MODULE_ID, - &state.types + &mut state ) .unwrap_err() .message, @@ -5485,7 +5512,7 @@ mod tests { fn wrong_name() { // a A value cannot be defined with B as id, even if A and B have the same members - let (mut checker, state) = create_module_with_foo(StructDefinition { + let (mut checker, mut state) = create_module_with_foo(StructDefinition { fields: vec![StructDefinitionField { id: "foo", ty: UnresolvedType::FieldElement.mock(), @@ -5502,7 +5529,7 @@ mod tests { ) .mock(), &*MODULE_ID, - &state.types + &mut state ) .unwrap_err() .message, @@ -5517,7 +5544,7 @@ mod tests { // struct Foo = { foo: field, bar: bool } // Foo foo = Foo { foo: 42, bar: true } - let (mut checker, state) = create_module_with_foo(StructDefinition { + let (mut checker, mut state) = create_module_with_foo(StructDefinition { fields: vec![ StructDefinitionField { id: "foo", @@ -5543,7 +5570,7 @@ mod tests { ) .mock(), &*MODULE_ID, - &state.types + &mut state ), Ok(StructExpressionInner::Value(vec![ FieldElementExpression::Number(Bn128Field::from(42u32)).into(), @@ -5568,7 +5595,7 @@ mod tests { // struct Foo = { foo: field, bar: bool } // Foo foo = Foo { bar: true, foo: 42 } - let (mut checker, state) = create_module_with_foo(StructDefinition { + let (mut checker, mut state) = create_module_with_foo(StructDefinition { fields: vec![ StructDefinitionField { id: "foo", @@ -5594,7 +5621,7 @@ mod tests { ) .mock(), &*MODULE_ID, - &state.types + &mut state ), Ok(StructExpressionInner::Value(vec![ FieldElementExpression::Number(Bn128Field::from(42u32)).into(), @@ -5619,7 +5646,7 @@ mod tests { // struct Foo = { foo: field, bar: bool } // Foo foo = Foo { foo: 42 } - let (mut checker, state) = create_module_with_foo(StructDefinition { + let (mut checker, mut state) = create_module_with_foo(StructDefinition { fields: vec![ StructDefinitionField { id: "foo", @@ -5643,7 +5670,7 @@ mod tests { ) .mock(), &*MODULE_ID, - &state.types + &mut state ) .unwrap_err() .message, @@ -5660,7 +5687,7 @@ mod tests { // Foo { foo: 42, baz: bool } // error // Foo { foo: 42, baz: 42 } // error - let (mut checker, state) = create_module_with_foo(StructDefinition { + let (mut checker, mut state) = create_module_with_foo(StructDefinition { fields: vec![ StructDefinitionField { id: "foo", @@ -5690,7 +5717,7 @@ mod tests { ) .mock(), &*MODULE_ID, - &state.types + &mut state ).unwrap_err() .message, "Member bar of struct Foo {foo: field, bar: bool} not found in value Foo {baz: true, foo: 42}" @@ -5708,7 +5735,7 @@ mod tests { ) .mock(), &*MODULE_ID, - &state.types + &mut state ) .unwrap_err() .message, @@ -5835,7 +5862,8 @@ mod tests { // a = 42 let a = Assignee::Identifier("a").mock(); - let types = HashMap::new(); + let modules = Modules::new(); + let state = State::new(modules); let mut checker: Checker = Checker::new(); checker.enter_scope(); @@ -5847,12 +5875,12 @@ mod tests { ) .mock(), &*MODULE_ID, - &types, + &state, ) .unwrap(); assert_eq!( - checker.check_assignee(a, &*MODULE_ID, &types), + checker.check_assignee(a, &*MODULE_ID, &state), Ok(TypedAssignee::Identifier( typed_absy::Variable::field_element("a") )) @@ -5869,7 +5897,8 @@ mod tests { ) .mock(); - let types = HashMap::new(); + let modules = Modules::new(); + let state = State::new(modules); let mut checker: Checker = Checker::new(); checker.enter_scope(); @@ -5889,12 +5918,12 @@ mod tests { ) .mock(), &*MODULE_ID, - &types, + &state, ) .unwrap(); assert_eq!( - checker.check_assignee(a, &*MODULE_ID, &types), + checker.check_assignee(a, &*MODULE_ID, &state), Ok(TypedAssignee::Select( box TypedAssignee::Identifier(typed_absy::Variable::field_array( "a", @@ -5921,7 +5950,8 @@ mod tests { ) .mock(); - let types = HashMap::new(); + let modules = Modules::new(); + let state = State::new(modules); let mut checker: Checker = Checker::new(); checker.enter_scope(); @@ -5945,12 +5975,12 @@ mod tests { ) .mock(), &*MODULE_ID, - &types, + &state, ) .unwrap(); assert_eq!( - checker.check_assignee(a, &*MODULE_ID, &types), + checker.check_assignee(a, &*MODULE_ID, &state), Ok(TypedAssignee::Select( box TypedAssignee::Select( box TypedAssignee::Identifier(typed_absy::Variable::array( diff --git a/zokrates_core/src/static_analysis/constant_inliner.rs b/zokrates_core/src/static_analysis/constant_inliner.rs index 9afecd4a..ed014464 100644 --- a/zokrates_core/src/static_analysis/constant_inliner.rs +++ b/zokrates_core/src/static_analysis/constant_inliner.rs @@ -107,7 +107,7 @@ impl<'ast, 'a, T: Field> Folder<'ast, T> for ConstantInliner<'ast, 'a, T> { let expression: UExpression<'ast, T> = tc.expression.try_into().unwrap(); match expression.inner { UExpressionInner::Value(v) => DeclarationType::array(( - *array_ty.ty.clone(), + self.fold_declaration_type(*array_ty.ty.clone()), Constant::Concrete(v as u32), )), _ => unreachable!("expected u32 value"), @@ -692,11 +692,9 @@ mod tests { let expected_main = TypedFunction { arguments: vec![], - statements: vec![TypedStatement::Return(vec![FieldElementExpression::Add( - box FieldElementExpression::Number(Bn128Field::from(1)), - box FieldElementExpression::Number(Bn128Field::from(1)), - ) - .into()])], + statements: vec![TypedStatement::Return(vec![ + FieldElementExpression::Number(Bn128Field::from(2)).into(), + ])], signature: DeclarationSignature::new() .inputs(vec![]) .outputs(vec![DeclarationType::FieldElement]), @@ -731,9 +729,8 @@ mod tests { const_b_id, TypedConstantSymbol::Here(TypedConstant::new( GType::FieldElement, - TypedExpression::FieldElement(FieldElementExpression::Add( - box FieldElementExpression::Number(Bn128Field::from(1)), - box FieldElementExpression::Number(Bn128Field::from(1)), + TypedExpression::FieldElement(FieldElementExpression::Number( + Bn128Field::from(2), )), )), ), From 7406f8769dbb0b0fb2ee31d4f895c2fe02b62f9a Mon Sep 17 00:00:00 2001 From: dark64 Date: Thu, 13 May 2021 19:02:38 +0200 Subject: [PATCH 5/6] visit struct type, improve tests --- .../src/static_analysis/constant_inliner.rs | 28 +++++++++++++++++ zokrates_core/src/typed_absy/folder.rs | 30 +++++++++++++++++-- zokrates_core/src/typed_absy/result_folder.rs | 22 ++++++++++++-- .../tests/tests/constants/array.zok | 5 ++-- .../tests/tests/constants/mixed.json | 16 ++++++++++ .../tests/tests/constants/mixed.zok | 15 ++++++++++ .../tests/tests/constants/struct.json | 4 +-- .../tests/tests/constants/struct.zok | 15 ++++++---- 8 files changed, 120 insertions(+), 15 deletions(-) create mode 100644 zokrates_core_test/tests/tests/constants/mixed.json create mode 100644 zokrates_core_test/tests/tests/constants/mixed.zok diff --git a/zokrates_core/src/static_analysis/constant_inliner.rs b/zokrates_core/src/static_analysis/constant_inliner.rs index ed014464..6edf4453 100644 --- a/zokrates_core/src/static_analysis/constant_inliner.rs +++ b/zokrates_core/src/static_analysis/constant_inliner.rs @@ -127,6 +127,34 @@ impl<'ast, 'a, T: Field> Folder<'ast, T> for ConstantInliner<'ast, 'a, T> { } } + fn fold_type(&mut self, t: Type<'ast, T>) -> Type<'ast, T> { + use self::GType::*; + match t { + Array(ref array_type) => match &array_type.size.inner { + UExpressionInner::Identifier(v) => match self.get_constant(v) { + Some(tc) => { + let expression: UExpression<'ast, T> = tc.expression.try_into().unwrap(); + Type::array(GArrayType::new( + self.fold_type(*array_type.ty.clone()), + expression, + )) + } + None => t, + }, + _ => t, + }, + Struct(struct_type) => Type::struc(GStructType { + members: struct_type + .members + .into_iter() + .map(|m| GStructMember::new(m.id, self.fold_type(*m.ty))) + .collect(), + ..struct_type + }), + _ => t, + } + } + fn fold_constant_symbol( &mut self, s: TypedConstantSymbol<'ast, T>, diff --git a/zokrates_core/src/typed_absy/folder.rs b/zokrates_core/src/typed_absy/folder.rs index 0284fb3e..dd3b9211 100644 --- a/zokrates_core/src/typed_absy/folder.rs +++ b/zokrates_core/src/typed_absy/folder.rs @@ -31,6 +31,13 @@ pub trait Folder<'ast, T: Field>: Sized { fold_function_symbol(self, s) } + fn fold_declaration_function_key( + &mut self, + key: DeclarationFunctionKey<'ast>, + ) -> DeclarationFunctionKey<'ast> { + fold_declaration_function_key(self, key) + } + fn fold_function(&mut self, f: TypedFunction<'ast, T>) -> TypedFunction<'ast, T> { fold_function(self, f) } @@ -194,6 +201,7 @@ pub trait Folder<'ast, T: Field>: Sized { ) -> ArrayExpressionInner<'ast, T> { fold_array_expression_inner(self, ty, e) } + fn fold_struct_expression_inner( &mut self, ty: &StructType<'ast, T>, @@ -216,7 +224,12 @@ pub fn fold_module<'ast, T: Field, F: Folder<'ast, T>>( functions: m .functions .into_iter() - .map(|(key, fun)| (key, f.fold_function_symbol(fun))) + .map(|(key, fun)| { + ( + f.fold_declaration_function_key(key), + f.fold_function_symbol(fun), + ) + }) .collect(), } } @@ -657,6 +670,16 @@ pub fn fold_uint_expression_inner<'ast, T: Field, F: Folder<'ast, T>>( } } +pub fn fold_declaration_function_key<'ast, T: Field, F: Folder<'ast, T>>( + f: &mut F, + key: DeclarationFunctionKey<'ast>, +) -> DeclarationFunctionKey<'ast> { + DeclarationFunctionKey { + signature: f.fold_signature(key.signature), + ..key + } +} + pub fn fold_function<'ast, T: Field, F: Folder<'ast, T>>( f: &mut F, fun: TypedFunction<'ast, T>, @@ -725,9 +748,10 @@ pub fn fold_struct_expression<'ast, T: Field, F: Folder<'ast, T>>( f: &mut F, e: StructExpression<'ast, T>, ) -> StructExpression<'ast, T> { + let ty = f.fold_struct_type(e.ty); StructExpression { - inner: f.fold_struct_expression_inner(&e.ty, e.inner), - ..e + inner: f.fold_struct_expression_inner(&ty, e.inner), + ty, } } diff --git a/zokrates_core/src/typed_absy/result_folder.rs b/zokrates_core/src/typed_absy/result_folder.rs index 8a961e7f..2f7fe2af 100644 --- a/zokrates_core/src/typed_absy/result_folder.rs +++ b/zokrates_core/src/typed_absy/result_folder.rs @@ -42,6 +42,13 @@ pub trait ResultFolder<'ast, T: Field>: Sized { fold_function_symbol(self, s) } + fn fold_declaration_function_key( + &mut self, + key: DeclarationFunctionKey<'ast>, + ) -> Result, Self::Error> { + fold_declaration_function_key(self, key) + } + fn fold_function( &mut self, f: TypedFunction<'ast, T>, @@ -730,6 +737,16 @@ pub fn fold_uint_expression_inner<'ast, T: Field, F: ResultFolder<'ast, T>>( Ok(e) } +pub fn fold_declaration_function_key<'ast, T: Field, F: ResultFolder<'ast, T>>( + f: &mut F, + key: DeclarationFunctionKey<'ast>, +) -> Result, F::Error> { + Ok(DeclarationFunctionKey { + signature: f.fold_signature(key.signature)?, + ..key + }) +} + pub fn fold_function<'ast, T: Field, F: ResultFolder<'ast, T>>( f: &mut F, fun: TypedFunction<'ast, T>, @@ -808,9 +825,10 @@ pub fn fold_struct_expression<'ast, T: Field, F: ResultFolder<'ast, T>>( f: &mut F, e: StructExpression<'ast, T>, ) -> Result, F::Error> { + let ty = f.fold_struct_type(e.ty)?; Ok(StructExpression { - inner: f.fold_struct_expression_inner(&e.ty, e.inner)?, - ..e + inner: f.fold_struct_expression_inner(&ty, e.inner)?, + ty, }) } diff --git a/zokrates_core_test/tests/tests/constants/array.zok b/zokrates_core_test/tests/tests/constants/array.zok index cce74dc7..60191184 100644 --- a/zokrates_core_test/tests/tests/constants/array.zok +++ b/zokrates_core_test/tests/tests/constants/array.zok @@ -1,4 +1,5 @@ -const field[2] ARRAY = [1, 2] +const u32 N = 2 +const field[N] ARRAY = [1, 2] -def main() -> field[2]: +def main() -> field[N]: return ARRAY \ No newline at end of file diff --git a/zokrates_core_test/tests/tests/constants/mixed.json b/zokrates_core_test/tests/tests/constants/mixed.json new file mode 100644 index 00000000..7b375291 --- /dev/null +++ b/zokrates_core_test/tests/tests/constants/mixed.json @@ -0,0 +1,16 @@ +{ + "entry_point": "./tests/tests/constants/mixed.zok", + "max_constraint_count": 6, + "tests": [ + { + "input": { + "values": [] + }, + "output": { + "Ok": { + "values": ["1", "2", "1", "3", "4", "0"] + } + } + } + ] +} diff --git a/zokrates_core_test/tests/tests/constants/mixed.zok b/zokrates_core_test/tests/tests/constants/mixed.zok new file mode 100644 index 00000000..851043c6 --- /dev/null +++ b/zokrates_core_test/tests/tests/constants/mixed.zok @@ -0,0 +1,15 @@ +const u32 N = 2 +const bool B = true + +struct Foo { + field[N] a + bool b +} + +const Foo[N] F = [ + Foo { a: [1, 2], b: B }, + Foo { a: [3, 4], b: !B } +] + +def main() -> Foo[N]: + return F \ No newline at end of file diff --git a/zokrates_core_test/tests/tests/constants/struct.json b/zokrates_core_test/tests/tests/constants/struct.json index 4d77d484..5ff07d9d 100644 --- a/zokrates_core_test/tests/tests/constants/struct.json +++ b/zokrates_core_test/tests/tests/constants/struct.json @@ -1,6 +1,6 @@ { "entry_point": "./tests/tests/constants/struct.zok", - "max_constraint_count": 1, + "max_constraint_count": 6, "tests": [ { "input": { @@ -8,7 +8,7 @@ }, "output": { "Ok": { - "values": ["4"] + "values": ["1", "2", "3", "4", "5", "6"] } } } diff --git a/zokrates_core_test/tests/tests/constants/struct.zok b/zokrates_core_test/tests/tests/constants/struct.zok index 545bfef5..c40a5513 100644 --- a/zokrates_core_test/tests/tests/constants/struct.zok +++ b/zokrates_core_test/tests/tests/constants/struct.zok @@ -1,11 +1,14 @@ -const u32 A_SIZE = 2 +const u32 N = 2 struct State { - field[A_SIZE] a - field b + field[N] a + field[N][N] b } -const State STATE = State { a: [1, 1], b: 2 } +const State STATE = State { + a: [1, 2], + b: [[3, 4], [5, 6]] +} -def main() -> field: - return STATE.a[0] + STATE.a[1] + STATE.b \ No newline at end of file +def main() -> State: + return STATE \ No newline at end of file From 96952d6e0455fc7bfe2d8c1430275cfe72d6560c Mon Sep 17 00:00:00 2001 From: dark64 Date: Fri, 14 May 2021 19:49:55 +0200 Subject: [PATCH 6/6] refactor --- changelogs/unreleased/864-dark64 | 2 +- ...nstant_array_size_type_mismatch_field.zok} | 0 .../constant_array_size_type_mismatch_u8.zok | 4 + zokrates_core/src/semantics.rs | 367 ++++++++---------- .../src/static_analysis/constant_inliner.rs | 37 +- zokrates_core/src/typed_absy/folder.rs | 21 +- zokrates_core/src/typed_absy/result_folder.rs | 21 +- 7 files changed, 229 insertions(+), 223 deletions(-) rename zokrates_cli/examples/compile_errors/{constant_array_size_type_mismatch.zok => constant_array_size_type_mismatch_field.zok} (100%) create mode 100644 zokrates_cli/examples/compile_errors/constant_array_size_type_mismatch_u8.zok diff --git a/changelogs/unreleased/864-dark64 b/changelogs/unreleased/864-dark64 index a7322d31..b596601e 100644 --- a/changelogs/unreleased/864-dark64 +++ b/changelogs/unreleased/864-dark64 @@ -1 +1 @@ -Support the use of constants in declaration types \ No newline at end of file +Support the use of constants in struct and function declarations \ No newline at end of file diff --git a/zokrates_cli/examples/compile_errors/constant_array_size_type_mismatch.zok b/zokrates_cli/examples/compile_errors/constant_array_size_type_mismatch_field.zok similarity index 100% rename from zokrates_cli/examples/compile_errors/constant_array_size_type_mismatch.zok rename to zokrates_cli/examples/compile_errors/constant_array_size_type_mismatch_field.zok diff --git a/zokrates_cli/examples/compile_errors/constant_array_size_type_mismatch_u8.zok b/zokrates_cli/examples/compile_errors/constant_array_size_type_mismatch_u8.zok new file mode 100644 index 00000000..5a086af7 --- /dev/null +++ b/zokrates_cli/examples/compile_errors/constant_array_size_type_mismatch_u8.zok @@ -0,0 +1,4 @@ +const u8 SIZE = 0x02 + +def main(field[SIZE] n): + return \ No newline at end of file diff --git a/zokrates_core/src/semantics.rs b/zokrates_core/src/semantics.rs index a9ad7ef0..e76e33b8 100644 --- a/zokrates_core/src/semantics.rs +++ b/zokrates_core/src/semantics.rs @@ -356,8 +356,9 @@ impl<'ast, T: Field> Checker<'ast, T> { state: &State<'ast, T>, ) -> Result, ErrorInner> { let pos = c.pos(); - let ty = self.check_type(c.value.ty.clone(), module_id, state)?; - let checked_expr = self.check_expression(c.value.expression.clone(), module_id, state)?; + let ty = self.check_type(c.value.ty.clone(), module_id, &state.types)?; + let checked_expr = + self.check_expression(c.value.expression.clone(), module_id, &state.types)?; match ty { Type::FieldElement => { @@ -397,7 +398,7 @@ impl<'ast, T: Field> Checker<'ast, T> { id: String, s: StructDefinitionNode<'ast>, module_id: &ModuleId, - state: &mut State<'ast, T>, + state: &State<'ast, T>, ) -> Result, Vec> { let pos = s.pos(); let s = s.value; @@ -761,8 +762,9 @@ impl<'ast, T: Field> Checker<'ast, T> { None => None, // if it was not, check it Some(module) => { - // we need to create an entry in the types map to store types for this module + // create default entries for this module state.types.entry(module_id.to_path_buf()).or_default(); + state.constants.entry(module_id.to_path_buf()).or_default(); // we keep track of the introduced symbols to avoid collisions between types and functions let mut symbol_unifier = SymbolUnifier::default(); @@ -831,7 +833,7 @@ impl<'ast, T: Field> Checker<'ast, T> { &mut self, funct_node: FunctionNode<'ast>, module_id: &ModuleId, - state: &mut State<'ast, T>, + state: &State<'ast, T>, ) -> Result, Vec> { assert!(self.return_types.is_none()); @@ -905,7 +907,7 @@ impl<'ast, T: Field> Checker<'ast, T> { found_return = true; } - match self.check_statement(stat, module_id, state) { + match self.check_statement(stat, module_id, &state.types) { Ok(statement) => { if let TypedStatement::Return(e) = &statement { match e.iter().map(|e| e.get_type()).collect::>() @@ -971,7 +973,7 @@ impl<'ast, T: Field> Checker<'ast, T> { &mut self, signature: UnresolvedSignature<'ast>, module_id: &ModuleId, - state: &mut State<'ast, T>, + state: &State<'ast, T>, ) -> Result, Vec> { let mut errors = vec![]; let mut inputs = vec![]; @@ -981,17 +983,17 @@ impl<'ast, T: Field> Checker<'ast, T> { let mut generics_map = HashMap::new(); for (index, g) in signature.generics.iter().enumerate() { - if let Some((key, _)) = state + if state .constants - .entry(module_id.to_path_buf()) - .or_default() - .get_key_value(g.value) + .get(module_id) + .and_then(|m| m.get(g.value)) + .is_some() { errors.push(ErrorInner { pos: Some(g.pos()), message: format!( - "Generic parameter {} conflicts with constant symbol {}", - g.value, key + "Generic parameter {p} conflicts with constant symbol {p}", + p = g.value ), }); } else { @@ -1051,7 +1053,7 @@ impl<'ast, T: Field> Checker<'ast, T> { &mut self, ty: UnresolvedTypeNode<'ast>, module_id: &ModuleId, - state: &State<'ast, T>, + types: &TypeMap<'ast>, ) -> Result, ErrorInner> { let pos = ty.pos(); let ty = ty.value; @@ -1061,7 +1063,7 @@ impl<'ast, T: Field> Checker<'ast, T> { UnresolvedType::Boolean => Ok(Type::Boolean), UnresolvedType::Uint(bitwidth) => Ok(Type::uint(bitwidth)), UnresolvedType::Array(t, size) => { - let size = self.check_expression(size, module_id, state)?; + let size = self.check_expression(size, module_id, types)?; let ty = size.get_type(); @@ -1094,12 +1096,11 @@ impl<'ast, T: Field> Checker<'ast, T> { }?; Ok(Type::Array(ArrayType::new( - self.check_type(*t, module_id, state)?, + self.check_type(*t, module_id, types)?, size, ))) } - UnresolvedType::User(id) => state - .types + UnresolvedType::User(id) => types .get(module_id) .unwrap() .get(&id) @@ -1111,7 +1112,6 @@ impl<'ast, T: Field> Checker<'ast, T> { .map(|t| t.into()), } } - fn check_generic_expression( &mut self, expr: ExpressionNode<'ast>, @@ -1139,14 +1139,14 @@ impl<'ast, T: Field> Checker<'ast, T> { } Expression::Identifier(name) => { match (constants_map.get(name), generics_map.get(&name)) { - (Some(c), None) => { - match c { - Type::Uint(bitwidth) => Ok(Constant::Identifier(name, bitwidth.to_usize())), + (Some(ty), None) => { + match ty { + Type::Uint(UBitwidth::B32) => Ok(Constant::Identifier(name, 32usize)), _ => Err(ErrorInner { pos: Some(pos), message: format!( "Expected array dimension to be a u32 constant or an identifier, found {} of type {}", - name, c + name, ty ), }) } @@ -1172,7 +1172,7 @@ impl<'ast, T: Field> Checker<'ast, T> { &mut self, ty: UnresolvedTypeNode<'ast>, module_id: &ModuleId, - state: &mut State<'ast, T>, + state: &State<'ast, T>, generics_map: &HashMap, usize>, ) -> Result, ErrorInner> { let pos = ty.pos(); @@ -1185,7 +1185,7 @@ impl<'ast, T: Field> Checker<'ast, T> { UnresolvedType::Array(t, size) => { let checked_size = self.check_generic_expression( size.clone(), - state.constants.entry(module_id.to_path_buf()).or_default(), + state.constants.get(module_id).unwrap_or(&HashMap::new()), generics_map, )?; @@ -1211,11 +1211,11 @@ impl<'ast, T: Field> Checker<'ast, T> { &mut self, v: crate::absy::VariableNode<'ast>, module_id: &ModuleId, - state: &State<'ast, T>, + types: &TypeMap<'ast>, ) -> Result, Vec> { Ok(Variable::with_id_and_type( v.value.id, - self.check_type(v.value._type, module_id, state) + self.check_type(v.value._type, module_id, types) .map_err(|e| vec![e])?, )) } @@ -1227,17 +1227,17 @@ impl<'ast, T: Field> Checker<'ast, T> { statements: Vec>, pos: (Position, Position), module_id: &ModuleId, - state: &State<'ast, T>, + types: &TypeMap<'ast>, ) -> Result, Vec> { self.check_for_var(&var).map_err(|e| vec![e])?; - let var = self.check_variable(var, module_id, state).unwrap(); + let var = self.check_variable(var, module_id, types).unwrap(); let from = self - .check_expression(range.0, module_id, state) + .check_expression(range.0, module_id, types) .map_err(|e| vec![e])?; let to = self - .check_expression(range.1, module_id, state) + .check_expression(range.1, module_id, types) .map_err(|e| vec![e])?; let from = match from { @@ -1305,7 +1305,7 @@ impl<'ast, T: Field> Checker<'ast, T> { let mut checked_statements = vec![]; for stat in statements { - let checked_stat = self.check_statement(stat, module_id, state)?; + let checked_stat = self.check_statement(stat, module_id, types)?; checked_statements.push(checked_stat); } @@ -1316,7 +1316,7 @@ impl<'ast, T: Field> Checker<'ast, T> { &mut self, stat: StatementNode<'ast>, module_id: &ModuleId, - state: &State<'ast, T>, + types: &TypeMap<'ast>, ) -> Result, Vec> { let pos = stat.pos(); @@ -1330,7 +1330,7 @@ impl<'ast, T: Field> Checker<'ast, T> { for e in e.value.expressions.into_iter() { let e_checked = self - .check_expression(e, module_id, state) + .check_expression(e, module_id, types) .map_err(|e| vec![e])?; expression_list_checked.push(e_checked); } @@ -1398,7 +1398,7 @@ impl<'ast, T: Field> Checker<'ast, T> { Ok(res) } Statement::Declaration(var) => { - let var = self.check_variable(var, module_id, state)?; + let var = self.check_variable(var, module_id, types)?; match self.insert_into_scope(var.clone()) { true => Ok(TypedStatement::Declaration(var)), false => Err(ErrorInner { @@ -1417,12 +1417,12 @@ impl<'ast, T: Field> Checker<'ast, T> { // check the expression to be assigned let checked_expr = self - .check_expression(expr, module_id, state) + .check_expression(expr, module_id, types) .map_err(|e| vec![e])?; // check that the assignee is declared and is well formed let var = self - .check_assignee(assignee, module_id, state) + .check_assignee(assignee, module_id, types) .map_err(|e| vec![e])?; let var_type = var.get_type(); @@ -1461,7 +1461,7 @@ impl<'ast, T: Field> Checker<'ast, T> { } Statement::Assertion(e) => { let e = self - .check_expression(e, module_id, state) + .check_expression(e, module_id, types) .map_err(|e| vec![e])?; match e { @@ -1480,7 +1480,7 @@ impl<'ast, T: Field> Checker<'ast, T> { Statement::For(var, from, to, statements) => { self.enter_scope(); - let res = self.check_for_loop(var, (from, to), statements, pos, module_id, state); + let res = self.check_for_loop(var, (from, to), statements, pos, module_id, types); self.exit_scope(); @@ -1496,7 +1496,7 @@ impl<'ast, T: Field> Checker<'ast, T> { generics.into_iter().map(|g| g.map(|g| { let pos = g.pos(); - self.check_expression(g, module_id, state).and_then(|g| { + self.check_expression(g, module_id, types).and_then(|g| { UExpression::try_from_typed(g, UBitwidth::B32).map_err( |e| ErrorInner { pos: Some(pos), @@ -1515,7 +1515,7 @@ impl<'ast, T: Field> Checker<'ast, T> { ).transpose().map_err(|e| vec![e])?; // check lhs assignees are defined - let (assignees, errors): (Vec<_>, Vec<_>) = assignees.into_iter().map(|a| self.check_assignee(a, module_id, state)).partition(|r| r.is_ok()); + let (assignees, errors): (Vec<_>, Vec<_>) = assignees.into_iter().map(|a| self.check_assignee(a, module_id, types)).partition(|r| r.is_ok()); if !errors.is_empty() { return Err(errors.into_iter().map(|e| e.unwrap_err()).collect()); @@ -1528,7 +1528,7 @@ impl<'ast, T: Field> Checker<'ast, T> { // find argument types let mut arguments_checked = vec![]; for arg in arguments { - let arg_checked = self.check_expression(arg, module_id, state).map_err(|e| vec![e])?; + let arg_checked = self.check_expression(arg, module_id, types).map_err(|e| vec![e])?; arguments_checked.push(arg_checked); } @@ -1576,7 +1576,7 @@ impl<'ast, T: Field> Checker<'ast, T> { &mut self, assignee: AssigneeNode<'ast>, module_id: &ModuleId, - state: &State<'ast, T>, + types: &TypeMap<'ast>, ) -> Result, ErrorInner> { let pos = assignee.pos(); // check that the assignee is declared @@ -1598,14 +1598,14 @@ impl<'ast, T: Field> Checker<'ast, T> { }), }, Assignee::Select(box assignee, box index) => { - let checked_assignee = self.check_assignee(assignee, module_id, state)?; + let checked_assignee = self.check_assignee(assignee, module_id, types)?; let ty = checked_assignee.get_type(); match ty { Type::Array(..) => { let checked_index = match index { RangeOrExpression::Expression(e) => { - self.check_expression(e, module_id, state)? + self.check_expression(e, module_id, types)? } r => unimplemented!( "Using slices in assignments is not supported yet, found {}", @@ -1640,7 +1640,7 @@ impl<'ast, T: Field> Checker<'ast, T> { } } Assignee::Member(box assignee, box member) => { - let checked_assignee = self.check_assignee(assignee, module_id, state)?; + let checked_assignee = self.check_assignee(assignee, module_id, types)?; let ty = checked_assignee.get_type(); match &ty { @@ -1677,14 +1677,14 @@ impl<'ast, T: Field> Checker<'ast, T> { &mut self, spread_or_expression: SpreadOrExpression<'ast>, module_id: &ModuleId, - state: &State<'ast, T>, + types: &TypeMap<'ast>, ) -> Result, ErrorInner> { match spread_or_expression { SpreadOrExpression::Spread(s) => { let pos = s.pos(); let checked_expression = - self.check_expression(s.value.expression, module_id, state)?; + self.check_expression(s.value.expression, module_id, types)?; match checked_expression { TypedExpression::Array(a) => Ok(TypedExpressionOrSpread::Spread(a.into())), @@ -1698,7 +1698,7 @@ impl<'ast, T: Field> Checker<'ast, T> { } } SpreadOrExpression::Expression(e) => { - self.check_expression(e, module_id, state).map(|r| r.into()) + self.check_expression(e, module_id, types).map(|r| r.into()) } } } @@ -1707,7 +1707,7 @@ impl<'ast, T: Field> Checker<'ast, T> { &mut self, expr: ExpressionNode<'ast>, module_id: &ModuleId, - state: &State<'ast, T>, + types: &TypeMap<'ast>, ) -> Result, ErrorInner> { let pos = expr.pos(); @@ -1742,8 +1742,8 @@ impl<'ast, T: Field> Checker<'ast, T> { } } Expression::Add(box e1, box e2) => { - let e1_checked = self.check_expression(e1, module_id, state)?; - let e2_checked = self.check_expression(e2, module_id, state)?; + let e1_checked = self.check_expression(e1, module_id, types)?; + let e2_checked = self.check_expression(e2, module_id, types)?; use self::TypedExpression::*; @@ -1777,8 +1777,8 @@ impl<'ast, T: Field> Checker<'ast, T> { } } Expression::Sub(box e1, box e2) => { - let e1_checked = self.check_expression(e1, module_id, state)?; - let e2_checked = self.check_expression(e2, module_id, state)?; + let e1_checked = self.check_expression(e1, module_id, types)?; + let e2_checked = self.check_expression(e2, module_id, types)?; use self::TypedExpression::*; @@ -1808,8 +1808,8 @@ impl<'ast, T: Field> Checker<'ast, T> { } } Expression::Mult(box e1, box e2) => { - let e1_checked = self.check_expression(e1, module_id, state)?; - let e2_checked = self.check_expression(e2, module_id, state)?; + let e1_checked = self.check_expression(e1, module_id, types)?; + let e2_checked = self.check_expression(e2, module_id, types)?; use self::TypedExpression::*; @@ -1843,8 +1843,8 @@ impl<'ast, T: Field> Checker<'ast, T> { } } Expression::Div(box e1, box e2) => { - let e1_checked = self.check_expression(e1, module_id, state)?; - let e2_checked = self.check_expression(e2, module_id, state)?; + let e1_checked = self.check_expression(e1, module_id, types)?; + let e2_checked = self.check_expression(e2, module_id, types)?; use self::TypedExpression::*; @@ -1878,8 +1878,8 @@ impl<'ast, T: Field> Checker<'ast, T> { } } Expression::Rem(box e1, box e2) => { - let e1_checked = self.check_expression(e1, module_id, state)?; - let e2_checked = self.check_expression(e2, module_id, state)?; + let e1_checked = self.check_expression(e1, module_id, types)?; + let e2_checked = self.check_expression(e2, module_id, types)?; let (e1_checked, e2_checked) = TypedExpression::align_without_integers( e1_checked, e2_checked, @@ -1907,8 +1907,8 @@ impl<'ast, T: Field> Checker<'ast, T> { } } Expression::Pow(box e1, box e2) => { - let e1_checked = self.check_expression(e1, module_id, state)?; - let e2_checked = self.check_expression(e2, module_id, state)?; + let e1_checked = self.check_expression(e1, module_id, types)?; + let e2_checked = self.check_expression(e2, module_id, types)?; let e1_checked = match FieldElementExpression::try_from_typed(e1_checked) { Ok(e) => e.into(), @@ -1935,7 +1935,7 @@ impl<'ast, T: Field> Checker<'ast, T> { } } Expression::Neg(box e) => { - let e = self.check_expression(e, module_id, state)?; + let e = self.check_expression(e, module_id, types)?; match e { TypedExpression::Int(e) => Ok(IntExpression::Neg(box e).into()), @@ -1954,7 +1954,7 @@ impl<'ast, T: Field> Checker<'ast, T> { } } Expression::Pos(box e) => { - let e = self.check_expression(e, module_id, state)?; + let e = self.check_expression(e, module_id, types)?; match e { TypedExpression::Int(e) => Ok(IntExpression::Pos(box e).into()), @@ -1973,9 +1973,9 @@ impl<'ast, T: Field> Checker<'ast, T> { } } Expression::IfElse(box condition, box consequence, box alternative) => { - let condition_checked = self.check_expression(condition, module_id, state)?; - let consequence_checked = self.check_expression(consequence, module_id, state)?; - let alternative_checked = self.check_expression(alternative, module_id, state)?; + let condition_checked = self.check_expression(condition, module_id, types)?; + let consequence_checked = self.check_expression(consequence, module_id, types)?; + let alternative_checked = self.check_expression(alternative, module_id, types)?; let (consequence_checked, alternative_checked) = TypedExpression::align_without_integers( @@ -2051,7 +2051,7 @@ impl<'ast, T: Field> Checker<'ast, T> { .map(|g| { g.map(|g| { let pos = g.pos(); - self.check_expression(g, module_id, state).and_then(|g| { + self.check_expression(g, module_id, types).and_then(|g| { UExpression::try_from_typed(g, UBitwidth::B32).map_err( |e| ErrorInner { pos: Some(pos), @@ -2073,7 +2073,7 @@ impl<'ast, T: Field> Checker<'ast, T> { // check the arguments let mut arguments_checked = vec![]; for arg in arguments { - let arg_checked = self.check_expression(arg, module_id, state)?; + let arg_checked = self.check_expression(arg, module_id, types)?; arguments_checked.push(arg_checked); } @@ -2199,8 +2199,8 @@ impl<'ast, T: Field> Checker<'ast, T> { } } Expression::Lt(box e1, box e2) => { - let e1_checked = self.check_expression(e1, module_id, state)?; - let e2_checked = self.check_expression(e2, module_id, state)?; + let e1_checked = self.check_expression(e1, module_id, types)?; + let e2_checked = self.check_expression(e2, module_id, types)?; let (e1_checked, e2_checked) = TypedExpression::align_without_integers( e1_checked, e2_checked, @@ -2249,8 +2249,8 @@ impl<'ast, T: Field> Checker<'ast, T> { } } Expression::Le(box e1, box e2) => { - let e1_checked = self.check_expression(e1, module_id, state)?; - let e2_checked = self.check_expression(e2, module_id, state)?; + let e1_checked = self.check_expression(e1, module_id, types)?; + let e2_checked = self.check_expression(e2, module_id, types)?; let (e1_checked, e2_checked) = TypedExpression::align_without_integers( e1_checked, e2_checked, @@ -2299,8 +2299,8 @@ impl<'ast, T: Field> Checker<'ast, T> { } } Expression::Eq(box e1, box e2) => { - let e1_checked = self.check_expression(e1, module_id, state)?; - let e2_checked = self.check_expression(e2, module_id, state)?; + let e1_checked = self.check_expression(e1, module_id, types)?; + let e2_checked = self.check_expression(e2, module_id, types)?; let (e1_checked, e2_checked) = TypedExpression::align_without_integers( e1_checked, e2_checked, @@ -2349,8 +2349,8 @@ impl<'ast, T: Field> Checker<'ast, T> { } } Expression::Ge(box e1, box e2) => { - let e1_checked = self.check_expression(e1, module_id, state)?; - let e2_checked = self.check_expression(e2, module_id, state)?; + let e1_checked = self.check_expression(e1, module_id, types)?; + let e2_checked = self.check_expression(e2, module_id, types)?; let (e1_checked, e2_checked) = TypedExpression::align_without_integers( e1_checked, e2_checked, @@ -2399,8 +2399,8 @@ impl<'ast, T: Field> Checker<'ast, T> { } } Expression::Gt(box e1, box e2) => { - let e1_checked = self.check_expression(e1, module_id, state)?; - let e2_checked = self.check_expression(e2, module_id, state)?; + let e1_checked = self.check_expression(e1, module_id, types)?; + let e2_checked = self.check_expression(e2, module_id, types)?; let (e1_checked, e2_checked) = TypedExpression::align_without_integers( e1_checked, e2_checked, @@ -2449,7 +2449,7 @@ impl<'ast, T: Field> Checker<'ast, T> { } } Expression::Select(box array, box index) => { - let array = self.check_expression(array, module_id, state)?; + let array = self.check_expression(array, module_id, types)?; match index { RangeOrExpression::Range(r) => { @@ -2463,13 +2463,13 @@ impl<'ast, T: Field> Checker<'ast, T> { let from = r .value .from - .map(|e| self.check_expression(e, module_id, state)) + .map(|e| self.check_expression(e, module_id, types)) .unwrap_or_else(|| Ok(UExpression::from(0u32).into()))?; let to = r .value .to - .map(|e| self.check_expression(e, module_id, state)) + .map(|e| self.check_expression(e, module_id, types)) .unwrap_or_else(|| Ok(array_size.clone().into()))?; let from = UExpression::try_from_typed(from, UBitwidth::B32).map_err(|e| ErrorInner { @@ -2509,7 +2509,7 @@ impl<'ast, T: Field> Checker<'ast, T> { } } RangeOrExpression::Expression(index) => { - let index = self.check_expression(index, module_id, state)?; + let index = self.check_expression(index, module_id, types)?; let index = UExpression::try_from_typed(index, UBitwidth::B32).map_err(|e| { @@ -2550,7 +2550,7 @@ impl<'ast, T: Field> Checker<'ast, T> { } } Expression::Member(box e, box id) => { - let e = self.check_expression(e, module_id, state)?; + let e = self.check_expression(e, module_id, types)?; match e { TypedExpression::Struct(s) => { @@ -2606,7 +2606,7 @@ impl<'ast, T: Field> Checker<'ast, T> { // check each expression, getting its type let mut expressions_or_spreads_checked = vec![]; for e in expressions_or_spreads { - let e_checked = self.check_spread_or_expression(e, module_id, state)?; + let e_checked = self.check_spread_or_expression(e, module_id, types)?; expressions_or_spreads_checked.push(e_checked); } @@ -2673,10 +2673,10 @@ impl<'ast, T: Field> Checker<'ast, T> { ) } Expression::ArrayInitializer(box e, box count) => { - let e = self.check_expression(e, module_id, state)?; + let e = self.check_expression(e, module_id, types)?; let ty = e.get_type(); - let count = self.check_expression(count, module_id, state)?; + let count = self.check_expression(count, module_id, types)?; let count = UExpression::try_from_typed(count, UBitwidth::B32).map_err(|e| ErrorInner { @@ -2696,7 +2696,7 @@ impl<'ast, T: Field> Checker<'ast, T> { let ty = self.check_type( UnresolvedType::User(id.clone()).at(42, 42, 42), module_id, - state, + types, )?; let struct_type = match ty { Type::Struct(struct_type) => struct_type, @@ -2736,7 +2736,7 @@ impl<'ast, T: Field> Checker<'ast, T> { match inline_members_map.remove(member.id.as_str()) { Some(value) => { let expression_checked = - self.check_expression(value, module_id, state)?; + self.check_expression(value, module_id, types)?; let expression_checked = TypedExpression::align_to_type( expression_checked, @@ -2781,8 +2781,8 @@ impl<'ast, T: Field> Checker<'ast, T> { .into()) } Expression::And(box e1, box e2) => { - let e1_checked = self.check_expression(e1, module_id, state)?; - let e2_checked = self.check_expression(e2, module_id, state)?; + let e1_checked = self.check_expression(e1, module_id, types)?; + let e2_checked = self.check_expression(e2, module_id, types)?; let (e1_checked, e2_checked) = TypedExpression::align_without_integers( e1_checked, e2_checked, @@ -2815,8 +2815,8 @@ impl<'ast, T: Field> Checker<'ast, T> { } } Expression::Or(box e1, box e2) => { - let e1_checked = self.check_expression(e1, module_id, state)?; - let e2_checked = self.check_expression(e2, module_id, state)?; + let e1_checked = self.check_expression(e1, module_id, types)?; + let e2_checked = self.check_expression(e2, module_id, types)?; match (e1_checked, e2_checked) { (TypedExpression::Boolean(e1), TypedExpression::Boolean(e2)) => { Ok(BooleanExpression::Or(box e1, box e2).into()) @@ -2832,8 +2832,8 @@ impl<'ast, T: Field> Checker<'ast, T> { } } Expression::LeftShift(box e1, box e2) => { - let e1 = self.check_expression(e1, module_id, state)?; - let e2 = self.check_expression(e2, module_id, state)?; + let e1 = self.check_expression(e1, module_id, types)?; + let e2 = self.check_expression(e2, module_id, types)?; let e2 = UExpression::try_from_typed(e2, UBitwidth::B32).map_err(|e| ErrorInner { @@ -2859,8 +2859,8 @@ impl<'ast, T: Field> Checker<'ast, T> { } } Expression::RightShift(box e1, box e2) => { - let e1 = self.check_expression(e1, module_id, state)?; - let e2 = self.check_expression(e2, module_id, state)?; + let e1 = self.check_expression(e1, module_id, types)?; + let e2 = self.check_expression(e2, module_id, types)?; let e2 = UExpression::try_from_typed(e2, UBitwidth::B32).map_err(|e| ErrorInner { @@ -2888,8 +2888,8 @@ impl<'ast, T: Field> Checker<'ast, T> { } } Expression::BitOr(box e1, box e2) => { - let e1_checked = self.check_expression(e1, module_id, state)?; - let e2_checked = self.check_expression(e2, module_id, state)?; + let e1_checked = self.check_expression(e1, module_id, types)?; + let e2_checked = self.check_expression(e2, module_id, types)?; let (e1_checked, e2_checked) = TypedExpression::align_without_integers( e1_checked, e2_checked, @@ -2920,8 +2920,8 @@ impl<'ast, T: Field> Checker<'ast, T> { } } Expression::BitAnd(box e1, box e2) => { - let e1_checked = self.check_expression(e1, module_id, state)?; - let e2_checked = self.check_expression(e2, module_id, state)?; + let e1_checked = self.check_expression(e1, module_id, types)?; + let e2_checked = self.check_expression(e2, module_id, types)?; let (e1_checked, e2_checked) = TypedExpression::align_without_integers( e1_checked, e2_checked, @@ -2952,8 +2952,8 @@ impl<'ast, T: Field> Checker<'ast, T> { } } Expression::BitXor(box e1, box e2) => { - let e1_checked = self.check_expression(e1, module_id, state)?; - let e2_checked = self.check_expression(e2, module_id, state)?; + let e1_checked = self.check_expression(e1, module_id, types)?; + let e2_checked = self.check_expression(e2, module_id, types)?; let (e1_checked, e2_checked) = TypedExpression::align_without_integers( e1_checked, e2_checked, @@ -2984,7 +2984,7 @@ impl<'ast, T: Field> Checker<'ast, T> { } } Expression::Not(box e) => { - let e_checked = self.check_expression(e, module_id, state)?; + let e_checked = self.check_expression(e, module_id, types)?; match e_checked { TypedExpression::Int(e) => Ok(IntExpression::Not(box e).into()), TypedExpression::Boolean(e) => Ok(BooleanExpression::Not(box e).into()), @@ -3052,28 +3052,20 @@ mod tests { #[test] fn field_in_range() { // The value of `P - 1` is a valid field literal - - let modules = Modules::new(); - let state = State::new(modules); - let expr = Expression::FieldConstant(Bn128Field::max_value().to_biguint()).mock(); assert!(Checker::::new() - .check_expression(expr, &*MODULE_ID, &state) + .check_expression(expr, &*MODULE_ID, &TypeMap::new()) .is_ok()); } #[test] fn field_overflow() { // the value of `P` is an invalid field literal - - let modules = Modules::new(); - let state = State::new(modules); - let value = Bn128Field::max_value().to_biguint().add(1u32); let expr = Expression::FieldConstant(value).mock(); assert!(Checker::::new() - .check_expression(expr, &*MODULE_ID, &state) + .check_expression(expr, &*MODULE_ID, &TypeMap::new()) .is_err()); } } @@ -3087,9 +3079,7 @@ mod tests { // having different types in an array isn't allowed // in the case of arrays, lengths do *not* have to match, as at this point they can be // generic, so we cannot tell yet - - let modules = Modules::new(); - let state = State::new(modules); + let types = TypeMap::new(); // [3, true] let a = Expression::InlineArray(vec![ @@ -3098,7 +3088,7 @@ mod tests { ]) .mock(); assert!(Checker::::new() - .check_expression(a, &*MODULE_ID, &state) + .check_expression(a, &*MODULE_ID, &types) .is_err()); // [[0f], [0f, 0f]] @@ -3118,7 +3108,7 @@ mod tests { ]) .mock(); assert!(Checker::::new() - .check_expression(a, &*MODULE_ID, &state) + .check_expression(a, &*MODULE_ID, &types) .is_ok()); // [[0f], true] @@ -3134,7 +3124,7 @@ mod tests { ]) .mock(); assert!(Checker::::new() - .check_expression(a, &*MODULE_ID, &state) + .check_expression(a, &*MODULE_ID, &types) .is_err()); } } @@ -3769,7 +3759,7 @@ mod tests { #[test] fn undeclared_generic() { let modules = Modules::new(); - let mut state = State::new(modules); + let state = State::new(modules); let signature = UnresolvedSignature::new().inputs(vec![UnresolvedType::Array( box UnresolvedType::FieldElement.mock(), @@ -3777,7 +3767,7 @@ mod tests { ) .mock()]); assert_eq!( - Checker::::new().check_signature(signature, &*MODULE_ID, &mut state), + Checker::::new().check_signature(signature, &*MODULE_ID, &state), Err(vec![ErrorInner { pos: Some((Position::mock(), Position::mock())), message: "Undeclared symbol `K` in function definition".to_string() @@ -3789,7 +3779,7 @@ mod tests { fn success() { // (field[L][K]) -> field[L][K] let modules = Modules::new(); - let mut state = State::new(modules); + let state = State::new(modules); let signature = UnresolvedSignature::new() .generics(vec!["K".mock(), "L".mock(), "M".mock()]) @@ -3812,7 +3802,7 @@ mod tests { ) .mock()]); assert_eq!( - Checker::::new().check_signature(signature, &*MODULE_ID, &mut state), + Checker::::new().check_signature(signature, &*MODULE_ID, &state), Ok(DeclarationSignature::new() .inputs(vec![DeclarationType::array(( DeclarationType::array(( @@ -3842,14 +3832,11 @@ mod tests { ) .mock(); - let modules = Modules::new(); - let state = State::new(modules); - let mut checker: Checker = Checker::new(); checker.enter_scope(); assert_eq!( - checker.check_statement(statement, &*MODULE_ID, &state), + checker.check_statement(statement, &*MODULE_ID, &TypeMap::new()), Err(vec![ErrorInner { pos: Some((Position::mock(), Position::mock())), message: "Identifier \"b\" is undefined".into() @@ -3867,9 +3854,6 @@ mod tests { ) .mock(); - let modules = Modules::new(); - let state = State::new(modules); - let mut scope = HashSet::new(); scope.insert(ScopedVariable { id: Variable::field_element("a"), @@ -3882,7 +3866,7 @@ mod tests { let mut checker: Checker = new_with_args(scope, 1, HashSet::new()); assert_eq!( - checker.check_statement(statement, &*MODULE_ID, &state), + checker.check_statement(statement, &*MODULE_ID, &TypeMap::new()), Ok(TypedStatement::Definition( TypedAssignee::Identifier(typed_absy::Variable::field_element("a")), FieldElementExpression::Identifier("b".into()).into() @@ -4113,11 +4097,11 @@ mod tests { .mock(); let modules = Modules::new(); - let mut state = State::new(modules); + let state = State::new(modules); let mut checker: Checker = Checker::new(); assert_eq!( - checker.check_function(foo, &*MODULE_ID, &mut state), + checker.check_function(foo, &*MODULE_ID, &state), Err(vec![ErrorInner { pos: Some((Position::mock(), Position::mock())), message: "Identifier \"i\" is undefined".into() @@ -4197,11 +4181,11 @@ mod tests { }; let modules = Modules::new(); - let mut state = State::new(modules); + let state = State::new(modules); let mut checker: Checker = Checker::new(); assert_eq!( - checker.check_function(foo, &*MODULE_ID, &mut state), + checker.check_function(foo, &*MODULE_ID, &state), Ok(foo_checked) ); } @@ -4251,11 +4235,11 @@ mod tests { .mock(); let modules = Modules::new(); - let mut state = State::new(modules); + let state = State::new(modules); let mut checker: Checker = new_with_args(HashSet::new(), 0, functions); assert_eq!( - checker.check_function(bar, &*MODULE_ID, &mut state), + checker.check_function(bar, &*MODULE_ID, &state), Err(vec![ErrorInner { pos: Some((Position::mock(), Position::mock())), message: @@ -4310,11 +4294,11 @@ mod tests { .mock(); let modules = Modules::new(); - let mut state = State::new(modules); + let state = State::new(modules); let mut checker: Checker = new_with_args(HashSet::new(), 0, functions); assert_eq!( - checker.check_function(bar, &*MODULE_ID, &mut state), + checker.check_function(bar, &*MODULE_ID, &state), Err(vec![ErrorInner { pos: Some((Position::mock(), Position::mock())), message: "Function definition for function foo with signature () -> _ not found." @@ -4356,11 +4340,11 @@ mod tests { .mock(); let modules = Modules::new(); - let mut state = State::new(modules); + let state = State::new(modules); let mut checker: Checker = new_with_args(HashSet::new(), 0, HashSet::new()); assert_eq!( - checker.check_function(bar, &*MODULE_ID, &mut state), + checker.check_function(bar, &*MODULE_ID, &state), Err(vec![ErrorInner { pos: Some((Position::mock(), Position::mock())), @@ -4699,11 +4683,11 @@ mod tests { .mock(); let modules = Modules::new(); - let mut state = State::new(modules); + let state = State::new(modules); let mut checker: Checker = new_with_args(HashSet::new(), 0, HashSet::new()); assert_eq!( - checker.check_function(bar, &*MODULE_ID, &mut state), + checker.check_function(bar, &*MODULE_ID, &state), Err(vec![ErrorInner { pos: Some((Position::mock(), Position::mock())), @@ -4740,11 +4724,11 @@ mod tests { .mock(); let modules = Modules::new(); - let mut state = State::new(modules); + let state = State::new(modules); let mut checker: Checker = new_with_args(HashSet::new(), 0, HashSet::new()); assert_eq!( - checker.check_function(bar, &*MODULE_ID, &mut state), + checker.check_function(bar, &*MODULE_ID, &state), Err(vec![ErrorInner { pos: Some((Position::mock(), Position::mock())), message: "Identifier \"a\" is undefined".into() @@ -4848,11 +4832,11 @@ mod tests { }; let modules = Modules::new(); - let mut state = State::new(modules); + let state = State::new(modules); let mut checker: Checker = new_with_args(HashSet::new(), 0, functions); assert_eq!( - checker.check_function(bar, &*MODULE_ID, &mut state), + checker.check_function(bar, &*MODULE_ID, &state), Ok(bar_checked) ); } @@ -4881,12 +4865,12 @@ mod tests { ]); let modules = Modules::new(); - let mut state = State::new(modules); + let state = State::new(modules); let mut checker: Checker = new_with_args(HashSet::new(), 0, HashSet::new()); assert_eq!( checker - .check_function(f, &*MODULE_ID, &mut state) + .check_function(f, &*MODULE_ID, &state) .unwrap_err()[0] .message, "Duplicate name in function definition: `a` was previously declared as an argument or a generic constant" @@ -4985,9 +4969,6 @@ mod tests { // // should fail - let modules = Modules::new(); - let state = State::new(modules); - let mut checker: Checker = Checker::new(); let _: Result, Vec> = checker.check_statement( Statement::Declaration( @@ -4995,7 +4976,7 @@ mod tests { ) .mock(), &*MODULE_ID, - &state, + &TypeMap::new(), ); let s2_checked: Result, Vec> = checker .check_statement( @@ -5004,7 +4985,7 @@ mod tests { ) .mock(), &*MODULE_ID, - &state, + &TypeMap::new(), ); assert_eq!( s2_checked, @@ -5022,9 +5003,6 @@ mod tests { // // should fail - let modules = Modules::new(); - let state = State::new(modules); - let mut checker: Checker = Checker::new(); let _: Result, Vec> = checker.check_statement( Statement::Declaration( @@ -5032,7 +5010,7 @@ mod tests { ) .mock(), &*MODULE_ID, - &state, + &TypeMap::new(), ); let s2_checked: Result, Vec> = checker .check_statement( @@ -5041,7 +5019,7 @@ mod tests { ) .mock(), &*MODULE_ID, - &state, + &TypeMap::new(), ); assert_eq!( s2_checked, @@ -5087,7 +5065,7 @@ mod tests { fn empty_def() { // an empty struct should be allowed to be defined let modules = Modules::new(); - let mut state = State::new(modules); + let state = State::new(modules); let declaration: StructDefinitionNode = StructDefinition { fields: vec![] }.mock(); @@ -5102,7 +5080,7 @@ mod tests { "Foo".into(), declaration, &*MODULE_ID, - &mut state + &state ), Ok(expected_type) ); @@ -5112,7 +5090,7 @@ mod tests { fn valid_def() { // a valid struct should be allowed to be defined let modules = Modules::new(); - let mut state = State::new(modules); + let state = State::new(modules); let declaration: StructDefinitionNode = StructDefinition { fields: vec![ @@ -5144,7 +5122,7 @@ mod tests { "Foo".into(), declaration, &*MODULE_ID, - &mut state + &state ), Ok(expected_type) ); @@ -5154,7 +5132,7 @@ mod tests { fn duplicate_member_def() { // definition of a struct with a duplicate member should be rejected let modules = Modules::new(); - let mut state = State::new(modules); + let state = State::new(modules); let declaration: StructDefinitionNode = StructDefinition { fields: vec![ @@ -5178,7 +5156,7 @@ mod tests { "Foo".into(), declaration, &*MODULE_ID, - &mut state + &state ) .unwrap_err()[0] .message, @@ -5383,7 +5361,7 @@ mod tests { // an undefined type cannot be checked // Bar - let (mut checker, mut state) = create_module_with_foo(StructDefinition { + let (mut checker, state) = create_module_with_foo(StructDefinition { fields: vec![StructDefinitionField { id: "foo", ty: UnresolvedType::FieldElement.mock(), @@ -5395,7 +5373,7 @@ mod tests { checker.check_type( UnresolvedType::User("Foo".into()).mock(), &*MODULE_ID, - &mut state + &state.types ), Ok(Type::Struct(StructType::new( "".into(), @@ -5409,7 +5387,7 @@ mod tests { .check_type( UnresolvedType::User("Bar".into()).mock(), &*MODULE_ID, - &mut state + &state.types ) .unwrap_err() .message, @@ -5429,7 +5407,7 @@ mod tests { // struct Foo = { foo: field } // Foo { foo: 42 }.foo - let (mut checker, mut state) = create_module_with_foo(StructDefinition { + let (mut checker, state) = create_module_with_foo(StructDefinition { fields: vec![StructDefinitionField { id: "foo", ty: UnresolvedType::FieldElement.mock(), @@ -5449,7 +5427,7 @@ mod tests { ) .mock(), &*MODULE_ID, - &mut state + &state.types ), Ok(FieldElementExpression::Member( box StructExpressionInner::Value(vec![FieldElementExpression::Number( @@ -5474,7 +5452,7 @@ mod tests { // struct Foo = { foo: field } // Foo { foo: 42 }.bar - let (mut checker, mut state) = create_module_with_foo(StructDefinition { + let (mut checker, state) = create_module_with_foo(StructDefinition { fields: vec![StructDefinitionField { id: "foo", ty: UnresolvedType::FieldElement.mock(), @@ -5495,7 +5473,7 @@ mod tests { ) .mock(), &*MODULE_ID, - &mut state + &state.types ) .unwrap_err() .message, @@ -5512,7 +5490,7 @@ mod tests { fn wrong_name() { // a A value cannot be defined with B as id, even if A and B have the same members - let (mut checker, mut state) = create_module_with_foo(StructDefinition { + let (mut checker, state) = create_module_with_foo(StructDefinition { fields: vec![StructDefinitionField { id: "foo", ty: UnresolvedType::FieldElement.mock(), @@ -5529,7 +5507,7 @@ mod tests { ) .mock(), &*MODULE_ID, - &mut state + &state.types ) .unwrap_err() .message, @@ -5544,7 +5522,7 @@ mod tests { // struct Foo = { foo: field, bar: bool } // Foo foo = Foo { foo: 42, bar: true } - let (mut checker, mut state) = create_module_with_foo(StructDefinition { + let (mut checker, state) = create_module_with_foo(StructDefinition { fields: vec![ StructDefinitionField { id: "foo", @@ -5570,7 +5548,7 @@ mod tests { ) .mock(), &*MODULE_ID, - &mut state + &state.types ), Ok(StructExpressionInner::Value(vec![ FieldElementExpression::Number(Bn128Field::from(42u32)).into(), @@ -5595,7 +5573,7 @@ mod tests { // struct Foo = { foo: field, bar: bool } // Foo foo = Foo { bar: true, foo: 42 } - let (mut checker, mut state) = create_module_with_foo(StructDefinition { + let (mut checker, state) = create_module_with_foo(StructDefinition { fields: vec![ StructDefinitionField { id: "foo", @@ -5621,7 +5599,7 @@ mod tests { ) .mock(), &*MODULE_ID, - &mut state + &state.types ), Ok(StructExpressionInner::Value(vec![ FieldElementExpression::Number(Bn128Field::from(42u32)).into(), @@ -5646,7 +5624,7 @@ mod tests { // struct Foo = { foo: field, bar: bool } // Foo foo = Foo { foo: 42 } - let (mut checker, mut state) = create_module_with_foo(StructDefinition { + let (mut checker, state) = create_module_with_foo(StructDefinition { fields: vec![ StructDefinitionField { id: "foo", @@ -5670,7 +5648,7 @@ mod tests { ) .mock(), &*MODULE_ID, - &mut state + &state.types ) .unwrap_err() .message, @@ -5687,7 +5665,7 @@ mod tests { // Foo { foo: 42, baz: bool } // error // Foo { foo: 42, baz: 42 } // error - let (mut checker, mut state) = create_module_with_foo(StructDefinition { + let (mut checker, state) = create_module_with_foo(StructDefinition { fields: vec![ StructDefinitionField { id: "foo", @@ -5717,7 +5695,7 @@ mod tests { ) .mock(), &*MODULE_ID, - &mut state + &state.types ).unwrap_err() .message, "Member bar of struct Foo {foo: field, bar: bool} not found in value Foo {baz: true, foo: 42}" @@ -5735,7 +5713,7 @@ mod tests { ) .mock(), &*MODULE_ID, - &mut state + &state.types ) .unwrap_err() .message, @@ -5862,9 +5840,6 @@ mod tests { // a = 42 let a = Assignee::Identifier("a").mock(); - let modules = Modules::new(); - let state = State::new(modules); - let mut checker: Checker = Checker::new(); checker.enter_scope(); @@ -5875,12 +5850,12 @@ mod tests { ) .mock(), &*MODULE_ID, - &state, + &TypeMap::new(), ) .unwrap(); assert_eq!( - checker.check_assignee(a, &*MODULE_ID, &state), + checker.check_assignee(a, &*MODULE_ID, &TypeMap::new()), Ok(TypedAssignee::Identifier( typed_absy::Variable::field_element("a") )) @@ -5897,9 +5872,6 @@ mod tests { ) .mock(); - let modules = Modules::new(); - let state = State::new(modules); - let mut checker: Checker = Checker::new(); checker.enter_scope(); @@ -5918,12 +5890,12 @@ mod tests { ) .mock(), &*MODULE_ID, - &state, + &TypeMap::new(), ) .unwrap(); assert_eq!( - checker.check_assignee(a, &*MODULE_ID, &state), + checker.check_assignee(a, &*MODULE_ID, &TypeMap::new()), Ok(TypedAssignee::Select( box TypedAssignee::Identifier(typed_absy::Variable::field_array( "a", @@ -5950,9 +5922,6 @@ mod tests { ) .mock(); - let modules = Modules::new(); - let state = State::new(modules); - let mut checker: Checker = Checker::new(); checker.enter_scope(); @@ -5975,12 +5944,12 @@ mod tests { ) .mock(), &*MODULE_ID, - &state, + &TypeMap::new(), ) .unwrap(); assert_eq!( - checker.check_assignee(a, &*MODULE_ID, &state), + checker.check_assignee(a, &*MODULE_ID, &TypeMap::new()), Ok(TypedAssignee::Select( box TypedAssignee::Select( box TypedAssignee::Identifier(typed_absy::Variable::array( diff --git a/zokrates_core/src/static_analysis/constant_inliner.rs b/zokrates_core/src/static_analysis/constant_inliner.rs index 6edf4453..4a363fa2 100644 --- a/zokrates_core/src/static_analysis/constant_inliner.rs +++ b/zokrates_core/src/static_analysis/constant_inliner.rs @@ -14,13 +14,24 @@ pub struct ConstantInliner<'ast, 'a, T: Field> { } impl<'ast, 'a, T: Field> ConstantInliner<'ast, 'a, T> { + pub fn new( + modules: TypedModules<'ast, T>, + location: OwnedTypedModuleId, + propagator: Propagator<'ast, 'a, T>, + ) -> Self { + ConstantInliner { + modules, + location, + propagator, + } + } pub fn inline(p: TypedProgram<'ast, T>) -> TypedProgram<'ast, T> { let mut constants = HashMap::new(); - let mut inliner = ConstantInliner { - modules: p.modules.clone(), - location: p.main.clone(), - propagator: Propagator::with_constants(&mut constants), - }; + let mut inliner = ConstantInliner::new( + p.modules.clone(), + p.main.clone(), + Propagator::with_constants(&mut constants), + ); inliner.fold_program(p) } @@ -83,22 +94,6 @@ impl<'ast, 'a, T: Field> Folder<'ast, T> for ConstantInliner<'ast, 'a, T> { } } - fn fold_signature(&mut self, s: DeclarationSignature<'ast>) -> DeclarationSignature<'ast> { - DeclarationSignature { - generics: s.generics, - inputs: s - .inputs - .into_iter() - .map(|ty| self.fold_declaration_type(ty)) - .collect(), - outputs: s - .outputs - .into_iter() - .map(|ty| self.fold_declaration_type(ty)) - .collect(), - } - } - fn fold_declaration_type(&mut self, t: DeclarationType<'ast>) -> DeclarationType<'ast> { match t { DeclarationType::Array(ref array_ty) => match array_ty.size { diff --git a/zokrates_core/src/typed_absy/folder.rs b/zokrates_core/src/typed_absy/folder.rs index dd3b9211..897ab097 100644 --- a/zokrates_core/src/typed_absy/folder.rs +++ b/zokrates_core/src/typed_absy/folder.rs @@ -43,7 +43,7 @@ pub trait Folder<'ast, T: Field>: Sized { } fn fold_signature(&mut self, s: DeclarationSignature<'ast>) -> DeclarationSignature<'ast> { - s + fold_signature(self, s) } fn fold_parameter(&mut self, p: DeclarationParameter<'ast>) -> DeclarationParameter<'ast> { @@ -699,6 +699,25 @@ pub fn fold_function<'ast, T: Field, F: Folder<'ast, T>>( } } +fn fold_signature<'ast, T: Field, F: Folder<'ast, T>>( + f: &mut F, + s: DeclarationSignature<'ast>, +) -> DeclarationSignature<'ast> { + DeclarationSignature { + generics: s.generics, + inputs: s + .inputs + .into_iter() + .map(|o| f.fold_declaration_type(o)) + .collect(), + outputs: s + .outputs + .into_iter() + .map(|o| f.fold_declaration_type(o)) + .collect(), + } +} + pub fn fold_array_expression<'ast, T: Field, F: Folder<'ast, T>>( f: &mut F, e: ArrayExpression<'ast, T>, diff --git a/zokrates_core/src/typed_absy/result_folder.rs b/zokrates_core/src/typed_absy/result_folder.rs index 2f7fe2af..76e73d93 100644 --- a/zokrates_core/src/typed_absy/result_folder.rs +++ b/zokrates_core/src/typed_absy/result_folder.rs @@ -60,7 +60,7 @@ pub trait ResultFolder<'ast, T: Field>: Sized { &mut self, s: DeclarationSignature<'ast>, ) -> Result, Self::Error> { - Ok(s) + fold_signature(self, s) } fn fold_parameter( @@ -769,6 +769,25 @@ pub fn fold_function<'ast, T: Field, F: ResultFolder<'ast, T>>( }) } +fn fold_signature<'ast, T: Field, F: ResultFolder<'ast, T>>( + f: &mut F, + s: DeclarationSignature<'ast>, +) -> Result, F::Error> { + Ok(DeclarationSignature { + generics: s.generics, + inputs: s + .inputs + .into_iter() + .map(|o| f.fold_declaration_type(o)) + .collect::>()?, + outputs: s + .outputs + .into_iter() + .map(|o| f.fold_declaration_type(o)) + .collect::>()?, + }) +} + pub fn fold_array_expression<'ast, T: Field, F: ResultFolder<'ast, T>>( f: &mut F, e: ArrayExpression<'ast, T>,