Merge pull request #945 from Zokrates/generic-structs
Implement generic structs
This commit is contained in:
commit
f717d243b2
40 changed files with 1586 additions and 599 deletions
1
Cargo.lock
generated
1
Cargo.lock
generated
|
@ -2547,6 +2547,7 @@ dependencies = [
|
|||
"serde",
|
||||
"serde_derive",
|
||||
"serde_json",
|
||||
"zokrates_abi",
|
||||
"zokrates_core",
|
||||
"zokrates_field",
|
||||
"zokrates_fs_resolver",
|
||||
|
|
1
changelogs/unreleased/945-schaeff
Normal file
1
changelogs/unreleased/945-schaeff
Normal file
|
@ -0,0 +1 @@
|
|||
Enable constant generics on structs
|
1
changelogs/unreleased/948-schaeff
Normal file
1
changelogs/unreleased/948-schaeff
Normal file
|
@ -0,0 +1 @@
|
|||
Add gm17 verifier to stdlib for bw6_761
|
|
@ -299,22 +299,7 @@ pub fn parse_strict<T: Field>(s: &str, types: Vec<ConcreteType>) -> Result<Value
|
|||
serde_json::from_str(s).map_err(|e| Error::Json(e.to_string()))?;
|
||||
|
||||
match values {
|
||||
serde_json::Value::Array(values) => {
|
||||
if values.len() != types.len() {
|
||||
return Err(Error::Type(format!(
|
||||
"Expected {} inputs, found {}",
|
||||
types.len(),
|
||||
values.len()
|
||||
)));
|
||||
}
|
||||
Ok(Values(
|
||||
types
|
||||
.into_iter()
|
||||
.zip(values.into_iter())
|
||||
.map(|(ty, v)| parse_value(v, ty))
|
||||
.collect::<Result<_, _>>()?,
|
||||
))
|
||||
}
|
||||
serde_json::Value::Array(values) => parse_strict_json(values, types),
|
||||
_ => Err(Error::Type(format!(
|
||||
"Expected an array of values, found `{}`",
|
||||
values
|
||||
|
@ -322,6 +307,27 @@ pub fn parse_strict<T: Field>(s: &str, types: Vec<ConcreteType>) -> Result<Value
|
|||
}
|
||||
}
|
||||
|
||||
pub fn parse_strict_json<T: Field>(
|
||||
values: Vec<serde_json::Value>,
|
||||
types: Vec<ConcreteType>,
|
||||
) -> Result<Values<T>, Error> {
|
||||
if values.len() != types.len() {
|
||||
return Err(Error::Type(format!(
|
||||
"Expected {} inputs, found {}",
|
||||
types.len(),
|
||||
values.len()
|
||||
)));
|
||||
}
|
||||
|
||||
Ok(Values(
|
||||
types
|
||||
.into_iter()
|
||||
.zip(values.into_iter())
|
||||
.map(|(ty, v)| parse_value(v, ty))
|
||||
.collect::<Result<_, _>>()?,
|
||||
))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
@ -428,6 +434,7 @@ mod tests {
|
|||
vec![ConcreteType::Struct(ConcreteStructType::new(
|
||||
"".into(),
|
||||
"".into(),
|
||||
vec![],
|
||||
vec![ConcreteStructMember::new(
|
||||
"a".into(),
|
||||
ConcreteType::FieldElement
|
||||
|
@ -449,6 +456,7 @@ mod tests {
|
|||
vec![ConcreteType::Struct(ConcreteStructType::new(
|
||||
"".into(),
|
||||
"".into(),
|
||||
vec![],
|
||||
vec![ConcreteStructMember::new(
|
||||
"a".into(),
|
||||
ConcreteType::FieldElement
|
||||
|
@ -466,6 +474,7 @@ mod tests {
|
|||
vec![ConcreteType::Struct(ConcreteStructType::new(
|
||||
"".into(),
|
||||
"".into(),
|
||||
vec![],
|
||||
vec![ConcreteStructMember::new(
|
||||
"a".into(),
|
||||
ConcreteType::FieldElement
|
||||
|
@ -483,6 +492,7 @@ mod tests {
|
|||
vec![ConcreteType::Struct(ConcreteStructType::new(
|
||||
"".into(),
|
||||
"".into(),
|
||||
vec![],
|
||||
vec![ConcreteStructMember::new(
|
||||
"a".into(),
|
||||
ConcreteType::FieldElement
|
||||
|
|
|
@ -107,7 +107,7 @@ field[2] b = a[1..3] // initialize an array copying a slice from `a`
|
|||
```
|
||||
|
||||
### Structs
|
||||
A struct is a composite datatype representing a named collection of variables.
|
||||
A struct is a composite datatype representing a named collection of variables. Structs can be generic over constants, in order to wrap arrays of generic size. For more details on generic array sizes, see [constant generics](../language/generics.md)
|
||||
The contained variables can be of any type.
|
||||
|
||||
The following code shows an example of how to use structs.
|
||||
|
|
|
@ -1,14 +1,14 @@
|
|||
struct Bar {
|
||||
field[2] c
|
||||
struct Bar<N> {
|
||||
field[N] c
|
||||
bool d
|
||||
}
|
||||
|
||||
struct Foo {
|
||||
Bar a
|
||||
struct Foo<P> {
|
||||
Bar<P> a
|
||||
bool b
|
||||
}
|
||||
|
||||
def main() -> (Foo):
|
||||
Foo[2] f = [Foo { a: Bar { c: [0, 0], d: false }, b: true}, Foo { a: Bar {c: [0, 0], d: false}, b: true}]
|
||||
def main() -> (Foo<2>):
|
||||
Foo<2>[2] f = [Foo { a: Bar { c: [0, 0], d: false }, b: true}, Foo { a: Bar {c: [0, 0], d: false}, b: true}]
|
||||
f[0].a.c = [42, 43]
|
||||
return f[0]
|
||||
|
|
|
@ -0,0 +1,4 @@
|
|||
const u32[2] A = [1]
|
||||
|
||||
def main() -> u32[2]:
|
||||
return A
|
|
@ -0,0 +1,6 @@
|
|||
struct A<N, N> {
|
||||
field[N] a
|
||||
}
|
||||
|
||||
def main():
|
||||
return
|
|
@ -0,0 +1,7 @@
|
|||
struct A<N> {
|
||||
field[N] a
|
||||
}
|
||||
|
||||
def main():
|
||||
A<_> a = A { a: [1] }
|
||||
return
|
|
@ -0,0 +1,6 @@
|
|||
struct A<1> {
|
||||
field[1] a
|
||||
}
|
||||
|
||||
def main():
|
||||
return
|
|
@ -0,0 +1,6 @@
|
|||
struct A<N> {
|
||||
field[N] a
|
||||
}
|
||||
|
||||
def main(A<1> a, A<2> b) -> bool:
|
||||
return a == b
|
|
@ -0,0 +1,6 @@
|
|||
struct A {
|
||||
field[N] a
|
||||
}
|
||||
|
||||
def main():
|
||||
return
|
|
@ -0,0 +1,4 @@
|
|||
struct A<N> {}
|
||||
|
||||
def main():
|
||||
return
|
15
zokrates_cli/examples/compile_errors/wrong_member_type.zok
Normal file
15
zokrates_cli/examples/compile_errors/wrong_member_type.zok
Normal file
|
@ -0,0 +1,15 @@
|
|||
struct B {
|
||||
field a
|
||||
}
|
||||
|
||||
struct A {
|
||||
B a
|
||||
}
|
||||
|
||||
def main():
|
||||
A a = A {
|
||||
a: B {
|
||||
a: false
|
||||
}
|
||||
}
|
||||
return
|
13
zokrates_cli/examples/structs/generic.zok
Normal file
13
zokrates_cli/examples/structs/generic.zok
Normal file
|
@ -0,0 +1,13 @@
|
|||
struct B<P> {
|
||||
field[P] a
|
||||
}
|
||||
|
||||
struct A<N> {
|
||||
field[N] a
|
||||
B<N> b
|
||||
}
|
||||
|
||||
def main(A<2> a) -> bool:
|
||||
u32 SIZE = 1 + 1
|
||||
A<SIZE> b = A { a: [1, 2], b: B { a: [1, 2] } }
|
||||
return a == b
|
|
@ -79,6 +79,11 @@ impl<'ast> From<pest::StructDefinition<'ast>> for absy::SymbolDeclarationNode<'a
|
|||
let id = definition.id.span.as_str();
|
||||
|
||||
let ty = absy::StructDefinition {
|
||||
generics: definition
|
||||
.generics
|
||||
.into_iter()
|
||||
.map(absy::ConstantGenericNode::from)
|
||||
.collect(),
|
||||
fields: definition
|
||||
.fields
|
||||
.into_iter()
|
||||
|
@ -767,9 +772,25 @@ impl<'ast> From<pest::Type<'ast>> for absy::UnresolvedTypeNode<'ast> {
|
|||
pest::BasicType::U32(t) => UnresolvedType::Uint(32).span(t.span),
|
||||
pest::BasicType::U64(t) => UnresolvedType::Uint(64).span(t.span),
|
||||
},
|
||||
pest::BasicOrStructType::Struct(t) => {
|
||||
UnresolvedType::User(t.span.as_str().to_string()).span(t.span)
|
||||
}
|
||||
pest::BasicOrStructType::Struct(t) => UnresolvedType::User(
|
||||
t.id.span.as_str().to_string(),
|
||||
t.explicit_generics.map(|explicit_generics| {
|
||||
explicit_generics
|
||||
.values
|
||||
.into_iter()
|
||||
.map(|i| match i {
|
||||
pest::ConstantGenericValue::Underscore(_) => None,
|
||||
pest::ConstantGenericValue::Value(v) => {
|
||||
Some(absy::ExpressionNode::from(v))
|
||||
}
|
||||
pest::ConstantGenericValue::Identifier(i) => Some(
|
||||
absy::Expression::Identifier(i.span.as_str()).span(i.span),
|
||||
),
|
||||
})
|
||||
.collect()
|
||||
}),
|
||||
)
|
||||
.span(t.span),
|
||||
};
|
||||
|
||||
let span = t.span;
|
||||
|
@ -785,9 +806,25 @@ impl<'ast> From<pest::Type<'ast>> for absy::UnresolvedTypeNode<'ast> {
|
|||
.unwrap()
|
||||
.span(span.clone())
|
||||
}
|
||||
pest::Type::Struct(s) => {
|
||||
UnresolvedType::User(s.id.span.as_str().to_string()).span(s.span)
|
||||
}
|
||||
pest::Type::Struct(s) => UnresolvedType::User(
|
||||
s.id.span.as_str().to_string(),
|
||||
s.explicit_generics.map(|explicit_generics| {
|
||||
explicit_generics
|
||||
.values
|
||||
.into_iter()
|
||||
.map(|i| match i {
|
||||
pest::ConstantGenericValue::Underscore(_) => None,
|
||||
pest::ConstantGenericValue::Value(v) => {
|
||||
Some(absy::ExpressionNode::from(v))
|
||||
}
|
||||
pest::ConstantGenericValue::Identifier(i) => {
|
||||
Some(absy::Expression::Identifier(i.span.as_str()).span(i.span))
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}),
|
||||
)
|
||||
.span(s.span),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -153,7 +153,7 @@ impl<'ast> fmt::Display for SymbolDeclaration<'ast> {
|
|||
i.value.source.display(),
|
||||
i.value.id
|
||||
),
|
||||
SymbolDefinition::Struct(ref t) => write!(f, "struct {} {}", self.id, t),
|
||||
SymbolDefinition::Struct(ref t) => write!(f, "struct {}{}", self.id, t),
|
||||
SymbolDefinition::Constant(ref c) => write!(
|
||||
f,
|
||||
"const {} {} = {}",
|
||||
|
@ -199,20 +199,25 @@ pub type UnresolvedTypeNode<'ast> = Node<UnresolvedType<'ast>>;
|
|||
/// A struct type definition
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub struct StructDefinition<'ast> {
|
||||
pub generics: Vec<ConstantGenericNode<'ast>>,
|
||||
pub fields: Vec<StructDefinitionFieldNode<'ast>>,
|
||||
}
|
||||
|
||||
impl<'ast> fmt::Display for StructDefinition<'ast> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
write!(
|
||||
writeln!(
|
||||
f,
|
||||
"{}",
|
||||
self.fields
|
||||
"<{}> {{",
|
||||
self.generics
|
||||
.iter()
|
||||
.map(|fi| fi.to_string())
|
||||
.map(|g| g.to_string())
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n")
|
||||
)
|
||||
.join(", "),
|
||||
)?;
|
||||
for field in &self.fields {
|
||||
writeln!(f, " {}", field)?;
|
||||
}
|
||||
write!(f, "}}",)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -227,7 +232,7 @@ pub struct StructDefinitionField<'ast> {
|
|||
|
||||
impl<'ast> fmt::Display for StructDefinitionField<'ast> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
write!(f, "{}: {},", self.id, self.ty)
|
||||
write!(f, "{} {}", self.ty, self.id)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -3,9 +3,7 @@ use crate::absy::UnresolvedTypeNode;
|
|||
use std::fmt;
|
||||
|
||||
pub type Identifier<'ast> = &'ast str;
|
||||
|
||||
pub type MemberId = String;
|
||||
|
||||
pub type UserTypeId = String;
|
||||
|
||||
#[derive(Clone, PartialEq, Debug)]
|
||||
|
@ -14,7 +12,7 @@ pub enum UnresolvedType<'ast> {
|
|||
Boolean,
|
||||
Uint(usize),
|
||||
Array(Box<UnresolvedTypeNode<'ast>>, ExpressionNode<'ast>),
|
||||
User(UserTypeId),
|
||||
User(UserTypeId, Option<Vec<Option<ExpressionNode<'ast>>>>),
|
||||
}
|
||||
|
||||
impl<'ast> fmt::Display for UnresolvedType<'ast> {
|
||||
|
@ -24,7 +22,30 @@ impl<'ast> fmt::Display for UnresolvedType<'ast> {
|
|||
UnresolvedType::Boolean => write!(f, "bool"),
|
||||
UnresolvedType::Uint(bitwidth) => write!(f, "u{}", bitwidth),
|
||||
UnresolvedType::Array(ref ty, ref size) => write!(f, "{}[{}]", ty, size),
|
||||
UnresolvedType::User(i) => write!(f, "{}", i),
|
||||
UnresolvedType::User(ref id, ref generics) => {
|
||||
write!(
|
||||
f,
|
||||
"{}{}",
|
||||
id,
|
||||
generics
|
||||
.as_ref()
|
||||
.map(|generics| {
|
||||
format!(
|
||||
"<{}>",
|
||||
generics
|
||||
.iter()
|
||||
.map(|e| {
|
||||
e.as_ref()
|
||||
.map(|e| e.to_string())
|
||||
.unwrap_or_else(|| "_".to_string())
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.join(", ")
|
||||
)
|
||||
})
|
||||
.unwrap_or_default()
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -437,11 +437,13 @@ struct Bar { field a }
|
|||
ty: ConcreteType::Struct(ConcreteStructType::new(
|
||||
"foo".into(),
|
||||
"Foo".into(),
|
||||
vec![],
|
||||
vec![ConcreteStructMember {
|
||||
id: "b".into(),
|
||||
ty: box ConcreteType::Struct(ConcreteStructType::new(
|
||||
"bar".into(),
|
||||
"Bar".into(),
|
||||
vec![],
|
||||
vec![ConcreteStructMember {
|
||||
id: "a".into(),
|
||||
ty: box ConcreteType::FieldElement
|
||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -121,7 +121,6 @@ impl<'ast, T: Field> Folder<'ast, T> for ConstantInliner<'ast, T> {
|
|||
(
|
||||
id,
|
||||
TypedConstantSymbol::Here(TypedConstant {
|
||||
ty: constant.get_type().clone(),
|
||||
expression: constant,
|
||||
}),
|
||||
)
|
||||
|
@ -252,8 +251,9 @@ mod tests {
|
|||
use super::*;
|
||||
use crate::typed_absy::types::DeclarationSignature;
|
||||
use crate::typed_absy::{
|
||||
DeclarationFunctionKey, DeclarationType, FieldElementExpression, GType, Identifier,
|
||||
TypedConstant, TypedExpression, TypedFunction, TypedFunctionSymbol, TypedStatement,
|
||||
DeclarationArrayType, DeclarationFunctionKey, DeclarationType, FieldElementExpression,
|
||||
GType, Identifier, TypedConstant, TypedExpression, TypedFunction, TypedFunctionSymbol,
|
||||
TypedStatement,
|
||||
};
|
||||
use zokrates_field::Bn128Field;
|
||||
|
||||
|
@ -276,11 +276,14 @@ mod tests {
|
|||
};
|
||||
|
||||
let constants: TypedConstantSymbols<_> = vec![(
|
||||
CanonicalConstantIdentifier::new(const_id, "main".into()),
|
||||
TypedConstantSymbol::Here(TypedConstant::new(
|
||||
GType::FieldElement,
|
||||
TypedExpression::FieldElement(FieldElementExpression::Number(Bn128Field::from(1))),
|
||||
)),
|
||||
CanonicalConstantIdentifier::new(
|
||||
const_id,
|
||||
"main".into(),
|
||||
DeclarationType::FieldElement,
|
||||
),
|
||||
TypedConstantSymbol::Here(TypedConstant::new(TypedExpression::FieldElement(
|
||||
FieldElementExpression::Number(Bn128Field::from(1)),
|
||||
))),
|
||||
)]
|
||||
.into_iter()
|
||||
.collect();
|
||||
|
@ -364,11 +367,10 @@ mod tests {
|
|||
};
|
||||
|
||||
let constants: TypedConstantSymbols<_> = vec![(
|
||||
CanonicalConstantIdentifier::new(const_id, "main".into()),
|
||||
TypedConstantSymbol::Here(TypedConstant::new(
|
||||
GType::Boolean,
|
||||
TypedExpression::Boolean(BooleanExpression::Value(true)),
|
||||
)),
|
||||
CanonicalConstantIdentifier::new(const_id, "main".into(), DeclarationType::Boolean),
|
||||
TypedConstantSymbol::Here(TypedConstant::new(TypedExpression::Boolean(
|
||||
BooleanExpression::Value(true),
|
||||
))),
|
||||
)]
|
||||
.into_iter()
|
||||
.collect();
|
||||
|
@ -453,9 +455,12 @@ mod tests {
|
|||
};
|
||||
|
||||
let constants: TypedConstantSymbols<_> = vec![(
|
||||
CanonicalConstantIdentifier::new(const_id, "main".into()),
|
||||
CanonicalConstantIdentifier::new(
|
||||
const_id,
|
||||
"main".into(),
|
||||
DeclarationType::Uint(UBitwidth::B32),
|
||||
),
|
||||
TypedConstantSymbol::Here(TypedConstant::new(
|
||||
GType::Uint(UBitwidth::B32),
|
||||
UExpressionInner::Value(1u128)
|
||||
.annotate(UBitwidth::B32)
|
||||
.into(),
|
||||
|
@ -554,20 +559,24 @@ mod tests {
|
|||
};
|
||||
|
||||
let constants: TypedConstantSymbols<_> = vec![(
|
||||
CanonicalConstantIdentifier::new(const_id, "main".into()),
|
||||
TypedConstantSymbol::Here(TypedConstant::new(
|
||||
GType::array(GArrayType::new(GType::FieldElement, 2usize)),
|
||||
TypedExpression::Array(
|
||||
ArrayExpressionInner::Value(
|
||||
vec![
|
||||
FieldElementExpression::Number(Bn128Field::from(2)).into(),
|
||||
FieldElementExpression::Number(Bn128Field::from(2)).into(),
|
||||
]
|
||||
.into(),
|
||||
)
|
||||
.annotate(GType::FieldElement, 2usize),
|
||||
),
|
||||
)),
|
||||
CanonicalConstantIdentifier::new(
|
||||
const_id,
|
||||
"main".into(),
|
||||
DeclarationType::Array(DeclarationArrayType::new(
|
||||
DeclarationType::FieldElement,
|
||||
2u32,
|
||||
)),
|
||||
),
|
||||
TypedConstantSymbol::Here(TypedConstant::new(TypedExpression::Array(
|
||||
ArrayExpressionInner::Value(
|
||||
vec![
|
||||
FieldElementExpression::Number(Bn128Field::from(2)).into(),
|
||||
FieldElementExpression::Number(Bn128Field::from(2)).into(),
|
||||
]
|
||||
.into(),
|
||||
)
|
||||
.annotate(GType::FieldElement, 2usize),
|
||||
))),
|
||||
)]
|
||||
.into_iter()
|
||||
.collect();
|
||||
|
@ -693,18 +702,24 @@ mod tests {
|
|||
.collect(),
|
||||
constants: vec![
|
||||
(
|
||||
CanonicalConstantIdentifier::new(const_a_id, "main".into()),
|
||||
CanonicalConstantIdentifier::new(
|
||||
const_a_id,
|
||||
"main".into(),
|
||||
DeclarationType::FieldElement,
|
||||
),
|
||||
TypedConstantSymbol::Here(TypedConstant::new(
|
||||
GType::FieldElement,
|
||||
TypedExpression::FieldElement(FieldElementExpression::Number(
|
||||
Bn128Field::from(1),
|
||||
)),
|
||||
)),
|
||||
),
|
||||
(
|
||||
CanonicalConstantIdentifier::new(const_b_id, "main".into()),
|
||||
CanonicalConstantIdentifier::new(
|
||||
const_b_id,
|
||||
"main".into(),
|
||||
DeclarationType::FieldElement,
|
||||
),
|
||||
TypedConstantSymbol::Here(TypedConstant::new(
|
||||
GType::FieldElement,
|
||||
TypedExpression::FieldElement(FieldElementExpression::Add(
|
||||
box FieldElementExpression::Identifier(Identifier::from(
|
||||
const_a_id,
|
||||
|
@ -751,18 +766,24 @@ mod tests {
|
|||
.collect(),
|
||||
constants: vec![
|
||||
(
|
||||
CanonicalConstantIdentifier::new(const_a_id, "main".into()),
|
||||
CanonicalConstantIdentifier::new(
|
||||
const_a_id,
|
||||
"main".into(),
|
||||
DeclarationType::FieldElement,
|
||||
),
|
||||
TypedConstantSymbol::Here(TypedConstant::new(
|
||||
GType::FieldElement,
|
||||
TypedExpression::FieldElement(FieldElementExpression::Number(
|
||||
Bn128Field::from(1),
|
||||
)),
|
||||
)),
|
||||
),
|
||||
(
|
||||
CanonicalConstantIdentifier::new(const_b_id, "main".into()),
|
||||
CanonicalConstantIdentifier::new(
|
||||
const_b_id,
|
||||
"main".into(),
|
||||
DeclarationType::FieldElement,
|
||||
),
|
||||
TypedConstantSymbol::Here(TypedConstant::new(
|
||||
GType::FieldElement,
|
||||
TypedExpression::FieldElement(FieldElementExpression::Number(
|
||||
Bn128Field::from(2),
|
||||
)),
|
||||
|
@ -812,13 +833,14 @@ mod tests {
|
|||
.into_iter()
|
||||
.collect(),
|
||||
constants: vec![(
|
||||
CanonicalConstantIdentifier::new(foo_const_id, "foo".into()),
|
||||
TypedConstantSymbol::Here(TypedConstant::new(
|
||||
GType::FieldElement,
|
||||
TypedExpression::FieldElement(FieldElementExpression::Number(
|
||||
Bn128Field::from(42),
|
||||
)),
|
||||
)),
|
||||
CanonicalConstantIdentifier::new(
|
||||
foo_const_id,
|
||||
"foo".into(),
|
||||
DeclarationType::FieldElement,
|
||||
),
|
||||
TypedConstantSymbol::Here(TypedConstant::new(TypedExpression::FieldElement(
|
||||
FieldElementExpression::Number(Bn128Field::from(42)),
|
||||
))),
|
||||
)]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
|
@ -844,10 +866,15 @@ mod tests {
|
|||
.into_iter()
|
||||
.collect(),
|
||||
constants: vec![(
|
||||
CanonicalConstantIdentifier::new(foo_const_id, "main".into()),
|
||||
CanonicalConstantIdentifier::new(
|
||||
foo_const_id,
|
||||
"main".into(),
|
||||
DeclarationType::FieldElement,
|
||||
),
|
||||
TypedConstantSymbol::There(CanonicalConstantIdentifier::new(
|
||||
foo_const_id,
|
||||
"foo".into(),
|
||||
DeclarationType::FieldElement,
|
||||
)),
|
||||
)]
|
||||
.into_iter()
|
||||
|
@ -885,13 +912,14 @@ mod tests {
|
|||
.into_iter()
|
||||
.collect(),
|
||||
constants: vec![(
|
||||
CanonicalConstantIdentifier::new(foo_const_id, "main".into()),
|
||||
TypedConstantSymbol::Here(TypedConstant::new(
|
||||
GType::FieldElement,
|
||||
TypedExpression::FieldElement(FieldElementExpression::Number(
|
||||
Bn128Field::from(42),
|
||||
)),
|
||||
)),
|
||||
CanonicalConstantIdentifier::new(
|
||||
foo_const_id,
|
||||
"main".into(),
|
||||
DeclarationType::FieldElement,
|
||||
),
|
||||
TypedConstantSymbol::Here(TypedConstant::new(TypedExpression::FieldElement(
|
||||
FieldElementExpression::Number(Bn128Field::from(42)),
|
||||
))),
|
||||
)]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
|
|
|
@ -129,7 +129,7 @@ impl<'ast, T: Field> Flattener<T> {
|
|||
p: typed_absy::DeclarationParameter<'ast>,
|
||||
) -> Vec<zir::Parameter<'ast>> {
|
||||
let private = p.private;
|
||||
self.fold_variable(p.id.try_into().unwrap())
|
||||
self.fold_variable(crate::typed_absy::variable::try_from_g_variable(p.id).unwrap())
|
||||
.into_iter()
|
||||
.map(|v| zir::Parameter { id: v, private })
|
||||
.collect()
|
||||
|
@ -1101,7 +1101,11 @@ fn fold_function<'ast, T: Field>(
|
|||
.collect(),
|
||||
statements: main_statements_buffer,
|
||||
signature: typed_absy::types::ConcreteSignature::try_from(
|
||||
typed_absy::types::Signature::<T>::try_from(fun.signature).unwrap(),
|
||||
crate::typed_absy::types::try_from_g_signature::<
|
||||
crate::typed_absy::types::DeclarationConstant<'ast>,
|
||||
crate::typed_absy::UExpression<'ast, T>,
|
||||
>(fun.signature)
|
||||
.unwrap(),
|
||||
)
|
||||
.unwrap()
|
||||
.into(),
|
||||
|
|
|
@ -1278,6 +1278,24 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Propagator<'ast, 'a, T> {
|
|||
|
||||
Ok(BooleanExpression::ArrayEq(box e1, box e2))
|
||||
}
|
||||
BooleanExpression::StructEq(box e1, box e2) => {
|
||||
let e1 = self.fold_struct_expression(e1)?;
|
||||
let e2 = self.fold_struct_expression(e2)?;
|
||||
|
||||
if let (Ok(t1), Ok(t2)) = (
|
||||
ConcreteType::try_from(e1.get_type()),
|
||||
ConcreteType::try_from(e2.get_type()),
|
||||
) {
|
||||
if t1 != t2 {
|
||||
return Err(Error::Type(format!(
|
||||
"Cannot compare {} of type {} to {} of type {}",
|
||||
e1, t1, e2, t2
|
||||
)));
|
||||
}
|
||||
};
|
||||
|
||||
Ok(BooleanExpression::StructEq(box e1, box e2))
|
||||
}
|
||||
BooleanExpression::FieldLt(box e1, box e2) => {
|
||||
let e1 = self.fold_field_expression(e1)?;
|
||||
let e2 = self.fold_field_expression(e2)?;
|
||||
|
|
|
@ -230,16 +230,21 @@ mod tests {
|
|||
public: true,
|
||||
ty: ConcreteType::Struct(ConcreteStructType::new(
|
||||
"".into(),
|
||||
"Foo".into(),
|
||||
vec![
|
||||
ConcreteStructMember::new(String::from("a"), ConcreteType::FieldElement),
|
||||
ConcreteStructMember::new(String::from("b"), ConcreteType::Boolean),
|
||||
],
|
||||
"Bar".into(),
|
||||
vec![Some(1usize)],
|
||||
vec![ConcreteStructMember::new(
|
||||
String::from("a"),
|
||||
ConcreteType::Array(ConcreteArrayType::new(
|
||||
ConcreteType::FieldElement,
|
||||
1usize,
|
||||
)),
|
||||
)],
|
||||
)),
|
||||
}],
|
||||
outputs: vec![ConcreteType::Struct(ConcreteStructType::new(
|
||||
"".into(),
|
||||
"Foo".into(),
|
||||
vec![],
|
||||
vec![
|
||||
ConcreteStructMember::new(String::from("a"), ConcreteType::FieldElement),
|
||||
ConcreteStructMember::new(String::from("b"), ConcreteType::Boolean),
|
||||
|
@ -257,15 +262,18 @@ mod tests {
|
|||
"public": true,
|
||||
"type": "struct",
|
||||
"components": {
|
||||
"name": "Foo",
|
||||
"name": "Bar",
|
||||
"generics": [
|
||||
1
|
||||
],
|
||||
"members": [
|
||||
{
|
||||
"name": "a",
|
||||
"type": "field"
|
||||
},
|
||||
{
|
||||
"name": "b",
|
||||
"type": "bool"
|
||||
"type": "array",
|
||||
"components": {
|
||||
"size": 1,
|
||||
"type": "field"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
|
@ -276,6 +284,7 @@ mod tests {
|
|||
"type": "struct",
|
||||
"components": {
|
||||
"name": "Foo",
|
||||
"generics": [],
|
||||
"members": [
|
||||
{
|
||||
"name": "a",
|
||||
|
@ -305,11 +314,13 @@ mod tests {
|
|||
ty: ConcreteType::Struct(ConcreteStructType::new(
|
||||
"".into(),
|
||||
"Foo".into(),
|
||||
vec![],
|
||||
vec![ConcreteStructMember::new(
|
||||
String::from("bar"),
|
||||
ConcreteType::Struct(ConcreteStructType::new(
|
||||
"".into(),
|
||||
"Bar".into(),
|
||||
vec![],
|
||||
vec![
|
||||
ConcreteStructMember::new(
|
||||
String::from("a"),
|
||||
|
@ -338,12 +349,14 @@ mod tests {
|
|||
"type": "struct",
|
||||
"components": {
|
||||
"name": "Foo",
|
||||
"generics": [],
|
||||
"members": [
|
||||
{
|
||||
"name": "bar",
|
||||
"type": "struct",
|
||||
"components": {
|
||||
"name": "Bar",
|
||||
"generics": [],
|
||||
"members": [
|
||||
{
|
||||
"name": "a",
|
||||
|
@ -378,6 +391,7 @@ mod tests {
|
|||
ConcreteType::Struct(ConcreteStructType::new(
|
||||
"".into(),
|
||||
"Foo".into(),
|
||||
vec![],
|
||||
vec![
|
||||
ConcreteStructMember::new(
|
||||
String::from("b"),
|
||||
|
@ -386,7 +400,7 @@ mod tests {
|
|||
ConcreteStructMember::new(String::from("c"), ConcreteType::Boolean),
|
||||
],
|
||||
)),
|
||||
2,
|
||||
2usize,
|
||||
)),
|
||||
}],
|
||||
outputs: vec![ConcreteType::Boolean],
|
||||
|
@ -406,6 +420,7 @@ mod tests {
|
|||
"type": "struct",
|
||||
"components": {
|
||||
"name": "Foo",
|
||||
"generics": [],
|
||||
"members": [
|
||||
{
|
||||
"name": "b",
|
||||
|
@ -439,8 +454,8 @@ mod tests {
|
|||
name: String::from("a"),
|
||||
public: false,
|
||||
ty: ConcreteType::Array(ConcreteArrayType::new(
|
||||
ConcreteType::Array(ConcreteArrayType::new(ConcreteType::FieldElement, 2)),
|
||||
2,
|
||||
ConcreteType::Array(ConcreteArrayType::new(ConcreteType::FieldElement, 2usize)),
|
||||
2usize,
|
||||
)),
|
||||
}],
|
||||
outputs: vec![ConcreteType::FieldElement],
|
||||
|
|
|
@ -138,6 +138,11 @@ pub trait Folder<'ast, T: Field>: Sized {
|
|||
|
||||
fn fold_struct_type(&mut self, t: StructType<'ast, T>) -> StructType<'ast, T> {
|
||||
StructType {
|
||||
generics: t
|
||||
.generics
|
||||
.into_iter()
|
||||
.map(|g| g.map(|g| self.fold_uint_expression(g)))
|
||||
.collect(),
|
||||
members: t
|
||||
.members
|
||||
.into_iter()
|
||||
|
@ -175,6 +180,11 @@ pub trait Folder<'ast, T: Field>: Sized {
|
|||
t: DeclarationStructType<'ast>,
|
||||
) -> DeclarationStructType<'ast> {
|
||||
DeclarationStructType {
|
||||
generics: t
|
||||
.generics
|
||||
.into_iter()
|
||||
.map(|g| g.map(|g| self.fold_declaration_constant(g)))
|
||||
.collect(),
|
||||
members: t
|
||||
.members
|
||||
.into_iter()
|
||||
|
@ -222,6 +232,7 @@ pub trait Folder<'ast, T: Field>: Sized {
|
|||
CanonicalConstantIdentifier {
|
||||
module: self.fold_module_id(i.module),
|
||||
id: i.id,
|
||||
ty: box self.fold_declaration_type(*i.ty),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1044,7 +1055,6 @@ pub fn fold_constant<'ast, T: Field, F: Folder<'ast, T>>(
|
|||
c: TypedConstant<'ast, T>,
|
||||
) -> TypedConstant<'ast, T> {
|
||||
TypedConstant {
|
||||
ty: f.fold_type(c.ty),
|
||||
expression: f.fold_expression(c.expression),
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,9 +1,13 @@
|
|||
use crate::typed_absy::types::{ArrayType, Type};
|
||||
use crate::typed_absy::types::{
|
||||
ArrayType, DeclarationArrayType, DeclarationConstant, DeclarationStructMember,
|
||||
DeclarationStructType, DeclarationType, GArrayType, GStructType, GType, GenericIdentifier,
|
||||
StructType, Type,
|
||||
};
|
||||
use crate::typed_absy::UBitwidth;
|
||||
use crate::typed_absy::{
|
||||
ArrayExpression, ArrayExpressionInner, BooleanExpression, FieldElementExpression, IfElse,
|
||||
IfElseExpression, Select, SelectExpression, StructExpression, Typed, TypedExpression,
|
||||
TypedExpressionOrSpread, TypedSpread, UExpression, UExpressionInner,
|
||||
IfElseExpression, Select, SelectExpression, StructExpression, StructExpressionInner, Typed,
|
||||
TypedExpression, TypedExpressionOrSpread, TypedSpread, UExpression, UExpressionInner,
|
||||
};
|
||||
use num_bigint::BigUint;
|
||||
use std::convert::TryFrom;
|
||||
|
@ -14,20 +18,107 @@ use zokrates_field::Field;
|
|||
type TypedExpressionPair<'ast, T> = (TypedExpression<'ast, T>, TypedExpression<'ast, T>);
|
||||
|
||||
impl<'ast, T: Field> TypedExpressionOrSpread<'ast, T> {
|
||||
pub fn align_to_type(e: Self, ty: Type<'ast, T>) -> Result<Self, (Self, Type<'ast, T>)> {
|
||||
pub fn align_to_type<S: PartialEq<UExpression<'ast, T>>>(
|
||||
e: Self,
|
||||
ty: &GArrayType<S>,
|
||||
) -> Result<Self, (Self, &GArrayType<S>)> {
|
||||
match e {
|
||||
TypedExpressionOrSpread::Expression(e) => TypedExpression::align_to_type(e, ty)
|
||||
TypedExpressionOrSpread::Expression(e) => TypedExpression::align_to_type(e, &ty.ty)
|
||||
.map(|e| e.into())
|
||||
.map_err(|(e, t)| (e.into(), t)),
|
||||
TypedExpressionOrSpread::Spread(s) => {
|
||||
ArrayExpression::try_from_int(s.array, ty.clone())
|
||||
.map(|e| TypedExpressionOrSpread::Spread(TypedSpread { array: e }))
|
||||
.map_err(|e| (e.into(), ty))
|
||||
}
|
||||
.map_err(|(e, _)| (e.into(), ty)),
|
||||
TypedExpressionOrSpread::Spread(s) => ArrayExpression::try_from_int(s.array, ty)
|
||||
.map(|e| TypedExpressionOrSpread::Spread(TypedSpread { array: e }))
|
||||
.map_err(|e| (e.into(), ty)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
trait IntegerInference: Sized {
|
||||
type Pattern;
|
||||
|
||||
fn get_common_pattern(self, other: Self) -> Result<Self::Pattern, (Self, Self)>;
|
||||
}
|
||||
|
||||
impl<'ast, T> IntegerInference for Type<'ast, T> {
|
||||
type Pattern = DeclarationType<'ast>;
|
||||
|
||||
fn get_common_pattern(self, other: Self) -> Result<Self::Pattern, (Self, Self)> {
|
||||
match (self, other) {
|
||||
(Type::Boolean, Type::Boolean) => Ok(DeclarationType::Boolean),
|
||||
(Type::Int, Type::Int) => Err((Type::Int, Type::Int)),
|
||||
(Type::Int, Type::FieldElement) => Ok(DeclarationType::FieldElement),
|
||||
(Type::Int, Type::Uint(bitwidth)) => Ok(DeclarationType::Uint(bitwidth)),
|
||||
(Type::FieldElement, Type::Int) => Ok(DeclarationType::FieldElement),
|
||||
(Type::Uint(bitwidth), Type::Int) => Ok(DeclarationType::Uint(bitwidth)),
|
||||
(Type::FieldElement, Type::FieldElement) => Ok(DeclarationType::FieldElement),
|
||||
(Type::Uint(b0), Type::Uint(b1)) => {
|
||||
if b0 == b1 {
|
||||
Ok(DeclarationType::Uint(b0))
|
||||
} else {
|
||||
Err((Type::Uint(b0), Type::Uint(b1)))
|
||||
}
|
||||
}
|
||||
(Type::Array(t), Type::Array(u)) => Ok(DeclarationType::Array(
|
||||
t.get_common_pattern(u)
|
||||
.map_err(|(t, u)| (Type::Array(t), Type::Array(u)))?,
|
||||
)),
|
||||
(Type::Struct(t), Type::Struct(u)) => Ok(DeclarationType::Struct(
|
||||
t.get_common_pattern(u)
|
||||
.map_err(|(t, u)| (Type::Struct(t), Type::Struct(u)))?,
|
||||
)),
|
||||
(t, u) => Err((t, u)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast, T> IntegerInference for ArrayType<'ast, T> {
|
||||
type Pattern = DeclarationArrayType<'ast>;
|
||||
|
||||
fn get_common_pattern(self, other: Self) -> Result<Self::Pattern, (Self, Self)> {
|
||||
let s0 = self.size;
|
||||
let s1 = other.size;
|
||||
|
||||
Ok(DeclarationArrayType::new(
|
||||
self.ty
|
||||
.get_common_pattern(*other.ty)
|
||||
.map_err(|(t, u)| (ArrayType::new(t, s0), ArrayType::new(u, s1)))?,
|
||||
DeclarationConstant::Generic(GenericIdentifier::with_name("DUMMY")), // sizes are not checked at this stage, therefore we insert a dummy generic variable which will be equal to all possible sizes
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast, T> IntegerInference for StructType<'ast, T> {
|
||||
type Pattern = DeclarationStructType<'ast>;
|
||||
|
||||
fn get_common_pattern(self, other: Self) -> Result<Self::Pattern, (Self, Self)> {
|
||||
Ok(DeclarationStructType {
|
||||
members: self
|
||||
.members
|
||||
.into_iter()
|
||||
.zip(other.members.into_iter())
|
||||
.map(|(m_t, m_u)| match m_t.ty.get_common_pattern(*m_u.ty) {
|
||||
Ok(ty) => DeclarationStructMember {
|
||||
ty: box ty,
|
||||
id: m_t.id,
|
||||
},
|
||||
Err(..) => unreachable!(
|
||||
"struct instances of the same struct should always have a common type"
|
||||
),
|
||||
})
|
||||
.collect::<Vec<_>>(),
|
||||
canonical_location: self.canonical_location,
|
||||
location: self.location,
|
||||
generics: self
|
||||
.generics
|
||||
.into_iter()
|
||||
.map(|g| {
|
||||
g.map(|_| DeclarationConstant::Generic(GenericIdentifier::with_name("DUMMY")))
|
||||
})
|
||||
.collect(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast, T: Field> TypedExpression<'ast, T> {
|
||||
// return two TypedExpression, replacing IntExpression by FieldElement or Uint to try to align the two types if possible.
|
||||
// Post condition is that (lhs, rhs) cannot be made equal by further removing IntExpressions
|
||||
|
@ -51,7 +142,7 @@ impl<'ast, T: Field> TypedExpression<'ast, T> {
|
|||
.into(),
|
||||
)),
|
||||
(Int(lhs), Uint(rhs)) => Ok((
|
||||
UExpression::try_from_int(lhs, rhs.bitwidth())
|
||||
UExpression::try_from_int(lhs, &rhs.bitwidth())
|
||||
.map_err(|lhs| (lhs.into(), rhs.clone().into()))?
|
||||
.into(),
|
||||
Uint(rhs),
|
||||
|
@ -60,47 +151,50 @@ impl<'ast, T: Field> TypedExpression<'ast, T> {
|
|||
let bitwidth = lhs.bitwidth();
|
||||
Ok((
|
||||
Uint(lhs.clone()),
|
||||
UExpression::try_from_int(rhs, bitwidth)
|
||||
UExpression::try_from_int(rhs, &bitwidth)
|
||||
.map_err(|rhs| (lhs.into(), rhs.into()))?
|
||||
.into(),
|
||||
))
|
||||
}
|
||||
(Array(lhs), Array(rhs)) => {
|
||||
fn get_common_type<'a, T: Field>(
|
||||
t: Type<'a, T>,
|
||||
u: Type<'a, T>,
|
||||
) -> Result<Type<'a, T>, ()> {
|
||||
match (t, u) {
|
||||
(Type::Int, Type::Int) => Err(()),
|
||||
(Type::Int, u) => Ok(u),
|
||||
(t, Type::Int) => Ok(t),
|
||||
(Type::Array(t), Type::Array(u)) => Ok(Type::Array(ArrayType::new(
|
||||
get_common_type(*t.ty, *u.ty)?,
|
||||
t.size,
|
||||
))),
|
||||
(t, _) => Ok(t),
|
||||
}
|
||||
}
|
||||
let common_type = lhs
|
||||
.get_type()
|
||||
.get_common_pattern(rhs.get_type())
|
||||
.map_err(|_| (lhs.clone().into(), rhs.clone().into()))?;
|
||||
|
||||
let common_type =
|
||||
get_common_type(lhs.inner_type().clone(), rhs.inner_type().clone())
|
||||
.map_err(|_| (lhs.clone().into(), rhs.clone().into()))?;
|
||||
let common_type = match common_type {
|
||||
DeclarationType::Array(ty) => ty,
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
Ok((
|
||||
ArrayExpression::try_from_int(lhs.clone(), common_type.clone())
|
||||
ArrayExpression::try_from_int(lhs.clone(), &common_type)
|
||||
.map_err(|lhs| (lhs.clone(), rhs.clone().into()))?
|
||||
.into(),
|
||||
ArrayExpression::try_from_int(rhs, common_type)
|
||||
ArrayExpression::try_from_int(rhs, &common_type)
|
||||
.map_err(|rhs| (lhs.clone().into(), rhs.clone()))?
|
||||
.into(),
|
||||
))
|
||||
}
|
||||
(Struct(lhs), Struct(rhs)) => {
|
||||
if lhs.get_type() == rhs.get_type() {
|
||||
Ok((Struct(lhs), Struct(rhs)))
|
||||
} else {
|
||||
Err((Struct(lhs), Struct(rhs)))
|
||||
}
|
||||
let common_type = lhs
|
||||
.get_type()
|
||||
.get_common_pattern(rhs.get_type())
|
||||
.map_err(|_| (lhs.clone().into(), rhs.clone().into()))?;
|
||||
|
||||
let common_type = match common_type {
|
||||
DeclarationType::Struct(ty) => ty,
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
Ok((
|
||||
StructExpression::try_from_int(lhs.clone(), &common_type)
|
||||
.map_err(|lhs| (lhs.clone(), rhs.clone().into()))?
|
||||
.into(),
|
||||
StructExpression::try_from_int(rhs, &common_type)
|
||||
.map_err(|rhs| (lhs.clone().into(), rhs.clone()))?
|
||||
.into(),
|
||||
))
|
||||
}
|
||||
(Uint(lhs), Uint(rhs)) => Ok((lhs.into(), rhs.into())),
|
||||
(Boolean(lhs), Boolean(rhs)) => Ok((lhs.into(), rhs.into())),
|
||||
|
@ -110,22 +204,25 @@ impl<'ast, T: Field> TypedExpression<'ast, T> {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn align_to_type(e: Self, ty: Type<'ast, T>) -> Result<Self, (Self, Type<'ast, T>)> {
|
||||
match ty.clone() {
|
||||
Type::FieldElement => {
|
||||
pub fn align_to_type<S: PartialEq<UExpression<'ast, T>>>(
|
||||
e: Self,
|
||||
ty: >ype<S>,
|
||||
) -> Result<Self, (Self, >ype<S>)> {
|
||||
match ty {
|
||||
GType::FieldElement => {
|
||||
FieldElementExpression::try_from_typed(e).map(TypedExpression::from)
|
||||
}
|
||||
Type::Boolean => BooleanExpression::try_from_typed(e).map(TypedExpression::from),
|
||||
Type::Uint(bitwidth) => {
|
||||
GType::Boolean => BooleanExpression::try_from_typed(e).map(TypedExpression::from),
|
||||
GType::Uint(bitwidth) => {
|
||||
UExpression::try_from_typed(e, bitwidth).map(TypedExpression::from)
|
||||
}
|
||||
Type::Array(array_ty) => {
|
||||
ArrayExpression::try_from_typed(e, *array_ty.ty).map(TypedExpression::from)
|
||||
GType::Array(array_ty) => {
|
||||
ArrayExpression::try_from_typed(e, array_ty).map(TypedExpression::from)
|
||||
}
|
||||
Type::Struct(struct_ty) => {
|
||||
GType::Struct(struct_ty) => {
|
||||
StructExpression::try_from_typed(e, struct_ty).map(TypedExpression::from)
|
||||
}
|
||||
Type::Int => Err(e),
|
||||
GType::Int => Err(e),
|
||||
}
|
||||
.map_err(|e| (e, ty))
|
||||
}
|
||||
|
@ -299,7 +396,7 @@ impl<'ast, T: Field> FieldElementExpression<'ast, T> {
|
|||
)),
|
||||
IntExpression::Pow(box e1, box e2) => Ok(Self::Pow(
|
||||
box Self::try_from_int(e1)?,
|
||||
box UExpression::try_from_int(e2, UBitwidth::B32)?,
|
||||
box UExpression::try_from_int(e2, &UBitwidth::B32)?,
|
||||
)),
|
||||
IntExpression::Div(box e1, box e2) => Ok(Self::Div(
|
||||
box Self::try_from_int(e1)?,
|
||||
|
@ -323,15 +420,21 @@ impl<'ast, T: Field> FieldElementExpression<'ast, T> {
|
|||
let values = values
|
||||
.into_iter()
|
||||
.map(|v| {
|
||||
TypedExpressionOrSpread::align_to_type(v, Type::FieldElement)
|
||||
.map_err(|(e, _)| match e {
|
||||
TypedExpressionOrSpread::Expression(e) => {
|
||||
IntExpression::try_from(e).unwrap()
|
||||
}
|
||||
TypedExpressionOrSpread::Spread(a) => {
|
||||
IntExpression::select(a.array, 0u32)
|
||||
}
|
||||
})
|
||||
TypedExpressionOrSpread::align_to_type(
|
||||
v,
|
||||
&DeclarationArrayType::new(
|
||||
DeclarationType::FieldElement,
|
||||
DeclarationConstant::Concrete(0),
|
||||
),
|
||||
)
|
||||
.map_err(|(e, _)| match e {
|
||||
TypedExpressionOrSpread::Expression(e) => {
|
||||
IntExpression::try_from(e).unwrap()
|
||||
}
|
||||
TypedExpressionOrSpread::Spread(a) => {
|
||||
IntExpression::select(a.array, 0u32)
|
||||
}
|
||||
})
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
Ok(FieldElementExpression::select(
|
||||
|
@ -351,15 +454,15 @@ impl<'ast, T: Field> FieldElementExpression<'ast, T> {
|
|||
impl<'ast, T: Field> UExpression<'ast, T> {
|
||||
pub fn try_from_typed(
|
||||
e: TypedExpression<'ast, T>,
|
||||
bitwidth: UBitwidth,
|
||||
bitwidth: &UBitwidth,
|
||||
) -> Result<Self, TypedExpression<'ast, T>> {
|
||||
match e {
|
||||
TypedExpression::Uint(e) => match e.bitwidth == bitwidth {
|
||||
TypedExpression::Uint(e) => match e.bitwidth == *bitwidth {
|
||||
true => Ok(e),
|
||||
_ => Err(TypedExpression::Uint(e)),
|
||||
},
|
||||
TypedExpression::Int(e) => {
|
||||
Self::try_from_int(e.clone(), bitwidth).map_err(|_| TypedExpression::Int(e))
|
||||
Self::try_from_int(e, bitwidth).map_err(TypedExpression::Int)
|
||||
}
|
||||
e => Err(e),
|
||||
}
|
||||
|
@ -367,7 +470,7 @@ impl<'ast, T: Field> UExpression<'ast, T> {
|
|||
|
||||
pub fn try_from_int(
|
||||
i: IntExpression<'ast, T>,
|
||||
bitwidth: UBitwidth,
|
||||
bitwidth: &UBitwidth,
|
||||
) -> Result<Self, IntExpression<'ast, T>> {
|
||||
use self::IntExpression::*;
|
||||
|
||||
|
@ -377,7 +480,7 @@ impl<'ast, T: Field> UExpression<'ast, T> {
|
|||
Ok(UExpressionInner::Value(
|
||||
u128::from_str_radix(&i.to_str_radix(16), 16).unwrap(),
|
||||
)
|
||||
.annotate(bitwidth))
|
||||
.annotate(*bitwidth))
|
||||
} else {
|
||||
Err(Value(i))
|
||||
}
|
||||
|
@ -435,20 +538,26 @@ impl<'ast, T: Field> UExpression<'ast, T> {
|
|||
let values = values
|
||||
.into_iter()
|
||||
.map(|v| {
|
||||
TypedExpressionOrSpread::align_to_type(v, Type::Uint(bitwidth))
|
||||
.map_err(|(e, _)| match e {
|
||||
TypedExpressionOrSpread::Expression(e) => {
|
||||
IntExpression::try_from(e).unwrap()
|
||||
}
|
||||
TypedExpressionOrSpread::Spread(a) => {
|
||||
IntExpression::select(a.array, 0u32)
|
||||
}
|
||||
})
|
||||
TypedExpressionOrSpread::align_to_type(
|
||||
v,
|
||||
&DeclarationArrayType::new(
|
||||
DeclarationType::Uint(*bitwidth),
|
||||
DeclarationConstant::Concrete(0),
|
||||
),
|
||||
)
|
||||
.map_err(|(e, _)| match e {
|
||||
TypedExpressionOrSpread::Expression(e) => {
|
||||
IntExpression::try_from(e).unwrap()
|
||||
}
|
||||
TypedExpressionOrSpread::Spread(a) => {
|
||||
IntExpression::select(a.array, 0u32)
|
||||
}
|
||||
})
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
Ok(UExpression::select(
|
||||
ArrayExpressionInner::Value(values.into())
|
||||
.annotate(Type::Uint(bitwidth), size),
|
||||
.annotate(Type::Uint(*bitwidth), size),
|
||||
index,
|
||||
))
|
||||
}
|
||||
|
@ -461,35 +570,34 @@ impl<'ast, T: Field> UExpression<'ast, T> {
|
|||
}
|
||||
|
||||
impl<'ast, T: Field> ArrayExpression<'ast, T> {
|
||||
pub fn try_from_typed(
|
||||
pub fn try_from_typed<S: PartialEq<UExpression<'ast, T>>>(
|
||||
e: TypedExpression<'ast, T>,
|
||||
target_inner_ty: Type<'ast, T>,
|
||||
target_array_ty: &GArrayType<S>,
|
||||
) -> Result<Self, TypedExpression<'ast, T>> {
|
||||
match e {
|
||||
TypedExpression::Array(e) => Self::try_from_int(e.clone(), target_inner_ty)
|
||||
.map_err(|_| TypedExpression::Array(e)),
|
||||
TypedExpression::Array(e) => Self::try_from_int(e, target_array_ty),
|
||||
e => Err(e),
|
||||
}
|
||||
}
|
||||
|
||||
// precondition: `array` is only made of inline arrays and repeat constructs unless it does not contain the Integer type
|
||||
pub fn try_from_int(
|
||||
pub fn try_from_int<S: PartialEq<UExpression<'ast, T>>>(
|
||||
array: Self,
|
||||
target_inner_ty: Type<'ast, T>,
|
||||
target_array_ty: &GArrayType<S>,
|
||||
) -> Result<Self, TypedExpression<'ast, T>> {
|
||||
let array_ty = array.ty();
|
||||
|
||||
// elements must fit in the target type
|
||||
match array.into_inner() {
|
||||
ArrayExpressionInner::Value(inline_array) => {
|
||||
let res = match target_inner_ty.clone() {
|
||||
Type::Int => Ok(inline_array),
|
||||
t => {
|
||||
let res = match &*target_array_ty.ty {
|
||||
GType::Int => Ok(inline_array),
|
||||
_ => {
|
||||
// try to convert all elements to the target type
|
||||
inline_array
|
||||
.into_iter()
|
||||
.map(|v| {
|
||||
TypedExpressionOrSpread::align_to_type(v, t.clone()).map_err(
|
||||
TypedExpressionOrSpread::align_to_type(v, &target_array_ty).map_err(
|
||||
|(e, _)| match e {
|
||||
TypedExpressionOrSpread::Expression(e) => e,
|
||||
TypedExpressionOrSpread::Spread(a) => {
|
||||
|
@ -508,11 +616,11 @@ impl<'ast, T: Field> ArrayExpression<'ast, T> {
|
|||
Ok(ArrayExpressionInner::Value(res).annotate(inner_ty, array_ty.size))
|
||||
}
|
||||
ArrayExpressionInner::Repeat(box e, box count) => {
|
||||
match target_inner_ty.clone() {
|
||||
Type::Int => Ok(ArrayExpressionInner::Repeat(box e, box count)
|
||||
match &*target_array_ty.ty {
|
||||
GType::Int => Ok(ArrayExpressionInner::Repeat(box e, box count)
|
||||
.annotate(Type::Int, array_ty.size)),
|
||||
// try to align the repeated element to the target type
|
||||
t => TypedExpression::align_to_type(e, t)
|
||||
t => TypedExpression::align_to_type(e, &t)
|
||||
.map(|e| {
|
||||
let ty = e.get_type().clone();
|
||||
|
||||
|
@ -523,7 +631,7 @@ impl<'ast, T: Field> ArrayExpression<'ast, T> {
|
|||
}
|
||||
}
|
||||
a => {
|
||||
if array_ty.ty.weak_eq(&target_inner_ty) {
|
||||
if *target_array_ty.ty == *array_ty.ty {
|
||||
Ok(a.annotate(*array_ty.ty, array_ty.size))
|
||||
} else {
|
||||
Err(a.annotate(*array_ty.ty, array_ty.size).into())
|
||||
|
@ -533,6 +641,49 @@ impl<'ast, T: Field> ArrayExpression<'ast, T> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<'ast, T: Field> StructExpression<'ast, T> {
|
||||
pub fn try_from_int<S: PartialEq<UExpression<'ast, T>>>(
|
||||
struc: Self,
|
||||
target_struct_ty: &GStructType<S>,
|
||||
) -> Result<Self, TypedExpression<'ast, T>> {
|
||||
let struct_ty = struc.ty().clone();
|
||||
|
||||
match struc.into_inner() {
|
||||
StructExpressionInner::Value(inline_struct) => inline_struct
|
||||
.into_iter()
|
||||
.zip(target_struct_ty.members.iter())
|
||||
.map(|(value, target_member)| {
|
||||
TypedExpression::align_to_type(value, &*target_member.ty)
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()
|
||||
.map(|v| StructExpressionInner::Value(v).annotate(struct_ty.clone()))
|
||||
.map_err(|(v, _)| v),
|
||||
s => {
|
||||
if struct_ty
|
||||
.members
|
||||
.iter()
|
||||
.zip(target_struct_ty.members.iter())
|
||||
.all(|(m, target_m)| *target_m.ty == *m.ty)
|
||||
{
|
||||
Ok(s.annotate(struct_ty))
|
||||
} else {
|
||||
Err(s.annotate(struct_ty).into())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn try_from_typed<S: PartialEq<UExpression<'ast, T>>>(
|
||||
e: TypedExpression<'ast, T>,
|
||||
target_struct_ty: &GStructType<S>,
|
||||
) -> Result<Self, TypedExpression<'ast, T>> {
|
||||
match e {
|
||||
TypedExpression::Struct(e) => Self::try_from_int(e, target_struct_ty),
|
||||
e => Err(e),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast, T> From<BigUint> for IntExpression<'ast, T> {
|
||||
fn from(v: BigUint) -> Self {
|
||||
IntExpression::Value(v)
|
||||
|
@ -652,7 +803,7 @@ mod tests {
|
|||
|
||||
for (r, e) in expressions
|
||||
.into_iter()
|
||||
.map(|e| UExpression::try_from_int(e, UBitwidth::B32).unwrap())
|
||||
.map(|e| UExpression::try_from_int(e, &UBitwidth::B32).unwrap())
|
||||
.zip(expected)
|
||||
{
|
||||
assert_eq!(r, e);
|
||||
|
@ -665,7 +816,7 @@ mod tests {
|
|||
|
||||
for e in should_error
|
||||
.into_iter()
|
||||
.map(|e| UExpression::try_from_int(e, UBitwidth::B32))
|
||||
.map(|e| UExpression::try_from_int(e, &UBitwidth::B32))
|
||||
{
|
||||
assert!(e.is_err());
|
||||
}
|
||||
|
|
|
@ -14,15 +14,15 @@ mod integer;
|
|||
mod parameter;
|
||||
pub mod types;
|
||||
mod uint;
|
||||
mod variable;
|
||||
pub mod variable;
|
||||
|
||||
pub use self::identifier::CoreIdentifier;
|
||||
pub use self::parameter::{DeclarationParameter, GParameter};
|
||||
pub use self::types::{
|
||||
CanonicalConstantIdentifier, ConcreteFunctionKey, ConcreteSignature, ConcreteType,
|
||||
ConstantIdentifier, DeclarationFunctionKey, DeclarationSignature, DeclarationType, GArrayType,
|
||||
GStructType, GType, GenericIdentifier, IntoTypes, Signature, StructType, Type, Types,
|
||||
UBitwidth,
|
||||
ConstantIdentifier, DeclarationArrayType, DeclarationFunctionKey, DeclarationSignature,
|
||||
DeclarationStructType, DeclarationType, GArrayType, GStructType, GType, GenericIdentifier,
|
||||
IntoTypes, Signature, StructType, Type, Types, UBitwidth,
|
||||
};
|
||||
use crate::typed_absy::types::ConcreteGenericsAssignment;
|
||||
|
||||
|
@ -107,13 +107,19 @@ impl<'ast, T: Field> TypedProgram<'ast, T> {
|
|||
.arguments
|
||||
.iter()
|
||||
.map(|p| {
|
||||
types::ConcreteType::try_from(types::Type::<T>::from(p.id._type.clone()))
|
||||
.map(|ty| AbiInput {
|
||||
public: !p.private,
|
||||
name: p.id.id.to_string(),
|
||||
ty,
|
||||
})
|
||||
.unwrap()
|
||||
types::ConcreteType::try_from(
|
||||
crate::typed_absy::types::try_from_g_type::<
|
||||
crate::typed_absy::types::DeclarationConstant<'ast>,
|
||||
UExpression<'ast, T>,
|
||||
>(p.id._type.clone())
|
||||
.unwrap(),
|
||||
)
|
||||
.map(|ty| AbiInput {
|
||||
public: !p.private,
|
||||
name: p.id.id.to_string(),
|
||||
ty,
|
||||
})
|
||||
.unwrap()
|
||||
})
|
||||
.collect(),
|
||||
outputs: main
|
||||
|
@ -121,7 +127,14 @@ impl<'ast, T: Field> TypedProgram<'ast, T> {
|
|||
.outputs
|
||||
.iter()
|
||||
.map(|ty| {
|
||||
types::ConcreteType::try_from(types::Type::<T>::from(ty.clone())).unwrap()
|
||||
types::ConcreteType::try_from(
|
||||
crate::typed_absy::types::try_from_g_type::<
|
||||
crate::typed_absy::types::DeclarationConstant<'ast>,
|
||||
UExpression<'ast, T>,
|
||||
>(ty.clone())
|
||||
.unwrap(),
|
||||
)
|
||||
.unwrap()
|
||||
})
|
||||
.collect(),
|
||||
}
|
||||
|
@ -192,7 +205,7 @@ impl<'ast, T: fmt::Display> fmt::Display for TypedModule<'ast, T> {
|
|||
.iter()
|
||||
.map(|(id, symbol)| match symbol {
|
||||
TypedConstantSymbol::Here(ref tc) => {
|
||||
format!("const {} {} = {}", tc.ty, id.id, tc.expression)
|
||||
format!("const {} {} = {}", id.ty, id.id, tc)
|
||||
}
|
||||
TypedConstantSymbol::There(ref imported_id) => {
|
||||
format!(
|
||||
|
@ -298,27 +311,24 @@ impl<'ast, T: fmt::Display> fmt::Display for TypedFunction<'ast, T> {
|
|||
|
||||
#[derive(Clone, PartialEq, Debug)]
|
||||
pub struct TypedConstant<'ast, T> {
|
||||
// the type is already stored in the TypedExpression, but we want to avoid awkward trait bounds in `fmt::Display`
|
||||
pub ty: Type<'ast, T>,
|
||||
pub expression: TypedExpression<'ast, T>,
|
||||
}
|
||||
|
||||
impl<'ast, T> TypedConstant<'ast, T> {
|
||||
pub fn new(ty: Type<'ast, T>, expression: TypedExpression<'ast, T>) -> Self {
|
||||
TypedConstant { ty, expression }
|
||||
pub fn new(expression: TypedExpression<'ast, T>) -> Self {
|
||||
TypedConstant { expression }
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast, T: fmt::Display> fmt::Display for TypedConstant<'ast, T> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
// using `self.expression.get_type()` would be better here but ends up requiring stronger trait bounds
|
||||
write!(f, "const {}({})", self.ty, self.expression)
|
||||
write!(f, "{}", self.expression)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast, T: Clone> Typed<'ast, T> for TypedConstant<'ast, T> {
|
||||
impl<'ast, T: Field> Typed<'ast, T> for TypedConstant<'ast, T> {
|
||||
fn get_type(&self) -> Type<'ast, T> {
|
||||
self.ty.clone()
|
||||
self.expression.get_type()
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1166,24 +1176,6 @@ pub struct StructExpression<'ast, T> {
|
|||
inner: StructExpressionInner<'ast, T>,
|
||||
}
|
||||
|
||||
impl<'ast, T: Field> StructExpression<'ast, T> {
|
||||
pub fn try_from_typed(
|
||||
e: TypedExpression<'ast, T>,
|
||||
target_struct_ty: StructType<'ast, T>,
|
||||
) -> Result<Self, TypedExpression<'ast, T>> {
|
||||
match e {
|
||||
TypedExpression::Struct(e) => {
|
||||
if e.ty() == &target_struct_ty {
|
||||
Ok(e)
|
||||
} else {
|
||||
Err(TypedExpression::Struct(e))
|
||||
}
|
||||
}
|
||||
e => Err(e),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast, T> StructExpression<'ast, T> {
|
||||
pub fn ty(&self) -> &StructType<'ast, T> {
|
||||
&self.ty
|
||||
|
|
|
@ -121,6 +121,7 @@ pub trait ResultFolder<'ast, T: Field>: Sized {
|
|||
Ok(CanonicalConstantIdentifier {
|
||||
module: self.fold_module_id(i.module)?,
|
||||
id: i.id,
|
||||
ty: box self.fold_declaration_type(*i.ty)?,
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -225,6 +226,11 @@ pub trait ResultFolder<'ast, T: Field>: Sized {
|
|||
t: StructType<'ast, T>,
|
||||
) -> Result<StructType<'ast, T>, Self::Error> {
|
||||
Ok(StructType {
|
||||
generics: t
|
||||
.generics
|
||||
.into_iter()
|
||||
.map(|g| g.map(|g| self.fold_uint_expression(g)).transpose())
|
||||
.collect::<Result<Vec<_>, _>>()?,
|
||||
members: t
|
||||
.members
|
||||
.into_iter()
|
||||
|
@ -260,6 +266,11 @@ pub trait ResultFolder<'ast, T: Field>: Sized {
|
|||
t: DeclarationStructType<'ast>,
|
||||
) -> Result<DeclarationStructType<'ast>, Self::Error> {
|
||||
Ok(DeclarationStructType {
|
||||
generics: t
|
||||
.generics
|
||||
.into_iter()
|
||||
.map(|g| g.map(|g| self.fold_declaration_constant(g)).transpose())
|
||||
.collect::<Result<Vec<_>, _>>()?,
|
||||
members: t
|
||||
.members
|
||||
.into_iter()
|
||||
|
@ -1092,7 +1103,6 @@ pub fn fold_constant<'ast, T: Field, F: ResultFolder<'ast, T>>(
|
|||
c: TypedConstant<'ast, T>,
|
||||
) -> Result<TypedConstant<'ast, T>, F::Error> {
|
||||
Ok(TypedConstant {
|
||||
ty: f.fold_type(c.ty)?,
|
||||
expression: f.fold_expression(c.expression)?,
|
||||
})
|
||||
}
|
||||
|
|
|
@ -107,11 +107,20 @@ pub type ConstantIdentifier<'ast> = &'ast str;
|
|||
pub struct CanonicalConstantIdentifier<'ast> {
|
||||
pub module: OwnedTypedModuleId,
|
||||
pub id: ConstantIdentifier<'ast>,
|
||||
pub ty: Box<DeclarationType<'ast>>,
|
||||
}
|
||||
|
||||
impl<'ast> CanonicalConstantIdentifier<'ast> {
|
||||
pub fn new(id: ConstantIdentifier<'ast>, module: OwnedTypedModuleId) -> Self {
|
||||
CanonicalConstantIdentifier { module, id }
|
||||
pub fn new(
|
||||
id: ConstantIdentifier<'ast>,
|
||||
module: OwnedTypedModuleId,
|
||||
ty: DeclarationType<'ast>,
|
||||
) -> Self {
|
||||
CanonicalConstantIdentifier {
|
||||
module,
|
||||
id,
|
||||
ty: box ty,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -122,6 +131,21 @@ pub enum DeclarationConstant<'ast> {
|
|||
Constant(CanonicalConstantIdentifier<'ast>),
|
||||
}
|
||||
|
||||
impl<'ast, T> PartialEq<UExpression<'ast, T>> for DeclarationConstant<'ast> {
|
||||
fn eq(&self, other: &UExpression<'ast, T>) -> bool {
|
||||
match (self, other.as_inner()) {
|
||||
(DeclarationConstant::Concrete(c), UExpressionInner::Value(v)) => *c == *v as u32,
|
||||
_ => true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast, T> PartialEq<DeclarationConstant<'ast>> for UExpression<'ast, T> {
|
||||
fn eq(&self, other: &DeclarationConstant<'ast>) -> bool {
|
||||
other.eq(self)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast> From<u32> for DeclarationConstant<'ast> {
|
||||
fn from(e: u32) -> Self {
|
||||
DeclarationConstant::Concrete(e)
|
||||
|
@ -198,7 +222,7 @@ impl<'ast> TryInto<usize> for DeclarationConstant<'ast> {
|
|||
|
||||
pub type MemberId = String;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, PartialOrd, Ord)]
|
||||
#[derive(Debug, Clone, Eq, Hash, Serialize, Deserialize, PartialOrd, Ord)]
|
||||
pub struct GStructMember<S> {
|
||||
#[serde(rename = "name")]
|
||||
pub id: MemberId,
|
||||
|
@ -210,8 +234,8 @@ pub type DeclarationStructMember<'ast> = GStructMember<DeclarationConstant<'ast>
|
|||
pub type ConcreteStructMember = GStructMember<usize>;
|
||||
pub type StructMember<'ast, T> = GStructMember<UExpression<'ast, T>>;
|
||||
|
||||
impl<'ast, T: PartialEq> PartialEq<DeclarationStructMember<'ast>> for StructMember<'ast, T> {
|
||||
fn eq(&self, other: &DeclarationStructMember<'ast>) -> bool {
|
||||
impl<'ast, S, R: PartialEq<S>> PartialEq<GStructMember<S>> for GStructMember<R> {
|
||||
fn eq(&self, other: &GStructMember<S>) -> bool {
|
||||
self.id == other.id && *self.ty == *other.ty
|
||||
}
|
||||
}
|
||||
|
@ -239,19 +263,7 @@ impl<'ast, T> From<ConcreteStructMember> for StructMember<'ast, T> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<'ast> From<ConcreteStructMember> for DeclarationStructMember<'ast> {
|
||||
fn from(t: ConcreteStructMember) -> Self {
|
||||
try_from_g_struct_member(t).unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast, T> From<DeclarationStructMember<'ast>> for StructMember<'ast, T> {
|
||||
fn from(t: DeclarationStructMember<'ast>) -> Self {
|
||||
try_from_g_struct_member(t).unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, PartialEq, Eq, Hash, Serialize, Deserialize, PartialOrd, Ord, Debug)]
|
||||
#[derive(Clone, Eq, Hash, Serialize, Deserialize, PartialOrd, Ord, Debug)]
|
||||
pub struct GArrayType<S> {
|
||||
pub size: S,
|
||||
#[serde(flatten)]
|
||||
|
@ -262,13 +274,9 @@ pub type DeclarationArrayType<'ast> = GArrayType<DeclarationConstant<'ast>>;
|
|||
pub type ConcreteArrayType = GArrayType<usize>;
|
||||
pub type ArrayType<'ast, T> = GArrayType<UExpression<'ast, T>>;
|
||||
|
||||
impl<'ast, T: PartialEq> PartialEq<DeclarationArrayType<'ast>> for ArrayType<'ast, T> {
|
||||
fn eq(&self, other: &DeclarationArrayType<'ast>) -> bool {
|
||||
*self.ty == *other.ty
|
||||
&& match (self.size.as_inner(), &other.size) {
|
||||
(UExpressionInner::Value(l), DeclarationConstant::Concrete(r)) => *l as u32 == *r,
|
||||
_ => true,
|
||||
}
|
||||
impl<'ast, S, R: PartialEq<S>> PartialEq<GArrayType<S>> for GArrayType<R> {
|
||||
fn eq(&self, other: &GArrayType<S>) -> bool {
|
||||
*self.ty == *other.ty && self.size == other.size
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -298,22 +306,6 @@ impl<S: fmt::Display> fmt::Display for GArrayType<S> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<'ast, T: PartialEq + fmt::Display> Type<'ast, T> {
|
||||
// array type equality with non-strict size checks
|
||||
// sizes always match unless they are different constants
|
||||
pub fn weak_eq(&self, other: &Self) -> bool {
|
||||
match (self, other) {
|
||||
(Type::Array(t), Type::Array(u)) => t.ty.weak_eq(&u.ty),
|
||||
(Type::Struct(t), Type::Struct(u)) => t
|
||||
.members
|
||||
.iter()
|
||||
.zip(u.members.iter())
|
||||
.all(|(m, n)| m.ty.weak_eq(&n.ty)),
|
||||
(t, u) => t == u,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn try_from_g_array_type<T: TryInto<U>, U>(
|
||||
t: GArrayType<T>,
|
||||
) -> Result<GArrayType<U>, SpecializationError> {
|
||||
|
@ -350,18 +342,13 @@ impl<'ast> From<ConcreteArrayType> for DeclarationArrayType<'ast> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<'ast, T> From<DeclarationArrayType<'ast>> for ArrayType<'ast, T> {
|
||||
fn from(t: DeclarationArrayType<'ast>) -> Self {
|
||||
try_from_g_array_type(t).unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialOrd, Ord)]
|
||||
pub struct GStructType<S> {
|
||||
#[serde(flatten)]
|
||||
pub canonical_location: StructLocation,
|
||||
#[serde(skip)]
|
||||
pub location: Option<StructLocation>,
|
||||
pub generics: Vec<Option<S>>,
|
||||
pub members: Vec<GStructMember<S>>,
|
||||
}
|
||||
|
||||
|
@ -369,15 +356,25 @@ pub type DeclarationStructType<'ast> = GStructType<DeclarationConstant<'ast>>;
|
|||
pub type ConcreteStructType = GStructType<usize>;
|
||||
pub type StructType<'ast, T> = GStructType<UExpression<'ast, T>>;
|
||||
|
||||
impl<S: PartialEq> PartialEq for GStructType<S> {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.canonical_location.eq(&other.canonical_location)
|
||||
impl<'ast, S, R: PartialEq<S>> PartialEq<GStructType<S>> for GStructType<R> {
|
||||
fn eq(&self, other: &GStructType<S>) -> bool {
|
||||
self.canonical_location == other.canonical_location
|
||||
&& self
|
||||
.generics
|
||||
.iter()
|
||||
.zip(other.generics.iter())
|
||||
.all(|(a, b)| match (a, b) {
|
||||
(Some(a), Some(b)) => a == b,
|
||||
(None, None) => true,
|
||||
_ => false,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: Hash> Hash for GStructType<S> {
|
||||
fn hash<H: Hasher>(&self, state: &mut H) {
|
||||
self.canonical_location.hash(state);
|
||||
self.generics.hash(state);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -389,6 +386,14 @@ fn try_from_g_struct_type<T: TryInto<U>, U>(
|
|||
Ok(GStructType {
|
||||
location: t.location,
|
||||
canonical_location: t.canonical_location,
|
||||
generics: t
|
||||
.generics
|
||||
.into_iter()
|
||||
.map(|g| match g {
|
||||
Some(g) => g.try_into().map(Some).map_err(|_| SpecializationError),
|
||||
None => Ok(None),
|
||||
})
|
||||
.collect::<Result<_, _>>()?,
|
||||
members: t
|
||||
.members
|
||||
.into_iter()
|
||||
|
@ -417,17 +422,17 @@ impl<'ast> From<ConcreteStructType> for DeclarationStructType<'ast> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<'ast, T> From<DeclarationStructType<'ast>> for StructType<'ast, T> {
|
||||
fn from(t: DeclarationStructType<'ast>) -> Self {
|
||||
try_from_g_struct_type(t).unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> GStructType<S> {
|
||||
pub fn new(module: PathBuf, name: String, members: Vec<GStructMember<S>>) -> Self {
|
||||
pub fn new(
|
||||
module: PathBuf,
|
||||
name: String,
|
||||
generics: Vec<Option<S>>,
|
||||
members: Vec<GStructMember<S>>,
|
||||
) -> Self {
|
||||
GStructType {
|
||||
canonical_location: StructLocation { module, name },
|
||||
location: None,
|
||||
generics,
|
||||
members,
|
||||
}
|
||||
}
|
||||
|
@ -498,7 +503,7 @@ impl fmt::Display for UBitwidth {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, PartialEq, Eq, Hash, PartialOrd, Ord, Debug)]
|
||||
#[derive(Clone, Eq, Hash, PartialOrd, Ord, Debug)]
|
||||
pub enum GType<S> {
|
||||
FieldElement,
|
||||
Boolean,
|
||||
|
@ -608,13 +613,13 @@ pub type DeclarationType<'ast> = GType<DeclarationConstant<'ast>>;
|
|||
pub type ConcreteType = GType<usize>;
|
||||
pub type Type<'ast, T> = GType<UExpression<'ast, T>>;
|
||||
|
||||
impl<'ast, T: PartialEq> PartialEq<DeclarationType<'ast>> for Type<'ast, T> {
|
||||
fn eq(&self, other: &DeclarationType<'ast>) -> bool {
|
||||
impl<'ast, S, R: PartialEq<S>> PartialEq<GType<S>> for GType<R> {
|
||||
fn eq(&self, other: >ype<S>) -> bool {
|
||||
use self::GType::*;
|
||||
|
||||
match (self, other) {
|
||||
(Array(l), Array(r)) => l == r,
|
||||
(Struct(l), Struct(r)) => l.canonical_location == r.canonical_location,
|
||||
(Struct(l), Struct(r)) => l == r,
|
||||
(FieldElement, FieldElement) | (Boolean, Boolean) => true,
|
||||
(Uint(l), Uint(r)) => l == r,
|
||||
_ => false,
|
||||
|
@ -622,7 +627,7 @@ impl<'ast, T: PartialEq> PartialEq<DeclarationType<'ast>> for Type<'ast, T> {
|
|||
}
|
||||
}
|
||||
|
||||
fn try_from_g_type<T: TryInto<U>, U>(t: GType<T>) -> Result<GType<U>, SpecializationError> {
|
||||
pub fn try_from_g_type<T: TryInto<U>, U>(t: GType<T>) -> Result<GType<U>, SpecializationError> {
|
||||
match t {
|
||||
GType::FieldElement => Ok(GType::FieldElement),
|
||||
GType::Boolean => Ok(GType::Boolean),
|
||||
|
@ -653,12 +658,6 @@ impl<'ast> From<ConcreteType> for DeclarationType<'ast> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<'ast, T> From<DeclarationType<'ast>> for Type<'ast, T> {
|
||||
fn from(t: DeclarationType<'ast>) -> Self {
|
||||
try_from_g_type(t).unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
impl<S, U: Into<S>> From<(GType<S>, U)> for GArrayType<S> {
|
||||
fn from(tup: (GType<S>, U)) -> Self {
|
||||
GArrayType {
|
||||
|
@ -669,10 +668,10 @@ impl<S, U: Into<S>> From<(GType<S>, U)> for GArrayType<S> {
|
|||
}
|
||||
|
||||
impl<S> GArrayType<S> {
|
||||
pub fn new(ty: GType<S>, size: S) -> Self {
|
||||
pub fn new<U: Into<S>>(ty: GType<S>, size: U) -> Self {
|
||||
GArrayType {
|
||||
ty: Box::new(ty),
|
||||
size,
|
||||
size: size.into(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -694,11 +693,36 @@ impl<S: fmt::Display> fmt::Display for GType<S> {
|
|||
GType::Uint(ref bitwidth) => write!(f, "u{}", bitwidth),
|
||||
GType::Int => write!(f, "{{integer}}"),
|
||||
GType::Array(ref array_type) => write!(f, "{}", array_type),
|
||||
GType::Struct(ref struct_type) => write!(f, "{}", struct_type.name(),),
|
||||
GType::Struct(ref struct_type) => write!(f, "{}", struct_type),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: fmt::Display> fmt::Display for GStructType<S> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
"{}{}",
|
||||
self.name(),
|
||||
if !self.generics.is_empty() {
|
||||
format!(
|
||||
"<{}>",
|
||||
self.generics
|
||||
.iter()
|
||||
.map(|g| g
|
||||
.as_ref()
|
||||
.map(|g| g.to_string())
|
||||
.unwrap_or_else(|| '_'.to_string()))
|
||||
.collect::<Vec<_>>()
|
||||
.join(", ")
|
||||
)
|
||||
} else {
|
||||
"".to_string()
|
||||
}
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> GType<S> {
|
||||
pub fn array<U: Into<GArrayType<S>>>(array_ty: U) -> Self {
|
||||
GType::Array(array_ty.into())
|
||||
|
@ -717,7 +741,7 @@ impl<'ast, T: fmt::Display + PartialEq + fmt::Debug> Type<'ast, T> {
|
|||
pub fn can_be_specialized_to(&self, other: &DeclarationType) -> bool {
|
||||
use self::GType::*;
|
||||
|
||||
if self == other {
|
||||
if other == self {
|
||||
true
|
||||
} else {
|
||||
match (self, other) {
|
||||
|
@ -735,7 +759,13 @@ impl<'ast, T: fmt::Display + PartialEq + fmt::Debug> Type<'ast, T> {
|
|||
}
|
||||
_ => false,
|
||||
},
|
||||
(Struct(_), Struct(_)) => false,
|
||||
(Struct(l), Struct(r)) => {
|
||||
l.canonical_location == r.canonical_location
|
||||
&& l.members
|
||||
.iter()
|
||||
.zip(r.members.iter())
|
||||
.all(|(m, d_m)| m.ty.can_be_specialized_to(&*d_m.ty))
|
||||
}
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
@ -848,14 +878,6 @@ impl<'ast, T> TryFrom<FunctionKey<'ast, T>> for ConcreteFunctionKey<'ast> {
|
|||
}
|
||||
}
|
||||
|
||||
// impl<'ast> TryFrom<DeclarationFunctionKey<'ast>> for ConcreteFunctionKey<'ast> {
|
||||
// type Error = SpecializationError;
|
||||
|
||||
// fn try_from(k: DeclarationFunctionKey<'ast>) -> Result<Self, Self::Error> {
|
||||
// try_from_g_function_key(k)
|
||||
// }
|
||||
// }
|
||||
|
||||
impl<'ast, T> From<ConcreteFunctionKey<'ast>> for FunctionKey<'ast, T> {
|
||||
fn from(k: ConcreteFunctionKey<'ast>) -> Self {
|
||||
try_from_g_function_key(k).unwrap()
|
||||
|
@ -868,12 +890,6 @@ impl<'ast> From<ConcreteFunctionKey<'ast>> for DeclarationFunctionKey<'ast> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<'ast, T> From<DeclarationFunctionKey<'ast>> for FunctionKey<'ast, T> {
|
||||
fn from(k: DeclarationFunctionKey<'ast>) -> Self {
|
||||
try_from_g_function_key(k).unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast, S> GFunctionKey<'ast, S> {
|
||||
pub fn with_location<T: Into<OwnedTypedModuleId>, U: Into<FunctionIdentifier<'ast>>>(
|
||||
module: T,
|
||||
|
@ -913,7 +929,118 @@ impl<'ast> ConcreteFunctionKey<'ast> {
|
|||
}
|
||||
}
|
||||
|
||||
pub use self::signature::{ConcreteSignature, DeclarationSignature, GSignature, Signature};
|
||||
use std::collections::btree_map::Entry;
|
||||
|
||||
pub fn check_type<'ast, S: Clone + PartialEq + PartialEq<usize>>(
|
||||
decl_ty: &DeclarationType<'ast>,
|
||||
ty: >ype<S>,
|
||||
constants: &mut GGenericsAssignment<'ast, S>,
|
||||
) -> bool {
|
||||
match (decl_ty, ty) {
|
||||
(DeclarationType::Array(t0), GType::Array(t1)) => {
|
||||
let s1 = t1.size.clone();
|
||||
|
||||
// both the inner type and the size must match
|
||||
check_type(&t0.ty, &t1.ty, constants)
|
||||
&& match &t0.size {
|
||||
// if the declared size is an identifier, we insert into the map, or check if the concrete size
|
||||
// matches if this identifier is already in the map
|
||||
DeclarationConstant::Generic(id) => match constants.0.entry(id.clone()) {
|
||||
Entry::Occupied(e) => *e.get() == s1,
|
||||
Entry::Vacant(e) => {
|
||||
e.insert(s1);
|
||||
true
|
||||
}
|
||||
},
|
||||
DeclarationConstant::Concrete(s0) => s1 == *s0 as usize,
|
||||
// in the case of a constant, we do not know the value yet, so we optimistically assume it's correct
|
||||
// if it does not match, it will be caught during inlining
|
||||
DeclarationConstant::Constant(..) => true,
|
||||
}
|
||||
}
|
||||
(DeclarationType::FieldElement, GType::FieldElement)
|
||||
| (DeclarationType::Boolean, GType::Boolean) => true,
|
||||
(DeclarationType::Uint(b0), GType::Uint(b1)) => b0 == b1,
|
||||
(DeclarationType::Struct(s0), GType::Struct(s1)) => {
|
||||
s0.canonical_location == s1.canonical_location
|
||||
}
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast, T> From<CanonicalConstantIdentifier<'ast>> for UExpression<'ast, T> {
|
||||
fn from(_: CanonicalConstantIdentifier<'ast>) -> Self {
|
||||
unreachable!("constants should have been removed in constant inlining")
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast> From<CanonicalConstantIdentifier<'ast>> for DeclarationConstant<'ast> {
|
||||
fn from(c: CanonicalConstantIdentifier<'ast>) -> Self {
|
||||
DeclarationConstant::Constant(c)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn specialize_declaration_type<
|
||||
'ast,
|
||||
S: Clone + PartialEq + From<u32> + fmt::Debug + From<CanonicalConstantIdentifier<'ast>>,
|
||||
>(
|
||||
decl_ty: DeclarationType<'ast>,
|
||||
generics: &GGenericsAssignment<'ast, S>,
|
||||
) -> Result<GType<S>, GenericIdentifier<'ast>> {
|
||||
Ok(match decl_ty {
|
||||
DeclarationType::Int => unreachable!(),
|
||||
DeclarationType::Array(t0) => {
|
||||
// let s1 = t1.size.clone();
|
||||
|
||||
let ty = box specialize_declaration_type(*t0.ty, &generics)?;
|
||||
let size = match t0.size {
|
||||
DeclarationConstant::Generic(s) => generics.0.get(&s).cloned().ok_or(s),
|
||||
DeclarationConstant::Concrete(s) => Ok(s.into()),
|
||||
DeclarationConstant::Constant(c) => Ok(c.into()),
|
||||
}?;
|
||||
|
||||
GType::Array(GArrayType { size, ty })
|
||||
}
|
||||
DeclarationType::FieldElement => GType::FieldElement,
|
||||
DeclarationType::Boolean => GType::Boolean,
|
||||
DeclarationType::Uint(b0) => GType::Uint(b0),
|
||||
DeclarationType::Struct(s0) => GType::Struct(GStructType {
|
||||
members: s0
|
||||
.members
|
||||
.into_iter()
|
||||
.map(|m| {
|
||||
let id = m.id;
|
||||
specialize_declaration_type(*m.ty, generics)
|
||||
.map(|ty| GStructMember { ty: box ty, id })
|
||||
})
|
||||
.collect::<Result<_, _>>()?,
|
||||
generics: s0
|
||||
.generics
|
||||
.into_iter()
|
||||
.map(|g| match g {
|
||||
Some(constant) => match constant {
|
||||
DeclarationConstant::Generic(s) => {
|
||||
generics.0.get(&s).cloned().ok_or(s).map(Some)
|
||||
}
|
||||
DeclarationConstant::Concrete(s) => Ok(Some(s.into())),
|
||||
DeclarationConstant::Constant(..) => {
|
||||
unreachable!(
|
||||
"identifiers should have been removed in constant inlining"
|
||||
)
|
||||
}
|
||||
},
|
||||
_ => Ok(None),
|
||||
})
|
||||
.collect::<Result<_, _>>()?,
|
||||
canonical_location: s0.canonical_location,
|
||||
location: s0.location,
|
||||
}),
|
||||
})
|
||||
}
|
||||
|
||||
pub use self::signature::{
|
||||
try_from_g_signature, ConcreteSignature, DeclarationSignature, GSignature, Signature,
|
||||
};
|
||||
|
||||
pub mod signature {
|
||||
use super::*;
|
||||
|
@ -968,83 +1095,6 @@ pub mod signature {
|
|||
pub type ConcreteSignature = GSignature<usize>;
|
||||
pub type Signature<'ast, T> = GSignature<UExpression<'ast, T>>;
|
||||
|
||||
use std::collections::btree_map::Entry;
|
||||
|
||||
fn check_type<'ast, S: Clone + PartialEq + PartialEq<usize>>(
|
||||
decl_ty: &DeclarationType<'ast>,
|
||||
ty: >ype<S>,
|
||||
constants: &mut GGenericsAssignment<'ast, S>,
|
||||
) -> bool {
|
||||
match (decl_ty, ty) {
|
||||
(DeclarationType::Array(t0), GType::Array(t1)) => {
|
||||
let s1 = t1.size.clone();
|
||||
|
||||
// both the inner type and the size must match
|
||||
check_type(&t0.ty, &t1.ty, constants)
|
||||
&& match &t0.size {
|
||||
// if the declared size is an identifier, we insert into the map, or check if the concrete size
|
||||
// matches if this identifier is already in the map
|
||||
DeclarationConstant::Generic(id) => match constants.0.entry(id.clone()) {
|
||||
Entry::Occupied(e) => *e.get() == s1,
|
||||
Entry::Vacant(e) => {
|
||||
e.insert(s1);
|
||||
true
|
||||
}
|
||||
},
|
||||
DeclarationConstant::Concrete(s0) => s1 == *s0 as usize,
|
||||
// in the case of a constant, we do not know the value yet, so we optimistically assume it's correct
|
||||
// if it does not match, it will be caught during inlining
|
||||
DeclarationConstant::Constant(..) => true,
|
||||
}
|
||||
}
|
||||
(DeclarationType::FieldElement, GType::FieldElement)
|
||||
| (DeclarationType::Boolean, GType::Boolean) => true,
|
||||
(DeclarationType::Uint(b0), GType::Uint(b1)) => b0 == b1,
|
||||
(DeclarationType::Struct(s0), GType::Struct(s1)) => {
|
||||
s0.canonical_location == s1.canonical_location
|
||||
}
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
fn specialize_type<'ast, S: Clone + PartialEq + PartialEq<usize> + From<u32> + fmt::Debug>(
|
||||
decl_ty: DeclarationType<'ast>,
|
||||
constants: &GGenericsAssignment<'ast, S>,
|
||||
) -> Result<GType<S>, GenericIdentifier<'ast>> {
|
||||
Ok(match decl_ty {
|
||||
DeclarationType::Int => unreachable!(),
|
||||
DeclarationType::Array(t0) => {
|
||||
// let s1 = t1.size.clone();
|
||||
|
||||
let ty = box specialize_type(*t0.ty, &constants)?;
|
||||
let size = match t0.size {
|
||||
DeclarationConstant::Generic(s) => constants.0.get(&s).cloned().ok_or(s),
|
||||
DeclarationConstant::Concrete(s) => Ok(s.into()),
|
||||
DeclarationConstant::Constant(..) => {
|
||||
unreachable!("identifiers should have been removed in constant inlining")
|
||||
}
|
||||
}?;
|
||||
|
||||
GType::Array(GArrayType { size, ty })
|
||||
}
|
||||
DeclarationType::FieldElement => GType::FieldElement,
|
||||
DeclarationType::Boolean => GType::Boolean,
|
||||
DeclarationType::Uint(b0) => GType::Uint(b0),
|
||||
DeclarationType::Struct(s0) => GType::Struct(GStructType {
|
||||
members: s0
|
||||
.members
|
||||
.into_iter()
|
||||
.map(|m| {
|
||||
let id = m.id;
|
||||
specialize_type(*m.ty, constants).map(|ty| GStructMember { ty: box ty, id })
|
||||
})
|
||||
.collect::<Result<_, _>>()?,
|
||||
canonical_location: s0.canonical_location,
|
||||
location: s0.location,
|
||||
}),
|
||||
})
|
||||
}
|
||||
|
||||
impl<'ast> PartialEq<DeclarationSignature<'ast>> for ConcreteSignature {
|
||||
fn eq(&self, other: &DeclarationSignature<'ast>) -> bool {
|
||||
// we keep track of the value of constants in a map, as a given constant can only have one value
|
||||
|
@ -1135,7 +1185,7 @@ pub mod signature {
|
|||
self.outputs
|
||||
.clone()
|
||||
.into_iter()
|
||||
.map(|t| specialize_type(t, &constants))
|
||||
.map(|t| specialize_declaration_type(t, &constants))
|
||||
.collect::<Result<_, _>>()
|
||||
}
|
||||
}
|
||||
|
@ -1185,12 +1235,6 @@ pub mod signature {
|
|||
}
|
||||
}
|
||||
|
||||
impl<'ast, T> From<DeclarationSignature<'ast>> for Signature<'ast, T> {
|
||||
fn from(s: DeclarationSignature<'ast>) -> Self {
|
||||
try_from_g_signature(s).unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: fmt::Display> fmt::Display for GSignature<S> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
if !self.generics.is_empty() {
|
||||
|
@ -1329,8 +1373,7 @@ pub mod signature {
|
|||
GenericIdentifier {
|
||||
name: "P",
|
||||
index: 0,
|
||||
}
|
||||
.into(),
|
||||
},
|
||||
))]);
|
||||
let generic2 = DeclarationSignature::new()
|
||||
.generics(vec![Some(
|
||||
|
@ -1345,8 +1388,7 @@ pub mod signature {
|
|||
GenericIdentifier {
|
||||
name: "Q",
|
||||
index: 0,
|
||||
}
|
||||
.into(),
|
||||
},
|
||||
))]);
|
||||
|
||||
assert_eq!(generic1, generic2);
|
||||
|
|
|
@ -25,16 +25,6 @@ impl<'ast, T> TryFrom<Variable<'ast, T>> for ConcreteVariable<'ast> {
|
|||
}
|
||||
}
|
||||
|
||||
// impl<'ast> TryFrom<DeclarationVariable<'ast>> for ConcreteVariable<'ast> {
|
||||
// type Error = SpecializationError;
|
||||
|
||||
// fn try_from(v: DeclarationVariable<'ast>) -> Result<Self, Self::Error> {
|
||||
// let _type = v._type.try_into()?;
|
||||
|
||||
// Ok(Self { _type, id: v.id })
|
||||
// }
|
||||
// }
|
||||
|
||||
impl<'ast, T> From<ConcreteVariable<'ast>> for Variable<'ast, T> {
|
||||
fn from(v: ConcreteVariable<'ast>) -> Self {
|
||||
let _type = v._type.into();
|
||||
|
@ -43,12 +33,12 @@ impl<'ast, T> From<ConcreteVariable<'ast>> for Variable<'ast, T> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<'ast, T> From<DeclarationVariable<'ast>> for Variable<'ast, T> {
|
||||
fn from(v: DeclarationVariable<'ast>) -> Self {
|
||||
let _type = v._type.into();
|
||||
pub fn try_from_g_variable<T: TryInto<U>, U>(
|
||||
v: GVariable<T>,
|
||||
) -> Result<GVariable<U>, SpecializationError> {
|
||||
let _type = crate::typed_absy::types::try_from_g_type(v._type)?;
|
||||
|
||||
Self { _type, id: v.id }
|
||||
}
|
||||
Ok(GVariable { _type, id: v.id })
|
||||
}
|
||||
|
||||
impl<'ast, S: Clone> GVariable<'ast, S> {
|
||||
|
|
20
zokrates_core_test/tests/tests/structs/member_order.json
Normal file
20
zokrates_core_test/tests/tests/structs/member_order.json
Normal file
|
@ -0,0 +1,20 @@
|
|||
{
|
||||
"entry_point": "./tests/tests/structs/member_order.zok",
|
||||
"curves": ["Bn128"],
|
||||
"tests": [
|
||||
{
|
||||
"abi": true,
|
||||
"input": {
|
||||
"values": [{
|
||||
"a":true,
|
||||
"b": "3"
|
||||
}]
|
||||
},
|
||||
"output": {
|
||||
"Ok": {
|
||||
"values": []
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
9
zokrates_core_test/tests/tests/structs/member_order.zok
Normal file
9
zokrates_core_test/tests/tests/structs/member_order.zok
Normal file
|
@ -0,0 +1,9 @@
|
|||
struct Foo {
|
||||
field b
|
||||
bool a
|
||||
}
|
||||
|
||||
// this tests the abi, checking that the fields of a `Foo` instance get encoded in the right order
|
||||
// if the the encoder reverses `a` and `b`, the boolean check ends up being done on the field value, which would fail
|
||||
def main(Foo f):
|
||||
return
|
|
@ -34,9 +34,9 @@ ty_array = { ty_basic_or_struct ~ ("[" ~ expression ~ "]")+ }
|
|||
ty = { ty_array | ty_basic | ty_struct }
|
||||
type_list = _{(ty ~ ("," ~ ty)*)?}
|
||||
// structs
|
||||
ty_struct = { identifier }
|
||||
ty_struct = { identifier ~ explicit_generics? }
|
||||
// type definitions
|
||||
ty_struct_definition = { "struct" ~ identifier ~ "{" ~ NEWLINE* ~ struct_field_list ~ NEWLINE* ~ "}" ~ NEWLINE* }
|
||||
ty_struct_definition = { "struct" ~ identifier ~ constant_generics_declaration? ~ "{" ~ NEWLINE* ~ struct_field_list ~ NEWLINE* ~ "}" ~ NEWLINE* }
|
||||
struct_field_list = _{(struct_field ~ (NEWLINE+ ~ struct_field)*)? }
|
||||
struct_field = { ty ~ identifier }
|
||||
|
||||
|
@ -77,9 +77,9 @@ conditional_expression = { "if" ~ expression ~ "then" ~ expression ~ "else" ~ ex
|
|||
postfix_expression = { identifier ~ access+ } // we force there to be at least one access, otherwise this matches single identifiers
|
||||
access = { array_access | call_access | member_access }
|
||||
array_access = { "[" ~ range_or_expression ~ "]" }
|
||||
call_access = { explicit_generics? ~ "(" ~ arguments ~ ")" }
|
||||
call_access = { ("::" ~ explicit_generics)? ~ "(" ~ arguments ~ ")" }
|
||||
arguments = { expression_list }
|
||||
explicit_generics = { "::<" ~ constant_generics_values ~ ">" }
|
||||
explicit_generics = { "<" ~ constant_generics_values ~ ">" }
|
||||
constant_generics_values = _{ constant_generics_value ~ ("," ~ constant_generics_value)* }
|
||||
constant_generics_value = { literal | identifier | underscore }
|
||||
underscore = { "_" }
|
||||
|
|
|
@ -147,6 +147,7 @@ mod ast {
|
|||
#[pest_ast(rule(Rule::ty_struct_definition))]
|
||||
pub struct StructDefinition<'ast> {
|
||||
pub id: IdentifierExpression<'ast>,
|
||||
pub generics: Vec<IdentifierExpression<'ast>>,
|
||||
pub fields: Vec<StructField<'ast>>,
|
||||
#[pest_ast(outer())]
|
||||
pub span: Span<'ast>,
|
||||
|
@ -307,6 +308,7 @@ mod ast {
|
|||
#[pest_ast(rule(Rule::ty_struct))]
|
||||
pub struct StructType<'ast> {
|
||||
pub id: IdentifierExpression<'ast>,
|
||||
pub explicit_generics: Option<ExplicitGenerics<'ast>>,
|
||||
#[pest_ast(outer())]
|
||||
pub span: Span<'ast>,
|
||||
}
|
||||
|
|
52
zokrates_stdlib/stdlib/snark/gm17.zok
Normal file
52
zokrates_stdlib/stdlib/snark/gm17.zok
Normal file
|
@ -0,0 +1,52 @@
|
|||
#pragma curve bw6_761
|
||||
from "EMBED" import snark_verify_bls12_377 as verify
|
||||
|
||||
struct ProofInner {
|
||||
field[2] a
|
||||
field[2][2] b
|
||||
field[2] c
|
||||
}
|
||||
|
||||
struct Proof<N> {
|
||||
ProofInner proof
|
||||
field[N] inputs
|
||||
}
|
||||
struct VerificationKey<N> {
|
||||
field[2][2] h
|
||||
field[2] g_alpha
|
||||
field[2][2] h_beta
|
||||
field[2] g_gamma
|
||||
field[2][2] h_gamma
|
||||
field[N][2] query // input length + 1
|
||||
}
|
||||
|
||||
def flat<N, F>(field[N][2] input) -> field[F]:
|
||||
assert(F == N * 2)
|
||||
field[F] out = [0; F]
|
||||
for u32 i in 0..N do
|
||||
for u32 j in 0..2 do
|
||||
out[(i * 2) + j] = input[i][j]
|
||||
endfor
|
||||
endfor
|
||||
return out
|
||||
|
||||
def main<N, Q>(Proof<N> proof, VerificationKey<Q> vk) -> bool:
|
||||
assert(Q == N + 1) // query length (Q) should be N + 1
|
||||
field[8] flat_proof = [
|
||||
...proof.proof.a,
|
||||
...flat::<2, 4>(proof.proof.b),
|
||||
...proof.proof.c
|
||||
]
|
||||
|
||||
u32 two_Q = 2 * Q
|
||||
|
||||
field[16 + (2 * Q)] flat_vk = [
|
||||
...flat::<2, 4>(vk.h),
|
||||
...vk.g_alpha,
|
||||
...flat::<2, 4>(vk.h_beta),
|
||||
...vk.g_gamma,
|
||||
...flat::<2, 4>(vk.h_gamma),
|
||||
...flat::<Q, two_Q>(vk.query)
|
||||
]
|
||||
|
||||
return verify(proof.inputs, flat_proof, flat_vk)
|
103
zokrates_stdlib/tests/tests/snark/gm17.json
Normal file
103
zokrates_stdlib/tests/tests/snark/gm17.json
Normal file
|
@ -0,0 +1,103 @@
|
|||
{
|
||||
"entry_point": "./tests/tests/snark/gm17.zok",
|
||||
"curves": ["Bw6_761"],
|
||||
"tests": [
|
||||
{
|
||||
"abi": true,
|
||||
"input": {
|
||||
"values": [
|
||||
{
|
||||
"proof": {
|
||||
"a": [
|
||||
"0x01441e34fd88112583831de068e3bdf67d7a5b020c9650e4dc8e3dd0cf92f62b32668dd4654ddc63fe5293a542756a27",
|
||||
"0x013d7b6097a6ae8534909cb2f2ec2e39f3ccbe8858db0285e45619131db37f84b1c88fbb257a7b8e8944a926bb41aa66"
|
||||
],
|
||||
"b": [
|
||||
[
|
||||
"0x00dcf8242e445213da28281aab32bcf47268bf16624dbca7c828cfbb0e8000bad94926272cba0cd5e9a959cf4e969c7c",
|
||||
"0x00b570276d40ae06ac3feb5db65b37acf1eabd16e1c588d01c553b1a60e5d007d9202a8ad2b6405e521b3eec84772521"
|
||||
],
|
||||
[
|
||||
"0x00acbeabed6267316420b73b9eba39e8c51080b8b507857478a54c0fc259b17eec2921253a15445e2ec3c130706398b0",
|
||||
"0x019b579a061cbc4aed64351d87ba96c071118ef3fd645e630c18986e284de5ffc8a48ea94eeb3bdc8807d62d366e223f"
|
||||
]
|
||||
],
|
||||
"c": [
|
||||
"0x004c93c20cd43f8b7818fcc4ece38243779bedb8b874702df4d6968b75cbe2e6831ab38475e2f0c7bc170171580198df",
|
||||
"0x0177a560e5f6ae87f07aeff2dcdb1e0737b4810aeba8a5ba1bc4c5d0e89f268aae142ab5327afbde8e8bad869702aad3"
|
||||
]
|
||||
},
|
||||
"inputs": [
|
||||
"0x0000000000000000000000000000000000000000000000000000000000000001",
|
||||
"0x0000000000000000000000000000000000000000000000000000000000000002",
|
||||
"0x0000000000000000000000000000000000000000000000000000000000000003"
|
||||
]
|
||||
},
|
||||
{
|
||||
"h": [
|
||||
[
|
||||
"0x000a4c42894d5fd7ac23ca05eac034d82299dd9db5fa493812e4852bcf50cd88faf8f3e97cd292678b292d11e173949b",
|
||||
"0x001ead78f91728b07146e93ee1f21165f25ad88e0fee997f5527076ca84374d3a6d834b59608226b28ab8b8d5ea9a94f"
|
||||
],
|
||||
[
|
||||
"0x0087b1837c209351af3b67bbfeaea80ed94f690584847b1aa34cc59a2b451f360fc268b2562ea8015f8f4d71c7bf4675",
|
||||
"0x015c50d51c8ed463a4e9cc76fc0583634b04dc26b36e10bfac9169d0baebf58b45b687a81a0ca60400427889bcbc6b76"
|
||||
]
|
||||
],
|
||||
"g_alpha": [
|
||||
"0x004b7af9ab6ef9061adb5ed7ba12e9cd41f508ac758c25c5e629d871a1b980e5242149b522b20c57808fae97cb76b971",
|
||||
"0x0196c16d89a7cccbb8f15775da22c01d5ec45b384829bcaad91b324a482676558d3d6d41f675966b5d22537f4ed77903"
|
||||
],
|
||||
"h_beta": [
|
||||
[
|
||||
"0x014d2d0bcfa272334efbc589dc263c3f2a5d2711f9a0d5fbb3c2ad1b7eebe93459aeee6e1c8bc02041945313aec93d8a",
|
||||
"0x0054800f89ebbbd924328a7782fdbb5260b56059901a06e6ad58c4a7df96018e5ea1c5ffd28ed0dd0139dcced6bde7e8"
|
||||
],
|
||||
[
|
||||
"0x00ca4e270e5fe79ff2a5432daf6e9e5aa22aebf6521a7d3c5ef97d981b05ea93043c6307b47e8a3e00ace9c987fb725e",
|
||||
"0x010cb8f97a5d586777e4f7ca8a0ce4465c0de02951cb8ccca43403b1a669e523c1163ebc9ce7d10edf583894fad70341"
|
||||
]
|
||||
],
|
||||
"g_gamma": [
|
||||
"0x003fa4d4d1fe1a9bb62e704b5ac76a514e4aaf53cfcbd12cb55aa7afecf2c12ce9346737b5594ee872700178748e9ed1",
|
||||
"0x018975a2eb9de8a1982d076b56bb86b5214f89cff897d492e16dcdc1eca2a692eb9f0af5183585ba4aee9d78af2ab570"
|
||||
],
|
||||
"h_gamma": [
|
||||
[
|
||||
"0x000a4c42894d5fd7ac23ca05eac034d82299dd9db5fa493812e4852bcf50cd88faf8f3e97cd292678b292d11e173949b",
|
||||
"0x001ead78f91728b07146e93ee1f21165f25ad88e0fee997f5527076ca84374d3a6d834b59608226b28ab8b8d5ea9a94f"
|
||||
],
|
||||
[
|
||||
"0x0087b1837c209351af3b67bbfeaea80ed94f690584847b1aa34cc59a2b451f360fc268b2562ea8015f8f4d71c7bf4675",
|
||||
"0x015c50d51c8ed463a4e9cc76fc0583634b04dc26b36e10bfac9169d0baebf58b45b687a81a0ca60400427889bcbc6b76"
|
||||
]
|
||||
],
|
||||
"query": [
|
||||
[
|
||||
"0x00dbcc84391e078ae2fa7b5dc8478651b945e155505332a55e5b7be4de52ce83450bbf94f1da270c012104d394b22fda",
|
||||
"0x002dc3039f7236d31fceaa6d8e13d33a5850984193f70c0abfe20a1f4540f59987e49cb0cc2722f1dccb47f1012d38c8"
|
||||
],
|
||||
[
|
||||
"0x00db1bc3a431619ca74564c8a734592151a5fc2d8bfa750d4ffb94126bdaed83dce86bcdc8f966dca3066f67c61c897c",
|
||||
"0x00e97f2f6c94a2676dd3c8646a45684cfd66a644644c1fc8ee5cf2ab4e322a5a82a9f9872ec9e8c7f3f1a9ddf38f2e53"
|
||||
],
|
||||
[
|
||||
"0x008f4c292ba1ae0fa22613e0afaa075796b21a935e591fb8e8b32fa7c0fe0ecda25d5575e1e2b178d5a4bfb8e89f9d36",
|
||||
"0x017cb6aca4e2d1027ab429a2a7d6b8f6e13dfeb427b7eaf9b6e3ca22554fae39f45ee0854098c9753cca04b46f3388d0"
|
||||
],
|
||||
[
|
||||
"0x0168740e2d9cab168df083dd1d340de23d5055f4eed63c87811e94a5bf9c492658c6c58ccb1a48bb153cbe9aa8d98c8d",
|
||||
"0x005b7c28b57504562c1d38a5ba9c67a59c696dc2e51b3c50d96e75e2f399f9106f08f6846d553d32e58b8131ad997fc1"
|
||||
]
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
"output": {
|
||||
"Ok": {
|
||||
"values": ["1"]
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
57
zokrates_stdlib/tests/tests/snark/gm17.zok
Normal file
57
zokrates_stdlib/tests/tests/snark/gm17.zok
Normal file
|
@ -0,0 +1,57 @@
|
|||
// verify a snark
|
||||
// to reproduce the test cases:
|
||||
//
|
||||
// 1. Create a program
|
||||
// ```zokrates
|
||||
// def main(field a, field b) -> field:
|
||||
// return a + b
|
||||
// ```
|
||||
//
|
||||
// 2. Compile it to bls12_377
|
||||
// ```sh
|
||||
// zokrates compile -i program.zok --curve bls12_377
|
||||
// ```
|
||||
//
|
||||
// 3. Run a trusted setup for gm17
|
||||
// ```sh
|
||||
// zokrates setup --proving-scheme gm17 --backend ark
|
||||
// ```
|
||||
//
|
||||
// 4. Execute the program and generate a proof
|
||||
// ```sh
|
||||
// zokrates compute-witness -a 1 2
|
||||
// zokrates generate-proof --proving-scheme gm17 --backend ark
|
||||
// ```
|
||||
//
|
||||
// 5. Generate the test case
|
||||
//
|
||||
// ```sh
|
||||
// cat > gm17.json << EOT
|
||||
// {
|
||||
// "entry_point": "./tests/tests/snark/gm17.zok",
|
||||
// "curves": ["Bw6_761"],
|
||||
// "tests": [
|
||||
// {
|
||||
// "abi": true,
|
||||
// "input": {
|
||||
// "values": [
|
||||
// $(cat proof.json && echo ", " && cat verification.key)
|
||||
// ]
|
||||
// },
|
||||
// "output": {
|
||||
// "Ok": {
|
||||
// "values": ["1"]
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// ]
|
||||
// }
|
||||
// EOT
|
||||
// ```
|
||||
//
|
||||
// `gm17.json` can then be used as a test for this code file
|
||||
|
||||
from "snark/gm17" import main as verify, Proof, VerificationKey
|
||||
|
||||
def main(Proof<3> proof, VerificationKey<4> vk) -> bool:
|
||||
return verify::<3, 4>(proof, vk)
|
|
@ -8,6 +8,7 @@ edition = "2018"
|
|||
zokrates_field = { version = "0.4", path = "../zokrates_field" }
|
||||
zokrates_core = { version = "0.6", path = "../zokrates_core" }
|
||||
zokrates_fs_resolver = { version = "0.5", path = "../zokrates_fs_resolver" }
|
||||
zokrates_abi = { version = "0.1", path = "../zokrates_abi" }
|
||||
serde = "1.0"
|
||||
serde_derive = "1.0"
|
||||
serde_json = "1.0"
|
||||
|
|
|
@ -5,8 +5,11 @@ use std::fs::File;
|
|||
use std::io::{BufReader, Read};
|
||||
use std::path::{Path, PathBuf};
|
||||
|
||||
use zokrates_core::compile::{compile, CompileConfig};
|
||||
use zokrates_core::ir;
|
||||
use zokrates_core::{
|
||||
compile::{compile, CompileConfig},
|
||||
typed_absy::ConcreteType,
|
||||
};
|
||||
use zokrates_field::{Bls12_377Field, Bls12_381Field, Bn128Field, Bw6_761Field, Field};
|
||||
use zokrates_fs_resolver::FileSystemResolver;
|
||||
|
||||
|
@ -34,6 +37,7 @@ struct Input {
|
|||
|
||||
#[derive(Serialize, Deserialize, Clone)]
|
||||
struct Test {
|
||||
pub abi: Option<bool>,
|
||||
pub input: Input,
|
||||
pub output: TestResult,
|
||||
}
|
||||
|
@ -48,11 +52,24 @@ struct Output {
|
|||
values: Vec<Val>,
|
||||
}
|
||||
|
||||
type Val = String;
|
||||
type Val = serde_json::Value;
|
||||
|
||||
fn parse_val<T: Field>(s: String) -> T {
|
||||
let radix = if s.starts_with("0x") { 16 } else { 10 };
|
||||
T::try_from_str(s.trim_start_matches("0x"), radix).unwrap()
|
||||
fn try_parse_raw_val<T: Field>(s: serde_json::Value) -> Result<T, ()> {
|
||||
match s {
|
||||
serde_json::Value::String(s) => {
|
||||
let radix = if s.starts_with("0x") { 16 } else { 10 };
|
||||
T::try_from_str(s.trim_start_matches("0x"), radix).map_err(|_| ())
|
||||
}
|
||||
_ => Err(()),
|
||||
}
|
||||
}
|
||||
|
||||
fn try_parse_abi_val<T: Field>(
|
||||
s: Vec<Val>,
|
||||
types: Vec<ConcreteType>,
|
||||
) -> Result<Vec<T>, zokrates_abi::Error> {
|
||||
use zokrates_abi::Encode;
|
||||
zokrates_abi::parse_strict_json(s, types).map(|v| v.encode())
|
||||
}
|
||||
|
||||
impl<T: Field> From<ir::ExecutionResult<T>> for ComparableResult<T> {
|
||||
|
@ -63,7 +80,12 @@ impl<T: Field> From<ir::ExecutionResult<T>> for ComparableResult<T> {
|
|||
|
||||
impl<T: Field> From<TestResult> for ComparableResult<T> {
|
||||
fn from(r: TestResult) -> ComparableResult<T> {
|
||||
ComparableResult(r.map(|v| v.values.into_iter().map(parse_val).collect()))
|
||||
ComparableResult(r.map(|v| {
|
||||
v.values
|
||||
.into_iter()
|
||||
.map(|v| try_parse_raw_val(v).unwrap())
|
||||
.collect()
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -130,6 +152,7 @@ fn compile_and_run<T: Field>(t: Tests) {
|
|||
let artifacts = compile::<T, _>(code, entry_point.clone(), Some(&resolver), &config).unwrap();
|
||||
|
||||
let bin = artifacts.prog();
|
||||
let abi = artifacts.abi();
|
||||
|
||||
if let Some(target_count) = t.max_constraint_count {
|
||||
let count = bin.constraint_count();
|
||||
|
@ -148,12 +171,21 @@ fn compile_and_run<T: Field>(t: Tests) {
|
|||
let interpreter = zokrates_core::ir::Interpreter::default();
|
||||
|
||||
for test in t.tests.into_iter() {
|
||||
let input = &test.input.values;
|
||||
let with_abi = test.abi.unwrap_or(false);
|
||||
|
||||
let output = interpreter.execute(
|
||||
bin,
|
||||
&(input.iter().cloned().map(parse_val).collect::<Vec<_>>()),
|
||||
);
|
||||
let input = if with_abi {
|
||||
try_parse_abi_val(test.input.values, abi.signature().inputs).unwrap()
|
||||
} else {
|
||||
test.input
|
||||
.values
|
||||
.iter()
|
||||
.cloned()
|
||||
.map(try_parse_raw_val)
|
||||
.collect::<Result<Vec<_>, _>>()
|
||||
.unwrap()
|
||||
};
|
||||
|
||||
let output = interpreter.execute(bin, &input);
|
||||
|
||||
if let Err(e) = compare(output, test.output) {
|
||||
let mut code = File::open(&entry_point).unwrap();
|
||||
|
|
Loading…
Reference in a new issue