Skip to content

Commit

Permalink
Optimizations to KZG commitment scheme (microsoft#300)
Browse files Browse the repository at this point in the history
* simplify kzg_verify_batch closure

credit: storojs72

* parallel computation of polynomials

credit: storojs72

* eliminate computation of the last commitment

credit: storojs72

* update tests to account for the new behavior of the Prove method

* simplify tests
  • Loading branch information
srinathsetty authored and huitseeker committed Jan 25, 2024
1 parent eab053b commit 55a26d2
Show file tree
Hide file tree
Showing 2 changed files with 202 additions and 69 deletions.
155 changes: 86 additions & 69 deletions src/provider/hyperkzg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,15 @@ use crate::{
evaluation::EvaluationEngineTrait,
AbsorbInROTrait, Engine, ROTrait, TranscriptEngineTrait, TranscriptReprTrait,
},
zip_with,
};
use core::{
marker::PhantomData,
ops::{Add, Mul, MulAssign},
};
use ff::Field;
use halo2curves::bn256::{Fq as Bn256Fq, Fr as Bn256Fr, G1 as Bn256G1};
use itertools::Itertools;
use rand_core::OsRng;
use rayon::prelude::*;
use serde::{Deserialize, Serialize};
Expand Down Expand Up @@ -366,10 +368,10 @@ where
ck: &CommitmentKey<E>,
_pk: &Self::ProverKey,
transcript: &mut <E as Engine>::TE,
C: &Commitment<E>,
_C: &Commitment<E>,
hat_P: &[E::Scalar],
point: &[E::Scalar],
eval: &E::Scalar,
_eval: &E::Scalar,
) -> Result<Self::EvaluationArgument, NovaError> {
let x: Vec<E::Scalar> = point.to_vec();

Expand Down Expand Up @@ -406,8 +408,7 @@ where
E::CE::commit(ck, &h).comm.preprocessed()
};

