diff --git a/zokrates_core/src/semantics.rs b/zokrates_core/src/semantics.rs index d2c4aeb1..7d5f953b 100644 --- a/zokrates_core/src/semantics.rs +++ b/zokrates_core/src/semantics.rs @@ -1714,8 +1714,29 @@ impl<'ast, T: Field> Checker<'ast, T> { let e_checked = e .map(|e| { - self.check_expression(e, module_id, types) - .map_err(|e| vec![e]) + match e.value { + Expression::FunctionCall( + box fun_id_expression, + generics, + arguments, + ) => { + let ty = + crate::typed_absy::types::try_from_g_type(return_type.clone()) + .map(Some) + .unwrap(); + + self.check_function_call_expression( + fun_id_expression, + generics, + arguments, + ty, + module_id, + types, + ) + } + _ => self.check_expression(e, module_id, types), + } + .map_err(|e| vec![e]) }) .unwrap_or_else(|| { Ok(TupleExpressionInner::Value(vec![]) @@ -1773,142 +1794,24 @@ impl<'ast, T: Field> Checker<'ast, T> { .map_err(|e| vec![e]) } Statement::Definition(assignee, expr) => { + // check that the assignee is declared and is well formed + let assignee = self + .check_assignee(assignee, module_id, types) + .map_err(|e| vec![e])?; + match expr.value { - Expression::FunctionCall(fun_id_expression, generics, arguments) => { - let fun_id = match fun_id_expression.value { - Expression::Identifier(id) => Ok(id), - e => Err(vec![ErrorInner { - pos: Some(pos), - message: format!( - "Expected function in function call to be an identifier, found {}", - e - ), - }]) - }?; - - // check the generic arguments, if any - let generics_checked: Option>>> = generics - .map(|generics| { - generics - .into_iter() - .map(|g| { - g.map(|g| { - let pos = g.pos(); - self.check_expression(g, module_id, types).and_then( - |g| { - UExpression::try_from_typed(g, &UBitwidth::B32) - .map_err(|e| ErrorInner { - pos: Some(pos), - message: format!( - "Expected {} to be of type u32, found {}", - e, - e.get_type(), - ), - }) - }, - ) - }) - .transpose() - }) - .collect::>() - }) - .transpose() + Expression::FunctionCall(box fun_id_expression, generics, arguments) => { + let e = self + .check_function_call_expression( + fun_id_expression, + generics, + arguments, + Some(assignee.get_type()), + module_id, + types, + ) .map_err(|e| vec![e])?; - - let assignee = self - .check_assignee(assignee, module_id, types) - .map_err(|e| vec![e])?; - - // find argument types - let mut arguments_checked = vec![]; - for arg in arguments { - let arg_checked = self - .check_expression(arg, module_id, types) - .map_err(|e| vec![e])?; - arguments_checked.push(arg_checked); - } - - let arguments_types: Vec<_> = - arguments_checked.iter().map(|a| a.get_type()).collect(); - - let query = FunctionQuery::new( - fun_id, - &generics_checked, - &arguments_types, - Some(assignee.get_type()), - ); - - let functions = self.find_functions(&query); - - match functions.len() { - // the function has to be defined - 1 => { - - let mut functions = functions; - let f = functions.pop().unwrap(); - - let signature = f.signature; - - let arguments_checked = arguments_checked.into_iter().zip(signature.inputs.iter()).map(|(a, t)| TypedExpression::align_to_type(a, t)).collect::, _>>().map_err(|e| ErrorInner { - pos: Some(pos), - message: format!("Expected function call argument to be of type `{}`, found `{}` of type `{}`", e.1, e.0, e.0.get_type()) - }).map_err(|e| vec![e])?; - - let generics_checked = generics_checked.unwrap_or_else(|| vec![None; signature.generics.len()]); - - let output_type = assignee.get_type(); - - let function_key = DeclarationFunctionKey { - module: module_id.to_path_buf(), - id: f.id, - signature: signature.clone(), - }; - - let call = match output_type { - Type::Int => unreachable!(), - Type::FieldElement => FieldElementExpression::function_call( - function_key, - generics_checked, - arguments_checked, - ).into(), - Type::Boolean => BooleanExpression::function_call( - function_key, - generics_checked, - arguments_checked, - ).into(), - Type::Uint(bitwidth) => UExpression::function_call( - function_key, - generics_checked, - arguments_checked, - ).annotate(bitwidth).into(), - Type::Struct(struct_ty) => StructExpression::function_call( - function_key, - generics_checked, - arguments_checked, - ).annotate(struct_ty).into(), - Type::Array(array_ty) => ArrayExpression::function_call( - function_key, - generics_checked, - arguments_checked, - ).annotate(*array_ty.ty, *array_ty.size).into(), - Type::Tuple(tuple_ty) => TupleExpression::function_call( - function_key, - generics_checked, - arguments_checked, - ).annotate(tuple_ty).into(), - }; - - Ok(TypedStatement::Definition(assignee, call)) - }, - 0 => Err(ErrorInner { - pos: Some(pos), - message: format!("Function definition for function {} with signature {} not found.", fun_id, query) - }), - n => Err(ErrorInner { - pos: Some(pos), - message: format!("Ambiguous call to function {}, {} candidates were found. Please be more explicit.", fun_id, n) - }) - }.map_err(|e| vec![e]) + Ok(TypedStatement::Definition(assignee, e)) } _ => { // check the expression to be assigned @@ -1916,15 +1819,10 @@ impl<'ast, T: Field> Checker<'ast, T> { .check_expression(expr, module_id, types) .map_err(|e| vec![e])?; - // check that the assignee is declared and is well formed - let var = self - .check_assignee(assignee, module_id, types) - .map_err(|e| vec![e])?; - - let var_type = var.get_type(); + let assignee_type = assignee.get_type(); // make sure the assignee has the same type as the rhs - match var_type { + match assignee_type { Type::FieldElement => FieldElementExpression::try_from_typed(checked_expr) .map(TypedExpression::from), Type::Boolean => { @@ -1952,11 +1850,11 @@ impl<'ast, T: Field> Checker<'ast, T> { "Expression `{}` of type `{}` cannot be assigned to `{}` of type `{}`", e, e.get_type(), - var.clone(), - var_type + assignee.clone(), + assignee_type ), }) - .map(|rhs| TypedStatement::Definition(var, rhs)) + .map(|rhs| TypedStatement::Definition(assignee, rhs)) .map_err(|e| vec![e]) } } @@ -2155,6 +2053,154 @@ impl<'ast, T: Field> Checker<'ast, T> { } } + fn check_function_call_expression( + &mut self, + function_id: ExpressionNode<'ast>, + generics: Option>>>, + arguments: Vec>, + assignee_type: Option>, + module_id: &ModuleId, + types: &TypeMap<'ast, T>, + ) -> Result, ErrorInner> { + let pos = function_id.pos(); + let fun_id = match function_id.value { + Expression::Identifier(id) => Ok(id), + e => Err(ErrorInner { + pos: Some(pos), + message: format!( + "Expected function in function call to be an identifier, found `{}`", + e + ), + }), + }?; + + // check the generic arguments, if any + let generics_checked: Option>>> = generics + .map(|generics| { + generics + .into_iter() + .map(|g| { + g.map(|g| { + let pos = g.pos(); + self.check_expression(g, module_id, types).and_then(|g| { + UExpression::try_from_typed(g, &UBitwidth::B32).map_err(|e| { + ErrorInner { + pos: Some(pos), + message: format!( + "Expected {} to be of type u32, found {}", + e, + e.get_type(), + ), + } + }) + }) + }) + .transpose() + }) + .collect::>() + }) + .transpose()?; + + // check the arguments + let mut arguments_checked = vec![]; + for arg in arguments { + let arg_checked = self.check_expression(arg, module_id, types)?; + arguments_checked.push(arg_checked); + } + + let arguments_types: Vec<_> = arguments_checked.iter().map(|a| a.get_type()).collect(); + + // we use type inference to determine the type of the return, so we don't specify it + let query = FunctionQuery::new( + fun_id, + &generics_checked, + &arguments_types, + assignee_type.clone(), + ); + + let functions = self.find_functions(&query); + + match functions.len() { + // the function has to be defined + 1 => { + let mut functions = functions; + + let f = functions.pop().unwrap(); + + let signature = f.signature; + + let arguments_checked = arguments_checked.into_iter().zip(signature.inputs.iter()).map(|(a, t)| TypedExpression::align_to_type(a, t)).collect::, _>>().map_err(|e| ErrorInner { + pos: Some(pos), + message: format!("Expected function call argument to be of type `{}`, found `{}` of type `{}`", e.1, e.0, e.0.get_type()) + })?; + + let generics_checked = generics_checked.unwrap_or_else(|| vec![None; signature.generics.len()]); + + let output_type = assignee_type.map(Ok).unwrap_or_else(|| signature.get_output_type( + generics_checked.clone(), + arguments_checked.iter().map(|a| a.get_type()).collect() + ).map_err(|e| ErrorInner { + pos: Some(pos), + message: format!( + "Failed to infer value for generic parameter `{}`, try providing an explicit value", + e, + ), + }))?; + + let function_key = DeclarationFunctionKey { + module: module_id.to_path_buf(), + id: f.id, + signature: signature.clone(), + }; + + match output_type { + Type::Int => unreachable!(), + Type::FieldElement => Ok(FieldElementExpression::function_call( + function_key, + generics_checked, + arguments_checked, + ).into()), + Type::Boolean => Ok(BooleanExpression::function_call( + function_key, + generics_checked, + arguments_checked, + ).into()), + Type::Uint(bitwidth) => Ok(UExpression::function_call( + function_key, + generics_checked, + arguments_checked, + ).annotate(bitwidth).into()), + Type::Struct(struct_ty) => Ok(StructExpression::function_call( + function_key, + generics_checked, + arguments_checked, + ).annotate(struct_ty).into()), + Type::Array(array_ty) => Ok(ArrayExpression::function_call( + function_key, + generics_checked, + arguments_checked, + ).annotate(*array_ty.ty, *array_ty.size).into()), + Type::Tuple(tuple_ty) => Ok(TupleExpression::function_call( + function_key, + generics_checked, + arguments_checked, + ).annotate(tuple_ty).into()), + } + } + 0 => Err(ErrorInner { + pos: Some(pos), + message: format!( + "Function definition for function {} with signature {} not found.", + fun_id, query + ), + }), + n => Err(ErrorInner { + pos: Some(pos), + message: format!("Ambiguous call to function {}, {} candidates were found. Please be more explicit.", fun_id, n) + }), + } + } + fn check_expression( &mut self, expr: ExpressionNode<'ast>, @@ -2523,140 +2569,15 @@ impl<'ast, T: Field> Checker<'ast, T> { Expression::U16Constant(n) => Ok(UExpressionInner::Value(n.into()).annotate(16).into()), Expression::U32Constant(n) => Ok(UExpressionInner::Value(n.into()).annotate(32).into()), Expression::U64Constant(n) => Ok(UExpressionInner::Value(n.into()).annotate(64).into()), - Expression::FunctionCall(fun_id_expression, generics, arguments) => { - let fun_id = match fun_id_expression.value { - Expression::Identifier(id) => Ok(id), - e => Err(ErrorInner { - pos: Some(pos), - message: format!( - "Expected function in function call to be an identifier, found `{}`", - e - ), - }), - }?; - - // check the generic arguments, if any - let generics_checked: Option>>> = generics - .map(|generics| { - generics - .into_iter() - .map(|g| { - g.map(|g| { - let pos = g.pos(); - self.check_expression(g, module_id, types).and_then(|g| { - UExpression::try_from_typed(g, &UBitwidth::B32).map_err( - |e| ErrorInner { - pos: Some(pos), - message: format!( - "Expected {} to be of type u32, found {}", - e, - e.get_type(), - ), - }, - ) - }) - }) - .transpose() - }) - .collect::>() - }) - .transpose()?; - - // check the arguments - let mut arguments_checked = vec![]; - for arg in arguments { - let arg_checked = self.check_expression(arg, module_id, types)?; - arguments_checked.push(arg_checked); - } - - let arguments_types: Vec<_> = - arguments_checked.iter().map(|a| a.get_type()).collect(); - - // we use type inference to determine the type of the return, so we don't specify it - let query = FunctionQuery::new(fun_id, &generics_checked, &arguments_types, None); - - let functions = self.find_functions(&query); - - match functions.len() { - // the function has to be defined - 1 => { - let mut functions = functions; - - let f = functions.pop().unwrap(); - - let signature = f.signature; - - let arguments_checked = arguments_checked.into_iter().zip(signature.inputs.iter()).map(|(a, t)| TypedExpression::align_to_type(a, t)).collect::, _>>().map_err(|e| ErrorInner { - pos: Some(pos), - message: format!("Expected function call argument to be of type `{}`, found `{}` of type `{}`", e.1, e.0, e.0.get_type()) - })?; - - let generics_checked = generics_checked.unwrap_or_else(|| vec![None; signature.generics.len()]); - - let output_type = signature.get_output_type( - generics_checked.clone(), - arguments_checked.iter().map(|a| a.get_type()).collect() - ).map_err(|e| ErrorInner { - pos: Some(pos), - message: format!( - "Failed to infer value for generic parameter `{}`, try providing an explicit value", - e, - ), - })?; - - let function_key = DeclarationFunctionKey { - module: module_id.to_path_buf(), - id: f.id, - signature: signature.clone(), - }; - - match output_type { - Type::Int => unreachable!(), - Type::FieldElement => Ok(FieldElementExpression::function_call( - function_key, - generics_checked, - arguments_checked, - ).into()), - Type::Boolean => Ok(BooleanExpression::function_call( - function_key, - generics_checked, - arguments_checked, - ).into()), - Type::Uint(bitwidth) => Ok(UExpression::function_call( - function_key, - generics_checked, - arguments_checked, - ).annotate(bitwidth).into()), - Type::Struct(struct_ty) => Ok(StructExpression::function_call( - function_key, - generics_checked, - arguments_checked, - ).annotate(struct_ty).into()), - Type::Array(array_ty) => Ok(ArrayExpression::function_call( - function_key, - generics_checked, - arguments_checked, - ).annotate(*array_ty.ty, *array_ty.size).into()), - Type::Tuple(tuple_ty) => Ok(TupleExpression::function_call( - function_key, - generics_checked, - arguments_checked, - ).annotate(tuple_ty).into()), - } - } - 0 => Err(ErrorInner { - pos: Some(pos), - message: format!( - "Function definition for function {} with signature {} not found.", - fun_id, query - ), - }), - n => Err(ErrorInner { - pos: Some(pos), - message: format!("Ambiguous call to function {}, {} candidates were found. Please be more explicit.", fun_id, n) - }), - } - } + Expression::FunctionCall(box fun_id_expression, generics, arguments) => self + .check_function_call_expression( + fun_id_expression, + generics, + arguments, + None, + module_id, + types, + ), Expression::Lt(box e1, box e2) => { let e1_checked = self.check_expression(e1, module_id, types)?; let e2_checked = self.check_expression(e2, module_id, types)?; diff --git a/zokrates_core_test/tests/tests/generics/return_inference.json b/zokrates_core_test/tests/tests/generics/return_inference.json new file mode 100644 index 00000000..0e0b40d2 --- /dev/null +++ b/zokrates_core_test/tests/tests/generics/return_inference.json @@ -0,0 +1,15 @@ +{ + "curves": ["Bn128"], + "tests": [ + { + "input": { + "values": [] + }, + "output": { + "Ok": { + "value": ["1", "2"] + } + } + } + ] +} \ No newline at end of file diff --git a/zokrates_core_test/tests/tests/generics/return_inference.zok b/zokrates_core_test/tests/tests/generics/return_inference.zok new file mode 100644 index 00000000..28c2ea18 --- /dev/null +++ b/zokrates_core_test/tests/tests/generics/return_inference.zok @@ -0,0 +1,11 @@ +def bar() -> field[N] { + return [42; N]; +} + +def foo() -> field[N] { + return bar(); +} + +def main() -> field[2] { + return foo(); +} \ No newline at end of file