1
0
Fork 0
mirror of synced 2025-09-22 19:57:54 +00:00

refactor, fix tests

This commit is contained in:
dark64 2023-02-28 02:04:02 +01:00
parent 6f61a93855
commit 1ef8649150
14 changed files with 314 additions and 144 deletions

1
Cargo.lock generated
View file

@ -2959,6 +2959,7 @@ name = "zokrates_ast"
version = "0.1.4"
dependencies = [
"ark-bls12-377",
"byteorder",
"cfg-if 0.1.10",
"csv",
"derivative",

View file

@ -123,6 +123,7 @@ mod tests {
arguments: vec![Parameter::public(Variable::new(0))],
return_count: 1,
statements: vec![Statement::constraint(Variable::new(0), Variable::public(0))],
solvers: vec![],
};
let rng = &mut StdRng::from_entropy();
@ -148,6 +149,7 @@ mod tests {
arguments: vec![Parameter::public(Variable::new(0))],
return_count: 1,
statements: vec![Statement::constraint(Variable::new(0), Variable::public(0))],
solvers: vec![],
};
let rng = &mut StdRng::from_entropy();

View file

@ -120,6 +120,7 @@ mod tests {
arguments: vec![Parameter::public(Variable::new(0))],
return_count: 1,
statements: vec![Statement::constraint(Variable::new(0), Variable::public(0))],
solvers: vec![],
};
let rng = &mut StdRng::from_entropy();
@ -145,6 +146,7 @@ mod tests {
arguments: vec![Parameter::public(Variable::new(0))],
return_count: 1,
statements: vec![Statement::constraint(Variable::new(0), Variable::public(0))],
solvers: vec![],
};
let rng = &mut StdRng::from_entropy();

View file

@ -404,6 +404,7 @@ mod tests {
),
Statement::constraint(Variable::new(1), Variable::public(0)),
],
solvers: vec![],
};
let rng = &mut StdRng::from_entropy();
@ -439,6 +440,7 @@ mod tests {
),
Statement::constraint(Variable::new(1), Variable::public(0)),
],
solvers: vec![],
};
let rng = &mut StdRng::from_entropy();

View file

@ -9,6 +9,7 @@ bellman = ["zokrates_field/bellman", "pairing_ce", "zokrates_embed/bellman"]
ark = ["ark-bls12-377", "zokrates_embed/ark"]
[dependencies]
byteorder = "1.4.3"
zokrates_pest_ast = { version = "0.3.0", path = "../zokrates_pest_ast" }
cfg-if = "0.1"
zokrates_field = { version = "0.5", path = "../zokrates_field", default-features = false }

View file

