use signature ouput for inference
This commit is contained in:
parent
aa9dc9740f
commit
3f1358acfe
3 changed files with 227 additions and 280 deletions
|
@ -1714,7 +1714,28 @@ impl<'ast, T: Field> Checker<'ast, T> {
|
|||
|
||||
let e_checked = e
|
||||
.map(|e| {
|
||||
self.check_expression(e, module_id, types)
|
||||
match e.value {
|
||||
Expression::FunctionCall(
|
||||
box fun_id_expression,
|
||||
generics,
|
||||
arguments,
|
||||
) => {
|
||||
let ty =
|
||||
crate::typed_absy::types::try_from_g_type(return_type.clone())
|
||||
.map(Some)
|
||||
.unwrap();
|
||||
|
||||
self.check_function_call_expression(
|
||||
fun_id_expression,
|
||||
generics,
|
||||
arguments,
|
||||
ty,
|
||||
module_id,
|
||||
types,
|
||||
)
|
||||
}
|
||||
_ => self.check_expression(e, module_id, types),
|
||||
}
|
||||
.map_err(|e| vec![e])
|
||||
})
|
||||
.unwrap_or_else(|| {
|
||||
|
@ -1773,142 +1794,24 @@ impl<'ast, T: Field> Checker<'ast, T> {
|
|||
.map_err(|e| vec![e])
|
||||
}
|
||||
Statement::Definition(assignee, expr) => {
|
||||
match expr.value {
|
||||
Expression::FunctionCall(fun_id_expression, generics, arguments) => {
|
||||
let fun_id = match fun_id_expression.value {
|
||||
Expression::Identifier(id) => Ok(id),
|
||||
e => Err(vec![ErrorInner {
|
||||
pos: Some(pos),
|
||||
message: format!(
|
||||
"Expected function in function call to be an identifier, found {}",
|
||||
e
|
||||
),
|
||||
}])
|
||||
}?;
|
||||
|
||||
// check the generic arguments, if any
|
||||
let generics_checked: Option<Vec<Option<UExpression<'ast, T>>>> = generics
|
||||
.map(|generics| {
|
||||
generics
|
||||
.into_iter()
|
||||
.map(|g| {
|
||||
g.map(|g| {
|
||||
let pos = g.pos();
|
||||
self.check_expression(g, module_id, types).and_then(
|
||||
|g| {
|
||||
UExpression::try_from_typed(g, &UBitwidth::B32)
|
||||
.map_err(|e| ErrorInner {
|
||||
pos: Some(pos),
|
||||
message: format!(
|
||||
"Expected {} to be of type u32, found {}",
|
||||
e,
|
||||
e.get_type(),
|
||||
),
|
||||
})
|
||||
},
|
||||
)
|
||||
})
|
||||
.transpose()
|
||||
})
|
||||
.collect::<Result<_, _>>()
|
||||
})
|
||||
.transpose()
|
||||
.map_err(|e| vec![e])?;
|
||||
|
||||
// check that the assignee is declared and is well formed
|
||||
let assignee = self
|
||||
.check_assignee(assignee, module_id, types)
|
||||
.map_err(|e| vec![e])?;
|
||||
|
||||
// find argument types
|
||||
let mut arguments_checked = vec![];
|
||||
for arg in arguments {
|
||||
let arg_checked = self
|
||||
.check_expression(arg, module_id, types)
|
||||
.map_err(|e| vec![e])?;
|
||||
arguments_checked.push(arg_checked);
|
||||
}
|
||||
|
||||
let arguments_types: Vec<_> =
|
||||
arguments_checked.iter().map(|a| a.get_type()).collect();
|
||||
|
||||
let query = FunctionQuery::new(
|
||||
fun_id,
|
||||
&generics_checked,
|
||||
&arguments_types,
|
||||
match expr.value {
|
||||
Expression::FunctionCall(box fun_id_expression, generics, arguments) => {
|
||||
let e = self
|
||||
.check_function_call_expression(
|
||||
fun_id_expression,
|
||||
generics,
|
||||
arguments,
|
||||
Some(assignee.get_type()),
|
||||
);
|
||||
|
||||
let functions = self.find_functions(&query);
|
||||
|
||||
match functions.len() {
|
||||
// the function has to be defined
|
||||
1 => {
|
||||
|
||||
let mut functions = functions;
|
||||
let f = functions.pop().unwrap();
|
||||
|
||||
let signature = f.signature;
|
||||
|
||||
let arguments_checked = arguments_checked.into_iter().zip(signature.inputs.iter()).map(|(a, t)| TypedExpression::align_to_type(a, t)).collect::<Result<Vec<_>, _>>().map_err(|e| ErrorInner {
|
||||
pos: Some(pos),
|
||||
message: format!("Expected function call argument to be of type `{}`, found `{}` of type `{}`", e.1, e.0, e.0.get_type())
|
||||
}).map_err(|e| vec![e])?;
|
||||
|
||||
let generics_checked = generics_checked.unwrap_or_else(|| vec![None; signature.generics.len()]);
|
||||
|
||||
let output_type = assignee.get_type();
|
||||
|
||||
let function_key = DeclarationFunctionKey {
|
||||
module: module_id.to_path_buf(),
|
||||
id: f.id,
|
||||
signature: signature.clone(),
|
||||
};
|
||||
|
||||
let call = match output_type {
|
||||
Type::Int => unreachable!(),
|
||||
Type::FieldElement => FieldElementExpression::function_call(
|
||||
function_key,
|
||||
generics_checked,
|
||||
arguments_checked,
|
||||
).into(),
|
||||
Type::Boolean => BooleanExpression::function_call(
|
||||
function_key,
|
||||
generics_checked,
|
||||
arguments_checked,
|
||||
).into(),
|
||||
Type::Uint(bitwidth) => UExpression::function_call(
|
||||
function_key,
|
||||
generics_checked,
|
||||
arguments_checked,
|
||||
).annotate(bitwidth).into(),
|
||||
Type::Struct(struct_ty) => StructExpression::function_call(
|
||||
function_key,
|
||||
generics_checked,
|
||||
arguments_checked,
|
||||
).annotate(struct_ty).into(),
|
||||
Type::Array(array_ty) => ArrayExpression::function_call(
|
||||
function_key,
|
||||
generics_checked,
|
||||
arguments_checked,
|
||||
).annotate(*array_ty.ty, *array_ty.size).into(),
|
||||
Type::Tuple(tuple_ty) => TupleExpression::function_call(
|
||||
function_key,
|
||||
generics_checked,
|
||||
arguments_checked,
|
||||
).annotate(tuple_ty).into(),
|
||||
};
|
||||
|
||||
Ok(TypedStatement::Definition(assignee, call))
|
||||
},
|
||||
0 => Err(ErrorInner {
|
||||
pos: Some(pos),
|
||||
message: format!("Function definition for function {} with signature {} not found.", fun_id, query)
|
||||
}),
|
||||
n => Err(ErrorInner {
|
||||
pos: Some(pos),
|
||||
message: format!("Ambiguous call to function {}, {} candidates were found. Please be more explicit.", fun_id, n)
|
||||
})
|
||||
}.map_err(|e| vec![e])
|
||||
module_id,
|
||||
types,
|
||||
)
|
||||
.map_err(|e| vec![e])?;
|
||||
Ok(TypedStatement::Definition(assignee, e))
|
||||
}
|
||||
_ => {
|
||||
// check the expression to be assigned
|
||||
|
@ -1916,15 +1819,10 @@ impl<'ast, T: Field> Checker<'ast, T> {
|
|||
.check_expression(expr, module_id, types)
|
||||
.map_err(|e| vec![e])?;
|
||||
|
||||
// check that the assignee is declared and is well formed
|
||||
let var = self
|
||||
.check_assignee(assignee, module_id, types)
|
||||
.map_err(|e| vec![e])?;
|
||||
|
||||
let var_type = var.get_type();
|
||||
let assignee_type = assignee.get_type();
|
||||
|
||||
// make sure the assignee has the same type as the rhs
|
||||
match var_type {
|
||||
match assignee_type {
|
||||
Type::FieldElement => FieldElementExpression::try_from_typed(checked_expr)
|
||||
.map(TypedExpression::from),
|
||||
Type::Boolean => {
|
||||
|
@ -1952,11 +1850,11 @@ impl<'ast, T: Field> Checker<'ast, T> {
|
|||
"Expression `{}` of type `{}` cannot be assigned to `{}` of type `{}`",
|
||||
e,
|
||||
e.get_type(),
|
||||
var.clone(),
|
||||
var_type
|
||||
assignee.clone(),
|
||||
assignee_type
|
||||
),
|
||||
})
|
||||
.map(|rhs| TypedStatement::Definition(var, rhs))
|
||||
.map(|rhs| TypedStatement::Definition(assignee, rhs))
|
||||
.map_err(|e| vec![e])
|
||||
}
|
||||
}
|
||||
|
@ -2155,6 +2053,154 @@ impl<'ast, T: Field> Checker<'ast, T> {
|
|||
}
|
||||
}
|
||||
|
||||
fn check_function_call_expression(
|
||||
&mut self,
|
||||
function_id: ExpressionNode<'ast>,
|
||||
generics: Option<Vec<Option<ExpressionNode<'ast>>>>,
|
||||
arguments: Vec<ExpressionNode<'ast>>,
|
||||
assignee_type: Option<Type<'ast, T>>,
|
||||
module_id: &ModuleId,
|
||||
types: &TypeMap<'ast, T>,
|
||||
) -> Result<TypedExpression<'ast, T>, ErrorInner> {
|
||||
let pos = function_id.pos();
|
||||
let fun_id = match function_id.value {
|
||||
Expression::Identifier(id) => Ok(id),
|
||||
e => Err(ErrorInner {
|
||||
pos: Some(pos),
|
||||
message: format!(
|
||||
"Expected function in function call to be an identifier, found `{}`",
|
||||
e
|
||||
),
|
||||
}),
|
||||
}?;
|
||||
|
||||
// check the generic arguments, if any
|
||||
let generics_checked: Option<Vec<Option<UExpression<'ast, T>>>> = generics
|
||||
.map(|generics| {
|
||||
generics
|
||||
.into_iter()
|
||||
.map(|g| {
|
||||
g.map(|g| {
|
||||
let pos = g.pos();
|
||||
self.check_expression(g, module_id, types).and_then(|g| {
|
||||
UExpression::try_from_typed(g, &UBitwidth::B32).map_err(|e| {
|
||||
ErrorInner {
|
||||
pos: Some(pos),
|
||||
message: format!(
|
||||
"Expected {} to be of type u32, found {}",
|
||||
e,
|
||||
e.get_type(),
|
||||
),
|
||||
}
|
||||
})
|
||||
})
|
||||
})
|
||||
.transpose()
|
||||
})
|
||||
.collect::<Result<_, _>>()
|
||||
})
|
||||
.transpose()?;
|
||||
|
||||
// check the arguments
|
||||
let mut arguments_checked = vec![];
|
||||
for arg in arguments {
|
||||
let arg_checked = self.check_expression(arg, module_id, types)?;
|
||||
arguments_checked.push(arg_checked);
|
||||
}
|
||||
|
||||
let arguments_types: Vec<_> = arguments_checked.iter().map(|a| a.get_type()).collect();
|
||||
|
||||
// we use type inference to determine the type of the return, so we don't specify it
|
||||
let query = FunctionQuery::new(
|
||||
fun_id,
|
||||
&generics_checked,
|
||||
&arguments_types,
|
||||
assignee_type.clone(),
|
||||
);
|
||||
|
||||
let functions = self.find_functions(&query);
|
||||
|
||||
match functions.len() {
|
||||
// the function has to be defined
|
||||
1 => {
|
||||
let mut functions = functions;
|
||||
|
||||
let f = functions.pop().unwrap();
|
||||
|
||||
let signature = f.signature;
|
||||
|
||||
let arguments_checked = arguments_checked.into_iter().zip(signature.inputs.iter()).map(|(a, t)| TypedExpression::align_to_type(a, t)).collect::<Result<Vec<_>, _>>().map_err(|e| ErrorInner {
|
||||
pos: Some(pos),
|
||||
message: format!("Expected function call argument to be of type `{}`, found `{}` of type `{}`", e.1, e.0, e.0.get_type())
|
||||
})?;
|
||||
|
||||
let generics_checked = generics_checked.unwrap_or_else(|| vec![None; signature.generics.len()]);
|
||||
|
||||
let output_type = assignee_type.map(Ok).unwrap_or_else(|| signature.get_output_type(
|
||||
generics_checked.clone(),
|
||||
arguments_checked.iter().map(|a| a.get_type()).collect()
|
||||
).map_err(|e| ErrorInner {
|
||||
pos: Some(pos),
|
||||
message: format!(
|
||||
"Failed to infer value for generic parameter `{}`, try providing an explicit value",
|
||||
e,
|
||||
),
|
||||
}))?;
|
||||
|
||||
let function_key = DeclarationFunctionKey {
|
||||
module: module_id.to_path_buf(),
|
||||
id: f.id,
|
||||
signature: signature.clone(),
|
||||
};
|
||||
|
||||
match output_type {
|
||||
Type::Int => unreachable!(),
|
||||
Type::FieldElement => Ok(FieldElementExpression::function_call(
|
||||
function_key,
|
||||
generics_checked,
|
||||
arguments_checked,
|
||||
).into()),
|
||||
Type::Boolean => Ok(BooleanExpression::function_call(
|
||||
function_key,
|
||||
generics_checked,
|
||||
arguments_checked,
|
||||
).into()),
|
||||
Type::Uint(bitwidth) => Ok(UExpression::function_call(
|
||||
function_key,
|
||||
generics_checked,
|
||||
arguments_checked,
|
||||
).annotate(bitwidth).into()),
|
||||
Type::Struct(struct_ty) => Ok(StructExpression::function_call(
|
||||
function_key,
|
||||
generics_checked,
|
||||
arguments_checked,
|
||||
).annotate(struct_ty).into()),
|
||||
Type::Array(array_ty) => Ok(ArrayExpression::function_call(
|
||||
function_key,
|
||||
generics_checked,
|
||||
arguments_checked,
|
||||
).annotate(*array_ty.ty, *array_ty.size).into()),
|
||||
Type::Tuple(tuple_ty) => Ok(TupleExpression::function_call(
|
||||
function_key,
|
||||
generics_checked,
|
||||
arguments_checked,
|
||||
).annotate(tuple_ty).into()),
|
||||
}
|
||||
}
|
||||
0 => Err(ErrorInner {
|
||||
pos: Some(pos),
|
||||
message: format!(
|
||||
"Function definition for function {} with signature {} not found.",
|
||||
fun_id, query
|
||||
),
|
||||
}),
|
||||
n => Err(ErrorInner {
|
||||
pos: Some(pos),
|
||||
message: format!("Ambiguous call to function {}, {} candidates were found. Please be more explicit.", fun_id, n)
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
fn check_expression(
|
||||
&mut self,
|
||||
expr: ExpressionNode<'ast>,
|
||||
|
@ -2523,140 +2569,15 @@ impl<'ast, T: Field> Checker<'ast, T> {
|
|||
Expression::U16Constant(n) => Ok(UExpressionInner::Value(n.into()).annotate(16).into()),
|
||||
Expression::U32Constant(n) => Ok(UExpressionInner::Value(n.into()).annotate(32).into()),
|
||||
Expression::U64Constant(n) => Ok(UExpressionInner::Value(n.into()).annotate(64).into()),
|
||||
Expression::FunctionCall(fun_id_expression, generics, arguments) => {
|
||||
let fun_id = match fun_id_expression.value {
|
||||
Expression::Identifier(id) => Ok(id),
|
||||
e => Err(ErrorInner {
|
||||
pos: Some(pos),
|
||||
message: format!(
|
||||
"Expected function in function call to be an identifier, found `{}`",
|
||||
e
|
||||
Expression::FunctionCall(box fun_id_expression, generics, arguments) => self
|
||||
.check_function_call_expression(
|
||||
fun_id_expression,
|
||||
generics,
|
||||
arguments,
|
||||
None,
|
||||
module_id,
|
||||
types,
|
||||
),
|
||||
}),
|
||||
}?;
|
||||
|
||||
// check the generic arguments, if any
|
||||
let generics_checked: Option<Vec<Option<UExpression<'ast, T>>>> = generics
|
||||
.map(|generics| {
|
||||
generics
|
||||
.into_iter()
|
||||
.map(|g| {
|
||||
g.map(|g| {
|
||||
let pos = g.pos();
|
||||
self.check_expression(g, module_id, types).and_then(|g| {
|
||||
UExpression::try_from_typed(g, &UBitwidth::B32).map_err(
|
||||
|e| ErrorInner {
|
||||
pos: Some(pos),
|
||||
message: format!(
|
||||
"Expected {} to be of type u32, found {}",
|
||||
e,
|
||||
e.get_type(),
|
||||
),
|
||||
},
|
||||
)
|
||||
})
|
||||
})
|
||||
.transpose()
|
||||
})
|
||||
.collect::<Result<_, _>>()
|
||||
})
|
||||
.transpose()?;
|
||||
|
||||
// check the arguments
|
||||
let mut arguments_checked = vec![];
|
||||
for arg in arguments {
|
||||
let arg_checked = self.check_expression(arg, module_id, types)?;
|
||||
arguments_checked.push(arg_checked);
|
||||
}
|
||||
|
||||
let arguments_types: Vec<_> =
|
||||
arguments_checked.iter().map(|a| a.get_type()).collect();
|
||||
|
||||
// we use type inference to determine the type of the return, so we don't specify it
|
||||
let query = FunctionQuery::new(fun_id, &generics_checked, &arguments_types, None);
|
||||
|
||||
let functions = self.find_functions(&query);
|
||||
|
||||
match functions.len() {
|
||||
// the function has to be defined
|
||||
1 => {
|
||||
let mut functions = functions;
|
||||
|
||||
let f = functions.pop().unwrap();
|
||||
|
||||
let signature = f.signature;
|
||||
|
||||
let arguments_checked = arguments_checked.into_iter().zip(signature.inputs.iter()).map(|(a, t)| TypedExpression::align_to_type(a, t)).collect::<Result<Vec<_>, _>>().map_err(|e| ErrorInner {
|
||||
pos: Some(pos),
|
||||
message: format!("Expected function call argument to be of type `{}`, found `{}` of type `{}`", e.1, e.0, e.0.get_type())
|
||||
})?;
|
||||
|
||||
let generics_checked = generics_checked.unwrap_or_else(|| vec![None; signature.generics.len()]);
|
||||
|
||||
let output_type = signature.get_output_type(
|
||||
generics_checked.clone(),
|
||||
arguments_checked.iter().map(|a| a.get_type()).collect()
|
||||
).map_err(|e| ErrorInner {
|
||||
pos: Some(pos),
|
||||
message: format!(
|
||||
"Failed to infer value for generic parameter `{}`, try providing an explicit value",
|
||||
e,
|
||||
),
|
||||
})?;
|
||||
|
||||
let function_key = DeclarationFunctionKey {
|
||||
module: module_id.to_path_buf(),
|
||||
id: f.id,
|
||||
signature: signature.clone(),
|
||||
};
|
||||
|
||||
match output_type {
|
||||
Type::Int => unreachable!(),
|
||||
Type::FieldElement => Ok(FieldElementExpression::function_call(
|
||||
function_key,
|
||||
generics_checked,
|
||||
arguments_checked,
|
||||
).into()),
|
||||
Type::Boolean => Ok(BooleanExpression::function_call(
|
||||
function_key,
|
||||
generics_checked,
|
||||
arguments_checked,
|
||||
).into()),
|
||||
Type::Uint(bitwidth) => Ok(UExpression::function_call(
|
||||
function_key,
|
||||
generics_checked,
|
||||
arguments_checked,
|
||||
).annotate(bitwidth).into()),
|
||||
Type::Struct(struct_ty) => Ok(StructExpression::function_call(
|
||||
function_key,
|
||||
generics_checked,
|
||||
arguments_checked,
|
||||
).annotate(struct_ty).into()),
|
||||
Type::Array(array_ty) => Ok(ArrayExpression::function_call(
|
||||
function_key,
|
||||
generics_checked,
|
||||
arguments_checked,
|
||||
).annotate(*array_ty.ty, *array_ty.size).into()),
|
||||
Type::Tuple(tuple_ty) => Ok(TupleExpression::function_call(
|
||||
function_key,
|
||||
generics_checked,
|
||||
arguments_checked,
|
||||
).annotate(tuple_ty).into()),
|
||||
}
|
||||
}
|
||||
0 => Err(ErrorInner {
|
||||
pos: Some(pos),
|
||||
message: format!(
|
||||
"Function definition for function {} with signature {} not found.",
|
||||
fun_id, query
|
||||
),
|
||||
}),
|
||||
n => Err(ErrorInner {
|
||||
pos: Some(pos),
|
||||
message: format!("Ambiguous call to function {}, {} candidates were found. Please be more explicit.", fun_id, n)
|
||||
}),
|
||||
}
|
||||
}
|
||||
Expression::Lt(box e1, box e2) => {
|
||||
let e1_checked = self.check_expression(e1, module_id, types)?;
|
||||
let e2_checked = self.check_expression(e2, module_id, types)?;
|
||||
|
|
|
@ -0,0 +1,15 @@
|
|||
{
|
||||
"curves": ["Bn128"],
|
||||
"tests": [
|
||||
{
|
||||
"input": {
|
||||
"values": []
|
||||
},
|
||||
"output": {
|
||||
"Ok": {
|
||||
"value": ["1", "2"]
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
11
zokrates_core_test/tests/tests/generics/return_inference.zok
Normal file
11
zokrates_core_test/tests/tests/generics/return_inference.zok
Normal file
|
@ -0,0 +1,11 @@
|
|||
def bar<N>() -> field[N] {
|
||||
return [42; N];
|
||||
}
|
||||
|
||||
def foo<N>() -> field[N] {
|
||||
return bar();
|
||||
}
|
||||
|
||||
def main() -> field[2] {
|
||||
return foo();
|
||||
}
|
Loading…
Reference in a new issue