let kzg_open_batch = |C: &[G1<E>],
f: &[Vec<E::Scalar>],
let kzg_open_batch = |f: &[Vec<E::Scalar>],
u: &[E::Scalar],
transcript: &mut <E as Engine>::TE|
-> (Vec<G1<E>>, Vec<Vec<E::Scalar>>) {
Expand Down Expand Up @@ -447,18 +448,18 @@ where

let k = f.len();
let t = u.len();
assert!(C.len() == k);

// The verifier needs f_i(u_j), so we compute them here
// (V will compute B(u_j) itself)
let mut v = vec![vec!(E::Scalar::ZERO; k); t];
for i in 0..t {
v.par_iter_mut().enumerate().for_each(|(i, v_i)| {
// for each point u
for (j, f_j) in f.iter().enumerate().take(k) {
v_i.par_iter_mut().zip_eq(f).for_each(|(v_ij, f)| {
// for each poly f
v[i][j] = poly_eval(f_j, u[i]); // = f_j(u_i)
}
}
// for each poly f (except the last one - since it is constant)
*v_ij = poly_eval(f, u[i]);
});
});

let q = Self::get_batch_challenge(&v, transcript);
let B = kzg_compute_batch_polynomial(f, q);
Expand All @@ -484,21 +485,18 @@ where
assert_eq!(n, 1 << ell); // Below we assume that n is a power of two

// Phase 1 -- create commitments com_1, ..., com_\ell
// We do not compute final Pi (and its commitment) as it is constant and equals to 'eval'
// also known to verifier, so can be derived on its side as well
let mut polys: Vec<Vec<E::Scalar>> = Vec::new();
polys.push(hat_P.to_vec());
for i in 0..ell {
for i in 0..ell - 1 {
let Pi_len = polys[i].len() / 2;
let mut Pi = vec![E::Scalar::ZERO; Pi_len];

#[allow(clippy::needless_range_loop)]
for j in 0..Pi_len {
Pi[j] = x[ell-i-1] * polys[i][2*j + 1] // Odd part of P^(i-1)
+ (E::Scalar::ONE - x[ell-i-1]) * polys[i][2*j]; // Even part of P^(i-1)
}

if i == ell - 1 && *eval != Pi[0] {
return Err(NovaError::UnSat);
}
Pi.par_iter_mut().enumerate().for_each(|(j, Pi_j)| {
*Pi_j = x[ell - i - 1] * (polys[i][2 * j + 1] - polys[i][2 * j]) + polys[i][2 * j];
});

polys.push(Pi);
}
Expand All @@ -517,9 +515,7 @@ where
let u = vec![r, -r, r * r];

// Phase 3 -- create response
let mut com_all = com.clone();
com_all.insert(0, C.comm.preprocessed());
let (w, v) = kzg_open_batch(&com_all, &polys, &u, transcript);
let (w, v) = kzg_open_batch(&polys, &u, transcript);

Ok(EvaluationArgument { com, w, v })
}
Expand Down Expand Up @@ -551,52 +547,70 @@ where
let q = Self::get_batch_challenge(v, transcript);
let q_powers = Self::batch_challenge_powers(q, k); // 1, q, q^2, ..., q^(k-1)

// Compute the commitment to the batched polynomial B(X)
let C_B = (<E::GE as DlogGroup>::group(&C[0])
+ E::GE::vartime_multiscalar_mul(&q_powers[1..k], &C[1..k]))
.preprocessed();

// Compute the batched openings
// compute B(u_i) = v[i][0] + q*v[i][1] + ... + q^(t-1) * v[i][t-1]
let B_u = (0..t)
.map(|i| {
assert_eq!(q_powers.len(), v[i].len());
q_powers.iter().zip(v[i].iter()).map(|(a, b)| *a * *b).sum()
})
.collect::<Vec<E::Scalar>>();

let d_0 = Self::verifier_second_challenge(W, transcript);
let d = [d_0, d_0 * d_0];
let d_1 = d_0 * d_0;

// Shorthand to convert from preprocessed G1 elements to non-preprocessed
let from_ppG1 = |P: &G1<E>| <E::GE as DlogGroup>::group(P);
// Shorthand to convert from preprocessed G2 elements to non-preprocessed
let from_ppG2 = |P: &G2<E>| <<E::GE as PairingGroup>::G2 as DlogGroup>::group(P);

assert!(t == 3);
assert_eq!(t, 3);
assert_eq!(W.len(), 3);
// We write a special case for t=3, since this what is required for
// mlkzg. Following the paper directly, we must compute:
// hyperkzg. Following the paper directly, we must compute:
// let L0 = C_B - vk.G * B_u[0] + W[0] * u[0];
// let L1 = C_B - vk.G * B_u[1] + W[1] * u[1];
// let L2 = C_B - vk.G * B_u[2] + W[2] * u[2];
// let R0 = -W[0];
// let R1 = -W[1];
// let R2 = -W[2];
// let L = L0 + L1*d[0] + L2*d[1];
// let R = R0 + R1*d[0] + R2*d[1];
// let L = L0 + L1*d_0 + L2*d_1;
// let R = R0 + R1*d_0 + R2*d_1;
//
// We group terms to reduce the number of scalar mults (to seven):
// In Rust, we could use MSMs for these, and speed up verification.
let L = from_ppG1(&C_B) * (E::Scalar::ONE + d[0] + d[1])
- from_ppG1(&vk.G) * (B_u[0] + d[0] * B_u[1] + d[1] * B_u[2])
+ from_ppG1(&W[0]) * u[0]
+ from_ppG1(&W[1]) * (u[1] * d[0])
+ from_ppG1(&W[2]) * (u[2] * d[1]);
//
// Note, that while computing L, the intermediate computation of C_B together with computing
// L0, L1, L2 can be replaced by single MSM of C with the powers of q multiplied by (1 + d_0 + d_1)
// with additionally concatenated inputs for scalars/bases.

let q_power_multiplier = E::Scalar::ONE + d_0 + d_1;

let q_powers_multiplied: Vec<E::Scalar> = q_powers
.par_iter()
.map(|q_power| *q_power * q_power_multiplier)
.collect();

// Compute the batched openings
// compute B(u_i) = v[i][0] + q*v[i][1] + ... + q^(t-1) * v[i][t-1]
let B_u = v
.into_par_iter()
.map(|v_i| zip_with!(iter, (q_powers, v_i), |a, b| *a * *b).sum())
.collect::<Vec<E::Scalar>>();

let L = E::GE::vartime_multiscalar_mul(
&[
&q_powers_multiplied[..k],
&[
u[0],
(u[1] * d_0),
(u[2] * d_1),
-(B_u[0] + d_0 * B_u[1] + d_1 * B_u[2]),
],
]
.concat(),
&[
&C[..k],
&[W[0].clone(), W[1].clone(), W[2].clone(), vk.G.clone()],
]
.concat(),
);

let R0 = from_ppG1(&W[0]);
let R1 = from_ppG1(&W[1]);
let R2 = from_ppG1(&W[2]);
let R = R0 + R1 * d[0] + R2 * d[1];
let R = R0 + R1 * d_0 + R2 * d_1;

// Check that e(L, vk.H) == e(R, vk.tau_H)
(<E::GE as PairingGroup>::pairing(&L, &from_ppG2(&vk.H)))
Expand Down Expand Up @@ -624,18 +638,15 @@ where
if v.len() != 3 {
return Err(NovaError::ProofVerifyError);
}
if v[0].len() != ell + 1 || v[1].len() != ell + 1 || v[2].len() != ell + 1 {
if v[0].len() != ell || v[1].len() != ell || v[2].len() != ell {
return Err(NovaError::ProofVerifyError);
}
let ypos = &v[0];
let yneg = &v[1];
let Y = &v[2];
let mut Y = v[2].to_vec();
Y.push(*y);

// Check consistency of (Y, ypos, yneg)
if Y[ell] != *y {
return Err(NovaError::ProofVerifyError);
}

let two = E::Scalar::from(2u64);
for i in 0..ell {
if two * r * Y[i + 1]
Expand Down Expand Up @@ -685,52 +696,58 @@ mod tests {
type Fr = <E as Engine>::Scalar;

#[test]
fn test_mlkzg_eval() {
fn test_hyperkzg_eval() {
// Test with poly(X1, X2) = 1 + X1 + X2 + X1*X2
let n = 4;
let ck: CommitmentKey<E> = CommitmentEngine::setup(b"test", n);
let (pk, _vk): (ProverKey<E>, VerifierKey<E>) = EvaluationEngine::setup(&ck);
let (pk, vk): (ProverKey<E>, VerifierKey<E>) = EvaluationEngine::setup(&ck);

// poly is in eval. representation; evaluated at [(0,0), (0,1), (1,0), (1,1)]
let poly = vec![Fr::from(1), Fr::from(2), Fr::from(2), Fr::from(4)];

let C = CommitmentEngine::commit(&ck, &poly);
let mut tr = Keccak256Transcript::new(b"TestEval");

// Call the prover with a (point, eval) pair. The prover recomputes
// poly(point) = eval', and fails if eval' != eval
let test_inner = |point: Vec<Fr>, eval: Fr| -> Result<(), NovaError> {
let mut tr = Keccak256Transcript::new(b"TestEval");
let proof = EvaluationEngine::prove(&ck, &pk, &mut tr, &C, &poly, &point, &eval).unwrap();
let mut tr = Keccak256Transcript::new(b"TestEval");
EvaluationEngine::verify(&vk, &mut tr, &C, &point, &eval, &proof)
};

// Call the prover with a (point, eval) pair.
// The prover does not recompute so it may produce a proof, but it should not verify
let point = vec![Fr::from(0), Fr::from(0)];
let eval = Fr::ONE;
assert!(EvaluationEngine::prove(&ck, &pk, &mut tr, &C, &poly, &point, &eval).is_ok());
assert!(test_inner(point, eval).is_ok());

let point = vec![Fr::from(0), Fr::from(1)];
let eval = Fr::from(2);
assert!(EvaluationEngine::prove(&ck, &pk, &mut tr, &C, &poly, &point, &eval).is_ok());
assert!(test_inner(point, eval).is_ok());

let point = vec![Fr::from(1), Fr::from(1)];
let eval = Fr::from(4);
assert!(EvaluationEngine::prove(&ck, &pk, &mut tr, &C, &poly, &point, &eval).is_ok());
assert!(test_inner(point, eval).is_ok());

let point = vec![Fr::from(0), Fr::from(2)];
let eval = Fr::from(3);
assert!(EvaluationEngine::prove(&ck, &pk, &mut tr, &C, &poly, &point, &eval).is_ok());
assert!(test_inner(point, eval).is_ok());

let point = vec![Fr::from(2), Fr::from(2)];
let eval = Fr::from(9);
assert!(EvaluationEngine::prove(&ck, &pk, &mut tr, &C, &poly, &point, &eval).is_ok());
assert!(test_inner(point, eval).is_ok());

// Try a couple incorrect evaluations and expect failure
let point = vec![Fr::from(2), Fr::from(2)];
let eval = Fr::from(50);
assert!(EvaluationEngine::prove(&ck, &pk, &mut tr, &C, &poly, &point, &eval).is_err());
assert!(test_inner(point, eval).is_err());

let point = vec![Fr::from(0), Fr::from(2)];
let eval = Fr::from(4);
assert!(EvaluationEngine::prove(&ck, &pk, &mut tr, &C, &poly, &point, &eval).is_err());
assert!(test_inner(point, eval).is_err());
}

#[test]
fn test_mlkzg() {
fn test_hyperkzg() {
let n = 4;

// poly = [1, 2, 1, 4]
Expand Down Expand Up @@ -778,7 +795,7 @@ mod tests {

// Change the proof and expect verification to fail
let mut bad_proof = proof.clone();
bad_proof.com[0] = (bad_proof.com[0] + bad_proof.com[1]).to_affine();
bad_proof.com[0] = (bad_proof.com[0] + bad_proof.com[0]).to_affine();
let mut verifier_transcript2 = Keccak256Transcript::new(b"TestEval");
assert!(EvaluationEngine::verify(
&vk,
Expand All @@ -792,8 +809,8 @@ mod tests {
}

#[test]
fn test_mlkzg_more() {
// test the mlkzg prover and verifier with random instances (derived from a seed)
fn test_hyperkzg_more() {
// test the hyperkzg prover and verifier with random instances (derived from a seed)
for ell in [4, 5, 6] {
let mut rng = rand::rngs::StdRng::seed_from_u64(ell as u64);

Expand Down
Loading

0 comments on commit 55a26d2

Please sign in to comment.