@ -4,6 +4,7 @@ use crate::{
};
use super::{ProgIterator, Statement};
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
use serde::Deserialize;
use serde_cbor::{self, StreamDeserializer};
use std::io::{Read, Seek, Write};
@ -12,7 +13,7 @@ use zokrates_field::*;
type DynamicError = Box<dyn std::error::Error>;
const ZOKRATES_MAGIC: &[u8; 4] = &[0x5a, 0x4f, 0x4b, 0];
const ZOKRATES_VERSION_3: &[u8; 4] = &[0, 0, 0, 3];
const FILE_VERSION: &[u8; 4] = &[3, 0, 0, 0];
#[derive(PartialEq, Eq, Debug)]
pub enum ProgEnum<
@ -62,47 +63,195 @@ impl<
}
}
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
pub enum SectionType {
Arguments = 1,
Constraints = 2,
Solvers = 3,
}
impl TryFrom<u32> for SectionType {
type Error = String;
fn try_from(value: u32) -> Result<Self, Self::Error> {
match value {
1 => Ok(SectionType::Arguments),
2 => Ok(SectionType::Constraints),
3 => Ok(SectionType::Solvers),
_ => Err("invalid section type".to_string()),
}
}
}
#[derive(Debug, Clone)]
pub struct Section {
pub ty: SectionType,
pub offset: u64,
pub length: u64,
}
impl Section {
pub fn new(ty: SectionType) -> Self {
Self {
ty,
offset: 0,
length: 0,
}
}
pub fn set_offset(&mut self, offset: u64) {
self.offset = offset;
}
pub fn set_length(&mut self, length: u64) {
self.length = length;
}
}
#[derive(Debug, Clone)]
pub struct ProgHeader {
pub magic: [u8; 4],
pub version: [u8; 4],
pub curve_id: [u8; 4],
pub constraint_count: u32,
pub return_count: u32,
pub sections: Vec<Section>,
}
impl ProgHeader {
pub fn write<W: Write>(&self, mut w: W) -> std::io::Result<()> {
w.write_all(&self.magic)?;
w.write_all(&self.version)?;
w.write_all(&self.curve_id)?;
w.write_u32::<LittleEndian>(self.constraint_count)?;
w.write_u32::<LittleEndian>(self.return_count)?;
w.write_u32::<LittleEndian>(self.sections.len() as u32)?;
for s in &self.sections {
w.write_u32::<LittleEndian>(s.ty as u32)?;
w.write_u64::<LittleEndian>(s.offset)?;
w.write_u64::<LittleEndian>(s.length)?;
}
Ok(())
}
pub fn read<R: Read>(mut r: R) -> std::io::Result<Self> {
let mut magic = [0; 4];
r.read_exact(&mut magic)?;
let mut version = [0; 4];
r.read_exact(&mut version)?;
let mut curve_id = [0; 4];
r.read_exact(&mut curve_id)?;
let constraint_count = r.read_u32::<LittleEndian>()?;
let return_count = r.read_u32::<LittleEndian>()?;
let section_count = r.read_u32::<LittleEndian>()?;
let mut sections = vec![];
for _ in 0..section_count {
let id = r.read_u32::<LittleEndian>()?;
let mut section = Section::new(
SectionType::try_from(id)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?,
);
section.set_offset(r.read_u64::<LittleEndian>()?);
section.set_length(r.read_u64::<LittleEndian>()?);
sections.push(section);
}
Ok(ProgHeader {
magic,
version,
curve_id,
constraint_count,
return_count,
sections,
})
}
fn get_section(&self, ty: SectionType) -> Option<&Section> {
self.sections.iter().find(|s| s.ty == ty)
}
}
impl<'ast, T: Field, I: IntoIterator<Item = Statement<'ast, T>>> ProgIterator<'ast, T, I> {
/// serialize a program iterator, returning the number of constraints serialized
/// Note that we only return constraints, not other statements such as directives
pub fn serialize<W: Write + Seek>(self, mut w: W) -> Result<usize, DynamicError> {
use super::folder::Folder;
w.write_all(ZOKRATES_MAGIC)?;
w.write_all(ZOKRATES_VERSION_3)?;
w.write_all(&T::id())?;
const SECTION_COUNT: usize = 3;
const HEADER_SIZE: usize = 24 + SECTION_COUNT * 20;
let solver_list_ptr_offset = w.stream_position()?;
w.write_all(&[0u8; std::mem::size_of::<u64>()])?; // reserve 8 bytes
let mut header = ProgHeader {
magic: *ZOKRATES_MAGIC,
version: *FILE_VERSION,
curve_id: T::id(),
constraint_count: 0,
return_count: self.return_count as u32,
sections: Vec::with_capacity(SECTION_COUNT),
};
serde_cbor::to_writer(&mut w, &self.arguments)?;
serde_cbor::to_writer(&mut w, &self.return_count)?;
w.write_all(&[0u8; HEADER_SIZE])?; // reserve bytes for the header
let arguments = {
let mut section = Section::new(SectionType::Arguments);
section.set_offset(w.stream_position()?);
serde_cbor::to_writer(&mut w, &self.arguments)?;
section.set_length(w.stream_position()? - section.offset);
section
};
let mut unconstrained_variable_detector = UnconstrainedVariableDetector::new(&self);
let mut solver_indexer: SolverIndexer<'ast, T> = SolverIndexer::default();
let mut unconstrained_variable_detector = UnconstrainedVariableDetector::new(&self);
let mut count: usize = 0;
let statements = self.statements.into_iter();
let constraints = {
let mut section = Section::new(SectionType::Constraints);
section.set_offset(w.stream_position()?);
let mut count = 0;
for s in statements {
if matches!(s, Statement::Constraint(..)) {
count += 1;
let statements = self.statements.into_iter();
for s in statements {
if matches!(s, Statement::Constraint(..)) {
count += 1;
}
let s: Vec<Statement<T>> = solver_indexer
.fold_statement(s)
.into_iter()
.flat_map(|s| unconstrained_variable_detector.fold_statement(s))
.collect();
for s in s {
serde_cbor::to_writer(&mut w, &s)?;
}
}
let s: Vec<Statement<T>> = solver_indexer
.fold_statement(s)
.into_iter()
.flat_map(|s| unconstrained_variable_detector.fold_statement(s))
.collect();
for s in s {
serde_cbor::to_writer(&mut w, &s)?;
}
}
let solver_list_offset = w.stream_position()?;
serde_cbor::to_writer(&mut w, &solver_indexer.solvers)?;
section.set_length(w.stream_position()? - section.offset);
section
};
w.seek(std::io::SeekFrom::Start(solver_list_ptr_offset))?;
w.write_all(&solver_list_offset.to_le_bytes())?;
let solvers = {
let mut section = Section::new(SectionType::Solvers);
section.set_offset(w.stream_position()?);
serde_cbor::to_writer(&mut w, &solver_indexer.solvers)?;
section.set_length(w.stream_position()? - section.offset);
section
};
header.constraint_count = count as u32;
header
.sections
.extend_from_slice(&[arguments, constraints, solvers]);
w.rewind()?;
header.write(&mut w)?;
unconstrained_variable_detector
.finalize()
@ -135,128 +284,108 @@ impl<'de, R: Read + Seek>
>
{
pub fn deserialize(mut r: R) -> Result<Self, String> {
let header = ProgHeader::read(&mut r).map_err(|_| String::from("Invalid header"))?;
// Check the magic number, `ZOK`
let mut magic = [0; 4];
r.read_exact(&mut magic)
.map_err(|_| String::from("Cannot read magic number"))?;
if &header.magic != ZOKRATES_MAGIC {
return Err("Invalid magic number".to_string());
}
if &magic == ZOKRATES_MAGIC {
// Check the version, 2
let mut version = [0; 4];
r.read_exact(&mut version)
.map_err(|_| String::from("Cannot read version"))?;
// Check the file version
if &header.version != FILE_VERSION {
return Err("Invalid file version".to_string());
}
if &version == ZOKRATES_VERSION_3 {
// Check the curve identifier, deserializing accordingly
let mut curve = [0; 4];
r.read_exact(&mut curve)
.map_err(|_| String::from("Cannot read curve identifier"))?;
let arguments = {
let section = header.get_section(SectionType::Arguments).unwrap();
r.seek(std::io::SeekFrom::Start(section.offset)).unwrap();
let mut buffer = [0u8; std::mem::size_of::<u64>()];
r.read_exact(&mut buffer)
.map_err(|_| String::from("Cannot read solver list offset"))?;
let mut p = serde_cbor::Deserializer::from_reader(r.by_ref());
Vec::deserialize(&mut p).map_err(|_| String::from("Cannot read parameters"))?
};
let solver_list_offset = u64::from_le_bytes(buffer);
let solvers_section = header.get_section(SectionType::Solvers).unwrap();
r.seek(std::io::SeekFrom::Start(solvers_section.offset))
.unwrap();
let (arguments, return_count) = {
match header.curve_id {
m if m == Bls12_381Field::id() => {
let solvers: Vec<Solver<'de, Bls12_381Field>> = {
let mut p = serde_cbor::Deserializer::from_reader(r.by_ref());
let arguments: Vec<super::Parameter> = Vec::deserialize(&mut p)
.map_err(|_| String::from("Cannot read parameters"))?;
let return_count = usize::deserialize(&mut p)
.map_err(|_| String::from("Cannot read return count"))?;
(arguments, return_count)
Vec::deserialize(&mut p).map_err(|_| String::from("Cannot read solvers"))?
};
let statement_offset = r.stream_position().unwrap();
r.seek(std::io::SeekFrom::Start(solver_list_offset))
.unwrap();
let section = header.get_section(SectionType::Constraints).unwrap();
r.seek(std::io::SeekFrom::Start(section.offset)).unwrap();
match curve {
m if m == Bls12_381Field::id() => {
let solvers: Vec<Solver<'de, Bls12_381Field>> = {
let mut p = serde_cbor::Deserializer::from_reader(r.by_ref());
Vec::deserialize(&mut p)
.map_err(|_| String::from("Cannot read solver list"))?
};
let p = serde_cbor::Deserializer::from_reader(r);
let s = p.into_iter::<Statement<Bls12_381Field>>();
r.seek(std::io::SeekFrom::Start(statement_offset)).unwrap();
let p = serde_cbor::Deserializer::from_reader(r);
let s = p.into_iter::<Statement<Bls12_381Field>>();
Ok(ProgEnum::Bls12_381Program(ProgIterator::new(
arguments,
UnwrappedStreamDeserializer { s },
return_count,
solvers,
)))
}
m if m == Bn128Field::id() => {
let solvers: Vec<Solver<'de, Bn128Field>> = {
let mut p = serde_cbor::Deserializer::from_reader(r.by_ref());
Vec::deserialize(&mut p)
.map_err(|_| String::from("Cannot read solver list"))?
};
r.seek(std::io::SeekFrom::Start(statement_offset)).unwrap();
let p = serde_cbor::Deserializer::from_reader(r);
let s = p.into_iter::<Statement<Bn128Field>>();
Ok(ProgEnum::Bn128Program(ProgIterator::new(
arguments,
UnwrappedStreamDeserializer { s },
return_count,
solvers,
)))
}
m if m == Bls12_377Field::id() => {
let solvers: Vec<Solver<'de, Bls12_377Field>> = {
let mut p = serde_cbor::Deserializer::from_reader(r.by_ref());
Vec::deserialize(&mut p)
.map_err(|_| String::from("Cannot read solver list"))?
};
r.seek(std::io::SeekFrom::Start(statement_offset)).unwrap();
let p = serde_cbor::Deserializer::from_reader(r);
let s = p.into_iter::<Statement<Bls12_377Field>>();
Ok(ProgEnum::Bls12_377Program(ProgIterator::new(
arguments,
UnwrappedStreamDeserializer { s },
return_count,
solvers,
)))
}
m if m == Bw6_761Field::id() => {
let solvers: Vec<Solver<'de, Bw6_761Field>> = {
let mut p = serde_cbor::Deserializer::from_reader(r.by_ref());
Vec::deserialize(&mut p)
.map_err(|_| String::from("Cannot read solver list"))?
};
r.seek(std::io::SeekFrom::Start(statement_offset)).unwrap();
let p = serde_cbor::Deserializer::from_reader(r);
let s = p.into_iter::<Statement<Bw6_761Field>>();
Ok(ProgEnum::Bw6_761Program(ProgIterator::new(
arguments,
UnwrappedStreamDeserializer { s },
return_count,
solvers,
)))
}
_ => Err(String::from("Unknown curve identifier")),
}
} else {
Err(String::from("Unknown version"))
Ok(ProgEnum::Bls12_381Program(ProgIterator::new(
arguments,
UnwrappedStreamDeserializer { s },
header.return_count as usize,
solvers,
)))
}
} else {
Err(String::from("Wrong magic number"))
m if m == Bn128Field::id() => {
let solvers: Vec<Solver<'de, Bn128Field>> = {
let mut p = serde_cbor::Deserializer::from_reader(r.by_ref());
Vec::deserialize(&mut p).map_err(|_| String::from("Cannot read solvers"))?
};
let section = header.get_section(SectionType::Constraints).unwrap();
r.seek(std::io::SeekFrom::Start(section.offset)).unwrap();
let p = serde_cbor::Deserializer::from_reader(r);
let s = p.into_iter::<Statement<Bn128Field>>();
Ok(ProgEnum::Bn128Program(ProgIterator::new(
arguments,
UnwrappedStreamDeserializer { s },
header.return_count as usize,
solvers,
)))
}
m if m == Bls12_377Field::id() => {
let solvers: Vec<Solver<'de, Bls12_377Field>> = {
let mut p = serde_cbor::Deserializer::from_reader(r.by_ref());
Vec::deserialize(&mut p).map_err(|_| String::from("Cannot read solvers"))?
};
let section = header.get_section(SectionType::Constraints).unwrap();
r.seek(std::io::SeekFrom::Start(section.offset)).unwrap();
let p = serde_cbor::Deserializer::from_reader(r);
let s = p.into_iter::<Statement<Bls12_377Field>>();
Ok(ProgEnum::Bls12_377Program(ProgIterator::new(
arguments,
UnwrappedStreamDeserializer { s },
header.return_count as usize,
solvers,
)))
}
m if m == Bw6_761Field::id() => {
let solvers: Vec<Solver<'de, Bw6_761Field>> = {
let mut p = serde_cbor::Deserializer::from_reader(r.by_ref());
Vec::deserialize(&mut p).map_err(|_| String::from("Cannot read solvers"))?
};
let section = header.get_section(SectionType::Constraints).unwrap();
r.seek(std::io::SeekFrom::Start(section.offset)).unwrap();
let p = serde_cbor::Deserializer::from_reader(r);
let s = p.into_iter::<Statement<Bw6_761Field>>();
Ok(ProgEnum::Bw6_761Program(ProgIterator::new(
arguments,
UnwrappedStreamDeserializer { s },
header.return_count as usize,
solvers,
)))
}
_ => Err(String::from("Unknown curve identifier")),
}
}
}

