1
0
Fork 0
mirror of synced 2025-09-23 12:18:44 +00:00

generic equivalence, add tests

This commit is contained in:
schaeff 2021-04-19 18:46:13 +02:00
parent 289a1b36c2
commit 3069d598c6
9 changed files with 293 additions and 85 deletions

View file

@ -0,0 +1,8 @@
def foo(field[2] a) -> bool:
return true
def foo(field[1] a) -> bool:
return true
def main() -> bool:
return foo([1])

View file

@ -0,0 +1,11 @@
def foo<N, P>(field[N] a) -> field[P]:
return a
def foo<P, N>(field[P] a) -> field[N]:
return a
def bar<Q>(field[Q] a) -> field[Q]:
return foo(a)
def main() -> field[1]:
return bar([1])

View file

@ -0,0 +1,8 @@
def foo<N>(field[N] a) -> bool:
return true
def foo<P>(field[P] a) -> bool:
return true
def main():
return

View file

@ -0,0 +1,14 @@
// this should compile but requires looking at the generic parameter values, and they are not known at semantic check time.
// It is enough of an edge case to not be worth fixing
def foo<N, P>(field[N] a) -> field[P]:
return [1; P]
def foo<N, P>(field[P] a) -> field[N]:
return [1; N]
def main() -> field[2]:
u32 X = 1
u32 Y = 2
field[2] a = foo::<X, Y>([1])
return a

View file

