refactor
This commit is contained in:
parent
0f31a1b42e
commit
475744bf6e
2 changed files with 67 additions and 103 deletions
|
@ -20,7 +20,7 @@ mod witness;
|
|||
|
||||
pub use self::expression::QuadComb;
|
||||
pub use self::expression::{CanonicalLinComb, LinComb};
|
||||
pub use self::serialize::ProgEnum;
|
||||
pub use self::serialize::{ProgEnum, ProgHeader};
|
||||
pub use crate::common::Parameter;
|
||||
pub use crate::common::RuntimeError;
|
||||
pub use crate::common::Solver;
|
||||
|
|
|
@ -1,7 +1,4 @@
|
|||
use crate::{
|
||||
ir::{check::UnconstrainedVariableDetector, solver_indexer::SolverIndexer},
|
||||
Solver,
|
||||
};
|
||||
use crate::ir::{check::UnconstrainedVariableDetector, solver_indexer::SolverIndexer};
|
||||
|
||||
use super::{ProgIterator, Statement};
|
||||
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
|
||||
|
@ -65,7 +62,7 @@ impl<
|
|||
|
||||
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
|
||||
pub enum SectionType {
|
||||
Arguments = 1,
|
||||
Parameters = 1,
|
||||
Constraints = 2,
|
||||
Solvers = 3,
|
||||
}
|
||||
|
@ -75,7 +72,7 @@ impl TryFrom<u32> for SectionType {
|
|||
|
||||
fn try_from(value: u32) -> Result<Self, Self::Error> {
|
||||
match value {
|
||||
1 => Ok(SectionType::Arguments),
|
||||
1 => Ok(SectionType::Parameters),
|
||||
2 => Ok(SectionType::Constraints),
|
||||
3 => Ok(SectionType::Solvers),
|
||||
_ => Err("invalid section type".to_string()),
|
||||
|
@ -198,21 +195,23 @@ impl<'ast, T: Field, I: IntoIterator<Item = Statement<'ast, T>>> ProgIterator<'a
|
|||
|
||||
w.write_all(&[0u8; HEADER_SIZE])?; // reserve bytes for the header
|
||||
|
||||
let arguments = {
|
||||
let mut section = Section::new(SectionType::Arguments);
|
||||
// write parameters
|
||||
if self.arguments.len() > 0 {
|
||||
let mut section = Section::new(SectionType::Parameters);
|
||||
section.set_offset(w.stream_position()?);
|
||||
|
||||
serde_cbor::to_writer(&mut w, &self.arguments)?;
|
||||
|
||||
section.set_length(w.stream_position()? - section.offset);
|
||||
section
|
||||
};
|
||||
header.sections.push(section);
|
||||
}
|
||||
|
||||
let mut solver_indexer: SolverIndexer<'ast, T> = SolverIndexer::default();
|
||||
let mut unconstrained_variable_detector = UnconstrainedVariableDetector::new(&self);
|
||||
let mut count: usize = 0;
|
||||
|
||||
let constraints = {
|
||||
// write constraints
|
||||
{
|
||||
let mut section = Section::new(SectionType::Constraints);
|
||||
section.set_offset(w.stream_position()?);
|
||||
|
||||
|
@ -232,24 +231,23 @@ impl<'ast, T: Field, I: IntoIterator<Item = Statement<'ast, T>>> ProgIterator<'a
|
|||
}
|
||||
|
||||
section.set_length(w.stream_position()? - section.offset);
|
||||
section
|
||||
header.sections.push(section);
|
||||
};
|
||||
|
||||
let solvers = {
|
||||
// write solvers
|
||||
if solver_indexer.solvers.len() > 0 {
|
||||
let mut section = Section::new(SectionType::Solvers);
|
||||
section.set_offset(w.stream_position()?);
|
||||
|
||||
serde_cbor::to_writer(&mut w, &solver_indexer.solvers)?;
|
||||
|
||||
section.set_length(w.stream_position()? - section.offset);
|
||||
section
|
||||
};
|
||||
header.sections.push(section);
|
||||
}
|
||||
|
||||
header.constraint_count = count as u32;
|
||||
header
|
||||
.sections
|
||||
.extend_from_slice(&[arguments, constraints, solvers]);
|
||||
|
||||
// rewind to write the header
|
||||
w.rewind()?;
|
||||
header.write(&mut w)?;
|
||||
|
||||
|
@ -283,6 +281,52 @@ impl<'de, R: Read + Seek>
|
|||
UnwrappedStreamDeserializer<'de, serde_cbor::de::IoRead<R>, Statement<'de, Bw6_761Field>>,
|
||||
>
|
||||
{
|
||||
fn read<T: Field>(
|
||||
mut r: R,
|
||||
header: &ProgHeader,
|
||||
) -> ProgIterator<
|
||||
'de,
|
||||
T,
|
||||
UnwrappedStreamDeserializer<'de, serde_cbor::de::IoRead<R>, Statement<'de, T>>,
|
||||
> {
|
||||
let parameters = match header.get_section(SectionType::Parameters) {
|
||||
Some(section) => {
|
||||
r.seek(std::io::SeekFrom::Start(section.offset)).unwrap();
|
||||
|
||||
let mut p = serde_cbor::Deserializer::from_reader(r.by_ref());
|
||||
Vec::deserialize(&mut p)
|
||||
.map_err(|_| String::from("Cannot read parameters"))
|
||||
.unwrap()
|
||||
}
|
||||
None => vec![],
|
||||
};
|
||||
|
||||
let solvers = match header.get_section(SectionType::Solvers) {
|
||||
Some(section) => {
|
||||
r.seek(std::io::SeekFrom::Start(section.offset)).unwrap();
|
||||
|
||||
let mut p = serde_cbor::Deserializer::from_reader(r.by_ref());
|
||||
Vec::deserialize(&mut p)
|
||||
.map_err(|_| String::from("Cannot read solvers"))
|
||||
.unwrap()
|
||||
}
|
||||
None => vec![],
|
||||
};
|
||||
|
||||
let section = header.get_section(SectionType::Constraints).unwrap();
|
||||
r.seek(std::io::SeekFrom::Start(section.offset)).unwrap();
|
||||
|
||||
let p = serde_cbor::Deserializer::from_reader(r);
|
||||
let s = p.into_iter::<Statement<T>>();
|
||||
|
||||
ProgIterator::new(
|
||||
parameters,
|
||||
UnwrappedStreamDeserializer { s },
|
||||
header.return_count as usize,
|
||||
solvers,
|
||||
)
|
||||
}
|
||||
|
||||
pub fn deserialize(mut r: R) -> Result<Self, String> {
|
||||
let header = ProgHeader::read(&mut r).map_err(|_| String::from("Invalid header"))?;
|
||||
|
||||
|
@ -296,95 +340,15 @@ impl<'de, R: Read + Seek>
|
|||
return Err("Invalid file version".to_string());
|
||||
}
|
||||
|
||||
let arguments = {
|
||||
let section = header.get_section(SectionType::Arguments).unwrap();
|
||||
r.seek(std::io::SeekFrom::Start(section.offset)).unwrap();
|
||||
|
||||
let mut p = serde_cbor::Deserializer::from_reader(r.by_ref());
|
||||
Vec::deserialize(&mut p).map_err(|_| String::from("Cannot read parameters"))?
|
||||
};
|
||||
|
||||
let solvers_section = header.get_section(SectionType::Solvers).unwrap();
|
||||
r.seek(std::io::SeekFrom::Start(solvers_section.offset))
|
||||
.unwrap();
|
||||
|
||||
match header.curve_id {
|
||||
m if m == Bls12_381Field::id() => {
|
||||
let solvers: Vec<Solver<'de, Bls12_381Field>> = {
|
||||
let mut p = serde_cbor::Deserializer::from_reader(r.by_ref());
|
||||
Vec::deserialize(&mut p).map_err(|_| String::from("Cannot read solvers"))?
|
||||
};
|
||||
|
||||
let section = header.get_section(SectionType::Constraints).unwrap();
|
||||
r.seek(std::io::SeekFrom::Start(section.offset)).unwrap();
|
||||
|
||||
let p = serde_cbor::Deserializer::from_reader(r);
|
||||
let s = p.into_iter::<Statement<Bls12_381Field>>();
|
||||
|
||||
Ok(ProgEnum::Bls12_381Program(ProgIterator::new(
|
||||
arguments,
|
||||
UnwrappedStreamDeserializer { s },
|
||||
header.return_count as usize,
|
||||
solvers,
|
||||
)))
|
||||
}
|
||||
m if m == Bn128Field::id() => {
|
||||
let solvers: Vec<Solver<'de, Bn128Field>> = {
|
||||
let mut p = serde_cbor::Deserializer::from_reader(r.by_ref());
|
||||
Vec::deserialize(&mut p).map_err(|_| String::from("Cannot read solvers"))?
|
||||
};
|
||||
|
||||
let section = header.get_section(SectionType::Constraints).unwrap();
|
||||
r.seek(std::io::SeekFrom::Start(section.offset)).unwrap();
|
||||
|
||||
let p = serde_cbor::Deserializer::from_reader(r);
|
||||
let s = p.into_iter::<Statement<Bn128Field>>();
|
||||
|
||||
Ok(ProgEnum::Bn128Program(ProgIterator::new(
|
||||
arguments,
|
||||
UnwrappedStreamDeserializer { s },
|
||||
header.return_count as usize,
|
||||
solvers,
|
||||
)))
|
||||
Ok(ProgEnum::Bls12_381Program(Self::read(r, &header)))
|
||||
}
|
||||
m if m == Bn128Field::id() => Ok(ProgEnum::Bn128Program(Self::read(r, &header))),
|
||||
m if m == Bls12_377Field::id() => {
|
||||
let solvers: Vec<Solver<'de, Bls12_377Field>> = {
|
||||
let mut p = serde_cbor::Deserializer::from_reader(r.by_ref());
|
||||
Vec::deserialize(&mut p).map_err(|_| String::from("Cannot read solvers"))?
|
||||
};
|
||||
|
||||
let section = header.get_section(SectionType::Constraints).unwrap();
|
||||
r.seek(std::io::SeekFrom::Start(section.offset)).unwrap();
|
||||
|
||||
let p = serde_cbor::Deserializer::from_reader(r);
|
||||
let s = p.into_iter::<Statement<Bls12_377Field>>();
|
||||
|
||||
Ok(ProgEnum::Bls12_377Program(ProgIterator::new(
|
||||
arguments,
|
||||
UnwrappedStreamDeserializer { s },
|
||||
header.return_count as usize,
|
||||
solvers,
|
||||
)))
|
||||
}
|
||||
m if m == Bw6_761Field::id() => {
|
||||
let solvers: Vec<Solver<'de, Bw6_761Field>> = {
|
||||
let mut p = serde_cbor::Deserializer::from_reader(r.by_ref());
|
||||
Vec::deserialize(&mut p).map_err(|_| String::from("Cannot read solvers"))?
|
||||
};
|
||||
|
||||
let section = header.get_section(SectionType::Constraints).unwrap();
|
||||
r.seek(std::io::SeekFrom::Start(section.offset)).unwrap();
|
||||
|
||||
let p = serde_cbor::Deserializer::from_reader(r);
|
||||
let s = p.into_iter::<Statement<Bw6_761Field>>();
|
||||
|
||||
Ok(ProgEnum::Bw6_761Program(ProgIterator::new(
|
||||
arguments,
|
||||
UnwrappedStreamDeserializer { s },
|
||||
header.return_count as usize,
|
||||
solvers,
|
||||
)))
|
||||
Ok(ProgEnum::Bls12_377Program(Self::read(r, &header)))
|
||||
}
|
||||
m if m == Bw6_761Field::id() => Ok(ProgEnum::Bw6_761Program(Self::read(r, &header))),
|
||||
_ => Err(String::from("Unknown curve identifier")),
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue