integrate uint§
This commit is contained in:
parent
7bba70f782
commit
e7c911f73a
21 changed files with 797 additions and 225 deletions
|
@ -9,5 +9,5 @@ members = [
|
|||
"zokrates_embed",
|
||||
"zokrates_abi",
|
||||
"zokrates_test",
|
||||
"zokrates_core_test",
|
||||
"zokrates_core_test"
|
||||
]
|
4
u8.zok
4
u8.zok
|
@ -1,2 +1,2 @@
|
|||
def main(u8 a) -> (u8):
|
||||
return a
|
||||
def main(u32 a, u32 b) -> (u32):
|
||||
return a + b
|
|
@ -52,6 +52,7 @@ enum Value<T> {
|
|||
#[derive(PartialEq, Debug)]
|
||||
enum CheckedValue<T> {
|
||||
U8(u8),
|
||||
U32(u32),
|
||||
Field(T),
|
||||
Boolean(bool),
|
||||
Array(Vec<CheckedValue<T>>),
|
||||
|
@ -148,6 +149,7 @@ impl<T: From<usize>> Encode<T> for CheckedValue<T> {
|
|||
match self {
|
||||
CheckedValue::Field(t) => vec![t],
|
||||
CheckedValue::U8(t) => vec![T::from(t as usize)],
|
||||
CheckedValue::U32(t) => vec![T::from(t as usize)],
|
||||
CheckedValue::Boolean(b) => vec![if b { 1.into() } else { 0.into() }],
|
||||
CheckedValue::Array(a) => a.into_iter().flat_map(|v| v.encode()).collect(),
|
||||
CheckedValue::Struct(s) => s.into_iter().flat_map(|(_, v)| v.encode()).collect(),
|
||||
|
@ -181,7 +183,12 @@ impl<T: Field> Decode<T> for CheckedValue<T> {
|
|||
|
||||
match expected {
|
||||
Type::FieldElement => CheckedValue::Field(raw.pop().unwrap()),
|
||||
Type::U8 => CheckedValue::U8(u8::from_str_radix(&raw.pop().unwrap().to_dec_string(), 16).unwrap()),
|
||||
Type::Uint(8) => CheckedValue::U8(
|
||||
u8::from_str_radix(&raw.pop().unwrap().to_dec_string(), 10).unwrap(),
|
||||
),
|
||||
Type::Uint(32) => CheckedValue::U32(
|
||||
u32::from_str_radix(&raw.pop().unwrap().to_dec_string(), 10).unwrap(),
|
||||
),
|
||||
Type::Boolean => {
|
||||
let v = raw.pop().unwrap();
|
||||
CheckedValue::Boolean(if v == 0.into() {
|
||||
|
@ -208,6 +215,7 @@ impl<T: Field> Decode<T> for CheckedValue<T> {
|
|||
})
|
||||
.collect(),
|
||||
),
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -269,6 +277,7 @@ impl<T: Field> Into<serde_json::Value> for CheckedValue<T> {
|
|||
match self {
|
||||
CheckedValue::Field(f) => serde_json::Value::String(f.to_dec_string()),
|
||||
CheckedValue::U8(u) => serde_json::Value::String(format!("{:#x}", u)),
|
||||
CheckedValue::U32(u) => serde_json::Value::String(format!("{:#x}", u)),
|
||||
CheckedValue::Boolean(b) => serde_json::Value::Bool(b),
|
||||
CheckedValue::Array(a) => {
|
||||
serde_json::Value::Array(a.into_iter().map(|e| e.into()).collect())
|
||||
|
|
|
@ -614,7 +614,8 @@ impl<'ast> From<pest::Type<'ast>> for absy::UnresolvedTypeNode {
|
|||
pest::Type::Basic(t) => match t {
|
||||
pest::BasicType::Field(t) => absy::UnresolvedType::FieldElement.span(t.span),
|
||||
pest::BasicType::Boolean(t) => absy::UnresolvedType::Boolean.span(t.span),
|
||||
pest::BasicType::U8(t) =>absy::UnresolvedType::U8.span(t.span),
|
||||
pest::BasicType::U8(t) => absy::UnresolvedType::Uint(8).span(t.span),
|
||||
pest::BasicType::U32(t) => absy::UnresolvedType::Uint(32).span(t.span),
|
||||
},
|
||||
pest::Type::Array(t) => {
|
||||
let inner_type = match t.ty {
|
||||
|
@ -623,7 +624,8 @@ impl<'ast> From<pest::Type<'ast>> for absy::UnresolvedTypeNode {
|
|||
absy::UnresolvedType::FieldElement.span(t.span)
|
||||
}
|
||||
pest::BasicType::Boolean(t) => absy::UnresolvedType::Boolean.span(t.span),
|
||||
pest::BasicType::U8(t) =>absy::UnresolvedType::U8.span(t.span),
|
||||
pest::BasicType::U8(t) => absy::UnresolvedType::Uint(8).span(t.span),
|
||||
pest::BasicType::U32(t) => absy::UnresolvedType::Uint(32).span(t.span),
|
||||
},
|
||||
pest::BasicOrStructType::Struct(t) => {
|
||||
absy::UnresolvedType::User(t.span.as_str().to_string()).span(t.span)
|
||||
|
|
|
@ -11,7 +11,7 @@ pub type UserTypeId = String;
|
|||
pub enum UnresolvedType {
|
||||
FieldElement,
|
||||
Boolean,
|
||||
U8,
|
||||
Uint(usize),
|
||||
Array(Box<UnresolvedTypeNode>, usize),
|
||||
User(UserTypeId),
|
||||
}
|
||||
|
@ -21,7 +21,7 @@ impl fmt::Display for UnresolvedType {
|
|||
match self {
|
||||
UnresolvedType::FieldElement => write!(f, "field"),
|
||||
UnresolvedType::Boolean => write!(f, "bool"),
|
||||
UnresolvedType::U8 => write!(f, "u8"),
|
||||
UnresolvedType::Uint(bitwidth) => write!(f, "u{}", bitwidth),
|
||||
UnresolvedType::Array(ref ty, ref size) => write!(f, "{}[{}]", ty, size),
|
||||
UnresolvedType::User(i) => write!(f, "{}", i),
|
||||
}
|
||||
|
|
|
@ -233,7 +233,7 @@ pub fn unpack<T: Field>() -> FlatFunction<T> {
|
|||
.map(|index| use_variable(&mut layout, format!("o{}", index), &mut counter))
|
||||
.collect();
|
||||
|
||||
let helper = Helper::bits();
|
||||
let helper = Helper::bits(T::get_required_bits());
|
||||
|
||||
let signature = Signature {
|
||||
inputs: vec![Type::FieldElement],
|
||||
|
@ -326,7 +326,7 @@ mod tests {
|
|||
(0..FieldPrime::get_required_bits())
|
||||
.map(|i| FlatVariable::new(i + 1))
|
||||
.collect(),
|
||||
Helper::bits(),
|
||||
Helper::bits(FieldPrime::get_required_bits()),
|
||||
vec![FlatVariable::new(0)]
|
||||
))
|
||||
);
|
||||
|
|
|
@ -48,14 +48,14 @@ impl<'ast, T: Field> Flatten<'ast, T> for FieldElementExpression<'ast, T> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<'ast, T: Field> Flatten<'ast, T> for U8Expression<'ast> {
|
||||
impl<'ast, T: Field> Flatten<'ast, T> for UExpression<'ast> {
|
||||
fn flatten(
|
||||
self,
|
||||
flattener: &mut Flattener<'ast, T>,
|
||||
symbols: &TypedFunctionSymbols<'ast, T>,
|
||||
statements_flattened: &mut Vec<FlatStatement<T>>,
|
||||
) -> Vec<FlatExpression<T>> {
|
||||
vec![flattener.flatten_u8_expression(symbols, statements_flattened, self)]
|
||||
vec![flattener.flatten_uint_expression(symbols, statements_flattened, self)]
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -100,7 +100,7 @@ impl<'ast, T: Field> Flatten<'ast, T> for ArrayExpression<'ast, T> {
|
|||
statements_flattened,
|
||||
self,
|
||||
),
|
||||
Type::U8 => flattener.flatten_array_expression::<U8Expression<'ast>>(
|
||||
Type::Uint(..) => flattener.flatten_array_expression::<UExpression<'ast>>(
|
||||
symbols,
|
||||
statements_flattened,
|
||||
self,
|
||||
|
@ -252,9 +252,11 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
Type::FieldElement => FieldElementExpression::try_from(v)
|
||||
.unwrap()
|
||||
.flatten(self, symbols, statements_flattened),
|
||||
Type::U8 => U8Expression::try_from(v)
|
||||
.unwrap()
|
||||
.flatten(self, symbols, statements_flattened),
|
||||
Type::Uint(..) => UExpression::try_from(v).unwrap().flatten(
|
||||
self,
|
||||
symbols,
|
||||
statements_flattened,
|
||||
),
|
||||
Type::Boolean => BooleanExpression::try_from(v).unwrap().flatten(
|
||||
self,
|
||||
symbols,
|
||||
|
@ -343,12 +345,12 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
BooleanExpression::member(consequence.clone(), member_id.clone()),
|
||||
BooleanExpression::member(alternative.clone(), member_id),
|
||||
),
|
||||
Type::U8 => self.flatten_if_else_expression(
|
||||
Type::Uint(..) => self.flatten_if_else_expression(
|
||||
symbols,
|
||||
statements_flattened,
|
||||
condition.clone(),
|
||||
U8Expression::member(consequence.clone(), member_id.clone()),
|
||||
U8Expression::member(alternative.clone(), member_id),
|
||||
UExpression::member(consequence.clone(), member_id.clone()),
|
||||
UExpression::member(alternative.clone(), member_id),
|
||||
),
|
||||
Type::Struct(..) => self.flatten_if_else_expression(
|
||||
symbols,
|
||||
|
@ -467,13 +469,12 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
array,
|
||||
index,
|
||||
),
|
||||
Type::U8 => self
|
||||
.flatten_select_expression::<U8Expression<'ast>>(
|
||||
symbols,
|
||||
statements_flattened,
|
||||
array,
|
||||
index,
|
||||
),
|
||||
Type::Uint(..) => self.flatten_select_expression::<UExpression<'ast>>(
|
||||
symbols,
|
||||
statements_flattened,
|
||||
array,
|
||||
index,
|
||||
),
|
||||
Type::Array(..) => self
|
||||
.flatten_select_expression::<ArrayExpression<'ast, T>>(
|
||||
symbols,
|
||||
|
@ -625,7 +626,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
// add a directive to get the bits
|
||||
statements_flattened.push(FlatStatement::Directive(DirectiveStatement::new(
|
||||
lhs_bits_be.clone(),
|
||||
Helper::bits(),
|
||||
Helper::bits(T::get_required_bits()),
|
||||
vec![lhs_id],
|
||||
)));
|
||||
|
||||
|
@ -672,7 +673,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
// add a directive to get the bits
|
||||
statements_flattened.push(FlatStatement::Directive(DirectiveStatement::new(
|
||||
rhs_bits_be.clone(),
|
||||
Helper::bits(),
|
||||
Helper::bits(T::get_required_bits()),
|
||||
vec![rhs_id],
|
||||
)));
|
||||
|
||||
|
@ -725,7 +726,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
// add a directive to get the bits
|
||||
statements_flattened.push(FlatStatement::Directive(DirectiveStatement::new(
|
||||
sub_bits_be.clone(),
|
||||
Helper::bits(),
|
||||
Helper::bits(T::get_required_bits()),
|
||||
vec![subtraction_result.clone()],
|
||||
)));
|
||||
|
||||
|
@ -761,7 +762,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
// lhs and rhs are booleans, they flatten to 0 or 1
|
||||
let x = self.flatten_boolean_expression(symbols, statements_flattened, lhs);
|
||||
let y = self.flatten_boolean_expression(symbols, statements_flattened, rhs);
|
||||
// Wanted: Not(X - Y)**2 which is an XNOR
|
||||
// Wanted: Not(X - Y)**2 which is an XNOR
|
||||
// We know that X and Y are [0, 1]
|
||||
// (X - Y) can become a negative values, which is why squaring the result is needed
|
||||
// Negating this returns correct result
|
||||
|
@ -776,15 +777,12 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
// | 0 | 0 | 0 | 1 |
|
||||
// +---+---+-------+---------------+
|
||||
|
||||
let x_sub_y = FlatExpression::Sub(box x, box y);
|
||||
let x_sub_y = FlatExpression::Sub(box x, box y);
|
||||
let name_x_mult_x = self.use_sym();
|
||||
|
||||
statements_flattened.push(FlatStatement::Definition(
|
||||
name_x_mult_x,
|
||||
FlatExpression::Mult(
|
||||
box x_sub_y.clone(),
|
||||
box x_sub_y,
|
||||
),
|
||||
FlatExpression::Mult(box x_sub_y.clone(), box x_sub_y),
|
||||
));
|
||||
|
||||
FlatExpression::Sub(
|
||||
|
@ -1042,8 +1040,17 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
TypedExpression::Boolean(e) => {
|
||||
vec![self.flatten_boolean_expression(symbols, statements_flattened, e)]
|
||||
}
|
||||
TypedExpression::U8(e) => {
|
||||
vec![self.flatten_u8_expression(symbols, statements_flattened, e)]
|
||||
TypedExpression::Uint(e) => {
|
||||
let e = UExpression {
|
||||
metadata: Some(UMetadata {
|
||||
should_reduce: Some(true),
|
||||
bitwidth: None,
|
||||
}),
|
||||
..e
|
||||
};
|
||||
let e = e.reduce::<T>();
|
||||
|
||||
vec![self.flatten_uint_expression(symbols, statements_flattened, e)]
|
||||
}
|
||||
TypedExpression::Array(e) => match e.inner_type().clone() {
|
||||
Type::FieldElement => self
|
||||
|
@ -1057,7 +1064,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
statements_flattened,
|
||||
e,
|
||||
),
|
||||
Type::U8 => self.flatten_array_expression::<U8Expression<'ast>>(
|
||||
Type::Uint(..) => self.flatten_array_expression::<UExpression<'ast>>(
|
||||
symbols,
|
||||
statements_flattened,
|
||||
e,
|
||||
|
@ -1079,24 +1086,131 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
}
|
||||
}
|
||||
|
||||
/// Flattens a u8 expression
|
||||
/// Flattens a uint expression
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `symbols` - Available functions in in this context
|
||||
/// * `statements_flattened` - Vector where new flattened statements can be added.
|
||||
/// * `expr` - `U8ElementExpression` that will be flattened.
|
||||
fn flatten_u8_expression(
|
||||
/// * `expr` - `UExpression` that will be flattened.
|
||||
fn flatten_uint_expression(
|
||||
&mut self,
|
||||
symbols: &TypedFunctionSymbols<'ast, T>,
|
||||
statements_flattened: &mut Vec<FlatStatement<T>>,
|
||||
expr: U8Expression<'ast>,
|
||||
expr: UExpression<'ast>,
|
||||
) -> FlatExpression<T> {
|
||||
match expr {
|
||||
U8Expression::Value(x) => FlatExpression::Number(T::from(x as usize)), // force to be a field element
|
||||
U8Expression::Identifier(x) => {
|
||||
let target_bitwidth = expr.bitwidth;
|
||||
|
||||
let metadata = expr.metadata.clone().unwrap().clone();
|
||||
let actual_bitwidth = metadata.bitwidth.unwrap();
|
||||
let should_reduce = metadata.should_reduce.unwrap();
|
||||
|
||||
let res = match expr.into_inner() {
|
||||
UExpressionInner::Value(x) => FlatExpression::Number(T::from(x as u128)), // force to be a field element
|
||||
UExpressionInner::Identifier(x) => {
|
||||
FlatExpression::Identifier(self.layout.get(&x).unwrap().clone()[0])
|
||||
}
|
||||
UExpressionInner::Add(box left, box right) => {
|
||||
let left_flattened =
|
||||
self.flatten_uint_expression(symbols, statements_flattened, left);
|
||||
let right_flattened =
|
||||
self.flatten_uint_expression(symbols, statements_flattened, right);
|
||||
let new_left = if left_flattened.is_linear() {
|
||||
left_flattened
|
||||
} else {
|
||||
let id = self.use_sym();
|
||||
statements_flattened.push(FlatStatement::Definition(id, left_flattened));
|
||||
FlatExpression::Identifier(id)
|
||||
};
|
||||
let new_right = if right_flattened.is_linear() {
|
||||
right_flattened
|
||||
} else {
|
||||
let id = self.use_sym();
|
||||
statements_flattened.push(FlatStatement::Definition(id, right_flattened));
|
||||
FlatExpression::Identifier(id)
|
||||
};
|
||||
FlatExpression::Add(box new_left, box new_right)
|
||||
}
|
||||
UExpressionInner::Mult(box left, box right) => {
|
||||
if metadata.should_reduce.unwrap() {
|
||||
unimplemented!()
|
||||
} else {
|
||||
let left_flattened =
|
||||
self.flatten_uint_expression(symbols, statements_flattened, left);
|
||||
let right_flattened =
|
||||
self.flatten_uint_expression(symbols, statements_flattened, right);
|
||||
let new_left = if left_flattened.is_linear() {
|
||||
left_flattened
|
||||
} else {
|
||||
let id = self.use_sym();
|
||||
statements_flattened.push(FlatStatement::Definition(id, left_flattened));
|
||||
FlatExpression::Identifier(id)
|
||||
};
|
||||
let new_right = if right_flattened.is_linear() {
|
||||
right_flattened
|
||||
} else {
|
||||
let id = self.use_sym();
|
||||
statements_flattened.push(FlatStatement::Definition(id, right_flattened));
|
||||
FlatExpression::Identifier(id)
|
||||
};
|
||||
FlatExpression::Mult(box new_left, box new_right)
|
||||
}
|
||||
}
|
||||
UExpressionInner::Xor(box left, box right) => unimplemented!(),
|
||||
};
|
||||
|
||||
match should_reduce {
|
||||
true => {
|
||||
let bits = (0..actual_bitwidth)
|
||||
.map(|_| self.use_sym())
|
||||
.collect::<Vec<_>>();
|
||||
statements_flattened.push(FlatStatement::Directive(DirectiveStatement::new(
|
||||
bits.clone(),
|
||||
Helper::Rust(RustHelper::Bits(actual_bitwidth)),
|
||||
vec![res.clone()],
|
||||
)));
|
||||
|
||||
use std::convert::TryInto;
|
||||
|
||||
// decompose to the actual bitwidth
|
||||
|
||||
// bit checks
|
||||
statements_flattened.extend((0..actual_bitwidth).map(|i| {
|
||||
FlatStatement::Condition(
|
||||
bits[i].clone().into(),
|
||||
FlatExpression::Mult(
|
||||
box bits[i].clone().into(),
|
||||
box bits[i].clone().into(),
|
||||
),
|
||||
)
|
||||
}));
|
||||
|
||||
// sum check
|
||||
statements_flattened.push(FlatStatement::Condition(
|
||||
res.clone(),
|
||||
(0..actual_bitwidth).fold(FlatExpression::Number(T::from(0)), |acc, i| {
|
||||
FlatExpression::Add(
|
||||
box acc,
|
||||
box FlatExpression::Mult(
|
||||
box FlatExpression::Number(T::from(2).pow(actual_bitwidth - i - 1)),
|
||||
box bits[i].into(),
|
||||
),
|
||||
)
|
||||
}),
|
||||
));
|
||||
|
||||
// truncate to the target bitwidth
|
||||
(0..target_bitwidth).fold(FlatExpression::Number(T::from(0)), |acc, i| {
|
||||
FlatExpression::Add(
|
||||
box acc,
|
||||
box FlatExpression::Mult(
|
||||
box FlatExpression::Number(T::from(2).pow(target_bitwidth - i - 1)),
|
||||
box bits[i + actual_bitwidth - target_bitwidth].into(),
|
||||
),
|
||||
)
|
||||
})
|
||||
}
|
||||
false => res,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1392,10 +1506,10 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
FieldElementExpression::member(alternative.clone(), id.clone()),
|
||||
)
|
||||
.flatten(self, symbols, statements_flattened),
|
||||
Type::U8 => U8Expression::if_else(
|
||||
Type::Uint(..) => UExpression::if_else(
|
||||
condition.clone(),
|
||||
U8Expression::member(consequence.clone(), id.clone()),
|
||||
U8Expression::member(alternative.clone(), id.clone()),
|
||||
UExpression::member(consequence.clone(), id.clone()),
|
||||
UExpression::member(alternative.clone(), id.clone()),
|
||||
)
|
||||
.flatten(self, symbols, statements_flattened),
|
||||
Type::Boolean => BooleanExpression::if_else(
|
||||
|
|
|
@ -77,8 +77,8 @@ impl Helper {
|
|||
Helper::Rust(RustHelper::Identity)
|
||||
}
|
||||
|
||||
pub fn bits() -> Self {
|
||||
Helper::Rust(RustHelper::Bits)
|
||||
pub fn bits(bitwidth: usize) -> Self {
|
||||
Helper::Rust(RustHelper::Bits(bitwidth))
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -7,7 +7,7 @@ use zokrates_field::field::Field;
|
|||
pub enum RustHelper {
|
||||
Identity,
|
||||
ConditionEq,
|
||||
Bits,
|
||||
Bits(usize),
|
||||
Div,
|
||||
Sha256Round,
|
||||
}
|
||||
|
@ -23,7 +23,7 @@ impl Signed for RustHelper {
|
|||
match self {
|
||||
RustHelper::Identity => (1, 1),
|
||||
RustHelper::ConditionEq => (1, 2),
|
||||
RustHelper::Bits => (1, 254),
|
||||
RustHelper::Bits(bitwidth) => (1, *bitwidth),
|
||||
RustHelper::Div => (2, 1),
|
||||
RustHelper::Sha256Round => (768, 26935),
|
||||
}
|
||||
|
@ -38,11 +38,11 @@ impl<T: Field> Executable<T> for RustHelper {
|
|||
true => Ok(vec![T::zero(), T::one()]),
|
||||
false => Ok(vec![T::one(), T::one() / inputs[0].clone()]),
|
||||
},
|
||||
RustHelper::Bits => {
|
||||
RustHelper::Bits(bitwidth) => {
|
||||
let mut num = inputs[0].clone();
|
||||
let mut res = vec![];
|
||||
let bits = 254;
|
||||
for i in (0..bits).rev() {
|
||||
|
||||
for i in (0..*bitwidth).rev() {
|
||||
if T::from(2).pow(i) <= num {
|
||||
num = num - T::from(2).pow(i);
|
||||
res.push(T::one());
|
||||
|
@ -77,7 +77,9 @@ mod tests {
|
|||
#[test]
|
||||
fn bits_of_one() {
|
||||
let inputs = vec![FieldPrime::from(1)];
|
||||
let res = RustHelper::Bits.execute(&inputs).unwrap();
|
||||
let res = RustHelper::Bits(FieldPrime::get_required_bits())
|
||||
.execute(&inputs)
|
||||
.unwrap();
|
||||
assert_eq!(res[253], FieldPrime::from(1));
|
||||
for i in 0..252 {
|
||||
assert_eq!(res[i], FieldPrime::from(0));
|
||||
|
@ -87,7 +89,9 @@ mod tests {
|
|||
#[test]
|
||||
fn bits_of_42() {
|
||||
let inputs = vec![FieldPrime::from(42)];
|
||||
let res = RustHelper::Bits.execute(&inputs).unwrap();
|
||||
let res = RustHelper::Bits(FieldPrime::get_required_bits())
|
||||
.execute(&inputs)
|
||||
.unwrap();
|
||||
assert_eq!(res[253], FieldPrime::from(0));
|
||||
assert_eq!(res[252], FieldPrime::from(1));
|
||||
assert_eq!(res[251], FieldPrime::from(0));
|
||||
|
|
|
@ -394,7 +394,7 @@ library BN256G2 {
|
|||
}
|
||||
"#;
|
||||
|
||||
pub const SOLIDITY_PAIRING_LIB_V2 : &str = r#"// This file is MIT Licensed.
|
||||
pub const SOLIDITY_PAIRING_LIB_V2: &str = r#"// This file is MIT Licensed.
|
||||
//
|
||||
// Copyright 2017 Christian Reitwiessner
|
||||
// Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
||||
|
@ -543,7 +543,7 @@ library Pairing {
|
|||
}
|
||||
"#;
|
||||
|
||||
pub const SOLIDITY_PAIRING_LIB : &str = r#"// This file is MIT Licensed.
|
||||
pub const SOLIDITY_PAIRING_LIB: &str = r#"// This file is MIT Licensed.
|
||||
//
|
||||
// Copyright 2017 Christian Reitwiessner
|
||||
// Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
||||
|
|
|
@ -712,7 +712,7 @@ impl<'ast> Checker<'ast> {
|
|||
match ty {
|
||||
UnresolvedType::FieldElement => Ok(Type::FieldElement),
|
||||
UnresolvedType::Boolean => Ok(Type::Boolean),
|
||||
UnresolvedType::U8 => Ok(Type::U8),
|
||||
UnresolvedType::Uint(bitwidth) => Ok(Type::Uint(bitwidth)),
|
||||
UnresolvedType::Array(t, size) => Ok(Type::Array(
|
||||
box self.check_type(*t, module_id, types)?,
|
||||
size,
|
||||
|
@ -1038,8 +1038,8 @@ impl<'ast> Checker<'ast> {
|
|||
FieldElementExpression::Number(T::from(i)),
|
||||
)
|
||||
.into(),
|
||||
Type::U8 => U8Expression::select(
|
||||
e.clone().annotate(Type::U8, size),
|
||||
Type::Uint(bitwidth) => UExpression::select(
|
||||
e.clone().annotate(Type::Uint(*bitwidth), size),
|
||||
FieldElementExpression::Number(T::from(i)),
|
||||
)
|
||||
.into(),
|
||||
|
@ -1049,9 +1049,7 @@ impl<'ast> Checker<'ast> {
|
|||
)
|
||||
.into(),
|
||||
Type::Array(box ty, s) => ArrayExpression::select(
|
||||
e
|
||||
.clone()
|
||||
.annotate(Type::array(ty.clone(), *s), size),
|
||||
e.clone().annotate(Type::array(ty.clone(), *s), size),
|
||||
FieldElementExpression::Number(T::from(i)),
|
||||
)
|
||||
.into(),
|
||||
|
@ -1095,7 +1093,9 @@ impl<'ast> Checker<'ast> {
|
|||
match self.get_scope(&name) {
|
||||
Some(v) => match v.id.get_type() {
|
||||
Type::Boolean => Ok(BooleanExpression::Identifier(name.into()).into()),
|
||||
Type::U8 => Ok(U8Expression::Identifier(name.into()).into()),
|
||||
Type::Uint(bitwidth) => Ok(UExpressionInner::Identifier(name.into())
|
||||
.annotate(bitwidth)
|
||||
.into()),
|
||||
Type::FieldElement => {
|
||||
Ok(FieldElementExpression::Identifier(name.into()).into())
|
||||
}
|
||||
|
@ -1120,11 +1120,26 @@ impl<'ast> Checker<'ast> {
|
|||
(TypedExpression::FieldElement(e1), TypedExpression::FieldElement(e2)) => {
|
||||
Ok(FieldElementExpression::Add(box e1, box e2).into())
|
||||
}
|
||||
(TypedExpression::Uint(e1), TypedExpression::Uint(e2)) => {
|
||||
if e1.get_type() == e2.get_type() {
|
||||
Ok(UExpression::add(e1, e2).into())
|
||||
} else {
|
||||
Err(Error {
|
||||
pos: Some(pos),
|
||||
|
||||
message: format!(
|
||||
"Cannot apply `+` to {:?}, {:?}",
|
||||
e1.get_type(),
|
||||
e2.get_type()
|
||||
),
|
||||
})
|
||||
}
|
||||
}
|
||||
(t1, t2) => Err(Error {
|
||||
pos: Some(pos),
|
||||
|
||||
message: format!(
|
||||
"Expected only field elements, found {:?}, {:?}",
|
||||
"Cannot apply `+` to {:?}, {:?}",
|
||||
t1.get_type(),
|
||||
t2.get_type()
|
||||
),
|
||||
|
@ -1158,11 +1173,26 @@ impl<'ast> Checker<'ast> {
|
|||
(TypedExpression::FieldElement(e1), TypedExpression::FieldElement(e2)) => {
|
||||
Ok(FieldElementExpression::Mult(box e1, box e2).into())
|
||||
}
|
||||
(TypedExpression::Uint(e1), TypedExpression::Uint(e2)) => {
|
||||
if e1.get_type() == e2.get_type() {
|
||||
Ok(UExpression::mult(e1, e2).into())
|
||||
} else {
|
||||
Err(Error {
|
||||
pos: Some(pos),
|
||||
|
||||
message: format!(
|
||||
"Cannot apply `*` to {:?}, {:?}",
|
||||
e1.get_type(),
|
||||
e2.get_type()
|
||||
),
|
||||
})
|
||||
}
|
||||
}
|
||||
(t1, t2) => Err(Error {
|
||||
pos: Some(pos),
|
||||
|
||||
message: format!(
|
||||
"Expected only field elements, found {:?}, {:?}",
|
||||
"Cannot apply `*` to {:?}, {:?}",
|
||||
t1.get_type(),
|
||||
t2.get_type()
|
||||
),
|
||||
|
@ -1503,23 +1533,13 @@ impl<'ast> Checker<'ast> {
|
|||
match (array, self.check_expression(e, module_id, &types)?) {
|
||||
(TypedExpression::Array(a), TypedExpression::FieldElement(i)) => {
|
||||
match a.inner_type().clone() {
|
||||
Type::FieldElement =>
|
||||
Ok(FieldElementExpression::select(a, i).into()),
|
||||
|
||||
Type::U8 =>
|
||||
Ok(U8Expression::select(a, i).into()),
|
||||
|
||||
Type::Boolean =>
|
||||
Ok(BooleanExpression::select(a, i).into()),
|
||||
|
||||
Type::Array(..) =>
|
||||
Ok(ArrayExpression::select(a, i)
|
||||
.into()),
|
||||
|
||||
Type::Struct(..) =>
|
||||
Ok(StructExpression::select(a, i)
|
||||
.into()),
|
||||
|
||||
Type::FieldElement => {
|
||||
Ok(FieldElementExpression::select(a, i).into())
|
||||
}
|
||||
Type::Uint(..) => Ok(UExpression::select(a, i).into()),
|
||||
Type::Boolean => Ok(BooleanExpression::select(a, i).into()),
|
||||
Type::Array(..) => Ok(ArrayExpression::select(a, i).into()),
|
||||
Type::Struct(..) => Ok(StructExpression::select(a, i).into()),
|
||||
}
|
||||
}
|
||||
(a, e) => Err(Error {
|
||||
|
@ -1548,21 +1568,22 @@ impl<'ast> Checker<'ast> {
|
|||
|
||||
match ty {
|
||||
Some(ty) => match ty {
|
||||
Type::FieldElement =>
|
||||
Ok(FieldElementExpression::member(s, id.to_string()).into()),
|
||||
Type::FieldElement => {
|
||||
Ok(FieldElementExpression::member(s, id.to_string()).into())
|
||||
}
|
||||
|
||||
Type::Boolean =>
|
||||
Ok(BooleanExpression::member(s, id.to_string()).into()),
|
||||
Type::Boolean => {
|
||||
Ok(BooleanExpression::member(s, id.to_string()).into())
|
||||
}
|
||||
|
||||
Type::U8 =>
|
||||
Ok(U8Expression::member(s, id.to_string()).into()),
|
||||
|
||||
Type::Array(..) =>
|
||||
Ok(ArrayExpression::member(s.clone(), id.to_string()).into()),
|
||||
Type::Struct(..) =>
|
||||
Ok(StructExpression::member(s.clone(), id.to_string())
|
||||
.into())
|
||||
Type::Uint(..) => Ok(UExpression::member(s, id.to_string()).into()),
|
||||
|
||||
Type::Array(..) => {
|
||||
Ok(ArrayExpression::member(s.clone(), id.to_string()).into())
|
||||
}
|
||||
Type::Struct(..) => {
|
||||
Ok(StructExpression::member(s.clone(), id.to_string()).into())
|
||||
}
|
||||
},
|
||||
None => Err(Error {
|
||||
pos: Some(pos),
|
||||
|
@ -1646,20 +1667,35 @@ impl<'ast> Checker<'ast> {
|
|||
.annotate(Type::Boolean, size)
|
||||
.into())
|
||||
}
|
||||
Type::U8 => {
|
||||
ty @ Type::Uint(..) => {
|
||||
// we check all expressions have that same type
|
||||
let mut unwrapped_expressions = vec![];
|
||||
|
||||
for e in expressions_checked {
|
||||
let unwrapped_e = match e {
|
||||
TypedExpression::U8(e) => Ok(e),
|
||||
TypedExpression::Uint(e) => {
|
||||
if e.get_type() == ty {
|
||||
Ok(e)
|
||||
} else {
|
||||
Err(Error {
|
||||
pos: Some(pos),
|
||||
|
||||
message: format!(
|
||||
"Expected {} to have type {}, but type is {}",
|
||||
e,
|
||||
ty,
|
||||
e.get_type()
|
||||
),
|
||||
})
|
||||
}
|
||||
}
|
||||
e => Err(Error {
|
||||
pos: Some(pos),
|
||||
|
||||
message: format!(
|
||||
"Expected {} to have type {}, but type is {}",
|
||||
e,
|
||||
inferred_type,
|
||||
ty,
|
||||
e.get_type()
|
||||
),
|
||||
}),
|
||||
|
@ -1670,7 +1706,7 @@ impl<'ast> Checker<'ast> {
|
|||
let size = unwrapped_expressions.len();
|
||||
|
||||
Ok(ArrayExpressionInner::Value(unwrapped_expressions)
|
||||
.annotate(Type::U8, size)
|
||||
.annotate(ty, size)
|
||||
.into())
|
||||
}
|
||||
ty @ Type::Array(..) => {
|
||||
|
|
|
@ -54,7 +54,8 @@ impl<'ast, T: Field> InputConstrainer<'ast, T> {
|
|||
b.clone().into(),
|
||||
BooleanExpression::And(box b.clone(), box b).into(),
|
||||
)),
|
||||
TypedExpression::U8(_) => {
|
||||
TypedExpression::Uint(bitwidth) => {
|
||||
// TODO constrain by checking that it decomposes correctly
|
||||
}
|
||||
TypedExpression::Array(a) => {
|
||||
for i in 0..a.size() {
|
||||
|
@ -64,7 +65,7 @@ impl<'ast, T: Field> InputConstrainer<'ast, T> {
|
|||
FieldElementExpression::Number(T::from(i)),
|
||||
)
|
||||
.into(),
|
||||
Type::U8 => U8Expression::select(
|
||||
Type::Uint(..) => UExpression::select(
|
||||
a.clone(),
|
||||
FieldElementExpression::Number(T::from(i)),
|
||||
)
|
||||
|
@ -96,7 +97,7 @@ impl<'ast, T: Field> InputConstrainer<'ast, T> {
|
|||
FieldElementExpression::member(s.clone(), id.clone()).into()
|
||||
}
|
||||
Type::Boolean => BooleanExpression::member(s.clone(), id.clone()).into(),
|
||||
Type::U8 => U8Expression::member(s.clone(), id.clone()).into(),
|
||||
Type::Uint(..) => UExpression::member(s.clone(), id.clone()).into(),
|
||||
Type::Array(..) => ArrayExpression::member(s.clone(), id.clone()).into(),
|
||||
Type::Struct(..) => StructExpression::member(s.clone(), id.clone()).into(),
|
||||
};
|
||||
|
@ -115,7 +116,7 @@ impl<'ast, T: Field> Folder<'ast, T> for InputConstrainer<'ast, T> {
|
|||
let e = match v.get_type() {
|
||||
Type::FieldElement => FieldElementExpression::Identifier(v.id).into(),
|
||||
Type::Boolean => BooleanExpression::Identifier(v.id).into(),
|
||||
Type::U8 => U8Expression::Identifier(v.id).into(),
|
||||
Type::Uint(bitwidth) => UExpressionInner::Identifier(v.id).annotate(bitwidth).into(),
|
||||
Type::Struct(members) => StructExpressionInner::Identifier(v.id)
|
||||
.annotate(members)
|
||||
.into(),
|
||||
|
|
|
@ -185,13 +185,13 @@ impl<'ast> Unroller<'ast> {
|
|||
),
|
||||
)
|
||||
.into(),
|
||||
Type::U8 => U8Expression::if_else(
|
||||
Type::Uint(..) => UExpression::if_else(
|
||||
BooleanExpression::FieldEq(
|
||||
box FieldElementExpression::Number(T::from(i)),
|
||||
box head.clone(),
|
||||
),
|
||||
match Self::choose_many(
|
||||
U8Expression::select(
|
||||
UExpression::select(
|
||||
base.clone(),
|
||||
FieldElementExpression::Number(T::from(i)),
|
||||
)
|
||||
|
@ -200,13 +200,13 @@ impl<'ast> Unroller<'ast> {
|
|||
new_expression.clone(),
|
||||
statements,
|
||||
) {
|
||||
TypedExpression::U8(e) => e,
|
||||
TypedExpression::Uint(e) => e,
|
||||
e => unreachable!(
|
||||
"the interior was expected to be a u8, was {}",
|
||||
"the interior was expected to be a uint, was {}",
|
||||
e.get_type()
|
||||
),
|
||||
},
|
||||
U8Expression::select(
|
||||
UExpression::select(
|
||||
base.clone(),
|
||||
FieldElementExpression::Number(T::from(i)),
|
||||
),
|
||||
|
@ -253,21 +253,17 @@ impl<'ast> Unroller<'ast> {
|
|||
.into()
|
||||
}
|
||||
}
|
||||
Type::U8 => {
|
||||
Type::Uint(..) => {
|
||||
if id == head {
|
||||
Self::choose_many(
|
||||
U8Expression::member(
|
||||
base.clone(),
|
||||
head.clone(),
|
||||
)
|
||||
.into(),
|
||||
UExpression::member(base.clone(), head.clone())
|
||||
.into(),
|
||||
tail.clone(),
|
||||
new_expression.clone(),
|
||||
statements,
|
||||
)
|
||||
} else {
|
||||
U8Expression::member(base.clone(), id.clone())
|
||||
.into()
|
||||
UExpression::member(base.clone(), id.clone()).into()
|
||||
}
|
||||
}
|
||||
Type::Boolean => {
|
||||
|
@ -370,8 +366,10 @@ impl<'ast, T: Field> Folder<'ast, T> for Unroller<'ast> {
|
|||
Type::Boolean => {
|
||||
BooleanExpression::Identifier(variable.id.clone().into()).into()
|
||||
}
|
||||
Type::U8 => {
|
||||
U8Expression::Identifier(variable.id.clone().into()).into()
|
||||
Type::Uint(bitwidth) => {
|
||||
UExpressionInner::Identifier(variable.id.clone().into())
|
||||
.annotate(bitwidth)
|
||||
.into()
|
||||
}
|
||||
Type::Array(box ty, size) => {
|
||||
ArrayExpressionInner::Identifier(variable.id.clone().into())
|
||||
|
|
|
@ -60,7 +60,7 @@ pub trait Folder<'ast, T: Field>: Sized {
|
|||
match e {
|
||||
TypedExpression::FieldElement(e) => self.fold_field_expression(e).into(),
|
||||
TypedExpression::Boolean(e) => self.fold_boolean_expression(e).into(),
|
||||
TypedExpression::U8(e) => self.fold_u8_expression(e).into(),
|
||||
TypedExpression::Uint(e) => self.fold_uint_expression(e).into(),
|
||||
TypedExpression::Array(e) => self.fold_array_expression(e).into(),
|
||||
TypedExpression::Struct(e) => self.fold_struct_expression(e).into(),
|
||||
}
|
||||
|
@ -107,12 +107,18 @@ pub trait Folder<'ast, T: Field>: Sized {
|
|||
) -> BooleanExpression<'ast, T> {
|
||||
fold_boolean_expression(self, e)
|
||||
}
|
||||
fn fold_u8_expression(
|
||||
&mut self,
|
||||
e: U8Expression<'ast>,
|
||||
) -> U8Expression<'ast> {
|
||||
fold_u8_expression(self, e)
|
||||
fn fold_uint_expression(&mut self, e: UExpression<'ast>) -> UExpression<'ast> {
|
||||
fold_uint_expression(self, e)
|
||||
}
|
||||
|
||||
fn fold_uint_expression_inner(
|
||||
&mut self,
|
||||
bitwidth: usize,
|
||||
e: UExpressionInner<'ast>,
|
||||
) -> UExpressionInner<'ast> {
|
||||
fold_uint_expression_inner(self, bitwidth, e)
|
||||
}
|
||||
|
||||
fn fold_array_expression_inner(
|
||||
&mut self,
|
||||
ty: &Type,
|
||||
|
@ -371,13 +377,42 @@ pub fn fold_boolean_expression<'ast, T: Field, F: Folder<'ast, T>>(
|
|||
}
|
||||
}
|
||||
|
||||
pub fn fold_u8_expression<'ast, T: Field, F: Folder<'ast, T>>(
|
||||
pub fn fold_uint_expression<'ast, T: Field, F: Folder<'ast, T>>(
|
||||
f: &mut F,
|
||||
e: U8Expression<'ast>,
|
||||
) -> U8Expression<'ast> {
|
||||
e: UExpression<'ast>,
|
||||
) -> UExpression<'ast> {
|
||||
UExpression {
|
||||
inner: f.fold_uint_expression_inner(e.bitwidth, e.inner),
|
||||
..e
|
||||
}
|
||||
}
|
||||
|
||||
pub fn fold_uint_expression_inner<'ast, T: Field, F: Folder<'ast, T>>(
|
||||
f: &mut F,
|
||||
bitwidth: usize,
|
||||
e: UExpressionInner<'ast>,
|
||||
) -> UExpressionInner<'ast> {
|
||||
match e {
|
||||
U8Expression::Value(v) => U8Expression::Value(v),
|
||||
U8Expression::Identifier(id) => U8Expression::Identifier(f.fold_name(id)),
|
||||
UExpressionInner::Value(v) => UExpressionInner::Value(v),
|
||||
UExpressionInner::Identifier(id) => UExpressionInner::Identifier(f.fold_name(id)),
|
||||
UExpressionInner::Add(box left, box right) => {
|
||||
let left = fold_uint_expression(f, left);
|
||||
let right = fold_uint_expression(f, right);
|
||||
|
||||
UExpressionInner::Add(box left, box right)
|
||||
}
|
||||
UExpressionInner::Mult(box left, box right) => {
|
||||
let left = fold_uint_expression(f, left);
|
||||
let right = fold_uint_expression(f, right);
|
||||
|
||||
UExpressionInner::Mult(box left, box right)
|
||||
}
|
||||
UExpressionInner::Xor(box left, box right) => {
|
||||
let left = fold_uint_expression(f, left);
|
||||
let right = fold_uint_expression(f, right);
|
||||
|
||||
UExpressionInner::Xor(box left, box right)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
57
zokrates_core/src/typed_absy/identifier.rs
Normal file
57
zokrates_core/src/typed_absy/identifier.rs
Normal file
|
@ -0,0 +1,57 @@
|
|||
use std::fmt;
|
||||
use typed_absy::types::FunctionKey;
|
||||
use typed_absy::TypedModuleId;
|
||||
|
||||
/// A identifier for a variable
|
||||
#[derive(Debug, PartialEq, Clone, Hash, Eq)]
|
||||
pub struct Identifier<'ast> {
|
||||
/// the id of the variable
|
||||
pub id: &'ast str,
|
||||
/// the version of the variable, used after SSA transformation
|
||||
pub version: usize,
|
||||
/// the call stack of the variable, used when inlining
|
||||
pub stack: Vec<(TypedModuleId, FunctionKey<'ast>, usize)>,
|
||||
}
|
||||
|
||||
impl<'ast> fmt::Display for Identifier<'ast> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
if self.stack.len() == 0 && self.version == 0 {
|
||||
write!(f, "{}", self.id)
|
||||
} else {
|
||||
write!(
|
||||
f,
|
||||
"{}_{}_{}",
|
||||
self.stack
|
||||
.iter()
|
||||
.map(|(name, sig, count)| format!("{}_{}_{}", name, sig.to_slug(), count))
|
||||
.collect::<Vec<_>>()
|
||||
.join("_"),
|
||||
self.id,
|
||||
self.version
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast> From<&'ast str> for Identifier<'ast> {
|
||||
fn from(id: &'ast str) -> Identifier<'ast> {
|
||||
Identifier {
|
||||
id,
|
||||
version: 0,
|
||||
stack: vec![],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
impl<'ast> Identifier<'ast> {
|
||||
pub fn version(mut self, version: usize) -> Self {
|
||||
self.version = version;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn stack(mut self, stack: Vec<(TypedModuleId, FunctionKey<'ast>, usize)>) -> Self {
|
||||
self.stack = stack;
|
||||
self
|
||||
}
|
||||
}
|
|
@ -6,13 +6,16 @@
|
|||
//! @date 2017
|
||||
|
||||
pub mod folder;
|
||||
mod identifier;
|
||||
mod parameter;
|
||||
pub mod types;
|
||||
mod uint;
|
||||
mod variable;
|
||||
|
||||
pub use crate::typed_absy::parameter::Parameter;
|
||||
pub use crate::typed_absy::types::Type;
|
||||
pub use crate::typed_absy::variable::Variable;
|
||||
pub use self::parameter::Parameter;
|
||||
pub use self::types::Type;
|
||||
pub use self::variable::Variable;
|
||||
pub use typed_absy::uint::{UExpression, UExpressionInner, UMetadata};
|
||||
|
||||
use crate::typed_absy::types::{FunctionKey, MemberId, Signature};
|
||||
use embed::FlatEmbed;
|
||||
|
@ -23,16 +26,7 @@ use zokrates_field::field::Field;
|
|||
|
||||
pub use self::folder::Folder;
|
||||
|
||||
/// A identifier for a variable
|
||||
#[derive(Debug, PartialEq, Clone, Hash, Eq)]
|
||||
pub struct Identifier<'ast> {
|
||||
/// the id of the variable
|
||||
pub id: &'ast str,
|
||||
/// the version of the variable, used after SSA transformation
|
||||
pub version: usize,
|
||||
/// the call stack of the variable, used when inlining
|
||||
pub stack: Vec<(TypedModuleId, FunctionKey<'ast>, usize)>,
|
||||
}
|
||||
pub use self::identifier::Identifier;
|
||||
|
||||
/// An identifier for a `TypedModule`. Typically a path or uri.
|
||||
pub type TypedModuleId = String;
|
||||
|
@ -82,49 +76,6 @@ pub struct TypedModule<'ast, T: Field> {
|
|||
pub functions: TypedFunctionSymbols<'ast, T>,
|
||||
}
|
||||
|
||||
impl<'ast> fmt::Display for Identifier<'ast> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
if self.stack.len() == 0 && self.version == 0 {
|
||||
write!(f, "{}", self.id)
|
||||
} else {
|
||||
write!(
|
||||
f,
|
||||
"{}_{}_{}",
|
||||
self.stack
|
||||
.iter()
|
||||
.map(|(name, sig, count)| format!("{}_{}_{}", name, sig.to_slug(), count))
|
||||
.collect::<Vec<_>>()
|
||||
.join("_"),
|
||||
self.id,
|
||||
self.version
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast> From<&'ast str> for Identifier<'ast> {
|
||||
fn from(id: &'ast str) -> Identifier<'ast> {
|
||||
Identifier {
|
||||
id,
|
||||
version: 0,
|
||||
stack: vec![],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
impl<'ast> Identifier<'ast> {
|
||||
pub fn version(mut self, version: usize) -> Self {
|
||||
self.version = version;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn stack(mut self, stack: Vec<(TypedModuleId, FunctionKey<'ast>, usize)>) -> Self {
|
||||
self.stack = stack;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub enum TypedFunctionSymbol<'ast, T: Field> {
|
||||
Here(TypedFunction<'ast, T>),
|
||||
|
@ -375,7 +326,7 @@ pub trait Typed {
|
|||
pub enum TypedExpression<'ast, T: Field> {
|
||||
Boolean(BooleanExpression<'ast, T>),
|
||||
FieldElement(FieldElementExpression<'ast, T>),
|
||||
U8(U8Expression<'ast>),
|
||||
Uint(UExpression<'ast>),
|
||||
Array(ArrayExpression<'ast, T>),
|
||||
Struct(StructExpression<'ast, T>),
|
||||
}
|
||||
|
@ -392,9 +343,9 @@ impl<'ast, T: Field> From<FieldElementExpression<'ast, T>> for TypedExpression<'
|
|||
}
|
||||
}
|
||||
|
||||
impl<'ast, T: Field> From<U8Expression<'ast>> for TypedExpression<'ast, T> {
|
||||
fn from(e: U8Expression<'ast>) -> TypedExpression<T> {
|
||||
TypedExpression::U8(e)
|
||||
impl<'ast, T: Field> From<UExpression<'ast>> for TypedExpression<'ast, T> {
|
||||
fn from(e: UExpression<'ast>) -> TypedExpression<T> {
|
||||
TypedExpression::Uint(e)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -415,7 +366,7 @@ impl<'ast, T: Field> fmt::Display for TypedExpression<'ast, T> {
|
|||
match *self {
|
||||
TypedExpression::Boolean(ref e) => write!(f, "{}", e),
|
||||
TypedExpression::FieldElement(ref e) => write!(f, "{}", e),
|
||||
TypedExpression::U8(ref e) => write!(f, "{}", e),
|
||||
TypedExpression::Uint(ref e) => write!(f, "{}", e),
|
||||
TypedExpression::Array(ref e) => write!(f, "{}", e),
|
||||
TypedExpression::Struct(ref s) => write!(f, "{}", s),
|
||||
}
|
||||
|
@ -427,7 +378,7 @@ impl<'ast, T: Field> fmt::Debug for TypedExpression<'ast, T> {
|
|||
match *self {
|
||||
TypedExpression::Boolean(ref e) => write!(f, "{:?}", e),
|
||||
TypedExpression::FieldElement(ref e) => write!(f, "{:?}", e),
|
||||
TypedExpression::U8(ref e) => write!(f, "{:?}", e),
|
||||
TypedExpression::Uint(ref e) => write!(f, "{:?}", e),
|
||||
TypedExpression::Array(ref e) => write!(f, "{:?}", e),
|
||||
TypedExpression::Struct(ref s) => write!(f, "{}", s),
|
||||
}
|
||||
|
@ -496,7 +447,7 @@ impl<'ast, T: Field> Typed for TypedExpression<'ast, T> {
|
|||
TypedExpression::Boolean(ref e) => e.get_type(),
|
||||
TypedExpression::FieldElement(ref e) => e.get_type(),
|
||||
TypedExpression::Array(ref e) => e.get_type(),
|
||||
TypedExpression::U8(ref e) => e.get_type(),
|
||||
TypedExpression::Uint(ref e) => e.get_type(),
|
||||
TypedExpression::Struct(ref s) => s.get_type(),
|
||||
}
|
||||
}
|
||||
|
@ -520,9 +471,9 @@ impl<'ast, T: Field> Typed for FieldElementExpression<'ast, T> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<'ast> Typed for U8Expression<'ast> {
|
||||
impl<'ast> Typed for UExpression<'ast> {
|
||||
fn get_type(&self) -> Type {
|
||||
Type::U8
|
||||
Type::Uint(self.bitwidth)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -587,13 +538,6 @@ pub enum FieldElementExpression<'ast, T: Field> {
|
|||
),
|
||||
}
|
||||
|
||||
/// An expression of type `u8`
|
||||
#[derive(Clone, PartialEq, Hash, Eq)]
|
||||
pub enum U8Expression<'ast> {
|
||||
Value(u8),
|
||||
Identifier(Identifier<'ast>),
|
||||
}
|
||||
|
||||
/// An expression of type `bool`
|
||||
#[derive(Clone, PartialEq, Hash, Eq)]
|
||||
pub enum BooleanExpression<'ast, T: Field> {
|
||||
|
@ -772,12 +716,12 @@ impl<'ast, T: Field> TryFrom<TypedExpression<'ast, T>> for BooleanExpression<'as
|
|||
}
|
||||
}
|
||||
|
||||
impl<'ast, T: Field> TryFrom<TypedExpression<'ast, T>> for U8Expression<'ast> {
|
||||
impl<'ast, T: Field> TryFrom<TypedExpression<'ast, T>> for UExpression<'ast> {
|
||||
type Error = ();
|
||||
|
||||
fn try_from(te: TypedExpression<'ast, T>) -> Result<U8Expression<'ast>, Self::Error> {
|
||||
fn try_from(te: TypedExpression<'ast, T>) -> Result<UExpression<'ast>, Self::Error> {
|
||||
match te {
|
||||
TypedExpression::U8(e) => Ok(e),
|
||||
TypedExpression::Uint(e) => Ok(e),
|
||||
_ => Err(()),
|
||||
}
|
||||
}
|
||||
|
@ -838,7 +782,7 @@ impl<'ast, T: Field> fmt::Display for FieldElementExpression<'ast, T> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<'ast> fmt::Display for U8Expression<'ast> {
|
||||
impl<'ast> fmt::Display for UExpression<'ast> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
unimplemented!()
|
||||
}
|
||||
|
@ -909,12 +853,6 @@ impl<'ast, T: Field> fmt::Debug for BooleanExpression<'ast, T> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<'ast> fmt::Debug for U8Expression<'ast> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
write!(f, "{}", self)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast, T: Field> fmt::Debug for FieldElementExpression<'ast, T> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
match *self {
|
||||
|
@ -1057,14 +995,14 @@ impl<'ast, T: Field> IfElse<'ast, T> for BooleanExpression<'ast, T> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<'ast, T: Field> IfElse<'ast, T> for U8Expression<'ast> {
|
||||
impl<'ast, T: Field> IfElse<'ast, T> for UExpression<'ast> {
|
||||
fn if_else(
|
||||
condition: BooleanExpression<'ast, T>,
|
||||
consequence: Self,
|
||||
alternative: Self,
|
||||
) -> Self {
|
||||
unimplemented!()
|
||||
// U8Expression::IfElse(box condition, box consequence, box alternative)
|
||||
// UExpression::IfElse(box condition, box consequence, box alternative)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1108,10 +1046,10 @@ impl<'ast, T: Field> Select<'ast, T> for BooleanExpression<'ast, T> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<'ast, T: Field> Select<'ast, T> for U8Expression<'ast> {
|
||||
impl<'ast, T: Field> Select<'ast, T> for UExpression<'ast> {
|
||||
fn select(array: ArrayExpression<'ast, T>, index: FieldElementExpression<'ast, T>) -> Self {
|
||||
unimplemented!()
|
||||
// U8Expression::Select(box array, box index)
|
||||
// UExpression::Select(box array, box index)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1153,10 +1091,10 @@ impl<'ast, T: Field> Member<'ast, T> for BooleanExpression<'ast, T> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<'ast, T: Field> Member<'ast, T> for U8Expression<'ast> {
|
||||
impl<'ast, T: Field> Member<'ast, T> for UExpression<'ast> {
|
||||
fn member(s: StructExpression<'ast, T>, member_id: MemberId) -> Self {
|
||||
unimplemented!()
|
||||
// U8Expression::Member(box s, member_id)
|
||||
// UExpression::Member(box s, member_id)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -8,7 +8,7 @@ pub type MemberId = String;
|
|||
pub enum Type {
|
||||
FieldElement,
|
||||
Boolean,
|
||||
U8,
|
||||
Uint(usize),
|
||||
Array(Box<Type>, usize),
|
||||
Struct(Vec<(MemberId, Type)>),
|
||||
}
|
||||
|
@ -18,7 +18,7 @@ impl fmt::Display for Type {
|
|||
match self {
|
||||
Type::FieldElement => write!(f, "field"),
|
||||
Type::Boolean => write!(f, "bool"),
|
||||
Type::U8 => write!(f, "u8"),
|
||||
Type::Uint(ref bitwidth) => write!(f, "u{}", bitwidth),
|
||||
Type::Array(ref ty, ref size) => write!(f, "{}[{}]", ty, size),
|
||||
Type::Struct(ref members) => write!(
|
||||
f,
|
||||
|
@ -38,7 +38,7 @@ impl fmt::Debug for Type {
|
|||
match self {
|
||||
Type::FieldElement => write!(f, "field"),
|
||||
Type::Boolean => write!(f, "bool"),
|
||||
Type::U8 => write!(f, "u8"),
|
||||
Type::Uint(ref bitwidth) => write!(f, "u{}", bitwidth),
|
||||
Type::Array(ref ty, ref size) => write!(f, "{}[{}]", ty, size),
|
||||
Type::Struct(ref members) => write!(
|
||||
f,
|
||||
|
@ -62,7 +62,7 @@ impl Type {
|
|||
match self {
|
||||
Type::FieldElement => String::from("f"),
|
||||
Type::Boolean => String::from("b"),
|
||||
Type::U8 => String::from("u8"),
|
||||
Type::Uint(bitwidth) => format!("u{}", bitwidth),
|
||||
Type::Array(box ty, size) => format!("{}[{}]", ty.to_slug(), size),
|
||||
Type::Struct(members) => format!(
|
||||
"{{{}}}",
|
||||
|
@ -80,7 +80,7 @@ impl Type {
|
|||
match self {
|
||||
Type::FieldElement => 1,
|
||||
Type::Boolean => 1,
|
||||
Type::U8 => 1,
|
||||
Type::Uint(_) => 1,
|
||||
Type::Array(ty, size) => size * ty.get_primitive_count(),
|
||||
Type::Struct(members) => members.iter().map(|(_, t)| t.get_primitive_count()).sum(),
|
||||
}
|
||||
|
|
359
zokrates_core/src/typed_absy/uint.rs
Normal file
359
zokrates_core/src/typed_absy/uint.rs
Normal file
|
@ -0,0 +1,359 @@
|
|||
use typed_absy::identifier::Identifier;
|
||||
use zokrates_field::field::Field;
|
||||
|
||||
type Bitwidth = usize;
|
||||
|
||||
impl<'ast> UExpression<'ast> {
|
||||
pub fn add(self, other: Self) -> UExpression<'ast> {
|
||||
let bitwidth = self.bitwidth;
|
||||
assert_eq!(bitwidth, other.bitwidth);
|
||||
UExpressionInner::Add(box self, box other).annotate(bitwidth)
|
||||
}
|
||||
|
||||
pub fn mult(self, other: Self) -> UExpression<'ast> {
|
||||
let bitwidth = self.bitwidth;
|
||||
assert_eq!(bitwidth, other.bitwidth);
|
||||
UExpressionInner::Mult(box self, box other).annotate(bitwidth)
|
||||
}
|
||||
|
||||
pub fn xor(self, other: Self) -> UExpression<'ast> {
|
||||
let bitwidth = self.bitwidth;
|
||||
assert_eq!(bitwidth, other.bitwidth);
|
||||
UExpressionInner::Xor(box self, box other).annotate(bitwidth)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast> From<u128> for UExpressionInner<'ast> {
|
||||
fn from(e: u128) -> Self {
|
||||
UExpressionInner::Value(e)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast> From<&'ast str> for UExpressionInner<'ast> {
|
||||
fn from(e: &'ast str) -> Self {
|
||||
UExpressionInner::Identifier(e.into())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
||||
pub struct UMetadata {
|
||||
pub bitwidth: Option<Bitwidth>,
|
||||
pub should_reduce: Option<bool>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
||||
pub struct UExpression<'ast> {
|
||||
pub bitwidth: Bitwidth,
|
||||
pub metadata: Option<UMetadata>,
|
||||
pub inner: UExpressionInner<'ast>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
||||
pub enum UExpressionInner<'ast> {
|
||||
Identifier(Identifier<'ast>),
|
||||
Value(u128),
|
||||
Add(Box<UExpression<'ast>>, Box<UExpression<'ast>>),
|
||||
Mult(Box<UExpression<'ast>>, Box<UExpression<'ast>>),
|
||||
Xor(Box<UExpression<'ast>>, Box<UExpression<'ast>>),
|
||||
}
|
||||
|
||||
impl<'ast> UExpressionInner<'ast> {
|
||||
pub fn annotate(self, bitwidth: Bitwidth) -> UExpression<'ast> {
|
||||
UExpression {
|
||||
metadata: None,
|
||||
bitwidth,
|
||||
inner: self,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast> UExpression<'ast> {
|
||||
fn metadata(self, metadata: UMetadata) -> UExpression<'ast> {
|
||||
UExpression {
|
||||
metadata: Some(metadata),
|
||||
..self
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn bitwidth(a: u128) -> Bitwidth {
|
||||
(128 - a.leading_zeros()) as Bitwidth
|
||||
}
|
||||
|
||||
impl<'ast> UExpression<'ast> {
|
||||
pub fn reduce<T: Field>(self) -> Self {
|
||||
let max_bitwidth = T::get_required_bits() - 1;
|
||||
|
||||
let range = self.bitwidth;
|
||||
|
||||
assert!(range < max_bitwidth / 2);
|
||||
|
||||
let metadata = self.metadata;
|
||||
let inner = self.inner;
|
||||
|
||||
use self::UExpressionInner::*;
|
||||
|
||||
match inner {
|
||||
Value(v) => Value(v).annotate(range).metadata(UMetadata {
|
||||
bitwidth: Some(bitwidth(v)),
|
||||
should_reduce: Some(
|
||||
metadata
|
||||
.map(|m| m.should_reduce.unwrap_or(false))
|
||||
.unwrap_or(false),
|
||||
),
|
||||
}),
|
||||
Identifier(id) => Identifier(id).annotate(range).metadata(UMetadata {
|
||||
bitwidth: Some(range),
|
||||
should_reduce: Some(
|
||||
metadata
|
||||
.map(|m| m.should_reduce.unwrap_or(false))
|
||||
.unwrap_or(false),
|
||||
),
|
||||
}),
|
||||
Add(box left, box right) => {
|
||||
// reduce the two terms
|
||||
let left = left.reduce::<T>();
|
||||
let right = right.reduce::<T>();
|
||||
|
||||
let left_metadata = left.metadata.clone().unwrap();
|
||||
let right_metadata = right.metadata.clone().unwrap();
|
||||
|
||||
// determine the bitwidth of each term. It's their current bitwidth, unless they are tagged as `should_reduce` in which case they now have bitwidth 8
|
||||
let left_bitwidth = left_metadata
|
||||
.should_reduce
|
||||
.map(|should_reduce| {
|
||||
if should_reduce {
|
||||
range
|
||||
} else {
|
||||
left_metadata.bitwidth.unwrap()
|
||||
}
|
||||
})
|
||||
.unwrap();
|
||||
let right_bitwidth = right_metadata
|
||||
.should_reduce
|
||||
.map(|should_reduce| {
|
||||
if should_reduce {
|
||||
range
|
||||
} else {
|
||||
right_metadata.bitwidth.unwrap()
|
||||
}
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
let output_width = std::cmp::max(left_bitwidth, right_bitwidth) + 1; // bitwidth(a + b) = max(bitwidth(a), bitwidth(b)) + 1
|
||||
|
||||
if output_width > max_bitwidth {
|
||||
// the addition doesnt fit, we reduce both terms first (TODO maybe one would be enough here)
|
||||
|
||||
let left = UExpression {
|
||||
metadata: Some(UMetadata {
|
||||
should_reduce: Some(true),
|
||||
..left_metadata
|
||||
}),
|
||||
..left
|
||||
};
|
||||
|
||||
let right = UExpression {
|
||||
metadata: Some(UMetadata {
|
||||
should_reduce: Some(true),
|
||||
..right_metadata
|
||||
}),
|
||||
..right
|
||||
};
|
||||
|
||||
UExpression::add(left, right).metadata(UMetadata {
|
||||
bitwidth: Some(range + 1),
|
||||
should_reduce: Some(
|
||||
metadata
|
||||
.map(|m| m.should_reduce.unwrap_or(false))
|
||||
.unwrap_or(false),
|
||||
),
|
||||
})
|
||||
} else {
|
||||
// the addition fits, so we just add
|
||||
UExpression::add(left, right).metadata(UMetadata {
|
||||
bitwidth: Some(output_width),
|
||||
should_reduce: Some(
|
||||
metadata
|
||||
.map(|m| m.should_reduce.unwrap_or(false))
|
||||
.unwrap_or(false),
|
||||
),
|
||||
})
|
||||
}
|
||||
}
|
||||
Xor(box left, box right) => {
|
||||
// reduce the two terms
|
||||
let left = left.reduce::<T>();
|
||||
let right = right.reduce::<T>();
|
||||
|
||||
let left_metadata = left.metadata.clone().unwrap();
|
||||
let right_metadata = right.metadata.clone().unwrap();
|
||||
|
||||
// for xor we need both terms to be in range. Therefore we reduce them to being in range.
|
||||
// NB: if they are already in range, the flattening process will ignore the reduction
|
||||
let left = left.metadata(UMetadata {
|
||||
should_reduce: Some(true),
|
||||
..left_metadata
|
||||
});
|
||||
|
||||
let right = right.metadata(UMetadata {
|
||||
should_reduce: Some(true),
|
||||
..right_metadata
|
||||
});
|
||||
|
||||
UExpression::xor(left, right)
|
||||
}
|
||||
Mult(box left, box right) => {
|
||||
// reduce the two terms
|
||||
let left = left.reduce::<T>();
|
||||
let right = right.reduce::<T>();
|
||||
|
||||
let left_metadata = left.metadata.clone().unwrap();
|
||||
let right_metadata = right.metadata.clone().unwrap();
|
||||
|
||||
// determine the bitwidth of each term. It's their current bitwidth, unless they are tagged as `should_reduce` in which case they now have bitwidth 8
|
||||
let left_bitwidth = left_metadata
|
||||
.should_reduce
|
||||
.map(|should_reduce| {
|
||||
if should_reduce {
|
||||
range
|
||||
} else {
|
||||
left_metadata.bitwidth.unwrap()
|
||||
}
|
||||
})
|
||||
.unwrap();
|
||||
let right_bitwidth = right_metadata
|
||||
.should_reduce
|
||||
.map(|should_reduce| {
|
||||
if should_reduce {
|
||||
range
|
||||
} else {
|
||||
right_metadata.bitwidth.unwrap()
|
||||
}
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
let output_width = left_bitwidth + right_bitwidth; // bitwidth(a*b) = bitwidth(a) + bitwidth(b)
|
||||
|
||||
if output_width > max_bitwidth {
|
||||
// the multiplication doesnt fit, we reduce both terms first (TODO maybe one would be enough here)
|
||||
|
||||
let left = UExpression {
|
||||
metadata: Some(UMetadata {
|
||||
should_reduce: Some(true),
|
||||
..left_metadata
|
||||
}),
|
||||
..left
|
||||
};
|
||||
|
||||
let right = UExpression {
|
||||
metadata: Some(UMetadata {
|
||||
should_reduce: Some(true),
|
||||
..right_metadata
|
||||
}),
|
||||
..right
|
||||
};
|
||||
|
||||
UExpression::mult(left, right).metadata(UMetadata {
|
||||
bitwidth: Some(2 * range),
|
||||
should_reduce: Some(
|
||||
metadata
|
||||
.map(|m| m.should_reduce.unwrap_or(false))
|
||||
.unwrap_or(false),
|
||||
),
|
||||
})
|
||||
} else {
|
||||
// the multiplication fits, so we just multiply
|
||||
UExpression::mult(left, right).metadata(UMetadata {
|
||||
bitwidth: Some(output_width),
|
||||
should_reduce: Some(
|
||||
metadata
|
||||
.map(|m| m.should_reduce.unwrap_or(false))
|
||||
.unwrap_or(false),
|
||||
),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast> UExpression<'ast> {
|
||||
pub fn bitwidth(&self) -> Bitwidth {
|
||||
self.bitwidth
|
||||
}
|
||||
|
||||
pub fn as_inner(&self) -> &UExpressionInner<'ast> {
|
||||
&self.inner
|
||||
}
|
||||
|
||||
pub fn into_inner(self) -> UExpressionInner<'ast> {
|
||||
self.inner
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use zokrates_field::field::FieldPrime;
|
||||
|
||||
fn count_readjustments(e: UExpression) -> usize {
|
||||
let metadata = e.metadata;
|
||||
let inner = e.inner;
|
||||
|
||||
use self::UExpressionInner::*;
|
||||
|
||||
match inner {
|
||||
Identifier(_) => 0,
|
||||
Value(_) => 0,
|
||||
Mult(box left, box right) | Xor(box left, box right) | Add(box left, box right) => {
|
||||
let l = count_readjustments(left);
|
||||
let r = count_readjustments(right);
|
||||
r + l
|
||||
+ if metadata
|
||||
.map(|m| m.should_reduce.unwrap_or(true))
|
||||
.unwrap_or(true)
|
||||
{
|
||||
1
|
||||
} else {
|
||||
0
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn _100_times_a() {
|
||||
// a * 100 where a is 8 bits and the host field prime has 254 bits
|
||||
// we don't readjust until we overflow 253 bits
|
||||
// each addition increases the bitwidth by 1
|
||||
// 253 = 100*1 + 153, so we can do all multiplications without readjusting
|
||||
|
||||
let e = (0..100).fold(
|
||||
UExpressionInner::Identifier("a".into()).annotate(8),
|
||||
|acc, _| UExpression::add(acc, UExpressionInner::Identifier("a".into()).annotate(8)),
|
||||
);
|
||||
|
||||
let e = e.reduce::<FieldPrime>();
|
||||
|
||||
assert_eq!(count_readjustments(e), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn _100_powers_of_a() {
|
||||
// a ** 100 where a is 8 bits and the host field prime has 254 bits
|
||||
// we don't readjust until we overflow 253 bits
|
||||
// each multiplication increases the bitwidth by 8
|
||||
// 253 = 31 * 8 + 5, so we do 31 multiplications followed by one readjustment, three times. Then we have 7 multiplications left
|
||||
// we readjusted 3 times
|
||||
|
||||
let e = (0..100).fold(
|
||||
UExpressionInner::Identifier("a".into()).annotate(8),
|
||||
|acc, _| UExpression::mult(acc, UExpressionInner::Identifier("a".into()).annotate(8)),
|
||||
);
|
||||
|
||||
let e = e.reduce::<FieldPrime>();
|
||||
|
||||
assert_eq!(count_readjustments(e), 3);
|
||||
}
|
||||
}
|
|
@ -35,6 +35,7 @@ pub trait Field:
|
|||
From<i32>
|
||||
+ From<u32>
|
||||
+ From<usize>
|
||||
+ From<u128>
|
||||
+ Zero
|
||||
+ One
|
||||
+ Clone
|
||||
|
@ -205,6 +206,15 @@ impl From<usize> for FieldPrime {
|
|||
}
|
||||
}
|
||||
|
||||
impl From<u128> for FieldPrime {
|
||||
fn from(num: u128) -> Self {
|
||||
let x = ToBigInt::to_bigint(&num).unwrap();
|
||||
FieldPrime {
|
||||
value: &x - x.div_floor(&*P) * &*P,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Zero for FieldPrime {
|
||||
fn zero() -> FieldPrime {
|
||||
FieldPrime {
|
||||
|
|
|
@ -18,7 +18,8 @@ parameter = {vis? ~ ty ~ identifier}
|
|||
ty_field = {"field"}
|
||||
ty_bool = {"bool"}
|
||||
ty_u8 = {"u8"}
|
||||
ty_basic = { ty_field | ty_bool | ty_u8 }
|
||||
ty_u32 = {"u32"}
|
||||
ty_basic = { ty_field | ty_bool | ty_u8 | ty_u32 }
|
||||
ty_basic_or_struct = { ty_basic | ty_struct }
|
||||
ty_array = { ty_basic_or_struct ~ ("[" ~ expression ~ "]")+ }
|
||||
ty = { ty_array | ty_basic | ty_struct }
|
||||
|
|
|
@ -244,6 +244,7 @@ mod ast {
|
|||
Field(FieldType<'ast>),
|
||||
Boolean(BooleanType<'ast>),
|
||||
U8(U8Type<'ast>),
|
||||
U32(U32Type<'ast>),
|
||||
}
|
||||
|
||||
#[derive(Debug, FromPest, PartialEq, Clone)]
|
||||
|
@ -283,6 +284,13 @@ mod ast {
|
|||
pub span: Span<'ast>,
|
||||
}
|
||||
|
||||
#[derive(Debug, FromPest, PartialEq, Clone)]
|
||||
#[pest_ast(rule(Rule::ty_u32))]
|
||||
pub struct U32Type<'ast> {
|
||||
#[pest_ast(outer())]
|
||||
pub span: Span<'ast>,
|
||||
}
|
||||
|
||||
#[derive(Debug, FromPest, PartialEq, Clone)]
|
||||
#[pest_ast(rule(Rule::ty_struct))]
|
||||
pub struct StructType<'ast> {
|
||||
|
|
Loading…
Reference in a new issue