simplify treatment of user defined types
This commit is contained in:
parent
1ffbc024b2
commit
30ebec6f37
1 changed files with 159 additions and 240 deletions
|
@ -55,9 +55,33 @@ impl ErrorInner {
|
|||
}
|
||||
}
|
||||
|
||||
type GenericDeclarations<'ast> = Option<Vec<Option<DeclarationConstant<'ast>>>>;
|
||||
type TypeMap<'ast> =
|
||||
HashMap<OwnedModuleId, HashMap<UserTypeId, (DeclarationType<'ast>, GenericDeclarations<'ast>)>>;
|
||||
// a single struct to cover all cases of user-defined types
|
||||
#[derive(Debug, Clone)]
|
||||
struct UserDeclarationType<'ast> {
|
||||
generics: Vec<DeclarationConstant<'ast>>,
|
||||
ty: DeclarationType<'ast>,
|
||||
}
|
||||
|
||||
impl<'ast> UserDeclarationType<'ast> {
|
||||
// returns the declared generics for this user type
|
||||
// for alias of basic types this is empty
|
||||
// for structs this is the same as the used generics
|
||||
// for aliases of structs this is the names of the generics declared on the left side of the type declaration
|
||||
fn declaration_generics(&self) -> Vec<&'ast str> {
|
||||
self.generics
|
||||
.iter()
|
||||
.filter_map(|g| match g {
|
||||
DeclarationConstant::Generic(g) => Some(g),
|
||||
_ => None,
|
||||
})
|
||||
.collect::<BTreeSet<_>>() // we collect into a BTreeSet because draining it after yields the element in the right order defined by Ord
|
||||
.into_iter()
|
||||
.map(|g| g.name())
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
type TypeMap<'ast> = HashMap<OwnedModuleId, HashMap<UserTypeId, UserDeclarationType<'ast>>>;
|
||||
type ConstantMap<'ast> =
|
||||
HashMap<OwnedModuleId, HashMap<ConstantIdentifier<'ast>, DeclarationType<'ast>>>;
|
||||
|
||||
|
@ -356,7 +380,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
|
|||
ty: TypeDefinitionNode<'ast>,
|
||||
module_id: &ModuleId,
|
||||
state: &State<'ast, T>,
|
||||
) -> Result<(DeclarationType<'ast>, GenericDeclarations<'ast>), Vec<ErrorInner>> {
|
||||
) -> Result<UserDeclarationType<'ast>, Vec<ErrorInner>> {
|
||||
let pos = ty.pos();
|
||||
let ty = ty.value;
|
||||
|
||||
|
@ -382,9 +406,9 @@ impl<'ast, T: Field> Checker<'ast, T> {
|
|||
} else {
|
||||
match generics_map.insert(g.value, index).is_none() {
|
||||
true => {
|
||||
generics.push(Some(DeclarationConstant::Generic(
|
||||
generics.push(DeclarationConstant::Generic(
|
||||
GenericIdentifier::with_name(g.value).with_index(index),
|
||||
)));
|
||||
));
|
||||
}
|
||||
false => {
|
||||
errors.push(ErrorInner {
|
||||
|
@ -420,7 +444,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
|
|||
return Err(errors);
|
||||
}
|
||||
|
||||
Ok((ty, Some(generics)))
|
||||
Ok(UserDeclarationType { generics, ty })
|
||||
}
|
||||
Err(e) => {
|
||||
errors.push(e);
|
||||
|
@ -486,7 +510,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
|
|||
s: StructDefinitionNode<'ast>,
|
||||
module_id: &ModuleId,
|
||||
state: &State<'ast, T>,
|
||||
) -> Result<DeclarationType<'ast>, Vec<ErrorInner>> {
|
||||
) -> Result<DeclarationStructType<'ast>, Vec<ErrorInner>> {
|
||||
let pos = s.pos();
|
||||
let s = s.value;
|
||||
|
||||
|
@ -569,7 +593,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
|
|||
return Err(errors);
|
||||
}
|
||||
|
||||
Ok(DeclarationType::Struct(DeclarationStructType::new(
|
||||
Ok(DeclarationStructType::new(
|
||||
module_id.to_path_buf(),
|
||||
id,
|
||||
generics,
|
||||
|
@ -577,7 +601,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
|
|||
.iter()
|
||||
.map(|f| DeclarationStructMember::new(f.0.clone(), f.1.clone()))
|
||||
.collect(),
|
||||
)))
|
||||
))
|
||||
}
|
||||
|
||||
fn check_symbol_declaration(
|
||||
|
@ -620,7 +644,18 @@ impl<'ast, T: Field> Checker<'ast, T> {
|
|||
.types
|
||||
.entry(module_id.to_path_buf())
|
||||
.or_default()
|
||||
.insert(declaration.id.to_string(), (ty, None))
|
||||
.insert(
|
||||
declaration.id.to_string(),
|
||||
UserDeclarationType {
|
||||
generics: ty
|
||||
.generics
|
||||
.clone()
|
||||
.into_iter()
|
||||
.map(|g| g.unwrap())
|
||||
.collect(),
|
||||
ty: DeclarationType::Struct(ty)
|
||||
}
|
||||
)
|
||||
.is_none());
|
||||
}
|
||||
};
|
||||
|
@ -781,18 +816,20 @@ impl<'ast, T: Field> Checker<'ast, T> {
|
|||
.cloned();
|
||||
|
||||
match (function_candidates.len(), type_candidate, const_candidate) {
|
||||
(0, Some((t, alias_generics)), None) => {
|
||||
|
||||
(0, Some(t), None) => {
|
||||
// rename the type to the declared symbol
|
||||
let t = match t {
|
||||
DeclarationType::Struct(t) => DeclarationType::Struct(DeclarationStructType {
|
||||
location: Some(StructLocation {
|
||||
name: declaration.id.into(),
|
||||
module: module_id.to_path_buf()
|
||||
let t = UserDeclarationType {
|
||||
ty: match t.ty {
|
||||
DeclarationType::Struct(t) => DeclarationType::Struct(DeclarationStructType {
|
||||
location: Some(StructLocation {
|
||||
name: declaration.id.into(),
|
||||
module: module_id.to_path_buf()
|
||||
}),
|
||||
..t
|
||||
}),
|
||||
..t
|
||||
}),
|
||||
_ => t // type alias
|
||||
_ => t.ty // all other cases
|
||||
},
|
||||
..t
|
||||
};
|
||||
|
||||
// we imported a type, so the symbol it gets bound to should not already exist
|
||||
|
@ -814,7 +851,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
|
|||
.types
|
||||
.entry(module_id.to_path_buf())
|
||||
.or_default()
|
||||
.insert(declaration.id.to_string(), (t, alias_generics));
|
||||
.insert(declaration.id.to_string(), t);
|
||||
}
|
||||
(0, None, Some(ty)) => {
|
||||
match symbol_unifier.insert_constant(declaration.id) {
|
||||
|
@ -1306,133 +1343,71 @@ impl<'ast, T: Field> Checker<'ast, T> {
|
|||
)))
|
||||
}
|
||||
UnresolvedType::User(id, generics) => {
|
||||
let (declaration_type, alias_generics) = types
|
||||
.get(module_id)
|
||||
.unwrap()
|
||||
.get(&id)
|
||||
.cloned()
|
||||
.ok_or_else(|| ErrorInner {
|
||||
pos: Some(pos),
|
||||
message: format!("Undefined type {}", id),
|
||||
})?;
|
||||
let declared_ty =
|
||||
types
|
||||
.get(module_id)
|
||||
.unwrap()
|
||||
.get(&id)
|
||||
.cloned()
|
||||
.ok_or_else(|| ErrorInner {
|
||||
pos: Some(pos),
|
||||
message: format!("Undefined type {}", id),
|
||||
})?;
|
||||
|
||||
let generic_identifiers = declared_ty.declaration_generics();
|
||||
|
||||
let declaration_type = declared_ty.ty;
|
||||
|
||||
// absence of generics is treated as 0 generics, as we do not provide inference for now
|
||||
let generics = generics.unwrap_or_default();
|
||||
|
||||
// check generics
|
||||
match (declaration_type, alias_generics) {
|
||||
(DeclarationType::Struct(struct_type), None) => {
|
||||
match struct_type.generics.len() == generics.len() {
|
||||
true => {
|
||||
// downcast the generics to identifiers, as this is the only possibility here
|
||||
let generic_identifiers = struct_type.generics.iter().map(|c| {
|
||||
match c.as_ref().unwrap() {
|
||||
DeclarationConstant::Generic(g) => g.clone(),
|
||||
_ => unreachable!(),
|
||||
}
|
||||
});
|
||||
|
||||
// build the generic assignment for this type
|
||||
let assignment = GGenericsAssignment(generics
|
||||
.into_iter()
|
||||
.zip(generic_identifiers)
|
||||
.map(|(e, g)| match e {
|
||||
Some(e) => {
|
||||
self
|
||||
.check_expression(e, module_id, types)
|
||||
.and_then(|e| {
|
||||
UExpression::try_from_typed(e, &UBitwidth::B32)
|
||||
.map(|e| (g, e))
|
||||
.map_err(|e| ErrorInner {
|
||||
pos: Some(pos),
|
||||
message: format!("Expected u32 expression, but got expression of type {}", e.get_type()),
|
||||
})
|
||||
match generic_identifiers.len() == generics.len() {
|
||||
true => {
|
||||
// build the generic assignment for this type
|
||||
let assignment = GGenericsAssignment(generics
|
||||
.into_iter()
|
||||
.zip(generic_identifiers)
|
||||
.enumerate()
|
||||
.map(|(i, (e, g))| match e {
|
||||
Some(e) => {
|
||||
self
|
||||
.check_expression(e, module_id, types)
|
||||
.and_then(|e| {
|
||||
UExpression::try_from_typed(e, &UBitwidth::B32)
|
||||
.map(|e| (GenericIdentifier::with_name(g).with_index(i), e))
|
||||
.map_err(|e| ErrorInner {
|
||||
pos: Some(pos),
|
||||
message: format!("Expected u32 expression, but got expression of type {}", e.get_type()),
|
||||
})
|
||||
},
|
||||
None => Err(ErrorInner {
|
||||
pos: Some(pos),
|
||||
message:
|
||||
"Expected u32 constant or identifier, but found `_`. Generic inference is not supported yet."
|
||||
.into(),
|
||||
})
|
||||
})
|
||||
.collect::<Result<_, _>>()?);
|
||||
},
|
||||
None => Err(ErrorInner {
|
||||
pos: Some(pos),
|
||||
message:
|
||||
"Expected u32 constant or identifier, but found `_`. Generic inference is not supported yet."
|
||||
.into(),
|
||||
})
|
||||
})
|
||||
.collect::<Result<_, _>>()?);
|
||||
|
||||
// specialize the declared type using the generic assignment
|
||||
Ok(specialize_declaration_type(
|
||||
DeclarationType::Struct(struct_type),
|
||||
&assignment,
|
||||
)
|
||||
.unwrap())
|
||||
}
|
||||
false => Err(ErrorInner {
|
||||
pos: Some(pos),
|
||||
message: format!(
|
||||
"Expected {} generic argument{} on type {}, but got {}",
|
||||
struct_type.generics.len(),
|
||||
if struct_type.generics.len() == 1 {
|
||||
""
|
||||
} else {
|
||||
"s"
|
||||
},
|
||||
id,
|
||||
generics.len()
|
||||
),
|
||||
}),
|
||||
}
|
||||
// specialize the declared type using the generic assignment
|
||||
Ok(specialize_declaration_type(declaration_type, &assignment).unwrap())
|
||||
}
|
||||
(declaration_type, Some(alias_generics)) => {
|
||||
match alias_generics.len() == generics.len() {
|
||||
true => {
|
||||
let generic_identifiers =
|
||||
alias_generics.iter().map(|c| match c.as_ref().unwrap() {
|
||||
DeclarationConstant::Generic(g) => g.clone(),
|
||||
_ => unreachable!(),
|
||||
});
|
||||
|
||||
// build the generic assignment for this type
|
||||
let assignment = GGenericsAssignment(generics
|
||||
.into_iter()
|
||||
.zip(generic_identifiers)
|
||||
.map(|(e, g)| match e {
|
||||
Some(e) => {
|
||||
self
|
||||
.check_expression(e, module_id, types)
|
||||
.and_then(|e| {
|
||||
UExpression::try_from_typed(e, &UBitwidth::B32)
|
||||
.map(|e| (g, e))
|
||||
.map_err(|e| ErrorInner {
|
||||
pos: Some(pos),
|
||||
message: format!("Expected u32 expression, but got expression of type {}", e.get_type()),
|
||||
})
|
||||
})
|
||||
},
|
||||
None => Err(ErrorInner {
|
||||
pos: Some(pos),
|
||||
message:
|
||||
"Expected u32 constant or identifier, but found `_`. Generic inference is not supported yet."
|
||||
.into(),
|
||||
})
|
||||
})
|
||||
.collect::<Result<_, _>>()?);
|
||||
|
||||
// specialize the declared type using the generic assignment
|
||||
Ok(specialize_declaration_type(declaration_type, &assignment)
|
||||
.unwrap())
|
||||
}
|
||||
false => Err(ErrorInner {
|
||||
pos: Some(pos),
|
||||
message: format!(
|
||||
"Expected {} generic argument{} on type {}, but got {}",
|
||||
alias_generics.len(),
|
||||
if alias_generics.len() == 1 { "" } else { "s" },
|
||||
id,
|
||||
generics.len()
|
||||
),
|
||||
}),
|
||||
}
|
||||
}
|
||||
_ => unreachable!(),
|
||||
false => Err(ErrorInner {
|
||||
pos: Some(pos),
|
||||
message: format!(
|
||||
"Expected {} generic argument{} on type {}, but got {}",
|
||||
generic_identifiers.len(),
|
||||
if generic_identifiers.len() == 1 {
|
||||
""
|
||||
} else {
|
||||
"s"
|
||||
},
|
||||
id,
|
||||
generics.len()
|
||||
),
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1527,7 +1502,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
|
|||
)))
|
||||
}
|
||||
UnresolvedType::User(id, generics) => {
|
||||
let (declared_ty, alias_generics) = state
|
||||
let ty = state
|
||||
.types
|
||||
.get(module_id)
|
||||
.unwrap()
|
||||
|
@ -1538,99 +1513,47 @@ impl<'ast, T: Field> Checker<'ast, T> {
|
|||
message: format!("Undefined type {}", id),
|
||||
})?;
|
||||
|
||||
match (declared_ty, alias_generics) {
|
||||
(ty, Some(alias_generics)) => {
|
||||
let generics = generics.unwrap_or_default();
|
||||
let checked_generics: Vec<_> = generics
|
||||
.into_iter()
|
||||
.map(|e| match e {
|
||||
Some(e) => self
|
||||
.check_generic_expression(
|
||||
e,
|
||||
module_id,
|
||||
state.constants.get(module_id).unwrap_or(&HashMap::new()),
|
||||
generics_map,
|
||||
used_generics,
|
||||
)
|
||||
.map(Some),
|
||||
None => Err(ErrorInner {
|
||||
pos: Some(pos),
|
||||
message: "Expected u32 constant or identifier, but found `_`"
|
||||
.into(),
|
||||
}),
|
||||
})
|
||||
.collect::<Result<_, _>>()?;
|
||||
let generics = generics.unwrap_or_default();
|
||||
let checked_generics: Vec<_> = generics
|
||||
.into_iter()
|
||||
.map(|e| match e {
|
||||
Some(e) => self
|
||||
.check_generic_expression(
|
||||
e,
|
||||
module_id,
|
||||
state.constants.get(module_id).unwrap_or(&HashMap::new()),
|
||||
generics_map,
|
||||
used_generics,
|
||||
)
|
||||
.map(Some),
|
||||
None => Err(ErrorInner {
|
||||
pos: Some(pos),
|
||||
message: "Expected u32 constant or identifier, but found `_`".into(),
|
||||
}),
|
||||
})
|
||||
.collect::<Result<_, _>>()?;
|
||||
|
||||
match ty.generics.len() == checked_generics.len() {
|
||||
true => {
|
||||
let mut assignment = GGenericsAssignment::default();
|
||||
|
||||
assignment.0.extend(
|
||||
alias_generics.iter().zip(checked_generics.iter()).map(
|
||||
|(decl_g, g_val)| match decl_g.clone().unwrap() {
|
||||
DeclarationConstant::Generic(g) => (g, g_val.clone().unwrap()),
|
||||
_ => unreachable!(),
|
||||
},
|
||||
),
|
||||
);
|
||||
assignment.0.extend(ty.generics.iter().zip(checked_generics.iter()).map(|(decl_g, g_val)| match decl_g.clone() {
|
||||
DeclarationConstant::Generic(g) => (g, g_val.clone().unwrap()),
|
||||
_ => unreachable!("generic on declaration struct types must be generic identifiers")
|
||||
}));
|
||||
|
||||
Ok(specialize_declaration_type(ty, &assignment).unwrap())
|
||||
Ok(specialize_declaration_type(ty.ty, &assignment).unwrap())
|
||||
}
|
||||
(DeclarationType::Struct(declared_struct_ty), None) => {
|
||||
let generics = generics.unwrap_or_default();
|
||||
match declared_struct_ty.generics.len() == generics.len() {
|
||||
true => {
|
||||
let checked_generics: Vec<_> = generics
|
||||
.into_iter()
|
||||
.map(|e| match e {
|
||||
Some(e) => self
|
||||
.check_generic_expression(
|
||||
e,
|
||||
module_id,
|
||||
state
|
||||
.constants
|
||||
.get(module_id)
|
||||
.unwrap_or(&HashMap::new()),
|
||||
generics_map,
|
||||
used_generics,
|
||||
)
|
||||
.map(Some),
|
||||
None => Err(ErrorInner {
|
||||
pos: Some(pos),
|
||||
message:
|
||||
"Expected u32 constant or identifier, but found `_`"
|
||||
.into(),
|
||||
}),
|
||||
})
|
||||
.collect::<Result<_, _>>()?;
|
||||
|
||||
let mut assignment = GGenericsAssignment::default();
|
||||
|
||||
assignment.0.extend(declared_struct_ty.generics.iter().zip(checked_generics.iter()).map(|(decl_g, g_val)| match decl_g.clone().unwrap() {
|
||||
DeclarationConstant::Generic(g) => (g, g_val.clone().unwrap()),
|
||||
_ => unreachable!("generic on declaration struct types must be generic identifiers")
|
||||
}));
|
||||
|
||||
Ok(DeclarationType::Struct(DeclarationStructType {
|
||||
generics: checked_generics,
|
||||
..declared_struct_ty
|
||||
}))
|
||||
}
|
||||
false => Err(ErrorInner {
|
||||
pos: Some(pos),
|
||||
message: format!(
|
||||
"Expected {} generic argument{} on type {}, but got {}",
|
||||
declared_struct_ty.generics.len(),
|
||||
if declared_struct_ty.generics.len() == 1 {
|
||||
""
|
||||
} else {
|
||||
"s"
|
||||
},
|
||||
id,
|
||||
generics.len()
|
||||
),
|
||||
}),
|
||||
}
|
||||
}
|
||||
(declared_ty, _) => Ok(declared_ty),
|
||||
false => Err(ErrorInner {
|
||||
pos: Some(pos),
|
||||
message: format!(
|
||||
"Expected {} generic argument{} on type {}, but got {}",
|
||||
ty.generics.len(),
|
||||
if ty.generics.len() == 1 { "" } else { "s" },
|
||||
id,
|
||||
checked_generics.len()
|
||||
),
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -3099,7 +3022,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
|
|||
.into())
|
||||
}
|
||||
Expression::InlineStruct(id, inline_members) => {
|
||||
let (ty, _) = match types.get(module_id).unwrap().get(&id).cloned() {
|
||||
let ty = match types.get(module_id).unwrap().get(&id).cloned() {
|
||||
None => Err(ErrorInner {
|
||||
pos: Some(pos),
|
||||
message: format!("Undefined type `{}`", id),
|
||||
|
@ -3107,7 +3030,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
|
|||
Some(ty) => Ok(ty),
|
||||
}?;
|
||||
|
||||
let mut declared_struct_type = match ty {
|
||||
let mut declared_struct_type = match ty.ty {
|
||||
DeclarationType::Struct(struct_type) => struct_type,
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
@ -5529,12 +5452,8 @@ mod tests {
|
|||
}
|
||||
.mock();
|
||||
|
||||
let expected_type = DeclarationType::Struct(DeclarationStructType::new(
|
||||
"".into(),
|
||||
"Foo".into(),
|
||||
vec![],
|
||||
vec![],
|
||||
));
|
||||
let expected_type =
|
||||
DeclarationStructType::new("".into(), "Foo".into(), vec![], vec![]);
|
||||
|
||||
assert_eq!(
|
||||
Checker::<Bn128Field>::new().check_struct_type_declaration(
|
||||
|
@ -5570,7 +5489,7 @@ mod tests {
|
|||
}
|
||||
.mock();
|
||||
|
||||
let expected_type = DeclarationType::Struct(DeclarationStructType::new(
|
||||
let expected_type = DeclarationStructType::new(
|
||||
"".into(),
|
||||
"Foo".into(),
|
||||
vec![],
|
||||
|
@ -5578,7 +5497,7 @@ mod tests {
|
|||
DeclarationStructMember::new("foo".into(), DeclarationType::FieldElement),
|
||||
DeclarationStructMember::new("bar".into(), DeclarationType::Boolean),
|
||||
],
|
||||
));
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
Checker::<Bn128Field>::new().check_struct_type_declaration(
|
||||
|
@ -5681,7 +5600,7 @@ mod tests {
|
|||
.get(&*MODULE_ID)
|
||||
.unwrap()
|
||||
.get(&"Bar".to_string())
|
||||
.map(|(ty, _)| ty)
|
||||
.map(|ty| &ty.ty)
|
||||
.unwrap(),
|
||||
&DeclarationType::Struct(DeclarationStructType::new(
|
||||
(*MODULE_ID).clone(),
|
||||
|
|
Loading…
Reference in a new issue