1
0
Fork 0
mirror of synced 2025-09-23 04:08:33 +00:00

Merge pull request #955 from Zokrates/bit_lt_embed

Add bit LT embed to make unpack safe for any size
This commit is contained in:
Thibaut Schaeffer 2021-08-16 11:25:09 +02:00 committed by GitHub
commit b324e17684
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
26 changed files with 678 additions and 357 deletions

1
Cargo.lock generated
View file

@ -2451,6 +2451,7 @@ dependencies = [
"zokrates_common",
"zokrates_embed",
"zokrates_field",
"zokrates_fs_resolver",
"zokrates_pest_ast",
]

View file

@ -0,0 +1 @@
Make the stdlib `unpack` function safe against overflows of bit decompositions for any size of output, introduce `unpack_unchecked` for cases that do not require determinism

View file

@ -0,0 +1,5 @@
from "EMBED" import bit_array_le
// Calling the `bit_array_le` embed on a non-constant second argument should fail at compile-time
def main(bool[1] a, bool[1] b) -> bool:
return bit_array_le::<1>(a, b)

View file

@ -60,6 +60,7 @@ sha2 = { version = "0.9.3", optional = true }
[dev-dependencies]
wasm-bindgen-test = "^0.3.0"
pretty_assertions = "0.6.1"
zokrates_fs_resolver = { version = "0.5", path = "../zokrates_fs_resolver"}
[build-dependencies]
cc = { version = "1.0", features = ["parallel"], optional = true }

View file

