Merge pull request #955 from Zokrates/bit_lt_embed
Add bit LT embed to make unpack safe for any size
This commit is contained in:
commit
b324e17684
26 changed files with 678 additions and 357 deletions
1
Cargo.lock
generated
1
Cargo.lock
generated
|
@ -2451,6 +2451,7 @@ dependencies = [
|
|||
"zokrates_common",
|
||||
"zokrates_embed",
|
||||
"zokrates_field",
|
||||
"zokrates_fs_resolver",
|
||||
"zokrates_pest_ast",
|
||||
]
|
||||
|
||||
|
|
1
changelogs/unreleased/955-schaeff
Normal file
1
changelogs/unreleased/955-schaeff
Normal file
|
@ -0,0 +1 @@
|
|||
Make the stdlib `unpack` function safe against overflows of bit decompositions for any size of output, introduce `unpack_unchecked` for cases that do not require determinism
|
|
@ -0,0 +1,5 @@
|
|||
from "EMBED" import bit_array_le
|
||||
|
||||
// Calling the `bit_array_le` embed on a non-constant second argument should fail at compile-time
|
||||
def main(bool[1] a, bool[1] b) -> bool:
|
||||
return bit_array_le::<1>(a, b)
|
|
@ -60,6 +60,7 @@ sha2 = { version = "0.9.3", optional = true }
|
|||
[dev-dependencies]
|
||||
wasm-bindgen-test = "^0.3.0"
|
||||
pretty_assertions = "0.6.1"
|
||||
zokrates_fs_resolver = { version = "0.5", path = "../zokrates_fs_resolver"}
|
||||
|
||||
[build-dependencies]
|
||||
cc = { version = "1.0", features = ["parallel"], optional = true }
|
||||
|
|
|
@ -28,6 +28,7 @@ cfg_if::cfg_if! {
|
|||
/// the flattening step when it can be inlined.
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash, Copy)]
|
||||
pub enum FlatEmbed {
|
||||
BitArrayLe,
|
||||
U32ToField,
|
||||
Unpack,
|
||||
U8ToBits,
|
||||
|
@ -47,6 +48,30 @@ pub enum FlatEmbed {
|
|||
impl FlatEmbed {
|
||||
pub fn signature(&self) -> DeclarationSignature<'static> {
|
||||
match self {
|
||||
FlatEmbed::BitArrayLe => DeclarationSignature::new()
|
||||
.generics(vec![Some(DeclarationConstant::Generic(
|
||||
GenericIdentifier {
|
||||
name: "N",
|
||||
index: 0,
|
||||
},
|
||||
))])
|
||||
.inputs(vec![
|
||||
DeclarationType::array((
|
||||
DeclarationType::Boolean,
|
||||
GenericIdentifier {
|
||||
name: "N",
|
||||
index: 0,
|
||||
},
|
||||
)),
|
||||
DeclarationType::array((
|
||||
DeclarationType::Boolean,
|
||||
GenericIdentifier {
|
||||
name: "N",
|
||||
index: 0,
|
||||
},
|
||||
)),
|
||||
])
|
||||
.outputs(vec![DeclarationType::Boolean]),
|
||||
FlatEmbed::U32ToField => DeclarationSignature::new()
|
||||
.inputs(vec![DeclarationType::uint(32)])
|
||||
.outputs(vec![DeclarationType::FieldElement]),
|
||||
|
@ -172,6 +197,7 @@ impl FlatEmbed {
|
|||
|
||||
pub fn id(&self) -> &'static str {
|
||||
match self {
|
||||
FlatEmbed::BitArrayLe => "_BIT_ARRAY_LT",
|
||||
FlatEmbed::U32ToField => "_U32_TO_FIELD",
|
||||
FlatEmbed::Unpack => "_UNPACK",
|
||||
FlatEmbed::U8ToBits => "_U8_TO_BITS",
|
||||
|
@ -449,14 +475,9 @@ fn use_variable(
|
|||
/// * bit_width the number of bits we want to decompose to
|
||||
///
|
||||
/// # Remarks
|
||||
/// * the return value of the `FlatFunction` is not deterministic if `bit_width == T::get_required_bits()`
|
||||
/// as we decompose over `log_2(p) + 1 bits, some
|
||||
/// elements can have multiple representations: For example, `unpack(0)` is `[0, ..., 0]` but also `unpack(p)`
|
||||
/// * the return value of the `FlatFunction` is not deterministic if `bit_width >= T::get_required_bits()`
|
||||
/// as some elements can have multiple representations: For example, `unpack(0)` is `[0, ..., 0]` but also `unpack(p)`
|
||||
pub fn unpack_to_bitwidth<T: Field>(bit_width: usize) -> FlatFunction<T> {
|
||||
let nbits = T::get_required_bits();
|
||||
|
||||
assert!(bit_width <= nbits);
|
||||
|
||||
let mut counter = 0;
|
||||
|
||||
let mut layout = HashMap::new();
|
||||
|
|
|
@ -223,7 +223,6 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
b: &[bool],
|
||||
) -> Vec<FlatExpression<T>> {
|
||||
let len = b.len();
|
||||
assert_eq!(a.len(), T::get_required_bits());
|
||||
assert_eq!(a.len(), b.len());
|
||||
|
||||
let mut is_not_smaller_run = vec![];
|
||||
|
@ -1164,6 +1163,58 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
crate::embed::FlatEmbed::U8FromBits => {
|
||||
vec![self.flatten_bits_to_u(statements_flattened, param_expressions, 8.into())]
|
||||
}
|
||||
crate::embed::FlatEmbed::BitArrayLe => {
|
||||
// get the length of the bit arrays
|
||||
let len = generics[0];
|
||||
|
||||
// split the arguments into the two bit arrays of size `len`
|
||||
let (expressions, constants) = (
|
||||
param_expressions[..len as usize].to_vec(),
|
||||
param_expressions[len as usize..].to_vec(),
|
||||
);
|
||||
|
||||
// define variables for the variable bits
|
||||
let variables: Vec<_> = expressions
|
||||
.into_iter()
|
||||
.map(|e| {
|
||||
let e = self
|
||||
.flatten_expression(statements_flattened, e)
|
||||
.get_field_unchecked();
|
||||
self.define(e, statements_flattened)
|
||||
})
|
||||
.collect();
|
||||
|
||||
// get constants for the constant bits
|
||||
let constants: Vec<_> = constants
|
||||
.into_iter()
|
||||
.map(|e| {
|
||||
self.flatten_expression(statements_flattened, e)
|
||||
.get_field_unchecked()
|
||||
})
|
||||
.map(|e| match e {
|
||||
FlatExpression::Number(n) if n == T::one() => true,
|
||||
FlatExpression::Number(n) if n == T::zero() => false,
|
||||
_ => unreachable!(),
|
||||
})
|
||||
.collect();
|
||||
|
||||
// get the list of conditions which must hold iff the `<=` relation holds
|
||||
let conditions =
|
||||
self.constant_le_check(statements_flattened, &variables, &constants);
|
||||
|
||||
// return `len(conditions) == sum(conditions)`
|
||||
vec![FlatUExpression::with_field(
|
||||
self.eq_check(
|
||||
statements_flattened,
|
||||
T::from(conditions.len()).into(),
|
||||
conditions
|
||||
.into_iter()
|
||||
.fold(FlatExpression::Number(T::zero()), |acc, e| {
|
||||
FlatExpression::Add(box acc, box e)
|
||||
}),
|
||||
),
|
||||
)]
|
||||
}
|
||||
funct => {
|
||||
let funct = funct.synthetize(&generics);
|
||||
|
||||
|
@ -1924,8 +1975,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
|
||||
// constants do not require directives
|
||||
if let Some(FlatExpression::Number(ref x)) = e.field {
|
||||
let bits: Vec<_> = Interpreter::default()
|
||||
.execute_solver(&Solver::bits(to), &[x.clone()])
|
||||
let bits: Vec<_> = Interpreter::execute_solver(&Solver::bits(to), &[x.clone()])
|
||||
.unwrap()
|
||||
.into_iter()
|
||||
.map(FlatExpression::Number)
|
||||
|
|
|
@ -148,6 +148,10 @@ impl Importer {
|
|||
id: symbol.get_alias(),
|
||||
symbol: Symbol::Flat(FlatEmbed::Unpack),
|
||||
},
|
||||
"bit_array_le" => SymbolDeclaration {
|
||||
id: symbol.get_alias(),
|
||||
symbol: Symbol::Flat(FlatEmbed::BitArrayLe),
|
||||
},
|
||||
"u64_to_bits" => SymbolDeclaration {
|
||||
id: symbol.get_alias(),
|
||||
symbol: Symbol::Flat(FlatEmbed::U64ToBits),
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
use crate::flat_absy::flat_variable::FlatVariable;
|
||||
use crate::ir::Directive;
|
||||
use crate::ir::{LinComb, Prog, QuadComb, Statement, Witness};
|
||||
use crate::solvers::Solver;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
@ -66,30 +65,25 @@ impl Interpreter {
|
|||
}
|
||||
},
|
||||
Statement::Directive(ref d) => {
|
||||
match (&d.solver, &d.inputs, self.should_try_out_of_range) {
|
||||
(Solver::Bits(bitwidth), inputs, true)
|
||||
if inputs[0].left.0.len() > 1
|
||||
|| inputs[0].right.0.len() > 1
|
||||
&& *bitwidth == T::get_required_bits() =>
|
||||
{
|
||||
Self::try_solve_out_of_range(&d, &mut witness)
|
||||
}
|
||||
_ => {
|
||||
let inputs: Vec<_> = d
|
||||
.inputs
|
||||
.iter()
|
||||
.map(|i| i.evaluate(&witness).unwrap())
|
||||
.collect();
|
||||
match self.execute_solver(&d.solver, &inputs) {
|
||||
Ok(res) => {
|
||||
for (i, o) in d.outputs.iter().enumerate() {
|
||||
witness.insert(*o, res[i].clone());
|
||||
}
|
||||
continue;
|
||||
}
|
||||
Err(_) => return Err(Error::Solver),
|
||||
};
|
||||
let mut inputs: Vec<_> = d
|
||||
.inputs
|
||||
.iter()
|
||||
.map(|i| i.evaluate(&witness).unwrap())
|
||||
.collect();
|
||||
|
||||
let res = match (&d.solver, self.should_try_out_of_range) {
|
||||
(Solver::Bits(bitwidth), true) if *bitwidth >= T::get_required_bits() => {
|
||||
Ok(Self::try_solve_with_out_of_range_bits(
|
||||
*bitwidth,
|
||||
inputs.pop().unwrap(),
|
||||
))
|
||||
}
|
||||
_ => Self::execute_solver(&d.solver, &inputs),
|
||||
}
|
||||
.map_err(|_| Error::Solver)?;
|
||||
|
||||
for (i, o) in d.outputs.iter().enumerate() {
|
||||
witness.insert(*o, res[i].clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -98,34 +92,31 @@ impl Interpreter {
|
|||
Ok(Witness(witness))
|
||||
}
|
||||
|
||||
fn try_solve_out_of_range<T: Field>(d: &Directive<T>, witness: &mut BTreeMap<FlatVariable, T>) {
|
||||
fn try_solve_with_out_of_range_bits<T: Field>(bit_width: usize, input: T) -> Vec<T> {
|
||||
use num::traits::Pow;
|
||||
use num_bigint::BigUint;
|
||||
|
||||
let candidate = input.to_biguint() + T::max_value().to_biguint() + T::from(1).to_biguint();
|
||||
|
||||
// we target the `2a - 2b` part of the `<` check by only returning out-of-range results
|
||||
// when the input is not a single summand
|
||||
let value = d.inputs[0].evaluate(&witness).unwrap();
|
||||
let candidate = value.to_biguint() + T::max_value().to_biguint() + T::from(1).to_biguint();
|
||||
let input = if candidate < T::from(2).to_biguint().pow(T::get_required_bits()) {
|
||||
candidate
|
||||
} else {
|
||||
value.to_biguint()
|
||||
input.to_biguint()
|
||||
};
|
||||
|
||||
let mut num = input;
|
||||
let mut res = vec![];
|
||||
let bits = T::get_required_bits();
|
||||
for i in (0..bits).rev() {
|
||||
if T::from(2).to_biguint().pow(i as usize) <= num {
|
||||
num -= T::from(2).to_biguint().pow(i as usize);
|
||||
res.push(T::one());
|
||||
} else {
|
||||
res.push(T::zero());
|
||||
}
|
||||
}
|
||||
assert_eq!(num, T::zero().to_biguint());
|
||||
for (i, o) in d.outputs.iter().enumerate() {
|
||||
witness.insert(*o, res[i].clone());
|
||||
}
|
||||
let padding = bit_width - T::get_required_bits();
|
||||
|
||||
(0..padding)
|
||||
.map(|_| T::zero())
|
||||
.chain((0..T::get_required_bits()).rev().scan(input, |state, i| {
|
||||
if BigUint::from(2usize).pow(i) <= *state {
|
||||
*state = (*state).clone() - BigUint::from(2usize).pow(i);
|
||||
Some(T::one())
|
||||
} else {
|
||||
Some(T::zero())
|
||||
}
|
||||
}))
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn check_inputs<T: Field, U>(&self, program: &Prog<T>, inputs: &[U]) -> Result<(), Error> {
|
||||
|
@ -139,11 +130,7 @@ impl Interpreter {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn execute_solver<T: Field>(
|
||||
&self,
|
||||
solver: &Solver,
|
||||
inputs: &[T],
|
||||
) -> Result<Vec<T>, String> {
|
||||
pub fn execute_solver<T: Field>(solver: &Solver, inputs: &[T]) -> Result<Vec<T>, String> {
|
||||
let (expected_input_count, expected_output_count) = solver.get_signature();
|
||||
assert_eq!(inputs.len(), expected_input_count);
|
||||
|
||||
|
@ -156,18 +143,23 @@ impl Interpreter {
|
|||
],
|
||||
},
|
||||
Solver::Bits(bit_width) => {
|
||||
let mut num = inputs[0].clone();
|
||||
let mut res = vec![];
|
||||
let padding = bit_width.saturating_sub(T::get_required_bits());
|
||||
|
||||
for i in (0..*bit_width).rev() {
|
||||
if T::from(2).pow(i) <= num {
|
||||
num = num - T::from(2).pow(i);
|
||||
res.push(T::one());
|
||||
} else {
|
||||
res.push(T::zero());
|
||||
}
|
||||
}
|
||||
res
|
||||
let bit_width = bit_width - padding;
|
||||
|
||||
let num = inputs[0].clone();
|
||||
|
||||
(0..padding)
|
||||
.map(|_| T::zero())
|
||||
.chain((0..bit_width).rev().scan(num, |state, i| {
|
||||
if T::from(2).pow(i) <= *state {
|
||||
*state = (*state).clone() - T::from(2).pow(i);
|
||||
Some(T::one())
|
||||
} else {
|
||||
Some(T::zero())
|
||||
}
|
||||
}))
|
||||
.collect()
|
||||
}
|
||||
Solver::Xor => {
|
||||
let x = inputs[0].clone();
|
||||
|
@ -346,16 +338,14 @@ mod tests {
|
|||
fn execute() {
|
||||
let cond_eq = Solver::ConditionEq;
|
||||
let inputs = vec![0];
|
||||
let interpreter = Interpreter::default();
|
||||
let r = interpreter
|
||||
.execute_solver(
|
||||
&cond_eq,
|
||||
&inputs
|
||||
.iter()
|
||||
.map(|&i| Bn128Field::from(i))
|
||||
.collect::<Vec<_>>(),
|
||||
)
|
||||
.unwrap();
|
||||
let r = Interpreter::execute_solver(
|
||||
&cond_eq,
|
||||
&inputs
|
||||
.iter()
|
||||
.map(|&i| Bn128Field::from(i))
|
||||
.collect::<Vec<_>>(),
|
||||
)
|
||||
.unwrap();
|
||||
let res: Vec<Bn128Field> = vec![0, 1].iter().map(|&i| Bn128Field::from(i)).collect();
|
||||
assert_eq!(r, &res[..]);
|
||||
}
|
||||
|
@ -364,16 +354,14 @@ mod tests {
|
|||
fn execute_non_eq() {
|
||||
let cond_eq = Solver::ConditionEq;
|
||||
let inputs = vec![1];
|
||||
let interpreter = Interpreter::default();
|
||||
let r = interpreter
|
||||
.execute_solver(
|
||||
&cond_eq,
|
||||
&inputs
|
||||
.iter()
|
||||
.map(|&i| Bn128Field::from(i))
|
||||
.collect::<Vec<_>>(),
|
||||
)
|
||||
.unwrap();
|
||||
let r = Interpreter::execute_solver(
|
||||
&cond_eq,
|
||||
&inputs
|
||||
.iter()
|
||||
.map(|&i| Bn128Field::from(i))
|
||||
.collect::<Vec<_>>(),
|
||||
)
|
||||
.unwrap();
|
||||
let res: Vec<Bn128Field> = vec![1, 1].iter().map(|&i| Bn128Field::from(i)).collect();
|
||||
assert_eq!(r, &res[..]);
|
||||
}
|
||||
|
@ -382,10 +370,9 @@ mod tests {
|
|||
#[test]
|
||||
fn bits_of_one() {
|
||||
let inputs = vec![Bn128Field::from(1)];
|
||||
let interpreter = Interpreter::default();
|
||||
let res = interpreter
|
||||
.execute_solver(&Solver::Bits(Bn128Field::get_required_bits()), &inputs)
|
||||
.unwrap();
|
||||
let res =
|
||||
Interpreter::execute_solver(&Solver::Bits(Bn128Field::get_required_bits()), &inputs)
|
||||
.unwrap();
|
||||
assert_eq!(res[253], Bn128Field::from(1));
|
||||
for r in &res[0..253] {
|
||||
assert_eq!(*r, Bn128Field::from(0));
|
||||
|
@ -395,10 +382,9 @@ mod tests {
|
|||
#[test]
|
||||
fn bits_of_42() {
|
||||
let inputs = vec![Bn128Field::from(42)];
|
||||
let interpreter = Interpreter::default();
|
||||
let res = interpreter
|
||||
.execute_solver(&Solver::Bits(Bn128Field::get_required_bits()), &inputs)
|
||||
.unwrap();
|
||||
let res =
|
||||
Interpreter::execute_solver(&Solver::Bits(Bn128Field::get_required_bits()), &inputs)
|
||||
.unwrap();
|
||||
assert_eq!(res[253], Bn128Field::from(0));
|
||||
assert_eq!(res[252], Bn128Field::from(1));
|
||||
assert_eq!(res[251], Bn128Field::from(0));
|
||||
|
@ -407,4 +393,15 @@ mod tests {
|
|||
assert_eq!(res[248], Bn128Field::from(1));
|
||||
assert_eq!(res[247], Bn128Field::from(0));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn five_hundred_bits_of_1() {
|
||||
let inputs = vec![Bn128Field::from(1)];
|
||||
let res = Interpreter::execute_solver(&Solver::Bits(500), &inputs).unwrap();
|
||||
|
||||
let mut expected = vec![Bn128Field::from(0); 500];
|
||||
expected[499] = Bn128Field::from(1);
|
||||
|
||||
assert_eq!(res, expected);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -140,9 +140,7 @@ impl<T: Field> Folder<T> for RedefinitionOptimizer<T> {
|
|||
// unwrap inputs to their constant value
|
||||
let inputs: Vec<_> = inputs.into_iter().map(|i| i.unwrap()).collect();
|
||||
// run the solver
|
||||
let outputs = Interpreter::default()
|
||||
.execute_solver(&d.solver, &inputs)
|
||||
.unwrap();
|
||||
let outputs = Interpreter::execute_solver(&d.solver, &inputs).unwrap();
|
||||
assert_eq!(outputs.len(), d.outputs.len());
|
||||
|
||||
// insert the results in the substitution
|
||||
|
|
|
@ -1,20 +1,22 @@
|
|||
use crate::embed::FlatEmbed;
|
||||
use crate::typed_absy::TypedProgram;
|
||||
use crate::typed_absy::{
|
||||
result_folder::fold_uint_expression_inner, result_folder::ResultFolder, UBitwidth,
|
||||
UExpressionInner,
|
||||
result_folder::ResultFolder,
|
||||
result_folder::{fold_expression_list_inner, fold_uint_expression_inner},
|
||||
Constant, TypedExpressionListInner, Types, UBitwidth, UExpressionInner,
|
||||
};
|
||||
use zokrates_field::Field;
|
||||
pub struct ShiftChecker;
|
||||
pub struct ConstantArgumentChecker;
|
||||
|
||||
impl ShiftChecker {
|
||||
impl ConstantArgumentChecker {
|
||||
pub fn check<T: Field>(p: TypedProgram<T>) -> Result<TypedProgram<T>, Error> {
|
||||
ShiftChecker.fold_program(p)
|
||||
ConstantArgumentChecker.fold_program(p)
|
||||
}
|
||||
}
|
||||
|
||||
pub type Error = String;
|
||||
|
||||
impl<'ast, T: Field> ResultFolder<'ast, T> for ShiftChecker {
|
||||
impl<'ast, T: Field> ResultFolder<'ast, T> for ConstantArgumentChecker {
|
||||
type Error = Error;
|
||||
|
||||
fn fold_uint_expression_inner(
|
||||
|
@ -52,4 +54,33 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for ShiftChecker {
|
|||
e => fold_uint_expression_inner(self, bitwidth, e),
|
||||
}
|
||||
}
|
||||
|
||||
fn fold_expression_list_inner(
|
||||
&mut self,
|
||||
tys: &Types<'ast, T>,
|
||||
l: TypedExpressionListInner<'ast, T>,
|
||||
) -> Result<TypedExpressionListInner<'ast, T>, Error> {
|
||||
match l {
|
||||
TypedExpressionListInner::EmbedCall(FlatEmbed::BitArrayLe, generics, arguments) => {
|
||||
let arguments = arguments
|
||||
.into_iter()
|
||||
.map(|a| self.fold_expression(a))
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
|
||||
if arguments[1].is_constant() {
|
||||
Ok(TypedExpressionListInner::EmbedCall(
|
||||
FlatEmbed::BitArrayLe,
|
||||
generics,
|
||||
arguments,
|
||||
))
|
||||
} else {
|
||||
Err(format!(
|
||||
"Cannot compare to a variable value, found `{}`",
|
||||
arguments[1]
|
||||
))
|
||||
}
|
||||
}
|
||||
l => fold_expression_list_inner(self, tys, l),
|
||||
}
|
||||
}
|
||||
}
|
|
@ -570,7 +570,7 @@ fn fold_select_expression<'ast, T: Field, E>(
|
|||
statements_buffer: &mut Vec<zir::ZirStatement<'ast, T>>,
|
||||
select: typed_absy::SelectExpression<'ast, T, E>,
|
||||
) -> Vec<zir::ZirExpression<'ast, T>> {
|
||||
let size = typed_absy::types::ConcreteType::try_from(*select.array.ty().ty)
|
||||
let size = typed_absy::types::ConcreteType::try_from(*select.array.ty().clone().ty)
|
||||
.unwrap()
|
||||
.get_primitive_count();
|
||||
|
||||
|
|
|
@ -5,21 +5,21 @@
|
|||
//! @date 2018
|
||||
|
||||
mod branch_isolator;
|
||||
mod constant_argument_checker;
|
||||
mod constant_inliner;
|
||||
mod flat_propagation;
|
||||
mod flatten_complex_types;
|
||||
mod propagation;
|
||||
mod reducer;
|
||||
mod shift_checker;
|
||||
mod uint_optimizer;
|
||||
mod unconstrained_vars;
|
||||
mod variable_write_remover;
|
||||
|
||||
use self::branch_isolator::Isolator;
|
||||
use self::constant_argument_checker::ConstantArgumentChecker;
|
||||
use self::flatten_complex_types::Flattener;
|
||||
use self::propagation::Propagator;
|
||||
use self::reducer::reduce_program;
|
||||
use self::shift_checker::ShiftChecker;
|
||||
use self::uint_optimizer::UintOptimizer;
|
||||
use self::unconstrained_vars::UnconstrainedVariableDetector;
|
||||
use self::variable_write_remover::VariableWriteRemover;
|
||||
|
@ -39,7 +39,7 @@ pub trait Analyse {
|
|||
pub enum Error {
|
||||
Reducer(self::reducer::Error),
|
||||
Propagation(self::propagation::Error),
|
||||
NonConstantShift(self::shift_checker::Error),
|
||||
NonConstantArgument(self::constant_argument_checker::Error),
|
||||
}
|
||||
|
||||
impl From<reducer::Error> for Error {
|
||||
|
@ -54,9 +54,9 @@ impl From<propagation::Error> for Error {
|
|||
}
|
||||
}
|
||||
|
||||
impl From<shift_checker::Error> for Error {
|
||||
fn from(e: shift_checker::Error) -> Self {
|
||||
Error::NonConstantShift(e)
|
||||
impl From<constant_argument_checker::Error> for Error {
|
||||
fn from(e: constant_argument_checker::Error) -> Self {
|
||||
Error::NonConstantArgument(e)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -65,7 +65,7 @@ impl fmt::Display for Error {
|
|||
match self {
|
||||
Error::Reducer(e) => write!(f, "{}", e),
|
||||
Error::Propagation(e) => write!(f, "{}", e),
|
||||
Error::NonConstantShift(e) => write!(f, "{}", e),
|
||||
Error::NonConstantArgument(e) => write!(f, "{}", e),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -107,9 +107,9 @@ impl<'ast, T: Field> TypedProgram<'ast, T> {
|
|||
let r = VariableWriteRemover::apply(r);
|
||||
log::trace!("\n{}", r);
|
||||
|
||||
// detect non constant shifts
|
||||
log::debug!("Static analyser: Detect non constant shifts");
|
||||
let r = ShiftChecker::check(r).map_err(Error::from)?;
|
||||
// detect non constant shifts and constant lt bounds
|
||||
log::debug!("Static analyser: Detect non constant arguments");
|
||||
let r = ConstantArgumentChecker::check(r).map_err(Error::from)?;
|
||||
log::trace!("\n{}", r);
|
||||
|
||||
// convert to zir, removing complex types
|
||||
|
|
|
@ -124,126 +124,6 @@ impl<'ast, 'a, T: Field> Propagator<'ast, 'a, T> {
|
|||
}
|
||||
}
|
||||
|
||||
fn is_constant<T: Field>(e: &TypedExpression<T>) -> bool {
|
||||
match e {
|
||||
TypedExpression::FieldElement(FieldElementExpression::Number(..)) => true,
|
||||
TypedExpression::Boolean(BooleanExpression::Value(..)) => true,
|
||||
TypedExpression::Array(a) => match a.as_inner() {
|
||||
ArrayExpressionInner::Value(v) => v.0.iter().all(|e| match e {
|
||||
TypedExpressionOrSpread::Expression(e) => is_constant(e),
|
||||
_ => false,
|
||||
}),
|
||||
ArrayExpressionInner::Slice(box a, box from, box to) => {
|
||||
is_constant(&from.clone().into())
|
||||
&& is_constant(&to.clone().into())
|
||||
&& is_constant(&a.clone().into())
|
||||
}
|
||||
ArrayExpressionInner::Repeat(box e, box count) => {
|
||||
is_constant(&count.clone().into()) && is_constant(&e)
|
||||
}
|
||||
_ => false,
|
||||
},
|
||||
TypedExpression::Struct(a) => match a.as_inner() {
|
||||
StructExpressionInner::Value(v) => v.iter().all(|e| is_constant(e)),
|
||||
_ => false,
|
||||
},
|
||||
TypedExpression::Uint(a) => matches!(a.as_inner(), UExpressionInner::Value(..)),
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
// in the constant map, we only want canonical constants: [0; 3] -> [0, 0, 0], [...[1], 2] -> [1, 2], etc
|
||||
fn to_canonical_constant<T: Field>(e: TypedExpression<T>) -> TypedExpression<T> {
|
||||
fn to_canonical_constant_aux<T: Field>(
|
||||
e: TypedExpressionOrSpread<T>,
|
||||
) -> Vec<TypedExpression<T>> {
|
||||
match e {
|
||||
TypedExpressionOrSpread::Expression(e) => vec![e],
|
||||
TypedExpressionOrSpread::Spread(s) => match s.array.into_inner() {
|
||||
ArrayExpressionInner::Value(v) => {
|
||||
v.into_iter().flat_map(to_canonical_constant_aux).collect()
|
||||
}
|
||||
_ => unimplemented!(),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
match e {
|
||||
TypedExpression::Array(a) => {
|
||||
let array_ty = a.ty();
|
||||
|
||||
match a.into_inner() {
|
||||
ArrayExpressionInner::Value(v) => ArrayExpressionInner::Value(
|
||||
v.into_iter()
|
||||
.flat_map(to_canonical_constant_aux)
|
||||
.map(|e| e.into())
|
||||
.collect::<Vec<_>>()
|
||||
.into(),
|
||||
)
|
||||
.annotate(*array_ty.ty, array_ty.size)
|
||||
.into(),
|
||||
ArrayExpressionInner::Slice(box a, box from, box to) => {
|
||||
let from = match from.into_inner() {
|
||||
UExpressionInner::Value(from) => from as usize,
|
||||
_ => unreachable!("should be a uint value"),
|
||||
};
|
||||
|
||||
let to = match to.into_inner() {
|
||||
UExpressionInner::Value(to) => to as usize,
|
||||
_ => unreachable!("should be a uint value"),
|
||||
};
|
||||
|
||||
let v = match a.into_inner() {
|
||||
ArrayExpressionInner::Value(v) => v,
|
||||
_ => unreachable!("should be an array value"),
|
||||
};
|
||||
|
||||
ArrayExpressionInner::Value(
|
||||
v.into_iter()
|
||||
.flat_map(to_canonical_constant_aux)
|
||||
.map(|e| e.into())
|
||||
.enumerate()
|
||||
.filter(|(index, _)| index >= &from && index < &to)
|
||||
.map(|(_, e)| e)
|
||||
.collect::<Vec<_>>()
|
||||
.into(),
|
||||
)
|
||||
.annotate(*array_ty.ty, array_ty.size)
|
||||
.into()
|
||||
}
|
||||
ArrayExpressionInner::Repeat(box e, box count) => {
|
||||
let count = match count.into_inner() {
|
||||
UExpressionInner::Value(from) => from as usize,
|
||||
_ => unreachable!("should be a uint value"),
|
||||
};
|
||||
|
||||
let e = to_canonical_constant(e);
|
||||
|
||||
ArrayExpressionInner::Value(
|
||||
vec![TypedExpressionOrSpread::Expression(e); count].into(),
|
||||
)
|
||||
.annotate(*array_ty.ty, array_ty.size)
|
||||
.into()
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
TypedExpression::Struct(s) => {
|
||||
let struct_ty = s.ty().clone();
|
||||
|
||||
match s.into_inner() {
|
||||
StructExpressionInner::Value(expressions) => StructExpressionInner::Value(
|
||||
expressions.into_iter().map(to_canonical_constant).collect(),
|
||||
)
|
||||
.annotate(struct_ty)
|
||||
.into(),
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
e => e,
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Propagator<'ast, 'a, T> {
|
||||
type Error = Error;
|
||||
|
||||
|
@ -341,10 +221,10 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Propagator<'ast, 'a, T> {
|
|||
}
|
||||
};
|
||||
|
||||
if is_constant(&expr) {
|
||||
if expr.is_constant() {
|
||||
match assignee {
|
||||
TypedAssignee::Identifier(var) => {
|
||||
let expr = to_canonical_constant(expr);
|
||||
let expr = expr.into_canonical_constant();
|
||||
|
||||
assert!(self.constants.insert(var.id, expr).is_none());
|
||||
|
||||
|
@ -352,7 +232,7 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Propagator<'ast, 'a, T> {
|
|||
}
|
||||
assignee => match self.try_get_constant_mut(&assignee) {
|
||||
Ok((_, c)) => {
|
||||
*c = to_canonical_constant(expr);
|
||||
*c = expr.into_canonical_constant();
|
||||
Ok(vec![])
|
||||
}
|
||||
Err(v) => match self.constants.remove(&v.id) {
|
||||
|
@ -423,7 +303,7 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Propagator<'ast, 'a, T> {
|
|||
|
||||
let argument = arguments.pop().unwrap();
|
||||
|
||||
let argument = to_canonical_constant(argument);
|
||||
let argument = argument.into_canonical_constant();
|
||||
|
||||
match ArrayExpression::try_from(argument)
|
||||
.unwrap()
|
||||
|
@ -498,10 +378,11 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Propagator<'ast, 'a, T> {
|
|||
}
|
||||
}
|
||||
|
||||
match arguments.iter().all(|a| is_constant(a)) {
|
||||
match arguments.iter().all(|a| a.is_constant()) {
|
||||
true => {
|
||||
let r: Option<TypedExpression<'ast, T>> = match embed {
|
||||
FlatEmbed::U32ToField => None, // todo
|
||||
FlatEmbed::BitArrayLe => None, // todo
|
||||
FlatEmbed::U64FromBits => Some(process_u_from_bits(
|
||||
assignees.clone(),
|
||||
arguments.clone(),
|
||||
|
@ -1173,6 +1054,34 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Propagator<'ast, 'a, T> {
|
|||
},
|
||||
None => Ok(ArrayExpressionInner::Identifier(id)),
|
||||
},
|
||||
ArrayExpressionInner::Value(exprs) => {
|
||||
Ok(ArrayExpressionInner::Value(
|
||||
exprs
|
||||
.into_iter()
|
||||
.map(|e| self.fold_expression_or_spread(e))
|
||||
.collect::<Result<Vec<_>, _>>()?
|
||||
.into_iter()
|
||||
.flat_map(|e| {
|
||||
match e {
|
||||
// simplify `...[a, b]` to `a, b`
|
||||
TypedExpressionOrSpread::Spread(TypedSpread {
|
||||
array:
|
||||
ArrayExpression {
|
||||
inner: ArrayExpressionInner::Value(v),
|
||||
..
|
||||
},
|
||||
}) => v.0,
|
||||
e => vec![e],
|
||||
}
|
||||
})
|
||||
// ignore spreads over empty arrays
|
||||
.filter_map(|e| match e {
|
||||
TypedExpressionOrSpread::Spread(s) if s.array.size() == 0 => None,
|
||||
e => Some(e),
|
||||
})
|
||||
.collect(),
|
||||
))
|
||||
}
|
||||
e => fold_array_expression_inner(self, ty, e),
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1273,13 +1273,8 @@ mod tests {
|
|||
// def main():
|
||||
// # PUSH CALL to foo::<1>
|
||||
// # PUSH CALL to bar::<2>
|
||||
// field[2] a_1 = [...[1]], 0]
|
||||
// field[2] #RET_0_1 = a_1
|
||||
// # POP CALL
|
||||
// field[1] ret := #RET_0_1[0..1]
|
||||
// field[1] #RET_0 = ret
|
||||
// # POP CALL
|
||||
// field[1] b_0 := #RET_0
|
||||
// return
|
||||
|
||||
let foo_signature = DeclarationSignature::new()
|
||||
|
@ -1452,71 +1447,8 @@ mod tests {
|
|||
.collect(),
|
||||
),
|
||||
),
|
||||
TypedStatement::Definition(
|
||||
Variable::array(Identifier::from("a").version(1), Type::FieldElement, 2u32)
|
||||
.into(),
|
||||
ArrayExpressionInner::Value(
|
||||
vec![
|
||||
TypedExpressionOrSpread::Spread(
|
||||
ArrayExpressionInner::Value(
|
||||
vec![TypedExpressionOrSpread::Expression(
|
||||
FieldElementExpression::Number(Bn128Field::from(1)).into(),
|
||||
)]
|
||||
.into(),
|
||||
)
|
||||
.annotate(Type::FieldElement, 1u32)
|
||||
.into(),
|
||||
),
|
||||
FieldElementExpression::Number(Bn128Field::from(0)).into(),
|
||||
]
|
||||
.into(),
|
||||
)
|
||||
.annotate(Type::FieldElement, 2u32)
|
||||
.into(),
|
||||
),
|
||||
TypedStatement::Definition(
|
||||
Variable::array(
|
||||
Identifier::from(CoreIdentifier::Call(0)).version(1),
|
||||
Type::FieldElement,
|
||||
2u32,
|
||||
)
|
||||
.into(),
|
||||
ArrayExpressionInner::Identifier(Identifier::from("a").version(1))
|
||||
.annotate(Type::FieldElement, 2u32)
|
||||
.into(),
|
||||
),
|
||||
TypedStatement::PopCallLog,
|
||||
TypedStatement::Definition(
|
||||
Variable::array("ret", Type::FieldElement, 1u32).into(),
|
||||
ArrayExpressionInner::Slice(
|
||||
box ArrayExpressionInner::Identifier(
|
||||
Identifier::from(CoreIdentifier::Call(0)).version(1),
|
||||
)
|
||||
.annotate(Type::FieldElement, 2u32),
|
||||
box 0u32.into(),
|
||||
box 1u32.into(),
|
||||
)
|
||||
.annotate(Type::FieldElement, 1u32)
|
||||
.into(),
|
||||
),
|
||||
TypedStatement::Definition(
|
||||
Variable::array(
|
||||
Identifier::from(CoreIdentifier::Call(0)),
|
||||
Type::FieldElement,
|
||||
1u32,
|
||||
)
|
||||
.into(),
|
||||
ArrayExpressionInner::Identifier("ret".into())
|
||||
.annotate(Type::FieldElement, 1u32)
|
||||
.into(),
|
||||
),
|
||||
TypedStatement::PopCallLog,
|
||||
TypedStatement::Definition(
|
||||
Variable::array("b", Type::FieldElement, 1u32).into(),
|
||||
ArrayExpressionInner::Identifier(Identifier::from(CoreIdentifier::Call(0)))
|
||||
.annotate(Type::FieldElement, 1u32)
|
||||
.into(),
|
||||
),
|
||||
TypedStatement::Return(vec![]),
|
||||
],
|
||||
signature: DeclarationSignature::new(),
|
||||
|
|
|
@ -738,7 +738,7 @@ mod tests {
|
|||
|
||||
assert_eq!(
|
||||
UintOptimizer::new()
|
||||
.fold_uint_expression(UExpression::right_shift(left.clone(), right.clone())),
|
||||
.fold_uint_expression(UExpression::right_shift(left.clone(), right)),
|
||||
UExpression::right_shift(left_expected, right_expected).with_max(output_max)
|
||||
);
|
||||
}
|
||||
|
@ -761,7 +761,7 @@ mod tests {
|
|||
|
||||
assert_eq!(
|
||||
UintOptimizer::new()
|
||||
.fold_uint_expression(UExpression::left_shift(left.clone(), right.clone())),
|
||||
.fold_uint_expression(UExpression::left_shift(left.clone(), right)),
|
||||
UExpression::left_shift(left_expected, right_expected).with_max(output_max)
|
||||
);
|
||||
}
|
||||
|
|
|
@ -5,7 +5,7 @@ use crate::typed_absy::types::{
|
|||
};
|
||||
use crate::typed_absy::UBitwidth;
|
||||
use crate::typed_absy::{
|
||||
ArrayExpression, ArrayExpressionInner, BooleanExpression, FieldElementExpression, IfElse,
|
||||
ArrayExpression, ArrayExpressionInner, BooleanExpression, Expr, FieldElementExpression, IfElse,
|
||||
IfElseExpression, Select, SelectExpression, StructExpression, StructExpressionInner, Typed,
|
||||
TypedExpression, TypedExpressionOrSpread, TypedSpread, UExpression, UExpressionInner,
|
||||
};
|
||||
|
@ -585,7 +585,7 @@ impl<'ast, T: Field> ArrayExpression<'ast, T> {
|
|||
array: Self,
|
||||
target_array_ty: &GArrayType<S>,
|
||||
) -> Result<Self, TypedExpression<'ast, T>> {
|
||||
let array_ty = array.ty();
|
||||
let array_ty = array.ty().clone();
|
||||
|
||||
// elements must fit in the target type
|
||||
match array.into_inner() {
|
||||
|
|
|
@ -1029,8 +1029,8 @@ impl<'ast, T> From<bool> for BooleanExpression<'ast, T> {
|
|||
/// type checking
|
||||
#[derive(Clone, PartialEq, Debug, Hash, Eq)]
|
||||
pub struct ArrayExpression<'ast, T> {
|
||||
ty: Box<ArrayType<'ast, T>>,
|
||||
inner: ArrayExpressionInner<'ast, T>,
|
||||
pub ty: Box<ArrayType<'ast, T>>,
|
||||
pub inner: ArrayExpressionInner<'ast, T>,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Hash, Clone)]
|
||||
|
@ -1149,25 +1149,6 @@ impl<'ast, T: Clone> ArrayExpression<'ast, T> {
|
|||
pub fn size(&self) -> UExpression<'ast, T> {
|
||||
self.ty.size.clone()
|
||||
}
|
||||
|
||||
pub fn as_inner(&self) -> &ArrayExpressionInner<'ast, T> {
|
||||
&self.inner
|
||||
}
|
||||
|
||||
pub fn as_inner_mut(&mut self) -> &mut ArrayExpressionInner<'ast, T> {
|
||||
&mut self.inner
|
||||
}
|
||||
|
||||
pub fn into_inner(self) -> ArrayExpressionInner<'ast, T> {
|
||||
self.inner
|
||||
}
|
||||
|
||||
pub fn ty(&self) -> ArrayType<'ast, T> {
|
||||
ArrayType {
|
||||
size: self.size(),
|
||||
ty: box self.inner_type().clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, PartialEq, Debug, Hash, Eq)]
|
||||
|
@ -1495,15 +1476,23 @@ pub trait Expr<'ast, T>: From<TypedExpression<'ast, T>> {
|
|||
type Inner;
|
||||
type Ty: Clone + IntoTypes<'ast, T>;
|
||||
|
||||
fn ty(&self) -> &Self::Ty;
|
||||
|
||||
fn into_inner(self) -> Self::Inner;
|
||||
|
||||
fn as_inner(&self) -> &Self::Inner;
|
||||
|
||||
fn as_inner_mut(&mut self) -> &mut Self::Inner;
|
||||
}
|
||||
|
||||
impl<'ast, T: Clone> Expr<'ast, T> for FieldElementExpression<'ast, T> {
|
||||
type Inner = Self;
|
||||
type Ty = Type<'ast, T>;
|
||||
|
||||
fn ty(&self) -> &Self::Ty {
|
||||
&Type::FieldElement
|
||||
}
|
||||
|
||||
fn into_inner(self) -> Self::Inner {
|
||||
self
|
||||
}
|
||||
|
@ -1511,12 +1500,20 @@ impl<'ast, T: Clone> Expr<'ast, T> for FieldElementExpression<'ast, T> {
|
|||
fn as_inner(&self) -> &Self::Inner {
|
||||
&self
|
||||
}
|
||||
|
||||
fn as_inner_mut(&mut self) -> &mut Self::Inner {
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast, T: Clone> Expr<'ast, T> for BooleanExpression<'ast, T> {
|
||||
type Inner = Self;
|
||||
type Ty = Type<'ast, T>;
|
||||
|
||||
fn ty(&self) -> &Self::Ty {
|
||||
&Type::Boolean
|
||||
}
|
||||
|
||||
fn into_inner(self) -> Self::Inner {
|
||||
self
|
||||
}
|
||||
|
@ -1524,12 +1521,20 @@ impl<'ast, T: Clone> Expr<'ast, T> for BooleanExpression<'ast, T> {
|
|||
fn as_inner(&self) -> &Self::Inner {
|
||||
&self
|
||||
}
|
||||
|
||||
fn as_inner_mut(&mut self) -> &mut Self::Inner {
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast, T: Clone> Expr<'ast, T> for UExpression<'ast, T> {
|
||||
type Inner = UExpressionInner<'ast, T>;
|
||||
type Ty = UBitwidth;
|
||||
|
||||
fn ty(&self) -> &Self::Ty {
|
||||
&self.bitwidth
|
||||
}
|
||||
|
||||
fn into_inner(self) -> Self::Inner {
|
||||
self.inner
|
||||
}
|
||||
|
@ -1537,12 +1542,20 @@ impl<'ast, T: Clone> Expr<'ast, T> for UExpression<'ast, T> {
|
|||
fn as_inner(&self) -> &Self::Inner {
|
||||
&self.inner
|
||||
}
|
||||
|
||||
fn as_inner_mut(&mut self) -> &mut Self::Inner {
|
||||
&mut self.inner
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast, T: Clone> Expr<'ast, T> for StructExpression<'ast, T> {
|
||||
type Inner = StructExpressionInner<'ast, T>;
|
||||
type Ty = StructType<'ast, T>;
|
||||
|
||||
fn ty(&self) -> &Self::Ty {
|
||||
&self.ty
|
||||
}
|
||||
|
||||
fn into_inner(self) -> Self::Inner {
|
||||
self.inner
|
||||
}
|
||||
|
@ -1550,12 +1563,20 @@ impl<'ast, T: Clone> Expr<'ast, T> for StructExpression<'ast, T> {
|
|||
fn as_inner(&self) -> &Self::Inner {
|
||||
&self.inner
|
||||
}
|
||||
|
||||
fn as_inner_mut(&mut self) -> &mut Self::Inner {
|
||||
&mut self.inner
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast, T: Clone> Expr<'ast, T> for ArrayExpression<'ast, T> {
|
||||
type Inner = ArrayExpressionInner<'ast, T>;
|
||||
type Ty = ArrayType<'ast, T>;
|
||||
|
||||
fn ty(&self) -> &Self::Ty {
|
||||
&self.ty
|
||||
}
|
||||
|
||||
fn into_inner(self) -> Self::Inner {
|
||||
self.inner
|
||||
}
|
||||
|
@ -1563,12 +1584,20 @@ impl<'ast, T: Clone> Expr<'ast, T> for ArrayExpression<'ast, T> {
|
|||
fn as_inner(&self) -> &Self::Inner {
|
||||
&self.inner
|
||||
}
|
||||
|
||||
fn as_inner_mut(&mut self) -> &mut Self::Inner {
|
||||
&mut self.inner
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast, T: Clone> Expr<'ast, T> for IntExpression<'ast, T> {
|
||||
type Inner = Self;
|
||||
type Ty = Type<'ast, T>;
|
||||
|
||||
fn ty(&self) -> &Self::Ty {
|
||||
&Type::Int
|
||||
}
|
||||
|
||||
fn into_inner(self) -> Self::Inner {
|
||||
self
|
||||
}
|
||||
|
@ -1576,12 +1605,20 @@ impl<'ast, T: Clone> Expr<'ast, T> for IntExpression<'ast, T> {
|
|||
fn as_inner(&self) -> &Self::Inner {
|
||||
&self
|
||||
}
|
||||
|
||||
fn as_inner_mut(&mut self) -> &mut Self::Inner {
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast, T: Clone> Expr<'ast, T> for TypedExpressionList<'ast, T> {
|
||||
type Inner = TypedExpressionListInner<'ast, T>;
|
||||
type Ty = Types<'ast, T>;
|
||||
|
||||
fn ty(&self) -> &Self::Ty {
|
||||
&self.types
|
||||
}
|
||||
|
||||
fn into_inner(self) -> Self::Inner {
|
||||
self.inner
|
||||
}
|
||||
|
@ -1589,6 +1626,10 @@ impl<'ast, T: Clone> Expr<'ast, T> for TypedExpressionList<'ast, T> {
|
|||
fn as_inner(&self) -> &Self::Inner {
|
||||
&self.inner
|
||||
}
|
||||
|
||||
fn as_inner_mut(&mut self) -> &mut Self::Inner {
|
||||
&mut self.inner
|
||||
}
|
||||
}
|
||||
|
||||
// Enums types to enable returning e.g a member expression OR another type of expression of this type
|
||||
|
@ -1769,7 +1810,7 @@ impl<'ast, T> Member<'ast, T> for BooleanExpression<'ast, T> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<'ast, T> Member<'ast, T> for UExpression<'ast, T> {
|
||||
impl<'ast, T: Clone> Member<'ast, T> for UExpression<'ast, T> {
|
||||
fn member(s: StructExpression<'ast, T>, id: MemberId) -> Self {
|
||||
let ty = s.ty().members.iter().find(|member| id == member.id);
|
||||
let bitwidth = match ty {
|
||||
|
@ -1949,7 +1990,7 @@ impl<'ast, T: Field> Block<'ast, T> for UExpression<'ast, T> {
|
|||
|
||||
impl<'ast, T: Field> Block<'ast, T> for ArrayExpression<'ast, T> {
|
||||
fn block(statements: Vec<TypedStatement<'ast, T>>, value: Self) -> Self {
|
||||
let array_ty = value.ty();
|
||||
let array_ty = value.ty().clone();
|
||||
ArrayExpressionInner::Block(BlockExpression::new(statements, value))
|
||||
.annotate(*array_ty.ty, array_ty.size)
|
||||
}
|
||||
|
@ -1962,3 +2003,199 @@ impl<'ast, T: Field> Block<'ast, T> for StructExpression<'ast, T> {
|
|||
StructExpressionInner::Block(BlockExpression::new(statements, value)).annotate(struct_ty)
|
||||
}
|
||||
}
|
||||
|
||||
pub trait Constant: Sized {
|
||||
// return whether this is constant
|
||||
fn is_constant(&self) -> bool;
|
||||
|
||||
// canonicalize an expression *that we know to be constant*
|
||||
// for example for [0; 3] -> [0, 0, 0], [...[1], 2] -> [1, 2], etc
|
||||
fn into_canonical_constant(self) -> Self {
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast, T: Field> Constant for FieldElementExpression<'ast, T> {
|
||||
fn is_constant(&self) -> bool {
|
||||
matches!(self, FieldElementExpression::Number(..))
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast, T: Field> Constant for BooleanExpression<'ast, T> {
|
||||
fn is_constant(&self) -> bool {
|
||||
matches!(self, BooleanExpression::Value(..))
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast, T: Field> Constant for UExpression<'ast, T> {
|
||||
fn is_constant(&self) -> bool {
|
||||
matches!(self.as_inner(), UExpressionInner::Value(..))
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast, T: Field> Constant for ArrayExpression<'ast, T> {
|
||||
fn is_constant(&self) -> bool {
|
||||
match self.as_inner() {
|
||||
ArrayExpressionInner::Value(v) => v.0.iter().all(|e| match e {
|
||||
TypedExpressionOrSpread::Expression(e) => e.is_constant(),
|
||||
TypedExpressionOrSpread::Spread(s) => s.array.is_constant(),
|
||||
}),
|
||||
ArrayExpressionInner::Slice(box a, box from, box to) => {
|
||||
from.is_constant() && to.is_constant() && a.is_constant()
|
||||
}
|
||||
ArrayExpressionInner::Repeat(box e, box count) => {
|
||||
count.is_constant() && e.is_constant()
|
||||
}
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
fn into_canonical_constant(self) -> Self {
|
||||
fn into_canonical_constant_aux<T: Field>(
|
||||
e: TypedExpressionOrSpread<T>,
|
||||
) -> Vec<TypedExpression<T>> {
|
||||
match e {
|
||||
TypedExpressionOrSpread::Expression(e) => vec![e],
|
||||
TypedExpressionOrSpread::Spread(s) => match s.array.into_inner() {
|
||||
ArrayExpressionInner::Value(v) => v
|
||||
.into_iter()
|
||||
.flat_map(into_canonical_constant_aux)
|
||||
.collect(),
|
||||
ArrayExpressionInner::Slice(box v, box from, box to) => {
|
||||
let from = match from.into_inner() {
|
||||
UExpressionInner::Value(v) => v,
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
let to = match to.into_inner() {
|
||||
UExpressionInner::Value(v) => v,
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
let v = match v.into_inner() {
|
||||
ArrayExpressionInner::Value(v) => v,
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
v.into_iter()
|
||||
.flat_map(into_canonical_constant_aux)
|
||||
.skip(from as usize)
|
||||
.take(to as usize - from as usize)
|
||||
.collect()
|
||||
}
|
||||
ArrayExpressionInner::Repeat(box e, box count) => {
|
||||
let count = match count.into_inner() {
|
||||
UExpressionInner::Value(count) => count,
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
vec![e; count as usize]
|
||||
}
|
||||
a => unreachable!("{}", a),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
let array_ty = self.ty().clone();
|
||||
|
||||
match self.into_inner() {
|
||||
ArrayExpressionInner::Value(v) => ArrayExpressionInner::Value(
|
||||
v.into_iter()
|
||||
.flat_map(into_canonical_constant_aux)
|
||||
.map(|e| e.into())
|
||||
.collect::<Vec<_>>()
|
||||
.into(),
|
||||
)
|
||||
.annotate(*array_ty.ty, array_ty.size),
|
||||
ArrayExpressionInner::Slice(box a, box from, box to) => {
|
||||
let from = match from.into_inner() {
|
||||
UExpressionInner::Value(from) => from as usize,
|
||||
_ => unreachable!("should be a uint value"),
|
||||
};
|
||||
|
||||
let to = match to.into_inner() {
|
||||
UExpressionInner::Value(to) => to as usize,
|
||||
_ => unreachable!("should be a uint value"),
|
||||
};
|
||||
|
||||
let v = match a.into_inner() {
|
||||
ArrayExpressionInner::Value(v) => v,
|
||||
_ => unreachable!("should be an array value"),
|
||||
};
|
||||
|
||||
ArrayExpressionInner::Value(
|
||||
v.into_iter()
|
||||
.flat_map(into_canonical_constant_aux)
|
||||
.map(|e| e.into())
|
||||
.skip(from)
|
||||
.take(to - from)
|
||||
.collect::<Vec<_>>()
|
||||
.into(),
|
||||
)
|
||||
.annotate(*array_ty.ty, array_ty.size)
|
||||
}
|
||||
ArrayExpressionInner::Repeat(box e, box count) => {
|
||||
let count = match count.into_inner() {
|
||||
UExpressionInner::Value(from) => from as usize,
|
||||
_ => unreachable!("should be a uint value"),
|
||||
};
|
||||
|
||||
let e = e.into_canonical_constant();
|
||||
|
||||
ArrayExpressionInner::Value(
|
||||
vec![TypedExpressionOrSpread::Expression(e); count].into(),
|
||||
)
|
||||
.annotate(*array_ty.ty, array_ty.size)
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast, T: Field> Constant for StructExpression<'ast, T> {
|
||||
fn is_constant(&self) -> bool {
|
||||
match self.as_inner() {
|
||||
StructExpressionInner::Value(v) => v.iter().all(|e| e.is_constant()),
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
fn into_canonical_constant(self) -> Self {
|
||||
let struct_ty = self.ty().clone();
|
||||
|
||||
match self.into_inner() {
|
||||
StructExpressionInner::Value(expressions) => StructExpressionInner::Value(
|
||||
expressions
|
||||
.into_iter()
|
||||
.map(|e| e.into_canonical_constant())
|
||||
.collect(),
|
||||
)
|
||||
.annotate(struct_ty),
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast, T: Field> Constant for TypedExpression<'ast, T> {
|
||||
fn is_constant(&self) -> bool {
|
||||
match self {
|
||||
TypedExpression::FieldElement(e) => e.is_constant(),
|
||||
TypedExpression::Boolean(e) => e.is_constant(),
|
||||
TypedExpression::Array(e) => e.is_constant(),
|
||||
TypedExpression::Struct(e) => e.is_constant(),
|
||||
TypedExpression::Uint(e) => e.is_constant(),
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
|
||||
fn into_canonical_constant(self) -> Self {
|
||||
match self {
|
||||
TypedExpression::FieldElement(e) => e.into_canonical_constant().into(),
|
||||
TypedExpression::Boolean(e) => e.into_canonical_constant().into(),
|
||||
TypedExpression::Array(e) => e.into_canonical_constant().into(),
|
||||
TypedExpression::Struct(e) => e.into_canonical_constant().into(),
|
||||
TypedExpression::Uint(e) => e.into_canonical_constant().into(),
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -10,6 +10,7 @@ use zokrates_core::{
|
|||
ir::Interpreter,
|
||||
};
|
||||
use zokrates_field::Bn128Field;
|
||||
use zokrates_fs_resolver::FileSystemResolver;
|
||||
|
||||
#[test]
|
||||
fn lt_field() {
|
||||
|
@ -74,3 +75,83 @@ fn lt_uint() {
|
|||
)
|
||||
.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn unpack256() {
|
||||
let source = r#"
|
||||
import "utils/pack/bool/unpack256"
|
||||
|
||||
def main(private field a):
|
||||
bool[256] bits = unpack256(a)
|
||||
assert(bits[255])
|
||||
return
|
||||
"#
|
||||
.to_string();
|
||||
|
||||
// let's try to prove that the least significant bit of 0 is 1
|
||||
// we exploit the fact that the bits of 0 are the bits of p, and p is even
|
||||
// we want this to still fail
|
||||
|
||||
let stdlib_path = std::fs::canonicalize(
|
||||
std::env::current_dir()
|
||||
.unwrap()
|
||||
.join("../zokrates_stdlib/stdlib"),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let res: CompilationArtifacts<Bn128Field> = compile(
|
||||
source,
|
||||
"./path/to/file".into(),
|
||||
Some(&FileSystemResolver::with_stdlib_root(
|
||||
stdlib_path.to_str().unwrap(),
|
||||
)),
|
||||
&CompileConfig::default(),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let interpreter = Interpreter::try_out_of_range();
|
||||
|
||||
assert!(interpreter
|
||||
.execute(&res.prog(), &[Bn128Field::from(0)])
|
||||
.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn unpack256_unchecked() {
|
||||
let source = r#"
|
||||
import "utils/pack/bool/nonStrictUnpack256"
|
||||
|
||||
def main(private field a):
|
||||
bool[256] bits = nonStrictUnpack256(a)
|
||||
assert(bits[255])
|
||||
return
|
||||
"#
|
||||
.to_string();
|
||||
|
||||
// let's try to prove that the least significant bit of 0 is 1
|
||||
// we exploit the fact that the bits of 0 are the bits of p, and p is odd
|
||||
// we want this to succeed as the non strict version does not enforce the bits to be in range
|
||||
|
||||
let stdlib_path = std::fs::canonicalize(
|
||||
std::env::current_dir()
|
||||
.unwrap()
|
||||
.join("../zokrates_stdlib/stdlib"),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let res: CompilationArtifacts<Bn128Field> = compile(
|
||||
source,
|
||||
"./path/to/file".into(),
|
||||
Some(&FileSystemResolver::with_stdlib_root(
|
||||
stdlib_path.to_str().unwrap(),
|
||||
)),
|
||||
&CompileConfig::default(),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let interpreter = Interpreter::try_out_of_range();
|
||||
|
||||
assert!(interpreter
|
||||
.execute(&res.prog(), &[Bn128Field::from(0)])
|
||||
.is_ok());
|
||||
}
|
||||
|
|
|
@ -26,17 +26,16 @@ impl<'a> Resolver<io::Error> for FileSystemResolver<'a> {
|
|||
) -> Result<(String, PathBuf), io::Error> {
|
||||
let source = Path::new(&import_location);
|
||||
|
||||
if !current_location.is_file() {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::Other,
|
||||
format!("{} was expected to be a file", current_location.display()),
|
||||
));
|
||||
}
|
||||
|
||||
// paths starting with `./` or `../` are interpreted relative to the current file
|
||||
// other paths `abc/def` are interpreted relative to the standard library root path
|
||||
let base = match source.components().next() {
|
||||
Some(Component::CurDir) | Some(Component::ParentDir) => {
|
||||
if !current_location.is_file() {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::Other,
|
||||
format!("{} was expected to be a file", current_location.display()),
|
||||
));
|
||||
}
|
||||
current_location.parent().unwrap().into()
|
||||
}
|
||||
_ => PathBuf::from(self.stdlib_root_path.unwrap_or("")),
|
||||
|
|
|
@ -1,12 +1,12 @@
|
|||
#pragma curve bn128
|
||||
|
||||
import "./unpack" as unpack
|
||||
import "./unpack_unchecked"
|
||||
|
||||
// Unpack a field element as 256 big-endian bits
|
||||
// Note: uniqueness of the output is not guaranteed
|
||||
// For example, `0` can map to `[0, 0, ..., 0]` or to `bits(p)`
|
||||
def main(field i) -> bool[256]:
|
||||
|
||||
bool[254] b = unpack::<254>(i)
|
||||
bool[254] b = unpack_unchecked::<254>(i)
|
||||
|
||||
return [false, false, ...b]
|
|
@ -1,12 +1,12 @@
|
|||
#pragma curve bn128
|
||||
|
||||
from "EMBED" import unpack
|
||||
import "./unpack_unchecked.zok"
|
||||
from "field" import FIELD_SIZE_IN_BITS
|
||||
from "EMBED" import bit_array_le
|
||||
|
||||
// Unpack a field element as N big endian bits
|
||||
def main<N>(field i) -> bool[N]:
|
||||
|
||||
assert(N <= 254)
|
||||
|
||||
bool[N] res = unpack(i)
|
||||
bool[N] res = unpack_unchecked(i)
|
||||
|
||||
assert(if N >= FIELD_SIZE_IN_BITS then bit_array_le(res, [...[false; N - FIELD_SIZE_IN_BITS], ...unpack_unchecked::<FIELD_SIZE_IN_BITS>(-1)]) else true fi)
|
||||
|
||||
return res
|
|
@ -1,9 +1,7 @@
|
|||
#pragma curve bn128
|
||||
|
||||
import "./unpack" as unpack
|
||||
|
||||
// Unpack a field element as 128 big-endian bits
|
||||
// Precondition: the input is smaller or equal to `2**128 - 1`
|
||||
// If the input is larger than `2**128 - 1`, the output is truncated.
|
||||
def main(field i) -> bool[128]:
|
||||
bool[128] res = unpack::<128>(i)
|
||||
return res
|
7
zokrates_stdlib/stdlib/utils/pack/bool/unpack256.zok
Normal file
7
zokrates_stdlib/stdlib/utils/pack/bool/unpack256.zok
Normal file
|
@ -0,0 +1,7 @@
|
|||
import "./unpack" as unpack
|
||||
|
||||
// Unpack a field element as 256 big-endian bits
|
||||
// If the input is larger than `2**256 - 1`, the output is truncated.
|
||||
def main(field i) -> bool[256]:
|
||||
bool[256] res = unpack::<256>(i)
|
||||
return res
|
|
@ -0,0 +1,9 @@
|
|||
from "EMBED" import unpack
|
||||
|
||||
// Unpack a field element as N big endian bits without checking for overflows
|
||||
// This does *not* guarantee a single output: for example, 0 can be decomposed as 0 or as P and this function does not enforce either
|
||||
def main<N>(field i) -> bool[N]:
|
||||
|
||||
bool[N] res = unpack(i)
|
||||
|
||||
return res
|
16
zokrates_stdlib/tests/tests/utils/pack/bool/unpack256.json
Normal file
16
zokrates_stdlib/tests/tests/utils/pack/bool/unpack256.json
Normal file
|
@ -0,0 +1,16 @@
|
|||
{
|
||||
"entry_point": "./tests/tests/utils/pack/bool/unpack256.zok",
|
||||
"curves": ["Bn128"],
|
||||
"tests": [
|
||||
{
|
||||
"input": {
|
||||
"values": []
|
||||
},
|
||||
"output": {
|
||||
"Ok": {
|
||||
"values": []
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
24
zokrates_stdlib/tests/tests/utils/pack/bool/unpack256.zok
Normal file
24
zokrates_stdlib/tests/tests/utils/pack/bool/unpack256.zok
Normal file
|
@ -0,0 +1,24 @@
|
|||
import "utils/pack/bool/unpack256" as unpack256
|
||||
|
||||
def testFive() -> bool:
|
||||
|
||||
bool[256] b = unpack256(5)
|
||||
|
||||
assert(b == [...[false; 253], true, false, true])
|
||||
|
||||
return true
|
||||
|
||||
def testZero() -> bool:
|
||||
|
||||
bool[256] b = unpack256(0)
|
||||
|
||||
assert(b == [false; 256])
|
||||
|
||||
return true
|
||||
|
||||
def main():
|
||||
|
||||
assert(testFive())
|
||||
assert(testZero())
|
||||
|
||||
return
|
Loading…
Reference in a new issue