1
0
Fork 0
mirror of synced 2025-09-23 12:18:44 +00:00

remove branch isolation

This commit is contained in:
dark64 2023-10-23 19:03:50 +02:00
parent c7e4e29ad0
commit 755016c911
19 changed files with 237 additions and 471 deletions

View file

@ -1,43 +0,0 @@
// Isolate branches means making sure that any branch is enclosed in a block.
// This is important, because we want any statement resulting from inlining any branch to be isolated from the coller, so that its panics can be conditional to the branch being logically run
// `if c then a else b fi` becomes `if c then { a } else { b } fi`, and down the line any statements resulting from trating `a` and `b` can be safely kept inside the respective blocks.
use zokrates_ast::common::{Fold, WithSpan};
use zokrates_ast::typed::folder::*;
use zokrates_ast::typed::*;
use zokrates_field::Field;
pub struct Isolator;
impl Isolator {
pub fn isolate<T: Field>(p: TypedProgram<T>) -> TypedProgram<T> {
let mut isolator = Isolator;
isolator.fold_program(p)
}
}
impl<'ast, T: Field> Folder<'ast, T> for Isolator {
fn fold_conditional_expression<
E: Expr<'ast, T> + Block<'ast, T> + Fold<Self> + Conditional<'ast, T>,
>(
&mut self,
_: &E::Ty,
e: ConditionalExpression<'ast, T, E>,
) -> ConditionalOrExpression<'ast, T, E> {
let span = e.get_span();
let consequence_span = e.consequence.get_span();
let alternative_span = e.alternative.get_span();
ConditionalOrExpression::Conditional(
ConditionalExpression::new(
self.fold_boolean_expression(*e.condition),
E::block(vec![], e.consequence.fold(self)).span(consequence_span),
E::block(vec![], e.alternative.fold(self)).span(alternative_span),
e.kind,
)
.span(span),
)
}
}

View file

