refactor types, implement struct generics
This commit is contained in:
parent
bfcc733714
commit
e5bcbed81f
14 changed files with 489 additions and 360 deletions
|
@ -432,6 +432,7 @@ mod tests {
|
|||
vec![ConcreteType::Struct(ConcreteStructType::new(
|
||||
"".into(),
|
||||
"".into(),
|
||||
vec![],
|
||||
vec![ConcreteStructMember::new(
|
||||
"a".into(),
|
||||
ConcreteType::FieldElement
|
||||
|
@ -453,6 +454,7 @@ mod tests {
|
|||
vec![ConcreteType::Struct(ConcreteStructType::new(
|
||||
"".into(),
|
||||
"".into(),
|
||||
vec![],
|
||||
vec![ConcreteStructMember::new(
|
||||
"a".into(),
|
||||
ConcreteType::FieldElement
|
||||
|
@ -470,6 +472,7 @@ mod tests {
|
|||
vec![ConcreteType::Struct(ConcreteStructType::new(
|
||||
"".into(),
|
||||
"".into(),
|
||||
vec![],
|
||||
vec![ConcreteStructMember::new(
|
||||
"a".into(),
|
||||
ConcreteType::FieldElement
|
||||
|
@ -487,6 +490,7 @@ mod tests {
|
|||
vec![ConcreteType::Struct(ConcreteStructType::new(
|
||||
"".into(),
|
||||
"".into(),
|
||||
vec![],
|
||||
vec![ConcreteStructMember::new(
|
||||
"a".into(),
|
||||
ConcreteType::FieldElement
|
||||
|
|
|
@ -417,11 +417,13 @@ struct Bar { field a }
|
|||
ty: ConcreteType::Struct(ConcreteStructType::new(
|
||||
"foo".into(),
|
||||
"Foo".into(),
|
||||
vec![],
|
||||
vec![ConcreteStructMember {
|
||||
id: "b".into(),
|
||||
ty: box ConcreteType::Struct(ConcreteStructType::new(
|
||||
"bar".into(),
|
||||
"Bar".into(),
|
||||
vec![],
|
||||
vec![ConcreteStructMember {
|
||||
id: "a".into(),
|
||||
ty: box ConcreteType::FieldElement
|
||||
|
|
|
@ -20,7 +20,7 @@ use crate::parser::Position;
|
|||
use crate::absy::types::{UnresolvedSignature, UnresolvedType, UserTypeId};
|
||||
|
||||
use crate::typed_absy::types::{
|
||||
check_type, specialize_type, ArrayType, DeclarationArrayType, DeclarationConstant,
|
||||
check_type, specialize_declaration_type, ArrayType, DeclarationArrayType, DeclarationConstant,
|
||||
DeclarationFunctionKey, DeclarationSignature, DeclarationStructMember, DeclarationStructType,
|
||||
DeclarationType, GenericIdentifier, StructLocation, StructMember,
|
||||
};
|
||||
|
@ -357,29 +357,34 @@ impl<'ast, T: Field> Checker<'ast, T> {
|
|||
state: &State<'ast, T>,
|
||||
) -> Result<TypedConstant<'ast, T>, ErrorInner> {
|
||||
let pos = c.pos();
|
||||
let ty = self.check_type(c.value.ty.clone(), module_id, &state.types)?;
|
||||
let ty = self.check_declaration_type(
|
||||
c.value.ty.clone(),
|
||||
module_id,
|
||||
&state,
|
||||
&HashMap::default(),
|
||||
&mut HashSet::default(),
|
||||
)?;
|
||||
let checked_expr =
|
||||
self.check_expression(c.value.expression.clone(), module_id, &state.types)?;
|
||||
|
||||
match ty {
|
||||
Type::FieldElement => {
|
||||
DeclarationType::FieldElement => {
|
||||
FieldElementExpression::try_from_typed(checked_expr).map(TypedExpression::from)
|
||||
}
|
||||
Type::Boolean => {
|
||||
DeclarationType::Boolean => {
|
||||
BooleanExpression::try_from_typed(checked_expr).map(TypedExpression::from)
|
||||
}
|
||||
Type::Uint(bitwidth) => {
|
||||
UExpression::try_from_typed(checked_expr, bitwidth).map(TypedExpression::from)
|
||||
DeclarationType::Uint(bitwidth) => {
|
||||
UExpression::try_from_typed(checked_expr, &bitwidth).map(TypedExpression::from)
|
||||
}
|
||||
Type::Array(ref array_ty) => {
|
||||
ArrayExpression::try_from_typed(checked_expr, *array_ty.ty.clone())
|
||||
DeclarationType::Array(ref array_ty) => {
|
||||
ArrayExpression::try_from_typed(checked_expr, &array_ty).map(TypedExpression::from)
|
||||
}
|
||||
DeclarationType::Struct(ref struct_ty) => {
|
||||
StructExpression::try_from_typed(checked_expr, &struct_ty)
|
||||
.map(TypedExpression::from)
|
||||
}
|
||||
Type::Struct(ref struct_ty) => {
|
||||
StructExpression::try_from_typed(checked_expr, struct_ty.clone())
|
||||
.map(TypedExpression::from)
|
||||
}
|
||||
Type::Int => Err(checked_expr), // Integers cannot be assigned
|
||||
DeclarationType::Int => Err(checked_expr), // Integers cannot be assigned
|
||||
}
|
||||
.map_err(|e| ErrorInner {
|
||||
pos: Some(pos),
|
||||
|
@ -391,7 +396,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
|
|||
ty
|
||||
),
|
||||
})
|
||||
.map(|e| TypedConstant::new(ty, e))
|
||||
.map(|e| TypedConstant::new(e))
|
||||
}
|
||||
|
||||
fn check_struct_type_declaration(
|
||||
|
@ -1152,14 +1157,17 @@ impl<'ast, T: Field> Checker<'ast, T> {
|
|||
),
|
||||
}),
|
||||
},
|
||||
TypedExpression::Int(v) => UExpression::try_from_int(v.clone(), UBitwidth::B32)
|
||||
.map_err(|_| ErrorInner {
|
||||
pos: Some(pos),
|
||||
message: format!(
|
||||
TypedExpression::Int(v) => {
|
||||
UExpression::try_from_int(v.clone(), &UBitwidth::B32).map_err(|_| {
|
||||
ErrorInner {
|
||||
pos: Some(pos),
|
||||
message: format!(
|
||||
"Expected array dimension to be a u32 constant, found {} of type {}",
|
||||
v, ty
|
||||
),
|
||||
}),
|
||||
}
|
||||
})
|
||||
}
|
||||
_ => Err(ErrorInner {
|
||||
pos: Some(pos),
|
||||
message: format!(
|
||||
|
@ -1194,33 +1202,38 @@ impl<'ast, T: Field> Checker<'ast, T> {
|
|||
DeclarationType::Struct(struct_type) => {
|
||||
match struct_type.generics.len() == generics.len() {
|
||||
true => {
|
||||
let checked_generics = generics
|
||||
// downcast the generics to identifiers, as this is the only possibility here
|
||||
let generic_identifiers = struct_type.generics.iter().map(|c| {
|
||||
match c.as_ref().unwrap() {
|
||||
DeclarationConstant::Generic(g) => g.clone(),
|
||||
_ => unreachable!(),
|
||||
}
|
||||
});
|
||||
|
||||
// build the generic assignment for this type
|
||||
let assignment = GGenericsAssignment(generics
|
||||
.into_iter()
|
||||
.map(|e| match e {
|
||||
Some(e) => self
|
||||
.zip(generic_identifiers)
|
||||
.filter_map(|(e, g)| e.map(|e| {
|
||||
self
|
||||
.check_expression(e, module_id, types)
|
||||
.and_then(|e| {
|
||||
UExpression::try_from_typed(e, UBitwidth::B32)
|
||||
.map(Some)
|
||||
UExpression::try_from_typed(e, &UBitwidth::B32)
|
||||
.map(|e| (g, e))
|
||||
.map_err(|e| ErrorInner {
|
||||
pos: Some(pos),
|
||||
message: format!("Expected u32 expression, but got expression of type {}", e.get_type()),
|
||||
})
|
||||
}),
|
||||
None => Ok(None),
|
||||
})
|
||||
.collect::<Result<_, _>>()?;
|
||||
})
|
||||
}))
|
||||
.collect::<Result<_, _>>()?);
|
||||
|
||||
Ok(Type::Struct(StructType {
|
||||
canonical_location: struct_type.canonical_location,
|
||||
location: struct_type.location,
|
||||
generics: checked_generics,
|
||||
members: struct_type
|
||||
.members
|
||||
.into_iter()
|
||||
.map(|m| m.into())
|
||||
.collect(),
|
||||
}))
|
||||
// specialize the declared type using the generic assignment
|
||||
Ok(specialize_declaration_type(
|
||||
DeclarationType::Struct(struct_type),
|
||||
&assignment,
|
||||
)
|
||||
.unwrap())
|
||||
}
|
||||
false => Err(ErrorInner {
|
||||
pos: Some(pos),
|
||||
|
@ -1384,7 +1397,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
|
|||
.into_iter()
|
||||
.map(|m| {
|
||||
Ok(DeclarationStructMember {
|
||||
ty: box specialize_type(*m.ty, &assignment)
|
||||
ty: box specialize_declaration_type(*m.ty, &assignment)
|
||||
.map_err(|_| unimplemented!())?,
|
||||
..m
|
||||
})
|
||||
|
@ -1465,7 +1478,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
|
|||
}),
|
||||
},
|
||||
TypedExpression::Int(v) => {
|
||||
UExpression::try_from_int(v, UBitwidth::B32).map_err(|_| ErrorInner {
|
||||
UExpression::try_from_int(v, &UBitwidth::B32).map_err(|_| ErrorInner {
|
||||
pos: Some(pos),
|
||||
message: format!(
|
||||
"Expected lower loop bound to be of type u32, found {}",
|
||||
|
@ -1495,7 +1508,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
|
|||
}),
|
||||
},
|
||||
TypedExpression::Int(v) => {
|
||||
UExpression::try_from_int(v, UBitwidth::B32).map_err(|_| ErrorInner {
|
||||
UExpression::try_from_int(v, &UBitwidth::B32).map_err(|_| ErrorInner {
|
||||
pos: Some(pos),
|
||||
message: format!(
|
||||
"Expected upper loop bound to be of type u32, found {}",
|
||||
|
@ -1550,9 +1563,10 @@ impl<'ast, T: Field> Checker<'ast, T> {
|
|||
|
||||
let res = match expression_list_checked.len() == return_types.len() {
|
||||
true => match expression_list_checked
|
||||
.iter()
|
||||
.zip(return_types.clone())
|
||||
.map(|(e, t)| TypedExpression::align_to_type(e.clone(), t.into()))
|
||||
.clone()
|
||||
.into_iter()
|
||||
.zip(return_types.iter())
|
||||
.map(|(e, t)| TypedExpression::align_to_type(e, t))
|
||||
.collect::<Result<Vec<_>, _>>()
|
||||
.map_err(|e| {
|
||||
vec![ErrorInner {
|
||||
|
@ -1641,19 +1655,19 @@ impl<'ast, T: Field> Checker<'ast, T> {
|
|||
let var_type = var.get_type();
|
||||
|
||||
// make sure the assignee has the same type as the rhs
|
||||
match var_type.clone() {
|
||||
match var_type {
|
||||
Type::FieldElement => FieldElementExpression::try_from_typed(checked_expr)
|
||||
.map(TypedExpression::from),
|
||||
Type::Boolean => {
|
||||
BooleanExpression::try_from_typed(checked_expr).map(TypedExpression::from)
|
||||
}
|
||||
Type::Uint(bitwidth) => UExpression::try_from_typed(checked_expr, bitwidth)
|
||||
Type::Uint(bitwidth) => UExpression::try_from_typed(checked_expr, &bitwidth)
|
||||
.map(TypedExpression::from),
|
||||
Type::Array(array_ty) => {
|
||||
ArrayExpression::try_from_typed(checked_expr, *array_ty.ty)
|
||||
Type::Array(ref array_ty) => {
|
||||
ArrayExpression::try_from_typed(checked_expr, array_ty)
|
||||
.map(TypedExpression::from)
|
||||
}
|
||||
Type::Struct(struct_ty) => {
|
||||
Type::Struct(ref struct_ty) => {
|
||||
StructExpression::try_from_typed(checked_expr, struct_ty)
|
||||
.map(TypedExpression::from)
|
||||
}
|
||||
|
@ -1710,7 +1724,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
|
|||
g.map(|g| {
|
||||
let pos = g.pos();
|
||||
self.check_expression(g, module_id, types).and_then(|g| {
|
||||
UExpression::try_from_typed(g, UBitwidth::B32).map_err(
|
||||
UExpression::try_from_typed(g, &UBitwidth::B32).map_err(
|
||||
|e| ErrorInner {
|
||||
pos: Some(pos),
|
||||
message: format!(
|
||||
|
@ -1759,7 +1773,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
|
|||
let mut functions = functions;
|
||||
let f = functions.pop().unwrap();
|
||||
|
||||
let arguments_checked = arguments_checked.into_iter().zip(f.signature.inputs.clone()).map(|(a, t)| TypedExpression::align_to_type(a, t.into())).collect::<Result<Vec<_>, _>>().map_err(|e| vec![ErrorInner {
|
||||
let arguments_checked = arguments_checked.into_iter().zip(f.signature.inputs.iter()).map(|(a, t)| TypedExpression::align_to_type(a, t)).collect::<Result<Vec<_>, _>>().map_err(|e| vec![ErrorInner {
|
||||
pos: Some(pos),
|
||||
message: format!("Expected function call argument to be of type {}, found {} of type {}", e.1, e.0, e.0.get_type())
|
||||
}])?;
|
||||
|
@ -1827,7 +1841,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
|
|||
};
|
||||
|
||||
let checked_typed_index =
|
||||
UExpression::try_from_typed(checked_index, UBitwidth::B32).map_err(
|
||||
UExpression::try_from_typed(checked_index, &UBitwidth::B32).map_err(
|
||||
|e| ErrorInner {
|
||||
pos: Some(pos),
|
||||
message: format!(
|
||||
|
@ -2127,7 +2141,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
|
|||
Ok(e) => e.into(),
|
||||
Err(e) => e,
|
||||
};
|
||||
let e2_checked = match UExpression::try_from_typed(e2_checked, UBitwidth::B32) {
|
||||
let e2_checked = match UExpression::try_from_typed(e2_checked, &UBitwidth::B32) {
|
||||
Ok(e) => e.into(),
|
||||
Err(e) => e,
|
||||
};
|
||||
|
@ -2261,7 +2275,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
|
|||
g.map(|g| {
|
||||
let pos = g.pos();
|
||||
self.check_expression(g, module_id, types).and_then(|g| {
|
||||
UExpression::try_from_typed(g, UBitwidth::B32).map_err(
|
||||
UExpression::try_from_typed(g, &UBitwidth::B32).map_err(
|
||||
|e| ErrorInner {
|
||||
pos: Some(pos),
|
||||
message: format!(
|
||||
|
@ -2307,7 +2321,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
|
|||
|
||||
let signature = f.signature;
|
||||
|
||||
let arguments_checked = arguments_checked.into_iter().zip(signature.inputs.clone()).map(|(a, t)| TypedExpression::align_to_type(a, t.into())).collect::<Result<Vec<_>, _>>().map_err(|e| ErrorInner {
|
||||
let arguments_checked = arguments_checked.into_iter().zip(signature.inputs.iter()).map(|(a, t)| TypedExpression::align_to_type(a, &t)).collect::<Result<Vec<_>, _>>().map_err(|e| ErrorInner {
|
||||
pos: Some(pos),
|
||||
message: format!("Expected function call argument to be of type {}, found {}", e.1, e.0)
|
||||
})?;
|
||||
|
@ -2513,9 +2527,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
|
|||
(TypedExpression::Array(e1), TypedExpression::Array(e2)) => {
|
||||
Ok(BooleanExpression::ArrayEq(box e1, box e2).into())
|
||||
}
|
||||
(TypedExpression::Struct(e1), TypedExpression::Struct(e2))
|
||||
if e1.get_type() == e2.get_type() =>
|
||||
{
|
||||
(TypedExpression::Struct(e1), TypedExpression::Struct(e2)) => {
|
||||
Ok(BooleanExpression::StructEq(box e1, box e2).into())
|
||||
}
|
||||
(TypedExpression::Uint(e1), TypedExpression::Uint(e2))
|
||||
|
@ -2659,7 +2671,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
|
|||
.map(|e| self.check_expression(e, module_id, types))
|
||||
.unwrap_or_else(|| Ok(array_size.clone().into()))?;
|
||||
|
||||
let from = UExpression::try_from_typed(from, UBitwidth::B32).map_err(|e| ErrorInner {
|
||||
let from = UExpression::try_from_typed(from, &UBitwidth::B32).map_err(|e| ErrorInner {
|
||||
pos: Some(pos),
|
||||
message: format!(
|
||||
"Expected the lower bound of the range to be a u32, found {} of type {}",
|
||||
|
@ -2668,7 +2680,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
|
|||
),
|
||||
})?;
|
||||
|
||||
let to = UExpression::try_from_typed(to, UBitwidth::B32).map_err(|e| ErrorInner {
|
||||
let to = UExpression::try_from_typed(to, &UBitwidth::B32).map_err(|e| ErrorInner {
|
||||
pos: Some(pos),
|
||||
message: format!(
|
||||
"Expected the upper bound of the range to be a u32, found {} of type {}",
|
||||
|
@ -2699,7 +2711,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
|
|||
let index = self.check_expression(index, module_id, types)?;
|
||||
|
||||
let index =
|
||||
UExpression::try_from_typed(index, UBitwidth::B32).map_err(|e| {
|
||||
UExpression::try_from_typed(index, &UBitwidth::B32).map_err(|e| {
|
||||
ErrorInner {
|
||||
pos: Some(pos),
|
||||
message: format!(
|
||||
|
@ -2813,19 +2825,24 @@ impl<'ast, T: Field> Checker<'ast, T> {
|
|||
.next()
|
||||
.unwrap_or(Type::Int);
|
||||
|
||||
let unwrapped_expressions_or_spreads = match &inferred_type {
|
||||
let unwrapped_expressions_or_spreads = match inferred_type.clone() {
|
||||
Type::Int => expressions_or_spreads_checked,
|
||||
t => expressions_or_spreads_checked
|
||||
.into_iter()
|
||||
.map(|e| {
|
||||
TypedExpressionOrSpread::align_to_type(e, t.clone()).map_err(
|
||||
|(e, ty)| ErrorInner {
|
||||
pos: Some(pos),
|
||||
message: format!("Expected {} to have type {}", e, ty,),
|
||||
},
|
||||
)
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()?,
|
||||
t => {
|
||||
let target_array_ty =
|
||||
ArrayType::new(t, UExpressionInner::Value(0).annotate(UBitwidth::B32));
|
||||
|
||||
expressions_or_spreads_checked
|
||||
.into_iter()
|
||||
.map(|e| {
|
||||
TypedExpressionOrSpread::align_to_type(e, &target_array_ty).map_err(
|
||||
|(e, ty)| ErrorInner {
|
||||
pos: Some(pos),
|
||||
message: format!("Expected {} to have type {}", e, ty,),
|
||||
},
|
||||
)
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()?
|
||||
}
|
||||
};
|
||||
|
||||
// the size of the inline array is the sum of the size of its elements. However expressed as a u32 expression,
|
||||
|
@ -2863,15 +2880,16 @@ impl<'ast, T: Field> Checker<'ast, T> {
|
|||
|
||||
let count = self.check_expression(count, module_id, types)?;
|
||||
|
||||
let count =
|
||||
UExpression::try_from_typed(count, UBitwidth::B32).map_err(|e| ErrorInner {
|
||||
let count = UExpression::try_from_typed(count, &UBitwidth::B32).map_err(|e| {
|
||||
ErrorInner {
|
||||
pos: Some(pos),
|
||||
message: format!(
|
||||
"Expected array initializer count to be a u32, found {} of type {}",
|
||||
e,
|
||||
e.get_type(),
|
||||
),
|
||||
})?;
|
||||
}
|
||||
})?;
|
||||
|
||||
Ok(ArrayExpressionInner::Repeat(box e, box count.clone())
|
||||
.annotate(ty, count)
|
||||
|
@ -2881,7 +2899,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
|
|||
let ty = match types.get(module_id).unwrap().get(&id).cloned() {
|
||||
None => Err(ErrorInner {
|
||||
pos: Some(pos),
|
||||
message: format!("Undeclared type of name `{}`", id),
|
||||
message: format!("Undefined type `{}`", id),
|
||||
}),
|
||||
Some(ty) => Ok(ty),
|
||||
}?;
|
||||
|
@ -2926,13 +2944,11 @@ impl<'ast, T: Field> Checker<'ast, T> {
|
|||
let expression_checked =
|
||||
self.check_expression(value, module_id, types)?;
|
||||
|
||||
let expression_checked = TypedExpression::align_to_type(
|
||||
expression_checked,
|
||||
Type::from(*member.ty.clone()),
|
||||
)
|
||||
.map_err(|e| ErrorInner {
|
||||
pos: Some(pos),
|
||||
message: format!(
|
||||
let expression_checked =
|
||||
TypedExpression::align_to_type(expression_checked, &*member.ty)
|
||||
.map_err(|e| ErrorInner {
|
||||
pos: Some(pos),
|
||||
message: format!(
|
||||
"Member {} of struct {} has type {}, found {} of type {}",
|
||||
member.id,
|
||||
id.clone(),
|
||||
|
@ -2940,7 +2956,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
|
|||
e.0,
|
||||
e.0.get_type(),
|
||||
),
|
||||
})?;
|
||||
})?;
|
||||
|
||||
Ok(expression_checked)
|
||||
}
|
||||
|
@ -3056,7 +3072,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
|
|||
let e2 = self.check_expression(e2, module_id, types)?;
|
||||
|
||||
let e2 =
|
||||
UExpression::try_from_typed(e2, UBitwidth::B32).map_err(|e| ErrorInner {
|
||||
UExpression::try_from_typed(e2, &UBitwidth::B32).map_err(|e| ErrorInner {
|
||||
pos: Some(pos),
|
||||
message: format!(
|
||||
"Expected the left shift right operand to have type `u32`, found {}",
|
||||
|
@ -3083,7 +3099,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
|
|||
let e2 = self.check_expression(e2, module_id, types)?;
|
||||
|
||||
let e2 =
|
||||
UExpression::try_from_typed(e2, UBitwidth::B32).map_err(|e| ErrorInner {
|
||||
UExpression::try_from_typed(e2, &UBitwidth::B32).map_err(|e| ErrorInner {
|
||||
pos: Some(pos),
|
||||
message: format!(
|
||||
"Expected the right shift right operand to be of type `u32`, found {}",
|
||||
|
@ -3402,11 +3418,16 @@ mod tests {
|
|||
use super::*;
|
||||
|
||||
fn struct0() -> StructDefinitionNode<'static> {
|
||||
StructDefinition { fields: vec![] }.mock()
|
||||
StructDefinition {
|
||||
generics: vec![],
|
||||
fields: vec![],
|
||||
}
|
||||
.mock()
|
||||
}
|
||||
|
||||
fn struct1() -> StructDefinitionNode<'static> {
|
||||
StructDefinition {
|
||||
generics: vec![],
|
||||
fields: vec![StructDefinitionField {
|
||||
id: "foo",
|
||||
ty: UnresolvedType::FieldElement.mock(),
|
||||
|
@ -3730,7 +3751,7 @@ mod tests {
|
|||
checker.check_module(&*MODULE_ID, &mut state).unwrap_err()[0]
|
||||
.inner
|
||||
.message,
|
||||
"Undeclared symbol `P` in function definition"
|
||||
"Undeclared symbol `P`"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
@ -3839,7 +3860,11 @@ mod tests {
|
|||
SymbolDeclaration {
|
||||
id: "foo",
|
||||
symbol: Symbol::Here(SymbolDefinition::Struct(
|
||||
StructDefinition { fields: vec![] }.mock(),
|
||||
StructDefinition {
|
||||
generics: vec![],
|
||||
fields: vec![],
|
||||
}
|
||||
.mock(),
|
||||
)),
|
||||
}
|
||||
.mock(),
|
||||
|
@ -3990,7 +4015,7 @@ mod tests {
|
|||
Checker::<Bn128Field>::new().check_signature(signature, &*MODULE_ID, &state),
|
||||
Err(vec![ErrorInner {
|
||||
pos: Some((Position::mock(), Position::mock())),
|
||||
message: "Undeclared symbol `K` in function definition".to_string()
|
||||
message: "Undeclared symbol `K`".to_string()
|
||||
}])
|
||||
);
|
||||
}
|
||||
|
@ -5287,12 +5312,17 @@ mod tests {
|
|||
let modules = Modules::new();
|
||||
let state = State::new(modules);
|
||||
|
||||
let declaration: StructDefinitionNode = StructDefinition { fields: vec![] }.mock();
|
||||
let declaration: StructDefinitionNode = StructDefinition {
|
||||
generics: vec![],
|
||||
fields: vec![],
|
||||
}
|
||||
.mock();
|
||||
|
||||
let expected_type = DeclarationType::Struct(DeclarationStructType::new(
|
||||
"".into(),
|
||||
"Foo".into(),
|
||||
vec![],
|
||||
vec![],
|
||||
));
|
||||
|
||||
assert_eq!(
|
||||
|
@ -5313,6 +5343,7 @@ mod tests {
|
|||
let state = State::new(modules);
|
||||
|
||||
let declaration: StructDefinitionNode = StructDefinition {
|
||||
generics: vec![],
|
||||
fields: vec![
|
||||
StructDefinitionField {
|
||||
id: "foo",
|
||||
|
@ -5331,6 +5362,7 @@ mod tests {
|
|||
let expected_type = DeclarationType::Struct(DeclarationStructType::new(
|
||||
"".into(),
|
||||
"Foo".into(),
|
||||
vec![],
|
||||
vec![
|
||||
DeclarationStructMember::new("foo".into(), DeclarationType::FieldElement),
|
||||
DeclarationStructMember::new("bar".into(), DeclarationType::Boolean),
|
||||
|
@ -5355,6 +5387,7 @@ mod tests {
|
|||
let state = State::new(modules);
|
||||
|
||||
let declaration: StructDefinitionNode = StructDefinition {
|
||||
generics: vec![],
|
||||
fields: vec![
|
||||
StructDefinitionField {
|
||||
id: "foo",
|
||||
|
@ -5397,6 +5430,7 @@ mod tests {
|
|||
id: "Foo",
|
||||
symbol: Symbol::Here(SymbolDefinition::Struct(
|
||||
StructDefinition {
|
||||
generics: vec![],
|
||||
fields: vec![StructDefinitionField {
|
||||
id: "foo",
|
||||
ty: UnresolvedType::FieldElement.mock(),
|
||||
|
@ -5411,9 +5445,10 @@ mod tests {
|
|||
id: "Bar",
|
||||
symbol: Symbol::Here(SymbolDefinition::Struct(
|
||||
StructDefinition {
|
||||
generics: vec![],
|
||||
fields: vec![StructDefinitionField {
|
||||
id: "foo",
|
||||
ty: UnresolvedType::User("Foo".into()).mock(),
|
||||
ty: UnresolvedType::User("Foo".into(), None).mock(),
|
||||
}
|
||||
.mock()],
|
||||
}
|
||||
|
@ -5439,11 +5474,13 @@ mod tests {
|
|||
&DeclarationType::Struct(DeclarationStructType::new(
|
||||
(*MODULE_ID).clone(),
|
||||
"Bar".into(),
|
||||
vec![],
|
||||
vec![DeclarationStructMember::new(
|
||||
"foo".into(),
|
||||
DeclarationType::Struct(DeclarationStructType::new(
|
||||
(*MODULE_ID).clone(),
|
||||
"Foo".into(),
|
||||
vec![],
|
||||
vec![DeclarationStructMember::new(
|
||||
"foo".into(),
|
||||
DeclarationType::FieldElement
|
||||
|
@ -5465,9 +5502,10 @@ mod tests {
|
|||
id: "Bar",
|
||||
symbol: Symbol::Here(SymbolDefinition::Struct(
|
||||
StructDefinition {
|
||||
generics: vec![],
|
||||
fields: vec![StructDefinitionField {
|
||||
id: "foo",
|
||||
ty: UnresolvedType::User("Foo".into()).mock(),
|
||||
ty: UnresolvedType::User("Foo".into(), None).mock(),
|
||||
}
|
||||
.mock()],
|
||||
}
|
||||
|
@ -5497,9 +5535,10 @@ mod tests {
|
|||
id: "Foo",
|
||||
symbol: Symbol::Here(SymbolDefinition::Struct(
|
||||
StructDefinition {
|
||||
generics: vec![],
|
||||
fields: vec![StructDefinitionField {
|
||||
id: "foo",
|
||||
ty: UnresolvedType::User("Foo".into()).mock(),
|
||||
ty: UnresolvedType::User("Foo".into(), None).mock(),
|
||||
}
|
||||
.mock()],
|
||||
}
|
||||
|
@ -5531,9 +5570,10 @@ mod tests {
|
|||
id: "Foo",
|
||||
symbol: Symbol::Here(SymbolDefinition::Struct(
|
||||
StructDefinition {
|
||||
generics: vec![],
|
||||
fields: vec![StructDefinitionField {
|
||||
id: "bar",
|
||||
ty: UnresolvedType::User("Bar".into()).mock(),
|
||||
ty: UnresolvedType::User("Bar".into(), None).mock(),
|
||||
}
|
||||
.mock()],
|
||||
}
|
||||
|
@ -5545,9 +5585,10 @@ mod tests {
|
|||
id: "Bar",
|
||||
symbol: Symbol::Here(SymbolDefinition::Struct(
|
||||
StructDefinition {
|
||||
generics: vec![],
|
||||
fields: vec![StructDefinitionField {
|
||||
id: "foo",
|
||||
ty: UnresolvedType::User("Foo".into()).mock(),
|
||||
ty: UnresolvedType::User("Foo".into(), None).mock(),
|
||||
}
|
||||
.mock()],
|
||||
}
|
||||
|
@ -5582,6 +5623,7 @@ mod tests {
|
|||
// Bar
|
||||
|
||||
let (mut checker, state) = create_module_with_foo(StructDefinition {
|
||||
generics: vec![],
|
||||
fields: vec![StructDefinitionField {
|
||||
id: "foo",
|
||||
ty: UnresolvedType::FieldElement.mock(),
|
||||
|
@ -5591,13 +5633,14 @@ mod tests {
|
|||
|
||||
assert_eq!(
|
||||
checker.check_type(
|
||||
UnresolvedType::User("Foo".into()).mock(),
|
||||
UnresolvedType::User("Foo".into(), None).mock(),
|
||||
&*MODULE_ID,
|
||||
&state.types
|
||||
),
|
||||
Ok(Type::Struct(StructType::new(
|
||||
"".into(),
|
||||
"Foo".into(),
|
||||
vec![],
|
||||
vec![StructMember::new("foo".into(), Type::FieldElement)]
|
||||
)))
|
||||
);
|
||||
|
@ -5605,7 +5648,7 @@ mod tests {
|
|||
assert_eq!(
|
||||
checker
|
||||
.check_type(
|
||||
UnresolvedType::User("Bar".into()).mock(),
|
||||
UnresolvedType::User("Bar".into(), None).mock(),
|
||||
&*MODULE_ID,
|
||||
&state.types
|
||||
)
|
||||
|
@ -5628,6 +5671,7 @@ mod tests {
|
|||
// Foo { foo: 42 }.foo
|
||||
|
||||
let (mut checker, state) = create_module_with_foo(StructDefinition {
|
||||
generics: vec![],
|
||||
fields: vec![StructDefinitionField {
|
||||
id: "foo",
|
||||
ty: UnresolvedType::FieldElement.mock(),
|
||||
|
@ -5657,6 +5701,7 @@ mod tests {
|
|||
.annotate(StructType::new(
|
||||
"".into(),
|
||||
"Foo".into(),
|
||||
vec![],
|
||||
vec![StructMember::new("foo".into(), Type::FieldElement)]
|
||||
)),
|
||||
"foo".into()
|
||||
|
@ -5673,6 +5718,7 @@ mod tests {
|
|||
// Foo { foo: 42 }.bar
|
||||
|
||||
let (mut checker, state) = create_module_with_foo(StructDefinition {
|
||||
generics: vec![],
|
||||
fields: vec![StructDefinitionField {
|
||||
id: "foo",
|
||||
ty: UnresolvedType::FieldElement.mock(),
|
||||
|
@ -5711,6 +5757,7 @@ mod tests {
|
|||
// a A value cannot be defined with B as id, even if A and B have the same members
|
||||
|
||||
let (mut checker, state) = create_module_with_foo(StructDefinition {
|
||||
generics: vec![],
|
||||
fields: vec![StructDefinitionField {
|
||||
id: "foo",
|
||||
ty: UnresolvedType::FieldElement.mock(),
|
||||
|
@ -5731,7 +5778,7 @@ mod tests {
|
|||
)
|
||||
.unwrap_err()
|
||||
.message,
|
||||
"Undefined type Bar"
|
||||
"Undefined type `Bar`"
|
||||
);
|
||||
}
|
||||
|
||||
|
@ -5743,6 +5790,7 @@ mod tests {
|
|||
// Foo foo = Foo { foo: 42, bar: true }
|
||||
|
||||
let (mut checker, state) = create_module_with_foo(StructDefinition {
|
||||
generics: vec![],
|
||||
fields: vec![
|
||||
StructDefinitionField {
|
||||
id: "foo",
|
||||
|
@ -5777,6 +5825,7 @@ mod tests {
|
|||
.annotate(StructType::new(
|
||||
"".into(),
|
||||
"Foo".into(),
|
||||
vec![],
|
||||
vec![
|
||||
StructMember::new("foo".into(), Type::FieldElement),
|
||||
StructMember::new("bar".into(), Type::Boolean)
|
||||
|
@ -5794,6 +5843,7 @@ mod tests {
|
|||
// Foo foo = Foo { bar: true, foo: 42 }
|
||||
|
||||
let (mut checker, state) = create_module_with_foo(StructDefinition {
|
||||
generics: vec![],
|
||||
fields: vec![
|
||||
StructDefinitionField {
|
||||
id: "foo",
|
||||
|
@ -5828,6 +5878,7 @@ mod tests {
|
|||
.annotate(StructType::new(
|
||||
"".into(),
|
||||
"Foo".into(),
|
||||
vec![],
|
||||
vec![
|
||||
StructMember::new("foo".into(), Type::FieldElement),
|
||||
StructMember::new("bar".into(), Type::Boolean)
|
||||
|
@ -5845,6 +5896,7 @@ mod tests {
|
|||
// Foo foo = Foo { foo: 42 }
|
||||
|
||||
let (mut checker, state) = create_module_with_foo(StructDefinition {
|
||||
generics: vec![],
|
||||
fields: vec![
|
||||
StructDefinitionField {
|
||||
id: "foo",
|
||||
|
@ -5886,6 +5938,7 @@ mod tests {
|
|||
// Foo { foo: 42, baz: 42 } // error
|
||||
|
||||
let (mut checker, state) = create_module_with_foo(StructDefinition {
|
||||
generics: vec![],
|
||||
fields: vec![
|
||||
StructDefinitionField {
|
||||
id: "foo",
|
||||
|
|
|
@ -112,7 +112,6 @@ impl<'ast, T: Field> Folder<'ast, T> for ConstantInliner<'ast, T> {
|
|||
(
|
||||
id,
|
||||
TypedConstantSymbol::Here(TypedConstant {
|
||||
ty: constant.get_type().clone(),
|
||||
expression: constant,
|
||||
}),
|
||||
)
|
||||
|
@ -268,10 +267,9 @@ mod tests {
|
|||
|
||||
let constants: TypedConstantSymbols<_> = vec![(
|
||||
CanonicalConstantIdentifier::new(const_id, "main".into()),
|
||||
TypedConstantSymbol::Here(TypedConstant::new(
|
||||
GType::FieldElement,
|
||||
TypedExpression::FieldElement(FieldElementExpression::Number(Bn128Field::from(1))),
|
||||
)),
|
||||
TypedConstantSymbol::Here(TypedConstant::new(TypedExpression::FieldElement(
|
||||
FieldElementExpression::Number(Bn128Field::from(1)),
|
||||
))),
|
||||
)]
|
||||
.into_iter()
|
||||
.collect();
|
||||
|
@ -356,10 +354,9 @@ mod tests {
|
|||
|
||||
let constants: TypedConstantSymbols<_> = vec![(
|
||||
CanonicalConstantIdentifier::new(const_id, "main".into()),
|
||||
TypedConstantSymbol::Here(TypedConstant::new(
|
||||
GType::Boolean,
|
||||
TypedExpression::Boolean(BooleanExpression::Value(true)),
|
||||
)),
|
||||
TypedConstantSymbol::Here(TypedConstant::new(TypedExpression::Boolean(
|
||||
BooleanExpression::Value(true),
|
||||
))),
|
||||
)]
|
||||
.into_iter()
|
||||
.collect();
|
||||
|
@ -446,7 +443,6 @@ mod tests {
|
|||
let constants: TypedConstantSymbols<_> = vec![(
|
||||
CanonicalConstantIdentifier::new(const_id, "main".into()),
|
||||
TypedConstantSymbol::Here(TypedConstant::new(
|
||||
GType::Uint(UBitwidth::B32),
|
||||
UExpressionInner::Value(1u128)
|
||||
.annotate(UBitwidth::B32)
|
||||
.into(),
|
||||
|
@ -546,19 +542,16 @@ mod tests {
|
|||
|
||||
let constants: TypedConstantSymbols<_> = vec![(
|
||||
CanonicalConstantIdentifier::new(const_id, "main".into()),
|
||||
TypedConstantSymbol::Here(TypedConstant::new(
|
||||
GType::array(GArrayType::new(GType::FieldElement, 2usize)),
|
||||
TypedExpression::Array(
|
||||
ArrayExpressionInner::Value(
|
||||
vec![
|
||||
FieldElementExpression::Number(Bn128Field::from(2)).into(),
|
||||
FieldElementExpression::Number(Bn128Field::from(2)).into(),
|
||||
]
|
||||
.into(),
|
||||
)
|
||||
.annotate(GType::FieldElement, 2usize),
|
||||
),
|
||||
)),
|
||||
TypedConstantSymbol::Here(TypedConstant::new(TypedExpression::Array(
|
||||
ArrayExpressionInner::Value(
|
||||
vec![
|
||||
FieldElementExpression::Number(Bn128Field::from(2)).into(),
|
||||
FieldElementExpression::Number(Bn128Field::from(2)).into(),
|
||||
]
|
||||
.into(),
|
||||
)
|
||||
.annotate(GType::FieldElement, 2usize),
|
||||
))),
|
||||
)]
|
||||
.into_iter()
|
||||
.collect();
|
||||
|
@ -686,7 +679,6 @@ mod tests {
|
|||
(
|
||||
CanonicalConstantIdentifier::new(const_a_id, "main".into()),
|
||||
TypedConstantSymbol::Here(TypedConstant::new(
|
||||
GType::FieldElement,
|
||||
TypedExpression::FieldElement(FieldElementExpression::Number(
|
||||
Bn128Field::from(1),
|
||||
)),
|
||||
|
@ -695,7 +687,6 @@ mod tests {
|
|||
(
|
||||
CanonicalConstantIdentifier::new(const_b_id, "main".into()),
|
||||
TypedConstantSymbol::Here(TypedConstant::new(
|
||||
GType::FieldElement,
|
||||
TypedExpression::FieldElement(FieldElementExpression::Add(
|
||||
box FieldElementExpression::Identifier(Identifier::from(
|
||||
const_a_id,
|
||||
|
@ -744,7 +735,6 @@ mod tests {
|
|||
(
|
||||
CanonicalConstantIdentifier::new(const_a_id, "main".into()),
|
||||
TypedConstantSymbol::Here(TypedConstant::new(
|
||||
GType::FieldElement,
|
||||
TypedExpression::FieldElement(FieldElementExpression::Number(
|
||||
Bn128Field::from(1),
|
||||
)),
|
||||
|
@ -753,7 +743,6 @@ mod tests {
|
|||
(
|
||||
CanonicalConstantIdentifier::new(const_b_id, "main".into()),
|
||||
TypedConstantSymbol::Here(TypedConstant::new(
|
||||
GType::FieldElement,
|
||||
TypedExpression::FieldElement(FieldElementExpression::Number(
|
||||
Bn128Field::from(2),
|
||||
)),
|
||||
|
@ -804,12 +793,9 @@ mod tests {
|
|||
.collect(),
|
||||
constants: vec![(
|
||||
CanonicalConstantIdentifier::new(foo_const_id, "foo".into()),
|
||||
TypedConstantSymbol::Here(TypedConstant::new(
|
||||
GType::FieldElement,
|
||||
TypedExpression::FieldElement(FieldElementExpression::Number(
|
||||
Bn128Field::from(42),
|
||||
)),
|
||||
)),
|
||||
TypedConstantSymbol::Here(TypedConstant::new(TypedExpression::FieldElement(
|
||||
FieldElementExpression::Number(Bn128Field::from(42)),
|
||||
))),
|
||||
)]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
|
@ -877,12 +863,9 @@ mod tests {
|
|||
.collect(),
|
||||
constants: vec![(
|
||||
CanonicalConstantIdentifier::new(foo_const_id, "main".into()),
|
||||
TypedConstantSymbol::Here(TypedConstant::new(
|
||||
GType::FieldElement,
|
||||
TypedExpression::FieldElement(FieldElementExpression::Number(
|
||||
Bn128Field::from(42),
|
||||
)),
|
||||
)),
|
||||
TypedConstantSymbol::Here(TypedConstant::new(TypedExpression::FieldElement(
|
||||
FieldElementExpression::Number(Bn128Field::from(42)),
|
||||
))),
|
||||
)]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
|
|
|
@ -1101,7 +1101,11 @@ fn fold_function<'ast, T: Field>(
|
|||
.collect(),
|
||||
statements: main_statements_buffer,
|
||||
signature: typed_absy::types::ConcreteSignature::try_from(
|
||||
typed_absy::types::Signature::<T>::try_from(fun.signature).unwrap(),
|
||||
crate::typed_absy::types::try_from_g_signature::<
|
||||
crate::typed_absy::types::DeclarationConstant<'ast>,
|
||||
crate::typed_absy::UExpression<'ast, T>,
|
||||
>(fun.signature)
|
||||
.unwrap(),
|
||||
)
|
||||
.unwrap()
|
||||
.into(),
|
||||
|
|
|
@ -72,8 +72,6 @@ impl fmt::Display for Error {
|
|||
|
||||
impl<'ast, T: Field> TypedProgram<'ast, T> {
|
||||
pub fn analyse(self, config: &CompileConfig) -> Result<(ZirProgram<'ast, T>, Abi), Error> {
|
||||
println!("{}", self);
|
||||
|
||||
// inline user-defined constants
|
||||
let r = ConstantInliner::inline(self);
|
||||
// isolate branches
|
||||
|
@ -86,13 +84,12 @@ impl<'ast, T: Field> TypedProgram<'ast, T> {
|
|||
// reduce the program to a single function
|
||||
let r = reduce_program(r).map_err(Error::from)?;
|
||||
|
||||
println!("{}", r);
|
||||
|
||||
// generate abi
|
||||
let abi = r.abi();
|
||||
|
||||
// propagate
|
||||
let r = Propagator::propagate(r).map_err(Error::from)?;
|
||||
|
||||
// remove assignment to variable index
|
||||
let r = VariableWriteRemover::apply(r);
|
||||
// detect non constant shifts
|
||||
|
|
|
@ -327,20 +327,13 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Propagator<'ast, 'a, T> {
|
|||
// propagation to the defined variable if rhs is a constant
|
||||
TypedStatement::Definition(assignee, expr) => {
|
||||
let expr = self.fold_expression(expr)?;
|
||||
let assignee = self.fold_assignee(assignee)?;
|
||||
|
||||
println!(
|
||||
"{:#?} vs {:#?}",
|
||||
assignee.get_type().clone(),
|
||||
expr.get_type().clone()
|
||||
);
|
||||
let assignee = self.fold_assignee(assignee)?;
|
||||
|
||||
if let (Ok(a), Ok(e)) = (
|
||||
ConcreteType::try_from(assignee.get_type()),
|
||||
ConcreteType::try_from(expr.get_type()),
|
||||
) {
|
||||
println!("{} vs {}", a, e);
|
||||
|
||||
if a != e {
|
||||
return Err(Error::Type(format!(
|
||||
"Cannot assign {} of type {} to {} of type {}",
|
||||
|
|
|
@ -231,6 +231,7 @@ mod tests {
|
|||
ty: ConcreteType::Struct(ConcreteStructType::new(
|
||||
"".into(),
|
||||
"Foo".into(),
|
||||
vec![],
|
||||
vec![
|
||||
ConcreteStructMember::new(String::from("a"), ConcreteType::FieldElement),
|
||||
ConcreteStructMember::new(String::from("b"), ConcreteType::Boolean),
|
||||
|
@ -240,6 +241,7 @@ mod tests {
|
|||
outputs: vec![ConcreteType::Struct(ConcreteStructType::new(
|
||||
"".into(),
|
||||
"Foo".into(),
|
||||
vec![],
|
||||
vec![
|
||||
ConcreteStructMember::new(String::from("a"), ConcreteType::FieldElement),
|
||||
ConcreteStructMember::new(String::from("b"), ConcreteType::Boolean),
|
||||
|
@ -258,6 +260,7 @@ mod tests {
|
|||
"type": "struct",
|
||||
"components": {
|
||||
"name": "Foo",
|
||||
"generics": [],
|
||||
"members": [
|
||||
{
|
||||
"name": "a",
|
||||
|
@ -276,6 +279,7 @@ mod tests {
|
|||
"type": "struct",
|
||||
"components": {
|
||||
"name": "Foo",
|
||||
"generics": [],
|
||||
"members": [
|
||||
{
|
||||
"name": "a",
|
||||
|
@ -305,11 +309,13 @@ mod tests {
|
|||
ty: ConcreteType::Struct(ConcreteStructType::new(
|
||||
"".into(),
|
||||
"Foo".into(),
|
||||
vec![],
|
||||
vec![ConcreteStructMember::new(
|
||||
String::from("bar"),
|
||||
ConcreteType::Struct(ConcreteStructType::new(
|
||||
"".into(),
|
||||
"Bar".into(),
|
||||
vec![],
|
||||
vec![
|
||||
ConcreteStructMember::new(
|
||||
String::from("a"),
|
||||
|
@ -338,12 +344,14 @@ mod tests {
|
|||
"type": "struct",
|
||||
"components": {
|
||||
"name": "Foo",
|
||||
"generics": [],
|
||||
"members": [
|
||||
{
|
||||
"name": "bar",
|
||||
"type": "struct",
|
||||
"components": {
|
||||
"name": "Bar",
|
||||
"generics": [],
|
||||
"members": [
|
||||
{
|
||||
"name": "a",
|
||||
|
@ -378,6 +386,7 @@ mod tests {
|
|||
ConcreteType::Struct(ConcreteStructType::new(
|
||||
"".into(),
|
||||
"Foo".into(),
|
||||
vec![],
|
||||
vec![
|
||||
ConcreteStructMember::new(
|
||||
String::from("b"),
|
||||
|
@ -406,6 +415,7 @@ mod tests {
|
|||
"type": "struct",
|
||||
"components": {
|
||||
"name": "Foo",
|
||||
"generics": [],
|
||||
"members": [
|
||||
{
|
||||
"name": "b",
|
||||
|
|
|
@ -138,6 +138,11 @@ pub trait Folder<'ast, T: Field>: Sized {
|
|||
|
||||
fn fold_struct_type(&mut self, t: StructType<'ast, T>) -> StructType<'ast, T> {
|
||||
StructType {
|
||||
generics: t
|
||||
.generics
|
||||
.into_iter()
|
||||
.map(|g| g.map(|g| self.fold_uint_expression(g)))
|
||||
.collect(),
|
||||
members: t
|
||||
.members
|
||||
.into_iter()
|
||||
|
@ -175,6 +180,11 @@ pub trait Folder<'ast, T: Field>: Sized {
|
|||
t: DeclarationStructType<'ast>,
|
||||
) -> DeclarationStructType<'ast> {
|
||||
DeclarationStructType {
|
||||
generics: t
|
||||
.generics
|
||||
.into_iter()
|
||||
.map(|g| g.map(|g| self.fold_declaration_constant(g)))
|
||||
.collect(),
|
||||
members: t
|
||||
.members
|
||||
.into_iter()
|
||||
|
@ -1044,7 +1054,6 @@ pub fn fold_constant<'ast, T: Field, F: Folder<'ast, T>>(
|
|||
c: TypedConstant<'ast, T>,
|
||||
) -> TypedConstant<'ast, T> {
|
||||
TypedConstant {
|
||||
ty: f.fold_type(c.ty),
|
||||
expression: f.fold_expression(c.expression),
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,9 +1,12 @@
|
|||
use crate::typed_absy::types::{ArrayType, Type};
|
||||
use crate::typed_absy::types::{
|
||||
ArrayType, DeclarationArrayType, DeclarationConstant, DeclarationType, GArrayType, GStructType,
|
||||
GType, StructMember, StructType, Type,
|
||||
};
|
||||
use crate::typed_absy::UBitwidth;
|
||||
use crate::typed_absy::{
|
||||
ArrayExpression, ArrayExpressionInner, BooleanExpression, FieldElementExpression, IfElse,
|
||||
IfElseExpression, Select, SelectExpression, StructExpression, Typed, TypedExpression,
|
||||
TypedExpressionOrSpread, TypedSpread, UExpression, UExpressionInner,
|
||||
IfElseExpression, Select, SelectExpression, StructExpression, StructExpressionInner, Typed,
|
||||
TypedExpression, TypedExpressionOrSpread, TypedSpread, UExpression, UExpressionInner,
|
||||
};
|
||||
use num_bigint::BigUint;
|
||||
use std::convert::TryFrom;
|
||||
|
@ -14,20 +17,46 @@ use zokrates_field::Field;
|
|||
type TypedExpressionPair<'ast, T> = (TypedExpression<'ast, T>, TypedExpression<'ast, T>);
|
||||
|
||||
impl<'ast, T: Field> TypedExpressionOrSpread<'ast, T> {
|
||||
pub fn align_to_type(e: Self, ty: Type<'ast, T>) -> Result<Self, (Self, Type<'ast, T>)> {
|
||||
pub fn align_to_type<S: PartialEq<UExpression<'ast, T>>>(
|
||||
e: Self,
|
||||
ty: &GArrayType<S>,
|
||||
) -> Result<Self, (Self, &GArrayType<S>)> {
|
||||
match e {
|
||||
TypedExpressionOrSpread::Expression(e) => TypedExpression::align_to_type(e, ty)
|
||||
TypedExpressionOrSpread::Expression(e) => TypedExpression::align_to_type(e, &ty.ty)
|
||||
.map(|e| e.into())
|
||||
.map_err(|(e, t)| (e.into(), t)),
|
||||
TypedExpressionOrSpread::Spread(s) => {
|
||||
ArrayExpression::try_from_int(s.array, ty.clone())
|
||||
.map(|e| TypedExpressionOrSpread::Spread(TypedSpread { array: e }))
|
||||
.map_err(|e| (e.into(), ty))
|
||||
}
|
||||
.map_err(|(e, _)| (e.into(), ty)),
|
||||
TypedExpressionOrSpread::Spread(s) => ArrayExpression::try_from_int(s.array, ty)
|
||||
.map(|e| TypedExpressionOrSpread::Spread(TypedSpread { array: e }))
|
||||
.map_err(|e| (e.into(), ty)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn get_common_type<'a, T: Field>(t: Type<'a, T>, u: Type<'a, T>) -> Result<Type<'a, T>, ()> {
|
||||
match (t, u) {
|
||||
(Type::Int, Type::Int) => Err(()),
|
||||
(Type::Int, u) => Ok(u),
|
||||
(t, Type::Int) => Ok(t),
|
||||
(Type::Array(t), Type::Array(u)) => Ok(Type::Array(ArrayType::new(
|
||||
get_common_type(*t.ty, *u.ty)?,
|
||||
t.size,
|
||||
))),
|
||||
(Type::Struct(t), Type::Struct(u)) => Ok(Type::Struct(StructType {
|
||||
members: t
|
||||
.members
|
||||
.into_iter()
|
||||
.zip(u.members.into_iter())
|
||||
.map(|(m_t, m_u)| {
|
||||
get_common_type(*m_t.ty.clone(), *m_u.ty)
|
||||
.map(|ty| StructMember { ty: box ty, ..m_t })
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()?,
|
||||
..t
|
||||
})),
|
||||
(t, _) => Ok(t),
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast, T: Field> TypedExpression<'ast, T> {
|
||||
// return two TypedExpression, replacing IntExpression by FieldElement or Uint to try to align the two types if possible.
|
||||
// Post condition is that (lhs, rhs) cannot be made equal by further removing IntExpressions
|
||||
|
@ -51,7 +80,7 @@ impl<'ast, T: Field> TypedExpression<'ast, T> {
|
|||
.into(),
|
||||
)),
|
||||
(Int(lhs), Uint(rhs)) => Ok((
|
||||
UExpression::try_from_int(lhs, rhs.bitwidth())
|
||||
UExpression::try_from_int(lhs, &rhs.bitwidth())
|
||||
.map_err(|lhs| (lhs.into(), rhs.clone().into()))?
|
||||
.into(),
|
||||
Uint(rhs),
|
||||
|
@ -60,47 +89,46 @@ impl<'ast, T: Field> TypedExpression<'ast, T> {
|
|||
let bitwidth = lhs.bitwidth();
|
||||
Ok((
|
||||
Uint(lhs.clone()),
|
||||
UExpression::try_from_int(rhs, bitwidth)
|
||||
UExpression::try_from_int(rhs, &bitwidth)
|
||||
.map_err(|rhs| (lhs.into(), rhs.into()))?
|
||||
.into(),
|
||||
))
|
||||
}
|
||||
(Array(lhs), Array(rhs)) => {
|
||||
fn get_common_type<'a, T: Field>(
|
||||
t: Type<'a, T>,
|
||||
u: Type<'a, T>,
|
||||
) -> Result<Type<'a, T>, ()> {
|
||||
match (t, u) {
|
||||
(Type::Int, Type::Int) => Err(()),
|
||||
(Type::Int, u) => Ok(u),
|
||||
(t, Type::Int) => Ok(t),
|
||||
(Type::Array(t), Type::Array(u)) => Ok(Type::Array(ArrayType::new(
|
||||
get_common_type(*t.ty, *u.ty)?,
|
||||
t.size,
|
||||
))),
|
||||
(t, _) => Ok(t),
|
||||
}
|
||||
}
|
||||
let common_type = get_common_type(lhs.get_type().clone(), rhs.get_type().clone())
|
||||
.map_err(|_| (lhs.clone().into(), rhs.clone().into()))?;
|
||||
|
||||
let common_type =
|
||||
get_common_type(lhs.inner_type().clone(), rhs.inner_type().clone())
|
||||
.map_err(|_| (lhs.clone().into(), rhs.clone().into()))?;
|
||||
let common_type = match common_type {
|
||||
Type::Array(ty) => ty,
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
Ok((
|
||||
ArrayExpression::try_from_int(lhs.clone(), common_type.clone())
|
||||
ArrayExpression::try_from_int(lhs.clone(), &common_type)
|
||||
.map_err(|lhs| (lhs.clone(), rhs.clone().into()))?
|
||||
.into(),
|
||||
ArrayExpression::try_from_int(rhs, common_type)
|
||||
ArrayExpression::try_from_int(rhs, &common_type)
|
||||
.map_err(|rhs| (lhs.clone().into(), rhs.clone()))?
|
||||
.into(),
|
||||
))
|
||||
}
|
||||
(Struct(lhs), Struct(rhs)) => {
|
||||
if lhs.get_type() == rhs.get_type() {
|
||||
Ok((Struct(lhs), Struct(rhs)))
|
||||
} else {
|
||||
Err((Struct(lhs), Struct(rhs)))
|
||||
}
|
||||
let common_type = get_common_type(lhs.get_type(), rhs.get_type())
|
||||
.map_err(|_| (lhs.clone().into(), rhs.clone().into()))?;
|
||||
|
||||
let common_type = match common_type {
|
||||
Type::Struct(ty) => ty,
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
Ok((
|
||||
StructExpression::try_from_int(lhs.clone(), &common_type)
|
||||
.map_err(|lhs| (lhs.clone(), rhs.clone().into()))?
|
||||
.into(),
|
||||
StructExpression::try_from_int(rhs, &common_type)
|
||||
.map_err(|rhs| (lhs.clone().into(), rhs.clone()))?
|
||||
.into(),
|
||||
))
|
||||
}
|
||||
(Uint(lhs), Uint(rhs)) => Ok((lhs.into(), rhs.into())),
|
||||
(Boolean(lhs), Boolean(rhs)) => Ok((lhs.into(), rhs.into())),
|
||||
|
@ -110,22 +138,25 @@ impl<'ast, T: Field> TypedExpression<'ast, T> {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn align_to_type(e: Self, ty: Type<'ast, T>) -> Result<Self, (Self, Type<'ast, T>)> {
|
||||
match ty.clone() {
|
||||
Type::FieldElement => {
|
||||
pub fn align_to_type<S: PartialEq<UExpression<'ast, T>>>(
|
||||
e: Self,
|
||||
ty: >ype<S>,
|
||||
) -> Result<Self, (Self, >ype<S>)> {
|
||||
match ty {
|
||||
GType::FieldElement => {
|
||||
FieldElementExpression::try_from_typed(e).map(TypedExpression::from)
|
||||
}
|
||||
Type::Boolean => BooleanExpression::try_from_typed(e).map(TypedExpression::from),
|
||||
Type::Uint(bitwidth) => {
|
||||
GType::Boolean => BooleanExpression::try_from_typed(e).map(TypedExpression::from),
|
||||
GType::Uint(bitwidth) => {
|
||||
UExpression::try_from_typed(e, bitwidth).map(TypedExpression::from)
|
||||
}
|
||||
Type::Array(array_ty) => {
|
||||
ArrayExpression::try_from_typed(e, *array_ty.ty).map(TypedExpression::from)
|
||||
GType::Array(array_ty) => {
|
||||
ArrayExpression::try_from_typed(e, array_ty).map(TypedExpression::from)
|
||||
}
|
||||
Type::Struct(struct_ty) => {
|
||||
GType::Struct(struct_ty) => {
|
||||
StructExpression::try_from_typed(e, struct_ty).map(TypedExpression::from)
|
||||
}
|
||||
Type::Int => Err(e),
|
||||
GType::Int => Err(e),
|
||||
}
|
||||
.map_err(|e| (e, ty))
|
||||
}
|
||||
|
@ -299,7 +330,7 @@ impl<'ast, T: Field> FieldElementExpression<'ast, T> {
|
|||
)),
|
||||
IntExpression::Pow(box e1, box e2) => Ok(Self::Pow(
|
||||
box Self::try_from_int(e1)?,
|
||||
box UExpression::try_from_int(e2, UBitwidth::B32)?,
|
||||
box UExpression::try_from_int(e2, &UBitwidth::B32)?,
|
||||
)),
|
||||
IntExpression::Div(box e1, box e2) => Ok(Self::Div(
|
||||
box Self::try_from_int(e1)?,
|
||||
|
@ -323,15 +354,21 @@ impl<'ast, T: Field> FieldElementExpression<'ast, T> {
|
|||
let values = values
|
||||
.into_iter()
|
||||
.map(|v| {
|
||||
TypedExpressionOrSpread::align_to_type(v, Type::FieldElement)
|
||||
.map_err(|(e, _)| match e {
|
||||
TypedExpressionOrSpread::Expression(e) => {
|
||||
IntExpression::try_from(e).unwrap()
|
||||
}
|
||||
TypedExpressionOrSpread::Spread(a) => {
|
||||
IntExpression::select(a.array, 0u32)
|
||||
}
|
||||
})
|
||||
TypedExpressionOrSpread::align_to_type(
|
||||
v,
|
||||
&DeclarationArrayType::new(
|
||||
DeclarationType::FieldElement,
|
||||
DeclarationConstant::Concrete(0),
|
||||
),
|
||||
)
|
||||
.map_err(|(e, _)| match e {
|
||||
TypedExpressionOrSpread::Expression(e) => {
|
||||
IntExpression::try_from(e).unwrap()
|
||||
}
|
||||
TypedExpressionOrSpread::Spread(a) => {
|
||||
IntExpression::select(a.array, 0u32)
|
||||
}
|
||||
})
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
Ok(FieldElementExpression::select(
|
||||
|
@ -351,10 +388,10 @@ impl<'ast, T: Field> FieldElementExpression<'ast, T> {
|
|||
impl<'ast, T: Field> UExpression<'ast, T> {
|
||||
pub fn try_from_typed(
|
||||
e: TypedExpression<'ast, T>,
|
||||
bitwidth: UBitwidth,
|
||||
bitwidth: &UBitwidth,
|
||||
) -> Result<Self, TypedExpression<'ast, T>> {
|
||||
match e {
|
||||
TypedExpression::Uint(e) => match e.bitwidth == bitwidth {
|
||||
TypedExpression::Uint(e) => match e.bitwidth == *bitwidth {
|
||||
true => Ok(e),
|
||||
_ => Err(TypedExpression::Uint(e)),
|
||||
},
|
||||
|
@ -367,7 +404,7 @@ impl<'ast, T: Field> UExpression<'ast, T> {
|
|||
|
||||
pub fn try_from_int(
|
||||
i: IntExpression<'ast, T>,
|
||||
bitwidth: UBitwidth,
|
||||
bitwidth: &UBitwidth,
|
||||
) -> Result<Self, IntExpression<'ast, T>> {
|
||||
use self::IntExpression::*;
|
||||
|
||||
|
@ -377,7 +414,7 @@ impl<'ast, T: Field> UExpression<'ast, T> {
|
|||
Ok(UExpressionInner::Value(
|
||||
u128::from_str_radix(&i.to_str_radix(16), 16).unwrap(),
|
||||
)
|
||||
.annotate(bitwidth))
|
||||
.annotate(*bitwidth))
|
||||
} else {
|
||||
Err(Value(i))
|
||||
}
|
||||
|
@ -435,20 +472,26 @@ impl<'ast, T: Field> UExpression<'ast, T> {
|
|||
let values = values
|
||||
.into_iter()
|
||||
.map(|v| {
|
||||
TypedExpressionOrSpread::align_to_type(v, Type::Uint(bitwidth))
|
||||
.map_err(|(e, _)| match e {
|
||||
TypedExpressionOrSpread::Expression(e) => {
|
||||
IntExpression::try_from(e).unwrap()
|
||||
}
|
||||
TypedExpressionOrSpread::Spread(a) => {
|
||||
IntExpression::select(a.array, 0u32)
|
||||
}
|
||||
})
|
||||
TypedExpressionOrSpread::align_to_type(
|
||||
v,
|
||||
&DeclarationArrayType::new(
|
||||
DeclarationType::Uint(*bitwidth),
|
||||
DeclarationConstant::Concrete(0),
|
||||
),
|
||||
)
|
||||
.map_err(|(e, _)| match e {
|
||||
TypedExpressionOrSpread::Expression(e) => {
|
||||
IntExpression::try_from(e).unwrap()
|
||||
}
|
||||
TypedExpressionOrSpread::Spread(a) => {
|
||||
IntExpression::select(a.array, 0u32)
|
||||
}
|
||||
})
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
Ok(UExpression::select(
|
||||
ArrayExpressionInner::Value(values.into())
|
||||
.annotate(Type::Uint(bitwidth), size),
|
||||
.annotate(Type::Uint(*bitwidth), size),
|
||||
index,
|
||||
))
|
||||
}
|
||||
|
@ -461,35 +504,35 @@ impl<'ast, T: Field> UExpression<'ast, T> {
|
|||
}
|
||||
|
||||
impl<'ast, T: Field> ArrayExpression<'ast, T> {
|
||||
pub fn try_from_typed(
|
||||
pub fn try_from_typed<S: PartialEq<UExpression<'ast, T>>>(
|
||||
e: TypedExpression<'ast, T>,
|
||||
target_inner_ty: Type<'ast, T>,
|
||||
target_array_ty: &GArrayType<S>,
|
||||
) -> Result<Self, TypedExpression<'ast, T>> {
|
||||
match e {
|
||||
TypedExpression::Array(e) => Self::try_from_int(e.clone(), target_inner_ty)
|
||||
TypedExpression::Array(e) => Self::try_from_int(e.clone(), target_array_ty)
|
||||
.map_err(|_| TypedExpression::Array(e)),
|
||||
e => Err(e),
|
||||
}
|
||||
}
|
||||
|
||||
// precondition: `array` is only made of inline arrays and repeat constructs unless it does not contain the Integer type
|
||||
pub fn try_from_int(
|
||||
pub fn try_from_int<S: PartialEq<UExpression<'ast, T>>>(
|
||||
array: Self,
|
||||
target_inner_ty: Type<'ast, T>,
|
||||
target_array_ty: &GArrayType<S>,
|
||||
) -> Result<Self, TypedExpression<'ast, T>> {
|
||||
let array_ty = array.ty();
|
||||
|
||||
// elements must fit in the target type
|
||||
match array.into_inner() {
|
||||
ArrayExpressionInner::Value(inline_array) => {
|
||||
let res = match target_inner_ty.clone() {
|
||||
Type::Int => Ok(inline_array),
|
||||
t => {
|
||||
let res = match &*target_array_ty.ty {
|
||||
GType::Int => Ok(inline_array),
|
||||
_ => {
|
||||
// try to convert all elements to the target type
|
||||
inline_array
|
||||
.into_iter()
|
||||
.map(|v| {
|
||||
TypedExpressionOrSpread::align_to_type(v, t.clone()).map_err(
|
||||
TypedExpressionOrSpread::align_to_type(v, &target_array_ty).map_err(
|
||||
|(e, _)| match e {
|
||||
TypedExpressionOrSpread::Expression(e) => e,
|
||||
TypedExpressionOrSpread::Spread(a) => {
|
||||
|
@ -508,11 +551,11 @@ impl<'ast, T: Field> ArrayExpression<'ast, T> {
|
|||
Ok(ArrayExpressionInner::Value(res).annotate(inner_ty, array_ty.size))
|
||||
}
|
||||
ArrayExpressionInner::Repeat(box e, box count) => {
|
||||
match target_inner_ty.clone() {
|
||||
Type::Int => Ok(ArrayExpressionInner::Repeat(box e, box count)
|
||||
match &*target_array_ty.ty {
|
||||
GType::Int => Ok(ArrayExpressionInner::Repeat(box e, box count)
|
||||
.annotate(Type::Int, array_ty.size)),
|
||||
// try to align the repeated element to the target type
|
||||
t => TypedExpression::align_to_type(e, t)
|
||||
t => TypedExpression::align_to_type(e, &t)
|
||||
.map(|e| {
|
||||
let ty = e.get_type().clone();
|
||||
|
||||
|
@ -523,7 +566,7 @@ impl<'ast, T: Field> ArrayExpression<'ast, T> {
|
|||
}
|
||||
}
|
||||
a => {
|
||||
if array_ty.ty.weak_eq(&target_inner_ty) {
|
||||
if *target_array_ty.ty == *array_ty.ty {
|
||||
Ok(a.annotate(*array_ty.ty, array_ty.size))
|
||||
} else {
|
||||
Err(a.annotate(*array_ty.ty, array_ty.size).into())
|
||||
|
@ -533,6 +576,50 @@ impl<'ast, T: Field> ArrayExpression<'ast, T> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<'ast, T: Field> StructExpression<'ast, T> {
|
||||
pub fn try_from_int<S: PartialEq<UExpression<'ast, T>>>(
|
||||
struc: Self,
|
||||
target_struct_ty: &GStructType<S>,
|
||||
) -> Result<Self, TypedExpression<'ast, T>> {
|
||||
let struct_ty = struc.ty().clone();
|
||||
|
||||
match struc.into_inner() {
|
||||
StructExpressionInner::Value(inline_struct) => inline_struct
|
||||
.into_iter()
|
||||
.zip(target_struct_ty.members.iter())
|
||||
.map(|(value, target_member)| {
|
||||
TypedExpression::align_to_type(value, &*target_member.ty)
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()
|
||||
.map(|v| StructExpressionInner::Value(v).annotate(struct_ty.clone()))
|
||||
.map_err(|_| unimplemented!()),
|
||||
s => {
|
||||
if struct_ty
|
||||
.members
|
||||
.iter()
|
||||
.zip(target_struct_ty.members.iter())
|
||||
.all(|(m, target_m)| *target_m.ty == *m.ty)
|
||||
{
|
||||
Ok(s.annotate(struct_ty.clone()))
|
||||
} else {
|
||||
Err(s.annotate(struct_ty.clone()).into())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn try_from_typed<S: PartialEq<UExpression<'ast, T>>>(
|
||||
e: TypedExpression<'ast, T>,
|
||||
target_struct_ty: &GStructType<S>,
|
||||
) -> Result<Self, TypedExpression<'ast, T>> {
|
||||
match e {
|
||||
TypedExpression::Struct(e) => Self::try_from_int(e.clone(), target_struct_ty)
|
||||
.map_err(|_| TypedExpression::Struct(e)),
|
||||
e => Err(e),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast, T> From<BigUint> for IntExpression<'ast, T> {
|
||||
fn from(v: BigUint) -> Self {
|
||||
IntExpression::Value(v)
|
||||
|
@ -652,7 +739,7 @@ mod tests {
|
|||
|
||||
for (r, e) in expressions
|
||||
.into_iter()
|
||||
.map(|e| UExpression::try_from_int(e, UBitwidth::B32).unwrap())
|
||||
.map(|e| UExpression::try_from_int(e, &UBitwidth::B32).unwrap())
|
||||
.zip(expected)
|
||||
{
|
||||
assert_eq!(r, e);
|
||||
|
@ -665,7 +752,7 @@ mod tests {
|
|||
|
||||
for e in should_error
|
||||
.into_iter()
|
||||
.map(|e| UExpression::try_from_int(e, UBitwidth::B32))
|
||||
.map(|e| UExpression::try_from_int(e, &UBitwidth::B32))
|
||||
{
|
||||
assert!(e.is_err());
|
||||
}
|
||||
|
|
|
@ -107,15 +107,19 @@ impl<'ast, T: Field> TypedProgram<'ast, T> {
|
|||
.arguments
|
||||
.iter()
|
||||
.map(|p| {
|
||||
println!("{:#?}", p);
|
||||
|
||||
types::ConcreteType::try_from(types::Type::<T>::from(p.id._type.clone()))
|
||||
.map(|ty| AbiInput {
|
||||
public: !p.private,
|
||||
name: p.id.id.to_string(),
|
||||
ty,
|
||||
})
|
||||
.unwrap()
|
||||
types::ConcreteType::try_from(
|
||||
crate::typed_absy::types::try_from_g_type::<
|
||||
crate::typed_absy::types::DeclarationConstant<'ast>,
|
||||
UExpression<'ast, T>,
|
||||
>(p.id._type.clone())
|
||||
.unwrap(),
|
||||
)
|
||||
.map(|ty| AbiInput {
|
||||
public: !p.private,
|
||||
name: p.id.id.to_string(),
|
||||
ty,
|
||||
})
|
||||
.unwrap()
|
||||
})
|
||||
.collect(),
|
||||
outputs: main
|
||||
|
@ -123,7 +127,14 @@ impl<'ast, T: Field> TypedProgram<'ast, T> {
|
|||
.outputs
|
||||
.iter()
|
||||
.map(|ty| {
|
||||
types::ConcreteType::try_from(types::Type::<T>::from(ty.clone())).unwrap()
|
||||
types::ConcreteType::try_from(
|
||||
crate::typed_absy::types::try_from_g_type::<
|
||||
crate::typed_absy::types::DeclarationConstant<'ast>,
|
||||
UExpression<'ast, T>,
|
||||
>(ty.clone())
|
||||
.unwrap(),
|
||||
)
|
||||
.unwrap()
|
||||
})
|
||||
.collect(),
|
||||
}
|
||||
|
@ -194,7 +205,7 @@ impl<'ast, T: fmt::Display> fmt::Display for TypedModule<'ast, T> {
|
|||
.iter()
|
||||
.map(|(id, symbol)| match symbol {
|
||||
TypedConstantSymbol::Here(ref tc) => {
|
||||
format!("const {} {} = {}", tc.ty, id.id, tc.expression)
|
||||
format!("const {} {} = {}", "todo", id.id, tc.expression)
|
||||
}
|
||||
TypedConstantSymbol::There(ref imported_id) => {
|
||||
format!(
|
||||
|
@ -300,27 +311,25 @@ impl<'ast, T: fmt::Display> fmt::Display for TypedFunction<'ast, T> {
|
|||
|
||||
#[derive(Clone, PartialEq, Debug)]
|
||||
pub struct TypedConstant<'ast, T> {
|
||||
// the type is already stored in the TypedExpression, but we want to avoid awkward trait bounds in `fmt::Display`
|
||||
pub ty: Type<'ast, T>,
|
||||
pub expression: TypedExpression<'ast, T>,
|
||||
}
|
||||
|
||||
impl<'ast, T> TypedConstant<'ast, T> {
|
||||
pub fn new(ty: Type<'ast, T>, expression: TypedExpression<'ast, T>) -> Self {
|
||||
TypedConstant { ty, expression }
|
||||
pub fn new(expression: TypedExpression<'ast, T>) -> Self {
|
||||
TypedConstant { expression }
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast, T: fmt::Display> fmt::Display for TypedConstant<'ast, T> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
// using `self.expression.get_type()` would be better here but ends up requiring stronger trait bounds
|
||||
write!(f, "const {}({})", self.ty, self.expression)
|
||||
write!(f, "const {}({})", "todo", self.expression)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast, T: Clone> Typed<'ast, T> for TypedConstant<'ast, T> {
|
||||
impl<'ast, T: Field> Typed<'ast, T> for TypedConstant<'ast, T> {
|
||||
fn get_type(&self) -> Type<'ast, T> {
|
||||
self.ty.clone()
|
||||
self.expression.get_type()
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1168,24 +1177,6 @@ pub struct StructExpression<'ast, T> {
|
|||
inner: StructExpressionInner<'ast, T>,
|
||||
}
|
||||
|
||||
impl<'ast, T: Field> StructExpression<'ast, T> {
|
||||
pub fn try_from_typed(
|
||||
e: TypedExpression<'ast, T>,
|
||||
target_struct_ty: StructType<'ast, T>,
|
||||
) -> Result<Self, TypedExpression<'ast, T>> {
|
||||
match e {
|
||||
TypedExpression::Struct(e) => {
|
||||
if e.ty() == &target_struct_ty {
|
||||
Ok(e)
|
||||
} else {
|
||||
Err(TypedExpression::Struct(e))
|
||||
}
|
||||
}
|
||||
e => Err(e),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast, T> StructExpression<'ast, T> {
|
||||
pub fn ty(&self) -> &StructType<'ast, T> {
|
||||
&self.ty
|
||||
|
|
|
@ -225,6 +225,11 @@ pub trait ResultFolder<'ast, T: Field>: Sized {
|
|||
t: StructType<'ast, T>,
|
||||
) -> Result<StructType<'ast, T>, Self::Error> {
|
||||
Ok(StructType {
|
||||
generics: t
|
||||
.generics
|
||||
.into_iter()
|
||||
.map(|g| g.map(|g| self.fold_uint_expression(g)).transpose())
|
||||
.collect::<Result<Vec<_>, _>>()?,
|
||||
members: t
|
||||
.members
|
||||
.into_iter()
|
||||
|
@ -260,6 +265,11 @@ pub trait ResultFolder<'ast, T: Field>: Sized {
|
|||
t: DeclarationStructType<'ast>,
|
||||
) -> Result<DeclarationStructType<'ast>, Self::Error> {
|
||||
Ok(DeclarationStructType {
|
||||
generics: t
|
||||
.generics
|
||||
.into_iter()
|
||||
.map(|g| g.map(|g| self.fold_declaration_constant(g)).transpose())
|
||||
.collect::<Result<Vec<_>, _>>()?,
|
||||
members: t
|
||||
.members
|
||||
.into_iter()
|
||||
|
@ -1092,7 +1102,6 @@ pub fn fold_constant<'ast, T: Field, F: ResultFolder<'ast, T>>(
|
|||
c: TypedConstant<'ast, T>,
|
||||
) -> Result<TypedConstant<'ast, T>, F::Error> {
|
||||
Ok(TypedConstant {
|
||||
ty: f.fold_type(c.ty)?,
|
||||
expression: f.fold_expression(c.expression)?,
|
||||
})
|
||||
}
|
||||
|
|
|
@ -122,6 +122,21 @@ pub enum DeclarationConstant<'ast> {
|
|||
Constant(CanonicalConstantIdentifier<'ast>),
|
||||
}
|
||||
|
||||
impl<'ast, T> PartialEq<UExpression<'ast, T>> for DeclarationConstant<'ast> {
|
||||
fn eq(&self, other: &UExpression<'ast, T>) -> bool {
|
||||
match (self, other.as_inner()) {
|
||||
(DeclarationConstant::Concrete(c), UExpressionInner::Value(v)) => *c == *v as u32,
|
||||
_ => true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast, T> PartialEq<DeclarationConstant<'ast>> for UExpression<'ast, T> {
|
||||
fn eq(&self, other: &DeclarationConstant<'ast>) -> bool {
|
||||
other.eq(self)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast> From<u32> for DeclarationConstant<'ast> {
|
||||
fn from(e: u32) -> Self {
|
||||
DeclarationConstant::Concrete(e)
|
||||
|
@ -198,7 +213,7 @@ impl<'ast> TryInto<usize> for DeclarationConstant<'ast> {
|
|||
|
||||
pub type MemberId = String;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, PartialOrd, Ord)]
|
||||
#[derive(Debug, Clone, Eq, Hash, Serialize, Deserialize, PartialOrd, Ord)]
|
||||
pub struct GStructMember<S> {
|
||||
#[serde(rename = "name")]
|
||||
pub id: MemberId,
|
||||
|
@ -210,8 +225,8 @@ pub type DeclarationStructMember<'ast> = GStructMember<DeclarationConstant<'ast>
|
|||
pub type ConcreteStructMember = GStructMember<usize>;
|
||||
pub type StructMember<'ast, T> = GStructMember<UExpression<'ast, T>>;
|
||||
|
||||
impl<'ast, T: PartialEq> PartialEq<DeclarationStructMember<'ast>> for StructMember<'ast, T> {
|
||||
fn eq(&self, other: &DeclarationStructMember<'ast>) -> bool {
|
||||
impl<'ast, S, R: PartialEq<S>> PartialEq<GStructMember<S>> for GStructMember<R> {
|
||||
fn eq(&self, other: &GStructMember<S>) -> bool {
|
||||
self.id == other.id && *self.ty == *other.ty
|
||||
}
|
||||
}
|
||||
|
@ -239,19 +254,19 @@ impl<'ast, T> From<ConcreteStructMember> for StructMember<'ast, T> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<'ast> From<ConcreteStructMember> for DeclarationStructMember<'ast> {
|
||||
fn from(t: ConcreteStructMember) -> Self {
|
||||
try_from_g_struct_member(t).unwrap()
|
||||
}
|
||||
}
|
||||
// impl<'ast> From<ConcreteStructMember> for DeclarationStructMember<'ast> {
|
||||
// fn from(t: ConcreteStructMember) -> Self {
|
||||
// try_from_g_struct_member(t).unwrap()
|
||||
// }
|
||||
// }
|
||||
|
||||
impl<'ast, T> From<DeclarationStructMember<'ast>> for StructMember<'ast, T> {
|
||||
fn from(t: DeclarationStructMember<'ast>) -> Self {
|
||||
try_from_g_struct_member(t).unwrap()
|
||||
}
|
||||
}
|
||||
// impl<'ast, T> From<DeclarationStructMember<'ast>> for StructMember<'ast, T> {
|
||||
// fn from(t: DeclarationStructMember<'ast>) -> Self {
|
||||
// try_from_g_struct_member(t).unwrap()
|
||||
// }
|
||||
// }
|
||||
|
||||
#[derive(Clone, PartialEq, Eq, Hash, Serialize, Deserialize, PartialOrd, Ord, Debug)]
|
||||
#[derive(Clone, Eq, Hash, Serialize, Deserialize, PartialOrd, Ord, Debug)]
|
||||
pub struct GArrayType<S> {
|
||||
pub size: S,
|
||||
#[serde(flatten)]
|
||||
|
@ -262,13 +277,9 @@ pub type DeclarationArrayType<'ast> = GArrayType<DeclarationConstant<'ast>>;
|
|||
pub type ConcreteArrayType = GArrayType<usize>;
|
||||
pub type ArrayType<'ast, T> = GArrayType<UExpression<'ast, T>>;
|
||||
|
||||
impl<'ast, T: PartialEq> PartialEq<DeclarationArrayType<'ast>> for ArrayType<'ast, T> {
|
||||
fn eq(&self, other: &DeclarationArrayType<'ast>) -> bool {
|
||||
*self.ty == *other.ty
|
||||
&& match (self.size.as_inner(), &other.size) {
|
||||
(UExpressionInner::Value(l), DeclarationConstant::Concrete(r)) => *l as u32 == *r,
|
||||
_ => true,
|
||||
}
|
||||
impl<'ast, S, R: PartialEq<S>> PartialEq<GArrayType<S>> for GArrayType<R> {
|
||||
fn eq(&self, other: &GArrayType<S>) -> bool {
|
||||
*self.ty == *other.ty && self.size == other.size
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -298,22 +309,6 @@ impl<S: fmt::Display> fmt::Display for GArrayType<S> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<'ast, T: PartialEq + fmt::Display> Type<'ast, T> {
|
||||
// array type equality with non-strict size checks
|
||||
// sizes always match unless they are different constants
|
||||
pub fn weak_eq(&self, other: &Self) -> bool {
|
||||
match (self, other) {
|
||||
(Type::Array(t), Type::Array(u)) => t.ty.weak_eq(&u.ty),
|
||||
(Type::Struct(t), Type::Struct(u)) => t
|
||||
.members
|
||||
.iter()
|
||||
.zip(u.members.iter())
|
||||
.all(|(m, n)| m.ty.weak_eq(&n.ty)),
|
||||
(t, u) => t == u,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn try_from_g_array_type<T: TryInto<U>, U>(
|
||||
t: GArrayType<T>,
|
||||
) -> Result<GArrayType<U>, SpecializationError> {
|
||||
|
@ -370,9 +365,18 @@ pub type DeclarationStructType<'ast> = GStructType<DeclarationConstant<'ast>>;
|
|||
pub type ConcreteStructType = GStructType<usize>;
|
||||
pub type StructType<'ast, T> = GStructType<UExpression<'ast, T>>;
|
||||
|
||||
impl<S: PartialEq> PartialEq for GStructType<S> {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.canonical_location.eq(&other.canonical_location) && self.generics.eq(&other.generics)
|
||||
impl<'ast, S, R: PartialEq<S>> PartialEq<GStructType<S>> for GStructType<R> {
|
||||
fn eq(&self, other: &GStructType<S>) -> bool {
|
||||
self.canonical_location == other.canonical_location
|
||||
&& self
|
||||
.generics
|
||||
.iter()
|
||||
.zip(other.generics.iter())
|
||||
.all(|(a, b)| match (a, b) {
|
||||
(Some(a), Some(b)) => a == b,
|
||||
(None, None) => true,
|
||||
_ => false,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -427,11 +431,11 @@ impl<'ast> From<ConcreteStructType> for DeclarationStructType<'ast> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<'ast, T> From<DeclarationStructType<'ast>> for StructType<'ast, T> {
|
||||
fn from(t: DeclarationStructType<'ast>) -> Self {
|
||||
try_from_g_struct_type(t).unwrap()
|
||||
}
|
||||
}
|
||||
// impl<'ast, T> From<DeclarationStructType<'ast>> for StructType<'ast, T> {
|
||||
// fn from(t: DeclarationStructType<'ast>) -> Self {
|
||||
// try_from_g_struct_type(t).unwrap()
|
||||
// }
|
||||
// }
|
||||
|
||||
impl<S> GStructType<S> {
|
||||
pub fn new(
|
||||
|
@ -514,7 +518,7 @@ impl fmt::Display for UBitwidth {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, PartialEq, Eq, Hash, PartialOrd, Ord, Debug)]
|
||||
#[derive(Clone, Eq, Hash, PartialOrd, Ord, Debug)]
|
||||
pub enum GType<S> {
|
||||
FieldElement,
|
||||
Boolean,
|
||||
|
@ -624,13 +628,13 @@ pub type DeclarationType<'ast> = GType<DeclarationConstant<'ast>>;
|
|||
pub type ConcreteType = GType<usize>;
|
||||
pub type Type<'ast, T> = GType<UExpression<'ast, T>>;
|
||||
|
||||
impl<'ast, T: PartialEq> PartialEq<DeclarationType<'ast>> for Type<'ast, T> {
|
||||
fn eq(&self, other: &DeclarationType<'ast>) -> bool {
|
||||
impl<'ast, S, R: PartialEq<S>> PartialEq<GType<S>> for GType<R> {
|
||||
fn eq(&self, other: >ype<S>) -> bool {
|
||||
use self::GType::*;
|
||||
|
||||
match (self, other) {
|
||||
(Array(l), Array(r)) => l == r,
|
||||
(Struct(l), Struct(r)) => l.canonical_location == r.canonical_location,
|
||||
(Struct(l), Struct(r)) => l == r,
|
||||
(FieldElement, FieldElement) | (Boolean, Boolean) => true,
|
||||
(Uint(l), Uint(r)) => l == r,
|
||||
_ => false,
|
||||
|
@ -638,7 +642,7 @@ impl<'ast, T: PartialEq> PartialEq<DeclarationType<'ast>> for Type<'ast, T> {
|
|||
}
|
||||
}
|
||||
|
||||
fn try_from_g_type<T: TryInto<U>, U>(t: GType<T>) -> Result<GType<U>, SpecializationError> {
|
||||
pub fn try_from_g_type<T: TryInto<U>, U>(t: GType<T>) -> Result<GType<U>, SpecializationError> {
|
||||
match t {
|
||||
GType::FieldElement => Ok(GType::FieldElement),
|
||||
GType::Boolean => Ok(GType::Boolean),
|
||||
|
@ -669,12 +673,6 @@ impl<'ast> From<ConcreteType> for DeclarationType<'ast> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<'ast, T> From<DeclarationType<'ast>> for Type<'ast, T> {
|
||||
fn from(t: DeclarationType<'ast>) -> Self {
|
||||
try_from_g_type(t).unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
impl<S, U: Into<S>> From<(GType<S>, U)> for GArrayType<S> {
|
||||
fn from(tup: (GType<S>, U)) -> Self {
|
||||
GArrayType {
|
||||
|
@ -758,7 +756,7 @@ impl<'ast, T: fmt::Display + PartialEq + fmt::Debug> Type<'ast, T> {
|
|||
pub fn can_be_specialized_to(&self, other: &DeclarationType) -> bool {
|
||||
use self::GType::*;
|
||||
|
||||
if self == other {
|
||||
if other == self {
|
||||
true
|
||||
} else {
|
||||
match (self, other) {
|
||||
|
@ -776,7 +774,13 @@ impl<'ast, T: fmt::Display + PartialEq + fmt::Debug> Type<'ast, T> {
|
|||
}
|
||||
_ => false,
|
||||
},
|
||||
(Struct(_), Struct(_)) => false,
|
||||
(Struct(l), Struct(r)) => {
|
||||
l.canonical_location == r.canonical_location
|
||||
&& l.members
|
||||
.iter()
|
||||
.zip(r.members.iter())
|
||||
.all(|(m, d_m)| m.ty.can_be_specialized_to(&*d_m.ty))
|
||||
}
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
@ -889,14 +893,6 @@ impl<'ast, T> TryFrom<FunctionKey<'ast, T>> for ConcreteFunctionKey<'ast> {
|
|||
}
|
||||
}
|
||||
|
||||
// impl<'ast> TryFrom<DeclarationFunctionKey<'ast>> for ConcreteFunctionKey<'ast> {
|
||||
// type Error = SpecializationError;
|
||||
|
||||
// fn try_from(k: DeclarationFunctionKey<'ast>) -> Result<Self, Self::Error> {
|
||||
// try_from_g_function_key(k)
|
||||
// }
|
||||
// }
|
||||
|
||||
impl<'ast, T> From<ConcreteFunctionKey<'ast>> for FunctionKey<'ast, T> {
|
||||
fn from(k: ConcreteFunctionKey<'ast>) -> Self {
|
||||
try_from_g_function_key(k).unwrap()
|
||||
|
@ -909,12 +905,6 @@ impl<'ast> From<ConcreteFunctionKey<'ast>> for DeclarationFunctionKey<'ast> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<'ast, T> From<DeclarationFunctionKey<'ast>> for FunctionKey<'ast, T> {
|
||||
fn from(k: DeclarationFunctionKey<'ast>) -> Self {
|
||||
try_from_g_function_key(k).unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast, S> GFunctionKey<'ast, S> {
|
||||
pub fn with_location<T: Into<OwnedTypedModuleId>, U: Into<FunctionIdentifier<'ast>>>(
|
||||
module: T,
|
||||
|
@ -993,7 +983,7 @@ pub fn check_type<'ast, S: Clone + PartialEq + PartialEq<usize>>(
|
|||
}
|
||||
}
|
||||
|
||||
pub fn specialize_type<'ast, S: Clone + PartialEq + From<u32> + fmt::Debug>(
|
||||
pub fn specialize_declaration_type<'ast, S: Clone + PartialEq + From<u32> + fmt::Debug>(
|
||||
decl_ty: DeclarationType<'ast>,
|
||||
constants: &GGenericsAssignment<'ast, S>,
|
||||
) -> Result<GType<S>, GenericIdentifier<'ast>> {
|
||||
|
@ -1002,7 +992,7 @@ pub fn specialize_type<'ast, S: Clone + PartialEq + From<u32> + fmt::Debug>(
|
|||
DeclarationType::Array(t0) => {
|
||||
// let s1 = t1.size.clone();
|
||||
|
||||
let ty = box specialize_type(*t0.ty, &constants)?;
|
||||
let ty = box specialize_declaration_type(*t0.ty, &constants)?;
|
||||
let size = match t0.size {
|
||||
DeclarationConstant::Generic(s) => constants.0.get(&s).cloned().ok_or(s),
|
||||
DeclarationConstant::Concrete(s) => Ok(s.into()),
|
||||
|
@ -1022,7 +1012,8 @@ pub fn specialize_type<'ast, S: Clone + PartialEq + From<u32> + fmt::Debug>(
|
|||
.into_iter()
|
||||
.map(|m| {
|
||||
let id = m.id;
|
||||
specialize_type(*m.ty, constants).map(|ty| GStructMember { ty: box ty, id })
|
||||
specialize_declaration_type(*m.ty, constants)
|
||||
.map(|ty| GStructMember { ty: box ty, id })
|
||||
})
|
||||
.collect::<Result<_, _>>()?,
|
||||
generics: s0
|
||||
|
@ -1049,7 +1040,9 @@ pub fn specialize_type<'ast, S: Clone + PartialEq + From<u32> + fmt::Debug>(
|
|||
})
|
||||
}
|
||||
|
||||
pub use self::signature::{ConcreteSignature, DeclarationSignature, GSignature, Signature};
|
||||
pub use self::signature::{
|
||||
try_from_g_signature, ConcreteSignature, DeclarationSignature, GSignature, Signature,
|
||||
};
|
||||
|
||||
pub mod signature {
|
||||
use super::*;
|
||||
|
@ -1194,7 +1187,7 @@ pub mod signature {
|
|||
self.outputs
|
||||
.clone()
|
||||
.into_iter()
|
||||
.map(|t| specialize_type(t, &constants))
|
||||
.map(|t| specialize_declaration_type(t, &constants))
|
||||
.collect::<Result<_, _>>()
|
||||
}
|
||||
}
|
||||
|
@ -1244,12 +1237,6 @@ pub mod signature {
|
|||
}
|
||||
}
|
||||
|
||||
impl<'ast, T> From<DeclarationSignature<'ast>> for Signature<'ast, T> {
|
||||
fn from(s: DeclarationSignature<'ast>) -> Self {
|
||||
try_from_g_signature(s).unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: fmt::Display> fmt::Display for GSignature<S> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
if !self.generics.is_empty() {
|
||||
|
|
|
@ -45,7 +45,7 @@ impl<'ast, T> From<ConcreteVariable<'ast>> for Variable<'ast, T> {
|
|||
|
||||
impl<'ast, T> From<DeclarationVariable<'ast>> for Variable<'ast, T> {
|
||||
fn from(v: DeclarationVariable<'ast>) -> Self {
|
||||
let _type = v._type.into();
|
||||
let _type = crate::typed_absy::types::try_from_g_type(v._type).unwrap();
|
||||
|
||||
Self { _type, id: v.id }
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue