1
0
Fork 0
mirror of synced 2025-09-24 04:40:05 +00:00

refactor types, implement struct generics

This commit is contained in:
schaeff 2021-07-15 10:44:44 +02:00
parent bfcc733714
commit e5bcbed81f
14 changed files with 489 additions and 360 deletions

View file

@ -432,6 +432,7 @@ mod tests {
vec![ConcreteType::Struct(ConcreteStructType::new(
"".into(),
"".into(),
vec![],
vec![ConcreteStructMember::new(
"a".into(),
ConcreteType::FieldElement
@ -453,6 +454,7 @@ mod tests {
vec![ConcreteType::Struct(ConcreteStructType::new(
"".into(),
"".into(),
vec![],
vec![ConcreteStructMember::new(
"a".into(),
ConcreteType::FieldElement
@ -470,6 +472,7 @@ mod tests {
vec![ConcreteType::Struct(ConcreteStructType::new(
"".into(),
"".into(),
vec![],
vec![ConcreteStructMember::new(
"a".into(),
ConcreteType::FieldElement
@ -487,6 +490,7 @@ mod tests {
vec![ConcreteType::Struct(ConcreteStructType::new(
"".into(),
"".into(),
vec![],
vec![ConcreteStructMember::new(
"a".into(),
ConcreteType::FieldElement

View file

@ -417,11 +417,13 @@ struct Bar { field a }
ty: ConcreteType::Struct(ConcreteStructType::new(
"foo".into(),
"Foo".into(),
vec![],
vec![ConcreteStructMember {
id: "b".into(),
ty: box ConcreteType::Struct(ConcreteStructType::new(
"bar".into(),
"Bar".into(),
vec![],
vec![ConcreteStructMember {
id: "a".into(),
ty: box ConcreteType::FieldElement

View file

@ -20,7 +20,7 @@ use crate::parser::Position;
use crate::absy::types::{UnresolvedSignature, UnresolvedType, UserTypeId};
use crate::typed_absy::types::{
check_type, specialize_type, ArrayType, DeclarationArrayType, DeclarationConstant,
check_type, specialize_declaration_type, ArrayType, DeclarationArrayType, DeclarationConstant,
DeclarationFunctionKey, DeclarationSignature, DeclarationStructMember, DeclarationStructType,
DeclarationType, GenericIdentifier, StructLocation, StructMember,
};
@ -357,29 +357,34 @@ impl<'ast, T: Field> Checker<'ast, T> {
state: &State<'ast, T>,
) -> Result<TypedConstant<'ast, T>, ErrorInner> {
let pos = c.pos();
let ty = self.check_type(c.value.ty.clone(), module_id, &state.types)?;
let ty = self.check_declaration_type(
c.value.ty.clone(),
module_id,
&state,
&HashMap::default(),
&mut HashSet::default(),
)?;
let checked_expr =
self.check_expression(c.value.expression.clone(), module_id, &state.types)?;
match ty {
Type::FieldElement => {
DeclarationType::FieldElement => {
FieldElementExpression::try_from_typed(checked_expr).map(TypedExpression::from)
}
Type::Boolean => {
DeclarationType::Boolean => {
BooleanExpression::try_from_typed(checked_expr).map(TypedExpression::from)
}
Type::Uint(bitwidth) => {
UExpression::try_from_typed(checked_expr, bitwidth).map(TypedExpression::from)
DeclarationType::Uint(bitwidth) => {
UExpression::try_from_typed(checked_expr, &bitwidth).map(TypedExpression::from)
}
Type::Array(ref array_ty) => {
ArrayExpression::try_from_typed(checked_expr, *array_ty.ty.clone())
DeclarationType::Array(ref array_ty) => {
ArrayExpression::try_from_typed(checked_expr, &array_ty).map(TypedExpression::from)
}
DeclarationType::Struct(ref struct_ty) => {
StructExpression::try_from_typed(checked_expr, &struct_ty)
.map(TypedExpression::from)
}
Type::Struct(ref struct_ty) => {
StructExpression::try_from_typed(checked_expr, struct_ty.clone())
.map(TypedExpression::from)
}
Type::Int => Err(checked_expr), // Integers cannot be assigned
DeclarationType::Int => Err(checked_expr), // Integers cannot be assigned
}
.map_err(|e| ErrorInner {
pos: Some(pos),
@ -391,7 +396,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
ty
),
})
.map(|e| TypedConstant::new(ty, e))
.map(|e| TypedConstant::new(e))
}
fn check_struct_type_declaration(
@ -1152,14 +1157,17 @@ impl<'ast, T: Field> Checker<'ast, T> {
),
}),
},
TypedExpression::Int(v) => UExpression::try_from_int(v.clone(), UBitwidth::B32)
.map_err(|_| ErrorInner {
pos: Some(pos),
message: format!(
TypedExpression::Int(v) => {
UExpression::try_from_int(v.clone(), &UBitwidth::B32).map_err(|_| {
ErrorInner {
pos: Some(pos),
message: format!(
"Expected array dimension to be a u32 constant, found {} of type {}",
v, ty
),
}),
}
})
}
_ => Err(ErrorInner {
pos: Some(pos),
message: format!(
@ -1194,33 +1202,38 @@ impl<'ast, T: Field> Checker<'ast, T> {
DeclarationType::Struct(struct_type) => {
match struct_type.generics.len() == generics.len() {
true => {
let checked_generics = generics
// downcast the generics to identifiers, as this is the only possibility here
let generic_identifiers = struct_type.generics.iter().map(|c| {
match c.as_ref().unwrap() {
DeclarationConstant::Generic(g) => g.clone(),
_ => unreachable!(),
}
});
// build the generic assignment for this type
let assignment = GGenericsAssignment(generics
.into_iter()
.map(|e| match e {
Some(e) => self
.zip(generic_identifiers)
.filter_map(|(e, g)| e.map(|e| {
self
.check_expression(e, module_id, types)
.and_then(|e| {
UExpression::try_from_typed(e, UBitwidth::B32)
.map(Some)
UExpression::try_from_typed(e, &UBitwidth::B32)
.map(|e| (g, e))
.map_err(|e| ErrorInner {
pos: Some(pos),
message: format!("Expected u32 expression, but got expression of type {}", e.get_type()),
})
}),
None => Ok(None),
})
.collect::<Result<_, _>>()?;
})
}))
.collect::<Result<_, _>>()?);
Ok(Type::Struct(StructType {
canonical_location: struct_type.canonical_location,
location: struct_type.location,
generics: checked_generics,
members: struct_type
.members
.into_iter()
.map(|m| m.into())
.collect(),
}))
// specialize the declared type using the generic assignment
Ok(specialize_declaration_type(
DeclarationType::Struct(struct_type),
&assignment,
)
.unwrap())
}
false => Err(ErrorInner {
pos: Some(pos),
@ -1384,7 +1397,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
.into_iter()
.map(|m| {
Ok(DeclarationStructMember {
ty: box specialize_type(*m.ty, &assignment)
ty: box specialize_declaration_type(*m.ty, &assignment)
.map_err(|_| unimplemented!())?,
..m
})
@ -1465,7 +1478,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
}),
},
TypedExpression::Int(v) => {
UExpression::try_from_int(v, UBitwidth::B32).map_err(|_| ErrorInner {
UExpression::try_from_int(v, &UBitwidth::B32).map_err(|_| ErrorInner {
pos: Some(pos),
message: format!(
"Expected lower loop bound to be of type u32, found {}",
@ -1495,7 +1508,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
}),
},
TypedExpression::Int(v) => {
UExpression::try_from_int(v, UBitwidth::B32).map_err(|_| ErrorInner {
UExpression::try_from_int(v, &UBitwidth::B32).map_err(|_| ErrorInner {
pos: Some(pos),
message: format!(
"Expected upper loop bound to be of type u32, found {}",
@ -1550,9 +1563,10 @@ impl<'ast, T: Field> Checker<'ast, T> {
let res = match expression_list_checked.len() == return_types.len() {
true => match expression_list_checked
.iter()
.zip(return_types.clone())
.map(|(e, t)| TypedExpression::align_to_type(e.clone(), t.into()))
.clone()
.into_iter()
.zip(return_types.iter())
.map(|(e, t)| TypedExpression::align_to_type(e, t))
.collect::<Result<Vec<_>, _>>()
.map_err(|e| {
vec![ErrorInner {
@ -1641,19 +1655,19 @@ impl<'ast, T: Field> Checker<'ast, T> {
let var_type = var.get_type();
// make sure the assignee has the same type as the rhs
match var_type.clone() {
match var_type {
Type::FieldElement => FieldElementExpression::try_from_typed(checked_expr)
.map(TypedExpression::from),
Type::Boolean => {
BooleanExpression::try_from_typed(checked_expr).map(TypedExpression::from)
}
Type::Uint(bitwidth) => UExpression::try_from_typed(checked_expr, bitwidth)
Type::Uint(bitwidth) => UExpression::try_from_typed(checked_expr, &bitwidth)
.map(TypedExpression::from),
Type::Array(array_ty) => {
ArrayExpression::try_from_typed(checked_expr, *array_ty.ty)
Type::Array(ref array_ty) => {
ArrayExpression::try_from_typed(checked_expr, array_ty)
.map(TypedExpression::from)
}
Type::Struct(struct_ty) => {
Type::Struct(ref struct_ty) => {
StructExpression::try_from_typed(checked_expr, struct_ty)
.map(TypedExpression::from)
}
@ -1710,7 +1724,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
g.map(|g| {
let pos = g.pos();
self.check_expression(g, module_id, types).and_then(|g| {
UExpression::try_from_typed(g, UBitwidth::B32).map_err(
UExpression::try_from_typed(g, &UBitwidth::B32).map_err(
|e| ErrorInner {
pos: Some(pos),
message: format!(
@ -1759,7 +1773,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
let mut functions = functions;
let f = functions.pop().unwrap();
let arguments_checked = arguments_checked.into_iter().zip(f.signature.inputs.clone()).map(|(a, t)| TypedExpression::align_to_type(a, t.into())).collect::<Result<Vec<_>, _>>().map_err(|e| vec![ErrorInner {
let arguments_checked = arguments_checked.into_iter().zip(f.signature.inputs.iter()).map(|(a, t)| TypedExpression::align_to_type(a, t)).collect::<Result<Vec<_>, _>>().map_err(|e| vec![ErrorInner {
pos: Some(pos),
message: format!("Expected function call argument to be of type {}, found {} of type {}", e.1, e.0, e.0.get_type())
}])?;
@ -1827,7 +1841,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
};
let checked_typed_index =
UExpression::try_from_typed(checked_index, UBitwidth::B32).map_err(
UExpression::try_from_typed(checked_index, &UBitwidth::B32).map_err(
|e| ErrorInner {
pos: Some(pos),
message: format!(
@ -2127,7 +2141,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
Ok(e) => e.into(),
Err(e) => e,
};
let e2_checked = match UExpression::try_from_typed(e2_checked, UBitwidth::B32) {
let e2_checked = match UExpression::try_from_typed(e2_checked, &UBitwidth::B32) {
Ok(e) => e.into(),
Err(e) => e,
};
@ -2261,7 +2275,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
g.map(|g| {
let pos = g.pos();
self.check_expression(g, module_id, types).and_then(|g| {
UExpression::try_from_typed(g, UBitwidth::B32).map_err(
UExpression::try_from_typed(g, &UBitwidth::B32).map_err(
|e| ErrorInner {
pos: Some(pos),
message: format!(
@ -2307,7 +2321,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
let signature = f.signature;
let arguments_checked = arguments_checked.into_iter().zip(signature.inputs.clone()).map(|(a, t)| TypedExpression::align_to_type(a, t.into())).collect::<Result<Vec<_>, _>>().map_err(|e| ErrorInner {
let arguments_checked = arguments_checked.into_iter().zip(signature.inputs.iter()).map(|(a, t)| TypedExpression::align_to_type(a, &t)).collect::<Result<Vec<_>, _>>().map_err(|e| ErrorInner {
pos: Some(pos),
message: format!("Expected function call argument to be of type {}, found {}", e.1, e.0)
})?;
@ -2513,9 +2527,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
(TypedExpression::Array(e1), TypedExpression::Array(e2)) => {
Ok(BooleanExpression::ArrayEq(box e1, box e2).into())
}
(TypedExpression::Struct(e1), TypedExpression::Struct(e2))
if e1.get_type() == e2.get_type() =>
{
(TypedExpression::Struct(e1), TypedExpression::Struct(e2)) => {
Ok(BooleanExpression::StructEq(box e1, box e2).into())
}
(TypedExpression::Uint(e1), TypedExpression::Uint(e2))
@ -2659,7 +2671,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
.map(|e| self.check_expression(e, module_id, types))
.unwrap_or_else(|| Ok(array_size.clone().into()))?;
let from = UExpression::try_from_typed(from, UBitwidth::B32).map_err(|e| ErrorInner {
let from = UExpression::try_from_typed(from, &UBitwidth::B32).map_err(|e| ErrorInner {
pos: Some(pos),
message: format!(
"Expected the lower bound of the range to be a u32, found {} of type {}",
@ -2668,7 +2680,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
),
})?;
let to = UExpression::try_from_typed(to, UBitwidth::B32).map_err(|e| ErrorInner {
let to = UExpression::try_from_typed(to, &UBitwidth::B32).map_err(|e| ErrorInner {
pos: Some(pos),
message: format!(
"Expected the upper bound of the range to be a u32, found {} of type {}",
@ -2699,7 +2711,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
let index = self.check_expression(index, module_id, types)?;
let index =
UExpression::try_from_typed(index, UBitwidth::B32).map_err(|e| {
UExpression::try_from_typed(index, &UBitwidth::B32).map_err(|e| {
ErrorInner {
pos: Some(pos),
message: format!(
@ -2813,19 +2825,24 @@ impl<'ast, T: Field> Checker<'ast, T> {
.next()
.unwrap_or(Type::Int);
let unwrapped_expressions_or_spreads = match &inferred_type {
let unwrapped_expressions_or_spreads = match inferred_type.clone() {
Type::Int => expressions_or_spreads_checked,
t => expressions_or_spreads_checked
.into_iter()
.map(|e| {
TypedExpressionOrSpread::align_to_type(e, t.clone()).map_err(
|(e, ty)| ErrorInner {
pos: Some(pos),
message: format!("Expected {} to have type {}", e, ty,),
},
)
})
.collect::<Result<Vec<_>, _>>()?,
t => {
let target_array_ty =
ArrayType::new(t, UExpressionInner::Value(0).annotate(UBitwidth::B32));
expressions_or_spreads_checked
.into_iter()
.map(|e| {
TypedExpressionOrSpread::align_to_type(e, &target_array_ty).map_err(
|(e, ty)| ErrorInner {
pos: Some(pos),
message: format!("Expected {} to have type {}", e, ty,),
},
)
})
.collect::<Result<Vec<_>, _>>()?
}
};
// the size of the inline array is the sum of the size of its elements. However expressed as a u32 expression,
@ -2863,15 +2880,16 @@ impl<'ast, T: Field> Checker<'ast, T> {
let count = self.check_expression(count, module_id, types)?;
let count =
UExpression::try_from_typed(count, UBitwidth::B32).map_err(|e| ErrorInner {
let count = UExpression::try_from_typed(count, &UBitwidth::B32).map_err(|e| {
ErrorInner {
pos: Some(pos),
message: format!(
"Expected array initializer count to be a u32, found {} of type {}",
e,
e.get_type(),
),
})?;
}
})?;
Ok(ArrayExpressionInner::Repeat(box e, box count.clone())
.annotate(ty, count)
@ -2881,7 +2899,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
let ty = match types.get(module_id).unwrap().get(&id).cloned() {
None => Err(ErrorInner {
pos: Some(pos),
message: format!("Undeclared type of name `{}`", id),
message: format!("Undefined type `{}`", id),
}),
Some(ty) => Ok(ty),
}?;
@ -2926,13 +2944,11 @@ impl<'ast, T: Field> Checker<'ast, T> {
let expression_checked =
self.check_expression(value, module_id, types)?;
let expression_checked = TypedExpression::align_to_type(
expression_checked,
Type::from(*member.ty.clone()),
)
.map_err(|e| ErrorInner {
pos: Some(pos),
message: format!(
let expression_checked =
TypedExpression::align_to_type(expression_checked, &*member.ty)
.map_err(|e| ErrorInner {
pos: Some(pos),
message: format!(
"Member {} of struct {} has type {}, found {} of type {}",
member.id,
id.clone(),
@ -2940,7 +2956,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
e.0,
e.0.get_type(),
),
})?;
})?;
Ok(expression_checked)
}
@ -3056,7 +3072,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
let e2 = self.check_expression(e2, module_id, types)?;
let e2 =
UExpression::try_from_typed(e2, UBitwidth::B32).map_err(|e| ErrorInner {
UExpression::try_from_typed(e2, &UBitwidth::B32).map_err(|e| ErrorInner {
pos: Some(pos),
message: format!(
"Expected the left shift right operand to have type `u32`, found {}",
@ -3083,7 +3099,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
let e2 = self.check_expression(e2, module_id, types)?;
let e2 =
UExpression::try_from_typed(e2, UBitwidth::B32).map_err(|e| ErrorInner {
UExpression::try_from_typed(e2, &UBitwidth::B32).map_err(|e| ErrorInner {
pos: Some(pos),
message: format!(
"Expected the right shift right operand to be of type `u32`, found {}",
@ -3402,11 +3418,16 @@ mod tests {
use super::*;
fn struct0() -> StructDefinitionNode<'static> {
StructDefinition { fields: vec![] }.mock()
StructDefinition {
generics: vec![],
fields: vec![],
}
.mock()
}
fn struct1() -> StructDefinitionNode<'static> {
StructDefinition {
generics: vec![],
fields: vec![StructDefinitionField {
id: "foo",
ty: UnresolvedType::FieldElement.mock(),
@ -3730,7 +3751,7 @@ mod tests {
checker.check_module(&*MODULE_ID, &mut state).unwrap_err()[0]
.inner
.message,
"Undeclared symbol `P` in function definition"
"Undeclared symbol `P`"
);
}
}
@ -3839,7 +3860,11 @@ mod tests {
SymbolDeclaration {
id: "foo",
symbol: Symbol::Here(SymbolDefinition::Struct(
StructDefinition { fields: vec![] }.mock(),
StructDefinition {
generics: vec![],
fields: vec![],
}
.mock(),
)),
}
.mock(),
@ -3990,7 +4015,7 @@ mod tests {
Checker::<Bn128Field>::new().check_signature(signature, &*MODULE_ID, &state),
Err(vec![ErrorInner {
pos: Some((Position::mock(), Position::mock())),
message: "Undeclared symbol `K` in function definition".to_string()
message: "Undeclared symbol `K`".to_string()
}])
);
}
@ -5287,12 +5312,17 @@ mod tests {
let modules = Modules::new();
let state = State::new(modules);
let declaration: StructDefinitionNode = StructDefinition { fields: vec![] }.mock();
let declaration: StructDefinitionNode = StructDefinition {
generics: vec![],
fields: vec![],
}
.mock();
let expected_type = DeclarationType::Struct(DeclarationStructType::new(
"".into(),
"Foo".into(),
vec![],
vec![],
));
assert_eq!(
@ -5313,6 +5343,7 @@ mod tests {
let state = State::new(modules);
let declaration: StructDefinitionNode = StructDefinition {
generics: vec![],
fields: vec![
StructDefinitionField {
id: "foo",
@ -5331,6 +5362,7 @@ mod tests {
let expected_type = DeclarationType::Struct(DeclarationStructType::new(
"".into(),
"Foo".into(),
vec![],
vec![
DeclarationStructMember::new("foo".into(), DeclarationType::FieldElement),
DeclarationStructMember::new("bar".into(), DeclarationType::Boolean),
@ -5355,6 +5387,7 @@ mod tests {
let state = State::new(modules);
let declaration: StructDefinitionNode = StructDefinition {
generics: vec![],
fields: vec![
StructDefinitionField {
id: "foo",
@ -5397,6 +5430,7 @@ mod tests {
id: "Foo",
symbol: Symbol::Here(SymbolDefinition::Struct(
StructDefinition {
generics: vec![],
fields: vec![StructDefinitionField {
id: "foo",
ty: UnresolvedType::FieldElement.mock(),
@ -5411,9 +5445,10 @@ mod tests {
id: "Bar",
symbol: Symbol::Here(SymbolDefinition::Struct(
StructDefinition {
generics: vec![],
fields: vec![StructDefinitionField {
id: "foo",
ty: UnresolvedType::User("Foo".into()).mock(),
ty: UnresolvedType::User("Foo".into(), None).mock(),
}
.mock()],
}
@ -5439,11 +5474,13 @@ mod tests {
&DeclarationType::Struct(DeclarationStructType::new(
(*MODULE_ID).clone(),
"Bar".into(),
vec![],
vec![DeclarationStructMember::new(
"foo".into(),
DeclarationType::Struct(DeclarationStructType::new(
(*MODULE_ID).clone(),
"Foo".into(),
vec![],
vec![DeclarationStructMember::new(
"foo".into(),
DeclarationType::FieldElement
@ -5465,9 +5502,10 @@ mod tests {
id: "Bar",
symbol: Symbol::Here(SymbolDefinition::Struct(
StructDefinition {
generics: vec![],
fields: vec![StructDefinitionField {
id: "foo",
ty: UnresolvedType::User("Foo".into()).mock(),
ty: UnresolvedType::User("Foo".into(), None).mock(),
}
.mock()],
}
@ -5497,9 +5535,10 @@ mod tests {
id: "Foo",
symbol: Symbol::Here(SymbolDefinition::Struct(
StructDefinition {
generics: vec![],
fields: vec![StructDefinitionField {
id: "foo",
ty: UnresolvedType::User("Foo".into()).mock(),
ty: UnresolvedType::User("Foo".into(), None).mock(),
}
.mock()],
}
@ -5531,9 +5570,10 @@ mod tests {
id: "Foo",
symbol: Symbol::Here(SymbolDefinition::Struct(
StructDefinition {
generics: vec![],
fields: vec![StructDefinitionField {
id: "bar",
ty: UnresolvedType::User("Bar".into()).mock(),
ty: UnresolvedType::User("Bar".into(), None).mock(),
}
.mock()],
}
@ -5545,9 +5585,10 @@ mod tests {
id: "Bar",
symbol: Symbol::Here(SymbolDefinition::Struct(
StructDefinition {
generics: vec![],
fields: vec![StructDefinitionField {
id: "foo",
ty: UnresolvedType::User("Foo".into()).mock(),
ty: UnresolvedType::User("Foo".into(), None).mock(),
}
.mock()],
}
@ -5582,6 +5623,7 @@ mod tests {
// Bar
let (mut checker, state) = create_module_with_foo(StructDefinition {
generics: vec![],
fields: vec![StructDefinitionField {
id: "foo",
ty: UnresolvedType::FieldElement.mock(),
@ -5591,13 +5633,14 @@ mod tests {
assert_eq!(
checker.check_type(
UnresolvedType::User("Foo".into()).mock(),
UnresolvedType::User("Foo".into(), None).mock(),
&*MODULE_ID,
&state.types
),
Ok(Type::Struct(StructType::new(
"".into(),
"Foo".into(),
vec![],
vec![StructMember::new("foo".into(), Type::FieldElement)]
)))
);
@ -5605,7 +5648,7 @@ mod tests {
assert_eq!(
checker
.check_type(
UnresolvedType::User("Bar".into()).mock(),
UnresolvedType::User("Bar".into(), None).mock(),
&*MODULE_ID,
&state.types
)
@ -5628,6 +5671,7 @@ mod tests {
// Foo { foo: 42 }.foo
let (mut checker, state) = create_module_with_foo(StructDefinition {
generics: vec![],
fields: vec![StructDefinitionField {
id: "foo",
ty: UnresolvedType::FieldElement.mock(),
@ -5657,6 +5701,7 @@ mod tests {
.annotate(StructType::new(
"".into(),
"Foo".into(),
vec![],
vec![StructMember::new("foo".into(), Type::FieldElement)]
)),
"foo".into()
@ -5673,6 +5718,7 @@ mod tests {
// Foo { foo: 42 }.bar
let (mut checker, state) = create_module_with_foo(StructDefinition {
generics: vec![],
fields: vec![StructDefinitionField {
id: "foo",
ty: UnresolvedType::FieldElement.mock(),
@ -5711,6 +5757,7 @@ mod tests {
// a A value cannot be defined with B as id, even if A and B have the same members
let (mut checker, state) = create_module_with_foo(StructDefinition {
generics: vec![],
fields: vec![StructDefinitionField {
id: "foo",
ty: UnresolvedType::FieldElement.mock(),
@ -5731,7 +5778,7 @@ mod tests {
)
.unwrap_err()
.message,
"Undefined type Bar"
"Undefined type `Bar`"
);
}
@ -5743,6 +5790,7 @@ mod tests {
// Foo foo = Foo { foo: 42, bar: true }
let (mut checker, state) = create_module_with_foo(StructDefinition {
generics: vec![],
fields: vec![
StructDefinitionField {
id: "foo",
@ -5777,6 +5825,7 @@ mod tests {
.annotate(StructType::new(
"".into(),
"Foo".into(),
vec![],
vec![
StructMember::new("foo".into(), Type::FieldElement),
StructMember::new("bar".into(), Type::Boolean)
@ -5794,6 +5843,7 @@ mod tests {
// Foo foo = Foo { bar: true, foo: 42 }
let (mut checker, state) = create_module_with_foo(StructDefinition {
generics: vec![],
fields: vec![
StructDefinitionField {
id: "foo",
@ -5828,6 +5878,7 @@ mod tests {
.annotate(StructType::new(
"".into(),
"Foo".into(),
vec![],
vec![
StructMember::new("foo".into(), Type::FieldElement),
StructMember::new("bar".into(), Type::Boolean)
@ -5845,6 +5896,7 @@ mod tests {
// Foo foo = Foo { foo: 42 }
let (mut checker, state) = create_module_with_foo(StructDefinition {
generics: vec![],
fields: vec![
StructDefinitionField {
id: "foo",
@ -5886,6 +5938,7 @@ mod tests {
// Foo { foo: 42, baz: 42 } // error
let (mut checker, state) = create_module_with_foo(StructDefinition {
generics: vec![],
fields: vec![
StructDefinitionField {
id: "foo",

View file

@ -112,7 +112,6 @@ impl<'ast, T: Field> Folder<'ast, T> for ConstantInliner<'ast, T> {
(
id,
TypedConstantSymbol::Here(TypedConstant {
ty: constant.get_type().clone(),
expression: constant,
}),
)
@ -268,10 +267,9 @@ mod tests {
let constants: TypedConstantSymbols<_> = vec![(
CanonicalConstantIdentifier::new(const_id, "main".into()),
TypedConstantSymbol::Here(TypedConstant::new(
GType::FieldElement,
TypedExpression::FieldElement(FieldElementExpression::Number(Bn128Field::from(1))),
)),
TypedConstantSymbol::Here(TypedConstant::new(TypedExpression::FieldElement(
FieldElementExpression::Number(Bn128Field::from(1)),
))),
)]
.into_iter()
.collect();
@ -356,10 +354,9 @@ mod tests {
let constants: TypedConstantSymbols<_> = vec![(
CanonicalConstantIdentifier::new(const_id, "main".into()),
TypedConstantSymbol::Here(TypedConstant::new(
GType::Boolean,
TypedExpression::Boolean(BooleanExpression::Value(true)),
)),
TypedConstantSymbol::Here(TypedConstant::new(TypedExpression::Boolean(
BooleanExpression::Value(true),
))),
)]
.into_iter()
.collect();
@ -446,7 +443,6 @@ mod tests {
let constants: TypedConstantSymbols<_> = vec![(
CanonicalConstantIdentifier::new(const_id, "main".into()),
TypedConstantSymbol::Here(TypedConstant::new(
GType::Uint(UBitwidth::B32),
UExpressionInner::Value(1u128)
.annotate(UBitwidth::B32)
.into(),
@ -546,19 +542,16 @@ mod tests {
let constants: TypedConstantSymbols<_> = vec![(
CanonicalConstantIdentifier::new(const_id, "main".into()),
TypedConstantSymbol::Here(TypedConstant::new(
GType::array(GArrayType::new(GType::FieldElement, 2usize)),
TypedExpression::Array(
ArrayExpressionInner::Value(
vec![
FieldElementExpression::Number(Bn128Field::from(2)).into(),
FieldElementExpression::Number(Bn128Field::from(2)).into(),
]
.into(),
)
.annotate(GType::FieldElement, 2usize),
),
)),
TypedConstantSymbol::Here(TypedConstant::new(TypedExpression::Array(
ArrayExpressionInner::Value(
vec![
FieldElementExpression::Number(Bn128Field::from(2)).into(),
FieldElementExpression::Number(Bn128Field::from(2)).into(),
]
.into(),
)
.annotate(GType::FieldElement, 2usize),
))),
)]
.into_iter()
.collect();
@ -686,7 +679,6 @@ mod tests {
(
CanonicalConstantIdentifier::new(const_a_id, "main".into()),
TypedConstantSymbol::Here(TypedConstant::new(
GType::FieldElement,
TypedExpression::FieldElement(FieldElementExpression::Number(
Bn128Field::from(1),
)),
@ -695,7 +687,6 @@ mod tests {
(
CanonicalConstantIdentifier::new(const_b_id, "main".into()),
TypedConstantSymbol::Here(TypedConstant::new(
GType::FieldElement,
TypedExpression::FieldElement(FieldElementExpression::Add(
box FieldElementExpression::Identifier(Identifier::from(
const_a_id,
@ -744,7 +735,6 @@ mod tests {
(
CanonicalConstantIdentifier::new(const_a_id, "main".into()),
TypedConstantSymbol::Here(TypedConstant::new(
GType::FieldElement,
TypedExpression::FieldElement(FieldElementExpression::Number(
Bn128Field::from(1),
)),
@ -753,7 +743,6 @@ mod tests {
(
CanonicalConstantIdentifier::new(const_b_id, "main".into()),
TypedConstantSymbol::Here(TypedConstant::new(
GType::FieldElement,
TypedExpression::FieldElement(FieldElementExpression::Number(
Bn128Field::from(2),
)),
@ -804,12 +793,9 @@ mod tests {
.collect(),
constants: vec![(
CanonicalConstantIdentifier::new(foo_const_id, "foo".into()),
TypedConstantSymbol::Here(TypedConstant::new(
GType::FieldElement,
TypedExpression::FieldElement(FieldElementExpression::Number(
Bn128Field::from(42),
)),
)),
TypedConstantSymbol::Here(TypedConstant::new(TypedExpression::FieldElement(
FieldElementExpression::Number(Bn128Field::from(42)),
))),
)]
.into_iter()
.collect(),
@ -877,12 +863,9 @@ mod tests {
.collect(),
constants: vec![(
CanonicalConstantIdentifier::new(foo_const_id, "main".into()),
TypedConstantSymbol::Here(TypedConstant::new(
GType::FieldElement,
TypedExpression::FieldElement(FieldElementExpression::Number(
Bn128Field::from(42),
)),
)),
TypedConstantSymbol::Here(TypedConstant::new(TypedExpression::FieldElement(
FieldElementExpression::Number(Bn128Field::from(42)),
))),
)]
.into_iter()
.collect(),

View file

@ -1101,7 +1101,11 @@ fn fold_function<'ast, T: Field>(
.collect(),
statements: main_statements_buffer,
signature: typed_absy::types::ConcreteSignature::try_from(
typed_absy::types::Signature::<T>::try_from(fun.signature).unwrap(),
crate::typed_absy::types::try_from_g_signature::<
crate::typed_absy::types::DeclarationConstant<'ast>,
crate::typed_absy::UExpression<'ast, T>,
>(fun.signature)
.unwrap(),
)
.unwrap()
.into(),

View file

@ -72,8 +72,6 @@ impl fmt::Display for Error {
impl<'ast, T: Field> TypedProgram<'ast, T> {
pub fn analyse(self, config: &CompileConfig) -> Result<(ZirProgram<'ast, T>, Abi), Error> {
println!("{}", self);
// inline user-defined constants
let r = ConstantInliner::inline(self);
// isolate branches
@ -86,13 +84,12 @@ impl<'ast, T: Field> TypedProgram<'ast, T> {
// reduce the program to a single function
let r = reduce_program(r).map_err(Error::from)?;
println!("{}", r);
// generate abi
let abi = r.abi();
// propagate
let r = Propagator::propagate(r).map_err(Error::from)?;
// remove assignment to variable index
let r = VariableWriteRemover::apply(r);
// detect non constant shifts

View file

@ -327,20 +327,13 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Propagator<'ast, 'a, T> {
// propagation to the defined variable if rhs is a constant
TypedStatement::Definition(assignee, expr) => {
let expr = self.fold_expression(expr)?;
let assignee = self.fold_assignee(assignee)?;
println!(
"{:#?} vs {:#?}",
assignee.get_type().clone(),
expr.get_type().clone()
);
let assignee = self.fold_assignee(assignee)?;
if let (Ok(a), Ok(e)) = (
ConcreteType::try_from(assignee.get_type()),
ConcreteType::try_from(expr.get_type()),
) {
println!("{} vs {}", a, e);
if a != e {
return Err(Error::Type(format!(
"Cannot assign {} of type {} to {} of type {}",

View file

@ -231,6 +231,7 @@ mod tests {
ty: ConcreteType::Struct(ConcreteStructType::new(
"".into(),
"Foo".into(),
vec![],
vec![
ConcreteStructMember::new(String::from("a"), ConcreteType::FieldElement),
ConcreteStructMember::new(String::from("b"), ConcreteType::Boolean),
@ -240,6 +241,7 @@ mod tests {
outputs: vec![ConcreteType::Struct(ConcreteStructType::new(
"".into(),
"Foo".into(),
vec![],
vec![
ConcreteStructMember::new(String::from("a"), ConcreteType::FieldElement),
ConcreteStructMember::new(String::from("b"), ConcreteType::Boolean),
@ -258,6 +260,7 @@ mod tests {
"type": "struct",
"components": {
"name": "Foo",
"generics": [],
"members": [
{
"name": "a",
@ -276,6 +279,7 @@ mod tests {
"type": "struct",
"components": {
"name": "Foo",
"generics": [],
"members": [
{
"name": "a",
@ -305,11 +309,13 @@ mod tests {
ty: ConcreteType::Struct(ConcreteStructType::new(
"".into(),
"Foo".into(),
vec![],
vec![ConcreteStructMember::new(
String::from("bar"),
ConcreteType::Struct(ConcreteStructType::new(
"".into(),
"Bar".into(),
vec![],
vec![
ConcreteStructMember::new(
String::from("a"),
@ -338,12 +344,14 @@ mod tests {
"type": "struct",
"components": {
"name": "Foo",
"generics": [],
"members": [
{
"name": "bar",
"type": "struct",
"components": {
"name": "Bar",
"generics": [],
"members": [
{
"name": "a",
@ -378,6 +386,7 @@ mod tests {
ConcreteType::Struct(ConcreteStructType::new(
"".into(),
"Foo".into(),
vec![],
vec![
ConcreteStructMember::new(
String::from("b"),
@ -406,6 +415,7 @@ mod tests {
"type": "struct",
"components": {
"name": "Foo",
"generics": [],
"members": [
{
"name": "b",

View file

@ -138,6 +138,11 @@ pub trait Folder<'ast, T: Field>: Sized {
fn fold_struct_type(&mut self, t: StructType<'ast, T>) -> StructType<'ast, T> {
StructType {
generics: t
.generics
.into_iter()
.map(|g| g.map(|g| self.fold_uint_expression(g)))
.collect(),
members: t
.members
.into_iter()
@ -175,6 +180,11 @@ pub trait Folder<'ast, T: Field>: Sized {
t: DeclarationStructType<'ast>,
) -> DeclarationStructType<'ast> {
DeclarationStructType {
generics: t
.generics
.into_iter()
.map(|g| g.map(|g| self.fold_declaration_constant(g)))
.collect(),
members: t
.members
.into_iter()
@ -1044,7 +1054,6 @@ pub fn fold_constant<'ast, T: Field, F: Folder<'ast, T>>(
c: TypedConstant<'ast, T>,
) -> TypedConstant<'ast, T> {
TypedConstant {
ty: f.fold_type(c.ty),
expression: f.fold_expression(c.expression),
}
}

View file

@ -1,9 +1,12 @@
use crate::typed_absy::types::{ArrayType, Type};
use crate::typed_absy::types::{
ArrayType, DeclarationArrayType, DeclarationConstant, DeclarationType, GArrayType, GStructType,
GType, StructMember, StructType, Type,
};
use crate::typed_absy::UBitwidth;
use crate::typed_absy::{
ArrayExpression, ArrayExpressionInner, BooleanExpression, FieldElementExpression, IfElse,
IfElseExpression, Select, SelectExpression, StructExpression, Typed, TypedExpression,
TypedExpressionOrSpread, TypedSpread, UExpression, UExpressionInner,
IfElseExpression, Select, SelectExpression, StructExpression, StructExpressionInner, Typed,
TypedExpression, TypedExpressionOrSpread, TypedSpread, UExpression, UExpressionInner,
};
use num_bigint::BigUint;
use std::convert::TryFrom;
@ -14,20 +17,46 @@ use zokrates_field::Field;
type TypedExpressionPair<'ast, T> = (TypedExpression<'ast, T>, TypedExpression<'ast, T>);
impl<'ast, T: Field> TypedExpressionOrSpread<'ast, T> {
pub fn align_to_type(e: Self, ty: Type<'ast, T>) -> Result<Self, (Self, Type<'ast, T>)> {
pub fn align_to_type<S: PartialEq<UExpression<'ast, T>>>(
e: Self,
ty: &GArrayType<S>,
) -> Result<Self, (Self, &GArrayType<S>)> {
match e {
TypedExpressionOrSpread::Expression(e) => TypedExpression::align_to_type(e, ty)
TypedExpressionOrSpread::Expression(e) => TypedExpression::align_to_type(e, &ty.ty)
.map(|e| e.into())
.map_err(|(e, t)| (e.into(), t)),
TypedExpressionOrSpread::Spread(s) => {
ArrayExpression::try_from_int(s.array, ty.clone())
.map(|e| TypedExpressionOrSpread::Spread(TypedSpread { array: e }))
.map_err(|e| (e.into(), ty))
}
.map_err(|(e, _)| (e.into(), ty)),
TypedExpressionOrSpread::Spread(s) => ArrayExpression::try_from_int(s.array, ty)
.map(|e| TypedExpressionOrSpread::Spread(TypedSpread { array: e }))
.map_err(|e| (e.into(), ty)),
}
}
}
fn get_common_type<'a, T: Field>(t: Type<'a, T>, u: Type<'a, T>) -> Result<Type<'a, T>, ()> {
match (t, u) {
(Type::Int, Type::Int) => Err(()),
(Type::Int, u) => Ok(u),
(t, Type::Int) => Ok(t),
(Type::Array(t), Type::Array(u)) => Ok(Type::Array(ArrayType::new(
get_common_type(*t.ty, *u.ty)?,
t.size,
))),
(Type::Struct(t), Type::Struct(u)) => Ok(Type::Struct(StructType {
members: t
.members
.into_iter()
.zip(u.members.into_iter())
.map(|(m_t, m_u)| {
get_common_type(*m_t.ty.clone(), *m_u.ty)
.map(|ty| StructMember { ty: box ty, ..m_t })
})
.collect::<Result<Vec<_>, _>>()?,
..t
})),
(t, _) => Ok(t),
}
}
impl<'ast, T: Field> TypedExpression<'ast, T> {
// return two TypedExpression, replacing IntExpression by FieldElement or Uint to try to align the two types if possible.
// Post condition is that (lhs, rhs) cannot be made equal by further removing IntExpressions
@ -51,7 +80,7 @@ impl<'ast, T: Field> TypedExpression<'ast, T> {
.into(),
)),
(Int(lhs), Uint(rhs)) => Ok((
UExpression::try_from_int(lhs, rhs.bitwidth())
UExpression::try_from_int(lhs, &rhs.bitwidth())
.map_err(|lhs| (lhs.into(), rhs.clone().into()))?
.into(),
Uint(rhs),
@ -60,47 +89,46 @@ impl<'ast, T: Field> TypedExpression<'ast, T> {
let bitwidth = lhs.bitwidth();
Ok((
Uint(lhs.clone()),
UExpression::try_from_int(rhs, bitwidth)
UExpression::try_from_int(rhs, &bitwidth)
.map_err(|rhs| (lhs.into(), rhs.into()))?
.into(),
))
}
(Array(lhs), Array(rhs)) => {
fn get_common_type<'a, T: Field>(
t: Type<'a, T>,
u: Type<'a, T>,
) -> Result<Type<'a, T>, ()> {
match (t, u) {
(Type::Int, Type::Int) => Err(()),
(Type::Int, u) => Ok(u),
(t, Type::Int) => Ok(t),
(Type::Array(t), Type::Array(u)) => Ok(Type::Array(ArrayType::new(
get_common_type(*t.ty, *u.ty)?,
t.size,
))),
(t, _) => Ok(t),
}
}
let common_type = get_common_type(lhs.get_type().clone(), rhs.get_type().clone())
.map_err(|_| (lhs.clone().into(), rhs.clone().into()))?;
let common_type =
get_common_type(lhs.inner_type().clone(), rhs.inner_type().clone())
.map_err(|_| (lhs.clone().into(), rhs.clone().into()))?;
let common_type = match common_type {
Type::Array(ty) => ty,
_ => unreachable!(),
};
Ok((
ArrayExpression::try_from_int(lhs.clone(), common_type.clone())
ArrayExpression::try_from_int(lhs.clone(), &common_type)
.map_err(|lhs| (lhs.clone(), rhs.clone().into()))?
.into(),
ArrayExpression::try_from_int(rhs, common_type)
ArrayExpression::try_from_int(rhs, &common_type)
.map_err(|rhs| (lhs.clone().into(), rhs.clone()))?
.into(),
))
}
(Struct(lhs), Struct(rhs)) => {
if lhs.get_type() == rhs.get_type() {
Ok((Struct(lhs), Struct(rhs)))
} else {
Err((Struct(lhs), Struct(rhs)))
}
let common_type = get_common_type(lhs.get_type(), rhs.get_type())
.map_err(|_| (lhs.clone().into(), rhs.clone().into()))?;
let common_type = match common_type {
Type::Struct(ty) => ty,
_ => unreachable!(),
};
Ok((
StructExpression::try_from_int(lhs.clone(), &common_type)
.map_err(|lhs| (lhs.clone(), rhs.clone().into()))?
.into(),
StructExpression::try_from_int(rhs, &common_type)
.map_err(|rhs| (lhs.clone().into(), rhs.clone()))?
.into(),
))
}
(Uint(lhs), Uint(rhs)) => Ok((lhs.into(), rhs.into())),
(Boolean(lhs), Boolean(rhs)) => Ok((lhs.into(), rhs.into())),
@ -110,22 +138,25 @@ impl<'ast, T: Field> TypedExpression<'ast, T> {
}
}
pub fn align_to_type(e: Self, ty: Type<'ast, T>) -> Result<Self, (Self, Type<'ast, T>)> {
match ty.clone() {
Type::FieldElement => {
pub fn align_to_type<S: PartialEq<UExpression<'ast, T>>>(
e: Self,
ty: &GType<S>,
) -> Result<Self, (Self, &GType<S>)> {
match ty {
GType::FieldElement => {
FieldElementExpression::try_from_typed(e).map(TypedExpression::from)
}
Type::Boolean => BooleanExpression::try_from_typed(e).map(TypedExpression::from),
Type::Uint(bitwidth) => {
GType::Boolean => BooleanExpression::try_from_typed(e).map(TypedExpression::from),
GType::Uint(bitwidth) => {
UExpression::try_from_typed(e, bitwidth).map(TypedExpression::from)
}
Type::Array(array_ty) => {
ArrayExpression::try_from_typed(e, *array_ty.ty).map(TypedExpression::from)
GType::Array(array_ty) => {
ArrayExpression::try_from_typed(e, array_ty).map(TypedExpression::from)
}
Type::Struct(struct_ty) => {
GType::Struct(struct_ty) => {
StructExpression::try_from_typed(e, struct_ty).map(TypedExpression::from)
}
Type::Int => Err(e),
GType::Int => Err(e),
}
.map_err(|e| (e, ty))
}
@ -299,7 +330,7 @@ impl<'ast, T: Field> FieldElementExpression<'ast, T> {
)),
IntExpression::Pow(box e1, box e2) => Ok(Self::Pow(
box Self::try_from_int(e1)?,
box UExpression::try_from_int(e2, UBitwidth::B32)?,
box UExpression::try_from_int(e2, &UBitwidth::B32)?,
)),
IntExpression::Div(box e1, box e2) => Ok(Self::Div(
box Self::try_from_int(e1)?,
@ -323,15 +354,21 @@ impl<'ast, T: Field> FieldElementExpression<'ast, T> {
let values = values
.into_iter()
.map(|v| {
TypedExpressionOrSpread::align_to_type(v, Type::FieldElement)
.map_err(|(e, _)| match e {
TypedExpressionOrSpread::Expression(e) => {
IntExpression::try_from(e).unwrap()
}
TypedExpressionOrSpread::Spread(a) => {
IntExpression::select(a.array, 0u32)
}
})
TypedExpressionOrSpread::align_to_type(
v,
&DeclarationArrayType::new(
DeclarationType::FieldElement,
DeclarationConstant::Concrete(0),
),
)
.map_err(|(e, _)| match e {
TypedExpressionOrSpread::Expression(e) => {
IntExpression::try_from(e).unwrap()
}
TypedExpressionOrSpread::Spread(a) => {
IntExpression::select(a.array, 0u32)
}
})
})
.collect::<Result<Vec<_>, _>>()?;
Ok(FieldElementExpression::select(
@ -351,10 +388,10 @@ impl<'ast, T: Field> FieldElementExpression<'ast, T> {
impl<'ast, T: Field> UExpression<'ast, T> {
pub fn try_from_typed(
e: TypedExpression<'ast, T>,
bitwidth: UBitwidth,
bitwidth: &UBitwidth,
) -> Result<Self, TypedExpression<'ast, T>> {
match e {
TypedExpression::Uint(e) => match e.bitwidth == bitwidth {
TypedExpression::Uint(e) => match e.bitwidth == *bitwidth {
true => Ok(e),
_ => Err(TypedExpression::Uint(e)),
},
@ -367,7 +404,7 @@ impl<'ast, T: Field> UExpression<'ast, T> {
pub fn try_from_int(
i: IntExpression<'ast, T>,
bitwidth: UBitwidth,
bitwidth: &UBitwidth,
) -> Result<Self, IntExpression<'ast, T>> {
use self::IntExpression::*;
@ -377,7 +414,7 @@ impl<'ast, T: Field> UExpression<'ast, T> {
Ok(UExpressionInner::Value(
u128::from_str_radix(&i.to_str_radix(16), 16).unwrap(),
)
.annotate(bitwidth))
.annotate(*bitwidth))
} else {
Err(Value(i))
}
@ -435,20 +472,26 @@ impl<'ast, T: Field> UExpression<'ast, T> {
let values = values
.into_iter()
.map(|v| {
TypedExpressionOrSpread::align_to_type(v, Type::Uint(bitwidth))
.map_err(|(e, _)| match e {
TypedExpressionOrSpread::Expression(e) => {
IntExpression::try_from(e).unwrap()
}
TypedExpressionOrSpread::Spread(a) => {
IntExpression::select(a.array, 0u32)
}
})
TypedExpressionOrSpread::align_to_type(
v,
&DeclarationArrayType::new(
DeclarationType::Uint(*bitwidth),
DeclarationConstant::Concrete(0),
),
)
.map_err(|(e, _)| match e {
TypedExpressionOrSpread::Expression(e) => {
IntExpression::try_from(e).unwrap()
}
TypedExpressionOrSpread::Spread(a) => {
IntExpression::select(a.array, 0u32)
}
})
})
.collect::<Result<Vec<_>, _>>()?;
Ok(UExpression::select(
ArrayExpressionInner::Value(values.into())
.annotate(Type::Uint(bitwidth), size),
.annotate(Type::Uint(*bitwidth), size),
index,
))
}
@ -461,35 +504,35 @@ impl<'ast, T: Field> UExpression<'ast, T> {
}
impl<'ast, T: Field> ArrayExpression<'ast, T> {
pub fn try_from_typed(
pub fn try_from_typed<S: PartialEq<UExpression<'ast, T>>>(
e: TypedExpression<'ast, T>,
target_inner_ty: Type<'ast, T>,
target_array_ty: &GArrayType<S>,
) -> Result<Self, TypedExpression<'ast, T>> {
match e {
TypedExpression::Array(e) => Self::try_from_int(e.clone(), target_inner_ty)
TypedExpression::Array(e) => Self::try_from_int(e.clone(), target_array_ty)
.map_err(|_| TypedExpression::Array(e)),
e => Err(e),
}
}
// precondition: `array` is only made of inline arrays and repeat constructs unless it does not contain the Integer type
pub fn try_from_int(
pub fn try_from_int<S: PartialEq<UExpression<'ast, T>>>(
array: Self,
target_inner_ty: Type<'ast, T>,
target_array_ty: &GArrayType<S>,
) -> Result<Self, TypedExpression<'ast, T>> {
let array_ty = array.ty();
// elements must fit in the target type
match array.into_inner() {
ArrayExpressionInner::Value(inline_array) => {
let res = match target_inner_ty.clone() {
Type::Int => Ok(inline_array),
t => {
let res = match &*target_array_ty.ty {
GType::Int => Ok(inline_array),
_ => {
// try to convert all elements to the target type
inline_array
.into_iter()
.map(|v| {
TypedExpressionOrSpread::align_to_type(v, t.clone()).map_err(
TypedExpressionOrSpread::align_to_type(v, &target_array_ty).map_err(
|(e, _)| match e {
TypedExpressionOrSpread::Expression(e) => e,
TypedExpressionOrSpread::Spread(a) => {
@ -508,11 +551,11 @@ impl<'ast, T: Field> ArrayExpression<'ast, T> {
Ok(ArrayExpressionInner::Value(res).annotate(inner_ty, array_ty.size))
}
ArrayExpressionInner::Repeat(box e, box count) => {
match target_inner_ty.clone() {
Type::Int => Ok(ArrayExpressionInner::Repeat(box e, box count)
match &*target_array_ty.ty {
GType::Int => Ok(ArrayExpressionInner::Repeat(box e, box count)
.annotate(Type::Int, array_ty.size)),
// try to align the repeated element to the target type
t => TypedExpression::align_to_type(e, t)
t => TypedExpression::align_to_type(e, &t)
.map(|e| {
let ty = e.get_type().clone();
@ -523,7 +566,7 @@ impl<'ast, T: Field> ArrayExpression<'ast, T> {
}
}
a => {
if array_ty.ty.weak_eq(&target_inner_ty) {
if *target_array_ty.ty == *array_ty.ty {
Ok(a.annotate(*array_ty.ty, array_ty.size))
} else {
Err(a.annotate(*array_ty.ty, array_ty.size).into())
@ -533,6 +576,50 @@ impl<'ast, T: Field> ArrayExpression<'ast, T> {
}
}
impl<'ast, T: Field> StructExpression<'ast, T> {
pub fn try_from_int<S: PartialEq<UExpression<'ast, T>>>(
struc: Self,
target_struct_ty: &GStructType<S>,
) -> Result<Self, TypedExpression<'ast, T>> {
let struct_ty = struc.ty().clone();
match struc.into_inner() {
StructExpressionInner::Value(inline_struct) => inline_struct
.into_iter()
.zip(target_struct_ty.members.iter())
.map(|(value, target_member)| {
TypedExpression::align_to_type(value, &*target_member.ty)
})
.collect::<Result<Vec<_>, _>>()
.map(|v| StructExpressionInner::Value(v).annotate(struct_ty.clone()))
.map_err(|_| unimplemented!()),
s => {
if struct_ty
.members
.iter()
.zip(target_struct_ty.members.iter())
.all(|(m, target_m)| *target_m.ty == *m.ty)
{
Ok(s.annotate(struct_ty.clone()))
} else {
Err(s.annotate(struct_ty.clone()).into())
}
}
}
}
pub fn try_from_typed<S: PartialEq<UExpression<'ast, T>>>(
e: TypedExpression<'ast, T>,
target_struct_ty: &GStructType<S>,
) -> Result<Self, TypedExpression<'ast, T>> {
match e {
TypedExpression::Struct(e) => Self::try_from_int(e.clone(), target_struct_ty)
.map_err(|_| TypedExpression::Struct(e)),
e => Err(e),
}
}
}
impl<'ast, T> From<BigUint> for IntExpression<'ast, T> {
fn from(v: BigUint) -> Self {
IntExpression::Value(v)
@ -652,7 +739,7 @@ mod tests {
for (r, e) in expressions
.into_iter()
.map(|e| UExpression::try_from_int(e, UBitwidth::B32).unwrap())
.map(|e| UExpression::try_from_int(e, &UBitwidth::B32).unwrap())
.zip(expected)
{
assert_eq!(r, e);
@ -665,7 +752,7 @@ mod tests {
for e in should_error
.into_iter()
.map(|e| UExpression::try_from_int(e, UBitwidth::B32))
.map(|e| UExpression::try_from_int(e, &UBitwidth::B32))
{
assert!(e.is_err());
}

View file

@ -107,15 +107,19 @@ impl<'ast, T: Field> TypedProgram<'ast, T> {
.arguments
.iter()
.map(|p| {
println!("{:#?}", p);
types::ConcreteType::try_from(types::Type::<T>::from(p.id._type.clone()))
.map(|ty| AbiInput {
public: !p.private,
name: p.id.id.to_string(),
ty,
})
.unwrap()
types::ConcreteType::try_from(
crate::typed_absy::types::try_from_g_type::<
crate::typed_absy::types::DeclarationConstant<'ast>,
UExpression<'ast, T>,
>(p.id._type.clone())
.unwrap(),
)
.map(|ty| AbiInput {
public: !p.private,
name: p.id.id.to_string(),
ty,
})
.unwrap()
})
.collect(),
outputs: main
@ -123,7 +127,14 @@ impl<'ast, T: Field> TypedProgram<'ast, T> {
.outputs
.iter()
.map(|ty| {
types::ConcreteType::try_from(types::Type::<T>::from(ty.clone())).unwrap()
types::ConcreteType::try_from(
crate::typed_absy::types::try_from_g_type::<
crate::typed_absy::types::DeclarationConstant<'ast>,
UExpression<'ast, T>,
>(ty.clone())
.unwrap(),
)
.unwrap()
})
.collect(),
}
@ -194,7 +205,7 @@ impl<'ast, T: fmt::Display> fmt::Display for TypedModule<'ast, T> {
.iter()
.map(|(id, symbol)| match symbol {
TypedConstantSymbol::Here(ref tc) => {
format!("const {} {} = {}", tc.ty, id.id, tc.expression)
format!("const {} {} = {}", "todo", id.id, tc.expression)
}
TypedConstantSymbol::There(ref imported_id) => {
format!(
@ -300,27 +311,25 @@ impl<'ast, T: fmt::Display> fmt::Display for TypedFunction<'ast, T> {
#[derive(Clone, PartialEq, Debug)]
pub struct TypedConstant<'ast, T> {
// the type is already stored in the TypedExpression, but we want to avoid awkward trait bounds in `fmt::Display`
pub ty: Type<'ast, T>,
pub expression: TypedExpression<'ast, T>,
}
impl<'ast, T> TypedConstant<'ast, T> {
pub fn new(ty: Type<'ast, T>, expression: TypedExpression<'ast, T>) -> Self {
TypedConstant { ty, expression }
pub fn new(expression: TypedExpression<'ast, T>) -> Self {
TypedConstant { expression }
}
}
impl<'ast, T: fmt::Display> fmt::Display for TypedConstant<'ast, T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
// using `self.expression.get_type()` would be better here but ends up requiring stronger trait bounds
write!(f, "const {}({})", self.ty, self.expression)
write!(f, "const {}({})", "todo", self.expression)
}
}
impl<'ast, T: Clone> Typed<'ast, T> for TypedConstant<'ast, T> {
impl<'ast, T: Field> Typed<'ast, T> for TypedConstant<'ast, T> {
fn get_type(&self) -> Type<'ast, T> {
self.ty.clone()
self.expression.get_type()
}
}
@ -1168,24 +1177,6 @@ pub struct StructExpression<'ast, T> {
inner: StructExpressionInner<'ast, T>,
}
impl<'ast, T: Field> StructExpression<'ast, T> {
pub fn try_from_typed(
e: TypedExpression<'ast, T>,
target_struct_ty: StructType<'ast, T>,
) -> Result<Self, TypedExpression<'ast, T>> {
match e {
TypedExpression::Struct(e) => {
if e.ty() == &target_struct_ty {
Ok(e)
} else {
Err(TypedExpression::Struct(e))
}
}
e => Err(e),
}
}
}
impl<'ast, T> StructExpression<'ast, T> {
pub fn ty(&self) -> &StructType<'ast, T> {
&self.ty

View file

@ -225,6 +225,11 @@ pub trait ResultFolder<'ast, T: Field>: Sized {
t: StructType<'ast, T>,
) -> Result<StructType<'ast, T>, Self::Error> {
Ok(StructType {
generics: t
.generics
.into_iter()
.map(|g| g.map(|g| self.fold_uint_expression(g)).transpose())
.collect::<Result<Vec<_>, _>>()?,
members: t
.members
.into_iter()
@ -260,6 +265,11 @@ pub trait ResultFolder<'ast, T: Field>: Sized {
t: DeclarationStructType<'ast>,
) -> Result<DeclarationStructType<'ast>, Self::Error> {
Ok(DeclarationStructType {
generics: t
.generics
.into_iter()
.map(|g| g.map(|g| self.fold_declaration_constant(g)).transpose())
.collect::<Result<Vec<_>, _>>()?,
members: t
.members
.into_iter()
@ -1092,7 +1102,6 @@ pub fn fold_constant<'ast, T: Field, F: ResultFolder<'ast, T>>(
c: TypedConstant<'ast, T>,
) -> Result<TypedConstant<'ast, T>, F::Error> {
Ok(TypedConstant {
ty: f.fold_type(c.ty)?,
expression: f.fold_expression(c.expression)?,
})
}

View file

@ -122,6 +122,21 @@ pub enum DeclarationConstant<'ast> {
Constant(CanonicalConstantIdentifier<'ast>),
}
impl<'ast, T> PartialEq<UExpression<'ast, T>> for DeclarationConstant<'ast> {
fn eq(&self, other: &UExpression<'ast, T>) -> bool {
match (self, other.as_inner()) {
(DeclarationConstant::Concrete(c), UExpressionInner::Value(v)) => *c == *v as u32,
_ => true,
}
}
}
impl<'ast, T> PartialEq<DeclarationConstant<'ast>> for UExpression<'ast, T> {
fn eq(&self, other: &DeclarationConstant<'ast>) -> bool {
other.eq(self)
}
}
impl<'ast> From<u32> for DeclarationConstant<'ast> {
fn from(e: u32) -> Self {
DeclarationConstant::Concrete(e)
@ -198,7 +213,7 @@ impl<'ast> TryInto<usize> for DeclarationConstant<'ast> {
pub type MemberId = String;
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, PartialOrd, Ord)]
#[derive(Debug, Clone, Eq, Hash, Serialize, Deserialize, PartialOrd, Ord)]
pub struct GStructMember<S> {
#[serde(rename = "name")]
pub id: MemberId,
@ -210,8 +225,8 @@ pub type DeclarationStructMember<'ast> = GStructMember<DeclarationConstant<'ast>
pub type ConcreteStructMember = GStructMember<usize>;
pub type StructMember<'ast, T> = GStructMember<UExpression<'ast, T>>;
impl<'ast, T: PartialEq> PartialEq<DeclarationStructMember<'ast>> for StructMember<'ast, T> {
fn eq(&self, other: &DeclarationStructMember<'ast>) -> bool {
impl<'ast, S, R: PartialEq<S>> PartialEq<GStructMember<S>> for GStructMember<R> {
fn eq(&self, other: &GStructMember<S>) -> bool {
self.id == other.id && *self.ty == *other.ty
}
}
@ -239,19 +254,19 @@ impl<'ast, T> From<ConcreteStructMember> for StructMember<'ast, T> {
}
}
impl<'ast> From<ConcreteStructMember> for DeclarationStructMember<'ast> {
fn from(t: ConcreteStructMember) -> Self {
try_from_g_struct_member(t).unwrap()
}
}
// impl<'ast> From<ConcreteStructMember> for DeclarationStructMember<'ast> {
// fn from(t: ConcreteStructMember) -> Self {
// try_from_g_struct_member(t).unwrap()
// }
// }
impl<'ast, T> From<DeclarationStructMember<'ast>> for StructMember<'ast, T> {
fn from(t: DeclarationStructMember<'ast>) -> Self {
try_from_g_struct_member(t).unwrap()
}
}
// impl<'ast, T> From<DeclarationStructMember<'ast>> for StructMember<'ast, T> {
// fn from(t: DeclarationStructMember<'ast>) -> Self {
// try_from_g_struct_member(t).unwrap()
// }
// }
#[derive(Clone, PartialEq, Eq, Hash, Serialize, Deserialize, PartialOrd, Ord, Debug)]
#[derive(Clone, Eq, Hash, Serialize, Deserialize, PartialOrd, Ord, Debug)]
pub struct GArrayType<S> {
pub size: S,
#[serde(flatten)]
@ -262,13 +277,9 @@ pub type DeclarationArrayType<'ast> = GArrayType<DeclarationConstant<'ast>>;
pub type ConcreteArrayType = GArrayType<usize>;
pub type ArrayType<'ast, T> = GArrayType<UExpression<'ast, T>>;
impl<'ast, T: PartialEq> PartialEq<DeclarationArrayType<'ast>> for ArrayType<'ast, T> {
fn eq(&self, other: &DeclarationArrayType<'ast>) -> bool {
*self.ty == *other.ty
&& match (self.size.as_inner(), &other.size) {
(UExpressionInner::Value(l), DeclarationConstant::Concrete(r)) => *l as u32 == *r,
_ => true,
}
impl<'ast, S, R: PartialEq<S>> PartialEq<GArrayType<S>> for GArrayType<R> {
fn eq(&self, other: &GArrayType<S>) -> bool {
*self.ty == *other.ty && self.size == other.size
}
}
@ -298,22 +309,6 @@ impl<S: fmt::Display> fmt::Display for GArrayType<S> {
}
}
impl<'ast, T: PartialEq + fmt::Display> Type<'ast, T> {
// array type equality with non-strict size checks
// sizes always match unless they are different constants
pub fn weak_eq(&self, other: &Self) -> bool {
match (self, other) {
(Type::Array(t), Type::Array(u)) => t.ty.weak_eq(&u.ty),
(Type::Struct(t), Type::Struct(u)) => t
.members
.iter()
.zip(u.members.iter())
.all(|(m, n)| m.ty.weak_eq(&n.ty)),
(t, u) => t == u,
}
}
}
fn try_from_g_array_type<T: TryInto<U>, U>(
t: GArrayType<T>,
) -> Result<GArrayType<U>, SpecializationError> {
@ -370,9 +365,18 @@ pub type DeclarationStructType<'ast> = GStructType<DeclarationConstant<'ast>>;
pub type ConcreteStructType = GStructType<usize>;
pub type StructType<'ast, T> = GStructType<UExpression<'ast, T>>;
impl<S: PartialEq> PartialEq for GStructType<S> {
fn eq(&self, other: &Self) -> bool {
self.canonical_location.eq(&other.canonical_location) && self.generics.eq(&other.generics)
impl<'ast, S, R: PartialEq<S>> PartialEq<GStructType<S>> for GStructType<R> {
fn eq(&self, other: &GStructType<S>) -> bool {
self.canonical_location == other.canonical_location
&& self
.generics
.iter()
.zip(other.generics.iter())
.all(|(a, b)| match (a, b) {
(Some(a), Some(b)) => a == b,
(None, None) => true,
_ => false,
})
}
}
@ -427,11 +431,11 @@ impl<'ast> From<ConcreteStructType> for DeclarationStructType<'ast> {
}
}
impl<'ast, T> From<DeclarationStructType<'ast>> for StructType<'ast, T> {
fn from(t: DeclarationStructType<'ast>) -> Self {
try_from_g_struct_type(t).unwrap()
}
}
// impl<'ast, T> From<DeclarationStructType<'ast>> for StructType<'ast, T> {
// fn from(t: DeclarationStructType<'ast>) -> Self {
// try_from_g_struct_type(t).unwrap()
// }
// }
impl<S> GStructType<S> {
pub fn new(
@ -514,7 +518,7 @@ impl fmt::Display for UBitwidth {
}
}
#[derive(Clone, PartialEq, Eq, Hash, PartialOrd, Ord, Debug)]
#[derive(Clone, Eq, Hash, PartialOrd, Ord, Debug)]
pub enum GType<S> {
FieldElement,
Boolean,
@ -624,13 +628,13 @@ pub type DeclarationType<'ast> = GType<DeclarationConstant<'ast>>;
pub type ConcreteType = GType<usize>;
pub type Type<'ast, T> = GType<UExpression<'ast, T>>;
impl<'ast, T: PartialEq> PartialEq<DeclarationType<'ast>> for Type<'ast, T> {
fn eq(&self, other: &DeclarationType<'ast>) -> bool {
impl<'ast, S, R: PartialEq<S>> PartialEq<GType<S>> for GType<R> {
fn eq(&self, other: &GType<S>) -> bool {
use self::GType::*;
match (self, other) {
(Array(l), Array(r)) => l == r,
(Struct(l), Struct(r)) => l.canonical_location == r.canonical_location,
(Struct(l), Struct(r)) => l == r,
(FieldElement, FieldElement) | (Boolean, Boolean) => true,
(Uint(l), Uint(r)) => l == r,
_ => false,
@ -638,7 +642,7 @@ impl<'ast, T: PartialEq> PartialEq<DeclarationType<'ast>> for Type<'ast, T> {
}
}
fn try_from_g_type<T: TryInto<U>, U>(t: GType<T>) -> Result<GType<U>, SpecializationError> {
pub fn try_from_g_type<T: TryInto<U>, U>(t: GType<T>) -> Result<GType<U>, SpecializationError> {
match t {
GType::FieldElement => Ok(GType::FieldElement),
GType::Boolean => Ok(GType::Boolean),
@ -669,12 +673,6 @@ impl<'ast> From<ConcreteType> for DeclarationType<'ast> {
}
}
impl<'ast, T> From<DeclarationType<'ast>> for Type<'ast, T> {
fn from(t: DeclarationType<'ast>) -> Self {
try_from_g_type(t).unwrap()
}
}
impl<S, U: Into<S>> From<(GType<S>, U)> for GArrayType<S> {
fn from(tup: (GType<S>, U)) -> Self {
GArrayType {
@ -758,7 +756,7 @@ impl<'ast, T: fmt::Display + PartialEq + fmt::Debug> Type<'ast, T> {
pub fn can_be_specialized_to(&self, other: &DeclarationType) -> bool {
use self::GType::*;
if self == other {
if other == self {
true
} else {
match (self, other) {
@ -776,7 +774,13 @@ impl<'ast, T: fmt::Display + PartialEq + fmt::Debug> Type<'ast, T> {
}
_ => false,
},
(Struct(_), Struct(_)) => false,
(Struct(l), Struct(r)) => {
l.canonical_location == r.canonical_location
&& l.members
.iter()
.zip(r.members.iter())
.all(|(m, d_m)| m.ty.can_be_specialized_to(&*d_m.ty))
}
_ => false,
}
}
@ -889,14 +893,6 @@ impl<'ast, T> TryFrom<FunctionKey<'ast, T>> for ConcreteFunctionKey<'ast> {
}
}
// impl<'ast> TryFrom<DeclarationFunctionKey<'ast>> for ConcreteFunctionKey<'ast> {
// type Error = SpecializationError;
// fn try_from(k: DeclarationFunctionKey<'ast>) -> Result<Self, Self::Error> {
// try_from_g_function_key(k)
// }
// }
impl<'ast, T> From<ConcreteFunctionKey<'ast>> for FunctionKey<'ast, T> {
fn from(k: ConcreteFunctionKey<'ast>) -> Self {
try_from_g_function_key(k).unwrap()
@ -909,12 +905,6 @@ impl<'ast> From<ConcreteFunctionKey<'ast>> for DeclarationFunctionKey<'ast> {
}
}
impl<'ast, T> From<DeclarationFunctionKey<'ast>> for FunctionKey<'ast, T> {
fn from(k: DeclarationFunctionKey<'ast>) -> Self {
try_from_g_function_key(k).unwrap()
}
}
impl<'ast, S> GFunctionKey<'ast, S> {
pub fn with_location<T: Into<OwnedTypedModuleId>, U: Into<FunctionIdentifier<'ast>>>(
module: T,
@ -993,7 +983,7 @@ pub fn check_type<'ast, S: Clone + PartialEq + PartialEq<usize>>(
}
}
pub fn specialize_type<'ast, S: Clone + PartialEq + From<u32> + fmt::Debug>(
pub fn specialize_declaration_type<'ast, S: Clone + PartialEq + From<u32> + fmt::Debug>(
decl_ty: DeclarationType<'ast>,
constants: &GGenericsAssignment<'ast, S>,
) -> Result<GType<S>, GenericIdentifier<'ast>> {
@ -1002,7 +992,7 @@ pub fn specialize_type<'ast, S: Clone + PartialEq + From<u32> + fmt::Debug>(
DeclarationType::Array(t0) => {
// let s1 = t1.size.clone();
let ty = box specialize_type(*t0.ty, &constants)?;
let ty = box specialize_declaration_type(*t0.ty, &constants)?;
let size = match t0.size {
DeclarationConstant::Generic(s) => constants.0.get(&s).cloned().ok_or(s),
DeclarationConstant::Concrete(s) => Ok(s.into()),
@ -1022,7 +1012,8 @@ pub fn specialize_type<'ast, S: Clone + PartialEq + From<u32> + fmt::Debug>(
.into_iter()
.map(|m| {
let id = m.id;
specialize_type(*m.ty, constants).map(|ty| GStructMember { ty: box ty, id })
specialize_declaration_type(*m.ty, constants)
.map(|ty| GStructMember { ty: box ty, id })
})
.collect::<Result<_, _>>()?,
generics: s0
@ -1049,7 +1040,9 @@ pub fn specialize_type<'ast, S: Clone + PartialEq + From<u32> + fmt::Debug>(
})
}
pub use self::signature::{ConcreteSignature, DeclarationSignature, GSignature, Signature};
pub use self::signature::{
try_from_g_signature, ConcreteSignature, DeclarationSignature, GSignature, Signature,
};
pub mod signature {
use super::*;
@ -1194,7 +1187,7 @@ pub mod signature {
self.outputs
.clone()
.into_iter()
.map(|t| specialize_type(t, &constants))
.map(|t| specialize_declaration_type(t, &constants))
.collect::<Result<_, _>>()
}
}
@ -1244,12 +1237,6 @@ pub mod signature {
}
}
impl<'ast, T> From<DeclarationSignature<'ast>> for Signature<'ast, T> {
fn from(s: DeclarationSignature<'ast>) -> Self {
try_from_g_signature(s).unwrap()
}
}
impl<S: fmt::Display> fmt::Display for GSignature<S> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
if !self.generics.is_empty() {

View file

@ -45,7 +45,7 @@ impl<'ast, T> From<ConcreteVariable<'ast>> for Variable<'ast, T> {
impl<'ast, T> From<DeclarationVariable<'ast>> for Variable<'ast, T> {
fn from(v: DeclarationVariable<'ast>) -> Self {
let _type = v._type.into();
let _type = crate::typed_absy::types::try_from_g_type(v._type).unwrap();
Self { _type, id: v.id }
}