generic equivalence, add tests
This commit is contained in:
parent
289a1b36c2
commit
3069d598c6
9 changed files with 293 additions and 85 deletions
8
zokrates_cli/examples/array_overload.zok
Normal file
8
zokrates_cli/examples/array_overload.zok
Normal file
|
@ -0,0 +1,8 @@
|
|||
def foo(field[2] a) -> bool:
|
||||
return true
|
||||
|
||||
def foo(field[1] a) -> bool:
|
||||
return true
|
||||
|
||||
def main() -> bool:
|
||||
return foo([1])
|
|
@ -0,0 +1,11 @@
|
|||
def foo<N, P>(field[N] a) -> field[P]:
|
||||
return a
|
||||
|
||||
def foo<P, N>(field[P] a) -> field[N]:
|
||||
return a
|
||||
|
||||
def bar<Q>(field[Q] a) -> field[Q]:
|
||||
return foo(a)
|
||||
|
||||
def main() -> field[1]:
|
||||
return bar([1])
|
|
@ -0,0 +1,8 @@
|
|||
def foo<N>(field[N] a) -> bool:
|
||||
return true
|
||||
|
||||
def foo<P>(field[P] a) -> bool:
|
||||
return true
|
||||
|
||||
def main():
|
||||
return
|
|
@ -0,0 +1,14 @@
|
|||
// this should compile but requires looking at the generic parameter values, and they are not known at semantic check time.
|
||||
// It is enough of an edge case to not be worth fixing
|
||||
|
||||
def foo<N, P>(field[N] a) -> field[P]:
|
||||
return [1; P]
|
||||
|
||||
def foo<N, P>(field[P] a) -> field[N]:
|
||||
return [1; N]
|
||||
|
||||
def main() -> field[2]:
|
||||
u32 X = 1
|
||||
u32 Y = 2
|
||||
field[2] a = foo::<X, Y>([1])
|
||||
return a
|
|
@ -4,7 +4,7 @@ use crate::flat_absy::{
|
|||
};
|
||||
use crate::solvers::Solver;
|
||||
use crate::typed_absy::types::{
|
||||
ConcreteGenericsAssignment, Constant, DeclarationSignature, DeclarationType,
|
||||
ConcreteGenericsAssignment, Constant, DeclarationSignature, DeclarationType, GenericIdentifier,
|
||||
};
|
||||
use std::collections::HashMap;
|
||||
use zokrates_field::{Bn128Field, Field};
|
||||
|
@ -43,11 +43,17 @@ impl FlatEmbed {
|
|||
.inputs(vec![DeclarationType::uint(32)])
|
||||
.outputs(vec![DeclarationType::FieldElement]),
|
||||
FlatEmbed::Unpack => DeclarationSignature::new()
|
||||
.generics(vec![Some(Constant::Generic("N"))])
|
||||
.generics(vec![Some(Constant::Generic(GenericIdentifier {
|
||||
name: "N",
|
||||
index: 0,
|
||||
}))])
|
||||
.inputs(vec![DeclarationType::FieldElement])
|
||||
.outputs(vec![DeclarationType::array((
|
||||
DeclarationType::Boolean,
|
||||
"N",
|
||||
GenericIdentifier {
|
||||
name: "N",
|
||||
index: 0,
|
||||
},
|
||||
))]),
|
||||
FlatEmbed::U8ToBits => DeclarationSignature::new()
|
||||
.inputs(vec![DeclarationType::uint(8)])
|
||||
|
|
|
@ -20,7 +20,8 @@ use crate::absy::types::{UnresolvedSignature, UnresolvedType, UserTypeId};
|
|||
|
||||
use crate::typed_absy::types::{
|
||||
ArrayType, Constant, DeclarationArrayType, DeclarationFunctionKey, DeclarationSignature,
|
||||
DeclarationStructMember, DeclarationStructType, DeclarationType, StructLocation,
|
||||
DeclarationStructMember, DeclarationStructType, DeclarationType, GenericIdentifier,
|
||||
StructLocation,
|
||||
};
|
||||
use std::hash::{Hash, Hasher};
|
||||
|
||||
|
@ -325,7 +326,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
|
|||
for field in s.fields {
|
||||
let member_id = field.value.id.to_string();
|
||||
match self
|
||||
.check_declaration_type(field.value.ty, module_id, &types, &HashSet::new())
|
||||
.check_declaration_type(field.value.ty, module_id, &types, &HashMap::new())
|
||||
.map(|t| (member_id, t))
|
||||
{
|
||||
Ok(f) => match fields_set.insert(f.0.clone()) {
|
||||
|
@ -696,7 +697,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
|
|||
|
||||
let v = Variable::with_id_and_type(
|
||||
match generic {
|
||||
Constant::Generic(g) => g,
|
||||
Constant::Generic(g) => g.name,
|
||||
_ => unreachable!(),
|
||||
},
|
||||
Type::Uint(UBitwidth::B32),
|
||||
|
@ -818,12 +819,15 @@ impl<'ast, T: Field> Checker<'ast, T> {
|
|||
let mut outputs = vec![];
|
||||
let mut generics = vec![];
|
||||
|
||||
let mut constants = HashSet::new();
|
||||
let mut generics_map = HashMap::new();
|
||||
|
||||
for g in signature.generics {
|
||||
match constants.insert(g.value) {
|
||||
for (index, g) in signature.generics.iter().enumerate() {
|
||||
match generics_map.insert(g.value, index).is_none() {
|
||||
true => {
|
||||
generics.push(Some(Constant::Generic(g.value)));
|
||||
generics.push(Some(Constant::Generic(GenericIdentifier {
|
||||
name: g.value,
|
||||
index,
|
||||
})));
|
||||
}
|
||||
false => {
|
||||
errors.push(ErrorInner {
|
||||
|
@ -835,7 +839,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
|
|||
}
|
||||
|
||||
for t in signature.inputs {
|
||||
match self.check_declaration_type(t, module_id, types, &constants) {
|
||||
match self.check_declaration_type(t, module_id, types, &generics_map) {
|
||||
Ok(t) => {
|
||||
inputs.push(t);
|
||||
}
|
||||
|
@ -846,7 +850,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
|
|||
}
|
||||
|
||||
for t in signature.outputs {
|
||||
match self.check_declaration_type(t, module_id, types, &constants) {
|
||||
match self.check_declaration_type(t, module_id, types, &generics_map) {
|
||||
Ok(t) => {
|
||||
outputs.push(t);
|
||||
}
|
||||
|
@ -936,6 +940,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
|
|||
fn check_generic_expression(
|
||||
&mut self,
|
||||
expr: ExpressionNode<'ast>,
|
||||
generics_map: &HashMap<Identifier<'ast>, usize>,
|
||||
) -> Result<Constant<'ast>, ErrorInner> {
|
||||
let pos = expr.pos();
|
||||
|
||||
|
@ -956,7 +961,16 @@ impl<'ast, T: Field> Checker<'ast, T> {
|
|||
})
|
||||
}
|
||||
}
|
||||
Expression::Identifier(name) => Ok(Constant::Generic(name)),
|
||||
Expression::Identifier(name) => {
|
||||
// check that this generic parameter is defined
|
||||
match generics_map.get(&name) {
|
||||
Some(index) => Ok(Constant::Generic(GenericIdentifier {name, index: *index})),
|
||||
None => Err(ErrorInner {
|
||||
pos: Some(pos),
|
||||
message: format!("Undeclared generic parameter in function definition: `{}` isn\'t declared as a generic constant", name)
|
||||
})
|
||||
}
|
||||
}
|
||||
e => Err(ErrorInner {
|
||||
pos: Some(pos),
|
||||
message: format!(
|
||||
|
@ -972,7 +986,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
|
|||
ty: UnresolvedTypeNode<'ast>,
|
||||
module_id: &ModuleId,
|
||||
types: &TypeMap<'ast>,
|
||||
constants: &HashSet<Identifier<'ast>>,
|
||||
generics_map: &HashMap<Identifier<'ast>, usize>,
|
||||
) -> Result<DeclarationType<'ast>, ErrorInner> {
|
||||
let pos = ty.pos();
|
||||
let ty = ty.value;
|
||||
|
@ -982,19 +996,10 @@ impl<'ast, T: Field> Checker<'ast, T> {
|
|||
UnresolvedType::Boolean => Ok(DeclarationType::Boolean),
|
||||
UnresolvedType::Uint(bitwidth) => Ok(DeclarationType::uint(bitwidth)),
|
||||
UnresolvedType::Array(t, size) => {
|
||||
let checked_size = self.check_generic_expression(size.clone())?;
|
||||
|
||||
if let Constant::Generic(g) = checked_size {
|
||||
if !constants.contains(g) {
|
||||
return Err(ErrorInner {
|
||||
pos: Some(pos),
|
||||
message: format!("Undeclared generic parameter in function definition: `{}` isn\'t declared as a generic constant", g)
|
||||
});
|
||||
}
|
||||
};
|
||||
let checked_size = self.check_generic_expression(size.clone(), &generics_map)?;
|
||||
|
||||
Ok(DeclarationType::Array(DeclarationArrayType::new(
|
||||
self.check_declaration_type(*t, module_id, types, constants)?,
|
||||
self.check_declaration_type(*t, module_id, types, generics_map)?,
|
||||
checked_size,
|
||||
)))
|
||||
}
|
||||
|
@ -3022,17 +3027,21 @@ mod tests {
|
|||
assert!(!unifier.insert_function(
|
||||
"bar",
|
||||
DeclarationSignature::new()
|
||||
.generics(vec![Some("K".into())])
|
||||
.generics(vec![Some(
|
||||
GenericIdentifier::with_name("K").index(0).into()
|
||||
)])
|
||||
.inputs(vec![DeclarationType::FieldElement])
|
||||
));
|
||||
// a `bar` function with a different signature
|
||||
assert!(unifier.insert_function(
|
||||
"bar",
|
||||
DeclarationSignature::new()
|
||||
.generics(vec![Some("K".into())])
|
||||
.generics(vec![Some(
|
||||
GenericIdentifier::with_name("K").index(0).into()
|
||||
)])
|
||||
.inputs(vec![DeclarationType::array((
|
||||
DeclarationType::FieldElement,
|
||||
"K"
|
||||
GenericIdentifier::with_name("K").index(0)
|
||||
))])
|
||||
));
|
||||
// a `bar` function with a different signature, but which could conflict with the previous one
|
||||
|
@ -3609,12 +3618,18 @@ mod tests {
|
|||
),
|
||||
Ok(DeclarationSignature::new()
|
||||
.inputs(vec![DeclarationType::array((
|
||||
DeclarationType::array((DeclarationType::FieldElement, "K")),
|
||||
"L"
|
||||
DeclarationType::array((
|
||||
DeclarationType::FieldElement,
|
||||
GenericIdentifier::with_name("K").index(0)
|
||||
)),
|
||||
GenericIdentifier::with_name("L").index(1)
|
||||
))])
|
||||
.outputs(vec![DeclarationType::array((
|
||||
DeclarationType::array((DeclarationType::FieldElement, "L")),
|
||||
"K"
|
||||
DeclarationType::array((
|
||||
DeclarationType::FieldElement,
|
||||
GenericIdentifier::with_name("L").index(1)
|
||||
)),
|
||||
GenericIdentifier::with_name("K").index(0)
|
||||
))]))
|
||||
);
|
||||
}
|
||||
|
|
|
@ -657,8 +657,9 @@ mod tests {
|
|||
use crate::typed_absy::types::DeclarationSignature;
|
||||
use crate::typed_absy::{
|
||||
ArrayExpressionInner, DeclarationFunctionKey, DeclarationType, DeclarationVariable,
|
||||
FieldElementExpression, Identifier, OwnedTypedModuleId, Select, Type, TypedExpression,
|
||||
TypedExpressionList, TypedExpressionOrSpread, UBitwidth, UExpressionInner, Variable,
|
||||
FieldElementExpression, GenericIdentifier, Identifier, OwnedTypedModuleId, Select, Type,
|
||||
TypedExpression, TypedExpressionList, TypedExpressionOrSpread, UBitwidth, UExpressionInner,
|
||||
Variable,
|
||||
};
|
||||
use zokrates_field::Bn128Field;
|
||||
|
||||
|
@ -865,20 +866,25 @@ mod tests {
|
|||
// return a_2 + b_1[0]
|
||||
|
||||
let foo_signature = DeclarationSignature::new()
|
||||
.generics(vec![Some("K".into())])
|
||||
.generics(vec![Some(
|
||||
GenericIdentifier::with_name("K").index(0).into(),
|
||||
)])
|
||||
.inputs(vec![DeclarationType::array((
|
||||
DeclarationType::FieldElement,
|
||||
Constant::Generic("K"),
|
||||
Constant::Generic(GenericIdentifier::with_name("K").index(0)),
|
||||
))])
|
||||
.outputs(vec![DeclarationType::array((
|
||||
DeclarationType::FieldElement,
|
||||
Constant::Generic("K"),
|
||||
Constant::Generic(GenericIdentifier::with_name("K").index(0)),
|
||||
))]);
|
||||
|
||||
let foo: TypedFunction<Bn128Field> = TypedFunction {
|
||||
arguments: vec![
|
||||
DeclarationVariable::array("a", DeclarationType::FieldElement, "K").into(),
|
||||
],
|
||||
arguments: vec![DeclarationVariable::array(
|
||||
"a",
|
||||
DeclarationType::FieldElement,
|
||||
GenericIdentifier::with_name("K").index(0),
|
||||
)
|
||||
.into()],
|
||||
statements: vec![TypedStatement::Return(vec![
|
||||
ArrayExpressionInner::Identifier("a".into())
|
||||
.annotate(Type::FieldElement, 1u32)
|
||||
|
@ -983,7 +989,11 @@ mod tests {
|
|||
TypedStatement::PushCallLog(
|
||||
DeclarationFunctionKey::with_location("main", "foo")
|
||||
.signature(foo_signature.clone()),
|
||||
GGenericsAssignment(vec![("K", 1)].into_iter().collect()),
|
||||
GGenericsAssignment(
|
||||
vec![(GenericIdentifier::with_name("K").index(0), 1)]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
),
|
||||
),
|
||||
TypedStatement::Definition(
|
||||
Variable::array(Identifier::from("a").version(1), Type::FieldElement, 1u32)
|
||||
|
@ -1073,20 +1083,25 @@ mod tests {
|
|||
// return a_2 + b_1[0]
|
||||
|
||||
let foo_signature = DeclarationSignature::new()
|
||||
.generics(vec![Some("K".into())])
|
||||
.generics(vec![Some(
|
||||
GenericIdentifier::with_name("K").index(0).into(),
|
||||
)])
|
||||
.inputs(vec![DeclarationType::array((
|
||||
DeclarationType::FieldElement,
|
||||
Constant::Generic("K"),
|
||||
Constant::Generic(GenericIdentifier::with_name("K").index(0)),
|
||||
))])
|
||||
.outputs(vec![DeclarationType::array((
|
||||
DeclarationType::FieldElement,
|
||||
Constant::Generic("K"),
|
||||
Constant::Generic(GenericIdentifier::with_name("K").index(0)),
|
||||
))]);
|
||||
|
||||
let foo: TypedFunction<Bn128Field> = TypedFunction {
|
||||
arguments: vec![
|
||||
DeclarationVariable::array("a", DeclarationType::FieldElement, "K").into(),
|
||||
],
|
||||
arguments: vec![DeclarationVariable::array(
|
||||
"a",
|
||||
DeclarationType::FieldElement,
|
||||
GenericIdentifier::with_name("K").index(0),
|
||||
)
|
||||
.into()],
|
||||
statements: vec![TypedStatement::Return(vec![
|
||||
ArrayExpressionInner::Identifier("a".into())
|
||||
.annotate(Type::FieldElement, 1u32)
|
||||
|
@ -1200,7 +1215,11 @@ mod tests {
|
|||
TypedStatement::PushCallLog(
|
||||
DeclarationFunctionKey::with_location("main", "foo")
|
||||
.signature(foo_signature.clone()),
|
||||
GGenericsAssignment(vec![("K", 1)].into_iter().collect()),
|
||||
GGenericsAssignment(
|
||||
vec![(GenericIdentifier::with_name("K").index(0), 1)]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
),
|
||||
),
|
||||
TypedStatement::Definition(
|
||||
Variable::array(Identifier::from("a").version(1), Type::FieldElement, 1u32)
|
||||
|
@ -1299,19 +1318,21 @@ mod tests {
|
|||
let foo_signature = DeclarationSignature::new()
|
||||
.inputs(vec![DeclarationType::array((
|
||||
DeclarationType::FieldElement,
|
||||
Constant::Generic("K"),
|
||||
Constant::Generic(GenericIdentifier::with_name("K").index(0)),
|
||||
))])
|
||||
.outputs(vec![DeclarationType::array((
|
||||
DeclarationType::FieldElement,
|
||||
Constant::Generic("K"),
|
||||
Constant::Generic(GenericIdentifier::with_name("K").index(0)),
|
||||
))])
|
||||
.generics(vec![Some("K".into())]);
|
||||
.generics(vec![Some(
|
||||
GenericIdentifier::with_name("K").index(0).into(),
|
||||
)]);
|
||||
|
||||
let foo: TypedFunction<Bn128Field> = TypedFunction {
|
||||
arguments: vec![DeclarationVariable::array(
|
||||
"a",
|
||||
DeclarationType::FieldElement,
|
||||
Constant::Generic("K"),
|
||||
Constant::Generic(GenericIdentifier::with_name("K").index(0)),
|
||||
)
|
||||
.into()],
|
||||
statements: vec![
|
||||
|
@ -1375,7 +1396,7 @@ mod tests {
|
|||
arguments: vec![DeclarationVariable::array(
|
||||
"a",
|
||||
DeclarationType::FieldElement,
|
||||
Constant::Generic("K"),
|
||||
Constant::Generic(GenericIdentifier::with_name("K").index(0)),
|
||||
)
|
||||
.into()],
|
||||
statements: vec![TypedStatement::Return(vec![
|
||||
|
@ -1448,12 +1469,20 @@ mod tests {
|
|||
TypedStatement::PushCallLog(
|
||||
DeclarationFunctionKey::with_location("main", "foo")
|
||||
.signature(foo_signature.clone()),
|
||||
GGenericsAssignment(vec![("K", 1)].into_iter().collect()),
|
||||
GGenericsAssignment(
|
||||
vec![(GenericIdentifier::with_name("K").index(0), 1)]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
),
|
||||
),
|
||||
TypedStatement::PushCallLog(
|
||||
DeclarationFunctionKey::with_location("main", "bar")
|
||||
.signature(foo_signature.clone()),
|
||||
GGenericsAssignment(vec![("K", 2)].into_iter().collect()),
|
||||
GGenericsAssignment(
|
||||
vec![(GenericIdentifier::with_name("K").index(0), 2)]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
),
|
||||
),
|
||||
TypedStatement::Definition(
|
||||
Variable::array(Identifier::from("a").version(1), Type::FieldElement, 2u32)
|
||||
|
@ -1558,20 +1587,25 @@ mod tests {
|
|||
// Error: Incompatible
|
||||
|
||||
let foo_signature = DeclarationSignature::new()
|
||||
.generics(vec![Some("K".into())])
|
||||
.generics(vec![Some(
|
||||
GenericIdentifier::with_name("K").index(0).into(),
|
||||
)])
|
||||
.inputs(vec![DeclarationType::array((
|
||||
DeclarationType::FieldElement,
|
||||
Constant::Generic("K"),
|
||||
GenericIdentifier::with_name("K").index(0),
|
||||
))])
|
||||
.outputs(vec![DeclarationType::array((
|
||||
DeclarationType::FieldElement,
|
||||
Constant::Generic("K"),
|
||||
GenericIdentifier::with_name("K").index(0),
|
||||
))]);
|
||||
|
||||
let foo: TypedFunction<Bn128Field> = TypedFunction {
|
||||
arguments: vec![
|
||||
DeclarationVariable::array("a", DeclarationType::FieldElement, "K").into(),
|
||||
],
|
||||
arguments: vec![DeclarationVariable::array(
|
||||
"a",
|
||||
DeclarationType::FieldElement,
|
||||
GenericIdentifier::with_name("K").index(0),
|
||||
)
|
||||
.into()],
|
||||
statements: vec![TypedStatement::Return(vec![
|
||||
ArrayExpressionInner::Identifier("a".into())
|
||||
.annotate(Type::FieldElement, 1u32)
|
||||
|
|
|
@ -105,7 +105,7 @@ impl<'ast, 'a> ShallowTransformer<'ast, 'a> {
|
|||
.map(|(g, v)| {
|
||||
TypedStatement::Definition(
|
||||
TypedAssignee::Identifier(Variable::with_id_and_type(
|
||||
*g,
|
||||
g.name,
|
||||
Type::Uint(UBitwidth::B32),
|
||||
)),
|
||||
UExpression::from(*v as u32).into(),
|
||||
|
@ -731,7 +731,9 @@ mod tests {
|
|||
]),
|
||||
],
|
||||
signature: DeclarationSignature::new()
|
||||
.generics(vec![Some("K".into())])
|
||||
.generics(vec![Some(
|
||||
GenericIdentifier::with_name("K").index(0).into(),
|
||||
)])
|
||||
.inputs(vec![DeclarationType::FieldElement])
|
||||
.outputs(vec![DeclarationType::FieldElement]),
|
||||
};
|
||||
|
@ -740,7 +742,11 @@ mod tests {
|
|||
|
||||
let ssa = ShallowTransformer::transform(
|
||||
f,
|
||||
&GGenericsAssignment(vec![("K", 1)].into_iter().collect()),
|
||||
&GGenericsAssignment(
|
||||
vec![(GenericIdentifier::with_name("K").index(0), 1)]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
),
|
||||
&mut versions,
|
||||
);
|
||||
|
||||
|
@ -805,7 +811,9 @@ mod tests {
|
|||
.into()]),
|
||||
],
|
||||
signature: DeclarationSignature::new()
|
||||
.generics(vec![Some("K".into())])
|
||||
.generics(vec![Some(
|
||||
GenericIdentifier::with_name("K").index(0).into(),
|
||||
)])
|
||||
.inputs(vec![DeclarationType::FieldElement])
|
||||
.outputs(vec![DeclarationType::FieldElement]),
|
||||
};
|
||||
|
@ -912,7 +920,9 @@ mod tests {
|
|||
]),
|
||||
],
|
||||
signature: DeclarationSignature::new()
|
||||
.generics(vec![Some("K".into())])
|
||||
.generics(vec![Some(
|
||||
GenericIdentifier::with_name("K").index(0).into(),
|
||||
)])
|
||||
.inputs(vec![DeclarationType::FieldElement])
|
||||
.outputs(vec![DeclarationType::FieldElement]),
|
||||
};
|
||||
|
@ -921,7 +931,11 @@ mod tests {
|
|||
|
||||
let ssa = ShallowTransformer::transform(
|
||||
f,
|
||||
&GGenericsAssignment(vec![("K", 1)].into_iter().collect()),
|
||||
&GGenericsAssignment(
|
||||
vec![(GenericIdentifier::with_name("K").index(0), 1)]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
),
|
||||
&mut versions,
|
||||
);
|
||||
|
||||
|
@ -989,7 +1003,9 @@ mod tests {
|
|||
.into()]),
|
||||
],
|
||||
signature: DeclarationSignature::new()
|
||||
.generics(vec![Some("K".into())])
|
||||
.generics(vec![Some(
|
||||
GenericIdentifier::with_name("K").index(0).into(),
|
||||
)])
|
||||
.inputs(vec![DeclarationType::FieldElement])
|
||||
.outputs(vec![DeclarationType::FieldElement]),
|
||||
};
|
||||
|
|
|
@ -6,7 +6,46 @@ use std::fmt;
|
|||
use std::hash::{Hash, Hasher};
|
||||
use std::path::{Path, PathBuf};
|
||||
|
||||
pub type GenericIdentifier<'ast> = &'ast str;
|
||||
#[derive(Debug, Clone, Eq, Ord)]
|
||||
pub struct GenericIdentifier<'ast> {
|
||||
pub name: &'ast str,
|
||||
pub index: usize,
|
||||
}
|
||||
|
||||
impl<'ast> GenericIdentifier<'ast> {
|
||||
pub fn with_name(name: &'ast str) -> Self {
|
||||
Self { name, index: 0 }
|
||||
}
|
||||
|
||||
pub fn index(mut self, index: usize) -> Self {
|
||||
self.index = index;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast> PartialEq for GenericIdentifier<'ast> {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.index == other.index
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast> PartialOrd for GenericIdentifier<'ast> {
|
||||
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
|
||||
self.index.partial_cmp(&other.index)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast> Hash for GenericIdentifier<'ast> {
|
||||
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
|
||||
self.index.hash(state);
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast> fmt::Display for GenericIdentifier<'ast> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
write!(f, "{}", self.name)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct SpecializationError;
|
||||
|
@ -53,7 +92,9 @@ impl<'ast, T> From<usize> for UExpression<'ast, T> {
|
|||
impl<'ast, T> From<Constant<'ast>> for UExpression<'ast, T> {
|
||||
fn from(c: Constant<'ast>) -> Self {
|
||||
match c {
|
||||
Constant::Generic(i) => UExpressionInner::Identifier(i.into()).annotate(UBitwidth::B32),
|
||||
Constant::Generic(i) => {
|
||||
UExpressionInner::Identifier(i.name.into()).annotate(UBitwidth::B32)
|
||||
}
|
||||
Constant::Concrete(v) => UExpressionInner::Value(v as u128).annotate(UBitwidth::B32),
|
||||
}
|
||||
}
|
||||
|
@ -819,13 +860,41 @@ pub mod signature {
|
|||
use super::*;
|
||||
use std::fmt;
|
||||
|
||||
#[derive(Clone, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord, Hash)]
|
||||
#[derive(Clone, Serialize, Deserialize, Eq)]
|
||||
pub struct GSignature<S> {
|
||||
pub generics: Vec<Option<S>>,
|
||||
pub inputs: Vec<GType<S>>,
|
||||
pub outputs: Vec<GType<S>>,
|
||||
}
|
||||
|
||||
impl<S: PartialEq> PartialEq for GSignature<S> {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.inputs == other.inputs && self.outputs == other.outputs
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: PartialOrd> PartialOrd for GSignature<S> {
|
||||
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
|
||||
self.inputs
|
||||
.partial_cmp(&other.inputs)
|
||||
.map(|c| self.outputs.partial_cmp(&other.outputs).map(|d| c.then(d)))
|
||||
.unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: Ord> Ord for GSignature<S> {
|
||||
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
|
||||
self.partial_cmp(&other).unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: Hash> Hash for GSignature<S> {
|
||||
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
|
||||
self.inputs.hash(state);
|
||||
self.outputs.hash(state);
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> Default for GSignature<S> {
|
||||
fn default() -> Self {
|
||||
GSignature {
|
||||
|
@ -853,17 +922,17 @@ pub mod signature {
|
|||
|
||||
// both the inner type and the size must match
|
||||
check_type(&t0.ty, &t1.ty, constants)
|
||||
&& match t0.size {
|
||||
&& match &t0.size {
|
||||
// if the declared size is an identifier, we insert into the map, or check if the concrete size
|
||||
// matches if this identifier is already in the map
|
||||
Constant::Generic(id) => match constants.0.entry(id) {
|
||||
Constant::Generic(id) => match constants.0.entry(id.clone()) {
|
||||
Entry::Occupied(e) => *e.get() == s1,
|
||||
Entry::Vacant(e) => {
|
||||
e.insert(s1);
|
||||
true
|
||||
}
|
||||
},
|
||||
Constant::Concrete(s0) => s1 == s0 as usize,
|
||||
Constant::Concrete(s0) => s1 == *s0 as usize,
|
||||
}
|
||||
}
|
||||
(DeclarationType::FieldElement, GType::FieldElement)
|
||||
|
@ -1169,34 +1238,61 @@ pub mod signature {
|
|||
|
||||
#[test]
|
||||
fn signature_equivalence() {
|
||||
let generic = DeclarationSignature::new()
|
||||
.generics(vec![Some("P".into())])
|
||||
// check equivalence of:
|
||||
// <P>(field[P])
|
||||
// <Q>(field[Q])
|
||||
|
||||
let generic1 = DeclarationSignature::new()
|
||||
.generics(vec![Some(
|
||||
GenericIdentifier {
|
||||
name: "P",
|
||||
index: 0,
|
||||
}
|
||||
.into(),
|
||||
)])
|
||||
.inputs(vec![DeclarationType::array(DeclarationArrayType::new(
|
||||
DeclarationType::FieldElement,
|
||||
"P".into(),
|
||||
GenericIdentifier {
|
||||
name: "P",
|
||||
index: 0,
|
||||
}
|
||||
.into(),
|
||||
))]);
|
||||
let generic2 = DeclarationSignature::new()
|
||||
.generics(vec![Some(
|
||||
GenericIdentifier {
|
||||
name: "Q",
|
||||
index: 0,
|
||||
}
|
||||
.into(),
|
||||
)])
|
||||
.inputs(vec![DeclarationType::array(DeclarationArrayType::new(
|
||||
DeclarationType::FieldElement,
|
||||
GenericIdentifier {
|
||||
name: "Q",
|
||||
index: 0,
|
||||
}
|
||||
.into(),
|
||||
))]);
|
||||
let specialized = DeclarationSignature::new().inputs(vec![DeclarationType::array(
|
||||
DeclarationArrayType::new(DeclarationType::FieldElement, 3u32.into()),
|
||||
)]);
|
||||
|
||||
assert_eq!(generic, specialized);
|
||||
assert_eq!(generic1, generic2);
|
||||
assert_eq!(
|
||||
{
|
||||
let mut hasher = std::collections::hash_map::DefaultHasher::new();
|
||||
generic.hash(&mut hasher);
|
||||
generic1.hash(&mut hasher);
|
||||
hasher.finish()
|
||||
},
|
||||
{
|
||||
let mut hasher = std::collections::hash_map::DefaultHasher::new();
|
||||
specialized.hash(&mut hasher);
|
||||
generic2.hash(&mut hasher);
|
||||
hasher.finish()
|
||||
}
|
||||
);
|
||||
assert_eq!(
|
||||
generic.partial_cmp(&specialized),
|
||||
generic1.partial_cmp(&generic2),
|
||||
Some(std::cmp::Ordering::Equal)
|
||||
);
|
||||
assert_eq!(generic.cmp(&specialized), std::cmp::Ordering::Equal);
|
||||
assert_eq!(generic1.cmp(&generic2), std::cmp::Ordering::Equal);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
Loading…
Reference in a new issue