1
0
Fork 0
mirror of synced 2025-09-23 12:18:44 +00:00

use signature ouput for inference

This commit is contained in:
dark64 2022-06-25 01:58:53 +02:00
parent aa9dc9740f
commit 3f1358acfe
3 changed files with 227 additions and 280 deletions

View file

@ -1714,7 +1714,28 @@ impl<'ast, T: Field> Checker<'ast, T> {
let e_checked = e
.map(|e| {
self.check_expression(e, module_id, types)
match e.value {
Expression::FunctionCall(
box fun_id_expression,
generics,
arguments,
) => {
let ty =
crate::typed_absy::types::try_from_g_type(return_type.clone())
.map(Some)
.unwrap();
self.check_function_call_expression(
fun_id_expression,
generics,
arguments,
ty,
module_id,
types,
)
}
_ => self.check_expression(e, module_id, types),
}
.map_err(|e| vec![e])
})
.unwrap_or_else(|| {
@ -1773,142 +1794,24 @@ impl<'ast, T: Field> Checker<'ast, T> {
.map_err(|e| vec![e])
}
Statement::Definition(assignee, expr) => {
match expr.value {
Expression::FunctionCall(fun_id_expression, generics, arguments) => {
let fun_id = match fun_id_expression.value {
Expression::Identifier(id) => Ok(id),
e => Err(vec![ErrorInner {
pos: Some(pos),
message: format!(
"Expected function in function call to be an identifier, found {}",
e
),
}])
}?;
// check the generic arguments, if any
let generics_checked: Option<Vec<Option<UExpression<'ast, T>>>> = generics
.map(|generics| {
generics
.into_iter()
.map(|g| {
g.map(|g| {
let pos = g.pos();
self.check_expression(g, module_id, types).and_then(
|g| {
UExpression::try_from_typed(g, &UBitwidth::B32)
.map_err(|e| ErrorInner {
pos: Some(pos),
message: format!(
"Expected {} to be of type u32, found {}",
e,
e.get_type(),
),
})
},
)
})
.transpose()
})
.collect::<Result<_, _>>()
})
.transpose()
.map_err(|e| vec![e])?;
// check that the assignee is declared and is well formed
let assignee = self
.check_assignee(assignee, module_id, types)
.map_err(|e| vec![e])?;
// find argument types
let mut arguments_checked = vec![];
for arg in arguments {
let arg_checked = self
.check_expression(arg, module_id, types)
.map_err(|e| vec![e])?;
arguments_checked.push(arg_checked);
}
let arguments_types: Vec<_> =
arguments_checked.iter().map(|a| a.get_type()).collect();
let query = FunctionQuery::new(
fun_id,
&generics_checked,
&arguments_types,
match expr.value {
Expression::FunctionCall(box fun_id_expression, generics, arguments) => {
let e = self
.check_function_call_expression(
fun_id_expression,
generics,
arguments,
Some(assignee.get_type()),
);
let functions = self.find_functions(&query);
match functions.len() {
// the function has to be defined
1 => {
let mut functions = functions;
let f = functions.pop().unwrap();
let signature = f.signature;
let arguments_checked = arguments_checked.into_iter().zip(signature.inputs.iter()).map(|(a, t)| TypedExpression::align_to_type(a, t)).collect::<Result<Vec<_>, _>>().map_err(|e| ErrorInner {
pos: Some(pos),
message: format!("Expected function call argument to be of type `{}`, found `{}` of type `{}`", e.1, e.0, e.0.get_type())
}).map_err(|e| vec![e])?;
let generics_checked = generics_checked.unwrap_or_else(|| vec![None; signature.generics.len()]);
let output_type = assignee.get_type();
let function_key = DeclarationFunctionKey {
module: module_id.to_path_buf(),
id: f.id,
signature: signature.clone(),
};
let call = match output_type {
Type::Int => unreachable!(),
Type::FieldElement => FieldElementExpression::function_call(
function_key,
generics_checked,
arguments_checked,
).into(),
Type::Boolean => BooleanExpression::function_call(
function_key,
generics_checked,
arguments_checked,
).into(),
Type::Uint(bitwidth) => UExpression::function_call(
function_key,
generics_checked,
arguments_checked,
).annotate(bitwidth).into(),
Type::Struct(struct_ty) => StructExpression::function_call(
function_key,
generics_checked,
arguments_checked,
).annotate(struct_ty).into(),
Type::Array(array_ty) => ArrayExpression::function_call(
function_key,
generics_checked,
arguments_checked,
).annotate(*array_ty.ty, *array_ty.size).into(),
Type::Tuple(tuple_ty) => TupleExpression::function_call(
function_key,
generics_checked,
arguments_checked,
).annotate(tuple_ty).into(),
};
Ok(TypedStatement::Definition(assignee, call))
},
0 => Err(ErrorInner {
pos: Some(pos),
message: format!("Function definition for function {} with signature {} not found.", fun_id, query)
}),
n => Err(ErrorInner {
pos: Some(pos),
message: format!("Ambiguous call to function {}, {} candidates were found. Please be more explicit.", fun_id, n)
})
}.map_err(|e| vec![e])
module_id,
types,
)
.map_err(|e| vec![e])?;
Ok(TypedStatement::Definition(assignee, e))
}
_ => {
// check the expression to be assigned
@ -1916,15 +1819,10 @@ impl<'ast, T: Field> Checker<'ast, T> {
.check_expression(expr, module_id, types)
.map_err(|e| vec![e])?;
// check that the assignee is declared and is well formed
let var = self
.check_assignee(assignee, module_id, types)
.map_err(|e| vec![e])?;
let var_type = var.get_type();
let assignee_type = assignee.get_type();
// make sure the assignee has the same type as the rhs
match var_type {
match assignee_type {
Type::FieldElement => FieldElementExpression::try_from_typed(checked_expr)
.map(TypedExpression::from),
Type::Boolean => {
@ -1952,11 +1850,11 @@ impl<'ast, T: Field> Checker<'ast, T> {
"Expression `{}` of type `{}` cannot be assigned to `{}` of type `{}`",
e,
e.get_type(),
var.clone(),
var_type
assignee.clone(),
assignee_type
),
})
.map(|rhs| TypedStatement::Definition(var, rhs))
.map(|rhs| TypedStatement::Definition(assignee, rhs))
.map_err(|e| vec![e])
}
}
@ -2155,6 +2053,154 @@ impl<'ast, T: Field> Checker<'ast, T> {
}
}
fn check_function_call_expression(
&mut self,
function_id: ExpressionNode<'ast>,
generics: Option<Vec<Option<ExpressionNode<'ast>>>>,
arguments: Vec<ExpressionNode<'ast>>,
assignee_type: Option<Type<'ast, T>>,
module_id: &ModuleId,
types: &TypeMap<'ast, T>,
) -> Result<TypedExpression<'ast, T>, ErrorInner> {
let pos = function_id.pos();
let fun_id = match function_id.value {
Expression::Identifier(id) => Ok(id),
e => Err(ErrorInner {
pos: Some(pos),
message: format!(
"Expected function in function call to be an identifier, found `{}`",
e
),
}),
}?;
// check the generic arguments, if any
let generics_checked: Option<Vec<Option<UExpression<'ast, T>>>> = generics
.map(|generics| {
generics
.into_iter()
.map(|g| {
g.map(|g| {
let pos = g.pos();
self.check_expression(g, module_id, types).and_then(|g| {
UExpression::try_from_typed(g, &UBitwidth::B32).map_err(|e| {
ErrorInner {
pos: Some(pos),
message: format!(
"Expected {} to be of type u32, found {}",
e,
e.get_type(),
),
}
})
})
})
.transpose()
})
.collect::<Result<_, _>>()
})
.transpose()?;
// check the arguments
let mut arguments_checked = vec![];
for arg in arguments {
let arg_checked = self.check_expression(arg, module_id, types)?;
arguments_checked.push(arg_checked);
}
let arguments_types: Vec<_> = arguments_checked.iter().map(|a| a.get_type()).collect();
// we use type inference to determine the type of the return, so we don't specify it
let query = FunctionQuery::new(
fun_id,
&generics_checked,
&arguments_types,
assignee_type.clone(),
);
let functions = self.find_functions(&query);
match functions.len() {
// the function has to be defined
1 => {
let mut functions = functions;
let f = functions.pop().unwrap();
let signature = f.signature;
let arguments_checked = arguments_checked.into_iter().zip(signature.inputs.iter()).map(|(a, t)| TypedExpression::align_to_type(a, t)).collect::<Result<Vec<_>, _>>().map_err(|e| ErrorInner {
pos: Some(pos),
message: format!("Expected function call argument to be of type `{}`, found `{}` of type `{}`", e.1, e.0, e.0.get_type())
})?;
let generics_checked = generics_checked.unwrap_or_else(|| vec![None; signature.generics.len()]);
let output_type = assignee_type.map(Ok).unwrap_or_else(|| signature.get_output_type(
generics_checked.clone(),
arguments_checked.iter().map(|a| a.get_type()).collect()
).map_err(|e| ErrorInner {
pos: Some(pos),
message: format!(
"Failed to infer value for generic parameter `{}`, try providing an explicit value",
e,
),
}))?;
let function_key = DeclarationFunctionKey {
module: module_id.to_path_buf(),
id: f.id,
signature: signature.clone(),
};
match output_type {
Type::Int => unreachable!(),
Type::FieldElement => Ok(FieldElementExpression::function_call(
function_key,
generics_checked,
arguments_checked,
).into()),
Type::Boolean => Ok(BooleanExpression::function_call(
function_key,
generics_checked,
arguments_checked,
).into()),
Type::Uint(bitwidth) => Ok(UExpression::function_call(
function_key,
generics_checked,
arguments_checked,
).annotate(bitwidth).into()),
Type::Struct(struct_ty) => Ok(StructExpression::function_call(
function_key,
generics_checked,
arguments_checked,
).annotate(struct_ty).into()),
Type::Array(array_ty) => Ok(ArrayExpression::function_call(
function_key,
generics_checked,
arguments_checked,
).annotate(*array_ty.ty, *array_ty.size).into()),
Type::Tuple(tuple_ty) => Ok(TupleExpression::function_call(
function_key,
generics_checked,
arguments_checked,
).annotate(tuple_ty).into()),
}
}
0 => Err(ErrorInner {
pos: Some(pos),
message: format!(
"Function definition for function {} with signature {} not found.",
fun_id, query
),
}),
n => Err(ErrorInner {
pos: Some(pos),
message: format!("Ambiguous call to function {}, {} candidates were found. Please be more explicit.", fun_id, n)
}),
}
}
fn check_expression(
&mut self,
expr: ExpressionNode<'ast>,
@ -2523,140 +2569,15 @@ impl<'ast, T: Field> Checker<'ast, T> {
Expression::U16Constant(n) => Ok(UExpressionInner::Value(n.into()).annotate(16).into()),
Expression::U32Constant(n) => Ok(UExpressionInner::Value(n.into()).annotate(32).into()),
Expression::U64Constant(n) => Ok(UExpressionInner::Value(n.into()).annotate(64).into()),
Expression::FunctionCall(fun_id_expression, generics, arguments) => {
let fun_id = match fun_id_expression.value {
Expression::Identifier(id) => Ok(id),
e => Err(ErrorInner {
pos: Some(pos),
message: format!(
"Expected function in function call to be an identifier, found `{}`",
e
Expression::FunctionCall(box fun_id_expression, generics, arguments) => self
.check_function_call_expression(
fun_id_expression,
generics,
arguments,
None,
module_id,
types,
),
}),
}?;
// check the generic arguments, if any
let generics_checked: Option<Vec<Option<UExpression<'ast, T>>>> = generics
.map(|generics| {
generics
.into_iter()
.map(|g| {
g.map(|g| {
let pos = g.pos();
self.check_expression(g, module_id, types).and_then(|g| {
UExpression::try_from_typed(g, &UBitwidth::B32).map_err(
|e| ErrorInner {
pos: Some(pos),
message: format!(
"Expected {} to be of type u32, found {}",
e,
e.get_type(),
),
},
)
})
})
.transpose()
})
.collect::<Result<_, _>>()
})
.transpose()?;
// check the arguments
let mut arguments_checked = vec![];
for arg in arguments {
let arg_checked = self.check_expression(arg, module_id, types)?;
arguments_checked.push(arg_checked);
}
let arguments_types: Vec<_> =
arguments_checked.iter().map(|a| a.get_type()).collect();
// we use type inference to determine the type of the return, so we don't specify it
let query = FunctionQuery::new(fun_id, &generics_checked, &arguments_types, None);
let functions = self.find_functions(&query);
match functions.len() {
// the function has to be defined
1 => {
let mut functions = functions;
let f = functions.pop().unwrap();
let signature = f.signature;
let arguments_checked = arguments_checked.into_iter().zip(signature.inputs.iter()).map(|(a, t)| TypedExpression::align_to_type(a, t)).collect::<Result<Vec<_>, _>>().map_err(|e| ErrorInner {
pos: Some(pos),
message: format!("Expected function call argument to be of type `{}`, found `{}` of type `{}`", e.1, e.0, e.0.get_type())
})?;
let generics_checked = generics_checked.unwrap_or_else(|| vec![None; signature.generics.len()]);
let output_type = signature.get_output_type(
generics_checked.clone(),
arguments_checked.iter().map(|a| a.get_type()).collect()
).map_err(|e| ErrorInner {
pos: Some(pos),
message: format!(
"Failed to infer value for generic parameter `{}`, try providing an explicit value",
e,
),
})?;
let function_key = DeclarationFunctionKey {
module: module_id.to_path_buf(),
id: f.id,
signature: signature.clone(),
};
match output_type {
Type::Int => unreachable!(),
Type::FieldElement => Ok(FieldElementExpression::function_call(
function_key,
generics_checked,
arguments_checked,
).into()),
Type::Boolean => Ok(BooleanExpression::function_call(
function_key,
generics_checked,
arguments_checked,
).into()),
Type::Uint(bitwidth) => Ok(UExpression::function_call(
function_key,
generics_checked,
arguments_checked,
).annotate(bitwidth).into()),
Type::Struct(struct_ty) => Ok(StructExpression::function_call(
function_key,
generics_checked,
arguments_checked,
).annotate(struct_ty).into()),
Type::Array(array_ty) => Ok(ArrayExpression::function_call(
function_key,
generics_checked,
arguments_checked,
).annotate(*array_ty.ty, *array_ty.size).into()),
Type::Tuple(tuple_ty) => Ok(TupleExpression::function_call(
function_key,
generics_checked,
arguments_checked,
).annotate(tuple_ty).into()),
}
}
0 => Err(ErrorInner {
pos: Some(pos),
message: format!(
"Function definition for function {} with signature {} not found.",
fun_id, query
),
}),
n => Err(ErrorInner {
pos: Some(pos),
message: format!("Ambiguous call to function {}, {} candidates were found. Please be more explicit.", fun_id, n)
}),
}
}
Expression::Lt(box e1, box e2) => {
let e1_checked = self.check_expression(e1, module_id, types)?;
let e2_checked = self.check_expression(e2, module_id, types)?;

View file

@ -0,0 +1,15 @@
{
"curves": ["Bn128"],
"tests": [
{
"input": {
"values": []
},
"output": {
"Ok": {
"value": ["1", "2"]
}
}
}
]
}

View file

@ -0,0 +1,11 @@
def bar<N>() -> field[N] {
return [42; N];
}
def foo<N>() -> field[N] {
return bar();
}
def main() -> field[2] {
return foo();
}