Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Speed up rotation of 3 values in SHA2 #1296

Merged
merged 6 commits into from
Dec 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 85 additions & 20 deletions src/lib/gadgets/sha256.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
// https://csrc.nist.gov/pubs/fips/180-4/upd1/final

import { Field } from '../field.js';
import { UInt32 } from '../int.js';
import { TupleN } from '../util/types.js';
import { assert, bitSlice, exists } from './common.js';
import { Gadgets } from './gadgets.js';

export { SHA256 };
Expand Down Expand Up @@ -149,36 +150,21 @@ function Maj(x: UInt32, y: UInt32, z: UInt32) {
}

function SigmaZero(x: UInt32) {
let rotr2 = ROTR(2, x);
let rotr13 = ROTR(13, x);
let rotr22 = ROTR(22, x);

return rotr2.xor(rotr13).xor(rotr22);
return sigma(x, [2, 13, 22]);
}

function SigmaOne(x: UInt32) {
let rotr6 = ROTR(6, x);
let rotr11 = ROTR(11, x);
let rotr25 = ROTR(25, x);

return rotr6.xor(rotr11).xor(rotr25);
return sigma(x, [6, 11, 25]);
}

// lowercase sigma = delta to avoid confusing function names

function DeltaZero(x: UInt32) {
let rotr7 = ROTR(7, x);
let rotr18 = ROTR(18, x);
let shr3 = SHR(3, x);

return rotr7.xor(rotr18).xor(shr3);
return sigma(x, [3, 7, 18], true);
}

function DeltaOne(x: UInt32) {
let rotr17 = ROTR(17, x);
let rotr19 = ROTR(19, x);
let shr10 = SHR(10, x);
return rotr17.xor(rotr19).xor(shr10);
return sigma(x, [10, 17, 19], true);
}

function ROTR(n: number, x: UInt32) {
Expand All @@ -189,3 +175,82 @@ function SHR(n: number, x: UInt32) {
let val = x.rightShift(n);
return val;
}

function sigmaSimple(u: UInt32, bits: TupleN<number, 3>, firstShifted = false) {
let [r0, r1, r2] = bits;
let rot0 = firstShifted ? SHR(r0, u) : ROTR(r0, u);
let rot1 = ROTR(r1, u);
let rot2 = ROTR(r2, u);
return rot0.xor(rot1).xor(rot2);
}

function sigma(u: UInt32, bits: TupleN<number, 3>, firstShifted = false) {
if (u.isConstant()) return sigmaSimple(u, bits, firstShifted);

let [r0, r1, r2] = bits; // TODO assert bits are sorted
let x = u.value;

let d0 = r0;
let d1 = r1 - r0;
let d2 = r2 - r1;
let d3 = 32 - r2;

// decompose x into 4 chunks of size d0, d1, d2, d3
let [x0, x1, x2, x3] = exists(4, () => {
let xx = x.toBigInt();
return [
bitSlice(xx, 0, d0),
bitSlice(xx, r0, d1),
bitSlice(xx, r1, d2),
bitSlice(xx, r2, d3),
];
});

// range check each chunk
// we only need to range check to 16 bits relying on the requirement that
// the rotated values are range-checked to 32 bits later; see comments below
rangeCheck16(x0);
rangeCheck16(x1);
rangeCheck16(x2);
rangeCheck16(x3);

// prove x decomposition

// x === x0 + x1*2^d0 + x2*2^(d0+d1) + x3*2^(d0+d1+d2)
let x23 = x2.add(x3.mul(1 << d2)).seal();
let x123 = x1.add(x23.mul(1 << d1)).seal();
x0.add(x123.mul(1 << d0)).assertEquals(x);
// ^ proves that 2^(32-d3)*x3 < x < 2^32 => x3 < 2^d3

// reassemble chunks into rotated values

let xRotR0: Field;

if (!firstShifted) {
// rotr(x, r0) = x1 + x2*2^d1 + x3*2^(d1+d2) + x0*2^(d1+d2+d3)
xRotR0 = x123.add(x0.mul(1 << (d1 + d2 + d3))).seal();
// ^ proves that 2^(32-d0)*x0 < xRotR0 => x0 < 2^d0 if we check xRotR0 < 2^32 later
} else {
// shr(x, r0) = x1 + x2*2^d1 + x3*2^(d1+d2)
xRotR0 = x123;

// finish x0 < 2^d0 proof:
rangeCheck16(x0.mul(1 << (16 - d0)).seal());
}

// rotr(x, r1) = x2 + x3*2^d2 + x0*2^(d2+d3) + x1*2^(d2+d3+d0)
let x01 = x0.add(x1.mul(1 << d0)).seal();
let xRotR1 = x23.add(x01.mul(1 << (d2 + d3))).seal();
// ^ proves that 2^(32-d1)*x1 < xRotR1 => x1 < 2^d1 if we check xRotR1 < 2^32 later

// rotr(x, r2) = x3 + x0*2^d3 + x1*2^(d3+d0) + x2*2^(d3+d0+d1)
let x012 = x01.add(x2.mul(1 << (d0 + d1))).seal();
let xRotR2 = x3.add(x012.mul(1 << d3)).seal();
// ^ proves that 2^(32-d2)*x2 < xRotR2 => x2 < 2^d2 if we check xRotR2 < 2^32 later

return UInt32.from(xRotR0).xor(xRotR1).xor(xRotR2);
}

function rangeCheck16(x: Field) {
x.rangeCheckHelper(16).assertEquals(x);
}
6 changes: 4 additions & 2 deletions src/lib/int.ts
Original file line number Diff line number Diff line change
Expand Up @@ -735,8 +735,10 @@ class UInt32 extends CircuitValue {
* c.assertEquals(0b0110);
* ```
*/
xor(x: UInt32) {
return UInt32.from(Gadgets.xor(this.value, x.value, UInt32.NUM_BITS));
xor(x: UInt32 | Field) {
return UInt32.from(
Gadgets.xor(this.value, UInt32.from(x).value, UInt32.NUM_BITS)
);
}

/**
Expand Down
Loading