1
0
Fork 0
mirror of synced 2025-09-23 12:18:44 +00:00

Merge branch 'u8-playground' of github.com:Zokrates/ZoKrates into u8-playground

This commit is contained in:
schaeff 2020-07-16 15:43:41 +02:00
commit ad2838e4eb
161 changed files with 2383 additions and 1872 deletions

73
Cargo.lock generated
View file

@ -27,15 +27,6 @@ dependencies = [
"winapi",
]
[[package]]
name = "arrayvec"
version = "0.4.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cd9fd44efafa8690358b7408d253adf110036b88f55672a933f01d616ad9b1b9"
dependencies = [
"nodrop",
]
[[package]]
name = "assert_cli"
version = "0.5.4"
@ -120,17 +111,6 @@ version = "1.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cf1de2fe8c75bc145a2f577add951f8134889b4795d47466a54a5c846d691693"
[[package]]
name = "blake2-rfc_bellman_edition"
version = "0.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fdc60350286c7c3db13b98e91dbe5c8b6830a6821bc20af5b0c310ce94d74915"
dependencies = [
"arrayvec",
"byteorder",
"constant_time_eq",
]
[[package]]
name = "block-buffer"
version = "0.7.3"
@ -261,12 +241,6 @@ dependencies = [
"wasm-bindgen",
]
[[package]]
name = "constant_time_eq"
version = "0.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "245097e9a4535ee1e3e3931fcfcd55a796a44c643e8596ff6566d68f09b87bbc"
[[package]]
name = "crossbeam"
version = "0.7.3"
@ -338,12 +312,6 @@ dependencies = [
"lazy_static",
]
[[package]]
name = "crunchy"
version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7a81dae078cea95a014a339291cec439d2f232ebe854a9d672b796c6afafa9b7"
[[package]]
name = "csv"
version = "1.1.3"
@ -802,12 +770,6 @@ dependencies = [
"autocfg",
]
[[package]]
name = "nodrop"
version = "0.1.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "72ef4a56884ca558e5ddb05a1d1e7e1bfd9a68d9ed024c21704cc98872dae1bb"
[[package]]
name = "num"
version = "0.1.42"
@ -1245,23 +1207,6 @@ dependencies = [
"winapi-util",
]
[[package]]
name = "sapling-crypto_ce"
version = "0.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1c4ff5309ec3e4bd800ad4ab3f71e9b76e9ea81c9f0eda6efa16008afbe440b3"
dependencies = [
"bellman_ce",
"blake2-rfc_bellman_edition",
"byteorder",
"digest",
"rand 0.4.6",
"serde",
"serde_derive",
"sha2",
"tiny-keccak",
]
[[package]]
name = "scoped-tls"
version = "1.0.0"
@ -1470,15 +1415,6 @@ dependencies = [
"lazy_static",
]
[[package]]
name = "tiny-keccak"
version = "2.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2c9d3793400a45f954c52e73d068316d76b6f4e36977e3fcebb13a2721e80237"
dependencies = [
"crunchy",
]
[[package]]
name = "typed-arena"
version = "1.7.0"
@ -1790,7 +1726,6 @@ dependencies = [
"typed-arena",
"wasm-bindgen-test",
"zokrates_common",
"zokrates_embed",
"zokrates_field",
"zokrates_pest_ast",
]
@ -1803,14 +1738,6 @@ dependencies = [
"zokrates_test",
]
[[package]]
name = "zokrates_embed"
version = "0.1.1"
dependencies = [
"bellman_ce",
"sapling-crypto_ce",
]
[[package]]
name = "zokrates_field"
version = "0.3.6"

View file

@ -6,7 +6,6 @@ members = [
"zokrates_cli",
"zokrates_fs_resolver",
"zokrates_stdlib",
"zokrates_embed",
"zokrates_abi",
"zokrates_test",
"zokrates_core_test",

View file

@ -5,5 +5,5 @@ def foo(field[3] a) -> (field):
def main() -> (field, field):
field[3] a = [0, 0, 0]
field res = foo(a)
a[1] == 0
assert(a[1] == 0)
return res, a[1]

View file

@ -5,5 +5,5 @@ def main(field a) -> (field, field):
field[2] result = [0, 0]
field r = foo(a)
result[1] = r
result[1] == r
assert(result[1] == r)
return result[1], r

View file

@ -1,4 +1,4 @@
def main() -> ():
field pMinusOne = 21888242871839275222246405745257275088548364400416034343698204186575808495616
0 - 1 == pMinusOne
assert(0 - 1 == pMinusOne)
return

View file

@ -2,6 +2,6 @@ import "hashes/sha256/512bitPacked" as sha256packed
def main(private field a, private field b, private field c, private field d) -> ():
field[2] h = sha256packed([a, b, c, d])
h[0] == 263561599766550617289250058199814760685
h[1] == 65303172752238645975888084098459749904
assert(h[0] == 263561599766550617289250058199814760685)
assert(h[1] == 65303172752238645975888084098459749904)
return

View file

@ -5,5 +5,5 @@ def incr(field a) -> (field):
def main() -> ():
field x = 1
field res = incr(x)
x == 1 // x has not changed
assert(x == 1) // x has not changed
return

View file

@ -3,5 +3,5 @@
def main(field a, field b) -> (field):
field y = if a + 2 == 3 && a * 2 == 2 then 1 else 0 fi
field z = if y == 1 && 1-y == 0 then y else 1 fi
b == 1
assert(b == 1)
return a

View file

@ -2,6 +2,6 @@
def main(field a, field b) -> (field):
field y = if a + 2 == 4 || b * 2 == 2 then 1 else 0 fi
field z = if y == 1 || y == 0 then y else 1 fi
z == 1
assert(z == 1)
return z

View file

@ -2,6 +2,6 @@
def main(field a) -> (field): // a needs to be 1
field b = a + 5 // inline comment
field c = a + b + a + 4
a == 1 // another inline comment
assert(a == 1) // another inline comment
field d = a + c + a + b
return b + c + d

View file

@ -1,6 +1,19 @@
struct Foo {
field a
}
struct Bar {
Foo[1] foo
}
def isEqual(field a, field b) -> (bool):
return a == b
def main(field a) -> (field):
field b = (a + 5) * 6
2 * b == a * 12 + 60
assert(2 * b == a * 12 + 60)
field c = 7 * (b + a)
c == 7 * b + 7 * a
return b + c
assert(isEqual(c, 7 * b + 7 * a))
field k = if [1, 2] == [3, 4] then 1 else 3 fi
assert([Bar { foo : [Foo { a: 42 }]}] == [Bar { foo : [Foo { a: 42 }]}])
return b + c

View file

@ -0,0 +1,5 @@
def assert() -> ():
return
def main() -> ():
return

View file

@ -1,5 +1,5 @@
// a and b are factorization of c
def main(field c, private field a, private field b) -> ():
field d = a * b
c == d
assert(c == d)
return

View file

@ -7,5 +7,5 @@ import "./bar"
def main() -> (field):
MyBar my_bar = MyBar {}
Bar bar = Bar {}
my_bar == bar
assert(my_bar == bar)
return foo() + bar()

View file

@ -2,5 +2,5 @@ def foo() -> (field):
return 1
def main() -> ():
foo() + (1 + 44*3) == 1
assert(foo() + (1 + 44*3) == 1)
return

View file

@ -6,14 +6,15 @@ import "hashes/utils/256bitsDirectionHelper" as multiplex
// Merke-Tree inclusion proof for tree depth 3 using SNARK efficient pedersen hashes
// directionSelector=> 1/true if current digest is on the rhs of the hash
def main(bool[256] rootDigest, private bool[256] leafDigest, private bool[3] directionSelector, bool[256] PathDigest0, private bool[256] PathDigest1, private bool[256] PathDigest2) -> ():
def main(u32[8] rootDigest, private u32[8] leafDigest, private bool[3] directionSelector, u32[8] PathDigest0, private u32[8] PathDigest1, private u32[8] PathDigest2) -> ():
BabyJubJubParams context = context()
//Setup
bool[256] currentDigest = leafDigest
u32[8] currentDigest = leafDigest
//Loop up the tree
bool[512] preimage = multiplex(directionSelector[0], currentDigest, PathDigest0)
u32[16] preimage = multiplex(directionSelector[0], currentDigest, PathDigest0)
currentDigest = hash(preimage)
preimage = multiplex(directionSelector[1], currentDigest, PathDigest1)
@ -22,7 +23,7 @@ def main(bool[256] rootDigest, private bool[256] leafDigest, private bool[3] dir
preimage = multiplex(directionSelector[2], currentDigest, PathDigest2)
currentDigest = hash(preimage)
rootDigest == currentDigest
assert(rootDigest == currentDigest)
return

View file

@ -3,17 +3,17 @@ import "utils/multiplexer/256bit" as multiplex
// Merkle-Tree inclusion proof for tree depth 3
def main(field treeDepth, bool[256] rootDigest, private bool[256] leafDigest, private bool[2] directionSelector, bool[256] PathDigest0, private bool[256] PathDigest1) -> ():
def main(field treeDepth, u32[8] rootDigest, private u32[8] leafDigest, private bool[2] directionSelector, u32[8] PathDigest0, private u32[8] PathDigest1) -> ():
//Setup
bool[256] currentDigest = leafDigest
u32[8] currentDigest = leafDigest
field counter = 1
bool currentDirection = false
//Loop up the tree
currentDirection = directionSelector[0]
bool[256] lhs = multiplex(currentDirection, currentDigest, PathDigest0)
bool[256] rhs = multiplex(!currentDirection, currentDigest, PathDigest0)
u32[8] lhs = multiplex(currentDirection, currentDigest, PathDigest0)
u32[8] rhs = multiplex(!currentDirection, currentDigest, PathDigest0)
currentDigest = sha256(lhs, rhs)
counter = counter + 1
@ -23,7 +23,7 @@ def main(field treeDepth, bool[256] rootDigest, private bool[256] leafDigest, pr
currentDigest = sha256(lhs, rhs)
counter = counter + 1
counter == treeDepth
rootDigest == currentDigest
assert(counter == treeDepth)
assert(rootDigest == currentDigest)
return

View file

@ -1,4 +1,4 @@
// this code does not need to be flattened
def main(field x, field a, field b) -> (field):
a == b * 7
assert(a == b * 7)
return x + a + b

View file

@ -1,6 +1,6 @@
// this code does not need to be flattened
def main(field x, field y, field z) -> (field):
field a = x + 3*y - z *2 - x * 12
3*y - z *2 - x * 12 == a - x
(x + y) - ((z + 3*x) - y) == (x - y) + ((2*x - 4*y) + (4*y - 2*z))
assert(3*y - z *2 - x * 12 == a - x)
assert((x + y) - ((z + 3*x) - y) == (x - y) + ((2*x - 4*y) + (4*y - 2*z)))
return x

View file

@ -1,16 +1,16 @@
def main() -> ():
field x = 2**4
x == 16
assert(x == 16)
x = x**2
x == 256
assert(x == 256)
field y = 3**3
y == 27
assert(y == 27)
field z = y**2
z == 729
assert(z == 729)
field a = 5**2
a == 25
assert(a == 25)
a = a**2
a == 625
assert(a == 625)
a = 5**5
a == 3125
assert(a == 3125)
return

View file

@ -1,5 +1,5 @@
def foo(field a, field b) -> (field, field):
a == b + 2
assert(a == b + 2)
return a, b
def main() -> (field):

View file

@ -20,32 +20,32 @@ def countDuplicates(field e11,field e12,field e21,field e22) -> (field):
return duplicates
// returns 0 for x in (1..4)
def validateInput(field x) -> (field):
return (x-1)*(x-2)*(x-3)*(x-4)
def validateInput(field x) -> (bool):
return (x-1)*(x-2)*(x-3)*(x-4) == 0
// variables naming: box'row''column'
def main(field a21, field b11, field b22, field c11, field c22, field d21, private field a11, private field a12, private field a22, private field b12, private field b21, private field c12, private field c21, private field d11, private field d12, private field d22) -> (bool):
// validate inputs
0 == validateInput(a11)
0 == validateInput(a12)
0 == validateInput(a21)
0 == validateInput(a22)
assert(validateInput(a11))
assert(validateInput(a12))
assert(validateInput(a21))
assert(validateInput(a22))
0 == validateInput(b11)
0 == validateInput(b12)
0 == validateInput(b21)
0 == validateInput(b22)
assert(validateInput(b11))
assert(validateInput(b12))
assert(validateInput(b21))
assert(validateInput(b22))
0 == validateInput(c11)
0 == validateInput(c12)
0 == validateInput(c21)
0 == validateInput(c22)
assert(validateInput(c11))
assert(validateInput(c12))
assert(validateInput(c21))
assert(validateInput(c22))
0 == validateInput(d11)
0 == validateInput(d12)
0 == validateInput(d21)
0 == validateInput(d22)
assert(validateInput(d11))
assert(validateInput(d12))
assert(validateInput(d21))
assert(validateInput(d22))
field duplicates = 0 // globally counts duplicate entries in boxes, rows and columns

View file

@ -3,11 +3,11 @@ def main(field x) -> (field):
field a = 5
field b = 7
field c = if a == b then 4 else 3 fi
c == 3
assert(c == 3)
field d = if a == 5 then 1 else 2 fi
d == 1
assert(d == 1)
field e = if a < b then 5 else 6 fi
e == 5
assert(e == 5)
field f = if b < a then 7 else 8 fi
f == 8
assert(f == 8)
return x

View file

@ -3,12 +3,13 @@
// * we don't enforce only one number being non-prime, so there could be multiple Waldos
def isWaldo(field a, field p, field q) -> (bool):
// make sure that neither `p` nor `q` is `1`
(!(p == 1) && !(q == 1)) == true
// make sure that p and q are both not one
assert(p != 1 && q != 1)
// we know how to factor a
return a == p * q
// define all
def main(field[3] a, private field index, private field p, private field q) -> (bool):
// prover provides the index of Waldo
return isWaldo(a[index], p, q)

View file

@ -1,3 +1,3 @@
def main(field a, field b) -> ():
a==b
assert(a == b)
return

View file

@ -1,260 +0,0 @@
[
[
false,
false,
false,
true,
true,
true,
true,
true,
false,
false,
true,
true,
true,
false,
true,
true,
true,
false,
false,
false,
true,
false,
true,
true,
true,
false,
false,
true,
true,
false,
false,
false,
true,
true,
false,
false,
false,
false,
true,
false,
false,
false,
false,
true,
true,
true,
true,
false,
true,
false,
true,
true,
true,
false,
false,
false,
true,
false,
false,
true,
false,
true,
false,
false,
false,
false,
true,
true,
true,
true,
false,
false,
true,
false,
false,
false,
true,
true,
true,
false,
true,
true,
true,
false,
false,
false,
true,
true,
false,
false,
true,
true,
false,
false,
true,
false,
false,
false,
true,
false,
true,
true,
false,
false,
false,
false,
false,
true,
false,
true,
false,
false,
false,
false,
false,
true,
false,
true,
false,
false,
true,
false,
true,
true,
false,
true,
true,
false,
false,
false,
false,
true,
false,
false,
false,
false,
false,
true,
false,
true,
false,
true,
false,
true,
false,
true,
true,
false,
false,
false,
true,
false,
false,
true,
true,
false,
false,
false,
false,
true,
false,
true,
false,
false,
true,
true,
true,
false,
false,
true,
true,
true,
false,
false,
true,
true,
true,
false,
false,
false,
true,
true,
true,
true,
false,
false,
true,
true,
false,
true,
false,
true,
true,
true,
true,
false,
true,
true,
true,
true,
false,
false,
false,
true,
false,
false,
true,
true,
true,
false,
true,
false,
false,
false,
false,
false,
false,
true,
true,
true,
true,
false,
true,
true,
true,
true,
true,
false,
true,
false,
true,
false,
true,
true,
false,
false,
true,
true,
false,
false,
false,
false,
true,
true,
true,
true,
false,
true,
false,
false,
true,
false,
true,
true,
false,
true
]
]

View file

@ -1 +0,0 @@
~out_0 1

View file

@ -1,14 +0,0 @@
import "EMBED/sha256round" as sha256
def main(private bool[256] expected) -> (field):
bool[256] a = [false; 256]
bool[256] b = [false; 256]
b[253] = true
b[255] = true
bool[256] IV = [false, true, true, false, true, false, true, false, false, false, false, false, true, false, false, true, true, true, true, false, false, true, true, false, false, true, true, false, false, true, true, true, true, false, true, true, true, false, true, true, false, true, true, false, false, true, true, true, true, false, true, false, true, true, true, false, true, false, false, false, false, true, false, true, false, false, true, true, true, true, false, false, false, true, true, false, true, true, true, false, true, true, true, true, false, false, true, true, false, true, true, true, false, false, true, false, true, false, true, false, false, true, false, true, false, true, false, false, true, true, true, true, true, true, true, true, false, true, false, true, false, false, true, true, true, false, true, false, false, true, false, true, false, false, false, true, false, false, false, false, true, true, true, false, false, true, false, true, false, false, true, false, false, true, true, true, true, true, true, true, true, false, false, true, true, false, true, true, false, false, false, false, false, true, false, true, false, true, true, false, true, false, false, false, true, false, false, false, true, true, false, false, false, false, false, true, true, true, true, true, true, false, false, false, false, false, true, true, true, true, false, true, true, false, false, true, true, false, true, false, true, false, true, true, false, true, false, true, true, false, true, true, true, true, true, false, false, false, false, false, true, true, false, false, true, true, false, true, false, false, false, true, true, false, false, true]
expected == sha256([...a, ...b], IV)
return 1

View file

@ -9,7 +9,7 @@ build = "build.rs"
[features]
default = ["bellman_ce/nolog"]
libsnark = ["cc", "cmake", "git2"]
wasm = ["bellman_ce/wasm", "zokrates_embed/wasm"]
wasm = ["bellman_ce/wasm"]
multicore = ["bellman_ce/multicore"]
[dependencies]
@ -29,7 +29,6 @@ pairing_ce = "^0.21"
ff_ce = "^0.9"
zokrates_field = { version = "0.3.0", path = "../zokrates_field" }
zokrates_pest_ast = { version = "0.1.0", path = "../zokrates_pest_ast" }
zokrates_embed = { path = "../zokrates_embed" }
zokrates_common = { path = "../zokrates_common" }
rand = "0.4"
csv = "1"

View file

@ -262,23 +262,8 @@ impl<'ast, T: Field> From<pest::AssertionStatement<'ast>> for absy::StatementNod
fn from(statement: pest::AssertionStatement<'ast>) -> absy::StatementNode<T> {
use absy::NodeValue;
match statement.expression {
pest::Expression::Binary(e) => match e.op {
pest::BinaryOperator::Eq => absy::Statement::Condition(
absy::ExpressionNode::from(*e.left),
absy::ExpressionNode::from(*e.right),
),
_ => unimplemented!(
"Assertion statements should be an equality check, found {}",
statement.span.as_str()
),
},
_ => unimplemented!(
"Assertion statements should be an equality check, found {}",
statement.span.as_str()
),
}
.span(statement.span)
absy::Statement::Assertion(absy::ExpressionNode::from(statement.expression))
.span(statement.span)
}
}
@ -389,7 +374,14 @@ impl<'ast, T: Field> From<pest::BinaryExpression<'ast>> for absy::ExpressionNode
box absy::ExpressionNode::from(*expression.left),
box absy::ExpressionNode::from(*expression.right),
),
o => unimplemented!("Operator {:?} not implemented", o),
// rewrite (a != b)` as `!(a == b)`
pest::BinaryOperator::NotEq => absy::Expression::Not(
box absy::Expression::Eq(
box absy::ExpressionNode::from(*expression.left),
box absy::ExpressionNode::from(*expression.right),
)
.span(expression.span.clone()),
),
}
.span(expression.span)
}

