1
0
Fork 0
mirror of synced 2025-09-24 04:40:05 +00:00

simplify treatment of user defined types

This commit is contained in:
schaeff 2021-09-23 11:14:18 +03:00
parent 1ffbc024b2
commit 30ebec6f37

View file

@ -55,9 +55,33 @@ impl ErrorInner {
}
}
type GenericDeclarations<'ast> = Option<Vec<Option<DeclarationConstant<'ast>>>>;
type TypeMap<'ast> =
HashMap<OwnedModuleId, HashMap<UserTypeId, (DeclarationType<'ast>, GenericDeclarations<'ast>)>>;
// a single struct to cover all cases of user-defined types
#[derive(Debug, Clone)]
struct UserDeclarationType<'ast> {
generics: Vec<DeclarationConstant<'ast>>,
ty: DeclarationType<'ast>,
}
impl<'ast> UserDeclarationType<'ast> {
// returns the declared generics for this user type
// for alias of basic types this is empty
// for structs this is the same as the used generics
// for aliases of structs this is the names of the generics declared on the left side of the type declaration
fn declaration_generics(&self) -> Vec<&'ast str> {
self.generics
.iter()
.filter_map(|g| match g {
DeclarationConstant::Generic(g) => Some(g),
_ => None,
})
.collect::<BTreeSet<_>>() // we collect into a BTreeSet because draining it after yields the element in the right order defined by Ord
.into_iter()
.map(|g| g.name())
.collect()
}
}
type TypeMap<'ast> = HashMap<OwnedModuleId, HashMap<UserTypeId, UserDeclarationType<'ast>>>;
type ConstantMap<'ast> =
HashMap<OwnedModuleId, HashMap<ConstantIdentifier<'ast>, DeclarationType<'ast>>>;
@ -356,7 +380,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
ty: TypeDefinitionNode<'ast>,
module_id: &ModuleId,
state: &State<'ast, T>,
) -> Result<(DeclarationType<'ast>, GenericDeclarations<'ast>), Vec<ErrorInner>> {
) -> Result<UserDeclarationType<'ast>, Vec<ErrorInner>> {
let pos = ty.pos();
let ty = ty.value;
@ -382,9 +406,9 @@ impl<'ast, T: Field> Checker<'ast, T> {
} else {
match generics_map.insert(g.value, index).is_none() {
true => {
generics.push(Some(DeclarationConstant::Generic(
generics.push(DeclarationConstant::Generic(
GenericIdentifier::with_name(g.value).with_index(index),
)));
));
}
false => {
errors.push(ErrorInner {
@ -420,7 +444,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
return Err(errors);
}
Ok((ty, Some(generics)))
Ok(UserDeclarationType { generics, ty })
}
Err(e) => {
errors.push(e);
@ -486,7 +510,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
s: StructDefinitionNode<'ast>,
module_id: &ModuleId,
state: &State<'ast, T>,
) -> Result<DeclarationType<'ast>, Vec<ErrorInner>> {
) -> Result<DeclarationStructType<'ast>, Vec<ErrorInner>> {
let pos = s.pos();
let s = s.value;
@ -569,7 +593,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
return Err(errors);
}
Ok(DeclarationType::Struct(DeclarationStructType::new(
Ok(DeclarationStructType::new(
module_id.to_path_buf(),
id,
generics,
@ -577,7 +601,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
.iter()
.map(|f| DeclarationStructMember::new(f.0.clone(), f.1.clone()))
.collect(),
)))
))
}
fn check_symbol_declaration(
@ -620,7 +644,18 @@ impl<'ast, T: Field> Checker<'ast, T> {
.types
.entry(module_id.to_path_buf())
.or_default()
.insert(declaration.id.to_string(), (ty, None))
.insert(
declaration.id.to_string(),
UserDeclarationType {
generics: ty
.generics
.clone()
.into_iter()
.map(|g| g.unwrap())
.collect(),
ty: DeclarationType::Struct(ty)
}
)
.is_none());
}
};
@ -781,18 +816,20 @@ impl<'ast, T: Field> Checker<'ast, T> {
.cloned();
match (function_candidates.len(), type_candidate, const_candidate) {
(0, Some((t, alias_generics)), None) => {
(0, Some(t), None) => {
// rename the type to the declared symbol
let t = match t {
DeclarationType::Struct(t) => DeclarationType::Struct(DeclarationStructType {
location: Some(StructLocation {
name: declaration.id.into(),
module: module_id.to_path_buf()
let t = UserDeclarationType {
ty: match t.ty {
DeclarationType::Struct(t) => DeclarationType::Struct(DeclarationStructType {
location: Some(StructLocation {
name: declaration.id.into(),
module: module_id.to_path_buf()
}),
..t
}),
..t
}),
_ => t // type alias
_ => t.ty // all other cases
},
..t
};
// we imported a type, so the symbol it gets bound to should not already exist
@ -814,7 +851,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
.types
.entry(module_id.to_path_buf())
.or_default()
.insert(declaration.id.to_string(), (t, alias_generics));
.insert(declaration.id.to_string(), t);
}
(0, None, Some(ty)) => {
match symbol_unifier.insert_constant(declaration.id) {
@ -1306,133 +1343,71 @@ impl<'ast, T: Field> Checker<'ast, T> {
)))
}
UnresolvedType::User(id, generics) => {
let (declaration_type, alias_generics) = types
.get(module_id)
.unwrap()
.get(&id)
.cloned()
.ok_or_else(|| ErrorInner {
pos: Some(pos),
message: format!("Undefined type {}", id),
})?;
let declared_ty =
types
.get(module_id)
.unwrap()
.get(&id)
.cloned()
.ok_or_else(|| ErrorInner {
pos: Some(pos),
message: format!("Undefined type {}", id),
})?;
let generic_identifiers = declared_ty.declaration_generics();
let declaration_type = declared_ty.ty;
// absence of generics is treated as 0 generics, as we do not provide inference for now
let generics = generics.unwrap_or_default();
// check generics
match (declaration_type, alias_generics) {
(DeclarationType::Struct(struct_type), None) => {
match struct_type.generics.len() == generics.len() {
true => {
// downcast the generics to identifiers, as this is the only possibility here
let generic_identifiers = struct_type.generics.iter().map(|c| {
match c.as_ref().unwrap() {
DeclarationConstant::Generic(g) => g.clone(),
_ => unreachable!(),
}
});
// build the generic assignment for this type
let assignment = GGenericsAssignment(generics
.into_iter()
.zip(generic_identifiers)
.map(|(e, g)| match e {
Some(e) => {
self
.check_expression(e, module_id, types)
.and_then(|e| {
UExpression::try_from_typed(e, &UBitwidth::B32)
.map(|e| (g, e))
.map_err(|e| ErrorInner {
pos: Some(pos),
message: format!("Expected u32 expression, but got expression of type {}", e.get_type()),
})
match generic_identifiers.len() == generics.len() {
true => {
// build the generic assignment for this type
let assignment = GGenericsAssignment(generics
.into_iter()
.zip(generic_identifiers)
.enumerate()
.map(|(i, (e, g))| match e {
Some(e) => {
self
.check_expression(e, module_id, types)
.and_then(|e| {
UExpression::try_from_typed(e, &UBitwidth::B32)
.map(|e| (GenericIdentifier::with_name(g).with_index(i), e))
.map_err(|e| ErrorInner {
pos: Some(pos),
message: format!("Expected u32 expression, but got expression of type {}", e.get_type()),
})
},
None => Err(ErrorInner {
pos: Some(pos),
message:
"Expected u32 constant or identifier, but found `_`. Generic inference is not supported yet."
.into(),
})
})
.collect::<Result<_, _>>()?);
},
None => Err(ErrorInner {
pos: Some(pos),
message:
"Expected u32 constant or identifier, but found `_`. Generic inference is not supported yet."
.into(),
})
})
.collect::<Result<_, _>>()?);
// specialize the declared type using the generic assignment
Ok(specialize_declaration_type(
DeclarationType::Struct(struct_type),
&assignment,
)
.unwrap())
}
false => Err(ErrorInner {
pos: Some(pos),
message: format!(
"Expected {} generic argument{} on type {}, but got {}",
struct_type.generics.len(),
if struct_type.generics.len() == 1 {
""
} else {
"s"
},
id,
generics.len()
),
}),
}
// specialize the declared type using the generic assignment
Ok(specialize_declaration_type(declaration_type, &assignment).unwrap())
}
(declaration_type, Some(alias_generics)) => {
match alias_generics.len() == generics.len() {
true => {
let generic_identifiers =
alias_generics.iter().map(|c| match c.as_ref().unwrap() {
DeclarationConstant::Generic(g) => g.clone(),
_ => unreachable!(),
});
// build the generic assignment for this type
let assignment = GGenericsAssignment(generics
.into_iter()
.zip(generic_identifiers)
.map(|(e, g)| match e {
Some(e) => {
self
.check_expression(e, module_id, types)
.and_then(|e| {
UExpression::try_from_typed(e, &UBitwidth::B32)
.map(|e| (g, e))
.map_err(|e| ErrorInner {
pos: Some(pos),
message: format!("Expected u32 expression, but got expression of type {}", e.get_type()),
})
})
},
None => Err(ErrorInner {
pos: Some(pos),
message:
"Expected u32 constant or identifier, but found `_`. Generic inference is not supported yet."
.into(),
})
})
.collect::<Result<_, _>>()?);
// specialize the declared type using the generic assignment
Ok(specialize_declaration_type(declaration_type, &assignment)
.unwrap())
}
false => Err(ErrorInner {
pos: Some(pos),
message: format!(
"Expected {} generic argument{} on type {}, but got {}",
alias_generics.len(),
if alias_generics.len() == 1 { "" } else { "s" },
id,
generics.len()
),
}),
}
}
_ => unreachable!(),
false => Err(ErrorInner {
pos: Some(pos),
message: format!(
"Expected {} generic argument{} on type {}, but got {}",
generic_identifiers.len(),
if generic_identifiers.len() == 1 {
""
} else {
"s"
},
id,
generics.len()
),
}),
}
}
}
@ -1527,7 +1502,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
)))
}
UnresolvedType::User(id, generics) => {
let (declared_ty, alias_generics) = state
let ty = state
.types
.get(module_id)
.unwrap()
@ -1538,99 +1513,47 @@ impl<'ast, T: Field> Checker<'ast, T> {
message: format!("Undefined type {}", id),
})?;
match (declared_ty, alias_generics) {
(ty, Some(alias_generics)) => {
let generics = generics.unwrap_or_default();
let checked_generics: Vec<_> = generics
.into_iter()
.map(|e| match e {
Some(e) => self
.check_generic_expression(
e,
module_id,
state.constants.get(module_id).unwrap_or(&HashMap::new()),
generics_map,
used_generics,
)
.map(Some),
None => Err(ErrorInner {
pos: Some(pos),
message: "Expected u32 constant or identifier, but found `_`"
.into(),
}),
})
.collect::<Result<_, _>>()?;
let generics = generics.unwrap_or_default();
let checked_generics: Vec<_> = generics
.into_iter()
.map(|e| match e {
Some(e) => self
.check_generic_expression(
e,
module_id,
state.constants.get(module_id).unwrap_or(&HashMap::new()),
generics_map,
used_generics,
)
.map(Some),
None => Err(ErrorInner {
pos: Some(pos),
message: "Expected u32 constant or identifier, but found `_`".into(),
}),
})
.collect::<Result<_, _>>()?;
match ty.generics.len() == checked_generics.len() {
true => {
let mut assignment = GGenericsAssignment::default();
assignment.0.extend(
alias_generics.iter().zip(checked_generics.iter()).map(
|(decl_g, g_val)| match decl_g.clone().unwrap() {
DeclarationConstant::Generic(g) => (g, g_val.clone().unwrap()),
_ => unreachable!(),
},
),
);
assignment.0.extend(ty.generics.iter().zip(checked_generics.iter()).map(|(decl_g, g_val)| match decl_g.clone() {
DeclarationConstant::Generic(g) => (g, g_val.clone().unwrap()),
_ => unreachable!("generic on declaration struct types must be generic identifiers")
}));
Ok(specialize_declaration_type(ty, &assignment).unwrap())
Ok(specialize_declaration_type(ty.ty, &assignment).unwrap())
}
(DeclarationType::Struct(declared_struct_ty), None) => {
let generics = generics.unwrap_or_default();
match declared_struct_ty.generics.len() == generics.len() {
true => {
let checked_generics: Vec<_> = generics
.into_iter()
.map(|e| match e {
Some(e) => self
.check_generic_expression(
e,
module_id,
state
.constants
.get(module_id)
.unwrap_or(&HashMap::new()),
generics_map,
used_generics,
)
.map(Some),
None => Err(ErrorInner {
pos: Some(pos),
message:
"Expected u32 constant or identifier, but found `_`"
.into(),
}),
})
.collect::<Result<_, _>>()?;
let mut assignment = GGenericsAssignment::default();
assignment.0.extend(declared_struct_ty.generics.iter().zip(checked_generics.iter()).map(|(decl_g, g_val)| match decl_g.clone().unwrap() {
DeclarationConstant::Generic(g) => (g, g_val.clone().unwrap()),
_ => unreachable!("generic on declaration struct types must be generic identifiers")
}));
Ok(DeclarationType::Struct(DeclarationStructType {
generics: checked_generics,
..declared_struct_ty
}))
}
false => Err(ErrorInner {
pos: Some(pos),
message: format!(
"Expected {} generic argument{} on type {}, but got {}",
declared_struct_ty.generics.len(),
if declared_struct_ty.generics.len() == 1 {
""
} else {
"s"
},
id,
generics.len()
),
}),
}
}
(declared_ty, _) => Ok(declared_ty),
false => Err(ErrorInner {
pos: Some(pos),
message: format!(
"Expected {} generic argument{} on type {}, but got {}",
ty.generics.len(),
if ty.generics.len() == 1 { "" } else { "s" },
id,
checked_generics.len()
),
}),
}
}
}
@ -3099,7 +3022,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
.into())
}
Expression::InlineStruct(id, inline_members) => {
let (ty, _) = match types.get(module_id).unwrap().get(&id).cloned() {
let ty = match types.get(module_id).unwrap().get(&id).cloned() {
None => Err(ErrorInner {
pos: Some(pos),
message: format!("Undefined type `{}`", id),
@ -3107,7 +3030,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
Some(ty) => Ok(ty),
}?;
let mut declared_struct_type = match ty {
let mut declared_struct_type = match ty.ty {
DeclarationType::Struct(struct_type) => struct_type,
_ => unreachable!(),
};
@ -5529,12 +5452,8 @@ mod tests {
}
.mock();
let expected_type = DeclarationType::Struct(DeclarationStructType::new(
"".into(),
"Foo".into(),
vec![],
vec![],
));
let expected_type =
DeclarationStructType::new("".into(), "Foo".into(), vec![], vec![]);
assert_eq!(
Checker::<Bn128Field>::new().check_struct_type_declaration(
@ -5570,7 +5489,7 @@ mod tests {
}
.mock();
let expected_type = DeclarationType::Struct(DeclarationStructType::new(
let expected_type = DeclarationStructType::new(
"".into(),
"Foo".into(),
vec![],
@ -5578,7 +5497,7 @@ mod tests {
DeclarationStructMember::new("foo".into(), DeclarationType::FieldElement),
DeclarationStructMember::new("bar".into(), DeclarationType::Boolean),
],
));
);
assert_eq!(
Checker::<Bn128Field>::new().check_struct_type_declaration(
@ -5681,7 +5600,7 @@ mod tests {
.get(&*MODULE_ID)
.unwrap()
.get(&"Bar".to_string())
.map(|(ty, _)| ty)
.map(|ty| &ty.ty)
.unwrap(),
&DeclarationType::Struct(DeclarationStructType::new(
(*MODULE_ID).clone(),