1
0
Fork 0
mirror of synced 2025-09-23 12:18:44 +00:00

implement type aliasing

This commit is contained in:
dark64 2021-08-24 14:49:08 +02:00
parent 6b8f36a429
commit 2a94af6ff0
9 changed files with 362 additions and 38 deletions

View file

@ -0,0 +1,13 @@
type byte = u8
type uint32 = u32
type UInt32Array<N> = uint32[N]
type matrix<R, C> = field[R][C]
def fill<R, C>(field v) -> matrix<R, C>:
return [[v; C]; R]
def main(uint32 a, uint32 b) -> (UInt32Array<2>, matrix<2, 4>):
UInt32Array<2> res = [a, b]
matrix<2, 4> m = fill(1)
return res, m

View file

@ -0,0 +1,14 @@
from "./basic_aliasing.zok" import matrix
from "./struct_aliasing.zok" import Buzz
const u32 R = 2
const u32 C = 4
type matrix_2x4 = matrix<R, C>
def buzz<N>() -> Buzz<N>:
return Buzz { a: [0; N], b: [0; N] }
def main(matrix_2x4 m) -> (Buzz<2>, matrix_2x4):
Buzz<2> b = buzz::<2>()
return b, m

View file

@ -0,0 +1,15 @@
type FieldArray<N> = field[N]
struct Foo<A, B> {
FieldArray<A> a
FieldArray<B> b
}
type Bar = Foo<2, 2>
type Buzz<A> = Foo<A, A>
def main(Bar a) -> Buzz<2>:
Bar bar = Bar { a: [1, 2], b: [1, 2] }
Buzz<2> buzz = Buzz { a: [1, 2], b: [1, 2] }
assert(bar == buzz)
return buzz

View file

