1
0
Fork 0
mirror of synced 2025-09-23 12:18:44 +00:00

merge dev, fix conflicts, fix comment in test

This commit is contained in:
schaeff 2021-08-09 11:43:03 +02:00
commit 249187a157
14 changed files with 219 additions and 56 deletions

65
Cargo.lock generated
View file

@ -37,6 +37,15 @@ dependencies = [
"memchr", "memchr",
] ]
[[package]]
name = "aho-corasick"
version = "0.7.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1e37cfd5e7657ada45f742d6e99ca5788580b5c529dc78faf11ece6dc702656f"
dependencies = [
"memchr",
]
[[package]] [[package]]
name = "ansi_term" name = "ansi_term"
version = "0.11.0" version = "0.11.0"
@ -776,6 +785,19 @@ version = "1.6.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e78d4f1cc4ae33bbfc157ed5d5a5ef3bc29227303d595861deb238fcec4e9457" checksum = "e78d4f1cc4ae33bbfc157ed5d5a5ef3bc29227303d595861deb238fcec4e9457"
[[package]]
name = "env_logger"
version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0b2cf0344971ee6c64c31be0d530793fba457d322dfec2810c453d0ef228f9c3"
dependencies = [
"atty",
"humantime",
"log",
"regex 1.5.4",
"termcolor",
]
[[package]] [[package]]
name = "environment" name = "environment"
version = "0.1.1" version = "0.1.1"
@ -1076,6 +1098,12 @@ version = "0.4.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70"
[[package]]
name = "humantime"
version = "2.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4"
[[package]] [[package]]
name = "idna" name = "idna"
version = "0.2.3" version = "0.2.3"
@ -1682,13 +1710,24 @@ version = "0.2.11"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9329abc99e39129fcceabd24cf5d85b4671ef7c29c50e972bc5afe32438ec384" checksum = "9329abc99e39129fcceabd24cf5d85b4671ef7c29c50e972bc5afe32438ec384"
dependencies = [ dependencies = [
"aho-corasick", "aho-corasick 0.6.10",
"memchr", "memchr",
"regex-syntax", "regex-syntax 0.5.6",
"thread_local", "thread_local",
"utf8-ranges", "utf8-ranges",
] ]
[[package]]
name = "regex"
version = "1.5.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d07a8629359eb56f1e2fb1652bb04212c072a87ba68546a04065d525673ac461"
dependencies = [
"aho-corasick 0.7.18",
"memchr",
"regex-syntax 0.6.25",
]
[[package]] [[package]]
name = "regex-automata" name = "regex-automata"
version = "0.1.10" version = "0.1.10"
@ -1704,6 +1743,12 @@ dependencies = [
"ucd-util", "ucd-util",
] ]
[[package]]
name = "regex-syntax"
version = "0.6.25"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f497285884f3fcff424ffc933e56d7cbca511def0c9831a7f9b5f6153e3cc89b"
[[package]] [[package]]
name = "remove_dir_all" name = "remove_dir_all"
version = "0.5.3" version = "0.5.3"
@ -1959,6 +2004,15 @@ dependencies = [
"winapi", "winapi",
] ]
[[package]]
name = "termcolor"
version = "1.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2dfed899f0eb03f32ee8c6a0aabdb8a7949659e3466561fc0adf54e26d88c5f4"
dependencies = [
"winapi-util",
]
[[package]] [[package]]
name = "textwrap" name = "textwrap"
version = "0.11.0" version = "0.11.0"
@ -2336,10 +2390,12 @@ dependencies = [
"cfg-if 0.1.10", "cfg-if 0.1.10",
"clap", "clap",
"dirs", "dirs",
"env_logger",
"fs_extra", "fs_extra",
"glob 0.2.11", "glob 0.2.11",
"lazy_static", "lazy_static",
"regex", "log",
"regex 0.2.11",
"serde_json", "serde_json",
"tempdir", "tempdir",
"zokrates_abi", "zokrates_abi",
@ -2378,6 +2434,7 @@ dependencies = [
"git2", "git2",
"hex", "hex",
"lazy_static", "lazy_static",
"log",
"num", "num",
"num-bigint 0.2.6", "num-bigint 0.2.6",
"pairing_ce", "pairing_ce",
@ -2385,7 +2442,7 @@ dependencies = [
"rand 0.4.6", "rand 0.4.6",
"rand 0.7.3", "rand 0.7.3",
"reduce", "reduce",
"regex", "regex 0.2.11",
"serde", "serde",
"serde_json", "serde_json",
"sha2 0.9.5", "sha2 0.9.5",

View file

@ -0,0 +1 @@
Add compiler logs

View file

@ -0,0 +1 @@
Fix constant range check in uint lt check

View file

@ -12,6 +12,8 @@ bellman = ["zokrates_core/bellman"]
ark = ["zokrates_core/ark"] ark = ["zokrates_core/ark"]
[dependencies] [dependencies]
log = "0.4"
env_logger = "0.9.0"
cfg-if = "0.1" cfg-if = "0.1"
clap = "2.26.2" clap = "2.26.2"
bincode = "0.8.0" bincode = "0.8.0"

View file

@ -1,5 +1,5 @@
from "EMBED" import bit_array_le from "EMBED" import bit_array_le
// Unpack a field element as N big endian bits // Calling the `bit_array_le` embed on a non-constant second argument should fail at compile-time
def main(bool[1] a, bool[1] b) -> bool: def main(bool[1] a, bool[1] b) -> bool:
return bit_array_le::<1>(a, b) return bit_array_le::<1>(a, b)

View file

@ -20,6 +20,8 @@ fn main() {
// set a custom panic hook // set a custom panic hook
std::panic::set_hook(Box::new(panic_hook)); std::panic::set_hook(Box::new(panic_hook));
env_logger::init();
cli().unwrap_or_else(|e| { cli().unwrap_or_else(|e| {
println!("{}", e); println!("{}", e);
std::process::exit(1); std::process::exit(1);

View file

@ -99,6 +99,8 @@ fn cli_compile<T: Field>(sub_matches: &ArgMatches) -> Result<(), String> {
let abi_spec_path = Path::new(sub_matches.value_of("abi-spec").unwrap()); let abi_spec_path = Path::new(sub_matches.value_of("abi-spec").unwrap());
let hr_output_path = bin_output_path.to_path_buf().with_extension("ztf"); let hr_output_path = bin_output_path.to_path_buf().with_extension("ztf");
log::debug!("Load entry point file {}", path.display());
let file = File::open(path.clone()) let file = File::open(path.clone())
.map_err(|why| format!("Could not open {}: {}", path.display(), why))?; .map_err(|why| format!("Could not open {}: {}", path.display(), why))?;
@ -131,6 +133,9 @@ fn cli_compile<T: Field>(sub_matches: &ArgMatches) -> Result<(), String> {
.isolate_branches(sub_matches.is_present("isolate-branches")); .isolate_branches(sub_matches.is_present("isolate-branches"));
let resolver = FileSystemResolver::with_stdlib_root(stdlib_path); let resolver = FileSystemResolver::with_stdlib_root(stdlib_path);
log::debug!("Compile");
let artifacts: CompilationArtifacts<T> = compile(source, path, Some(&resolver), &config) let artifacts: CompilationArtifacts<T> = compile(source, path, Some(&resolver), &config)
.map_err(|e| { .map_err(|e| {
format!( format!(
@ -148,6 +153,7 @@ fn cli_compile<T: Field>(sub_matches: &ArgMatches) -> Result<(), String> {
let num_constraints = program_flattened.constraint_count(); let num_constraints = program_flattened.constraint_count();
// serialize flattened program and write to binary file // serialize flattened program and write to binary file
log::debug!("Serialize program");
let bin_output_file = File::create(&bin_output_path) let bin_output_file = File::create(&bin_output_path)
.map_err(|why| format!("Could not create {}: {}", bin_output_path.display(), why))?; .map_err(|why| format!("Could not create {}: {}", bin_output_path.display(), why))?;
@ -156,6 +162,7 @@ fn cli_compile<T: Field>(sub_matches: &ArgMatches) -> Result<(), String> {
program_flattened.serialize(&mut writer); program_flattened.serialize(&mut writer);
// serialize ABI spec and write to JSON file // serialize ABI spec and write to JSON file
log::debug!("Serialize ABI");
let abi_spec_file = File::create(&abi_spec_path) let abi_spec_file = File::create(&abi_spec_path)
.map_err(|why| format!("Could not create {}: {}", abi_spec_path.display(), why))?; .map_err(|why| format!("Could not create {}: {}", abi_spec_path.display(), why))?;
@ -173,6 +180,7 @@ fn cli_compile<T: Field>(sub_matches: &ArgMatches) -> Result<(), String> {
if sub_matches.is_present("ztf") { if sub_matches.is_present("ztf") {
// write human-readable output file // write human-readable output file
log::debug!("Serialize human readable program");
let hr_output_file = File::create(&hr_output_path) let hr_output_file = File::create(&hr_output_path)
.map_err(|why| format!("Could not create {}: {}", hr_output_path.display(), why))?; .map_err(|why| format!("Could not create {}: {}", hr_output_path.display(), why))?;

View file

@ -16,6 +16,7 @@ multicore = ["bellman_ce/multicore"]
ark = ["ark-ff", "ark-ec", "ark-bn254", "ark-bls12-377", "ark-bw6-761", "ark-gm17", "ark-serialize", "ark-relations", "ark-marlin", "ark-poly", "ark-poly-commit", "zokrates_field/ark", "sha2"] ark = ["ark-ff", "ark-ec", "ark-bn254", "ark-bls12-377", "ark-bw6-761", "ark-gm17", "ark-serialize", "ark-relations", "ark-marlin", "ark-poly", "ark-poly-commit", "zokrates_field/ark", "sha2"]
[dependencies] [dependencies]
log = "0.4"
cfg-if = "0.1" cfg-if = "0.1"
num = { version = "0.1.36", default-features = false } num = { version = "0.1.36", default-features = false }
num-bigint = { version = "0.2", default-features = false } num-bigint = { version = "0.2", default-features = false }

View file

@ -192,18 +192,23 @@ pub fn compile<T: Field, E: Into<imports::Error>>(
let (typed_ast, abi) = check_with_arena(source, location, resolver, config, &arena)?; let (typed_ast, abi) = check_with_arena(source, location, resolver, config, &arena)?;
// flatten input program // flatten input program
log::debug!("Flatten");
let program_flattened = Flattener::flatten(typed_ast, config); let program_flattened = Flattener::flatten(typed_ast, config);
// analyse (constant propagation after call resolution) // analyse (constant propagation after call resolution)
log::debug!("Analyse flat program");
let program_flattened = program_flattened.analyse(); let program_flattened = program_flattened.analyse();
// convert to ir // convert to ir
log::debug!("Convert to IR");
let ir_prog = ir::Prog::from(program_flattened); let ir_prog = ir::Prog::from(program_flattened);
// optimize // optimize
log::debug!("Optimise IR");
let optimized_ir_prog = ir_prog.optimize(); let optimized_ir_prog = ir_prog.optimize();
// analyse (check constraints) // analyse (check constraints)
log::debug!("Analyse IR");
let optimized_ir_prog = optimized_ir_prog.analyse(); let optimized_ir_prog = optimized_ir_prog.analyse();
Ok(CompilationArtifacts { Ok(CompilationArtifacts {
@ -231,7 +236,12 @@ fn check_with_arena<'ast, T: Field, E: Into<imports::Error>>(
arena: &'ast Arena<String>, arena: &'ast Arena<String>,
) -> Result<(ZirProgram<'ast, T>, Abi), CompileErrors> { ) -> Result<(ZirProgram<'ast, T>, Abi), CompileErrors> {
let source = arena.alloc(source); let source = arena.alloc(source);
let compiled = compile_program::<T, E>(source, location, resolver, &arena)?;
log::debug!("Parse program with entry file {}", location.display());
let compiled = parse_program::<T, E>(source, location, resolver, &arena)?;
log::debug!("Check semantics");
// check semantics // check semantics
let typed_ast = Checker::check(compiled) let typed_ast = Checker::check(compiled)
@ -239,13 +249,15 @@ fn check_with_arena<'ast, T: Field, E: Into<imports::Error>>(
let main_module = typed_ast.main.clone(); let main_module = typed_ast.main.clone();
log::debug!("Run static analysis");
// analyse (unroll and constant propagation) // analyse (unroll and constant propagation)
typed_ast typed_ast
.analyse(config) .analyse(config)
.map_err(|e| CompileErrors(vec![CompileErrorInner::from(e).in_file(&main_module)])) .map_err(|e| CompileErrors(vec![CompileErrorInner::from(e).in_file(&main_module)]))
} }
pub fn compile_program<'ast, T: Field, E: Into<imports::Error>>( pub fn parse_program<'ast, T: Field, E: Into<imports::Error>>(
source: &'ast str, source: &'ast str,
location: FilePath, location: FilePath,
resolver: Option<&dyn Resolver<E>>, resolver: Option<&dyn Resolver<E>>,
@ -253,7 +265,7 @@ pub fn compile_program<'ast, T: Field, E: Into<imports::Error>>(
) -> Result<Program<'ast>, CompileErrors> { ) -> Result<Program<'ast>, CompileErrors> {
let mut modules = HashMap::new(); let mut modules = HashMap::new();
let main = compile_module::<T, E>(&source, location.clone(), resolver, &mut modules, &arena)?; let main = parse_module::<T, E>(&source, location.clone(), resolver, &mut modules, &arena)?;
modules.insert(location.clone(), main); modules.insert(location.clone(), main);
@ -263,21 +275,29 @@ pub fn compile_program<'ast, T: Field, E: Into<imports::Error>>(
}) })
} }
pub fn compile_module<'ast, T: Field, E: Into<imports::Error>>( pub fn parse_module<'ast, T: Field, E: Into<imports::Error>>(
source: &'ast str, source: &'ast str,
location: FilePath, location: FilePath,
resolver: Option<&dyn Resolver<E>>, resolver: Option<&dyn Resolver<E>>,
modules: &mut HashMap<OwnedModuleId, Module<'ast>>, modules: &mut HashMap<OwnedModuleId, Module<'ast>>,
arena: &'ast Arena<String>, arena: &'ast Arena<String>,
) -> Result<Module<'ast>, CompileErrors> { ) -> Result<Module<'ast>, CompileErrors> {
log::debug!("Generate pest AST for {}", location.display());
let ast = pest::generate_ast(&source) let ast = pest::generate_ast(&source)
.map_err(|e| CompileErrors::from(CompileErrorInner::from(e).in_file(&location)))?; .map_err(|e| CompileErrors::from(CompileErrorInner::from(e).in_file(&location)))?;
log::debug!("Process macros for {}", location.display());
let ast = process_macros::<T>(ast) let ast = process_macros::<T>(ast)
.map_err(|e| CompileErrors::from(CompileErrorInner::from(e).in_file(&location)))?; .map_err(|e| CompileErrors::from(CompileErrorInner::from(e).in_file(&location)))?;
log::debug!("Generate absy for {}", location.display());
let module_without_imports: Module = Module::from(ast); let module_without_imports: Module = Module::from(ast);
log::debug!("Apply imports to absy for {}", location.display());
Importer::apply_imports::<T, E>( Importer::apply_imports::<T, E>(
module_without_imports, module_without_imports,
location.clone(), location.clone(),

View file

@ -40,18 +40,18 @@ pub struct Flattener<'ast, T: Field> {
} }
trait FlattenOutput<T: Field>: Sized { trait FlattenOutput<T: Field>: Sized {
fn flat(&self) -> FlatExpression<T>; fn flat(self) -> FlatExpression<T>;
} }
impl<T: Field> FlattenOutput<T> for FlatExpression<T> { impl<T: Field> FlattenOutput<T> for FlatExpression<T> {
fn flat(&self) -> FlatExpression<T> { fn flat(self) -> FlatExpression<T> {
self.clone() self
} }
} }
impl<T: Field> FlattenOutput<T> for FlatUExpression<T> { impl<T: Field> FlattenOutput<T> for FlatUExpression<T> {
fn flat(&self) -> FlatExpression<T> { fn flat(self) -> FlatExpression<T> {
self.clone().get_field_unchecked() self.get_field_unchecked()
} }
} }
@ -215,6 +215,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
/// **true => a -> 0 /// **true => a -> 0
/// sizeUnkown * /// sizeUnkown *
/// **false => a -> {0,1} /// **false => a -> {0,1}
#[must_use]
fn constant_le_check( fn constant_le_check(
&mut self, &mut self,
statements_flattened: &mut FlatStatements<T>, statements_flattened: &mut FlatStatements<T>,
@ -904,8 +905,8 @@ impl<'ast, T: Field> Flattener<'ast, T> {
// Y == X * M // Y == X * M
// 0 == (1-Y) * X // 0 == (1-Y) * X
assert!(lhs.metadata.clone().unwrap().should_reduce.to_bool()); assert!(lhs.metadata.as_ref().unwrap().should_reduce.to_bool());
assert!(rhs.metadata.clone().unwrap().should_reduce.to_bool()); assert!(rhs.metadata.as_ref().unwrap().should_reduce.to_bool());
let lhs = self let lhs = self
.flatten_uint_expression(statements_flattened, lhs) .flatten_uint_expression(statements_flattened, lhs)
@ -923,7 +924,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
); );
let eq = self.flatten_boolean_expression( let eq = self.flatten_boolean_expression(
statements_flattened, statements_flattened,
BooleanExpression::FieldEq(box lhs.clone(), box rhs.clone()), BooleanExpression::FieldEq(box lhs, box rhs),
); );
FlatExpression::Add(box eq, box lt) FlatExpression::Add(box eq, box lt)
} }
@ -1017,7 +1018,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
); );
let eq = self.flatten_boolean_expression( let eq = self.flatten_boolean_expression(
statements_flattened, statements_flattened,
BooleanExpression::UintEq(box lhs.clone(), box rhs.clone()), BooleanExpression::UintEq(box lhs, box rhs),
); );
FlatExpression::Add(box eq, box lt) FlatExpression::Add(box eq, box lt)
} }
@ -1044,7 +1045,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
box x.clone(), box x.clone(),
box FlatExpression::Sub(box y.clone(), box name_x_or_y.into()), box FlatExpression::Sub(box y.clone(), box name_x_or_y.into()),
), ),
FlatExpression::Mult(box x.clone(), box y.clone()), FlatExpression::Mult(box x, box y),
RuntimeError::Or, RuntimeError::Or,
)); ));
name_x_or_y.into() name_x_or_y.into()
@ -1088,9 +1089,9 @@ impl<'ast, T: Field> Flattener<'ast, T> {
bitwidth: UBitwidth, bitwidth: UBitwidth,
) -> Vec<FlatUExpression<T>> { ) -> Vec<FlatUExpression<T>> {
let expression = UExpression::try_from(expression).unwrap(); let expression = UExpression::try_from(expression).unwrap();
let from = expression.metadata.clone().unwrap().bitwidth(); let from = expression.metadata.as_ref().unwrap().bitwidth();
let p = self.flatten_uint_expression(statements_flattened, expression); let p = self.flatten_uint_expression(statements_flattened, expression);
self.get_bits(p, from as usize, bitwidth, statements_flattened) self.get_bits(&p, from as usize, bitwidth, statements_flattened)
.into_iter() .into_iter()
.map(FlatUExpression::with_field) .map(FlatUExpression::with_field)
.collect() .collect()
@ -1127,27 +1128,29 @@ impl<'ast, T: Field> Flattener<'ast, T> {
statements_flattened: &mut FlatStatements<T>, statements_flattened: &mut FlatStatements<T>,
embed: FlatEmbed, embed: FlatEmbed,
generics: Vec<u32>, generics: Vec<u32>,
param_expressions: Vec<ZirExpression<'ast, T>>, mut param_expressions: Vec<ZirExpression<'ast, T>>,
) -> Vec<FlatUExpression<T>> { ) -> Vec<FlatUExpression<T>> {
match embed { match embed {
crate::embed::FlatEmbed::U64ToBits => self.flatten_u_to_bits( crate::embed::FlatEmbed::U64ToBits => self.flatten_u_to_bits(
statements_flattened, statements_flattened,
param_expressions[0].clone(), param_expressions.pop().unwrap(),
64.into(), 64.into(),
), ),
crate::embed::FlatEmbed::U32ToBits => self.flatten_u_to_bits( crate::embed::FlatEmbed::U32ToBits => self.flatten_u_to_bits(
statements_flattened, statements_flattened,
param_expressions[0].clone(), param_expressions.pop().unwrap(),
32.into(), 32.into(),
), ),
crate::embed::FlatEmbed::U16ToBits => self.flatten_u_to_bits( crate::embed::FlatEmbed::U16ToBits => self.flatten_u_to_bits(
statements_flattened, statements_flattened,
param_expressions[0].clone(), param_expressions.pop().unwrap(),
16.into(), 16.into(),
), ),
crate::embed::FlatEmbed::U8ToBits => { crate::embed::FlatEmbed::U8ToBits => self.flatten_u_to_bits(
self.flatten_u_to_bits(statements_flattened, param_expressions[0].clone(), 8.into()) statements_flattened,
} param_expressions.pop().unwrap(),
8.into(),
),
crate::embed::FlatEmbed::U64FromBits => { crate::embed::FlatEmbed::U64FromBits => {
vec![self.flatten_bits_to_u(statements_flattened, param_expressions, 64.into())] vec![self.flatten_bits_to_u(statements_flattened, param_expressions, 64.into())]
} }
@ -1344,10 +1347,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
self.define(e, statements_flattened).into() self.define(e, statements_flattened).into()
} else if n == T::from(1) { } else if n == T::from(1) {
self.define( self.define(
FlatExpression::Sub( FlatExpression::Sub(box FlatExpression::Number(T::from(1)), box e),
box FlatExpression::Number(T::from(1)),
box e.clone(),
),
statements_flattened, statements_flattened,
) )
.into() .into()
@ -1370,8 +1370,8 @@ impl<'ast, T: Field> Flattener<'ast, T> {
box FlatExpression::Sub(box y.clone(), box name.into()), box FlatExpression::Sub(box y.clone(), box name.into()),
), ),
FlatExpression::Mult( FlatExpression::Mult(
box FlatExpression::Add(box x.clone(), box x.clone()), box FlatExpression::Add(box x.clone(), box x),
box y.clone(), box y,
), ),
RuntimeError::Xor, RuntimeError::Xor,
), ),
@ -1442,7 +1442,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
// q in range // q in range
let _ = self.get_bits( let _ = self.get_bits(
FlatUExpression::with_field(FlatExpression::from(q)), &FlatUExpression::with_field(FlatExpression::from(q)),
target_bitwidth.to_usize(), target_bitwidth.to_usize(),
target_bitwidth, target_bitwidth,
statements_flattened, statements_flattened,
@ -1450,7 +1450,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
// r in range // r in range
let _ = self.get_bits( let _ = self.get_bits(
FlatUExpression::with_field(FlatExpression::from(r)), &FlatUExpression::with_field(FlatExpression::from(r)),
target_bitwidth.to_usize(), target_bitwidth.to_usize(),
target_bitwidth, target_bitwidth,
statements_flattened, statements_flattened,
@ -1458,7 +1458,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
// r < d <=> r - d + 2**w < 2**w // r < d <=> r - d + 2**w < 2**w
let _ = self.get_bits( let _ = self.get_bits(
FlatUExpression::with_field(FlatExpression::Add( &FlatUExpression::with_field(FlatExpression::Add(
box FlatExpression::Sub(box r.into(), box d.clone()), box FlatExpression::Sub(box r.into(), box d.clone()),
box FlatExpression::Number(T::from(2_u128.pow(target_bitwidth.to_usize() as u32))), box FlatExpression::Number(T::from(2_u128.pow(target_bitwidth.to_usize() as u32))),
)), )),
@ -1565,7 +1565,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
UExpressionInner::Sub(box left, box right) => { UExpressionInner::Sub(box left, box right) => {
// see uint optimizer for the reasoning here // see uint optimizer for the reasoning here
let offset = FlatExpression::Number(T::from(2).pow(std::cmp::max( let offset = FlatExpression::Number(T::from(2).pow(std::cmp::max(
right.metadata.clone().unwrap().bitwidth() as usize, right.metadata.as_ref().unwrap().bitwidth() as usize,
target_bitwidth as usize, target_bitwidth as usize,
))); )));
@ -1869,10 +1869,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
} }
} }
(x, y) => self (x, y) => self
.define( .define(FlatExpression::Mult(box x, box y), statements_flattened)
FlatExpression::Mult(box x.clone(), box y.clone()),
statements_flattened,
)
.into(), .into(),
}) })
.collect(); .collect();
@ -1934,12 +1931,8 @@ impl<'ast, T: Field> Flattener<'ast, T> {
let res = match should_reduce { let res = match should_reduce {
true => { true => {
let bits = self.get_bits( let bits =
res.clone(), self.get_bits(&res, actual_bitwidth, target_bitwidth, statements_flattened);
actual_bitwidth,
target_bitwidth,
statements_flattened,
);
let field = if actual_bitwidth > target_bitwidth.to_usize() { let field = if actual_bitwidth > target_bitwidth.to_usize() {
bits.iter().enumerate().fold( bits.iter().enumerate().fold(
@ -1970,7 +1963,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
fn get_bits( fn get_bits(
&mut self, &mut self,
e: FlatUExpression<T>, e: &FlatUExpression<T>,
from: usize, from: usize,
to: UBitwidth, to: UBitwidth,
statements_flattened: &mut FlatStatements<T>, statements_flattened: &mut FlatStatements<T>,
@ -2043,7 +2036,8 @@ impl<'ast, T: Field> Flattener<'ast, T> {
assert_eq!(bits.len(), to); assert_eq!(bits.len(), to);
self.bits_cache.insert(e.field.unwrap(), bits.clone()); self.bits_cache
.insert(e.field.clone().unwrap(), bits.clone());
self.bits_cache.insert(sum, bits.clone()); self.bits_cache.insert(sum, bits.clone());
bits bits
@ -2647,7 +2641,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
// to constrain unsigned integer inputs to be in range, we get their bit decomposition. // to constrain unsigned integer inputs to be in range, we get their bit decomposition.
// it will be cached // it will be cached
self.get_bits( self.get_bits(
FlatUExpression::with_field(FlatExpression::Identifier(variable)), &FlatUExpression::with_field(FlatExpression::Identifier(variable)),
bitwidth.to_usize(), bitwidth.to_usize(),
bitwidth, bitwidth,
statements_flattened, statements_flattened,

View file

@ -5,7 +5,7 @@
//! @date 2018 //! @date 2018
use crate::absy::*; use crate::absy::*;
use crate::compile::compile_module; use crate::compile::parse_module;
use crate::compile::{CompileErrorInner, CompileErrors}; use crate::compile::{CompileErrorInner, CompileErrors};
use crate::embed::FlatEmbed; use crate::embed::FlatEmbed;
use crate::parser::Position; use crate::parser::Position;
@ -226,7 +226,7 @@ impl Importer {
Some(_) => {} Some(_) => {}
None => { None => {
let source = arena.alloc(source); let source = arena.alloc(source);
let compiled = compile_module::<T, E>( let compiled = parse_module::<T, E>(
source, source,
new_location.clone(), new_location.clone(),
resolver, resolver,

View file

@ -21,12 +21,30 @@ use zokrates_field::Field;
impl<T: Field> Prog<T> { impl<T: Field> Prog<T> {
pub fn optimize(self) -> Self { pub fn optimize(self) -> Self {
// remove redefinitions // remove redefinitions
log::debug!("Constraints: {}", self.constraint_count());
log::debug!("Optimizer: Remove redefinitions");
let r = RedefinitionOptimizer::optimize(self); let r = RedefinitionOptimizer::optimize(self);
log::debug!("Done");
// remove constraints that are always satisfied // remove constraints that are always satisfied
log::debug!("Constraints: {}", r.constraint_count());
log::debug!("Optimizer: Remove tautologies");
let r = TautologyOptimizer::optimize(r); let r = TautologyOptimizer::optimize(r);
// // deduplicate directives which take the same input log::debug!("Done");
// deduplicate directives which take the same input
log::debug!("Constraints: {}", r.constraint_count());
log::debug!("Optimizer: Remove duplicate directive");
let r = DirectiveOptimizer::optimize(r); let r = DirectiveOptimizer::optimize(r);
log::debug!("Done");
// remove duplicate constraints // remove duplicate constraints
DuplicateOptimizer::optimize(r) log::debug!("Constraints: {}", r.constraint_count());
log::debug!("Optimizer: Remove duplicate constraints");
let r = DuplicateOptimizer::optimize(r);
log::debug!("Done");
log::debug!("Constraints: {}", r.constraint_count());
r
} }
} }

View file

@ -73,29 +73,54 @@ impl fmt::Display for Error {
impl<'ast, T: Field> TypedProgram<'ast, T> { impl<'ast, T: Field> TypedProgram<'ast, T> {
pub fn analyse(self, config: &CompileConfig) -> Result<(ZirProgram<'ast, T>, Abi), Error> { pub fn analyse(self, config: &CompileConfig) -> Result<(ZirProgram<'ast, T>, Abi), Error> {
// inline user-defined constants // inline user-defined constants
log::debug!("Static analyser: Inline constants");
let r = ConstantInliner::inline(self); let r = ConstantInliner::inline(self);
log::trace!("\n{}", r);
// isolate branches // isolate branches
let r = if config.isolate_branches { let r = if config.isolate_branches {
Isolator::isolate(r) log::debug!("Static analyser: Isolate branches");
let r = Isolator::isolate(r);
log::trace!("\n{}", r);
r
} else { } else {
log::debug!("Static analyser: Branch isolation skipped");
r r
}; };
// reduce the program to a single function // reduce the program to a single function
log::debug!("Static analyser: Reduce program");
let r = reduce_program(r).map_err(Error::from)?; let r = reduce_program(r).map_err(Error::from)?;
log::trace!("\n{}", r);
// generate abi // generate abi
log::debug!("Static analyser: Generate abi");
let abi = r.abi(); let abi = r.abi();
// propagate // propagate
log::debug!("Static analyser: Propagate");
let r = Propagator::propagate(r).map_err(Error::from)?; let r = Propagator::propagate(r).map_err(Error::from)?;
log::trace!("\n{}", r);
// remove assignment to variable index // remove assignment to variable index
log::debug!("Static analyser: Remove variable index");
let r = VariableWriteRemover::apply(r); let r = VariableWriteRemover::apply(r);
log::trace!("\n{}", r);
// detect non constant shifts and constant lt bounds // detect non constant shifts and constant lt bounds
log::debug!("Static analyser: Detect non constant arguments");
let r = ConstantArgumentChecker::check(r).map_err(Error::from)?; let r = ConstantArgumentChecker::check(r).map_err(Error::from)?;
log::trace!("\n{}", r);
// convert to zir, removing complex types // convert to zir, removing complex types
log::debug!("Static analyser: Convert to zir");
let zir = Flattener::flatten(r); let zir = Flattener::flatten(r);
log::trace!("\n{}", zir);
// optimize uint expressions // optimize uint expressions
log::debug!("Static analyser: Optimize uints");
let zir = UintOptimizer::optimize(zir); let zir = UintOptimizer::optimize(zir);
log::trace!("\n{}", zir);
Ok((zir, abi)) Ok((zir, abi))
} }
@ -103,12 +128,14 @@ impl<'ast, T: Field> TypedProgram<'ast, T> {
impl<T: Field> Analyse for FlatProg<T> { impl<T: Field> Analyse for FlatProg<T> {
fn analyse(self) -> Self { fn analyse(self) -> Self {
log::debug!("Static analyser: Propagate flat");
self.propagate() self.propagate()
} }
} }
impl<T: Field> Analyse for Prog<T> { impl<T: Field> Analyse for Prog<T> {
fn analyse(self) -> Self { fn analyse(self) -> Self {
log::debug!("Static analyser: Detect unconstrained zir");
UnconstrainedVariableDetector::detect(self) UnconstrainedVariableDetector::detect(self)
} }
} }

View file

@ -76,6 +76,38 @@ fn lt_uint() {
.is_err()); .is_err());
} }
#[test]
fn lt_uint() {
let source = r#"
def main(private u32 a, private u32 b):
field x = if a < b then 3333 else 4444 fi
assert(x == 3333)
return
"#
.to_string();
// let's try to prove that "10000u32 < 5555u32" is true by exploiting
// the fact that `2*10000 - 2*5555` has two distinct bit decompositions
// we chose the one which is out of range, ie the sum check features an overflow
let res: CompilationArtifacts<Bn128Field> = compile(
source,
"./path/to/file".into(),
None::<&dyn Resolver<io::Error>>,
&CompileConfig::default(),
)
.unwrap();
let interpreter = Interpreter::try_out_of_range();
assert!(interpreter
.execute(
&res.prog(),
&[Bn128Field::from(10000), Bn128Field::from(5555)]
)
.is_err());
}
#[test] #[test]
fn unpack256() { fn unpack256() {
let source = r#" let source = r#"