@ -6,7 +6,6 @@
mod assembly_transformer;
mod boolean_array_comparator;
mod branch_isolator;
mod condition_redefiner;
mod constant_argument_checker;
mod constant_resolver;
@ -25,7 +24,6 @@ mod variable_write_remover;
mod zir_propagation;
use self::boolean_array_comparator::BooleanArrayComparator;
use self::branch_isolator::Isolator;
use self::condition_redefiner::ConditionRedefiner;
use self::constant_argument_checker::ConstantArgumentChecker;
use self::flatten_complex_types::Flattener;
@ -132,17 +130,6 @@ pub fn analyse<'ast, T: Field>(
let r = ConstantResolver::inline(p);
log::trace!("\n{}", r);
// isolate branches
let r = if config.isolate_branches {
log::debug!("Static analyser: Isolate branches");
let r = Isolator::isolate(r);
log::trace!("\n{}", r);
r
} else {
log::debug!("Static analyser: Branch isolation skipped");
r
};
// include logs
let r = if config.debug {
log::debug!("Static analyser: Include logs");

View file

@ -44,11 +44,6 @@ pub fn subcommand() -> App<'static, 'static> {
.possible_values(cli_constants::CURVES)
.default_value(BN128),
)
.arg(Arg::with_name("isolate-branches")
.long("isolate-branches")
.help("Isolate the execution of branches: a panic in a branch only makes the program panic if this branch is being logically executed")
.required(false)
)
}
pub fn exec(sub_matches: &ArgMatches) -> Result<(), String> {
@ -94,10 +89,9 @@ fn cli_check<T: Field>(sub_matches: &ArgMatches) -> Result<(), String> {
)),
}?;
let config =
CompileConfig::default().isolate_branches(sub_matches.is_present("isolate-branches"));
let config = CompileConfig::default();
let resolver = FileSystemResolver::with_stdlib_root(stdlib_path);
check::<T, _>(source, path, Some(&resolver), &config).map_err(|e| {
format!(
"Check failed:\n\n{}",

View file

@ -18,62 +18,71 @@ use zokrates_fs_resolver::FileSystemResolver;
pub fn subcommand() -> App<'static, 'static> {
SubCommand::with_name("compile")
.about("Compiles into a runnable constraint system")
.arg(Arg::with_name("input")
.short("i")
.long("input")
.help("Path of the source code")
.value_name("FILE")
.takes_value(true)
.required(true)
).arg(Arg::with_name("stdlib-path")
.long("stdlib-path")
.help("Path to the standard library")
.value_name("PATH")
.takes_value(true)
.required(false)
.env("ZOKRATES_STDLIB")
.default_value(cli_constants::DEFAULT_STDLIB_PATH.as_str())
).arg(Arg::with_name("abi-spec")
.short("s")
.long("abi-spec")
.help("Path of the ABI specification")
.value_name("FILE")
.takes_value(true)
.required(false)
.default_value(cli_constants::ABI_SPEC_DEFAULT_PATH)
).arg(Arg::with_name("output")
.short("o")
.long("output")
.help("Path of the output binary")
.value_name("FILE")
.takes_value(true)
.required(false)
.default_value(cli_constants::FLATTENED_CODE_DEFAULT_PATH)
).arg(Arg::with_name("r1cs")
.short("r1cs")
.long("r1cs")
.help("Path of the output r1cs file")
.value_name("FILE")
.takes_value(true)
.required(false)
.default_value(cli_constants::CIRCOM_R1CS_DEFAULT_PATH)
).arg(Arg::with_name("curve")
.short("c")
.long("curve")
.help("Curve to be used in the compilation")
.takes_value(true)
.required(false)
.possible_values(cli_constants::CURVES)
.default_value(BN128)
).arg(Arg::with_name("isolate-branches")
.long("isolate-branches")
.help("Isolate the execution of branches: a panic in a branch only makes the program panic if this branch is being logically executed")
.required(false)
).arg(Arg::with_name("debug")
.long("debug")
.help("Include logs")
.required(false)
)
.arg(
Arg::with_name("input")
.short("i")
.long("input")
.help("Path of the source code")
.value_name("FILE")
.takes_value(true)
.required(true),
)
.arg(
Arg::with_name("stdlib-path")
.long("stdlib-path")
.help("Path to the standard library")
.value_name("PATH")
.takes_value(true)
.required(false)
.env("ZOKRATES_STDLIB")
.default_value(cli_constants::DEFAULT_STDLIB_PATH.as_str()),
)
.arg(
Arg::with_name("abi-spec")
.short("s")
.long("abi-spec")
.help("Path of the ABI specification")
.value_name("FILE")
.takes_value(true)
.required(false)
.default_value(cli_constants::ABI_SPEC_DEFAULT_PATH),
)
.arg(
Arg::with_name("output")
.short("o")
.long("output")
.help("Path of the output binary")
.value_name("FILE")
.takes_value(true)
.required(false)
.default_value(cli_constants::FLATTENED_CODE_DEFAULT_PATH),
)
.arg(
Arg::with_name("r1cs")
.short("r1cs")
.long("r1cs")
.help("Path of the output r1cs file")
.value_name("FILE")
.takes_value(true)
.required(false)
.default_value(cli_constants::CIRCOM_R1CS_DEFAULT_PATH),
)
.arg(
Arg::with_name("curve")
.short("c")
.long("curve")
.help("Curve to be used in the compilation")
.takes_value(true)
.required(false)
.possible_values(cli_constants::CURVES)
.default_value(BN128),
)
.arg(
Arg::with_name("debug")
.long("debug")
.help("Include logs")
.required(false),
)
}
pub fn exec(sub_matches: &ArgMatches) -> Result<(), String> {
@ -124,10 +133,7 @@ fn cli_compile<T: Field>(sub_matches: &ArgMatches) -> Result<(), String> {
)),
}?;
let config = CompileConfig::default()
.isolate_branches(sub_matches.is_present("isolate-branches"))
.debug(sub_matches.is_present("debug"));
let config = CompileConfig::default().debug(sub_matches.is_present("debug"));
let resolver = FileSystemResolver::with_stdlib_root(stdlib_path);
log::debug!("Compile");

View file

@ -33,7 +33,6 @@ use zokrates_ast::zir::{
BooleanExpression, Conditional, FieldElementExpression, Identifier, Parameter as ZirParameter,
UExpression, UExpressionInner, Variable as ZirVariable, ZirExpression, ZirStatement,
};
use zokrates_common::CompileConfig;
use zokrates_field::Field;
/// A container for statements produced during code generation
@ -80,13 +79,10 @@ impl<'ast, T> IntoIterator for FlatStatements<'ast, T> {
///
/// # Arguments
/// * `funct` - `ZirFunction` that will be flattened
pub fn from_program_and_config<T: Field>(
prog: ZirProgram<T>,
config: CompileConfig,
) -> FlattenerIterator<T> {
pub fn from_program_and_config<T: Field>(prog: ZirProgram<T>) -> FlattenerIterator<T> {
let funct = prog.main;
let mut flattener = Flattener::new(config);
let mut flattener = Flattener::new();
let mut statements_flattened = FlatStatements::default();
// push parameters
let arguments_flattened = funct
@ -137,7 +133,6 @@ impl<'ast, T: Field> Iterator for FlattenerIteratorInner<'ast, T> {
/// Flattener, computes flattened program.
#[derive(Debug)]
pub struct Flattener<'ast, T> {
config: CompileConfig,
/// Index of the next introduced variable while processing the program.
next_var_idx: usize,
/// `Variable`s corresponding to each `Identifier`
@ -275,9 +270,8 @@ impl<T: Field> FlatUExpression<T> {
impl<'ast, T: Field> Flattener<'ast, T> {
/// Returns a `Flattener` with fresh `layout`.
fn new(config: CompileConfig) -> Flattener<'ast, T> {
fn new() -> Flattener<'ast, T> {
Flattener {
config,
next_var_idx: 0,
layout: HashMap::new(),
bits_cache: HashMap::new(),
@ -571,68 +565,6 @@ impl<'ast, T: Field> Flattener<'ast, T> {
}
}
fn make_conditional(
&mut self,
statements: FlatStatements<'ast, T>,
condition: FlatExpression<T>,
) -> FlatStatements<'ast, T> {
statements
.into_iter()
.flat_map(|s| match s {
FlatStatement::Condition(s) => {
let span = s.get_span();
let mut output = FlatStatements::default();
output.set_span(span);
// we transform (a == b) into (c => (a == b)) which is (!c || (a == b))
// let's introduce new variables to make sure everything is linear
let name_lin = self.define(s.lin, &mut output);
let name_quad = self.define(s.quad, &mut output);
// let's introduce an expression which is 1 iff `a == b`
let y = FlatExpression::add(
FlatExpression::sub(name_lin.into(), name_quad.into()),
T::one().into(),
); // let's introduce !c
let x = FlatExpression::sub(T::one().into(), condition.clone());
assert!(x.is_linear() && y.is_linear());
let name_x_or_y = self.use_sym();
output.push_back(FlatStatement::directive(
vec![name_x_or_y],
Solver::Or,
vec![x.clone(), y.clone()],
));
output.push_back(FlatStatement::condition(
FlatExpression::add(
x.clone(),
FlatExpression::sub(y.clone(), name_x_or_y.into()),
),
FlatExpression::mul(x, y),
RuntimeError::BranchIsolation,
));
output.push_back(FlatStatement::condition(
name_x_or_y.into(),
T::one().into(),
s.error,
));
output
}
s => {
let mut v = FlatStatements::default();
v.push_back(s);
v
}
})
.fold(FlatStatements::default(), |mut acc, s| {
acc.push_back(s);
acc
})
}
/// Flatten an if/else expression
///
/// # Arguments
@ -658,35 +590,8 @@ impl<'ast, T: Field> Flattener<'ast, T> {
let condition_id = self.use_sym();
statements_flattened.push_back(FlatStatement::definition(condition_id, condition_flat));
let (consequence, alternative) = if self.config.isolate_branches {
let mut consequence_statements = FlatStatements::default();
let consequence = consequence.flatten(self, &mut consequence_statements);
let mut alternative_statements = FlatStatements::default();
let alternative = alternative.flatten(self, &mut alternative_statements);
let consequence_statements =
self.make_conditional(consequence_statements, condition_id.into());
let alternative_statements = self.make_conditional(
alternative_statements,
FlatExpression::sub(FlatExpression::value(T::one()), condition_id.into()),
);
statements_flattened.extend(consequence_statements);
statements_flattened.extend(alternative_statements);
(consequence, alternative)
} else {
(
consequence.flatten(self, statements_flattened),
alternative.flatten(self, statements_flattened),
)
};
let consequence = consequence.flat();
let alternative = alternative.flat();
let consequence = consequence.flatten(self, statements_flattened).flat();
let alternative = alternative.flatten(self, statements_flattened).flat();
let consequence_id = self.use_sym();
statements_flattened.push_back(FlatStatement::definition(consequence_id, consequence));
@ -2448,34 +2353,12 @@ impl<'ast, T: Field> Flattener<'ast, T> {
statements_flattened
.push_back(FlatStatement::definition(condition_id, condition_flat));
if self.config.isolate_branches {
let mut consequence_statements = FlatStatements::default();
let mut alternative_statements = FlatStatements::default();
s.consequence
.into_iter()
.for_each(|s| self.flatten_statement(&mut consequence_statements, s));
s.alternative
.into_iter()
.for_each(|s| self.flatten_statement(&mut alternative_statements, s));
let consequence_statements =
self.make_conditional(consequence_statements, condition_id.into());
let alternative_statements = self.make_conditional(
alternative_statements,
FlatExpression::sub(FlatExpression::value(T::one()), condition_id.into()),
);
statements_flattened.extend(consequence_statements);
statements_flattened.extend(alternative_statements);
} else {
s.consequence
.into_iter()
.for_each(|s| self.flatten_statement(statements_flattened, s));
s.alternative
.into_iter()
.for_each(|s| self.flatten_statement(statements_flattened, s));
}
s.consequence
.into_iter()
.for_each(|s| self.flatten_statement(statements_flattened, s));
s.alternative
.into_iter()
.for_each(|s| self.flatten_statement(statements_flattened, s));
}
ZirStatement::Definition(s) => {
// define n variables with n the number of primitive types for v_type
@ -2999,13 +2882,10 @@ mod tests {
use zokrates_field::Bn128Field;
fn flatten_function<T: Field>(f: ZirFunction<T>) -> FlatProg<T> {
from_program_and_config(
ZirProgram {
main: f,
module_map: Default::default(),
},
CompileConfig::default(),
)
from_program_and_config(ZirProgram {
main: f,
module_map: Default::default(),
})
.collect()
}
@ -3801,7 +3681,6 @@ mod tests {
#[test]
fn if_else() {
let config = CompileConfig::default();
let expression = FieldElementExpression::conditional(
BooleanExpression::field_eq(
FieldElementExpression::value(Bn128Field::from(32)),
@ -3811,15 +3690,14 @@ mod tests {
FieldElementExpression::value(Bn128Field::from(51)),
);
let mut flattener = Flattener::new(config);
let mut flattener = Flattener::new();
flattener.flatten_field_expression(&mut FlatStatements::default(), expression);
}
#[test]
fn geq_leq() {
let config = CompileConfig::default();
let mut flattener = Flattener::new(config);
let mut flattener = Flattener::new();
let expression_le = BooleanExpression::field_le(
FieldElementExpression::value(Bn128Field::from(32)),
FieldElementExpression::value(Bn128Field::from(4)),
@ -3829,8 +3707,7 @@ mod tests {
#[test]
fn bool_and() {
let config = CompileConfig::default();
let mut flattener = Flattener::new(config);
let mut flattener = Flattener::new();
let expression = FieldElementExpression::conditional(
BooleanExpression::bitand(
@ -3853,8 +3730,7 @@ mod tests {
#[test]
fn div() {
// a = 5 / b / b
let config = CompileConfig::default();
let mut flattener = Flattener::new(config);
let mut flattener = Flattener::new();
let mut statements_flattened = FlatStatements::default();
let definition = ZirStatement::definition(

View file

@ -14,18 +14,11 @@ pub trait Resolver<E> {
#[derive(Debug, Default, Serialize, Deserialize, Clone, Copy)]
pub struct CompileConfig {
#[serde(default)]
pub isolate_branches: bool,
#[serde(default)]
pub debug: bool,
}
impl CompileConfig {
pub fn isolate_branches(mut self, flag: bool) -> Self {
self.isolate_branches = flag;
self
}
pub fn debug(mut self, debug: bool) -> Self {
self.debug = debug;
self

View file

@ -183,7 +183,7 @@ pub fn compile<'ast, T: Field, E: Into<imports::Error>>(
// flatten input program
log::debug!("Flatten");
let program_flattened = from_program_and_config(typed_ast, config);
let program_flattened = from_program_and_config(typed_ast);
// convert to ir
log::debug!("Convert to IR");

View file

@ -4,7 +4,7 @@ def throwing_bound<N>(u32 x) -> u32 {
}
// this compiles: the conditional, even though it can throw, has a constant compile-time value of `1`
// However, the assertions are still checked at runtime, which leads to panics without branch isolation.
// However, the assertions are still checked at runtime, which leads to panics
def main(u32 x) {
for u32 i in 0..x == 0 ? throwing_bound::<0>(x) : throwing_bound::<1>(x) {}
return;

View file

@ -0,0 +1,50 @@
{
"entry_point": "./tests/tests/panics/conditional_panic.zok",
"curves": ["Bn128"],
"tests": [
{
"input": {
"values": [true]
},
"output": {
"Err": {
"UnsatisfiedConstraint": {
"left": "1",
"right": "0",
"error": {
"SourceAssertion": {
"file": "./tests/tests/panics/conditional_panic.zok",
"position": {
"line": 7,
"col": 5
}
}
}
}
}
}
},
{
"input": {
"values": [false]
},
"output": {
"Err": {
"UnsatisfiedConstraint": {
"left": "1",
"right": "0",
"error": {
"SourceAssertion": {
"file": "./tests/tests/panics/conditional_panic.zok",
"position": {
"line": 2,
"col": 5
}
}
}
}
}
}
}
]
}

View file

@ -0,0 +1,14 @@
def yes(bool x) -> bool {
assert(x);
return x;
}
def no(bool x) -> bool {
assert(!x);
return x;
}
def main(bool condition) -> bool {
// this will always panic
return condition ? yes(condition) : no(condition);
}

View file

@ -1,17 +1,48 @@
{
"entry_point": "./tests/tests/panics/deep_branch.zok",
"curves": ["Bn128"],
"config": {
"isolate_branches": true
},
"tests": [
{
"input": {
"values": [[false, false, false]]
},
"output": {
"Err": {
"UnsatisfiedConstraint": {
"left": "0",
"right": "1",
"error": {
"SourceAssertion": {
"file": "./tests/tests/panics/deep_branch.zok",
"position": {
"line": 2,
"col": 5
}
}
}
}
}
}
},
{
"input": {
"values": [[true, true, true]]
},
"output": {
"Ok": {
"value": [true, true, true]
"Err": {
"UnsatisfiedConstraint": {
"left": "0",
"right": "1",
"error": {
"SourceAssertion": {
"file": "./tests/tests/panics/deep_branch.zok",
"position": {
"line": 2,
"col": 5
}
}
}
}
}
}
},
@ -20,8 +51,20 @@
"values": [[false, false, false]]
},
"output": {
"Ok": {
"value": [false, false, false]
"Err": {
"UnsatisfiedConstraint": {
"left": "0",
"right": "1",
"error": {
"SourceAssertion": {
"file": "./tests/tests/panics/deep_branch.zok",
"position": {
"line": 2,
"col": 5
}
}
}
}
}
}
},
@ -30,8 +73,20 @@
"values": [[false, true, false]]
},
"output": {
"Ok": {
"value": [false, true, false]
"Err": {
"UnsatisfiedConstraint": {
"left": "0",
"right": "1",
"error": {
"SourceAssertion": {
"file": "./tests/tests/panics/deep_branch.zok",
"position": {
"line": 2,
"col": 5
}
}
}
}
}
}
},
@ -40,8 +95,20 @@
"values": [[true, false, true]]
},
"output": {
"Ok": {
"value": [true, false, true]
"Err": {
"UnsatisfiedConstraint": {
"left": "0",
"right": "1",
"error": {
"SourceAssertion": {
"file": "./tests/tests/panics/deep_branch.zok",
"position": {
"line": 2,
"col": 5
}
}
}
}
}
}
}

View file

@ -1,28 +0,0 @@
{
"entry_point": "./tests/tests/panics/deep_branch.zok",
"curves": ["Bn128"],
"tests": [
{
"input": {
"values": [[false, false, false]]
},
"output": {
"Err": {
"UnsatisfiedConstraint": {
"left": "0",
"right": "1",
"error": {
"SourceAssertion": {
"file": "./tests/tests/panics/deep_branch.zok",
"position": {
"line": 2,
"col": 5
}
}
}
}
}
}
}
]
}

View file

@ -1,9 +1,6 @@
{
"entry_point": "./tests/tests/panics/internal_panic.zok",
"curves": ["Bn128"],
"config": {
"isolate_branches": true
},
"tests": [
{
"input": {
@ -20,8 +17,12 @@
"values": ["0"]
},
"output": {
"Ok": {
"value": "0"
"Err": {
"UnsatisfiedConstraint": {
"left": "0",
"right": "1",
"error": "Inverse"
}
}
}
}

View file

@ -1,30 +0,0 @@
{
"entry_point": "./tests/tests/panics/internal_panic.zok",
"curves": ["Bn128"],
"tests": [
{
"input": {
"values": ["1"]
},
"output": {
"Ok": {
"value": "1"
}
}
},
{
"input": {
"values": ["0"]
},
"output": {
"Err": {
"UnsatisfiedConstraint": {
"left": "0",
"right": "1",
"error": "Inverse"
}
}
}
}
]
}

View file

@ -1,51 +0,0 @@
{
"entry_point": "./tests/tests/panics/panic_isolation.zok",
"config": {
"isolate_branches": true
},
"curves": ["Bn128"],
"tests": [
{
"input": {
"values": [true, ["42", "42"], "0"]
},
"output": {
"Err": {
"UnsatisfiedConstraint": {
"left": "1",
"right": "21888242871839275222246405745257275088548364400416034343698204186575808495577",
"error": {
"SourceAssertion": {
"file": "./tests/tests/panics/panic_isolation.zok",
"position": {
"line": 22,
"col": 5
}
}
}
}
}
}
},
{
"input": {
"values": [true, ["1", "1"], "1"]
},
"output": {
"Ok": {
"value": [true, ["1", "1"], "1"]
}
}
},
{
"input": {
"values": [false, ["2", "2"], "0"]
},
"output": {
"Ok": {
"value": [false, ["2", "2"], "0"]
}
}
}
]
}

View file

@ -1,38 +0,0 @@
def zero(field x) -> field {
assert(x == 0);
return 0;
}
def inverse(field x) -> field {
assert(x != 0);
return 1/x;
}
def yes(bool x) -> bool {
assert(x);
return x;
}
def no(bool x) -> bool {
assert(!x);
return x;
}
def ones(field[2] a) -> field[2] {
assert(a == [1, 1]);
return a;
}
def twos(field[2] a) -> field[2] {
assert(a == [2, 2]);
return a;
}
def main(bool condition, field[2] a, field x) -> (bool, field[2], field) {
// first branch asserts that `condition` is true, second branch asserts that `condition` is false. This should never throw.
// first branch asserts that all elements in `a` are 1, 2 in the second branch. This should throw only if `a` is neither ones or zeroes
// first branch asserts that `x` is zero and returns it, second branch asserts that `x` isn't 0 and returns its inverse (which internally generates a failing assert if x is 0). This should never throw
return (condition ? yes(condition) : no(condition), \
condition ? ones(a) : twos(a), \
x == 0 ? zero(x) : inverse(x));
}

View file

@ -1,31 +0,0 @@
{
"entry_point": "./tests/tests/panics/panic_isolation.zok",
"config": {
"isolate_branches": false
},
"curves": ["Bn128"],
"tests": [
{
"input": {
"values": [true, ["1", "1"], "1"]
},
"output": {
"Err": {
"UnsatisfiedConstraint": {
"left": "1",
"right": "0",
"error": {
"SourceAssertion": {
"file": "./tests/tests/panics/panic_isolation.zok",
"position": {
"line": 17,
"col": 5
}
}
}
}
}
}
}
]
}

View file

@ -12,7 +12,6 @@ declare module "zokrates-js" {
) => ResolverResult;
export interface CompileConfig {
isolate_branches?: boolean;
debug?: boolean;
}