From 30ebec6f375af5a243679a912b45d5894157f8a2 Mon Sep 17 00:00:00 2001 From: schaeff Date: Thu, 23 Sep 2021 11:14:18 +0300 Subject: [PATCH] simplify treatment of user defined types --- zokrates_core/src/semantics.rs | 399 +++++++++++++-------------------- 1 file changed, 159 insertions(+), 240 deletions(-) diff --git a/zokrates_core/src/semantics.rs b/zokrates_core/src/semantics.rs index 36f7af1d..2c28e0b3 100644 --- a/zokrates_core/src/semantics.rs +++ b/zokrates_core/src/semantics.rs @@ -55,9 +55,33 @@ impl ErrorInner { } } -type GenericDeclarations<'ast> = Option>>>; -type TypeMap<'ast> = - HashMap, GenericDeclarations<'ast>)>>; +// a single struct to cover all cases of user-defined types +#[derive(Debug, Clone)] +struct UserDeclarationType<'ast> { + generics: Vec>, + ty: DeclarationType<'ast>, +} + +impl<'ast> UserDeclarationType<'ast> { + // returns the declared generics for this user type + // for alias of basic types this is empty + // for structs this is the same as the used generics + // for aliases of structs this is the names of the generics declared on the left side of the type declaration + fn declaration_generics(&self) -> Vec<&'ast str> { + self.generics + .iter() + .filter_map(|g| match g { + DeclarationConstant::Generic(g) => Some(g), + _ => None, + }) + .collect::>() // we collect into a BTreeSet because draining it after yields the element in the right order defined by Ord + .into_iter() + .map(|g| g.name()) + .collect() + } +} + +type TypeMap<'ast> = HashMap>>; type ConstantMap<'ast> = HashMap, DeclarationType<'ast>>>; @@ -356,7 +380,7 @@ impl<'ast, T: Field> Checker<'ast, T> { ty: TypeDefinitionNode<'ast>, module_id: &ModuleId, state: &State<'ast, T>, - ) -> Result<(DeclarationType<'ast>, GenericDeclarations<'ast>), Vec> { + ) -> Result, Vec> { let pos = ty.pos(); let ty = ty.value; @@ -382,9 +406,9 @@ impl<'ast, T: Field> Checker<'ast, T> { } else { match generics_map.insert(g.value, index).is_none() { true => { - generics.push(Some(DeclarationConstant::Generic( + generics.push(DeclarationConstant::Generic( GenericIdentifier::with_name(g.value).with_index(index), - ))); + )); } false => { errors.push(ErrorInner { @@ -420,7 +444,7 @@ impl<'ast, T: Field> Checker<'ast, T> { return Err(errors); } - Ok((ty, Some(generics))) + Ok(UserDeclarationType { generics, ty }) } Err(e) => { errors.push(e); @@ -486,7 +510,7 @@ impl<'ast, T: Field> Checker<'ast, T> { s: StructDefinitionNode<'ast>, module_id: &ModuleId, state: &State<'ast, T>, - ) -> Result, Vec> { + ) -> Result, Vec> { let pos = s.pos(); let s = s.value; @@ -569,7 +593,7 @@ impl<'ast, T: Field> Checker<'ast, T> { return Err(errors); } - Ok(DeclarationType::Struct(DeclarationStructType::new( + Ok(DeclarationStructType::new( module_id.to_path_buf(), id, generics, @@ -577,7 +601,7 @@ impl<'ast, T: Field> Checker<'ast, T> { .iter() .map(|f| DeclarationStructMember::new(f.0.clone(), f.1.clone())) .collect(), - ))) + )) } fn check_symbol_declaration( @@ -620,7 +644,18 @@ impl<'ast, T: Field> Checker<'ast, T> { .types .entry(module_id.to_path_buf()) .or_default() - .insert(declaration.id.to_string(), (ty, None)) + .insert( + declaration.id.to_string(), + UserDeclarationType { + generics: ty + .generics + .clone() + .into_iter() + .map(|g| g.unwrap()) + .collect(), + ty: DeclarationType::Struct(ty) + } + ) .is_none()); } }; @@ -781,18 +816,20 @@ impl<'ast, T: Field> Checker<'ast, T> { .cloned(); match (function_candidates.len(), type_candidate, const_candidate) { - (0, Some((t, alias_generics)), None) => { - + (0, Some(t), None) => { // rename the type to the declared symbol - let t = match t { - DeclarationType::Struct(t) => DeclarationType::Struct(DeclarationStructType { - location: Some(StructLocation { - name: declaration.id.into(), - module: module_id.to_path_buf() + let t = UserDeclarationType { + ty: match t.ty { + DeclarationType::Struct(t) => DeclarationType::Struct(DeclarationStructType { + location: Some(StructLocation { + name: declaration.id.into(), + module: module_id.to_path_buf() + }), + ..t }), - ..t - }), - _ => t // type alias + _ => t.ty // all other cases + }, + ..t }; // we imported a type, so the symbol it gets bound to should not already exist @@ -814,7 +851,7 @@ impl<'ast, T: Field> Checker<'ast, T> { .types .entry(module_id.to_path_buf()) .or_default() - .insert(declaration.id.to_string(), (t, alias_generics)); + .insert(declaration.id.to_string(), t); } (0, None, Some(ty)) => { match symbol_unifier.insert_constant(declaration.id) { @@ -1306,133 +1343,71 @@ impl<'ast, T: Field> Checker<'ast, T> { ))) } UnresolvedType::User(id, generics) => { - let (declaration_type, alias_generics) = types - .get(module_id) - .unwrap() - .get(&id) - .cloned() - .ok_or_else(|| ErrorInner { - pos: Some(pos), - message: format!("Undefined type {}", id), - })?; + let declared_ty = + types + .get(module_id) + .unwrap() + .get(&id) + .cloned() + .ok_or_else(|| ErrorInner { + pos: Some(pos), + message: format!("Undefined type {}", id), + })?; + + let generic_identifiers = declared_ty.declaration_generics(); + + let declaration_type = declared_ty.ty; // absence of generics is treated as 0 generics, as we do not provide inference for now let generics = generics.unwrap_or_default(); // check generics - match (declaration_type, alias_generics) { - (DeclarationType::Struct(struct_type), None) => { - match struct_type.generics.len() == generics.len() { - true => { - // downcast the generics to identifiers, as this is the only possibility here - let generic_identifiers = struct_type.generics.iter().map(|c| { - match c.as_ref().unwrap() { - DeclarationConstant::Generic(g) => g.clone(), - _ => unreachable!(), - } - }); - - // build the generic assignment for this type - let assignment = GGenericsAssignment(generics - .into_iter() - .zip(generic_identifiers) - .map(|(e, g)| match e { - Some(e) => { - self - .check_expression(e, module_id, types) - .and_then(|e| { - UExpression::try_from_typed(e, &UBitwidth::B32) - .map(|e| (g, e)) - .map_err(|e| ErrorInner { - pos: Some(pos), - message: format!("Expected u32 expression, but got expression of type {}", e.get_type()), - }) + match generic_identifiers.len() == generics.len() { + true => { + // build the generic assignment for this type + let assignment = GGenericsAssignment(generics + .into_iter() + .zip(generic_identifiers) + .enumerate() + .map(|(i, (e, g))| match e { + Some(e) => { + self + .check_expression(e, module_id, types) + .and_then(|e| { + UExpression::try_from_typed(e, &UBitwidth::B32) + .map(|e| (GenericIdentifier::with_name(g).with_index(i), e)) + .map_err(|e| ErrorInner { + pos: Some(pos), + message: format!("Expected u32 expression, but got expression of type {}", e.get_type()), }) - }, - None => Err(ErrorInner { - pos: Some(pos), - message: - "Expected u32 constant or identifier, but found `_`. Generic inference is not supported yet." - .into(), }) - }) - .collect::>()?); + }, + None => Err(ErrorInner { + pos: Some(pos), + message: + "Expected u32 constant or identifier, but found `_`. Generic inference is not supported yet." + .into(), + }) + }) + .collect::>()?); - // specialize the declared type using the generic assignment - Ok(specialize_declaration_type( - DeclarationType::Struct(struct_type), - &assignment, - ) - .unwrap()) - } - false => Err(ErrorInner { - pos: Some(pos), - message: format!( - "Expected {} generic argument{} on type {}, but got {}", - struct_type.generics.len(), - if struct_type.generics.len() == 1 { - "" - } else { - "s" - }, - id, - generics.len() - ), - }), - } + // specialize the declared type using the generic assignment + Ok(specialize_declaration_type(declaration_type, &assignment).unwrap()) } - (declaration_type, Some(alias_generics)) => { - match alias_generics.len() == generics.len() { - true => { - let generic_identifiers = - alias_generics.iter().map(|c| match c.as_ref().unwrap() { - DeclarationConstant::Generic(g) => g.clone(), - _ => unreachable!(), - }); - - // build the generic assignment for this type - let assignment = GGenericsAssignment(generics - .into_iter() - .zip(generic_identifiers) - .map(|(e, g)| match e { - Some(e) => { - self - .check_expression(e, module_id, types) - .and_then(|e| { - UExpression::try_from_typed(e, &UBitwidth::B32) - .map(|e| (g, e)) - .map_err(|e| ErrorInner { - pos: Some(pos), - message: format!("Expected u32 expression, but got expression of type {}", e.get_type()), - }) - }) - }, - None => Err(ErrorInner { - pos: Some(pos), - message: - "Expected u32 constant or identifier, but found `_`. Generic inference is not supported yet." - .into(), - }) - }) - .collect::>()?); - - // specialize the declared type using the generic assignment - Ok(specialize_declaration_type(declaration_type, &assignment) - .unwrap()) - } - false => Err(ErrorInner { - pos: Some(pos), - message: format!( - "Expected {} generic argument{} on type {}, but got {}", - alias_generics.len(), - if alias_generics.len() == 1 { "" } else { "s" }, - id, - generics.len() - ), - }), - } - } - _ => unreachable!(), + false => Err(ErrorInner { + pos: Some(pos), + message: format!( + "Expected {} generic argument{} on type {}, but got {}", + generic_identifiers.len(), + if generic_identifiers.len() == 1 { + "" + } else { + "s" + }, + id, + generics.len() + ), + }), } } } @@ -1527,7 +1502,7 @@ impl<'ast, T: Field> Checker<'ast, T> { ))) } UnresolvedType::User(id, generics) => { - let (declared_ty, alias_generics) = state + let ty = state .types .get(module_id) .unwrap() @@ -1538,99 +1513,47 @@ impl<'ast, T: Field> Checker<'ast, T> { message: format!("Undefined type {}", id), })?; - match (declared_ty, alias_generics) { - (ty, Some(alias_generics)) => { - let generics = generics.unwrap_or_default(); - let checked_generics: Vec<_> = generics - .into_iter() - .map(|e| match e { - Some(e) => self - .check_generic_expression( - e, - module_id, - state.constants.get(module_id).unwrap_or(&HashMap::new()), - generics_map, - used_generics, - ) - .map(Some), - None => Err(ErrorInner { - pos: Some(pos), - message: "Expected u32 constant or identifier, but found `_`" - .into(), - }), - }) - .collect::>()?; + let generics = generics.unwrap_or_default(); + let checked_generics: Vec<_> = generics + .into_iter() + .map(|e| match e { + Some(e) => self + .check_generic_expression( + e, + module_id, + state.constants.get(module_id).unwrap_or(&HashMap::new()), + generics_map, + used_generics, + ) + .map(Some), + None => Err(ErrorInner { + pos: Some(pos), + message: "Expected u32 constant or identifier, but found `_`".into(), + }), + }) + .collect::>()?; + match ty.generics.len() == checked_generics.len() { + true => { let mut assignment = GGenericsAssignment::default(); - assignment.0.extend( - alias_generics.iter().zip(checked_generics.iter()).map( - |(decl_g, g_val)| match decl_g.clone().unwrap() { - DeclarationConstant::Generic(g) => (g, g_val.clone().unwrap()), - _ => unreachable!(), - }, - ), - ); + assignment.0.extend(ty.generics.iter().zip(checked_generics.iter()).map(|(decl_g, g_val)| match decl_g.clone() { + DeclarationConstant::Generic(g) => (g, g_val.clone().unwrap()), + _ => unreachable!("generic on declaration struct types must be generic identifiers") + })); - Ok(specialize_declaration_type(ty, &assignment).unwrap()) + Ok(specialize_declaration_type(ty.ty, &assignment).unwrap()) } - (DeclarationType::Struct(declared_struct_ty), None) => { - let generics = generics.unwrap_or_default(); - match declared_struct_ty.generics.len() == generics.len() { - true => { - let checked_generics: Vec<_> = generics - .into_iter() - .map(|e| match e { - Some(e) => self - .check_generic_expression( - e, - module_id, - state - .constants - .get(module_id) - .unwrap_or(&HashMap::new()), - generics_map, - used_generics, - ) - .map(Some), - None => Err(ErrorInner { - pos: Some(pos), - message: - "Expected u32 constant or identifier, but found `_`" - .into(), - }), - }) - .collect::>()?; - - let mut assignment = GGenericsAssignment::default(); - - assignment.0.extend(declared_struct_ty.generics.iter().zip(checked_generics.iter()).map(|(decl_g, g_val)| match decl_g.clone().unwrap() { - DeclarationConstant::Generic(g) => (g, g_val.clone().unwrap()), - _ => unreachable!("generic on declaration struct types must be generic identifiers") - })); - - Ok(DeclarationType::Struct(DeclarationStructType { - generics: checked_generics, - ..declared_struct_ty - })) - } - false => Err(ErrorInner { - pos: Some(pos), - message: format!( - "Expected {} generic argument{} on type {}, but got {}", - declared_struct_ty.generics.len(), - if declared_struct_ty.generics.len() == 1 { - "" - } else { - "s" - }, - id, - generics.len() - ), - }), - } - } - (declared_ty, _) => Ok(declared_ty), + false => Err(ErrorInner { + pos: Some(pos), + message: format!( + "Expected {} generic argument{} on type {}, but got {}", + ty.generics.len(), + if ty.generics.len() == 1 { "" } else { "s" }, + id, + checked_generics.len() + ), + }), } } } @@ -3099,7 +3022,7 @@ impl<'ast, T: Field> Checker<'ast, T> { .into()) } Expression::InlineStruct(id, inline_members) => { - let (ty, _) = match types.get(module_id).unwrap().get(&id).cloned() { + let ty = match types.get(module_id).unwrap().get(&id).cloned() { None => Err(ErrorInner { pos: Some(pos), message: format!("Undefined type `{}`", id), @@ -3107,7 +3030,7 @@ impl<'ast, T: Field> Checker<'ast, T> { Some(ty) => Ok(ty), }?; - let mut declared_struct_type = match ty { + let mut declared_struct_type = match ty.ty { DeclarationType::Struct(struct_type) => struct_type, _ => unreachable!(), }; @@ -5529,12 +5452,8 @@ mod tests { } .mock(); - let expected_type = DeclarationType::Struct(DeclarationStructType::new( - "".into(), - "Foo".into(), - vec![], - vec![], - )); + let expected_type = + DeclarationStructType::new("".into(), "Foo".into(), vec![], vec![]); assert_eq!( Checker::::new().check_struct_type_declaration( @@ -5570,7 +5489,7 @@ mod tests { } .mock(); - let expected_type = DeclarationType::Struct(DeclarationStructType::new( + let expected_type = DeclarationStructType::new( "".into(), "Foo".into(), vec![], @@ -5578,7 +5497,7 @@ mod tests { DeclarationStructMember::new("foo".into(), DeclarationType::FieldElement), DeclarationStructMember::new("bar".into(), DeclarationType::Boolean), ], - )); + ); assert_eq!( Checker::::new().check_struct_type_declaration( @@ -5681,7 +5600,7 @@ mod tests { .get(&*MODULE_ID) .unwrap() .get(&"Bar".to_string()) - .map(|(ty, _)| ty) + .map(|ty| &ty.ty) .unwrap(), &DeclarationType::Struct(DeclarationStructType::new( (*MODULE_ID).clone(),