View file

@ -214,6 +214,7 @@ mod tests {
arguments: vec![Parameter::public(Variable::new(0))],
return_count: 1,
statements: vec![Statement::constraint(Variable::new(0), Variable::public(0))],
solvers: vec![],
};
let rng = &mut StdRng::from_entropy();

View file

@ -276,6 +276,7 @@ mod tests {
arguments: vec![Parameter::private(Variable::new(0))],
return_count: 1,
statements: vec![Statement::constraint(Variable::new(0), Variable::public(0))],
solvers: vec![],
};
let interpreter = Interpreter::default();
@ -297,6 +298,7 @@ mod tests {
arguments: vec![Parameter::public(Variable::new(0))],
return_count: 1,
statements: vec![Statement::constraint(Variable::new(0), Variable::public(0))],
solvers: vec![],
};
let interpreter = Interpreter::default();
@ -318,6 +320,7 @@ mod tests {
arguments: vec![],
return_count: 1,
statements: vec![Statement::constraint(Variable::one(), Variable::public(0))],
solvers: vec![],
};
let interpreter = Interpreter::default();
@ -350,6 +353,7 @@ mod tests {
Variable::public(1),
),
],
solvers: vec![],
};
let interpreter = Interpreter::default();
@ -373,6 +377,7 @@ mod tests {
LinComb::from(Variable::new(42)) + LinComb::one(),
Variable::public(0),
)],
solvers: vec![],
};
let interpreter = Interpreter::default();
@ -400,6 +405,7 @@ mod tests {
LinComb::from(Variable::new(42)) + LinComb::from(Variable::new(51)),
Variable::public(0),
)],
solvers: vec![],
};
let interpreter = Interpreter::default();

