1
0
Fork 0
mirror of synced 2025-09-24 04:40:05 +00:00
This commit is contained in:
dark64 2023-03-20 20:21:23 +01:00
parent 0f31a1b42e
commit 475744bf6e
2 changed files with 67 additions and 103 deletions

View file

@ -20,7 +20,7 @@ mod witness;
pub use self::expression::QuadComb;
pub use self::expression::{CanonicalLinComb, LinComb};
pub use self::serialize::ProgEnum;
pub use self::serialize::{ProgEnum, ProgHeader};
pub use crate::common::Parameter;
pub use crate::common::RuntimeError;
pub use crate::common::Solver;

View file

@ -1,7 +1,4 @@
use crate::{
ir::{check::UnconstrainedVariableDetector, solver_indexer::SolverIndexer},
Solver,
};
use crate::ir::{check::UnconstrainedVariableDetector, solver_indexer::SolverIndexer};
use super::{ProgIterator, Statement};
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
@ -65,7 +62,7 @@ impl<
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
pub enum SectionType {
Arguments = 1,
Parameters = 1,
Constraints = 2,
Solvers = 3,
}
@ -75,7 +72,7 @@ impl TryFrom<u32> for SectionType {
fn try_from(value: u32) -> Result<Self, Self::Error> {
match value {
1 => Ok(SectionType::Arguments),
1 => Ok(SectionType::Parameters),
2 => Ok(SectionType::Constraints),
3 => Ok(SectionType::Solvers),
_ => Err("invalid section type".to_string()),
@ -198,21 +195,23 @@ impl<'ast, T: Field, I: IntoIterator<Item = Statement<'ast, T>>> ProgIterator<'a
w.write_all(&[0u8; HEADER_SIZE])?; // reserve bytes for the header
let arguments = {
let mut section = Section::new(SectionType::Arguments);
// write parameters
if self.arguments.len() > 0 {
let mut section = Section::new(SectionType::Parameters);
section.set_offset(w.stream_position()?);
serde_cbor::to_writer(&mut w, &self.arguments)?;
section.set_length(w.stream_position()? - section.offset);
section
};
header.sections.push(section);
}
let mut solver_indexer: SolverIndexer<'ast, T> = SolverIndexer::default();
let mut unconstrained_variable_detector = UnconstrainedVariableDetector::new(&self);
let mut count: usize = 0;
let constraints = {
// write constraints
{
let mut section = Section::new(SectionType::Constraints);
section.set_offset(w.stream_position()?);
@ -232,24 +231,23 @@ impl<'ast, T: Field, I: IntoIterator<Item = Statement<'ast, T>>> ProgIterator<'a
}
section.set_length(w.stream_position()? - section.offset);
section
header.sections.push(section);
};
let solvers = {
// write solvers
if solver_indexer.solvers.len() > 0 {
let mut section = Section::new(SectionType::Solvers);
section.set_offset(w.stream_position()?);
serde_cbor::to_writer(&mut w, &solver_indexer.solvers)?;
section.set_length(w.stream_position()? - section.offset);
section
};
header.sections.push(section);
}
header.constraint_count = count as u32;
header
.sections
.extend_from_slice(&[arguments, constraints, solvers]);
// rewind to write the header
w.rewind()?;
header.write(&mut w)?;
@ -283,6 +281,52 @@ impl<'de, R: Read + Seek>
UnwrappedStreamDeserializer<'de, serde_cbor::de::IoRead<R>, Statement<'de, Bw6_761Field>>,
>
{
fn read<T: Field>(
mut r: R,
header: &ProgHeader,
) -> ProgIterator<
'de,
T,
UnwrappedStreamDeserializer<'de, serde_cbor::de::IoRead<R>, Statement<'de, T>>,
> {
let parameters = match header.get_section(SectionType::Parameters) {
Some(section) => {
r.seek(std::io::SeekFrom::Start(section.offset)).unwrap();
let mut p = serde_cbor::Deserializer::from_reader(r.by_ref());
Vec::deserialize(&mut p)
.map_err(|_| String::from("Cannot read parameters"))
.unwrap()
}
None => vec![],
};
let solvers = match header.get_section(SectionType::Solvers) {
Some(section) => {
r.seek(std::io::SeekFrom::Start(section.offset)).unwrap();
let mut p = serde_cbor::Deserializer::from_reader(r.by_ref());
Vec::deserialize(&mut p)
.map_err(|_| String::from("Cannot read solvers"))
.unwrap()
}
None => vec![],
};
let section = header.get_section(SectionType::Constraints).unwrap();
r.seek(std::io::SeekFrom::Start(section.offset)).unwrap();
let p = serde_cbor::Deserializer::from_reader(r);
let s = p.into_iter::<Statement<T>>();
ProgIterator::new(
parameters,
UnwrappedStreamDeserializer { s },
header.return_count as usize,
solvers,
)
}
pub fn deserialize(mut r: R) -> Result<Self, String> {
let header = ProgHeader::read(&mut r).map_err(|_| String::from("Invalid header"))?;
@ -296,95 +340,15 @@ impl<'de, R: Read + Seek>
return Err("Invalid file version".to_string());
}
let arguments = {
let section = header.get_section(SectionType::Arguments).unwrap();
r.seek(std::io::SeekFrom::Start(section.offset)).unwrap();
let mut p = serde_cbor::Deserializer::from_reader(r.by_ref());
Vec::deserialize(&mut p).map_err(|_| String::from("Cannot read parameters"))?
};
let solvers_section = header.get_section(SectionType::Solvers).unwrap();
r.seek(std::io::SeekFrom::Start(solvers_section.offset))
.unwrap();
match header.curve_id {
m if m == Bls12_381Field::id() => {
let solvers: Vec<Solver<'de, Bls12_381Field>> = {
let mut p = serde_cbor::Deserializer::from_reader(r.by_ref());
Vec::deserialize(&mut p).map_err(|_| String::from("Cannot read solvers"))?
};
let section = header.get_section(SectionType::Constraints).unwrap();
r.seek(std::io::SeekFrom::Start(section.offset)).unwrap();
let p = serde_cbor::Deserializer::from_reader(r);
let s = p.into_iter::<Statement<Bls12_381Field>>();
Ok(ProgEnum::Bls12_381Program(ProgIterator::new(
arguments,
UnwrappedStreamDeserializer { s },
header.return_count as usize,
solvers,
)))
}
m if m == Bn128Field::id() => {
let solvers: Vec<Solver<'de, Bn128Field>> = {
let mut p = serde_cbor::Deserializer::from_reader(r.by_ref());
Vec::deserialize(&mut p).map_err(|_| String::from("Cannot read solvers"))?
};
let section = header.get_section(SectionType::Constraints).unwrap();
r.seek(std::io::SeekFrom::Start(section.offset)).unwrap();
let p = serde_cbor::Deserializer::from_reader(r);
let s = p.into_iter::<Statement<Bn128Field>>();
Ok(ProgEnum::Bn128Program(ProgIterator::new(
arguments,
UnwrappedStreamDeserializer { s },
header.return_count as usize,
solvers,
)))
Ok(ProgEnum::Bls12_381Program(Self::read(r, &header)))
}
m if m == Bn128Field::id() => Ok(ProgEnum::Bn128Program(Self::read(r, &header))),
m if m == Bls12_377Field::id() => {
let solvers: Vec<Solver<'de, Bls12_377Field>> = {
let mut p = serde_cbor::Deserializer::from_reader(r.by_ref());
Vec::deserialize(&mut p).map_err(|_| String::from("Cannot read solvers"))?
};
let section = header.get_section(SectionType::Constraints).unwrap();
r.seek(std::io::SeekFrom::Start(section.offset)).unwrap();
let p = serde_cbor::Deserializer::from_reader(r);
let s = p.into_iter::<Statement<Bls12_377Field>>();
Ok(ProgEnum::Bls12_377Program(ProgIterator::new(
arguments,
UnwrappedStreamDeserializer { s },
header.return_count as usize,
solvers,
)))
}
m if m == Bw6_761Field::id() => {
let solvers: Vec<Solver<'de, Bw6_761Field>> = {
let mut p = serde_cbor::Deserializer::from_reader(r.by_ref());
Vec::deserialize(&mut p).map_err(|_| String::from("Cannot read solvers"))?
};
let section = header.get_section(SectionType::Constraints).unwrap();
r.seek(std::io::SeekFrom::Start(section.offset)).unwrap();
let p = serde_cbor::Deserializer::from_reader(r);
let s = p.into_iter::<Statement<Bw6_761Field>>();
Ok(ProgEnum::Bw6_761Program(ProgIterator::new(
arguments,
UnwrappedStreamDeserializer { s },
header.return_count as usize,
solvers,
)))
Ok(ProgEnum::Bls12_377Program(Self::read(r, &header)))
}
m if m == Bw6_761Field::id() => Ok(ProgEnum::Bw6_761Program(Self::read(r, &header))),
_ => Err(String::from("Unknown curve identifier")),
}
}