1
0
Fork 0
mirror of synced 2025-09-23 12:18:44 +00:00

add bit lt embed, fail on non constant bound, implement safe unpack

This commit is contained in:
schaeff 2021-08-04 14:55:17 +02:00
parent caf4dd2acb
commit b1c9a171f8
16 changed files with 277 additions and 80 deletions

View file

@ -0,0 +1,5 @@
from "EMBED" import bit_array_le
// Unpack a field element as N big endian bits
def main(bool[1] a, bool[1] b) -> bool:
return bit_array_le::<1>(a, b)

View file

@ -28,6 +28,7 @@ cfg_if::cfg_if! {
/// the flattening step when it can be inlined.
#[derive(Debug, Clone, PartialEq, Eq, Hash, Copy)]
pub enum FlatEmbed {
BitArrayLe,
U32ToField,
Unpack,
U8ToBits,
@ -47,6 +48,30 @@ pub enum FlatEmbed {
impl FlatEmbed {
pub fn signature(&self) -> DeclarationSignature<'static> {
match self {
FlatEmbed::BitArrayLe => DeclarationSignature::new()
.generics(vec![Some(DeclarationConstant::Generic(
GenericIdentifier {
name: "N",
index: 0,
},
))])
.inputs(vec![
DeclarationType::array((
DeclarationType::Boolean,
GenericIdentifier {
name: "N",
index: 0,
},
)),
DeclarationType::array((
DeclarationType::Boolean,
GenericIdentifier {
name: "N",
index: 0,
},
)),
])
.outputs(vec![DeclarationType::Boolean]),
FlatEmbed::U32ToField => DeclarationSignature::new()
.inputs(vec![DeclarationType::uint(32)])
.outputs(vec![DeclarationType::FieldElement]),
@ -172,6 +197,7 @@ impl FlatEmbed {
pub fn id(&self) -> &'static str {
match self {
&FlatEmbed::BitArrayLe => "_BIT_ARRAY_LT",
FlatEmbed::U32ToField => "_U32_TO_FIELD",
FlatEmbed::Unpack => "_UNPACK",
FlatEmbed::U8ToBits => "_U8_TO_BITS",
@ -453,10 +479,6 @@ fn use_variable(
/// as we decompose over `log_2(p) + 1 bits, some
/// elements can have multiple representations: For example, `unpack(0)` is `[0, ..., 0]` but also `unpack(p)`
pub fn unpack_to_bitwidth<T: Field>(bit_width: usize) -> FlatFunction<T> {
let nbits = T::get_required_bits();
assert!(bit_width <= nbits);
let mut counter = 0;
let mut layout = HashMap::new();

View file

@ -222,7 +222,6 @@ impl<'ast, T: Field> Flattener<'ast, T> {
b: &[bool],
) -> Vec<FlatExpression<T>> {
let len = b.len();
assert_eq!(a.len(), T::get_required_bits());
assert_eq!(a.len(), b.len());
let mut is_not_smaller_run = vec![];
@ -984,7 +983,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
}
// check that the decomposition is in the field with a strict `< p` checks
self.constant_le_check(
self.enforce_constant_le_check(
statements_flattened,
&sub_bits_be,
&T::max_value().bit_vector_be(),
@ -1161,6 +1160,52 @@ impl<'ast, T: Field> Flattener<'ast, T> {
crate::embed::FlatEmbed::U8FromBits => {
vec![self.flatten_bits_to_u(statements_flattened, param_expressions, 8.into())]
}
crate::embed::FlatEmbed::BitArrayLe => {
let len = generics[0];
let (expressions, constants) = (
param_expressions[..len as usize].to_vec(),
param_expressions[len as usize..].to_vec(),
);
let variables: Vec<_> = expressions
.into_iter()
.map(|e| {
let e = self
.flatten_expression(statements_flattened, e)
.get_field_unchecked();
self.define(e, statements_flattened)
})
.collect();
let constants: Vec<_> = constants
.into_iter()
.map(|e| {
self.flatten_expression(statements_flattened, e)
.get_field_unchecked()
})
.map(|e| match e {
FlatExpression::Number(n) => n == T::one(),
_ => unreachable!(),
})
.collect();
let conditions =
self.constant_le_check(statements_flattened, &variables, &constants);
// return `len(conditions) == sum(conditions)`
vec![FlatUExpression::with_field(
self.eq_check(
statements_flattened,
T::from(conditions.len()).into(),
conditions
.into_iter()
.fold(FlatExpression::Number(T::zero()), |acc, e| {
FlatExpression::Add(box acc, box e)
}),
),
)]
}
funct => {
let funct = funct.synthetize(&generics);

View file

@ -148,6 +148,10 @@ impl Importer {
id: symbol.get_alias(),
symbol: Symbol::Flat(FlatEmbed::Unpack),
},
"bit_array_le" => SymbolDeclaration {
id: symbol.get_alias(),
symbol: Symbol::Flat(FlatEmbed::BitArrayLe),
},
"u64_to_bits" => SymbolDeclaration {
id: symbol.get_alias(),
symbol: Symbol::Flat(FlatEmbed::U64ToBits),

View file

@ -156,10 +156,18 @@ impl Interpreter {
],
},
Solver::Bits(bit_width) => {
let padding = bit_width.saturating_sub(T::get_required_bits());
let bit_width = bit_width - padding;
let mut num = inputs[0].clone();
let mut res = vec![];
for i in (0..*bit_width).rev() {
for _ in 0..padding {
res.push(T::zero());
}
for i in (0..bit_width).rev() {
if T::from(2).pow(i) <= num {
num = num - T::from(2).pow(i);
res.push(T::one());
@ -407,4 +415,18 @@ mod tests {
assert_eq!(res[248], Bn128Field::from(1));
assert_eq!(res[247], Bn128Field::from(0));
}
#[test]
fn five_hundred_bits_of_1() {
let inputs = vec![Bn128Field::from(1)];
let interpreter = Interpreter::default();
let res = interpreter
.execute_solver(&Solver::Bits(500), &inputs)
.unwrap();
let mut expected = vec![Bn128Field::from(0); 500];
expected[499] = Bn128Field::from(1);
assert_eq!(res, expected);
}
}

View file

@ -0,0 +1,99 @@
use crate::embed::FlatEmbed;
use crate::typed_absy::TypedProgram;
use crate::typed_absy::{
result_folder::ResultFolder,
result_folder::{fold_expression_list_inner, fold_uint_expression_inner},
ArrayExpressionInner, BooleanExpression, TypedExpression, TypedExpressionListInner,
TypedExpressionOrSpread, Types, UBitwidth, UExpressionInner,
};
use zokrates_field::Field;
pub struct ConstantArgumentChecker;
impl ConstantArgumentChecker {
pub fn check<T: Field>(p: TypedProgram<T>) -> Result<TypedProgram<T>, Error> {
ConstantArgumentChecker.fold_program(p)
}
}
pub type Error = String;
impl<'ast, T: Field> ResultFolder<'ast, T> for ConstantArgumentChecker {
type Error = Error;
fn fold_uint_expression_inner(
&mut self,
bitwidth: UBitwidth,
e: UExpressionInner<'ast, T>,
) -> Result<UExpressionInner<'ast, T>, Error> {
match e {
UExpressionInner::LeftShift(box e, box by) => {
let e = self.fold_uint_expression(e)?;
let by = self.fold_uint_expression(by)?;
match by.as_inner() {
UExpressionInner::Value(_) => Ok(UExpressionInner::LeftShift(box e, box by)),
by => Err(format!(
"Cannot shift by a variable value, found `{} << {}`",
e,
by.clone().annotate(UBitwidth::B32)
)),
}
}
UExpressionInner::RightShift(box e, box by) => {
let e = self.fold_uint_expression(e)?;
let by = self.fold_uint_expression(by)?;
match by.as_inner() {
UExpressionInner::Value(_) => Ok(UExpressionInner::RightShift(box e, box by)),
by => Err(format!(
"Cannot shift by a variable value, found `{} >> {}`",
e,
by.clone().annotate(UBitwidth::B32)
)),
}
}
e => fold_uint_expression_inner(self, bitwidth, e),
}
}
fn fold_expression_list_inner(
&mut self,
tys: &Types<'ast, T>,
l: TypedExpressionListInner<'ast, T>,
) -> Result<TypedExpressionListInner<'ast, T>, Error> {
match l {
TypedExpressionListInner::EmbedCall(FlatEmbed::BitArrayLe, generics, arguments) => {
let arguments = arguments
.into_iter()
.map(|a| self.fold_expression(a))
.collect::<Result<Vec<_>, _>>()?;
match arguments[1] {
TypedExpression::Array(ref a) => match a.as_inner() {
ArrayExpressionInner::Value(v) => {
if v.0.iter().all(|v| {
matches!(
v,
TypedExpressionOrSpread::Expression(TypedExpression::Boolean(
BooleanExpression::Value(_)
))
)
}) {
Ok(TypedExpressionListInner::EmbedCall(
FlatEmbed::BitArrayLe,
generics,
arguments,
))
} else {
Err(format!("Cannot compare to a variable value, found `{}`", a))
}
}
v => Err(format!("Cannot compare to a variable value, found `{}`", v)),
},
_ => unreachable!(),
}
}
l => fold_expression_list_inner(self, tys, l),
}
}
}

View file

@ -5,21 +5,21 @@
//! @date 2018
mod branch_isolator;
mod constant_argument_checker;
mod constant_inliner;
mod flat_propagation;
mod flatten_complex_types;
mod propagation;
mod reducer;
mod shift_checker;
mod uint_optimizer;
mod unconstrained_vars;
mod variable_write_remover;
use self::branch_isolator::Isolator;
use self::constant_argument_checker::ConstantArgumentChecker;
use self::flatten_complex_types::Flattener;
use self::propagation::Propagator;
use self::reducer::reduce_program;
use self::shift_checker::ShiftChecker;
use self::uint_optimizer::UintOptimizer;
use self::unconstrained_vars::UnconstrainedVariableDetector;
use self::variable_write_remover::VariableWriteRemover;
@ -39,7 +39,7 @@ pub trait Analyse {
pub enum Error {
Reducer(self::reducer::Error),
Propagation(self::propagation::Error),
NonConstantShift(self::shift_checker::Error),
NonConstantShift(self::constant_argument_checker::Error),
}
impl From<reducer::Error> for Error {
@ -54,8 +54,8 @@ impl From<propagation::Error> for Error {
}
}
impl From<shift_checker::Error> for Error {
fn from(e: shift_checker::Error) -> Self {
impl From<constant_argument_checker::Error> for Error {
fn from(e: constant_argument_checker::Error) -> Self {
Error::NonConstantShift(e)
}
}
@ -90,8 +90,8 @@ impl<'ast, T: Field> TypedProgram<'ast, T> {
let r = Propagator::propagate(r).map_err(Error::from)?;
// remove assignment to variable index
let r = VariableWriteRemover::apply(r);
// detect non constant shifts
let r = ShiftChecker::check(r).map_err(Error::from)?;
// detect non constant shifts and constant lt bounds
let r = ConstantArgumentChecker::check(r).map_err(Error::from)?;
// convert to zir, removing complex types
let zir = Flattener::flatten(r);
// optimize uint expressions

View file

@ -502,6 +502,7 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Propagator<'ast, 'a, T> {
true => {
let r: Option<TypedExpression<'ast, T>> = match embed {
FlatEmbed::U32ToField => None, // todo
FlatEmbed::BitArrayLe => None, // todo
FlatEmbed::U64FromBits => Some(process_u_from_bits(
assignees.clone(),
arguments.clone(),

View file

@ -1,55 +0,0 @@
use crate::typed_absy::TypedProgram;
use crate::typed_absy::{
result_folder::fold_uint_expression_inner, result_folder::ResultFolder, UBitwidth,
UExpressionInner,
};
use zokrates_field::Field;
pub struct ShiftChecker;
impl ShiftChecker {
pub fn check<T: Field>(p: TypedProgram<T>) -> Result<TypedProgram<T>, Error> {
ShiftChecker.fold_program(p)
}
}
pub type Error = String;
impl<'ast, T: Field> ResultFolder<'ast, T> for ShiftChecker {
type Error = Error;
fn fold_uint_expression_inner(
&mut self,
bitwidth: UBitwidth,
e: UExpressionInner<'ast, T>,
) -> Result<UExpressionInner<'ast, T>, Error> {
match e {
UExpressionInner::LeftShift(box e, box by) => {
let e = self.fold_uint_expression(e)?;
let by = self.fold_uint_expression(by)?;
match by.as_inner() {
UExpressionInner::Value(_) => Ok(UExpressionInner::LeftShift(box e, box by)),
by => Err(format!(
"Cannot shift by a variable value, found `{} << {}`",
e,
by.clone().annotate(UBitwidth::B32)
)),
}
}
UExpressionInner::RightShift(box e, box by) => {
let e = self.fold_uint_expression(e)?;
let by = self.fold_uint_expression(by)?;
match by.as_inner() {
UExpressionInner::Value(_) => Ok(UExpressionInner::RightShift(box e, box by)),
by => Err(format!(
"Cannot shift by a variable value, found `{} >> {}`",
e,
by.clone().annotate(UBitwidth::B32)
)),
}
}
e => fold_uint_expression_inner(self, bitwidth, e),
}
}
}

View file

@ -1,12 +1,12 @@
#pragma curve bn128
import "./unpack" as unpack
import "./unpack_unchecked"
// Unpack a field element as 256 big-endian bits
// Note: uniqueness of the output is not guaranteed
// For example, `0` can map to `[0, 0, ..., 0]` or to `bits(p)`
def main(field i) -> bool[256]:
bool[254] b = unpack::<254>(i)
bool[254] b = unpack_unchecked::<254>(i)
return [false, false, ...b]

View file

@ -1,12 +1,12 @@
#pragma curve bn128
from "EMBED" import unpack
import "./unpack_unchecked.zok"
from "field" import FIELD_SIZE_IN_BITS
from "EMBED" import bit_array_le
// Unpack a field element as N big endian bits
def main<N>(field i) -> bool[N]:
assert(N <= 254)
bool[N] res = unpack(i)
bool[N] res = unpack_unchecked(i)
assert(if N >= FIELD_SIZE_IN_BITS then bit_array_le(res, [...[false; N - FIELD_SIZE_IN_BITS], ...unpack_unchecked::<FIELD_SIZE_IN_BITS>(-1)]) else true fi)
return res

View file

@ -1,9 +1,7 @@
#pragma curve bn128
import "./unpack" as unpack
// Unpack a field element as 128 big-endian bits
// Precondition: the input is smaller or equal to `2**128 - 1`
// If the input is larger than `2**128 - 1`, the output is truncated.
def main(field i) -> bool[128]:
bool[128] res = unpack::<128>(i)
return res

View file

@ -0,0 +1,7 @@
import "./unpack" as unpack
// Unpack a field element as 256 big-endian bits
// If the input is larger than `2**256 - 1`, the output is truncated.
def main(field i) -> bool[256]:
bool[256] res = unpack::<256>(i)
return res

View file

@ -0,0 +1,9 @@
from "EMBED" import unpack
// Unpack a field element as N big endian bits without checking for overflows
// This does *not* guarantee a single output: for example, 0 can be decomposed as 0 or as P and this function does not enforce either
def main<N>(field i) -> bool[N]:
bool[N] res = unpack(i)
return res

View file

@ -0,0 +1,16 @@
{
"entry_point": "./tests/tests/utils/pack/bool/unpack256.zok",
"curves": ["Bn128"],
"tests": [
{
"input": {
"values": []
},
"output": {
"Ok": {
"values": []
}
}
}
]
}

View file

@ -0,0 +1,24 @@
import "utils/pack/bool/unpack256" as unpack256
def testFive() -> bool:
bool[256] b = unpack256(5)
assert(b == [...[false; 253], true, false, true])
return true
def testZero() -> bool:
bool[256] b = unpack256(0)
assert(b == [false; 256])
return true
def main():
assert(testFive())
assert(testZero())
return