implement all unsigned types
This commit is contained in:
parent
a80ebf0b15
commit
cbd573a17c
10 changed files with 490 additions and 225 deletions
|
@ -15,12 +15,11 @@ use zokrates_field::Field;
|
|||
pub enum FlatEmbed {
|
||||
Sha256Round,
|
||||
Unpack(usize),
|
||||
CheckU8,
|
||||
CheckU16,
|
||||
CheckU32,
|
||||
U8ToBits,
|
||||
U16ToBits,
|
||||
U32ToBits,
|
||||
U8FromBits,
|
||||
U16FromBits,
|
||||
U32FromBits,
|
||||
}
|
||||
|
||||
|
@ -36,15 +35,6 @@ impl FlatEmbed {
|
|||
FlatEmbed::Unpack(bitwidth) => Signature::new()
|
||||
.inputs(vec![Type::FieldElement])
|
||||
.outputs(vec![Type::array(Type::FieldElement, *bitwidth)]),
|
||||
FlatEmbed::CheckU8 => Signature::new()
|
||||
.inputs(vec![Type::Uint(8)])
|
||||
.outputs(vec![Type::array(Type::FieldElement, 8)]),
|
||||
FlatEmbed::CheckU16 => Signature::new()
|
||||
.inputs(vec![Type::Uint(16)])
|
||||
.outputs(vec![Type::array(Type::FieldElement, 16)]),
|
||||
FlatEmbed::CheckU32 => Signature::new()
|
||||
.inputs(vec![Type::Uint(32)])
|
||||
.outputs(vec![Type::array(Type::FieldElement, 32)]),
|
||||
FlatEmbed::U8ToBits => Signature::new()
|
||||
.inputs(vec![Type::Uint(8)])
|
||||
.outputs(vec![Type::array(Type::Boolean, 8)]),
|
||||
|
@ -54,6 +44,12 @@ impl FlatEmbed {
|
|||
FlatEmbed::U32ToBits => Signature::new()
|
||||
.inputs(vec![Type::Uint(32)])
|
||||
.outputs(vec![Type::array(Type::Boolean, 32)]),
|
||||
FlatEmbed::U8FromBits => Signature::new()
|
||||
.outputs(vec![Type::Uint(8)])
|
||||
.inputs(vec![Type::array(Type::Boolean, 8)]),
|
||||
FlatEmbed::U16FromBits => Signature::new()
|
||||
.outputs(vec![Type::Uint(16)])
|
||||
.inputs(vec![Type::array(Type::Boolean, 16)]),
|
||||
FlatEmbed::U32FromBits => Signature::new()
|
||||
.outputs(vec![Type::Uint(32)])
|
||||
.inputs(vec![Type::array(Type::Boolean, 32)]),
|
||||
|
@ -68,12 +64,11 @@ impl FlatEmbed {
|
|||
match self {
|
||||
FlatEmbed::Sha256Round => "_SHA256_ROUND",
|
||||
FlatEmbed::Unpack(_) => "_UNPACK",
|
||||
FlatEmbed::CheckU8 => "_CHECK_U8",
|
||||
FlatEmbed::CheckU16 => "_CHECK_U16",
|
||||
FlatEmbed::CheckU32 => "_CHECK_U32",
|
||||
FlatEmbed::U8ToBits => "_U8_TO_BITS",
|
||||
FlatEmbed::U16ToBits => "_U16_TO_BITS",
|
||||
FlatEmbed::U32ToBits => "_U32_TO_BITS",
|
||||
FlatEmbed::U8FromBits => "_U8_FROM_BITS",
|
||||
FlatEmbed::U16FromBits => "_U16_FROM_BITS",
|
||||
FlatEmbed::U32FromBits => "_U32_FROM_BITS",
|
||||
}
|
||||
}
|
||||
|
@ -83,13 +78,7 @@ impl FlatEmbed {
|
|||
match self {
|
||||
FlatEmbed::Sha256Round => sha256_round(),
|
||||
FlatEmbed::Unpack(bitwidth) => unpack_to_bitwidth(*bitwidth),
|
||||
FlatEmbed::CheckU8 => unpack_to_bitwidth(8),
|
||||
FlatEmbed::CheckU16 => unpack_to_bitwidth(16),
|
||||
FlatEmbed::CheckU32 => unpack_to_bitwidth(32),
|
||||
FlatEmbed::U8ToBits => unreachable!(),
|
||||
FlatEmbed::U16ToBits => unreachable!(),
|
||||
FlatEmbed::U32ToBits => unreachable!(),
|
||||
FlatEmbed::U32FromBits => unreachable!(),
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -5,6 +5,10 @@
|
|||
//! @author Jacob Eberhardt <jacob.eberhardt@tu-berlin.de>
|
||||
//! @date 2017
|
||||
|
||||
mod utils;
|
||||
|
||||
use self::utils::flat_expression_from_bits;
|
||||
|
||||
use crate::flat_absy::*;
|
||||
use crate::solvers::Solver;
|
||||
use crate::zir::types::{FunctionIdentifier, FunctionKey, Signature, Type};
|
||||
|
@ -49,48 +53,16 @@ pub struct Flattener<'ast, T: Field> {
|
|||
}
|
||||
|
||||
trait FlattenOutput<T: Field>: Sized {
|
||||
// fn branches(self, other: Self) -> (Self, Self);
|
||||
|
||||
fn flat(&self) -> FlatExpression<T>;
|
||||
}
|
||||
|
||||
impl<T: Field> FlattenOutput<T> for FlatExpression<T> {
|
||||
// fn branches(self, other: Self) -> (Self, Self) {
|
||||
// (self, other)
|
||||
// }
|
||||
|
||||
fn flat(&self) -> FlatExpression<T> {
|
||||
self.clone()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Field> FlattenOutput<T> for FlatUExpression<T> {
|
||||
// fn branches(self, other: Self) -> (Self, Self) {
|
||||
// let left_bits = self.bits.unwrap();
|
||||
// let right_bits = other.bits.unwrap();
|
||||
// let size = std::cmp::max(left_bits.len(), right_bits.len());
|
||||
|
||||
// let left_bits = (0..size - left_bits.len())
|
||||
// .map(|_| FlatExpression::Number(T::from(0)))
|
||||
// .chain(left_bits)
|
||||
// .collect();
|
||||
// let right_bits = (0..size - right_bits.len())
|
||||
// .map(|_| FlatExpression::Number(T::from(0)))
|
||||
// .chain(right_bits)
|
||||
// .collect();
|
||||
|
||||
// (
|
||||
// FlatUExpression {
|
||||
// bits: Some(left_bits),
|
||||
// ..self
|
||||
// },
|
||||
// FlatUExpression {
|
||||
// bits: Some(right_bits),
|
||||
// ..other
|
||||
// },
|
||||
// )
|
||||
// }
|
||||
|
||||
fn flat(&self) -> FlatExpression<T> {
|
||||
self.clone().get_field_unchecked()
|
||||
}
|
||||
|
@ -186,21 +158,7 @@ impl<T: Field> FlatUExpression<T> {
|
|||
match self.field {
|
||||
Some(f) => f,
|
||||
None => match self.bits {
|
||||
Some(bits) => {
|
||||
//assert_eq!(bits.len(), 32);
|
||||
bits.into_iter().rev().enumerate().fold(
|
||||
FlatExpression::Number(T::from(0)),
|
||||
|acc, (index, bit)| {
|
||||
FlatExpression::Add(
|
||||
box acc,
|
||||
box FlatExpression::Mult(
|
||||
box FlatExpression::Number(T::from(2).pow(index)),
|
||||
box bit,
|
||||
),
|
||||
)
|
||||
},
|
||||
)
|
||||
}
|
||||
Some(bits) => flat_expression_from_bits(bits),
|
||||
None => unreachable!(),
|
||||
},
|
||||
}
|
||||
|
@ -668,7 +626,8 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
|
||||
let funct = self.get_embed(&key, &symbols);
|
||||
|
||||
if funct == crate::embed::FlatEmbed::U32ToBits {
|
||||
match funct {
|
||||
crate::embed::FlatEmbed::U32ToBits => {
|
||||
assert_eq!(param_expressions.len(), 1);
|
||||
let mut param_expressions = param_expressions;
|
||||
let p = param_expressions.pop().unwrap();
|
||||
|
@ -683,10 +642,43 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
.into_iter()
|
||||
.map(|b| FlatUExpression::with_field(b))
|
||||
.collect();
|
||||
return bits;
|
||||
bits
|
||||
}
|
||||
|
||||
if funct == crate::embed::FlatEmbed::U32FromBits {
|
||||
crate::embed::FlatEmbed::U16ToBits => {
|
||||
assert_eq!(param_expressions.len(), 1);
|
||||
let mut param_expressions = param_expressions;
|
||||
let p = param_expressions.pop().unwrap();
|
||||
let p = match p {
|
||||
ZirExpression::Uint(e) => e,
|
||||
_ => unreachable!(),
|
||||
};
|
||||
let from = p.metadata.clone().unwrap().bitwidth();
|
||||
let p = self.flatten_uint_expression(symbols, statements_flattened, p);
|
||||
let bits = self
|
||||
.get_bits(p, from as usize, 16, statements_flattened)
|
||||
.into_iter()
|
||||
.map(|b| FlatUExpression::with_field(b))
|
||||
.collect();
|
||||
bits
|
||||
}
|
||||
crate::embed::FlatEmbed::U8ToBits => {
|
||||
assert_eq!(param_expressions.len(), 1);
|
||||
let mut param_expressions = param_expressions;
|
||||
let p = param_expressions.pop().unwrap();
|
||||
let p = match p {
|
||||
ZirExpression::Uint(e) => e,
|
||||
_ => unreachable!(),
|
||||
};
|
||||
let from = p.metadata.clone().unwrap().bitwidth();
|
||||
let p = self.flatten_uint_expression(symbols, statements_flattened, p);
|
||||
let bits = self
|
||||
.get_bits(p, from as usize, 8, statements_flattened)
|
||||
.into_iter()
|
||||
.map(|b| FlatUExpression::with_field(b))
|
||||
.collect();
|
||||
bits
|
||||
}
|
||||
crate::embed::FlatEmbed::U32FromBits => {
|
||||
assert_eq!(param_expressions.len(), 32);
|
||||
let param_expressions: Vec<_> = param_expressions
|
||||
.into_iter()
|
||||
|
@ -696,9 +688,33 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
})
|
||||
.collect();
|
||||
|
||||
return vec![FlatUExpression::with_bits(param_expressions)];
|
||||
vec![FlatUExpression::with_bits(param_expressions)]
|
||||
}
|
||||
crate::embed::FlatEmbed::U16FromBits => {
|
||||
assert_eq!(param_expressions.len(), 16);
|
||||
let param_expressions: Vec<_> = param_expressions
|
||||
.into_iter()
|
||||
.map(|p| {
|
||||
self.flatten_expression(symbols, statements_flattened, p)
|
||||
.get_field_unchecked()
|
||||
})
|
||||
.collect();
|
||||
|
||||
vec![FlatUExpression::with_bits(param_expressions)]
|
||||
}
|
||||
crate::embed::FlatEmbed::U8FromBits => {
|
||||
assert_eq!(param_expressions.len(), 8);
|
||||
let param_expressions: Vec<_> = param_expressions
|
||||
.into_iter()
|
||||
.map(|p| {
|
||||
self.flatten_expression(symbols, statements_flattened, p)
|
||||
.get_field_unchecked()
|
||||
})
|
||||
.collect();
|
||||
|
||||
vec![FlatUExpression::with_bits(param_expressions)]
|
||||
}
|
||||
funct => {
|
||||
let funct = funct.synthetize();
|
||||
|
||||
let mut replacement_map = HashMap::new();
|
||||
|
@ -707,7 +723,9 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
// Rename Parameters, assign them to values in call. Resolve complex expressions with definitions
|
||||
let params_flattened = param_expressions
|
||||
.into_iter()
|
||||
.map(|param_expr| self.flatten_expression(symbols, statements_flattened, param_expr))
|
||||
.map(|param_expr| {
|
||||
self.flatten_expression(symbols, statements_flattened, param_expr)
|
||||
})
|
||||
.into_iter()
|
||||
.map(|x| x.get_field_unchecked())
|
||||
.collect::<Vec<_>>();
|
||||
|
@ -781,6 +799,8 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Flattens an expression
|
||||
///
|
||||
|
@ -1293,12 +1313,17 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
vec![e.field.clone().unwrap()],
|
||||
)));
|
||||
|
||||
let bits: Vec<_> = bits
|
||||
.into_iter()
|
||||
.map(|b| FlatExpression::Identifier(b))
|
||||
.collect();
|
||||
|
||||
// decompose to the actual bitwidth
|
||||
|
||||
// bit checks
|
||||
statements_flattened.extend((0..from).map(|i| {
|
||||
FlatStatement::Condition(
|
||||
bits[i].clone().into(),
|
||||
bits[i].clone(),
|
||||
FlatExpression::Mult(
|
||||
box bits[i].clone().into(),
|
||||
box bits[i].clone().into(),
|
||||
|
@ -1306,18 +1331,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
)
|
||||
}));
|
||||
|
||||
let sum = bits.iter().enumerate().fold(
|
||||
FlatExpression::Number(T::from(0)),
|
||||
|acc, (index, bit)| {
|
||||
FlatExpression::Add(
|
||||
box acc,
|
||||
box FlatExpression::Mult(
|
||||
box FlatExpression::Number(T::from(2).pow(from - index - 1)),
|
||||
box bit.clone().into(),
|
||||
),
|
||||
)
|
||||
},
|
||||
);
|
||||
let sum = flat_expression_from_bits(bits.clone());
|
||||
|
||||
// sum check
|
||||
statements_flattened.push(FlatStatement::Condition(
|
||||
|
@ -1325,15 +1339,11 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
sum.clone(),
|
||||
));
|
||||
|
||||
// truncate to the `to` lowest bits
|
||||
let bits = bits[from - to..].to_vec();
|
||||
|
||||
assert_eq!(bits.len(), to);
|
||||
|
||||
let bits: Vec<_> = bits
|
||||
.into_iter()
|
||||
.map(|b| FlatExpression::Identifier(b))
|
||||
.collect();
|
||||
|
||||
self.bits_cache.insert(e.field.unwrap(), bits.clone());
|
||||
self.bits_cache.insert(sum, bits.clone());
|
||||
|
||||
|
@ -1681,7 +1691,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
})
|
||||
.collect();
|
||||
|
||||
if key.id == "_U32_FROM_BITS" {
|
||||
if ["_U32_FROM_BITS", "_U16_FROM_BITS", "_U8_FROM_BITS"].contains(&key.id) {
|
||||
let bits = exprs
|
||||
.into_iter()
|
||||
.map(|e| {
|
||||
|
|
31
zokrates_core/src/flatten/utils.rs
Normal file
31
zokrates_core/src/flatten/utils.rs
Normal file
|
@ -0,0 +1,31 @@
|
|||
use flat_absy::*;
|
||||
use zokrates_field::Field;
|
||||
|
||||
pub fn flat_expression_from_bits<T: Field>(v: Vec<FlatExpression<T>>) -> FlatExpression<T> {
|
||||
fn flat_expression_from_bits_aux<T: Field>(
|
||||
v: Vec<(T, FlatExpression<T>)>,
|
||||
) -> FlatExpression<T> {
|
||||
match v.len() {
|
||||
0 => FlatExpression::Number(T::zero()),
|
||||
1 => {
|
||||
let (coeff, var) = v[0].clone();
|
||||
FlatExpression::Mult(box FlatExpression::Number(coeff), box var)
|
||||
}
|
||||
n => {
|
||||
let (u, v) = v.split_at(n / 2);
|
||||
FlatExpression::Add(
|
||||
box flat_expression_from_bits_aux(u.to_vec()),
|
||||
box flat_expression_from_bits_aux(v.to_vec()),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
flat_expression_from_bits_aux(
|
||||
v.into_iter()
|
||||
.rev()
|
||||
.enumerate()
|
||||
.map(|(index, var)| (T::from(2).pow(index), var))
|
||||
.collect::<Vec<_>>(),
|
||||
)
|
||||
}
|
|
@ -182,6 +182,28 @@ impl Importer {
|
|||
.start_end(pos.0, pos.1),
|
||||
);
|
||||
}
|
||||
"EMBED/u16_to_bits" => {
|
||||
let alias = alias.unwrap_or("u16_to_bits");
|
||||
|
||||
symbols.push(
|
||||
SymbolDeclaration {
|
||||
id: &alias,
|
||||
symbol: Symbol::Flat(FlatEmbed::U16ToBits),
|
||||
}
|
||||
.start_end(pos.0, pos.1),
|
||||
);
|
||||
}
|
||||
"EMBED/u8_to_bits" => {
|
||||
let alias = alias.unwrap_or("u8_to_bits");
|
||||
|
||||
symbols.push(
|
||||
SymbolDeclaration {
|
||||
id: &alias,
|
||||
symbol: Symbol::Flat(FlatEmbed::U8ToBits),
|
||||
}
|
||||
.start_end(pos.0, pos.1),
|
||||
);
|
||||
}
|
||||
"EMBED/u32_from_bits" => {
|
||||
let alias = alias.unwrap_or("u32_from_bits");
|
||||
|
||||
|
@ -193,9 +215,31 @@ impl Importer {
|
|||
.start_end(pos.0, pos.1),
|
||||
);
|
||||
}
|
||||
"EMBED/u16_from_bits" => {
|
||||
let alias = alias.unwrap_or("u16_from_bits");
|
||||
|
||||
symbols.push(
|
||||
SymbolDeclaration {
|
||||
id: &alias,
|
||||
symbol: Symbol::Flat(FlatEmbed::U16FromBits),
|
||||
}
|
||||
.start_end(pos.0, pos.1),
|
||||
);
|
||||
}
|
||||
"EMBED/u8_from_bits" => {
|
||||
let alias = alias.unwrap_or("u8_from_bits");
|
||||
|
||||
symbols.push(
|
||||
SymbolDeclaration {
|
||||
id: &alias,
|
||||
symbol: Symbol::Flat(FlatEmbed::U8FromBits),
|
||||
}
|
||||
.start_end(pos.0, pos.1),
|
||||
);
|
||||
}
|
||||
s => {
|
||||
return Err(CompileErrorInner::ImportError(
|
||||
Error::new(format!("Embed {} not found. Options are \"EMBED/sha256round\", \"EMBED/unpack\", \"EMBED/u32_to_bits\", \"EMBED/u32_from_bits\"", s)).with_pos(Some(pos)),
|
||||
Error::new(format!("Embed {} not found", s)).with_pos(Some(pos)),
|
||||
)
|
||||
.in_file(&location)
|
||||
.into());
|
||||
|
|
|
@ -108,26 +108,37 @@ impl<'ast, T: Field> Inliner<'ast, T> {
|
|||
let sha256_round = crate::embed::FlatEmbed::Sha256Round;
|
||||
let sha256_round_key = sha256_round.key::<T>();
|
||||
|
||||
// define a function in the main module for the `check_u8` embed
|
||||
let check_u8 = crate::embed::FlatEmbed::CheckU8;
|
||||
let check_u8_key = check_u8.key::<T>();
|
||||
|
||||
// define a function in the main module for the `check_u8` embed
|
||||
let check_u16 = crate::embed::FlatEmbed::CheckU16;
|
||||
let check_u16_key = check_u16.key::<T>();
|
||||
|
||||
// define a function in the main module for the `check_u8` embed
|
||||
let check_u32 = crate::embed::FlatEmbed::CheckU32;
|
||||
let check_u32_key = check_u32.key::<T>();
|
||||
|
||||
// define a function in the main module for the `u32_to_bits` embed
|
||||
let u32_to_bits = crate::embed::FlatEmbed::U32ToBits;
|
||||
let u32_to_bits_key = u32_to_bits.key::<T>();
|
||||
|
||||
// define a function in the main module for the `u32_to_bits` embed
|
||||
// define a function in the main module for the `u16_to_bits` embed
|
||||
let u16_to_bits = crate::embed::FlatEmbed::U16ToBits;
|
||||
let u16_to_bits_key = u16_to_bits.key::<T>();
|
||||
|
||||
// define a function in the main module for the `u8_to_bits` embed
|
||||
let u8_to_bits = crate::embed::FlatEmbed::U8ToBits;
|
||||
let u8_to_bits_key = u8_to_bits.key::<T>();
|
||||
|
||||
// define a function in the main module for the `u32_from_bits` embed
|
||||
let u32_from_bits = crate::embed::FlatEmbed::U32FromBits;
|
||||
let u32_from_bits_key = u32_from_bits.key::<T>();
|
||||
|
||||
// define a function in the main module for the `u16_from_bits` embed
|
||||
let u16_from_bits = crate::embed::FlatEmbed::U16FromBits;
|
||||
let u16_from_bits_key = u16_from_bits.key::<T>();
|
||||
|
||||
// define a function in the main module for the `u8_from_bits` embed
|
||||
let u8_from_bits = crate::embed::FlatEmbed::U8FromBits;
|
||||
let u8_from_bits_key = u8_from_bits.key::<T>();
|
||||
|
||||
println!("{:?}", unpack_key);
|
||||
println!(
|
||||
"{:?}",
|
||||
crate::embed::FlatEmbed::Unpack(T::get_required_bits()).signature()
|
||||
);
|
||||
println!("{:?}", crate::embed::FlatEmbed::U32FromBits.signature());
|
||||
|
||||
// return a program with a single module containing `main`, `_UNPACK`, and `_SHA256_ROUND
|
||||
TypedProgram {
|
||||
main: "main".into(),
|
||||
|
@ -137,11 +148,12 @@ impl<'ast, T: Field> Inliner<'ast, T> {
|
|||
functions: vec![
|
||||
(unpack_key, TypedFunctionSymbol::Flat(unpack)),
|
||||
(sha256_round_key, TypedFunctionSymbol::Flat(sha256_round)),
|
||||
(check_u8_key, TypedFunctionSymbol::Flat(check_u8)),
|
||||
(check_u16_key, TypedFunctionSymbol::Flat(check_u16)),
|
||||
(check_u32_key, TypedFunctionSymbol::Flat(check_u32)),
|
||||
(u32_to_bits_key, TypedFunctionSymbol::Flat(u32_to_bits)),
|
||||
(u32_from_bits_key, TypedFunctionSymbol::Flat(u32_from_bits)),
|
||||
(u16_from_bits_key, TypedFunctionSymbol::Flat(u16_from_bits)),
|
||||
(u8_from_bits_key, TypedFunctionSymbol::Flat(u8_from_bits)),
|
||||
(u32_to_bits_key, TypedFunctionSymbol::Flat(u32_to_bits)),
|
||||
(u16_to_bits_key, TypedFunctionSymbol::Flat(u16_to_bits)),
|
||||
(u8_to_bits_key, TypedFunctionSymbol::Flat(u8_to_bits)),
|
||||
(main_key, main),
|
||||
]
|
||||
.into_iter()
|
||||
|
|
|
@ -55,6 +55,7 @@ impl<'ast, T: Field> TypedProgram<'ast, T> {
|
|||
let r = VariableAccessRemover::apply(r);
|
||||
|
||||
let zir = Flattener::flatten(r.clone());
|
||||
|
||||
// constrain inputs
|
||||
let zir = InputConstrainer::constrain(zir);
|
||||
|
||||
|
|
|
@ -142,7 +142,15 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> {
|
|||
.collect();
|
||||
|
||||
if arguments.iter().all(|a| is_constant(a))
|
||||
&& (key.id == "_U32_FROM_BITS" || key.id == "_U32_TO_BITS")
|
||||
&& [
|
||||
"_U32_FROM_BITS",
|
||||
"_U16_FROM_BITS",
|
||||
"_U8_FROM_BITS",
|
||||
"_U32_TO_BITS",
|
||||
"_U16_TO_BITS",
|
||||
"_U8_TO_BITS",
|
||||
]
|
||||
.contains(&key.id)
|
||||
{
|
||||
let expr: TypedExpression<'ast, T> = if key.id == "_U32_FROM_BITS" {
|
||||
assert_eq!(variables.len(), 1);
|
||||
|
@ -180,6 +188,78 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> {
|
|||
},
|
||||
_ => unreachable!("should be an array"),
|
||||
}
|
||||
} else if key.id == "_U16_FROM_BITS" {
|
||||
assert_eq!(variables.len(), 1);
|
||||
assert_eq!(arguments.len(), 1);
|
||||
|
||||
use std::convert::TryInto;
|
||||
|
||||
match arguments[0].clone() {
|
||||
TypedExpression::Array(a) => match a.into_inner() {
|
||||
ArrayExpressionInner::Value(v) => {
|
||||
assert_eq!(v.len(), 16);
|
||||
UExpressionInner::Value(
|
||||
v.into_iter()
|
||||
.map(|v| match v {
|
||||
TypedExpression::Boolean(
|
||||
BooleanExpression::Value(v),
|
||||
) => v,
|
||||
_ => unreachable!("should be a boolean"),
|
||||
})
|
||||
.enumerate()
|
||||
.fold(0, |acc, (i, v)| {
|
||||
if v {
|
||||
acc + 2u128.pow(
|
||||
(16 - i - 1).try_into().unwrap(),
|
||||
)
|
||||
} else {
|
||||
acc
|
||||
}
|
||||
}),
|
||||
)
|
||||
.annotate(16)
|
||||
.into()
|
||||
}
|
||||
_ => unreachable!("should be an array value"),
|
||||
},
|
||||
_ => unreachable!("should be an array"),
|
||||
}
|
||||
} else if key.id == "_U8_FROM_BITS" {
|
||||
assert_eq!(variables.len(), 1);
|
||||
assert_eq!(arguments.len(), 1);
|
||||
|
||||
use std::convert::TryInto;
|
||||
|
||||
match arguments[0].clone() {
|
||||
TypedExpression::Array(a) => match a.into_inner() {
|
||||
ArrayExpressionInner::Value(v) => {
|
||||
assert_eq!(v.len(), 8);
|
||||
UExpressionInner::Value(
|
||||
v.into_iter()
|
||||
.map(|v| match v {
|
||||
TypedExpression::Boolean(
|
||||
BooleanExpression::Value(v),
|
||||
) => v,
|
||||
_ => unreachable!("should be a boolean"),
|
||||
})
|
||||
.enumerate()
|
||||
.fold(0, |acc, (i, v)| {
|
||||
if v {
|
||||
acc + 2u128.pow(
|
||||
(8 - i - 1).try_into().unwrap(),
|
||||
)
|
||||
} else {
|
||||
acc
|
||||
}
|
||||
}),
|
||||
)
|
||||
.annotate(8)
|
||||
.into()
|
||||
}
|
||||
_ => unreachable!("should be an array value"),
|
||||
},
|
||||
_ => unreachable!("should be an array"),
|
||||
}
|
||||
} else if key.id == "_U32_TO_BITS" {
|
||||
assert_eq!(variables.len(), 1);
|
||||
assert_eq!(arguments.len(), 1);
|
||||
|
@ -212,6 +292,70 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> {
|
|||
},
|
||||
_ => unreachable!("should be a uint"),
|
||||
}
|
||||
} else if key.id == "_U16_TO_BITS" {
|
||||
assert_eq!(variables.len(), 1);
|
||||
assert_eq!(arguments.len(), 1);
|
||||
|
||||
match arguments[0].clone() {
|
||||
TypedExpression::Uint(a) => match a.into_inner() {
|
||||
UExpressionInner::Value(v) => {
|
||||
let mut num = v;
|
||||
let mut res = vec![];
|
||||
|
||||
for i in (0..16).rev() {
|
||||
if 2u128.pow(i) <= num {
|
||||
num = num - 2u128.pow(i);
|
||||
res.push(true);
|
||||
} else {
|
||||
res.push(false);
|
||||
}
|
||||
}
|
||||
assert_eq!(num, 0);
|
||||
|
||||
ArrayExpressionInner::Value(
|
||||
res.into_iter()
|
||||
.map(|v| BooleanExpression::Value(v).into())
|
||||
.collect(),
|
||||
)
|
||||
.annotate(Type::Boolean, 16)
|
||||
.into()
|
||||
}
|
||||
_ => unreachable!("should be a uint value"),
|
||||
},
|
||||
_ => unreachable!("should be a uint"),
|
||||
}
|
||||
} else if key.id == "_U8_TO_BITS" {
|
||||
assert_eq!(variables.len(), 1);
|
||||
assert_eq!(arguments.len(), 1);
|
||||
|
||||
match arguments[0].clone() {
|
||||
TypedExpression::Uint(a) => match a.into_inner() {
|
||||
UExpressionInner::Value(v) => {
|
||||
let mut num = v;
|
||||
let mut res = vec![];
|
||||
|
||||
for i in (0..8).rev() {
|
||||
if 2u128.pow(i) <= num {
|
||||
num = num - 2u128.pow(i);
|
||||
res.push(true);
|
||||
} else {
|
||||
res.push(false);
|
||||
}
|
||||
}
|
||||
assert_eq!(num, 0);
|
||||
|
||||
ArrayExpressionInner::Value(
|
||||
res.into_iter()
|
||||
.map(|v| BooleanExpression::Value(v).into())
|
||||
.collect(),
|
||||
)
|
||||
.annotate(Type::Boolean, 8)
|
||||
.into()
|
||||
}
|
||||
_ => unreachable!("should be a uint value"),
|
||||
},
|
||||
_ => unreachable!("should be a uint"),
|
||||
}
|
||||
} else {
|
||||
unreachable!()
|
||||
};
|
||||
|
|
|
@ -359,6 +359,34 @@ impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> {
|
|||
ZirExpressionList::FunctionCall(key, arguments, ty),
|
||||
)]
|
||||
}
|
||||
"_U16_FROM_BITS" => {
|
||||
assert_eq!(lhs.len(), 1);
|
||||
self.register(
|
||||
lhs[0].clone(),
|
||||
UMetadata {
|
||||
max: T::from(2).pow(16) - T::from(1),
|
||||
should_reduce: ShouldReduce::False,
|
||||
},
|
||||
);
|
||||
vec![ZirStatement::MultipleDefinition(
|
||||
lhs,
|
||||
ZirExpressionList::FunctionCall(key, arguments, ty),
|
||||
)]
|
||||
}
|
||||
"_U8_FROM_BITS" => {
|
||||
assert_eq!(lhs.len(), 1);
|
||||
self.register(
|
||||
lhs[0].clone(),
|
||||
UMetadata {
|
||||
max: T::from(2).pow(8) - T::from(1),
|
||||
should_reduce: ShouldReduce::False,
|
||||
},
|
||||
);
|
||||
vec![ZirStatement::MultipleDefinition(
|
||||
lhs,
|
||||
ZirExpressionList::FunctionCall(key, arguments, ty),
|
||||
)]
|
||||
}
|
||||
_ => vec![ZirStatement::MultipleDefinition(
|
||||
lhs,
|
||||
ZirExpressionList::FunctionCall(
|
||||
|
|
|
@ -4,31 +4,31 @@
|
|||
"tests": [
|
||||
{
|
||||
"input": {
|
||||
"values": ["0x00000000"]
|
||||
"values": ["0x00000000", "0x0000", "0x00"]
|
||||
},
|
||||
"output": {
|
||||
"Ok": {
|
||||
"values": ["0x00000000"]
|
||||
"values": ["0x00000000", "0x0000", "0x00"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"input": {
|
||||
"values": ["0x00000002"]
|
||||
"values": ["0xffffffff", "0xffff", "0xff"]
|
||||
},
|
||||
"output": {
|
||||
"Ok": {
|
||||
"values": ["0x00000002"]
|
||||
"values": ["0xffffffff", "0xffff", "0xff"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"input": {
|
||||
"values": ["0x12345678"]
|
||||
"values": ["0x12345678", "0x1234", "0x12"]
|
||||
},
|
||||
"output": {
|
||||
"Ok": {
|
||||
"values": ["0x12345678"]
|
||||
"values": ["0x12345678", "0x1234", "0x12"]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,6 +1,12 @@
|
|||
import "EMBED/u32_to_bits" as to_bits
|
||||
import "EMBED/u32_from_bits" as from_bits
|
||||
import "EMBED/u32_to_bits" as to_bits_32
|
||||
import "EMBED/u32_from_bits" as from_bits_32
|
||||
import "EMBED/u16_to_bits" as to_bits_16
|
||||
import "EMBED/u16_from_bits" as from_bits_16
|
||||
import "EMBED/u8_to_bits" as to_bits_8
|
||||
import "EMBED/u8_from_bits" as from_bits_8
|
||||
|
||||
def main(u32 e) -> (u32):
|
||||
bool[32] f = to_bits(e)
|
||||
return from_bits(f)
|
||||
def main(u32 e, u16 f, u8 g) -> (u32, u16, u8):
|
||||
bool[32] e_bits = to_bits_32(e)
|
||||
bool[16] f_bits = to_bits_16(f)
|
||||
bool[8] g_bits = to_bits_8(g)
|
||||
return from_bits_32(e_bits), from_bits_16(f_bits), from_bits_8(g_bits)
|
Loading…
Reference in a new issue