View file

@ -44,6 +44,7 @@ mod tests {
None,
),
],
solvers: vec![],
};
let mut r1cs = vec![];

View file

@ -296,6 +296,7 @@ mod tests {
Variable::public(0).into(),
None,
)],
solvers: vec![],
};
let mut buf = Vec::new();
@ -365,6 +366,7 @@ mod tests {
None,
),
],
solvers: vec![],
};
let mut buf = Vec::new();

View file

@ -81,6 +81,7 @@ mod tests {
],
return_count: 0,
arguments: vec![],
solvers: vec![],
};
let expected = p.clone();
@ -117,6 +118,7 @@ mod tests {
],
return_count: 0,
arguments: vec![],
solvers: vec![],
};
let expected = Prog {
@ -132,6 +134,7 @@ mod tests {
],
return_count: 0,
arguments: vec![],
solvers: vec![],
};
assert_eq!(

View file

@ -256,12 +256,14 @@ mod tests {
Statement::definition(out, y),
],
return_count: 1,
solvers: vec![],
};
let optimized: Prog<Bn128Field> = Prog {
arguments: vec![x],
statements: vec![Statement::definition(out, x.id)],
return_count: 1,
solvers: vec![],
};
let mut optimizer = RedefinitionOptimizer::init(&p);
@ -280,6 +282,7 @@ mod tests {
arguments: vec![x],
statements: vec![Statement::definition(one, x.id)],
return_count: 1,
solvers: vec![],
};
let optimized = p.clone();
@ -316,6 +319,7 @@ mod tests {
Statement::definition(out, z),
],
return_count: 1,
solvers: vec![],
};
let optimized: Prog<Bn128Field> = Prog {
@ -325,6 +329,7 @@ mod tests {
Statement::definition(out, x.id),
],
return_count: 1,
solvers: vec![],
};
let mut optimizer = RedefinitionOptimizer::init(&p);
@ -365,6 +370,7 @@ mod tests {
Statement::definition(out_1, w),
],
return_count: 2,
solvers: vec![],
};
let optimized: Prog<Bn128Field> = Prog {
@ -374,6 +380,7 @@ mod tests {
Statement::definition(out_1, Bn128Field::from(1)),
],
return_count: 2,
solvers: vec![],
};
let mut optimizer = RedefinitionOptimizer::init(&p);
@ -422,6 +429,7 @@ mod tests {
Statement::definition(r, LinComb::from(a) + LinComb::from(b) + LinComb::from(c)),
],
return_count: 1,
solvers: vec![],
};
let expected: Prog<Bn128Field> = Prog {
@ -442,6 +450,7 @@ mod tests {
),
],
return_count: 1,
solvers: vec![],
};
let mut optimizer = RedefinitionOptimizer::init(&p);
@ -479,6 +488,7 @@ mod tests {
Statement::definition(z, LinComb::from(x.id)),
],
return_count: 0,
solvers: vec![],
};
let optimized = p.clone();
@ -507,6 +517,7 @@ mod tests {
Statement::constraint(x.id, Bn128Field::from(2)),
],
return_count: 1,
solvers: vec![],
};
let optimized = p.clone();

