1
0
Fork 0
mirror of synced 2025-09-23 20:28:36 +00:00

Merge pull request #945 from Zokrates/generic-structs

Implement generic structs
This commit is contained in:
Thibaut Schaeffer 2021-08-09 11:58:34 +02:00 committed by GitHub
commit f717d243b2
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
40 changed files with 1586 additions and 599 deletions

1
Cargo.lock generated
View file

@ -2547,6 +2547,7 @@ dependencies = [
"serde", "serde",
"serde_derive", "serde_derive",
"serde_json", "serde_json",
"zokrates_abi",
"zokrates_core", "zokrates_core",
"zokrates_field", "zokrates_field",
"zokrates_fs_resolver", "zokrates_fs_resolver",

View file

@ -0,0 +1 @@
Enable constant generics on structs

View file

@ -0,0 +1 @@
Add gm17 verifier to stdlib for bw6_761

View file

@ -299,22 +299,7 @@ pub fn parse_strict<T: Field>(s: &str, types: Vec<ConcreteType>) -> Result<Value
serde_json::from_str(s).map_err(|e| Error::Json(e.to_string()))?; serde_json::from_str(s).map_err(|e| Error::Json(e.to_string()))?;
match values { match values {
serde_json::Value::Array(values) => { serde_json::Value::Array(values) => parse_strict_json(values, types),
if values.len() != types.len() {
return Err(Error::Type(format!(
"Expected {} inputs, found {}",
types.len(),
values.len()
)));
}
Ok(Values(
types
.into_iter()
.zip(values.into_iter())
.map(|(ty, v)| parse_value(v, ty))
.collect::<Result<_, _>>()?,
))
}
_ => Err(Error::Type(format!( _ => Err(Error::Type(format!(
"Expected an array of values, found `{}`", "Expected an array of values, found `{}`",
values values
@ -322,6 +307,27 @@ pub fn parse_strict<T: Field>(s: &str, types: Vec<ConcreteType>) -> Result<Value
} }
} }
pub fn parse_strict_json<T: Field>(
values: Vec<serde_json::Value>,
types: Vec<ConcreteType>,
) -> Result<Values<T>, Error> {
if values.len() != types.len() {
return Err(Error::Type(format!(
"Expected {} inputs, found {}",
types.len(),
values.len()
)));
}
Ok(Values(
types
.into_iter()
.zip(values.into_iter())
.map(|(ty, v)| parse_value(v, ty))
.collect::<Result<_, _>>()?,
))
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
@ -428,6 +434,7 @@ mod tests {
vec![ConcreteType::Struct(ConcreteStructType::new( vec![ConcreteType::Struct(ConcreteStructType::new(
"".into(), "".into(),
"".into(), "".into(),
vec![],
vec![ConcreteStructMember::new( vec![ConcreteStructMember::new(
"a".into(), "a".into(),
ConcreteType::FieldElement ConcreteType::FieldElement
@ -449,6 +456,7 @@ mod tests {
vec![ConcreteType::Struct(ConcreteStructType::new( vec![ConcreteType::Struct(ConcreteStructType::new(
"".into(), "".into(),
"".into(), "".into(),
vec![],
vec![ConcreteStructMember::new( vec![ConcreteStructMember::new(
"a".into(), "a".into(),
ConcreteType::FieldElement ConcreteType::FieldElement
@ -466,6 +474,7 @@ mod tests {
vec![ConcreteType::Struct(ConcreteStructType::new( vec![ConcreteType::Struct(ConcreteStructType::new(
"".into(), "".into(),
"".into(), "".into(),
vec![],
vec![ConcreteStructMember::new( vec![ConcreteStructMember::new(
"a".into(), "a".into(),
ConcreteType::FieldElement ConcreteType::FieldElement
@ -483,6 +492,7 @@ mod tests {
vec![ConcreteType::Struct(ConcreteStructType::new( vec![ConcreteType::Struct(ConcreteStructType::new(
"".into(), "".into(),
"".into(), "".into(),
vec![],
vec![ConcreteStructMember::new( vec![ConcreteStructMember::new(
"a".into(), "a".into(),
ConcreteType::FieldElement ConcreteType::FieldElement

View file

@ -107,7 +107,7 @@ field[2] b = a[1..3] // initialize an array copying a slice from `a`
``` ```
### Structs ### Structs
A struct is a composite datatype representing a named collection of variables. A struct is a composite datatype representing a named collection of variables. Structs can be generic over constants, in order to wrap arrays of generic size. For more details on generic array sizes, see [constant generics](../language/generics.md)
The contained variables can be of any type. The contained variables can be of any type.
The following code shows an example of how to use structs. The following code shows an example of how to use structs.

View file

@ -1,14 +1,14 @@
struct Bar { struct Bar<N> {
field[2] c field[N] c
bool d bool d
} }
struct Foo { struct Foo<P> {
Bar a Bar<P> a
bool b bool b
} }
def main() -> (Foo): def main() -> (Foo<2>):
Foo[2] f = [Foo { a: Bar { c: [0, 0], d: false }, b: true}, Foo { a: Bar {c: [0, 0], d: false}, b: true}] Foo<2>[2] f = [Foo { a: Bar { c: [0, 0], d: false }, b: true}, Foo { a: Bar {c: [0, 0], d: false}, b: true}]
f[0].a.c = [42, 43] f[0].a.c = [42, 43]
return f[0] return f[0]

View file

@ -0,0 +1,4 @@
const u32[2] A = [1]
def main() -> u32[2]:
return A

View file

@ -0,0 +1,6 @@
struct A<N, N> {
field[N] a
}
def main():
return

View file

@ -0,0 +1,7 @@
struct A<N> {
field[N] a
}
def main():
A<_> a = A { a: [1] }
return

View file

@ -0,0 +1,6 @@
struct A<1> {
field[1] a
}
def main():
return

View file

@ -0,0 +1,6 @@
struct A<N> {
field[N] a
}
def main(A<1> a, A<2> b) -> bool:
return a == b

View file

@ -0,0 +1,6 @@
struct A {
field[N] a
}
def main():
return

View file

@ -0,0 +1,4 @@
struct A<N> {}
def main():
return

View file

@ -0,0 +1,15 @@
struct B {
field a
}
struct A {
B a
}
def main():
A a = A {
a: B {
a: false
}
}
return

View file

@ -0,0 +1,13 @@
struct B<P> {
field[P] a
}
struct A<N> {
field[N] a
B<N> b
}
def main(A<2> a) -> bool:
u32 SIZE = 1 + 1
A<SIZE> b = A { a: [1, 2], b: B { a: [1, 2] } }
return a == b

View file

@ -79,6 +79,11 @@ impl<'ast> From<pest::StructDefinition<'ast>> for absy::SymbolDeclarationNode<'a
let id = definition.id.span.as_str(); let id = definition.id.span.as_str();
let ty = absy::StructDefinition { let ty = absy::StructDefinition {
generics: definition
.generics
.into_iter()
.map(absy::ConstantGenericNode::from)
.collect(),
fields: definition fields: definition
.fields .fields
.into_iter() .into_iter()
@ -767,9 +772,25 @@ impl<'ast> From<pest::Type<'ast>> for absy::UnresolvedTypeNode<'ast> {
pest::BasicType::U32(t) => UnresolvedType::Uint(32).span(t.span), pest::BasicType::U32(t) => UnresolvedType::Uint(32).span(t.span),
pest::BasicType::U64(t) => UnresolvedType::Uint(64).span(t.span), pest::BasicType::U64(t) => UnresolvedType::Uint(64).span(t.span),
}, },
pest::BasicOrStructType::Struct(t) => { pest::BasicOrStructType::Struct(t) => UnresolvedType::User(
UnresolvedType::User(t.span.as_str().to_string()).span(t.span) t.id.span.as_str().to_string(),
} t.explicit_generics.map(|explicit_generics| {
explicit_generics
.values
.into_iter()
.map(|i| match i {
pest::ConstantGenericValue::Underscore(_) => None,
pest::ConstantGenericValue::Value(v) => {
Some(absy::ExpressionNode::from(v))
}
pest::ConstantGenericValue::Identifier(i) => Some(
absy::Expression::Identifier(i.span.as_str()).span(i.span),
),
})
.collect()
}),
)
.span(t.span),
}; };
let span = t.span; let span = t.span;
@ -785,9 +806,25 @@ impl<'ast> From<pest::Type<'ast>> for absy::UnresolvedTypeNode<'ast> {
.unwrap() .unwrap()
.span(span.clone()) .span(span.clone())
} }
pest::Type::Struct(s) => { pest::Type::Struct(s) => UnresolvedType::User(
UnresolvedType::User(s.id.span.as_str().to_string()).span(s.span) s.id.span.as_str().to_string(),
} s.explicit_generics.map(|explicit_generics| {
explicit_generics
.values
.into_iter()
.map(|i| match i {
pest::ConstantGenericValue::Underscore(_) => None,
pest::ConstantGenericValue::Value(v) => {
Some(absy::ExpressionNode::from(v))
}
pest::ConstantGenericValue::Identifier(i) => {
Some(absy::Expression::Identifier(i.span.as_str()).span(i.span))
}
})
.collect()
}),
)
.span(s.span),
} }
} }
} }

View file

@ -153,7 +153,7 @@ impl<'ast> fmt::Display for SymbolDeclaration<'ast> {
i.value.source.display(), i.value.source.display(),
i.value.id i.value.id
), ),
SymbolDefinition::Struct(ref t) => write!(f, "struct {} {}", self.id, t), SymbolDefinition::Struct(ref t) => write!(f, "struct {}{}", self.id, t),
SymbolDefinition::Constant(ref c) => write!( SymbolDefinition::Constant(ref c) => write!(
f, f,
"const {} {} = {}", "const {} {} = {}",
@ -199,20 +199,25 @@ pub type UnresolvedTypeNode<'ast> = Node<UnresolvedType<'ast>>;
/// A struct type definition /// A struct type definition
#[derive(Debug, Clone, PartialEq)] #[derive(Debug, Clone, PartialEq)]
pub struct StructDefinition<'ast> { pub struct StructDefinition<'ast> {
pub generics: Vec<ConstantGenericNode<'ast>>,
pub fields: Vec<StructDefinitionFieldNode<'ast>>, pub fields: Vec<StructDefinitionFieldNode<'ast>>,
} }
impl<'ast> fmt::Display for StructDefinition<'ast> { impl<'ast> fmt::Display for StructDefinition<'ast> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!( writeln!(
f, f,
"{}", "<{}> {{",
self.fields self.generics
.iter() .iter()
.map(|fi| fi.to_string()) .map(|g| g.to_string())
.collect::<Vec<_>>() .collect::<Vec<_>>()
.join("\n") .join(", "),
) )?;
for field in &self.fields {
writeln!(f, " {}", field)?;
}
write!(f, "}}",)
} }
} }
@ -227,7 +232,7 @@ pub struct StructDefinitionField<'ast> {
impl<'ast> fmt::Display for StructDefinitionField<'ast> { impl<'ast> fmt::Display for StructDefinitionField<'ast> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}: {},", self.id, self.ty) write!(f, "{} {}", self.ty, self.id)
} }
} }

View file

@ -3,9 +3,7 @@ use crate::absy::UnresolvedTypeNode;
use std::fmt; use std::fmt;
pub type Identifier<'ast> = &'ast str; pub type Identifier<'ast> = &'ast str;
pub type MemberId = String; pub type MemberId = String;
pub type UserTypeId = String; pub type UserTypeId = String;
#[derive(Clone, PartialEq, Debug)] #[derive(Clone, PartialEq, Debug)]
@ -14,7 +12,7 @@ pub enum UnresolvedType<'ast> {
Boolean, Boolean,
Uint(usize), Uint(usize),
Array(Box<UnresolvedTypeNode<'ast>>, ExpressionNode<'ast>), Array(Box<UnresolvedTypeNode<'ast>>, ExpressionNode<'ast>),
User(UserTypeId), User(UserTypeId, Option<Vec<Option<ExpressionNode<'ast>>>>),
} }
impl<'ast> fmt::Display for UnresolvedType<'ast> { impl<'ast> fmt::Display for UnresolvedType<'ast> {
@ -24,7 +22,30 @@ impl<'ast> fmt::Display for UnresolvedType<'ast> {
UnresolvedType::Boolean => write!(f, "bool"), UnresolvedType::Boolean => write!(f, "bool"),
UnresolvedType::Uint(bitwidth) => write!(f, "u{}", bitwidth), UnresolvedType::Uint(bitwidth) => write!(f, "u{}", bitwidth),
UnresolvedType::Array(ref ty, ref size) => write!(f, "{}[{}]", ty, size), UnresolvedType::Array(ref ty, ref size) => write!(f, "{}[{}]", ty, size),
UnresolvedType::User(i) => write!(f, "{}", i), UnresolvedType::User(ref id, ref generics) => {
write!(
f,
"{}{}",
id,
generics
.as_ref()
.map(|generics| {
format!(
"<{}>",
generics
.iter()
.map(|e| {
e.as_ref()
.map(|e| e.to_string())
.unwrap_or_else(|| "_".to_string())
})
.collect::<Vec<_>>()
.join(", ")
)
})
.unwrap_or_default()
)
}
} }
} }
} }

View file

@ -437,11 +437,13 @@ struct Bar { field a }
ty: ConcreteType::Struct(ConcreteStructType::new( ty: ConcreteType::Struct(ConcreteStructType::new(
"foo".into(), "foo".into(),
"Foo".into(), "Foo".into(),
vec![],
vec![ConcreteStructMember { vec![ConcreteStructMember {
id: "b".into(), id: "b".into(),
ty: box ConcreteType::Struct(ConcreteStructType::new( ty: box ConcreteType::Struct(ConcreteStructType::new(
"bar".into(), "bar".into(),
"Bar".into(), "Bar".into(),
vec![],
vec![ConcreteStructMember { vec![ConcreteStructMember {
id: "a".into(), id: "a".into(),
ty: box ConcreteType::FieldElement ty: box ConcreteType::FieldElement

File diff suppressed because it is too large Load diff

View file

@ -121,7 +121,6 @@ impl<'ast, T: Field> Folder<'ast, T> for ConstantInliner<'ast, T> {
( (
id, id,
TypedConstantSymbol::Here(TypedConstant { TypedConstantSymbol::Here(TypedConstant {
ty: constant.get_type().clone(),
expression: constant, expression: constant,
}), }),
) )
@ -252,8 +251,9 @@ mod tests {
use super::*; use super::*;
use crate::typed_absy::types::DeclarationSignature; use crate::typed_absy::types::DeclarationSignature;
use crate::typed_absy::{ use crate::typed_absy::{
DeclarationFunctionKey, DeclarationType, FieldElementExpression, GType, Identifier, DeclarationArrayType, DeclarationFunctionKey, DeclarationType, FieldElementExpression,
TypedConstant, TypedExpression, TypedFunction, TypedFunctionSymbol, TypedStatement, GType, Identifier, TypedConstant, TypedExpression, TypedFunction, TypedFunctionSymbol,
TypedStatement,
}; };
use zokrates_field::Bn128Field; use zokrates_field::Bn128Field;
@ -276,11 +276,14 @@ mod tests {
}; };
let constants: TypedConstantSymbols<_> = vec![( let constants: TypedConstantSymbols<_> = vec![(
CanonicalConstantIdentifier::new(const_id, "main".into()), CanonicalConstantIdentifier::new(
TypedConstantSymbol::Here(TypedConstant::new( const_id,
GType::FieldElement, "main".into(),
TypedExpression::FieldElement(FieldElementExpression::Number(Bn128Field::from(1))), DeclarationType::FieldElement,
)), ),
TypedConstantSymbol::Here(TypedConstant::new(TypedExpression::FieldElement(
FieldElementExpression::Number(Bn128Field::from(1)),
))),
)] )]
.into_iter() .into_iter()
.collect(); .collect();
@ -364,11 +367,10 @@ mod tests {
}; };
let constants: TypedConstantSymbols<_> = vec![( let constants: TypedConstantSymbols<_> = vec![(
CanonicalConstantIdentifier::new(const_id, "main".into()), CanonicalConstantIdentifier::new(const_id, "main".into(), DeclarationType::Boolean),
TypedConstantSymbol::Here(TypedConstant::new( TypedConstantSymbol::Here(TypedConstant::new(TypedExpression::Boolean(
GType::Boolean, BooleanExpression::Value(true),
TypedExpression::Boolean(BooleanExpression::Value(true)), ))),
)),
)] )]
.into_iter() .into_iter()
.collect(); .collect();
@ -453,9 +455,12 @@ mod tests {
}; };
let constants: TypedConstantSymbols<_> = vec![( let constants: TypedConstantSymbols<_> = vec![(
CanonicalConstantIdentifier::new(const_id, "main".into()), CanonicalConstantIdentifier::new(
const_id,
"main".into(),
DeclarationType::Uint(UBitwidth::B32),
),
TypedConstantSymbol::Here(TypedConstant::new( TypedConstantSymbol::Here(TypedConstant::new(
GType::Uint(UBitwidth::B32),
UExpressionInner::Value(1u128) UExpressionInner::Value(1u128)
.annotate(UBitwidth::B32) .annotate(UBitwidth::B32)
.into(), .into(),
@ -554,20 +559,24 @@ mod tests {
}; };
let constants: TypedConstantSymbols<_> = vec![( let constants: TypedConstantSymbols<_> = vec![(
CanonicalConstantIdentifier::new(const_id, "main".into()), CanonicalConstantIdentifier::new(
TypedConstantSymbol::Here(TypedConstant::new( const_id,
GType::array(GArrayType::new(GType::FieldElement, 2usize)), "main".into(),
TypedExpression::Array( DeclarationType::Array(DeclarationArrayType::new(
ArrayExpressionInner::Value( DeclarationType::FieldElement,
vec![ 2u32,
FieldElementExpression::Number(Bn128Field::from(2)).into(), )),
FieldElementExpression::Number(Bn128Field::from(2)).into(), ),
] TypedConstantSymbol::Here(TypedConstant::new(TypedExpression::Array(
.into(), ArrayExpressionInner::Value(
) vec![
.annotate(GType::FieldElement, 2usize), FieldElementExpression::Number(Bn128Field::from(2)).into(),
), FieldElementExpression::Number(Bn128Field::from(2)).into(),
)), ]
.into(),
)
.annotate(GType::FieldElement, 2usize),
))),
)] )]
.into_iter() .into_iter()
.collect(); .collect();
@ -693,18 +702,24 @@ mod tests {
.collect(), .collect(),
constants: vec![ constants: vec![
( (
CanonicalConstantIdentifier::new(const_a_id, "main".into()), CanonicalConstantIdentifier::new(
const_a_id,
"main".into(),
DeclarationType::FieldElement,
),
TypedConstantSymbol::Here(TypedConstant::new( TypedConstantSymbol::Here(TypedConstant::new(
GType::FieldElement,
TypedExpression::FieldElement(FieldElementExpression::Number( TypedExpression::FieldElement(FieldElementExpression::Number(
Bn128Field::from(1), Bn128Field::from(1),
)), )),
)), )),
), ),
( (
CanonicalConstantIdentifier::new(const_b_id, "main".into()), CanonicalConstantIdentifier::new(
const_b_id,
"main".into(),
DeclarationType::FieldElement,
),
TypedConstantSymbol::Here(TypedConstant::new( TypedConstantSymbol::Here(TypedConstant::new(
GType::FieldElement,
TypedExpression::FieldElement(FieldElementExpression::Add( TypedExpression::FieldElement(FieldElementExpression::Add(
box FieldElementExpression::Identifier(Identifier::from( box FieldElementExpression::Identifier(Identifier::from(
const_a_id, const_a_id,
@ -751,18 +766,24 @@ mod tests {
.collect(), .collect(),
constants: vec![ constants: vec![
( (
CanonicalConstantIdentifier::new(const_a_id, "main".into()), CanonicalConstantIdentifier::new(
const_a_id,
"main".into(),
DeclarationType::FieldElement,
),
TypedConstantSymbol::Here(TypedConstant::new( TypedConstantSymbol::Here(TypedConstant::new(
GType::FieldElement,
TypedExpression::FieldElement(FieldElementExpression::Number( TypedExpression::FieldElement(FieldElementExpression::Number(
Bn128Field::from(1), Bn128Field::from(1),
)), )),
)), )),
), ),
( (
CanonicalConstantIdentifier::new(const_b_id, "main".into()), CanonicalConstantIdentifier::new(
const_b_id,
"main".into(),
DeclarationType::FieldElement,
),
TypedConstantSymbol::Here(TypedConstant::new( TypedConstantSymbol::Here(TypedConstant::new(
GType::FieldElement,
TypedExpression::FieldElement(FieldElementExpression::Number( TypedExpression::FieldElement(FieldElementExpression::Number(
Bn128Field::from(2), Bn128Field::from(2),
)), )),
@ -812,13 +833,14 @@ mod tests {
.into_iter() .into_iter()
.collect(), .collect(),
constants: vec![( constants: vec![(
CanonicalConstantIdentifier::new(foo_const_id, "foo".into()), CanonicalConstantIdentifier::new(
TypedConstantSymbol::Here(TypedConstant::new( foo_const_id,
GType::FieldElement, "foo".into(),
TypedExpression::FieldElement(FieldElementExpression::Number( DeclarationType::FieldElement,
Bn128Field::from(42), ),
)), TypedConstantSymbol::Here(TypedConstant::new(TypedExpression::FieldElement(
)), FieldElementExpression::Number(Bn128Field::from(42)),
))),
)] )]
.into_iter() .into_iter()
.collect(), .collect(),
@ -844,10 +866,15 @@ mod tests {
.into_iter() .into_iter()
.collect(), .collect(),
constants: vec![( constants: vec![(
CanonicalConstantIdentifier::new(foo_const_id, "main".into()), CanonicalConstantIdentifier::new(
foo_const_id,
"main".into(),
DeclarationType::FieldElement,
),
TypedConstantSymbol::There(CanonicalConstantIdentifier::new( TypedConstantSymbol::There(CanonicalConstantIdentifier::new(
foo_const_id, foo_const_id,
"foo".into(), "foo".into(),
DeclarationType::FieldElement,
)), )),
)] )]
.into_iter() .into_iter()
@ -885,13 +912,14 @@ mod tests {
.into_iter() .into_iter()
.collect(), .collect(),
constants: vec![( constants: vec![(
CanonicalConstantIdentifier::new(foo_const_id, "main".into()), CanonicalConstantIdentifier::new(
TypedConstantSymbol::Here(TypedConstant::new( foo_const_id,
GType::FieldElement, "main".into(),
TypedExpression::FieldElement(FieldElementExpression::Number( DeclarationType::FieldElement,
Bn128Field::from(42), ),
)), TypedConstantSymbol::Here(TypedConstant::new(TypedExpression::FieldElement(
)), FieldElementExpression::Number(Bn128Field::from(42)),
))),
)] )]
.into_iter() .into_iter()
.collect(), .collect(),

View file

@ -129,7 +129,7 @@ impl<'ast, T: Field> Flattener<T> {
p: typed_absy::DeclarationParameter<'ast>, p: typed_absy::DeclarationParameter<'ast>,
) -> Vec<zir::Parameter<'ast>> { ) -> Vec<zir::Parameter<'ast>> {
let private = p.private; let private = p.private;
self.fold_variable(p.id.try_into().unwrap()) self.fold_variable(crate::typed_absy::variable::try_from_g_variable(p.id).unwrap())
.into_iter() .into_iter()
.map(|v| zir::Parameter { id: v, private }) .map(|v| zir::Parameter { id: v, private })
.collect() .collect()
@ -1101,7 +1101,11 @@ fn fold_function<'ast, T: Field>(
.collect(), .collect(),
statements: main_statements_buffer, statements: main_statements_buffer,
signature: typed_absy::types::ConcreteSignature::try_from( signature: typed_absy::types::ConcreteSignature::try_from(
typed_absy::types::Signature::<T>::try_from(fun.signature).unwrap(), crate::typed_absy::types::try_from_g_signature::<
crate::typed_absy::types::DeclarationConstant<'ast>,
crate::typed_absy::UExpression<'ast, T>,
>(fun.signature)
.unwrap(),
) )
.unwrap() .unwrap()
.into(), .into(),

View file

@ -1278,6 +1278,24 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Propagator<'ast, 'a, T> {
Ok(BooleanExpression::ArrayEq(box e1, box e2)) Ok(BooleanExpression::ArrayEq(box e1, box e2))
} }
BooleanExpression::StructEq(box e1, box e2) => {
let e1 = self.fold_struct_expression(e1)?;
let e2 = self.fold_struct_expression(e2)?;
if let (Ok(t1), Ok(t2)) = (
ConcreteType::try_from(e1.get_type()),
ConcreteType::try_from(e2.get_type()),
) {
if t1 != t2 {
return Err(Error::Type(format!(
"Cannot compare {} of type {} to {} of type {}",
e1, t1, e2, t2
)));
}
};
Ok(BooleanExpression::StructEq(box e1, box e2))
}
BooleanExpression::FieldLt(box e1, box e2) => { BooleanExpression::FieldLt(box e1, box e2) => {
let e1 = self.fold_field_expression(e1)?; let e1 = self.fold_field_expression(e1)?;
let e2 = self.fold_field_expression(e2)?; let e2 = self.fold_field_expression(e2)?;

View file

@ -230,16 +230,21 @@ mod tests {
public: true, public: true,
ty: ConcreteType::Struct(ConcreteStructType::new( ty: ConcreteType::Struct(ConcreteStructType::new(
"".into(), "".into(),
"Foo".into(), "Bar".into(),
vec![ vec![Some(1usize)],
ConcreteStructMember::new(String::from("a"), ConcreteType::FieldElement), vec![ConcreteStructMember::new(
ConcreteStructMember::new(String::from("b"), ConcreteType::Boolean), String::from("a"),
], ConcreteType::Array(ConcreteArrayType::new(
ConcreteType::FieldElement,
1usize,
)),
)],
)), )),
}], }],
outputs: vec![ConcreteType::Struct(ConcreteStructType::new( outputs: vec![ConcreteType::Struct(ConcreteStructType::new(
"".into(), "".into(),
"Foo".into(), "Foo".into(),
vec![],
vec![ vec![
ConcreteStructMember::new(String::from("a"), ConcreteType::FieldElement), ConcreteStructMember::new(String::from("a"), ConcreteType::FieldElement),
ConcreteStructMember::new(String::from("b"), ConcreteType::Boolean), ConcreteStructMember::new(String::from("b"), ConcreteType::Boolean),
@ -257,15 +262,18 @@ mod tests {
"public": true, "public": true,
"type": "struct", "type": "struct",
"components": { "components": {
"name": "Foo", "name": "Bar",
"generics": [
1
],
"members": [ "members": [
{ {
"name": "a", "name": "a",
"type": "field" "type": "array",
}, "components": {
{ "size": 1,
"name": "b", "type": "field"
"type": "bool" }
} }
] ]
} }
@ -276,6 +284,7 @@ mod tests {
"type": "struct", "type": "struct",
"components": { "components": {
"name": "Foo", "name": "Foo",
"generics": [],
"members": [ "members": [
{ {
"name": "a", "name": "a",
@ -305,11 +314,13 @@ mod tests {
ty: ConcreteType::Struct(ConcreteStructType::new( ty: ConcreteType::Struct(ConcreteStructType::new(
"".into(), "".into(),
"Foo".into(), "Foo".into(),
vec![],
vec![ConcreteStructMember::new( vec![ConcreteStructMember::new(
String::from("bar"), String::from("bar"),
ConcreteType::Struct(ConcreteStructType::new( ConcreteType::Struct(ConcreteStructType::new(
"".into(), "".into(),
"Bar".into(), "Bar".into(),
vec![],
vec![ vec![
ConcreteStructMember::new( ConcreteStructMember::new(
String::from("a"), String::from("a"),
@ -338,12 +349,14 @@ mod tests {
"type": "struct", "type": "struct",
"components": { "components": {
"name": "Foo", "name": "Foo",
"generics": [],
"members": [ "members": [
{ {
"name": "bar", "name": "bar",
"type": "struct", "type": "struct",
"components": { "components": {
"name": "Bar", "name": "Bar",
"generics": [],
"members": [ "members": [
{ {
"name": "a", "name": "a",
@ -378,6 +391,7 @@ mod tests {
ConcreteType::Struct(ConcreteStructType::new( ConcreteType::Struct(ConcreteStructType::new(
"".into(), "".into(),
"Foo".into(), "Foo".into(),
vec![],
vec![ vec![
ConcreteStructMember::new( ConcreteStructMember::new(
String::from("b"), String::from("b"),
@ -386,7 +400,7 @@ mod tests {
ConcreteStructMember::new(String::from("c"), ConcreteType::Boolean), ConcreteStructMember::new(String::from("c"), ConcreteType::Boolean),
], ],
)), )),
2, 2usize,
)), )),
}], }],
outputs: vec![ConcreteType::Boolean], outputs: vec![ConcreteType::Boolean],
@ -406,6 +420,7 @@ mod tests {
"type": "struct", "type": "struct",
"components": { "components": {
"name": "Foo", "name": "Foo",
"generics": [],
"members": [ "members": [
{ {
"name": "b", "name": "b",
@ -439,8 +454,8 @@ mod tests {
name: String::from("a"), name: String::from("a"),
public: false, public: false,
ty: ConcreteType::Array(ConcreteArrayType::new( ty: ConcreteType::Array(ConcreteArrayType::new(
ConcreteType::Array(ConcreteArrayType::new(ConcreteType::FieldElement, 2)), ConcreteType::Array(ConcreteArrayType::new(ConcreteType::FieldElement, 2usize)),
2, 2usize,
)), )),
}], }],
outputs: vec![ConcreteType::FieldElement], outputs: vec![ConcreteType::FieldElement],

View file

@ -138,6 +138,11 @@ pub trait Folder<'ast, T: Field>: Sized {
fn fold_struct_type(&mut self, t: StructType<'ast, T>) -> StructType<'ast, T> { fn fold_struct_type(&mut self, t: StructType<'ast, T>) -> StructType<'ast, T> {
StructType { StructType {
generics: t
.generics
.into_iter()
.map(|g| g.map(|g| self.fold_uint_expression(g)))
.collect(),
members: t members: t
.members .members
.into_iter() .into_iter()
@ -175,6 +180,11 @@ pub trait Folder<'ast, T: Field>: Sized {
t: DeclarationStructType<'ast>, t: DeclarationStructType<'ast>,
) -> DeclarationStructType<'ast> { ) -> DeclarationStructType<'ast> {
DeclarationStructType { DeclarationStructType {
generics: t
.generics
.into_iter()
.map(|g| g.map(|g| self.fold_declaration_constant(g)))
.collect(),
members: t members: t
.members .members
.into_iter() .into_iter()
@ -222,6 +232,7 @@ pub trait Folder<'ast, T: Field>: Sized {
CanonicalConstantIdentifier { CanonicalConstantIdentifier {
module: self.fold_module_id(i.module), module: self.fold_module_id(i.module),
id: i.id, id: i.id,
ty: box self.fold_declaration_type(*i.ty),
} }
} }
@ -1044,7 +1055,6 @@ pub fn fold_constant<'ast, T: Field, F: Folder<'ast, T>>(
c: TypedConstant<'ast, T>, c: TypedConstant<'ast, T>,
) -> TypedConstant<'ast, T> { ) -> TypedConstant<'ast, T> {
TypedConstant { TypedConstant {
ty: f.fold_type(c.ty),
expression: f.fold_expression(c.expression), expression: f.fold_expression(c.expression),
} }
} }

View file

@ -1,9 +1,13 @@
use crate::typed_absy::types::{ArrayType, Type}; use crate::typed_absy::types::{
ArrayType, DeclarationArrayType, DeclarationConstant, DeclarationStructMember,
DeclarationStructType, DeclarationType, GArrayType, GStructType, GType, GenericIdentifier,
StructType, Type,
};
use crate::typed_absy::UBitwidth; use crate::typed_absy::UBitwidth;
use crate::typed_absy::{ use crate::typed_absy::{
ArrayExpression, ArrayExpressionInner, BooleanExpression, FieldElementExpression, IfElse, ArrayExpression, ArrayExpressionInner, BooleanExpression, FieldElementExpression, IfElse,
IfElseExpression, Select, SelectExpression, StructExpression, Typed, TypedExpression, IfElseExpression, Select, SelectExpression, StructExpression, StructExpressionInner, Typed,
TypedExpressionOrSpread, TypedSpread, UExpression, UExpressionInner, TypedExpression, TypedExpressionOrSpread, TypedSpread, UExpression, UExpressionInner,
}; };
use num_bigint::BigUint; use num_bigint::BigUint;
use std::convert::TryFrom; use std::convert::TryFrom;
@ -14,20 +18,107 @@ use zokrates_field::Field;
type TypedExpressionPair<'ast, T> = (TypedExpression<'ast, T>, TypedExpression<'ast, T>); type TypedExpressionPair<'ast, T> = (TypedExpression<'ast, T>, TypedExpression<'ast, T>);
impl<'ast, T: Field> TypedExpressionOrSpread<'ast, T> { impl<'ast, T: Field> TypedExpressionOrSpread<'ast, T> {
pub fn align_to_type(e: Self, ty: Type<'ast, T>) -> Result<Self, (Self, Type<'ast, T>)> { pub fn align_to_type<S: PartialEq<UExpression<'ast, T>>>(
e: Self,
ty: &GArrayType<S>,
) -> Result<Self, (Self, &GArrayType<S>)> {
match e { match e {
TypedExpressionOrSpread::Expression(e) => TypedExpression::align_to_type(e, ty) TypedExpressionOrSpread::Expression(e) => TypedExpression::align_to_type(e, &ty.ty)
.map(|e| e.into()) .map(|e| e.into())
.map_err(|(e, t)| (e.into(), t)), .map_err(|(e, _)| (e.into(), ty)),
TypedExpressionOrSpread::Spread(s) => { TypedExpressionOrSpread::Spread(s) => ArrayExpression::try_from_int(s.array, ty)
ArrayExpression::try_from_int(s.array, ty.clone()) .map(|e| TypedExpressionOrSpread::Spread(TypedSpread { array: e }))
.map(|e| TypedExpressionOrSpread::Spread(TypedSpread { array: e })) .map_err(|e| (e.into(), ty)),
.map_err(|e| (e.into(), ty))
}
} }
} }
} }
trait IntegerInference: Sized {
type Pattern;
fn get_common_pattern(self, other: Self) -> Result<Self::Pattern, (Self, Self)>;
}
impl<'ast, T> IntegerInference for Type<'ast, T> {
type Pattern = DeclarationType<'ast>;
fn get_common_pattern(self, other: Self) -> Result<Self::Pattern, (Self, Self)> {
match (self, other) {
(Type::Boolean, Type::Boolean) => Ok(DeclarationType::Boolean),
(Type::Int, Type::Int) => Err((Type::Int, Type::Int)),
(Type::Int, Type::FieldElement) => Ok(DeclarationType::FieldElement),
(Type::Int, Type::Uint(bitwidth)) => Ok(DeclarationType::Uint(bitwidth)),
(Type::FieldElement, Type::Int) => Ok(DeclarationType::FieldElement),
(Type::Uint(bitwidth), Type::Int) => Ok(DeclarationType::Uint(bitwidth)),
(Type::FieldElement, Type::FieldElement) => Ok(DeclarationType::FieldElement),
(Type::Uint(b0), Type::Uint(b1)) => {
if b0 == b1 {
Ok(DeclarationType::Uint(b0))
} else {
Err((Type::Uint(b0), Type::Uint(b1)))
}
}
(Type::Array(t), Type::Array(u)) => Ok(DeclarationType::Array(
t.get_common_pattern(u)
.map_err(|(t, u)| (Type::Array(t), Type::Array(u)))?,
)),
(Type::Struct(t), Type::Struct(u)) => Ok(DeclarationType::Struct(
t.get_common_pattern(u)
.map_err(|(t, u)| (Type::Struct(t), Type::Struct(u)))?,
)),
(t, u) => Err((t, u)),
}
}
}
impl<'ast, T> IntegerInference for ArrayType<'ast, T> {
type Pattern = DeclarationArrayType<'ast>;
fn get_common_pattern(self, other: Self) -> Result<Self::Pattern, (Self, Self)> {
let s0 = self.size;
let s1 = other.size;
Ok(DeclarationArrayType::new(
self.ty
.get_common_pattern(*other.ty)
.map_err(|(t, u)| (ArrayType::new(t, s0), ArrayType::new(u, s1)))?,
DeclarationConstant::Generic(GenericIdentifier::with_name("DUMMY")), // sizes are not checked at this stage, therefore we insert a dummy generic variable which will be equal to all possible sizes
))
}
}
impl<'ast, T> IntegerInference for StructType<'ast, T> {
type Pattern = DeclarationStructType<'ast>;
fn get_common_pattern(self, other: Self) -> Result<Self::Pattern, (Self, Self)> {
Ok(DeclarationStructType {
members: self
.members
.into_iter()
.zip(other.members.into_iter())
.map(|(m_t, m_u)| match m_t.ty.get_common_pattern(*m_u.ty) {
Ok(ty) => DeclarationStructMember {
ty: box ty,
id: m_t.id,
},
Err(..) => unreachable!(
"struct instances of the same struct should always have a common type"
),
})
.collect::<Vec<_>>(),
canonical_location: self.canonical_location,
location: self.location,
generics: self
.generics
.into_iter()
.map(|g| {
g.map(|_| DeclarationConstant::Generic(GenericIdentifier::with_name("DUMMY")))
})
.collect(),
})
}
}
impl<'ast, T: Field> TypedExpression<'ast, T> { impl<'ast, T: Field> TypedExpression<'ast, T> {
// return two TypedExpression, replacing IntExpression by FieldElement or Uint to try to align the two types if possible. // return two TypedExpression, replacing IntExpression by FieldElement or Uint to try to align the two types if possible.
// Post condition is that (lhs, rhs) cannot be made equal by further removing IntExpressions // Post condition is that (lhs, rhs) cannot be made equal by further removing IntExpressions
@ -51,7 +142,7 @@ impl<'ast, T: Field> TypedExpression<'ast, T> {
.into(), .into(),
)), )),
(Int(lhs), Uint(rhs)) => Ok(( (Int(lhs), Uint(rhs)) => Ok((
UExpression::try_from_int(lhs, rhs.bitwidth()) UExpression::try_from_int(lhs, &rhs.bitwidth())
.map_err(|lhs| (lhs.into(), rhs.clone().into()))? .map_err(|lhs| (lhs.into(), rhs.clone().into()))?
.into(), .into(),
Uint(rhs), Uint(rhs),
@ -60,47 +151,50 @@ impl<'ast, T: Field> TypedExpression<'ast, T> {
let bitwidth = lhs.bitwidth(); let bitwidth = lhs.bitwidth();
Ok(( Ok((
Uint(lhs.clone()), Uint(lhs.clone()),
UExpression::try_from_int(rhs, bitwidth) UExpression::try_from_int(rhs, &bitwidth)
.map_err(|rhs| (lhs.into(), rhs.into()))? .map_err(|rhs| (lhs.into(), rhs.into()))?
.into(), .into(),
)) ))
} }
(Array(lhs), Array(rhs)) => { (Array(lhs), Array(rhs)) => {
fn get_common_type<'a, T: Field>( let common_type = lhs
t: Type<'a, T>, .get_type()
u: Type<'a, T>, .get_common_pattern(rhs.get_type())
) -> Result<Type<'a, T>, ()> { .map_err(|_| (lhs.clone().into(), rhs.clone().into()))?;
match (t, u) {
(Type::Int, Type::Int) => Err(()),
(Type::Int, u) => Ok(u),
(t, Type::Int) => Ok(t),
(Type::Array(t), Type::Array(u)) => Ok(Type::Array(ArrayType::new(
get_common_type(*t.ty, *u.ty)?,
t.size,
))),
(t, _) => Ok(t),
}
}
let common_type = let common_type = match common_type {
get_common_type(lhs.inner_type().clone(), rhs.inner_type().clone()) DeclarationType::Array(ty) => ty,
.map_err(|_| (lhs.clone().into(), rhs.clone().into()))?; _ => unreachable!(),
};
Ok(( Ok((
ArrayExpression::try_from_int(lhs.clone(), common_type.clone()) ArrayExpression::try_from_int(lhs.clone(), &common_type)
.map_err(|lhs| (lhs.clone(), rhs.clone().into()))? .map_err(|lhs| (lhs.clone(), rhs.clone().into()))?
.into(), .into(),
ArrayExpression::try_from_int(rhs, common_type) ArrayExpression::try_from_int(rhs, &common_type)
.map_err(|rhs| (lhs.clone().into(), rhs.clone()))? .map_err(|rhs| (lhs.clone().into(), rhs.clone()))?
.into(), .into(),
)) ))
} }
(Struct(lhs), Struct(rhs)) => { (Struct(lhs), Struct(rhs)) => {
if lhs.get_type() == rhs.get_type() { let common_type = lhs
Ok((Struct(lhs), Struct(rhs))) .get_type()
} else { .get_common_pattern(rhs.get_type())
Err((Struct(lhs), Struct(rhs))) .map_err(|_| (lhs.clone().into(), rhs.clone().into()))?;
}
let common_type = match common_type {
DeclarationType::Struct(ty) => ty,
_ => unreachable!(),
};
Ok((
StructExpression::try_from_int(lhs.clone(), &common_type)
.map_err(|lhs| (lhs.clone(), rhs.clone().into()))?
.into(),
StructExpression::try_from_int(rhs, &common_type)
.map_err(|rhs| (lhs.clone().into(), rhs.clone()))?
.into(),
))
} }
(Uint(lhs), Uint(rhs)) => Ok((lhs.into(), rhs.into())), (Uint(lhs), Uint(rhs)) => Ok((lhs.into(), rhs.into())),
(Boolean(lhs), Boolean(rhs)) => Ok((lhs.into(), rhs.into())), (Boolean(lhs), Boolean(rhs)) => Ok((lhs.into(), rhs.into())),
@ -110,22 +204,25 @@ impl<'ast, T: Field> TypedExpression<'ast, T> {
} }
} }
pub fn align_to_type(e: Self, ty: Type<'ast, T>) -> Result<Self, (Self, Type<'ast, T>)> { pub fn align_to_type<S: PartialEq<UExpression<'ast, T>>>(
match ty.clone() { e: Self,
Type::FieldElement => { ty: &GType<S>,
) -> Result<Self, (Self, &GType<S>)> {
match ty {
GType::FieldElement => {
FieldElementExpression::try_from_typed(e).map(TypedExpression::from) FieldElementExpression::try_from_typed(e).map(TypedExpression::from)
} }
Type::Boolean => BooleanExpression::try_from_typed(e).map(TypedExpression::from), GType::Boolean => BooleanExpression::try_from_typed(e).map(TypedExpression::from),
Type::Uint(bitwidth) => { GType::Uint(bitwidth) => {
UExpression::try_from_typed(e, bitwidth).map(TypedExpression::from) UExpression::try_from_typed(e, bitwidth).map(TypedExpression::from)
} }
Type::Array(array_ty) => { GType::Array(array_ty) => {
ArrayExpression::try_from_typed(e, *array_ty.ty).map(TypedExpression::from) ArrayExpression::try_from_typed(e, array_ty).map(TypedExpression::from)
} }
Type::Struct(struct_ty) => { GType::Struct(struct_ty) => {
StructExpression::try_from_typed(e, struct_ty).map(TypedExpression::from) StructExpression::try_from_typed(e, struct_ty).map(TypedExpression::from)
} }
Type::Int => Err(e), GType::Int => Err(e),
} }
.map_err(|e| (e, ty)) .map_err(|e| (e, ty))
} }
@ -299,7 +396,7 @@ impl<'ast, T: Field> FieldElementExpression<'ast, T> {
)), )),
IntExpression::Pow(box e1, box e2) => Ok(Self::Pow( IntExpression::Pow(box e1, box e2) => Ok(Self::Pow(
box Self::try_from_int(e1)?, box Self::try_from_int(e1)?,
box UExpression::try_from_int(e2, UBitwidth::B32)?, box UExpression::try_from_int(e2, &UBitwidth::B32)?,
)), )),
IntExpression::Div(box e1, box e2) => Ok(Self::Div( IntExpression::Div(box e1, box e2) => Ok(Self::Div(
box Self::try_from_int(e1)?, box Self::try_from_int(e1)?,
@ -323,15 +420,21 @@ impl<'ast, T: Field> FieldElementExpression<'ast, T> {
let values = values let values = values
.into_iter() .into_iter()
.map(|v| { .map(|v| {
TypedExpressionOrSpread::align_to_type(v, Type::FieldElement) TypedExpressionOrSpread::align_to_type(
.map_err(|(e, _)| match e { v,
TypedExpressionOrSpread::Expression(e) => { &DeclarationArrayType::new(
IntExpression::try_from(e).unwrap() DeclarationType::FieldElement,
} DeclarationConstant::Concrete(0),
TypedExpressionOrSpread::Spread(a) => { ),
IntExpression::select(a.array, 0u32) )
} .map_err(|(e, _)| match e {
}) TypedExpressionOrSpread::Expression(e) => {
IntExpression::try_from(e).unwrap()
}
TypedExpressionOrSpread::Spread(a) => {
IntExpression::select(a.array, 0u32)
}
})
}) })
.collect::<Result<Vec<_>, _>>()?; .collect::<Result<Vec<_>, _>>()?;
Ok(FieldElementExpression::select( Ok(FieldElementExpression::select(
@ -351,15 +454,15 @@ impl<'ast, T: Field> FieldElementExpression<'ast, T> {
impl<'ast, T: Field> UExpression<'ast, T> { impl<'ast, T: Field> UExpression<'ast, T> {
pub fn try_from_typed( pub fn try_from_typed(
e: TypedExpression<'ast, T>, e: TypedExpression<'ast, T>,
bitwidth: UBitwidth, bitwidth: &UBitwidth,
) -> Result<Self, TypedExpression<'ast, T>> { ) -> Result<Self, TypedExpression<'ast, T>> {
match e { match e {
TypedExpression::Uint(e) => match e.bitwidth == bitwidth { TypedExpression::Uint(e) => match e.bitwidth == *bitwidth {
true => Ok(e), true => Ok(e),
_ => Err(TypedExpression::Uint(e)), _ => Err(TypedExpression::Uint(e)),
}, },
TypedExpression::Int(e) => { TypedExpression::Int(e) => {
Self::try_from_int(e.clone(), bitwidth).map_err(|_| TypedExpression::Int(e)) Self::try_from_int(e, bitwidth).map_err(TypedExpression::Int)
} }
e => Err(e), e => Err(e),
} }
@ -367,7 +470,7 @@ impl<'ast, T: Field> UExpression<'ast, T> {
pub fn try_from_int( pub fn try_from_int(
i: IntExpression<'ast, T>, i: IntExpression<'ast, T>,
bitwidth: UBitwidth, bitwidth: &UBitwidth,
) -> Result<Self, IntExpression<'ast, T>> { ) -> Result<Self, IntExpression<'ast, T>> {
use self::IntExpression::*; use self::IntExpression::*;
@ -377,7 +480,7 @@ impl<'ast, T: Field> UExpression<'ast, T> {
Ok(UExpressionInner::Value( Ok(UExpressionInner::Value(
u128::from_str_radix(&i.to_str_radix(16), 16).unwrap(), u128::from_str_radix(&i.to_str_radix(16), 16).unwrap(),
) )
.annotate(bitwidth)) .annotate(*bitwidth))
} else { } else {
Err(Value(i)) Err(Value(i))
} }
@ -435,20 +538,26 @@ impl<'ast, T: Field> UExpression<'ast, T> {
let values = values let values = values
.into_iter() .into_iter()
.map(|v| { .map(|v| {
TypedExpressionOrSpread::align_to_type(v, Type::Uint(bitwidth)) TypedExpressionOrSpread::align_to_type(
.map_err(|(e, _)| match e { v,
TypedExpressionOrSpread::Expression(e) => { &DeclarationArrayType::new(
IntExpression::try_from(e).unwrap() DeclarationType::Uint(*bitwidth),
} DeclarationConstant::Concrete(0),
TypedExpressionOrSpread::Spread(a) => { ),
IntExpression::select(a.array, 0u32) )
} .map_err(|(e, _)| match e {
}) TypedExpressionOrSpread::Expression(e) => {
IntExpression::try_from(e).unwrap()
}
TypedExpressionOrSpread::Spread(a) => {
IntExpression::select(a.array, 0u32)
}
})
}) })
.collect::<Result<Vec<_>, _>>()?; .collect::<Result<Vec<_>, _>>()?;
Ok(UExpression::select( Ok(UExpression::select(
ArrayExpressionInner::Value(values.into()) ArrayExpressionInner::Value(values.into())
.annotate(Type::Uint(bitwidth), size), .annotate(Type::Uint(*bitwidth), size),
index, index,
)) ))
} }
@ -461,35 +570,34 @@ impl<'ast, T: Field> UExpression<'ast, T> {
} }
impl<'ast, T: Field> ArrayExpression<'ast, T> { impl<'ast, T: Field> ArrayExpression<'ast, T> {
pub fn try_from_typed( pub fn try_from_typed<S: PartialEq<UExpression<'ast, T>>>(
e: TypedExpression<'ast, T>, e: TypedExpression<'ast, T>,
target_inner_ty: Type<'ast, T>, target_array_ty: &GArrayType<S>,
) -> Result<Self, TypedExpression<'ast, T>> { ) -> Result<Self, TypedExpression<'ast, T>> {
match e { match e {
TypedExpression::Array(e) => Self::try_from_int(e.clone(), target_inner_ty) TypedExpression::Array(e) => Self::try_from_int(e, target_array_ty),
.map_err(|_| TypedExpression::Array(e)),
e => Err(e), e => Err(e),
} }
} }
// precondition: `array` is only made of inline arrays and repeat constructs unless it does not contain the Integer type // precondition: `array` is only made of inline arrays and repeat constructs unless it does not contain the Integer type
pub fn try_from_int( pub fn try_from_int<S: PartialEq<UExpression<'ast, T>>>(
array: Self, array: Self,
target_inner_ty: Type<'ast, T>, target_array_ty: &GArrayType<S>,
) -> Result<Self, TypedExpression<'ast, T>> { ) -> Result<Self, TypedExpression<'ast, T>> {
let array_ty = array.ty(); let array_ty = array.ty();
// elements must fit in the target type // elements must fit in the target type
match array.into_inner() { match array.into_inner() {
ArrayExpressionInner::Value(inline_array) => { ArrayExpressionInner::Value(inline_array) => {
let res = match target_inner_ty.clone() { let res = match &*target_array_ty.ty {
Type::Int => Ok(inline_array), GType::Int => Ok(inline_array),
t => { _ => {
// try to convert all elements to the target type // try to convert all elements to the target type
inline_array inline_array
.into_iter() .into_iter()
.map(|v| { .map(|v| {
TypedExpressionOrSpread::align_to_type(v, t.clone()).map_err( TypedExpressionOrSpread::align_to_type(v, &target_array_ty).map_err(
|(e, _)| match e { |(e, _)| match e {
TypedExpressionOrSpread::Expression(e) => e, TypedExpressionOrSpread::Expression(e) => e,
TypedExpressionOrSpread::Spread(a) => { TypedExpressionOrSpread::Spread(a) => {
@ -508,11 +616,11 @@ impl<'ast, T: Field> ArrayExpression<'ast, T> {
Ok(ArrayExpressionInner::Value(res).annotate(inner_ty, array_ty.size)) Ok(ArrayExpressionInner::Value(res).annotate(inner_ty, array_ty.size))
} }
ArrayExpressionInner::Repeat(box e, box count) => { ArrayExpressionInner::Repeat(box e, box count) => {
match target_inner_ty.clone() { match &*target_array_ty.ty {
Type::Int => Ok(ArrayExpressionInner::Repeat(box e, box count) GType::Int => Ok(ArrayExpressionInner::Repeat(box e, box count)
.annotate(Type::Int, array_ty.size)), .annotate(Type::Int, array_ty.size)),
// try to align the repeated element to the target type // try to align the repeated element to the target type
t => TypedExpression::align_to_type(e, t) t => TypedExpression::align_to_type(e, &t)
.map(|e| { .map(|e| {
let ty = e.get_type().clone(); let ty = e.get_type().clone();
@ -523,7 +631,7 @@ impl<'ast, T: Field> ArrayExpression<'ast, T> {
} }
} }
a => { a => {
if array_ty.ty.weak_eq(&target_inner_ty) { if *target_array_ty.ty == *array_ty.ty {
Ok(a.annotate(*array_ty.ty, array_ty.size)) Ok(a.annotate(*array_ty.ty, array_ty.size))
} else { } else {
Err(a.annotate(*array_ty.ty, array_ty.size).into()) Err(a.annotate(*array_ty.ty, array_ty.size).into())
@ -533,6 +641,49 @@ impl<'ast, T: Field> ArrayExpression<'ast, T> {
} }
} }
impl<'ast, T: Field> StructExpression<'ast, T> {
pub fn try_from_int<S: PartialEq<UExpression<'ast, T>>>(
struc: Self,
target_struct_ty: &GStructType<S>,
) -> Result<Self, TypedExpression<'ast, T>> {
let struct_ty = struc.ty().clone();
match struc.into_inner() {
StructExpressionInner::Value(inline_struct) => inline_struct
.into_iter()
.zip(target_struct_ty.members.iter())
.map(|(value, target_member)| {
TypedExpression::align_to_type(value, &*target_member.ty)
})
.collect::<Result<Vec<_>, _>>()
.map(|v| StructExpressionInner::Value(v).annotate(struct_ty.clone()))
.map_err(|(v, _)| v),
s => {
if struct_ty
.members
.iter()
.zip(target_struct_ty.members.iter())
.all(|(m, target_m)| *target_m.ty == *m.ty)
{
Ok(s.annotate(struct_ty))
} else {
Err(s.annotate(struct_ty).into())
}
}
}
}
pub fn try_from_typed<S: PartialEq<UExpression<'ast, T>>>(
e: TypedExpression<'ast, T>,
target_struct_ty: &GStructType<S>,
) -> Result<Self, TypedExpression<'ast, T>> {
match e {
TypedExpression::Struct(e) => Self::try_from_int(e, target_struct_ty),
e => Err(e),
}
}
}
impl<'ast, T> From<BigUint> for IntExpression<'ast, T> { impl<'ast, T> From<BigUint> for IntExpression<'ast, T> {
fn from(v: BigUint) -> Self { fn from(v: BigUint) -> Self {
IntExpression::Value(v) IntExpression::Value(v)
@ -652,7 +803,7 @@ mod tests {
for (r, e) in expressions for (r, e) in expressions
.into_iter() .into_iter()
.map(|e| UExpression::try_from_int(e, UBitwidth::B32).unwrap()) .map(|e| UExpression::try_from_int(e, &UBitwidth::B32).unwrap())
.zip(expected) .zip(expected)
{ {
assert_eq!(r, e); assert_eq!(r, e);
@ -665,7 +816,7 @@ mod tests {
for e in should_error for e in should_error
.into_iter() .into_iter()
.map(|e| UExpression::try_from_int(e, UBitwidth::B32)) .map(|e| UExpression::try_from_int(e, &UBitwidth::B32))
{ {
assert!(e.is_err()); assert!(e.is_err());
} }

View file

@ -14,15 +14,15 @@ mod integer;
mod parameter; mod parameter;
pub mod types; pub mod types;
mod uint; mod uint;
mod variable; pub mod variable;
pub use self::identifier::CoreIdentifier; pub use self::identifier::CoreIdentifier;
pub use self::parameter::{DeclarationParameter, GParameter}; pub use self::parameter::{DeclarationParameter, GParameter};
pub use self::types::{ pub use self::types::{
CanonicalConstantIdentifier, ConcreteFunctionKey, ConcreteSignature, ConcreteType, CanonicalConstantIdentifier, ConcreteFunctionKey, ConcreteSignature, ConcreteType,
ConstantIdentifier, DeclarationFunctionKey, DeclarationSignature, DeclarationType, GArrayType, ConstantIdentifier, DeclarationArrayType, DeclarationFunctionKey, DeclarationSignature,
GStructType, GType, GenericIdentifier, IntoTypes, Signature, StructType, Type, Types, DeclarationStructType, DeclarationType, GArrayType, GStructType, GType, GenericIdentifier,
UBitwidth, IntoTypes, Signature, StructType, Type, Types, UBitwidth,
}; };
use crate::typed_absy::types::ConcreteGenericsAssignment; use crate::typed_absy::types::ConcreteGenericsAssignment;
@ -107,13 +107,19 @@ impl<'ast, T: Field> TypedProgram<'ast, T> {
.arguments .arguments
.iter() .iter()
.map(|p| { .map(|p| {
types::ConcreteType::try_from(types::Type::<T>::from(p.id._type.clone())) types::ConcreteType::try_from(
.map(|ty| AbiInput { crate::typed_absy::types::try_from_g_type::<
public: !p.private, crate::typed_absy::types::DeclarationConstant<'ast>,
name: p.id.id.to_string(), UExpression<'ast, T>,
ty, >(p.id._type.clone())
}) .unwrap(),
.unwrap() )
.map(|ty| AbiInput {
public: !p.private,
name: p.id.id.to_string(),
ty,
})
.unwrap()
}) })
.collect(), .collect(),
outputs: main outputs: main
@ -121,7 +127,14 @@ impl<'ast, T: Field> TypedProgram<'ast, T> {
.outputs .outputs
.iter() .iter()
.map(|ty| { .map(|ty| {
types::ConcreteType::try_from(types::Type::<T>::from(ty.clone())).unwrap() types::ConcreteType::try_from(
crate::typed_absy::types::try_from_g_type::<
crate::typed_absy::types::DeclarationConstant<'ast>,
UExpression<'ast, T>,
>(ty.clone())
.unwrap(),
)
.unwrap()
}) })
.collect(), .collect(),
} }
@ -192,7 +205,7 @@ impl<'ast, T: fmt::Display> fmt::Display for TypedModule<'ast, T> {
.iter() .iter()
.map(|(id, symbol)| match symbol { .map(|(id, symbol)| match symbol {
TypedConstantSymbol::Here(ref tc) => { TypedConstantSymbol::Here(ref tc) => {
format!("const {} {} = {}", tc.ty, id.id, tc.expression) format!("const {} {} = {}", id.ty, id.id, tc)
} }
TypedConstantSymbol::There(ref imported_id) => { TypedConstantSymbol::There(ref imported_id) => {
format!( format!(
@ -298,27 +311,24 @@ impl<'ast, T: fmt::Display> fmt::Display for TypedFunction<'ast, T> {
#[derive(Clone, PartialEq, Debug)] #[derive(Clone, PartialEq, Debug)]
pub struct TypedConstant<'ast, T> { pub struct TypedConstant<'ast, T> {
// the type is already stored in the TypedExpression, but we want to avoid awkward trait bounds in `fmt::Display`
pub ty: Type<'ast, T>,
pub expression: TypedExpression<'ast, T>, pub expression: TypedExpression<'ast, T>,
} }
impl<'ast, T> TypedConstant<'ast, T> { impl<'ast, T> TypedConstant<'ast, T> {
pub fn new(ty: Type<'ast, T>, expression: TypedExpression<'ast, T>) -> Self { pub fn new(expression: TypedExpression<'ast, T>) -> Self {
TypedConstant { ty, expression } TypedConstant { expression }
} }
} }
impl<'ast, T: fmt::Display> fmt::Display for TypedConstant<'ast, T> { impl<'ast, T: fmt::Display> fmt::Display for TypedConstant<'ast, T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
// using `self.expression.get_type()` would be better here but ends up requiring stronger trait bounds write!(f, "{}", self.expression)
write!(f, "const {}({})", self.ty, self.expression)
} }
} }
impl<'ast, T: Clone> Typed<'ast, T> for TypedConstant<'ast, T> { impl<'ast, T: Field> Typed<'ast, T> for TypedConstant<'ast, T> {
fn get_type(&self) -> Type<'ast, T> { fn get_type(&self) -> Type<'ast, T> {
self.ty.clone() self.expression.get_type()
} }
} }
@ -1166,24 +1176,6 @@ pub struct StructExpression<'ast, T> {
inner: StructExpressionInner<'ast, T>, inner: StructExpressionInner<'ast, T>,
} }
impl<'ast, T: Field> StructExpression<'ast, T> {
pub fn try_from_typed(
e: TypedExpression<'ast, T>,
target_struct_ty: StructType<'ast, T>,
) -> Result<Self, TypedExpression<'ast, T>> {
match e {
TypedExpression::Struct(e) => {
if e.ty() == &target_struct_ty {
Ok(e)
} else {
Err(TypedExpression::Struct(e))
}
}
e => Err(e),
}
}
}
impl<'ast, T> StructExpression<'ast, T> { impl<'ast, T> StructExpression<'ast, T> {
pub fn ty(&self) -> &StructType<'ast, T> { pub fn ty(&self) -> &StructType<'ast, T> {
&self.ty &self.ty

View file

@ -121,6 +121,7 @@ pub trait ResultFolder<'ast, T: Field>: Sized {
Ok(CanonicalConstantIdentifier { Ok(CanonicalConstantIdentifier {
module: self.fold_module_id(i.module)?, module: self.fold_module_id(i.module)?,
id: i.id, id: i.id,
ty: box self.fold_declaration_type(*i.ty)?,
}) })
} }
@ -225,6 +226,11 @@ pub trait ResultFolder<'ast, T: Field>: Sized {
t: StructType<'ast, T>, t: StructType<'ast, T>,
) -> Result<StructType<'ast, T>, Self::Error> { ) -> Result<StructType<'ast, T>, Self::Error> {
Ok(StructType { Ok(StructType {
generics: t
.generics
.into_iter()
.map(|g| g.map(|g| self.fold_uint_expression(g)).transpose())
.collect::<Result<Vec<_>, _>>()?,
members: t members: t
.members .members
.into_iter() .into_iter()
@ -260,6 +266,11 @@ pub trait ResultFolder<'ast, T: Field>: Sized {
t: DeclarationStructType<'ast>, t: DeclarationStructType<'ast>,
) -> Result<DeclarationStructType<'ast>, Self::Error> { ) -> Result<DeclarationStructType<'ast>, Self::Error> {
Ok(DeclarationStructType { Ok(DeclarationStructType {
generics: t
.generics
.into_iter()
.map(|g| g.map(|g| self.fold_declaration_constant(g)).transpose())
.collect::<Result<Vec<_>, _>>()?,
members: t members: t
.members .members
.into_iter() .into_iter()
@ -1092,7 +1103,6 @@ pub fn fold_constant<'ast, T: Field, F: ResultFolder<'ast, T>>(
c: TypedConstant<'ast, T>, c: TypedConstant<'ast, T>,
) -> Result<TypedConstant<'ast, T>, F::Error> { ) -> Result<TypedConstant<'ast, T>, F::Error> {
Ok(TypedConstant { Ok(TypedConstant {
ty: f.fold_type(c.ty)?,
expression: f.fold_expression(c.expression)?, expression: f.fold_expression(c.expression)?,
}) })
} }

View file

@ -107,11 +107,20 @@ pub type ConstantIdentifier<'ast> = &'ast str;
pub struct CanonicalConstantIdentifier<'ast> { pub struct CanonicalConstantIdentifier<'ast> {
pub module: OwnedTypedModuleId, pub module: OwnedTypedModuleId,
pub id: ConstantIdentifier<'ast>, pub id: ConstantIdentifier<'ast>,
pub ty: Box<DeclarationType<'ast>>,
} }
impl<'ast> CanonicalConstantIdentifier<'ast> { impl<'ast> CanonicalConstantIdentifier<'ast> {
pub fn new(id: ConstantIdentifier<'ast>, module: OwnedTypedModuleId) -> Self { pub fn new(
CanonicalConstantIdentifier { module, id } id: ConstantIdentifier<'ast>,
module: OwnedTypedModuleId,
ty: DeclarationType<'ast>,
) -> Self {
CanonicalConstantIdentifier {
module,
id,
ty: box ty,
}
} }
} }
@ -122,6 +131,21 @@ pub enum DeclarationConstant<'ast> {
Constant(CanonicalConstantIdentifier<'ast>), Constant(CanonicalConstantIdentifier<'ast>),
} }
impl<'ast, T> PartialEq<UExpression<'ast, T>> for DeclarationConstant<'ast> {
fn eq(&self, other: &UExpression<'ast, T>) -> bool {
match (self, other.as_inner()) {
(DeclarationConstant::Concrete(c), UExpressionInner::Value(v)) => *c == *v as u32,
_ => true,
}
}
}
impl<'ast, T> PartialEq<DeclarationConstant<'ast>> for UExpression<'ast, T> {
fn eq(&self, other: &DeclarationConstant<'ast>) -> bool {
other.eq(self)
}
}
impl<'ast> From<u32> for DeclarationConstant<'ast> { impl<'ast> From<u32> for DeclarationConstant<'ast> {
fn from(e: u32) -> Self { fn from(e: u32) -> Self {
DeclarationConstant::Concrete(e) DeclarationConstant::Concrete(e)
@ -198,7 +222,7 @@ impl<'ast> TryInto<usize> for DeclarationConstant<'ast> {
pub type MemberId = String; pub type MemberId = String;
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, PartialOrd, Ord)] #[derive(Debug, Clone, Eq, Hash, Serialize, Deserialize, PartialOrd, Ord)]
pub struct GStructMember<S> { pub struct GStructMember<S> {
#[serde(rename = "name")] #[serde(rename = "name")]
pub id: MemberId, pub id: MemberId,
@ -210,8 +234,8 @@ pub type DeclarationStructMember<'ast> = GStructMember<DeclarationConstant<'ast>
pub type ConcreteStructMember = GStructMember<usize>; pub type ConcreteStructMember = GStructMember<usize>;
pub type StructMember<'ast, T> = GStructMember<UExpression<'ast, T>>; pub type StructMember<'ast, T> = GStructMember<UExpression<'ast, T>>;
impl<'ast, T: PartialEq> PartialEq<DeclarationStructMember<'ast>> for StructMember<'ast, T> { impl<'ast, S, R: PartialEq<S>> PartialEq<GStructMember<S>> for GStructMember<R> {
fn eq(&self, other: &DeclarationStructMember<'ast>) -> bool { fn eq(&self, other: &GStructMember<S>) -> bool {
self.id == other.id && *self.ty == *other.ty self.id == other.id && *self.ty == *other.ty
} }
} }
@ -239,19 +263,7 @@ impl<'ast, T> From<ConcreteStructMember> for StructMember<'ast, T> {
} }
} }
impl<'ast> From<ConcreteStructMember> for DeclarationStructMember<'ast> { #[derive(Clone, Eq, Hash, Serialize, Deserialize, PartialOrd, Ord, Debug)]
fn from(t: ConcreteStructMember) -> Self {
try_from_g_struct_member(t).unwrap()
}
}
impl<'ast, T> From<DeclarationStructMember<'ast>> for StructMember<'ast, T> {
fn from(t: DeclarationStructMember<'ast>) -> Self {
try_from_g_struct_member(t).unwrap()
}
}
#[derive(Clone, PartialEq, Eq, Hash, Serialize, Deserialize, PartialOrd, Ord, Debug)]
pub struct GArrayType<S> { pub struct GArrayType<S> {
pub size: S, pub size: S,
#[serde(flatten)] #[serde(flatten)]
@ -262,13 +274,9 @@ pub type DeclarationArrayType<'ast> = GArrayType<DeclarationConstant<'ast>>;
pub type ConcreteArrayType = GArrayType<usize>; pub type ConcreteArrayType = GArrayType<usize>;
pub type ArrayType<'ast, T> = GArrayType<UExpression<'ast, T>>; pub type ArrayType<'ast, T> = GArrayType<UExpression<'ast, T>>;
impl<'ast, T: PartialEq> PartialEq<DeclarationArrayType<'ast>> for ArrayType<'ast, T> { impl<'ast, S, R: PartialEq<S>> PartialEq<GArrayType<S>> for GArrayType<R> {
fn eq(&self, other: &DeclarationArrayType<'ast>) -> bool { fn eq(&self, other: &GArrayType<S>) -> bool {
*self.ty == *other.ty *self.ty == *other.ty && self.size == other.size
&& match (self.size.as_inner(), &other.size) {
(UExpressionInner::Value(l), DeclarationConstant::Concrete(r)) => *l as u32 == *r,
_ => true,
}
} }
} }
@ -298,22 +306,6 @@ impl<S: fmt::Display> fmt::Display for GArrayType<S> {
} }
} }
impl<'ast, T: PartialEq + fmt::Display> Type<'ast, T> {
// array type equality with non-strict size checks
// sizes always match unless they are different constants
pub fn weak_eq(&self, other: &Self) -> bool {
match (self, other) {
(Type::Array(t), Type::Array(u)) => t.ty.weak_eq(&u.ty),
(Type::Struct(t), Type::Struct(u)) => t
.members
.iter()
.zip(u.members.iter())
.all(|(m, n)| m.ty.weak_eq(&n.ty)),
(t, u) => t == u,
}
}
}
fn try_from_g_array_type<T: TryInto<U>, U>( fn try_from_g_array_type<T: TryInto<U>, U>(
t: GArrayType<T>, t: GArrayType<T>,
) -> Result<GArrayType<U>, SpecializationError> { ) -> Result<GArrayType<U>, SpecializationError> {
@ -350,18 +342,13 @@ impl<'ast> From<ConcreteArrayType> for DeclarationArrayType<'ast> {
} }
} }
impl<'ast, T> From<DeclarationArrayType<'ast>> for ArrayType<'ast, T> {
fn from(t: DeclarationArrayType<'ast>) -> Self {
try_from_g_array_type(t).unwrap()
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialOrd, Ord)] #[derive(Debug, Clone, Serialize, Deserialize, PartialOrd, Ord)]
pub struct GStructType<S> { pub struct GStructType<S> {
#[serde(flatten)] #[serde(flatten)]
pub canonical_location: StructLocation, pub canonical_location: StructLocation,
#[serde(skip)] #[serde(skip)]
pub location: Option<StructLocation>, pub location: Option<StructLocation>,
pub generics: Vec<Option<S>>,
pub members: Vec<GStructMember<S>>, pub members: Vec<GStructMember<S>>,
} }
@ -369,15 +356,25 @@ pub type DeclarationStructType<'ast> = GStructType<DeclarationConstant<'ast>>;
pub type ConcreteStructType = GStructType<usize>; pub type ConcreteStructType = GStructType<usize>;
pub type StructType<'ast, T> = GStructType<UExpression<'ast, T>>; pub type StructType<'ast, T> = GStructType<UExpression<'ast, T>>;
impl<S: PartialEq> PartialEq for GStructType<S> { impl<'ast, S, R: PartialEq<S>> PartialEq<GStructType<S>> for GStructType<R> {
fn eq(&self, other: &Self) -> bool { fn eq(&self, other: &GStructType<S>) -> bool {
self.canonical_location.eq(&other.canonical_location) self.canonical_location == other.canonical_location
&& self
.generics
.iter()
.zip(other.generics.iter())
.all(|(a, b)| match (a, b) {
(Some(a), Some(b)) => a == b,
(None, None) => true,
_ => false,
})
} }
} }
impl<S: Hash> Hash for GStructType<S> { impl<S: Hash> Hash for GStructType<S> {
fn hash<H: Hasher>(&self, state: &mut H) { fn hash<H: Hasher>(&self, state: &mut H) {
self.canonical_location.hash(state); self.canonical_location.hash(state);
self.generics.hash(state);
} }
} }
@ -389,6 +386,14 @@ fn try_from_g_struct_type<T: TryInto<U>, U>(
Ok(GStructType { Ok(GStructType {
location: t.location, location: t.location,
canonical_location: t.canonical_location, canonical_location: t.canonical_location,
generics: t
.generics
.into_iter()
.map(|g| match g {
Some(g) => g.try_into().map(Some).map_err(|_| SpecializationError),
None => Ok(None),
})
.collect::<Result<_, _>>()?,
members: t members: t
.members .members
.into_iter() .into_iter()
@ -417,17 +422,17 @@ impl<'ast> From<ConcreteStructType> for DeclarationStructType<'ast> {
} }
} }
impl<'ast, T> From<DeclarationStructType<'ast>> for StructType<'ast, T> {
fn from(t: DeclarationStructType<'ast>) -> Self {
try_from_g_struct_type(t).unwrap()
}
}
impl<S> GStructType<S> { impl<S> GStructType<S> {
pub fn new(module: PathBuf, name: String, members: Vec<GStructMember<S>>) -> Self { pub fn new(
module: PathBuf,
name: String,
generics: Vec<Option<S>>,
members: Vec<GStructMember<S>>,
) -> Self {
GStructType { GStructType {
canonical_location: StructLocation { module, name }, canonical_location: StructLocation { module, name },
location: None, location: None,
generics,
members, members,
} }
} }
@ -498,7 +503,7 @@ impl fmt::Display for UBitwidth {
} }
} }
#[derive(Clone, PartialEq, Eq, Hash, PartialOrd, Ord, Debug)] #[derive(Clone, Eq, Hash, PartialOrd, Ord, Debug)]
pub enum GType<S> { pub enum GType<S> {
FieldElement, FieldElement,
Boolean, Boolean,
@ -608,13 +613,13 @@ pub type DeclarationType<'ast> = GType<DeclarationConstant<'ast>>;
pub type ConcreteType = GType<usize>; pub type ConcreteType = GType<usize>;
pub type Type<'ast, T> = GType<UExpression<'ast, T>>; pub type Type<'ast, T> = GType<UExpression<'ast, T>>;
impl<'ast, T: PartialEq> PartialEq<DeclarationType<'ast>> for Type<'ast, T> { impl<'ast, S, R: PartialEq<S>> PartialEq<GType<S>> for GType<R> {
fn eq(&self, other: &DeclarationType<'ast>) -> bool { fn eq(&self, other: &GType<S>) -> bool {
use self::GType::*; use self::GType::*;
match (self, other) { match (self, other) {
(Array(l), Array(r)) => l == r, (Array(l), Array(r)) => l == r,
(Struct(l), Struct(r)) => l.canonical_location == r.canonical_location, (Struct(l), Struct(r)) => l == r,
(FieldElement, FieldElement) | (Boolean, Boolean) => true, (FieldElement, FieldElement) | (Boolean, Boolean) => true,
(Uint(l), Uint(r)) => l == r, (Uint(l), Uint(r)) => l == r,
_ => false, _ => false,
@ -622,7 +627,7 @@ impl<'ast, T: PartialEq> PartialEq<DeclarationType<'ast>> for Type<'ast, T> {
} }
} }
fn try_from_g_type<T: TryInto<U>, U>(t: GType<T>) -> Result<GType<U>, SpecializationError> { pub fn try_from_g_type<T: TryInto<U>, U>(t: GType<T>) -> Result<GType<U>, SpecializationError> {
match t { match t {
GType::FieldElement => Ok(GType::FieldElement), GType::FieldElement => Ok(GType::FieldElement),
GType::Boolean => Ok(GType::Boolean), GType::Boolean => Ok(GType::Boolean),
@ -653,12 +658,6 @@ impl<'ast> From<ConcreteType> for DeclarationType<'ast> {
} }
} }
impl<'ast, T> From<DeclarationType<'ast>> for Type<'ast, T> {
fn from(t: DeclarationType<'ast>) -> Self {
try_from_g_type(t).unwrap()
}
}
impl<S, U: Into<S>> From<(GType<S>, U)> for GArrayType<S> { impl<S, U: Into<S>> From<(GType<S>, U)> for GArrayType<S> {
fn from(tup: (GType<S>, U)) -> Self { fn from(tup: (GType<S>, U)) -> Self {
GArrayType { GArrayType {
@ -669,10 +668,10 @@ impl<S, U: Into<S>> From<(GType<S>, U)> for GArrayType<S> {
} }
impl<S> GArrayType<S> { impl<S> GArrayType<S> {
pub fn new(ty: GType<S>, size: S) -> Self { pub fn new<U: Into<S>>(ty: GType<S>, size: U) -> Self {
GArrayType { GArrayType {
ty: Box::new(ty), ty: Box::new(ty),
size, size: size.into(),
} }
} }
} }
@ -694,11 +693,36 @@ impl<S: fmt::Display> fmt::Display for GType<S> {
GType::Uint(ref bitwidth) => write!(f, "u{}", bitwidth), GType::Uint(ref bitwidth) => write!(f, "u{}", bitwidth),
GType::Int => write!(f, "{{integer}}"), GType::Int => write!(f, "{{integer}}"),
GType::Array(ref array_type) => write!(f, "{}", array_type), GType::Array(ref array_type) => write!(f, "{}", array_type),
GType::Struct(ref struct_type) => write!(f, "{}", struct_type.name(),), GType::Struct(ref struct_type) => write!(f, "{}", struct_type),
} }
} }
} }
impl<S: fmt::Display> fmt::Display for GStructType<S> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(
f,
"{}{}",
self.name(),
if !self.generics.is_empty() {
format!(
"<{}>",
self.generics
.iter()
.map(|g| g
.as_ref()
.map(|g| g.to_string())
.unwrap_or_else(|| '_'.to_string()))
.collect::<Vec<_>>()
.join(", ")
)
} else {
"".to_string()
}
)
}
}
impl<S> GType<S> { impl<S> GType<S> {
pub fn array<U: Into<GArrayType<S>>>(array_ty: U) -> Self { pub fn array<U: Into<GArrayType<S>>>(array_ty: U) -> Self {
GType::Array(array_ty.into()) GType::Array(array_ty.into())
@ -717,7 +741,7 @@ impl<'ast, T: fmt::Display + PartialEq + fmt::Debug> Type<'ast, T> {
pub fn can_be_specialized_to(&self, other: &DeclarationType) -> bool { pub fn can_be_specialized_to(&self, other: &DeclarationType) -> bool {
use self::GType::*; use self::GType::*;
if self == other { if other == self {
true true
} else { } else {
match (self, other) { match (self, other) {
@ -735,7 +759,13 @@ impl<'ast, T: fmt::Display + PartialEq + fmt::Debug> Type<'ast, T> {
} }
_ => false, _ => false,
}, },
(Struct(_), Struct(_)) => false, (Struct(l), Struct(r)) => {
l.canonical_location == r.canonical_location
&& l.members
.iter()
.zip(r.members.iter())
.all(|(m, d_m)| m.ty.can_be_specialized_to(&*d_m.ty))
}
_ => false, _ => false,
} }
} }
@ -848,14 +878,6 @@ impl<'ast, T> TryFrom<FunctionKey<'ast, T>> for ConcreteFunctionKey<'ast> {
} }
} }
// impl<'ast> TryFrom<DeclarationFunctionKey<'ast>> for ConcreteFunctionKey<'ast> {
// type Error = SpecializationError;
// fn try_from(k: DeclarationFunctionKey<'ast>) -> Result<Self, Self::Error> {
// try_from_g_function_key(k)
// }
// }
impl<'ast, T> From<ConcreteFunctionKey<'ast>> for FunctionKey<'ast, T> { impl<'ast, T> From<ConcreteFunctionKey<'ast>> for FunctionKey<'ast, T> {
fn from(k: ConcreteFunctionKey<'ast>) -> Self { fn from(k: ConcreteFunctionKey<'ast>) -> Self {
try_from_g_function_key(k).unwrap() try_from_g_function_key(k).unwrap()
@ -868,12 +890,6 @@ impl<'ast> From<ConcreteFunctionKey<'ast>> for DeclarationFunctionKey<'ast> {
} }
} }
impl<'ast, T> From<DeclarationFunctionKey<'ast>> for FunctionKey<'ast, T> {
fn from(k: DeclarationFunctionKey<'ast>) -> Self {
try_from_g_function_key(k).unwrap()
}
}
impl<'ast, S> GFunctionKey<'ast, S> { impl<'ast, S> GFunctionKey<'ast, S> {
pub fn with_location<T: Into<OwnedTypedModuleId>, U: Into<FunctionIdentifier<'ast>>>( pub fn with_location<T: Into<OwnedTypedModuleId>, U: Into<FunctionIdentifier<'ast>>>(
module: T, module: T,
@ -913,7 +929,118 @@ impl<'ast> ConcreteFunctionKey<'ast> {
} }
} }
pub use self::signature::{ConcreteSignature, DeclarationSignature, GSignature, Signature}; use std::collections::btree_map::Entry;
pub fn check_type<'ast, S: Clone + PartialEq + PartialEq<usize>>(
decl_ty: &DeclarationType<'ast>,
ty: &GType<S>,
constants: &mut GGenericsAssignment<'ast, S>,
) -> bool {
match (decl_ty, ty) {
(DeclarationType::Array(t0), GType::Array(t1)) => {
let s1 = t1.size.clone();
// both the inner type and the size must match
check_type(&t0.ty, &t1.ty, constants)
&& match &t0.size {
// if the declared size is an identifier, we insert into the map, or check if the concrete size
// matches if this identifier is already in the map
DeclarationConstant::Generic(id) => match constants.0.entry(id.clone()) {
Entry::Occupied(e) => *e.get() == s1,
Entry::Vacant(e) => {
e.insert(s1);
true
}
},
DeclarationConstant::Concrete(s0) => s1 == *s0 as usize,
// in the case of a constant, we do not know the value yet, so we optimistically assume it's correct
// if it does not match, it will be caught during inlining
DeclarationConstant::Constant(..) => true,
}
}
(DeclarationType::FieldElement, GType::FieldElement)
| (DeclarationType::Boolean, GType::Boolean) => true,
(DeclarationType::Uint(b0), GType::Uint(b1)) => b0 == b1,
(DeclarationType::Struct(s0), GType::Struct(s1)) => {
s0.canonical_location == s1.canonical_location
}
_ => false,
}
}
impl<'ast, T> From<CanonicalConstantIdentifier<'ast>> for UExpression<'ast, T> {
fn from(_: CanonicalConstantIdentifier<'ast>) -> Self {
unreachable!("constants should have been removed in constant inlining")
}
}
impl<'ast> From<CanonicalConstantIdentifier<'ast>> for DeclarationConstant<'ast> {
fn from(c: CanonicalConstantIdentifier<'ast>) -> Self {
DeclarationConstant::Constant(c)
}
}
pub fn specialize_declaration_type<
'ast,
S: Clone + PartialEq + From<u32> + fmt::Debug + From<CanonicalConstantIdentifier<'ast>>,
>(
decl_ty: DeclarationType<'ast>,
generics: &GGenericsAssignment<'ast, S>,
) -> Result<GType<S>, GenericIdentifier<'ast>> {
Ok(match decl_ty {
DeclarationType::Int => unreachable!(),
DeclarationType::Array(t0) => {
// let s1 = t1.size.clone();
let ty = box specialize_declaration_type(*t0.ty, &generics)?;
let size = match t0.size {
DeclarationConstant::Generic(s) => generics.0.get(&s).cloned().ok_or(s),
DeclarationConstant::Concrete(s) => Ok(s.into()),
DeclarationConstant::Constant(c) => Ok(c.into()),
}?;
GType::Array(GArrayType { size, ty })
}
DeclarationType::FieldElement => GType::FieldElement,
DeclarationType::Boolean => GType::Boolean,
DeclarationType::Uint(b0) => GType::Uint(b0),
DeclarationType::Struct(s0) => GType::Struct(GStructType {
members: s0
.members
.into_iter()
.map(|m| {
let id = m.id;
specialize_declaration_type(*m.ty, generics)
.map(|ty| GStructMember { ty: box ty, id })
})
.collect::<Result<_, _>>()?,
generics: s0
.generics
.into_iter()
.map(|g| match g {
Some(constant) => match constant {
DeclarationConstant::Generic(s) => {
generics.0.get(&s).cloned().ok_or(s).map(Some)
}
DeclarationConstant::Concrete(s) => Ok(Some(s.into())),
DeclarationConstant::Constant(..) => {
unreachable!(
"identifiers should have been removed in constant inlining"
)
}
},
_ => Ok(None),
})
.collect::<Result<_, _>>()?,
canonical_location: s0.canonical_location,
location: s0.location,
}),
})
}
pub use self::signature::{
try_from_g_signature, ConcreteSignature, DeclarationSignature, GSignature, Signature,
};
pub mod signature { pub mod signature {
use super::*; use super::*;
@ -968,83 +1095,6 @@ pub mod signature {
pub type ConcreteSignature = GSignature<usize>; pub type ConcreteSignature = GSignature<usize>;
pub type Signature<'ast, T> = GSignature<UExpression<'ast, T>>; pub type Signature<'ast, T> = GSignature<UExpression<'ast, T>>;
use std::collections::btree_map::Entry;
fn check_type<'ast, S: Clone + PartialEq + PartialEq<usize>>(
decl_ty: &DeclarationType<'ast>,
ty: &GType<S>,
constants: &mut GGenericsAssignment<'ast, S>,
) -> bool {
match (decl_ty, ty) {
(DeclarationType::Array(t0), GType::Array(t1)) => {
let s1 = t1.size.clone();
// both the inner type and the size must match
check_type(&t0.ty, &t1.ty, constants)
&& match &t0.size {
// if the declared size is an identifier, we insert into the map, or check if the concrete size
// matches if this identifier is already in the map
DeclarationConstant::Generic(id) => match constants.0.entry(id.clone()) {
Entry::Occupied(e) => *e.get() == s1,
Entry::Vacant(e) => {
e.insert(s1);
true
}
},
DeclarationConstant::Concrete(s0) => s1 == *s0 as usize,
// in the case of a constant, we do not know the value yet, so we optimistically assume it's correct
// if it does not match, it will be caught during inlining
DeclarationConstant::Constant(..) => true,
}
}
(DeclarationType::FieldElement, GType::FieldElement)
| (DeclarationType::Boolean, GType::Boolean) => true,
(DeclarationType::Uint(b0), GType::Uint(b1)) => b0 == b1,
(DeclarationType::Struct(s0), GType::Struct(s1)) => {
s0.canonical_location == s1.canonical_location
}
_ => false,
}
}
fn specialize_type<'ast, S: Clone + PartialEq + PartialEq<usize> + From<u32> + fmt::Debug>(
decl_ty: DeclarationType<'ast>,
constants: &GGenericsAssignment<'ast, S>,
) -> Result<GType<S>, GenericIdentifier<'ast>> {
Ok(match decl_ty {
DeclarationType::Int => unreachable!(),
DeclarationType::Array(t0) => {
// let s1 = t1.size.clone();
let ty = box specialize_type(*t0.ty, &constants)?;
let size = match t0.size {
DeclarationConstant::Generic(s) => constants.0.get(&s).cloned().ok_or(s),
DeclarationConstant::Concrete(s) => Ok(s.into()),
DeclarationConstant::Constant(..) => {
unreachable!("identifiers should have been removed in constant inlining")
}
}?;
GType::Array(GArrayType { size, ty })
}
DeclarationType::FieldElement => GType::FieldElement,
DeclarationType::Boolean => GType::Boolean,
DeclarationType::Uint(b0) => GType::Uint(b0),
DeclarationType::Struct(s0) => GType::Struct(GStructType {
members: s0
.members
.into_iter()
.map(|m| {
let id = m.id;
specialize_type(*m.ty, constants).map(|ty| GStructMember { ty: box ty, id })
})
.collect::<Result<_, _>>()?,
canonical_location: s0.canonical_location,
location: s0.location,
}),
})
}
impl<'ast> PartialEq<DeclarationSignature<'ast>> for ConcreteSignature { impl<'ast> PartialEq<DeclarationSignature<'ast>> for ConcreteSignature {
fn eq(&self, other: &DeclarationSignature<'ast>) -> bool { fn eq(&self, other: &DeclarationSignature<'ast>) -> bool {
// we keep track of the value of constants in a map, as a given constant can only have one value // we keep track of the value of constants in a map, as a given constant can only have one value
@ -1135,7 +1185,7 @@ pub mod signature {
self.outputs self.outputs
.clone() .clone()
.into_iter() .into_iter()
.map(|t| specialize_type(t, &constants)) .map(|t| specialize_declaration_type(t, &constants))
.collect::<Result<_, _>>() .collect::<Result<_, _>>()
} }
} }
@ -1185,12 +1235,6 @@ pub mod signature {
} }
} }
impl<'ast, T> From<DeclarationSignature<'ast>> for Signature<'ast, T> {
fn from(s: DeclarationSignature<'ast>) -> Self {
try_from_g_signature(s).unwrap()
}
}
impl<S: fmt::Display> fmt::Display for GSignature<S> { impl<S: fmt::Display> fmt::Display for GSignature<S> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
if !self.generics.is_empty() { if !self.generics.is_empty() {
@ -1329,8 +1373,7 @@ pub mod signature {
GenericIdentifier { GenericIdentifier {
name: "P", name: "P",
index: 0, index: 0,
} },
.into(),
))]); ))]);
let generic2 = DeclarationSignature::new() let generic2 = DeclarationSignature::new()
.generics(vec![Some( .generics(vec![Some(
@ -1345,8 +1388,7 @@ pub mod signature {
GenericIdentifier { GenericIdentifier {
name: "Q", name: "Q",
index: 0, index: 0,
} },
.into(),
))]); ))]);
assert_eq!(generic1, generic2); assert_eq!(generic1, generic2);

View file

@ -25,16 +25,6 @@ impl<'ast, T> TryFrom<Variable<'ast, T>> for ConcreteVariable<'ast> {
} }
} }
// impl<'ast> TryFrom<DeclarationVariable<'ast>> for ConcreteVariable<'ast> {
// type Error = SpecializationError;
// fn try_from(v: DeclarationVariable<'ast>) -> Result<Self, Self::Error> {
// let _type = v._type.try_into()?;
// Ok(Self { _type, id: v.id })
// }
// }
impl<'ast, T> From<ConcreteVariable<'ast>> for Variable<'ast, T> { impl<'ast, T> From<ConcreteVariable<'ast>> for Variable<'ast, T> {
fn from(v: ConcreteVariable<'ast>) -> Self { fn from(v: ConcreteVariable<'ast>) -> Self {
let _type = v._type.into(); let _type = v._type.into();
@ -43,12 +33,12 @@ impl<'ast, T> From<ConcreteVariable<'ast>> for Variable<'ast, T> {
} }
} }
impl<'ast, T> From<DeclarationVariable<'ast>> for Variable<'ast, T> { pub fn try_from_g_variable<T: TryInto<U>, U>(
fn from(v: DeclarationVariable<'ast>) -> Self { v: GVariable<T>,
let _type = v._type.into(); ) -> Result<GVariable<U>, SpecializationError> {
let _type = crate::typed_absy::types::try_from_g_type(v._type)?;
Self { _type, id: v.id } Ok(GVariable { _type, id: v.id })
}
} }
impl<'ast, S: Clone> GVariable<'ast, S> { impl<'ast, S: Clone> GVariable<'ast, S> {

View file

@ -0,0 +1,20 @@
{
"entry_point": "./tests/tests/structs/member_order.zok",
"curves": ["Bn128"],
"tests": [
{
"abi": true,
"input": {
"values": [{
"a":true,
"b": "3"
}]
},
"output": {
"Ok": {
"values": []
}
}
}
]
}

View file

@ -0,0 +1,9 @@
struct Foo {
field b
bool a
}
// this tests the abi, checking that the fields of a `Foo` instance get encoded in the right order
// if the the encoder reverses `a` and `b`, the boolean check ends up being done on the field value, which would fail
def main(Foo f):
return

View file

@ -34,9 +34,9 @@ ty_array = { ty_basic_or_struct ~ ("[" ~ expression ~ "]")+ }
ty = { ty_array | ty_basic | ty_struct } ty = { ty_array | ty_basic | ty_struct }
type_list = _{(ty ~ ("," ~ ty)*)?} type_list = _{(ty ~ ("," ~ ty)*)?}
// structs // structs
ty_struct = { identifier } ty_struct = { identifier ~ explicit_generics? }
// type definitions // type definitions
ty_struct_definition = { "struct" ~ identifier ~ "{" ~ NEWLINE* ~ struct_field_list ~ NEWLINE* ~ "}" ~ NEWLINE* } ty_struct_definition = { "struct" ~ identifier ~ constant_generics_declaration? ~ "{" ~ NEWLINE* ~ struct_field_list ~ NEWLINE* ~ "}" ~ NEWLINE* }
struct_field_list = _{(struct_field ~ (NEWLINE+ ~ struct_field)*)? } struct_field_list = _{(struct_field ~ (NEWLINE+ ~ struct_field)*)? }
struct_field = { ty ~ identifier } struct_field = { ty ~ identifier }
@ -77,9 +77,9 @@ conditional_expression = { "if" ~ expression ~ "then" ~ expression ~ "else" ~ ex
postfix_expression = { identifier ~ access+ } // we force there to be at least one access, otherwise this matches single identifiers postfix_expression = { identifier ~ access+ } // we force there to be at least one access, otherwise this matches single identifiers
access = { array_access | call_access | member_access } access = { array_access | call_access | member_access }
array_access = { "[" ~ range_or_expression ~ "]" } array_access = { "[" ~ range_or_expression ~ "]" }
call_access = { explicit_generics? ~ "(" ~ arguments ~ ")" } call_access = { ("::" ~ explicit_generics)? ~ "(" ~ arguments ~ ")" }
arguments = { expression_list } arguments = { expression_list }
explicit_generics = { "::<" ~ constant_generics_values ~ ">" } explicit_generics = { "<" ~ constant_generics_values ~ ">" }
constant_generics_values = _{ constant_generics_value ~ ("," ~ constant_generics_value)* } constant_generics_values = _{ constant_generics_value ~ ("," ~ constant_generics_value)* }
constant_generics_value = { literal | identifier | underscore } constant_generics_value = { literal | identifier | underscore }
underscore = { "_" } underscore = { "_" }

View file

@ -147,6 +147,7 @@ mod ast {
#[pest_ast(rule(Rule::ty_struct_definition))] #[pest_ast(rule(Rule::ty_struct_definition))]
pub struct StructDefinition<'ast> { pub struct StructDefinition<'ast> {
pub id: IdentifierExpression<'ast>, pub id: IdentifierExpression<'ast>,
pub generics: Vec<IdentifierExpression<'ast>>,
pub fields: Vec<StructField<'ast>>, pub fields: Vec<StructField<'ast>>,
#[pest_ast(outer())] #[pest_ast(outer())]
pub span: Span<'ast>, pub span: Span<'ast>,
@ -307,6 +308,7 @@ mod ast {
#[pest_ast(rule(Rule::ty_struct))] #[pest_ast(rule(Rule::ty_struct))]
pub struct StructType<'ast> { pub struct StructType<'ast> {
pub id: IdentifierExpression<'ast>, pub id: IdentifierExpression<'ast>,
pub explicit_generics: Option<ExplicitGenerics<'ast>>,
#[pest_ast(outer())] #[pest_ast(outer())]
pub span: Span<'ast>, pub span: Span<'ast>,
} }

View file

@ -0,0 +1,52 @@
#pragma curve bw6_761
from "EMBED" import snark_verify_bls12_377 as verify
struct ProofInner {
field[2] a
field[2][2] b
field[2] c
}
struct Proof<N> {
ProofInner proof
field[N] inputs
}
struct VerificationKey<N> {
field[2][2] h
field[2] g_alpha
field[2][2] h_beta
field[2] g_gamma
field[2][2] h_gamma
field[N][2] query // input length + 1
}
def flat<N, F>(field[N][2] input) -> field[F]:
assert(F == N * 2)
field[F] out = [0; F]
for u32 i in 0..N do
for u32 j in 0..2 do
out[(i * 2) + j] = input[i][j]
endfor
endfor
return out
def main<N, Q>(Proof<N> proof, VerificationKey<Q> vk) -> bool:
assert(Q == N + 1) // query length (Q) should be N + 1
field[8] flat_proof = [
...proof.proof.a,
...flat::<2, 4>(proof.proof.b),
...proof.proof.c
]
u32 two_Q = 2 * Q
field[16 + (2 * Q)] flat_vk = [
...flat::<2, 4>(vk.h),
...vk.g_alpha,
...flat::<2, 4>(vk.h_beta),
...vk.g_gamma,
...flat::<2, 4>(vk.h_gamma),
...flat::<Q, two_Q>(vk.query)
]
return verify(proof.inputs, flat_proof, flat_vk)

View file

@ -0,0 +1,103 @@
{
"entry_point": "./tests/tests/snark/gm17.zok",
"curves": ["Bw6_761"],
"tests": [
{
"abi": true,
"input": {
"values": [
{
"proof": {
"a": [
"0x01441e34fd88112583831de068e3bdf67d7a5b020c9650e4dc8e3dd0cf92f62b32668dd4654ddc63fe5293a542756a27",
"0x013d7b6097a6ae8534909cb2f2ec2e39f3ccbe8858db0285e45619131db37f84b1c88fbb257a7b8e8944a926bb41aa66"
],
"b": [
[
"0x00dcf8242e445213da28281aab32bcf47268bf16624dbca7c828cfbb0e8000bad94926272cba0cd5e9a959cf4e969c7c",
"0x00b570276d40ae06ac3feb5db65b37acf1eabd16e1c588d01c553b1a60e5d007d9202a8ad2b6405e521b3eec84772521"
],
[
"0x00acbeabed6267316420b73b9eba39e8c51080b8b507857478a54c0fc259b17eec2921253a15445e2ec3c130706398b0",
"0x019b579a061cbc4aed64351d87ba96c071118ef3fd645e630c18986e284de5ffc8a48ea94eeb3bdc8807d62d366e223f"
]
],
"c": [
"0x004c93c20cd43f8b7818fcc4ece38243779bedb8b874702df4d6968b75cbe2e6831ab38475e2f0c7bc170171580198df",
"0x0177a560e5f6ae87f07aeff2dcdb1e0737b4810aeba8a5ba1bc4c5d0e89f268aae142ab5327afbde8e8bad869702aad3"
]
},
"inputs": [
"0x0000000000000000000000000000000000000000000000000000000000000001",
"0x0000000000000000000000000000000000000000000000000000000000000002",
"0x0000000000000000000000000000000000000000000000000000000000000003"
]
},
{
"h": [
[
"0x000a4c42894d5fd7ac23ca05eac034d82299dd9db5fa493812e4852bcf50cd88faf8f3e97cd292678b292d11e173949b",
"0x001ead78f91728b07146e93ee1f21165f25ad88e0fee997f5527076ca84374d3a6d834b59608226b28ab8b8d5ea9a94f"
],
[
"0x0087b1837c209351af3b67bbfeaea80ed94f690584847b1aa34cc59a2b451f360fc268b2562ea8015f8f4d71c7bf4675",
"0x015c50d51c8ed463a4e9cc76fc0583634b04dc26b36e10bfac9169d0baebf58b45b687a81a0ca60400427889bcbc6b76"
]
],
"g_alpha": [
"0x004b7af9ab6ef9061adb5ed7ba12e9cd41f508ac758c25c5e629d871a1b980e5242149b522b20c57808fae97cb76b971",
"0x0196c16d89a7cccbb8f15775da22c01d5ec45b384829bcaad91b324a482676558d3d6d41f675966b5d22537f4ed77903"
],
"h_beta": [
[
"0x014d2d0bcfa272334efbc589dc263c3f2a5d2711f9a0d5fbb3c2ad1b7eebe93459aeee6e1c8bc02041945313aec93d8a",
"0x0054800f89ebbbd924328a7782fdbb5260b56059901a06e6ad58c4a7df96018e5ea1c5ffd28ed0dd0139dcced6bde7e8"
],
[
"0x00ca4e270e5fe79ff2a5432daf6e9e5aa22aebf6521a7d3c5ef97d981b05ea93043c6307b47e8a3e00ace9c987fb725e",
"0x010cb8f97a5d586777e4f7ca8a0ce4465c0de02951cb8ccca43403b1a669e523c1163ebc9ce7d10edf583894fad70341"
]
],
"g_gamma": [
"0x003fa4d4d1fe1a9bb62e704b5ac76a514e4aaf53cfcbd12cb55aa7afecf2c12ce9346737b5594ee872700178748e9ed1",
"0x018975a2eb9de8a1982d076b56bb86b5214f89cff897d492e16dcdc1eca2a692eb9f0af5183585ba4aee9d78af2ab570"
],
"h_gamma": [
[
"0x000a4c42894d5fd7ac23ca05eac034d82299dd9db5fa493812e4852bcf50cd88faf8f3e97cd292678b292d11e173949b",
"0x001ead78f91728b07146e93ee1f21165f25ad88e0fee997f5527076ca84374d3a6d834b59608226b28ab8b8d5ea9a94f"
],
[
"0x0087b1837c209351af3b67bbfeaea80ed94f690584847b1aa34cc59a2b451f360fc268b2562ea8015f8f4d71c7bf4675",
"0x015c50d51c8ed463a4e9cc76fc0583634b04dc26b36e10bfac9169d0baebf58b45b687a81a0ca60400427889bcbc6b76"
]
],
"query": [
[
"0x00dbcc84391e078ae2fa7b5dc8478651b945e155505332a55e5b7be4de52ce83450bbf94f1da270c012104d394b22fda",
"0x002dc3039f7236d31fceaa6d8e13d33a5850984193f70c0abfe20a1f4540f59987e49cb0cc2722f1dccb47f1012d38c8"
],
[
"0x00db1bc3a431619ca74564c8a734592151a5fc2d8bfa750d4ffb94126bdaed83dce86bcdc8f966dca3066f67c61c897c",
"0x00e97f2f6c94a2676dd3c8646a45684cfd66a644644c1fc8ee5cf2ab4e322a5a82a9f9872ec9e8c7f3f1a9ddf38f2e53"
],
[
"0x008f4c292ba1ae0fa22613e0afaa075796b21a935e591fb8e8b32fa7c0fe0ecda25d5575e1e2b178d5a4bfb8e89f9d36",
"0x017cb6aca4e2d1027ab429a2a7d6b8f6e13dfeb427b7eaf9b6e3ca22554fae39f45ee0854098c9753cca04b46f3388d0"
],
[
"0x0168740e2d9cab168df083dd1d340de23d5055f4eed63c87811e94a5bf9c492658c6c58ccb1a48bb153cbe9aa8d98c8d",
"0x005b7c28b57504562c1d38a5ba9c67a59c696dc2e51b3c50d96e75e2f399f9106f08f6846d553d32e58b8131ad997fc1"
]
]
}
]
},
"output": {
"Ok": {
"values": ["1"]
}
}
}
]
}

View file

@ -0,0 +1,57 @@
// verify a snark
// to reproduce the test cases:
//
// 1. Create a program
// ```zokrates
// def main(field a, field b) -> field:
// return a + b
// ```
//
// 2. Compile it to bls12_377
// ```sh
// zokrates compile -i program.zok --curve bls12_377
// ```
//
// 3. Run a trusted setup for gm17
// ```sh
// zokrates setup --proving-scheme gm17 --backend ark
// ```
//
// 4. Execute the program and generate a proof
// ```sh
// zokrates compute-witness -a 1 2
// zokrates generate-proof --proving-scheme gm17 --backend ark
// ```
//
// 5. Generate the test case
//
// ```sh
// cat > gm17.json << EOT
// {
// "entry_point": "./tests/tests/snark/gm17.zok",
// "curves": ["Bw6_761"],
// "tests": [
// {
// "abi": true,
// "input": {
// "values": [
// $(cat proof.json && echo ", " && cat verification.key)
// ]
// },
// "output": {
// "Ok": {
// "values": ["1"]
// }
// }
// }
// ]
// }
// EOT
// ```
//
// `gm17.json` can then be used as a test for this code file
from "snark/gm17" import main as verify, Proof, VerificationKey
def main(Proof<3> proof, VerificationKey<4> vk) -> bool:
return verify::<3, 4>(proof, vk)

View file

@ -8,6 +8,7 @@ edition = "2018"
zokrates_field = { version = "0.4", path = "../zokrates_field" } zokrates_field = { version = "0.4", path = "../zokrates_field" }
zokrates_core = { version = "0.6", path = "../zokrates_core" } zokrates_core = { version = "0.6", path = "../zokrates_core" }
zokrates_fs_resolver = { version = "0.5", path = "../zokrates_fs_resolver" } zokrates_fs_resolver = { version = "0.5", path = "../zokrates_fs_resolver" }
zokrates_abi = { version = "0.1", path = "../zokrates_abi" }
serde = "1.0" serde = "1.0"
serde_derive = "1.0" serde_derive = "1.0"
serde_json = "1.0" serde_json = "1.0"

View file

@ -5,8 +5,11 @@ use std::fs::File;
use std::io::{BufReader, Read}; use std::io::{BufReader, Read};
use std::path::{Path, PathBuf}; use std::path::{Path, PathBuf};
use zokrates_core::compile::{compile, CompileConfig};
use zokrates_core::ir; use zokrates_core::ir;
use zokrates_core::{
compile::{compile, CompileConfig},
typed_absy::ConcreteType,
};
use zokrates_field::{Bls12_377Field, Bls12_381Field, Bn128Field, Bw6_761Field, Field}; use zokrates_field::{Bls12_377Field, Bls12_381Field, Bn128Field, Bw6_761Field, Field};
use zokrates_fs_resolver::FileSystemResolver; use zokrates_fs_resolver::FileSystemResolver;
@ -34,6 +37,7 @@ struct Input {
#[derive(Serialize, Deserialize, Clone)] #[derive(Serialize, Deserialize, Clone)]
struct Test { struct Test {
pub abi: Option<bool>,
pub input: Input, pub input: Input,
pub output: TestResult, pub output: TestResult,
} }
@ -48,11 +52,24 @@ struct Output {
values: Vec<Val>, values: Vec<Val>,
} }
type Val = String; type Val = serde_json::Value;
fn parse_val<T: Field>(s: String) -> T { fn try_parse_raw_val<T: Field>(s: serde_json::Value) -> Result<T, ()> {
let radix = if s.starts_with("0x") { 16 } else { 10 }; match s {
T::try_from_str(s.trim_start_matches("0x"), radix).unwrap() serde_json::Value::String(s) => {
let radix = if s.starts_with("0x") { 16 } else { 10 };
T::try_from_str(s.trim_start_matches("0x"), radix).map_err(|_| ())
}
_ => Err(()),
}
}
fn try_parse_abi_val<T: Field>(
s: Vec<Val>,
types: Vec<ConcreteType>,
) -> Result<Vec<T>, zokrates_abi::Error> {
use zokrates_abi::Encode;
zokrates_abi::parse_strict_json(s, types).map(|v| v.encode())
} }
impl<T: Field> From<ir::ExecutionResult<T>> for ComparableResult<T> { impl<T: Field> From<ir::ExecutionResult<T>> for ComparableResult<T> {
@ -63,7 +80,12 @@ impl<T: Field> From<ir::ExecutionResult<T>> for ComparableResult<T> {
impl<T: Field> From<TestResult> for ComparableResult<T> { impl<T: Field> From<TestResult> for ComparableResult<T> {
fn from(r: TestResult) -> ComparableResult<T> { fn from(r: TestResult) -> ComparableResult<T> {
ComparableResult(r.map(|v| v.values.into_iter().map(parse_val).collect())) ComparableResult(r.map(|v| {
v.values
.into_iter()
.map(|v| try_parse_raw_val(v).unwrap())
.collect()
}))
} }
} }
@ -130,6 +152,7 @@ fn compile_and_run<T: Field>(t: Tests) {
let artifacts = compile::<T, _>(code, entry_point.clone(), Some(&resolver), &config).unwrap(); let artifacts = compile::<T, _>(code, entry_point.clone(), Some(&resolver), &config).unwrap();
let bin = artifacts.prog(); let bin = artifacts.prog();
let abi = artifacts.abi();
if let Some(target_count) = t.max_constraint_count { if let Some(target_count) = t.max_constraint_count {
let count = bin.constraint_count(); let count = bin.constraint_count();
@ -148,12 +171,21 @@ fn compile_and_run<T: Field>(t: Tests) {
let interpreter = zokrates_core::ir::Interpreter::default(); let interpreter = zokrates_core::ir::Interpreter::default();
for test in t.tests.into_iter() { for test in t.tests.into_iter() {
let input = &test.input.values; let with_abi = test.abi.unwrap_or(false);
let output = interpreter.execute( let input = if with_abi {
bin, try_parse_abi_val(test.input.values, abi.signature().inputs).unwrap()
&(input.iter().cloned().map(parse_val).collect::<Vec<_>>()), } else {
); test.input
.values
.iter()
.cloned()
.map(try_parse_raw_val)
.collect::<Result<Vec<_>, _>>()
.unwrap()
};
let output = interpreter.execute(bin, &input);
if let Err(e) = compare(output, test.output) { if let Err(e) = compare(output, test.output) {
let mut code = File::open(&entry_point).unwrap(); let mut code = File::open(&entry_point).unwrap();