View file

@ -299,7 +299,7 @@ pub enum Statement<'ast, T> {
Return(ExpressionListNode<'ast, T>),
Declaration(VariableNode<'ast>),
Definition(AssigneeNode<'ast, T>, ExpressionNode<'ast, T>),
Condition(ExpressionNode<'ast, T>, ExpressionNode<'ast, T>),
Assertion(ExpressionNode<'ast, T>),
For(
VariableNode<'ast>,
ExpressionNode<'ast, T>,
@ -317,7 +317,7 @@ impl<'ast, T: fmt::Display> fmt::Display for Statement<'ast, T> {
Statement::Return(ref expr) => write!(f, "return {}", expr),
Statement::Declaration(ref var) => write!(f, "{}", var),
Statement::Definition(ref lhs, ref rhs) => write!(f, "{} = {}", lhs, rhs),
Statement::Condition(ref lhs, ref rhs) => write!(f, "{} == {}", lhs, rhs),
Statement::Assertion(ref e) => write!(f, "assert({})", e),
Statement::For(ref var, ref start, ref stop, ref list) => {
write!(f, "for {} in {}..{} do\n", var, start, stop)?;
for l in list {
@ -346,7 +346,7 @@ impl<'ast, T: fmt::Debug> fmt::Debug for Statement<'ast, T> {
Statement::Definition(ref lhs, ref rhs) => {
write!(f, "Definition({:?}, {:?})", lhs, rhs)
}
Statement::Condition(ref lhs, ref rhs) => write!(f, "Condition({:?}, {:?})", lhs, rhs),
Statement::Assertion(ref e) => write!(f, "Assertion({:?})", e),
Statement::For(ref var, ref start, ref stop, ref list) => {
write!(f, "for {:?} in {:?}..{:?} do\n", var, start, stop)?;
for l in list {

View file

@ -1,19 +1,16 @@
use crate::solvers::Solver;
use bellman::pairing::ff::ScalarEngine;
use flat_absy::{
FlatDirective, FlatExpression, FlatExpressionList, FlatFunction, FlatParameter, FlatStatement,
FlatVariable,
};
use std::collections::HashMap;
use typed_absy::types::{FunctionKey, Signature, Type};
use zokrates_embed::{generate_sha256_round_constraints, BellmanConstraint};
use zokrates_field::Field;
/// A low level function that contains non-deterministic introduction of variables. It is carried out as is until
/// the flattening step when it can be inlined.
#[derive(Debug, Clone, PartialEq, Hash)]
pub enum FlatEmbed {
Sha256Round,
Unpack(usize),
U8ToBits,
U16ToBits,
@ -26,12 +23,6 @@ pub enum FlatEmbed {
impl FlatEmbed {
pub fn signature(&self) -> Signature {
match self {
FlatEmbed::Sha256Round => Signature::new()
.inputs(vec![
Type::array(Type::Boolean, 512),
Type::array(Type::Boolean, 256),
])
.outputs(vec![Type::array(Type::Boolean, 256)]),
FlatEmbed::Unpack(bitwidth) => Signature::new()
.inputs(vec![Type::FieldElement])
.outputs(vec![Type::array(Type::Boolean, *bitwidth)]),
@ -62,7 +53,6 @@ impl FlatEmbed {
pub fn id(&self) -> &'static str {
match self {
FlatEmbed::Sha256Round => "_SHA256_ROUND",
FlatEmbed::Unpack(_) => "_UNPACK",
FlatEmbed::U8ToBits => "_U8_TO_BITS",
FlatEmbed::U16ToBits => "_U16_TO_BITS",
@ -76,144 +66,12 @@ impl FlatEmbed {
/// Actually get the `FlatFunction` that this `FlatEmbed` represents
pub fn synthetize<T: Field>(&self) -> FlatFunction<T> {
match self {
FlatEmbed::Sha256Round => sha256_round(),
FlatEmbed::Unpack(bitwidth) => unpack_to_bitwidth(*bitwidth),
_ => unreachable!(),
}
}
}
// util to convert a vector of `(variable_id, coefficient)` to a flat_expression
// we build a binary tree of additions by splitting the vector recursively
fn flat_expression_from_vec<T: Field>(
v: &[(usize, <<T as Field>::BellmanEngine as ScalarEngine>::Fr)],
) -> FlatExpression<T> {
match v.len() {
0 => FlatExpression::Number(T::zero()),
1 => {
let (key, val) = v[0].clone();
FlatExpression::Mult(
box FlatExpression::Number(T::from_bellman(val)),
box FlatExpression::Identifier(FlatVariable::new(key)),
)
}
n => {
let (u, v) = v.split_at(n / 2);
FlatExpression::Add(
box flat_expression_from_vec(u),
box flat_expression_from_vec(v),
)
}
}
}
impl<T: Field> From<BellmanConstraint<T::BellmanEngine>> for FlatStatement<T> {
fn from(c: zokrates_embed::BellmanConstraint<T::BellmanEngine>) -> FlatStatement<T> {
let rhs_a = flat_expression_from_vec(&c.a);
let rhs_b = flat_expression_from_vec(&c.b);
let lhs = flat_expression_from_vec(&c.c);
FlatStatement::Condition(lhs, FlatExpression::Mult(box rhs_a, box rhs_b))
}
}
/// Returns a flat function which computes a sha256 round
///
/// # Remarks
///
/// The variables inside the function are set in this order:
/// - constraint system variables
/// - arguments
pub fn sha256_round<T: Field>() -> FlatFunction<T> {
// Define iterators for all indices at hand
let (r1cs, input_indices, current_hash_indices, output_indices) =
generate_sha256_round_constraints::<T::BellmanEngine>();
// indices of the input
let input_indices = input_indices.into_iter();
// indices of the current hash
let current_hash_indices = current_hash_indices.into_iter();
// indices of the output
let output_indices = output_indices.into_iter();
let variable_count = r1cs.aux_count + 1; // auxiliary and ONE
// indices of the sha256round constraint system variables
let cs_indices = (0..variable_count).into_iter();
// indices of the arguments to the function
// apply an offset of `variable_count` to get the indice of our dummy `input` argument
let input_argument_indices = input_indices
.clone()
.into_iter()
.map(|i| i + variable_count);
// apply an offset of `variable_count` to get the indice of our dummy `current_hash` argument
let current_hash_argument_indices = current_hash_indices
.clone()
.into_iter()
.map(|i| i + variable_count);
// define parameters to the function based on the variables
let arguments = input_argument_indices
.clone()
.chain(current_hash_argument_indices.clone())
.map(|i| FlatParameter {
id: FlatVariable::new(i),
private: true,
})
.collect();
// define a binding of the first variable in the constraint system to one
let one_binding_statement = FlatStatement::Condition(
FlatVariable::new(0).into(),
FlatExpression::Number(T::from(1)),
);
let input_binding_statements =
// bind input and current_hash to inputs
input_indices.clone().chain(current_hash_indices).zip(input_argument_indices.clone().chain(current_hash_argument_indices.clone())).map(|(cs_index, argument_index)| {
FlatStatement::Condition(
FlatVariable::new(cs_index).into(),
FlatVariable::new(argument_index).into(),
)
});
// insert flattened statements to represent constraints
let constraint_statements = r1cs.constraints.into_iter().map(|c| c.into());
// define which subset of the witness is returned
let outputs: Vec<FlatExpression<T>> = output_indices
.map(|o| FlatExpression::Identifier(FlatVariable::new(o)))
.collect();
// insert a directive to set the witness based on the bellman gadget and inputs
let directive_statement = FlatStatement::Directive(FlatDirective {
outputs: cs_indices.map(|i| FlatVariable::new(i)).collect(),
inputs: input_argument_indices
.chain(current_hash_argument_indices)
.map(|i| FlatVariable::new(i).into())
.collect(),
solver: Solver::Sha256Round,
});
// insert a statement to return the subset of the witness
let return_statement = FlatStatement::Return(FlatExpressionList {
expressions: outputs,
});
let statements = std::iter::once(directive_statement)
.chain(std::iter::once(one_binding_statement))
.chain(input_binding_statements)
.chain(constraint_statements)
.chain(std::iter::once(return_statement))
.collect();
FlatFunction {
arguments,
statements,
}
}
fn use_variable(
layout: &mut HashMap<String, FlatVariable>,
name: String,
@ -361,86 +219,4 @@ mod tests {
);
}
}
#[cfg(test)]
mod sha256 {
use super::*;
use ir::Interpreter;
#[test]
fn generate_sha256_constraints() {
let compiled = sha256_round();
// function should have 768 inputs
assert_eq!(compiled.arguments.len(), 768,);
// function should return 256 values
assert_eq!(
compiled
.statements
.iter()
.filter_map(|s| match s {
FlatStatement::Return(v) => Some(v),
_ => None,
})
.next()
.unwrap()
.expressions
.len(),
256,
);
// directive should take 768 inputs and return n_var outputs
let directive = compiled
.statements
.iter()
.filter_map(|s| match s {
FlatStatement::Directive(d) => Some(d.clone()),
_ => None,
})
.next()
.unwrap();
assert_eq!(directive.inputs.len(), 768);
assert_eq!(directive.outputs.len(), 26935);
// function input should be offset by variable_count
assert_eq!(
compiled.arguments[0].id,
FlatVariable::new(directive.outputs.len() + 1)
);
// bellman variable #0: index 0 should equal 1
assert_eq!(
compiled.statements[1],
FlatStatement::Condition(
FlatVariable::new(0).into(),
FlatExpression::Number(Bn128Field::from(1))
)
);
// bellman input #0: index 1 should equal zokrates input #0: index v_count
assert_eq!(
compiled.statements[2],
FlatStatement::Condition(
FlatVariable::new(1).into(),
FlatVariable::new(26936).into()
)
);
let f = crate::ir::Function::from(compiled);
let prog = crate::ir::Prog {
main: f,
private: vec![true; 768],
};
let input = (0..512)
.map(|_| 0)
.chain((0..256).map(|_| 1))
.map(|i| Bn128Field::from(i))
.collect();
let interpreter = Interpreter::default();
interpreter.execute(&prog, &input).unwrap();
}
}
}

View file

@ -575,9 +575,6 @@ impl<'ast, T: Field> Flattener<'ast, T> {
)
}
BooleanExpression::FieldEq(box lhs, box rhs) => {
// We know from semantic checking that lhs and rhs have the same type
// What the expression will flatten to depends on that type
// Wanted: (Y = (X != 0) ? 1 : 0)
// X = a - b
// # Y = if X == 0 then 0 else 1 fi
@ -616,6 +613,53 @@ impl<'ast, T: Field> Flattener<'ast, T> {
res
}
BooleanExpression::UintEq(box lhs, box rhs) => {
// We reduce each side into range and apply the same approach as for field elements
// Wanted: (Y = (X != 0) ? 1 : 0)
// X = a - b
// # Y = if X == 0 then 0 else 1 fi
// # M = if X == 0 then 1 else 1/X fi
// Y == X * M
// 0 == (1-Y) * X
let name_y = self.use_sym();
let name_m = self.use_sym();
assert!(lhs.metadata.clone().unwrap().should_reduce.to_bool());
assert!(rhs.metadata.clone().unwrap().should_reduce.to_bool());
let lhs = self
.flatten_uint_expression(symbols, statements_flattened, lhs)
.get_field_unchecked();
let rhs = self
.flatten_uint_expression(symbols, statements_flattened, rhs)
.get_field_unchecked();
let x = FlatExpression::Sub(box lhs, box rhs);
statements_flattened.push(FlatStatement::Directive(FlatDirective::new(
vec![name_y, name_m],
Solver::ConditionEq,
vec![x.clone()],
)));
statements_flattened.push(FlatStatement::Condition(
FlatExpression::Identifier(name_y),
FlatExpression::Mult(box x.clone(), box FlatExpression::Identifier(name_m)),
));
let res = FlatExpression::Sub(
box FlatExpression::Number(T::one()),
box FlatExpression::Identifier(name_y),
);
statements_flattened.push(FlatStatement::Condition(
FlatExpression::Number(T::zero()),
FlatExpression::Mult(box res.clone(), box x),
));
res
}
BooleanExpression::Le(box lhs, box rhs) => {
let lt = self.flatten_boolean_expression(
symbols,
@ -1851,24 +1895,22 @@ impl<'ast, T: Field> Flattener<'ast, T> {
None => {}
}
}
ZirStatement::Condition(lhs, rhs) => {
// flatten expr1 and expr2 to n flattened expressions with n the number of primitive types for expr1
// add n conditions to check equality of the n expressions
ZirStatement::Assertion(e) => {
// naive approach: flatten the boolean to a single field element and constrain it to 1
let lhs = self
.flatten_expression(symbols, statements_flattened, lhs)
.get_field_unchecked();
let rhs = self
.flatten_expression(symbols, statements_flattened, rhs)
.get_field_unchecked();
let e = self.flatten_boolean_expression(symbols, statements_flattened, e);
if lhs.is_linear() {
statements_flattened.push(FlatStatement::Condition(lhs, rhs));
} else if rhs.is_linear() {
// swap so that left side is linear
statements_flattened.push(FlatStatement::Condition(rhs, lhs));
if e.is_linear() {
statements_flattened.push(FlatStatement::Condition(
e,
FlatExpression::Number(T::from(1)),
));
} else {
unreachable!()
// swap so that left side is linear
statements_flattened.push(FlatStatement::Condition(
FlatExpression::Number(T::from(1)),
e,
));
}
}
ZirStatement::MultipleDefinition(vars, rhs) => {

View file

@ -149,17 +149,6 @@ impl Importer {
// handle the case of special bellman and packing imports
if import.source.starts_with("EMBED") {
match import.source.to_str().unwrap() {
"EMBED/sha256round" => {
let alias = alias.unwrap_or("sha256round");
symbols.push(
SymbolDeclaration {
id: &alias,
symbol: Symbol::Flat(FlatEmbed::Sha256Round),
}
.start_end(pos.0, pos.1),
);
}
"EMBED/unpack" => {
let alias = alias.unwrap_or("unpack");

View file

@ -4,7 +4,6 @@ use ir::Directive;
use solvers::Solver;
use std::collections::BTreeMap;
use std::fmt;
use zokrates_embed::generate_sha256_round_witness;
use zokrates_field::Field;
pub type ExecutionResult<T> = Result<Witness<T>, Error>;
@ -186,17 +185,6 @@ impl Interpreter {
vec![a * (b - c.clone()) + c]
}
Solver::Div => vec![inputs[0].clone() / inputs[1].clone()],
Solver::Sha256Round => {
let i = &inputs[0..512];
let h = &inputs[512..];
let i: Vec<_> = i.iter().map(|x| x.clone().into_bellman()).collect();
let h: Vec<_> = h.iter().map(|x| x.clone().into_bellman()).collect();
assert!(h.len() == 256);
generate_sha256_round_witness::<T::BellmanEngine>(&i, &h)
.into_iter()
.map(|x| T::from_bellman(x))
.collect()
}
};
assert_eq!(res.len(), expected_output_count);

View file

@ -17,7 +17,6 @@ extern crate lazy_static;
extern crate pairing_ce as pairing;
extern crate regex;
extern crate zokrates_common;
extern crate zokrates_embed;
extern crate zokrates_field;
extern crate zokrates_pest_ast;

View file

@ -875,27 +875,21 @@ impl<'ast> Checker<'ast> {
}
.map_err(|e| vec![e])
}
Statement::Condition(lhs, rhs) => {
let checked_lhs = self
.check_expression(lhs, module_id, &types)
.map_err(|e| vec![e])?;
let checked_rhs = self
.check_expression(rhs, module_id, &types)
Statement::Assertion(e) => {
let e = self
.check_expression(e, module_id, &types)
.map_err(|e| vec![e])?;
if checked_lhs.get_type() == checked_rhs.get_type() {
Ok(TypedStatement::Condition(checked_lhs, checked_rhs))
} else {
Err(ErrorInner {
match e {
TypedExpression::Boolean(e) => Ok(TypedStatement::Assertion(e)),
e => Err(ErrorInner {
pos: Some(pos),
message: format!(
"Cannot compare {} of type {:?} to {} of type {:?}",
checked_lhs,
checked_lhs.get_type(),
checked_rhs,
checked_rhs.get_type(),
"Expected {} to be of type bool, found {}",
e,
e.get_type(),
),
})
}),
}
.map_err(|e| vec![e])
}
@ -1543,6 +1537,54 @@ impl<'ast> Checker<'ast> {
(TypedExpression::Boolean(e1), TypedExpression::Boolean(e2)) => {
Ok(BooleanExpression::BoolEq(box e1, box e2).into())
}
(TypedExpression::Array(e1), TypedExpression::Array(e2)) => {
if e1.get_type() == e2.get_type() {
Ok(BooleanExpression::ArrayEq(box e1, box e2).into())
} else {
Err(ErrorInner {
pos: Some(pos),
message: format!(
"Cannot compare {} of type {} to {} of type {}",
e1,
e1.get_type(),
e2,
e2.get_type()
),
})
}
}
(TypedExpression::Struct(e1), TypedExpression::Struct(e2)) => {
if e1.get_type() == e2.get_type() {
Ok(BooleanExpression::StructEq(box e1, box e2).into())
} else {
Err(ErrorInner {
pos: Some(pos),
message: format!(
"Cannot compare {} of type {} to {} of type {}",
e1,
e1.get_type(),
e2,
e2.get_type()
),
})
}
}
(TypedExpression::Uint(e1), TypedExpression::Uint(e2)) => {
if e1.get_type() == e2.get_type() {
Ok(BooleanExpression::UintEq(box e1, box e2).into())
} else {
Err(ErrorInner {
pos: Some(pos),
message: format!(
"Cannot compare {} of type {} to {} of type {}",
e1,
e1.get_type(),
e2,
e2.get_type()
),
})
}
}
(e1, e2) => Err(ErrorInner {
pos: Some(pos),
message: format!(
@ -3136,9 +3178,12 @@ mod tests {
// def bar():
// 2 == foo()
// should fail
let bar_statements: Vec<StatementNode<Bn128Field>> = vec![Statement::Condition(
Expression::FieldConstant(Bn128Field::from(2)).mock(),
Expression::FunctionCall("foo", vec![]).mock(),
let bar_statements: Vec<StatementNode<Bn128Field>> = vec![Statement::Assertion(
Expression::Eq(
box Expression::FieldConstant(Bn128Field::from(2)).mock(),
box Expression::FunctionCall("foo", vec![]).mock(),
)
.mock(),
)
.mock()];
@ -3535,9 +3580,12 @@ mod tests {
// def bar():
// 1 == foo()
// should fail
let bar_statements: Vec<StatementNode<Bn128Field>> = vec![Statement::Condition(
Expression::FieldConstant(Bn128Field::from(1)).mock(),
Expression::FunctionCall("foo", vec![]).mock(),
let bar_statements: Vec<StatementNode<Bn128Field>> = vec![Statement::Assertion(
Expression::Eq(
box Expression::FieldConstant(Bn128Field::from(1)).mock(),
box Expression::FunctionCall("foo", vec![]).mock(),
)
.mock(),
)
.mock()];

View file

@ -6,7 +6,6 @@ pub enum Solver {
ConditionEq,
Bits(usize),
Div,
Sha256Round,
Xor,
Or,
ShaAndXorAndXorAnd,
@ -25,7 +24,6 @@ impl Solver {
Solver::ConditionEq => (1, 2),
Solver::Bits(bit_width) => (1, *bit_width),
Solver::Div => (2, 1),
Solver::Sha256Round => (768, 26935),
Solver::Xor => (2, 1),
Solver::Or => (2, 1),
Solver::ShaAndXorAndXorAnd => (3, 1),

View file

@ -1,232 +0,0 @@
use crate::flat_absy::{FlatExpression, FlatExpressionList, FlatFunction, FlatStatement};
use crate::flat_absy::{FlatParameter, FlatVariable};
use crate::helpers::{DirectiveStatement, Helper, RustHelper};
use crate::types::{Signature, Type};
use bellman::pairing::ff::ScalarEngine;
use reduce::Reduce;
use zokrates_embed::{generate_sha256_round_constraints, BellmanConstraint};
use zokrates_field::Field;
// util to convert a vector of `(variable_id, coefficient)` to a flat_expression
fn flat_expression_from_vec<T: Field>(
v: Vec<(usize, <<T as Field>::BellmanEngine as ScalarEngine>::Fr)>,
) -> FlatExpression<T> {
match v
.into_iter()
.map(|(key, val)| {
FlatExpression::Mult(
box FlatExpression::Number(T::from_bellman(val)),
box FlatExpression::Identifier(FlatVariable::new(key)),
)
})
.reduce(|acc, e| FlatExpression::Add(box acc, box e))
{
Some(e @ FlatExpression::Mult(..)) => {
FlatExpression::Add(box FlatExpression::Number(T::zero()), box e)
} // the R1CS serializer only recognizes Add
Some(e) => e,
None => FlatExpression::Number(T::zero()),
}
}
impl<T: Field> From<BellmanConstraint<T::BellmanEngine>> for FlatStatement<T> {
fn from(c: zokrates_embed::BellmanConstraint<T::BellmanEngine>) -> FlatStatement<T> {
let rhs_a = flat_expression_from_vec(c.a);
let rhs_b = flat_expression_from_vec(c.b);
let lhs = flat_expression_from_vec(c.c);
FlatStatement::Condition(lhs, FlatExpression::Mult(box rhs_a, box rhs_b))
}
}
/// Returns a flat function which computes a sha256 round
///
/// # Remarks
///
/// The variables inside the function are set in this order:
/// - constraint system variables
/// - arguments
pub fn sha_round<T: Field>() -> FlatFunction<T> {
// Define iterators for all indices at hand
let (r1cs, input_indices, current_hash_indices, output_indices) =
generate_sha256_round_constraints::<T::BellmanEngine>();
// indices of the input
let input_indices = input_indices.into_iter();
// indices of the current hash
let current_hash_indices = current_hash_indices.into_iter();
// indices of the output
let output_indices = output_indices.into_iter();
let variable_count = r1cs.aux_count + 1; // auxiliary and ONE
// indices of the sha256round constraint system variables
let cs_indices = (0..variable_count).into_iter();
// indices of the arguments to the function
// apply an offset of `variable_count` to get the indice of our dummy `input` argument
let input_argument_indices = input_indices
.clone()
.into_iter()
.map(|i| i + variable_count);
// apply an offset of `variable_count` to get the indice of our dummy `current_hash` argument
let current_hash_argument_indices = current_hash_indices
.clone()
.into_iter()
.map(|i| i + variable_count);
// define the signature of the resulting function
let signature = Signature {
inputs: vec![
Type::array(Type::FieldElement, input_indices.len()),
Type::array(Type::FieldElement, current_hash_indices.len()),
],
outputs: vec![Type::array(Type::FieldElement, output_indices.len())],
};
// define parameters to the function based on the variables
let arguments = input_argument_indices
.clone()
.chain(current_hash_argument_indices.clone())
.map(|i| FlatParameter {
id: FlatVariable::new(i),
private: true,
})
.collect();
// define a binding of the first variable in the constraint system to one
let one_binding_statement = FlatStatement::Condition(
FlatVariable::new(0).into(),
FlatExpression::Number(T::from(1)),
);
let input_binding_statements =
// bind input and current_hash to inputs
input_indices.clone().chain(current_hash_indices).zip(input_argument_indices.clone().chain(current_hash_argument_indices.clone())).map(|(cs_index, argument_index)| {
FlatStatement::Condition(
FlatVariable::new(cs_index).into(),
FlatVariable::new(argument_index).into(),
)
});
// insert flattened statements to represent constraints
let constraint_statements = r1cs.constraints.into_iter().map(|c| c.into());
// define which subset of the witness is returned
let outputs: Vec<FlatExpression<T>> = output_indices
.map(|o| FlatExpression::Identifier(FlatVariable::new(o)))
.collect();
// insert a directive to set the witness based on the bellman gadget and inputs
let directive_statement = FlatStatement::Directive(DirectiveStatement {
outputs: cs_indices.map(|i| FlatVariable::new(i)).collect(),
inputs: input_argument_indices
.chain(current_hash_argument_indices)
.map(|i| FlatVariable::new(i).into())
.collect(),
helper: Helper::Rust(RustHelper::Sha256Round),
});
// insert a statement to return the subset of the witness
let return_statement = FlatStatement::Return(FlatExpressionList {
expressions: outputs,
});
let statements = std::iter::once(directive_statement)
.chain(std::iter::once(one_binding_statement))
.chain(input_binding_statements)
.chain(constraint_statements)
.chain(std::iter::once(return_statement))
.collect();
FlatFunction {
id: "main".to_owned(),
arguments,
statements,
signature,
}
}
#[cfg(test)]
mod tests {
use super::*;
use zokrates_field::Bn128Field;
#[test]
fn generate_sha256_constraints() {
let compiled = sha_round();
// function should have a signature of 768 inputs and 256 outputs
assert_eq!(
compiled.signature,
Signature::new()
.inputs(vec![
Type::array(Type::FieldElement, 512),
Type::array(Type::FieldElement, 256)
])
.outputs(vec![Type::array(Type::FieldElement, 256)])
);
// function should have 768 inputs
assert_eq!(compiled.arguments.len(), 768,);
// function should return 256 values
assert_eq!(
compiled
.statements
.iter()
.filter_map(|s| match s {
FlatStatement::Return(v) => Some(v),
_ => None,
})
.next()
.unwrap()
.expressions
.len(),
256,
);
// directive should take 768 inputs and return n_var outputs
let directive = compiled
.statements
.iter()
.filter_map(|s| match s {
FlatStatement::Directive(d) => Some(d.clone()),
_ => None,
})
.next()
.unwrap();
assert_eq!(directive.inputs.len(), 768);
assert_eq!(directive.outputs.len(), 26935);
// function input should be offset by variable_count
assert_eq!(
compiled.arguments[0].id,
FlatVariable::new(directive.outputs.len() + 1)
);
// bellman variable #0: index 0 should equal 1
assert_eq!(
compiled.statements[1],
FlatStatement::Condition(
FlatVariable::new(0).into(),
FlatExpression::Number(Bn128Field::from(1))
)
);
// bellman input #0: index 1 should equal zokrates input #0: index v_count
assert_eq!(
compiled.statements[2],
FlatStatement::Condition(FlatVariable::new(1).into(), FlatVariable::new(26936).into())
);
let f = crate::ir::Function::from(compiled);
let prog = crate::ir::Prog {
main: f,
private: vec![true; 768],
};
let input = (0..512).map(|_| 0).chain((0..256).map(|_| 1)).collect();
prog.execute(&input).unwrap();
}
}

View file

@ -248,14 +248,9 @@ pub fn fold_statement<'ast, T: Field>(
.map(|v| zir::ZirStatement::Declaration(v))
.collect()
}
typed_absy::TypedStatement::Condition(left, right) => {
let left = f.fold_expression(left);
let right = f.fold_expression(right);
assert_eq!(left.len(), right.len());
left.into_iter()
.zip(right.into_iter())
.map(|(left, right)| zir::ZirStatement::Condition(left, right))
.collect()
typed_absy::TypedStatement::Assertion(e) => {
let e = f.fold_boolean_expression(e);
vec![zir::ZirStatement::Assertion(e)]
}
typed_absy::TypedStatement::For(..) => unreachable!(),
typed_absy::TypedStatement::MultipleDefinition(variables, elist) => {
@ -555,6 +550,68 @@ pub fn fold_boolean_expression<'ast, T: Field>(
let e2 = f.fold_boolean_expression(e2);
zir::BooleanExpression::BoolEq(box e1, box e2)
}
typed_absy::BooleanExpression::ArrayEq(box e1, box e2) => {
let e1 = f.fold_array_expression(e1);
let e2 = f.fold_array_expression(e2);
assert_eq!(e1.len(), e2.len());
e1.into_iter().zip(e2.into_iter()).fold(
zir::BooleanExpression::Value(true),
|acc, (e1, e2)| {
zir::BooleanExpression::And(
box acc,
box match (e1, e2) {
(
zir::ZirExpression::FieldElement(e1),
zir::ZirExpression::FieldElement(e2),
) => zir::BooleanExpression::FieldEq(box e1, box e2),
(zir::ZirExpression::Boolean(e1), zir::ZirExpression::Boolean(e2)) => {
zir::BooleanExpression::BoolEq(box e1, box e2)
}
(zir::ZirExpression::Uint(e1), zir::ZirExpression::Uint(e2)) => {
zir::BooleanExpression::UintEq(box e1, box e2)
}
_ => unreachable!(),
},
)
},
)
}
typed_absy::BooleanExpression::StructEq(box e1, box e2) => {
let e1 = f.fold_struct_expression(e1);
let e2 = f.fold_struct_expression(e2);
assert_eq!(e1.len(), e2.len());
e1.into_iter().zip(e2.into_iter()).fold(
zir::BooleanExpression::Value(true),
|acc, (e1, e2)| {
zir::BooleanExpression::And(
box acc,
box match (e1, e2) {
(
zir::ZirExpression::FieldElement(e1),
zir::ZirExpression::FieldElement(e2),
) => zir::BooleanExpression::FieldEq(box e1, box e2),
(zir::ZirExpression::Boolean(e1), zir::ZirExpression::Boolean(e2)) => {
zir::BooleanExpression::BoolEq(box e1, box e2)
}
(zir::ZirExpression::Uint(e1), zir::ZirExpression::Uint(e2)) => {
zir::BooleanExpression::UintEq(box e1, box e2)
}
_ => unreachable!(),
},
)
},
)
}
typed_absy::BooleanExpression::UintEq(box e1, box e2) => {
let e1 = f.fold_uint_expression(e1);
let e2 = f.fold_uint_expression(e2);
zir::BooleanExpression::UintEq(box e1, box e2)
}
typed_absy::BooleanExpression::Lt(box e1, box e2) => {
let e1 = f.fold_field_expression(e1);
let e2 = f.fold_field_expression(e2);

View file

@ -104,10 +104,6 @@ impl<'ast, T: Field> Inliner<'ast, T> {
let unpack = crate::embed::FlatEmbed::Unpack(T::get_required_bits());
let unpack_key = unpack.key::<T>();
// define a function in the main module for the `sha256_round` embed
let sha256_round = crate::embed::FlatEmbed::Sha256Round;
let sha256_round_key = sha256_round.key::<T>();
// define a function in the main module for the `u32_to_bits` embed
let u32_to_bits = crate::embed::FlatEmbed::U32ToBits;
let u32_to_bits_key = u32_to_bits.key::<T>();
@ -140,7 +136,6 @@ impl<'ast, T: Field> Inliner<'ast, T> {
TypedModule {
functions: vec![
(unpack_key, TypedFunctionSymbol::Flat(unpack)),
(sha256_round_key, TypedFunctionSymbol::Flat(sha256_round)),
(u32_from_bits_key, TypedFunctionSymbol::Flat(u32_from_bits)),
(u16_from_bits_key, TypedFunctionSymbol::Flat(u16_from_bits)),
(u8_from_bits_key, TypedFunctionSymbol::Flat(u8_from_bits)),

View file

@ -55,7 +55,7 @@ impl<'ast, T: Field> TypedProgram<'ast, T> {
let r = VariableAccessRemover::apply(r);
// convert to zir, removing complex types
let zir = Flattener::flatten(r.clone());
let zir = Flattener::flatten(r);
// optimize uint expressions
let zir = UintOptimizer::optimize(zir);

View file

@ -112,13 +112,10 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> {
TypedStatement::Definition(TypedAssignee::Member(..), _) => {
unreachable!("struct update should have been replaced with full struct redef")
}
// propagate lhs and rhs for conditions
TypedStatement::Condition(e1, e2) => {
// propagate the boolean
TypedStatement::Assertion(e) => {
// could stop execution here if condition is known to fail
Some(TypedStatement::Condition(
self.fold_expression(e1),
self.fold_expression(e2),
))
Some(TypedStatement::Assertion(self.fold_boolean_expression(e)))
}
// only loops with variable bounds are expected here
// we stop propagation here as constants maybe be modified inside the loop body
@ -957,6 +954,11 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> {
&mut self,
e: BooleanExpression<'ast, T>,
) -> BooleanExpression<'ast, T> {
// Note: we only propagate when we see constants, as comparing of arbitrary expressions would lead to
// a lot of false negatives due to expressions not being in a canonical form
// For example, `2 * a` is equivalent to `a + a`, but our notion of equality would not detect that here
// These kind of reduction rules are easier to apply later in the process, when we have canonical representations
// of expressions, ie `a + a` would always be written `2 * a`
match e {
BooleanExpression::Identifier(id) => match self
.constants

View file

@ -53,6 +53,24 @@ fn force_no_reduce<'ast, T: Field>(e: UExpression<'ast, T>) -> UExpression<'ast,
}
impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> {
fn fold_boolean_expression(
&mut self,
e: BooleanExpression<'ast, T>,
) -> BooleanExpression<'ast, T> {
match e {
BooleanExpression::UintEq(box left, box right) => {
let left = self.fold_uint_expression(left);
let right = self.fold_uint_expression(right);
let left = force_reduce(left);
let right = force_reduce(right);
BooleanExpression::UintEq(box left, box right)
}
e => fold_boolean_expression(self, e),
}
}
fn fold_uint_expression(&mut self, e: UExpression<'ast, T>) -> UExpression<'ast, T> {
if e.metadata.is_some() {
return e;
@ -398,17 +416,17 @@ impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> {
)],
},
},
// we need to put back in range to assert
ZirStatement::Condition(lhs, rhs) => {
match (self.fold_expression(lhs), self.fold_expression(rhs)) {
(ZirExpression::Uint(lhs), ZirExpression::Uint(rhs)) => {
vec![ZirStatement::Condition(
force_reduce(lhs).into(),
force_reduce(rhs).into(),
)]
}
(lhs, rhs) => vec![ZirStatement::Condition(lhs, rhs)],
}
ZirStatement::Assertion(BooleanExpression::UintEq(box left, box right)) => {
let left = self.fold_uint_expression(left);
let right = self.fold_uint_expression(right);
// we can only compare two unsigned integers if they are in range
let left = force_reduce(left);
let right = force_reduce(right);
vec![ZirStatement::Assertion(BooleanExpression::UintEq(
box left, box right,
))]
}
s => fold_statement(self, s),
}

View file

@ -84,13 +84,12 @@ impl<'ast> Unroller<'ast> {
match head {
Access::Select(head) => {
statements.insert(TypedStatement::Condition(
statements.insert(TypedStatement::Assertion(
BooleanExpression::Lt(
box head.clone(),
box FieldElementExpression::Number(T::from(size)),
)
.into(),
BooleanExpression::Value(true).into(),
));
ArrayExpressionInner::Value(
@ -1089,13 +1088,12 @@ mod tests {
assert_eq!(
u.fold_statement(s),
vec![
TypedStatement::Condition(
TypedStatement::Assertion(
BooleanExpression::Lt(
box FieldElementExpression::Number(Bn128Field::from(1)),
box FieldElementExpression::Number(Bn128Field::from(2))
)
.into(),
BooleanExpression::Value(true).into()
),
TypedStatement::Definition(
TypedAssignee::Identifier(Variable::field_array(
@ -1227,13 +1225,12 @@ mod tests {
assert_eq!(
u.fold_statement(s),
vec![
TypedStatement::Condition(
TypedStatement::Assertion(
BooleanExpression::Lt(
box FieldElementExpression::Number(Bn128Field::from(1)),
box FieldElementExpression::Number(Bn128Field::from(2))
)
.into(),
BooleanExpression::Value(true).into()
),
TypedStatement::Definition(
TypedAssignee::Identifier(Variable::with_id_and_type(

View file

@ -39,7 +39,7 @@ impl<'ast, T: Field> VariableAccessRemover<'ast, T> {
_ => unreachable!(),
};
self.statements.push(TypedStatement::Condition(
self.statements.push(TypedStatement::Assertion(
(0..size)
.map(|index| {
BooleanExpression::FieldEq(
@ -53,7 +53,6 @@ impl<'ast, T: Field> VariableAccessRemover<'ast, T> {
})
.unwrap()
.into(),
BooleanExpression::Value(true).into(),
));
(0..size)
@ -170,7 +169,7 @@ mod tests {
assert_eq!(
VariableAccessRemover::new().fold_statement(access),
vec![
TypedStatement::Condition(
TypedStatement::Assertion(
BooleanExpression::Or(
box BooleanExpression::FieldEq(
box FieldElementExpression::Identifier("i".into()),
@ -182,7 +181,6 @@ mod tests {
)
)
.into(),
BooleanExpression::Value(true).into()
),
TypedStatement::Definition(
TypedAssignee::Identifier(Variable::field_element("b")),

View file

@ -165,9 +165,7 @@ pub fn fold_statement<'ast, T: Field, F: Folder<'ast, T>>(
TypedStatement::Definition(f.fold_assignee(a), f.fold_expression(e))
}
TypedStatement::Declaration(v) => TypedStatement::Declaration(f.fold_variable(v)),
TypedStatement::Condition(left, right) => {
TypedStatement::Condition(f.fold_expression(left), f.fold_expression(right))
}
TypedStatement::Assertion(e) => TypedStatement::Assertion(f.fold_boolean_expression(e)),
TypedStatement::For(v, from, to, statements) => TypedStatement::For(
f.fold_variable(v),
from,
@ -325,6 +323,21 @@ pub fn fold_boolean_expression<'ast, T: Field, F: Folder<'ast, T>>(
let e2 = f.fold_boolean_expression(e2);
BooleanExpression::BoolEq(box e1, box e2)
}
BooleanExpression::ArrayEq(box e1, box e2) => {
let e1 = f.fold_array_expression(e1);
let e2 = f.fold_array_expression(e2);
BooleanExpression::ArrayEq(box e1, box e2)
}
BooleanExpression::StructEq(box e1, box e2) => {
let e1 = f.fold_struct_expression(e1);
let e2 = f.fold_struct_expression(e2);
BooleanExpression::StructEq(box e1, box e2)
}
BooleanExpression::UintEq(box e1, box e2) => {
let e1 = f.fold_uint_expression(e1);
let e2 = f.fold_uint_expression(e2);
BooleanExpression::UintEq(box e1, box e2)
}
BooleanExpression::Lt(box e1, box e2) => {
let e1 = f.fold_field_expression(e1);
let e2 = f.fold_field_expression(e2);

View file

@ -300,7 +300,7 @@ pub enum TypedStatement<'ast, T> {
Return(Vec<TypedExpression<'ast, T>>),
Definition(TypedAssignee<'ast, T>, TypedExpression<'ast, T>),
Declaration(Variable<'ast>),
Condition(TypedExpression<'ast, T>, TypedExpression<'ast, T>),
Assertion(BooleanExpression<'ast, T>),
For(
Variable<'ast>,
FieldElementExpression<'ast, T>,
@ -327,9 +327,7 @@ impl<'ast, T: fmt::Debug> fmt::Debug for TypedStatement<'ast, T> {
TypedStatement::Definition(ref lhs, ref rhs) => {
write!(f, "Definition({:?}, {:?})", lhs, rhs)
}
TypedStatement::Condition(ref lhs, ref rhs) => {
write!(f, "Condition({:?}, {:?})", lhs, rhs)
}
TypedStatement::Assertion(ref e) => write!(f, "Assertion({:?})", e),
TypedStatement::For(ref var, ref start, ref stop, ref list) => {
write!(f, "for {:?} in {:?}..{:?} do\n", var, start, stop)?;
for l in list {
@ -376,7 +374,7 @@ impl<'ast, T: fmt::Display> fmt::Display for TypedStatement<'ast, T> {
}
TypedStatement::Declaration(ref var) => write!(f, "{}", var),
TypedStatement::Definition(ref lhs, ref rhs) => write!(f, "{} = {}", lhs, rhs),
TypedStatement::Condition(ref lhs, ref rhs) => write!(f, "{} == {}", lhs, rhs),
TypedStatement::Assertion(ref e) => write!(f, "assert({})", e),
TypedStatement::For(ref var, ref start, ref stop, ref list) => {
write!(f, "for {} in {}..{} do\n", var, start, stop)?;
for l in list {
@ -639,6 +637,12 @@ pub enum BooleanExpression<'ast, T> {
Box<BooleanExpression<'ast, T>>,
Box<BooleanExpression<'ast, T>>,
),
ArrayEq(Box<ArrayExpression<'ast, T>>, Box<ArrayExpression<'ast, T>>),
StructEq(
Box<StructExpression<'ast, T>>,
Box<StructExpression<'ast, T>>,
),
UintEq(Box<UExpression<'ast, T>>, Box<UExpression<'ast, T>>),
Ge(
Box<FieldElementExpression<'ast, T>>,
Box<FieldElementExpression<'ast, T>>,
@ -906,6 +910,9 @@ impl<'ast, T: fmt::Display> fmt::Display for BooleanExpression<'ast, T> {
BooleanExpression::Le(ref lhs, ref rhs) => write!(f, "{} <= {}", lhs, rhs),
BooleanExpression::FieldEq(ref lhs, ref rhs) => write!(f, "{} == {}", lhs, rhs),
BooleanExpression::BoolEq(ref lhs, ref rhs) => write!(f, "{} == {}", lhs, rhs),
BooleanExpression::ArrayEq(ref lhs, ref rhs) => write!(f, "{} == {}", lhs, rhs),
BooleanExpression::StructEq(ref lhs, ref rhs) => write!(f, "{} == {}", lhs, rhs),
BooleanExpression::UintEq(ref lhs, ref rhs) => write!(f, "{} == {}", lhs, rhs),
BooleanExpression::Ge(ref lhs, ref rhs) => write!(f, "{} >= {}", lhs, rhs),
BooleanExpression::Gt(ref lhs, ref rhs) => write!(f, "{} > {}", lhs, rhs),
BooleanExpression::Or(ref lhs, ref rhs) => write!(f, "{} || {}", lhs, rhs),
@ -985,6 +992,15 @@ impl<'ast, T: fmt::Debug> fmt::Debug for BooleanExpression<'ast, T> {
BooleanExpression::BoolEq(ref lhs, ref rhs) => {
write!(f, "BoolEq({:?}, {:?})", lhs, rhs)
}
BooleanExpression::ArrayEq(ref lhs, ref rhs) => {
write!(f, "ArrayEq({:?}, {:?})", lhs, rhs)
}
BooleanExpression::StructEq(ref lhs, ref rhs) => {
write!(f, "StructEq({:?}, {:?})", lhs, rhs)
}
BooleanExpression::UintEq(ref lhs, ref rhs) => {
write!(f, "UintEq({:?}, {:?})", lhs, rhs)
}
BooleanExpression::Ge(ref lhs, ref rhs) => write!(f, "Ge({:?}, {:?})", lhs, rhs),
BooleanExpression::Gt(ref lhs, ref rhs) => write!(f, "Gt({:?}, {:?})", lhs, rhs),
BooleanExpression::And(ref lhs, ref rhs) => write!(f, "And({:?}, {:?})", lhs, rhs),

View file

@ -130,9 +130,7 @@ pub fn fold_statement<'ast, T: Field, F: Folder<'ast, T>>(
ZirStatement::Definition(f.fold_assignee(a), f.fold_expression(e))
}
ZirStatement::Declaration(v) => ZirStatement::Declaration(f.fold_variable(v)),
ZirStatement::Condition(left, right) => {
ZirStatement::Condition(f.fold_expression(left), f.fold_expression(right))
}
ZirStatement::Assertion(e) => ZirStatement::Assertion(f.fold_boolean_expression(e)),
ZirStatement::MultipleDefinition(variables, elist) => ZirStatement::MultipleDefinition(
variables.into_iter().map(|v| f.fold_variable(v)).collect(),
f.fold_expression_list(elist),
@ -201,6 +199,11 @@ pub fn fold_boolean_expression<'ast, T: Field, F: Folder<'ast, T>>(
let e2 = f.fold_boolean_expression(e2);
BooleanExpression::BoolEq(box e1, box e2)
}
BooleanExpression::UintEq(box e1, box e2) => {
let e1 = f.fold_uint_expression(e1);
let e2 = f.fold_uint_expression(e2);
BooleanExpression::UintEq(box e1, box e2)
}
BooleanExpression::Lt(box e1, box e2) => {
let e1 = f.fold_field_expression(e1);
let e2 = f.fold_field_expression(e2);

View file

@ -191,7 +191,7 @@ pub enum ZirStatement<'ast, T> {
Return(Vec<ZirExpression<'ast, T>>),
Definition(ZirAssignee<'ast>, ZirExpression<'ast, T>),
Declaration(Variable<'ast>),
Condition(ZirExpression<'ast, T>, ZirExpression<'ast, T>),
Assertion(BooleanExpression<'ast, T>),
MultipleDefinition(Vec<Variable<'ast>>, ZirExpressionList<'ast, T>),
}
@ -212,9 +212,7 @@ impl<'ast, T: fmt::Debug> fmt::Debug for ZirStatement<'ast, T> {
ZirStatement::Definition(ref lhs, ref rhs) => {
write!(f, "Definition({:?}, {:?})", lhs, rhs)
}
ZirStatement::Condition(ref lhs, ref rhs) => {
write!(f, "Condition({:?}, {:?})", lhs, rhs)
}
ZirStatement::Assertion(ref e) => write!(f, "Assertion({:?})", e),
ZirStatement::MultipleDefinition(ref lhs, ref rhs) => {
write!(f, "MultipleDefinition({:?}, {:?})", lhs, rhs)
}
@ -235,9 +233,9 @@ impl<'ast, T: fmt::Display> fmt::Display for ZirStatement<'ast, T> {
}
write!(f, "")
}
ZirStatement::Declaration(ref var) => write!(f, "{}", var),
ZirStatement::Declaration(ref var) => write!(f, "assert({})", var),
ZirStatement::Definition(ref lhs, ref rhs) => write!(f, "{} = {}", lhs, rhs),
ZirStatement::Condition(ref lhs, ref rhs) => write!(f, "{} == {}", lhs, rhs),
ZirStatement::Assertion(ref e) => write!(f, "{}", e),
ZirStatement::MultipleDefinition(ref ids, ref rhs) => {
for (i, id) in ids.iter().enumerate() {
write!(f, "{}", id)?;
@ -399,6 +397,7 @@ pub enum BooleanExpression<'ast, T> {
Box<BooleanExpression<'ast, T>>,
Box<BooleanExpression<'ast, T>>,
),
UintEq(Box<UExpression<'ast, T>>, Box<UExpression<'ast, T>>),
Ge(
Box<FieldElementExpression<'ast, T>>,
Box<FieldElementExpression<'ast, T>>,
@ -511,6 +510,7 @@ impl<'ast, T: fmt::Display> fmt::Display for BooleanExpression<'ast, T> {
BooleanExpression::Le(ref lhs, ref rhs) => write!(f, "{} <= {}", lhs, rhs),
BooleanExpression::FieldEq(ref lhs, ref rhs) => write!(f, "{} == {}", lhs, rhs),
BooleanExpression::BoolEq(ref lhs, ref rhs) => write!(f, "{} == {}", lhs, rhs),
BooleanExpression::UintEq(ref lhs, ref rhs) => write!(f, "{} == {}", lhs, rhs),
BooleanExpression::Ge(ref lhs, ref rhs) => write!(f, "{} >= {}", lhs, rhs),
BooleanExpression::Gt(ref lhs, ref rhs) => write!(f, "{} > {}", lhs, rhs),
BooleanExpression::Or(ref lhs, ref rhs) => write!(f, "{} || {}", lhs, rhs),

View file

@ -15,7 +15,7 @@ fn out_of_range() {
let source = r#"
def main(private field a) -> (field):
field x = if a < 5555 then 3333 else 4444 fi
x == 3333
assert(x == 3333)
return 1
"#
.to_string();

View file

@ -1,4 +1,4 @@
import "utils/pack/nonStrictUnpack256.zok" as unpack256
import "utils/pack/bool/nonStrictUnpack256.zok" as unpack256
def main(field[2] inputs) -> (bool[512]):

View file

@ -1,3 +1,3 @@
def main(field a) -> ():
a == 1
assert(a == 1)
return

View file

@ -6,9 +6,9 @@ def local(field a) -> (field): // this costs 3 constraints per call
def main(field a) -> ():
// calling a local function many times with the same arg should cost only once
local(a) + local(a) + local(a) + local(a) + local(a) == 5 * (a ** 8)
assert(local(a) + local(a) + local(a) + local(a) + local(a) == 5 * (a ** 8))
// calling an imported function many times with the same arg should cost only once
dep(a) + dep(a) + dep(a) + dep(a) + dep(a) == 5 * (a ** 4)
assert(dep(a) + dep(a) + dep(a) + dep(a) + dep(a) == 5 * (a ** 4))
return

View file

@ -4,15 +4,11 @@
"tests": [
{
"input": {
"values": [
"12"
]
"values": []
},
"output": {
"Ok": {
"values": [
"12"
]
"values": []
}
}
}

View file

@ -1,8 +1,8 @@
def main(field g) -> (field):
9 == 1 + 2 * 2 ** 2 // Checks precedence of arithmetic operators (expecting transitive behaviour)
9 == 2 ** 2 * 2 + 1
7 == 2 ** 2 * 2 - 1
3 == 2 ** 2 / 2 + 1
def main() -> ():
assert(9 == 1 + 2 * 2 ** 2) // Checks precedence of arithmetic operators (expecting transitive behaviour)
assert(9 == 2 ** 2 * 2 + 1)
assert(7 == 2 ** 2 * 2 - 1)
assert(3 == 2 ** 2 / 2 + 1)
field a = if 3 == 2 ** 2 / 2 + 1 && true then 1 else 0 fi // combines arithmetic with boolean operators
field b = if 3 == 3 && 4 < 5 then 1 else 0 fi // checks precedence of boolean operators
@ -11,6 +11,8 @@ def main(field g) -> (field):
field e = if 2 >= 1 && 4 > 5 || 1 == 1 then 1 else 0 fi
field f = if 1 < 2 && false || 4 < 5 && 2 >= 1 then 1 else 0 fi
assert(0x00 ^ 0x00 == 0x00)
//check if all statements have evalutated to true
a * b * c * d * e * f == 1
return g
assert(a * b * c * d * e * f == 1)
return

View file

@ -0,0 +1,4 @@
{
"entry_point": "./tests/tests/uint/eq.zok",
"tests": []
}

View file

@ -0,0 +1,3 @@
def main(private u32 a, u32 b) -> (field):
field result = if a * a == b then 1 else 0 fi
return result

View file

@ -60,47 +60,47 @@ def main(u32 e, u32 f, u32[4] terms) -> ():
// rotate
u32 rotated = right_rotate_4(e)
rotated == 0x81234567
assert(rotated == 0x81234567)
// and
(e & f) == 0x00204460
assert((e & f) == 0x00204460)
// xor
(e ^ f) == 0x1317131f
assert((e ^ f) == 0x1317131f)
// shift
e >> 12 == 0x00012345
e << 12 == 0x45678000
assert(e >> 12 == 0x00012345)
assert(e << 12 == 0x45678000)
// not
!e == 0xedcba987
assert(!e == 0xedcba987)
// add
terms[0] + terms[1] + terms[2] + terms[3] == 0xddddddda
assert(terms[0] + terms[1] + terms[2] + terms[3] == 0xddddddda)
// to_bits
bool[32] bits1 = to_bits(e)
bool[32] expected1 = [false, false, false, true, false, false, true, false, false, false, true, true, false, true, false, false, false, true, false, true, false, true, true, false, false, true, true, true, true, false, false, false]
bits1 == expected1
e == from_bits(expected1)
assert(bits1 == expected1)
assert(e == from_bits(expected1))
bool[32] bits2 = to_bits(f)
bool[32] expected2 = [false, false, false, false, false, false, false, true, false, false, true, false, false, false, true, true, false, true, false, false, false, true, false, true, false, true, true, false, false, true, true, true]
bits2 == expected2
f == from_bits(expected2)
assert(bits2 == expected2)
assert(f == from_bits(expected2))
// S0
u32 e2 = right_rotate_2(e)
u32 e13 = right_rotate_13(e)
u32 e22 = right_rotate_22(e)
u32 S0 = e2 ^ e13 ^ e22
S0 == 0x66146474
assert(S0 == 0x66146474)
// S1
u32 e6 = right_rotate_6(e)
u32 e11 = right_rotate_11(e)
u32 e25 = right_rotate_25(e)
u32 S1 = e6 ^ e11 ^ e25
S1 == 0x3561abda
assert(S1 == 0x3561abda)
return

View file

@ -54,6 +54,7 @@ def right_rotate_22(u32 e) -> (u32):
def right_rotate_25(u32 e) -> (u32):
bool[32] b = to_bits(e)
u32 res = from_bits([...b[7..], ...b[..7]])
return res
def main() -> ():
@ -62,47 +63,47 @@ def main() -> ():
// rotate
u32 rotated = right_rotate_4(e)
rotated == 0x81234567
assert(rotated == 0x81234567)
// and
(e & f) == 0x00204460
assert((e & f) == 0x00204460)
// xor
(e ^ f) == 0x1317131f
assert((e ^ f) == 0x1317131f)
// shift
e >> 12 == 0x00012345
e << 12 == 0x45678000
assert(e >> 12 == 0x00012345)
assert(e << 12 == 0x45678000)
// not
!e == 0xedcba987
assert(!e == 0xedcba987)
// add
0xfefefefe + 0xefefefef + 0xffffffff + 0xeeeeeeee == 0xddddddda
assert(0xfefefefe + 0xefefefef + 0xffffffff + 0xeeeeeeee == 0xddddddda)
// to_bits
bool[32] bits1 = to_bits(e)
bool[32] expected1 = [false, false, false, true, false, false, true, false, false, false, true, true, false, true, false, false, false, true, false, true, false, true, true, false, false, true, true, true, true, false, false, false]
bits1 == expected1
e == from_bits(expected1)
assert(bits1 == expected1)
assert(e == from_bits(expected1))
bool[32] bits2 = to_bits(f)
bool[32] expected2 = [false, false, false, false, false, false, false, true, false, false, true, false, false, false, true, true, false, true, false, false, false, true, false, true, false, true, true, false, false, true, true, true]
bits2 == expected2
f == from_bits(expected2)
assert(bits2 == expected2)
assert(f == from_bits(expected2))
// S0
u32 e2 = right_rotate_2(e)
u32 e13 = right_rotate_13(e)
u32 e22 = right_rotate_22(e)
u32 S0 = e2 ^ e13 ^ e22
S0 == 0x66146474
assert(S0 == 0x66146474)
// S1
u32 e6 = right_rotate_6(e)
u32 e11 = right_rotate_11(e)
u32 e25 = right_rotate_25(e)
u32 S1 = e6 ^ e11 ^ e25
S1 == 0x3561abda
assert(S1 == 0x3561abda)
return

View file

@ -1,6 +1,6 @@
{
"entry_point": "./tests/tests/uint/sha256.zok",
"max_constraint_count": 43000,
"max_constraint_count": 30000,
"tests": [
{
"input": {

View file

@ -1,14 +0,0 @@
[package]
name = "zokrates_embed"
version = "0.1.1"
authors = ["schaeff <thibaut@schaeff.fr>"]
edition = "2018"
[features]
default = ["bellman_ce/nolog"]
wasm = ["bellman_ce/wasm", "sapling-crypto_ce/wasm"]
multicore = ["bellman_ce/multicore", "sapling-crypto_ce/multicore"]
[dependencies]
bellman_ce = { version = "^0.3", default-features = false}
sapling-crypto_ce = { version = "0.1.3", default-features = false }

View file

@ -1,319 +0,0 @@
extern crate sapling_crypto_ce as sapling_crypto;
use sapling_crypto::bellman;
use bellman::{
pairing::{ff::Field, Engine},
ConstraintSystem, Index, LinearCombination, SynthesisError, Variable,
};
use sapling_crypto::circuit::{
boolean::{AllocatedBit, Boolean},
sha256::sha256_compression_function,
uint32::UInt32,
};
#[derive(Debug)]
pub struct BellmanR1CS<E: Engine> {
pub aux_count: usize,
pub constraints: Vec<BellmanConstraint<E>>,
}
impl<E: Engine> BellmanR1CS<E> {
pub fn new() -> Self {
BellmanR1CS {
aux_count: 0,
constraints: vec![],
}
}
}
#[derive(Debug)]
pub struct BellmanWitness<E: Engine> {
pub values: Vec<E::Fr>,
}
#[derive(Debug, PartialEq)]
pub struct BellmanConstraint<E: Engine> {
pub a: Vec<(usize, E::Fr)>,
pub b: Vec<(usize, E::Fr)>,
pub c: Vec<(usize, E::Fr)>,
}
fn sha256_round<E: Engine, CS: ConstraintSystem<E>>(
mut cs: CS,
input: &Vec<Option<E::Fr>>,
current_hash: &Vec<Option<E::Fr>>,
) -> Result<(Vec<usize>, Vec<usize>, Vec<usize>), SynthesisError> {
// Allocate bits for `input`
let input_bits = input
.iter()
.enumerate()
.map(|(index, i)| {
AllocatedBit::alloc::<E, _>(
&mut cs.namespace(|| format!("input_{}", index)),
Some(*i == Some(<E::Fr as Field>::one())),
)
.unwrap()
})
.collect::<Vec<_>>();
// Define Booleans whose values are the defined bits
let input = input_bits
.iter()
.map(|i| Boolean::Is(i.clone()))
.collect::<Vec<_>>();
// Allocate bits for `current_hash`
let current_hash_bits = current_hash
.iter()
.enumerate()
.map(|(index, i)| {
AllocatedBit::alloc::<E, _>(
&mut cs.namespace(|| format!("current_hash_{}", index)),
Some(*i == Some(<E::Fr as Field>::one())),
)
.unwrap()
})
.collect::<Vec<_>>();
// Define Booleans whose values are the defined bits
let current_hash = current_hash_bits
.chunks(32)
.map(|chunk| {
UInt32::from_bits_be(
&chunk
.into_iter()
.map(|i| Boolean::Is(i.clone()))
.collect::<Vec<_>>(),
)
})
.collect::<Vec<_>>();
// Apply the compression function, returning the 8 bytes of outputs
let res = sha256_compression_function::<E, _>(&mut cs, &input, &current_hash).unwrap();
// Extract the 256 bits of output out of the 8 bytes
let output_bits = res
.into_iter()
.flat_map(|u| u.into_bits_be())
.map(|b| b.get_variable().unwrap().clone())
.collect::<Vec<_>>();
// Return indices of `input`, `current_hash` and `output` in the CS
Ok((
input_bits
.into_iter()
.map(|b| var_to_index(b.get_variable()))
.collect(),
current_hash_bits
.into_iter()
.map(|b| var_to_index(b.get_variable()))
.collect(),
output_bits
.into_iter()
.map(|b| var_to_index(b.get_variable()))
.collect(),
))
}
impl<E: Engine> ConstraintSystem<E> for BellmanWitness<E> {
type Root = Self;
fn alloc<F, A, AR>(&mut self, _: A, f: F) -> Result<Variable, SynthesisError>
where
F: FnOnce() -> Result<E::Fr, SynthesisError>,
A: FnOnce() -> AR,
AR: Into<String>,
{
let index = self.values.len();
let var = Variable::new_unchecked(Index::Aux(index));
self.values.push(f().unwrap());
Ok(var)
}
fn alloc_input<F, A, AR>(&mut self, _: A, _: F) -> Result<Variable, SynthesisError>
where
F: FnOnce() -> Result<E::Fr, SynthesisError>,
A: FnOnce() -> AR,
AR: Into<String>,
{
unreachable!("Bellman helpers are not allowed to allocate public variables")
}
fn enforce<A, AR, LA, LB, LC>(&mut self, _: A, _: LA, _: LB, _: LC)
where
A: FnOnce() -> AR,
AR: Into<String>,
LA: FnOnce(LinearCombination<E>) -> LinearCombination<E>,
LB: FnOnce(LinearCombination<E>) -> LinearCombination<E>,
LC: FnOnce(LinearCombination<E>) -> LinearCombination<E>,
{
// do nothing
}
fn push_namespace<NR, N>(&mut self, _: N)
where
NR: Into<String>,
N: FnOnce() -> NR,
{
// do nothing
}
fn pop_namespace(&mut self) {
// do nothing
}
fn get_root(&mut self) -> &mut Self::Root {
self
}
}
impl<E: Engine> ConstraintSystem<E> for BellmanR1CS<E> {
type Root = Self;
fn alloc<F, A, AR>(&mut self, _: A, _: F) -> Result<Variable, SynthesisError>
where
F: FnOnce() -> Result<E::Fr, SynthesisError>,
A: FnOnce() -> AR,
AR: Into<String>,
{
// we don't care about the value as we're only generating the CS
let index = self.aux_count;
let var = Variable::new_unchecked(Index::Aux(index));
self.aux_count += 1;
Ok(var)
}
fn alloc_input<F, A, AR>(&mut self, _: A, _: F) -> Result<Variable, SynthesisError>
where
F: FnOnce() -> Result<E::Fr, SynthesisError>,
A: FnOnce() -> AR,
AR: Into<String>,
{
unreachable!("Bellman helpers are not allowed to allocate public variables")
}
fn enforce<A, AR, LA, LB, LC>(&mut self, _: A, a: LA, b: LB, c: LC)
where
A: FnOnce() -> AR,
AR: Into<String>,
LA: FnOnce(LinearCombination<E>) -> LinearCombination<E>,
LB: FnOnce(LinearCombination<E>) -> LinearCombination<E>,
LC: FnOnce(LinearCombination<E>) -> LinearCombination<E>,
{
let a = a(LinearCombination::zero());
let b = b(LinearCombination::zero());
let c = c(LinearCombination::zero());
let a = a
.as_ref()
.into_iter()
.map(|(variable, coefficient)| (var_to_index(*variable), *coefficient))
.collect();
let b = b
.as_ref()
.into_iter()
.map(|(variable, coefficient)| (var_to_index(*variable), *coefficient))
.collect();
let c = c
.as_ref()
.into_iter()
.map(|(variable, coefficient)| (var_to_index(*variable), *coefficient))
.collect();
self.constraints.push(BellmanConstraint { a, b, c });
}
fn push_namespace<NR, N>(&mut self, _: N)
where
NR: Into<String>,
N: FnOnce() -> NR,
{
// do nothing
}
fn pop_namespace(&mut self) {
// do nothing
}
fn get_root(&mut self) -> &mut Self::Root {
self
}
}
pub fn generate_sha256_round_constraints<E: Engine>(
) -> (BellmanR1CS<E>, Vec<usize>, Vec<usize>, Vec<usize>) {
let mut cs = BellmanR1CS::new();
let (input_bits, current_hash_bits, output_bits) =
sha256_round(&mut cs, &vec![None; 512], &vec![None; 256]).unwrap();
// res is now the allocated bits for `input`, `current_hash` and `sha256_output`
(cs, input_bits, current_hash_bits, output_bits)
}
pub fn generate_sha256_round_witness<E: Engine>(
input: &[E::Fr],
current_hash: &[E::Fr],
) -> Vec<E::Fr> {
assert_eq!(input.len(), 512);
assert_eq!(current_hash.len(), 256);
let mut cs: BellmanWitness<E> = BellmanWitness {
values: vec![<E::Fr as Field>::one()],
};
sha256_round(
&mut cs,
&input.iter().map(|x| Some(x.clone())).collect(),
&current_hash.iter().map(|x| Some(x.clone())).collect(),
)
.unwrap();
cs.values
}
fn var_to_index(v: Variable) -> usize {
match v.get_unchecked() {
Index::Aux(i) => i + 1,
Index::Input(0) => 0,
_ => unreachable!("No public variables should have been allocated"),
}
}
#[cfg(test)]
mod tests {
use super::*;
use bellman::pairing::bn256::{Bn256, Fr};
#[test]
fn generate_constraints() {
let (_c, input, current_hash, output) = generate_sha256_round_constraints::<Bn256>();
assert_eq!(input.len(), 512);
assert_eq!(current_hash.len(), 256);
assert_eq!(output.len(), 256);
}
#[test]
fn generate_witness() {
let witness =
generate_sha256_round_witness::<Bn256>(&vec![Fr::one(); 512], &vec![Fr::zero(); 256]);
assert_eq!(witness.len(), 26935);
}
#[test]
fn test_cs() {
use sapling_crypto::circuit::test::TestConstraintSystem;
let mut cs: TestConstraintSystem<Bn256> = TestConstraintSystem::new();
let _ = sha256_round(
&mut cs,
&vec![Some(Fr::zero()); 512],
&vec![Some(Fr::one()); 256],
)
.unwrap();
assert!(cs.is_satisfied());
}
}

1105
zokrates_js/Cargo.lock generated Normal file

File diff suppressed because it is too large Load diff

View file

@ -26,8 +26,6 @@ describe('tests', function() {
it('should resolve stdlib module', function() {
assert.doesNotThrow(() => {
const code = `
import "hashes/sha256/512bit" as sha256
import "ecc/edwardsAdd" as edwardsAdd
def main() -> ():
return
`;

View file

@ -37,7 +37,7 @@ ace.define("ace/mode/zokrates_highlight_rules",["require","exports","module","ac
var ZoKratesHighlightRules = function () {
var keywords = (
"endfor|as|return|byte|field|bool|if|then|fi|do|else|export|false|def|for|import|from|uint|in|public|private|struct|true"
"assert|endfor|as|return|byte|field|bool|if|then|fi|do|else|export|false|def|for|import|from|uint|in|public|private|struct|true"
);
var keywordMapper = this.createKeywordMapper({

View file

@ -46,7 +46,7 @@ statement = { (return_statement // does not require subsequent newline
iteration_statement = { "for" ~ ty ~ identifier ~ "in" ~ expression ~ ".." ~ expression ~ "do" ~ NEWLINE* ~ statement* ~ "endfor"}
return_statement = { "return" ~ expression_list}
definition_statement = { optionally_typed_assignee_list ~ "=" ~ expression } // declare and assign, so only identifiers are allowed, unlike `assignment_statement`
expression_statement = {expression}
expression_statement = {"assert" ~ "(" ~ expression ~ ")"}
optionally_typed_assignee_list = _{ optionally_typed_assignee ~ ("," ~ optionally_typed_assignee)* }
optionally_typed_assignee = { (ty ~ assignee) | (assignee) } // we don't use { ty? ~ identifier } as with a single token, it gets parsed as `ty` but we want `identifier`
@ -124,6 +124,8 @@ op_unary = { op_not }
WHITESPACE = _{ " " | "\t" | "\\" ~ NEWLINE}
COMMENT = _{ ("/*" ~ (!"*/" ~ ANY)* ~ "*/") | ("//" ~ (!NEWLINE ~ ANY)*) }
keyword = @{"as"|"bool"|"byte"|"def"|"do"|"else"|"endfor"|"export"|"false"|"field"|"for"|"if"|"then"|"fi"|"import"|"from"|
// the ordering of reserved keywords matters: if "as" is before "assert", then "assert" gets parsed as (as)(sert) and incorrectly
// accepted
keyword = @{"assert"|"as"|"bool"|"byte"|"def"|"do"|"else"|"endfor"|"export"|"false"|"field"|"for"|"if"|"then"|"fi"|"import"|"from"|
"in"|"private"|"public"|"return"|"struct"|"true"|"u8"|"u16"|"u32"
}

View file

@ -32,21 +32,22 @@ mod ast {
static ref PREC_CLIMBER: PrecClimber<Rule> = build_precedence_climber();
}
// based on https://docs.python.org/3/reference/expressions.html#operator-precedence
fn build_precedence_climber() -> PrecClimber<Rule> {
PrecClimber::new(vec![
Operator::new(Rule::op_or, Assoc::Left),
Operator::new(Rule::op_and, Assoc::Left),
Operator::new(Rule::op_lt, Assoc::Left)
| Operator::new(Rule::op_lte, Assoc::Left)
| Operator::new(Rule::op_gt, Assoc::Left)
| Operator::new(Rule::op_gte, Assoc::Left)
| Operator::new(Rule::op_not_equal, Assoc::Left)
| Operator::new(Rule::op_equal, Assoc::Left),
Operator::new(Rule::op_bit_or, Assoc::Left),
Operator::new(Rule::op_bit_xor, Assoc::Left),
Operator::new(Rule::op_bit_and, Assoc::Left),
Operator::new(Rule::op_equal, Assoc::Left)
| Operator::new(Rule::op_not_equal, Assoc::Left),
Operator::new(Rule::op_lte, Assoc::Left)
| Operator::new(Rule::op_gte, Assoc::Left)
| Operator::new(Rule::op_lt, Assoc::Left)
| Operator::new(Rule::op_gt, Assoc::Left),
Operator::new(Rule::op_right_shift, Assoc::Left)
| Operator::new(Rule::op_left_shift, Assoc::Left),
Operator::new(Rule::op_left_shift, Assoc::Left)
| Operator::new(Rule::op_right_shift, Assoc::Left),
Operator::new(Rule::op_add, Assoc::Left) | Operator::new(Rule::op_sub, Assoc::Left),
Operator::new(Rule::op_mul, Assoc::Left) | Operator::new(Rule::op_div, Assoc::Left),
Operator::new(Rule::op_pow, Assoc::Left),
@ -1174,9 +1175,9 @@ mod tests {
field a = 1
a[32 + x][55] = y
for field i in 0..3 do
a == 1 + 2 + 3+ 4+ 5+ 6+ 6+ 7+ 8 + 4+ 5+ 3+ 4+ 2+ 3
assert(a == 1 + 2 + 3+ 4+ 5+ 6+ 6+ 7+ 8 + 4+ 5+ 3+ 4+ 2+ 3)
endfor
a.member == 1
assert(a.member == 1)
return a
"#;
let res = generate_ast(&source);

View file

@ -1,4 +1,4 @@
import "utils/pack/nonStrictUnpack256" as unpack256
import "utils/pack/bool/nonStrictUnpack256" as unpack256
// Compress JubJub Curve Point to 256bit array using big endianness bit order
// Python reference code from pycrypto:

View file

@ -13,6 +13,6 @@ def main(field[2] pt, BabyJubJubParams context) -> (bool):
field vv = pt[1] * pt[1]
field uuvv = uu * vv
a * uu + vv == 1 + d * uuvv
assert(a * uu + vv == 1 + d * uuvv)
return true

View file

@ -1,19 +1,19 @@
import "ecc/edwardsAdd" as add
import "ecc/edwardsScalarMult" as multiply
import "utils/pack/nonStrictUnpack256" as unpack256
import "utils/pack/bool/nonStrictUnpack256" as unpack256
from "ecc/babyjubjubParams" import BabyJubJubParams
// Verifies that the point is not one of the low-order points.
// If any of the points is multiplied by the cofactor, the resulting point
// will be infinity.
// Returns 1 if the point is not one of the low-order points, 0 otherwise.
// Returns true if the point is not one of the low-order points, false otherwise.
// Curve parameters are defined with the last argument
// https://github.com/zcash-hackworks/sapling-crypto/blob/master/src/jubjub/edwards.rs#L166
def main(field[2] pt, BabyJubJubParams context) -> (field):
def main(field[2] pt, BabyJubJubParams context) -> (bool):
field cofactor = context.JUBJUBC
cofactor == 8
assert(cofactor == 8)
// Co-factor currently hard-coded to 8 for efficiency reasons
// See discussion here: https://github.com/Zokrates/ZoKrates/pull/301#discussion_r267203391
@ -24,6 +24,4 @@ def main(field[2] pt, BabyJubJubParams context) -> (field):
ptExp = add(ptExp, ptExp, context) // 4*pt
ptExp = add(ptExp, ptExp, context) // 8*pt
field out = if ptExp[0] == 0 && ptExp[1] == 1 then 0 else 1 fi
return out
return !(ptExp[0] == 0 && ptExp[1] == 1)

View file

@ -1,5 +1,5 @@
import "ecc/edwardsAdd" as add
import "ecc/edwardsOnCurve" as assertOnCurve
import "ecc/edwardsOnCurve" as onCurve
from "ecc/babyjubjubParams" import BabyJubJubParams
// Function that implements scalar multiplication for a fixed base point
@ -22,6 +22,6 @@ def main(bool[256] exponent, field[2] pt, BabyJubJubParams context) -> (field[2]
doubledP = add(doubledP, doubledP, context)
endfor
true == assertOnCurve(accumulatedP, context)
assert(onCurve(accumulatedP, context))
return accumulatedP

View file

@ -1,6 +1,6 @@
import "ecc/edwardsAdd" as add
import "ecc/edwardsScalarMult" as multiply
import "utils/pack/nonStrictUnpack256" as unpack256
import "utils/pack/bool/nonStrictUnpack256" as unpack256
from "ecc/babyjubjubParams" import BabyJubJubParams
/// Verifies match of a given public/private keypair.
@ -24,6 +24,6 @@ def main(field[2] pk, field sk, BabyJubJubParams context) -> (bool):
bool[256] skBits = unpack256(sk)
field[2] ptExp = multiply(skBits, G, context)
bool out = ptExp[0] == pk[0] && ptExp[1] == pk[1]
bool out = ptExp[0] == pk[0] && ptExp[1] == pk[1]
return out

View file

@ -4,6 +4,8 @@ import "ecc/babyjubjubParams" as context
import "ecc/edwardsAdd" as add
import "ecc/edwardsCompress" as edwardsCompress
from "ecc/babyjubjubParams" import BabyJubJubParams
import "EMBED/u32_to_bits" as to_bits
import "EMBED/u32_from_bits" as from_bits
// Code to export generators used in this example:
// import bitstring
@ -16,7 +18,27 @@ from "ecc/babyjubjubParams" import BabyJubJubParams
// print(hasher.dsl_code)
// 512bit to 256bit Pedersen hash using compression of the field elements
def main(bool[512] e) -> (bool[256]):
def main(u32[16] input) -> (u32[8]):
bool[512] e = [ \
...to_bits(input[0]),
...to_bits(input[1]),
...to_bits(input[2]),
...to_bits(input[3]),
...to_bits(input[4]),
...to_bits(input[5]),
...to_bits(input[6]),
...to_bits(input[7]),
...to_bits(input[8]),
...to_bits(input[9]),
...to_bits(input[10]),
...to_bits(input[11]),
...to_bits(input[12]),
...to_bits(input[13]),
...to_bits(input[14]),
...to_bits(input[15])
]
BabyJubJubParams context = context()
field[2] a = context.INFINITY //Infinity
//Round 0
@ -705,4 +727,14 @@ def main(bool[512] e) -> (bool[256]):
a = add(a, [cx, cy], context)
bool[256] aC = edwardsCompress(a)
return aC
return [\
from_bits(aC[0..32]),
from_bits(aC[32..64]),
from_bits(aC[64..96]),
from_bits(aC[96..128]),
from_bits(aC[128..160]),
from_bits(aC[160..192]),
from_bits(aC[192..224]),
from_bits(aC[224..256])
]

View file

@ -1,13 +1,12 @@
import "./IVconstants" as IVconstants
import "./shaRoundNoBoolCheck" as sha256
import "./shaRound" as sha256
// A function that takes 4 bool[256] arrays as inputs
// and applies 2 rounds of sha256 compression.
// It returns an array of 256 bool.
def main(bool[256] a, bool[256] b, bool[256] c, bool[256] d) -> (bool[256]):
def main(u32[8] a, u32[8] b, u32[8] c, u32[8] d) -> (u32[8]):
bool[256] IV = IVconstants()
bool[256] digest1 = sha256(a, b, IV)
bool[256] digest2 = sha256(c, d, digest1)
u32[8] IV = IVconstants()
u32[8] digest1 = sha256([...a, ...b], IV)
u32[8] digest2 = sha256([...c, ...d], digest1)
return digest2

View file

@ -1,15 +1,31 @@
import "./1536bit" as sha256
// Take two bool[256] arrays as input
// and returns their sha256 full round output as an array of 256 bool.
def main(bool[256] a, bool[256] b, bool[256] c, bool[256] d) -> (bool[256]):
def main(u32[8] a, u32[8] b, u32[8] c, u32[8] d) -> (u32[8]):
// Hash is computed on the full 1024bit block size
// padding does not fit in the first two blocks
// add dummy block (single "1" followed by "0" + total length)
bool[256] dummyblock1 = [true, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false]
// total length of message is 1024 bits: 0b10000000000
bool[256] dummyblock2 = [false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, true, false, false, false, false, false, false, false, false, false, false]
u32[8] dummyblock1 = [ \
0x80000000,
0x00000000,
0x00000000,
0x00000000,
0x00000000,
0x00000000,
0x00000000,
0x00000000
]
bool[256] digest = sha256(a, b, c, d, dummyblock1, dummyblock2)
u32[8] dummyblock2 = [ \
0x00000000,
0x00000000,
0x00000000,
0x00000000,
0x00000000,
0x00000000,
0x00000000,
0x00000400
]
return digest
return sha256(a, b, c, d, dummyblock1, dummyblock2)

View file

@ -1,14 +1,14 @@
import "./IVconstants" as IVconstants
import "./shaRoundNoBoolCheck" as sha256
import "./shaRound" as sha256
// A function that takes 6 bool[256] arrays as inputs
// A function that takes 6 u32[8] arrays as inputs
// and applies 3 rounds of sha256 compression.
// It returns an array of 256 bool.
def main(bool[256] a, bool[256] b, bool[256] c, bool[256] d, bool[256] e, bool[256] f) -> (bool[256]):
def main(u32[8] a, u32[8] b, u32[8] c, u32[8] d, u32[8] e, u32[8] f) -> (u32[8]):
bool[256] IV = IVconstants()
bool[256] digest1 = sha256(a, b, IV)
bool[256] digest2 = sha256(c, d, digest1)
bool[256] digest3 = sha256(e, f, digest2)
u32[8] IV = IVconstants()
u32[8] digest1 = sha256([...a, ...b], IV)
u32[8] digest2 = sha256([...c, ...d], digest1)
u32[8] digest3 = sha256([...e, ...f], digest2)
return digest3

View file

@ -2,29 +2,20 @@ import "./512bit" as sha256
// A function that takes 1 bool[256] array as input
// and returns the sha256 full round output as an array of 256 bool.
def main(bool[256] a) -> (bool[256]):
def main(u32[8] a) -> (u32[8]):
// Hash is computed on 256 bits of input
// padding fits in the remaining 256 bits of the first block
// add dummy block (single "1" followed by "0" + total length)
bool[256] dummyblock1 = [ \
1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, \
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, \
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, \
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, \
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, \
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, \
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, \
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, \
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, \
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, \
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, \
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, \
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, \
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, \
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, \
0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0]
u32[8] dummyblock1 = [ \
0x80000000,
0x00000000,
0x00000000,
0x00000000,
0x00000000,
0x00000000,
0x00000000,
0x00000100
]
digest = sha256(a, dummyblock1)
return digest
return sha256(a, dummyblock1)

View file

@ -1,15 +1,9 @@
import "./IVconstants" as IVconstants
import "./shaRoundNoBoolCheck" as sha256
import "./shaRound" as sha256
// A function that takes 2 bool[256] arrays as inputs
// and returns their sha256 compression function as an array of 256 bool.
// In contrast to full_round.zok no padding is being applied
def main(bool[256] a, bool[256] b) -> (bool[256]):
// A function that takes 2 u32[8] arrays as inputs
// and returns their sha256 compression function as an array of 8 u32.
// a and b is NOT checked to be of type bool
def main(u32[8] a, u32[8] b) -> (u32[8]):
bool[256] IV = IVconstants()
bool[256] digest = sha256(a, b, IV)
//digest is constraint to be of type bool
return digest
return sha256([...a, ...b], IVconstants())

View file

@ -1,22 +1,19 @@
import "../../utils/pack/pack128" as pack128
import "../../utils/pack/unpack128" as unpack128
import "../../utils/pack/u32/pack128" as pack128
import "../../utils/pack/u32/unpack128" as unpack128
import "./512bitPadded" as sha256
// A function that takes an array of 4 field elements as inputs, unpacks each of them to 128
// bits (big endian), concatenates them and applies sha256.
// It then returns an array of two field elements, each representing 128 bits of the result.
def main(field[4] preimage) -> (field[2]):
bool[128] a = unpack128(preimage[0])
bool[128] b = unpack128(preimage[1])
bool[128] c = unpack128(preimage[2])
bool[128] d = unpack128(preimage[3])
u32[4] a_bits = unpack128(preimage[0])
u32[4] b_bits = unpack128(preimage[1])
u32[4] c_bits = unpack128(preimage[2])
u32[4] d_bits = unpack128(preimage[3])
bool[256] lhs = [...a, ...b]
bool[256] rhs = [...c, ...d]
u32[8] lhs = [...a_bits, ...b_bits]
u32[8] rhs = [...c_bits, ...d_bits]
bool[256] r = sha256(lhs, rhs)
u32[8] r = sha256(lhs, rhs)
field res0 = pack128(r[..128])
field res1 = pack128(r[128..])
return [res0, res1]
return [pack128(r[0..4]), pack128(r[4..8])]

View file

@ -2,16 +2,31 @@ import "./1024bit" as sha256
// A function that takes 2 bool[256] arrays as inputs
// and returns their sha256 full round output as an array of 256 bool.
def main(bool[256] a, bool[256] b) -> (bool[256]):
def main(u32[8] a, u32[8] b) -> (u32[8]):
// Hash is computed on the full 512bit block size
// padding does not fit in the primary block
// add dummy block (single "1" followed by "0" + total length)
bool[256] dummyblock1 = [true, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false]
u32[8] dummyblock1 = [ \
0x80000000,
0x00000000,
0x00000000,
0x00000000,
0x00000000,
0x00000000,
0x00000000,
0x00000000
]
// total length of message is 512 bits: 0b1000000000
bool[256] dummyblock2 = [false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, true, false, false, false, false, false, false, false, false, false]
u32[8] dummyblock2 = [ \
0x00000000,
0x00000000,
0x00000000,
0x00000000,
0x00000000,
0x00000000,
0x00000000,
0x00000200
]
bool[256] digest = sha256(a, b, dummyblock1, dummyblock2)
return digest
return sha256(a, b, dummyblock1, dummyblock2)

View file

@ -1,15 +1,4 @@
// SHA-256 is specified in FIPS 180-3 and initial values are listed in section 5.3.3
// https://csrc.nist.gov/csrc/media/publications/fips/180/3/archive/2008-10-31/documents/fips180-3_final.pdf
def main() -> (bool[256]):
bool[32] h0 = [false, true, true, false, true, false, true, false, false, false, false, false, true, false, false, true, true, true, true, false, false, true, true, false, false, true, true, false, false, true, true, true]
bool[32] h1 = [true, false, true, true, true, false, true, true, false, true, true, false, false, true, true, true, true, false, true, false, true, true, true, false, true, false, false, false, false, true, false, true]
bool[32] h2 = [false, false, true, true, true, true, false, false, false, true, true, false, true, true, true, false, true, true, true, true, false, false, true, true, false, true, true, true, false, false, true, false]
bool[32] h3 = [true, false, true, false, false, true, false, true, false, true, false, false, true, true, true, true, true, true, true, true, false, true, false, true, false, false, true, true, true, false, true, false]
bool[32] h4 = [false, true, false, true, false, false, false, true, false, false, false, false, true, true, true, false, false, true, false, true, false, false, true, false, false, true, true, true, true, true, true, true]
bool[32] h5 = [true, false, false, true, true, false, true, true, false, false, false, false, false, true, false, true, false, true, true, false, true, false, false, false, true, false, false, false, true, true, false, false]
bool[32] h6 = [false, false, false, true, true, true, true, true, true, false, false, false, false, false, true, true, true, true, false, true, true, false, false, true, true, false, true, false, true, false, true, true]
bool[32] h7 = [false, true, false, true, true, false, true, true, true, true, true, false, false, false, false, false, true, true, false, false, true, true, false, true, false, false, false, true, true, false, false, true]
bool[256] IV = [...h0, ...h1, ...h2, ...h3, ...h4, ...h5, ...h6, ...h7]
return IV
def main() -> (u32[8]):
return [0x6a09e667, 0xbb67ae85, 0x3c6ef372, 0xa54ff53a, 0x510e527f, 0x9b05688c, 0x1f83d9ab, 0x5be0cd19]

View file

@ -0,0 +1,126 @@
import "EMBED/u32_to_bits" as to_bits
import "EMBED/u32_from_bits" as from_bits
import "./IVconstants.zok"
def right_rotate_2(u32 e) -> (u32):
bool[32] b = to_bits(e)
return from_bits([...b[30..], ...b[..30]])
def right_rotate_6(u32 e) -> (u32):
bool[32] b = to_bits(e)
return from_bits([...b[26..], ...b[..26]])
def right_rotate_7(u32 e) -> (u32):
bool[32] b = to_bits(e)
return from_bits([...b[25..], ...b[..25]])
def right_rotate_11(u32 e) -> (u32):
bool[32] b = to_bits(e)
return from_bits([...b[21..], ...b[..21]])
def right_rotate_13(u32 e) -> (u32):
bool[32] b = to_bits(e)
return from_bits([...b[19..], ...b[..19]])
def right_rotate_17(u32 e) -> (u32):
bool[32] b = to_bits(e)
return from_bits([...b[15..], ...b[..15]])
def right_rotate_18(u32 e) -> (u32):
bool[32] b = to_bits(e)
return from_bits([...b[14..], ...b[..14]])
def right_rotate_19(u32 e) -> (u32):
bool[32] b = to_bits(e)
return from_bits([...b[13..], ...b[..13]])
def right_rotate_22(u32 e) -> (u32):
bool[32] b = to_bits(e)
return from_bits([...b[10..], ...b[..10]])
def right_rotate_25(u32 e) -> (u32):
bool[32] b = to_bits(e)
return from_bits([...b[7..], ...b[..7]])
def extend(u32[64] w, field i) -> (u32):
u32 s0 = right_rotate_7(w[i-15]) ^ right_rotate_18(w[i-15]) ^ (w[i-15] >> 3)
u32 s1 = right_rotate_17(w[i-2]) ^ right_rotate_19(w[i-2]) ^ (w[i-2] >> 10)
return w[i-16] + s0 + w[i-7] + s1
def temp1(u32 e, u32 f, u32 g, u32 h, u32 k, u32 w) -> (u32):
// ch := (e and f) xor ((not e) and g)
u32 ch = (e & f) ^ ((!e) & g)
// S1 := (e rightrotate 6) xor (e rightrotate 11) xor (e rightrotate 25)
u32 S1 = right_rotate_6(e) ^ right_rotate_11(e) ^ right_rotate_25(e)
// temp1 := h + S1 + ch + k + w
return h + S1 + ch + k + w
def temp2(u32 a, u32 b, u32 c) -> (u32):
// maj := (a and b) xor (a and c) xor (b and c)
u32 maj = (a & b) ^ (a & c) ^ (b & c)
// S0 := (a rightrotate 2) xor (a rightrotate 13) xor (a rightrotate 22)
u32 S0 = right_rotate_2(a) ^ right_rotate_13(a) ^ right_rotate_22(a)
// temp2 := S0 + maj
return S0 + maj
def main(u32[16] input, u32[8] current) -> (u32[8]):
u32 h0 = current[0]
u32 h1 = current[1]
u32 h2 = current[2]
u32 h3 = current[3]
u32 h4 = current[4]
u32 h5 = current[5]
u32 h6 = current[6]
u32 h7 = current[7]
u32[64] k = [0x428a2f98, 0x71374491, 0xb5c0fbcf, 0xe9b5dba5, 0x3956c25b, 0x59f111f1, 0x923f82a4, 0xab1c5ed5, 0xd807aa98, 0x12835b01, 0x243185be, 0x550c7dc3, 0x72be5d74, 0x80deb1fe, 0x9bdc06a7, 0xc19bf174, 0xe49b69c1, 0xefbe4786, 0x0fc19dc6, 0x240ca1cc, 0x2de92c6f, 0x4a7484aa, 0x5cb0a9dc, 0x76f988da, 0x983e5152, 0xa831c66d, 0xb00327c8, 0xbf597fc7, 0xc6e00bf3, 0xd5a79147, 0x06ca6351, 0x14292967, 0x27b70a85, 0x2e1b2138, 0x4d2c6dfc, 0x53380d13, 0x650a7354, 0x766a0abb, 0x81c2c92e, 0x92722c85, 0xa2bfe8a1, 0xa81a664b, 0xc24b8b70, 0xc76c51a3, 0xd192e819, 0xd6990624, 0xf40e3585, 0x106aa070, 0x19a4c116, 0x1e376c08, 0x2748774c, 0x34b0bcb5, 0x391c0cb3, 0x4ed8aa4a, 0x5b9cca4f, 0x682e6ff3, 0x748f82ee, 0x78a5636f, 0x84c87814, 0x8cc70208, 0x90befffa, 0xa4506ceb, 0xbef9a3f7, 0xc67178f2]
u32[64] w = [...input, ...[0x00000000; 48]]
for field i in 16..64 do
u32 r = extend(w, i)
w[i] = r
endfor
u32 a = h0
u32 b = h1
u32 c = h2
u32 d = h3
u32 e = h4
u32 f = h5
u32 g = h6
u32 h = h7
for field i in 0..64 do
u32 t1 = temp1(e, f, g, h, k[i], w[i])
u32 t2 = temp2(a, b, c)
h = g
g = f
f = e
e = d + t1
d = c
c = b
b = a
a = t1 + t2
endfor
h0 = h0 + a
h1 = h1 + b
h2 = h2 + c
h3 = h3 + d
h4 = h4 + e
h5 = h5 + f
h6 = h6 + g
h7 = h7 + h
return [h0, h1, h2, h3, h4, h5, h6, h7]

View file

@ -1,6 +0,0 @@
import "EMBED/sha256round" as sha256round
// a and b is NOT checked to be 0 or 1
// the return value is checked to be 0 or 1
// IV vector is checked to be of type bool
def main(bool[256] a, bool[256] b, bool[256] IV) -> (bool[256]):
return sha256round([...a, ...b], IV)

View file

@ -1,2 +1,2 @@
def main(bool selector, bool[256] lhs, bool[256] rhs) -> (bool[512]):
def main(bool selector, u32[8] lhs, u32[8] rhs) -> (u32[16]):
return if selector then [...rhs, ...lhs] else [...lhs, ...rhs] fi

View file

@ -1,10 +1,12 @@
import "hashes/sha256/1024bitPadded" as sha256
import "ecc/edwardsScalarMult" as scalarMult
import "ecc/edwardsAdd" as add
import "utils/pack/nonStrictUnpack256" as unpack256
import "utils/pack/bool/nonStrictUnpack256" as unpack256bool
import "utils/pack/u32/nonStrictUnpack256" as unpack256u
import "ecc/edwardsOnCurve" as onCurve
import "ecc/edwardsOrderCheck" as orderCheck
from "ecc/babyjubjubParams" import BabyJubJubParams
import "utils/casts/u32_8_to_bool_256"
/// Verifies an EdDSA Signature.
///
@ -27,20 +29,19 @@ from "ecc/babyjubjubParams" import BabyJubJubParams
///
/// Returns:
/// Return true for S being a valid EdDSA Signature, false otherwise.
def main(private field[2] R, private field S, field[2] A, bool[256] M0, bool[256] M1, BabyJubJubParams context) -> (bool):
def main(private field[2] R, private field S, field[2] A, u32[8] M0, u32[8] M1, BabyJubJubParams context) -> (bool):
field[2] G = [context.Gu, context.Gv]
// Check if R is on curve and if it is not in a small subgroup. A is public input and can be checked offline
true == onCurve(R, context) // throws if R is not on curve
field isPrimeOrder = orderCheck(R, context)
1 == isPrimeOrder
assert(onCurve(R, context)) // throws if R is not on curve
assert(orderCheck(R, context))
bool[256] Rx = unpack256(R[0])
bool[256] Ax = unpack256(A[0])
bool[256] hRAM = sha256(Rx, Ax, M0, M1)
u32[8] Rx = unpack256u(R[0])
u32[8] Ax = unpack256u(A[0])
bool[256] hRAM = u32_8_to_bool_256(sha256(Rx, Ax, M0, M1))
bool[256] sBits = unpack256(S)
bool[256] sBits = unpack256bool(S)
field[2] lhs = scalarMult(sBits, G, context)
field[2] AhRAM = scalarMult(hRAM, A, context)

View file

@ -0,0 +1,4 @@
import "EMBED/u32_from_bits" as from_bits
def main(bool[128] bits) -> (u32[4]):
return [from_bits(bits[0..32]), from_bits(bits[32..64]), from_bits(bits[64..96]), from_bits(bits[96..128])]

View file

@ -0,0 +1,4 @@
import "EMBED/u32_from_bits" as from_bits
def main(bool[256] bits) -> (u32[8]):
return [from_bits(bits[0..32]), from_bits(bits[32..64]), from_bits(bits[64..96]), from_bits(bits[96..128]), from_bits(bits[128..160]), from_bits(bits[160..192]), from_bits(bits[192..224]), from_bits(bits[224..256])]

View file

@ -0,0 +1,4 @@
import "EMBED/u32_to_bits" as to_bits
def main(u32[4] input) -> (bool[128]):
return [...to_bits(input[0]), ...to_bits(input[1]), ...to_bits(input[2]), ...to_bits(input[3])]

View file

@ -0,0 +1,4 @@
import "EMBED/u32_to_bits" as to_bits
def main(u32[8] input) -> (bool[256]):
return [...to_bits(input[0]), ...to_bits(input[1]), ...to_bits(input[2]), ...to_bits(input[3]), ...to_bits(input[4]), ...to_bits(input[5]), ...to_bits(input[6]), ...to_bits(input[7])]

View file

@ -1,2 +1,2 @@
def main(bool selector, bool[256] lhs, bool[256] rhs) -> (bool[256]):
def main(bool selector, u32[8] lhs, u32[8] rhs) -> (u32[8]):
return if selector then rhs else lhs fi

View file

@ -10,4 +10,4 @@ def main(field i) -> (bool[256]):
bool[254] b = unpack(i)
return [false, false, ...b]
return [false, false, ...b]

View file

@ -1,9 +1,13 @@
#pragma curve bn128
def main(bool[128] bits) -> (field):
field out = 0
for field j in 0..128 do
field i = 128 - (j + 1)
field len = 128
for field j in 0..len do
field i = len - (j + 1)
out = out + if bits[i] then (2 ** j) else 0 fi
endfor

View file

@ -0,0 +1,14 @@
#pragma curve bn128
def main(bool[256] input) -> (field):
field out = 0
field len = 256
for field j in 0..len do
field i = len - (j + 1)
out = out + if bits[i] then (2 ** j) else 0 fi
endfor
return out

View file

@ -6,6 +6,6 @@ def main(field i) -> (bool[128]):
bool[254] b = unpack(i)
b[0..126] == [false; 126]
assert(b[0..126] == [false; 126])
return b[126..254]

View file

@ -1,10 +0,0 @@
def main(bool[256] bits) -> (field):
field out = 0
for field j in 0..256 do
field i = 256 - (j + 1)
out = out + if bits[i] then (2 ** j) else 0 fi
endfor
return out

View file

@ -0,0 +1,12 @@
#pragma curve bn128
// Non-strict version:
// Note that this does not strongly enforce that the commitment is
// in the field.
import "../bool/nonStrictUnpack256" as unpack
import "../../casts/bool_256_to_u32_8" as from_bits
def main(field i) -> (u32[8]):
return from_bits(unpack(i))

View file

@ -0,0 +1,10 @@
#pragma curve bn128
import "EMBED/u32_to_bits" as to_bits
import "../bool/pack128"
def main(u32[4] input) -> (field):
bool[128] bits = [...to_bits(input[0]), ...to_bits(input[1]), ...to_bits(input[2]), ...to_bits(input[3])]
return pack128(bits)

View file

@ -0,0 +1,10 @@
#pragma curve bn128
import "EMBED/u32_to_bits" as to_bits
import "../bool/pack256"
def main(u32[8] input) -> (field):
bool[256] bits = [...to_bits(input[0]), ...to_bits(input[1]), ...to_bits(input[2]), ...to_bits(input[3]), ...to_bits(input[4]), ...to_bits(input[5]), ...to_bits(input[6]), ...to_bits(input[7])]
return pack256(bits)

View file

@ -0,0 +1,7 @@
#pragma curve bn128
import "../bool/unpack128" as unpack
import "../../casts/bool_128_to_u32_4" as from_bits
def main(field i) -> (u32[4]):
return from_bits(unpack(i))

Some files were not shown because too many files have changed in this diff Show more