add bit lt embed, fail on non constant bound, implement safe unpack
This commit is contained in:
parent
caf4dd2acb
commit
b1c9a171f8
16 changed files with 277 additions and 80 deletions
|
@ -0,0 +1,5 @@
|
|||
from "EMBED" import bit_array_le
|
||||
|
||||
// Unpack a field element as N big endian bits
|
||||
def main(bool[1] a, bool[1] b) -> bool:
|
||||
return bit_array_le::<1>(a, b)
|
|
@ -28,6 +28,7 @@ cfg_if::cfg_if! {
|
|||
/// the flattening step when it can be inlined.
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash, Copy)]
|
||||
pub enum FlatEmbed {
|
||||
BitArrayLe,
|
||||
U32ToField,
|
||||
Unpack,
|
||||
U8ToBits,
|
||||
|
@ -47,6 +48,30 @@ pub enum FlatEmbed {
|
|||
impl FlatEmbed {
|
||||
pub fn signature(&self) -> DeclarationSignature<'static> {
|
||||
match self {
|
||||
FlatEmbed::BitArrayLe => DeclarationSignature::new()
|
||||
.generics(vec![Some(DeclarationConstant::Generic(
|
||||
GenericIdentifier {
|
||||
name: "N",
|
||||
index: 0,
|
||||
},
|
||||
))])
|
||||
.inputs(vec![
|
||||
DeclarationType::array((
|
||||
DeclarationType::Boolean,
|
||||
GenericIdentifier {
|
||||
name: "N",
|
||||
index: 0,
|
||||
},
|
||||
)),
|
||||
DeclarationType::array((
|
||||
DeclarationType::Boolean,
|
||||
GenericIdentifier {
|
||||
name: "N",
|
||||
index: 0,
|
||||
},
|
||||
)),
|
||||
])
|
||||
.outputs(vec![DeclarationType::Boolean]),
|
||||
FlatEmbed::U32ToField => DeclarationSignature::new()
|
||||
.inputs(vec![DeclarationType::uint(32)])
|
||||
.outputs(vec![DeclarationType::FieldElement]),
|
||||
|
@ -172,6 +197,7 @@ impl FlatEmbed {
|
|||
|
||||
pub fn id(&self) -> &'static str {
|
||||
match self {
|
||||
&FlatEmbed::BitArrayLe => "_BIT_ARRAY_LT",
|
||||
FlatEmbed::U32ToField => "_U32_TO_FIELD",
|
||||
FlatEmbed::Unpack => "_UNPACK",
|
||||
FlatEmbed::U8ToBits => "_U8_TO_BITS",
|
||||
|
@ -453,10 +479,6 @@ fn use_variable(
|
|||
/// as we decompose over `log_2(p) + 1 bits, some
|
||||
/// elements can have multiple representations: For example, `unpack(0)` is `[0, ..., 0]` but also `unpack(p)`
|
||||
pub fn unpack_to_bitwidth<T: Field>(bit_width: usize) -> FlatFunction<T> {
|
||||
let nbits = T::get_required_bits();
|
||||
|
||||
assert!(bit_width <= nbits);
|
||||
|
||||
let mut counter = 0;
|
||||
|
||||
let mut layout = HashMap::new();
|
||||
|
|
|
@ -222,7 +222,6 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
b: &[bool],
|
||||
) -> Vec<FlatExpression<T>> {
|
||||
let len = b.len();
|
||||
assert_eq!(a.len(), T::get_required_bits());
|
||||
assert_eq!(a.len(), b.len());
|
||||
|
||||
let mut is_not_smaller_run = vec![];
|
||||
|
@ -984,7 +983,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
}
|
||||
|
||||
// check that the decomposition is in the field with a strict `< p` checks
|
||||
self.constant_le_check(
|
||||
self.enforce_constant_le_check(
|
||||
statements_flattened,
|
||||
&sub_bits_be,
|
||||
&T::max_value().bit_vector_be(),
|
||||
|
@ -1161,6 +1160,52 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
crate::embed::FlatEmbed::U8FromBits => {
|
||||
vec![self.flatten_bits_to_u(statements_flattened, param_expressions, 8.into())]
|
||||
}
|
||||
crate::embed::FlatEmbed::BitArrayLe => {
|
||||
let len = generics[0];
|
||||
|
||||
let (expressions, constants) = (
|
||||
param_expressions[..len as usize].to_vec(),
|
||||
param_expressions[len as usize..].to_vec(),
|
||||
);
|
||||
|
||||
let variables: Vec<_> = expressions
|
||||
.into_iter()
|
||||
.map(|e| {
|
||||
let e = self
|
||||
.flatten_expression(statements_flattened, e)
|
||||
.get_field_unchecked();
|
||||
self.define(e, statements_flattened)
|
||||
})
|
||||
.collect();
|
||||
|
||||
let constants: Vec<_> = constants
|
||||
.into_iter()
|
||||
.map(|e| {
|
||||
self.flatten_expression(statements_flattened, e)
|
||||
.get_field_unchecked()
|
||||
})
|
||||
.map(|e| match e {
|
||||
FlatExpression::Number(n) => n == T::one(),
|
||||
_ => unreachable!(),
|
||||
})
|
||||
.collect();
|
||||
|
||||
let conditions =
|
||||
self.constant_le_check(statements_flattened, &variables, &constants);
|
||||
|
||||
// return `len(conditions) == sum(conditions)`
|
||||
vec![FlatUExpression::with_field(
|
||||
self.eq_check(
|
||||
statements_flattened,
|
||||
T::from(conditions.len()).into(),
|
||||
conditions
|
||||
.into_iter()
|
||||
.fold(FlatExpression::Number(T::zero()), |acc, e| {
|
||||
FlatExpression::Add(box acc, box e)
|
||||
}),
|
||||
),
|
||||
)]
|
||||
}
|
||||
funct => {
|
||||
let funct = funct.synthetize(&generics);
|
||||
|
||||
|
|
|
@ -148,6 +148,10 @@ impl Importer {
|
|||
id: symbol.get_alias(),
|
||||
symbol: Symbol::Flat(FlatEmbed::Unpack),
|
||||
},
|
||||
"bit_array_le" => SymbolDeclaration {
|
||||
id: symbol.get_alias(),
|
||||
symbol: Symbol::Flat(FlatEmbed::BitArrayLe),
|
||||
},
|
||||
"u64_to_bits" => SymbolDeclaration {
|
||||
id: symbol.get_alias(),
|
||||
symbol: Symbol::Flat(FlatEmbed::U64ToBits),
|
||||
|
|
|
@ -156,10 +156,18 @@ impl Interpreter {
|
|||
],
|
||||
},
|
||||
Solver::Bits(bit_width) => {
|
||||
let padding = bit_width.saturating_sub(T::get_required_bits());
|
||||
|
||||
let bit_width = bit_width - padding;
|
||||
|
||||
let mut num = inputs[0].clone();
|
||||
let mut res = vec![];
|
||||
|
||||
for i in (0..*bit_width).rev() {
|
||||
for _ in 0..padding {
|
||||
res.push(T::zero());
|
||||
}
|
||||
|
||||
for i in (0..bit_width).rev() {
|
||||
if T::from(2).pow(i) <= num {
|
||||
num = num - T::from(2).pow(i);
|
||||
res.push(T::one());
|
||||
|
@ -407,4 +415,18 @@ mod tests {
|
|||
assert_eq!(res[248], Bn128Field::from(1));
|
||||
assert_eq!(res[247], Bn128Field::from(0));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn five_hundred_bits_of_1() {
|
||||
let inputs = vec![Bn128Field::from(1)];
|
||||
let interpreter = Interpreter::default();
|
||||
let res = interpreter
|
||||
.execute_solver(&Solver::Bits(500), &inputs)
|
||||
.unwrap();
|
||||
|
||||
let mut expected = vec![Bn128Field::from(0); 500];
|
||||
expected[499] = Bn128Field::from(1);
|
||||
|
||||
assert_eq!(res, expected);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,99 @@
|
|||
use crate::embed::FlatEmbed;
|
||||
use crate::typed_absy::TypedProgram;
|
||||
use crate::typed_absy::{
|
||||
result_folder::ResultFolder,
|
||||
result_folder::{fold_expression_list_inner, fold_uint_expression_inner},
|
||||
ArrayExpressionInner, BooleanExpression, TypedExpression, TypedExpressionListInner,
|
||||
TypedExpressionOrSpread, Types, UBitwidth, UExpressionInner,
|
||||
};
|
||||
use zokrates_field::Field;
|
||||
pub struct ConstantArgumentChecker;
|
||||
|
||||
impl ConstantArgumentChecker {
|
||||
pub fn check<T: Field>(p: TypedProgram<T>) -> Result<TypedProgram<T>, Error> {
|
||||
ConstantArgumentChecker.fold_program(p)
|
||||
}
|
||||
}
|
||||
|
||||
pub type Error = String;
|
||||
|
||||
impl<'ast, T: Field> ResultFolder<'ast, T> for ConstantArgumentChecker {
|
||||
type Error = Error;
|
||||
|
||||
fn fold_uint_expression_inner(
|
||||
&mut self,
|
||||
bitwidth: UBitwidth,
|
||||
e: UExpressionInner<'ast, T>,
|
||||
) -> Result<UExpressionInner<'ast, T>, Error> {
|
||||
match e {
|
||||
UExpressionInner::LeftShift(box e, box by) => {
|
||||
let e = self.fold_uint_expression(e)?;
|
||||
let by = self.fold_uint_expression(by)?;
|
||||
|
||||
match by.as_inner() {
|
||||
UExpressionInner::Value(_) => Ok(UExpressionInner::LeftShift(box e, box by)),
|
||||
by => Err(format!(
|
||||
"Cannot shift by a variable value, found `{} << {}`",
|
||||
e,
|
||||
by.clone().annotate(UBitwidth::B32)
|
||||
)),
|
||||
}
|
||||
}
|
||||
UExpressionInner::RightShift(box e, box by) => {
|
||||
let e = self.fold_uint_expression(e)?;
|
||||
let by = self.fold_uint_expression(by)?;
|
||||
|
||||
match by.as_inner() {
|
||||
UExpressionInner::Value(_) => Ok(UExpressionInner::RightShift(box e, box by)),
|
||||
by => Err(format!(
|
||||
"Cannot shift by a variable value, found `{} >> {}`",
|
||||
e,
|
||||
by.clone().annotate(UBitwidth::B32)
|
||||
)),
|
||||
}
|
||||
}
|
||||
e => fold_uint_expression_inner(self, bitwidth, e),
|
||||
}
|
||||
}
|
||||
|
||||
fn fold_expression_list_inner(
|
||||
&mut self,
|
||||
tys: &Types<'ast, T>,
|
||||
l: TypedExpressionListInner<'ast, T>,
|
||||
) -> Result<TypedExpressionListInner<'ast, T>, Error> {
|
||||
match l {
|
||||
TypedExpressionListInner::EmbedCall(FlatEmbed::BitArrayLe, generics, arguments) => {
|
||||
let arguments = arguments
|
||||
.into_iter()
|
||||
.map(|a| self.fold_expression(a))
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
|
||||
match arguments[1] {
|
||||
TypedExpression::Array(ref a) => match a.as_inner() {
|
||||
ArrayExpressionInner::Value(v) => {
|
||||
if v.0.iter().all(|v| {
|
||||
matches!(
|
||||
v,
|
||||
TypedExpressionOrSpread::Expression(TypedExpression::Boolean(
|
||||
BooleanExpression::Value(_)
|
||||
))
|
||||
)
|
||||
}) {
|
||||
Ok(TypedExpressionListInner::EmbedCall(
|
||||
FlatEmbed::BitArrayLe,
|
||||
generics,
|
||||
arguments,
|
||||
))
|
||||
} else {
|
||||
Err(format!("Cannot compare to a variable value, found `{}`", a))
|
||||
}
|
||||
}
|
||||
v => Err(format!("Cannot compare to a variable value, found `{}`", v)),
|
||||
},
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
l => fold_expression_list_inner(self, tys, l),
|
||||
}
|
||||
}
|
||||
}
|
|
@ -5,21 +5,21 @@
|
|||
//! @date 2018
|
||||
|
||||
mod branch_isolator;
|
||||
mod constant_argument_checker;
|
||||
mod constant_inliner;
|
||||
mod flat_propagation;
|
||||
mod flatten_complex_types;
|
||||
mod propagation;
|
||||
mod reducer;
|
||||
mod shift_checker;
|
||||
mod uint_optimizer;
|
||||
mod unconstrained_vars;
|
||||
mod variable_write_remover;
|
||||
|
||||
use self::branch_isolator::Isolator;
|
||||
use self::constant_argument_checker::ConstantArgumentChecker;
|
||||
use self::flatten_complex_types::Flattener;
|
||||
use self::propagation::Propagator;
|
||||
use self::reducer::reduce_program;
|
||||
use self::shift_checker::ShiftChecker;
|
||||
use self::uint_optimizer::UintOptimizer;
|
||||
use self::unconstrained_vars::UnconstrainedVariableDetector;
|
||||
use self::variable_write_remover::VariableWriteRemover;
|
||||
|
@ -39,7 +39,7 @@ pub trait Analyse {
|
|||
pub enum Error {
|
||||
Reducer(self::reducer::Error),
|
||||
Propagation(self::propagation::Error),
|
||||
NonConstantShift(self::shift_checker::Error),
|
||||
NonConstantShift(self::constant_argument_checker::Error),
|
||||
}
|
||||
|
||||
impl From<reducer::Error> for Error {
|
||||
|
@ -54,8 +54,8 @@ impl From<propagation::Error> for Error {
|
|||
}
|
||||
}
|
||||
|
||||
impl From<shift_checker::Error> for Error {
|
||||
fn from(e: shift_checker::Error) -> Self {
|
||||
impl From<constant_argument_checker::Error> for Error {
|
||||
fn from(e: constant_argument_checker::Error) -> Self {
|
||||
Error::NonConstantShift(e)
|
||||
}
|
||||
}
|
||||
|
@ -90,8 +90,8 @@ impl<'ast, T: Field> TypedProgram<'ast, T> {
|
|||
let r = Propagator::propagate(r).map_err(Error::from)?;
|
||||
// remove assignment to variable index
|
||||
let r = VariableWriteRemover::apply(r);
|
||||
// detect non constant shifts
|
||||
let r = ShiftChecker::check(r).map_err(Error::from)?;
|
||||
// detect non constant shifts and constant lt bounds
|
||||
let r = ConstantArgumentChecker::check(r).map_err(Error::from)?;
|
||||
// convert to zir, removing complex types
|
||||
let zir = Flattener::flatten(r);
|
||||
// optimize uint expressions
|
||||
|
|
|
@ -502,6 +502,7 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Propagator<'ast, 'a, T> {
|
|||
true => {
|
||||
let r: Option<TypedExpression<'ast, T>> = match embed {
|
||||
FlatEmbed::U32ToField => None, // todo
|
||||
FlatEmbed::BitArrayLe => None, // todo
|
||||
FlatEmbed::U64FromBits => Some(process_u_from_bits(
|
||||
assignees.clone(),
|
||||
arguments.clone(),
|
||||
|
|
|
@ -1,55 +0,0 @@
|
|||
use crate::typed_absy::TypedProgram;
|
||||
use crate::typed_absy::{
|
||||
result_folder::fold_uint_expression_inner, result_folder::ResultFolder, UBitwidth,
|
||||
UExpressionInner,
|
||||
};
|
||||
use zokrates_field::Field;
|
||||
pub struct ShiftChecker;
|
||||
|
||||
impl ShiftChecker {
|
||||
pub fn check<T: Field>(p: TypedProgram<T>) -> Result<TypedProgram<T>, Error> {
|
||||
ShiftChecker.fold_program(p)
|
||||
}
|
||||
}
|
||||
|
||||
pub type Error = String;
|
||||
|
||||
impl<'ast, T: Field> ResultFolder<'ast, T> for ShiftChecker {
|
||||
type Error = Error;
|
||||
|
||||
fn fold_uint_expression_inner(
|
||||
&mut self,
|
||||
bitwidth: UBitwidth,
|
||||
e: UExpressionInner<'ast, T>,
|
||||
) -> Result<UExpressionInner<'ast, T>, Error> {
|
||||
match e {
|
||||
UExpressionInner::LeftShift(box e, box by) => {
|
||||
let e = self.fold_uint_expression(e)?;
|
||||
let by = self.fold_uint_expression(by)?;
|
||||
|
||||
match by.as_inner() {
|
||||
UExpressionInner::Value(_) => Ok(UExpressionInner::LeftShift(box e, box by)),
|
||||
by => Err(format!(
|
||||
"Cannot shift by a variable value, found `{} << {}`",
|
||||
e,
|
||||
by.clone().annotate(UBitwidth::B32)
|
||||
)),
|
||||
}
|
||||
}
|
||||
UExpressionInner::RightShift(box e, box by) => {
|
||||
let e = self.fold_uint_expression(e)?;
|
||||
let by = self.fold_uint_expression(by)?;
|
||||
|
||||
match by.as_inner() {
|
||||
UExpressionInner::Value(_) => Ok(UExpressionInner::RightShift(box e, box by)),
|
||||
by => Err(format!(
|
||||
"Cannot shift by a variable value, found `{} >> {}`",
|
||||
e,
|
||||
by.clone().annotate(UBitwidth::B32)
|
||||
)),
|
||||
}
|
||||
}
|
||||
e => fold_uint_expression_inner(self, bitwidth, e),
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,12 +1,12 @@
|
|||
#pragma curve bn128
|
||||
|
||||
import "./unpack" as unpack
|
||||
import "./unpack_unchecked"
|
||||
|
||||
// Unpack a field element as 256 big-endian bits
|
||||
// Note: uniqueness of the output is not guaranteed
|
||||
// For example, `0` can map to `[0, 0, ..., 0]` or to `bits(p)`
|
||||
def main(field i) -> bool[256]:
|
||||
|
||||
bool[254] b = unpack::<254>(i)
|
||||
bool[254] b = unpack_unchecked::<254>(i)
|
||||
|
||||
return [false, false, ...b]
|
|
@ -1,12 +1,12 @@
|
|||
#pragma curve bn128
|
||||
|
||||
from "EMBED" import unpack
|
||||
import "./unpack_unchecked.zok"
|
||||
from "field" import FIELD_SIZE_IN_BITS
|
||||
from "EMBED" import bit_array_le
|
||||
|
||||
// Unpack a field element as N big endian bits
|
||||
def main<N>(field i) -> bool[N]:
|
||||
|
||||
assert(N <= 254)
|
||||
|
||||
bool[N] res = unpack(i)
|
||||
bool[N] res = unpack_unchecked(i)
|
||||
|
||||
assert(if N >= FIELD_SIZE_IN_BITS then bit_array_le(res, [...[false; N - FIELD_SIZE_IN_BITS], ...unpack_unchecked::<FIELD_SIZE_IN_BITS>(-1)]) else true fi)
|
||||
|
||||
return res
|
|
@ -1,9 +1,7 @@
|
|||
#pragma curve bn128
|
||||
|
||||
import "./unpack" as unpack
|
||||
|
||||
// Unpack a field element as 128 big-endian bits
|
||||
// Precondition: the input is smaller or equal to `2**128 - 1`
|
||||
// If the input is larger than `2**128 - 1`, the output is truncated.
|
||||
def main(field i) -> bool[128]:
|
||||
bool[128] res = unpack::<128>(i)
|
||||
return res
|
7
zokrates_stdlib/stdlib/utils/pack/bool/unpack256.zok
Normal file
7
zokrates_stdlib/stdlib/utils/pack/bool/unpack256.zok
Normal file
|
@ -0,0 +1,7 @@
|
|||
import "./unpack" as unpack
|
||||
|
||||
// Unpack a field element as 256 big-endian bits
|
||||
// If the input is larger than `2**256 - 1`, the output is truncated.
|
||||
def main(field i) -> bool[256]:
|
||||
bool[256] res = unpack::<256>(i)
|
||||
return res
|
|
@ -0,0 +1,9 @@
|
|||
from "EMBED" import unpack
|
||||
|
||||
// Unpack a field element as N big endian bits without checking for overflows
|
||||
// This does *not* guarantee a single output: for example, 0 can be decomposed as 0 or as P and this function does not enforce either
|
||||
def main<N>(field i) -> bool[N]:
|
||||
|
||||
bool[N] res = unpack(i)
|
||||
|
||||
return res
|
16
zokrates_stdlib/tests/tests/utils/pack/bool/unpack256.json
Normal file
16
zokrates_stdlib/tests/tests/utils/pack/bool/unpack256.json
Normal file
|
@ -0,0 +1,16 @@
|
|||
{
|
||||
"entry_point": "./tests/tests/utils/pack/bool/unpack256.zok",
|
||||
"curves": ["Bn128"],
|
||||
"tests": [
|
||||
{
|
||||
"input": {
|
||||
"values": []
|
||||
},
|
||||
"output": {
|
||||
"Ok": {
|
||||
"values": []
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
24
zokrates_stdlib/tests/tests/utils/pack/bool/unpack256.zok
Normal file
24
zokrates_stdlib/tests/tests/utils/pack/bool/unpack256.zok
Normal file
|
@ -0,0 +1,24 @@
|
|||
import "utils/pack/bool/unpack256" as unpack256
|
||||
|
||||
def testFive() -> bool:
|
||||
|
||||
bool[256] b = unpack256(5)
|
||||
|
||||
assert(b == [...[false; 253], true, false, true])
|
||||
|
||||
return true
|
||||
|
||||
def testZero() -> bool:
|
||||
|
||||
bool[256] b = unpack256(0)
|
||||
|
||||
assert(b == [false; 256])
|
||||
|
||||
return true
|
||||
|
||||
def main():
|
||||
|
||||
assert(testFive())
|
||||
assert(testZero())
|
||||
|
||||
return
|
Loading…
Reference in a new issue