1
0
Fork 0
mirror of synced 2025-09-23 04:08:33 +00:00

Merge branch 'develop' of github.com:Zokrates/ZoKrates into better-boolean-array-equality-check

This commit is contained in:
schaeff 2022-09-27 10:08:30 +02:00
commit 3b7fec990b
26 changed files with 692 additions and 496 deletions

1
Cargo.lock generated
View file

@ -3025,6 +3025,7 @@ dependencies = [
"ark-bls12-377",
"cfg-if 0.1.10",
"csv",
"derivative",
"num-bigint 0.2.6",
"pairing_ce",
"serde",

View file

@ -0,0 +1 @@
Disallow the use of the `private` and `public` keywords on non-entrypoint functions

View file

@ -0,0 +1 @@
Fix duplicate constraint optimiser

View file

@ -20,7 +20,4 @@ serde_json = { version = "1.0", features = ["preserve_order"] }
zokrates_embed = { version = "0.1.0", path = "../zokrates_embed", default-features = false }
pairing_ce = { version = "^0.21", optional = true }
ark-bls12-377 = { version = "^0.3.0", features = ["curve"], default-features = false, optional = true }
derivative = "2.2.0"

View file

@ -1,5 +1,6 @@
use crate::common::FormatString;
use crate::typed::ConcreteType;
use derivative::Derivative;
use serde::{Deserialize, Serialize};
use std::collections::BTreeSet;
use std::fmt;
@ -25,9 +26,14 @@ pub use crate::common::Variable;
pub use self::witness::Witness;
#[derive(Debug, Serialize, Deserialize, Clone, Hash, PartialEq, Eq)]
#[derive(Debug, Serialize, Deserialize, Clone, Derivative)]
#[derivative(Hash, PartialEq, Eq)]
pub enum Statement<T> {
Constraint(QuadComb<T>, LinComb<T>, Option<RuntimeError>),
Constraint(
QuadComb<T>,
LinComb<T>,
#[derivative(Hash = "ignore")] Option<RuntimeError>,
),
Directive(Directive<T>),
Log(FormatString, Vec<(ConcreteType, Vec<LinComb<T>>)>),
}
@ -74,7 +80,16 @@ impl<T: Field> fmt::Display for Directive<T> {
impl<T: Field> fmt::Display for Statement<T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
Statement::Constraint(ref quad, ref lin, _) => write!(f, "{} == {}", quad, lin),
Statement::Constraint(ref quad, ref lin, ref error) => write!(
f,
"{} == {}{}",
quad,
lin,
error
.as_ref()
.map(|e| format!(" // {}", e))
.unwrap_or_else(|| "".to_string())
),
Statement::Directive(ref s) => write!(f, "{}", s),
Statement::Log(ref s, ref expressions) => write!(
f,

View file

@ -341,6 +341,16 @@ pub trait Folder<'ast, T: Field>: Sized {
fold_member_expression(self, ty, e)
}
fn fold_identifier_expression<
E: Expr<'ast, T> + Id<'ast, T> + From<TypedExpression<'ast, T>>,
>(
&mut self,
ty: &E::Ty,
e: IdentifierExpression<'ast, E>,
) -> IdentifierOrExpression<'ast, T, E> {
fold_identifier_expression(self, ty, e)
}
fn fold_element_expression<
E: Expr<'ast, T> + Element<'ast, T> + From<TypedExpression<'ast, T>>,
>(
@ -534,6 +544,19 @@ pub fn fold_statement<'ast, T: Field, F: Folder<'ast, T>>(
vec![res]
}
pub fn fold_identifier_expression<
'ast,
T: Field,
E: Expr<'ast, T> + Id<'ast, T> + From<TypedExpression<'ast, T>>,
F: Folder<'ast, T>,
>(
f: &mut F,
_: &E::Ty,
e: IdentifierExpression<'ast, E>,
) -> IdentifierOrExpression<'ast, T, E> {
IdentifierOrExpression::Identifier(IdentifierExpression::new(f.fold_name(e.id)))
}
pub fn fold_embed_call<'ast, T: Field, F: Folder<'ast, T>>(
f: &mut F,
e: EmbedCall<'ast, T>,
@ -571,8 +594,11 @@ pub fn fold_array_expression_inner<'ast, T: Field, F: Folder<'ast, T>>(
use ArrayExpressionInner::*;
match e {
Identifier(id) => match f.fold_identifier_expression(ty, id) {
IdentifierOrExpression::Identifier(i) => ArrayExpressionInner::Identifier(i),
IdentifierOrExpression::Expression(u) => u,
},
Block(block) => Block(f.fold_block_expression(block)),
Identifier(id) => Identifier(f.fold_name(id)),
Value(exprs) => Value(
exprs
.into_iter()
@ -621,8 +647,11 @@ pub fn fold_struct_expression_inner<'ast, T: Field, F: Folder<'ast, T>>(
use StructExpressionInner::*;
match e {
Identifier(id) => match f.fold_identifier_expression(ty, id) {
IdentifierOrExpression::Identifier(i) => Identifier(i),
IdentifierOrExpression::Expression(u) => u,
},
Block(block) => Block(f.fold_block_expression(block)),
Identifier(id) => Identifier(f.fold_name(id)),
Value(exprs) => Value(exprs.into_iter().map(|e| f.fold_expression(e)).collect()),
FunctionCall(function_call) => match f.fold_function_call_expression(ty, function_call) {
FunctionCallOrExpression::FunctionCall(function_call) => FunctionCall(function_call),
@ -656,7 +685,10 @@ pub fn fold_tuple_expression_inner<'ast, T: Field, F: Folder<'ast, T>>(
match e {
Block(block) => Block(f.fold_block_expression(block)),
Identifier(id) => Identifier(f.fold_name(id)),
Identifier(id) => match f.fold_identifier_expression(ty, id) {
IdentifierOrExpression::Identifier(i) => Identifier(i),
IdentifierOrExpression::Expression(u) => u,
},
Value(exprs) => Value(exprs.into_iter().map(|e| f.fold_expression(e)).collect()),
FunctionCall(function_call) => match f.fold_function_call_expression(ty, function_call) {
FunctionCallOrExpression::FunctionCall(function_call) => FunctionCall(function_call),
@ -688,9 +720,12 @@ pub fn fold_field_expression<'ast, T: Field, F: Folder<'ast, T>>(
use FieldElementExpression::*;
match e {
Identifier(id) => match f.fold_identifier_expression(&Type::FieldElement, id) {
IdentifierOrExpression::Identifier(i) => Identifier(i),
IdentifierOrExpression::Expression(u) => u,
},
Block(block) => Block(f.fold_block_expression(block)),
Number(n) => Number(n),
Identifier(id) => Identifier(f.fold_name(id)),
Add(box e1, box e2) => {
let e1 = f.fold_field_expression(e1);
let e2 = f.fold_field_expression(e2);
@ -840,9 +875,12 @@ pub fn fold_boolean_expression<'ast, T: Field, F: Folder<'ast, T>>(
use BooleanExpression::*;
match e {
Identifier(id) => match f.fold_identifier_expression(&Type::Boolean, id) {
IdentifierOrExpression::Identifier(i) => Identifier(i),
IdentifierOrExpression::Expression(u) => u,
},
Block(block) => BooleanExpression::Block(f.fold_block_expression(block)),
Value(v) => BooleanExpression::Value(v),
Identifier(id) => BooleanExpression::Identifier(f.fold_name(id)),
FieldEq(e) => match f.fold_eq_expression(e) {
EqOrBoolean::Eq(e) => BooleanExpression::FieldEq(e),
EqOrBoolean::Boolean(u) => u,
@ -966,9 +1004,12 @@ pub fn fold_uint_expression_inner<'ast, T: Field, F: Folder<'ast, T>>(
use UExpressionInner::*;
match e {
Identifier(id) => match f.fold_identifier_expression(&ty, id) {
IdentifierOrExpression::Identifier(i) => Identifier(i),
IdentifierOrExpression::Expression(u) => u,
},
Block(block) => Block(f.fold_block_expression(block)),
Value(v) => Value(v),
Identifier(id) => Identifier(f.fold_name(id)),
Add(box left, box right) => {
let left = f.fold_uint_expression(left);
let right = f.fold_uint_expression(right);

View file

@ -26,6 +26,7 @@ pub use self::types::{
GArrayType, GStructType, GType, GenericIdentifier, Signature, StructType, TupleType, Type,
UBitwidth,
};
use self::types::{ConcreteArrayType, ConcreteStructType};
use crate::typed::types::{ConcreteGenericsAssignment, IntoType};
use crate::untyped::Position;
@ -970,6 +971,27 @@ impl<'ast, T, E> BlockExpression<'ast, T, E> {
}
}
#[derive(Clone, PartialEq, Debug, Hash, Eq, PartialOrd, Ord)]
pub struct IdentifierExpression<'ast, E> {
pub id: Identifier<'ast>,
ty: PhantomData<E>,
}
impl<'ast, E> IdentifierExpression<'ast, E> {
pub fn new(id: Identifier<'ast>) -> Self {
IdentifierExpression {
id,
ty: PhantomData,
}
}
}
impl<'ast, E> fmt::Display for IdentifierExpression<'ast, E> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}", self.id)
}
}
#[derive(Clone, PartialEq, Debug, Hash, Eq, PartialOrd, Ord)]
pub struct MemberExpression<'ast, T, E> {
pub struc: Box<StructExpression<'ast, T>>,
@ -1145,7 +1167,7 @@ impl<'ast, T: fmt::Display, E> fmt::Display for FunctionCallExpression<'ast, T,
pub enum FieldElementExpression<'ast, T> {
Block(BlockExpression<'ast, T, Self>),
Number(T),
Identifier(Identifier<'ast>),
Identifier(IdentifierExpression<'ast, Self>),
Add(
Box<FieldElementExpression<'ast, T>>,
Box<FieldElementExpression<'ast, T>>,
@ -1222,7 +1244,7 @@ impl<'ast, T> From<T> for FieldElementExpression<'ast, T> {
#[derive(Clone, PartialEq, Debug, Hash, Eq, PartialOrd, Ord)]
pub enum BooleanExpression<'ast, T> {
Block(BlockExpression<'ast, T, Self>),
Identifier(Identifier<'ast>),
Identifier(IdentifierExpression<'ast, Self>),
Value(bool),
FieldLt(
Box<FieldElementExpression<'ast, T>>,
@ -1363,7 +1385,7 @@ impl<'ast, T> std::iter::FromIterator<TypedExpressionOrSpread<'ast, T>> for Arra
#[derive(Clone, PartialEq, Debug, Hash, Eq, PartialOrd, Ord)]
pub enum ArrayExpressionInner<'ast, T> {
Block(BlockExpression<'ast, T, ArrayExpression<'ast, T>>),
Identifier(Identifier<'ast>),
Identifier(IdentifierExpression<'ast, ArrayExpression<'ast, T>>),
Value(ArrayValue<'ast, T>),
FunctionCall(FunctionCallExpression<'ast, T, ArrayExpression<'ast, T>>),
Conditional(ConditionalExpression<'ast, T, ArrayExpression<'ast, T>>),
@ -1428,7 +1450,7 @@ impl<'ast, T> StructExpression<'ast, T> {
#[derive(Clone, PartialEq, Debug, Hash, Eq, PartialOrd, Ord)]
pub enum StructExpressionInner<'ast, T> {
Block(BlockExpression<'ast, T, StructExpression<'ast, T>>),
Identifier(Identifier<'ast>),
Identifier(IdentifierExpression<'ast, StructExpression<'ast, T>>),
Value(Vec<TypedExpression<'ast, T>>),
FunctionCall(FunctionCallExpression<'ast, T, StructExpression<'ast, T>>),
Conditional(ConditionalExpression<'ast, T, StructExpression<'ast, T>>),
@ -1470,7 +1492,7 @@ impl<'ast, T> TupleExpression<'ast, T> {
#[derive(Clone, PartialEq, Debug, Hash, Eq, PartialOrd, Ord)]
pub enum TupleExpressionInner<'ast, T> {
Block(BlockExpression<'ast, T, TupleExpression<'ast, T>>),
Identifier(Identifier<'ast>),
Identifier(IdentifierExpression<'ast, TupleExpression<'ast, T>>),
Value(Vec<TypedExpression<'ast, T>>),
FunctionCall(FunctionCallExpression<'ast, T, TupleExpression<'ast, T>>),
Conditional(ConditionalExpression<'ast, T, TupleExpression<'ast, T>>),
@ -1762,14 +1784,14 @@ impl<'ast, T: fmt::Display> fmt::Display for ArrayExpressionInner<'ast, T> {
impl<'ast, T: Field> From<Variable<'ast, T>> for TypedExpression<'ast, T> {
fn from(v: Variable<'ast, T>) -> Self {
match v.get_type() {
Type::FieldElement => FieldElementExpression::Identifier(v.id).into(),
Type::Boolean => BooleanExpression::Identifier(v.id).into(),
Type::Array(ty) => ArrayExpressionInner::Identifier(v.id)
Type::FieldElement => FieldElementExpression::identifier(v.id).into(),
Type::Boolean => BooleanExpression::identifier(v.id).into(),
Type::Array(ty) => ArrayExpression::identifier(v.id)
.annotate(*ty.ty, *ty.size)
.into(),
Type::Struct(ty) => StructExpressionInner::Identifier(v.id).annotate(ty).into(),
Type::Tuple(ty) => TupleExpressionInner::Identifier(v.id).annotate(ty).into(),
Type::Uint(w) => UExpressionInner::Identifier(v.id).annotate(w).into(),
Type::Struct(ty) => StructExpression::identifier(v.id).annotate(ty).into(),
Type::Tuple(ty) => TupleExpression::identifier(v.id).annotate(ty).into(),
Type::Uint(w) => UExpression::identifier(v.id).annotate(w).into(),
Type::Int => unreachable!(),
}
}
@ -1779,7 +1801,8 @@ impl<'ast, T: Field> From<Variable<'ast, T>> for TypedExpression<'ast, T> {
pub trait Expr<'ast, T>: fmt::Display + From<TypedExpression<'ast, T>> {
type Inner;
type Ty: Clone + IntoType<'ast, T>;
type Ty: Clone + IntoType<UExpression<'ast, T>>;
type ConcreteTy: Clone + IntoType<u32>;
fn ty(&self) -> &Self::Ty;
@ -1793,6 +1816,7 @@ pub trait Expr<'ast, T>: fmt::Display + From<TypedExpression<'ast, T>> {
impl<'ast, T: Field> Expr<'ast, T> for FieldElementExpression<'ast, T> {
type Inner = Self;
type Ty = Type<'ast, T>;
type ConcreteTy = ConcreteType;
fn ty(&self) -> &Self::Ty {
&Type::FieldElement
@ -1814,6 +1838,7 @@ impl<'ast, T: Field> Expr<'ast, T> for FieldElementExpression<'ast, T> {
impl<'ast, T: Field> Expr<'ast, T> for BooleanExpression<'ast, T> {
type Inner = Self;
type Ty = Type<'ast, T>;
type ConcreteTy = ConcreteType;
fn ty(&self) -> &Self::Ty {
&Type::Boolean
@ -1835,6 +1860,7 @@ impl<'ast, T: Field> Expr<'ast, T> for BooleanExpression<'ast, T> {
impl<'ast, T: Field> Expr<'ast, T> for UExpression<'ast, T> {
type Inner = UExpressionInner<'ast, T>;
type Ty = UBitwidth;
type ConcreteTy = UBitwidth;
fn ty(&self) -> &Self::Ty {
&self.bitwidth
@ -1856,6 +1882,7 @@ impl<'ast, T: Field> Expr<'ast, T> for UExpression<'ast, T> {
impl<'ast, T: Field> Expr<'ast, T> for StructExpression<'ast, T> {
type Inner = StructExpressionInner<'ast, T>;
type Ty = StructType<'ast, T>;
type ConcreteTy = ConcreteStructType;
fn ty(&self) -> &Self::Ty {
&self.ty
@ -1877,6 +1904,7 @@ impl<'ast, T: Field> Expr<'ast, T> for StructExpression<'ast, T> {
impl<'ast, T: Field> Expr<'ast, T> for ArrayExpression<'ast, T> {
type Inner = ArrayExpressionInner<'ast, T>;
type Ty = ArrayType<'ast, T>;
type ConcreteTy = ConcreteArrayType;
fn ty(&self) -> &Self::Ty {
&self.ty
@ -1898,6 +1926,7 @@ impl<'ast, T: Field> Expr<'ast, T> for ArrayExpression<'ast, T> {
impl<'ast, T: Field> Expr<'ast, T> for TupleExpression<'ast, T> {
type Inner = TupleExpressionInner<'ast, T>;
type Ty = TupleType<'ast, T>;
type ConcreteTy = ConcreteTupleType;
fn ty(&self) -> &Self::Ty {
&self.ty
@ -1919,6 +1948,7 @@ impl<'ast, T: Field> Expr<'ast, T> for TupleExpression<'ast, T> {
impl<'ast, T: Field> Expr<'ast, T> for IntExpression<'ast, T> {
type Inner = Self;
type Ty = Type<'ast, T>;
type ConcreteTy = ConcreteType;
fn ty(&self) -> &Self::Ty {
&Type::Int
@ -1958,6 +1988,11 @@ pub enum MemberOrExpression<'ast, T, E: Expr<'ast, T>> {
Expression(E::Inner),
}
pub enum IdentifierOrExpression<'ast, T, E: Expr<'ast, T>> {
Identifier(IdentifierExpression<'ast, E>),
Expression(E::Inner),
}
pub enum ElementOrExpression<'ast, T, E: Expr<'ast, T>> {
Element(ElementExpression<'ast, T, E>),
Expression(E::Inner),
@ -2318,37 +2353,37 @@ pub trait Id<'ast, T>: Expr<'ast, T> {
impl<'ast, T: Field> Id<'ast, T> for FieldElementExpression<'ast, T> {
fn identifier(id: Identifier<'ast>) -> Self::Inner {
FieldElementExpression::Identifier(id)
FieldElementExpression::Identifier(IdentifierExpression::new(id))
}
}
impl<'ast, T: Field> Id<'ast, T> for BooleanExpression<'ast, T> {
fn identifier(id: Identifier<'ast>) -> Self::Inner {
BooleanExpression::Identifier(id)
BooleanExpression::Identifier(IdentifierExpression::new(id))
}
}
impl<'ast, T: Field> Id<'ast, T> for UExpression<'ast, T> {
fn identifier(id: Identifier<'ast>) -> Self::Inner {
UExpressionInner::Identifier(id)
UExpressionInner::Identifier(IdentifierExpression::new(id))
}
}
impl<'ast, T: Field> Id<'ast, T> for ArrayExpression<'ast, T> {
fn identifier(id: Identifier<'ast>) -> Self::Inner {
ArrayExpressionInner::Identifier(id)
ArrayExpressionInner::Identifier(IdentifierExpression::new(id))
}
}
impl<'ast, T: Field> Id<'ast, T> for StructExpression<'ast, T> {
fn identifier(id: Identifier<'ast>) -> Self::Inner {
StructExpressionInner::Identifier(id)
StructExpressionInner::Identifier(IdentifierExpression::new(id))
}
}
impl<'ast, T: Field> Id<'ast, T> for TupleExpression<'ast, T> {
fn identifier(id: Identifier<'ast>) -> Self::Inner {
TupleExpressionInner::Identifier(id)
TupleExpressionInner::Identifier(IdentifierExpression::new(id))
}
}

View file

@ -213,6 +213,14 @@ pub trait ResultFolder<'ast, T: Field>: Sized {
fold_block_expression(self, block)
}
fn fold_identifier_expression<E: Expr<'ast, T> + Id<'ast, T> + ResultFold<'ast, T>>(
&mut self,
ty: &E::Ty,
id: IdentifierExpression<'ast, E>,
) -> Result<IdentifierOrExpression<'ast, T, E>, Self::Error> {
fold_identifier_expression(self, ty, id)
}
fn fold_member_expression<
E: Expr<'ast, T> + Member<'ast, T> + From<TypedExpression<'ast, T>>,
>(
@ -576,7 +584,10 @@ pub fn fold_array_expression_inner<'ast, T: Field, F: ResultFolder<'ast, T>>(
let e = match e {
Block(block) => Block(f.fold_block_expression(block)?),
Identifier(id) => Identifier(f.fold_name(id)?),
Identifier(id) => match f.fold_identifier_expression(ty, id)? {
IdentifierOrExpression::Identifier(i) => ArrayExpressionInner::Identifier(i),
IdentifierOrExpression::Expression(u) => u,
},
Value(exprs) => Value(
exprs
.into_iter()
@ -643,8 +654,11 @@ pub fn fold_struct_expression_inner<'ast, T: Field, F: ResultFolder<'ast, T>>(
use StructExpressionInner::*;
let e = match e {
Identifier(id) => match f.fold_identifier_expression(ty, id)? {
IdentifierOrExpression::Identifier(i) => Identifier(i),
IdentifierOrExpression::Expression(u) => u,
},
Block(block) => Block(f.fold_block_expression(block)?),
Identifier(id) => Identifier(f.fold_name(id)?),
Value(exprs) => Value(
exprs
.into_iter()
@ -684,7 +698,10 @@ pub fn fold_tuple_expression_inner<'ast, T: Field, F: ResultFolder<'ast, T>>(
let e = match e {
Block(block) => Block(f.fold_block_expression(block)?),
Identifier(id) => Identifier(f.fold_name(id)?),
Identifier(id) => match f.fold_identifier_expression(ty, id)? {
IdentifierOrExpression::Identifier(i) => Identifier(i),
IdentifierOrExpression::Expression(u) => u,
},
Value(exprs) => Value(
exprs
.into_iter()
@ -722,9 +739,12 @@ pub fn fold_field_expression<'ast, T: Field, F: ResultFolder<'ast, T>>(
use FieldElementExpression::*;
let e = match e {
Identifier(id) => match f.fold_identifier_expression(&Type::FieldElement, id)? {
IdentifierOrExpression::Identifier(i) => Identifier(i),
IdentifierOrExpression::Expression(u) => u,
},
Block(block) => Block(f.fold_block_expression(block)?),
Number(n) => Number(n),
Identifier(id) => Identifier(f.fold_name(id)?),
Add(box e1, box e2) => {
let e1 = f.fold_field_expression(e1)?;
let e2 = f.fold_field_expression(e2)?;
@ -850,6 +870,21 @@ pub fn fold_member_expression<
)))
}
pub fn fold_identifier_expression<
'ast,
T: Field,
E: Expr<'ast, T> + Id<'ast, T> + From<TypedExpression<'ast, T>>,
F: ResultFolder<'ast, T>,
>(
f: &mut F,
_: &E::Ty,
e: IdentifierExpression<'ast, E>,
) -> Result<IdentifierOrExpression<'ast, T, E>, F::Error> {
Ok(IdentifierOrExpression::Identifier(
IdentifierExpression::new(f.fold_name(e.id)?),
))
}
pub fn fold_eq_expression<'ast, T: Field, E: ResultFold<'ast, T>, F: ResultFolder<'ast, T>>(
f: &mut F,
e: EqExpression<E>,
@ -925,9 +960,12 @@ pub fn fold_boolean_expression<'ast, T: Field, F: ResultFolder<'ast, T>>(
use BooleanExpression::*;
let e = match e {
Identifier(id) => match f.fold_identifier_expression(&Type::Boolean, id)? {
IdentifierOrExpression::Identifier(i) => Identifier(i),
IdentifierOrExpression::Expression(u) => u,
},
Block(block) => Block(f.fold_block_expression(block)?),
Value(v) => Value(v),
Identifier(id) => Identifier(f.fold_name(id)?),
FieldEq(e) => match f.fold_eq_expression(e)? {
EqOrBoolean::Eq(e) => FieldEq(e),
EqOrBoolean::Boolean(u) => u,
@ -1050,9 +1088,12 @@ pub fn fold_uint_expression_inner<'ast, T: Field, F: ResultFolder<'ast, T>>(
use UExpressionInner::*;
let e = match e {
Identifier(id) => match f.fold_identifier_expression(&ty, id)? {
IdentifierOrExpression::Identifier(i) => Identifier(i),
IdentifierOrExpression::Expression(u) => u,
},
Block(block) => Block(f.fold_block_expression(block)?),
Value(v) => Value(v),
Identifier(id) => Identifier(f.fold_name(id)?),
Add(box left, box right) => {
let left = f.fold_uint_expression(left)?;
let right = f.fold_uint_expression(right)?;

View file

@ -7,38 +7,39 @@ use std::collections::BTreeMap;
use std::fmt;
use std::hash::{Hash, Hasher};
use std::path::{Path, PathBuf};
use zokrates_field::Field;
pub trait IntoType<'ast, T> {
fn into_type(self) -> Type<'ast, T>;
pub trait IntoType<S> {
fn into_type(self) -> GType<S>;
}
impl<'ast, T> IntoType<'ast, T> for Type<'ast, T> {
fn into_type(self) -> Type<'ast, T> {
impl<S> IntoType<S> for GType<S> {
fn into_type(self) -> GType<S> {
self
}
}
impl<'ast, T> IntoType<'ast, T> for StructType<'ast, T> {
fn into_type(self) -> Type<'ast, T> {
Type::Struct(self)
impl<S> IntoType<S> for GStructType<S> {
fn into_type(self) -> GType<S> {
GType::Struct(self)
}
}
impl<'ast, T> IntoType<'ast, T> for ArrayType<'ast, T> {
fn into_type(self) -> Type<'ast, T> {
Type::Array(self)
impl<S> IntoType<S> for GArrayType<S> {
fn into_type(self) -> GType<S> {
GType::Array(self)
}
}
impl<'ast, T> IntoType<'ast, T> for TupleType<'ast, T> {
fn into_type(self) -> Type<'ast, T> {
Type::Tuple(self)
impl<S> IntoType<S> for GTupleType<S> {
fn into_type(self) -> GType<S> {
GType::Tuple(self)
}
}
impl<'ast, T> IntoType<'ast, T> for UBitwidth {
fn into_type(self) -> Type<'ast, T> {
Type::Uint(self)
impl<S> IntoType<S> for UBitwidth {
fn into_type(self) -> GType<S> {
GType::Uint(self)
}
}
@ -232,19 +233,17 @@ impl<'ast, T> From<u32> for UExpression<'ast, T> {
}
}
impl<'ast, T> From<DeclarationConstant<'ast, T>> for UExpression<'ast, T> {
impl<'ast, T: Field> From<DeclarationConstant<'ast, T>> for UExpression<'ast, T> {
fn from(c: DeclarationConstant<'ast, T>) -> Self {
match c {
DeclarationConstant::Generic(g) => {
UExpressionInner::Identifier(CoreIdentifier::from(g).into())
.annotate(UBitwidth::B32)
UExpression::identifier(CoreIdentifier::from(g).into()).annotate(UBitwidth::B32)
}
DeclarationConstant::Concrete(v) => {
UExpressionInner::Value(v as u128).annotate(UBitwidth::B32)
}
DeclarationConstant::Constant(v) => {
UExpressionInner::Identifier(CoreIdentifier::from(v).into())
.annotate(UBitwidth::B32)
UExpression::identifier(CoreIdentifier::from(v).into()).annotate(UBitwidth::B32)
}
DeclarationConstant::Expression(e) => e.try_into().unwrap(),
}
@ -1139,9 +1138,9 @@ pub fn check_type<'ast, T, S: Clone + PartialEq + PartialEq<u32>>(
}
}
impl<'ast, T> From<CanonicalConstantIdentifier<'ast>> for UExpression<'ast, T> {
impl<'ast, T: Field> From<CanonicalConstantIdentifier<'ast>> for UExpression<'ast, T> {
fn from(c: CanonicalConstantIdentifier<'ast>) -> Self {
UExpressionInner::Identifier(Identifier::from(CoreIdentifier::Constant(c)))
UExpression::identifier(Identifier::from(CoreIdentifier::Constant(c)))
.annotate(UBitwidth::B32)
}
}
@ -1227,7 +1226,7 @@ pub use self::signature::{
try_from_g_signature, ConcreteSignature, DeclarationSignature, GSignature, Signature,
};
use super::ShadowedIdentifier;
use super::{Id, ShadowedIdentifier};
pub mod signature {
use super::*;
@ -1296,7 +1295,7 @@ pub mod signature {
}
}
impl<'ast, T: Clone + PartialEq> DeclarationSignature<'ast, T> {
impl<'ast, T: Field> DeclarationSignature<'ast, T> {
pub fn specialize(
&self,
values: Vec<Option<u32>>,

View file

@ -164,7 +164,7 @@ impl<'ast, T> PartialEq<u32> for UExpression<'ast, T> {
#[derive(Clone, PartialEq, Eq, Hash, Debug, PartialOrd, Ord)]
pub enum UExpressionInner<'ast, T> {
Block(BlockExpression<'ast, T, UExpression<'ast, T>>),
Identifier(Identifier<'ast>),
Identifier(IdentifierExpression<'ast, UExpression<'ast, T>>),
Value(u128),
Add(Box<UExpression<'ast, T>>, Box<UExpression<'ast, T>>),
Sub(Box<UExpression<'ast, T>>, Box<UExpression<'ast, T>>),

View file

@ -223,13 +223,10 @@ impl<'ast> From<pest::Parameter<'ast>> for untyped::ParameterNode<'ast> {
fn from(param: pest::Parameter<'ast>) -> untyped::ParameterNode<'ast> {
use crate::untyped::NodeValue;
let is_private = param
.visibility
.map(|v| match v {
pest::Visibility::Private(_) => true,
pest::Visibility::Public(_) => false,
})
.unwrap_or(false);
let is_private = param.visibility.map(|v| match v {
pest::Visibility::Private(_) => true,
pest::Visibility::Public(_) => false,
});
let is_mutable = param.mutable.is_some();
@ -949,9 +946,10 @@ mod tests {
.into(),
)
.into(),
untyped::Parameter::public(
untyped::Parameter::new(
untyped::Variable::mutable("b", UnresolvedType::Boolean.mock())
.into(),
.mock(),
None,
)
.into(),
],

View file

@ -4,20 +4,20 @@ use std::fmt;
#[derive(Clone, PartialEq)]
pub struct Parameter<'ast> {
pub id: VariableNode<'ast>,
pub is_private: bool,
pub is_private: Option<bool>,
}
impl<'ast> Parameter<'ast> {
pub fn new(v: VariableNode<'ast>, is_private: bool) -> Self {
pub fn new(v: VariableNode<'ast>, is_private: Option<bool>) -> Self {
Parameter { id: v, is_private }
}
pub fn private(v: VariableNode<'ast>) -> Self {
Self::new(v, true)
Self::new(v, Some(true))
}
pub fn public(v: VariableNode<'ast>) -> Self {
Self::new(v, false)
Self::new(v, Some(false))
}
}
@ -25,7 +25,12 @@ pub type ParameterNode<'ast> = Node<Parameter<'ast>>;
impl<'ast> fmt::Display for Parameter<'ast> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let visibility = if self.is_private { "private " } else { "" };
let visibility = if let Some(true) = self.is_private {
"private "
} else {
""
};
write!(
f,
"{}{} {}",

View file

@ -0,0 +1,8 @@
def mul(private field a) -> field { // `private` should not be allowed here
return a * a;
}
def main(private field a, field b) {
assert(mul(a) == b);
return;
}

View file

@ -0,0 +1,8 @@
def mul(public field a) -> field { // `public` should not be allowed here
return a * a;
}
def main(field a, field b) {
assert(mul(a) == b);
return;
}

View file

@ -89,6 +89,7 @@ type ConstantMap<'ast, T> =
/// The global state of the program during semantic checks
#[derive(Debug)]
struct State<'ast, T> {
main_id: OwnedModuleId,
/// The modules yet to be checked, which we consume as we explore the dependency tree
modules: Modules<'ast>,
/// The already checked modules, which we're returning at the end
@ -166,8 +167,9 @@ impl<'ast, T: std::cmp::Ord> SymbolUnifier<'ast, T> {
}
impl<'ast, T: Field> State<'ast, T> {
fn new(modules: Modules<'ast>) -> Self {
fn new(modules: Modules<'ast>, main_id: OwnedModuleId) -> Self {
State {
main_id,
modules,
typed_modules: BTreeMap::new(),
types: BTreeMap::new(),
@ -340,12 +342,13 @@ impl<'ast, T: Field> Checker<'ast, T> {
&mut self,
program: Program<'ast>,
) -> Result<TypedProgram<'ast, T>, Vec<Error>> {
let mut state = State::new(program.modules);
let main_id = program.main.clone();
let mut state = State::new(program.modules, main_id.clone());
let mut errors = vec![];
// recursively type-check modules starting with `main`
match self.check_module(&program.main, &mut state) {
match self.check_module(&main_id, &mut state) {
Ok(()) => {}
Err(e) => errors.extend(e),
};
@ -354,9 +357,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
return Err(errors);
}
let main_id = program.main.clone();
Checker::check_single_main(state.typed_modules.get(&program.main).unwrap()).map_err(
Checker::check_single_main(state.typed_modules.get(&main_id).unwrap()).map_err(
|inner| {
vec![Error {
inner,
@ -744,7 +745,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
}
}
Symbol::Here(SymbolDefinition::Function(f)) => {
match self.check_function(f, module_id, state) {
match self.check_function(declaration.id, f, module_id, state) {
Ok(funct) => {
match symbol_unifier
.insert_function(declaration.id, funct.signature.clone())
@ -1095,6 +1096,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
fn check_function(
&mut self,
id: Identifier<'ast>,
funct_node: FunctionNode<'ast>,
module_id: &ModuleId,
state: &State<'ast, T>,
@ -1130,7 +1132,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
// for declaration signatures, generics cannot be ignored
generics.0.insert(
generic.clone(),
UExpressionInner::Identifier(self.id_in_this_scope(generic.name()).into())
UExpression::identifier(self.id_in_this_scope(generic.name()).into())
.annotate(UBitwidth::B32),
);
@ -1143,6 +1145,16 @@ impl<'ast, T: Field> Checker<'ast, T> {
let arg = arg.value;
// parameters defined on a non-entrypoint function should not have visibility modifiers
if (state.main_id != module_id || id != "main") && arg.is_private.is_some() {
errors.push(ErrorInner {
pos: Some(pos),
message:
"Visibility modifiers on arguments are only allowed on the entrypoint function"
.into(),
});
}
let decl_v = DeclarationVariable::new(
self.id_in_this_scope(arg.id.value.id),
decl_ty.clone(),
@ -1173,7 +1185,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
arguments_checked.push(DeclarationParameter {
id: decl_v,
private: arg.is_private,
private: arg.is_private.unwrap_or(false),
});
}
@ -2350,28 +2362,22 @@ impl<'ast, T: Field> Checker<'ast, T> {
Some(info) => {
let id = info.id;
match info.ty.clone() {
Type::Boolean => Ok(BooleanExpression::Identifier(id.into()).into()),
Type::Uint(bitwidth) => Ok(UExpressionInner::Identifier(id.into())
.annotate(bitwidth)
.into()),
Type::Boolean => Ok(BooleanExpression::identifier(id.into()).into()),
Type::Uint(bitwidth) => {
Ok(UExpression::identifier(id.into()).annotate(bitwidth).into())
}
Type::FieldElement => {
Ok(FieldElementExpression::Identifier(id.into()).into())
}
Type::Array(array_type) => {
Ok(ArrayExpressionInner::Identifier(id.into())
.annotate(*array_type.ty, *array_type.size)
.into())
}
Type::Struct(members) => {
Ok(StructExpressionInner::Identifier(id.into())
.annotate(members)
.into())
}
Type::Tuple(tuple_ty) => {
Ok(TupleExpressionInner::Identifier(id.into())
.annotate(tuple_ty)
.into())
Ok(FieldElementExpression::identifier(id.into()).into())
}
Type::Array(array_type) => Ok(ArrayExpression::identifier(id.into())
.annotate(*array_type.ty, *array_type.size)
.into()),
Type::Struct(members) => Ok(StructExpression::identifier(id.into())
.annotate(members)
.into()),
Type::Tuple(tuple_ty) => Ok(TupleExpression::identifier(id.into())
.annotate(tuple_ty)
.into()),
Type::Int => unreachable!(),
}
}
@ -3751,14 +3757,14 @@ mod tests {
.mock()
}
/// Helper function to create: (private field a) { return; }
/// Helper function to create: (field a) { return; }
fn function1() -> FunctionNode<'static> {
let statements = vec![Statement::Return(None).mock()];
let arguments = vec![untyped::Parameter {
id: untyped::Variable::immutable("a", UnresolvedType::FieldElement.mock()).mock(),
is_private: true,
}
let arguments = vec![untyped::Parameter::new(
untyped::Variable::immutable("a", UnresolvedType::FieldElement.mock()).mock(),
None,
)
.mock()];
let signature =
@ -3885,6 +3891,7 @@ mod tests {
vec![("foo".into(), foo), ("bar".into(), bar)]
.into_iter()
.collect(),
(*MODULE_ID).clone(),
);
let mut checker: Checker<Bn128Field> = Checker::default();
@ -3939,6 +3946,7 @@ mod tests {
let mut state = State::<Bn128Field>::new(
vec![((*MODULE_ID).clone(), module)].into_iter().collect(),
(*MODULE_ID).clone(),
);
let mut checker: Checker<Bn128Field> = Checker::default();
@ -3952,10 +3960,10 @@ mod tests {
#[test]
fn duplicate_function_declaration_generic() {
// def foo<P>(private field[P] a) {
// def foo<P>(field[P] a) {
// return;
// }
// def foo(private field[3] a) {
// def foo(field[3] a) {
// return;
// }
//
@ -3963,7 +3971,7 @@ mod tests {
let mut f0 = function0();
f0.value.arguments = vec![untyped::Parameter::private(
f0.value.arguments = vec![untyped::Parameter::new(
untyped::Variable::immutable(
"a",
UnresolvedType::array(
@ -3973,6 +3981,7 @@ mod tests {
.mock(),
)
.mock(),
None,
)
.mock()];
f0.value.signature = UnresolvedSignature::new()
@ -3985,7 +3994,7 @@ mod tests {
let mut f1 = function0();
f1.value.arguments = vec![untyped::Parameter::private(
f1.value.arguments = vec![untyped::Parameter::new(
untyped::Variable::immutable(
"a",
UnresolvedType::array(
@ -3995,6 +4004,7 @@ mod tests {
.mock(),
)
.mock(),
None,
)
.mock()];
f1.value.signature = UnresolvedSignature::new().inputs(vec![UnresolvedType::array(
@ -4018,7 +4028,10 @@ mod tests {
],
};
let mut state = State::new(vec![((*MODULE_ID).clone(), module)].into_iter().collect());
let mut state = State::new(
vec![((*MODULE_ID).clone(), module)].into_iter().collect(),
(*MODULE_ID).clone(),
);
let mut checker: Checker<Bn128Field> = Checker::default();
assert!(checker.check_module(&*MODULE_ID, &mut state).is_ok());
@ -4057,8 +4070,10 @@ mod tests {
],
};
let mut state =
State::new(vec![((*MODULE_ID).clone(), module)].into_iter().collect());
let mut state = State::new(
vec![((*MODULE_ID).clone(), module)].into_iter().collect(),
(*MODULE_ID).clone(),
);
let mut checker: Checker<Bn128Field> = Checker::default();
assert!(checker.check_module(&*MODULE_ID, &mut state).is_ok());
@ -4111,8 +4126,10 @@ mod tests {
],
};
let mut state =
State::new(vec![((*MODULE_ID).clone(), module)].into_iter().collect());
let mut state = State::new(
vec![((*MODULE_ID).clone(), module)].into_iter().collect(),
(*MODULE_ID).clone(),
);
let mut checker: Checker<Bn128Field> = Checker::default();
assert_eq!(
@ -4129,7 +4146,7 @@ mod tests {
// def foo() {
// return;
// }
// def foo(a) {
// def foo(field a) {
// return;
// }
//
@ -4152,6 +4169,7 @@ mod tests {
let mut state = State::<Bn128Field>::new(
vec![((*MODULE_ID).clone(), module)].into_iter().collect(),
(*MODULE_ID).clone(),
);
let mut checker: Checker<Bn128Field> = Checker::default();
@ -4201,6 +4219,7 @@ mod tests {
let mut state = State::<Bn128Field>::new(
vec![((*MODULE_ID).clone(), module)].into_iter().collect(),
(*MODULE_ID).clone(),
);
let mut checker: Checker<Bn128Field> = Checker::default();
@ -4244,6 +4263,7 @@ mod tests {
let mut state = State::<Bn128Field>::new(
vec![((*MODULE_ID).clone(), module)].into_iter().collect(),
(*MODULE_ID).clone(),
);
let mut checker: Checker<Bn128Field> = Checker::default();
@ -4295,6 +4315,7 @@ mod tests {
vec![((*MODULE_ID).clone(), main), ("bar".into(), bar)]
.into_iter()
.collect(),
(*MODULE_ID).clone(),
);
let mut checker: Checker<Bn128Field> = Checker::default();
@ -4343,6 +4364,7 @@ mod tests {
vec![((*MODULE_ID).clone(), main), ("bar".into(), bar)]
.into_iter()
.collect(),
(*MODULE_ID).clone(),
);
let mut checker: Checker<Bn128Field> = Checker::default();
@ -4373,7 +4395,7 @@ mod tests {
#[test]
fn undeclared_generic() {
let modules = Modules::new();
let state = State::new(modules);
let state = State::new(modules, (*MODULE_ID).clone());
let signature = UnresolvedSignature::new().inputs(vec![UnresolvedType::Array(
box UnresolvedType::FieldElement.mock(),
@ -4393,7 +4415,7 @@ mod tests {
fn success() {
// <K, L, M>(field[L][K]) -> field[L][K]
let modules = Modules::new();
let state = State::new(modules);
let state = State::new(modules, (*MODULE_ID).clone());
let signature = UnresolvedSignature::new()
.generics(vec!["K".mock(), "L".mock(), "M".mock()])
@ -4487,7 +4509,7 @@ mod tests {
checker.check_statement(statement, &*MODULE_ID, &TypeMap::new()),
Ok(TypedStatement::definition(
typed::Variable::field_element("a").into(),
FieldElementExpression::Identifier("b".into()).into()
FieldElementExpression::identifier("b".into()).into()
))
);
}
@ -4546,8 +4568,10 @@ mod tests {
];
let module = Module { symbols };
let mut state =
State::<Bn128Field>::new(vec![((*MODULE_ID).clone(), module)].into_iter().collect());
let mut state = State::<Bn128Field>::new(
vec![((*MODULE_ID).clone(), module)].into_iter().collect(),
(*MODULE_ID).clone(),
);
let mut checker: Checker<Bn128Field> = Checker::default();
assert_eq!(
@ -4641,8 +4665,10 @@ mod tests {
];
let module = Module { symbols };
let mut state =
State::<Bn128Field>::new(vec![((*MODULE_ID).clone(), module)].into_iter().collect());
let mut state = State::<Bn128Field>::new(
vec![((*MODULE_ID).clone(), module)].into_iter().collect(),
(*MODULE_ID).clone(),
);
let mut checker: Checker<Bn128Field> = Checker::default();
assert!(checker.check_module(&*MODULE_ID, &mut state).is_ok());
@ -4675,11 +4701,11 @@ mod tests {
.mock();
let modules = Modules::new();
let state = State::new(modules);
let state = State::new(modules, (*MODULE_ID).clone());
let mut checker: Checker<Bn128Field> = Checker::default();
assert_eq!(
checker.check_function(foo, &*MODULE_ID, &state),
checker.check_function("foo", foo, &*MODULE_ID, &state),
Err(vec![ErrorInner {
pos: Some((Position::mock(), Position::mock())),
message: "Identifier \"i\" is undefined".into()
@ -4720,7 +4746,7 @@ mod tests {
UBitwidth::B32,
)
.into(),
UExpressionInner::Identifier(
UExpression::identifier(
CoreIdentifier::Source(ShadowedIdentifier::shadow("i", 1)).into(),
)
.annotate(UBitwidth::B32)
@ -4754,11 +4780,11 @@ mod tests {
};
let modules = Modules::new();
let state = State::new(modules);
let state = State::new(modules, (*MODULE_ID).clone());
let mut checker: Checker<Bn128Field> = Checker::default();
assert_eq!(
checker.check_function(foo, &*MODULE_ID, &state),
checker.check_function("foo", foo, &*MODULE_ID, &state),
Ok(foo_checked)
);
}
@ -4801,11 +4827,11 @@ mod tests {
.mock();
let modules = Modules::new();
let state = State::new(modules);
let state = State::new(modules, (*MODULE_ID).clone());
let mut checker: Checker<Bn128Field> = new_with_args(Scope::default(), functions);
assert_eq!(
checker.check_function(bar, &*MODULE_ID, &state),
checker.check_function("bar", bar, &*MODULE_ID, &state),
Err(vec![ErrorInner {
pos: Some((Position::mock(), Position::mock())),
message:
@ -4840,11 +4866,11 @@ mod tests {
.mock();
let modules = Modules::new();
let state = State::new(modules);
let state = State::new(modules, (*MODULE_ID).clone());
let mut checker: Checker<Bn128Field> = new_with_args(Scope::default(), HashSet::new());
assert_eq!(
checker.check_function(bar, &*MODULE_ID, &state),
checker.check_function("bar", bar, &*MODULE_ID, &state),
Err(vec![ErrorInner {
pos: Some((Position::mock(), Position::mock())),
@ -4912,8 +4938,10 @@ mod tests {
],
};
let mut state =
State::<Bn128Field>::new(vec![((*MODULE_ID).clone(), module)].into_iter().collect());
let mut state = State::<Bn128Field>::new(
vec![((*MODULE_ID).clone(), module)].into_iter().collect(),
(*MODULE_ID).clone(),
);
let mut checker: Checker<Bn128Field> = new_with_args(Scope::default(), HashSet::new());
assert_eq!(
@ -5008,8 +5036,10 @@ mod tests {
],
};
let mut state =
State::<Bn128Field>::new(vec![((*MODULE_ID).clone(), module)].into_iter().collect());
let mut state = State::<Bn128Field>::new(
vec![((*MODULE_ID).clone(), module)].into_iter().collect(),
(*MODULE_ID).clone(),
);
let mut checker: Checker<Bn128Field> = new_with_args(Scope::default(), HashSet::new());
assert!(checker.check_module(&*MODULE_ID, &mut state).is_ok());
@ -5049,11 +5079,11 @@ mod tests {
.mock();
let modules = Modules::new();
let state = State::new(modules);
let state = State::new(modules, (*MODULE_ID).clone());
let mut checker: Checker<Bn128Field> = new_with_args(Scope::default(), HashSet::new());
assert_eq!(
checker.check_function(bar, &*MODULE_ID, &state),
checker.check_function("bar", bar, &*MODULE_ID, &state),
Err(vec![ErrorInner {
pos: Some((Position::mock(), Position::mock())),
@ -5082,11 +5112,11 @@ mod tests {
.mock();
let modules = Modules::new();
let state = State::new(modules);
let state = State::new(modules, (*MODULE_ID).clone());
let mut checker: Checker<Bn128Field> = new_with_args(Scope::default(), HashSet::new());
assert_eq!(
checker.check_function(bar, &*MODULE_ID, &state),
checker.check_function("bar", bar, &*MODULE_ID, &state),
Err(vec![ErrorInner {
pos: Some((Position::mock(), Position::mock())),
message: "Identifier \"a\" is undefined".into()
@ -5119,12 +5149,12 @@ mod tests {
]);
let modules = Modules::new();
let state = State::new(modules);
let state = State::new(modules, (*MODULE_ID).clone());
let mut checker: Checker<Bn128Field> = new_with_args(Scope::default(), HashSet::new());
assert_eq!(
checker
.check_function(f, &*MODULE_ID, &state)
.check_function("main", f, &*MODULE_ID, &state)
.unwrap_err()[0]
.message,
"Duplicate name in function definition: `a` was previously declared as an argument, a generic parameter or a constant"
@ -5133,7 +5163,7 @@ mod tests {
#[test]
fn duplicate_main_function() {
// def main(a) -> field {
// def main(field a) -> field {
// return 1;
// }
// def main() -> field {
@ -5144,10 +5174,9 @@ mod tests {
let main1_statements: Vec<StatementNode> =
vec![Statement::Return(Some(Expression::IntConstant(1usize.into()).mock())).mock()];
let main1_arguments = vec![zokrates_ast::untyped::Parameter {
id: untyped::Variable::immutable("a", UnresolvedType::FieldElement.mock()).mock(),
is_private: false,
}
let main1_arguments = vec![zokrates_ast::untyped::Parameter::public(
untyped::Variable::immutable("a", UnresolvedType::FieldElement.mock()).mock(),
)
.mock()];
let main2_statements: Vec<StatementNode> =
@ -5413,6 +5442,7 @@ mod tests {
let mut state = State::<Bn128Field>::new(
vec![((*MODULE_ID).clone(), module)].into_iter().collect(),
(*MODULE_ID).clone(),
);
let mut checker: Checker<Bn128Field> = Checker::default();
@ -5430,7 +5460,7 @@ mod tests {
fn empty_def() {
// an empty struct should be allowed to be defined
let modules = Modules::new();
let state = State::new(modules);
let state = State::new(modules, (*MODULE_ID).clone());
let declaration: StructDefinitionNode = StructDefinition {
generics: vec![],
@ -5456,7 +5486,7 @@ mod tests {
fn valid_def() {
// a valid struct should be allowed to be defined
let modules = Modules::new();
let state = State::new(modules);
let state = State::new(modules, (*MODULE_ID).clone());
let declaration: StructDefinitionNode = StructDefinition {
generics: vec![],
@ -5500,7 +5530,7 @@ mod tests {
fn duplicate_member_def() {
// definition of a struct with a duplicate member should be rejected
let modules = Modules::new();
let state = State::new(modules);
let state = State::new(modules, (*MODULE_ID).clone());
let declaration: StructDefinitionNode = StructDefinition {
generics: vec![],
@ -5577,6 +5607,7 @@ mod tests {
let mut state = State::<Bn128Field>::new(
vec![((*MODULE_ID).clone(), module)].into_iter().collect(),
(*MODULE_ID).clone(),
);
assert!(Checker::default()
@ -5636,6 +5667,7 @@ mod tests {
let mut state = State::<Bn128Field>::new(
vec![((*MODULE_ID).clone(), module)].into_iter().collect(),
(*MODULE_ID).clone(),
);
assert!(Checker::default()
@ -5669,6 +5701,7 @@ mod tests {
let mut state = State::<Bn128Field>::new(
vec![((*MODULE_ID).clone(), module)].into_iter().collect(),
(*MODULE_ID).clone(),
);
assert!(Checker::default()
@ -5720,6 +5753,7 @@ mod tests {
let mut state = State::<Bn128Field>::new(
vec![((*MODULE_ID).clone(), module)].into_iter().collect(),
(*MODULE_ID).clone(),
);
assert!(Checker::default()
@ -6136,8 +6170,9 @@ mod tests {
let mut foo_field = function0();
foo_field.value.arguments = vec![untyped::Parameter::private(
foo_field.value.arguments = vec![untyped::Parameter::new(
untyped::Variable::immutable("a", UnresolvedType::FieldElement.mock()).mock(),
None,
)
.mock()];
foo_field.value.statements =
@ -6148,8 +6183,9 @@ mod tests {
let mut foo_u32 = function0();
foo_u32.value.arguments = vec![untyped::Parameter::private(
foo_u32.value.arguments = vec![untyped::Parameter::new(
untyped::Variable::immutable("a", UnresolvedType::Uint(32).mock()).mock(),
None,
)
.mock()];
foo_u32.value.statements =

View file

@ -1,7 +1,7 @@
use zokrates_ast::typed::{
folder::*, BlockExpression, BooleanExpression, Conditional, ConditionalExpression,
ConditionalOrExpression, CoreIdentifier, Expr, Identifier, Type, TypedExpression, TypedProgram,
TypedStatement, Variable,
ConditionalOrExpression, CoreIdentifier, Expr, Id, Identifier, Type, TypedExpression,
TypedProgram, TypedStatement, Variable,
};
use zokrates_field::Field;
@ -70,7 +70,7 @@ impl<'ast, T: Field> Folder<'ast, T> for ConditionRedefiner<'ast, T> {
TypedExpression::from(condition),
));
self.index += 1;
BooleanExpression::Identifier(condition_id)
BooleanExpression::identifier(condition_id)
}
};
@ -123,7 +123,7 @@ mod tests {
let s = TypedStatement::definition(
Variable::field_element("foo").into(),
FieldElementExpression::conditional(
BooleanExpression::Identifier("c".into()),
BooleanExpression::identifier("c".into()),
FieldElementExpression::Number(Bn128Field::from(1)),
FieldElementExpression::Number(Bn128Field::from(2)),
ConditionalKind::IfElse,
@ -144,8 +144,8 @@ mod tests {
// field foo = if #CONDITION_0 { 1 } else { 2 };
let condition = BooleanExpression::And(
box BooleanExpression::Identifier("c".into()),
box BooleanExpression::Identifier("d".into()),
box BooleanExpression::identifier("c".into()),
box BooleanExpression::identifier("d".into()),
);
let s = TypedStatement::definition(
@ -171,7 +171,7 @@ mod tests {
TypedStatement::definition(
Variable::field_element("foo").into(),
FieldElementExpression::conditional(
BooleanExpression::Identifier(CoreIdentifier::Condition(0).into()),
BooleanExpression::identifier(CoreIdentifier::Condition(0).into()),
FieldElementExpression::Number(Bn128Field::from(1)),
FieldElementExpression::Number(Bn128Field::from(2)),
ConditionalKind::IfElse,
@ -203,13 +203,13 @@ mod tests {
// };
let condition_0 = BooleanExpression::And(
box BooleanExpression::Identifier("c".into()),
box BooleanExpression::Identifier("d".into()),
box BooleanExpression::identifier("c".into()),
box BooleanExpression::identifier("d".into()),
);
let condition_1 = BooleanExpression::And(
box BooleanExpression::Identifier("e".into()),
box BooleanExpression::Identifier("f".into()),
box BooleanExpression::identifier("e".into()),
box BooleanExpression::identifier("f".into()),
);
let s = TypedStatement::definition(
@ -244,9 +244,9 @@ mod tests {
TypedStatement::definition(
Variable::field_element("foo").into(),
FieldElementExpression::conditional(
BooleanExpression::Identifier(CoreIdentifier::Condition(0).into()),
BooleanExpression::identifier(CoreIdentifier::Condition(0).into()),
FieldElementExpression::conditional(
BooleanExpression::Identifier(CoreIdentifier::Condition(1).into()),
BooleanExpression::identifier(CoreIdentifier::Condition(1).into()),
FieldElementExpression::Number(Bn128Field::from(1)),
FieldElementExpression::Number(Bn128Field::from(2)),
ConditionalKind::IfElse,
@ -285,23 +285,23 @@ mod tests {
// };
let condition_0 = BooleanExpression::And(
box BooleanExpression::Identifier("c".into()),
box BooleanExpression::Identifier("d".into()),
box BooleanExpression::identifier("c".into()),
box BooleanExpression::identifier("d".into()),
);
let condition_1 = BooleanExpression::And(
box BooleanExpression::Identifier("e".into()),
box BooleanExpression::Identifier("f".into()),
box BooleanExpression::identifier("e".into()),
box BooleanExpression::identifier("f".into()),
);
let condition_2 = BooleanExpression::And(
box BooleanExpression::Identifier("e".into()),
box BooleanExpression::Identifier("f".into()),
box BooleanExpression::identifier("e".into()),
box BooleanExpression::identifier("f".into()),
);
let condition_id_0 = BooleanExpression::Identifier(CoreIdentifier::Condition(0).into());
let condition_id_1 = BooleanExpression::Identifier(CoreIdentifier::Condition(1).into());
let condition_id_2 = BooleanExpression::Identifier(CoreIdentifier::Condition(2).into());
let condition_id_0 = BooleanExpression::identifier(CoreIdentifier::Condition(0).into());
let condition_id_1 = BooleanExpression::identifier(CoreIdentifier::Condition(1).into());
let condition_id_2 = BooleanExpression::identifier(CoreIdentifier::Condition(2).into());
let s = TypedStatement::definition(
Variable::field_element("foo").into(),

View file

@ -131,7 +131,7 @@ mod tests {
let main: TypedFunction<Bn128Field> = TypedFunction {
arguments: vec![],
statements: vec![TypedStatement::Return(
FieldElementExpression::Identifier(Identifier::from(const_id)).into(),
FieldElementExpression::identifier(Identifier::from(const_id)).into(),
)],
signature: DeclarationSignature::new()
.inputs(vec![])
@ -191,7 +191,7 @@ mod tests {
let main: TypedFunction<Bn128Field> = TypedFunction {
arguments: vec![],
statements: vec![TypedStatement::Return(
BooleanExpression::Identifier(Identifier::from(const_id.clone())).into(),
BooleanExpression::identifier(Identifier::from(const_id.clone())).into(),
)],
signature: DeclarationSignature::new()
.inputs(vec![])
@ -249,7 +249,7 @@ mod tests {
let main: TypedFunction<Bn128Field> = TypedFunction {
arguments: vec![],
statements: vec![TypedStatement::Return(
UExpressionInner::Identifier(Identifier::from(const_id.clone()))
UExpression::identifier(Identifier::from(const_id.clone()))
.annotate(UBitwidth::B32)
.into(),
)],
@ -313,13 +313,13 @@ mod tests {
statements: vec![TypedStatement::Return(
FieldElementExpression::Add(
FieldElementExpression::select(
ArrayExpressionInner::Identifier(Identifier::from(const_id.clone()))
ArrayExpression::identifier(Identifier::from(const_id.clone()))
.annotate(GType::FieldElement, 2u32),
UExpressionInner::Value(0u128).annotate(UBitwidth::B32),
)
.into(),
FieldElementExpression::select(
ArrayExpressionInner::Identifier(Identifier::from(const_id.clone()))
ArrayExpression::identifier(Identifier::from(const_id.clone()))
.annotate(GType::FieldElement, 2u32),
UExpressionInner::Value(1u128).annotate(UBitwidth::B32),
)
@ -398,7 +398,7 @@ mod tests {
let main: TypedFunction<Bn128Field> = TypedFunction {
arguments: vec![],
statements: vec![TypedStatement::Return(
FieldElementExpression::Identifier(Identifier::from(const_b_id.clone())).into(),
FieldElementExpression::identifier(Identifier::from(const_b_id.clone())).into(),
)],
signature: DeclarationSignature::new()
.inputs(vec![])
@ -425,7 +425,7 @@ mod tests {
const_b_id.clone(),
TypedConstantSymbol::Here(TypedConstant::new(
TypedExpression::FieldElement(FieldElementExpression::Add(
box FieldElementExpression::Identifier(Identifier::from(
box FieldElementExpression::identifier(Identifier::from(
const_a_id.clone(),
)),
box FieldElementExpression::Number(Bn128Field::from(1)),
@ -515,7 +515,7 @@ mod tests {
TypedConstantSymbolDeclaration::new(
bar_const_id.clone(),
TypedConstantSymbol::Here(TypedConstant::new(
TypedExpression::FieldElement(FieldElementExpression::Identifier(
TypedExpression::FieldElement(FieldElementExpression::identifier(
foo_const_id.clone().into(),
)),
DeclarationType::FieldElement,
@ -557,7 +557,7 @@ mod tests {
TypedFunctionSymbol::Here(TypedFunction {
arguments: vec![],
statements: vec![TypedStatement::Return(
FieldElementExpression::Identifier(Identifier::from(
FieldElementExpression::identifier(Identifier::from(
main_const_id.clone(),
))
.into(),
@ -587,7 +587,7 @@ mod tests {
TypedConstantSymbolDeclaration::new(
main_const_id.clone(),
TypedConstantSymbol::Here(TypedConstant::new(
TypedExpression::FieldElement(FieldElementExpression::Identifier(
TypedExpression::FieldElement(FieldElementExpression::identifier(
foo_const_id.into(),
)),
DeclarationType::FieldElement,
@ -603,7 +603,7 @@ mod tests {
TypedFunctionSymbol::Here(TypedFunction {
arguments: vec![],
statements: vec![TypedStatement::Return(
FieldElementExpression::Identifier(Identifier::from(
FieldElementExpression::identifier(Identifier::from(
main_const_id.clone(),
))
.into(),
@ -745,7 +745,7 @@ mod tests {
main_baz_const_id.clone(),
TypedConstantSymbol::Here(TypedConstant::new(
TypedExpression::Array(
ArrayExpressionInner::Identifier(main_bar_const_id.clone().into())
ArrayExpression::identifier(main_bar_const_id.clone().into())
.annotate(Type::FieldElement, main_foo_const_id.clone()),
),
DeclarationType::Array(DeclarationArrayType::new(
@ -764,7 +764,7 @@ mod tests {
TypedFunctionSymbol::Here(TypedFunction {
arguments: vec![],
statements: vec![TypedStatement::Return(
FieldElementExpression::Identifier(Identifier::from(
FieldElementExpression::identifier(Identifier::from(
main_foo_const_id.clone(),
))
.into(),
@ -820,7 +820,7 @@ mod tests {
main_baz_const_id.clone(),
TypedConstantSymbol::Here(TypedConstant::new(
TypedExpression::Array(
ArrayExpressionInner::Identifier(main_bar_const_id.into())
ArrayExpression::identifier(main_bar_const_id.into())
.annotate(Type::FieldElement, main_foo_const_id.clone()),
),
DeclarationType::Array(DeclarationArrayType::new(
@ -839,7 +839,7 @@ mod tests {
TypedFunctionSymbol::Here(TypedFunction {
arguments: vec![],
statements: vec![TypedStatement::Return(
FieldElementExpression::Identifier(Identifier::from(
FieldElementExpression::identifier(Identifier::from(
main_foo_const_id.clone(),
))
.into(),

View file

@ -1,5 +1,5 @@
use std::marker::PhantomData;
use zokrates_ast::typed::types::UBitwidth;
use zokrates_ast::typed::types::{ConcreteArrayType, IntoType, UBitwidth};
use zokrates_ast::typed::{self, Expr, Typed};
use zokrates_ast::zir::{self, Select};
use zokrates_field::Field;
@ -57,6 +57,56 @@ fn flatten_identifier_rec<'ast>(
}
}
fn flatten_identifier_to_expression_rec<'ast, T: Field>(
id: zir::SourceIdentifier<'ast>,
ty: &typed::types::ConcreteType,
) -> Vec<zir::ZirExpression<'ast, T>> {
match ty {
typed::ConcreteType::Int => unreachable!(),
typed::ConcreteType::FieldElement => {
vec![zir::FieldElementExpression::Identifier(zir::Identifier::Source(id)).into()]
}
typed::ConcreteType::Boolean => {
vec![zir::BooleanExpression::Identifier(zir::Identifier::Source(id)).into()]
}
typed::ConcreteType::Uint(bitwidth) => {
vec![
zir::UExpressionInner::Identifier(zir::Identifier::Source(id))
.annotate(bitwidth.to_usize())
.into(),
]
}
typed::ConcreteType::Array(array_type) => (0..*array_type.size)
.flat_map(|i| {
flatten_identifier_to_expression_rec(
zir::SourceIdentifier::Select(box id.clone(), i),
&array_type.ty,
)
})
.collect(),
typed::types::ConcreteType::Struct(members) => members
.iter()
.flat_map(|struct_member| {
flatten_identifier_to_expression_rec(
zir::SourceIdentifier::Member(box id.clone(), struct_member.id.clone()),
&struct_member.ty,
)
})
.collect(),
typed::types::ConcreteType::Tuple(tuple_ty) => tuple_ty
.elements
.iter()
.enumerate()
.flat_map(|(i, ty)| {
flatten_identifier_to_expression_rec(
zir::SourceIdentifier::Element(box id.clone(), i as u32),
ty,
)
})
.collect(),
}
}
trait Flatten<'ast, T: Field> {
fn flatten(
self,
@ -269,6 +319,14 @@ impl<'ast, T: Field> Flattener<T> {
}
}
fn fold_identifier_expression<E: Expr<'ast, T>>(
&mut self,
ty: E::ConcreteTy,
e: typed::IdentifierExpression<'ast, E>,
) -> Vec<zir::ZirExpression<'ast, T>> {
fold_identifier_expression(self, ty, e)
}
fn fold_array_expression(
&mut self,
statements_buffer: &mut Vec<zir::ZirStatement<'ast, T>>,
@ -367,7 +425,7 @@ impl<'ast, T: Field> Flattener<T> {
fn fold_array_expression_inner(
&mut self,
statements_buffer: &mut Vec<zir::ZirStatement<'ast, T>>,
ty: &typed::types::ConcreteType,
ty: typed::types::ConcreteType,
size: u32,
e: typed::ArrayExpressionInner<'ast, T>,
) -> Vec<zir::ZirExpression<'ast, T>> {
@ -377,7 +435,7 @@ impl<'ast, T: Field> Flattener<T> {
fn fold_struct_expression_inner(
&mut self,
statements_buffer: &mut Vec<zir::ZirStatement<'ast, T>>,
ty: &typed::types::ConcreteStructType,
ty: typed::types::ConcreteStructType,
e: typed::StructExpressionInner<'ast, T>,
) -> Vec<zir::ZirExpression<'ast, T>> {
fold_struct_expression_inner(self, statements_buffer, ty, e)
@ -386,7 +444,7 @@ impl<'ast, T: Field> Flattener<T> {
fn fold_tuple_expression_inner(
&mut self,
statements_buffer: &mut Vec<zir::ZirStatement<'ast, T>>,
ty: &typed::types::ConcreteTupleType,
ty: typed::types::ConcreteTupleType,
e: typed::TupleExpressionInner<'ast, T>,
) -> Vec<zir::ZirExpression<'ast, T>> {
fold_tuple_expression_inner(self, statements_buffer, ty, e)
@ -461,7 +519,7 @@ fn fold_statement<'ast, T: Field>(
fn fold_array_expression_inner<'ast, T: Field>(
f: &mut Flattener<T>,
statements_buffer: &mut Vec<zir::ZirStatement<'ast, T>>,
ty: &typed::types::ConcreteType,
ty: typed::types::ConcreteType,
size: u32,
array: typed::ArrayExpressionInner<'ast, T>,
) -> Vec<zir::ZirExpression<'ast, T>> {
@ -474,20 +532,7 @@ fn fold_array_expression_inner<'ast, T: Field>(
f.fold_array_expression(statements_buffer, *block.value)
}
typed::ArrayExpressionInner::Identifier(id) => {
let variables = flatten_identifier_rec(
f.fold_name(id),
&typed::types::ConcreteType::array((ty.clone(), size)),
);
variables
.into_iter()
.map(|v| match v._type {
zir::Type::FieldElement => zir::FieldElementExpression::Identifier(v.id).into(),
zir::Type::Boolean => zir::BooleanExpression::Identifier(v.id).into(),
zir::Type::Uint(bitwidth) => zir::UExpressionInner::Identifier(v.id)
.annotate(bitwidth)
.into(),
})
.collect()
f.fold_identifier_expression(ConcreteArrayType::new(ty, size), id)
}
typed::ArrayExpressionInner::Value(exprs) => {
let exprs: Vec<_> = exprs
@ -544,7 +589,7 @@ fn fold_array_expression_inner<'ast, T: Field>(
fn fold_struct_expression_inner<'ast, T: Field>(
f: &mut Flattener<T>,
statements_buffer: &mut Vec<zir::ZirStatement<'ast, T>>,
ty: &typed::types::ConcreteStructType,
ty: typed::types::ConcreteStructType,
struc: typed::StructExpressionInner<'ast, T>,
) -> Vec<zir::ZirExpression<'ast, T>> {
match struc {
@ -555,22 +600,7 @@ fn fold_struct_expression_inner<'ast, T: Field>(
.for_each(|s| f.fold_statement(statements_buffer, s));
f.fold_struct_expression(statements_buffer, *block.value)
}
typed::StructExpressionInner::Identifier(id) => {
let variables = flatten_identifier_rec(
f.fold_name(id),
&typed::types::ConcreteType::struc(ty.clone()),
);
variables
.into_iter()
.map(|v| match v._type {
zir::Type::FieldElement => zir::FieldElementExpression::Identifier(v.id).into(),
zir::Type::Boolean => zir::BooleanExpression::Identifier(v.id).into(),
zir::Type::Uint(bitwidth) => zir::UExpressionInner::Identifier(v.id)
.annotate(bitwidth)
.into(),
})
.collect()
}
typed::StructExpressionInner::Identifier(id) => f.fold_identifier_expression(ty, id),
typed::StructExpressionInner::Value(exprs) => exprs
.into_iter()
.flat_map(|e| f.fold_expression(statements_buffer, e))
@ -592,7 +622,7 @@ fn fold_struct_expression_inner<'ast, T: Field>(
fn fold_tuple_expression_inner<'ast, T: Field>(
f: &mut Flattener<T>,
statements_buffer: &mut Vec<zir::ZirStatement<'ast, T>>,
ty: &typed::types::ConcreteTupleType,
ty: typed::types::ConcreteTupleType,
tuple: typed::TupleExpressionInner<'ast, T>,
) -> Vec<zir::ZirExpression<'ast, T>> {
match tuple {
@ -603,22 +633,7 @@ fn fold_tuple_expression_inner<'ast, T: Field>(
.for_each(|s| f.fold_statement(statements_buffer, s));
f.fold_tuple_expression(statements_buffer, *block.value)
}
typed::TupleExpressionInner::Identifier(id) => {
let variables = flatten_identifier_rec(
f.fold_name(id),
&typed::types::ConcreteType::tuple(ty.clone()),
);
variables
.into_iter()
.map(|v| match v._type {
zir::Type::FieldElement => zir::FieldElementExpression::Identifier(v.id).into(),
zir::Type::Boolean => zir::BooleanExpression::Identifier(v.id).into(),
zir::Type::Uint(bitwidth) => zir::UExpressionInner::Identifier(v.id)
.annotate(bitwidth)
.into(),
})
.collect()
}
typed::TupleExpressionInner::Identifier(id) => f.fold_identifier_expression(ty, id),
typed::TupleExpressionInner::Value(exprs) => exprs
.into_iter()
.flat_map(|e| f.fold_expression(statements_buffer, e))
@ -828,6 +843,14 @@ fn fold_conditional_expression<'ast, T: Field, E: Flatten<'ast, T>>(
.collect()
}
fn fold_identifier_expression<'ast, T: Field, E: Expr<'ast, T>>(
f: &mut Flattener<T>,
ty: E::ConcreteTy,
e: typed::IdentifierExpression<'ast, E>,
) -> Vec<zir::ZirExpression<'ast, T>> {
flatten_identifier_to_expression_rec(f.fold_name(e.id), &ty.into_type())
}
fn fold_field_expression<'ast, T: Field>(
f: &mut Flattener<T>,
statements_buffer: &mut Vec<zir::ZirStatement<'ast, T>>,
@ -835,12 +858,12 @@ fn fold_field_expression<'ast, T: Field>(
) -> zir::FieldElementExpression<'ast, T> {
match e {
typed::FieldElementExpression::Number(n) => zir::FieldElementExpression::Number(n),
typed::FieldElementExpression::Identifier(id) => zir::FieldElementExpression::Identifier(
flatten_identifier_rec(f.fold_name(id), &typed::types::ConcreteType::FieldElement)
.pop()
.unwrap()
.id,
),
typed::FieldElementExpression::Identifier(id) => f
.fold_identifier_expression(typed::ConcreteType::FieldElement, id)
.pop()
.unwrap()
.try_into()
.unwrap(),
typed::FieldElementExpression::Add(box e1, box e2) => {
let e1 = f.fold_field_expression(statements_buffer, e1);
let e2 = f.fold_field_expression(statements_buffer, e2);
@ -965,12 +988,12 @@ fn fold_boolean_expression<'ast, T: Field>(
f.fold_boolean_expression(statements_buffer, *block.value)
}
typed::BooleanExpression::Value(v) => zir::BooleanExpression::Value(v),
typed::BooleanExpression::Identifier(id) => zir::BooleanExpression::Identifier(
flatten_identifier_rec(f.fold_name(id), &typed::types::ConcreteType::Boolean)
.pop()
.unwrap()
.id,
),
typed::BooleanExpression::Identifier(id) => f
.fold_identifier_expression(typed::ConcreteType::Boolean, id)
.pop()
.unwrap()
.try_into()
.unwrap(),
typed::BooleanExpression::FieldEq(e) => f.fold_eq_expression(statements_buffer, e),
typed::BooleanExpression::BoolEq(e) => f.fold_eq_expression(statements_buffer, e),
typed::BooleanExpression::ArrayEq(e) => f.fold_eq_expression(statements_buffer, e),
@ -1084,12 +1107,11 @@ fn fold_uint_expression_inner<'ast, T: Field>(
.into_inner()
}
typed::UExpressionInner::Value(v) => zir::UExpressionInner::Value(v),
typed::UExpressionInner::Identifier(id) => zir::UExpressionInner::Identifier(
flatten_identifier_rec(f.fold_name(id), &typed::types::ConcreteType::Uint(bitwidth))
.pop()
typed::UExpressionInner::Identifier(id) => {
zir::UExpression::try_from(f.fold_identifier_expression(bitwidth, id).pop().unwrap())
.unwrap()
.id,
),
.into_inner()
}
typed::UExpressionInner::Add(box left, box right) => {
let left = f.fold_uint_expression(statements_buffer, left);
let right = f.fold_uint_expression(statements_buffer, right);
@ -1250,7 +1272,7 @@ fn fold_array_expression<'ast, T: Field>(
let size: u32 = e.size().try_into().unwrap();
f.fold_array_expression_inner(
statements_buffer,
&typed::types::ConcreteType::try_from(e.inner_type().clone()).unwrap(),
typed::types::ConcreteType::try_from(e.inner_type().clone()).unwrap(),
size,
e.into_inner(),
)
@ -1263,7 +1285,7 @@ fn fold_struct_expression<'ast, T: Field>(
) -> Vec<zir::ZirExpression<'ast, T>> {
f.fold_struct_expression_inner(
statements_buffer,
&typed::types::ConcreteStructType::try_from(e.ty().clone()).unwrap(),
typed::types::ConcreteStructType::try_from(e.ty().clone()).unwrap(),
e.into_inner(),
)
}
@ -1275,7 +1297,7 @@ fn fold_tuple_expression<'ast, T: Field>(
) -> Vec<zir::ZirExpression<'ast, T>> {
f.fold_tuple_expression_inner(
statements_buffer,
&typed::types::ConcreteTupleType::try_from(e.ty().clone()).unwrap(),
typed::types::ConcreteTupleType::try_from(e.ty().clone()).unwrap(),
e.into_inner(),
)
}

View file

@ -544,13 +544,6 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Propagator<'ast, 'a, T> {
e: UExpressionInner<'ast, T>,
) -> Result<UExpressionInner<'ast, T>, Error> {
match e {
UExpressionInner::Identifier(id) => match self.constants.get(&id) {
Some(e) => match e {
TypedExpression::Uint(e) => Ok(e.as_inner().clone()),
_ => unreachable!("constant stored for a uint should be a uint"),
},
None => Ok(UExpressionInner::Identifier(id)),
},
UExpressionInner::Add(box e1, box e2) => match (
self.fold_uint_expression(e1)?.into_inner(),
self.fold_uint_expression(e2)?.into_inner(),
@ -774,15 +767,6 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Propagator<'ast, 'a, T> {
e: FieldElementExpression<'ast, T>,
) -> Result<FieldElementExpression<'ast, T>, Error> {
match e {
FieldElementExpression::Identifier(id) => match self.constants.get(&id) {
Some(e) => match e {
TypedExpression::FieldElement(e) => Ok(e.clone()),
_ => unreachable!(
"constant stored for a field element should be a field element"
),
},
None => Ok(FieldElementExpression::Identifier(id)),
},
FieldElementExpression::Add(box e1, box e2) => match (
self.fold_field_expression(e1)?,
self.fold_field_expression(e2)?,
@ -936,7 +920,7 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Propagator<'ast, 'a, T> {
}
}
(ArrayExpressionInner::Identifier(id), UExpressionInner::Value(n)) => {
match self.constants.get(&id) {
match self.constants.get(&id.id) {
Some(a) => match a {
TypedExpression::Array(a) => match a.as_inner() {
ArrayExpressionInner::Value(v) => {
@ -975,13 +959,6 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Propagator<'ast, 'a, T> {
e: ArrayExpressionInner<'ast, T>,
) -> Result<ArrayExpressionInner<'ast, T>, Error> {
match e {
ArrayExpressionInner::Identifier(id) => match self.constants.get(&id) {
Some(e) => match e {
TypedExpression::Array(e) => Ok(e.as_inner().clone()),
_ => panic!("constant stored for an array should be an array"),
},
None => Ok(ArrayExpressionInner::Identifier(id)),
},
ArrayExpressionInner::Value(exprs) => {
Ok(ArrayExpressionInner::Value(
exprs
@ -1032,13 +1009,6 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Propagator<'ast, 'a, T> {
e: StructExpressionInner<'ast, T>,
) -> Result<StructExpressionInner<'ast, T>, Error> {
match e {
StructExpressionInner::Identifier(id) => match self.constants.get(&id) {
Some(e) => match e {
TypedExpression::Struct(e) => Ok(e.as_inner().clone()),
_ => panic!("constant stored for an array should be an array"),
},
None => Ok(StructExpressionInner::Identifier(id)),
},
StructExpressionInner::Value(v) => {
let v = v.into_iter().zip(ty.iter()).map(|(v, member)|
match self.fold_expression(v) {
@ -1059,19 +1029,23 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Propagator<'ast, 'a, T> {
}
}
fn fold_identifier_expression<E: Expr<'ast, T> + Id<'ast, T> + ResultFold<'ast, T>>(
&mut self,
_: &E::Ty,
id: IdentifierExpression<'ast, E>,
) -> Result<IdentifierOrExpression<'ast, T, E>, Self::Error> {
match self.constants.get(&id.id).cloned() {
Some(e) => Ok(IdentifierOrExpression::Expression(E::from(e).into_inner())),
None => Ok(IdentifierOrExpression::Identifier(id)),
}
}
fn fold_tuple_expression_inner(
&mut self,
ty: &TupleType<'ast, T>,
e: TupleExpressionInner<'ast, T>,
) -> Result<TupleExpressionInner<'ast, T>, Error> {
match e {
TupleExpressionInner::Identifier(id) => match self.constants.get(&id) {
Some(e) => match e {
TypedExpression::Tuple(e) => Ok(e.as_inner().clone()),
_ => panic!("constant stored for an tuple should be an tuple"),
},
None => Ok(TupleExpressionInner::Identifier(id)),
},
TupleExpressionInner::Value(v) => {
let v = v.into_iter().zip(ty.elements.iter().enumerate()).map(|(v, (index, element_ty))|
match self.fold_expression(v) {
@ -1141,13 +1115,6 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Propagator<'ast, 'a, T> {
// These kind of reduction rules are easier to apply later in the process, when we have canonical representations
// of expressions, ie `a + a` would always be written `2 * a`
match e {
BooleanExpression::Identifier(id) => match self.constants.get(&id) {
Some(e) => match e {
TypedExpression::Boolean(e) => Ok(e.clone()),
_ => panic!("constant stored for a boolean should be a boolean"),
},
None => Ok(BooleanExpression::Identifier(id)),
},
BooleanExpression::FieldLt(box e1, box e2) => {
let e1 = self.fold_field_expression(e1)?;
let e2 = self.fold_field_expression(e2)?;
@ -1432,7 +1399,7 @@ mod tests {
BooleanExpression::Not(box BooleanExpression::Value(true));
let e_default: BooleanExpression<Bn128Field> =
BooleanExpression::Not(box BooleanExpression::Identifier("a".into()));
BooleanExpression::Not(box BooleanExpression::identifier("a".into()));
assert_eq!(
Propagator::with_constants(&mut Constants::new())
@ -1465,14 +1432,14 @@ mod tests {
let e_identifier_true: BooleanExpression<Bn128Field> =
BooleanExpression::FieldEq(EqExpression::new(
FieldElementExpression::Identifier("a".into()),
FieldElementExpression::Identifier("a".into()),
FieldElementExpression::identifier("a".into()),
FieldElementExpression::identifier("a".into()),
));
let e_identifier_unchanged: BooleanExpression<Bn128Field> =
BooleanExpression::FieldEq(EqExpression::new(
FieldElementExpression::Identifier("a".into()),
FieldElementExpression::Identifier("b".into()),
FieldElementExpression::identifier("a".into()),
FieldElementExpression::identifier("b".into()),
));
assert_eq!(
@ -1574,18 +1541,14 @@ mod tests {
let e_identifier_true: BooleanExpression<Bn128Field> =
BooleanExpression::ArrayEq(EqExpression::new(
ArrayExpressionInner::Identifier("a".into())
.annotate(Type::FieldElement, 1u32),
ArrayExpressionInner::Identifier("a".into())
.annotate(Type::FieldElement, 1u32),
ArrayExpression::identifier("a".into()).annotate(Type::FieldElement, 1u32),
ArrayExpression::identifier("a".into()).annotate(Type::FieldElement, 1u32),
));
let e_identifier_unchanged: BooleanExpression<Bn128Field> =
BooleanExpression::ArrayEq(EqExpression::new(
ArrayExpressionInner::Identifier("a".into())
.annotate(Type::FieldElement, 1u32),
ArrayExpressionInner::Identifier("b".into())
.annotate(Type::FieldElement, 1u32),
ArrayExpression::identifier("a".into()).annotate(Type::FieldElement, 1u32),
ArrayExpression::identifier("b".into()).annotate(Type::FieldElement, 1u32),
));
let e_non_canonical_true = BooleanExpression::ArrayEq(EqExpression::new(
@ -1772,30 +1735,30 @@ mod tests {
Propagator::<Bn128Field>::with_constants(&mut Constants::new())
.fold_boolean_expression(BooleanExpression::And(
box BooleanExpression::Value(true),
box BooleanExpression::Identifier(a_bool.clone())
box BooleanExpression::identifier(a_bool.clone())
)),
Ok(BooleanExpression::Identifier(a_bool.clone()))
Ok(BooleanExpression::identifier(a_bool.clone()))
);
assert_eq!(
Propagator::<Bn128Field>::with_constants(&mut Constants::new())
.fold_boolean_expression(BooleanExpression::And(
box BooleanExpression::Identifier(a_bool.clone()),
box BooleanExpression::identifier(a_bool.clone()),
box BooleanExpression::Value(true),
)),
Ok(BooleanExpression::Identifier(a_bool.clone()))
Ok(BooleanExpression::identifier(a_bool.clone()))
);
assert_eq!(
Propagator::<Bn128Field>::with_constants(&mut Constants::new())
.fold_boolean_expression(BooleanExpression::And(
box BooleanExpression::Value(false),
box BooleanExpression::Identifier(a_bool.clone())
box BooleanExpression::identifier(a_bool.clone())
)),
Ok(BooleanExpression::Value(false))
);
assert_eq!(
Propagator::<Bn128Field>::with_constants(&mut Constants::new())
.fold_boolean_expression(BooleanExpression::And(
box BooleanExpression::Identifier(a_bool.clone()),
box BooleanExpression::identifier(a_bool.clone()),
box BooleanExpression::Value(false),
)),
Ok(BooleanExpression::Value(false))
@ -1842,14 +1805,14 @@ mod tests {
Propagator::<Bn128Field>::with_constants(&mut Constants::new())
.fold_boolean_expression(BooleanExpression::Or(
box BooleanExpression::Value(true),
box BooleanExpression::Identifier(a_bool.clone())
box BooleanExpression::identifier(a_bool.clone())
)),
Ok(BooleanExpression::Value(true))
);
assert_eq!(
Propagator::<Bn128Field>::with_constants(&mut Constants::new())
.fold_boolean_expression(BooleanExpression::Or(
box BooleanExpression::Identifier(a_bool.clone()),
box BooleanExpression::identifier(a_bool.clone()),
box BooleanExpression::Value(true),
)),
Ok(BooleanExpression::Value(true))
@ -1858,17 +1821,17 @@ mod tests {
Propagator::<Bn128Field>::with_constants(&mut Constants::new())
.fold_boolean_expression(BooleanExpression::Or(
box BooleanExpression::Value(false),
box BooleanExpression::Identifier(a_bool.clone())
box BooleanExpression::identifier(a_bool.clone())
)),
Ok(BooleanExpression::Identifier(a_bool.clone()))
Ok(BooleanExpression::identifier(a_bool.clone()))
);
assert_eq!(
Propagator::<Bn128Field>::with_constants(&mut Constants::new())
.fold_boolean_expression(BooleanExpression::Or(
box BooleanExpression::Identifier(a_bool.clone()),
box BooleanExpression::identifier(a_bool.clone()),
box BooleanExpression::Value(false),
)),
Ok(BooleanExpression::Identifier(a_bool.clone()))
Ok(BooleanExpression::identifier(a_bool.clone()))
);
assert_eq!(
Propagator::<Bn128Field>::with_constants(&mut Constants::new())

View file

@ -3,9 +3,9 @@
use crate::static_analysis::reducer::ConstantDefinitions;
use zokrates_ast::typed::{
folder::*, ArrayExpression, ArrayExpressionInner, ArrayType, BooleanExpression, CoreIdentifier,
DeclarationConstant, Expr, FieldElementExpression, Identifier, StructExpression,
StructExpressionInner, StructType, TupleExpression, TupleExpressionInner, TupleType,
TypedProgram, TypedSymbolDeclaration, UBitwidth, UExpression, UExpressionInner,
DeclarationConstant, Expr, FieldElementExpression, Id, Identifier, IdentifierExpression,
StructExpression, StructExpressionInner, StructType, TupleExpression, TupleExpressionInner,
TupleType, TypedProgram, TypedSymbolDeclaration, UBitwidth, UExpression, UExpressionInner,
};
use zokrates_field::Field;
@ -58,14 +58,18 @@ impl<'a, 'ast, T: Field> Folder<'ast, T> for ConstantsReader<'a, 'ast, T> {
e: FieldElementExpression<'ast, T>,
) -> FieldElementExpression<'ast, T> {
match e {
FieldElementExpression::Identifier(Identifier {
id: CoreIdentifier::Constant(c),
version,
FieldElementExpression::Identifier(IdentifierExpression {
id:
Identifier {
id: CoreIdentifier::Constant(c),
version,
},
..
}) => {
assert_eq!(version, 0);
match self.constants.get(&c).cloned() {
Some(v) => v.try_into().unwrap(),
None => FieldElementExpression::Identifier(Identifier::from(
None => FieldElementExpression::identifier(Identifier::from(
CoreIdentifier::Constant(c),
)),
}
@ -79,15 +83,19 @@ impl<'a, 'ast, T: Field> Folder<'ast, T> for ConstantsReader<'a, 'ast, T> {
e: BooleanExpression<'ast, T>,
) -> BooleanExpression<'ast, T> {
match e {
BooleanExpression::Identifier(Identifier {
id: CoreIdentifier::Constant(c),
version,
BooleanExpression::Identifier(IdentifierExpression {
id:
Identifier {
id: CoreIdentifier::Constant(c),
version,
},
..
}) => {
assert_eq!(version, 0);
match self.constants.get(&c).cloned() {
Some(v) => v.try_into().unwrap(),
None => {
BooleanExpression::Identifier(Identifier::from(CoreIdentifier::Constant(c)))
BooleanExpression::identifier(Identifier::from(CoreIdentifier::Constant(c)))
}
}
}
@ -101,16 +109,18 @@ impl<'a, 'ast, T: Field> Folder<'ast, T> for ConstantsReader<'a, 'ast, T> {
e: UExpressionInner<'ast, T>,
) -> UExpressionInner<'ast, T> {
match e {
UExpressionInner::Identifier(Identifier {
id: CoreIdentifier::Constant(c),
version,
UExpressionInner::Identifier(IdentifierExpression {
id:
Identifier {
id: CoreIdentifier::Constant(c),
version,
},
..
}) => {
assert_eq!(version, 0);
match self.constants.get(&c).cloned() {
Some(v) => UExpression::try_from(v).unwrap().into_inner(),
None => {
UExpressionInner::Identifier(Identifier::from(CoreIdentifier::Constant(c)))
}
None => UExpression::identifier(Identifier::from(CoreIdentifier::Constant(c))),
}
}
e => fold_uint_expression_inner(self, ty, e),
@ -123,16 +133,20 @@ impl<'a, 'ast, T: Field> Folder<'ast, T> for ConstantsReader<'a, 'ast, T> {
e: ArrayExpressionInner<'ast, T>,
) -> ArrayExpressionInner<'ast, T> {
match e {
ArrayExpressionInner::Identifier(Identifier {
id: CoreIdentifier::Constant(c),
version,
ArrayExpressionInner::Identifier(IdentifierExpression {
id:
Identifier {
id: CoreIdentifier::Constant(c),
version,
},
..
}) => {
assert_eq!(version, 0);
match self.constants.get(&c).cloned() {
Some(v) => ArrayExpression::try_from(v).unwrap().into_inner(),
None => ArrayExpressionInner::Identifier(Identifier::from(
CoreIdentifier::Constant(c),
)),
None => {
ArrayExpression::identifier(Identifier::from(CoreIdentifier::Constant(c)))
}
}
}
e => fold_array_expression_inner(self, ty, e),
@ -145,16 +159,20 @@ impl<'a, 'ast, T: Field> Folder<'ast, T> for ConstantsReader<'a, 'ast, T> {
e: TupleExpressionInner<'ast, T>,
) -> TupleExpressionInner<'ast, T> {
match e {
TupleExpressionInner::Identifier(Identifier {
id: CoreIdentifier::Constant(c),
version,
TupleExpressionInner::Identifier(IdentifierExpression {
id:
Identifier {
id: CoreIdentifier::Constant(c),
version,
},
..
}) => {
assert_eq!(version, 0);
match self.constants.get(&c).cloned() {
Some(v) => TupleExpression::try_from(v).unwrap().into_inner(),
None => TupleExpressionInner::Identifier(Identifier::from(
CoreIdentifier::Constant(c),
)),
None => {
TupleExpression::identifier(Identifier::from(CoreIdentifier::Constant(c)))
}
}
}
e => fold_tuple_expression_inner(self, ty, e),
@ -167,16 +185,20 @@ impl<'a, 'ast, T: Field> Folder<'ast, T> for ConstantsReader<'a, 'ast, T> {
e: StructExpressionInner<'ast, T>,
) -> StructExpressionInner<'ast, T> {
match e {
StructExpressionInner::Identifier(Identifier {
id: CoreIdentifier::Constant(c),
version,
StructExpressionInner::Identifier(IdentifierExpression {
id:
Identifier {
id: CoreIdentifier::Constant(c),
version,
},
..
}) => {
assert_eq!(version, 0);
match self.constants.get(&c).cloned() {
Some(v) => StructExpression::try_from(v).unwrap().into_inner(),
None => StructExpressionInner::Identifier(Identifier::from(
CoreIdentifier::Constant(c),
)),
None => {
StructExpression::identifier(Identifier::from(CoreIdentifier::Constant(c)))
}
}
}
e => fold_struct_expression_inner(self, ty, e),

View file

@ -601,7 +601,7 @@ mod tests {
let foo: TypedFunction<Bn128Field> = TypedFunction {
arguments: vec![DeclarationVariable::field_element("a").into()],
statements: vec![TypedStatement::Return(
FieldElementExpression::Identifier("a".into()).into(),
FieldElementExpression::identifier("a".into()).into(),
)],
signature: DeclarationSignature::new()
.inputs(vec![DeclarationType::FieldElement])
@ -617,13 +617,13 @@ mod tests {
),
TypedStatement::definition(
Variable::uint("n", UBitwidth::B32).into(),
UExpressionInner::Identifier("n".into())
UExpression::identifier("n".into())
.annotate(UBitwidth::B32)
.into(),
),
TypedStatement::definition(
Variable::field_element("a").into(),
FieldElementExpression::Identifier("a".into()).into(),
FieldElementExpression::identifier("a".into()).into(),
),
TypedStatement::definition(
Variable::field_element("a").into(),
@ -634,17 +634,17 @@ mod tests {
.output(DeclarationType::FieldElement),
),
vec![],
vec![FieldElementExpression::Identifier("a".into()).into()],
vec![FieldElementExpression::identifier("a".into()).into()],
)
.into(),
),
TypedStatement::definition(
Variable::uint("n", UBitwidth::B32).into(),
UExpressionInner::Identifier("n".into())
UExpression::identifier("n".into())
.annotate(UBitwidth::B32)
.into(),
),
TypedStatement::Return(FieldElementExpression::Identifier("a".into()).into()),
TypedStatement::Return(FieldElementExpression::identifier("a".into()).into()),
],
signature: DeclarationSignature::new()
.inputs(vec![DeclarationType::FieldElement])
@ -689,7 +689,7 @@ mod tests {
statements: vec![
TypedStatement::definition(
Variable::field_element(Identifier::from("a").version(1)).into(),
FieldElementExpression::Identifier("a".into()).into(),
FieldElementExpression::identifier("a".into()).into(),
),
TypedStatement::PushCallLog(
DeclarationFunctionKey::with_location("main", "foo").signature(
@ -701,23 +701,23 @@ mod tests {
),
TypedStatement::definition(
Variable::field_element(Identifier::from("a").version(3)).into(),
FieldElementExpression::Identifier(Identifier::from("a").version(1)).into(),
FieldElementExpression::identifier(Identifier::from("a").version(1)).into(),
),
TypedStatement::definition(
Variable::field_element(Identifier::from(CoreIdentifier::Call(0)).version(0))
.into(),
FieldElementExpression::Identifier(Identifier::from("a").version(3)).into(),
FieldElementExpression::identifier(Identifier::from("a").version(3)).into(),
),
TypedStatement::PopCallLog,
TypedStatement::definition(
Variable::field_element(Identifier::from("a").version(2)).into(),
FieldElementExpression::Identifier(
FieldElementExpression::identifier(
Identifier::from(CoreIdentifier::Call(0)).version(0),
)
.into(),
),
TypedStatement::Return(
FieldElementExpression::Identifier(Identifier::from("a").version(2)).into(),
FieldElementExpression::identifier(Identifier::from("a").version(2)).into(),
),
],
signature: DeclarationSignature::new()
@ -794,7 +794,7 @@ mod tests {
)
.into()],
statements: vec![TypedStatement::Return(
ArrayExpressionInner::Identifier("a".into())
ArrayExpression::identifier("a".into())
.annotate(Type::FieldElement, 1u32)
.into(),
)],
@ -810,14 +810,14 @@ mod tests {
),
TypedStatement::definition(
Variable::uint("n", UBitwidth::B32).into(),
UExpressionInner::Identifier("n".into())
UExpression::identifier("n".into())
.annotate(UBitwidth::B32)
.into(),
),
TypedStatement::definition(
Variable::array("b", Type::FieldElement, 1u32).into(),
ArrayExpressionInner::Value(
vec![FieldElementExpression::Identifier("a".into()).into()].into(),
vec![FieldElementExpression::identifier("a".into()).into()].into(),
)
.annotate(Type::FieldElement, 1u32)
.into(),
@ -828,7 +828,7 @@ mod tests {
DeclarationFunctionKey::with_location("main", "foo")
.signature(foo_signature.clone()),
vec![None],
vec![ArrayExpressionInner::Identifier("b".into())
vec![ArrayExpression::identifier("b".into())
.annotate(Type::FieldElement, 1u32)
.into()],
)
@ -837,14 +837,14 @@ mod tests {
),
TypedStatement::definition(
Variable::uint("n", UBitwidth::B32).into(),
UExpressionInner::Identifier("n".into())
UExpression::identifier("n".into())
.annotate(UBitwidth::B32)
.into(),
),
TypedStatement::Return(
(FieldElementExpression::Identifier("a".into())
(FieldElementExpression::identifier("a".into())
+ FieldElementExpression::select(
ArrayExpressionInner::Identifier("b".into())
ArrayExpression::identifier("b".into())
.annotate(Type::FieldElement, 1u32),
0u32,
))
@ -892,7 +892,7 @@ mod tests {
TypedStatement::definition(
Variable::array("b", Type::FieldElement, 1u32).into(),
ArrayExpressionInner::Value(
vec![FieldElementExpression::Identifier("a".into()).into()].into(),
vec![FieldElementExpression::identifier("a".into()).into()].into(),
)
.annotate(Type::FieldElement, 1u32)
.into(),
@ -909,7 +909,7 @@ mod tests {
TypedStatement::definition(
Variable::array(Identifier::from("a").version(1), Type::FieldElement, 1u32)
.into(),
ArrayExpressionInner::Identifier("b".into())
ArrayExpression::identifier("b".into())
.annotate(Type::FieldElement, 1u32)
.into(),
),
@ -920,7 +920,7 @@ mod tests {
1u32,
)
.into(),
ArrayExpressionInner::Identifier(Identifier::from("a").version(1))
ArrayExpression::identifier(Identifier::from("a").version(1))
.annotate(Type::FieldElement, 1u32)
.into(),
),
@ -928,16 +928,16 @@ mod tests {
TypedStatement::definition(
Variable::array(Identifier::from("b").version(1), Type::FieldElement, 1u32)
.into(),
ArrayExpressionInner::Identifier(
ArrayExpression::identifier(
Identifier::from(CoreIdentifier::Call(0)).version(0),
)
.annotate(Type::FieldElement, 1u32)
.into(),
),
TypedStatement::Return(
(FieldElementExpression::Identifier("a".into())
(FieldElementExpression::identifier("a".into())
+ FieldElementExpression::select(
ArrayExpressionInner::Identifier(Identifier::from("b").version(1))
ArrayExpression::identifier(Identifier::from("b").version(1))
.annotate(Type::FieldElement, 1u32),
0u32,
))
@ -1018,7 +1018,7 @@ mod tests {
)
.into()],
statements: vec![TypedStatement::Return(
ArrayExpressionInner::Identifier("a".into())
ArrayExpression::identifier("a".into())
.annotate(Type::FieldElement, 1u32)
.into(),
)],
@ -1034,7 +1034,7 @@ mod tests {
),
TypedStatement::definition(
Variable::uint("n", UBitwidth::B32).into(),
UExpressionInner::Identifier("n".into())
UExpression::identifier("n".into())
.annotate(UBitwidth::B32)
.into(),
),
@ -1043,14 +1043,14 @@ mod tests {
"b",
Type::FieldElement,
UExpressionInner::Sub(
box UExpressionInner::Identifier("n".into()).annotate(UBitwidth::B32),
box UExpression::identifier("n".into()).annotate(UBitwidth::B32),
box 1u32.into(),
)
.annotate(UBitwidth::B32),
)
.into(),
ArrayExpressionInner::Value(
vec![FieldElementExpression::Identifier("a".into()).into()].into(),
vec![FieldElementExpression::identifier("a".into()).into()].into(),
)
.annotate(Type::FieldElement, 1u32)
.into(),
@ -1061,7 +1061,7 @@ mod tests {
DeclarationFunctionKey::with_location("main", "foo")
.signature(foo_signature.clone()),
vec![None],
vec![ArrayExpressionInner::Identifier("b".into())
vec![ArrayExpression::identifier("b".into())
.annotate(Type::FieldElement, 1u32)
.into()],
)
@ -1070,14 +1070,14 @@ mod tests {
),
TypedStatement::definition(
Variable::uint("n", UBitwidth::B32).into(),
UExpressionInner::Identifier("n".into())
UExpression::identifier("n".into())
.annotate(UBitwidth::B32)
.into(),
),
TypedStatement::Return(
(FieldElementExpression::Identifier("a".into())
(FieldElementExpression::identifier("a".into())
+ FieldElementExpression::select(
ArrayExpressionInner::Identifier("b".into())
ArrayExpression::identifier("b".into())
.annotate(Type::FieldElement, 1u32),
0u32,
))
@ -1125,7 +1125,7 @@ mod tests {
TypedStatement::definition(
Variable::array("b", Type::FieldElement, 1u32).into(),
ArrayExpressionInner::Value(
vec![FieldElementExpression::Identifier("a".into()).into()].into(),
vec![FieldElementExpression::identifier("a".into()).into()].into(),
)
.annotate(Type::FieldElement, 1u32)
.into(),
@ -1142,7 +1142,7 @@ mod tests {
TypedStatement::definition(
Variable::array(Identifier::from("a").version(1), Type::FieldElement, 1u32)
.into(),
ArrayExpressionInner::Identifier("b".into())
ArrayExpression::identifier("b".into())
.annotate(Type::FieldElement, 1u32)
.into(),
),
@ -1153,7 +1153,7 @@ mod tests {
1u32,
)
.into(),
ArrayExpressionInner::Identifier(Identifier::from("a").version(1))
ArrayExpression::identifier(Identifier::from("a").version(1))
.annotate(Type::FieldElement, 1u32)
.into(),
),
@ -1161,16 +1161,16 @@ mod tests {
TypedStatement::definition(
Variable::array(Identifier::from("b").version(1), Type::FieldElement, 1u32)
.into(),
ArrayExpressionInner::Identifier(
ArrayExpression::identifier(
Identifier::from(CoreIdentifier::Call(0)).version(0),
)
.annotate(Type::FieldElement, 1u32)
.into(),
),
TypedStatement::Return(
(FieldElementExpression::Identifier("a".into())
(FieldElementExpression::identifier("a".into())
+ FieldElementExpression::select(
ArrayExpressionInner::Identifier(Identifier::from("b").version(1))
ArrayExpression::identifier(Identifier::from("b").version(1))
.annotate(Type::FieldElement, 1u32),
0u32,
))
@ -1258,7 +1258,7 @@ mod tests {
Variable::array(
"ret",
Type::FieldElement,
UExpressionInner::Identifier("K".into()).annotate(UBitwidth::B32),
UExpression::identifier("K".into()).annotate(UBitwidth::B32),
)
.into(),
ArrayExpressionInner::Slice(
@ -1269,7 +1269,7 @@ mod tests {
vec![ArrayExpressionInner::Value(
vec![
TypedExpressionOrSpread::Spread(
ArrayExpressionInner::Identifier("a".into())
ArrayExpression::identifier("a".into())
.annotate(Type::FieldElement, 1u32)
.into(),
),
@ -1279,30 +1279,30 @@ mod tests {
)
.annotate(
Type::FieldElement,
UExpressionInner::Identifier("K".into()).annotate(UBitwidth::B32)
UExpression::identifier("K".into()).annotate(UBitwidth::B32)
+ 1u32.into(),
)
.into()],
)
.annotate(
Type::FieldElement,
UExpressionInner::Identifier("K".into()).annotate(UBitwidth::B32)
UExpression::identifier("K".into()).annotate(UBitwidth::B32)
+ 1u32.into(),
),
box 0u32.into(),
box UExpressionInner::Identifier("K".into()).annotate(UBitwidth::B32),
box UExpression::identifier("K".into()).annotate(UBitwidth::B32),
)
.annotate(
Type::FieldElement,
UExpressionInner::Identifier("K".into()).annotate(UBitwidth::B32),
UExpression::identifier("K".into()).annotate(UBitwidth::B32),
)
.into(),
),
TypedStatement::Return(
ArrayExpressionInner::Identifier("ret".into())
ArrayExpression::identifier("ret".into())
.annotate(
Type::FieldElement,
UExpressionInner::Identifier("K".into()).annotate(UBitwidth::B32),
UExpression::identifier("K".into()).annotate(UBitwidth::B32),
)
.into(),
),
@ -1320,10 +1320,10 @@ mod tests {
)
.into()],
statements: vec![TypedStatement::Return(
ArrayExpressionInner::Identifier("a".into())
ArrayExpression::identifier("a".into())
.annotate(
Type::FieldElement,
UExpressionInner::Identifier("K".into()).annotate(UBitwidth::B32),
UExpression::identifier("K".into()).annotate(UBitwidth::B32),
)
.into(),
)],
@ -1475,7 +1475,7 @@ mod tests {
)
.into()],
statements: vec![TypedStatement::Return(
ArrayExpressionInner::Identifier("a".into())
ArrayExpression::identifier("a".into())
.annotate(Type::FieldElement, 1u32)
.into(),
)],

View file

@ -197,7 +197,7 @@ mod tests {
fn detect_non_constant_bound() {
let loops: Vec<TypedStatement<Bn128Field>> = vec![TypedStatement::For(
Variable::new("i", Type::Uint(UBitwidth::B32), false),
UExpressionInner::Identifier("i".into()).annotate(UBitwidth::B32),
UExpression::identifier("i".into()).annotate(UBitwidth::B32),
2u32.into(),
vec![],
)];
@ -265,10 +265,10 @@ mod tests {
);
let e: FieldElementExpression<Bn128Field> =
FieldElementExpression::Identifier("a".into());
FieldElementExpression::identifier("a".into());
assert_eq!(
u.fold_field_expression(e),
FieldElementExpression::Identifier(Identifier::from("a").version(1))
FieldElementExpression::identifier(Identifier::from("a").version(1))
);
}
@ -303,7 +303,7 @@ mod tests {
let s = TypedStatement::definition(
TypedAssignee::Identifier(Variable::field_element("a")),
FieldElementExpression::Add(
box FieldElementExpression::Identifier("a".into()),
box FieldElementExpression::identifier("a".into()),
box FieldElementExpression::Number(Bn128Field::from(1)),
)
.into(),
@ -315,7 +315,7 @@ mod tests {
Identifier::from("a").version(1)
)),
FieldElementExpression::Add(
box FieldElementExpression::Identifier(Identifier::from("a").version(0)),
box FieldElementExpression::identifier(Identifier::from("a").version(0)),
box FieldElementExpression::Number(Bn128Field::from(1))
)
.into()
@ -360,7 +360,7 @@ mod tests {
.output(DeclarationType::FieldElement),
),
vec![],
vec![FieldElementExpression::Identifier("a".into()).into()],
vec![FieldElementExpression::identifier("a".into()).into()],
)
.into(),
);
@ -376,7 +376,7 @@ mod tests {
),
vec![],
vec![
FieldElementExpression::Identifier(Identifier::from("a").version(0))
FieldElementExpression::identifier(Identifier::from("a").version(0))
.into()
]
)
@ -594,43 +594,43 @@ mod tests {
),
TypedStatement::definition(
Variable::uint("n", UBitwidth::B32).into(),
UExpressionInner::Identifier("n".into())
UExpression::identifier("n".into())
.annotate(UBitwidth::B32)
.into(),
),
TypedStatement::definition(
Variable::field_element("a").into(),
FieldElementExpression::Identifier("a".into()).into(),
FieldElementExpression::identifier("a".into()).into(),
),
TypedStatement::For(
Variable::uint("i", UBitwidth::B32),
UExpressionInner::Identifier("n".into()).annotate(UBitwidth::B32),
UExpressionInner::Identifier("n".into()).annotate(UBitwidth::B32)
* UExpressionInner::Identifier("n".into()).annotate(UBitwidth::B32),
UExpression::identifier("n".into()).annotate(UBitwidth::B32),
UExpression::identifier("n".into()).annotate(UBitwidth::B32)
* UExpression::identifier("n".into()).annotate(UBitwidth::B32),
vec![TypedStatement::definition(
Variable::field_element("a").into(),
FieldElementExpression::Identifier("a".into()).into(),
FieldElementExpression::identifier("a".into()).into(),
)],
),
TypedStatement::definition(
Variable::field_element("a").into(),
FieldElementExpression::Identifier("a".into()).into(),
FieldElementExpression::identifier("a".into()).into(),
),
TypedStatement::For(
Variable::uint("i", UBitwidth::B32),
UExpressionInner::Identifier("n".into()).annotate(UBitwidth::B32),
UExpressionInner::Identifier("n".into()).annotate(UBitwidth::B32)
* UExpressionInner::Identifier("n".into()).annotate(UBitwidth::B32),
UExpression::identifier("n".into()).annotate(UBitwidth::B32),
UExpression::identifier("n".into()).annotate(UBitwidth::B32)
* UExpression::identifier("n".into()).annotate(UBitwidth::B32),
vec![TypedStatement::definition(
Variable::field_element("a").into(),
FieldElementExpression::Identifier("a".into()).into(),
FieldElementExpression::identifier("a".into()).into(),
)],
),
TypedStatement::definition(
Variable::field_element("a").into(),
FieldElementExpression::Identifier("a".into()).into(),
FieldElementExpression::identifier("a".into()).into(),
),
TypedStatement::Return(FieldElementExpression::Identifier("a".into()).into()),
TypedStatement::Return(FieldElementExpression::identifier("a".into()).into()),
],
signature: DeclarationSignature::new()
.generics(vec![Some(
@ -665,50 +665,50 @@ mod tests {
),
TypedStatement::definition(
Variable::uint(Identifier::from("n").version(1), UBitwidth::B32).into(),
UExpressionInner::Identifier("n".into())
UExpression::identifier("n".into())
.annotate(UBitwidth::B32)
.into(),
),
TypedStatement::definition(
Variable::field_element(Identifier::from("a").version(1)).into(),
FieldElementExpression::Identifier("a".into()).into(),
FieldElementExpression::identifier("a".into()).into(),
),
TypedStatement::For(
Variable::uint("i", UBitwidth::B32),
UExpressionInner::Identifier(Identifier::from("n").version(1))
UExpression::identifier(Identifier::from("n").version(1))
.annotate(UBitwidth::B32),
UExpressionInner::Identifier(Identifier::from("n").version(1))
UExpression::identifier(Identifier::from("n").version(1))
.annotate(UBitwidth::B32)
* UExpressionInner::Identifier(Identifier::from("n").version(1))
* UExpression::identifier(Identifier::from("n").version(1))
.annotate(UBitwidth::B32),
vec![TypedStatement::definition(
Variable::field_element("a").into(),
FieldElementExpression::Identifier("a".into()).into(),
FieldElementExpression::identifier("a".into()).into(),
)],
),
TypedStatement::definition(
Variable::field_element(Identifier::from("a").version(3)).into(),
FieldElementExpression::Identifier(Identifier::from("a").version(2)).into(),
FieldElementExpression::identifier(Identifier::from("a").version(2)).into(),
),
TypedStatement::For(
Variable::uint("i", UBitwidth::B32),
UExpressionInner::Identifier(Identifier::from("n").version(2))
UExpression::identifier(Identifier::from("n").version(2))
.annotate(UBitwidth::B32),
UExpressionInner::Identifier(Identifier::from("n").version(2))
UExpression::identifier(Identifier::from("n").version(2))
.annotate(UBitwidth::B32)
* UExpressionInner::Identifier(Identifier::from("n").version(2))
* UExpression::identifier(Identifier::from("n").version(2))
.annotate(UBitwidth::B32),
vec![TypedStatement::definition(
Variable::field_element("a").into(),
FieldElementExpression::Identifier("a".into()).into(),
FieldElementExpression::identifier("a".into()).into(),
)],
),
TypedStatement::definition(
Variable::field_element(Identifier::from("a").version(5)).into(),
FieldElementExpression::Identifier(Identifier::from("a").version(4)).into(),
FieldElementExpression::identifier(Identifier::from("a").version(4)).into(),
),
TypedStatement::Return(
FieldElementExpression::Identifier(Identifier::from("a").version(5)).into(),
FieldElementExpression::identifier(Identifier::from("a").version(5)).into(),
),
],
signature: DeclarationSignature::new()
@ -846,7 +846,7 @@ mod tests {
vec![
TypedStatement::definition(
Variable::field_element(Identifier::from("a")).into(),
FieldElementExpression::Identifier("a".into()).into(),
FieldElementExpression::identifier("a".into()).into(),
),
TypedStatement::definition(
Variable::field_element(Identifier::from("a")).into(),
@ -855,7 +855,7 @@ mod tests {
],
),
TypedStatement::Return(
TupleExpressionInner::Value(vec![FieldElementExpression::Identifier(
TupleExpressionInner::Value(vec![FieldElementExpression::identifier(
"a".into(),
)
.into()])
@ -879,7 +879,7 @@ mod tests {
vec![
TypedStatement::definition(
Variable::field_element(Identifier::from("a")).into(),
FieldElementExpression::Identifier(Identifier::from("a")).into(),
FieldElementExpression::identifier(Identifier::from("a")).into(),
),
TypedStatement::definition(
Variable::field_element(Identifier::from("a")).into(),
@ -888,7 +888,7 @@ mod tests {
],
),
TypedStatement::Return(
TupleExpressionInner::Value(vec![FieldElementExpression::Identifier(
TupleExpressionInner::Value(vec![FieldElementExpression::identifier(
Identifier::from("a").version(1),
)
.into()])
@ -951,45 +951,44 @@ mod tests {
),
TypedStatement::definition(
Variable::uint("n", UBitwidth::B32).into(),
UExpressionInner::Identifier("n".into())
UExpression::identifier("n".into())
.annotate(UBitwidth::B32)
.into(),
),
TypedStatement::definition(
Variable::field_element("a").into(),
FieldElementExpression::Identifier("a".into()).into(),
FieldElementExpression::identifier("a".into()).into(),
),
TypedStatement::definition(
Variable::field_element("a").into(),
FieldElementExpression::function_call(
DeclarationFunctionKey::with_location("main", "foo"),
vec![Some(
UExpressionInner::Identifier("n".into()).annotate(UBitwidth::B32),
UExpression::identifier("n".into()).annotate(UBitwidth::B32),
)],
vec![FieldElementExpression::Identifier("a".into()).into()],
vec![FieldElementExpression::identifier("a".into()).into()],
)
.into(),
),
TypedStatement::definition(
Variable::uint("n", UBitwidth::B32).into(),
UExpressionInner::Identifier("n".into())
UExpression::identifier("n".into())
.annotate(UBitwidth::B32)
.into(),
),
TypedStatement::definition(
Variable::field_element("a").into(),
(FieldElementExpression::Identifier("a".into())
(FieldElementExpression::identifier("a".into())
* FieldElementExpression::function_call(
DeclarationFunctionKey::with_location("main", "foo"),
vec![Some(
UExpressionInner::Identifier("n".into())
.annotate(UBitwidth::B32),
UExpression::identifier("n".into()).annotate(UBitwidth::B32),
)],
vec![FieldElementExpression::Identifier("a".into()).into()],
vec![FieldElementExpression::identifier("a".into()).into()],
))
.into(),
),
TypedStatement::Return(FieldElementExpression::Identifier("a".into()).into()),
TypedStatement::Return(FieldElementExpression::identifier("a".into()).into()),
],
signature: DeclarationSignature::new()
.generics(vec![Some(
@ -1024,23 +1023,23 @@ mod tests {
),
TypedStatement::definition(
Variable::uint(Identifier::from("n").version(1), UBitwidth::B32).into(),
UExpressionInner::Identifier("n".into())
UExpression::identifier("n".into())
.annotate(UBitwidth::B32)
.into(),
),
TypedStatement::definition(
Variable::field_element(Identifier::from("a").version(1)).into(),
FieldElementExpression::Identifier("a".into()).into(),
FieldElementExpression::identifier("a".into()).into(),
),
TypedStatement::definition(
Variable::field_element(Identifier::from("a").version(2)).into(),
FieldElementExpression::function_call(
DeclarationFunctionKey::with_location("main", "foo"),
vec![Some(
UExpressionInner::Identifier(Identifier::from("n").version(1))
UExpression::identifier(Identifier::from("n").version(1))
.annotate(UBitwidth::B32),
)],
vec![FieldElementExpression::Identifier(
vec![FieldElementExpression::identifier(
Identifier::from("a").version(1),
)
.into()],
@ -1049,20 +1048,20 @@ mod tests {
),
TypedStatement::definition(
Variable::uint(Identifier::from("n").version(2), UBitwidth::B32).into(),
UExpressionInner::Identifier(Identifier::from("n").version(1))
UExpression::identifier(Identifier::from("n").version(1))
.annotate(UBitwidth::B32)
.into(),
),
TypedStatement::definition(
Variable::field_element(Identifier::from("a").version(3)).into(),
(FieldElementExpression::Identifier(Identifier::from("a").version(2))
(FieldElementExpression::identifier(Identifier::from("a").version(2))
* FieldElementExpression::function_call(
DeclarationFunctionKey::with_location("main", "foo"),
vec![Some(
UExpressionInner::Identifier(Identifier::from("n").version(2))
UExpression::identifier(Identifier::from("n").version(2))
.annotate(UBitwidth::B32),
)],
vec![FieldElementExpression::Identifier(
vec![FieldElementExpression::identifier(
Identifier::from("a").version(2),
)
.into()],
@ -1070,7 +1069,7 @@ mod tests {
.into(),
),
TypedStatement::Return(
FieldElementExpression::Identifier(Identifier::from("a").version(3)).into(),
FieldElementExpression::identifier(Identifier::from("a").version(3)).into(),
),
],
signature: DeclarationSignature::new()

View file

@ -469,27 +469,21 @@ impl<'ast, T: Field> Folder<'ast, T> for VariableWriteRemover {
let base = match variable.get_type() {
Type::Int => unreachable!(),
Type::FieldElement => {
FieldElementExpression::Identifier(variable.id.clone()).into()
FieldElementExpression::identifier(variable.id.clone()).into()
}
Type::Boolean => BooleanExpression::Identifier(variable.id.clone()).into(),
Type::Uint(bitwidth) => UExpressionInner::Identifier(variable.id.clone())
Type::Boolean => BooleanExpression::identifier(variable.id.clone()).into(),
Type::Uint(bitwidth) => UExpression::identifier(variable.id.clone())
.annotate(bitwidth)
.into(),
Type::Array(array_type) => {
ArrayExpressionInner::Identifier(variable.id.clone())
.annotate(*array_type.ty, *array_type.size)
.into()
}
Type::Struct(members) => {
StructExpressionInner::Identifier(variable.id.clone())
.annotate(members)
.into()
}
Type::Tuple(tuple_ty) => {
TupleExpressionInner::Identifier(variable.id.clone())
.annotate(tuple_ty)
.into()
}
Type::Array(array_type) => ArrayExpression::identifier(variable.id.clone())
.annotate(*array_type.ty, *array_type.size)
.into(),
Type::Struct(members) => StructExpression::identifier(variable.id.clone())
.annotate(members)
.into(),
Type::Tuple(tuple_ty) => TupleExpression::identifier(variable.id.clone())
.annotate(tuple_ty)
.into(),
};
let base = self.fold_expression(base);

View file

@ -0,0 +1,6 @@
{
"entry_point": "./tests/tests/duplicate.zok",
"max_constraint_count": 1,
"curves": ["Bn128"],
"tests": []
}

View file

@ -0,0 +1,4 @@
def main(field a) {
assert(a == 0);
assert(a == 0);
}

View file

@ -29,7 +29,7 @@ import "utils/casts/u32_8_to_bool_256";
///
/// Returns:
/// Return true for S being a valid EdDSA Signature, false otherwise.
def main(private field[2] R, private field S, field[2] A, u32[8] M0, u32[8] M1, BabyJubJubParams context) -> bool {
def main(field[2] R, field S, field[2] A, u32[8] M0, u32[8] M1, BabyJubJubParams context) -> bool {
field[2] G = [context.Gu, context.Gv];
// Check if R is on curve and if it is not in a small subgroup. A is public input and can be checked offline