@ -4,7 +4,7 @@ use crate::flat_absy::{
}; };
use crate::solvers::Solver; use crate::solvers::Solver;
use crate::typed_absy::types::{ use crate::typed_absy::types::{
ConcreteGenericsAssignment, Constant, DeclarationSignature, DeclarationType, ConcreteGenericsAssignment, Constant, DeclarationSignature, DeclarationType, GenericIdentifier,
}; };
use std::collections::HashMap; use std::collections::HashMap;
use zokrates_field::{Bn128Field, Field}; use zokrates_field::{Bn128Field, Field};
@ -43,11 +43,17 @@ impl FlatEmbed {
.inputs(vec![DeclarationType::uint(32)]) .inputs(vec![DeclarationType::uint(32)])
.outputs(vec![DeclarationType::FieldElement]), .outputs(vec![DeclarationType::FieldElement]),
FlatEmbed::Unpack => DeclarationSignature::new() FlatEmbed::Unpack => DeclarationSignature::new()
.generics(vec![Some(Constant::Generic("N"))]) .generics(vec![Some(Constant::Generic(GenericIdentifier {
name: "N",
index: 0,
}))])
.inputs(vec![DeclarationType::FieldElement]) .inputs(vec![DeclarationType::FieldElement])
.outputs(vec![DeclarationType::array(( .outputs(vec![DeclarationType::array((
DeclarationType::Boolean, DeclarationType::Boolean,
"N", GenericIdentifier {
name: "N",
index: 0,
},
))]), ))]),
FlatEmbed::U8ToBits => DeclarationSignature::new() FlatEmbed::U8ToBits => DeclarationSignature::new()
.inputs(vec![DeclarationType::uint(8)]) .inputs(vec![DeclarationType::uint(8)])

View file

@ -20,7 +20,8 @@ use crate::absy::types::{UnresolvedSignature, UnresolvedType, UserTypeId};
use crate::typed_absy::types::{ use crate::typed_absy::types::{
ArrayType, Constant, DeclarationArrayType, DeclarationFunctionKey, DeclarationSignature, ArrayType, Constant, DeclarationArrayType, DeclarationFunctionKey, DeclarationSignature,
DeclarationStructMember, DeclarationStructType, DeclarationType, StructLocation, DeclarationStructMember, DeclarationStructType, DeclarationType, GenericIdentifier,
StructLocation,
}; };
use std::hash::{Hash, Hasher}; use std::hash::{Hash, Hasher};
@ -325,7 +326,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
for field in s.fields { for field in s.fields {
let member_id = field.value.id.to_string(); let member_id = field.value.id.to_string();
match self match self
.check_declaration_type(field.value.ty, module_id, &types, &HashSet::new()) .check_declaration_type(field.value.ty, module_id, &types, &HashMap::new())
.map(|t| (member_id, t)) .map(|t| (member_id, t))
{ {
Ok(f) => match fields_set.insert(f.0.clone()) { Ok(f) => match fields_set.insert(f.0.clone()) {
@ -696,7 +697,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
let v = Variable::with_id_and_type( let v = Variable::with_id_and_type(
match generic { match generic {
Constant::Generic(g) => g, Constant::Generic(g) => g.name,
_ => unreachable!(), _ => unreachable!(),
}, },
Type::Uint(UBitwidth::B32), Type::Uint(UBitwidth::B32),
@ -818,12 +819,15 @@ impl<'ast, T: Field> Checker<'ast, T> {
let mut outputs = vec![]; let mut outputs = vec![];
let mut generics = vec![]; let mut generics = vec![];
let mut constants = HashSet::new(); let mut generics_map = HashMap::new();
for g in signature.generics { for (index, g) in signature.generics.iter().enumerate() {
match constants.insert(g.value) { match generics_map.insert(g.value, index).is_none() {
true => { true => {
generics.push(Some(Constant::Generic(g.value))); generics.push(Some(Constant::Generic(GenericIdentifier {
name: g.value,
index,
})));
} }
false => { false => {
errors.push(ErrorInner { errors.push(ErrorInner {
@ -835,7 +839,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
} }
for t in signature.inputs { for t in signature.inputs {
match self.check_declaration_type(t, module_id, types, &constants) { match self.check_declaration_type(t, module_id, types, &generics_map) {
Ok(t) => { Ok(t) => {
inputs.push(t); inputs.push(t);
} }
@ -846,7 +850,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
} }
for t in signature.outputs { for t in signature.outputs {
match self.check_declaration_type(t, module_id, types, &constants) { match self.check_declaration_type(t, module_id, types, &generics_map) {
Ok(t) => { Ok(t) => {
outputs.push(t); outputs.push(t);
} }
@ -936,6 +940,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
fn check_generic_expression( fn check_generic_expression(
&mut self, &mut self,
expr: ExpressionNode<'ast>, expr: ExpressionNode<'ast>,
generics_map: &HashMap<Identifier<'ast>, usize>,
) -> Result<Constant<'ast>, ErrorInner> { ) -> Result<Constant<'ast>, ErrorInner> {
let pos = expr.pos(); let pos = expr.pos();
@ -956,7 +961,16 @@ impl<'ast, T: Field> Checker<'ast, T> {
}) })
} }
} }
Expression::Identifier(name) => Ok(Constant::Generic(name)), Expression::Identifier(name) => {
// check that this generic parameter is defined
match generics_map.get(&name) {
Some(index) => Ok(Constant::Generic(GenericIdentifier {name, index: *index})),
None => Err(ErrorInner {
pos: Some(pos),
message: format!("Undeclared generic parameter in function definition: `{}` isn\'t declared as a generic constant", name)
})
}
}
e => Err(ErrorInner { e => Err(ErrorInner {
pos: Some(pos), pos: Some(pos),
message: format!( message: format!(
@ -972,7 +986,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
ty: UnresolvedTypeNode<'ast>, ty: UnresolvedTypeNode<'ast>,
module_id: &ModuleId, module_id: &ModuleId,
types: &TypeMap<'ast>, types: &TypeMap<'ast>,
constants: &HashSet<Identifier<'ast>>, generics_map: &HashMap<Identifier<'ast>, usize>,
) -> Result<DeclarationType<'ast>, ErrorInner> { ) -> Result<DeclarationType<'ast>, ErrorInner> {
let pos = ty.pos(); let pos = ty.pos();
let ty = ty.value; let ty = ty.value;
@ -982,19 +996,10 @@ impl<'ast, T: Field> Checker<'ast, T> {
UnresolvedType::Boolean => Ok(DeclarationType::Boolean), UnresolvedType::Boolean => Ok(DeclarationType::Boolean),
UnresolvedType::Uint(bitwidth) => Ok(DeclarationType::uint(bitwidth)), UnresolvedType::Uint(bitwidth) => Ok(DeclarationType::uint(bitwidth)),
UnresolvedType::Array(t, size) => { UnresolvedType::Array(t, size) => {
let checked_size = self.check_generic_expression(size.clone())?; let checked_size = self.check_generic_expression(size.clone(), &generics_map)?;
if let Constant::Generic(g) = checked_size {
if !constants.contains(g) {
return Err(ErrorInner {
pos: Some(pos),
message: format!("Undeclared generic parameter in function definition: `{}` isn\'t declared as a generic constant", g)
});
}
};
Ok(DeclarationType::Array(DeclarationArrayType::new( Ok(DeclarationType::Array(DeclarationArrayType::new(
self.check_declaration_type(*t, module_id, types, constants)?, self.check_declaration_type(*t, module_id, types, generics_map)?,
checked_size, checked_size,
))) )))
} }
@ -3022,17 +3027,21 @@ mod tests {
assert!(!unifier.insert_function( assert!(!unifier.insert_function(
"bar", "bar",
DeclarationSignature::new() DeclarationSignature::new()
.generics(vec![Some("K".into())]) .generics(vec![Some(
GenericIdentifier::with_name("K").index(0).into()
)])
.inputs(vec![DeclarationType::FieldElement]) .inputs(vec![DeclarationType::FieldElement])
)); ));
// a `bar` function with a different signature // a `bar` function with a different signature
assert!(unifier.insert_function( assert!(unifier.insert_function(
"bar", "bar",
DeclarationSignature::new() DeclarationSignature::new()
.generics(vec![Some("K".into())]) .generics(vec![Some(
GenericIdentifier::with_name("K").index(0).into()
)])
.inputs(vec![DeclarationType::array(( .inputs(vec![DeclarationType::array((
DeclarationType::FieldElement, DeclarationType::FieldElement,
"K" GenericIdentifier::with_name("K").index(0)
))]) ))])
)); ));
// a `bar` function with a different signature, but which could conflict with the previous one // a `bar` function with a different signature, but which could conflict with the previous one
@ -3609,12 +3618,18 @@ mod tests {
), ),
Ok(DeclarationSignature::new() Ok(DeclarationSignature::new()
.inputs(vec![DeclarationType::array(( .inputs(vec![DeclarationType::array((
DeclarationType::array((DeclarationType::FieldElement, "K")), DeclarationType::array((
"L" DeclarationType::FieldElement,
GenericIdentifier::with_name("K").index(0)
)),
GenericIdentifier::with_name("L").index(1)
))]) ))])
.outputs(vec![DeclarationType::array(( .outputs(vec![DeclarationType::array((
DeclarationType::array((DeclarationType::FieldElement, "L")), DeclarationType::array((
"K" DeclarationType::FieldElement,
GenericIdentifier::with_name("L").index(1)
)),
GenericIdentifier::with_name("K").index(0)
))])) ))]))
); );
} }

View file

@ -657,8 +657,9 @@ mod tests {
use crate::typed_absy::types::DeclarationSignature; use crate::typed_absy::types::DeclarationSignature;
use crate::typed_absy::{ use crate::typed_absy::{
ArrayExpressionInner, DeclarationFunctionKey, DeclarationType, DeclarationVariable, ArrayExpressionInner, DeclarationFunctionKey, DeclarationType, DeclarationVariable,
FieldElementExpression, Identifier, OwnedTypedModuleId, Select, Type, TypedExpression, FieldElementExpression, GenericIdentifier, Identifier, OwnedTypedModuleId, Select, Type,
TypedExpressionList, TypedExpressionOrSpread, UBitwidth, UExpressionInner, Variable, TypedExpression, TypedExpressionList, TypedExpressionOrSpread, UBitwidth, UExpressionInner,
Variable,
}; };
use zokrates_field::Bn128Field; use zokrates_field::Bn128Field;
@ -865,20 +866,25 @@ mod tests {
// return a_2 + b_1[0] // return a_2 + b_1[0]
let foo_signature = DeclarationSignature::new() let foo_signature = DeclarationSignature::new()
.generics(vec![Some("K".into())]) .generics(vec![Some(
GenericIdentifier::with_name("K").index(0).into(),
)])
.inputs(vec![DeclarationType::array(( .inputs(vec![DeclarationType::array((
DeclarationType::FieldElement, DeclarationType::FieldElement,
Constant::Generic("K"), Constant::Generic(GenericIdentifier::with_name("K").index(0)),
))]) ))])
.outputs(vec![DeclarationType::array(( .outputs(vec![DeclarationType::array((
DeclarationType::FieldElement, DeclarationType::FieldElement,
Constant::Generic("K"), Constant::Generic(GenericIdentifier::with_name("K").index(0)),
))]); ))]);
let foo: TypedFunction<Bn128Field> = TypedFunction { let foo: TypedFunction<Bn128Field> = TypedFunction {
arguments: vec![ arguments: vec![DeclarationVariable::array(
DeclarationVariable::array("a", DeclarationType::FieldElement, "K").into(), "a",
], DeclarationType::FieldElement,
GenericIdentifier::with_name("K").index(0),
)
.into()],
statements: vec![TypedStatement::Return(vec![ statements: vec![TypedStatement::Return(vec![
ArrayExpressionInner::Identifier("a".into()) ArrayExpressionInner::Identifier("a".into())
.annotate(Type::FieldElement, 1u32) .annotate(Type::FieldElement, 1u32)
@ -983,7 +989,11 @@ mod tests {
TypedStatement::PushCallLog( TypedStatement::PushCallLog(
DeclarationFunctionKey::with_location("main", "foo") DeclarationFunctionKey::with_location("main", "foo")
.signature(foo_signature.clone()), .signature(foo_signature.clone()),
GGenericsAssignment(vec![("K", 1)].into_iter().collect()), GGenericsAssignment(
vec![(GenericIdentifier::with_name("K").index(0), 1)]
.into_iter()
.collect(),
),
), ),
TypedStatement::Definition( TypedStatement::Definition(
Variable::array(Identifier::from("a").version(1), Type::FieldElement, 1u32) Variable::array(Identifier::from("a").version(1), Type::FieldElement, 1u32)
@ -1073,20 +1083,25 @@ mod tests {
// return a_2 + b_1[0] // return a_2 + b_1[0]
let foo_signature = DeclarationSignature::new() let foo_signature = DeclarationSignature::new()
.generics(vec![Some("K".into())]) .generics(vec![Some(
GenericIdentifier::with_name("K").index(0).into(),
)])
.inputs(vec![DeclarationType::array(( .inputs(vec![DeclarationType::array((
DeclarationType::FieldElement, DeclarationType::FieldElement,
Constant::Generic("K"), Constant::Generic(GenericIdentifier::with_name("K").index(0)),
))]) ))])
.outputs(vec![DeclarationType::array(( .outputs(vec![DeclarationType::array((
DeclarationType::FieldElement, DeclarationType::FieldElement,
Constant::Generic("K"), Constant::Generic(GenericIdentifier::with_name("K").index(0)),
))]); ))]);
let foo: TypedFunction<Bn128Field> = TypedFunction { let foo: TypedFunction<Bn128Field> = TypedFunction {
arguments: vec![ arguments: vec![DeclarationVariable::array(
DeclarationVariable::array("a", DeclarationType::FieldElement, "K").into(), "a",
], DeclarationType::FieldElement,
GenericIdentifier::with_name("K").index(0),
)
.into()],
statements: vec![TypedStatement::Return(vec![ statements: vec![TypedStatement::Return(vec![
ArrayExpressionInner::Identifier("a".into()) ArrayExpressionInner::Identifier("a".into())
.annotate(Type::FieldElement, 1u32) .annotate(Type::FieldElement, 1u32)
@ -1200,7 +1215,11 @@ mod tests {
TypedStatement::PushCallLog( TypedStatement::PushCallLog(
DeclarationFunctionKey::with_location("main", "foo") DeclarationFunctionKey::with_location("main", "foo")
.signature(foo_signature.clone()), .signature(foo_signature.clone()),
GGenericsAssignment(vec![("K", 1)].into_iter().collect()), GGenericsAssignment(
vec![(GenericIdentifier::with_name("K").index(0), 1)]
.into_iter()
.collect(),
),
), ),
TypedStatement::Definition( TypedStatement::Definition(
Variable::array(Identifier::from("a").version(1), Type::FieldElement, 1u32) Variable::array(Identifier::from("a").version(1), Type::FieldElement, 1u32)
@ -1299,19 +1318,21 @@ mod tests {
let foo_signature = DeclarationSignature::new() let foo_signature = DeclarationSignature::new()
.inputs(vec![DeclarationType::array(( .inputs(vec![DeclarationType::array((
DeclarationType::FieldElement, DeclarationType::FieldElement,
Constant::Generic("K"), Constant::Generic(GenericIdentifier::with_name("K").index(0)),
))]) ))])
.outputs(vec![DeclarationType::array(( .outputs(vec![DeclarationType::array((
DeclarationType::FieldElement, DeclarationType::FieldElement,
Constant::Generic("K"), Constant::Generic(GenericIdentifier::with_name("K").index(0)),
))]) ))])
.generics(vec![Some("K".into())]); .generics(vec![Some(
GenericIdentifier::with_name("K").index(0).into(),
)]);
let foo: TypedFunction<Bn128Field> = TypedFunction { let foo: TypedFunction<Bn128Field> = TypedFunction {
arguments: vec![DeclarationVariable::array( arguments: vec![DeclarationVariable::array(
"a", "a",
DeclarationType::FieldElement, DeclarationType::FieldElement,
Constant::Generic("K"), Constant::Generic(GenericIdentifier::with_name("K").index(0)),
) )
.into()], .into()],
statements: vec![ statements: vec![
@ -1375,7 +1396,7 @@ mod tests {
arguments: vec![DeclarationVariable::array( arguments: vec![DeclarationVariable::array(
"a", "a",
DeclarationType::FieldElement, DeclarationType::FieldElement,
Constant::Generic("K"), Constant::Generic(GenericIdentifier::with_name("K").index(0)),
) )
.into()], .into()],
statements: vec![TypedStatement::Return(vec![ statements: vec![TypedStatement::Return(vec![
@ -1448,12 +1469,20 @@ mod tests {
TypedStatement::PushCallLog( TypedStatement::PushCallLog(
DeclarationFunctionKey::with_location("main", "foo") DeclarationFunctionKey::with_location("main", "foo")
.signature(foo_signature.clone()), .signature(foo_signature.clone()),
GGenericsAssignment(vec![("K", 1)].into_iter().collect()), GGenericsAssignment(
vec![(GenericIdentifier::with_name("K").index(0), 1)]
.into_iter()
.collect(),
),
), ),
TypedStatement::PushCallLog( TypedStatement::PushCallLog(
DeclarationFunctionKey::with_location("main", "bar") DeclarationFunctionKey::with_location("main", "bar")
.signature(foo_signature.clone()), .signature(foo_signature.clone()),
GGenericsAssignment(vec![("K", 2)].into_iter().collect()), GGenericsAssignment(
vec![(GenericIdentifier::with_name("K").index(0), 2)]
.into_iter()
.collect(),
),
), ),
TypedStatement::Definition( TypedStatement::Definition(
Variable::array(Identifier::from("a").version(1), Type::FieldElement, 2u32) Variable::array(Identifier::from("a").version(1), Type::FieldElement, 2u32)
@ -1558,20 +1587,25 @@ mod tests {
// Error: Incompatible // Error: Incompatible
let foo_signature = DeclarationSignature::new() let foo_signature = DeclarationSignature::new()
.generics(vec![Some("K".into())]) .generics(vec![Some(
GenericIdentifier::with_name("K").index(0).into(),
)])
.inputs(vec![DeclarationType::array(( .inputs(vec![DeclarationType::array((
DeclarationType::FieldElement, DeclarationType::FieldElement,
Constant::Generic("K"), GenericIdentifier::with_name("K").index(0),
))]) ))])
.outputs(vec![DeclarationType::array(( .outputs(vec![DeclarationType::array((
DeclarationType::FieldElement, DeclarationType::FieldElement,
Constant::Generic("K"), GenericIdentifier::with_name("K").index(0),
))]); ))]);
let foo: TypedFunction<Bn128Field> = TypedFunction { let foo: TypedFunction<Bn128Field> = TypedFunction {
arguments: vec![ arguments: vec![DeclarationVariable::array(
DeclarationVariable::array("a", DeclarationType::FieldElement, "K").into(), "a",
], DeclarationType::FieldElement,
GenericIdentifier::with_name("K").index(0),
)
.into()],
statements: vec![TypedStatement::Return(vec![ statements: vec![TypedStatement::Return(vec![
ArrayExpressionInner::Identifier("a".into()) ArrayExpressionInner::Identifier("a".into())
.annotate(Type::FieldElement, 1u32) .annotate(Type::FieldElement, 1u32)

View file

@ -105,7 +105,7 @@ impl<'ast, 'a> ShallowTransformer<'ast, 'a> {
.map(|(g, v)| { .map(|(g, v)| {
TypedStatement::Definition( TypedStatement::Definition(
TypedAssignee::Identifier(Variable::with_id_and_type( TypedAssignee::Identifier(Variable::with_id_and_type(
*g, g.name,
Type::Uint(UBitwidth::B32), Type::Uint(UBitwidth::B32),
)), )),
UExpression::from(*v as u32).into(), UExpression::from(*v as u32).into(),
@ -731,7 +731,9 @@ mod tests {
]), ]),
], ],
signature: DeclarationSignature::new() signature: DeclarationSignature::new()
.generics(vec![Some("K".into())]) .generics(vec![Some(
GenericIdentifier::with_name("K").index(0).into(),
)])
.inputs(vec![DeclarationType::FieldElement]) .inputs(vec![DeclarationType::FieldElement])
.outputs(vec![DeclarationType::FieldElement]), .outputs(vec![DeclarationType::FieldElement]),
}; };
@ -740,7 +742,11 @@ mod tests {
let ssa = ShallowTransformer::transform( let ssa = ShallowTransformer::transform(
f, f,
&GGenericsAssignment(vec![("K", 1)].into_iter().collect()), &GGenericsAssignment(
vec![(GenericIdentifier::with_name("K").index(0), 1)]
.into_iter()
.collect(),
),
&mut versions, &mut versions,
); );
@ -805,7 +811,9 @@ mod tests {
.into()]), .into()]),
], ],
signature: DeclarationSignature::new() signature: DeclarationSignature::new()
.generics(vec![Some("K".into())]) .generics(vec![Some(
GenericIdentifier::with_name("K").index(0).into(),
)])
.inputs(vec![DeclarationType::FieldElement]) .inputs(vec![DeclarationType::FieldElement])
.outputs(vec![DeclarationType::FieldElement]), .outputs(vec![DeclarationType::FieldElement]),
}; };
@ -912,7 +920,9 @@ mod tests {
]), ]),
], ],
signature: DeclarationSignature::new() signature: DeclarationSignature::new()
.generics(vec![Some("K".into())]) .generics(vec![Some(
GenericIdentifier::with_name("K").index(0).into(),
)])
.inputs(vec![DeclarationType::FieldElement]) .inputs(vec![DeclarationType::FieldElement])
.outputs(vec![DeclarationType::FieldElement]), .outputs(vec![DeclarationType::FieldElement]),
}; };
@ -921,7 +931,11 @@ mod tests {
let ssa = ShallowTransformer::transform( let ssa = ShallowTransformer::transform(
f, f,
&GGenericsAssignment(vec![("K", 1)].into_iter().collect()), &GGenericsAssignment(
vec![(GenericIdentifier::with_name("K").index(0), 1)]
.into_iter()
.collect(),
),
&mut versions, &mut versions,
); );
@ -989,7 +1003,9 @@ mod tests {
.into()]), .into()]),
], ],
signature: DeclarationSignature::new() signature: DeclarationSignature::new()
.generics(vec![Some("K".into())]) .generics(vec![Some(
GenericIdentifier::with_name("K").index(0).into(),
)])
.inputs(vec![DeclarationType::FieldElement]) .inputs(vec![DeclarationType::FieldElement])
.outputs(vec![DeclarationType::FieldElement]), .outputs(vec![DeclarationType::FieldElement]),
}; };

View file

@ -6,7 +6,46 @@ use std::fmt;
use std::hash::{Hash, Hasher}; use std::hash::{Hash, Hasher};
use std::path::{Path, PathBuf}; use std::path::{Path, PathBuf};
pub type GenericIdentifier<'ast> = &'ast str; #[derive(Debug, Clone, Eq, Ord)]
pub struct GenericIdentifier<'ast> {
pub name: &'ast str,
pub index: usize,
}
impl<'ast> GenericIdentifier<'ast> {
pub fn with_name(name: &'ast str) -> Self {
Self { name, index: 0 }
}
pub fn index(mut self, index: usize) -> Self {
self.index = index;
self
}
}
impl<'ast> PartialEq for GenericIdentifier<'ast> {
fn eq(&self, other: &Self) -> bool {
self.index == other.index
}
}
impl<'ast> PartialOrd for GenericIdentifier<'ast> {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
self.index.partial_cmp(&other.index)
}
}
impl<'ast> Hash for GenericIdentifier<'ast> {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.index.hash(state);
}
}
impl<'ast> fmt::Display for GenericIdentifier<'ast> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}", self.name)
}
}
#[derive(Debug)] #[derive(Debug)]
pub struct SpecializationError; pub struct SpecializationError;
@ -53,7 +92,9 @@ impl<'ast, T> From<usize> for UExpression<'ast, T> {
impl<'ast, T> From<Constant<'ast>> for UExpression<'ast, T> { impl<'ast, T> From<Constant<'ast>> for UExpression<'ast, T> {
fn from(c: Constant<'ast>) -> Self { fn from(c: Constant<'ast>) -> Self {
match c { match c {
Constant::Generic(i) => UExpressionInner::Identifier(i.into()).annotate(UBitwidth::B32), Constant::Generic(i) => {
UExpressionInner::Identifier(i.name.into()).annotate(UBitwidth::B32)
}
Constant::Concrete(v) => UExpressionInner::Value(v as u128).annotate(UBitwidth::B32), Constant::Concrete(v) => UExpressionInner::Value(v as u128).annotate(UBitwidth::B32),
} }
} }
@ -819,13 +860,41 @@ pub mod signature {
use super::*; use super::*;
use std::fmt; use std::fmt;
#[derive(Clone, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord, Hash)] #[derive(Clone, Serialize, Deserialize, Eq)]
pub struct GSignature<S> { pub struct GSignature<S> {
pub generics: Vec<Option<S>>, pub generics: Vec<Option<S>>,
pub inputs: Vec<GType<S>>, pub inputs: Vec<GType<S>>,
pub outputs: Vec<GType<S>>, pub outputs: Vec<GType<S>>,
} }
impl<S: PartialEq> PartialEq for GSignature<S> {
fn eq(&self, other: &Self) -> bool {
self.inputs == other.inputs && self.outputs == other.outputs
}
}
impl<S: PartialOrd> PartialOrd for GSignature<S> {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
self.inputs
.partial_cmp(&other.inputs)
.map(|c| self.outputs.partial_cmp(&other.outputs).map(|d| c.then(d)))
.unwrap()
}
}
impl<S: Ord> Ord for GSignature<S> {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.partial_cmp(&other).unwrap()
}
}
impl<S: Hash> Hash for GSignature<S> {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.inputs.hash(state);
self.outputs.hash(state);
}
}
impl<S> Default for GSignature<S> { impl<S> Default for GSignature<S> {
fn default() -> Self { fn default() -> Self {
GSignature { GSignature {
@ -853,17 +922,17 @@ pub mod signature {
// both the inner type and the size must match // both the inner type and the size must match
check_type(&t0.ty, &t1.ty, constants) check_type(&t0.ty, &t1.ty, constants)
&& match t0.size { && match &t0.size {
// if the declared size is an identifier, we insert into the map, or check if the concrete size // if the declared size is an identifier, we insert into the map, or check if the concrete size
// matches if this identifier is already in the map // matches if this identifier is already in the map
Constant::Generic(id) => match constants.0.entry(id) { Constant::Generic(id) => match constants.0.entry(id.clone()) {
Entry::Occupied(e) => *e.get() == s1, Entry::Occupied(e) => *e.get() == s1,
Entry::Vacant(e) => { Entry::Vacant(e) => {
e.insert(s1); e.insert(s1);
true true
} }
}, },
Constant::Concrete(s0) => s1 == s0 as usize, Constant::Concrete(s0) => s1 == *s0 as usize,
} }
} }
(DeclarationType::FieldElement, GType::FieldElement) (DeclarationType::FieldElement, GType::FieldElement)
@ -1169,34 +1238,61 @@ pub mod signature {
#[test] #[test]
fn signature_equivalence() { fn signature_equivalence() {
let generic = DeclarationSignature::new() // check equivalence of:
.generics(vec![Some("P".into())]) // <P>(field[P])
// <Q>(field[Q])
let generic1 = DeclarationSignature::new()
.generics(vec![Some(
GenericIdentifier {
name: "P",
index: 0,
}
.into(),
)])
.inputs(vec![DeclarationType::array(DeclarationArrayType::new( .inputs(vec![DeclarationType::array(DeclarationArrayType::new(
DeclarationType::FieldElement, DeclarationType::FieldElement,
"P".into(), GenericIdentifier {
name: "P",
index: 0,
}
.into(),
))]);
let generic2 = DeclarationSignature::new()
.generics(vec![Some(
GenericIdentifier {
name: "Q",
index: 0,
}
.into(),
)])
.inputs(vec![DeclarationType::array(DeclarationArrayType::new(
DeclarationType::FieldElement,
GenericIdentifier {
name: "Q",
index: 0,
}
.into(),
))]); ))]);
let specialized = DeclarationSignature::new().inputs(vec![DeclarationType::array(
DeclarationArrayType::new(DeclarationType::FieldElement, 3u32.into()),
)]);
assert_eq!(generic, specialized); assert_eq!(generic1, generic2);
assert_eq!( assert_eq!(
{ {
let mut hasher = std::collections::hash_map::DefaultHasher::new(); let mut hasher = std::collections::hash_map::DefaultHasher::new();
generic.hash(&mut hasher); generic1.hash(&mut hasher);
hasher.finish() hasher.finish()
}, },
{ {
let mut hasher = std::collections::hash_map::DefaultHasher::new(); let mut hasher = std::collections::hash_map::DefaultHasher::new();
specialized.hash(&mut hasher); generic2.hash(&mut hasher);
hasher.finish() hasher.finish()
} }
); );
assert_eq!( assert_eq!(
generic.partial_cmp(&specialized), generic1.partial_cmp(&generic2),
Some(std::cmp::Ordering::Equal) Some(std::cmp::Ordering::Equal)
); );
assert_eq!(generic.cmp(&specialized), std::cmp::Ordering::Equal); assert_eq!(generic1.cmp(&generic2), std::cmp::Ordering::Equal);
} }
#[test] #[test]