@ -28,6 +28,7 @@ cfg_if::cfg_if! {
/// the flattening step when it can be inlined.
#[derive(Debug, Clone, PartialEq, Eq, Hash, Copy)]
pub enum FlatEmbed {
BitArrayLe,
U32ToField,
Unpack,
U8ToBits,
@ -47,6 +48,30 @@ pub enum FlatEmbed {
impl FlatEmbed {
pub fn signature(&self) -> DeclarationSignature<'static> {
match self {
FlatEmbed::BitArrayLe => DeclarationSignature::new()
.generics(vec![Some(DeclarationConstant::Generic(
GenericIdentifier {
name: "N",
index: 0,
},
))])
.inputs(vec![
DeclarationType::array((
DeclarationType::Boolean,
GenericIdentifier {
name: "N",
index: 0,
},
)),
DeclarationType::array((
DeclarationType::Boolean,
GenericIdentifier {
name: "N",
index: 0,
},
)),
])
.outputs(vec![DeclarationType::Boolean]),
FlatEmbed::U32ToField => DeclarationSignature::new()
.inputs(vec![DeclarationType::uint(32)])
.outputs(vec![DeclarationType::FieldElement]),
@ -172,6 +197,7 @@ impl FlatEmbed {
pub fn id(&self) -> &'static str {
match self {
FlatEmbed::BitArrayLe => "_BIT_ARRAY_LT",
FlatEmbed::U32ToField => "_U32_TO_FIELD",
FlatEmbed::Unpack => "_UNPACK",
FlatEmbed::U8ToBits => "_U8_TO_BITS",
@ -449,14 +475,9 @@ fn use_variable(
/// * bit_width the number of bits we want to decompose to
///
/// # Remarks
/// * the return value of the `FlatFunction` is not deterministic if `bit_width == T::get_required_bits()`
/// as we decompose over `log_2(p) + 1 bits, some
/// elements can have multiple representations: For example, `unpack(0)` is `[0, ..., 0]` but also `unpack(p)`
/// * the return value of the `FlatFunction` is not deterministic if `bit_width >= T::get_required_bits()`
/// as some elements can have multiple representations: For example, `unpack(0)` is `[0, ..., 0]` but also `unpack(p)`
pub fn unpack_to_bitwidth<T: Field>(bit_width: usize) -> FlatFunction<T> {
let nbits = T::get_required_bits();
assert!(bit_width <= nbits);
let mut counter = 0;
let mut layout = HashMap::new();

View file

@ -223,7 +223,6 @@ impl<'ast, T: Field> Flattener<'ast, T> {
b: &[bool],
) -> Vec<FlatExpression<T>> {
let len = b.len();
assert_eq!(a.len(), T::get_required_bits());
assert_eq!(a.len(), b.len());
let mut is_not_smaller_run = vec![];
@ -1164,6 +1163,58 @@ impl<'ast, T: Field> Flattener<'ast, T> {
crate::embed::FlatEmbed::U8FromBits => {
vec![self.flatten_bits_to_u(statements_flattened, param_expressions, 8.into())]
}
crate::embed::FlatEmbed::BitArrayLe => {
// get the length of the bit arrays
let len = generics[0];
// split the arguments into the two bit arrays of size `len`
let (expressions, constants) = (
param_expressions[..len as usize].to_vec(),
param_expressions[len as usize..].to_vec(),
);
// define variables for the variable bits
let variables: Vec<_> = expressions
.into_iter()
.map(|e| {
let e = self
.flatten_expression(statements_flattened, e)
.get_field_unchecked();
self.define(e, statements_flattened)
})
.collect();
// get constants for the constant bits
let constants: Vec<_> = constants
.into_iter()
.map(|e| {
self.flatten_expression(statements_flattened, e)
.get_field_unchecked()
})
.map(|e| match e {
FlatExpression::Number(n) if n == T::one() => true,
FlatExpression::Number(n) if n == T::zero() => false,
_ => unreachable!(),
})
.collect();
// get the list of conditions which must hold iff the `<=` relation holds
let conditions =
self.constant_le_check(statements_flattened, &variables, &constants);
// return `len(conditions) == sum(conditions)`
vec![FlatUExpression::with_field(
self.eq_check(
statements_flattened,
T::from(conditions.len()).into(),
conditions
.into_iter()
.fold(FlatExpression::Number(T::zero()), |acc, e| {
FlatExpression::Add(box acc, box e)
}),
),
)]
}
funct => {
let funct = funct.synthetize(&generics);
@ -1924,8 +1975,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
// constants do not require directives
if let Some(FlatExpression::Number(ref x)) = e.field {
let bits: Vec<_> = Interpreter::default()
.execute_solver(&Solver::bits(to), &[x.clone()])
let bits: Vec<_> = Interpreter::execute_solver(&Solver::bits(to), &[x.clone()])
.unwrap()
.into_iter()
.map(FlatExpression::Number)

View file

@ -148,6 +148,10 @@ impl Importer {
id: symbol.get_alias(),
symbol: Symbol::Flat(FlatEmbed::Unpack),
},
"bit_array_le" => SymbolDeclaration {
id: symbol.get_alias(),
symbol: Symbol::Flat(FlatEmbed::BitArrayLe),
},
"u64_to_bits" => SymbolDeclaration {
id: symbol.get_alias(),
symbol: Symbol::Flat(FlatEmbed::U64ToBits),

View file

@ -1,5 +1,4 @@
use crate::flat_absy::flat_variable::FlatVariable;
use crate::ir::Directive;
use crate::ir::{LinComb, Prog, QuadComb, Statement, Witness};
use crate::solvers::Solver;
use serde::{Deserialize, Serialize};
@ -66,30 +65,25 @@ impl Interpreter {
}
},
Statement::Directive(ref d) => {
match (&d.solver, &d.inputs, self.should_try_out_of_range) {
(Solver::Bits(bitwidth), inputs, true)
if inputs[0].left.0.len() > 1
|| inputs[0].right.0.len() > 1
&& *bitwidth == T::get_required_bits() =>
{
Self::try_solve_out_of_range(&d, &mut witness)
}
_ => {
let inputs: Vec<_> = d
.inputs
.iter()
.map(|i| i.evaluate(&witness).unwrap())
.collect();
match self.execute_solver(&d.solver, &inputs) {
Ok(res) => {
for (i, o) in d.outputs.iter().enumerate() {
witness.insert(*o, res[i].clone());
}
continue;
}
Err(_) => return Err(Error::Solver),
};
let mut inputs: Vec<_> = d
.inputs
.iter()
.map(|i| i.evaluate(&witness).unwrap())
.collect();
let res = match (&d.solver, self.should_try_out_of_range) {
(Solver::Bits(bitwidth), true) if *bitwidth >= T::get_required_bits() => {
Ok(Self::try_solve_with_out_of_range_bits(
*bitwidth,
inputs.pop().unwrap(),
))
}
_ => Self::execute_solver(&d.solver, &inputs),
}
.map_err(|_| Error::Solver)?;
for (i, o) in d.outputs.iter().enumerate() {
witness.insert(*o, res[i].clone());
}
}
}
@ -98,34 +92,31 @@ impl Interpreter {
Ok(Witness(witness))
}
fn try_solve_out_of_range<T: Field>(d: &Directive<T>, witness: &mut BTreeMap<FlatVariable, T>) {
fn try_solve_with_out_of_range_bits<T: Field>(bit_width: usize, input: T) -> Vec<T> {
use num::traits::Pow;
use num_bigint::BigUint;
let candidate = input.to_biguint() + T::max_value().to_biguint() + T::from(1).to_biguint();
// we target the `2a - 2b` part of the `<` check by only returning out-of-range results
// when the input is not a single summand
let value = d.inputs[0].evaluate(&witness).unwrap();
let candidate = value.to_biguint() + T::max_value().to_biguint() + T::from(1).to_biguint();
let input = if candidate < T::from(2).to_biguint().pow(T::get_required_bits()) {
candidate
} else {
value.to_biguint()
input.to_biguint()
};
let mut num = input;
let mut res = vec![];
let bits = T::get_required_bits();
for i in (0..bits).rev() {
if T::from(2).to_biguint().pow(i as usize) <= num {
num -= T::from(2).to_biguint().pow(i as usize);
res.push(T::one());
} else {
res.push(T::zero());
}
}
assert_eq!(num, T::zero().to_biguint());
for (i, o) in d.outputs.iter().enumerate() {
witness.insert(*o, res[i].clone());
}
let padding = bit_width - T::get_required_bits();
(0..padding)
.map(|_| T::zero())
.chain((0..T::get_required_bits()).rev().scan(input, |state, i| {
if BigUint::from(2usize).pow(i) <= *state {
*state = (*state).clone() - BigUint::from(2usize).pow(i);
Some(T::one())
} else {
Some(T::zero())
}
}))
.collect()
}
fn check_inputs<T: Field, U>(&self, program: &Prog<T>, inputs: &[U]) -> Result<(), Error> {
@ -139,11 +130,7 @@ impl Interpreter {
}
}
pub fn execute_solver<T: Field>(
&self,
solver: &Solver,
inputs: &[T],
) -> Result<Vec<T>, String> {
pub fn execute_solver<T: Field>(solver: &Solver, inputs: &[T]) -> Result<Vec<T>, String> {
let (expected_input_count, expected_output_count) = solver.get_signature();
assert_eq!(inputs.len(), expected_input_count);
@ -156,18 +143,23 @@ impl Interpreter {
],
},
Solver::Bits(bit_width) => {
let mut num = inputs[0].clone();
let mut res = vec![];
let padding = bit_width.saturating_sub(T::get_required_bits());
for i in (0..*bit_width).rev() {
if T::from(2).pow(i) <= num {
num = num - T::from(2).pow(i);
res.push(T::one());
} else {
res.push(T::zero());
}
}
res
let bit_width = bit_width - padding;
let num = inputs[0].clone();
(0..padding)
.map(|_| T::zero())
.chain((0..bit_width).rev().scan(num, |state, i| {
if T::from(2).pow(i) <= *state {
*state = (*state).clone() - T::from(2).pow(i);
Some(T::one())
} else {
Some(T::zero())
}
}))
.collect()
}
Solver::Xor => {
let x = inputs[0].clone();
@ -346,16 +338,14 @@ mod tests {
fn execute() {
let cond_eq = Solver::ConditionEq;
let inputs = vec![0];
let interpreter = Interpreter::default();
let r = interpreter
.execute_solver(
&cond_eq,
&inputs
.iter()
.map(|&i| Bn128Field::from(i))
.collect::<Vec<_>>(),
)
.unwrap();
let r = Interpreter::execute_solver(
&cond_eq,
&inputs
.iter()
.map(|&i| Bn128Field::from(i))
.collect::<Vec<_>>(),
)
.unwrap();
let res: Vec<Bn128Field> = vec![0, 1].iter().map(|&i| Bn128Field::from(i)).collect();
assert_eq!(r, &res[..]);
}
@ -364,16 +354,14 @@ mod tests {
fn execute_non_eq() {
let cond_eq = Solver::ConditionEq;
let inputs = vec![1];
let interpreter = Interpreter::default();
let r = interpreter
.execute_solver(
&cond_eq,
&inputs
.iter()
.map(|&i| Bn128Field::from(i))
.collect::<Vec<_>>(),
)
.unwrap();
let r = Interpreter::execute_solver(
&cond_eq,
&inputs
.iter()
.map(|&i| Bn128Field::from(i))
.collect::<Vec<_>>(),
)
.unwrap();
let res: Vec<Bn128Field> = vec![1, 1].iter().map(|&i| Bn128Field::from(i)).collect();
assert_eq!(r, &res[..]);
}
@ -382,10 +370,9 @@ mod tests {
#[test]
fn bits_of_one() {
let inputs = vec![Bn128Field::from(1)];
let interpreter = Interpreter::default();
let res = interpreter
.execute_solver(&Solver::Bits(Bn128Field::get_required_bits()), &inputs)
.unwrap();
let res =
Interpreter::execute_solver(&Solver::Bits(Bn128Field::get_required_bits()), &inputs)
.unwrap();
assert_eq!(res[253], Bn128Field::from(1));
for r in &res[0..253] {
assert_eq!(*r, Bn128Field::from(0));
@ -395,10 +382,9 @@ mod tests {
#[test]
fn bits_of_42() {
let inputs = vec![Bn128Field::from(42)];
let interpreter = Interpreter::default();
let res = interpreter
.execute_solver(&Solver::Bits(Bn128Field::get_required_bits()), &inputs)
.unwrap();
let res =
Interpreter::execute_solver(&Solver::Bits(Bn128Field::get_required_bits()), &inputs)
.unwrap();
assert_eq!(res[253], Bn128Field::from(0));
assert_eq!(res[252], Bn128Field::from(1));
assert_eq!(res[251], Bn128Field::from(0));
@ -407,4 +393,15 @@ mod tests {
assert_eq!(res[248], Bn128Field::from(1));
assert_eq!(res[247], Bn128Field::from(0));
}
#[test]
fn five_hundred_bits_of_1() {
let inputs = vec![Bn128Field::from(1)];
let res = Interpreter::execute_solver(&Solver::Bits(500), &inputs).unwrap();
let mut expected = vec![Bn128Field::from(0); 500];
expected[499] = Bn128Field::from(1);
assert_eq!(res, expected);
}
}

View file

@ -140,9 +140,7 @@ impl<T: Field> Folder<T> for RedefinitionOptimizer<T> {
// unwrap inputs to their constant value
let inputs: Vec<_> = inputs.into_iter().map(|i| i.unwrap()).collect();
// run the solver
let outputs = Interpreter::default()
.execute_solver(&d.solver, &inputs)
.unwrap();
let outputs = Interpreter::execute_solver(&d.solver, &inputs).unwrap();
assert_eq!(outputs.len(), d.outputs.len());
// insert the results in the substitution

View file

@ -1,20 +1,22 @@
use crate::embed::FlatEmbed;
use crate::typed_absy::TypedProgram;
use crate::typed_absy::{
result_folder::fold_uint_expression_inner, result_folder::ResultFolder, UBitwidth,
UExpressionInner,
result_folder::ResultFolder,
result_folder::{fold_expression_list_inner, fold_uint_expression_inner},
Constant, TypedExpressionListInner, Types, UBitwidth, UExpressionInner,
};
use zokrates_field::Field;
pub struct ShiftChecker;
pub struct ConstantArgumentChecker;
impl ShiftChecker {
impl ConstantArgumentChecker {
pub fn check<T: Field>(p: TypedProgram<T>) -> Result<TypedProgram<T>, Error> {
ShiftChecker.fold_program(p)
ConstantArgumentChecker.fold_program(p)
}
}
pub type Error = String;
impl<'ast, T: Field> ResultFolder<'ast, T> for ShiftChecker {
impl<'ast, T: Field> ResultFolder<'ast, T> for ConstantArgumentChecker {
type Error = Error;
fn fold_uint_expression_inner(
@ -52,4 +54,33 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for ShiftChecker {
e => fold_uint_expression_inner(self, bitwidth, e),
}
}
fn fold_expression_list_inner(
&mut self,
tys: &Types<'ast, T>,
l: TypedExpressionListInner<'ast, T>,
) -> Result<TypedExpressionListInner<'ast, T>, Error> {
match l {
TypedExpressionListInner::EmbedCall(FlatEmbed::BitArrayLe, generics, arguments) => {
let arguments = arguments
.into_iter()
.map(|a| self.fold_expression(a))
.collect::<Result<Vec<_>, _>>()?;
if arguments[1].is_constant() {
Ok(TypedExpressionListInner::EmbedCall(
FlatEmbed::BitArrayLe,
generics,
arguments,
))
} else {
Err(format!(
"Cannot compare to a variable value, found `{}`",
arguments[1]
))
}
}
l => fold_expression_list_inner(self, tys, l),
}
}
}

View file

@ -570,7 +570,7 @@ fn fold_select_expression<'ast, T: Field, E>(
statements_buffer: &mut Vec<zir::ZirStatement<'ast, T>>,
select: typed_absy::SelectExpression<'ast, T, E>,
) -> Vec<zir::ZirExpression<'ast, T>> {
let size = typed_absy::types::ConcreteType::try_from(*select.array.ty().ty)
let size = typed_absy::types::ConcreteType::try_from(*select.array.ty().clone().ty)
.unwrap()
.get_primitive_count();

View file

@ -5,21 +5,21 @@
//! @date 2018
mod branch_isolator;
mod constant_argument_checker;
mod constant_inliner;
mod flat_propagation;
mod flatten_complex_types;
mod propagation;
mod reducer;
mod shift_checker;
mod uint_optimizer;
mod unconstrained_vars;
mod variable_write_remover;
use self::branch_isolator::Isolator;
use self::constant_argument_checker::ConstantArgumentChecker;
use self::flatten_complex_types::Flattener;
use self::propagation::Propagator;
use self::reducer::reduce_program;
use self::shift_checker::ShiftChecker;
use self::uint_optimizer::UintOptimizer;
use self::unconstrained_vars::UnconstrainedVariableDetector;
use self::variable_write_remover::VariableWriteRemover;
@ -39,7 +39,7 @@ pub trait Analyse {
pub enum Error {
Reducer(self::reducer::Error),
Propagation(self::propagation::Error),
NonConstantShift(self::shift_checker::Error),
NonConstantArgument(self::constant_argument_checker::Error),
}
impl From<reducer::Error> for Error {
@ -54,9 +54,9 @@ impl From<propagation::Error> for Error {
}
}
impl From<shift_checker::Error> for Error {
fn from(e: shift_checker::Error) -> Self {
Error::NonConstantShift(e)
impl From<constant_argument_checker::Error> for Error {
fn from(e: constant_argument_checker::Error) -> Self {
Error::NonConstantArgument(e)
}
}
@ -65,7 +65,7 @@ impl fmt::Display for Error {
match self {
Error::Reducer(e) => write!(f, "{}", e),
Error::Propagation(e) => write!(f, "{}", e),
Error::NonConstantShift(e) => write!(f, "{}", e),
Error::NonConstantArgument(e) => write!(f, "{}", e),
}
}
}
@ -107,9 +107,9 @@ impl<'ast, T: Field> TypedProgram<'ast, T> {
let r = VariableWriteRemover::apply(r);
log::trace!("\n{}", r);
// detect non constant shifts
log::debug!("Static analyser: Detect non constant shifts");
let r = ShiftChecker::check(r).map_err(Error::from)?;
// detect non constant shifts and constant lt bounds
log::debug!("Static analyser: Detect non constant arguments");
let r = ConstantArgumentChecker::check(r).map_err(Error::from)?;
log::trace!("\n{}", r);
// convert to zir, removing complex types

View file

@ -124,126 +124,6 @@ impl<'ast, 'a, T: Field> Propagator<'ast, 'a, T> {
}
}
fn is_constant<T: Field>(e: &TypedExpression<T>) -> bool {
match e {
TypedExpression::FieldElement(FieldElementExpression::Number(..)) => true,
TypedExpression::Boolean(BooleanExpression::Value(..)) => true,
TypedExpression::Array(a) => match a.as_inner() {
ArrayExpressionInner::Value(v) => v.0.iter().all(|e| match e {
TypedExpressionOrSpread::Expression(e) => is_constant(e),
_ => false,
}),
ArrayExpressionInner::Slice(box a, box from, box to) => {
is_constant(&from.clone().into())
&& is_constant(&to.clone().into())
&& is_constant(&a.clone().into())
}
ArrayExpressionInner::Repeat(box e, box count) => {
is_constant(&count.clone().into()) && is_constant(&e)
}
_ => false,
},
TypedExpression::Struct(a) => match a.as_inner() {
StructExpressionInner::Value(v) => v.iter().all(|e| is_constant(e)),
_ => false,
},
TypedExpression::Uint(a) => matches!(a.as_inner(), UExpressionInner::Value(..)),
_ => false,
}
}
// in the constant map, we only want canonical constants: [0; 3] -> [0, 0, 0], [...[1], 2] -> [1, 2], etc
fn to_canonical_constant<T: Field>(e: TypedExpression<T>) -> TypedExpression<T> {
fn to_canonical_constant_aux<T: Field>(
e: TypedExpressionOrSpread<T>,
) -> Vec<TypedExpression<T>> {
match e {
TypedExpressionOrSpread::Expression(e) => vec![e],
TypedExpressionOrSpread::Spread(s) => match s.array.into_inner() {
ArrayExpressionInner::Value(v) => {
v.into_iter().flat_map(to_canonical_constant_aux).collect()
}
_ => unimplemented!(),
},
}
}
match e {
TypedExpression::Array(a) => {
let array_ty = a.ty();
match a.into_inner() {
ArrayExpressionInner::Value(v) => ArrayExpressionInner::Value(
v.into_iter()
.flat_map(to_canonical_constant_aux)
.map(|e| e.into())
.collect::<Vec<_>>()
.into(),
)
.annotate(*array_ty.ty, array_ty.size)
.into(),
ArrayExpressionInner::Slice(box a, box from, box to) => {
let from = match from.into_inner() {
UExpressionInner::Value(from) => from as usize,
_ => unreachable!("should be a uint value"),
};
let to = match to.into_inner() {
UExpressionInner::Value(to) => to as usize,
_ => unreachable!("should be a uint value"),
};
let v = match a.into_inner() {
ArrayExpressionInner::Value(v) => v,
_ => unreachable!("should be an array value"),
};
ArrayExpressionInner::Value(
v.into_iter()
.flat_map(to_canonical_constant_aux)
.map(|e| e.into())
.enumerate()
.filter(|(index, _)| index >= &from && index < &to)
.map(|(_, e)| e)
.collect::<Vec<_>>()
.into(),
)
.annotate(*array_ty.ty, array_ty.size)
.into()
}
ArrayExpressionInner::Repeat(box e, box count) => {
let count = match count.into_inner() {
UExpressionInner::Value(from) => from as usize,
_ => unreachable!("should be a uint value"),
};
let e = to_canonical_constant(e);
ArrayExpressionInner::Value(
vec![TypedExpressionOrSpread::Expression(e); count].into(),
)
.annotate(*array_ty.ty, array_ty.size)
.into()
}
_ => unreachable!(),
}
}
TypedExpression::Struct(s) => {
let struct_ty = s.ty().clone();
match s.into_inner() {
StructExpressionInner::Value(expressions) => StructExpressionInner::Value(
expressions.into_iter().map(to_canonical_constant).collect(),
)
.annotate(struct_ty)
.into(),
_ => unreachable!(),
}
}
e => e,
}
}
impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Propagator<'ast, 'a, T> {
type Error = Error;
@ -341,10 +221,10 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Propagator<'ast, 'a, T> {
}
};
if is_constant(&expr) {
if expr.is_constant() {
match assignee {
TypedAssignee::Identifier(var) => {
let expr = to_canonical_constant(expr);
let expr = expr.into_canonical_constant();
assert!(self.constants.insert(var.id, expr).is_none());
@ -352,7 +232,7 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Propagator<'ast, 'a, T> {
}
assignee => match self.try_get_constant_mut(&assignee) {
Ok((_, c)) => {
*c = to_canonical_constant(expr);
*c = expr.into_canonical_constant();
Ok(vec![])
}
Err(v) => match self.constants.remove(&v.id) {
@ -423,7 +303,7 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Propagator<'ast, 'a, T> {
let argument = arguments.pop().unwrap();
let argument = to_canonical_constant(argument);
let argument = argument.into_canonical_constant();
match ArrayExpression::try_from(argument)
.unwrap()
@ -498,10 +378,11 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Propagator<'ast, 'a, T> {
}
}
match arguments.iter().all(|a| is_constant(a)) {
match arguments.iter().all(|a| a.is_constant()) {
true => {
let r: Option<TypedExpression<'ast, T>> = match embed {
FlatEmbed::U32ToField => None, // todo
FlatEmbed::BitArrayLe => None, // todo
FlatEmbed::U64FromBits => Some(process_u_from_bits(
assignees.clone(),
arguments.clone(),
@ -1173,6 +1054,34 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Propagator<'ast, 'a, T> {
},
None => Ok(ArrayExpressionInner::Identifier(id)),
},
ArrayExpressionInner::Value(exprs) => {
Ok(ArrayExpressionInner::Value(
exprs
.into_iter()
.map(|e| self.fold_expression_or_spread(e))
.collect::<Result<Vec<_>, _>>()?
.into_iter()
.flat_map(|e| {
match e {
// simplify `...[a, b]` to `a, b`
TypedExpressionOrSpread::Spread(TypedSpread {
array:
ArrayExpression {
inner: ArrayExpressionInner::Value(v),
..
},
}) => v.0,
e => vec![e],
}
})
// ignore spreads over empty arrays
.filter_map(|e| match e {
TypedExpressionOrSpread::Spread(s) if s.array.size() == 0 => None,
e => Some(e),
})
.collect(),
))
}
e => fold_array_expression_inner(self, ty, e),
}
}

View file

@ -1273,13 +1273,8 @@ mod tests {
// def main():
// # PUSH CALL to foo::<1>
// # PUSH CALL to bar::<2>
// field[2] a_1 = [...[1]], 0]
// field[2] #RET_0_1 = a_1
// # POP CALL
// field[1] ret := #RET_0_1[0..1]
// field[1] #RET_0 = ret
// # POP CALL
// field[1] b_0 := #RET_0
// return
let foo_signature = DeclarationSignature::new()
@ -1452,71 +1447,8 @@ mod tests {
.collect(),
),
),
TypedStatement::Definition(
Variable::array(Identifier::from("a").version(1), Type::FieldElement, 2u32)
.into(),
ArrayExpressionInner::Value(
vec![
TypedExpressionOrSpread::Spread(
ArrayExpressionInner::Value(
vec![TypedExpressionOrSpread::Expression(
FieldElementExpression::Number(Bn128Field::from(1)).into(),
)]
.into(),
)
.annotate(Type::FieldElement, 1u32)
.into(),
),
FieldElementExpression::Number(Bn128Field::from(0)).into(),
]
.into(),
)
.annotate(Type::FieldElement, 2u32)
.into(),
),
TypedStatement::Definition(
Variable::array(
Identifier::from(CoreIdentifier::Call(0)).version(1),
Type::FieldElement,
2u32,
)
.into(),
ArrayExpressionInner::Identifier(Identifier::from("a").version(1))
.annotate(Type::FieldElement, 2u32)
.into(),
),
TypedStatement::PopCallLog,
TypedStatement::Definition(
Variable::array("ret", Type::FieldElement, 1u32).into(),
ArrayExpressionInner::Slice(
box ArrayExpressionInner::Identifier(
Identifier::from(CoreIdentifier::Call(0)).version(1),
)
.annotate(Type::FieldElement, 2u32),
box 0u32.into(),
box 1u32.into(),
)
.annotate(Type::FieldElement, 1u32)
.into(),
),
TypedStatement::Definition(
Variable::array(
Identifier::from(CoreIdentifier::Call(0)),
Type::FieldElement,
1u32,
)
.into(),
ArrayExpressionInner::Identifier("ret".into())
.annotate(Type::FieldElement, 1u32)
.into(),
),
TypedStatement::PopCallLog,
TypedStatement::Definition(
Variable::array("b", Type::FieldElement, 1u32).into(),
ArrayExpressionInner::Identifier(Identifier::from(CoreIdentifier::Call(0)))
.annotate(Type::FieldElement, 1u32)
.into(),
),
TypedStatement::Return(vec![]),
],
signature: DeclarationSignature::new(),

View file

@ -738,7 +738,7 @@ mod tests {
assert_eq!(
UintOptimizer::new()
.fold_uint_expression(UExpression::right_shift(left.clone(), right.clone())),
.fold_uint_expression(UExpression::right_shift(left.clone(), right)),
UExpression::right_shift(left_expected, right_expected).with_max(output_max)
);
}
@ -761,7 +761,7 @@ mod tests {
assert_eq!(
UintOptimizer::new()
.fold_uint_expression(UExpression::left_shift(left.clone(), right.clone())),
.fold_uint_expression(UExpression::left_shift(left.clone(), right)),
UExpression::left_shift(left_expected, right_expected).with_max(output_max)
);
}

View file

@ -5,7 +5,7 @@ use crate::typed_absy::types::{
};
use crate::typed_absy::UBitwidth;
use crate::typed_absy::{
ArrayExpression, ArrayExpressionInner, BooleanExpression, FieldElementExpression, IfElse,
ArrayExpression, ArrayExpressionInner, BooleanExpression, Expr, FieldElementExpression, IfElse,
IfElseExpression, Select, SelectExpression, StructExpression, StructExpressionInner, Typed,
TypedExpression, TypedExpressionOrSpread, TypedSpread, UExpression, UExpressionInner,
};
@ -585,7 +585,7 @@ impl<'ast, T: Field> ArrayExpression<'ast, T> {
array: Self,
target_array_ty: &GArrayType<S>,
) -> Result<Self, TypedExpression<'ast, T>> {
let array_ty = array.ty();
let array_ty = array.ty().clone();
// elements must fit in the target type
match array.into_inner() {

View file

@ -1029,8 +1029,8 @@ impl<'ast, T> From<bool> for BooleanExpression<'ast, T> {
/// type checking
#[derive(Clone, PartialEq, Debug, Hash, Eq)]
pub struct ArrayExpression<'ast, T> {
ty: Box<ArrayType<'ast, T>>,
inner: ArrayExpressionInner<'ast, T>,
pub ty: Box<ArrayType<'ast, T>>,
pub inner: ArrayExpressionInner<'ast, T>,
}
#[derive(Debug, PartialEq, Eq, Hash, Clone)]
@ -1149,25 +1149,6 @@ impl<'ast, T: Clone> ArrayExpression<'ast, T> {
pub fn size(&self) -> UExpression<'ast, T> {
self.ty.size.clone()
}
pub fn as_inner(&self) -> &ArrayExpressionInner<'ast, T> {
&self.inner
}
pub fn as_inner_mut(&mut self) -> &mut ArrayExpressionInner<'ast, T> {
&mut self.inner
}
pub fn into_inner(self) -> ArrayExpressionInner<'ast, T> {
self.inner
}
pub fn ty(&self) -> ArrayType<'ast, T> {
ArrayType {
size: self.size(),
ty: box self.inner_type().clone(),
}
}
}
#[derive(Clone, PartialEq, Debug, Hash, Eq)]
@ -1495,15 +1476,23 @@ pub trait Expr<'ast, T>: From<TypedExpression<'ast, T>> {
type Inner;
type Ty: Clone + IntoTypes<'ast, T>;
fn ty(&self) -> &Self::Ty;
fn into_inner(self) -> Self::Inner;
fn as_inner(&self) -> &Self::Inner;
fn as_inner_mut(&mut self) -> &mut Self::Inner;
}
impl<'ast, T: Clone> Expr<'ast, T> for FieldElementExpression<'ast, T> {
type Inner = Self;
type Ty = Type<'ast, T>;
fn ty(&self) -> &Self::Ty {
&Type::FieldElement
}
fn into_inner(self) -> Self::Inner {
self
}
@ -1511,12 +1500,20 @@ impl<'ast, T: Clone> Expr<'ast, T> for FieldElementExpression<'ast, T> {
fn as_inner(&self) -> &Self::Inner {
&self
}
fn as_inner_mut(&mut self) -> &mut Self::Inner {
self
}
}
impl<'ast, T: Clone> Expr<'ast, T> for BooleanExpression<'ast, T> {
type Inner = Self;
type Ty = Type<'ast, T>;
fn ty(&self) -> &Self::Ty {
&Type::Boolean
}
fn into_inner(self) -> Self::Inner {
self
}
@ -1524,12 +1521,20 @@ impl<'ast, T: Clone> Expr<'ast, T> for BooleanExpression<'ast, T> {
fn as_inner(&self) -> &Self::Inner {
&self
}
fn as_inner_mut(&mut self) -> &mut Self::Inner {
self
}
}
impl<'ast, T: Clone> Expr<'ast, T> for UExpression<'ast, T> {
type Inner = UExpressionInner<'ast, T>;
type Ty = UBitwidth;
fn ty(&self) -> &Self::Ty {
&self.bitwidth
}
fn into_inner(self) -> Self::Inner {
self.inner
}
@ -1537,12 +1542,20 @@ impl<'ast, T: Clone> Expr<'ast, T> for UExpression<'ast, T> {
fn as_inner(&self) -> &Self::Inner {
&self.inner
}
fn as_inner_mut(&mut self) -> &mut Self::Inner {
&mut self.inner
}
}
impl<'ast, T: Clone> Expr<'ast, T> for StructExpression<'ast, T> {
type Inner = StructExpressionInner<'ast, T>;
type Ty = StructType<'ast, T>;
fn ty(&self) -> &Self::Ty {
&self.ty
}
fn into_inner(self) -> Self::Inner {
self.inner
}
@ -1550,12 +1563,20 @@ impl<'ast, T: Clone> Expr<'ast, T> for StructExpression<'ast, T> {
fn as_inner(&self) -> &Self::Inner {
&self.inner
}
fn as_inner_mut(&mut self) -> &mut Self::Inner {
&mut self.inner
}
}
impl<'ast, T: Clone> Expr<'ast, T> for ArrayExpression<'ast, T> {
type Inner = ArrayExpressionInner<'ast, T>;
type Ty = ArrayType<'ast, T>;
fn ty(&self) -> &Self::Ty {
&self.ty
}
fn into_inner(self) -> Self::Inner {
self.inner
}
@ -1563,12 +1584,20 @@ impl<'ast, T: Clone> Expr<'ast, T> for ArrayExpression<'ast, T> {
fn as_inner(&self) -> &Self::Inner {
&self.inner
}
fn as_inner_mut(&mut self) -> &mut Self::Inner {
&mut self.inner
}
}
impl<'ast, T: Clone> Expr<'ast, T> for IntExpression<'ast, T> {
type Inner = Self;
type Ty = Type<'ast, T>;
fn ty(&self) -> &Self::Ty {
&Type::Int
}
fn into_inner(self) -> Self::Inner {
self
}
@ -1576,12 +1605,20 @@ impl<'ast, T: Clone> Expr<'ast, T> for IntExpression<'ast, T> {
fn as_inner(&self) -> &Self::Inner {
&self
}
fn as_inner_mut(&mut self) -> &mut Self::Inner {
self
}
}
impl<'ast, T: Clone> Expr<'ast, T> for TypedExpressionList<'ast, T> {
type Inner = TypedExpressionListInner<'ast, T>;
type Ty = Types<'ast, T>;
fn ty(&self) -> &Self::Ty {
&self.types
}
fn into_inner(self) -> Self::Inner {
self.inner
}
@ -1589,6 +1626,10 @@ impl<'ast, T: Clone> Expr<'ast, T> for TypedExpressionList<'ast, T> {
fn as_inner(&self) -> &Self::Inner {
&self.inner
}
fn as_inner_mut(&mut self) -> &mut Self::Inner {
&mut self.inner
}
}
// Enums types to enable returning e.g a member expression OR another type of expression of this type
@ -1769,7 +1810,7 @@ impl<'ast, T> Member<'ast, T> for BooleanExpression<'ast, T> {
}
}
impl<'ast, T> Member<'ast, T> for UExpression<'ast, T> {
impl<'ast, T: Clone> Member<'ast, T> for UExpression<'ast, T> {
fn member(s: StructExpression<'ast, T>, id: MemberId) -> Self {
let ty = s.ty().members.iter().find(|member| id == member.id);
let bitwidth = match ty {
@ -1949,7 +1990,7 @@ impl<'ast, T: Field> Block<'ast, T> for UExpression<'ast, T> {
impl<'ast, T: Field> Block<'ast, T> for ArrayExpression<'ast, T> {
fn block(statements: Vec<TypedStatement<'ast, T>>, value: Self) -> Self {
let array_ty = value.ty();
let array_ty = value.ty().clone();
ArrayExpressionInner::Block(BlockExpression::new(statements, value))
.annotate(*array_ty.ty, array_ty.size)
}
@ -1962,3 +2003,199 @@ impl<'ast, T: Field> Block<'ast, T> for StructExpression<'ast, T> {
StructExpressionInner::Block(BlockExpression::new(statements, value)).annotate(struct_ty)
}
}
pub trait Constant: Sized {
// return whether this is constant
fn is_constant(&self) -> bool;
// canonicalize an expression *that we know to be constant*
// for example for [0; 3] -> [0, 0, 0], [...[1], 2] -> [1, 2], etc
fn into_canonical_constant(self) -> Self {
self
}
}
impl<'ast, T: Field> Constant for FieldElementExpression<'ast, T> {
fn is_constant(&self) -> bool {
matches!(self, FieldElementExpression::Number(..))
}
}
impl<'ast, T: Field> Constant for BooleanExpression<'ast, T> {
fn is_constant(&self) -> bool {
matches!(self, BooleanExpression::Value(..))
}
}
impl<'ast, T: Field> Constant for UExpression<'ast, T> {
fn is_constant(&self) -> bool {
matches!(self.as_inner(), UExpressionInner::Value(..))
}
}
impl<'ast, T: Field> Constant for ArrayExpression<'ast, T> {
fn is_constant(&self) -> bool {
match self.as_inner() {
ArrayExpressionInner::Value(v) => v.0.iter().all(|e| match e {
TypedExpressionOrSpread::Expression(e) => e.is_constant(),
TypedExpressionOrSpread::Spread(s) => s.array.is_constant(),
}),
ArrayExpressionInner::Slice(box a, box from, box to) => {
from.is_constant() && to.is_constant() && a.is_constant()
}
ArrayExpressionInner::Repeat(box e, box count) => {
count.is_constant() && e.is_constant()
}
_ => false,
}
}
fn into_canonical_constant(self) -> Self {
fn into_canonical_constant_aux<T: Field>(
e: TypedExpressionOrSpread<T>,
) -> Vec<TypedExpression<T>> {
match e {
TypedExpressionOrSpread::Expression(e) => vec![e],
TypedExpressionOrSpread::Spread(s) => match s.array.into_inner() {
ArrayExpressionInner::Value(v) => v
.into_iter()
.flat_map(into_canonical_constant_aux)
.collect(),
ArrayExpressionInner::Slice(box v, box from, box to) => {
let from = match from.into_inner() {
UExpressionInner::Value(v) => v,
_ => unreachable!(),
};
let to = match to.into_inner() {
UExpressionInner::Value(v) => v,
_ => unreachable!(),
};
let v = match v.into_inner() {
ArrayExpressionInner::Value(v) => v,
_ => unreachable!(),
};
v.into_iter()
.flat_map(into_canonical_constant_aux)
.skip(from as usize)
.take(to as usize - from as usize)
.collect()
}
ArrayExpressionInner::Repeat(box e, box count) => {
let count = match count.into_inner() {
UExpressionInner::Value(count) => count,
_ => unreachable!(),
};
vec![e; count as usize]
}
a => unreachable!("{}", a),
},
}
}
let array_ty = self.ty().clone();
match self.into_inner() {
ArrayExpressionInner::Value(v) => ArrayExpressionInner::Value(
v.into_iter()
.flat_map(into_canonical_constant_aux)
.map(|e| e.into())
.collect::<Vec<_>>()
.into(),
)
.annotate(*array_ty.ty, array_ty.size),
ArrayExpressionInner::Slice(box a, box from, box to) => {
let from = match from.into_inner() {
UExpressionInner::Value(from) => from as usize,
_ => unreachable!("should be a uint value"),
};
let to = match to.into_inner() {
UExpressionInner::Value(to) => to as usize,
_ => unreachable!("should be a uint value"),
};
let v = match a.into_inner() {
ArrayExpressionInner::Value(v) => v,
_ => unreachable!("should be an array value"),
};
ArrayExpressionInner::Value(
v.into_iter()
.flat_map(into_canonical_constant_aux)
.map(|e| e.into())
.skip(from)
.take(to - from)
.collect::<Vec<_>>()
.into(),
)
.annotate(*array_ty.ty, array_ty.size)
}
ArrayExpressionInner::Repeat(box e, box count) => {
let count = match count.into_inner() {
UExpressionInner::Value(from) => from as usize,
_ => unreachable!("should be a uint value"),
};
let e = e.into_canonical_constant();
ArrayExpressionInner::Value(
vec![TypedExpressionOrSpread::Expression(e); count].into(),
)
.annotate(*array_ty.ty, array_ty.size)
}
_ => unreachable!(),
}
}
}
impl<'ast, T: Field> Constant for StructExpression<'ast, T> {
fn is_constant(&self) -> bool {
match self.as_inner() {
StructExpressionInner::Value(v) => v.iter().all(|e| e.is_constant()),
_ => false,
}
}
fn into_canonical_constant(self) -> Self {
let struct_ty = self.ty().clone();
match self.into_inner() {
StructExpressionInner::Value(expressions) => StructExpressionInner::Value(
expressions
.into_iter()
.map(|e| e.into_canonical_constant())
.collect(),
)
.annotate(struct_ty),
_ => unreachable!(),
}
}
}
impl<'ast, T: Field> Constant for TypedExpression<'ast, T> {
fn is_constant(&self) -> bool {
match self {
TypedExpression::FieldElement(e) => e.is_constant(),
TypedExpression::Boolean(e) => e.is_constant(),
TypedExpression::Array(e) => e.is_constant(),
TypedExpression::Struct(e) => e.is_constant(),
TypedExpression::Uint(e) => e.is_constant(),
_ => unreachable!(),
}
}
fn into_canonical_constant(self) -> Self {
match self {
TypedExpression::FieldElement(e) => e.into_canonical_constant().into(),
TypedExpression::Boolean(e) => e.into_canonical_constant().into(),
TypedExpression::Array(e) => e.into_canonical_constant().into(),
TypedExpression::Struct(e) => e.into_canonical_constant().into(),
TypedExpression::Uint(e) => e.into_canonical_constant().into(),
_ => unreachable!(),
}
}
}

View file

@ -10,6 +10,7 @@ use zokrates_core::{
ir::Interpreter,
};
use zokrates_field::Bn128Field;
use zokrates_fs_resolver::FileSystemResolver;
#[test]
fn lt_field() {
@ -74,3 +75,83 @@ fn lt_uint() {
)
.is_err());
}
#[test]
fn unpack256() {
let source = r#"
import "utils/pack/bool/unpack256"
def main(private field a):
bool[256] bits = unpack256(a)
assert(bits[255])
return
"#
.to_string();
// let's try to prove that the least significant bit of 0 is 1
// we exploit the fact that the bits of 0 are the bits of p, and p is even
// we want this to still fail
let stdlib_path = std::fs::canonicalize(
std::env::current_dir()
.unwrap()
.join("../zokrates_stdlib/stdlib"),
)
.unwrap();
let res: CompilationArtifacts<Bn128Field> = compile(
source,
"./path/to/file".into(),
Some(&FileSystemResolver::with_stdlib_root(
stdlib_path.to_str().unwrap(),
)),
&CompileConfig::default(),
)
.unwrap();
let interpreter = Interpreter::try_out_of_range();
assert!(interpreter
.execute(&res.prog(), &[Bn128Field::from(0)])
.is_err());
}
#[test]
fn unpack256_unchecked() {
let source = r#"
import "utils/pack/bool/nonStrictUnpack256"
def main(private field a):
bool[256] bits = nonStrictUnpack256(a)
assert(bits[255])
return
"#
.to_string();
// let's try to prove that the least significant bit of 0 is 1
// we exploit the fact that the bits of 0 are the bits of p, and p is odd
// we want this to succeed as the non strict version does not enforce the bits to be in range
let stdlib_path = std::fs::canonicalize(
std::env::current_dir()
.unwrap()
.join("../zokrates_stdlib/stdlib"),
)
.unwrap();
let res: CompilationArtifacts<Bn128Field> = compile(
source,
"./path/to/file".into(),
Some(&FileSystemResolver::with_stdlib_root(
stdlib_path.to_str().unwrap(),
)),
&CompileConfig::default(),
)
.unwrap();
let interpreter = Interpreter::try_out_of_range();
assert!(interpreter
.execute(&res.prog(), &[Bn128Field::from(0)])
.is_ok());
}

View file

@ -26,17 +26,16 @@ impl<'a> Resolver<io::Error> for FileSystemResolver<'a> {
) -> Result<(String, PathBuf), io::Error> {
let source = Path::new(&import_location);
if !current_location.is_file() {
return Err(io::Error::new(
io::ErrorKind::Other,
format!("{} was expected to be a file", current_location.display()),
));
}
// paths starting with `./` or `../` are interpreted relative to the current file
// other paths `abc/def` are interpreted relative to the standard library root path
let base = match source.components().next() {
Some(Component::CurDir) | Some(Component::ParentDir) => {
if !current_location.is_file() {
return Err(io::Error::new(
io::ErrorKind::Other,
format!("{} was expected to be a file", current_location.display()),
));
}
current_location.parent().unwrap().into()
}
_ => PathBuf::from(self.stdlib_root_path.unwrap_or("")),

View file

@ -1,12 +1,12 @@
#pragma curve bn128
import "./unpack" as unpack
import "./unpack_unchecked"
// Unpack a field element as 256 big-endian bits
// Note: uniqueness of the output is not guaranteed
// For example, `0` can map to `[0, 0, ..., 0]` or to `bits(p)`
def main(field i) -> bool[256]:
bool[254] b = unpack::<254>(i)
bool[254] b = unpack_unchecked::<254>(i)
return [false, false, ...b]

View file

@ -1,12 +1,12 @@
#pragma curve bn128
from "EMBED" import unpack
import "./unpack_unchecked.zok"
from "field" import FIELD_SIZE_IN_BITS
from "EMBED" import bit_array_le
// Unpack a field element as N big endian bits
def main<N>(field i) -> bool[N]:
assert(N <= 254)
bool[N] res = unpack(i)
bool[N] res = unpack_unchecked(i)
assert(if N >= FIELD_SIZE_IN_BITS then bit_array_le(res, [...[false; N - FIELD_SIZE_IN_BITS], ...unpack_unchecked::<FIELD_SIZE_IN_BITS>(-1)]) else true fi)
return res

View file

@ -1,9 +1,7 @@
#pragma curve bn128
import "./unpack" as unpack
// Unpack a field element as 128 big-endian bits
// Precondition: the input is smaller or equal to `2**128 - 1`
// If the input is larger than `2**128 - 1`, the output is truncated.
def main(field i) -> bool[128]:
bool[128] res = unpack::<128>(i)
return res

View file

@ -0,0 +1,7 @@
import "./unpack" as unpack
// Unpack a field element as 256 big-endian bits
// If the input is larger than `2**256 - 1`, the output is truncated.
def main(field i) -> bool[256]:
bool[256] res = unpack::<256>(i)
return res

View file

@ -0,0 +1,9 @@
from "EMBED" import unpack
// Unpack a field element as N big endian bits without checking for overflows
// This does *not* guarantee a single output: for example, 0 can be decomposed as 0 or as P and this function does not enforce either
def main<N>(field i) -> bool[N]:
bool[N] res = unpack(i)
return res

View file

@ -0,0 +1,16 @@
{
"entry_point": "./tests/tests/utils/pack/bool/unpack256.zok",
"curves": ["Bn128"],
"tests": [
{
"input": {
"values": []
},
"output": {
"Ok": {
"values": []
}
}
}
]
}

View file

@ -0,0 +1,24 @@
import "utils/pack/bool/unpack256" as unpack256
def testFive() -> bool:
bool[256] b = unpack256(5)
assert(b == [...[false; 253], true, false, true])
return true
def testZero() -> bool:
bool[256] b = unpack256(0)
assert(b == [false; 256])
return true
def main():
assert(testFive())
assert(testZero())
return