merge develop, fix conflicts
This commit is contained in:
commit
8d9e8899c1
177 changed files with 4721 additions and 566 deletions
|
@ -96,7 +96,7 @@ jobs:
|
|||
zokrates_js_build:
|
||||
docker:
|
||||
- image: zokrates/env:latest
|
||||
resource_class: large
|
||||
resource_class: xlarge
|
||||
working_directory: ~/project/zokrates_js
|
||||
steps:
|
||||
- checkout:
|
||||
|
@ -111,7 +111,7 @@ jobs:
|
|||
zokrates_js_test:
|
||||
docker:
|
||||
- image: zokrates/env:latest
|
||||
resource_class: large
|
||||
resource_class: xlarge
|
||||
working_directory: ~/project/zokrates_js
|
||||
steps:
|
||||
- checkout:
|
||||
|
|
48
Cargo.lock
generated
48
Cargo.lock
generated
|
@ -2109,6 +2109,12 @@ dependencies = [
|
|||
"thiserror",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "reduce"
|
||||
version = "0.1.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "16d2dc47b68ac15ea328cd7ebe01d7d512ed29787f7d534ad2a3c341328b35d7"
|
||||
|
||||
[[package]]
|
||||
name = "regex"
|
||||
version = "0.2.11"
|
||||
|
@ -2991,6 +2997,29 @@ dependencies = [
|
|||
"zokrates_field",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "zokrates_analysis"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"cfg-if 0.1.10",
|
||||
"csv",
|
||||
"lazy_static",
|
||||
"log",
|
||||
"num 0.1.42",
|
||||
"num-bigint 0.2.6",
|
||||
"pretty_assertions 0.6.1",
|
||||
"reduce",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"typed-arena",
|
||||
"zokrates_ast",
|
||||
"zokrates_common",
|
||||
"zokrates_embed",
|
||||
"zokrates_field",
|
||||
"zokrates_fs_resolver",
|
||||
"zokrates_pest_ast",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "zokrates_ark"
|
||||
version = "0.1.1"
|
||||
|
@ -3107,9 +3136,23 @@ dependencies = [
|
|||
"zokrates_solidity_test",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "zokrates_codegen"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"zokrates_ast",
|
||||
"zokrates_common",
|
||||
"zokrates_embed",
|
||||
"zokrates_field",
|
||||
"zokrates_interpreter",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "zokrates_common"
|
||||
version = "0.1.1"
|
||||
dependencies = [
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "zokrates_core"
|
||||
|
@ -3125,7 +3168,9 @@ dependencies = [
|
|||
"serde",
|
||||
"serde_json",
|
||||
"typed-arena",
|
||||
"zokrates_analysis",
|
||||
"zokrates_ast",
|
||||
"zokrates_codegen",
|
||||
"zokrates_common",
|
||||
"zokrates_embed",
|
||||
"zokrates_field",
|
||||
|
@ -3202,6 +3247,7 @@ dependencies = [
|
|||
"pairing_ce",
|
||||
"serde",
|
||||
"zokrates_abi",
|
||||
"zokrates_analysis",
|
||||
"zokrates_ast",
|
||||
"zokrates_embed",
|
||||
"zokrates_field",
|
||||
|
@ -3212,6 +3258,7 @@ name = "zokrates_js"
|
|||
version = "1.1.4"
|
||||
dependencies = [
|
||||
"console_error_panic_hook",
|
||||
"getrandom",
|
||||
"indexmap",
|
||||
"js-sys",
|
||||
"json",
|
||||
|
@ -3300,6 +3347,7 @@ dependencies = [
|
|||
name = "zokrates_test"
|
||||
version = "0.2.0"
|
||||
dependencies = [
|
||||
"getrandom",
|
||||
"rand 0.8.5",
|
||||
"serde",
|
||||
"serde_derive",
|
||||
|
|
|
@ -6,6 +6,8 @@ members = [
|
|||
"zokrates_cli",
|
||||
"zokrates_fs_resolver",
|
||||
"zokrates_stdlib",
|
||||
"zokrates_codegen",
|
||||
"zokrates_analysis",
|
||||
"zokrates_embed",
|
||||
"zokrates_abi",
|
||||
"zokrates_test",
|
||||
|
|
1
changelogs/unreleased/1246-dark64
Normal file
1
changelogs/unreleased/1246-dark64
Normal file
|
@ -0,0 +1 @@
|
|||
Introduce constraint generation through assembly blocks
|
31
zokrates_analysis/Cargo.toml
Normal file
31
zokrates_analysis/Cargo.toml
Normal file
|
@ -0,0 +1,31 @@
|
|||
[package]
|
||||
name = "zokrates_analysis"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
[features]
|
||||
default = ["ark", "bellman"]
|
||||
ark = ["zokrates_ast/ark", "zokrates_embed/ark", "zokrates_common/ark"]
|
||||
bellman = ["zokrates_ast/bellman", "zokrates_embed/bellman", "zokrates_common/bellman"]
|
||||
|
||||
[dependencies]
|
||||
log = "0.4"
|
||||
cfg-if = "0.1"
|
||||
num = { version = "0.1.36", default-features = false }
|
||||
num-bigint = { version = "0.2", default-features = false }
|
||||
lazy_static = "1.4"
|
||||
typed-arena = "1.4.1"
|
||||
reduce = "0.1.1"
|
||||
# serialization and deserialization
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
serde_json = { version = "1.0", features = ["preserve_order"] }
|
||||
zokrates_field = { version = "0.5.0", path = "../zokrates_field", default-features = false }
|
||||
zokrates_pest_ast = { version = "0.3.0", path = "../zokrates_pest_ast" }
|
||||
zokrates_common = { version = "0.1", path = "../zokrates_common", default-features = false }
|
||||
zokrates_embed = { version = "0.1.0", path = "../zokrates_embed", default-features = false }
|
||||
zokrates_ast = { version = "0.1", path = "../zokrates_ast", default-features = false }
|
||||
csv = "1"
|
||||
|
||||
[dev-dependencies]
|
||||
pretty_assertions = "0.6.1"
|
||||
zokrates_fs_resolver = { version = "0.5", path = "../zokrates_fs_resolver"}
|
412
zokrates_analysis/src/assembly_transformer.rs
Normal file
412
zokrates_analysis/src/assembly_transformer.rs
Normal file
|
@ -0,0 +1,412 @@
|
|||
// A static analyser pass to transform user-defined constraints to the form `lin_comb === quad_comb`
|
||||
// This pass can fail if a non-quadratic constraint is found which cannot be transformed to the expected form
|
||||
|
||||
use crate::ZirPropagator;
|
||||
use std::fmt;
|
||||
use zokrates_ast::zir::lqc::LinQuadComb;
|
||||
use zokrates_ast::zir::result_folder::ResultFolder;
|
||||
use zokrates_ast::zir::{FieldElementExpression, Id, Identifier, ZirAssemblyStatement, ZirProgram};
|
||||
use zokrates_field::Field;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Error(String);
|
||||
|
||||
impl fmt::Display for Error {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
write!(f, "{}", self.0)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct AssemblyTransformer;
|
||||
|
||||
impl AssemblyTransformer {
|
||||
pub fn transform<T: Field>(p: ZirProgram<T>) -> Result<ZirProgram<T>, Error> {
|
||||
AssemblyTransformer.fold_program(p)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast, T: Field> ResultFolder<'ast, T> for AssemblyTransformer {
|
||||
type Error = Error;
|
||||
|
||||
fn fold_assembly_statement(
|
||||
&mut self,
|
||||
s: ZirAssemblyStatement<'ast, T>,
|
||||
) -> Result<Vec<ZirAssemblyStatement<'ast, T>>, Self::Error> {
|
||||
match s {
|
||||
ZirAssemblyStatement::Assignment(_, _) => Ok(vec![s]),
|
||||
ZirAssemblyStatement::Constraint(lhs, rhs, metadata) => {
|
||||
let lhs = self.fold_field_expression(lhs)?;
|
||||
let rhs = self.fold_field_expression(rhs)?;
|
||||
|
||||
let (is_quadratic, lhs, rhs) = match (lhs, rhs) {
|
||||
(
|
||||
lhs @ FieldElementExpression::Identifier(..),
|
||||
rhs @ FieldElementExpression::Identifier(..),
|
||||
) => (true, lhs, rhs),
|
||||
(FieldElementExpression::Mult(x, y), other)
|
||||
| (other, FieldElementExpression::Mult(x, y))
|
||||
if other.is_linear() =>
|
||||
{
|
||||
(
|
||||
x.is_linear() && y.is_linear(),
|
||||
other,
|
||||
FieldElementExpression::Mult(x, y),
|
||||
)
|
||||
}
|
||||
(lhs, rhs) => (false, lhs, rhs),
|
||||
};
|
||||
|
||||
match is_quadratic {
|
||||
true => Ok(vec![ZirAssemblyStatement::Constraint(lhs, rhs, metadata)]),
|
||||
false => {
|
||||
let sub = FieldElementExpression::Sub(box lhs, box rhs);
|
||||
let mut lqc = LinQuadComb::try_from(sub.clone()).map_err(|_| {
|
||||
Error("Non-quadratic constraints are not allowed".to_string())
|
||||
})?;
|
||||
|
||||
let linear = lqc
|
||||
.linear
|
||||
.into_iter()
|
||||
.map(|(c, i)| {
|
||||
FieldElementExpression::Mult(
|
||||
box FieldElementExpression::Number(c),
|
||||
box FieldElementExpression::identifier(i),
|
||||
)
|
||||
})
|
||||
.fold(FieldElementExpression::Number(T::from(0)), |acc, e| {
|
||||
FieldElementExpression::Add(box acc, box e)
|
||||
});
|
||||
|
||||
let lhs = FieldElementExpression::Add(
|
||||
box FieldElementExpression::Number(lqc.constant),
|
||||
box linear,
|
||||
);
|
||||
|
||||
let rhs: FieldElementExpression<'ast, T> = if lqc.quadratic.len() > 1 {
|
||||
let common_factor = lqc
|
||||
.quadratic
|
||||
.iter()
|
||||
.scan(None, |state: &mut Option<Vec<&Identifier>>, (_, a, b)| {
|
||||
// short circuit if we do not have any common factors anymore
|
||||
if *state == Some(vec![]) {
|
||||
None
|
||||
} else {
|
||||
match state {
|
||||
// only keep factors found in this term
|
||||
Some(factors) => {
|
||||
factors.retain(|&x| x == a || x == b);
|
||||
}
|
||||
// initialisation step, start with the two factors in the first term
|
||||
None => {
|
||||
*state = Some(vec![a, b]);
|
||||
}
|
||||
};
|
||||
state.clone()
|
||||
}
|
||||
})
|
||||
.last()
|
||||
.and_then(|mut v| v.pop().cloned());
|
||||
|
||||
match common_factor {
|
||||
Some(factor) => Ok(FieldElementExpression::Mult(
|
||||
box lqc
|
||||
.quadratic
|
||||
.into_iter()
|
||||
.map(|(c, i0, i1)| {
|
||||
let c = T::zero() - c;
|
||||
let e = match (i0, i1) {
|
||||
(i0, i1) if factor.eq(&i0) => {
|
||||
FieldElementExpression::identifier(i1)
|
||||
}
|
||||
(i0, i1) if factor.eq(&i1) => {
|
||||
FieldElementExpression::identifier(i0)
|
||||
}
|
||||
_ => unreachable!(),
|
||||
};
|
||||
FieldElementExpression::Mult(
|
||||
box FieldElementExpression::Number(c),
|
||||
box e,
|
||||
)
|
||||
})
|
||||
.fold(
|
||||
FieldElementExpression::Number(T::from(0)),
|
||||
|acc, e| FieldElementExpression::Add(box acc, box e),
|
||||
),
|
||||
box FieldElementExpression::identifier(factor),
|
||||
)),
|
||||
None => Err(Error(
|
||||
"Non-quadratic constraints are not allowed".to_string(),
|
||||
)),
|
||||
}?
|
||||
} else {
|
||||
lqc.quadratic
|
||||
.pop()
|
||||
.map(|(c, i0, i1)| {
|
||||
FieldElementExpression::Mult(
|
||||
box FieldElementExpression::Mult(
|
||||
box FieldElementExpression::Number(T::zero() - c),
|
||||
box FieldElementExpression::identifier(i0),
|
||||
),
|
||||
box FieldElementExpression::identifier(i1),
|
||||
)
|
||||
})
|
||||
.unwrap_or_else(|| FieldElementExpression::Number(T::from(0)))
|
||||
};
|
||||
|
||||
let mut propagator = ZirPropagator::default();
|
||||
let lhs = propagator
|
||||
.fold_field_expression(lhs)
|
||||
.map_err(|e| Error(e.to_string()))?;
|
||||
|
||||
let rhs = propagator
|
||||
.fold_field_expression(rhs)
|
||||
.map_err(|e| Error(e.to_string()))?;
|
||||
|
||||
Ok(vec![ZirAssemblyStatement::Constraint(lhs, rhs, metadata)])
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use zokrates_ast::common::SourceMetadata;
|
||||
use zokrates_field::Bn128Field;
|
||||
|
||||
#[test]
|
||||
fn quadratic() {
|
||||
// x === a * b;
|
||||
let lhs = FieldElementExpression::<Bn128Field>::identifier("x".into());
|
||||
let rhs = FieldElementExpression::Mult(
|
||||
box FieldElementExpression::identifier("a".into()),
|
||||
box FieldElementExpression::identifier("b".into()),
|
||||
);
|
||||
|
||||
let expected = vec![ZirAssemblyStatement::Constraint(
|
||||
FieldElementExpression::identifier("x".into()),
|
||||
FieldElementExpression::Mult(
|
||||
box FieldElementExpression::identifier("a".into()),
|
||||
box FieldElementExpression::identifier("b".into()),
|
||||
),
|
||||
SourceMetadata::default(),
|
||||
)];
|
||||
let result = AssemblyTransformer
|
||||
.fold_assembly_statement(ZirAssemblyStatement::Constraint(
|
||||
lhs,
|
||||
rhs,
|
||||
SourceMetadata::default(),
|
||||
))
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(result, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn non_quadratic() {
|
||||
// x === ((a * b) * c);
|
||||
let lhs = FieldElementExpression::<Bn128Field>::identifier("x".into());
|
||||
let rhs = FieldElementExpression::Mult(
|
||||
box FieldElementExpression::Mult(
|
||||
box FieldElementExpression::identifier("a".into()),
|
||||
box FieldElementExpression::identifier("b".into()),
|
||||
),
|
||||
box FieldElementExpression::identifier("c".into()),
|
||||
);
|
||||
|
||||
let result = AssemblyTransformer.fold_assembly_statement(ZirAssemblyStatement::Constraint(
|
||||
lhs,
|
||||
rhs,
|
||||
SourceMetadata::default(),
|
||||
));
|
||||
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn transform() {
|
||||
// x === 1 - a * b; --> (-1) + x === (((-1) * a) * b);
|
||||
let lhs = FieldElementExpression::identifier("x".into());
|
||||
let rhs = FieldElementExpression::Sub(
|
||||
box FieldElementExpression::Number(Bn128Field::from(1)),
|
||||
box FieldElementExpression::Mult(
|
||||
box FieldElementExpression::identifier("a".into()),
|
||||
box FieldElementExpression::identifier("b".into()),
|
||||
),
|
||||
);
|
||||
|
||||
let expected = vec![ZirAssemblyStatement::Constraint(
|
||||
FieldElementExpression::Add(
|
||||
box FieldElementExpression::Number(Bn128Field::from(-1)),
|
||||
box FieldElementExpression::identifier("x".into()),
|
||||
),
|
||||
FieldElementExpression::Mult(
|
||||
box FieldElementExpression::Mult(
|
||||
box FieldElementExpression::Number(Bn128Field::from(-1)),
|
||||
box FieldElementExpression::identifier("a".into()),
|
||||
),
|
||||
box FieldElementExpression::identifier("b".into()),
|
||||
),
|
||||
SourceMetadata::default(),
|
||||
)];
|
||||
|
||||
let result = AssemblyTransformer
|
||||
.fold_assembly_statement(ZirAssemblyStatement::Constraint(
|
||||
lhs,
|
||||
rhs,
|
||||
SourceMetadata::default(),
|
||||
))
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(result, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn factorize() {
|
||||
// x === (a * b) + (b * c); --> x === ((a + c) * b);
|
||||
let lhs = FieldElementExpression::<Bn128Field>::identifier("x".into());
|
||||
let rhs = FieldElementExpression::Add(
|
||||
box FieldElementExpression::Mult(
|
||||
box FieldElementExpression::identifier("a".into()),
|
||||
box FieldElementExpression::identifier("b".into()),
|
||||
),
|
||||
box FieldElementExpression::Mult(
|
||||
box FieldElementExpression::identifier("b".into()),
|
||||
box FieldElementExpression::identifier("c".into()),
|
||||
),
|
||||
);
|
||||
|
||||
let expected = vec![ZirAssemblyStatement::Constraint(
|
||||
FieldElementExpression::identifier("x".into()),
|
||||
FieldElementExpression::Mult(
|
||||
box FieldElementExpression::Add(
|
||||
box FieldElementExpression::identifier("a".into()),
|
||||
box FieldElementExpression::identifier("c".into()),
|
||||
),
|
||||
box FieldElementExpression::identifier("b".into()),
|
||||
),
|
||||
SourceMetadata::default(),
|
||||
)];
|
||||
let result = AssemblyTransformer
|
||||
.fold_assembly_statement(ZirAssemblyStatement::Constraint(
|
||||
lhs,
|
||||
rhs,
|
||||
SourceMetadata::default(),
|
||||
))
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(result, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn transform_complex() {
|
||||
// mid = b*c;
|
||||
// x === a+b+c - 2*a*b - 2*a*c - 2*mid + 4*a*mid; // x === a ^ b ^ c
|
||||
// -->
|
||||
// ((((x + ((-1)*a)) + ((-1)*b)) + ((-1)*c)) + (2*mid)) === (((((-2)*b) + ((-2)*c)) + (4*mid)) * a);
|
||||
let lhs = FieldElementExpression::<Bn128Field>::identifier("x".into());
|
||||
let rhs = FieldElementExpression::Add(
|
||||
box FieldElementExpression::Sub(
|
||||
box FieldElementExpression::Sub(
|
||||
box FieldElementExpression::Sub(
|
||||
box FieldElementExpression::Add(
|
||||
box FieldElementExpression::Add(
|
||||
box FieldElementExpression::identifier("a".into()),
|
||||
box FieldElementExpression::identifier("b".into()),
|
||||
),
|
||||
box FieldElementExpression::identifier("c".into()),
|
||||
),
|
||||
box FieldElementExpression::Mult(
|
||||
box FieldElementExpression::Mult(
|
||||
box FieldElementExpression::Number(Bn128Field::from(2)),
|
||||
box FieldElementExpression::identifier("a".into()),
|
||||
),
|
||||
box FieldElementExpression::identifier("b".into()),
|
||||
),
|
||||
),
|
||||
box FieldElementExpression::Mult(
|
||||
box FieldElementExpression::Mult(
|
||||
box FieldElementExpression::Number(Bn128Field::from(2)),
|
||||
box FieldElementExpression::identifier("a".into()),
|
||||
),
|
||||
box FieldElementExpression::identifier("c".into()),
|
||||
),
|
||||
),
|
||||
box FieldElementExpression::Mult(
|
||||
box FieldElementExpression::Number(Bn128Field::from(2)),
|
||||
box FieldElementExpression::identifier("mid".into()),
|
||||
),
|
||||
),
|
||||
box FieldElementExpression::Mult(
|
||||
box FieldElementExpression::Mult(
|
||||
box FieldElementExpression::Number(Bn128Field::from(4)),
|
||||
box FieldElementExpression::identifier("a".into()),
|
||||
),
|
||||
box FieldElementExpression::identifier("mid".into()),
|
||||
),
|
||||
);
|
||||
|
||||
let lhs_expected = FieldElementExpression::Add(
|
||||
box FieldElementExpression::Add(
|
||||
box FieldElementExpression::Add(
|
||||
box FieldElementExpression::Add(
|
||||
box FieldElementExpression::identifier("x".into()),
|
||||
box FieldElementExpression::Mult(
|
||||
box FieldElementExpression::Number(Bn128Field::from(-1)),
|
||||
box FieldElementExpression::identifier("a".into()),
|
||||
),
|
||||
),
|
||||
box FieldElementExpression::Mult(
|
||||
box FieldElementExpression::Number(Bn128Field::from(-1)),
|
||||
box FieldElementExpression::identifier("b".into()),
|
||||
),
|
||||
),
|
||||
box FieldElementExpression::Mult(
|
||||
box FieldElementExpression::Number(Bn128Field::from(-1)),
|
||||
box FieldElementExpression::identifier("c".into()),
|
||||
),
|
||||
),
|
||||
box FieldElementExpression::Mult(
|
||||
box FieldElementExpression::Number(Bn128Field::from(2)),
|
||||
box FieldElementExpression::identifier("mid".into()),
|
||||
),
|
||||
);
|
||||
|
||||
let rhs_expected = FieldElementExpression::Mult(
|
||||
box FieldElementExpression::Add(
|
||||
box FieldElementExpression::Add(
|
||||
box FieldElementExpression::Mult(
|
||||
box FieldElementExpression::Number(Bn128Field::from(-2)),
|
||||
box FieldElementExpression::identifier("b".into()),
|
||||
),
|
||||
box FieldElementExpression::Mult(
|
||||
box FieldElementExpression::Number(Bn128Field::from(-2)),
|
||||
box FieldElementExpression::identifier("c".into()),
|
||||
),
|
||||
),
|
||||
box FieldElementExpression::Mult(
|
||||
box FieldElementExpression::Number(Bn128Field::from(4)),
|
||||
box FieldElementExpression::identifier("mid".into()),
|
||||
),
|
||||
),
|
||||
box FieldElementExpression::identifier("a".into()),
|
||||
);
|
||||
|
||||
let expected = vec![ZirAssemblyStatement::Constraint(
|
||||
lhs_expected,
|
||||
rhs_expected,
|
||||
SourceMetadata::default(),
|
||||
)];
|
||||
let result = AssemblyTransformer
|
||||
.fold_assembly_statement(ZirAssemblyStatement::Constraint(
|
||||
lhs,
|
||||
rhs,
|
||||
SourceMetadata::default(),
|
||||
))
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(result, expected);
|
||||
}
|
||||
}
|
|
@ -1,9 +1,7 @@
|
|||
use std::fmt;
|
||||
use zokrates_ast::common::FlatEmbed;
|
||||
use zokrates_ast::typed::{
|
||||
result_folder::ResultFolder,
|
||||
result_folder::{fold_statement, fold_uint_expression_inner},
|
||||
Constant, EmbedCall, TypedStatement, UBitwidth, UExpressionInner,
|
||||
result_folder::fold_statement, result_folder::ResultFolder, Constant, EmbedCall, TypedStatement,
|
||||
};
|
||||
use zokrates_ast::typed::{DefinitionRhs, TypedProgram};
|
||||
use zokrates_field::Field;
|
||||
|
@ -71,40 +69,4 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for ConstantArgumentChecker {
|
|||
s => fold_statement(self, s),
|
||||
}
|
||||
}
|
||||
|
||||
fn fold_uint_expression_inner(
|
||||
&mut self,
|
||||
bitwidth: UBitwidth,
|
||||
e: UExpressionInner<'ast, T>,
|
||||
) -> Result<UExpressionInner<'ast, T>, Error> {
|
||||
match e {
|
||||
UExpressionInner::LeftShift(box e, box by) => {
|
||||
let e = self.fold_uint_expression(e)?;
|
||||
let by = self.fold_uint_expression(by)?;
|
||||
|
||||
match by.as_inner() {
|
||||
UExpressionInner::Value(_) => Ok(UExpressionInner::LeftShift(box e, box by)),
|
||||
by => Err(Error(format!(
|
||||
"Cannot shift by a variable value, found `{} << {}`",
|
||||
e,
|
||||
by.clone().annotate(UBitwidth::B32)
|
||||
))),
|
||||
}
|
||||
}
|
||||
UExpressionInner::RightShift(box e, box by) => {
|
||||
let e = self.fold_uint_expression(e)?;
|
||||
let by = self.fold_uint_expression(by)?;
|
||||
|
||||
match by.as_inner() {
|
||||
UExpressionInner::Value(_) => Ok(UExpressionInner::RightShift(box e, box by)),
|
||||
by => Err(Error(format!(
|
||||
"Cannot shift by a variable value, found `{} >> {}`",
|
||||
e,
|
||||
by.clone().annotate(UBitwidth::B32)
|
||||
))),
|
||||
}
|
||||
}
|
||||
e => fold_uint_expression_inner(self, bitwidth, e),
|
||||
}
|
||||
}
|
||||
}
|
107
zokrates_analysis/src/expression_validator.rs
Normal file
107
zokrates_analysis/src/expression_validator.rs
Normal file
|
@ -0,0 +1,107 @@
|
|||
use std::fmt;
|
||||
use zokrates_ast::typed::result_folder::{
|
||||
fold_assembly_statement, fold_field_expression, fold_uint_expression_inner, ResultFolder,
|
||||
};
|
||||
use zokrates_ast::typed::{
|
||||
FieldElementExpression, TypedAssemblyStatement, TypedProgram, UBitwidth, UExpressionInner,
|
||||
};
|
||||
use zokrates_field::Field;
|
||||
|
||||
#[derive(Debug, PartialEq, Eq)]
|
||||
pub struct Error(String);
|
||||
|
||||
impl fmt::Display for Error {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
write!(f, "{}", self.0)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ExpressionValidator;
|
||||
|
||||
impl ExpressionValidator {
|
||||
pub fn validate<T: Field>(p: TypedProgram<T>) -> Result<TypedProgram<T>, Error> {
|
||||
ExpressionValidator.fold_program(p)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast, T: Field> ResultFolder<'ast, T> for ExpressionValidator {
|
||||
type Error = Error;
|
||||
|
||||
fn fold_assembly_statement(
|
||||
&mut self,
|
||||
s: TypedAssemblyStatement<'ast, T>,
|
||||
) -> Result<Vec<TypedAssemblyStatement<'ast, T>>, Self::Error> {
|
||||
match s {
|
||||
// we allow more dynamic expressions in witness generation
|
||||
TypedAssemblyStatement::Assignment(_, _) => Ok(vec![s]),
|
||||
s => fold_assembly_statement(self, s),
|
||||
}
|
||||
}
|
||||
|
||||
fn fold_field_expression(
|
||||
&mut self,
|
||||
e: FieldElementExpression<'ast, T>,
|
||||
) -> Result<FieldElementExpression<'ast, T>, Self::Error> {
|
||||
match e {
|
||||
// these should have been propagated away
|
||||
FieldElementExpression::And(_, _)
|
||||
| FieldElementExpression::Or(_, _)
|
||||
| FieldElementExpression::Xor(_, _)
|
||||
| FieldElementExpression::LeftShift(_, _)
|
||||
| FieldElementExpression::RightShift(_, _) => Err(Error(format!(
|
||||
"Found non-constant bitwise operation in field element expression `{}`",
|
||||
e
|
||||
))),
|
||||
FieldElementExpression::Pow(box e, box exp) => {
|
||||
let e = self.fold_field_expression(e)?;
|
||||
let exp = self.fold_uint_expression(exp)?;
|
||||
|
||||
match exp.as_inner() {
|
||||
UExpressionInner::Value(_) => Ok(FieldElementExpression::Pow(box e, box exp)),
|
||||
exp => Err(Error(format!(
|
||||
"Found non-constant exponent in power expression `{}**{}`",
|
||||
e,
|
||||
exp.clone().annotate(UBitwidth::B32)
|
||||
))),
|
||||
}
|
||||
}
|
||||
e => fold_field_expression(self, e),
|
||||
}
|
||||
}
|
||||
|
||||
fn fold_uint_expression_inner(
|
||||
&mut self,
|
||||
bitwidth: UBitwidth,
|
||||
e: UExpressionInner<'ast, T>,
|
||||
) -> Result<UExpressionInner<'ast, T>, Error> {
|
||||
match e {
|
||||
UExpressionInner::LeftShift(box e, box by) => {
|
||||
let e = self.fold_uint_expression(e)?;
|
||||
let by = self.fold_uint_expression(by)?;
|
||||
|
||||
match by.as_inner() {
|
||||
UExpressionInner::Value(_) => Ok(UExpressionInner::LeftShift(box e, box by)),
|
||||
by => Err(Error(format!(
|
||||
"Cannot shift by a variable value, found `{} << {}`",
|
||||
e,
|
||||
by.clone().annotate(UBitwidth::B32)
|
||||
))),
|
||||
}
|
||||
}
|
||||
UExpressionInner::RightShift(box e, box by) => {
|
||||
let e = self.fold_uint_expression(e)?;
|
||||
let by = self.fold_uint_expression(by)?;
|
||||
|
||||
match by.as_inner() {
|
||||
UExpressionInner::Value(_) => Ok(UExpressionInner::RightShift(box e, box by)),
|
||||
by => Err(Error(format!(
|
||||
"Cannot shift by a variable value, found `{} >> {}`",
|
||||
e,
|
||||
by.clone().annotate(UBitwidth::B32)
|
||||
))),
|
||||
}
|
||||
}
|
||||
e => fold_uint_expression_inner(self, bitwidth, e),
|
||||
}
|
||||
}
|
||||
}
|
|
@ -14,8 +14,8 @@ struct Propagator<T> {
|
|||
constants: HashMap<Variable, T>,
|
||||
}
|
||||
|
||||
impl<T: Field> Folder<T> for Propagator<T> {
|
||||
fn fold_statement(&mut self, s: FlatStatement<T>) -> Vec<FlatStatement<T>> {
|
||||
impl<'ast, T: Field> Folder<'ast, T> for Propagator<T> {
|
||||
fn fold_statement(&mut self, s: FlatStatement<'ast, T>) -> Vec<FlatStatement<'ast, T>> {
|
||||
match s {
|
||||
FlatStatement::Definition(var, expr) => match self.fold_expression(expr) {
|
||||
FlatExpression::Number(n) => {
|
|
@ -1,11 +1,12 @@
|
|||
use std::collections::HashMap;
|
||||
use std::convert::{TryFrom, TryInto};
|
||||
use std::marker::PhantomData;
|
||||
use zokrates_ast::typed::types::{ConcreteArrayType, IntoType, UBitwidth};
|
||||
use zokrates_ast::typed::{self, Expr, Typed};
|
||||
use zokrates_ast::zir::{self, Id, Select};
|
||||
use zokrates_ast::zir::IntoType as ZirIntoType;
|
||||
use zokrates_ast::zir::{self, Folder, Id, Select};
|
||||
use zokrates_field::Field;
|
||||
|
||||
use std::convert::{TryFrom, TryInto};
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct Flattener<T: Field> {
|
||||
phantom: PhantomData<T>,
|
||||
|
@ -272,6 +273,14 @@ impl<'ast, T: Field> Flattener<T> {
|
|||
}
|
||||
}
|
||||
|
||||
fn fold_assembly_statement(
|
||||
&mut self,
|
||||
statements_buffer: &mut Vec<zir::ZirStatement<'ast, T>>,
|
||||
s: typed::TypedAssemblyStatement<'ast, T>,
|
||||
) -> zir::ZirAssemblyStatement<'ast, T> {
|
||||
fold_assembly_statement(self, statements_buffer, s)
|
||||
}
|
||||
|
||||
fn fold_statement(
|
||||
&mut self,
|
||||
statements_buffer: &mut Vec<zir::ZirStatement<'ast, T>>,
|
||||
|
@ -449,12 +458,126 @@ impl<'ast, T: Field> Flattener<T> {
|
|||
}
|
||||
}
|
||||
|
||||
// This finder looks for identifiers that were not defined in some block of statements
|
||||
// These identifiers are used as function arguments when moving witness assignment expression
|
||||
// to a zir function.
|
||||
//
|
||||
// Example:
|
||||
// def main(field a, field mut b) -> field {
|
||||
// asm {
|
||||
// b <== a * a;
|
||||
// }
|
||||
// return b;
|
||||
// }
|
||||
// is turned into
|
||||
// def main(field a, field mut b) -> field {
|
||||
// asm {
|
||||
// b <-- (field a) -> field {
|
||||
// return a * a;
|
||||
// }
|
||||
// b == a * a;
|
||||
// }
|
||||
// return b;
|
||||
// }
|
||||
#[derive(Default)]
|
||||
pub struct ArgumentFinder<'ast, T> {
|
||||
pub identifiers: HashMap<zir::Identifier<'ast>, zir::Type>,
|
||||
_phantom: PhantomData<T>,
|
||||
}
|
||||
|
||||
impl<'ast, T: Field> Folder<'ast, T> for ArgumentFinder<'ast, T> {
|
||||
fn fold_statement(&mut self, s: zir::ZirStatement<'ast, T>) -> Vec<zir::ZirStatement<'ast, T>> {
|
||||
match s {
|
||||
zir::ZirStatement::Definition(assignee, expr) => {
|
||||
let assignee = self.fold_assignee(assignee);
|
||||
let expr = self.fold_expression(expr);
|
||||
self.identifiers.remove(&assignee.id);
|
||||
vec![zir::ZirStatement::Definition(assignee, expr)]
|
||||
}
|
||||
zir::ZirStatement::MultipleDefinition(assignees, list) => {
|
||||
let assignees: Vec<zir::ZirAssignee<'ast>> = assignees
|
||||
.into_iter()
|
||||
.map(|v| self.fold_assignee(v))
|
||||
.collect();
|
||||
let list = self.fold_expression_list(list);
|
||||
for a in &assignees {
|
||||
self.identifiers.remove(&a.id);
|
||||
}
|
||||
vec![zir::ZirStatement::MultipleDefinition(assignees, list)]
|
||||
}
|
||||
s => zir::folder::fold_statement(self, s),
|
||||
}
|
||||
}
|
||||
|
||||
fn fold_identifier_expression<E: zir::Expr<'ast, T> + Id<'ast, T>>(
|
||||
&mut self,
|
||||
ty: &E::Ty,
|
||||
e: zir::IdentifierExpression<'ast, E>,
|
||||
) -> zir::IdentifierOrExpression<'ast, T, E> {
|
||||
self.identifiers
|
||||
.insert(e.id.clone(), ty.clone().into_type());
|
||||
zir::IdentifierOrExpression::Identifier(e)
|
||||
}
|
||||
}
|
||||
|
||||
fn fold_assembly_statement<'ast, T: Field>(
|
||||
f: &mut Flattener<T>,
|
||||
statements_buffer: &mut Vec<zir::ZirStatement<'ast, T>>,
|
||||
s: typed::TypedAssemblyStatement<'ast, T>,
|
||||
) -> zir::ZirAssemblyStatement<'ast, T> {
|
||||
match s {
|
||||
typed::TypedAssemblyStatement::Assignment(a, e) => {
|
||||
let mut statements_buffer: Vec<zir::ZirStatement<'ast, T>> = vec![];
|
||||
let a = f.fold_assignee(a);
|
||||
let e = f.fold_expression(&mut statements_buffer, e);
|
||||
statements_buffer.push(zir::ZirStatement::Return(e));
|
||||
|
||||
let mut finder = ArgumentFinder::default();
|
||||
let mut statements_buffer: Vec<zir::ZirStatement<'ast, T>> = statements_buffer
|
||||
.into_iter()
|
||||
.rev()
|
||||
.flat_map(|s| finder.fold_statement(s))
|
||||
.collect();
|
||||
statements_buffer.reverse();
|
||||
|
||||
let function = zir::ZirFunction {
|
||||
signature: zir::types::Signature::default()
|
||||
.inputs(finder.identifiers.values().cloned().collect())
|
||||
.outputs(a.iter().map(|a| a.get_type()).collect()),
|
||||
arguments: finder
|
||||
.identifiers
|
||||
.into_iter()
|
||||
.map(|(id, ty)| zir::Parameter {
|
||||
id: zir::Variable::with_id_and_type(id, ty),
|
||||
private: true,
|
||||
})
|
||||
.collect(),
|
||||
statements: statements_buffer,
|
||||
};
|
||||
|
||||
zir::ZirAssemblyStatement::Assignment(a, function)
|
||||
}
|
||||
typed::TypedAssemblyStatement::Constraint(lhs, rhs, metadata) => {
|
||||
let lhs = f.fold_field_expression(statements_buffer, lhs);
|
||||
let rhs = f.fold_field_expression(statements_buffer, rhs);
|
||||
zir::ZirAssemblyStatement::Constraint(lhs, rhs, metadata)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn fold_statement<'ast, T: Field>(
|
||||
f: &mut Flattener<T>,
|
||||
statements_buffer: &mut Vec<zir::ZirStatement<'ast, T>>,
|
||||
s: typed::TypedStatement<'ast, T>,
|
||||
) {
|
||||
let res = match s {
|
||||
typed::TypedStatement::Assembly(statements) => {
|
||||
let statements = statements
|
||||
.into_iter()
|
||||
.map(|s| f.fold_assembly_statement(statements_buffer, s))
|
||||
.collect();
|
||||
vec![zir::ZirStatement::Assembly(statements)]
|
||||
}
|
||||
typed::TypedStatement::Return(expression) => vec![zir::ZirStatement::Return(
|
||||
f.fold_expression(statements_buffer, expression),
|
||||
)],
|
||||
|
@ -471,7 +594,7 @@ fn fold_statement<'ast, T: Field>(
|
|||
let e = f.fold_boolean_expression(statements_buffer, e);
|
||||
let error = match error {
|
||||
typed::RuntimeError::SourceAssertion(metadata) => {
|
||||
zir::RuntimeError::SourceAssertion(metadata.to_string())
|
||||
zir::RuntimeError::SourceAssertion(metadata)
|
||||
}
|
||||
typed::RuntimeError::SelectRangeCheck => zir::RuntimeError::SelectRangeCheck,
|
||||
typed::RuntimeError::DivisionByZero => zir::RuntimeError::DivisionByZero,
|
||||
|
@ -896,6 +1019,36 @@ fn fold_field_expression<'ast, T: Field>(
|
|||
)
|
||||
}
|
||||
typed::FieldElementExpression::Pos(box e) => f.fold_field_expression(statements_buffer, e),
|
||||
typed::FieldElementExpression::Xor(box left, box right) => {
|
||||
let left = f.fold_field_expression(statements_buffer, left);
|
||||
let right = f.fold_field_expression(statements_buffer, right);
|
||||
|
||||
zir::FieldElementExpression::Xor(box left, box right)
|
||||
}
|
||||
typed::FieldElementExpression::And(box left, box right) => {
|
||||
let left = f.fold_field_expression(statements_buffer, left);
|
||||
let right = f.fold_field_expression(statements_buffer, right);
|
||||
|
||||
zir::FieldElementExpression::And(box left, box right)
|
||||
}
|
||||
typed::FieldElementExpression::Or(box left, box right) => {
|
||||
let left = f.fold_field_expression(statements_buffer, left);
|
||||
let right = f.fold_field_expression(statements_buffer, right);
|
||||
|
||||
zir::FieldElementExpression::Or(box left, box right)
|
||||
}
|
||||
typed::FieldElementExpression::LeftShift(box e, box by) => {
|
||||
let e = f.fold_field_expression(statements_buffer, e);
|
||||
let by = f.fold_uint_expression(statements_buffer, by);
|
||||
|
||||
zir::FieldElementExpression::LeftShift(box e, box by)
|
||||
}
|
||||
typed::FieldElementExpression::RightShift(box e, box by) => {
|
||||
let e = f.fold_field_expression(statements_buffer, e);
|
||||
let by = f.fold_uint_expression(statements_buffer, by);
|
||||
|
||||
zir::FieldElementExpression::RightShift(box e, box by)
|
||||
}
|
||||
typed::FieldElementExpression::Conditional(c) => f
|
||||
.fold_conditional_expression(statements_buffer, c)
|
||||
.pop()
|
|
@ -1,15 +1,19 @@
|
|||
#![feature(box_patterns, box_syntax)]
|
||||
|
||||
//! Module containing static analysis
|
||||
//!
|
||||
//! @file mod.rs
|
||||
//! @author Thibaut Schaeffer <thibaut@schaeff.fr>
|
||||
//! @date 2018
|
||||
|
||||
mod assembly_transformer;
|
||||
mod boolean_array_comparator;
|
||||
mod branch_isolator;
|
||||
mod condition_redefiner;
|
||||
mod constant_argument_checker;
|
||||
mod constant_resolver;
|
||||
mod dead_code;
|
||||
mod expression_validator;
|
||||
mod flat_propagation;
|
||||
mod flatten_complex_types;
|
||||
mod log_ignorer;
|
||||
|
@ -34,14 +38,16 @@ use self::reducer::reduce_program;
|
|||
use self::struct_concretizer::StructConcretizer;
|
||||
use self::uint_optimizer::UintOptimizer;
|
||||
use self::variable_write_remover::VariableWriteRemover;
|
||||
use crate::compile::CompileConfig;
|
||||
use crate::static_analysis::constant_resolver::ConstantResolver;
|
||||
use crate::static_analysis::dead_code::DeadCodeEliminator;
|
||||
use crate::static_analysis::panic_extractor::PanicExtractor;
|
||||
use crate::static_analysis::zir_propagation::ZirPropagator;
|
||||
use crate::assembly_transformer::AssemblyTransformer;
|
||||
use crate::constant_resolver::ConstantResolver;
|
||||
use crate::dead_code::DeadCodeEliminator;
|
||||
use crate::expression_validator::ExpressionValidator;
|
||||
use crate::panic_extractor::PanicExtractor;
|
||||
pub use crate::zir_propagation::ZirPropagator;
|
||||
use std::fmt;
|
||||
use zokrates_ast::typed::{abi::Abi, TypedProgram};
|
||||
use zokrates_ast::zir::ZirProgram;
|
||||
use zokrates_common::CompileConfig;
|
||||
use zokrates_field::Field;
|
||||
|
||||
#[derive(Debug)]
|
||||
|
@ -51,6 +57,9 @@ pub enum Error {
|
|||
ZirPropagation(self::zir_propagation::Error),
|
||||
NonConstantArgument(self::constant_argument_checker::Error),
|
||||
OutOfBounds(self::out_of_bounds::Error),
|
||||
Assembly(self::assembly_transformer::Error),
|
||||
VariableIndex(self::variable_write_remover::Error),
|
||||
InvalidExpression(self::expression_validator::Error),
|
||||
}
|
||||
|
||||
impl From<reducer::Error> for Error {
|
||||
|
@ -83,6 +92,24 @@ impl From<constant_argument_checker::Error> for Error {
|
|||
}
|
||||
}
|
||||
|
||||
impl From<assembly_transformer::Error> for Error {
|
||||
fn from(e: assembly_transformer::Error) -> Self {
|
||||
Error::Assembly(e)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<variable_write_remover::Error> for Error {
|
||||
fn from(e: variable_write_remover::Error) -> Self {
|
||||
Error::VariableIndex(e)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<expression_validator::Error> for Error {
|
||||
fn from(e: expression_validator::Error) -> Self {
|
||||
Error::InvalidExpression(e)
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for Error {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
match self {
|
||||
|
@ -91,6 +118,9 @@ impl fmt::Display for Error {
|
|||
Error::ZirPropagation(e) => write!(f, "{}", e),
|
||||
Error::NonConstantArgument(e) => write!(f, "{}", e),
|
||||
Error::OutOfBounds(e) => write!(f, "{}", e),
|
||||
Error::Assembly(e) => write!(f, "{}", e),
|
||||
Error::VariableIndex(e) => write!(f, "{}", e),
|
||||
Error::InvalidExpression(e) => write!(f, "{}", e),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -139,6 +169,10 @@ pub fn analyse<'ast, T: Field>(
|
|||
let r = StructConcretizer::concretize(r);
|
||||
log::trace!("\n{}", r);
|
||||
|
||||
// validate expressions
|
||||
log::debug!("Static analyser: Validate expressions");
|
||||
let r = ExpressionValidator::validate(r).map_err(Error::from)?;
|
||||
|
||||
// generate abi
|
||||
log::debug!("Static analyser: Generate abi");
|
||||
let abi = r.abi();
|
||||
|
@ -155,7 +189,7 @@ pub fn analyse<'ast, T: Field>(
|
|||
|
||||
// remove assignment to variable index
|
||||
log::debug!("Static analyser: Remove variable index");
|
||||
let r = VariableWriteRemover::apply(r);
|
||||
let r = VariableWriteRemover::apply(r).map_err(Error::from)?;
|
||||
log::trace!("\n{}", r);
|
||||
|
||||
// detect non constant shifts and constant lt bounds
|
||||
|
@ -196,5 +230,9 @@ pub fn analyse<'ast, T: Field>(
|
|||
let zir = UintOptimizer::optimize(zir);
|
||||
log::trace!("\n{}", zir);
|
||||
|
||||
log::debug!("Static analyser: Apply constraint transformations in assembly");
|
||||
let zir = AssemblyTransformer::transform(zir).map_err(Error::from)?;
|
||||
log::trace!("\n{}", zir);
|
||||
|
||||
Ok((zir, abi))
|
||||
}
|
|
@ -7,9 +7,12 @@
|
|||
//! @author Thibaut Schaeffer <thibaut@schaeff.fr>
|
||||
//! @date 2018
|
||||
|
||||
use num::traits::Pow;
|
||||
use num_bigint::BigUint;
|
||||
use std::collections::HashMap;
|
||||
use std::convert::{TryFrom, TryInto};
|
||||
use std::fmt;
|
||||
use std::ops::{BitAnd, BitOr, BitXor, Shl, Shr, Sub};
|
||||
use zokrates_ast::common::FlatEmbed;
|
||||
use zokrates_ast::typed::result_folder::*;
|
||||
use zokrates_ast::typed::types::Type;
|
||||
|
@ -21,28 +24,22 @@ pub type Constants<'ast, T> = HashMap<Identifier<'ast>, TypedExpression<'ast, T>
|
|||
#[derive(Debug, PartialEq, Eq)]
|
||||
pub enum Error {
|
||||
Type(String),
|
||||
AssertionFailed(String),
|
||||
ValueTooLarge(String),
|
||||
AssertionFailed(RuntimeError),
|
||||
InvalidValue(String),
|
||||
OutOfBounds(u128, u128),
|
||||
NonConstantExponent(String),
|
||||
}
|
||||
|
||||
impl fmt::Display for Error {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
match self {
|
||||
Error::Type(s) => write!(f, "{}", s),
|
||||
Error::AssertionFailed(s) => write!(f, "{}", s),
|
||||
Error::ValueTooLarge(s) => write!(f, "{}", s),
|
||||
Error::AssertionFailed(err) => write!(f, "Assertion failed ({})", err),
|
||||
Error::InvalidValue(s) => write!(f, "{}", s),
|
||||
Error::OutOfBounds(index, size) => write!(
|
||||
f,
|
||||
"Out of bounds index ({} >= {}) found during static analysis",
|
||||
index, size
|
||||
),
|
||||
Error::NonConstantExponent(s) => write!(
|
||||
f,
|
||||
"Non-constant exponent `{}` detected during static analysis",
|
||||
s
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -179,13 +176,6 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Propagator<'ast, 'a, T> {
|
|||
}
|
||||
}
|
||||
|
||||
fn fold_function(
|
||||
&mut self,
|
||||
f: TypedFunction<'ast, T>,
|
||||
) -> Result<TypedFunction<'ast, T>, Error> {
|
||||
fold_function(self, f)
|
||||
}
|
||||
|
||||
fn fold_conditional_expression<
|
||||
E: Expr<'ast, T> + Conditional<'ast, T> + PartialEq + ResultFold<'ast, T>,
|
||||
>(
|
||||
|
@ -215,11 +205,101 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Propagator<'ast, 'a, T> {
|
|||
)
|
||||
}
|
||||
|
||||
fn fold_assembly_statement(
|
||||
&mut self,
|
||||
s: TypedAssemblyStatement<'ast, T>,
|
||||
) -> Result<Vec<TypedAssemblyStatement<'ast, T>>, Self::Error> {
|
||||
match s {
|
||||
TypedAssemblyStatement::Assignment(assignee, expr) => {
|
||||
let assignee = self.fold_assignee(assignee)?;
|
||||
let expr = self.fold_expression(expr)?;
|
||||
|
||||
if expr.is_constant() {
|
||||
match assignee {
|
||||
TypedAssignee::Identifier(var) => {
|
||||
let expr = expr.into_canonical_constant();
|
||||
|
||||
assert!(self.constants.insert(var.id, expr).is_none());
|
||||
|
||||
Ok(vec![])
|
||||
}
|
||||
assignee => match self.try_get_constant_mut(&assignee) {
|
||||
Ok((_, c)) => {
|
||||
*c = expr.into_canonical_constant();
|
||||
Ok(vec![])
|
||||
}
|
||||
Err(v) => match self.constants.remove(&v.id) {
|
||||
// invalidate the cache for this identifier, and define the latest
|
||||
// version of the constant in the program, if any
|
||||
Some(c) => Ok(vec![
|
||||
TypedAssemblyStatement::Assignment(v.clone().into(), c),
|
||||
TypedAssemblyStatement::Assignment(assignee, expr),
|
||||
]),
|
||||
None => {
|
||||
Ok(vec![TypedAssemblyStatement::Assignment(assignee, expr)])
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
} else {
|
||||
// the expression being assigned is not constant, invalidate the cache
|
||||
let v = self
|
||||
.try_get_constant_mut(&assignee)
|
||||
.map(|(v, _)| v)
|
||||
.unwrap_or_else(|v| v);
|
||||
|
||||
match self.constants.remove(&v.id) {
|
||||
Some(c) => Ok(vec![
|
||||
TypedAssemblyStatement::Assignment(v.clone().into(), c),
|
||||
TypedAssemblyStatement::Assignment(assignee, expr),
|
||||
]),
|
||||
None => Ok(vec![TypedAssemblyStatement::Assignment(assignee, expr)]),
|
||||
}
|
||||
}
|
||||
}
|
||||
TypedAssemblyStatement::Constraint(left, right, metadata) => {
|
||||
let left = self.fold_field_expression(left)?;
|
||||
let right = self.fold_field_expression(right)?;
|
||||
|
||||
// a bit hacky, but we use a fake boolean expression to check this
|
||||
let is_equal =
|
||||
BooleanExpression::FieldEq(EqExpression::new(left.clone(), right.clone()));
|
||||
let is_equal = self.fold_boolean_expression(is_equal)?;
|
||||
|
||||
match is_equal {
|
||||
BooleanExpression::Value(true) => Ok(vec![]),
|
||||
BooleanExpression::Value(false) => {
|
||||
Err(Error::AssertionFailed(RuntimeError::SourceAssertion(
|
||||
metadata
|
||||
.message(Some(format!("In asm block: `{} !== {}`", left, right))),
|
||||
)))
|
||||
}
|
||||
_ => Ok(vec![TypedAssemblyStatement::Constraint(
|
||||
left, right, metadata,
|
||||
)]),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn fold_statement(
|
||||
&mut self,
|
||||
s: TypedStatement<'ast, T>,
|
||||
) -> Result<Vec<TypedStatement<'ast, T>>, Error> {
|
||||
match s {
|
||||
TypedStatement::Assembly(statements) => {
|
||||
let statements: Vec<_> = statements
|
||||
.into_iter()
|
||||
.map(|s| self.fold_assembly_statement(s))
|
||||
.collect::<Result<Vec<_>, _>>()?
|
||||
.into_iter()
|
||||
.flatten()
|
||||
.collect();
|
||||
match statements.len() {
|
||||
0 => Ok(vec![]),
|
||||
_ => Ok(vec![TypedStatement::Assembly(statements)]),
|
||||
}
|
||||
}
|
||||
// propagation to the defined variable if rhs is a constant
|
||||
TypedStatement::Definition(assignee, DefinitionRhs::Expression(expr)) => {
|
||||
let assignee = self.fold_assignee(assignee)?;
|
||||
|
@ -373,6 +453,26 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Propagator<'ast, 'a, T> {
|
|||
true => {
|
||||
let r: Option<TypedExpression<'ast, T>> = match embed_call.embed {
|
||||
FlatEmbed::BitArrayLe => Ok(None), // todo
|
||||
FlatEmbed::FieldToBoolUnsafe => {
|
||||
match FieldElementExpression::try_from_typed(
|
||||
embed_call.arguments[0].clone(),
|
||||
) {
|
||||
Ok(FieldElementExpression::Number(n)) if n == T::from(0) => {
|
||||
Ok(Some(BooleanExpression::Value(false).into()))
|
||||
}
|
||||
Ok(FieldElementExpression::Number(n)) if n == T::from(1) => {
|
||||
Ok(Some(BooleanExpression::Value(true).into()))
|
||||
}
|
||||
Ok(FieldElementExpression::Number(n)) => {
|
||||
Err(Error::InvalidValue(format!(
|
||||
"Cannot call `{}` with value `{}`: should be 0 or 1",
|
||||
embed_call.embed.id(),
|
||||
n
|
||||
)))
|
||||
}
|
||||
_ => Ok(None),
|
||||
}
|
||||
}
|
||||
FlatEmbed::U64FromBits => Ok(Some(process_u_from_bits(
|
||||
&embed_call.arguments,
|
||||
UBitwidth::B64,
|
||||
|
@ -430,7 +530,7 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Propagator<'ast, 'a, T> {
|
|||
}
|
||||
|
||||
if acc != T::zero() {
|
||||
Err(Error::ValueTooLarge(format!(
|
||||
Err(Error::InvalidValue(format!(
|
||||
"Cannot unpack `{}` to `{}`: value is too large",
|
||||
num,
|
||||
assignee.get_type()
|
||||
|
@ -521,15 +621,12 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Propagator<'ast, 'a, T> {
|
|||
}
|
||||
}
|
||||
}
|
||||
TypedStatement::Assertion(e, ty) => {
|
||||
let e_str = e.to_string();
|
||||
TypedStatement::Assertion(e, err) => {
|
||||
let expr = self.fold_boolean_expression(e)?;
|
||||
match expr {
|
||||
BooleanExpression::Value(false) => {
|
||||
Err(Error::AssertionFailed(format!("{}: ({})", ty, e_str)))
|
||||
}
|
||||
BooleanExpression::Value(false) => Err(Error::AssertionFailed(err)),
|
||||
BooleanExpression::Value(true) => Ok(vec![]),
|
||||
_ => Ok(vec![TypedStatement::Assertion(expr, ty)]),
|
||||
_ => Ok(vec![TypedStatement::Assertion(expr, err)]),
|
||||
}
|
||||
}
|
||||
s @ TypedStatement::PushCallLog(..) => Ok(vec![s]),
|
||||
|
@ -827,11 +924,140 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Propagator<'ast, 'a, T> {
|
|||
box e1,
|
||||
box UExpressionInner::Value(n2).annotate(UBitwidth::B32),
|
||||
)),
|
||||
(_, e2) => Err(Error::NonConstantExponent(
|
||||
e2.annotate(UBitwidth::B32).to_string(),
|
||||
(e1, e2) => Ok(FieldElementExpression::Pow(
|
||||
box e1,
|
||||
box e2.annotate(UBitwidth::B32),
|
||||
)),
|
||||
}
|
||||
}
|
||||
FieldElementExpression::Xor(box e1, box e2) => {
|
||||
let e1 = self.fold_field_expression(e1)?;
|
||||
let e2 = self.fold_field_expression(e2)?;
|
||||
|
||||
match (e1, e2) {
|
||||
(FieldElementExpression::Number(n1), FieldElementExpression::Number(n2)) => {
|
||||
Ok(FieldElementExpression::Number(
|
||||
T::try_from(n1.to_biguint().bitxor(n2.to_biguint())).unwrap(),
|
||||
))
|
||||
}
|
||||
(FieldElementExpression::Number(n), e)
|
||||
| (e, FieldElementExpression::Number(n))
|
||||
if n == T::from(0) =>
|
||||
{
|
||||
Ok(e)
|
||||
}
|
||||
(e1, e2) if e1.eq(&e2) => Ok(FieldElementExpression::Number(T::from(0))),
|
||||
(e1, e2) => Ok(FieldElementExpression::Xor(box e1, box e2)),
|
||||
}
|
||||
}
|
||||
|
||||
FieldElementExpression::And(box e1, box e2) => {
|
||||
let e1 = self.fold_field_expression(e1)?;
|
||||
let e2 = self.fold_field_expression(e2)?;
|
||||
|
||||
match (e1, e2) {
|
||||
(_, FieldElementExpression::Number(n))
|
||||
| (FieldElementExpression::Number(n), _)
|
||||
if n == T::from(0) =>
|
||||
{
|
||||
Ok(FieldElementExpression::Number(n))
|
||||
}
|
||||
(FieldElementExpression::Number(n1), FieldElementExpression::Number(n2)) => {
|
||||
Ok(FieldElementExpression::Number(
|
||||
T::try_from(n1.to_biguint().bitand(n2.to_biguint())).unwrap(),
|
||||
))
|
||||
}
|
||||
(e1, e2) => Ok(FieldElementExpression::And(box e1, box e2)),
|
||||
}
|
||||
}
|
||||
FieldElementExpression::Or(box e1, box e2) => {
|
||||
let e1 = self.fold_field_expression(e1)?;
|
||||
let e2 = self.fold_field_expression(e2)?;
|
||||
|
||||
match (e1, e2) {
|
||||
(e, FieldElementExpression::Number(n))
|
||||
| (FieldElementExpression::Number(n), e)
|
||||
if n == T::from(0) =>
|
||||
{
|
||||
Ok(e)
|
||||
}
|
||||
(FieldElementExpression::Number(n1), FieldElementExpression::Number(n2)) => {
|
||||
Ok(FieldElementExpression::Number(
|
||||
T::try_from(n1.to_biguint().bitor(n2.to_biguint())).unwrap(),
|
||||
))
|
||||
}
|
||||
(e1, e2) => Ok(FieldElementExpression::Or(box e1, box e2)),
|
||||
}
|
||||
}
|
||||
FieldElementExpression::LeftShift(box e, box by) => {
|
||||
let e = self.fold_field_expression(e)?;
|
||||
let by = self.fold_uint_expression(by)?;
|
||||
match (e, by) {
|
||||
(
|
||||
e,
|
||||
UExpression {
|
||||
inner: UExpressionInner::Value(by),
|
||||
..
|
||||
},
|
||||
) if by == 0 => Ok(e),
|
||||
(
|
||||
_,
|
||||
UExpression {
|
||||
inner: UExpressionInner::Value(by),
|
||||
..
|
||||
},
|
||||
) if by as usize >= T::get_required_bits() => {
|
||||
Ok(FieldElementExpression::Number(T::from(0)))
|
||||
}
|
||||
(
|
||||
FieldElementExpression::Number(n),
|
||||
UExpression {
|
||||
inner: UExpressionInner::Value(by),
|
||||
..
|
||||
},
|
||||
) => {
|
||||
let two = BigUint::from(2usize);
|
||||
let mask: BigUint = two.pow(T::get_required_bits()).sub(1usize);
|
||||
|
||||
Ok(FieldElementExpression::Number(
|
||||
T::try_from(n.to_biguint().shl(by as usize).bitand(mask)).unwrap(),
|
||||
))
|
||||
}
|
||||
(e, by) => Ok(FieldElementExpression::LeftShift(box e, box by)),
|
||||
}
|
||||
}
|
||||
FieldElementExpression::RightShift(box e, box by) => {
|
||||
let e = self.fold_field_expression(e)?;
|
||||
let by = self.fold_uint_expression(by)?;
|
||||
match (e, by) {
|
||||
(
|
||||
e,
|
||||
UExpression {
|
||||
inner: UExpressionInner::Value(by),
|
||||
..
|
||||
},
|
||||
) if by == 0 => Ok(e),
|
||||
(
|
||||
_,
|
||||
UExpression {
|
||||
inner: UExpressionInner::Value(by),
|
||||
..
|
||||
},
|
||||
) if by as usize >= T::get_required_bits() => {
|
||||
Ok(FieldElementExpression::Number(T::from(0)))
|
||||
}
|
||||
(
|
||||
FieldElementExpression::Number(n),
|
||||
UExpression {
|
||||
inner: UExpressionInner::Value(by),
|
||||
..
|
||||
},
|
||||
) => Ok(FieldElementExpression::Number(
|
||||
T::try_from(n.to_biguint().shr(by as usize)).unwrap(),
|
||||
)),
|
||||
(e, by) => Ok(FieldElementExpression::RightShift(box e, box by)),
|
||||
}
|
||||
}
|
||||
e => fold_field_expression(self, e),
|
||||
}
|
||||
}
|
||||
|
@ -1333,6 +1559,113 @@ mod tests {
|
|||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn left_shift() {
|
||||
let mut constants = Constants::new();
|
||||
let mut propagator = Propagator::with_constants(&mut constants);
|
||||
|
||||
assert_eq!(
|
||||
propagator.fold_field_expression(FieldElementExpression::LeftShift(
|
||||
box FieldElementExpression::identifier("a".into()),
|
||||
box 0u32.into(),
|
||||
)),
|
||||
Ok(FieldElementExpression::identifier("a".into()))
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
propagator.fold_field_expression(FieldElementExpression::LeftShift(
|
||||
box FieldElementExpression::Number(Bn128Field::from(2)),
|
||||
box 2u32.into(),
|
||||
)),
|
||||
Ok(FieldElementExpression::Number(Bn128Field::from(8)))
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
propagator.fold_field_expression(FieldElementExpression::LeftShift(
|
||||
box FieldElementExpression::Number(Bn128Field::from(1)),
|
||||
box ((Bn128Field::get_required_bits() - 1) as u32).into(),
|
||||
)),
|
||||
Ok(FieldElementExpression::Number(Bn128Field::try_from_dec_str("14474011154664524427946373126085988481658748083205070504932198000989141204992").unwrap()))
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
propagator.fold_field_expression(FieldElementExpression::LeftShift(
|
||||
box FieldElementExpression::Number(Bn128Field::from(3)),
|
||||
box ((Bn128Field::get_required_bits() - 3) as u32).into(),
|
||||
)),
|
||||
Ok(FieldElementExpression::Number(Bn128Field::try_from_dec_str("10855508365998393320959779844564491361244061062403802878699148500741855903744").unwrap()))
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
propagator.fold_field_expression(FieldElementExpression::LeftShift(
|
||||
box FieldElementExpression::Number(Bn128Field::from(1)),
|
||||
box (Bn128Field::get_required_bits() as u32).into(),
|
||||
)),
|
||||
Ok(FieldElementExpression::Number(Bn128Field::from(0)))
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn right_shift() {
|
||||
let mut constants = Constants::new();
|
||||
let mut propagator = Propagator::with_constants(&mut constants);
|
||||
|
||||
assert_eq!(
|
||||
propagator.fold_field_expression(FieldElementExpression::RightShift(
|
||||
box FieldElementExpression::identifier("a".into()),
|
||||
box 0u32.into(),
|
||||
)),
|
||||
Ok(FieldElementExpression::identifier("a".into()))
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
propagator.fold_field_expression(FieldElementExpression::RightShift(
|
||||
box FieldElementExpression::identifier("a".into()),
|
||||
box (Bn128Field::get_required_bits() as u32).into(),
|
||||
)),
|
||||
Ok(FieldElementExpression::Number(Bn128Field::from(0)))
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
propagator.fold_field_expression(FieldElementExpression::RightShift(
|
||||
box FieldElementExpression::Number(Bn128Field::from(3)),
|
||||
box 1u32.into(),
|
||||
)),
|
||||
Ok(FieldElementExpression::Number(Bn128Field::from(1)))
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
propagator.fold_field_expression(FieldElementExpression::RightShift(
|
||||
box FieldElementExpression::Number(Bn128Field::from(2)),
|
||||
box 2u32.into(),
|
||||
)),
|
||||
Ok(FieldElementExpression::Number(Bn128Field::from(0)))
|
||||
);
|
||||
assert_eq!(
|
||||
propagator.fold_field_expression(FieldElementExpression::RightShift(
|
||||
box FieldElementExpression::Number(Bn128Field::from(2)),
|
||||
box 4u32.into(),
|
||||
)),
|
||||
Ok(FieldElementExpression::Number(Bn128Field::from(0)))
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
propagator.fold_field_expression(FieldElementExpression::RightShift(
|
||||
box FieldElementExpression::Number(Bn128Field::max_value()),
|
||||
box ((Bn128Field::get_required_bits() - 1) as u32).into(),
|
||||
)),
|
||||
Ok(FieldElementExpression::Number(Bn128Field::from(1)))
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
propagator.fold_field_expression(FieldElementExpression::RightShift(
|
||||
box FieldElementExpression::Number(Bn128Field::max_value()),
|
||||
box (Bn128Field::get_required_bits() as u32).into(),
|
||||
)),
|
||||
Ok(FieldElementExpression::Number(Bn128Field::from(0)))
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn if_else_true() {
|
||||
let e = FieldElementExpression::conditional(
|
|
@ -1,6 +1,6 @@
|
|||
// given a (partial) map of values for program constants, replace where applicable constants by their value
|
||||
|
||||
use crate::static_analysis::reducer::ConstantDefinitions;
|
||||
use crate::reducer::ConstantDefinitions;
|
||||
use zokrates_ast::typed::{
|
||||
folder::*, ArrayExpression, ArrayExpressionInner, ArrayType, BooleanExpression, CoreIdentifier,
|
||||
DeclarationConstant, Expr, FieldElementExpression, Id, Identifier, IdentifierExpression,
|
|
@ -1,6 +1,6 @@
|
|||
// A folder to inline all constant definitions down to a single literal and register them in the state for later use.
|
||||
|
||||
use crate::static_analysis::reducer::{
|
||||
use crate::reducer::{
|
||||
constants_reader::ConstantsReader, reduce_function, ConstantDefinitions, Error,
|
||||
};
|
||||
use std::collections::{BTreeMap, HashSet};
|
|
@ -26,9 +26,9 @@
|
|||
// - The body of the function is in SSA form
|
||||
// - The return value(s) are assigned to internal variables
|
||||
|
||||
use crate::static_analysis::reducer::Output;
|
||||
use crate::static_analysis::reducer::ShallowTransformer;
|
||||
use crate::static_analysis::reducer::Versions;
|
||||
use crate::reducer::Output;
|
||||
use crate::reducer::ShallowTransformer;
|
||||
use crate::reducer::Versions;
|
||||
|
||||
use zokrates_ast::common::FlatEmbed;
|
||||
use zokrates_ast::typed::types::{ConcreteGenericsAssignment, IntoType};
|
|
@ -36,7 +36,7 @@ use zokrates_field::Field;
|
|||
use self::constants_writer::ConstantsWriter;
|
||||
use self::shallow_ssa::ShallowTransformer;
|
||||
|
||||
use crate::static_analysis::propagation::{Constants, Propagator};
|
||||
use crate::propagation::{Constants, Propagator};
|
||||
|
||||
use std::fmt;
|
||||
|
|
@ -121,33 +121,42 @@ impl<'ast, 'a> ShallowTransformer<'ast, 'a> {
|
|||
|
||||
fold_function(self, f)
|
||||
}
|
||||
|
||||
fn fold_assignee<T: Field>(&mut self, a: TypedAssignee<'ast, T>) -> TypedAssignee<'ast, T> {
|
||||
match a {
|
||||
TypedAssignee::Identifier(v) => {
|
||||
let v = self.issue_next_ssa_variable(v);
|
||||
TypedAssignee::Identifier(self.fold_variable(v))
|
||||
}
|
||||
a => fold_assignee(self, a),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast, 'a, T: Field> Folder<'ast, T> for ShallowTransformer<'ast, 'a> {
|
||||
fn fold_assembly_statement(
|
||||
&mut self,
|
||||
s: TypedAssemblyStatement<'ast, T>,
|
||||
) -> Vec<TypedAssemblyStatement<'ast, T>> {
|
||||
match s {
|
||||
TypedAssemblyStatement::Assignment(a, e) => {
|
||||
let e = self.fold_expression(e);
|
||||
let a = self.fold_assignee(a);
|
||||
vec![TypedAssemblyStatement::Assignment(a, e)]
|
||||
}
|
||||
s => fold_assembly_statement(self, s),
|
||||
}
|
||||
}
|
||||
fn fold_statement(&mut self, s: TypedStatement<'ast, T>) -> Vec<TypedStatement<'ast, T>> {
|
||||
match s {
|
||||
TypedStatement::Definition(a, DefinitionRhs::Expression(e)) => {
|
||||
let e = self.fold_expression(e);
|
||||
|
||||
let a = match a {
|
||||
TypedAssignee::Identifier(v) => {
|
||||
let v = self.issue_next_ssa_variable(v);
|
||||
TypedAssignee::Identifier(self.fold_variable(v))
|
||||
}
|
||||
a => fold_assignee(self, a),
|
||||
};
|
||||
|
||||
let a = self.fold_assignee(a);
|
||||
vec![TypedStatement::definition(a, e)]
|
||||
}
|
||||
TypedStatement::Definition(assignee, DefinitionRhs::EmbedCall(embed_call)) => {
|
||||
let assignee = match assignee {
|
||||
TypedAssignee::Identifier(v) => {
|
||||
let v = self.issue_next_ssa_variable(v);
|
||||
TypedAssignee::Identifier(self.fold_variable(v))
|
||||
}
|
||||
a => fold_assignee(self, a),
|
||||
};
|
||||
let embed_call = self.fold_embed_call(embed_call);
|
||||
let assignee = self.fold_assignee(assignee);
|
||||
vec![TypedStatement::embed_call_definition(assignee, embed_call)]
|
||||
}
|
||||
TypedStatement::For(v, from, to, stats) => {
|
|
@ -5,11 +5,22 @@
|
|||
//! @date 2018
|
||||
|
||||
use std::collections::HashSet;
|
||||
use zokrates_ast::typed::folder::*;
|
||||
use std::fmt;
|
||||
use zokrates_ast::typed::result_folder::ResultFolder;
|
||||
use zokrates_ast::typed::result_folder::*;
|
||||
use zokrates_ast::typed::types::{MemberId, Type};
|
||||
use zokrates_ast::typed::*;
|
||||
use zokrates_field::Field;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Error(String);
|
||||
|
||||
impl fmt::Display for Error {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
write!(f, "{}", self.0)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct VariableWriteRemover;
|
||||
|
||||
impl<'ast> VariableWriteRemover {
|
||||
|
@ -17,7 +28,7 @@ impl<'ast> VariableWriteRemover {
|
|||
VariableWriteRemover
|
||||
}
|
||||
|
||||
pub fn apply<T: Field>(p: TypedProgram<T>) -> TypedProgram<T> {
|
||||
pub fn apply<T: Field>(p: TypedProgram<T>) -> Result<TypedProgram<T>, Error> {
|
||||
let mut remover = VariableWriteRemover::new();
|
||||
remover.fold_program(p)
|
||||
}
|
||||
|
@ -452,14 +463,35 @@ fn is_constant<T>(assignee: &TypedAssignee<T>) -> bool {
|
|||
}
|
||||
}
|
||||
|
||||
impl<'ast, T: Field> Folder<'ast, T> for VariableWriteRemover {
|
||||
fn fold_statement(&mut self, s: TypedStatement<'ast, T>) -> Vec<TypedStatement<'ast, T>> {
|
||||
impl<'ast, T: Field> ResultFolder<'ast, T> for VariableWriteRemover {
|
||||
type Error = Error;
|
||||
|
||||
fn fold_assembly_statement(
|
||||
&mut self,
|
||||
s: TypedAssemblyStatement<'ast, T>,
|
||||
) -> Result<Vec<TypedAssemblyStatement<'ast, T>>, Self::Error> {
|
||||
match s {
|
||||
TypedAssemblyStatement::Assignment(a, e) if is_constant(&a) => {
|
||||
Ok(vec![TypedAssemblyStatement::Assignment(a, e)])
|
||||
}
|
||||
TypedAssemblyStatement::Assignment(a, _) => Err(Error(format!(
|
||||
"Cannot assign to an assignee with a variable index `{}`",
|
||||
a
|
||||
))),
|
||||
s => Ok(vec![s]),
|
||||
}
|
||||
}
|
||||
|
||||
fn fold_statement(
|
||||
&mut self,
|
||||
s: TypedStatement<'ast, T>,
|
||||
) -> Result<Vec<TypedStatement<'ast, T>>, Self::Error> {
|
||||
match s {
|
||||
TypedStatement::Definition(assignee, DefinitionRhs::Expression(expr)) => {
|
||||
let expr = self.fold_expression(expr);
|
||||
let expr = self.fold_expression(expr)?;
|
||||
|
||||
if is_constant(&assignee) {
|
||||
vec![TypedStatement::definition(assignee, expr)]
|
||||
Ok(vec![TypedStatement::definition(assignee, expr)])
|
||||
} else {
|
||||
// Note: here we redefine the whole object, ideally we would only redefine some of it
|
||||
// Example: `a[0][i] = 42` we redefine `a` but we could redefine just `a[0]`
|
||||
|
@ -486,28 +518,28 @@ impl<'ast, T: Field> Folder<'ast, T> for VariableWriteRemover {
|
|||
.into(),
|
||||
};
|
||||
|
||||
let base = self.fold_expression(base);
|
||||
let base = self.fold_expression(base)?;
|
||||
|
||||
let indices = indices
|
||||
.into_iter()
|
||||
.map(|a| match a {
|
||||
Access::Select(box i) => {
|
||||
Access::Select(box self.fold_uint_expression(i))
|
||||
Ok(Access::Select(box self.fold_uint_expression(i)?))
|
||||
}
|
||||
a => a,
|
||||
a => Ok(a),
|
||||
})
|
||||
.collect();
|
||||
.collect::<Result<_, _>>()?;
|
||||
|
||||
let mut range_checks = HashSet::new();
|
||||
let e = Self::choose_many(base, indices, expr, &mut range_checks);
|
||||
|
||||
range_checks
|
||||
Ok(range_checks
|
||||
.into_iter()
|
||||
.chain(std::iter::once(TypedStatement::definition(
|
||||
TypedAssignee::Identifier(variable),
|
||||
e,
|
||||
)))
|
||||
.collect()
|
||||
.collect())
|
||||
}
|
||||
}
|
||||
s => fold_statement(self, s),
|
|
@ -1,9 +1,13 @@
|
|||
use num::traits::Pow;
|
||||
use num_bigint::BigUint;
|
||||
use std::collections::HashMap;
|
||||
use std::fmt;
|
||||
use std::ops::{BitAnd, BitOr, BitXor, Shl, Shr, Sub};
|
||||
use zokrates_ast::zir::types::UBitwidth;
|
||||
use zokrates_ast::zir::{
|
||||
result_folder::*, Conditional, ConditionalExpression, ConditionalOrExpression, Expr, Id,
|
||||
IdentifierExpression, IdentifierOrExpression, SelectExpression, SelectOrExpression,
|
||||
result_folder::*, Conditional, ConditionalExpression, ConditionalOrExpression, Constant, Expr,
|
||||
Id, IdentifierExpression, IdentifierOrExpression, SelectExpression, SelectOrExpression,
|
||||
ZirAssemblyStatement,
|
||||
};
|
||||
use zokrates_ast::zir::{
|
||||
BooleanExpression, FieldElementExpression, Identifier, RuntimeError, UExpression,
|
||||
|
@ -31,7 +35,7 @@ impl fmt::Display for Error {
|
|||
Error::DivisionByZero => {
|
||||
write!(f, "Division by zero detected in zir during static analysis",)
|
||||
}
|
||||
Error::AssertionFailed(err) => write!(f, "{}", err),
|
||||
Error::AssertionFailed(err) => write!(f, "Assertion failed ({})", err),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -42,6 +46,9 @@ pub struct ZirPropagator<'ast, T> {
|
|||
}
|
||||
|
||||
impl<'ast, T: Field> ZirPropagator<'ast, T> {
|
||||
pub fn with_constants(constants: Constants<'ast, T>) -> Self {
|
||||
Self { constants }
|
||||
}
|
||||
pub fn propagate(p: ZirProgram<T>) -> Result<ZirProgram<T>, Error> {
|
||||
ZirPropagator::default().fold_program(p)
|
||||
}
|
||||
|
@ -50,6 +57,68 @@ impl<'ast, T: Field> ZirPropagator<'ast, T> {
|
|||
impl<'ast, T: Field> ResultFolder<'ast, T> for ZirPropagator<'ast, T> {
|
||||
type Error = Error;
|
||||
|
||||
fn fold_assembly_statement(
|
||||
&mut self,
|
||||
s: ZirAssemblyStatement<'ast, T>,
|
||||
) -> Result<Vec<ZirAssemblyStatement<'ast, T>>, Self::Error> {
|
||||
match s {
|
||||
ZirAssemblyStatement::Assignment(assignees, function) => {
|
||||
let assignees: Vec<_> = assignees
|
||||
.into_iter()
|
||||
.map(|a| self.fold_assignee(a))
|
||||
.collect::<Result<_, _>>()?;
|
||||
|
||||
let function = self.fold_function(function)?;
|
||||
|
||||
match &function.statements.last().unwrap() {
|
||||
ZirStatement::Return(values) => {
|
||||
if values.iter().all(|v| v.is_constant()) {
|
||||
self.constants.extend(
|
||||
assignees
|
||||
.into_iter()
|
||||
.zip(values.iter())
|
||||
.map(|(a, v)| (a.id, v.clone())),
|
||||
);
|
||||
Ok(vec![])
|
||||
} else {
|
||||
assignees.iter().for_each(|a| {
|
||||
self.constants.remove(&a.id);
|
||||
});
|
||||
Ok(vec![ZirAssemblyStatement::Assignment(assignees, function)])
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
assignees.iter().for_each(|a| {
|
||||
self.constants.remove(&a.id);
|
||||
});
|
||||
Ok(vec![ZirAssemblyStatement::Assignment(assignees, function)])
|
||||
}
|
||||
}
|
||||
}
|
||||
ZirAssemblyStatement::Constraint(left, right, metadata) => {
|
||||
let left = self.fold_field_expression(left)?;
|
||||
let right = self.fold_field_expression(right)?;
|
||||
|
||||
// a bit hacky, but we use a fake boolean expression to check this
|
||||
let is_equal = BooleanExpression::FieldEq(box left.clone(), box right.clone());
|
||||
let is_equal = self.fold_boolean_expression(is_equal)?;
|
||||
|
||||
match is_equal {
|
||||
BooleanExpression::Value(true) => Ok(vec![]),
|
||||
BooleanExpression::Value(false) => {
|
||||
Err(Error::AssertionFailed(RuntimeError::SourceAssertion(
|
||||
metadata
|
||||
.message(Some(format!("In asm block: `{} !== {}`", left, right))),
|
||||
)))
|
||||
}
|
||||
_ => Ok(vec![ZirAssemblyStatement::Constraint(
|
||||
left, right, metadata,
|
||||
)]),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn fold_statement(
|
||||
&mut self,
|
||||
s: ZirStatement<'ast, T>,
|
||||
|
@ -122,6 +191,19 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for ZirPropagator<'ast, T> {
|
|||
self.fold_expression_list(list)?,
|
||||
)])
|
||||
}
|
||||
ZirStatement::Assembly(statements) => {
|
||||
let statements: Vec<_> = statements
|
||||
.into_iter()
|
||||
.map(|s| self.fold_assembly_statement(s))
|
||||
.collect::<Result<Vec<_>, _>>()?
|
||||
.into_iter()
|
||||
.flatten()
|
||||
.collect();
|
||||
match statements.len() {
|
||||
0 => Ok(vec![]),
|
||||
_ => Ok(vec![ZirStatement::Assembly(statements)]),
|
||||
}
|
||||
}
|
||||
_ => fold_statement(self, s),
|
||||
}
|
||||
}
|
||||
|
@ -226,6 +308,127 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for ZirPropagator<'ast, T> {
|
|||
)),
|
||||
}
|
||||
}
|
||||
FieldElementExpression::Xor(box e1, box e2) => {
|
||||
let e1 = self.fold_field_expression(e1)?;
|
||||
let e2 = self.fold_field_expression(e2)?;
|
||||
|
||||
match (e1, e2) {
|
||||
(FieldElementExpression::Number(n1), FieldElementExpression::Number(n2)) => {
|
||||
Ok(FieldElementExpression::Number(
|
||||
T::try_from(n1.to_biguint().bitxor(n2.to_biguint())).unwrap(),
|
||||
))
|
||||
}
|
||||
(e1, e2) if e1.eq(&e2) => Ok(FieldElementExpression::Number(T::from(0))),
|
||||
(e1, e2) => Ok(FieldElementExpression::Xor(box e1, box e2)),
|
||||
}
|
||||
}
|
||||
FieldElementExpression::And(box e1, box e2) => {
|
||||
let e1 = self.fold_field_expression(e1)?;
|
||||
let e2 = self.fold_field_expression(e2)?;
|
||||
|
||||
match (e1, e2) {
|
||||
(_, FieldElementExpression::Number(n))
|
||||
| (FieldElementExpression::Number(n), _)
|
||||
if n == T::from(0) =>
|
||||
{
|
||||
Ok(FieldElementExpression::Number(n))
|
||||
}
|
||||
(FieldElementExpression::Number(n1), FieldElementExpression::Number(n2)) => {
|
||||
Ok(FieldElementExpression::Number(
|
||||
T::try_from(n1.to_biguint().bitand(n2.to_biguint())).unwrap(),
|
||||
))
|
||||
}
|
||||
(e1, e2) => Ok(FieldElementExpression::And(box e1, box e2)),
|
||||
}
|
||||
}
|
||||
FieldElementExpression::Or(box e1, box e2) => {
|
||||
let e1 = self.fold_field_expression(e1)?;
|
||||
let e2 = self.fold_field_expression(e2)?;
|
||||
|
||||
match (e1, e2) {
|
||||
(e, FieldElementExpression::Number(n))
|
||||
| (FieldElementExpression::Number(n), e)
|
||||
if n == T::from(0) =>
|
||||
{
|
||||
Ok(e)
|
||||
}
|
||||
(FieldElementExpression::Number(n1), FieldElementExpression::Number(n2)) => {
|
||||
Ok(FieldElementExpression::Number(
|
||||
T::try_from(n1.to_biguint().bitor(n2.to_biguint())).unwrap(),
|
||||
))
|
||||
}
|
||||
(e1, e2) => Ok(FieldElementExpression::Or(box e1, box e2)),
|
||||
}
|
||||
}
|
||||
FieldElementExpression::LeftShift(box e, box by) => {
|
||||
let e = self.fold_field_expression(e)?;
|
||||
let by = self.fold_uint_expression(by)?;
|
||||
match (e, by) {
|
||||
(
|
||||
e,
|
||||
UExpression {
|
||||
inner: UExpressionInner::Value(by),
|
||||
..
|
||||
},
|
||||
) if by == 0 => Ok(e),
|
||||
(
|
||||
_,
|
||||
UExpression {
|
||||
inner: UExpressionInner::Value(by),
|
||||
..
|
||||
},
|
||||
) if by as usize >= T::get_required_bits() => {
|
||||
Ok(FieldElementExpression::Number(T::from(0)))
|
||||
}
|
||||
(
|
||||
FieldElementExpression::Number(n),
|
||||
UExpression {
|
||||
inner: UExpressionInner::Value(by),
|
||||
..
|
||||
},
|
||||
) => {
|
||||
let two = BigUint::from(2usize);
|
||||
let mask: BigUint = two.pow(T::get_required_bits()).sub(1usize);
|
||||
|
||||
Ok(FieldElementExpression::Number(
|
||||
T::try_from(n.to_biguint().shl(by as usize).bitand(mask)).unwrap(),
|
||||
))
|
||||
}
|
||||
(e, by) => Ok(FieldElementExpression::LeftShift(box e, box by)),
|
||||
}
|
||||
}
|
||||
FieldElementExpression::RightShift(box e, box by) => {
|
||||
let e = self.fold_field_expression(e)?;
|
||||
let by = self.fold_uint_expression(by)?;
|
||||
match (e, by) {
|
||||
(
|
||||
e,
|
||||
UExpression {
|
||||
inner: UExpressionInner::Value(by),
|
||||
..
|
||||
},
|
||||
) if by == 0 => Ok(e),
|
||||
(
|
||||
_,
|
||||
UExpression {
|
||||
inner: UExpressionInner::Value(by),
|
||||
..
|
||||
},
|
||||
) if by as usize >= T::get_required_bits() => {
|
||||
Ok(FieldElementExpression::Number(T::from(0)))
|
||||
}
|
||||
(
|
||||
FieldElementExpression::Number(n),
|
||||
UExpression {
|
||||
inner: UExpressionInner::Value(by),
|
||||
..
|
||||
},
|
||||
) => Ok(FieldElementExpression::Number(
|
||||
T::try_from(n.to_biguint().shr(by as usize)).unwrap(),
|
||||
)),
|
||||
(e, by) => Ok(FieldElementExpression::RightShift(box e, box by)),
|
||||
}
|
||||
}
|
||||
e => fold_field_expression(self, e),
|
||||
}
|
||||
}
|
||||
|
@ -587,22 +790,28 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for ZirPropagator<'ast, T> {
|
|||
e: ConditionalExpression<'ast, T, E>,
|
||||
) -> Result<ConditionalOrExpression<'ast, T, E>, Self::Error> {
|
||||
let condition = self.fold_boolean_expression(*e.condition)?;
|
||||
let consequence = e.consequence.fold(self)?;
|
||||
let alternative = e.alternative.fold(self)?;
|
||||
|
||||
match (condition, consequence, alternative) {
|
||||
(_, consequence, alternative) if consequence == alternative => Ok(
|
||||
ConditionalOrExpression::Expression(consequence.into_inner()),
|
||||
),
|
||||
(BooleanExpression::Value(true), consequence, _) => Ok(
|
||||
ConditionalOrExpression::Expression(consequence.into_inner()),
|
||||
),
|
||||
(BooleanExpression::Value(false), _, alternative) => Ok(
|
||||
ConditionalOrExpression::Expression(alternative.into_inner()),
|
||||
),
|
||||
(condition, consequence, alternative) => Ok(ConditionalOrExpression::Conditional(
|
||||
ConditionalExpression::new(condition, consequence, alternative),
|
||||
match condition {
|
||||
BooleanExpression::Value(true) => Ok(ConditionalOrExpression::Expression(
|
||||
e.consequence.fold(self)?.into_inner(),
|
||||
)),
|
||||
BooleanExpression::Value(false) => Ok(ConditionalOrExpression::Expression(
|
||||
e.alternative.fold(self)?.into_inner(),
|
||||
)),
|
||||
condition => {
|
||||
let consequence = e.consequence.fold(self)?;
|
||||
let alternative = e.alternative.fold(self)?;
|
||||
|
||||
if consequence == alternative {
|
||||
Ok(ConditionalOrExpression::Expression(
|
||||
consequence.into_inner(),
|
||||
))
|
||||
} else {
|
||||
Ok(ConditionalOrExpression::Conditional(
|
||||
ConditionalExpression::new(condition, consequence, alternative),
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -821,6 +1030,115 @@ mod tests {
|
|||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn left_shift() {
|
||||
let mut propagator = ZirPropagator::<Bn128Field>::default();
|
||||
|
||||
assert_eq!(
|
||||
propagator.fold_field_expression(FieldElementExpression::LeftShift(
|
||||
box FieldElementExpression::identifier("a".into()),
|
||||
box UExpressionInner::Value(0).annotate(UBitwidth::B32),
|
||||
)),
|
||||
Ok(FieldElementExpression::identifier("a".into()))
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
propagator.fold_field_expression(FieldElementExpression::LeftShift(
|
||||
box FieldElementExpression::Number(Bn128Field::from(2)),
|
||||
box UExpressionInner::Value(2 as u128).annotate(UBitwidth::B32),
|
||||
)),
|
||||
Ok(FieldElementExpression::Number(Bn128Field::from(8)))
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
propagator.fold_field_expression(FieldElementExpression::LeftShift(
|
||||
box FieldElementExpression::Number(Bn128Field::from(1)),
|
||||
box UExpressionInner::Value((Bn128Field::get_required_bits() - 1) as u128).annotate(UBitwidth::B32),
|
||||
)),
|
||||
Ok(FieldElementExpression::Number(Bn128Field::try_from_dec_str("14474011154664524427946373126085988481658748083205070504932198000989141204992").unwrap()))
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
propagator.fold_field_expression(FieldElementExpression::LeftShift(
|
||||
box FieldElementExpression::Number(Bn128Field::from(3)),
|
||||
box UExpressionInner::Value((Bn128Field::get_required_bits() - 3) as u128).annotate(UBitwidth::B32),
|
||||
)),
|
||||
Ok(FieldElementExpression::Number(Bn128Field::try_from_dec_str("10855508365998393320959779844564491361244061062403802878699148500741855903744").unwrap()))
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
propagator.fold_field_expression(FieldElementExpression::LeftShift(
|
||||
box FieldElementExpression::Number(Bn128Field::from(1)),
|
||||
box UExpressionInner::Value((Bn128Field::get_required_bits()) as u128)
|
||||
.annotate(UBitwidth::B32),
|
||||
)),
|
||||
Ok(FieldElementExpression::Number(Bn128Field::from(0)))
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn right_shift() {
|
||||
let mut propagator = ZirPropagator::<Bn128Field>::default();
|
||||
|
||||
assert_eq!(
|
||||
propagator.fold_field_expression(FieldElementExpression::RightShift(
|
||||
box FieldElementExpression::identifier("a".into()),
|
||||
box UExpressionInner::Value(0).annotate(UBitwidth::B32),
|
||||
)),
|
||||
Ok(FieldElementExpression::identifier("a".into()))
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
propagator.fold_field_expression(FieldElementExpression::RightShift(
|
||||
box FieldElementExpression::identifier("a".into()),
|
||||
box UExpressionInner::Value(Bn128Field::get_required_bits() as u128)
|
||||
.annotate(UBitwidth::B32),
|
||||
)),
|
||||
Ok(FieldElementExpression::Number(Bn128Field::from(0)))
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
propagator.fold_field_expression(FieldElementExpression::RightShift(
|
||||
box FieldElementExpression::Number(Bn128Field::from(3)),
|
||||
box UExpressionInner::Value(1 as u128).annotate(UBitwidth::B32),
|
||||
)),
|
||||
Ok(FieldElementExpression::Number(Bn128Field::from(1)))
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
propagator.fold_field_expression(FieldElementExpression::RightShift(
|
||||
box FieldElementExpression::Number(Bn128Field::from(2)),
|
||||
box UExpressionInner::Value(2 as u128).annotate(UBitwidth::B32),
|
||||
)),
|
||||
Ok(FieldElementExpression::Number(Bn128Field::from(0)))
|
||||
);
|
||||
assert_eq!(
|
||||
propagator.fold_field_expression(FieldElementExpression::RightShift(
|
||||
box FieldElementExpression::Number(Bn128Field::from(2)),
|
||||
box UExpressionInner::Value(4 as u128).annotate(UBitwidth::B32),
|
||||
)),
|
||||
Ok(FieldElementExpression::Number(Bn128Field::from(0)))
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
propagator.fold_field_expression(FieldElementExpression::RightShift(
|
||||
box FieldElementExpression::Number(Bn128Field::max_value()),
|
||||
box UExpressionInner::Value((Bn128Field::get_required_bits() - 1) as u128)
|
||||
.annotate(UBitwidth::B32),
|
||||
)),
|
||||
Ok(FieldElementExpression::Number(Bn128Field::from(1)))
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
propagator.fold_field_expression(FieldElementExpression::RightShift(
|
||||
box FieldElementExpression::Number(Bn128Field::max_value()),
|
||||
box UExpressionInner::Value(Bn128Field::get_required_bits() as u128)
|
||||
.annotate(UBitwidth::B32),
|
||||
)),
|
||||
Ok(FieldElementExpression::Number(Bn128Field::from(0)))
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn if_else() {
|
||||
let mut propagator = ZirPropagator::default();
|
|
@ -16,8 +16,8 @@ use zokrates_proof_systems::Scheme;
|
|||
use zokrates_proof_systems::{Backend, NonUniversalBackend, Proof, SetupKeypair};
|
||||
|
||||
impl<T: Field + ArkFieldExtensions> NonUniversalBackend<T, GM17> for Ark {
|
||||
fn setup<I: IntoIterator<Item = Statement<T>>, R: RngCore + CryptoRng>(
|
||||
program: ProgIterator<T, I>,
|
||||
fn setup<'a, I: IntoIterator<Item = Statement<'a, T>>, R: RngCore + CryptoRng>(
|
||||
program: ProgIterator<'a, T, I>,
|
||||
rng: &mut R,
|
||||
) -> SetupKeypair<T, GM17> {
|
||||
let computation = Computation::without_witness(program);
|
||||
|
@ -40,8 +40,8 @@ impl<T: Field + ArkFieldExtensions> NonUniversalBackend<T, GM17> for Ark {
|
|||
}
|
||||
|
||||
impl<T: Field + ArkFieldExtensions> Backend<T, GM17> for Ark {
|
||||
fn generate_proof<I: IntoIterator<Item = Statement<T>>, R: RngCore + CryptoRng>(
|
||||
program: ProgIterator<T, I>,
|
||||
fn generate_proof<'a, I: IntoIterator<Item = Statement<'a, T>>, R: RngCore + CryptoRng>(
|
||||
program: ProgIterator<'a, T, I>,
|
||||
witness: Witness<T>,
|
||||
proving_key: Vec<u8>,
|
||||
rng: &mut R,
|
||||
|
|
|
@ -19,8 +19,8 @@ use zokrates_proof_systems::Scheme;
|
|||
const G16_WARNING: &str = "WARNING: You are using the G16 scheme which is subject to malleability. See zokrates.github.io/toolbox/proving_schemes.html#g16-malleability for implications.";
|
||||
|
||||
impl<T: Field + ArkFieldExtensions> Backend<T, G16> for Ark {
|
||||
fn generate_proof<I: IntoIterator<Item = Statement<T>>, R: RngCore + CryptoRng>(
|
||||
program: ProgIterator<T, I>,
|
||||
fn generate_proof<'a, I: IntoIterator<Item = Statement<'a, T>>, R: RngCore + CryptoRng>(
|
||||
program: ProgIterator<'a, T, I>,
|
||||
witness: Witness<T>,
|
||||
proving_key: Vec<u8>,
|
||||
rng: &mut R,
|
||||
|
@ -85,8 +85,8 @@ impl<T: Field + ArkFieldExtensions> Backend<T, G16> for Ark {
|
|||
}
|
||||
|
||||
impl<T: Field + ArkFieldExtensions> NonUniversalBackend<T, G16> for Ark {
|
||||
fn setup<I: IntoIterator<Item = Statement<T>>, R: RngCore + CryptoRng>(
|
||||
program: ProgIterator<T, I>,
|
||||
fn setup<'a, I: IntoIterator<Item = Statement<'a, T>>, R: RngCore + CryptoRng>(
|
||||
program: ProgIterator<'a, T, I>,
|
||||
rng: &mut R,
|
||||
) -> SetupKeypair<T, G16> {
|
||||
println!("{}", G16_WARNING);
|
||||
|
|
|
@ -17,20 +17,20 @@ pub use self::parse::*;
|
|||
pub struct Ark;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Computation<T, I: IntoIterator<Item = Statement<T>>> {
|
||||
program: ProgIterator<T, I>,
|
||||
pub struct Computation<'a, T, I: IntoIterator<Item = Statement<'a, T>>> {
|
||||
program: ProgIterator<'a, T, I>,
|
||||
witness: Option<Witness<T>>,
|
||||
}
|
||||
|
||||
impl<T, I: IntoIterator<Item = Statement<T>>> Computation<T, I> {
|
||||
pub fn with_witness(program: ProgIterator<T, I>, witness: Witness<T>) -> Self {
|
||||
impl<'a, T, I: IntoIterator<Item = Statement<'a, T>>> Computation<'a, T, I> {
|
||||
pub fn with_witness(program: ProgIterator<'a, T, I>, witness: Witness<T>) -> Self {
|
||||
Computation {
|
||||
program,
|
||||
witness: Some(witness),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn without_witness(program: ProgIterator<T, I>) -> Self {
|
||||
pub fn without_witness(program: ProgIterator<'a, T, I>) -> Self {
|
||||
Computation {
|
||||
program,
|
||||
witness: None,
|
||||
|
@ -72,9 +72,9 @@ fn ark_combination<T: Field + ArkFieldExtensions>(
|
|||
.fold(LinearCombination::zero(), |acc, e| acc + e)
|
||||
}
|
||||
|
||||
impl<T: Field + ArkFieldExtensions, I: IntoIterator<Item = Statement<T>>>
|
||||
impl<'a, T: Field + ArkFieldExtensions, I: IntoIterator<Item = Statement<'a, T>>>
|
||||
ConstraintSynthesizer<<<T as ArkFieldExtensions>::ArkEngine as PairingEngine>::Fr>
|
||||
for Computation<T, I>
|
||||
for Computation<'a, T, I>
|
||||
{
|
||||
fn generate_constraints(
|
||||
self,
|
||||
|
@ -143,7 +143,9 @@ impl<T: Field + ArkFieldExtensions, I: IntoIterator<Item = Statement<T>>>
|
|||
}
|
||||
}
|
||||
|
||||
impl<T: Field + ArkFieldExtensions, I: IntoIterator<Item = Statement<T>>> Computation<T, I> {
|
||||
impl<'a, T: Field + ArkFieldExtensions, I: IntoIterator<Item = Statement<'a, T>>>
|
||||
Computation<'a, T, I>
|
||||
{
|
||||
pub fn public_inputs_values(&self) -> Vec<<T::ArkEngine as PairingEngine>::Fr> {
|
||||
self.program
|
||||
.public_inputs_values(self.witness.as_ref().unwrap())
|
||||
|
|
|
@ -130,9 +130,9 @@ impl<T: Field + ArkFieldExtensions> UniversalBackend<T, marlin::Marlin> for Ark
|
|||
res
|
||||
}
|
||||
|
||||
fn setup<I: IntoIterator<Item = Statement<T>>>(
|
||||
fn setup<'a, I: IntoIterator<Item = Statement<'a, T>>>(
|
||||
srs: Vec<u8>,
|
||||
program: ProgIterator<T, I>,
|
||||
program: ProgIterator<'a, T, I>,
|
||||
) -> Result<SetupKeypair<T, marlin::Marlin>, String> {
|
||||
let program = program.collect();
|
||||
|
||||
|
@ -206,8 +206,8 @@ impl<T: Field + ArkFieldExtensions> UniversalBackend<T, marlin::Marlin> for Ark
|
|||
}
|
||||
|
||||
impl<T: Field + ArkFieldExtensions> Backend<T, marlin::Marlin> for Ark {
|
||||
fn generate_proof<I: IntoIterator<Item = Statement<T>>, R: RngCore + CryptoRng>(
|
||||
program: ProgIterator<T, I>,
|
||||
fn generate_proof<'a, I: IntoIterator<Item = Statement<'a, T>>, R: RngCore + CryptoRng>(
|
||||
program: ProgIterator<'a, T, I>,
|
||||
witness: Witness<T>,
|
||||
proving_key: Vec<u8>,
|
||||
rng: &mut R,
|
||||
|
|
|
@ -9,6 +9,7 @@ use crate::untyped::{
|
|||
types::{UnresolvedSignature, UnresolvedType},
|
||||
ConstantGenericNode, Expression,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use zokrates_field::Field;
|
||||
|
||||
|
@ -28,8 +29,9 @@ cfg_if::cfg_if! {
|
|||
|
||||
/// A low level function that contains non-deterministic introduction of variables. It is carried out as is until
|
||||
/// the flattening step when it can be inlined.
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash, Copy, PartialOrd, Ord)]
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash, Copy, PartialOrd, Ord, Serialize, Deserialize)]
|
||||
pub enum FlatEmbed {
|
||||
FieldToBoolUnsafe,
|
||||
BitArrayLe,
|
||||
Unpack,
|
||||
U8ToBits,
|
||||
|
@ -49,6 +51,9 @@ pub enum FlatEmbed {
|
|||
impl FlatEmbed {
|
||||
pub fn signature(&self) -> UnresolvedSignature {
|
||||
match self {
|
||||
FlatEmbed::FieldToBoolUnsafe => UnresolvedSignature::new()
|
||||
.inputs(vec![UnresolvedType::FieldElement.into()])
|
||||
.output(UnresolvedType::Boolean.into()),
|
||||
FlatEmbed::BitArrayLe => UnresolvedSignature::new()
|
||||
.generics(vec![ConstantGenericNode::mock("N")])
|
||||
.inputs(vec![
|
||||
|
@ -185,6 +190,9 @@ impl FlatEmbed {
|
|||
|
||||
pub fn typed_signature<T>(&self) -> DeclarationSignature<'static, T> {
|
||||
match self {
|
||||
FlatEmbed::FieldToBoolUnsafe => DeclarationSignature::new()
|
||||
.inputs(vec![DeclarationType::FieldElement])
|
||||
.output(DeclarationType::Boolean),
|
||||
FlatEmbed::BitArrayLe => DeclarationSignature::new()
|
||||
.generics(vec![Some(DeclarationConstant::Generic(
|
||||
GenericIdentifier::with_name("N").with_index(0),
|
||||
|
@ -291,6 +299,7 @@ impl FlatEmbed {
|
|||
|
||||
pub fn id(&self) -> &'static str {
|
||||
match self {
|
||||
FlatEmbed::FieldToBoolUnsafe => "_FIELD_TO_BOOL_UNSAFE",
|
||||
FlatEmbed::BitArrayLe => "_BIT_ARRAY_LT",
|
||||
FlatEmbed::Unpack => "_UNPACK",
|
||||
FlatEmbed::U8ToBits => "_U8_TO_BITS",
|
||||
|
@ -317,8 +326,8 @@ impl FlatEmbed {
|
|||
/// - constraint system variables
|
||||
/// - arguments
|
||||
#[cfg(feature = "bellman")]
|
||||
pub fn sha256_round<T: Field>(
|
||||
) -> FlatFunctionIterator<T, impl IntoIterator<Item = FlatStatement<T>>> {
|
||||
pub fn sha256_round<'ast, T: Field>(
|
||||
) -> FlatFunctionIterator<'ast, T, impl IntoIterator<Item = FlatStatement<'ast, T>>> {
|
||||
use zokrates_field::Bn128Field;
|
||||
assert_eq!(T::id(), Bn128Field::id());
|
||||
|
||||
|
@ -420,9 +429,9 @@ pub fn sha256_round<T: Field>(
|
|||
}
|
||||
|
||||
#[cfg(feature = "ark")]
|
||||
pub fn snark_verify_bls12_377<T: Field>(
|
||||
pub fn snark_verify_bls12_377<'ast, T: Field>(
|
||||
n: usize,
|
||||
) -> FlatFunctionIterator<T, impl IntoIterator<Item = FlatStatement<T>>> {
|
||||
) -> FlatFunctionIterator<'ast, T, impl IntoIterator<Item = FlatStatement<'ast, T>>> {
|
||||
use zokrates_field::Bw6_761Field;
|
||||
assert_eq!(T::id(), Bw6_761Field::id());
|
||||
|
||||
|
@ -546,9 +555,9 @@ fn use_variable(
|
|||
/// # Remarks
|
||||
/// * the return value of the `FlatFunction` is not deterministic if `bit_width >= T::get_required_bits()`
|
||||
/// as some elements can have multiple representations: For example, `unpack(0)` is `[0, ..., 0]` but also `unpack(p)`
|
||||
pub fn unpack_to_bitwidth<T: Field>(
|
||||
pub fn unpack_to_bitwidth<'ast, T: Field>(
|
||||
bit_width: usize,
|
||||
) -> FlatFunctionIterator<T, impl IntoIterator<Item = FlatStatement<T>>> {
|
||||
) -> FlatFunctionIterator<'ast, T, impl IntoIterator<Item = FlatStatement<'ast, T>>> {
|
||||
let mut counter = 0;
|
||||
|
||||
let mut layout = HashMap::new();
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
use crate::common::SourceMetadata;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::fmt;
|
||||
use std::fmt::Write;
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Hash, PartialEq, Eq)]
|
||||
pub enum RuntimeError {
|
||||
|
@ -25,7 +27,8 @@ pub enum RuntimeError {
|
|||
Euclidean,
|
||||
ShaXor,
|
||||
Division,
|
||||
SourceAssertion(String),
|
||||
SourceAssertion(SourceMetadata),
|
||||
SourceAssemblyConstraint(SourceMetadata),
|
||||
ArgumentBitness,
|
||||
SelectRangeCheck,
|
||||
}
|
||||
|
@ -33,7 +36,9 @@ pub enum RuntimeError {
|
|||
impl From<crate::zir::RuntimeError> for RuntimeError {
|
||||
fn from(error: crate::zir::RuntimeError) -> Self {
|
||||
match error {
|
||||
crate::zir::RuntimeError::SourceAssertion(s) => RuntimeError::SourceAssertion(s),
|
||||
crate::zir::RuntimeError::SourceAssertion(metadata) => {
|
||||
RuntimeError::SourceAssertion(metadata)
|
||||
}
|
||||
crate::zir::RuntimeError::SelectRangeCheck => RuntimeError::SelectRangeCheck,
|
||||
crate::zir::RuntimeError::DivisionByZero => RuntimeError::Inverse,
|
||||
crate::zir::RuntimeError::IncompleteDynamicRange => {
|
||||
|
@ -49,7 +54,8 @@ impl RuntimeError {
|
|||
|
||||
!matches!(
|
||||
self,
|
||||
SourceAssertion(_)
|
||||
SourceAssemblyConstraint(_)
|
||||
| SourceAssertion(_)
|
||||
| Inverse
|
||||
| SelectRangeCheck
|
||||
| ArgumentBitness
|
||||
|
@ -62,6 +68,7 @@ impl fmt::Display for RuntimeError {
|
|||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
use RuntimeError::*;
|
||||
|
||||
let mut buf = String::new();
|
||||
let msg = match self {
|
||||
BellmanConstraint => "Bellman constraint is unsatisfied",
|
||||
BellmanOneBinding => "Bellman ~one binding is unsatisfied",
|
||||
|
@ -87,7 +94,14 @@ impl fmt::Display for RuntimeError {
|
|||
Euclidean => "Euclidean check failed",
|
||||
ShaXor => "Internal Sha check failed",
|
||||
Division => "Division check failed",
|
||||
SourceAssertion(m) => m.as_str(),
|
||||
SourceAssertion(m) => {
|
||||
write!(&mut buf, "Assertion failed at {}", m).unwrap();
|
||||
buf.as_str()
|
||||
}
|
||||
SourceAssemblyConstraint(m) => {
|
||||
write!(&mut buf, "Unsatisfied constraint at {}", m).unwrap();
|
||||
buf.as_str()
|
||||
}
|
||||
ArgumentBitness => "Argument bitness check failed",
|
||||
SelectRangeCheck => "Out of bounds array access",
|
||||
};
|
||||
|
|
34
zokrates_ast/src/common/metadata.rs
Normal file
34
zokrates_ast/src/common/metadata.rs
Normal file
|
@ -0,0 +1,34 @@
|
|||
use crate::untyped::Position;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::fmt;
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Hash, Eq, Default, PartialOrd, Ord, Serialize, Deserialize)]
|
||||
pub struct SourceMetadata {
|
||||
pub file: String,
|
||||
pub position: Position,
|
||||
pub message: Option<String>,
|
||||
}
|
||||
|
||||
impl SourceMetadata {
|
||||
pub fn new(file: String, position: Position) -> Self {
|
||||
Self {
|
||||
file,
|
||||
position,
|
||||
message: None,
|
||||
}
|
||||
}
|
||||
pub fn message(mut self, message: Option<String>) -> Self {
|
||||
self.message = message;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for SourceMetadata {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
write!(f, "{}:{}", self.file, self.position)?;
|
||||
match &self.message {
|
||||
Some(m) => write!(f, ": \"{}\"", m),
|
||||
None => write!(f, ""),
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,12 +1,14 @@
|
|||
pub mod embed;
|
||||
mod error;
|
||||
mod format_string;
|
||||
mod metadata;
|
||||
mod parameter;
|
||||
mod solvers;
|
||||
mod variable;
|
||||
|
||||
pub use self::embed::FlatEmbed;
|
||||
pub use self::error::RuntimeError;
|
||||
pub use self::metadata::SourceMetadata;
|
||||
pub use self::parameter::Parameter;
|
||||
pub use self::solvers::Solver;
|
||||
pub use self::variable::Variable;
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
use crate::zir::ZirFunction;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::fmt;
|
||||
|
||||
#[derive(Clone, PartialEq, Debug, Serialize, Deserialize, Hash, Eq)]
|
||||
pub enum Solver {
|
||||
pub enum Solver<'ast, T> {
|
||||
ConditionEq,
|
||||
Bits(usize),
|
||||
Div,
|
||||
|
@ -11,19 +12,35 @@ pub enum Solver {
|
|||
ShaAndXorAndXorAnd,
|
||||
ShaCh,
|
||||
EuclideanDiv,
|
||||
#[serde(borrow)]
|
||||
Zir(ZirFunction<'ast, T>),
|
||||
#[cfg(feature = "bellman")]
|
||||
Sha256Round,
|
||||
#[cfg(feature = "ark")]
|
||||
SnarkVerifyBls12377(usize),
|
||||
}
|
||||
|
||||
impl fmt::Display for Solver {
|
||||
impl<'ast, T> fmt::Display for Solver<'ast, T> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
write!(f, "{:?}", self)
|
||||
match self {
|
||||
Solver::ConditionEq => write!(f, "ConditionEq"),
|
||||
Solver::Bits(n) => write!(f, "Bits({})", n),
|
||||
Solver::Div => write!(f, "Div"),
|
||||
Solver::Xor => write!(f, "Xor"),
|
||||
Solver::Or => write!(f, "Or"),
|
||||
Solver::ShaAndXorAndXorAnd => write!(f, "ShaAndXorAndXorAnd"),
|
||||
Solver::ShaCh => write!(f, "ShaCh"),
|
||||
Solver::EuclideanDiv => write!(f, "EuclideanDiv"),
|
||||
Solver::Zir(_) => write!(f, "Zir(..)"),
|
||||
#[cfg(feature = "bellman")]
|
||||
Solver::Sha256Round => write!(f, "Sha256Round"),
|
||||
#[cfg(feature = "ark")]
|
||||
Solver::SnarkVerifyBls12377(n) => write!(f, "SnarkVerifyBls12377({})", n),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Solver {
|
||||
impl<'ast, T> Solver<'ast, T> {
|
||||
pub fn get_signature(&self) -> (usize, usize) {
|
||||
match self {
|
||||
Solver::ConditionEq => (1, 2),
|
||||
|
@ -34,6 +51,7 @@ impl Solver {
|
|||
Solver::ShaAndXorAndXorAnd => (3, 1),
|
||||
Solver::ShaCh => (3, 1),
|
||||
Solver::EuclideanDiv => (2, 2),
|
||||
Solver::Zir(f) => (f.arguments.len(), 1),
|
||||
#[cfg(feature = "bellman")]
|
||||
Solver::Sha256Round => (768, 26935),
|
||||
#[cfg(feature = "ark")]
|
||||
|
@ -42,7 +60,7 @@ impl Solver {
|
|||
}
|
||||
}
|
||||
|
||||
impl Solver {
|
||||
impl<'ast, T> Solver<'ast, T> {
|
||||
pub fn bits(width: usize) -> Self {
|
||||
Solver::Bits(width)
|
||||
}
|
||||
|
|
|
@ -4,8 +4,8 @@ use super::*;
|
|||
use crate::common::Variable;
|
||||
use zokrates_field::Field;
|
||||
|
||||
pub trait Folder<T: Field>: Sized {
|
||||
fn fold_program(&mut self, p: FlatProg<T>) -> FlatProg<T> {
|
||||
pub trait Folder<'ast, T: Field>: Sized {
|
||||
fn fold_program(&mut self, p: FlatProg<'ast, T>) -> FlatProg<'ast, T> {
|
||||
fold_program(self, p)
|
||||
}
|
||||
|
||||
|
@ -17,7 +17,7 @@ pub trait Folder<T: Field>: Sized {
|
|||
fold_variable(self, v)
|
||||
}
|
||||
|
||||
fn fold_statement(&mut self, s: FlatStatement<T>) -> Vec<FlatStatement<T>> {
|
||||
fn fold_statement(&mut self, s: FlatStatement<'ast, T>) -> Vec<FlatStatement<'ast, T>> {
|
||||
fold_statement(self, s)
|
||||
}
|
||||
|
||||
|
@ -25,12 +25,15 @@ pub trait Folder<T: Field>: Sized {
|
|||
fold_expression(self, e)
|
||||
}
|
||||
|
||||
fn fold_directive(&mut self, d: FlatDirective<T>) -> FlatDirective<T> {
|
||||
fn fold_directive(&mut self, d: FlatDirective<'ast, T>) -> FlatDirective<'ast, T> {
|
||||
fold_directive(self, d)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn fold_program<T: Field, F: Folder<T>>(f: &mut F, p: FlatProg<T>) -> FlatProg<T> {
|
||||
pub fn fold_program<'ast, T: Field, F: Folder<'ast, T>>(
|
||||
f: &mut F,
|
||||
p: FlatProg<'ast, T>,
|
||||
) -> FlatProg<'ast, T> {
|
||||
FlatProg {
|
||||
arguments: p
|
||||
.arguments
|
||||
|
@ -46,11 +49,17 @@ pub fn fold_program<T: Field, F: Folder<T>>(f: &mut F, p: FlatProg<T>) -> FlatPr
|
|||
}
|
||||
}
|
||||
|
||||
pub fn fold_statement<T: Field, F: Folder<T>>(
|
||||
pub fn fold_statement<'ast, T: Field, F: Folder<'ast, T>>(
|
||||
f: &mut F,
|
||||
s: FlatStatement<T>,
|
||||
) -> Vec<FlatStatement<T>> {
|
||||
s: FlatStatement<'ast, T>,
|
||||
) -> Vec<FlatStatement<'ast, T>> {
|
||||
match s {
|
||||
FlatStatement::Block(statements) => vec![FlatStatement::Block(
|
||||
statements
|
||||
.into_iter()
|
||||
.flat_map(|s| f.fold_statement(s))
|
||||
.collect(),
|
||||
)],
|
||||
FlatStatement::Condition(left, right, error) => vec![FlatStatement::Condition(
|
||||
f.fold_expression(left),
|
||||
f.fold_expression(right),
|
||||
|
@ -70,7 +79,7 @@ pub fn fold_statement<T: Field, F: Folder<T>>(
|
|||
}
|
||||
}
|
||||
|
||||
pub fn fold_expression<T: Field, F: Folder<T>>(
|
||||
pub fn fold_expression<'ast, T: Field, F: Folder<'ast, T>>(
|
||||
f: &mut F,
|
||||
e: FlatExpression<T>,
|
||||
) -> FlatExpression<T> {
|
||||
|
@ -89,7 +98,10 @@ pub fn fold_expression<T: Field, F: Folder<T>>(
|
|||
}
|
||||
}
|
||||
|
||||
pub fn fold_directive<T: Field, F: Folder<T>>(f: &mut F, ds: FlatDirective<T>) -> FlatDirective<T> {
|
||||
pub fn fold_directive<'ast, T: Field, F: Folder<'ast, T>>(
|
||||
f: &mut F,
|
||||
ds: FlatDirective<'ast, T>,
|
||||
) -> FlatDirective<'ast, T> {
|
||||
FlatDirective {
|
||||
inputs: ds
|
||||
.inputs
|
||||
|
@ -101,13 +113,13 @@ pub fn fold_directive<T: Field, F: Folder<T>>(f: &mut F, ds: FlatDirective<T>) -
|
|||
}
|
||||
}
|
||||
|
||||
pub fn fold_argument<T: Field, F: Folder<T>>(f: &mut F, a: Parameter) -> Parameter {
|
||||
pub fn fold_argument<'ast, T: Field, F: Folder<'ast, T>>(f: &mut F, a: Parameter) -> Parameter {
|
||||
Parameter {
|
||||
id: f.fold_variable(a.id),
|
||||
private: a.private,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn fold_variable<T: Field, F: Folder<T>>(_f: &mut F, v: Variable) -> Variable {
|
||||
pub fn fold_variable<'ast, T: Field, F: Folder<'ast, T>>(_f: &mut F, v: Variable) -> Variable {
|
||||
v
|
||||
}
|
||||
|
|
|
@ -24,14 +24,14 @@ use std::collections::HashMap;
|
|||
use std::fmt;
|
||||
use zokrates_field::Field;
|
||||
|
||||
pub type FlatProg<T> = FlatFunction<T>;
|
||||
pub type FlatProg<'ast, T> = FlatFunction<'ast, T>;
|
||||
|
||||
pub type FlatFunction<T> = FlatFunctionIterator<T, Vec<FlatStatement<T>>>;
|
||||
pub type FlatFunction<'ast, T> = FlatFunctionIterator<'ast, T, Vec<FlatStatement<'ast, T>>>;
|
||||
|
||||
pub type FlatProgIterator<T, I> = FlatFunctionIterator<T, I>;
|
||||
pub type FlatProgIterator<'ast, T, I> = FlatFunctionIterator<'ast, T, I>;
|
||||
|
||||
#[derive(Clone, PartialEq, Eq, Debug)]
|
||||
pub struct FlatFunctionIterator<T, I: IntoIterator<Item = FlatStatement<T>>> {
|
||||
pub struct FlatFunctionIterator<'ast, T, I: IntoIterator<Item = FlatStatement<'ast, T>>> {
|
||||
/// Arguments of the function
|
||||
pub arguments: Vec<Parameter>,
|
||||
/// Vector of statements that are executed when running the function
|
||||
|
@ -40,8 +40,8 @@ pub struct FlatFunctionIterator<T, I: IntoIterator<Item = FlatStatement<T>>> {
|
|||
pub return_count: usize,
|
||||
}
|
||||
|
||||
impl<T, I: IntoIterator<Item = FlatStatement<T>>> FlatFunctionIterator<T, I> {
|
||||
pub fn collect(self) -> FlatFunction<T> {
|
||||
impl<'ast, T, I: IntoIterator<Item = FlatStatement<'ast, T>>> FlatFunctionIterator<'ast, T, I> {
|
||||
pub fn collect(self) -> FlatFunction<'ast, T> {
|
||||
FlatFunction {
|
||||
statements: self.statements.into_iter().collect(),
|
||||
arguments: self.arguments,
|
||||
|
@ -50,7 +50,7 @@ impl<T, I: IntoIterator<Item = FlatStatement<T>>> FlatFunctionIterator<T, I> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<T: Field> fmt::Display for FlatFunction<T> {
|
||||
impl<'ast, T: Field> fmt::Display for FlatFunction<'ast, T> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
|
@ -81,16 +81,24 @@ impl<T: Field> fmt::Display for FlatFunction<T> {
|
|||
/// * r1cs - R1CS in standard JSON data format
|
||||
|
||||
#[derive(Clone, PartialEq, Eq, Debug)]
|
||||
pub enum FlatStatement<T> {
|
||||
pub enum FlatStatement<'ast, T> {
|
||||
Block(Vec<FlatStatement<'ast, T>>),
|
||||
Condition(FlatExpression<T>, FlatExpression<T>, RuntimeError),
|
||||
Definition(Variable, FlatExpression<T>),
|
||||
Directive(FlatDirective<T>),
|
||||
Directive(FlatDirective<'ast, T>),
|
||||
Log(FormatString, Vec<(ConcreteType, Vec<FlatExpression<T>>)>),
|
||||
}
|
||||
|
||||
impl<T: Field> fmt::Display for FlatStatement<T> {
|
||||
impl<'ast, T: Field> fmt::Display for FlatStatement<'ast, T> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
match *self {
|
||||
FlatStatement::Block(ref statements) => {
|
||||
writeln!(f, "{{")?;
|
||||
for s in statements {
|
||||
writeln!(f, "{}", s)?;
|
||||
}
|
||||
writeln!(f, "}}")
|
||||
}
|
||||
FlatStatement::Definition(ref lhs, ref rhs) => write!(f, "{} = {}", lhs, rhs),
|
||||
FlatStatement::Condition(ref lhs, ref rhs, ref message) => {
|
||||
write!(f, "{} == {} // {}", lhs, rhs, message)
|
||||
|
@ -116,12 +124,18 @@ impl<T: Field> fmt::Display for FlatStatement<T> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<T: Field> FlatStatement<T> {
|
||||
impl<'ast, T: Field> FlatStatement<'ast, T> {
|
||||
pub fn apply_substitution(
|
||||
self,
|
||||
substitution: &HashMap<Variable, Variable>,
|
||||
substitution: &'ast HashMap<Variable, Variable>,
|
||||
) -> FlatStatement<T> {
|
||||
match self {
|
||||
FlatStatement::Block(statements) => FlatStatement::Block(
|
||||
statements
|
||||
.into_iter()
|
||||
.map(|s| s.apply_substitution(substitution))
|
||||
.collect(),
|
||||
),
|
||||
FlatStatement::Definition(id, x) => FlatStatement::Definition(
|
||||
*id.apply_substitution(substitution),
|
||||
x.apply_substitution(substitution),
|
||||
|
@ -167,16 +181,16 @@ impl<T: Field> FlatStatement<T> {
|
|||
}
|
||||
|
||||
#[derive(Clone, Hash, Debug, PartialEq, Eq)]
|
||||
pub struct FlatDirective<T> {
|
||||
pub struct FlatDirective<'ast, T> {
|
||||
pub inputs: Vec<FlatExpression<T>>,
|
||||
pub outputs: Vec<Variable>,
|
||||
pub solver: Solver,
|
||||
pub solver: Solver<'ast, T>,
|
||||
}
|
||||
|
||||
impl<T> FlatDirective<T> {
|
||||
impl<'ast, T> FlatDirective<'ast, T> {
|
||||
pub fn new<E: Into<FlatExpression<T>>>(
|
||||
outputs: Vec<Variable>,
|
||||
solver: Solver,
|
||||
solver: Solver<'ast, T>,
|
||||
inputs: Vec<E>,
|
||||
) -> Self {
|
||||
let (in_len, out_len) = solver.get_signature();
|
||||
|
@ -190,7 +204,7 @@ impl<T> FlatDirective<T> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<T: Field> fmt::Display for FlatDirective<T> {
|
||||
impl<'ast, T: Field> fmt::Display for FlatDirective<'ast, T> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
|
|
|
@ -13,7 +13,9 @@ pub struct UnconstrainedVariableDetector {
|
|||
}
|
||||
|
||||
impl UnconstrainedVariableDetector {
|
||||
pub fn new<T: Field, I: IntoIterator<Item = Statement<T>>>(p: &ProgIterator<T, I>) -> Self {
|
||||
pub fn new<'ast, T: Field, I: IntoIterator<Item = Statement<'ast, T>>>(
|
||||
p: &ProgIterator<'ast, T, I>,
|
||||
) -> Self {
|
||||
UnconstrainedVariableDetector {
|
||||
variables: p
|
||||
.arguments
|
||||
|
@ -32,7 +34,7 @@ impl UnconstrainedVariableDetector {
|
|||
}
|
||||
}
|
||||
|
||||
impl<T: Field> Folder<T> for UnconstrainedVariableDetector {
|
||||
impl<'ast, T: Field> Folder<'ast, T> for UnconstrainedVariableDetector {
|
||||
fn fold_argument(&mut self, p: Parameter) -> Parameter {
|
||||
p
|
||||
}
|
||||
|
@ -40,7 +42,7 @@ impl<T: Field> Folder<T> for UnconstrainedVariableDetector {
|
|||
self.variables.remove(&v);
|
||||
v
|
||||
}
|
||||
fn fold_directive(&mut self, d: Directive<T>) -> Directive<T> {
|
||||
fn fold_directive(&mut self, d: Directive<'ast, T>) -> Directive<'ast, T> {
|
||||
self.variables.extend(d.outputs.iter());
|
||||
d
|
||||
}
|
||||
|
|
31
zokrates_ast/src/ir/clean.rs
Normal file
31
zokrates_ast/src/ir/clean.rs
Normal file
|
@ -0,0 +1,31 @@
|
|||
use super::folder::Folder;
|
||||
use super::{ProgIterator, Statement};
|
||||
use zokrates_field::Field;
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct Cleaner;
|
||||
|
||||
impl<'ast, T: Field, I: IntoIterator<Item = Statement<'ast, T>>> ProgIterator<'ast, T, I> {
|
||||
pub fn clean(self) -> ProgIterator<'ast, T, impl IntoIterator<Item = Statement<'ast, T>>> {
|
||||
ProgIterator {
|
||||
arguments: self.arguments,
|
||||
return_count: self.return_count,
|
||||
statements: self
|
||||
.statements
|
||||
.into_iter()
|
||||
.flat_map(|s| Cleaner::default().fold_statement(s)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast, T: Field> Folder<'ast, T> for Cleaner {
|
||||
fn fold_statement(&mut self, s: Statement<'ast, T>) -> Vec<Statement<'ast, T>> {
|
||||
match s {
|
||||
Statement::Block(statements) => statements
|
||||
.into_iter()
|
||||
.flat_map(|s| self.fold_statement(s))
|
||||
.collect(),
|
||||
s => vec![s],
|
||||
}
|
||||
}
|
||||
}
|
|
@ -4,8 +4,8 @@ use super::*;
|
|||
use crate::common::Variable;
|
||||
use zokrates_field::Field;
|
||||
|
||||
pub trait Folder<T: Field>: Sized {
|
||||
fn fold_program(&mut self, p: Prog<T>) -> Prog<T> {
|
||||
pub trait Folder<'ast, T: Field>: Sized {
|
||||
fn fold_program(&mut self, p: Prog<'ast, T>) -> Prog<'ast, T> {
|
||||
fold_program(self, p)
|
||||
}
|
||||
|
||||
|
@ -17,7 +17,7 @@ pub trait Folder<T: Field>: Sized {
|
|||
fold_variable(self, v)
|
||||
}
|
||||
|
||||
fn fold_statement(&mut self, s: Statement<T>) -> Vec<Statement<T>> {
|
||||
fn fold_statement(&mut self, s: Statement<'ast, T>) -> Vec<Statement<'ast, T>> {
|
||||
fold_statement(self, s)
|
||||
}
|
||||
|
||||
|
@ -29,12 +29,15 @@ pub trait Folder<T: Field>: Sized {
|
|||
fold_quadratic_combination(self, es)
|
||||
}
|
||||
|
||||
fn fold_directive(&mut self, d: Directive<T>) -> Directive<T> {
|
||||
fn fold_directive(&mut self, d: Directive<'ast, T>) -> Directive<'ast, T> {
|
||||
fold_directive(self, d)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn fold_program<T: Field, F: Folder<T>>(f: &mut F, p: Prog<T>) -> Prog<T> {
|
||||
pub fn fold_program<'ast, T: Field, F: Folder<'ast, T>>(
|
||||
f: &mut F,
|
||||
p: Prog<'ast, T>,
|
||||
) -> Prog<'ast, T> {
|
||||
Prog {
|
||||
arguments: p
|
||||
.arguments
|
||||
|
@ -50,8 +53,17 @@ pub fn fold_program<T: Field, F: Folder<T>>(f: &mut F, p: Prog<T>) -> Prog<T> {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn fold_statement<T: Field, F: Folder<T>>(f: &mut F, s: Statement<T>) -> Vec<Statement<T>> {
|
||||
pub fn fold_statement<'ast, T: Field, F: Folder<'ast, T>>(
|
||||
f: &mut F,
|
||||
s: Statement<'ast, T>,
|
||||
) -> Vec<Statement<'ast, T>> {
|
||||
match s {
|
||||
Statement::Block(statements) => vec![Statement::Block(
|
||||
statements
|
||||
.into_iter()
|
||||
.flat_map(|s| f.fold_statement(s))
|
||||
.collect(),
|
||||
)],
|
||||
Statement::Constraint(quad, lin, message) => vec![Statement::Constraint(
|
||||
f.fold_quadratic_combination(quad),
|
||||
f.fold_linear_combination(lin),
|
||||
|
@ -74,7 +86,10 @@ pub fn fold_statement<T: Field, F: Folder<T>>(f: &mut F, s: Statement<T>) -> Vec
|
|||
}
|
||||
}
|
||||
|
||||
pub fn fold_linear_combination<T: Field, F: Folder<T>>(f: &mut F, e: LinComb<T>) -> LinComb<T> {
|
||||
pub fn fold_linear_combination<'ast, T: Field, F: Folder<'ast, T>>(
|
||||
f: &mut F,
|
||||
e: LinComb<T>,
|
||||
) -> LinComb<T> {
|
||||
LinComb(
|
||||
e.0.into_iter()
|
||||
.map(|(variable, coefficient)| (f.fold_variable(variable), coefficient))
|
||||
|
@ -82,7 +97,7 @@ pub fn fold_linear_combination<T: Field, F: Folder<T>>(f: &mut F, e: LinComb<T>)
|
|||
)
|
||||
}
|
||||
|
||||
pub fn fold_quadratic_combination<T: Field, F: Folder<T>>(
|
||||
pub fn fold_quadratic_combination<'ast, T: Field, F: Folder<'ast, T>>(
|
||||
f: &mut F,
|
||||
e: QuadComb<T>,
|
||||
) -> QuadComb<T> {
|
||||
|
@ -92,7 +107,10 @@ pub fn fold_quadratic_combination<T: Field, F: Folder<T>>(
|
|||
}
|
||||
}
|
||||
|
||||
pub fn fold_directive<T: Field, F: Folder<T>>(f: &mut F, ds: Directive<T>) -> Directive<T> {
|
||||
pub fn fold_directive<'ast, T: Field, F: Folder<'ast, T>>(
|
||||
f: &mut F,
|
||||
ds: Directive<'ast, T>,
|
||||
) -> Directive<'ast, T> {
|
||||
Directive {
|
||||
inputs: ds
|
||||
.inputs
|
||||
|
@ -104,13 +122,13 @@ pub fn fold_directive<T: Field, F: Folder<T>>(f: &mut F, ds: Directive<T>) -> Di
|
|||
}
|
||||
}
|
||||
|
||||
pub fn fold_argument<T: Field, F: Folder<T>>(f: &mut F, a: Parameter) -> Parameter {
|
||||
pub fn fold_argument<'ast, T: Field, F: Folder<'ast, T>>(f: &mut F, a: Parameter) -> Parameter {
|
||||
Parameter {
|
||||
id: f.fold_variable(a.id),
|
||||
private: a.private,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn fold_variable<T: Field, F: Folder<T>>(_f: &mut F, v: Variable) -> Variable {
|
||||
pub fn fold_variable<'ast, T: Field, F: Folder<'ast, T>>(_f: &mut F, v: Variable) -> Variable {
|
||||
v
|
||||
}
|
||||
|
|
|
@ -17,9 +17,9 @@ impl<T: Field> QuadComb<T> {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn from_flat<T: Field, I: IntoIterator<Item = FlatStatement<T>>>(
|
||||
flat_prog_iterator: FlatProgIterator<T, I>,
|
||||
) -> ProgIterator<T, impl IntoIterator<Item = Statement<T>>> {
|
||||
pub fn from_flat<'ast, T: Field, I: IntoIterator<Item = FlatStatement<'ast, T>>>(
|
||||
flat_prog_iterator: FlatProgIterator<'ast, T, I>,
|
||||
) -> ProgIterator<T, impl IntoIterator<Item = Statement<'ast, T>>> {
|
||||
ProgIterator {
|
||||
statements: flat_prog_iterator.statements.into_iter().map(Into::into),
|
||||
arguments: flat_prog_iterator.arguments,
|
||||
|
@ -52,9 +52,12 @@ impl<T: Field> From<FlatExpression<T>> for LinComb<T> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<T: Field> From<FlatStatement<T>> for Statement<T> {
|
||||
fn from(flat_statement: FlatStatement<T>) -> Statement<T> {
|
||||
impl<'ast, T: Field> From<FlatStatement<'ast, T>> for Statement<'ast, T> {
|
||||
fn from(flat_statement: FlatStatement<'ast, T>) -> Statement<'ast, T> {
|
||||
match flat_statement {
|
||||
FlatStatement::Block(statements) => {
|
||||
Statement::Block(statements.into_iter().map(Statement::from).collect())
|
||||
}
|
||||
FlatStatement::Condition(linear, quadratic, message) => match quadratic {
|
||||
FlatExpression::Mult(box lhs, box rhs) => Statement::Constraint(
|
||||
QuadComb::from_linear_combinations(lhs.into(), rhs.into()),
|
||||
|
@ -83,8 +86,8 @@ impl<T: Field> From<FlatStatement<T>> for Statement<T> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<T: Field> From<FlatDirective<T>> for Directive<T> {
|
||||
fn from(ds: FlatDirective<T>) -> Directive<T> {
|
||||
impl<'ast, T: Field> From<FlatDirective<'ast, T>> for Directive<'ast, T> {
|
||||
fn from(ds: FlatDirective<'ast, T>) -> Directive<T> {
|
||||
Directive {
|
||||
inputs: ds
|
||||
.inputs
|
||||
|
|
|
@ -8,6 +8,7 @@ use std::hash::Hash;
|
|||
use zokrates_field::Field;
|
||||
|
||||
mod check;
|
||||
mod clean;
|
||||
mod expression;
|
||||
pub mod folder;
|
||||
pub mod from_flat;
|
||||
|
@ -28,19 +29,22 @@ pub use self::witness::Witness;
|
|||
|
||||
#[derive(Debug, Serialize, Deserialize, Clone, Derivative)]
|
||||
#[derivative(Hash, PartialEq, Eq)]
|
||||
pub enum Statement<T> {
|
||||
pub enum Statement<'ast, T> {
|
||||
#[serde(skip)]
|
||||
Block(Vec<Statement<'ast, T>>),
|
||||
Constraint(
|
||||
QuadComb<T>,
|
||||
LinComb<T>,
|
||||
#[derivative(Hash = "ignore")] Option<RuntimeError>,
|
||||
),
|
||||
Directive(Directive<T>),
|
||||
#[serde(borrow)]
|
||||
Directive(Directive<'ast, T>),
|
||||
Log(FormatString, Vec<(ConcreteType, Vec<LinComb<T>>)>),
|
||||
}
|
||||
|
||||
pub type PublicInputs = BTreeSet<Variable>;
|
||||
|
||||
impl<T: Field> Statement<T> {
|
||||
impl<'ast, T: Field> Statement<'ast, T> {
|
||||
pub fn definition<U: Into<QuadComb<T>>>(v: Variable, e: U) -> Self {
|
||||
Statement::Constraint(e.into(), v.into(), None)
|
||||
}
|
||||
|
@ -51,13 +55,14 @@ impl<T: Field> Statement<T> {
|
|||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, Hash, PartialEq, Eq)]
|
||||
pub struct Directive<T> {
|
||||
pub struct Directive<'ast, T> {
|
||||
pub inputs: Vec<QuadComb<T>>,
|
||||
pub outputs: Vec<Variable>,
|
||||
pub solver: Solver,
|
||||
#[serde(borrow)]
|
||||
pub solver: Solver<'ast, T>,
|
||||
}
|
||||
|
||||
impl<T: Field> fmt::Display for Directive<T> {
|
||||
impl<'ast, T: Field> fmt::Display for Directive<'ast, T> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
|
@ -77,9 +82,16 @@ impl<T: Field> fmt::Display for Directive<T> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<T: Field> fmt::Display for Statement<T> {
|
||||
impl<'ast, T: Field> fmt::Display for Statement<'ast, T> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
match *self {
|
||||
Statement::Block(ref statements) => {
|
||||
writeln!(f, "{{")?;
|
||||
for s in statements {
|
||||
writeln!(f, "{}", s)?;
|
||||
}
|
||||
write!(f, "}}")
|
||||
}
|
||||
Statement::Constraint(ref quad, ref lin, ref error) => write!(
|
||||
f,
|
||||
"{} == {}{}",
|
||||
|
@ -111,16 +123,16 @@ impl<T: Field> fmt::Display for Statement<T> {
|
|||
}
|
||||
}
|
||||
|
||||
pub type Prog<T> = ProgIterator<T, Vec<Statement<T>>>;
|
||||
pub type Prog<'ast, T> = ProgIterator<'ast, T, Vec<Statement<'ast, T>>>;
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq, Default)]
|
||||
pub struct ProgIterator<T, I: IntoIterator<Item = Statement<T>>> {
|
||||
pub struct ProgIterator<'ast, T, I: IntoIterator<Item = Statement<'ast, T>>> {
|
||||
pub arguments: Vec<Parameter>,
|
||||
pub return_count: usize,
|
||||
pub statements: I,
|
||||
}
|
||||
|
||||
impl<T, I: IntoIterator<Item = Statement<T>>> ProgIterator<T, I> {
|
||||
impl<'ast, T, I: IntoIterator<Item = Statement<'ast, T>>> ProgIterator<'ast, T, I> {
|
||||
pub fn new(arguments: Vec<Parameter>, statements: I, return_count: usize) -> Self {
|
||||
Self {
|
||||
arguments,
|
||||
|
@ -129,7 +141,7 @@ impl<T, I: IntoIterator<Item = Statement<T>>> ProgIterator<T, I> {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn collect(self) -> ProgIterator<T, Vec<Statement<T>>> {
|
||||
pub fn collect(self) -> ProgIterator<'ast, T, Vec<Statement<'ast, T>>> {
|
||||
ProgIterator {
|
||||
statements: self.statements.into_iter().collect::<Vec<_>>(),
|
||||
arguments: self.arguments,
|
||||
|
@ -154,7 +166,7 @@ impl<T, I: IntoIterator<Item = Statement<T>>> ProgIterator<T, I> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<T: Field, I: IntoIterator<Item = Statement<T>>> ProgIterator<T, I> {
|
||||
impl<'ast, T: Field, I: IntoIterator<Item = Statement<'ast, T>>> ProgIterator<'ast, T, I> {
|
||||
pub fn public_inputs_values(&self, witness: &Witness<T>) -> Vec<T> {
|
||||
self.arguments
|
||||
.iter()
|
||||
|
@ -165,7 +177,7 @@ impl<T: Field, I: IntoIterator<Item = Statement<T>>> ProgIterator<T, I> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<T> Prog<T> {
|
||||
impl<'ast, T> Prog<'ast, T> {
|
||||
pub fn constraint_count(&self) -> usize {
|
||||
self.statements
|
||||
.iter()
|
||||
|
@ -173,7 +185,9 @@ impl<T> Prog<T> {
|
|||
.count()
|
||||
}
|
||||
|
||||
pub fn into_prog_iter(self) -> ProgIterator<T, impl IntoIterator<Item = Statement<T>>> {
|
||||
pub fn into_prog_iter(
|
||||
self,
|
||||
) -> ProgIterator<'ast, T, impl IntoIterator<Item = Statement<'ast, T>>> {
|
||||
ProgIterator {
|
||||
statements: self.statements.into_iter(),
|
||||
arguments: self.arguments,
|
||||
|
@ -182,7 +196,7 @@ impl<T> Prog<T> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<T: Field> fmt::Display for Prog<T> {
|
||||
impl<'ast, T: Field> fmt::Display for Prog<'ast, T> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
let returns = (0..self.return_count)
|
||||
.map(Variable::public)
|
||||
|
|
|
@ -12,32 +12,35 @@ const ZOKRATES_VERSION_2: &[u8; 4] = &[0, 0, 0, 2];
|
|||
|
||||
#[derive(PartialEq, Eq, Debug)]
|
||||
pub enum ProgEnum<
|
||||
Bls12_381I: IntoIterator<Item = Statement<Bls12_381Field>>,
|
||||
Bn128I: IntoIterator<Item = Statement<Bn128Field>>,
|
||||
Bls12_377I: IntoIterator<Item = Statement<Bls12_377Field>>,
|
||||
Bw6_761I: IntoIterator<Item = Statement<Bw6_761Field>>,
|
||||
'ast,
|
||||
Bls12_381I: IntoIterator<Item = Statement<'ast, Bls12_381Field>>,
|
||||
Bn128I: IntoIterator<Item = Statement<'ast, Bn128Field>>,
|
||||
Bls12_377I: IntoIterator<Item = Statement<'ast, Bls12_377Field>>,
|
||||
Bw6_761I: IntoIterator<Item = Statement<'ast, Bw6_761Field>>,
|
||||
> {
|
||||
Bls12_381Program(ProgIterator<Bls12_381Field, Bls12_381I>),
|
||||
Bn128Program(ProgIterator<Bn128Field, Bn128I>),
|
||||
Bls12_377Program(ProgIterator<Bls12_377Field, Bls12_377I>),
|
||||
Bw6_761Program(ProgIterator<Bw6_761Field, Bw6_761I>),
|
||||
Bls12_381Program(ProgIterator<'ast, Bls12_381Field, Bls12_381I>),
|
||||
Bn128Program(ProgIterator<'ast, Bn128Field, Bn128I>),
|
||||
Bls12_377Program(ProgIterator<'ast, Bls12_377Field, Bls12_377I>),
|
||||
Bw6_761Program(ProgIterator<'ast, Bw6_761Field, Bw6_761I>),
|
||||
}
|
||||
|
||||
type MemoryProgEnum = ProgEnum<
|
||||
Vec<Statement<Bls12_381Field>>,
|
||||
Vec<Statement<Bn128Field>>,
|
||||
Vec<Statement<Bls12_377Field>>,
|
||||
Vec<Statement<Bw6_761Field>>,
|
||||
type MemoryProgEnum<'ast> = ProgEnum<
|
||||
'ast,
|
||||
Vec<Statement<'ast, Bls12_381Field>>,
|
||||
Vec<Statement<'ast, Bn128Field>>,
|
||||
Vec<Statement<'ast, Bls12_377Field>>,
|
||||
Vec<Statement<'ast, Bw6_761Field>>,
|
||||
>;
|
||||
|
||||
impl<
|
||||
Bls12_381I: IntoIterator<Item = Statement<Bls12_381Field>>,
|
||||
Bn128I: IntoIterator<Item = Statement<Bn128Field>>,
|
||||
Bls12_377I: IntoIterator<Item = Statement<Bls12_377Field>>,
|
||||
Bw6_761I: IntoIterator<Item = Statement<Bw6_761Field>>,
|
||||
> ProgEnum<Bls12_381I, Bn128I, Bls12_377I, Bw6_761I>
|
||||
'ast,
|
||||
Bls12_381I: IntoIterator<Item = Statement<'ast, Bls12_381Field>>,
|
||||
Bn128I: IntoIterator<Item = Statement<'ast, Bn128Field>>,
|
||||
Bls12_377I: IntoIterator<Item = Statement<'ast, Bls12_377Field>>,
|
||||
Bw6_761I: IntoIterator<Item = Statement<'ast, Bw6_761Field>>,
|
||||
> ProgEnum<'ast, Bls12_381I, Bn128I, Bls12_377I, Bw6_761I>
|
||||
{
|
||||
pub fn collect(self) -> MemoryProgEnum {
|
||||
pub fn collect(self) -> MemoryProgEnum<'ast> {
|
||||
match self {
|
||||
ProgEnum::Bls12_381Program(p) => ProgEnum::Bls12_381Program(p.collect()),
|
||||
ProgEnum::Bn128Program(p) => ProgEnum::Bn128Program(p.collect()),
|
||||
|
@ -55,7 +58,7 @@ impl<
|
|||
}
|
||||
}
|
||||
|
||||
impl<T: Field, I: IntoIterator<Item = Statement<T>>> ProgIterator<T, I> {
|
||||
impl<'ast, T: Field, I: IntoIterator<Item = Statement<'ast, T>>> ProgIterator<'ast, T, I> {
|
||||
/// serialize a program iterator, returning the number of constraints serialized
|
||||
/// Note that we only return constraints, not other statements such as directives
|
||||
pub fn serialize<W: Write>(self, mut w: W) -> Result<usize, DynamicError> {
|
||||
|
@ -106,10 +109,11 @@ impl<'de, R: serde_cbor::de::Read<'de>, T: serde::Deserialize<'de>> Iterator
|
|||
|
||||
impl<'de, R: Read>
|
||||
ProgEnum<
|
||||
UnwrappedStreamDeserializer<'de, serde_cbor::de::IoRead<R>, Statement<Bls12_381Field>>,
|
||||
UnwrappedStreamDeserializer<'de, serde_cbor::de::IoRead<R>, Statement<Bn128Field>>,
|
||||
UnwrappedStreamDeserializer<'de, serde_cbor::de::IoRead<R>, Statement<Bls12_377Field>>,
|
||||
UnwrappedStreamDeserializer<'de, serde_cbor::de::IoRead<R>, Statement<Bw6_761Field>>,
|
||||
'de,
|
||||
UnwrappedStreamDeserializer<'de, serde_cbor::de::IoRead<R>, Statement<'de, Bls12_381Field>>,
|
||||
UnwrappedStreamDeserializer<'de, serde_cbor::de::IoRead<R>, Statement<'de, Bn128Field>>,
|
||||
UnwrappedStreamDeserializer<'de, serde_cbor::de::IoRead<R>, Statement<'de, Bls12_377Field>>,
|
||||
UnwrappedStreamDeserializer<'de, serde_cbor::de::IoRead<R>, Statement<'de, Bw6_761Field>>,
|
||||
>
|
||||
{
|
||||
pub fn deserialize(mut r: R) -> Result<Self, String> {
|
||||
|
|
|
@ -12,9 +12,9 @@ pub trait SMTLib2 {
|
|||
fn to_smtlib2(&self, f: &mut fmt::Formatter) -> fmt::Result;
|
||||
}
|
||||
|
||||
pub struct SMTLib2Display<'a, T>(pub &'a Prog<T>);
|
||||
pub struct SMTLib2Display<'a, 'ast, T>(pub &'a Prog<'ast, T>);
|
||||
|
||||
impl<T: Field> fmt::Display for SMTLib2Display<'_, T> {
|
||||
impl<'ast, T: Field> fmt::Display for SMTLib2Display<'_, 'ast, T> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
self.0.to_smtlib2(f)
|
||||
}
|
||||
|
@ -30,7 +30,7 @@ impl<T: Field> Visitor<T> for VariableCollector {
|
|||
}
|
||||
}
|
||||
|
||||
impl<T: Field> SMTLib2 for Prog<T> {
|
||||
impl<'ast, T: Field> SMTLib2 for Prog<'ast, T> {
|
||||
fn to_smtlib2(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
let mut collector = VariableCollector {
|
||||
variables: BTreeSet::<Variable>::new(),
|
||||
|
@ -75,9 +75,10 @@ fn format_prefix_op_smtlib2<T: SMTLib2, Ts: SMTLib2>(
|
|||
write!(f, ")")
|
||||
}
|
||||
|
||||
impl<T: Field> SMTLib2 for Statement<T> {
|
||||
impl<'ast, T: Field> SMTLib2 for Statement<'ast, T> {
|
||||
fn to_smtlib2(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
match *self {
|
||||
Statement::Block(..) => unreachable!(),
|
||||
Statement::Constraint(ref quad, ref lin, _) => {
|
||||
write!(f, "(= (mod ")?;
|
||||
quad.to_smtlib2(f)?;
|
||||
|
@ -91,7 +92,7 @@ impl<T: Field> SMTLib2 for Statement<T> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<T: Field> SMTLib2 for Directive<T> {
|
||||
impl<'ast, T: Field> SMTLib2 for Directive<'ast, T> {
|
||||
fn to_smtlib2(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
write!(f, "")
|
||||
}
|
||||
|
|
|
@ -53,6 +53,11 @@ pub fn visit_module<T: Field, F: Visitor<T>>(f: &mut F, p: &Prog<T>) {
|
|||
|
||||
pub fn visit_statement<T: Field, F: Visitor<T>>(f: &mut F, s: &Statement<T>) {
|
||||
match s {
|
||||
Statement::Block(statements) => {
|
||||
for s in statements {
|
||||
f.visit_statement(s);
|
||||
}
|
||||
}
|
||||
Statement::Constraint(quad, lin, error) => {
|
||||
f.visit_quadratic_combination(quad);
|
||||
f.visit_linear_combination(lin);
|
||||
|
|
|
@ -260,6 +260,13 @@ pub trait Folder<'ast, T: Field>: Sized {
|
|||
fold_assignee(self, a)
|
||||
}
|
||||
|
||||
fn fold_assembly_statement(
|
||||
&mut self,
|
||||
s: TypedAssemblyStatement<'ast, T>,
|
||||
) -> Vec<TypedAssemblyStatement<'ast, T>> {
|
||||
fold_assembly_statement(self, s)
|
||||
}
|
||||
|
||||
fn fold_statement(&mut self, s: TypedStatement<'ast, T>) -> Vec<TypedStatement<'ast, T>> {
|
||||
fold_statement(self, s)
|
||||
}
|
||||
|
@ -515,6 +522,27 @@ pub fn fold_definition_rhs<'ast, T: Field, F: Folder<'ast, T>>(
|
|||
}
|
||||
}
|
||||
|
||||
pub fn fold_assembly_statement<'ast, T: Field, F: Folder<'ast, T>>(
|
||||
f: &mut F,
|
||||
s: TypedAssemblyStatement<'ast, T>,
|
||||
) -> Vec<TypedAssemblyStatement<'ast, T>> {
|
||||
match s {
|
||||
TypedAssemblyStatement::Assignment(a, e) => {
|
||||
vec![TypedAssemblyStatement::Assignment(
|
||||
f.fold_assignee(a),
|
||||
f.fold_expression(e),
|
||||
)]
|
||||
}
|
||||
TypedAssemblyStatement::Constraint(lhs, rhs, metadata) => {
|
||||
vec![TypedAssemblyStatement::Constraint(
|
||||
f.fold_field_expression(lhs),
|
||||
f.fold_field_expression(rhs),
|
||||
metadata,
|
||||
)]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn fold_statement<'ast, T: Field, F: Folder<'ast, T>>(
|
||||
f: &mut F,
|
||||
s: TypedStatement<'ast, T>,
|
||||
|
@ -539,6 +567,12 @@ pub fn fold_statement<'ast, T: Field, F: Folder<'ast, T>>(
|
|||
TypedStatement::Log(s, e) => {
|
||||
TypedStatement::Log(s, e.into_iter().map(|e| f.fold_expression(e)).collect())
|
||||
}
|
||||
TypedStatement::Assembly(statements) => TypedStatement::Assembly(
|
||||
statements
|
||||
.into_iter()
|
||||
.flat_map(|s| f.fold_assembly_statement(s))
|
||||
.collect(),
|
||||
),
|
||||
s => s,
|
||||
};
|
||||
vec![res]
|
||||
|
@ -761,6 +795,36 @@ pub fn fold_field_expression<'ast, T: Field, F: Folder<'ast, T>>(
|
|||
|
||||
Pos(box e)
|
||||
}
|
||||
And(box left, box right) => {
|
||||
let left = f.fold_field_expression(left);
|
||||
let right = f.fold_field_expression(right);
|
||||
|
||||
And(box left, box right)
|
||||
}
|
||||
Or(box left, box right) => {
|
||||
let left = f.fold_field_expression(left);
|
||||
let right = f.fold_field_expression(right);
|
||||
|
||||
Or(box left, box right)
|
||||
}
|
||||
Xor(box left, box right) => {
|
||||
let left = f.fold_field_expression(left);
|
||||
let right = f.fold_field_expression(right);
|
||||
|
||||
Xor(box left, box right)
|
||||
}
|
||||
LeftShift(box e, box by) => {
|
||||
let e = f.fold_field_expression(e);
|
||||
let by = f.fold_uint_expression(by);
|
||||
|
||||
LeftShift(box e, box by)
|
||||
}
|
||||
RightShift(box e, box by) => {
|
||||
let e = f.fold_field_expression(e);
|
||||
let by = f.fold_uint_expression(by);
|
||||
|
||||
RightShift(box e, box by)
|
||||
}
|
||||
Conditional(c) => match f.fold_conditional_expression(&Type::FieldElement, c) {
|
||||
ConditionalOrExpression::Conditional(s) => Conditional(s),
|
||||
ConditionalOrExpression::Expression(u) => u,
|
||||
|
|
|
@ -1,10 +1,12 @@
|
|||
use crate::typed::CanonicalConstantIdentifier;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::fmt;
|
||||
|
||||
pub type SourceIdentifier<'ast> = &'ast str;
|
||||
pub type SourceIdentifier<'ast> = std::borrow::Cow<'ast, str>;
|
||||
|
||||
#[derive(Debug, PartialEq, Clone, Hash, Eq, PartialOrd, Ord)]
|
||||
#[derive(Debug, PartialEq, Clone, Hash, Eq, PartialOrd, Ord, Serialize, Deserialize)]
|
||||
pub enum CoreIdentifier<'ast> {
|
||||
#[serde(borrow)]
|
||||
Source(ShadowedIdentifier<'ast>),
|
||||
Call(usize),
|
||||
Constant(CanonicalConstantIdentifier<'ast>),
|
||||
|
@ -29,16 +31,18 @@ impl<'ast> From<CanonicalConstantIdentifier<'ast>> for CoreIdentifier<'ast> {
|
|||
}
|
||||
|
||||
/// A identifier for a variable
|
||||
#[derive(Debug, PartialEq, Clone, Hash, Eq, PartialOrd, Ord)]
|
||||
#[derive(Debug, PartialEq, Clone, Hash, Eq, PartialOrd, Ord, Serialize, Deserialize)]
|
||||
pub struct Identifier<'ast> {
|
||||
/// the id of the variable
|
||||
#[serde(borrow)]
|
||||
pub id: CoreIdentifier<'ast>,
|
||||
/// the version of the variable, used after SSA transformation
|
||||
pub version: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Clone, Hash, Eq, PartialOrd, Ord)]
|
||||
#[derive(Debug, PartialEq, Clone, Hash, Eq, PartialOrd, Ord, Serialize, Deserialize)]
|
||||
pub struct ShadowedIdentifier<'ast> {
|
||||
#[serde(borrow)]
|
||||
pub id: SourceIdentifier<'ast>,
|
||||
pub shadow: usize,
|
||||
}
|
||||
|
@ -97,7 +101,7 @@ impl<'ast> Identifier<'ast> {
|
|||
// these two From implementations are only used in tests but somehow cfg(test) doesn't work
|
||||
impl<'ast> From<&'ast str> for CoreIdentifier<'ast> {
|
||||
fn from(s: &str) -> CoreIdentifier {
|
||||
CoreIdentifier::Source(ShadowedIdentifier::shadow(s, 0))
|
||||
CoreIdentifier::Source(ShadowedIdentifier::shadow(std::borrow::Cow::Borrowed(s), 0))
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -446,6 +446,24 @@ impl<'ast, T: Field> FieldElementExpression<'ast, T> {
|
|||
box Self::try_from_int(e1)?,
|
||||
box Self::try_from_int(e2)?,
|
||||
)),
|
||||
IntExpression::And(box e1, box e2) => Ok(Self::And(
|
||||
box Self::try_from_int(e1)?,
|
||||
box Self::try_from_int(e2)?,
|
||||
)),
|
||||
IntExpression::Or(box e1, box e2) => Ok(Self::Or(
|
||||
box Self::try_from_int(e1)?,
|
||||
box Self::try_from_int(e2)?,
|
||||
)),
|
||||
IntExpression::Xor(box e1, box e2) => Ok(Self::Xor(
|
||||
box Self::try_from_int(e1)?,
|
||||
box Self::try_from_int(e2)?,
|
||||
)),
|
||||
IntExpression::LeftShift(box e1, box e2) => {
|
||||
Ok(Self::LeftShift(box Self::try_from_int(e1)?, box e2))
|
||||
}
|
||||
IntExpression::RightShift(box e1, box e2) => {
|
||||
Ok(Self::RightShift(box Self::try_from_int(e1)?, box e2))
|
||||
}
|
||||
IntExpression::Pos(box e) => Ok(Self::Pos(box Self::try_from_int(e)?)),
|
||||
IntExpression::Neg(box e) => Ok(Self::Neg(box Self::try_from_int(e)?)),
|
||||
IntExpression::Conditional(c) => Ok(Self::Conditional(ConditionalExpression::new(
|
||||
|
@ -843,11 +861,6 @@ mod tests {
|
|||
|
||||
let should_error = vec![
|
||||
BigUint::parse_bytes(b"99999999999999999999999999999999999999999999999999999999999999999999999999999999999", 10).unwrap().into(),
|
||||
IntExpression::xor(n.clone(), n.clone()),
|
||||
IntExpression::or(n.clone(), n.clone()),
|
||||
IntExpression::and(n.clone(), n.clone()),
|
||||
IntExpression::left_shift(n.clone(), i.clone()),
|
||||
IntExpression::right_shift(n.clone(), i.clone()),
|
||||
IntExpression::not(n.clone()),
|
||||
];
|
||||
|
||||
|
|
|
@ -27,9 +27,7 @@ pub use self::types::{
|
|||
UBitwidth,
|
||||
};
|
||||
use self::types::{ConcreteArrayType, ConcreteStructType};
|
||||
|
||||
use crate::typed::types::{ConcreteGenericsAssignment, IntoType};
|
||||
use crate::untyped::Position;
|
||||
|
||||
pub use self::variable::{ConcreteVariable, DeclarationVariable, GVariable, Variable};
|
||||
use std::marker::PhantomData;
|
||||
|
@ -38,7 +36,7 @@ use std::path::{Path, PathBuf};
|
|||
pub use crate::typed::integer::IntExpression;
|
||||
pub use crate::typed::uint::{bitwidth, UExpression, UExpressionInner, UMetadata};
|
||||
|
||||
use crate::common::{FlatEmbed, FormatString};
|
||||
use crate::common::{FlatEmbed, FormatString, SourceMetadata};
|
||||
|
||||
use std::collections::BTreeMap;
|
||||
use std::convert::{TryFrom, TryInto};
|
||||
|
@ -569,26 +567,9 @@ impl<'ast, T: fmt::Display> fmt::Display for TypedAssignee<'ast, T> {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Hash, Eq, Default, PartialOrd, Ord)]
|
||||
pub struct AssertionMetadata {
|
||||
pub file: String,
|
||||
pub position: Position,
|
||||
pub message: Option<String>,
|
||||
}
|
||||
|
||||
impl fmt::Display for AssertionMetadata {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
write!(f, "Assertion failed at {}:{}", self.file, self.position)?;
|
||||
match &self.message {
|
||||
Some(m) => write!(f, ": \"{}\"", m),
|
||||
None => write!(f, ""),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Hash, Eq, PartialOrd, Ord)]
|
||||
pub enum RuntimeError {
|
||||
SourceAssertion(AssertionMetadata),
|
||||
SourceAssertion(SourceMetadata),
|
||||
SelectRangeCheck,
|
||||
DivisionByZero,
|
||||
}
|
||||
|
@ -677,6 +658,29 @@ impl<'ast, T: fmt::Display> fmt::Display for DefinitionRhs<'ast, T> {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, PartialEq, Debug, Hash, Eq, PartialOrd, Ord)]
|
||||
pub enum TypedAssemblyStatement<'ast, T> {
|
||||
Assignment(TypedAssignee<'ast, T>, TypedExpression<'ast, T>),
|
||||
Constraint(
|
||||
FieldElementExpression<'ast, T>,
|
||||
FieldElementExpression<'ast, T>,
|
||||
SourceMetadata,
|
||||
),
|
||||
}
|
||||
|
||||
impl<'ast, T: fmt::Display> fmt::Display for TypedAssemblyStatement<'ast, T> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
match *self {
|
||||
TypedAssemblyStatement::Assignment(ref lhs, ref rhs) => {
|
||||
write!(f, "{} <-- {};", lhs, rhs)
|
||||
}
|
||||
TypedAssemblyStatement::Constraint(ref lhs, ref rhs, _) => {
|
||||
write!(f, "{} === {};", lhs, rhs)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A statement in a `TypedFunction`
|
||||
#[allow(clippy::large_enum_variant)]
|
||||
#[derive(Clone, PartialEq, Debug, Hash, Eq, PartialOrd, Ord)]
|
||||
|
@ -697,6 +701,7 @@ pub enum TypedStatement<'ast, T> {
|
|||
ConcreteGenericsAssignment<'ast>,
|
||||
),
|
||||
PopCallLog,
|
||||
Assembly(Vec<TypedAssemblyStatement<'ast, T>>),
|
||||
}
|
||||
|
||||
impl<'ast, T> TypedStatement<'ast, T> {
|
||||
|
@ -721,6 +726,14 @@ impl<'ast, T: fmt::Display> TypedStatement<'ast, T> {
|
|||
}
|
||||
write!(f, "{}}}", "\t".repeat(depth))
|
||||
}
|
||||
TypedStatement::Assembly(statements) => {
|
||||
write!(f, "{}", "\t".repeat(depth))?;
|
||||
writeln!(f, "asm {{")?;
|
||||
for s in statements {
|
||||
writeln!(f, "{}{}", "\t".repeat(depth + 1), s)?;
|
||||
}
|
||||
write!(f, "{}}}", "\t".repeat(depth))
|
||||
}
|
||||
s => write!(f, "{}{}", "\t".repeat(depth), s),
|
||||
}
|
||||
}
|
||||
|
@ -768,6 +781,13 @@ impl<'ast, T: fmt::Display> fmt::Display for TypedStatement<'ast, T> {
|
|||
generics,
|
||||
),
|
||||
TypedStatement::PopCallLog => write!(f, "// POP CALL",),
|
||||
TypedStatement::Assembly(ref statements) => {
|
||||
writeln!(f, "asm {{")?;
|
||||
for s in statements {
|
||||
writeln!(f, "\t\t{}", s)?;
|
||||
}
|
||||
write!(f, "\t}}")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1188,6 +1208,26 @@ pub enum FieldElementExpression<'ast, T> {
|
|||
Box<FieldElementExpression<'ast, T>>,
|
||||
Box<UExpression<'ast, T>>,
|
||||
),
|
||||
And(
|
||||
Box<FieldElementExpression<'ast, T>>,
|
||||
Box<FieldElementExpression<'ast, T>>,
|
||||
),
|
||||
Or(
|
||||
Box<FieldElementExpression<'ast, T>>,
|
||||
Box<FieldElementExpression<'ast, T>>,
|
||||
),
|
||||
Xor(
|
||||
Box<FieldElementExpression<'ast, T>>,
|
||||
Box<FieldElementExpression<'ast, T>>,
|
||||
),
|
||||
LeftShift(
|
||||
Box<FieldElementExpression<'ast, T>>,
|
||||
Box<UExpression<'ast, T>>,
|
||||
),
|
||||
RightShift(
|
||||
Box<FieldElementExpression<'ast, T>>,
|
||||
Box<UExpression<'ast, T>>,
|
||||
),
|
||||
Conditional(ConditionalExpression<'ast, T, Self>),
|
||||
Neg(Box<FieldElementExpression<'ast, T>>),
|
||||
Pos(Box<FieldElementExpression<'ast, T>>),
|
||||
|
@ -1196,6 +1236,73 @@ pub enum FieldElementExpression<'ast, T> {
|
|||
Select(SelectExpression<'ast, T, Self>),
|
||||
Element(ElementExpression<'ast, T, Self>),
|
||||
}
|
||||
|
||||
impl<'ast, T: Field> From<TypedAssignee<'ast, T>> for TupleExpression<'ast, T> {
|
||||
fn from(assignee: TypedAssignee<'ast, T>) -> Self {
|
||||
match assignee {
|
||||
TypedAssignee::Identifier(v) => {
|
||||
let inner = TupleExpression::identifier(v.id);
|
||||
match v._type {
|
||||
GType::Tuple(tuple_ty) => inner.annotate(tuple_ty),
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
TypedAssignee::Select(box a, box index) => TupleExpression::select(a.into(), index),
|
||||
TypedAssignee::Member(box a, id) => TupleExpression::member(a.into(), id),
|
||||
TypedAssignee::Element(box a, index) => TupleExpression::element(a.into(), index),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast, T: Field> From<TypedAssignee<'ast, T>> for StructExpression<'ast, T> {
|
||||
fn from(assignee: TypedAssignee<'ast, T>) -> Self {
|
||||
match assignee {
|
||||
TypedAssignee::Identifier(v) => {
|
||||
let inner = StructExpression::identifier(v.id);
|
||||
match v._type {
|
||||
GType::Struct(struct_ty) => inner.annotate(struct_ty),
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
TypedAssignee::Select(box a, box index) => StructExpression::select(a.into(), index),
|
||||
TypedAssignee::Member(box a, id) => StructExpression::member(a.into(), id),
|
||||
TypedAssignee::Element(box a, index) => StructExpression::element(a.into(), index),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast, T: Field> From<TypedAssignee<'ast, T>> for ArrayExpression<'ast, T> {
|
||||
fn from(assignee: TypedAssignee<'ast, T>) -> Self {
|
||||
match assignee {
|
||||
TypedAssignee::Identifier(v) => {
|
||||
let inner = ArrayExpression::identifier(v.id);
|
||||
match v._type {
|
||||
GType::Array(array_ty) => inner.annotate(*array_ty.ty, *array_ty.size),
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
TypedAssignee::Select(box a, box index) => ArrayExpression::select(a.into(), index),
|
||||
TypedAssignee::Member(box a, id) => ArrayExpression::member(a.into(), id),
|
||||
TypedAssignee::Element(box a, index) => ArrayExpression::element(a.into(), index),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast, T: Field> From<TypedAssignee<'ast, T>> for FieldElementExpression<'ast, T> {
|
||||
fn from(assignee: TypedAssignee<'ast, T>) -> Self {
|
||||
match assignee {
|
||||
TypedAssignee::Identifier(v) => FieldElementExpression::identifier(v.id),
|
||||
TypedAssignee::Element(box a, index) => {
|
||||
FieldElementExpression::element(a.into(), index)
|
||||
}
|
||||
TypedAssignee::Member(box a, id) => FieldElementExpression::member(a.into(), id),
|
||||
TypedAssignee::Select(box a, box index) => {
|
||||
FieldElementExpression::select(a.into(), index)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast, T> Add for FieldElementExpression<'ast, T> {
|
||||
type Output = Self;
|
||||
|
||||
|
@ -1676,6 +1783,11 @@ impl<'ast, T: fmt::Display> fmt::Display for FieldElementExpression<'ast, T> {
|
|||
FieldElementExpression::Pow(ref lhs, ref rhs) => write!(f, "{}**{}", lhs, rhs),
|
||||
FieldElementExpression::Neg(ref e) => write!(f, "(-{})", e),
|
||||
FieldElementExpression::Pos(ref e) => write!(f, "(+{})", e),
|
||||
FieldElementExpression::And(ref lhs, ref rhs) => write!(f, "({} & {})", lhs, rhs),
|
||||
FieldElementExpression::Or(ref lhs, ref rhs) => write!(f, "({} | {})", lhs, rhs),
|
||||
FieldElementExpression::Xor(ref lhs, ref rhs) => write!(f, "({} ^ {})", lhs, rhs),
|
||||
FieldElementExpression::RightShift(ref e, ref by) => write!(f, "({} >> {})", e, by),
|
||||
FieldElementExpression::LeftShift(ref e, ref by) => write!(f, "({} << {})", e, by),
|
||||
FieldElementExpression::Conditional(ref c) => write!(f, "{}", c),
|
||||
FieldElementExpression::FunctionCall(ref function_call) => {
|
||||
write!(f, "{}", function_call)
|
||||
|
|
|
@ -386,6 +386,13 @@ pub trait ResultFolder<'ast, T: Field>: Sized {
|
|||
fold_assignee(self, a)
|
||||
}
|
||||
|
||||
fn fold_assembly_statement(
|
||||
&mut self,
|
||||
s: TypedAssemblyStatement<'ast, T>,
|
||||
) -> Result<Vec<TypedAssemblyStatement<'ast, T>>, Self::Error> {
|
||||
fold_assembly_statement(self, s)
|
||||
}
|
||||
|
||||
fn fold_statement(
|
||||
&mut self,
|
||||
s: TypedStatement<'ast, T>,
|
||||
|
@ -516,6 +523,27 @@ pub trait ResultFolder<'ast, T: Field>: Sized {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn fold_assembly_statement<'ast, T: Field, F: ResultFolder<'ast, T>>(
|
||||
f: &mut F,
|
||||
s: TypedAssemblyStatement<'ast, T>,
|
||||
) -> Result<Vec<TypedAssemblyStatement<'ast, T>>, F::Error> {
|
||||
Ok(match s {
|
||||
TypedAssemblyStatement::Assignment(a, e) => {
|
||||
vec![TypedAssemblyStatement::Assignment(
|
||||
f.fold_assignee(a)?,
|
||||
f.fold_expression(e)?,
|
||||
)]
|
||||
}
|
||||
TypedAssemblyStatement::Constraint(lhs, rhs, metadata) => {
|
||||
vec![TypedAssemblyStatement::Constraint(
|
||||
f.fold_field_expression(lhs)?,
|
||||
f.fold_field_expression(rhs)?,
|
||||
metadata,
|
||||
)]
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
pub fn fold_statement<'ast, T: Field, F: ResultFolder<'ast, T>>(
|
||||
f: &mut F,
|
||||
s: TypedStatement<'ast, T>,
|
||||
|
@ -546,6 +574,15 @@ pub fn fold_statement<'ast, T: Field, F: ResultFolder<'ast, T>>(
|
|||
.map(|e| f.fold_expression(e))
|
||||
.collect::<Result<Vec<_>, _>>()?,
|
||||
),
|
||||
TypedStatement::Assembly(statements) => TypedStatement::Assembly(
|
||||
statements
|
||||
.into_iter()
|
||||
.map(|s| f.fold_assembly_statement(s))
|
||||
.collect::<Result<Vec<_>, _>>()?
|
||||
.into_iter()
|
||||
.flatten()
|
||||
.collect(),
|
||||
),
|
||||
s => s,
|
||||
};
|
||||
Ok(vec![res])
|
||||
|
@ -780,6 +817,36 @@ pub fn fold_field_expression<'ast, T: Field, F: ResultFolder<'ast, T>>(
|
|||
|
||||
Pos(box e)
|
||||
}
|
||||
And(box left, box right) => {
|
||||
let left = f.fold_field_expression(left)?;
|
||||
let right = f.fold_field_expression(right)?;
|
||||
|
||||
And(box left, box right)
|
||||
}
|
||||
Or(box left, box right) => {
|
||||
let left = f.fold_field_expression(left)?;
|
||||
let right = f.fold_field_expression(right)?;
|
||||
|
||||
Or(box left, box right)
|
||||
}
|
||||
Xor(box left, box right) => {
|
||||
let left = f.fold_field_expression(left)?;
|
||||
let right = f.fold_field_expression(right)?;
|
||||
|
||||
Xor(box left, box right)
|
||||
}
|
||||
LeftShift(box e, box by) => {
|
||||
let e = f.fold_field_expression(e)?;
|
||||
let by = f.fold_uint_expression(by)?;
|
||||
|
||||
LeftShift(box e, box by)
|
||||
}
|
||||
RightShift(box e, box by) => {
|
||||
let e = f.fold_field_expression(e)?;
|
||||
let by = f.fold_uint_expression(by)?;
|
||||
|
||||
RightShift(box e, box by)
|
||||
}
|
||||
Conditional(c) => match f.fold_conditional_expression(&Type::FieldElement, c)? {
|
||||
ConditionalOrExpression::Conditional(c) => Conditional(c),
|
||||
ConditionalOrExpression::Expression(u) => u,
|
||||
|
|
|
@ -52,7 +52,10 @@ pub struct GenericIdentifier<'ast> {
|
|||
impl<'ast> From<GenericIdentifier<'ast>> for CoreIdentifier<'ast> {
|
||||
fn from(g: GenericIdentifier<'ast>) -> CoreIdentifier<'ast> {
|
||||
// generic identifiers are always declared in the function scope, which is shadow 0
|
||||
CoreIdentifier::Source(ShadowedIdentifier::shadow(g.name(), 0))
|
||||
CoreIdentifier::Source(ShadowedIdentifier::shadow(
|
||||
std::borrow::Cow::Borrowed(g.name()),
|
||||
0,
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -120,9 +123,10 @@ pub struct SpecializationError;
|
|||
|
||||
pub type ConstantIdentifier<'ast> = &'ast str;
|
||||
|
||||
#[derive(Clone, PartialEq, Eq, Debug, Hash, PartialOrd, Ord)]
|
||||
#[derive(Clone, PartialEq, Eq, Debug, Hash, PartialOrd, Ord, Serialize, Deserialize)]
|
||||
pub struct CanonicalConstantIdentifier<'ast> {
|
||||
pub module: OwnedTypedModuleId,
|
||||
#[serde(borrow)]
|
||||
pub id: ConstantIdentifier<'ast>,
|
||||
}
|
||||
|
||||
|
|
|
@ -277,6 +277,7 @@ impl<'ast> From<pest::Statement<'ast>> for untyped::StatementNode<'ast> {
|
|||
pest::Statement::Assertion(s) => untyped::StatementNode::from(s),
|
||||
pest::Statement::Return(s) => untyped::StatementNode::from(s),
|
||||
pest::Statement::Log(s) => untyped::StatementNode::from(s),
|
||||
pest::Statement::Assembly(s) => untyped::StatementNode::from(s),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -340,6 +341,32 @@ impl<'ast> From<pest::IterationStatement<'ast>> for untyped::StatementNode<'ast>
|
|||
}
|
||||
}
|
||||
|
||||
impl<'ast> From<pest::AssemblyStatement<'ast>> for untyped::StatementNode<'ast> {
|
||||
fn from(statement: pest::AssemblyStatement<'ast>) -> untyped::StatementNode<'ast> {
|
||||
use crate::untyped::NodeValue;
|
||||
|
||||
let statements = statement
|
||||
.inner
|
||||
.into_iter()
|
||||
.map(|s| match s {
|
||||
pest::AssemblyStatementInner::Assignment(a) => {
|
||||
untyped::AssemblyStatement::Assignment(
|
||||
a.assignee.into(),
|
||||
a.expression.into(),
|
||||
matches!(a.operator, pest::AssignmentOperator::AssignConstrain(_)),
|
||||
)
|
||||
.span(a.span)
|
||||
}
|
||||
pest::AssemblyStatementInner::Constraint(c) => {
|
||||
untyped::AssemblyStatement::Constraint(c.lhs.into(), c.rhs.into()).span(c.span)
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
untyped::Statement::Assembly(statements).span(statement.span)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast> From<pest::Expression<'ast>> for untyped::ExpressionNode<'ast> {
|
||||
fn from(expression: pest::Expression<'ast>) -> untyped::ExpressionNode<'ast> {
|
||||
match expression {
|
||||
|
|
|
@ -382,6 +382,33 @@ impl<'ast> fmt::Display for Assignee<'ast> {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub enum AssemblyStatement<'ast> {
|
||||
Assignment(AssigneeNode<'ast>, ExpressionNode<'ast>, bool),
|
||||
Constraint(ExpressionNode<'ast>, ExpressionNode<'ast>),
|
||||
}
|
||||
|
||||
pub type AssemblyStatementNode<'ast> = Node<AssemblyStatement<'ast>>;
|
||||
|
||||
impl<'ast> fmt::Display for AssemblyStatement<'ast> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
match *self {
|
||||
AssemblyStatement::Assignment(ref lhs, ref rhs, ref constrained) => {
|
||||
write!(
|
||||
f,
|
||||
"{} <{} {}",
|
||||
lhs,
|
||||
if *constrained { "==" } else { "--" },
|
||||
rhs
|
||||
)
|
||||
}
|
||||
AssemblyStatement::Constraint(ref lhs, ref rhs) => {
|
||||
write!(f, "{} === {}", lhs, rhs)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A statement in a `Function`
|
||||
#[allow(clippy::large_enum_variant)]
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
|
@ -397,6 +424,7 @@ pub enum Statement<'ast> {
|
|||
Vec<StatementNode<'ast>>,
|
||||
),
|
||||
Log(&'ast str, Vec<ExpressionNode<'ast>>),
|
||||
Assembly(Vec<AssemblyStatementNode<'ast>>),
|
||||
}
|
||||
|
||||
pub type StatementNode<'ast> = Node<Statement<'ast>>;
|
||||
|
@ -431,7 +459,7 @@ impl<'ast> fmt::Display for Statement<'ast> {
|
|||
}
|
||||
Statement::Log(ref l, ref expressions) => write!(
|
||||
f,
|
||||
"log({}, {})",
|
||||
"log({}, {});",
|
||||
l,
|
||||
expressions
|
||||
.iter()
|
||||
|
@ -439,6 +467,13 @@ impl<'ast> fmt::Display for Statement<'ast> {
|
|||
.collect::<Vec<_>>()
|
||||
.join(", ")
|
||||
),
|
||||
Statement::Assembly(ref statements) => {
|
||||
writeln!(f, "asm {{")?;
|
||||
for s in statements {
|
||||
writeln!(f, "\t\t{};", s)?;
|
||||
}
|
||||
write!(f, "\t}}")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -84,6 +84,7 @@ use super::*;
|
|||
impl<'ast> NodeValue for Expression<'ast> {}
|
||||
impl<'ast> NodeValue for Assignee<'ast> {}
|
||||
impl<'ast> NodeValue for Statement<'ast> {}
|
||||
impl<'ast> NodeValue for AssemblyStatement<'ast> {}
|
||||
impl<'ast> NodeValue for SymbolDeclaration<'ast> {}
|
||||
impl<'ast> NodeValue for UnresolvedType<'ast> {}
|
||||
impl<'ast> NodeValue for StructDefinition<'ast> {}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
use serde::{Deserialize, Serialize};
|
||||
use std::fmt;
|
||||
|
||||
#[derive(Clone, PartialEq, Eq, Copy, Hash, Default, PartialOrd, Ord)]
|
||||
#[derive(Clone, PartialEq, Eq, Copy, Hash, Default, PartialOrd, Ord, Serialize, Deserialize)]
|
||||
pub struct Position {
|
||||
pub line: usize,
|
||||
pub col: usize,
|
||||
|
|
|
@ -56,6 +56,13 @@ pub trait Folder<'ast, T: Field>: Sized {
|
|||
self.fold_variable(a)
|
||||
}
|
||||
|
||||
fn fold_assembly_statement(
|
||||
&mut self,
|
||||
s: ZirAssemblyStatement<'ast, T>,
|
||||
) -> Vec<ZirAssemblyStatement<'ast, T>> {
|
||||
fold_assembly_statement(self, s)
|
||||
}
|
||||
|
||||
fn fold_statement(&mut self, s: ZirStatement<'ast, T>) -> Vec<ZirStatement<'ast, T>> {
|
||||
fold_statement(self, s)
|
||||
}
|
||||
|
@ -135,6 +142,24 @@ pub trait Folder<'ast, T: Field>: Sized {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn fold_assembly_statement<'ast, T: Field, F: Folder<'ast, T>>(
|
||||
f: &mut F,
|
||||
s: ZirAssemblyStatement<'ast, T>,
|
||||
) -> Vec<ZirAssemblyStatement<'ast, T>> {
|
||||
match s {
|
||||
ZirAssemblyStatement::Assignment(assignees, function) => {
|
||||
let assignees = assignees.into_iter().map(|a| f.fold_assignee(a)).collect();
|
||||
let function = f.fold_function(function);
|
||||
vec![ZirAssemblyStatement::Assignment(assignees, function)]
|
||||
}
|
||||
ZirAssemblyStatement::Constraint(lhs, rhs, metadata) => {
|
||||
let lhs = f.fold_field_expression(lhs);
|
||||
let rhs = f.fold_field_expression(rhs);
|
||||
vec![ZirAssemblyStatement::Constraint(lhs, rhs, metadata)]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn fold_statement<'ast, T: Field, F: Folder<'ast, T>>(
|
||||
f: &mut F,
|
||||
s: ZirStatement<'ast, T>,
|
||||
|
@ -173,6 +198,12 @@ pub fn fold_statement<'ast, T: Field, F: Folder<'ast, T>>(
|
|||
.map(|(t, e)| (t, e.into_iter().map(|e| f.fold_expression(e)).collect()))
|
||||
.collect(),
|
||||
),
|
||||
ZirStatement::Assembly(statements) => ZirStatement::Assembly(
|
||||
statements
|
||||
.into_iter()
|
||||
.flat_map(|s| f.fold_assembly_statement(s))
|
||||
.collect(),
|
||||
),
|
||||
};
|
||||
vec![res]
|
||||
}
|
||||
|
@ -233,6 +264,36 @@ pub fn fold_field_expression<'ast, T: Field, F: Folder<'ast, T>>(
|
|||
let e2 = f.fold_uint_expression(e2);
|
||||
FieldElementExpression::Pow(box e1, box e2)
|
||||
}
|
||||
FieldElementExpression::And(box left, box right) => {
|
||||
let left = f.fold_field_expression(left);
|
||||
let right = f.fold_field_expression(right);
|
||||
|
||||
FieldElementExpression::And(box left, box right)
|
||||
}
|
||||
FieldElementExpression::Or(box left, box right) => {
|
||||
let left = f.fold_field_expression(left);
|
||||
let right = f.fold_field_expression(right);
|
||||
|
||||
FieldElementExpression::Or(box left, box right)
|
||||
}
|
||||
FieldElementExpression::Xor(box left, box right) => {
|
||||
let left = f.fold_field_expression(left);
|
||||
let right = f.fold_field_expression(right);
|
||||
|
||||
FieldElementExpression::Xor(box left, box right)
|
||||
}
|
||||
FieldElementExpression::LeftShift(box e, box by) => {
|
||||
let e = f.fold_field_expression(e);
|
||||
let by = f.fold_uint_expression(by);
|
||||
|
||||
FieldElementExpression::LeftShift(box e, box by)
|
||||
}
|
||||
FieldElementExpression::RightShift(box e, box by) => {
|
||||
let e = f.fold_field_expression(e);
|
||||
let by = f.fold_uint_expression(by);
|
||||
|
||||
FieldElementExpression::RightShift(box e, box by)
|
||||
}
|
||||
FieldElementExpression::Conditional(c) => {
|
||||
match f.fold_conditional_expression(&Type::FieldElement, c) {
|
||||
ConditionalOrExpression::Conditional(s) => FieldElementExpression::Conditional(s),
|
||||
|
|
|
@ -1,15 +1,18 @@
|
|||
use crate::zir::types::MemberId;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::fmt;
|
||||
|
||||
use crate::typed::Identifier as CoreIdentifier;
|
||||
|
||||
#[derive(Debug, PartialEq, Clone, Hash, Eq)]
|
||||
#[derive(Debug, PartialEq, Clone, Hash, Eq, Serialize, Deserialize)]
|
||||
pub enum Identifier<'ast> {
|
||||
#[serde(borrow)]
|
||||
Source(SourceIdentifier<'ast>),
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Clone, Hash, Eq)]
|
||||
#[derive(Debug, PartialEq, Clone, Hash, Eq, Serialize, Deserialize)]
|
||||
pub enum SourceIdentifier<'ast> {
|
||||
#[serde(borrow)]
|
||||
Basic(CoreIdentifier<'ast>),
|
||||
Select(Box<SourceIdentifier<'ast>>, u32),
|
||||
Member(Box<SourceIdentifier<'ast>>, MemberId),
|
||||
|
|
287
zokrates_ast/src/zir/lqc.rs
Normal file
287
zokrates_ast/src/zir/lqc.rs
Normal file
|
@ -0,0 +1,287 @@
|
|||
use crate::zir::{FieldElementExpression, Identifier};
|
||||
use zokrates_field::Field;
|
||||
|
||||
pub type LinearTerm<'ast, T> = (T, Identifier<'ast>);
|
||||
pub type QuadraticTerm<'ast, T> = (T, Identifier<'ast>, Identifier<'ast>);
|
||||
|
||||
#[derive(Clone, PartialEq, Hash, Eq, Debug, Default)]
|
||||
pub struct LinQuadComb<'ast, T> {
|
||||
// the constant terms
|
||||
pub constant: T,
|
||||
// the linear terms
|
||||
pub linear: Vec<LinearTerm<'ast, T>>,
|
||||
// the quadratic terms
|
||||
pub quadratic: Vec<QuadraticTerm<'ast, T>>,
|
||||
}
|
||||
|
||||
impl<'ast, T: Field> std::ops::Add for LinQuadComb<'ast, T> {
|
||||
type Output = Self;
|
||||
|
||||
fn add(self, mut other: Self) -> Self::Output {
|
||||
Self {
|
||||
constant: self.constant + other.constant,
|
||||
linear: {
|
||||
let mut l = self.linear;
|
||||
l.append(&mut other.linear);
|
||||
l
|
||||
},
|
||||
quadratic: {
|
||||
let mut q = self.quadratic;
|
||||
q.append(&mut other.quadratic);
|
||||
q
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast, T: Field> std::ops::Sub for LinQuadComb<'ast, T> {
|
||||
type Output = Self;
|
||||
|
||||
fn sub(self, mut other: Self) -> Self::Output {
|
||||
Self {
|
||||
constant: self.constant - other.constant,
|
||||
linear: {
|
||||
let mut l = self.linear;
|
||||
other.linear.iter_mut().for_each(|(c, _)| {
|
||||
*c = T::zero() - &*c;
|
||||
});
|
||||
l.append(&mut other.linear);
|
||||
l
|
||||
},
|
||||
quadratic: {
|
||||
let mut q = self.quadratic;
|
||||
other.quadratic.iter_mut().for_each(|(c, _, _)| {
|
||||
*c = T::zero() - &*c;
|
||||
});
|
||||
q.append(&mut other.quadratic);
|
||||
q
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast, T: Field> LinQuadComb<'ast, T> {
|
||||
fn try_mul(self, rhs: Self) -> Result<Self, ()> {
|
||||
// fail if the result has degree higher than 2
|
||||
if !(self.quadratic.is_empty() || rhs.quadratic.is_empty()) {
|
||||
return Err(());
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
constant: self.constant.clone() * rhs.constant.clone(),
|
||||
linear: {
|
||||
// lin0 * const1 + lin1 * const0
|
||||
self.linear
|
||||
.clone()
|
||||
.into_iter()
|
||||
.map(|(c, i)| (c * rhs.constant.clone(), i))
|
||||
.chain(
|
||||
rhs.linear
|
||||
.clone()
|
||||
.into_iter()
|
||||
.map(|(c, i)| (c * self.constant.clone(), i)),
|
||||
)
|
||||
.collect()
|
||||
},
|
||||
quadratic: {
|
||||
// quad0 * const1 + quad1 * const0 + lin0 * lin1
|
||||
self.quadratic
|
||||
.into_iter()
|
||||
.map(|(c, i0, i1)| (c * rhs.constant.clone(), i0, i1))
|
||||
.chain(
|
||||
rhs.quadratic
|
||||
.into_iter()
|
||||
.map(|(c, i0, i1)| (c * self.constant.clone(), i0, i1)),
|
||||
)
|
||||
.chain(self.linear.iter().flat_map(|(cl, l)| {
|
||||
rhs.linear
|
||||
.iter()
|
||||
.map(|(cr, r)| (cl.clone() * cr.clone(), l.clone(), r.clone()))
|
||||
}))
|
||||
.collect()
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast, T: Field> TryFrom<FieldElementExpression<'ast, T>> for LinQuadComb<'ast, T> {
|
||||
type Error = ();
|
||||
|
||||
fn try_from(e: FieldElementExpression<'ast, T>) -> Result<Self, Self::Error> {
|
||||
match e {
|
||||
FieldElementExpression::Number(v) => Ok(Self {
|
||||
constant: v,
|
||||
..Self::default()
|
||||
}),
|
||||
FieldElementExpression::Identifier(id) => Ok(Self {
|
||||
linear: vec![(T::one(), id.id)],
|
||||
..Self::default()
|
||||
}),
|
||||
FieldElementExpression::Add(box left, box right) => {
|
||||
Ok(Self::try_from(left)? + Self::try_from(right)?)
|
||||
}
|
||||
FieldElementExpression::Sub(box left, box right) => {
|
||||
Ok(Self::try_from(left)? - Self::try_from(right)?)
|
||||
}
|
||||
FieldElementExpression::Mult(box left, box right) => {
|
||||
let left = Self::try_from(left)?;
|
||||
let right = Self::try_from(right)?;
|
||||
|
||||
left.try_mul(right)
|
||||
}
|
||||
_ => Err(()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::zir::Id;
|
||||
use zokrates_field::Bn128Field;
|
||||
|
||||
#[test]
|
||||
fn add() {
|
||||
// (2 + 2*a)
|
||||
let a = LinQuadComb::try_from(FieldElementExpression::Add(
|
||||
box FieldElementExpression::Number(Bn128Field::from(2)),
|
||||
box FieldElementExpression::Mult(
|
||||
box FieldElementExpression::Number(Bn128Field::from(2)),
|
||||
box FieldElementExpression::identifier("a".into()),
|
||||
),
|
||||
))
|
||||
.unwrap();
|
||||
|
||||
// (2 + 2*a*b)
|
||||
let b = LinQuadComb::try_from(FieldElementExpression::Add(
|
||||
box FieldElementExpression::Number(Bn128Field::from(2)),
|
||||
box FieldElementExpression::Mult(
|
||||
box FieldElementExpression::Mult(
|
||||
box FieldElementExpression::Number(Bn128Field::from(2)),
|
||||
box FieldElementExpression::identifier("a".into()),
|
||||
),
|
||||
box FieldElementExpression::identifier("b".into()),
|
||||
),
|
||||
))
|
||||
.unwrap();
|
||||
|
||||
// (2 + 2*a) + (2 + 2*a*b) => 4 + 2*a + 2*a*b
|
||||
let c = a + b;
|
||||
|
||||
assert_eq!(c.constant, Bn128Field::from(4));
|
||||
assert_eq!(
|
||||
c.linear,
|
||||
vec![
|
||||
(Bn128Field::from(2), "a".into()),
|
||||
(Bn128Field::from(0), "a".into()),
|
||||
(Bn128Field::from(0), "b".into())
|
||||
]
|
||||
);
|
||||
assert_eq!(
|
||||
c.quadratic,
|
||||
vec![(Bn128Field::from(2), "a".into(), "b".into())]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sub() {
|
||||
// (2 + 2*a)
|
||||
let a = LinQuadComb::try_from(FieldElementExpression::Add(
|
||||
box FieldElementExpression::Number(Bn128Field::from(2)),
|
||||
box FieldElementExpression::Mult(
|
||||
box FieldElementExpression::Number(Bn128Field::from(2)),
|
||||
box FieldElementExpression::identifier("a".into()),
|
||||
),
|
||||
))
|
||||
.unwrap();
|
||||
|
||||
// (2 + 2*a*b)
|
||||
let b = LinQuadComb::try_from(FieldElementExpression::Add(
|
||||
box FieldElementExpression::Number(Bn128Field::from(2)),
|
||||
box FieldElementExpression::Mult(
|
||||
box FieldElementExpression::Mult(
|
||||
box FieldElementExpression::Number(Bn128Field::from(2)),
|
||||
box FieldElementExpression::identifier("a".into()),
|
||||
),
|
||||
box FieldElementExpression::identifier("b".into()),
|
||||
),
|
||||
))
|
||||
.unwrap();
|
||||
|
||||
// (2 + 2*a) - (2 + 2*a*b) => 0 + 2*a + (-2)*a*b
|
||||
let c = a - b;
|
||||
|
||||
assert_eq!(c.constant, Bn128Field::from(0));
|
||||
assert_eq!(
|
||||
c.linear,
|
||||
vec![
|
||||
(Bn128Field::from(2), "a".into()),
|
||||
(Bn128Field::from(0), "a".into()),
|
||||
(Bn128Field::from(0), "b".into())
|
||||
]
|
||||
);
|
||||
assert_eq!(
|
||||
c.quadratic,
|
||||
vec![(Bn128Field::from(-2), "a".into(), "b".into())]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mult() {
|
||||
// (2 + 2*a)
|
||||
let a = LinQuadComb::try_from(FieldElementExpression::Add(
|
||||
box FieldElementExpression::Number(Bn128Field::from(2)),
|
||||
box FieldElementExpression::Mult(
|
||||
box FieldElementExpression::Number(Bn128Field::from(2)),
|
||||
box FieldElementExpression::identifier("a".into()),
|
||||
),
|
||||
))
|
||||
.unwrap();
|
||||
|
||||
// (2 + 2*b)
|
||||
let b = LinQuadComb::try_from(FieldElementExpression::Add(
|
||||
box FieldElementExpression::Number(Bn128Field::from(2)),
|
||||
box FieldElementExpression::Mult(
|
||||
box FieldElementExpression::Number(Bn128Field::from(2)),
|
||||
box FieldElementExpression::identifier("b".into()),
|
||||
),
|
||||
))
|
||||
.unwrap();
|
||||
|
||||
// (2 + 2*a) * (2 + 2*b) => 4 + 4*b + 4*a + 4*a*b
|
||||
let c = a.try_mul(b).unwrap();
|
||||
|
||||
assert_eq!(c.constant, Bn128Field::from(4));
|
||||
assert_eq!(
|
||||
c.linear,
|
||||
vec![
|
||||
(Bn128Field::from(4), "a".into()),
|
||||
(Bn128Field::from(4), "b".into()),
|
||||
]
|
||||
);
|
||||
assert_eq!(
|
||||
c.quadratic,
|
||||
vec![(Bn128Field::from(4), "a".into(), "b".into())]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mult_degree_error() {
|
||||
// 2*a*b
|
||||
let a = LinQuadComb::try_from(FieldElementExpression::Add(
|
||||
box FieldElementExpression::Number(Bn128Field::from(2)),
|
||||
box FieldElementExpression::Mult(
|
||||
box FieldElementExpression::identifier("a".into()),
|
||||
box FieldElementExpression::identifier("b".into()),
|
||||
),
|
||||
))
|
||||
.unwrap();
|
||||
|
||||
// 2*a*b
|
||||
let b = a.clone();
|
||||
|
||||
// (2*a*b) * (2*a*b) would result in a higher degree than expected
|
||||
let c = a.try_mul(b);
|
||||
assert!(c.is_err());
|
||||
}
|
||||
}
|
|
@ -1,6 +1,7 @@
|
|||
pub mod folder;
|
||||
mod from_typed;
|
||||
mod identifier;
|
||||
pub mod lqc;
|
||||
mod parameter;
|
||||
pub mod result_folder;
|
||||
pub mod types;
|
||||
|
@ -10,7 +11,7 @@ mod variable;
|
|||
pub use self::parameter::Parameter;
|
||||
pub use self::types::{Type, UBitwidth};
|
||||
pub use self::variable::Variable;
|
||||
use crate::common::{FlatEmbed, FormatString};
|
||||
use crate::common::{FlatEmbed, FormatString, SourceMetadata};
|
||||
use crate::typed::ConcreteType;
|
||||
pub use crate::zir::uint::{ShouldReduce, UExpression, UExpressionInner, UMetadata};
|
||||
|
||||
|
@ -21,6 +22,7 @@ use zokrates_field::Field;
|
|||
|
||||
pub use self::folder::Folder;
|
||||
pub use self::identifier::{Identifier, SourceIdentifier};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// A typed program as a collection of modules, one of them being the main
|
||||
#[derive(PartialEq, Eq, Debug, Clone)]
|
||||
|
@ -34,11 +36,13 @@ impl<'ast, T: fmt::Display> fmt::Display for ZirProgram<'ast, T> {
|
|||
}
|
||||
}
|
||||
/// A typed function
|
||||
#[derive(Clone, PartialEq, Eq)]
|
||||
#[derive(Clone, PartialEq, Hash, Eq, Serialize, Deserialize)]
|
||||
pub struct ZirFunction<'ast, T> {
|
||||
/// Arguments of the function
|
||||
#[serde(borrow)]
|
||||
pub arguments: Vec<Parameter<'ast>>,
|
||||
/// Vector of statements that are executed when running the function
|
||||
#[serde(borrow)]
|
||||
pub statements: Vec<ZirStatement<'ast, T>>,
|
||||
/// function signature
|
||||
pub signature: Signature,
|
||||
|
@ -67,7 +71,7 @@ impl<'ast, T: fmt::Display> fmt::Display for ZirFunction<'ast, T> {
|
|||
writeln!(f)?;
|
||||
}
|
||||
|
||||
writeln!(f, "}}")
|
||||
write!(f, "}}")
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -88,9 +92,9 @@ impl<'ast, T: fmt::Debug> fmt::Debug for ZirFunction<'ast, T> {
|
|||
|
||||
pub type ZirAssignee<'ast> = Variable<'ast>;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Hash, Eq)]
|
||||
#[derive(Debug, Clone, PartialEq, Hash, Eq, Serialize, Deserialize)]
|
||||
pub enum RuntimeError {
|
||||
SourceAssertion(String),
|
||||
SourceAssertion(SourceMetadata),
|
||||
SelectRangeCheck,
|
||||
DivisionByZero,
|
||||
IncompleteDynamicRange,
|
||||
|
@ -99,7 +103,7 @@ pub enum RuntimeError {
|
|||
impl fmt::Display for RuntimeError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
RuntimeError::SourceAssertion(message) => write!(f, "{}", message),
|
||||
RuntimeError::SourceAssertion(metadata) => write!(f, "{}", metadata),
|
||||
RuntimeError::SelectRangeCheck => write!(f, "Range check on array access"),
|
||||
RuntimeError::DivisionByZero => write!(f, "Division by zero"),
|
||||
RuntimeError::IncompleteDynamicRange => write!(f, "Dynamic comparison is incomplete"),
|
||||
|
@ -109,12 +113,46 @@ impl fmt::Display for RuntimeError {
|
|||
|
||||
impl RuntimeError {
|
||||
pub fn mock() -> Self {
|
||||
RuntimeError::SourceAssertion(String::default())
|
||||
RuntimeError::SourceAssertion(SourceMetadata::default())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, PartialEq, Hash, Eq, Debug, Serialize, Deserialize)]
|
||||
pub enum ZirAssemblyStatement<'ast, T> {
|
||||
Assignment(
|
||||
#[serde(borrow)] Vec<ZirAssignee<'ast>>,
|
||||
ZirFunction<'ast, T>,
|
||||
),
|
||||
Constraint(
|
||||
FieldElementExpression<'ast, T>,
|
||||
FieldElementExpression<'ast, T>,
|
||||
SourceMetadata,
|
||||
),
|
||||
}
|
||||
|
||||
impl<'ast, T: fmt::Display> fmt::Display for ZirAssemblyStatement<'ast, T> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
match *self {
|
||||
ZirAssemblyStatement::Assignment(ref lhs, ref rhs) => {
|
||||
write!(
|
||||
f,
|
||||
"{} <-- {};",
|
||||
lhs.iter()
|
||||
.map(|a| a.to_string())
|
||||
.collect::<Vec<_>>()
|
||||
.join(", "),
|
||||
rhs
|
||||
)
|
||||
}
|
||||
ZirAssemblyStatement::Constraint(ref lhs, ref rhs, _) => {
|
||||
write!(f, "{} === {};", lhs, rhs)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A statement in a `ZirFunction`
|
||||
#[derive(Clone, PartialEq, Hash, Eq, Debug)]
|
||||
#[derive(Clone, PartialEq, Hash, Eq, Debug, Serialize, Deserialize)]
|
||||
pub enum ZirStatement<'ast, T> {
|
||||
Return(Vec<ZirExpression<'ast, T>>),
|
||||
Definition(ZirAssignee<'ast>, ZirExpression<'ast, T>),
|
||||
|
@ -129,6 +167,8 @@ pub enum ZirStatement<'ast, T> {
|
|||
FormatString,
|
||||
Vec<(ConcreteType, Vec<ZirExpression<'ast, T>>)>,
|
||||
),
|
||||
#[serde(borrow)]
|
||||
Assembly(Vec<ZirAssemblyStatement<'ast, T>>),
|
||||
}
|
||||
|
||||
impl<'ast, T: fmt::Display> fmt::Display for ZirStatement<'ast, T> {
|
||||
|
@ -142,15 +182,19 @@ impl<'ast, T: fmt::Display> ZirStatement<'ast, T> {
|
|||
write!(f, "{}", "\t".repeat(depth))?;
|
||||
match self {
|
||||
ZirStatement::Return(ref exprs) => {
|
||||
write!(
|
||||
f,
|
||||
"return {};",
|
||||
exprs
|
||||
.iter()
|
||||
.map(|e| e.to_string())
|
||||
.collect::<Vec<_>>()
|
||||
.join(", ")
|
||||
)
|
||||
write!(f, "return")?;
|
||||
if !exprs.is_empty() {
|
||||
write!(
|
||||
f,
|
||||
" {}",
|
||||
exprs
|
||||
.iter()
|
||||
.map(|e| e.to_string())
|
||||
.collect::<Vec<_>>()
|
||||
.join(", ")
|
||||
)?;
|
||||
}
|
||||
write!(f, ";")
|
||||
}
|
||||
ZirStatement::Definition(ref lhs, ref rhs) => {
|
||||
write!(f, "{} = {};", lhs, rhs)
|
||||
|
@ -166,7 +210,7 @@ impl<'ast, T: fmt::Display> ZirStatement<'ast, T> {
|
|||
s.fmt_indented(f, depth + 1)?;
|
||||
writeln!(f)?;
|
||||
}
|
||||
write!(f, "{}}};", "\t".repeat(depth))
|
||||
write!(f, "{}}}", "\t".repeat(depth))
|
||||
}
|
||||
ZirStatement::Assertion(ref e, ref error) => {
|
||||
write!(f, "assert({}", e)?;
|
||||
|
@ -200,6 +244,13 @@ impl<'ast, T: fmt::Display> ZirStatement<'ast, T> {
|
|||
.collect::<Vec<_>>()
|
||||
.join(", ")
|
||||
),
|
||||
ZirStatement::Assembly(statements) => {
|
||||
writeln!(f, "asm {{")?;
|
||||
for s in statements {
|
||||
writeln!(f, "{}{}", "\t".repeat(depth + 1), s)?;
|
||||
}
|
||||
write!(f, "{}}}", "\t".repeat(depth))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -208,8 +259,9 @@ pub trait Typed {
|
|||
fn get_type(&self) -> Type;
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Hash, Eq)]
|
||||
#[derive(Debug, Clone, PartialEq, Hash, Eq, Serialize, Deserialize)]
|
||||
pub struct IdentifierExpression<'ast, E> {
|
||||
#[serde(borrow)]
|
||||
pub id: Identifier<'ast>,
|
||||
ty: PhantomData<E>,
|
||||
}
|
||||
|
@ -229,8 +281,9 @@ impl<'ast, E> IdentifierExpression<'ast, E> {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Hash, Eq)]
|
||||
#[derive(Debug, Clone, PartialEq, Hash, Eq, Serialize, Deserialize)]
|
||||
pub struct ConditionalExpression<'ast, T, E> {
|
||||
#[serde(borrow)]
|
||||
pub condition: Box<BooleanExpression<'ast, T>>,
|
||||
pub consequence: Box<E>,
|
||||
pub alternative: Box<E>,
|
||||
|
@ -256,9 +309,10 @@ impl<'ast, T: fmt::Display, E: fmt::Display> fmt::Display for ConditionalExpress
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Hash, Eq)]
|
||||
#[derive(Debug, Clone, PartialEq, Hash, Eq, Serialize, Deserialize)]
|
||||
pub struct SelectExpression<'ast, T, E> {
|
||||
pub array: Vec<E>,
|
||||
#[serde(borrow)]
|
||||
pub index: Box<UExpression<'ast, T>>,
|
||||
}
|
||||
|
||||
|
@ -287,11 +341,11 @@ impl<'ast, T: fmt::Display, E: fmt::Display> fmt::Display for SelectExpression<'
|
|||
}
|
||||
|
||||
/// A typed expression
|
||||
#[derive(Clone, PartialEq, Hash, Eq)]
|
||||
#[derive(Clone, PartialEq, Hash, Eq, Serialize, Deserialize)]
|
||||
pub enum ZirExpression<'ast, T> {
|
||||
Boolean(BooleanExpression<'ast, T>),
|
||||
FieldElement(FieldElementExpression<'ast, T>),
|
||||
Uint(UExpression<'ast, T>),
|
||||
Uint(#[serde(borrow)] UExpression<'ast, T>),
|
||||
}
|
||||
|
||||
impl<'ast, T: Field> From<BooleanExpression<'ast, T>> for ZirExpression<'ast, T> {
|
||||
|
@ -364,15 +418,20 @@ pub trait MultiTyped {
|
|||
fn get_types(&self) -> &Vec<Type>;
|
||||
}
|
||||
|
||||
#[derive(Clone, PartialEq, Hash, Eq)]
|
||||
#[derive(Clone, PartialEq, Hash, Eq, Serialize, Deserialize)]
|
||||
pub enum ZirExpressionList<'ast, T> {
|
||||
EmbedCall(FlatEmbed, Vec<u32>, Vec<ZirExpression<'ast, T>>),
|
||||
EmbedCall(
|
||||
FlatEmbed,
|
||||
Vec<u32>,
|
||||
#[serde(borrow)] Vec<ZirExpression<'ast, T>>,
|
||||
),
|
||||
}
|
||||
|
||||
/// An expression of type `field`
|
||||
#[derive(Clone, PartialEq, Hash, Eq, Debug)]
|
||||
#[derive(Clone, PartialEq, Hash, Eq, Debug, Serialize, Deserialize)]
|
||||
pub enum FieldElementExpression<'ast, T> {
|
||||
Number(T),
|
||||
#[serde(borrow)]
|
||||
Identifier(IdentifierExpression<'ast, Self>),
|
||||
Select(SelectExpression<'ast, T, Self>),
|
||||
Add(
|
||||
|
@ -392,16 +451,57 @@ pub enum FieldElementExpression<'ast, T> {
|
|||
Box<FieldElementExpression<'ast, T>>,
|
||||
),
|
||||
Pow(
|
||||
Box<FieldElementExpression<'ast, T>>,
|
||||
#[serde(borrow)] Box<UExpression<'ast, T>>,
|
||||
),
|
||||
And(
|
||||
Box<FieldElementExpression<'ast, T>>,
|
||||
Box<FieldElementExpression<'ast, T>>,
|
||||
),
|
||||
Or(
|
||||
Box<FieldElementExpression<'ast, T>>,
|
||||
Box<FieldElementExpression<'ast, T>>,
|
||||
),
|
||||
Xor(
|
||||
Box<FieldElementExpression<'ast, T>>,
|
||||
Box<FieldElementExpression<'ast, T>>,
|
||||
),
|
||||
LeftShift(
|
||||
Box<FieldElementExpression<'ast, T>>,
|
||||
Box<UExpression<'ast, T>>,
|
||||
),
|
||||
RightShift(
|
||||
Box<FieldElementExpression<'ast, T>>,
|
||||
Box<UExpression<'ast, T>>,
|
||||
),
|
||||
Conditional(ConditionalExpression<'ast, T, FieldElementExpression<'ast, T>>),
|
||||
}
|
||||
|
||||
impl<'ast, T> FieldElementExpression<'ast, T> {
|
||||
pub fn is_linear(&self) -> bool {
|
||||
match self {
|
||||
FieldElementExpression::Number(_) => true,
|
||||
FieldElementExpression::Identifier(_) => true,
|
||||
FieldElementExpression::Add(box left, box right) => {
|
||||
left.is_linear() && right.is_linear()
|
||||
}
|
||||
FieldElementExpression::Sub(box left, box right) => {
|
||||
left.is_linear() && right.is_linear()
|
||||
}
|
||||
FieldElementExpression::Mult(box left, box right) => matches!(
|
||||
(left, right),
|
||||
(FieldElementExpression::Number(_), _) | (_, FieldElementExpression::Number(_))
|
||||
),
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// An expression of type `bool`
|
||||
#[derive(Clone, PartialEq, Hash, Eq, Debug)]
|
||||
#[derive(Clone, PartialEq, Hash, Eq, Debug, Serialize, Deserialize)]
|
||||
pub enum BooleanExpression<'ast, T> {
|
||||
Value(bool),
|
||||
#[serde(borrow)]
|
||||
Identifier(IdentifierExpression<'ast, Self>),
|
||||
Select(SelectExpression<'ast, T, Self>),
|
||||
FieldLt(
|
||||
|
@ -501,6 +601,15 @@ impl<'ast, T: fmt::Display> fmt::Display for FieldElementExpression<'ast, T> {
|
|||
FieldElementExpression::Mult(ref lhs, ref rhs) => write!(f, "({} * {})", lhs, rhs),
|
||||
FieldElementExpression::Div(ref lhs, ref rhs) => write!(f, "({} / {})", lhs, rhs),
|
||||
FieldElementExpression::Pow(ref lhs, ref rhs) => write!(f, "{}**{}", lhs, rhs),
|
||||
FieldElementExpression::And(ref lhs, ref rhs) => write!(f, "({} & {})", lhs, rhs),
|
||||
FieldElementExpression::Or(ref lhs, ref rhs) => write!(f, "({} | {})", lhs, rhs),
|
||||
FieldElementExpression::Xor(ref lhs, ref rhs) => write!(f, "({} ^ {})", lhs, rhs),
|
||||
FieldElementExpression::LeftShift(ref lhs, ref rhs) => {
|
||||
write!(f, "({} << {})", lhs, rhs)
|
||||
}
|
||||
FieldElementExpression::RightShift(ref lhs, ref rhs) => {
|
||||
write!(f, "({} >> {})", lhs, rhs)
|
||||
}
|
||||
FieldElementExpression::Conditional(ref c) => {
|
||||
write!(f, "{}", c)
|
||||
}
|
||||
|
@ -804,3 +913,36 @@ impl IntoType for UBitwidth {
|
|||
Type::Uint(self)
|
||||
}
|
||||
}
|
||||
|
||||
pub trait Constant: Sized {
|
||||
// return whether this is constant
|
||||
fn is_constant(&self) -> bool;
|
||||
}
|
||||
|
||||
impl<'ast, T: Field> Constant for ZirExpression<'ast, T> {
|
||||
fn is_constant(&self) -> bool {
|
||||
match self {
|
||||
ZirExpression::FieldElement(e) => e.is_constant(),
|
||||
ZirExpression::Boolean(e) => e.is_constant(),
|
||||
ZirExpression::Uint(e) => e.is_constant(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast, T: Field> Constant for FieldElementExpression<'ast, T> {
|
||||
fn is_constant(&self) -> bool {
|
||||
matches!(self, FieldElementExpression::Number(..))
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast, T: Field> Constant for BooleanExpression<'ast, T> {
|
||||
fn is_constant(&self) -> bool {
|
||||
matches!(self, BooleanExpression::Value(..))
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast, T: Field> Constant for UExpression<'ast, T> {
|
||||
fn is_constant(&self) -> bool {
|
||||
matches!(self.as_inner(), UExpressionInner::Value(..))
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,8 +1,10 @@
|
|||
use crate::zir::Variable;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::fmt;
|
||||
|
||||
#[derive(Clone, PartialEq, Eq)]
|
||||
#[derive(Clone, PartialEq, Hash, Eq, Serialize, Deserialize)]
|
||||
pub struct Parameter<'ast> {
|
||||
#[serde(borrow)]
|
||||
pub id: Variable<'ast>,
|
||||
pub private: bool,
|
||||
}
|
||||
|
|
|
@ -61,6 +61,13 @@ pub trait ResultFolder<'ast, T: Field>: Sized {
|
|||
self.fold_variable(a)
|
||||
}
|
||||
|
||||
fn fold_assembly_statement(
|
||||
&mut self,
|
||||
s: ZirAssemblyStatement<'ast, T>,
|
||||
) -> Result<Vec<ZirAssemblyStatement<'ast, T>>, Self::Error> {
|
||||
fold_assembly_statement(self, s)
|
||||
}
|
||||
|
||||
fn fold_statement(
|
||||
&mut self,
|
||||
s: ZirStatement<'ast, T>,
|
||||
|
@ -152,6 +159,26 @@ pub trait ResultFolder<'ast, T: Field>: Sized {
|
|||
fold_uint_expression_inner(self, bitwidth, e)
|
||||
}
|
||||
}
|
||||
pub fn fold_assembly_statement<'ast, T: Field, F: ResultFolder<'ast, T>>(
|
||||
f: &mut F,
|
||||
s: ZirAssemblyStatement<'ast, T>,
|
||||
) -> Result<Vec<ZirAssemblyStatement<'ast, T>>, F::Error> {
|
||||
Ok(match s {
|
||||
ZirAssemblyStatement::Assignment(assignees, function) => {
|
||||
let assignees = assignees
|
||||
.into_iter()
|
||||
.map(|a| f.fold_assignee(a))
|
||||
.collect::<Result<_, _>>()?;
|
||||
let function = f.fold_function(function)?;
|
||||
vec![ZirAssemblyStatement::Assignment(assignees, function)]
|
||||
}
|
||||
ZirAssemblyStatement::Constraint(lhs, rhs, metadata) => {
|
||||
let lhs = f.fold_field_expression(lhs)?;
|
||||
let rhs = f.fold_field_expression(rhs)?;
|
||||
vec![ZirAssemblyStatement::Constraint(lhs, rhs, metadata)]
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
pub fn fold_statement<'ast, T: Field, F: ResultFolder<'ast, T>>(
|
||||
f: &mut F,
|
||||
|
@ -207,6 +234,16 @@ pub fn fold_statement<'ast, T: Field, F: ResultFolder<'ast, T>>(
|
|||
|
||||
ZirStatement::Log(l, e)
|
||||
}
|
||||
ZirStatement::Assembly(statements) => {
|
||||
let statements = statements
|
||||
.into_iter()
|
||||
.map(|s| f.fold_assembly_statement(s))
|
||||
.collect::<Result<Vec<_>, _>>()?
|
||||
.into_iter()
|
||||
.flatten()
|
||||
.collect();
|
||||
ZirStatement::Assembly(statements)
|
||||
}
|
||||
};
|
||||
Ok(vec![res])
|
||||
}
|
||||
|
@ -254,6 +291,36 @@ pub fn fold_field_expression<'ast, T: Field, F: ResultFolder<'ast, T>>(
|
|||
let e2 = f.fold_uint_expression(e2)?;
|
||||
FieldElementExpression::Pow(box e1, box e2)
|
||||
}
|
||||
FieldElementExpression::Xor(box left, box right) => {
|
||||
let left = f.fold_field_expression(left)?;
|
||||
let right = f.fold_field_expression(right)?;
|
||||
|
||||
FieldElementExpression::Xor(box left, box right)
|
||||
}
|
||||
FieldElementExpression::And(box left, box right) => {
|
||||
let left = f.fold_field_expression(left)?;
|
||||
let right = f.fold_field_expression(right)?;
|
||||
|
||||
FieldElementExpression::And(box left, box right)
|
||||
}
|
||||
FieldElementExpression::Or(box left, box right) => {
|
||||
let left = f.fold_field_expression(left)?;
|
||||
let right = f.fold_field_expression(right)?;
|
||||
|
||||
FieldElementExpression::Or(box left, box right)
|
||||
}
|
||||
FieldElementExpression::LeftShift(box e, box by) => {
|
||||
let e = f.fold_field_expression(e)?;
|
||||
let by = f.fold_uint_expression(by)?;
|
||||
|
||||
FieldElementExpression::LeftShift(box e, box by)
|
||||
}
|
||||
FieldElementExpression::RightShift(box e, box by) => {
|
||||
let e = f.fold_field_expression(e)?;
|
||||
let by = f.fold_uint_expression(by)?;
|
||||
|
||||
FieldElementExpression::RightShift(box e, box by)
|
||||
}
|
||||
FieldElementExpression::Conditional(c) => {
|
||||
match f.fold_conditional_expression(&Type::FieldElement, c)? {
|
||||
ConditionalOrExpression::Conditional(s) => FieldElementExpression::Conditional(s),
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
use crate::zir::types::UBitwidth;
|
||||
use crate::zir::IdentifierExpression;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use zokrates_field::Field;
|
||||
|
||||
use super::{ConditionalExpression, SelectExpression};
|
||||
|
@ -91,7 +92,7 @@ impl<'ast, T> From<u32> for UExpression<'ast, T> {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Clone, Hash)]
|
||||
#[derive(Debug, PartialEq, Eq, Clone, Hash, Serialize, Deserialize)]
|
||||
pub enum ShouldReduce {
|
||||
Unknown,
|
||||
True,
|
||||
|
@ -135,7 +136,7 @@ impl ShouldReduce {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
|
||||
pub struct UMetadata<T> {
|
||||
pub max: T,
|
||||
pub should_reduce: ShouldReduce,
|
||||
|
@ -162,16 +163,18 @@ impl<T: Field> UMetadata<T> {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
|
||||
pub struct UExpression<'ast, T> {
|
||||
pub bitwidth: UBitwidth,
|
||||
pub metadata: Option<UMetadata<T>>,
|
||||
#[serde(borrow)]
|
||||
pub inner: UExpressionInner<'ast, T>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
|
||||
pub enum UExpressionInner<'ast, T> {
|
||||
Value(u128),
|
||||
#[serde(borrow)]
|
||||
Identifier(IdentifierExpression<'ast, UExpression<'ast, T>>),
|
||||
Select(SelectExpression<'ast, T, UExpression<'ast, T>>),
|
||||
Add(Box<UExpression<'ast, T>>, Box<UExpression<'ast, T>>),
|
||||
|
|
|
@ -1,9 +1,11 @@
|
|||
use crate::zir::types::{Type, UBitwidth};
|
||||
use crate::zir::Identifier;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::fmt;
|
||||
|
||||
#[derive(Clone, PartialEq, Hash, Eq)]
|
||||
#[derive(Clone, PartialEq, Hash, Eq, Serialize, Deserialize)]
|
||||
pub struct Variable<'ast> {
|
||||
#[serde(borrow)]
|
||||
pub id: Identifier<'ast>,
|
||||
pub _type: Type,
|
||||
}
|
||||
|
|
|
@ -22,8 +22,8 @@ use zokrates_proof_systems::Scheme;
|
|||
const G16_WARNING: &str = "WARNING: You are using the G16 scheme which is subject to malleability. See zokrates.github.io/toolbox/proving_schemes.html#g16-malleability for implications.";
|
||||
|
||||
impl<T: Field + BellmanFieldExtensions> Backend<T, G16> for Bellman {
|
||||
fn generate_proof<I: IntoIterator<Item = Statement<T>>, R: RngCore + CryptoRng>(
|
||||
program: ProgIterator<T, I>,
|
||||
fn generate_proof<'a, I: IntoIterator<Item = Statement<'a, T>>, R: RngCore + CryptoRng>(
|
||||
program: ProgIterator<'a, T, I>,
|
||||
witness: Witness<T>,
|
||||
proving_key: Vec<u8>,
|
||||
rng: &mut R,
|
||||
|
@ -86,8 +86,8 @@ impl<T: Field + BellmanFieldExtensions> Backend<T, G16> for Bellman {
|
|||
}
|
||||
|
||||
impl<T: Field + BellmanFieldExtensions> NonUniversalBackend<T, G16> for Bellman {
|
||||
fn setup<I: IntoIterator<Item = Statement<T>>, R: RngCore + CryptoRng>(
|
||||
program: ProgIterator<T, I>,
|
||||
fn setup<'a, I: IntoIterator<Item = Statement<'a, T>>, R: RngCore + CryptoRng>(
|
||||
program: ProgIterator<'a, T, I>,
|
||||
rng: &mut R,
|
||||
) -> SetupKeypair<T, G16> {
|
||||
println!("{}", G16_WARNING);
|
||||
|
@ -102,8 +102,8 @@ impl<T: Field + BellmanFieldExtensions> NonUniversalBackend<T, G16> for Bellman
|
|||
}
|
||||
|
||||
impl<T: Field + BellmanFieldExtensions> MpcBackend<T, G16> for Bellman {
|
||||
fn initialize<R: Read, W: Write, I: IntoIterator<Item = Statement<T>>>(
|
||||
program: ProgIterator<T, I>,
|
||||
fn initialize<'a, R: Read, W: Write, I: IntoIterator<Item = Statement<'a, T>>>(
|
||||
program: ProgIterator<'a, T, I>,
|
||||
phase1_radix: &mut R,
|
||||
output: &mut W,
|
||||
) -> Result<(), String> {
|
||||
|
@ -130,9 +130,9 @@ impl<T: Field + BellmanFieldExtensions> MpcBackend<T, G16> for Bellman {
|
|||
Ok(hash)
|
||||
}
|
||||
|
||||
fn verify<P: Read, R: Read, I: IntoIterator<Item = Statement<T>>>(
|
||||
fn verify<'a, P: Read, R: Read, I: IntoIterator<Item = Statement<'a, T>>>(
|
||||
params: &mut P,
|
||||
program: ProgIterator<T, I>,
|
||||
program: ProgIterator<'a, T, I>,
|
||||
phase1_radix: &mut R,
|
||||
) -> Result<Vec<[u8; 64]>, String> {
|
||||
let params =
|
||||
|
|
|
@ -23,20 +23,20 @@ pub use self::parse::*;
|
|||
pub struct Bellman;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Computation<T, I: IntoIterator<Item = Statement<T>>> {
|
||||
program: ProgIterator<T, I>,
|
||||
pub struct Computation<'a, T, I: IntoIterator<Item = Statement<'a, T>>> {
|
||||
program: ProgIterator<'a, T, I>,
|
||||
witness: Option<Witness<T>>,
|
||||
}
|
||||
|
||||
impl<T: Field, I: IntoIterator<Item = Statement<T>>> Computation<T, I> {
|
||||
pub fn with_witness(program: ProgIterator<T, I>, witness: Witness<T>) -> Self {
|
||||
impl<'a, T: Field, I: IntoIterator<Item = Statement<'a, T>>> Computation<'a, T, I> {
|
||||
pub fn with_witness(program: ProgIterator<'a, T, I>, witness: Witness<T>) -> Self {
|
||||
Computation {
|
||||
program,
|
||||
witness: Some(witness),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn without_witness(program: ProgIterator<T, I>) -> Self {
|
||||
pub fn without_witness(program: ProgIterator<'a, T, I>) -> Self {
|
||||
Computation {
|
||||
program,
|
||||
witness: None,
|
||||
|
@ -84,8 +84,8 @@ fn bellman_combination<T: BellmanFieldExtensions, CS: ConstraintSystem<T::Bellma
|
|||
.fold(LinearCombination::zero(), |acc, e| acc + e)
|
||||
}
|
||||
|
||||
impl<T: BellmanFieldExtensions + Field, I: IntoIterator<Item = Statement<T>>>
|
||||
Circuit<T::BellmanEngine> for Computation<T, I>
|
||||
impl<'a, T: BellmanFieldExtensions + Field, I: IntoIterator<Item = Statement<'a, T>>>
|
||||
Circuit<T::BellmanEngine> for Computation<'a, T, I>
|
||||
{
|
||||
fn synthesize<CS: ConstraintSystem<T::BellmanEngine>>(
|
||||
self,
|
||||
|
@ -161,7 +161,9 @@ pub fn get_random_seed<R: RngCore + CryptoRng>(rng: &mut R) -> [u32; 8] {
|
|||
seed
|
||||
}
|
||||
|
||||
impl<T: BellmanFieldExtensions + Field, I: IntoIterator<Item = Statement<T>>> Computation<T, I> {
|
||||
impl<'a, T: BellmanFieldExtensions + Field, I: IntoIterator<Item = Statement<'a, T>>>
|
||||
Computation<'a, T, I>
|
||||
{
|
||||
pub fn prove<R: RngCore + CryptoRng>(
|
||||
self,
|
||||
params: &Parameters<T::BellmanEngine>,
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
- [Comments](language/comments.md)
|
||||
- [Macros](language/macros.md)
|
||||
- [Logging](language/logging.md)
|
||||
- [Assembly](language/assembly.md)
|
||||
|
||||
- [Toolbox](toolbox/index.md)
|
||||
- [CLI](toolbox/cli.md)
|
||||
|
|
75
zokrates_book/src/language/assembly.md
Normal file
75
zokrates_book/src/language/assembly.md
Normal file
|
@ -0,0 +1,75 @@
|
|||
## Assembly
|
||||
|
||||
ZoKrates allows developers to define constraints through assembly blocks. Assembly blocks are considered **unsafe**, as safety and correctness of the resulting arithmetic circuit is in the hands of the developer. Usage of assembly is recommended only in optimization efforts for the experienced developers to minimize constraint count of an arithmetic circuit.
|
||||
|
||||
## Writing assembly
|
||||
|
||||
All constraints must be enclosed within an `asm` block. In an assembly block we can do the following:
|
||||
|
||||
1. Assign to a witness variable using `<--`
|
||||
2. Define a constraint using `===`
|
||||
|
||||
Assigning a value, in general, should be combined with adding a constraint:
|
||||
|
||||
```zok
|
||||
{{#include ../../../zokrates_cli/examples/book/assembly/division.zok}}
|
||||
```
|
||||
|
||||
> The operator `<--` can be sometimes misused, as this operator does not generate any constraints, resulting in unconstrained variables in the constraint system.
|
||||
|
||||
In some cases we can combine the witness assignment and constraint generation with the `<==` operator:
|
||||
|
||||
```zok
|
||||
asm {
|
||||
c <== 1 - a*b;
|
||||
}
|
||||
```
|
||||
|
||||
which is equivalent to:
|
||||
|
||||
```zok
|
||||
asm {
|
||||
c <-- 1 - a*b;
|
||||
c === 1 - a*b;
|
||||
}
|
||||
```
|
||||
|
||||
A constraint can contain arithmetic expressions that are built using multiplication, addition, and other variables or `field` values. Only quadratic expressions are allowed to be included in constraints. Non-quadratic expressions or usage of other arithmetic operators like division or power are not allowed as constraints, but can be used in the witness assignment expression.
|
||||
|
||||
The following code is not allowed:
|
||||
|
||||
```zok
|
||||
asm {
|
||||
d === a*b*c;
|
||||
}
|
||||
```
|
||||
|
||||
as the constraint `d === a*b*c` is not quadratic.
|
||||
|
||||
In some cases, ZoKrates will apply minor transformations on the defined constraints in order to meet the correct format:
|
||||
|
||||
```zok
|
||||
asm {
|
||||
x * (x - 1) === 0;
|
||||
}
|
||||
```
|
||||
|
||||
will be transformed to:
|
||||
|
||||
```zok
|
||||
asm {
|
||||
x === x * x;
|
||||
}
|
||||
```
|
||||
|
||||
## Type casting
|
||||
|
||||
Assembly is a low-level part of the compiler which does not have type safety. In some cases we might want to do zero-cost conversions between `field` and `bool` type.
|
||||
|
||||
### field_to_bool_unsafe
|
||||
|
||||
This call is unsafe because it is the responsibility of the user to constrain the field input:
|
||||
|
||||
```zok
|
||||
{{#include ../../../zokrates_cli/examples/book/assembly/field_to_bool.zok}}
|
||||
```
|
|
@ -72,6 +72,7 @@ pub fn r1cs_program<T: Field>(prog: Prog<T>) -> (Vec<Variable>, usize, Vec<Const
|
|||
for (quad, lin) in prog.statements.iter().filter_map(|s| match s {
|
||||
Statement::Constraint(quad, lin, _) => Some((quad, lin)),
|
||||
Statement::Directive(..) => None,
|
||||
Statement::Block(..) => unreachable!(),
|
||||
Statement::Log(..) => None,
|
||||
}) {
|
||||
for (k, _) in &quad.left.0 {
|
||||
|
@ -95,6 +96,7 @@ pub fn r1cs_program<T: Field>(prog: Prog<T>) -> (Vec<Variable>, usize, Vec<Const
|
|||
// second pass to convert program to raw sparse vectors
|
||||
for (quad, lin) in prog.statements.into_iter().filter_map(|s| match s {
|
||||
Statement::Constraint(quad, lin, _) => Some((quad, lin)),
|
||||
Statement::Block(..) => unreachable!(),
|
||||
Statement::Directive(..) => None,
|
||||
Statement::Log(..) => None,
|
||||
}) {
|
||||
|
|
11
zokrates_cli/examples/book/assembly/division.zok
Normal file
11
zokrates_cli/examples/book/assembly/division.zok
Normal file
|
@ -0,0 +1,11 @@
|
|||
def main(field a, field b) -> field {
|
||||
field mut c = 0;
|
||||
field mut invb = 0;
|
||||
asm {
|
||||
invb <-- b == 0 ? 0 : 1 / b;
|
||||
invb * b === 1;
|
||||
c <-- invb * a;
|
||||
a === b * c;
|
||||
}
|
||||
return c;
|
||||
}
|
13
zokrates_cli/examples/book/assembly/field_to_bool.zok
Normal file
13
zokrates_cli/examples/book/assembly/field_to_bool.zok
Normal file
|
@ -0,0 +1,13 @@
|
|||
from "EMBED" import field_to_bool_unsafe;
|
||||
|
||||
def main(field x) -> bool {
|
||||
// we constrain `x` to be 0 or 1
|
||||
asm {
|
||||
x * (x - 1) === 0;
|
||||
}
|
||||
// we can convert `x` to `bool` afterwards, as we constrained it properly
|
||||
// if we failed to constrain `x` to `0` or `1`, the call to `field_to_bool_unsafe` introduces undefined behavior
|
||||
// `field_to_bool_unsafe` call does not produce any extra constraints
|
||||
bool out = field_to_bool_unsafe(x);
|
||||
return out;
|
||||
}
|
|
@ -0,0 +1,6 @@
|
|||
def main(field mut a, u32 i) {
|
||||
asm {
|
||||
a <-- a << i; // bitwise operations are allowed in witness generation
|
||||
a === a << i; // but not in constraints
|
||||
}
|
||||
}
|
|
@ -0,0 +1,6 @@
|
|||
def main(field[2] mut a, u32 i) -> field[2] {
|
||||
asm {
|
||||
a[i] <== 42; // assigning to a variable index is not allowed in assembly
|
||||
}
|
||||
return a;
|
||||
}
|
|
@ -0,0 +1,3 @@
|
|||
def main(field a, u32 b) -> field {
|
||||
return a**b;
|
||||
}
|
|
@ -121,7 +121,8 @@ mod tests {
|
|||
use std::io::{BufReader, Read};
|
||||
use std::string::String;
|
||||
use typed_arena::Arena;
|
||||
use zokrates_core::compile::{compile, CompilationArtifacts, CompileConfig};
|
||||
use zokrates_common::CompileConfig;
|
||||
use zokrates_core::compile::{compile, CompilationArtifacts};
|
||||
use zokrates_field::Bn128Field;
|
||||
use zokrates_fs_resolver::FileSystemResolver;
|
||||
|
||||
|
@ -219,7 +220,7 @@ mod tests {
|
|||
let interpreter = zokrates_interpreter::Interpreter::default();
|
||||
|
||||
let _ = interpreter
|
||||
.execute(artifacts.prog(), &[Bn128Field::from(0)])
|
||||
.execute(artifacts.prog(), &[Bn128Field::from(0u32)])
|
||||
.unwrap();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -5,8 +5,8 @@ use std::fs::File;
|
|||
use std::io::{BufReader, Read};
|
||||
use std::path::{Path, PathBuf};
|
||||
use zokrates_common::constants::BN128;
|
||||
use zokrates_common::helpers::CurveParameter;
|
||||
use zokrates_core::compile::{check, CompileConfig, CompileError};
|
||||
use zokrates_common::{helpers::CurveParameter, CompileConfig};
|
||||
use zokrates_core::compile::{check, CompileError};
|
||||
use zokrates_field::{Bls12_377Field, Bls12_381Field, Bn128Field, Bw6_761Field, Field};
|
||||
use zokrates_fs_resolver::FileSystemResolver;
|
||||
|
||||
|
|
|
@ -8,8 +8,8 @@ use std::path::{Path, PathBuf};
|
|||
use typed_arena::Arena;
|
||||
use zokrates_circom::write_r1cs;
|
||||
use zokrates_common::constants::BN128;
|
||||
use zokrates_common::helpers::CurveParameter;
|
||||
use zokrates_core::compile::{compile, CompileConfig, CompileError};
|
||||
use zokrates_common::{helpers::CurveParameter, CompileConfig};
|
||||
use zokrates_core::compile::{compile, CompileError};
|
||||
use zokrates_field::{Bls12_377Field, Bls12_381Field, Bn128Field, Bw6_761Field, Field};
|
||||
use zokrates_fs_resolver::FileSystemResolver;
|
||||
|
||||
|
|
|
@ -85,8 +85,8 @@ pub fn exec(sub_matches: &ArgMatches) -> Result<(), String> {
|
|||
}
|
||||
}
|
||||
|
||||
fn cli_compute<T: Field, I: Iterator<Item = ir::Statement<T>>>(
|
||||
ir_prog: ir::ProgIterator<T, I>,
|
||||
fn cli_compute<'a, T: Field, I: Iterator<Item = ir::Statement<'a, T>>>(
|
||||
ir_prog: ir::ProgIterator<'a, T, I>,
|
||||
sub_matches: &ArgMatches,
|
||||
) -> Result<(), String> {
|
||||
println!("Computing witness...");
|
||||
|
|
|
@ -147,12 +147,13 @@ pub fn exec(sub_matches: &ArgMatches) -> Result<(), String> {
|
|||
}
|
||||
|
||||
fn cli_generate_proof<
|
||||
'a,
|
||||
T: Field,
|
||||
I: Iterator<Item = ir::Statement<T>>,
|
||||
I: Iterator<Item = ir::Statement<'a, T>>,
|
||||
S: Scheme<T>,
|
||||
B: Backend<T, S>,
|
||||
>(
|
||||
program: ir::ProgIterator<T, I>,
|
||||
program: ir::ProgIterator<'a, T, I>,
|
||||
sub_matches: &ArgMatches,
|
||||
) -> Result<(), String> {
|
||||
println!("Generating proof...");
|
||||
|
|
|
@ -47,8 +47,8 @@ pub fn exec(sub_matches: &ArgMatches) -> Result<(), String> {
|
|||
}
|
||||
}
|
||||
|
||||
fn cli_smtlib2<T: Field, I: Iterator<Item = ir::Statement<T>>>(
|
||||
ir_prog: ir::ProgIterator<T, I>,
|
||||
fn cli_smtlib2<'a, T: Field, I: Iterator<Item = ir::Statement<'a, T>>>(
|
||||
ir_prog: ir::ProgIterator<'a, T, I>,
|
||||
sub_matches: &ArgMatches,
|
||||
) -> Result<(), String> {
|
||||
println!("Generating SMTLib2...");
|
||||
|
|
|
@ -43,8 +43,8 @@ pub fn exec(sub_matches: &ArgMatches) -> Result<(), String> {
|
|||
}
|
||||
}
|
||||
|
||||
fn cli_inspect<T: Field, I: Iterator<Item = ir::Statement<T>>>(
|
||||
ir_prog: ir::ProgIterator<T, I>,
|
||||
fn cli_inspect<'a, T: Field, I: Iterator<Item = ir::Statement<'a, T>>>(
|
||||
ir_prog: ir::ProgIterator<'a, T, I>,
|
||||
sub_matches: &ArgMatches,
|
||||
) -> Result<(), String> {
|
||||
let ir_prog: ir::Prog<T> = ir_prog.collect();
|
||||
|
@ -52,6 +52,9 @@ fn cli_inspect<T: Field, I: Iterator<Item = ir::Statement<T>>>(
|
|||
let curve = format!("{:<17} {}", "curve:", T::name());
|
||||
let constraint_count = format!("{:<17} {}", "constraint_count:", ir_prog.constraint_count());
|
||||
|
||||
println!("{}", curve);
|
||||
println!("{}", constraint_count);
|
||||
|
||||
if sub_matches.is_present("ztf") {
|
||||
let output_path =
|
||||
PathBuf::from(sub_matches.value_of("input").unwrap()).with_extension("ztf");
|
||||
|
|
|
@ -58,12 +58,13 @@ pub fn exec(sub_matches: &ArgMatches) -> Result<(), String> {
|
|||
}
|
||||
|
||||
fn cli_mpc_init<
|
||||
'a,
|
||||
T: Field + BellmanFieldExtensions,
|
||||
I: Iterator<Item = ir::Statement<T>>,
|
||||
I: Iterator<Item = ir::Statement<'a, T>>,
|
||||
S: MpcScheme<T>,
|
||||
B: MpcBackend<T, S>,
|
||||
>(
|
||||
program: ir::ProgIterator<T, I>,
|
||||
program: ir::ProgIterator<'a, T, I>,
|
||||
sub_matches: &ArgMatches,
|
||||
) -> Result<(), String> {
|
||||
println!("Initializing MPC...");
|
||||
|
|
|
@ -58,12 +58,13 @@ pub fn exec(sub_matches: &ArgMatches) -> Result<(), String> {
|
|||
}
|
||||
|
||||
fn cli_mpc_verify<
|
||||
'a,
|
||||
T: Field + BellmanFieldExtensions,
|
||||
I: Iterator<Item = ir::Statement<T>>,
|
||||
I: Iterator<Item = ir::Statement<'a, T>>,
|
||||
S: MpcScheme<T>,
|
||||
B: MpcBackend<T, S>,
|
||||
>(
|
||||
program: ir::ProgIterator<T, I>,
|
||||
program: ir::ProgIterator<'a, T, I>,
|
||||
sub_matches: &ArgMatches,
|
||||
) -> Result<(), String> {
|
||||
println!("Verifying contributions...");
|
||||
|
|
|
@ -178,12 +178,13 @@ pub fn exec(sub_matches: &ArgMatches) -> Result<(), String> {
|
|||
}
|
||||
|
||||
fn cli_setup_non_universal<
|
||||
'a,
|
||||
T: Field,
|
||||
I: Iterator<Item = ir::Statement<T>>,
|
||||
I: Iterator<Item = ir::Statement<'a, T>>,
|
||||
S: NonUniversalScheme<T>,
|
||||
B: NonUniversalBackend<T, S>,
|
||||
>(
|
||||
program: ir::ProgIterator<T, I>,
|
||||
program: ir::ProgIterator<'a, T, I>,
|
||||
sub_matches: &ArgMatches,
|
||||
) -> Result<(), String> {
|
||||
println!("Performing setup...");
|
||||
|
@ -227,12 +228,13 @@ fn cli_setup_non_universal<
|
|||
}
|
||||
|
||||
fn cli_setup_universal<
|
||||
'a,
|
||||
T: Field,
|
||||
I: Iterator<Item = ir::Statement<T>>,
|
||||
I: Iterator<Item = ir::Statement<'a, T>>,
|
||||
S: UniversalScheme<T>,
|
||||
B: UniversalBackend<T, S>,
|
||||
>(
|
||||
program: ir::ProgIterator<T, I>,
|
||||
program: ir::ProgIterator<'a, T, I>,
|
||||
srs: Vec<u8>,
|
||||
sub_matches: &ArgMatches,
|
||||
) -> Result<(), String> {
|
||||
|
|
16
zokrates_codegen/Cargo.toml
Normal file
16
zokrates_codegen/Cargo.toml
Normal file
|
@ -0,0 +1,16 @@
|
|||
[package]
|
||||
name = "zokrates_codegen"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
[features]
|
||||
default = ["ark", "bellman"]
|
||||
ark = ["zokrates_ast/ark", "zokrates_embed/ark", "zokrates_common/ark", "zokrates_interpreter/ark"]
|
||||
bellman = ["zokrates_ast/bellman", "zokrates_embed/bellman", "zokrates_common/bellman", "zokrates_interpreter/bellman"]
|
||||
|
||||
[dependencies]
|
||||
zokrates_field = { version = "0.5.0", path = "../zokrates_field", default-features = false }
|
||||
zokrates_common = { version = "0.1.0", path = "../zokrates_common", default-features = false }
|
||||
zokrates_embed = { version = "0.1.0", path = "../zokrates_embed", default-features = false }
|
||||
zokrates_interpreter = { version = "0.1", path = "../zokrates_interpreter", default-features = false }
|
||||
zokrates_ast = { version = "0.1", path = "../zokrates_ast", default-features = false }
|
|
@ -1,3 +1,5 @@
|
|||
#![feature(box_patterns, box_syntax)]
|
||||
|
||||
//! Module containing the `Flattener` to process a program that is R1CS-able.
|
||||
//!
|
||||
//! @file flatten.rs
|
||||
|
@ -9,11 +11,11 @@ mod utils;
|
|||
|
||||
use self::utils::flat_expression_from_bits;
|
||||
use zokrates_ast::zir::{
|
||||
ConditionalExpression, SelectExpression, ShouldReduce, UMetadata, ZirExpressionList,
|
||||
ConditionalExpression, SelectExpression, ShouldReduce, UMetadata, ZirAssemblyStatement,
|
||||
ZirExpressionList,
|
||||
};
|
||||
use zokrates_interpreter::Interpreter;
|
||||
|
||||
use crate::compile::CompileConfig;
|
||||
use std::collections::{
|
||||
hash_map::{Entry, HashMap},
|
||||
VecDeque,
|
||||
|
@ -29,9 +31,10 @@ use zokrates_ast::zir::{
|
|||
UExpression, UExpressionInner, Variable as ZirVariable, ZirExpression, ZirFunction,
|
||||
ZirStatement,
|
||||
};
|
||||
use zokrates_common::CompileConfig;
|
||||
use zokrates_field::Field;
|
||||
|
||||
type FlatStatements<T> = VecDeque<FlatStatement<T>>;
|
||||
type FlatStatements<'ast, T> = VecDeque<FlatStatement<'ast, T>>;
|
||||
|
||||
/// Flattens a function
|
||||
///
|
||||
|
@ -63,14 +66,14 @@ pub fn from_function_and_config<T: Field>(
|
|||
|
||||
pub struct FlattenerIteratorInner<'ast, T> {
|
||||
pub statements: VecDeque<ZirStatement<'ast, T>>,
|
||||
pub statements_flattened: FlatStatements<T>,
|
||||
pub statements_flattened: FlatStatements<'ast, T>,
|
||||
pub flattener: Flattener<'ast, T>,
|
||||
}
|
||||
|
||||
pub type FlattenerIterator<'ast, T> = FlatProgIterator<T, FlattenerIteratorInner<'ast, T>>;
|
||||
pub type FlattenerIterator<'ast, T> = FlatProgIterator<'ast, T, FlattenerIteratorInner<'ast, T>>;
|
||||
|
||||
impl<'ast, T: Field> Iterator for FlattenerIteratorInner<'ast, T> {
|
||||
type Item = FlatStatement<T>;
|
||||
type Item = FlatStatement<'ast, T>;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
while self.statements_flattened.is_empty() {
|
||||
|
@ -124,7 +127,7 @@ trait Flatten<'ast, T: Field>: From<ZirExpression<'ast, T>> + Conditional<'ast,
|
|||
fn flatten(
|
||||
self,
|
||||
flattener: &mut Flattener<'ast, T>,
|
||||
statements_flattened: &mut FlatStatements<T>,
|
||||
statements_flattened: &mut FlatStatements<'ast, T>,
|
||||
) -> Self::Output;
|
||||
}
|
||||
|
||||
|
@ -134,7 +137,7 @@ impl<'ast, T: Field> Flatten<'ast, T> for FieldElementExpression<'ast, T> {
|
|||
fn flatten(
|
||||
self,
|
||||
flattener: &mut Flattener<'ast, T>,
|
||||
statements_flattened: &mut FlatStatements<T>,
|
||||
statements_flattened: &mut FlatStatements<'ast, T>,
|
||||
) -> Self::Output {
|
||||
flattener.flatten_field_expression(statements_flattened, self)
|
||||
}
|
||||
|
@ -146,7 +149,7 @@ impl<'ast, T: Field> Flatten<'ast, T> for UExpression<'ast, T> {
|
|||
fn flatten(
|
||||
self,
|
||||
flattener: &mut Flattener<'ast, T>,
|
||||
statements_flattened: &mut FlatStatements<T>,
|
||||
statements_flattened: &mut FlatStatements<'ast, T>,
|
||||
) -> Self::Output {
|
||||
flattener.flatten_uint_expression(statements_flattened, self)
|
||||
}
|
||||
|
@ -158,7 +161,7 @@ impl<'ast, T: Field> Flatten<'ast, T> for BooleanExpression<'ast, T> {
|
|||
fn flatten(
|
||||
self,
|
||||
flattener: &mut Flattener<'ast, T>,
|
||||
statements_flattened: &mut FlatStatements<T>,
|
||||
statements_flattened: &mut FlatStatements<'ast, T>,
|
||||
) -> Self::Output {
|
||||
flattener.flatten_boolean_expression(statements_flattened, self)
|
||||
}
|
||||
|
@ -224,7 +227,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
fn define(
|
||||
&mut self,
|
||||
e: FlatExpression<T>,
|
||||
statements_flattened: &mut FlatStatements<T>,
|
||||
statements_flattened: &mut FlatStatements<'ast, T>,
|
||||
) -> Variable {
|
||||
match e {
|
||||
FlatExpression::Identifier(id) => id,
|
||||
|
@ -273,7 +276,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
#[must_use]
|
||||
fn constant_le_check(
|
||||
&mut self,
|
||||
statements_flattened: &mut FlatStatements<T>,
|
||||
statements_flattened: &mut FlatStatements<'ast, T>,
|
||||
a: &[FlatExpression<T>],
|
||||
b: &[bool],
|
||||
) -> Vec<FlatExpression<T>> {
|
||||
|
@ -378,7 +381,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
/// * A FlatExpression which evaluates to `1` if `left == right`, `0` otherwise
|
||||
fn eq_check(
|
||||
&mut self,
|
||||
statements_flattened: &mut FlatStatements<T>,
|
||||
statements_flattened: &mut FlatStatements<'ast, T>,
|
||||
left: FlatExpression<T>,
|
||||
right: FlatExpression<T>,
|
||||
) -> FlatExpression<T> {
|
||||
|
@ -431,7 +434,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
/// * `b` - the big-endian bit decomposition of the upper bound of the range
|
||||
fn enforce_constant_le_check_bits(
|
||||
&mut self,
|
||||
statements_flattened: &mut FlatStatements<T>,
|
||||
statements_flattened: &mut FlatStatements<'ast, T>,
|
||||
a: &[FlatExpression<T>],
|
||||
c: &[bool],
|
||||
error: RuntimeError,
|
||||
|
@ -461,7 +464,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
/// * `c` - the constant upper bound of the range
|
||||
fn enforce_constant_le_check(
|
||||
&mut self,
|
||||
statements_flattened: &mut FlatStatements<T>,
|
||||
statements_flattened: &mut FlatStatements<'ast, T>,
|
||||
e: FlatExpression<T>,
|
||||
c: T,
|
||||
error: RuntimeError,
|
||||
|
@ -497,7 +500,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
/// * `c` - the constant upper bound of the range
|
||||
fn enforce_constant_lt_check(
|
||||
&mut self,
|
||||
statements_flattened: &mut FlatStatements<T>,
|
||||
statements_flattened: &mut FlatStatements<'ast, T>,
|
||||
e: FlatExpression<T>,
|
||||
c: T,
|
||||
error: RuntimeError,
|
||||
|
@ -516,9 +519,9 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
|
||||
fn make_conditional(
|
||||
&mut self,
|
||||
statements: FlatStatements<T>,
|
||||
statements: FlatStatements<'ast, T>,
|
||||
condition: FlatExpression<T>,
|
||||
) -> FlatStatements<T> {
|
||||
) -> FlatStatements<'ast, T> {
|
||||
statements
|
||||
.into_iter()
|
||||
.flat_map(|s| match s {
|
||||
|
@ -579,7 +582,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
/// * U is the type of the expression
|
||||
fn flatten_conditional_expression<U: Flatten<'ast, T>>(
|
||||
&mut self,
|
||||
statements_flattened: &mut FlatStatements<T>,
|
||||
statements_flattened: &mut FlatStatements<'ast, T>,
|
||||
e: ConditionalExpression<'ast, T, U>,
|
||||
) -> FlatUExpression<T> {
|
||||
let condition = *e.condition;
|
||||
|
@ -677,7 +680,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
/// * a `FlatExpression` which evaluates to `1` if `0 <= e < c`, and to `0` otherwise
|
||||
fn constant_lt_check(
|
||||
&mut self,
|
||||
statements_flattened: &mut FlatStatements<T>,
|
||||
statements_flattened: &mut FlatStatements<'ast, T>,
|
||||
e: FlatExpression<T>,
|
||||
c: T,
|
||||
) -> FlatExpression<T> {
|
||||
|
@ -701,7 +704,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
/// * a `FlatExpression` which evaluates to `1` if `0 <= e <= c`, and to `0` otherwise
|
||||
fn constant_field_le_check(
|
||||
&mut self,
|
||||
statements_flattened: &mut FlatStatements<T>,
|
||||
statements_flattened: &mut FlatStatements<'ast, T>,
|
||||
e: FlatExpression<T>,
|
||||
c: T,
|
||||
) -> FlatExpression<T> {
|
||||
|
@ -742,7 +745,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
#[must_use]
|
||||
fn le_check(
|
||||
&mut self,
|
||||
statements_flattened: &mut FlatStatements<T>,
|
||||
statements_flattened: &mut FlatStatements<'ast, T>,
|
||||
lhs_flattened: FlatExpression<T>,
|
||||
rhs_flattened: FlatExpression<T>,
|
||||
bit_width: usize,
|
||||
|
@ -765,7 +768,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
#[must_use]
|
||||
fn lt_check(
|
||||
&mut self,
|
||||
statements_flattened: &mut FlatStatements<T>,
|
||||
statements_flattened: &mut FlatStatements<'ast, T>,
|
||||
lhs_flattened: FlatExpression<T>,
|
||||
rhs_flattened: FlatExpression<T>,
|
||||
bit_width: usize,
|
||||
|
@ -824,7 +827,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
/// * in order to preserve composability.
|
||||
fn flatten_boolean_expression(
|
||||
&mut self,
|
||||
statements_flattened: &mut FlatStatements<T>,
|
||||
statements_flattened: &mut FlatStatements<'ast, T>,
|
||||
expression: BooleanExpression<'ast, T>,
|
||||
) -> FlatExpression<T> {
|
||||
match expression {
|
||||
|
@ -1030,7 +1033,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
/// * `param_expressions` - Arguments of this call
|
||||
fn flatten_embed_call(
|
||||
&mut self,
|
||||
statements_flattened: &mut FlatStatements<T>,
|
||||
statements_flattened: &mut FlatStatements<'ast, T>,
|
||||
embed: FlatEmbed,
|
||||
generics: Vec<u32>,
|
||||
param_expressions: Vec<ZirExpression<'ast, T>>,
|
||||
|
@ -1046,6 +1049,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
.collect();
|
||||
|
||||
match embed {
|
||||
FlatEmbed::FieldToBoolUnsafe => vec![params.pop().unwrap()],
|
||||
FlatEmbed::U8ToBits => self.u_to_bits(params.pop().unwrap(), 8.into()),
|
||||
FlatEmbed::U16ToBits => self.u_to_bits(params.pop().unwrap(), 16.into()),
|
||||
FlatEmbed::U32ToBits => self.u_to_bits(params.pop().unwrap(), 32.into()),
|
||||
|
@ -1131,9 +1135,9 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
|
||||
fn flatten_embed_call_aux(
|
||||
&mut self,
|
||||
statements_flattened: &mut FlatStatements<T>,
|
||||
statements_flattened: &mut FlatStatements<'ast, T>,
|
||||
params: Vec<FlatUExpression<T>>,
|
||||
funct: FlatFunctionIterator<T, impl IntoIterator<Item = FlatStatement<T>>>,
|
||||
funct: FlatFunctionIterator<'ast, T, impl IntoIterator<Item = FlatStatement<'ast, T>>>,
|
||||
) -> Vec<FlatUExpression<T>> {
|
||||
let mut replacement_map = HashMap::new();
|
||||
|
||||
|
@ -1152,6 +1156,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
// add all flattened statements, adapt return statements
|
||||
|
||||
let statements = funct.statements.into_iter().map(|stat| match stat {
|
||||
FlatStatement::Block(..) => unreachable!(),
|
||||
FlatStatement::Definition(var, rhs) => {
|
||||
let new_var = self.use_sym();
|
||||
replacement_map.insert(var, new_var);
|
||||
|
@ -1216,7 +1221,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
/// * `expr` - `ZirExpression` that will be flattened.
|
||||
fn flatten_expression(
|
||||
&mut self,
|
||||
statements_flattened: &mut FlatStatements<T>,
|
||||
statements_flattened: &mut FlatStatements<'ast, T>,
|
||||
expr: ZirExpression<'ast, T>,
|
||||
) -> FlatUExpression<T> {
|
||||
match expr {
|
||||
|
@ -1232,7 +1237,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
|
||||
fn default_xor(
|
||||
&mut self,
|
||||
statements_flattened: &mut FlatStatements<T>,
|
||||
statements_flattened: &mut FlatStatements<'ast, T>,
|
||||
left: UExpression<'ast, T>,
|
||||
right: UExpression<'ast, T>,
|
||||
) -> FlatUExpression<T> {
|
||||
|
@ -1293,7 +1298,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
|
||||
fn euclidean_division(
|
||||
&mut self,
|
||||
statements_flattened: &mut FlatStatements<T>,
|
||||
statements_flattened: &mut FlatStatements<'ast, T>,
|
||||
target_bitwidth: UBitwidth,
|
||||
left: UExpression<'ast, T>,
|
||||
right: UExpression<'ast, T>,
|
||||
|
@ -1379,7 +1384,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
/// * `expr` - `UExpression` that will be flattened.
|
||||
fn flatten_uint_expression(
|
||||
&mut self,
|
||||
statements_flattened: &mut FlatStatements<T>,
|
||||
statements_flattened: &mut FlatStatements<'ast, T>,
|
||||
expr: UExpression<'ast, T>,
|
||||
) -> FlatUExpression<T> {
|
||||
// the bitwidth for this type of uint (8, 16 or 32)
|
||||
|
@ -1872,7 +1877,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
e: &FlatUExpression<T>,
|
||||
from: usize,
|
||||
to: usize,
|
||||
statements_flattened: &mut FlatStatements<T>,
|
||||
statements_flattened: &mut FlatStatements<'ast, T>,
|
||||
error: RuntimeError,
|
||||
) -> Vec<FlatExpression<T>> {
|
||||
assert!(from <= T::get_required_bits());
|
||||
|
@ -1966,7 +1971,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
|
||||
fn flatten_select_expression<U: Flatten<'ast, T>>(
|
||||
&mut self,
|
||||
statements_flattened: &mut FlatStatements<T>,
|
||||
statements_flattened: &mut FlatStatements<'ast, T>,
|
||||
e: SelectExpression<'ast, T, U>,
|
||||
) -> FlatUExpression<T> {
|
||||
let array = e.array;
|
||||
|
@ -2030,7 +2035,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
/// * `expr` - `FieldElementExpression` that will be flattened.
|
||||
fn flatten_field_expression(
|
||||
&mut self,
|
||||
statements_flattened: &mut FlatStatements<T>,
|
||||
statements_flattened: &mut FlatStatements<'ast, T>,
|
||||
expr: FieldElementExpression<'ast, T>,
|
||||
) -> FlatExpression<T> {
|
||||
match expr {
|
||||
|
@ -2215,6 +2220,39 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
FieldElementExpression::Conditional(e) => self
|
||||
.flatten_conditional_expression(statements_flattened, e)
|
||||
.get_field_unchecked(),
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
|
||||
fn flatten_assembly_statement(
|
||||
&mut self,
|
||||
statements_flattened: &mut FlatStatements<'ast, T>,
|
||||
stat: ZirAssemblyStatement<'ast, T>,
|
||||
) {
|
||||
match stat {
|
||||
ZirAssemblyStatement::Assignment(assignees, function) => {
|
||||
let inputs: Vec<FlatExpression<T>> = function
|
||||
.arguments
|
||||
.iter()
|
||||
.cloned()
|
||||
.map(|p| self.layout.get(&p.id.id).cloned().unwrap().into())
|
||||
.collect();
|
||||
let outputs: Vec<Variable> = assignees
|
||||
.into_iter()
|
||||
.map(|assignee| self.use_variable(&assignee))
|
||||
.collect();
|
||||
let directive = FlatDirective::new(outputs, Solver::Zir(function), inputs);
|
||||
statements_flattened.push_back(FlatStatement::Directive(directive));
|
||||
}
|
||||
ZirAssemblyStatement::Constraint(lhs, rhs, metadata) => {
|
||||
let lhs = self.flatten_field_expression(statements_flattened, lhs);
|
||||
let rhs = self.flatten_field_expression(statements_flattened, rhs);
|
||||
statements_flattened.push_back(FlatStatement::Condition(
|
||||
lhs,
|
||||
rhs,
|
||||
RuntimeError::SourceAssemblyConstraint(metadata),
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -2226,10 +2264,17 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
/// * `stat` - `ZirStatement` that will be flattened.
|
||||
fn flatten_statement(
|
||||
&mut self,
|
||||
statements_flattened: &mut FlatStatements<T>,
|
||||
statements_flattened: &mut FlatStatements<'ast, T>,
|
||||
stat: ZirStatement<'ast, T>,
|
||||
) {
|
||||
match stat {
|
||||
ZirStatement::Assembly(statements) => {
|
||||
let mut block_statements = VecDeque::new();
|
||||
for s in statements {
|
||||
self.flatten_assembly_statement(&mut block_statements, s);
|
||||
}
|
||||
statements_flattened.push_back(FlatStatement::Block(block_statements.into()));
|
||||
}
|
||||
ZirStatement::Return(exprs) => {
|
||||
#[allow(clippy::needless_collect)]
|
||||
// clippy suggests to not collect here, but `statements_flattened` is borrowed in the iterator,
|
||||
|
@ -2630,12 +2675,12 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `statements_flattened` - `FlatStatements<T>` Vector where new flattened statements can be added.
|
||||
/// * `statements_flattened` - `FlatStatements<'ast, T>` Vector where new flattened statements can be added.
|
||||
/// * `lhs` - `FlatExpression<T>` Left-hand side of the equality expression.
|
||||
/// * `rhs` - `FlatExpression<T>` Right-hand side of the equality expression.
|
||||
fn flatten_equality_assertion(
|
||||
&mut self,
|
||||
statements_flattened: &mut FlatStatements<T>,
|
||||
statements_flattened: &mut FlatStatements<'ast, T>,
|
||||
lhs: FlatExpression<T>,
|
||||
rhs: FlatExpression<T>,
|
||||
error: RuntimeError,
|
||||
|
@ -2664,11 +2709,11 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
/// # Arguments
|
||||
///
|
||||
/// * `e` - `FlatExpression<T>` Expression to be assigned to an identifier.
|
||||
/// * `statements_flattened` - `FlatStatements<T>` Vector where new flattened statements can be added.
|
||||
/// * `statements_flattened` - `FlatStatements<'ast, T>` Vector where new flattened statements can be added.
|
||||
fn identify_expression(
|
||||
&mut self,
|
||||
e: FlatExpression<T>,
|
||||
statements_flattened: &mut FlatStatements<T>,
|
||||
statements_flattened: &mut FlatStatements<'ast, T>,
|
||||
) -> FlatExpression<T> {
|
||||
match e.is_linear() {
|
||||
true => e,
|
||||
|
@ -2707,7 +2752,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
fn use_parameter(
|
||||
&mut self,
|
||||
parameter: &ZirParameter<'ast>,
|
||||
statements_flattened: &mut FlatStatements<T>,
|
||||
statements_flattened: &mut FlatStatements<'ast, T>,
|
||||
) -> Parameter {
|
||||
let variable = self.use_variable(¶meter.id);
|
||||
|
|
@ -12,4 +12,5 @@ bellman = []
|
|||
ark = []
|
||||
|
||||
|
||||
[dependencies]
|
||||
[dependencies]
|
||||
serde = { version = "1.0", features = ["derive"] }
|
|
@ -1,6 +1,7 @@
|
|||
pub mod constants;
|
||||
pub mod helpers;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::path::PathBuf;
|
||||
|
||||
pub trait Resolver<E> {
|
||||
|
@ -10,3 +11,23 @@ pub trait Resolver<E> {
|
|||
import_location: PathBuf,
|
||||
) -> Result<(String, PathBuf), E>;
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, Serialize, Deserialize, Clone, Copy)]
|
||||
pub struct CompileConfig {
|
||||
#[serde(default)]
|
||||
pub isolate_branches: bool,
|
||||
#[serde(default)]
|
||||
pub debug: bool,
|
||||
}
|
||||
|
||||
impl CompileConfig {
|
||||
pub fn isolate_branches(mut self, flag: bool) -> Self {
|
||||
self.isolate_branches = flag;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn debug(mut self, debug: bool) -> Self {
|
||||
self.debug = debug;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
|
|
@ -8,8 +8,8 @@ readme = "README.md"
|
|||
|
||||
[features]
|
||||
default = ["ark", "bellman"]
|
||||
ark = ["zokrates_ast/ark", "zokrates_embed/ark", "zokrates_common/ark", "zokrates_interpreter/ark"]
|
||||
bellman = ["zokrates_ast/bellman", "zokrates_embed/bellman", "zokrates_common/bellman", "zokrates_interpreter/bellman"]
|
||||
ark = ["zokrates_ast/ark", "zokrates_embed/ark", "zokrates_common/ark", "zokrates_interpreter/ark", "zokrates_codegen/ark", "zokrates_analysis/ark"]
|
||||
bellman = ["zokrates_ast/bellman", "zokrates_embed/bellman", "zokrates_common/bellman", "zokrates_interpreter/bellman", "zokrates_codegen/bellman", "zokrates_analysis/bellman"]
|
||||
|
||||
[dependencies]
|
||||
log = "0.4"
|
||||
|
@ -26,6 +26,8 @@ zokrates_pest_ast = { version = "0.3.0", path = "../zokrates_pest_ast" }
|
|||
zokrates_common = { version = "0.1", path = "../zokrates_common", default-features = false }
|
||||
zokrates_embed = { version = "0.1.0", path = "../zokrates_embed", default-features = false }
|
||||
zokrates_interpreter = { version = "0.1", path = "../zokrates_interpreter", default-features = false }
|
||||
zokrates_codegen = { version = "0.1", path = "../zokrates_codegen", default-features = false }
|
||||
zokrates_analysis = { version = "0.1", path = "../zokrates_analysis", default-features = false }
|
||||
zokrates_ast = { version = "0.1", path = "../zokrates_ast", default-features = false }
|
||||
csv = "1"
|
||||
|
||||
|
|
|
@ -3,35 +3,34 @@
|
|||
//! @file compile.rs
|
||||
//! @author Thibaut Schaeffer <thibaut@schaeff.fr>
|
||||
//! @date 2018
|
||||
use crate::flatten::from_function_and_config;
|
||||
use crate::imports::{self, Importer};
|
||||
use crate::macros;
|
||||
use crate::optimizer::optimize;
|
||||
use crate::semantics::{self, Checker};
|
||||
use crate::static_analysis::{self, analyse};
|
||||
use macros::process_macros;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::fmt;
|
||||
use std::io;
|
||||
use std::path::{Path, PathBuf};
|
||||
use typed_arena::Arena;
|
||||
use zokrates_analysis::{self, analyse};
|
||||
use zokrates_ast::ir::{self, from_flat::from_flat};
|
||||
use zokrates_ast::typed::abi::Abi;
|
||||
use zokrates_ast::untyped::{Module, OwnedModuleId, Program};
|
||||
use zokrates_ast::zir::ZirProgram;
|
||||
use zokrates_common::Resolver;
|
||||
use zokrates_codegen::from_function_and_config;
|
||||
use zokrates_common::{CompileConfig, Resolver};
|
||||
use zokrates_field::Field;
|
||||
use zokrates_pest_ast as pest;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct CompilationArtifacts<T, I: IntoIterator<Item = ir::Statement<T>>> {
|
||||
prog: ir::ProgIterator<T, I>,
|
||||
pub struct CompilationArtifacts<'ast, T, I: IntoIterator<Item = ir::Statement<'ast, T>>> {
|
||||
prog: ir::ProgIterator<'ast, T, I>,
|
||||
abi: Abi,
|
||||
}
|
||||
|
||||
impl<T, I: IntoIterator<Item = ir::Statement<T>>> CompilationArtifacts<T, I> {
|
||||
pub fn prog(self) -> ir::ProgIterator<T, I> {
|
||||
impl<'ast, T, I: IntoIterator<Item = ir::Statement<'ast, T>>> CompilationArtifacts<'ast, T, I> {
|
||||
pub fn prog(self) -> ir::ProgIterator<'ast, T, I> {
|
||||
self.prog
|
||||
}
|
||||
|
||||
|
@ -39,11 +38,11 @@ impl<T, I: IntoIterator<Item = ir::Statement<T>>> CompilationArtifacts<T, I> {
|
|||
&self.abi
|
||||
}
|
||||
|
||||
pub fn into_inner(self) -> (ir::ProgIterator<T, I>, Abi) {
|
||||
pub fn into_inner(self) -> (ir::ProgIterator<'ast, T, I>, Abi) {
|
||||
(self.prog, self.abi)
|
||||
}
|
||||
|
||||
pub fn collect(self) -> CompilationArtifacts<T, Vec<ir::Statement<T>>> {
|
||||
pub fn collect(self) -> CompilationArtifacts<'ast, T, Vec<ir::Statement<'ast, T>>> {
|
||||
CompilationArtifacts {
|
||||
prog: self.prog.collect(),
|
||||
abi: self.abi,
|
||||
|
@ -67,7 +66,7 @@ pub enum CompileErrorInner {
|
|||
MacroError(macros::Error),
|
||||
SemanticError(semantics::ErrorInner),
|
||||
ReadError(io::Error),
|
||||
AnalysisError(static_analysis::Error),
|
||||
AnalysisError(zokrates_analysis::Error),
|
||||
}
|
||||
|
||||
impl CompileErrorInner {
|
||||
|
@ -142,8 +141,8 @@ impl From<semantics::Error> for CompileError {
|
|||
}
|
||||
}
|
||||
|
||||
impl From<static_analysis::Error> for CompileErrorInner {
|
||||
fn from(error: static_analysis::Error) -> Self {
|
||||
impl From<zokrates_analysis::Error> for CompileErrorInner {
|
||||
fn from(error: zokrates_analysis::Error) -> Self {
|
||||
CompileErrorInner::AnalysisError(error)
|
||||
}
|
||||
}
|
||||
|
@ -173,26 +172,6 @@ impl fmt::Display for CompileErrorInner {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, Serialize, Deserialize, Clone, Copy)]
|
||||
pub struct CompileConfig {
|
||||
#[serde(default)]
|
||||
pub isolate_branches: bool,
|
||||
#[serde(default)]
|
||||
pub debug: bool,
|
||||
}
|
||||
|
||||
impl CompileConfig {
|
||||
pub fn isolate_branches(mut self, flag: bool) -> Self {
|
||||
self.isolate_branches = flag;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn debug(mut self, debug: bool) -> Self {
|
||||
self.debug = debug;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
type FilePath = PathBuf;
|
||||
|
||||
pub fn compile<'ast, T: Field, E: Into<imports::Error>>(
|
||||
|
@ -201,8 +180,10 @@ pub fn compile<'ast, T: Field, E: Into<imports::Error>>(
|
|||
resolver: Option<&dyn Resolver<E>>,
|
||||
config: CompileConfig,
|
||||
arena: &'ast Arena<String>,
|
||||
) -> Result<CompilationArtifacts<T, impl IntoIterator<Item = ir::Statement<T>> + 'ast>, CompileErrors>
|
||||
{
|
||||
) -> Result<
|
||||
CompilationArtifacts<'ast, T, impl IntoIterator<Item = ir::Statement<'ast, T>> + 'ast>,
|
||||
CompileErrors,
|
||||
> {
|
||||
let (typed_ast, abi): (zokrates_ast::zir::ZirProgram<'_, T>, _) =
|
||||
check_with_arena(source, location, resolver, &config, arena)?;
|
||||
|
||||
|
@ -218,8 +199,11 @@ pub fn compile<'ast, T: Field, E: Into<imports::Error>>(
|
|||
log::debug!("Optimise IR");
|
||||
let optimized_ir_prog = optimize(ir_prog);
|
||||
|
||||
// clean (remove blocks)
|
||||
let clean_ir_prog = optimized_ir_prog.clean();
|
||||
|
||||
Ok(CompilationArtifacts {
|
||||
prog: optimized_ir_prog,
|
||||
prog: clean_ir_prog,
|
||||
abi,
|
||||
})
|
||||
}
|
||||
|
|
|
@ -147,6 +147,10 @@ impl Importer {
|
|||
id: symbol.get_alias(),
|
||||
symbol: Symbol::Flat(FlatEmbed::Unpack),
|
||||
},
|
||||
"field_to_bool_unsafe" => SymbolDeclaration {
|
||||
id: symbol.get_alias(),
|
||||
symbol: Symbol::Flat(FlatEmbed::FieldToBoolUnsafe),
|
||||
},
|
||||
"bit_array_le" => SymbolDeclaration {
|
||||
id: symbol.get_alias(),
|
||||
symbol: Symbol::Flat(FlatEmbed::BitArrayLe),
|
||||
|
|
|
@ -1,9 +1,7 @@
|
|||
#![feature(box_patterns, box_syntax)]
|
||||
|
||||
pub mod compile;
|
||||
mod flatten;
|
||||
pub mod imports;
|
||||
mod macros;
|
||||
mod optimizer;
|
||||
mod semantics;
|
||||
mod static_analysis;
|
||||
|
|
|
@ -4,7 +4,7 @@ use zokrates_field::Field;
|
|||
#[derive(Default)]
|
||||
pub struct Canonicalizer;
|
||||
|
||||
impl<T: Field> Folder<T> for Canonicalizer {
|
||||
impl<'ast, T: Field> Folder<'ast, T> for Canonicalizer {
|
||||
fn fold_linear_combination(&mut self, l: LinComb<T>) -> LinComb<T> {
|
||||
l.into_canonical().into()
|
||||
}
|
||||
|
|
|
@ -14,19 +14,21 @@ use zokrates_ast::ir::folder::*;
|
|||
use zokrates_ast::ir::*;
|
||||
use zokrates_field::Field;
|
||||
|
||||
type SolverCall<'ast, T> = (Solver<'ast, T>, Vec<QuadComb<T>>);
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
pub struct DirectiveOptimizer<T> {
|
||||
calls: HashMap<(Solver, Vec<QuadComb<T>>), Vec<Variable>>,
|
||||
pub struct DirectiveOptimizer<'ast, T> {
|
||||
calls: HashMap<SolverCall<'ast, T>, Vec<Variable>>,
|
||||
/// Map of renamings for reassigned variables while processing the program.
|
||||
substitution: HashMap<Variable, Variable>,
|
||||
}
|
||||
|
||||
impl<T: Field> Folder<T> for DirectiveOptimizer<T> {
|
||||
impl<'ast, T: Field> Folder<'ast, T> for DirectiveOptimizer<'ast, T> {
|
||||
fn fold_variable(&mut self, v: Variable) -> Variable {
|
||||
*self.substitution.get(&v).unwrap_or(&v)
|
||||
}
|
||||
|
||||
fn fold_statement(&mut self, s: Statement<T>) -> Vec<Statement<T>> {
|
||||
fn fold_statement(&mut self, s: Statement<'ast, T>) -> Vec<Statement<'ast, T>> {
|
||||
match s {
|
||||
Statement::Directive(d) => {
|
||||
let d = self.fold_directive(d);
|
||||
|
|
|
@ -21,8 +21,8 @@ pub struct DuplicateOptimizer {
|
|||
seen: HashSet<Hash>,
|
||||
}
|
||||
|
||||
impl<T: Field> Folder<T> for DuplicateOptimizer {
|
||||
fn fold_program(&mut self, p: Prog<T>) -> Prog<T> {
|
||||
impl<'ast, T: Field> Folder<'ast, T> for DuplicateOptimizer {
|
||||
fn fold_program(&mut self, p: Prog<'ast, T>) -> Prog<'ast, T> {
|
||||
// in order to correctly identify duplicates, we need to first canonicalize the statements
|
||||
let mut canonicalizer = Canonicalizer;
|
||||
|
||||
|
@ -38,7 +38,7 @@ impl<T: Field> Folder<T> for DuplicateOptimizer {
|
|||
fold_program(self, p)
|
||||
}
|
||||
|
||||
fn fold_statement(&mut self, s: Statement<T>) -> Vec<Statement<T>> {
|
||||
fn fold_statement(&mut self, s: Statement<'ast, T>) -> Vec<Statement<'ast, T>> {
|
||||
let hashed = hash(&s);
|
||||
let result = match self.seen.get(&hashed) {
|
||||
Some(_) => vec![],
|
||||
|
|
|
@ -19,9 +19,9 @@ use self::tautology::TautologyOptimizer;
|
|||
use zokrates_ast::ir::{ProgIterator, Statement};
|
||||
use zokrates_field::Field;
|
||||
|
||||
pub fn optimize<T: Field, I: IntoIterator<Item = Statement<T>>>(
|
||||
p: ProgIterator<T, I>,
|
||||
) -> ProgIterator<T, impl IntoIterator<Item = Statement<T>>> {
|
||||
pub fn optimize<'ast, T: Field, I: IntoIterator<Item = Statement<'ast, T>>>(
|
||||
p: ProgIterator<'ast, T, I>,
|
||||
) -> ProgIterator<'ast, T, impl IntoIterator<Item = Statement<'ast, T>>> {
|
||||
// remove redefinitions
|
||||
log::debug!("Optimizer: Remove redefinitions and tautologies and directives and duplicates");
|
||||
|
||||
|
|
|
@ -52,8 +52,10 @@ pub struct RedefinitionOptimizer<T> {
|
|||
pub ignore: HashSet<Variable>,
|
||||
}
|
||||
|
||||
impl<T> RedefinitionOptimizer<T> {
|
||||
pub fn init<I: IntoIterator<Item = Statement<T>>>(p: &ProgIterator<T, I>) -> Self {
|
||||
impl<T: Field> RedefinitionOptimizer<T> {
|
||||
pub fn init<'ast, I: IntoIterator<Item = Statement<'ast, T>>>(
|
||||
p: &ProgIterator<'ast, T, I>,
|
||||
) -> Self {
|
||||
RedefinitionOptimizer {
|
||||
substitution: HashMap::new(),
|
||||
ignore: vec![Variable::one()]
|
||||
|
@ -64,10 +66,12 @@ impl<T> RedefinitionOptimizer<T> {
|
|||
.collect(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Field> Folder<T> for RedefinitionOptimizer<T> {
|
||||
fn fold_statement(&mut self, s: Statement<T>) -> Vec<Statement<T>> {
|
||||
fn fold_statement<'ast>(
|
||||
&mut self,
|
||||
s: Statement<'ast, T>,
|
||||
aggressive: bool,
|
||||
) -> Vec<Statement<'ast, T>> {
|
||||
match s {
|
||||
Statement::Constraint(quad, lin, message) => {
|
||||
let quad = self.fold_quadratic_combination(quad);
|
||||
|
@ -161,9 +165,11 @@ impl<T: Field> Folder<T> for RedefinitionOptimizer<T> {
|
|||
.unwrap_or_else(|q| q)
|
||||
})
|
||||
.collect();
|
||||
// to prevent the optimiser from replacing variables introduced by directives, add them to the ignored set
|
||||
for o in d.outputs.iter().cloned() {
|
||||
self.ignore.insert(o);
|
||||
if !aggressive {
|
||||
// to prevent the optimiser from replacing variables introduced by directives, add them to the ignored set
|
||||
for o in d.outputs.iter().cloned() {
|
||||
self.ignore.insert(o);
|
||||
}
|
||||
}
|
||||
vec![Statement::Directive(Directive { inputs, ..d })]
|
||||
}
|
||||
|
@ -172,6 +178,36 @@ impl<T: Field> Folder<T> for RedefinitionOptimizer<T> {
|
|||
s => fold_statement(self, s),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast, T: Field> Folder<'ast, T> for RedefinitionOptimizer<T> {
|
||||
fn fold_statement(&mut self, s: Statement<'ast, T>) -> Vec<Statement<'ast, T>> {
|
||||
match s {
|
||||
Statement::Block(statements) => {
|
||||
#[allow(clippy::needless_collect)]
|
||||
// optimize aggressively and clean up in a second pass (we need to collect here)
|
||||
let statements: Vec<_> = statements
|
||||
.into_iter()
|
||||
.flat_map(|s| self.fold_statement(s, true))
|
||||
.collect();
|
||||
|
||||
// clean up
|
||||
let statements = statements
|
||||
.into_iter()
|
||||
.filter(|s| match s {
|
||||
// we remove a directive iff it has a single output and this output is in the substitution map, meaning it was propagated
|
||||
Statement::Directive(d) => {
|
||||
d.outputs.len() > 1 || !self.substitution.contains_key(&d.outputs[0])
|
||||
}
|
||||
_ => true,
|
||||
})
|
||||
.collect();
|
||||
|
||||
vec![Statement::Block(statements)]
|
||||
}
|
||||
s => self.fold_statement(s, false),
|
||||
}
|
||||
}
|
||||
|
||||
fn fold_linear_combination(&mut self, lc: LinComb<T>) -> LinComb<T> {
|
||||
match lc
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Reference in a new issue