@ -1,5 +1,6 @@
use crate::absy;
use crate::absy::SymbolDefinition;
use num_bigint::BigUint;
use std::path::Path;
use zokrates_pest_ast as pest;
@ -10,6 +11,7 @@ impl<'ast> From<pest::File<'ast>> for absy::Module<'ast> {
pest::SymbolDeclaration::Import(i) => import_directive_to_symbol_vec(i),
pest::SymbolDeclaration::Constant(c) => vec![c.into()],
pest::SymbolDeclaration::Struct(s) => vec![s.into()],
pest::SymbolDeclaration::Type(t) => vec![t.into()],
pest::SymbolDeclaration::Function(f) => vec![f.into()],
}))
}
@ -135,6 +137,31 @@ impl<'ast> From<pest::ConstantDefinition<'ast>> for absy::SymbolDeclarationNode<
}
}
impl<'ast> From<pest::TypeDefinition<'ast>> for absy::SymbolDeclarationNode<'ast> {
fn from(definition: pest::TypeDefinition<'ast>) -> absy::SymbolDeclarationNode<'ast> {
use crate::absy::NodeValue;
let span = definition.span;
let id = definition.id.span.as_str();
let ty = absy::TypeDefinition {
generics: definition
.generics
.into_iter()
.map(absy::ConstantGenericNode::from)
.collect(),
ty: definition.ty.into(),
}
.span(span.clone());
absy::SymbolDeclaration {
id,
symbol: absy::Symbol::Here(SymbolDefinition::Type(ty)),
}
.span(span)
}
}
impl<'ast> From<pest::FunctionDefinition<'ast>> for absy::SymbolDeclarationNode<'ast> {
fn from(function: pest::FunctionDefinition<'ast>) -> absy::SymbolDeclarationNode<'ast> {
use crate::absy::NodeValue;

View file

@ -133,6 +133,7 @@ pub enum SymbolDefinition<'ast> {
Import(CanonicalImportNode<'ast>),
Struct(StructDefinitionNode<'ast>),
Constant(ConstantDefinitionNode<'ast>),
Type(TypeDefinitionNode<'ast>),
Function(FunctionNode<'ast>),
}
@ -153,12 +154,28 @@ impl<'ast> fmt::Display for SymbolDeclaration<'ast> {
i.value.source.display(),
i.value.id
),
SymbolDefinition::Struct(ref t) => write!(f, "struct {}{}", self.id, t),
SymbolDefinition::Struct(ref s) => write!(f, "struct {}{}", self.id, s),
SymbolDefinition::Constant(ref c) => write!(
f,
"const {} {} = {}",
c.value.ty, self.id, c.value.expression
),
SymbolDefinition::Type(ref t) => {
write!(f, "type {}", self.id)?;
if !t.value.generics.is_empty() {
write!(
f,
"<{}>",
t.value
.generics
.iter()
.map(|g| g.to_string())
.collect::<Vec<_>>()
.join(", ")
)?;
}
write!(f, " = {}", t.value.ty)
}
SymbolDefinition::Function(ref func) => {
write!(f, "def {}{}", self.id, func)
}
@ -205,15 +222,18 @@ pub struct StructDefinition<'ast> {
impl<'ast> fmt::Display for StructDefinition<'ast> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
writeln!(
f,
"<{}> {{",
self.generics
.iter()
.map(|g| g.to_string())
.collect::<Vec<_>>()
.join(", "),
)?;
if !self.generics.is_empty() {
write!(
f,
"<{}> ",
self.generics
.iter()
.map(|g| g.to_string())
.collect::<Vec<_>>()
.join(", ")
)?;
}
writeln!(f, "{{")?;
for field in &self.fields {
writeln!(f, " {}", field)?;
}
@ -248,7 +268,34 @@ pub type ConstantDefinitionNode<'ast> = Node<ConstantDefinition<'ast>>;
impl<'ast> fmt::Display for ConstantDefinition<'ast> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "const {}({})", self.ty, self.expression)
write!(f, "const {} _ = {}", self.ty, self.expression)
}
}
/// A type definition
#[derive(Debug, Clone, PartialEq)]
pub struct TypeDefinition<'ast> {
pub generics: Vec<ConstantGenericNode<'ast>>,
pub ty: UnresolvedTypeNode<'ast>,
}
pub type TypeDefinitionNode<'ast> = Node<TypeDefinition<'ast>>;
impl<'ast> fmt::Display for TypeDefinition<'ast> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "type _")?;
if !self.generics.is_empty() {
write!(
f,
"<{}>",
self.generics
.iter()
.map(|g| g.to_string())
.collect::<Vec<_>>()
.join(", ")
)?;
}
write!(f, " = {}", self.ty)
}
}

View file

@ -84,6 +84,7 @@ impl<'ast> NodeValue for UnresolvedType<'ast> {}
impl<'ast> NodeValue for StructDefinition<'ast> {}
impl<'ast> NodeValue for StructDefinitionField<'ast> {}
impl<'ast> NodeValue for ConstantDefinition<'ast> {}
impl<'ast> NodeValue for TypeDefinition<'ast> {}
impl<'ast> NodeValue for Function<'ast> {}
impl<'ast> NodeValue for Module<'ast> {}
impl<'ast> NodeValue for CanonicalImport<'ast> {}

View file

