generic equivalence, add tests
This commit is contained in:
parent
289a1b36c2
commit
3069d598c6
9 changed files with 293 additions and 85 deletions
8
zokrates_cli/examples/array_overload.zok
Normal file
8
zokrates_cli/examples/array_overload.zok
Normal file
|
@ -0,0 +1,8 @@
|
||||||
|
def foo(field[2] a) -> bool:
|
||||||
|
return true
|
||||||
|
|
||||||
|
def foo(field[1] a) -> bool:
|
||||||
|
return true
|
||||||
|
|
||||||
|
def main() -> bool:
|
||||||
|
return foo([1])
|
|
@ -0,0 +1,11 @@
|
||||||
|
def foo<N, P>(field[N] a) -> field[P]:
|
||||||
|
return a
|
||||||
|
|
||||||
|
def foo<P, N>(field[P] a) -> field[N]:
|
||||||
|
return a
|
||||||
|
|
||||||
|
def bar<Q>(field[Q] a) -> field[Q]:
|
||||||
|
return foo(a)
|
||||||
|
|
||||||
|
def main() -> field[1]:
|
||||||
|
return bar([1])
|
|
@ -0,0 +1,8 @@
|
||||||
|
def foo<N>(field[N] a) -> bool:
|
||||||
|
return true
|
||||||
|
|
||||||
|
def foo<P>(field[P] a) -> bool:
|
||||||
|
return true
|
||||||
|
|
||||||
|
def main():
|
||||||
|
return
|
|
@ -0,0 +1,14 @@
|
||||||
|
// this should compile but requires looking at the generic parameter values, and they are not known at semantic check time.
|
||||||
|
// It is enough of an edge case to not be worth fixing
|
||||||
|
|
||||||
|
def foo<N, P>(field[N] a) -> field[P]:
|
||||||
|
return [1; P]
|
||||||
|
|
||||||
|
def foo<N, P>(field[P] a) -> field[N]:
|
||||||
|
return [1; N]
|
||||||
|
|
||||||
|
def main() -> field[2]:
|
||||||
|
u32 X = 1
|
||||||
|
u32 Y = 2
|
||||||
|
field[2] a = foo::<X, Y>([1])
|
||||||
|
return a
|
|
@ -4,7 +4,7 @@ use crate::flat_absy::{
|
||||||
};
|
};
|
||||||
use crate::solvers::Solver;
|
use crate::solvers::Solver;
|
||||||
use crate::typed_absy::types::{
|
use crate::typed_absy::types::{
|
||||||
ConcreteGenericsAssignment, Constant, DeclarationSignature, DeclarationType,
|
ConcreteGenericsAssignment, Constant, DeclarationSignature, DeclarationType, GenericIdentifier,
|
||||||
};
|
};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use zokrates_field::{Bn128Field, Field};
|
use zokrates_field::{Bn128Field, Field};
|
||||||
|
@ -43,11 +43,17 @@ impl FlatEmbed {
|
||||||
.inputs(vec![DeclarationType::uint(32)])
|
.inputs(vec![DeclarationType::uint(32)])
|
||||||
.outputs(vec![DeclarationType::FieldElement]),
|
.outputs(vec![DeclarationType::FieldElement]),
|
||||||
FlatEmbed::Unpack => DeclarationSignature::new()
|
FlatEmbed::Unpack => DeclarationSignature::new()
|
||||||
.generics(vec![Some(Constant::Generic("N"))])
|
.generics(vec![Some(Constant::Generic(GenericIdentifier {
|
||||||
|
name: "N",
|
||||||
|
index: 0,
|
||||||
|
}))])
|
||||||
.inputs(vec![DeclarationType::FieldElement])
|
.inputs(vec![DeclarationType::FieldElement])
|
||||||
.outputs(vec![DeclarationType::array((
|
.outputs(vec![DeclarationType::array((
|
||||||
DeclarationType::Boolean,
|
DeclarationType::Boolean,
|
||||||
"N",
|
GenericIdentifier {
|
||||||
|
name: "N",
|
||||||
|
index: 0,
|
||||||
|
},
|
||||||
))]),
|
))]),
|
||||||
FlatEmbed::U8ToBits => DeclarationSignature::new()
|
FlatEmbed::U8ToBits => DeclarationSignature::new()
|
||||||
.inputs(vec![DeclarationType::uint(8)])
|
.inputs(vec![DeclarationType::uint(8)])
|
||||||
|
|
|
@ -20,7 +20,8 @@ use crate::absy::types::{UnresolvedSignature, UnresolvedType, UserTypeId};
|
||||||
|
|
||||||
use crate::typed_absy::types::{
|
use crate::typed_absy::types::{
|
||||||
ArrayType, Constant, DeclarationArrayType, DeclarationFunctionKey, DeclarationSignature,
|
ArrayType, Constant, DeclarationArrayType, DeclarationFunctionKey, DeclarationSignature,
|
||||||
DeclarationStructMember, DeclarationStructType, DeclarationType, StructLocation,
|
DeclarationStructMember, DeclarationStructType, DeclarationType, GenericIdentifier,
|
||||||
|
StructLocation,
|
||||||
};
|
};
|
||||||
use std::hash::{Hash, Hasher};
|
use std::hash::{Hash, Hasher};
|
||||||
|
|
||||||
|
@ -325,7 +326,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
|
||||||
for field in s.fields {
|
for field in s.fields {
|
||||||
let member_id = field.value.id.to_string();
|
let member_id = field.value.id.to_string();
|
||||||
match self
|
match self
|
||||||
.check_declaration_type(field.value.ty, module_id, &types, &HashSet::new())
|
.check_declaration_type(field.value.ty, module_id, &types, &HashMap::new())
|
||||||
.map(|t| (member_id, t))
|
.map(|t| (member_id, t))
|
||||||
{
|
{
|
||||||
Ok(f) => match fields_set.insert(f.0.clone()) {
|
Ok(f) => match fields_set.insert(f.0.clone()) {
|
||||||
|
@ -696,7 +697,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
|
||||||
|
|
||||||
let v = Variable::with_id_and_type(
|
let v = Variable::with_id_and_type(
|
||||||
match generic {
|
match generic {
|
||||||
Constant::Generic(g) => g,
|
Constant::Generic(g) => g.name,
|
||||||
_ => unreachable!(),
|
_ => unreachable!(),
|
||||||
},
|
},
|
||||||
Type::Uint(UBitwidth::B32),
|
Type::Uint(UBitwidth::B32),
|
||||||
|
@ -818,12 +819,15 @@ impl<'ast, T: Field> Checker<'ast, T> {
|
||||||
let mut outputs = vec![];
|
let mut outputs = vec![];
|
||||||
let mut generics = vec![];
|
let mut generics = vec![];
|
||||||
|
|
||||||
let mut constants = HashSet::new();
|
let mut generics_map = HashMap::new();
|
||||||
|
|
||||||
for g in signature.generics {
|
for (index, g) in signature.generics.iter().enumerate() {
|
||||||
match constants.insert(g.value) {
|
match generics_map.insert(g.value, index).is_none() {
|
||||||
true => {
|
true => {
|
||||||
generics.push(Some(Constant::Generic(g.value)));
|
generics.push(Some(Constant::Generic(GenericIdentifier {
|
||||||
|
name: g.value,
|
||||||
|
index,
|
||||||
|
})));
|
||||||
}
|
}
|
||||||
false => {
|
false => {
|
||||||
errors.push(ErrorInner {
|
errors.push(ErrorInner {
|
||||||
|
@ -835,7 +839,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
|
||||||
}
|
}
|
||||||
|
|
||||||
for t in signature.inputs {
|
for t in signature.inputs {
|
||||||
match self.check_declaration_type(t, module_id, types, &constants) {
|
match self.check_declaration_type(t, module_id, types, &generics_map) {
|
||||||
Ok(t) => {
|
Ok(t) => {
|
||||||
inputs.push(t);
|
inputs.push(t);
|
||||||
}
|
}
|
||||||
|
@ -846,7 +850,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
|
||||||
}
|
}
|
||||||
|
|
||||||
for t in signature.outputs {
|
for t in signature.outputs {
|
||||||
match self.check_declaration_type(t, module_id, types, &constants) {
|
match self.check_declaration_type(t, module_id, types, &generics_map) {
|
||||||
Ok(t) => {
|
Ok(t) => {
|
||||||
outputs.push(t);
|
outputs.push(t);
|
||||||
}
|
}
|
||||||
|
@ -936,6 +940,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
|
||||||
fn check_generic_expression(
|
fn check_generic_expression(
|
||||||
&mut self,
|
&mut self,
|
||||||
expr: ExpressionNode<'ast>,
|
expr: ExpressionNode<'ast>,
|
||||||
|
generics_map: &HashMap<Identifier<'ast>, usize>,
|
||||||
) -> Result<Constant<'ast>, ErrorInner> {
|
) -> Result<Constant<'ast>, ErrorInner> {
|
||||||
let pos = expr.pos();
|
let pos = expr.pos();
|
||||||
|
|
||||||
|
@ -956,7 +961,16 @@ impl<'ast, T: Field> Checker<'ast, T> {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Expression::Identifier(name) => Ok(Constant::Generic(name)),
|
Expression::Identifier(name) => {
|
||||||
|
// check that this generic parameter is defined
|
||||||
|
match generics_map.get(&name) {
|
||||||
|
Some(index) => Ok(Constant::Generic(GenericIdentifier {name, index: *index})),
|
||||||
|
None => Err(ErrorInner {
|
||||||
|
pos: Some(pos),
|
||||||
|
message: format!("Undeclared generic parameter in function definition: `{}` isn\'t declared as a generic constant", name)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
e => Err(ErrorInner {
|
e => Err(ErrorInner {
|
||||||
pos: Some(pos),
|
pos: Some(pos),
|
||||||
message: format!(
|
message: format!(
|
||||||
|
@ -972,7 +986,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
|
||||||
ty: UnresolvedTypeNode<'ast>,
|
ty: UnresolvedTypeNode<'ast>,
|
||||||
module_id: &ModuleId,
|
module_id: &ModuleId,
|
||||||
types: &TypeMap<'ast>,
|
types: &TypeMap<'ast>,
|
||||||
constants: &HashSet<Identifier<'ast>>,
|
generics_map: &HashMap<Identifier<'ast>, usize>,
|
||||||
) -> Result<DeclarationType<'ast>, ErrorInner> {
|
) -> Result<DeclarationType<'ast>, ErrorInner> {
|
||||||
let pos = ty.pos();
|
let pos = ty.pos();
|
||||||
let ty = ty.value;
|
let ty = ty.value;
|
||||||
|
@ -982,19 +996,10 @@ impl<'ast, T: Field> Checker<'ast, T> {
|
||||||
UnresolvedType::Boolean => Ok(DeclarationType::Boolean),
|
UnresolvedType::Boolean => Ok(DeclarationType::Boolean),
|
||||||
UnresolvedType::Uint(bitwidth) => Ok(DeclarationType::uint(bitwidth)),
|
UnresolvedType::Uint(bitwidth) => Ok(DeclarationType::uint(bitwidth)),
|
||||||
UnresolvedType::Array(t, size) => {
|
UnresolvedType::Array(t, size) => {
|
||||||
let checked_size = self.check_generic_expression(size.clone())?;
|
let checked_size = self.check_generic_expression(size.clone(), &generics_map)?;
|
||||||
|
|
||||||
if let Constant::Generic(g) = checked_size {
|
|
||||||
if !constants.contains(g) {
|
|
||||||
return Err(ErrorInner {
|
|
||||||
pos: Some(pos),
|
|
||||||
message: format!("Undeclared generic parameter in function definition: `{}` isn\'t declared as a generic constant", g)
|
|
||||||
});
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
Ok(DeclarationType::Array(DeclarationArrayType::new(
|
Ok(DeclarationType::Array(DeclarationArrayType::new(
|
||||||
self.check_declaration_type(*t, module_id, types, constants)?,
|
self.check_declaration_type(*t, module_id, types, generics_map)?,
|
||||||
checked_size,
|
checked_size,
|
||||||
)))
|
)))
|
||||||
}
|
}
|
||||||
|
@ -3022,17 +3027,21 @@ mod tests {
|
||||||
assert!(!unifier.insert_function(
|
assert!(!unifier.insert_function(
|
||||||
"bar",
|
"bar",
|
||||||
DeclarationSignature::new()
|
DeclarationSignature::new()
|
||||||
.generics(vec![Some("K".into())])
|
.generics(vec![Some(
|
||||||
|
GenericIdentifier::with_name("K").index(0).into()
|
||||||
|
)])
|
||||||
.inputs(vec![DeclarationType::FieldElement])
|
.inputs(vec![DeclarationType::FieldElement])
|
||||||
));
|
));
|
||||||
// a `bar` function with a different signature
|
// a `bar` function with a different signature
|
||||||
assert!(unifier.insert_function(
|
assert!(unifier.insert_function(
|
||||||
"bar",
|
"bar",
|
||||||
DeclarationSignature::new()
|
DeclarationSignature::new()
|
||||||
.generics(vec![Some("K".into())])
|
.generics(vec![Some(
|
||||||
|
GenericIdentifier::with_name("K").index(0).into()
|
||||||
|
)])
|
||||||
.inputs(vec![DeclarationType::array((
|
.inputs(vec![DeclarationType::array((
|
||||||
DeclarationType::FieldElement,
|
DeclarationType::FieldElement,
|
||||||
"K"
|
GenericIdentifier::with_name("K").index(0)
|
||||||
))])
|
))])
|
||||||
));
|
));
|
||||||
// a `bar` function with a different signature, but which could conflict with the previous one
|
// a `bar` function with a different signature, but which could conflict with the previous one
|
||||||
|
@ -3609,12 +3618,18 @@ mod tests {
|
||||||
),
|
),
|
||||||
Ok(DeclarationSignature::new()
|
Ok(DeclarationSignature::new()
|
||||||
.inputs(vec![DeclarationType::array((
|
.inputs(vec![DeclarationType::array((
|
||||||
DeclarationType::array((DeclarationType::FieldElement, "K")),
|
DeclarationType::array((
|
||||||
"L"
|
DeclarationType::FieldElement,
|
||||||
|
GenericIdentifier::with_name("K").index(0)
|
||||||
|
)),
|
||||||
|
GenericIdentifier::with_name("L").index(1)
|
||||||
))])
|
))])
|
||||||
.outputs(vec![DeclarationType::array((
|
.outputs(vec![DeclarationType::array((
|
||||||
DeclarationType::array((DeclarationType::FieldElement, "L")),
|
DeclarationType::array((
|
||||||
"K"
|
DeclarationType::FieldElement,
|
||||||
|
GenericIdentifier::with_name("L").index(1)
|
||||||
|
)),
|
||||||
|
GenericIdentifier::with_name("K").index(0)
|
||||||
))]))
|
))]))
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
|
@ -657,8 +657,9 @@ mod tests {
|
||||||
use crate::typed_absy::types::DeclarationSignature;
|
use crate::typed_absy::types::DeclarationSignature;
|
||||||
use crate::typed_absy::{
|
use crate::typed_absy::{
|
||||||
ArrayExpressionInner, DeclarationFunctionKey, DeclarationType, DeclarationVariable,
|
ArrayExpressionInner, DeclarationFunctionKey, DeclarationType, DeclarationVariable,
|
||||||
FieldElementExpression, Identifier, OwnedTypedModuleId, Select, Type, TypedExpression,
|
FieldElementExpression, GenericIdentifier, Identifier, OwnedTypedModuleId, Select, Type,
|
||||||
TypedExpressionList, TypedExpressionOrSpread, UBitwidth, UExpressionInner, Variable,
|
TypedExpression, TypedExpressionList, TypedExpressionOrSpread, UBitwidth, UExpressionInner,
|
||||||
|
Variable,
|
||||||
};
|
};
|
||||||
use zokrates_field::Bn128Field;
|
use zokrates_field::Bn128Field;
|
||||||
|
|
||||||
|
@ -865,20 +866,25 @@ mod tests {
|
||||||
// return a_2 + b_1[0]
|
// return a_2 + b_1[0]
|
||||||
|
|
||||||
let foo_signature = DeclarationSignature::new()
|
let foo_signature = DeclarationSignature::new()
|
||||||
.generics(vec![Some("K".into())])
|
.generics(vec![Some(
|
||||||
|
GenericIdentifier::with_name("K").index(0).into(),
|
||||||
|
)])
|
||||||
.inputs(vec![DeclarationType::array((
|
.inputs(vec![DeclarationType::array((
|
||||||
DeclarationType::FieldElement,
|
DeclarationType::FieldElement,
|
||||||
Constant::Generic("K"),
|
Constant::Generic(GenericIdentifier::with_name("K").index(0)),
|
||||||
))])
|
))])
|
||||||
.outputs(vec![DeclarationType::array((
|
.outputs(vec![DeclarationType::array((
|
||||||
DeclarationType::FieldElement,
|
DeclarationType::FieldElement,
|
||||||
Constant::Generic("K"),
|
Constant::Generic(GenericIdentifier::with_name("K").index(0)),
|
||||||
))]);
|
))]);
|
||||||
|
|
||||||
let foo: TypedFunction<Bn128Field> = TypedFunction {
|
let foo: TypedFunction<Bn128Field> = TypedFunction {
|
||||||
arguments: vec![
|
arguments: vec![DeclarationVariable::array(
|
||||||
DeclarationVariable::array("a", DeclarationType::FieldElement, "K").into(),
|
"a",
|
||||||
],
|
DeclarationType::FieldElement,
|
||||||
|
GenericIdentifier::with_name("K").index(0),
|
||||||
|
)
|
||||||
|
.into()],
|
||||||
statements: vec![TypedStatement::Return(vec![
|
statements: vec![TypedStatement::Return(vec![
|
||||||
ArrayExpressionInner::Identifier("a".into())
|
ArrayExpressionInner::Identifier("a".into())
|
||||||
.annotate(Type::FieldElement, 1u32)
|
.annotate(Type::FieldElement, 1u32)
|
||||||
|
@ -983,7 +989,11 @@ mod tests {
|
||||||
TypedStatement::PushCallLog(
|
TypedStatement::PushCallLog(
|
||||||
DeclarationFunctionKey::with_location("main", "foo")
|
DeclarationFunctionKey::with_location("main", "foo")
|
||||||
.signature(foo_signature.clone()),
|
.signature(foo_signature.clone()),
|
||||||
GGenericsAssignment(vec![("K", 1)].into_iter().collect()),
|
GGenericsAssignment(
|
||||||
|
vec![(GenericIdentifier::with_name("K").index(0), 1)]
|
||||||
|
.into_iter()
|
||||||
|
.collect(),
|
||||||
|
),
|
||||||
),
|
),
|
||||||
TypedStatement::Definition(
|
TypedStatement::Definition(
|
||||||
Variable::array(Identifier::from("a").version(1), Type::FieldElement, 1u32)
|
Variable::array(Identifier::from("a").version(1), Type::FieldElement, 1u32)
|
||||||
|
@ -1073,20 +1083,25 @@ mod tests {
|
||||||
// return a_2 + b_1[0]
|
// return a_2 + b_1[0]
|
||||||
|
|
||||||
let foo_signature = DeclarationSignature::new()
|
let foo_signature = DeclarationSignature::new()
|
||||||
.generics(vec![Some("K".into())])
|
.generics(vec![Some(
|
||||||
|
GenericIdentifier::with_name("K").index(0).into(),
|
||||||
|
)])
|
||||||
.inputs(vec![DeclarationType::array((
|
.inputs(vec![DeclarationType::array((
|
||||||
DeclarationType::FieldElement,
|
DeclarationType::FieldElement,
|
||||||
Constant::Generic("K"),
|
Constant::Generic(GenericIdentifier::with_name("K").index(0)),
|
||||||
))])
|
))])
|
||||||
.outputs(vec![DeclarationType::array((
|
.outputs(vec![DeclarationType::array((
|
||||||
DeclarationType::FieldElement,
|
DeclarationType::FieldElement,
|
||||||
Constant::Generic("K"),
|
Constant::Generic(GenericIdentifier::with_name("K").index(0)),
|
||||||
))]);
|
))]);
|
||||||
|
|
||||||
let foo: TypedFunction<Bn128Field> = TypedFunction {
|
let foo: TypedFunction<Bn128Field> = TypedFunction {
|
||||||
arguments: vec![
|
arguments: vec![DeclarationVariable::array(
|
||||||
DeclarationVariable::array("a", DeclarationType::FieldElement, "K").into(),
|
"a",
|
||||||
],
|
DeclarationType::FieldElement,
|
||||||
|
GenericIdentifier::with_name("K").index(0),
|
||||||
|
)
|
||||||
|
.into()],
|
||||||
statements: vec![TypedStatement::Return(vec![
|
statements: vec![TypedStatement::Return(vec![
|
||||||
ArrayExpressionInner::Identifier("a".into())
|
ArrayExpressionInner::Identifier("a".into())
|
||||||
.annotate(Type::FieldElement, 1u32)
|
.annotate(Type::FieldElement, 1u32)
|
||||||
|
@ -1200,7 +1215,11 @@ mod tests {
|
||||||
TypedStatement::PushCallLog(
|
TypedStatement::PushCallLog(
|
||||||
DeclarationFunctionKey::with_location("main", "foo")
|
DeclarationFunctionKey::with_location("main", "foo")
|
||||||
.signature(foo_signature.clone()),
|
.signature(foo_signature.clone()),
|
||||||
GGenericsAssignment(vec![("K", 1)].into_iter().collect()),
|
GGenericsAssignment(
|
||||||
|
vec![(GenericIdentifier::with_name("K").index(0), 1)]
|
||||||
|
.into_iter()
|
||||||
|
.collect(),
|
||||||
|
),
|
||||||
),
|
),
|
||||||
TypedStatement::Definition(
|
TypedStatement::Definition(
|
||||||
Variable::array(Identifier::from("a").version(1), Type::FieldElement, 1u32)
|
Variable::array(Identifier::from("a").version(1), Type::FieldElement, 1u32)
|
||||||
|
@ -1299,19 +1318,21 @@ mod tests {
|
||||||
let foo_signature = DeclarationSignature::new()
|
let foo_signature = DeclarationSignature::new()
|
||||||
.inputs(vec![DeclarationType::array((
|
.inputs(vec![DeclarationType::array((
|
||||||
DeclarationType::FieldElement,
|
DeclarationType::FieldElement,
|
||||||
Constant::Generic("K"),
|
Constant::Generic(GenericIdentifier::with_name("K").index(0)),
|
||||||
))])
|
))])
|
||||||
.outputs(vec![DeclarationType::array((
|
.outputs(vec![DeclarationType::array((
|
||||||
DeclarationType::FieldElement,
|
DeclarationType::FieldElement,
|
||||||
Constant::Generic("K"),
|
Constant::Generic(GenericIdentifier::with_name("K").index(0)),
|
||||||
))])
|
))])
|
||||||
.generics(vec![Some("K".into())]);
|
.generics(vec![Some(
|
||||||
|
GenericIdentifier::with_name("K").index(0).into(),
|
||||||
|
)]);
|
||||||
|
|
||||||
let foo: TypedFunction<Bn128Field> = TypedFunction {
|
let foo: TypedFunction<Bn128Field> = TypedFunction {
|
||||||
arguments: vec![DeclarationVariable::array(
|
arguments: vec![DeclarationVariable::array(
|
||||||
"a",
|
"a",
|
||||||
DeclarationType::FieldElement,
|
DeclarationType::FieldElement,
|
||||||
Constant::Generic("K"),
|
Constant::Generic(GenericIdentifier::with_name("K").index(0)),
|
||||||
)
|
)
|
||||||
.into()],
|
.into()],
|
||||||
statements: vec![
|
statements: vec![
|
||||||
|
@ -1375,7 +1396,7 @@ mod tests {
|
||||||
arguments: vec![DeclarationVariable::array(
|
arguments: vec![DeclarationVariable::array(
|
||||||
"a",
|
"a",
|
||||||
DeclarationType::FieldElement,
|
DeclarationType::FieldElement,
|
||||||
Constant::Generic("K"),
|
Constant::Generic(GenericIdentifier::with_name("K").index(0)),
|
||||||
)
|
)
|
||||||
.into()],
|
.into()],
|
||||||
statements: vec![TypedStatement::Return(vec![
|
statements: vec![TypedStatement::Return(vec![
|
||||||
|
@ -1448,12 +1469,20 @@ mod tests {
|
||||||
TypedStatement::PushCallLog(
|
TypedStatement::PushCallLog(
|
||||||
DeclarationFunctionKey::with_location("main", "foo")
|
DeclarationFunctionKey::with_location("main", "foo")
|
||||||
.signature(foo_signature.clone()),
|
.signature(foo_signature.clone()),
|
||||||
GGenericsAssignment(vec![("K", 1)].into_iter().collect()),
|
GGenericsAssignment(
|
||||||
|
vec![(GenericIdentifier::with_name("K").index(0), 1)]
|
||||||
|
.into_iter()
|
||||||
|
.collect(),
|
||||||
|
),
|
||||||
),
|
),
|
||||||
TypedStatement::PushCallLog(
|
TypedStatement::PushCallLog(
|
||||||
DeclarationFunctionKey::with_location("main", "bar")
|
DeclarationFunctionKey::with_location("main", "bar")
|
||||||
.signature(foo_signature.clone()),
|
.signature(foo_signature.clone()),
|
||||||
GGenericsAssignment(vec![("K", 2)].into_iter().collect()),
|
GGenericsAssignment(
|
||||||
|
vec![(GenericIdentifier::with_name("K").index(0), 2)]
|
||||||
|
.into_iter()
|
||||||
|
.collect(),
|
||||||
|
),
|
||||||
),
|
),
|
||||||
TypedStatement::Definition(
|
TypedStatement::Definition(
|
||||||
Variable::array(Identifier::from("a").version(1), Type::FieldElement, 2u32)
|
Variable::array(Identifier::from("a").version(1), Type::FieldElement, 2u32)
|
||||||
|
@ -1558,20 +1587,25 @@ mod tests {
|
||||||
// Error: Incompatible
|
// Error: Incompatible
|
||||||
|
|
||||||
let foo_signature = DeclarationSignature::new()
|
let foo_signature = DeclarationSignature::new()
|
||||||
.generics(vec![Some("K".into())])
|
.generics(vec![Some(
|
||||||
|
GenericIdentifier::with_name("K").index(0).into(),
|
||||||
|
)])
|
||||||
.inputs(vec![DeclarationType::array((
|
.inputs(vec![DeclarationType::array((
|
||||||
DeclarationType::FieldElement,
|
DeclarationType::FieldElement,
|
||||||
Constant::Generic("K"),
|
GenericIdentifier::with_name("K").index(0),
|
||||||
))])
|
))])
|
||||||
.outputs(vec![DeclarationType::array((
|
.outputs(vec![DeclarationType::array((
|
||||||
DeclarationType::FieldElement,
|
DeclarationType::FieldElement,
|
||||||
Constant::Generic("K"),
|
GenericIdentifier::with_name("K").index(0),
|
||||||
))]);
|
))]);
|
||||||
|
|
||||||
let foo: TypedFunction<Bn128Field> = TypedFunction {
|
let foo: TypedFunction<Bn128Field> = TypedFunction {
|
||||||
arguments: vec![
|
arguments: vec![DeclarationVariable::array(
|
||||||
DeclarationVariable::array("a", DeclarationType::FieldElement, "K").into(),
|
"a",
|
||||||
],
|
DeclarationType::FieldElement,
|
||||||
|
GenericIdentifier::with_name("K").index(0),
|
||||||
|
)
|
||||||
|
.into()],
|
||||||
statements: vec![TypedStatement::Return(vec![
|
statements: vec![TypedStatement::Return(vec![
|
||||||
ArrayExpressionInner::Identifier("a".into())
|
ArrayExpressionInner::Identifier("a".into())
|
||||||
.annotate(Type::FieldElement, 1u32)
|
.annotate(Type::FieldElement, 1u32)
|
||||||
|
|
|
@ -105,7 +105,7 @@ impl<'ast, 'a> ShallowTransformer<'ast, 'a> {
|
||||||
.map(|(g, v)| {
|
.map(|(g, v)| {
|
||||||
TypedStatement::Definition(
|
TypedStatement::Definition(
|
||||||
TypedAssignee::Identifier(Variable::with_id_and_type(
|
TypedAssignee::Identifier(Variable::with_id_and_type(
|
||||||
*g,
|
g.name,
|
||||||
Type::Uint(UBitwidth::B32),
|
Type::Uint(UBitwidth::B32),
|
||||||
)),
|
)),
|
||||||
UExpression::from(*v as u32).into(),
|
UExpression::from(*v as u32).into(),
|
||||||
|
@ -731,7 +731,9 @@ mod tests {
|
||||||
]),
|
]),
|
||||||
],
|
],
|
||||||
signature: DeclarationSignature::new()
|
signature: DeclarationSignature::new()
|
||||||
.generics(vec![Some("K".into())])
|
.generics(vec![Some(
|
||||||
|
GenericIdentifier::with_name("K").index(0).into(),
|
||||||
|
)])
|
||||||
.inputs(vec![DeclarationType::FieldElement])
|
.inputs(vec![DeclarationType::FieldElement])
|
||||||
.outputs(vec![DeclarationType::FieldElement]),
|
.outputs(vec![DeclarationType::FieldElement]),
|
||||||
};
|
};
|
||||||
|
@ -740,7 +742,11 @@ mod tests {
|
||||||
|
|
||||||
let ssa = ShallowTransformer::transform(
|
let ssa = ShallowTransformer::transform(
|
||||||
f,
|
f,
|
||||||
&GGenericsAssignment(vec![("K", 1)].into_iter().collect()),
|
&GGenericsAssignment(
|
||||||
|
vec![(GenericIdentifier::with_name("K").index(0), 1)]
|
||||||
|
.into_iter()
|
||||||
|
.collect(),
|
||||||
|
),
|
||||||
&mut versions,
|
&mut versions,
|
||||||
);
|
);
|
||||||
|
|
||||||
|
@ -805,7 +811,9 @@ mod tests {
|
||||||
.into()]),
|
.into()]),
|
||||||
],
|
],
|
||||||
signature: DeclarationSignature::new()
|
signature: DeclarationSignature::new()
|
||||||
.generics(vec![Some("K".into())])
|
.generics(vec![Some(
|
||||||
|
GenericIdentifier::with_name("K").index(0).into(),
|
||||||
|
)])
|
||||||
.inputs(vec![DeclarationType::FieldElement])
|
.inputs(vec![DeclarationType::FieldElement])
|
||||||
.outputs(vec![DeclarationType::FieldElement]),
|
.outputs(vec![DeclarationType::FieldElement]),
|
||||||
};
|
};
|
||||||
|
@ -912,7 +920,9 @@ mod tests {
|
||||||
]),
|
]),
|
||||||
],
|
],
|
||||||
signature: DeclarationSignature::new()
|
signature: DeclarationSignature::new()
|
||||||
.generics(vec![Some("K".into())])
|
.generics(vec![Some(
|
||||||
|
GenericIdentifier::with_name("K").index(0).into(),
|
||||||
|
)])
|
||||||
.inputs(vec![DeclarationType::FieldElement])
|
.inputs(vec![DeclarationType::FieldElement])
|
||||||
.outputs(vec![DeclarationType::FieldElement]),
|
.outputs(vec![DeclarationType::FieldElement]),
|
||||||
};
|
};
|
||||||
|
@ -921,7 +931,11 @@ mod tests {
|
||||||
|
|
||||||
let ssa = ShallowTransformer::transform(
|
let ssa = ShallowTransformer::transform(
|
||||||
f,
|
f,
|
||||||
&GGenericsAssignment(vec![("K", 1)].into_iter().collect()),
|
&GGenericsAssignment(
|
||||||
|
vec![(GenericIdentifier::with_name("K").index(0), 1)]
|
||||||
|
.into_iter()
|
||||||
|
.collect(),
|
||||||
|
),
|
||||||
&mut versions,
|
&mut versions,
|
||||||
);
|
);
|
||||||
|
|
||||||
|
@ -989,7 +1003,9 @@ mod tests {
|
||||||
.into()]),
|
.into()]),
|
||||||
],
|
],
|
||||||
signature: DeclarationSignature::new()
|
signature: DeclarationSignature::new()
|
||||||
.generics(vec![Some("K".into())])
|
.generics(vec![Some(
|
||||||
|
GenericIdentifier::with_name("K").index(0).into(),
|
||||||
|
)])
|
||||||
.inputs(vec![DeclarationType::FieldElement])
|
.inputs(vec![DeclarationType::FieldElement])
|
||||||
.outputs(vec![DeclarationType::FieldElement]),
|
.outputs(vec![DeclarationType::FieldElement]),
|
||||||
};
|
};
|
||||||
|
|
|
@ -6,7 +6,46 @@ use std::fmt;
|
||||||
use std::hash::{Hash, Hasher};
|
use std::hash::{Hash, Hasher};
|
||||||
use std::path::{Path, PathBuf};
|
use std::path::{Path, PathBuf};
|
||||||
|
|
||||||
pub type GenericIdentifier<'ast> = &'ast str;
|
#[derive(Debug, Clone, Eq, Ord)]
|
||||||
|
pub struct GenericIdentifier<'ast> {
|
||||||
|
pub name: &'ast str,
|
||||||
|
pub index: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'ast> GenericIdentifier<'ast> {
|
||||||
|
pub fn with_name(name: &'ast str) -> Self {
|
||||||
|
Self { name, index: 0 }
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn index(mut self, index: usize) -> Self {
|
||||||
|
self.index = index;
|
||||||
|
self
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'ast> PartialEq for GenericIdentifier<'ast> {
|
||||||
|
fn eq(&self, other: &Self) -> bool {
|
||||||
|
self.index == other.index
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'ast> PartialOrd for GenericIdentifier<'ast> {
|
||||||
|
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
|
||||||
|
self.index.partial_cmp(&other.index)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'ast> Hash for GenericIdentifier<'ast> {
|
||||||
|
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
|
||||||
|
self.index.hash(state);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'ast> fmt::Display for GenericIdentifier<'ast> {
|
||||||
|
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||||
|
write!(f, "{}", self.name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct SpecializationError;
|
pub struct SpecializationError;
|
||||||
|
@ -53,7 +92,9 @@ impl<'ast, T> From<usize> for UExpression<'ast, T> {
|
||||||
impl<'ast, T> From<Constant<'ast>> for UExpression<'ast, T> {
|
impl<'ast, T> From<Constant<'ast>> for UExpression<'ast, T> {
|
||||||
fn from(c: Constant<'ast>) -> Self {
|
fn from(c: Constant<'ast>) -> Self {
|
||||||
match c {
|
match c {
|
||||||
Constant::Generic(i) => UExpressionInner::Identifier(i.into()).annotate(UBitwidth::B32),
|
Constant::Generic(i) => {
|
||||||
|
UExpressionInner::Identifier(i.name.into()).annotate(UBitwidth::B32)
|
||||||
|
}
|
||||||
Constant::Concrete(v) => UExpressionInner::Value(v as u128).annotate(UBitwidth::B32),
|
Constant::Concrete(v) => UExpressionInner::Value(v as u128).annotate(UBitwidth::B32),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -819,13 +860,41 @@ pub mod signature {
|
||||||
use super::*;
|
use super::*;
|
||||||
use std::fmt;
|
use std::fmt;
|
||||||
|
|
||||||
#[derive(Clone, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord, Hash)]
|
#[derive(Clone, Serialize, Deserialize, Eq)]
|
||||||
pub struct GSignature<S> {
|
pub struct GSignature<S> {
|
||||||
pub generics: Vec<Option<S>>,
|
pub generics: Vec<Option<S>>,
|
||||||
pub inputs: Vec<GType<S>>,
|
pub inputs: Vec<GType<S>>,
|
||||||
pub outputs: Vec<GType<S>>,
|
pub outputs: Vec<GType<S>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl<S: PartialEq> PartialEq for GSignature<S> {
|
||||||
|
fn eq(&self, other: &Self) -> bool {
|
||||||
|
self.inputs == other.inputs && self.outputs == other.outputs
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<S: PartialOrd> PartialOrd for GSignature<S> {
|
||||||
|
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
|
||||||
|
self.inputs
|
||||||
|
.partial_cmp(&other.inputs)
|
||||||
|
.map(|c| self.outputs.partial_cmp(&other.outputs).map(|d| c.then(d)))
|
||||||
|
.unwrap()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<S: Ord> Ord for GSignature<S> {
|
||||||
|
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
|
||||||
|
self.partial_cmp(&other).unwrap()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<S: Hash> Hash for GSignature<S> {
|
||||||
|
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
|
||||||
|
self.inputs.hash(state);
|
||||||
|
self.outputs.hash(state);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl<S> Default for GSignature<S> {
|
impl<S> Default for GSignature<S> {
|
||||||
fn default() -> Self {
|
fn default() -> Self {
|
||||||
GSignature {
|
GSignature {
|
||||||
|
@ -853,17 +922,17 @@ pub mod signature {
|
||||||
|
|
||||||
// both the inner type and the size must match
|
// both the inner type and the size must match
|
||||||
check_type(&t0.ty, &t1.ty, constants)
|
check_type(&t0.ty, &t1.ty, constants)
|
||||||
&& match t0.size {
|
&& match &t0.size {
|
||||||
// if the declared size is an identifier, we insert into the map, or check if the concrete size
|
// if the declared size is an identifier, we insert into the map, or check if the concrete size
|
||||||
// matches if this identifier is already in the map
|
// matches if this identifier is already in the map
|
||||||
Constant::Generic(id) => match constants.0.entry(id) {
|
Constant::Generic(id) => match constants.0.entry(id.clone()) {
|
||||||
Entry::Occupied(e) => *e.get() == s1,
|
Entry::Occupied(e) => *e.get() == s1,
|
||||||
Entry::Vacant(e) => {
|
Entry::Vacant(e) => {
|
||||||
e.insert(s1);
|
e.insert(s1);
|
||||||
true
|
true
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
Constant::Concrete(s0) => s1 == s0 as usize,
|
Constant::Concrete(s0) => s1 == *s0 as usize,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
(DeclarationType::FieldElement, GType::FieldElement)
|
(DeclarationType::FieldElement, GType::FieldElement)
|
||||||
|
@ -1169,34 +1238,61 @@ pub mod signature {
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn signature_equivalence() {
|
fn signature_equivalence() {
|
||||||
let generic = DeclarationSignature::new()
|
// check equivalence of:
|
||||||
.generics(vec![Some("P".into())])
|
// <P>(field[P])
|
||||||
|
// <Q>(field[Q])
|
||||||
|
|
||||||
|
let generic1 = DeclarationSignature::new()
|
||||||
|
.generics(vec![Some(
|
||||||
|
GenericIdentifier {
|
||||||
|
name: "P",
|
||||||
|
index: 0,
|
||||||
|
}
|
||||||
|
.into(),
|
||||||
|
)])
|
||||||
.inputs(vec![DeclarationType::array(DeclarationArrayType::new(
|
.inputs(vec![DeclarationType::array(DeclarationArrayType::new(
|
||||||
DeclarationType::FieldElement,
|
DeclarationType::FieldElement,
|
||||||
"P".into(),
|
GenericIdentifier {
|
||||||
|
name: "P",
|
||||||
|
index: 0,
|
||||||
|
}
|
||||||
|
.into(),
|
||||||
|
))]);
|
||||||
|
let generic2 = DeclarationSignature::new()
|
||||||
|
.generics(vec![Some(
|
||||||
|
GenericIdentifier {
|
||||||
|
name: "Q",
|
||||||
|
index: 0,
|
||||||
|
}
|
||||||
|
.into(),
|
||||||
|
)])
|
||||||
|
.inputs(vec![DeclarationType::array(DeclarationArrayType::new(
|
||||||
|
DeclarationType::FieldElement,
|
||||||
|
GenericIdentifier {
|
||||||
|
name: "Q",
|
||||||
|
index: 0,
|
||||||
|
}
|
||||||
|
.into(),
|
||||||
))]);
|
))]);
|
||||||
let specialized = DeclarationSignature::new().inputs(vec![DeclarationType::array(
|
|
||||||
DeclarationArrayType::new(DeclarationType::FieldElement, 3u32.into()),
|
|
||||||
)]);
|
|
||||||
|
|
||||||
assert_eq!(generic, specialized);
|
assert_eq!(generic1, generic2);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
{
|
{
|
||||||
let mut hasher = std::collections::hash_map::DefaultHasher::new();
|
let mut hasher = std::collections::hash_map::DefaultHasher::new();
|
||||||
generic.hash(&mut hasher);
|
generic1.hash(&mut hasher);
|
||||||
hasher.finish()
|
hasher.finish()
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
let mut hasher = std::collections::hash_map::DefaultHasher::new();
|
let mut hasher = std::collections::hash_map::DefaultHasher::new();
|
||||||
specialized.hash(&mut hasher);
|
generic2.hash(&mut hasher);
|
||||||
hasher.finish()
|
hasher.finish()
|
||||||
}
|
}
|
||||||
);
|
);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
generic.partial_cmp(&specialized),
|
generic1.partial_cmp(&generic2),
|
||||||
Some(std::cmp::Ordering::Equal)
|
Some(std::cmp::Ordering::Equal)
|
||||||
);
|
);
|
||||||
assert_eq!(generic.cmp(&specialized), std::cmp::Ordering::Equal);
|
assert_eq!(generic1.cmp(&generic2), std::cmp::Ordering::Equal);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|
Loading…
Reference in a new issue