View file

@ -443,6 +443,7 @@ mod tests {
.iter()
.map(|&i| Bn128Field::from(i))
.collect::<Vec<_>>(),
&[],
)
.unwrap();
let res: Vec<Bn128Field> = vec![0, 1].iter().map(|&i| Bn128Field::from(i)).collect();
@ -459,6 +460,7 @@ mod tests {
.iter()
.map(|&i| Bn128Field::from(i))
.collect::<Vec<_>>(),
&[],
)
.unwrap();
let res: Vec<Bn128Field> = vec![1, 1].iter().map(|&i| Bn128Field::from(i)).collect();
@ -469,9 +471,12 @@ mod tests {
#[test]
fn bits_of_one() {
let inputs = vec![Bn128Field::from(1)];
let res =
Interpreter::execute_solver(&Solver::Bits(Bn128Field::get_required_bits()), &inputs)
.unwrap();
let res = Interpreter::execute_solver(
&Solver::Bits(Bn128Field::get_required_bits()),
&inputs,
&[],
)
.unwrap();
assert_eq!(res[253], Bn128Field::from(1));
for r in &res[0..253] {
assert_eq!(*r, Bn128Field::from(0));
@ -481,9 +486,12 @@ mod tests {
#[test]
fn bits_of_42() {
let inputs = vec![Bn128Field::from(42)];
let res =
Interpreter::execute_solver(&Solver::Bits(Bn128Field::get_required_bits()), &inputs)
.unwrap();
let res = Interpreter::execute_solver(
&Solver::Bits(Bn128Field::get_required_bits()),
&inputs,
&[],
)
.unwrap();
assert_eq!(res[253], Bn128Field::from(0));
assert_eq!(res[252], Bn128Field::from(1));
assert_eq!(res[251], Bn128Field::from(0));
@ -496,7 +504,7 @@ mod tests {
#[test]
fn five_hundred_bits_of_1() {
let inputs = vec![Bn128Field::from(1)];
let res = Interpreter::execute_solver(&Solver::Bits(500), &inputs).unwrap();
let res = Interpreter::execute_solver(&Solver::Bits(500), &inputs, &[]).unwrap();
let mut expected = vec![Bn128Field::from(0); 500];
expected[499] = Bn128Field::from(1);

View file

@ -20,6 +20,7 @@ fn generate_proof() {
arguments: vec![Parameter::public(Variable::new(0))],
return_count: 1,
statements: vec![Statement::constraint(Variable::new(0), Variable::new(0))],
solvers: vec![],
};
let interpreter = Interpreter::default();