@ -55,7 +55,9 @@ impl ErrorInner {
}
}
type TypeMap<'ast> = HashMap<OwnedModuleId, HashMap<UserTypeId, DeclarationType<'ast>>>;
type GenericDeclarations<'ast> = Option<Vec<Option<DeclarationConstant<'ast>>>>;
type TypeMap<'ast> =
HashMap<OwnedModuleId, HashMap<UserTypeId, (DeclarationType<'ast>, GenericDeclarations<'ast>)>>;
type ConstantMap<'ast> =
HashMap<OwnedModuleId, HashMap<ConstantIdentifier<'ast>, DeclarationType<'ast>>>;
@ -349,6 +351,85 @@ impl<'ast, T: Field> Checker<'ast, T> {
})
}
fn check_type_definition(
&mut self,
ty: TypeDefinitionNode<'ast>,
module_id: &ModuleId,
state: &State<'ast, T>,
) -> Result<(DeclarationType<'ast>, GenericDeclarations<'ast>), Vec<ErrorInner>> {
let pos = ty.pos();
let ty = ty.value;
let mut errors = vec![];
let mut generics = vec![];
let mut generics_map = HashMap::new();
for (index, g) in ty.generics.iter().enumerate() {
if state
.constants
.get(module_id)
.and_then(|m| m.get(g.value))
.is_some()
{
errors.push(ErrorInner {
pos: Some(g.pos()),
message: format!(
"Generic parameter {p} conflicts with constant symbol {p}",
p = g.value
),
});
} else {
match generics_map.insert(g.value, index).is_none() {
true => {
generics.push(Some(DeclarationConstant::Generic(GenericIdentifier {
name: g.value,
index,
})));
}
false => {
errors.push(ErrorInner {
pos: Some(g.pos()),
message: format!("Generic parameter {} is already declared", g.value),
});
}
}
}
}
let mut used_generics = HashSet::new();
match self.check_declaration_type(
ty.ty,
module_id,
state,
&generics_map,
&mut used_generics,
) {
Ok(ty) => {
// check that all declared generics were used
for declared_generic in generics_map.keys() {
if !used_generics.contains(declared_generic) {
errors.push(ErrorInner {
pos: Some(pos),
message: format!("Generic parameter {} must be used", declared_generic),
});
}
}
if !errors.is_empty() {
return Err(errors);
}
Ok((ty, Some(generics)))
}
Err(e) => {
errors.push(e);
Err(errors)
}
}
}
fn check_constant_definition(
&mut self,
id: ConstantIdentifier<'ast>,
@ -541,7 +622,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
.types
.entry(module_id.to_path_buf())
.or_default()
.insert(declaration.id.to_string(), ty)
.insert(declaration.id.to_string(), (ty, None))
.is_none());
}
};
@ -593,6 +674,35 @@ impl<'ast, T: Field> Checker<'ast, T> {
}
}
}
Symbol::Here(SymbolDefinition::Type(t)) => {
match self.check_type_definition(t, module_id, state) {
Ok(ty) => {
match symbol_unifier.insert_type(declaration.id) {
false => errors.push(
ErrorInner {
pos: Some(pos),
message: format!(
"{} conflicts with another symbol",
declaration.id,
),
}
.in_file(module_id),
),
true => {
assert!(state
.types
.entry(module_id.to_path_buf())
.or_default()
.insert(declaration.id.to_string(), ty)
.is_none());
}
};
}
Err(e) => {
errors.extend(e.into_iter().map(|inner| inner.in_file(module_id)));
}
}
}
Symbol::Here(SymbolDefinition::Function(f)) => {
match self.check_function(f, module_id, state) {
Ok(funct) => {
@ -673,7 +783,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
.cloned();
match (function_candidates.len(), type_candidate, const_candidate) {
(0, Some(t), None) => {
(0, Some((t, alias_generics)), None) => {
// rename the type to the declared symbol
let t = match t {
@ -684,7 +794,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
}),
..t
}),
_ => unreachable!()
_ => t // type alias
};
// we imported a type, so the symbol it gets bound to should not already exist
@ -706,7 +816,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
.types
.entry(module_id.to_path_buf())
.or_default()
.insert(declaration.id.to_string(), t);
.insert(declaration.id.to_string(), (t, alias_generics));
}
(0, None, Some(ty)) => {
match symbol_unifier.insert_constant(declaration.id) {
@ -1187,23 +1297,22 @@ impl<'ast, T: Field> Checker<'ast, T> {
)))
}
UnresolvedType::User(id, generics) => {
let declaration_type =
types
.get(module_id)
.unwrap()
.get(&id)
.cloned()
.ok_or_else(|| ErrorInner {
pos: Some(pos),
message: format!("Undefined type {}", id),
})?;
let (declaration_type, alias_generics) = types
.get(module_id)
.unwrap()
.get(&id)
.cloned()
.ok_or_else(|| ErrorInner {
pos: Some(pos),
message: format!("Undefined type {}", id),
})?;
// absence of generics is treated as 0 generics, as we do not provide inference for now
let generics = generics.unwrap_or_default();
// check generics
match declaration_type {
DeclarationType::Struct(struct_type) => {
match (declaration_type, alias_generics) {
(DeclarationType::Struct(struct_type), None) => {
match struct_type.generics.len() == generics.len() {
true => {
// downcast the generics to identifiers, as this is the only possibility here
@ -1263,7 +1372,58 @@ impl<'ast, T: Field> Checker<'ast, T> {
}),
}
}
_ => unreachable!("user defined types should always be structs"),
(declaration_type, Some(alias_generics)) => {
match alias_generics.len() == generics.len() {
true => {
let generic_identifiers =
alias_generics.iter().map(|c| match c.as_ref().unwrap() {
DeclarationConstant::Generic(g) => g.clone(),
_ => unreachable!(),
});
// build the generic assignment for this type
let assignment = GGenericsAssignment(generics
.into_iter()
.zip(generic_identifiers)
.map(|(e, g)| match e {
Some(e) => {
self
.check_expression(e, module_id, types)
.and_then(|e| {
UExpression::try_from_typed(e, &UBitwidth::B32)
.map(|e| (g, e))
.map_err(|e| ErrorInner {
pos: Some(pos),
message: format!("Expected u32 expression, but got expression of type {}", e.get_type()),
})
})
},
None => Err(ErrorInner {
pos: Some(pos),
message:
"Expected u32 constant or identifier, but found `_`. Generic inference is not supported yet."
.into(),
})
})
.collect::<Result<_, _>>()?);
// specialize the declared type using the generic assignment
Ok(specialize_declaration_type(declaration_type, &assignment)
.unwrap())
}
false => Err(ErrorInner {
pos: Some(pos),
message: format!(
"Expected {} generic argument{} on type {}, but got {}",
alias_generics.len(),
if alias_generics.len() == 1 { "" } else { "s" },
id,
generics.len()
),
}),
}
}
_ => unreachable!(),
}
}
}
@ -1358,7 +1518,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
)))
}
UnresolvedType::User(id, generics) => {
let declared_ty = state
let (declared_ty, alias_generics) = state
.types
.get(module_id)
.unwrap()
@ -1369,8 +1529,43 @@ impl<'ast, T: Field> Checker<'ast, T> {
message: format!("Undefined type {}", id),
})?;
match declared_ty {
DeclarationType::Struct(declared_struct_ty) => {
match (declared_ty, alias_generics) {
(ty, Some(alias_generics)) => {
let generics = generics.unwrap_or_default();
let checked_generics: Vec<_> = generics
.into_iter()
.map(|e| match e {
Some(e) => self
.check_generic_expression(
e,
module_id,
state.constants.get(module_id).unwrap_or(&HashMap::new()),
generics_map,
used_generics,
)
.map(Some),
None => Err(ErrorInner {
pos: Some(pos),
message: "Expected u32 constant or identifier, but found `_`"
.into(),
}),
})
.collect::<Result<_, _>>()?;
let mut assignment = GGenericsAssignment::default();
assignment.0.extend(
alias_generics.iter().zip(checked_generics.iter()).map(
|(decl_g, g_val)| match decl_g.clone().unwrap() {
DeclarationConstant::Generic(g) => (g, g_val.clone().unwrap()),
_ => unreachable!(),
},
),
);
Ok(specialize_declaration_type(ty, &assignment).unwrap())
}
(DeclarationType::Struct(declared_struct_ty), None) => {
let generics = generics.unwrap_or_default();
match declared_struct_ty.generics.len() == generics.len() {
true => {
@ -1441,7 +1636,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
}),
}
}
_ => Ok(declared_ty),
(declared_ty, _) => Ok(declared_ty),
}
}
}
@ -2910,7 +3105,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
.into())
}
Expression::InlineStruct(id, inline_members) => {
let ty = match types.get(module_id).unwrap().get(&id).cloned() {
let (ty, _) = match types.get(module_id).unwrap().get(&id).cloned() {
None => Err(ErrorInner {
pos: Some(pos),
message: format!("Undefined type `{}`", id),

View file

@ -4,7 +4,7 @@ file = { SOI ~ NEWLINE* ~ pragma? ~ NEWLINE* ~ symbol_declaration* ~ EOI }
pragma = { "#pragma" ~ "curve" ~ curve }
curve = @{ (ASCII_ALPHANUMERIC | "_") * }
symbol_declaration = { (import_directive | ty_struct_definition | const_definition | function_definition) ~ NEWLINE* }
symbol_declaration = { (import_directive | ty_struct_definition | const_definition | type_definition | function_definition) ~ NEWLINE* }
import_directive = { main_import_directive | from_import_directive }
from_import_directive = { "from" ~ "\"" ~ import_source ~ "\"" ~ "import" ~ import_symbol_list ~ NEWLINE* }
@ -14,6 +14,7 @@ import_symbol = { identifier ~ ("as" ~ identifier)? }
import_symbol_list = _{ import_symbol ~ ("," ~ import_symbol)* }
function_definition = {"def" ~ identifier ~ constant_generics_declaration? ~ "(" ~ parameter_list ~ ")" ~ return_types ~ ":" ~ NEWLINE* ~ statement* }
const_definition = {"const" ~ ty ~ identifier ~ "=" ~ expression ~ NEWLINE*}
type_definition = {"type" ~ identifier ~ constant_generics_declaration? ~ "=" ~ ty ~ NEWLINE*}
return_types = _{ ( "->" ~ ( "(" ~ type_list ~ ")" | ty ))? }
constant_generics_declaration = _{ "<" ~ constant_generics_list ~ ">" }
constant_generics_list = _{ identifier ~ ("," ~ identifier)* }
@ -163,6 +164,6 @@ COMMENT = _{ ("/*" ~ (!"*/" ~ ANY)* ~ "*/") | ("//" ~ (!NEWLINE ~ ANY)*) }
// the ordering of reserved keywords matters: if "as" is before "assert", then "assert" gets parsed as (as)(sert) and incorrectly
// accepted
keyword = @{"assert"|"as"|"bool"|"byte"|"const"|"def"|"do"|"else"|"endfor"|"export"|"false"|"field"|"for"|"if"|"then"|"fi"|"import"|"from"|
keyword = @{"assert"|"as"|"bool"|"const"|"def"|"do"|"else"|"endfor"|"export"|"false"|"field"|"for"|"if"|"then"|"fi"|"import"|"from"|
"in"|"private"|"public"|"return"|"struct"|"true"|"u8"|"u16"|"u32"|"u64"
}

View file

@ -17,8 +17,8 @@ pub use ast::{
InlineStructExpression, InlineStructMember, IterationStatement, LiteralExpression, Parameter,
PostfixExpression, Range, RangeOrExpression, ReturnStatement, Span, Spread, SpreadOrExpression,
Statement, StructDefinition, StructField, SymbolDeclaration, TernaryExpression, ToExpression,
Type, TypedIdentifier, TypedIdentifierOrAssignee, UnaryExpression, UnaryOperator, Underscore,
Visibility,
Type, TypeDefinition, TypedIdentifier, TypedIdentifierOrAssignee, UnaryExpression,
UnaryOperator, Underscore, Visibility,
};
mod ast {
@ -140,6 +140,7 @@ mod ast {
Import(ImportDirective<'ast>),
Constant(ConstantDefinition<'ast>),
Struct(StructDefinition<'ast>),
Type(TypeDefinition<'ast>),
Function(FunctionDefinition<'ast>),
}
@ -184,6 +185,16 @@ mod ast {
pub span: Span<'ast>,
}
#[derive(Debug, FromPest, PartialEq, Clone)]
#[pest_ast(rule(Rule::type_definition))]
pub struct TypeDefinition<'ast> {
pub id: IdentifierExpression<'ast>,
pub generics: Vec<IdentifierExpression<'ast>>,
pub ty: Type<'ast>,
#[pest_ast(outer())]
pub span: Span<'ast>,
}
#[derive(Debug, FromPest, PartialEq, Clone)]
#[pest_ast(rule(Rule::import_directive))]
pub enum ImportDirective<'ast> {