diff --git a/zokrates_core/src/embed.rs b/zokrates_core/src/embed.rs index d5e0d637..68261cb2 100644 --- a/zokrates_core/src/embed.rs +++ b/zokrates_core/src/embed.rs @@ -15,12 +15,11 @@ use zokrates_field::Field; pub enum FlatEmbed { Sha256Round, Unpack(usize), - CheckU8, - CheckU16, - CheckU32, U8ToBits, U16ToBits, U32ToBits, + U8FromBits, + U16FromBits, U32FromBits, } @@ -36,15 +35,6 @@ impl FlatEmbed { FlatEmbed::Unpack(bitwidth) => Signature::new() .inputs(vec![Type::FieldElement]) .outputs(vec![Type::array(Type::FieldElement, *bitwidth)]), - FlatEmbed::CheckU8 => Signature::new() - .inputs(vec![Type::Uint(8)]) - .outputs(vec![Type::array(Type::FieldElement, 8)]), - FlatEmbed::CheckU16 => Signature::new() - .inputs(vec![Type::Uint(16)]) - .outputs(vec![Type::array(Type::FieldElement, 16)]), - FlatEmbed::CheckU32 => Signature::new() - .inputs(vec![Type::Uint(32)]) - .outputs(vec![Type::array(Type::FieldElement, 32)]), FlatEmbed::U8ToBits => Signature::new() .inputs(vec![Type::Uint(8)]) .outputs(vec![Type::array(Type::Boolean, 8)]), @@ -54,6 +44,12 @@ impl FlatEmbed { FlatEmbed::U32ToBits => Signature::new() .inputs(vec![Type::Uint(32)]) .outputs(vec![Type::array(Type::Boolean, 32)]), + FlatEmbed::U8FromBits => Signature::new() + .outputs(vec![Type::Uint(8)]) + .inputs(vec![Type::array(Type::Boolean, 8)]), + FlatEmbed::U16FromBits => Signature::new() + .outputs(vec![Type::Uint(16)]) + .inputs(vec![Type::array(Type::Boolean, 16)]), FlatEmbed::U32FromBits => Signature::new() .outputs(vec![Type::Uint(32)]) .inputs(vec![Type::array(Type::Boolean, 32)]), @@ -68,12 +64,11 @@ impl FlatEmbed { match self { FlatEmbed::Sha256Round => "_SHA256_ROUND", FlatEmbed::Unpack(_) => "_UNPACK", - FlatEmbed::CheckU8 => "_CHECK_U8", - FlatEmbed::CheckU16 => "_CHECK_U16", - FlatEmbed::CheckU32 => "_CHECK_U32", FlatEmbed::U8ToBits => "_U8_TO_BITS", FlatEmbed::U16ToBits => "_U16_TO_BITS", FlatEmbed::U32ToBits => "_U32_TO_BITS", + FlatEmbed::U8FromBits => "_U8_FROM_BITS", + FlatEmbed::U16FromBits => "_U16_FROM_BITS", FlatEmbed::U32FromBits => "_U32_FROM_BITS", } } @@ -83,13 +78,7 @@ impl FlatEmbed { match self { FlatEmbed::Sha256Round => sha256_round(), FlatEmbed::Unpack(bitwidth) => unpack_to_bitwidth(*bitwidth), - FlatEmbed::CheckU8 => unpack_to_bitwidth(8), - FlatEmbed::CheckU16 => unpack_to_bitwidth(16), - FlatEmbed::CheckU32 => unpack_to_bitwidth(32), - FlatEmbed::U8ToBits => unreachable!(), - FlatEmbed::U16ToBits => unreachable!(), - FlatEmbed::U32ToBits => unreachable!(), - FlatEmbed::U32FromBits => unreachable!(), + _ => unreachable!(), } } } diff --git a/zokrates_core/src/flatten/mod.rs b/zokrates_core/src/flatten/mod.rs index 190159fd..a1304991 100644 --- a/zokrates_core/src/flatten/mod.rs +++ b/zokrates_core/src/flatten/mod.rs @@ -5,6 +5,10 @@ //! @author Jacob Eberhardt //! @date 2017 +mod utils; + +use self::utils::flat_expression_from_bits; + use crate::flat_absy::*; use crate::solvers::Solver; use crate::zir::types::{FunctionIdentifier, FunctionKey, Signature, Type}; @@ -49,48 +53,16 @@ pub struct Flattener<'ast, T: Field> { } trait FlattenOutput: Sized { - // fn branches(self, other: Self) -> (Self, Self); - fn flat(&self) -> FlatExpression; } impl FlattenOutput for FlatExpression { - // fn branches(self, other: Self) -> (Self, Self) { - // (self, other) - // } - fn flat(&self) -> FlatExpression { self.clone() } } impl FlattenOutput for FlatUExpression { - // fn branches(self, other: Self) -> (Self, Self) { - // let left_bits = self.bits.unwrap(); - // let right_bits = other.bits.unwrap(); - // let size = std::cmp::max(left_bits.len(), right_bits.len()); - - // let left_bits = (0..size - left_bits.len()) - // .map(|_| FlatExpression::Number(T::from(0))) - // .chain(left_bits) - // .collect(); - // let right_bits = (0..size - right_bits.len()) - // .map(|_| FlatExpression::Number(T::from(0))) - // .chain(right_bits) - // .collect(); - - // ( - // FlatUExpression { - // bits: Some(left_bits), - // ..self - // }, - // FlatUExpression { - // bits: Some(right_bits), - // ..other - // }, - // ) - // } - fn flat(&self) -> FlatExpression { self.clone().get_field_unchecked() } @@ -186,21 +158,7 @@ impl FlatUExpression { match self.field { Some(f) => f, None => match self.bits { - Some(bits) => { - //assert_eq!(bits.len(), 32); - bits.into_iter().rev().enumerate().fold( - FlatExpression::Number(T::from(0)), - |acc, (index, bit)| { - FlatExpression::Add( - box acc, - box FlatExpression::Mult( - box FlatExpression::Number(T::from(2).pow(index)), - box bit, - ), - ) - }, - ) - } + Some(bits) => flat_expression_from_bits(bits), None => unreachable!(), }, } @@ -668,117 +626,179 @@ impl<'ast, T: Field> Flattener<'ast, T> { let funct = self.get_embed(&key, &symbols); - if funct == crate::embed::FlatEmbed::U32ToBits { - assert_eq!(param_expressions.len(), 1); - let mut param_expressions = param_expressions; - let p = param_expressions.pop().unwrap(); - let p = match p { - ZirExpression::Uint(e) => e, - _ => unreachable!(), - }; - let from = p.metadata.clone().unwrap().bitwidth(); - let p = self.flatten_uint_expression(symbols, statements_flattened, p); - let bits = self - .get_bits(p, from as usize, 32, statements_flattened) - .into_iter() - .map(|b| FlatUExpression::with_field(b)) - .collect(); - return bits; - } - - if funct == crate::embed::FlatEmbed::U32FromBits { - assert_eq!(param_expressions.len(), 32); - let param_expressions: Vec<_> = param_expressions - .into_iter() - .map(|p| { - self.flatten_expression(symbols, statements_flattened, p) - .get_field_unchecked() - }) - .collect(); - - return vec![FlatUExpression::with_bits(param_expressions)]; - } - - let funct = funct.synthetize(); - - let mut replacement_map = HashMap::new(); - - // Handle complex parameters and assign values: - // Rename Parameters, assign them to values in call. Resolve complex expressions with definitions - let params_flattened = param_expressions - .into_iter() - .map(|param_expr| self.flatten_expression(symbols, statements_flattened, param_expr)) - .into_iter() - .map(|x| x.get_field_unchecked()) - .collect::>(); - - for (concrete_argument, formal_argument) in - params_flattened.into_iter().zip(funct.arguments) - { - let new_var = self.define(concrete_argument, statements_flattened); - replacement_map.insert(formal_argument.id, new_var); - } - - // Ensure renaming and correct returns: - // add all flattened statements, adapt return statements - - let (mut return_statements, statements): (Vec<_>, Vec<_>) = - funct.statements.into_iter().partition(|s| match s { - FlatStatement::Return(..) => true, - _ => false, - }); - - let statements: Vec<_> = statements - .into_iter() - .map(|stat| match stat { - // set return statements as expression result - FlatStatement::Return(..) => unreachable!(), - FlatStatement::Definition(var, rhs) => { - let new_var = self.use_sym(); - replacement_map.insert(var, new_var); - let new_rhs = rhs.apply_substitution(&replacement_map); - FlatStatement::Definition(new_var, new_rhs) - } - FlatStatement::Condition(lhs, rhs) => { - let new_lhs = lhs.apply_substitution(&replacement_map); - let new_rhs = rhs.apply_substitution(&replacement_map); - FlatStatement::Condition(new_lhs, new_rhs) - } - FlatStatement::Directive(d) => { - let new_outputs = d - .outputs - .into_iter() - .map(|o| { - let new_o = self.use_sym(); - replacement_map.insert(o, new_o); - new_o - }) - .collect(); - let new_inputs = d - .inputs - .into_iter() - .map(|i| i.apply_substitution(&replacement_map)) - .collect(); - FlatStatement::Directive(FlatDirective { - outputs: new_outputs, - solver: d.solver, - inputs: new_inputs, + match funct { + crate::embed::FlatEmbed::U32ToBits => { + assert_eq!(param_expressions.len(), 1); + let mut param_expressions = param_expressions; + let p = param_expressions.pop().unwrap(); + let p = match p { + ZirExpression::Uint(e) => e, + _ => unreachable!(), + }; + let from = p.metadata.clone().unwrap().bitwidth(); + let p = self.flatten_uint_expression(symbols, statements_flattened, p); + let bits = self + .get_bits(p, from as usize, 32, statements_flattened) + .into_iter() + .map(|b| FlatUExpression::with_field(b)) + .collect(); + bits + } + crate::embed::FlatEmbed::U16ToBits => { + assert_eq!(param_expressions.len(), 1); + let mut param_expressions = param_expressions; + let p = param_expressions.pop().unwrap(); + let p = match p { + ZirExpression::Uint(e) => e, + _ => unreachable!(), + }; + let from = p.metadata.clone().unwrap().bitwidth(); + let p = self.flatten_uint_expression(symbols, statements_flattened, p); + let bits = self + .get_bits(p, from as usize, 16, statements_flattened) + .into_iter() + .map(|b| FlatUExpression::with_field(b)) + .collect(); + bits + } + crate::embed::FlatEmbed::U8ToBits => { + assert_eq!(param_expressions.len(), 1); + let mut param_expressions = param_expressions; + let p = param_expressions.pop().unwrap(); + let p = match p { + ZirExpression::Uint(e) => e, + _ => unreachable!(), + }; + let from = p.metadata.clone().unwrap().bitwidth(); + let p = self.flatten_uint_expression(symbols, statements_flattened, p); + let bits = self + .get_bits(p, from as usize, 8, statements_flattened) + .into_iter() + .map(|b| FlatUExpression::with_field(b)) + .collect(); + bits + } + crate::embed::FlatEmbed::U32FromBits => { + assert_eq!(param_expressions.len(), 32); + let param_expressions: Vec<_> = param_expressions + .into_iter() + .map(|p| { + self.flatten_expression(symbols, statements_flattened, p) + .get_field_unchecked() }) + .collect(); + + vec![FlatUExpression::with_bits(param_expressions)] + } + crate::embed::FlatEmbed::U16FromBits => { + assert_eq!(param_expressions.len(), 16); + let param_expressions: Vec<_> = param_expressions + .into_iter() + .map(|p| { + self.flatten_expression(symbols, statements_flattened, p) + .get_field_unchecked() + }) + .collect(); + + vec![FlatUExpression::with_bits(param_expressions)] + } + crate::embed::FlatEmbed::U8FromBits => { + assert_eq!(param_expressions.len(), 8); + let param_expressions: Vec<_> = param_expressions + .into_iter() + .map(|p| { + self.flatten_expression(symbols, statements_flattened, p) + .get_field_unchecked() + }) + .collect(); + + vec![FlatUExpression::with_bits(param_expressions)] + } + funct => { + let funct = funct.synthetize(); + + let mut replacement_map = HashMap::new(); + + // Handle complex parameters and assign values: + // Rename Parameters, assign them to values in call. Resolve complex expressions with definitions + let params_flattened = param_expressions + .into_iter() + .map(|param_expr| { + self.flatten_expression(symbols, statements_flattened, param_expr) + }) + .into_iter() + .map(|x| x.get_field_unchecked()) + .collect::>(); + + for (concrete_argument, formal_argument) in + params_flattened.into_iter().zip(funct.arguments) + { + let new_var = self.define(concrete_argument, statements_flattened); + replacement_map.insert(formal_argument.id, new_var); } - FlatStatement::Log(s) => FlatStatement::Log(s), - }) - .collect(); - statements_flattened.extend(statements); + // Ensure renaming and correct returns: + // add all flattened statements, adapt return statements - match return_statements.pop().unwrap() { - FlatStatement::Return(list) => list - .expressions - .into_iter() - .map(|x| x.apply_substitution(&replacement_map)) - .map(|x| FlatUExpression::with_field(x)) - .collect(), - _ => unreachable!(), + let (mut return_statements, statements): (Vec<_>, Vec<_>) = + funct.statements.into_iter().partition(|s| match s { + FlatStatement::Return(..) => true, + _ => false, + }); + + let statements: Vec<_> = statements + .into_iter() + .map(|stat| match stat { + // set return statements as expression result + FlatStatement::Return(..) => unreachable!(), + FlatStatement::Definition(var, rhs) => { + let new_var = self.use_sym(); + replacement_map.insert(var, new_var); + let new_rhs = rhs.apply_substitution(&replacement_map); + FlatStatement::Definition(new_var, new_rhs) + } + FlatStatement::Condition(lhs, rhs) => { + let new_lhs = lhs.apply_substitution(&replacement_map); + let new_rhs = rhs.apply_substitution(&replacement_map); + FlatStatement::Condition(new_lhs, new_rhs) + } + FlatStatement::Directive(d) => { + let new_outputs = d + .outputs + .into_iter() + .map(|o| { + let new_o = self.use_sym(); + replacement_map.insert(o, new_o); + new_o + }) + .collect(); + let new_inputs = d + .inputs + .into_iter() + .map(|i| i.apply_substitution(&replacement_map)) + .collect(); + FlatStatement::Directive(FlatDirective { + outputs: new_outputs, + solver: d.solver, + inputs: new_inputs, + }) + } + FlatStatement::Log(s) => FlatStatement::Log(s), + }) + .collect(); + + statements_flattened.extend(statements); + + match return_statements.pop().unwrap() { + FlatStatement::Return(list) => list + .expressions + .into_iter() + .map(|x| x.apply_substitution(&replacement_map)) + .map(|x| FlatUExpression::with_field(x)) + .collect(), + _ => unreachable!(), + } + } } } @@ -1293,12 +1313,17 @@ impl<'ast, T: Field> Flattener<'ast, T> { vec![e.field.clone().unwrap()], ))); + let bits: Vec<_> = bits + .into_iter() + .map(|b| FlatExpression::Identifier(b)) + .collect(); + // decompose to the actual bitwidth // bit checks statements_flattened.extend((0..from).map(|i| { FlatStatement::Condition( - bits[i].clone().into(), + bits[i].clone(), FlatExpression::Mult( box bits[i].clone().into(), box bits[i].clone().into(), @@ -1306,18 +1331,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { ) })); - let sum = bits.iter().enumerate().fold( - FlatExpression::Number(T::from(0)), - |acc, (index, bit)| { - FlatExpression::Add( - box acc, - box FlatExpression::Mult( - box FlatExpression::Number(T::from(2).pow(from - index - 1)), - box bit.clone().into(), - ), - ) - }, - ); + let sum = flat_expression_from_bits(bits.clone()); // sum check statements_flattened.push(FlatStatement::Condition( @@ -1325,15 +1339,11 @@ impl<'ast, T: Field> Flattener<'ast, T> { sum.clone(), )); + // truncate to the `to` lowest bits let bits = bits[from - to..].to_vec(); assert_eq!(bits.len(), to); - let bits: Vec<_> = bits - .into_iter() - .map(|b| FlatExpression::Identifier(b)) - .collect(); - self.bits_cache.insert(e.field.unwrap(), bits.clone()); self.bits_cache.insert(sum, bits.clone()); @@ -1681,7 +1691,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { }) .collect(); - if key.id == "_U32_FROM_BITS" { + if ["_U32_FROM_BITS", "_U16_FROM_BITS", "_U8_FROM_BITS"].contains(&key.id) { let bits = exprs .into_iter() .map(|e| { diff --git a/zokrates_core/src/flatten/utils.rs b/zokrates_core/src/flatten/utils.rs new file mode 100644 index 00000000..9a280838 --- /dev/null +++ b/zokrates_core/src/flatten/utils.rs @@ -0,0 +1,31 @@ +use flat_absy::*; +use zokrates_field::Field; + +pub fn flat_expression_from_bits(v: Vec>) -> FlatExpression { + fn flat_expression_from_bits_aux( + v: Vec<(T, FlatExpression)>, + ) -> FlatExpression { + match v.len() { + 0 => FlatExpression::Number(T::zero()), + 1 => { + let (coeff, var) = v[0].clone(); + FlatExpression::Mult(box FlatExpression::Number(coeff), box var) + } + n => { + let (u, v) = v.split_at(n / 2); + FlatExpression::Add( + box flat_expression_from_bits_aux(u.to_vec()), + box flat_expression_from_bits_aux(v.to_vec()), + ) + } + } + } + + flat_expression_from_bits_aux( + v.into_iter() + .rev() + .enumerate() + .map(|(index, var)| (T::from(2).pow(index), var)) + .collect::>(), + ) +} diff --git a/zokrates_core/src/imports.rs b/zokrates_core/src/imports.rs index 17508bea..1e8db324 100644 --- a/zokrates_core/src/imports.rs +++ b/zokrates_core/src/imports.rs @@ -182,6 +182,28 @@ impl Importer { .start_end(pos.0, pos.1), ); } + "EMBED/u16_to_bits" => { + let alias = alias.unwrap_or("u16_to_bits"); + + symbols.push( + SymbolDeclaration { + id: &alias, + symbol: Symbol::Flat(FlatEmbed::U16ToBits), + } + .start_end(pos.0, pos.1), + ); + } + "EMBED/u8_to_bits" => { + let alias = alias.unwrap_or("u8_to_bits"); + + symbols.push( + SymbolDeclaration { + id: &alias, + symbol: Symbol::Flat(FlatEmbed::U8ToBits), + } + .start_end(pos.0, pos.1), + ); + } "EMBED/u32_from_bits" => { let alias = alias.unwrap_or("u32_from_bits"); @@ -193,9 +215,31 @@ impl Importer { .start_end(pos.0, pos.1), ); } + "EMBED/u16_from_bits" => { + let alias = alias.unwrap_or("u16_from_bits"); + + symbols.push( + SymbolDeclaration { + id: &alias, + symbol: Symbol::Flat(FlatEmbed::U16FromBits), + } + .start_end(pos.0, pos.1), + ); + } + "EMBED/u8_from_bits" => { + let alias = alias.unwrap_or("u8_from_bits"); + + symbols.push( + SymbolDeclaration { + id: &alias, + symbol: Symbol::Flat(FlatEmbed::U8FromBits), + } + .start_end(pos.0, pos.1), + ); + } s => { return Err(CompileErrorInner::ImportError( - Error::new(format!("Embed {} not found. Options are \"EMBED/sha256round\", \"EMBED/unpack\", \"EMBED/u32_to_bits\", \"EMBED/u32_from_bits\"", s)).with_pos(Some(pos)), + Error::new(format!("Embed {} not found", s)).with_pos(Some(pos)), ) .in_file(&location) .into()); diff --git a/zokrates_core/src/static_analysis/inline.rs b/zokrates_core/src/static_analysis/inline.rs index 6cdafc33..1c3d5b37 100644 --- a/zokrates_core/src/static_analysis/inline.rs +++ b/zokrates_core/src/static_analysis/inline.rs @@ -108,26 +108,37 @@ impl<'ast, T: Field> Inliner<'ast, T> { let sha256_round = crate::embed::FlatEmbed::Sha256Round; let sha256_round_key = sha256_round.key::(); - // define a function in the main module for the `check_u8` embed - let check_u8 = crate::embed::FlatEmbed::CheckU8; - let check_u8_key = check_u8.key::(); - - // define a function in the main module for the `check_u8` embed - let check_u16 = crate::embed::FlatEmbed::CheckU16; - let check_u16_key = check_u16.key::(); - - // define a function in the main module for the `check_u8` embed - let check_u32 = crate::embed::FlatEmbed::CheckU32; - let check_u32_key = check_u32.key::(); - // define a function in the main module for the `u32_to_bits` embed let u32_to_bits = crate::embed::FlatEmbed::U32ToBits; let u32_to_bits_key = u32_to_bits.key::(); - // define a function in the main module for the `u32_to_bits` embed + // define a function in the main module for the `u16_to_bits` embed + let u16_to_bits = crate::embed::FlatEmbed::U16ToBits; + let u16_to_bits_key = u16_to_bits.key::(); + + // define a function in the main module for the `u8_to_bits` embed + let u8_to_bits = crate::embed::FlatEmbed::U8ToBits; + let u8_to_bits_key = u8_to_bits.key::(); + + // define a function in the main module for the `u32_from_bits` embed let u32_from_bits = crate::embed::FlatEmbed::U32FromBits; let u32_from_bits_key = u32_from_bits.key::(); + // define a function in the main module for the `u16_from_bits` embed + let u16_from_bits = crate::embed::FlatEmbed::U16FromBits; + let u16_from_bits_key = u16_from_bits.key::(); + + // define a function in the main module for the `u8_from_bits` embed + let u8_from_bits = crate::embed::FlatEmbed::U8FromBits; + let u8_from_bits_key = u8_from_bits.key::(); + + println!("{:?}", unpack_key); + println!( + "{:?}", + crate::embed::FlatEmbed::Unpack(T::get_required_bits()).signature() + ); + println!("{:?}", crate::embed::FlatEmbed::U32FromBits.signature()); + // return a program with a single module containing `main`, `_UNPACK`, and `_SHA256_ROUND TypedProgram { main: "main".into(), @@ -137,11 +148,12 @@ impl<'ast, T: Field> Inliner<'ast, T> { functions: vec![ (unpack_key, TypedFunctionSymbol::Flat(unpack)), (sha256_round_key, TypedFunctionSymbol::Flat(sha256_round)), - (check_u8_key, TypedFunctionSymbol::Flat(check_u8)), - (check_u16_key, TypedFunctionSymbol::Flat(check_u16)), - (check_u32_key, TypedFunctionSymbol::Flat(check_u32)), - (u32_to_bits_key, TypedFunctionSymbol::Flat(u32_to_bits)), (u32_from_bits_key, TypedFunctionSymbol::Flat(u32_from_bits)), + (u16_from_bits_key, TypedFunctionSymbol::Flat(u16_from_bits)), + (u8_from_bits_key, TypedFunctionSymbol::Flat(u8_from_bits)), + (u32_to_bits_key, TypedFunctionSymbol::Flat(u32_to_bits)), + (u16_to_bits_key, TypedFunctionSymbol::Flat(u16_to_bits)), + (u8_to_bits_key, TypedFunctionSymbol::Flat(u8_to_bits)), (main_key, main), ] .into_iter() diff --git a/zokrates_core/src/static_analysis/mod.rs b/zokrates_core/src/static_analysis/mod.rs index a4f00110..9a96ab78 100644 --- a/zokrates_core/src/static_analysis/mod.rs +++ b/zokrates_core/src/static_analysis/mod.rs @@ -55,6 +55,7 @@ impl<'ast, T: Field> TypedProgram<'ast, T> { let r = VariableAccessRemover::apply(r); let zir = Flattener::flatten(r.clone()); + // constrain inputs let zir = InputConstrainer::constrain(zir); diff --git a/zokrates_core/src/static_analysis/propagation.rs b/zokrates_core/src/static_analysis/propagation.rs index 333e69cc..73c9f865 100644 --- a/zokrates_core/src/static_analysis/propagation.rs +++ b/zokrates_core/src/static_analysis/propagation.rs @@ -142,7 +142,15 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> { .collect(); if arguments.iter().all(|a| is_constant(a)) - && (key.id == "_U32_FROM_BITS" || key.id == "_U32_TO_BITS") + && [ + "_U32_FROM_BITS", + "_U16_FROM_BITS", + "_U8_FROM_BITS", + "_U32_TO_BITS", + "_U16_TO_BITS", + "_U8_TO_BITS", + ] + .contains(&key.id) { let expr: TypedExpression<'ast, T> = if key.id == "_U32_FROM_BITS" { assert_eq!(variables.len(), 1); @@ -180,6 +188,78 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> { }, _ => unreachable!("should be an array"), } + } else if key.id == "_U16_FROM_BITS" { + assert_eq!(variables.len(), 1); + assert_eq!(arguments.len(), 1); + + use std::convert::TryInto; + + match arguments[0].clone() { + TypedExpression::Array(a) => match a.into_inner() { + ArrayExpressionInner::Value(v) => { + assert_eq!(v.len(), 16); + UExpressionInner::Value( + v.into_iter() + .map(|v| match v { + TypedExpression::Boolean( + BooleanExpression::Value(v), + ) => v, + _ => unreachable!("should be a boolean"), + }) + .enumerate() + .fold(0, |acc, (i, v)| { + if v { + acc + 2u128.pow( + (16 - i - 1).try_into().unwrap(), + ) + } else { + acc + } + }), + ) + .annotate(16) + .into() + } + _ => unreachable!("should be an array value"), + }, + _ => unreachable!("should be an array"), + } + } else if key.id == "_U8_FROM_BITS" { + assert_eq!(variables.len(), 1); + assert_eq!(arguments.len(), 1); + + use std::convert::TryInto; + + match arguments[0].clone() { + TypedExpression::Array(a) => match a.into_inner() { + ArrayExpressionInner::Value(v) => { + assert_eq!(v.len(), 8); + UExpressionInner::Value( + v.into_iter() + .map(|v| match v { + TypedExpression::Boolean( + BooleanExpression::Value(v), + ) => v, + _ => unreachable!("should be a boolean"), + }) + .enumerate() + .fold(0, |acc, (i, v)| { + if v { + acc + 2u128.pow( + (8 - i - 1).try_into().unwrap(), + ) + } else { + acc + } + }), + ) + .annotate(8) + .into() + } + _ => unreachable!("should be an array value"), + }, + _ => unreachable!("should be an array"), + } } else if key.id == "_U32_TO_BITS" { assert_eq!(variables.len(), 1); assert_eq!(arguments.len(), 1); @@ -212,6 +292,70 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> { }, _ => unreachable!("should be a uint"), } + } else if key.id == "_U16_TO_BITS" { + assert_eq!(variables.len(), 1); + assert_eq!(arguments.len(), 1); + + match arguments[0].clone() { + TypedExpression::Uint(a) => match a.into_inner() { + UExpressionInner::Value(v) => { + let mut num = v; + let mut res = vec![]; + + for i in (0..16).rev() { + if 2u128.pow(i) <= num { + num = num - 2u128.pow(i); + res.push(true); + } else { + res.push(false); + } + } + assert_eq!(num, 0); + + ArrayExpressionInner::Value( + res.into_iter() + .map(|v| BooleanExpression::Value(v).into()) + .collect(), + ) + .annotate(Type::Boolean, 16) + .into() + } + _ => unreachable!("should be a uint value"), + }, + _ => unreachable!("should be a uint"), + } + } else if key.id == "_U8_TO_BITS" { + assert_eq!(variables.len(), 1); + assert_eq!(arguments.len(), 1); + + match arguments[0].clone() { + TypedExpression::Uint(a) => match a.into_inner() { + UExpressionInner::Value(v) => { + let mut num = v; + let mut res = vec![]; + + for i in (0..8).rev() { + if 2u128.pow(i) <= num { + num = num - 2u128.pow(i); + res.push(true); + } else { + res.push(false); + } + } + assert_eq!(num, 0); + + ArrayExpressionInner::Value( + res.into_iter() + .map(|v| BooleanExpression::Value(v).into()) + .collect(), + ) + .annotate(Type::Boolean, 8) + .into() + } + _ => unreachable!("should be a uint value"), + }, + _ => unreachable!("should be a uint"), + } } else { unreachable!() }; diff --git a/zokrates_core/src/static_analysis/uint_optimizer.rs b/zokrates_core/src/static_analysis/uint_optimizer.rs index 4fb38727..86552c6b 100644 --- a/zokrates_core/src/static_analysis/uint_optimizer.rs +++ b/zokrates_core/src/static_analysis/uint_optimizer.rs @@ -359,6 +359,34 @@ impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> { ZirExpressionList::FunctionCall(key, arguments, ty), )] } + "_U16_FROM_BITS" => { + assert_eq!(lhs.len(), 1); + self.register( + lhs[0].clone(), + UMetadata { + max: T::from(2).pow(16) - T::from(1), + should_reduce: ShouldReduce::False, + }, + ); + vec![ZirStatement::MultipleDefinition( + lhs, + ZirExpressionList::FunctionCall(key, arguments, ty), + )] + } + "_U8_FROM_BITS" => { + assert_eq!(lhs.len(), 1); + self.register( + lhs[0].clone(), + UMetadata { + max: T::from(2).pow(8) - T::from(1), + should_reduce: ShouldReduce::False, + }, + ); + vec![ZirStatement::MultipleDefinition( + lhs, + ZirExpressionList::FunctionCall(key, arguments, ty), + )] + } _ => vec![ZirStatement::MultipleDefinition( lhs, ZirExpressionList::FunctionCall( diff --git a/zokrates_core_test/tests/tests/uint/from_to_bits.json b/zokrates_core_test/tests/tests/uint/from_to_bits.json index 0bb9e58c..e086d206 100644 --- a/zokrates_core_test/tests/tests/uint/from_to_bits.json +++ b/zokrates_core_test/tests/tests/uint/from_to_bits.json @@ -4,31 +4,31 @@ "tests": [ { "input": { - "values": ["0x00000000"] + "values": ["0x00000000", "0x0000", "0x00"] }, "output": { "Ok": { - "values": ["0x00000000"] + "values": ["0x00000000", "0x0000", "0x00"] } } }, { "input": { - "values": ["0x00000002"] + "values": ["0xffffffff", "0xffff", "0xff"] }, "output": { "Ok": { - "values": ["0x00000002"] + "values": ["0xffffffff", "0xffff", "0xff"] } } }, { "input": { - "values": ["0x12345678"] + "values": ["0x12345678", "0x1234", "0x12"] }, "output": { "Ok": { - "values": ["0x12345678"] + "values": ["0x12345678", "0x1234", "0x12"] } } } diff --git a/zokrates_core_test/tests/tests/uint/from_to_bits.zok b/zokrates_core_test/tests/tests/uint/from_to_bits.zok index 2220d771..6bb7b804 100644 --- a/zokrates_core_test/tests/tests/uint/from_to_bits.zok +++ b/zokrates_core_test/tests/tests/uint/from_to_bits.zok @@ -1,6 +1,12 @@ -import "EMBED/u32_to_bits" as to_bits -import "EMBED/u32_from_bits" as from_bits +import "EMBED/u32_to_bits" as to_bits_32 +import "EMBED/u32_from_bits" as from_bits_32 +import "EMBED/u16_to_bits" as to_bits_16 +import "EMBED/u16_from_bits" as from_bits_16 +import "EMBED/u8_to_bits" as to_bits_8 +import "EMBED/u8_from_bits" as from_bits_8 -def main(u32 e) -> (u32): - bool[32] f = to_bits(e) - return from_bits(f) \ No newline at end of file +def main(u32 e, u16 f, u8 g) -> (u32, u16, u8): + bool[32] e_bits = to_bits_32(e) + bool[16] f_bits = to_bits_16(f) + bool[8] g_bits = to_bits_8(g) + return from_bits_32(e_bits), from_bits_16(f_bits), from_bits_8(g_bits) \ No newline at end of file