diff --git a/src/lib/gadgets/sha256.ts b/src/lib/gadgets/sha256.ts index c85dc47ff3..d513b610f4 100644 --- a/src/lib/gadgets/sha256.ts +++ b/src/lib/gadgets/sha256.ts @@ -1,7 +1,8 @@ // https://csrc.nist.gov/pubs/fips/180-4/upd1/final - import { Field } from '../field.js'; import { UInt32 } from '../int.js'; +import { TupleN } from '../util/types.js'; +import { assert, bitSlice, exists } from './common.js'; import { Gadgets } from './gadgets.js'; export { SHA256 }; @@ -149,36 +150,21 @@ function Maj(x: UInt32, y: UInt32, z: UInt32) { } function SigmaZero(x: UInt32) { - let rotr2 = ROTR(2, x); - let rotr13 = ROTR(13, x); - let rotr22 = ROTR(22, x); - - return rotr2.xor(rotr13).xor(rotr22); + return sigma(x, [2, 13, 22]); } function SigmaOne(x: UInt32) { - let rotr6 = ROTR(6, x); - let rotr11 = ROTR(11, x); - let rotr25 = ROTR(25, x); - - return rotr6.xor(rotr11).xor(rotr25); + return sigma(x, [6, 11, 25]); } // lowercase sigma = delta to avoid confusing function names function DeltaZero(x: UInt32) { - let rotr7 = ROTR(7, x); - let rotr18 = ROTR(18, x); - let shr3 = SHR(3, x); - - return rotr7.xor(rotr18).xor(shr3); + return sigma(x, [3, 7, 18], true); } function DeltaOne(x: UInt32) { - let rotr17 = ROTR(17, x); - let rotr19 = ROTR(19, x); - let shr10 = SHR(10, x); - return rotr17.xor(rotr19).xor(shr10); + return sigma(x, [10, 17, 19], true); } function ROTR(n: number, x: UInt32) { @@ -189,3 +175,82 @@ function SHR(n: number, x: UInt32) { let val = x.rightShift(n); return val; } + +function sigmaSimple(u: UInt32, bits: TupleN, firstShifted = false) { + let [r0, r1, r2] = bits; + let rot0 = firstShifted ? SHR(r0, u) : ROTR(r0, u); + let rot1 = ROTR(r1, u); + let rot2 = ROTR(r2, u); + return rot0.xor(rot1).xor(rot2); +} + +function sigma(u: UInt32, bits: TupleN, firstShifted = false) { + if (u.isConstant()) return sigmaSimple(u, bits, firstShifted); + + let [r0, r1, r2] = bits; // TODO assert bits are sorted + let x = u.value; + + let d0 = r0; + let d1 = r1 - r0; + let d2 = r2 - r1; + let d3 = 32 - r2; + + // decompose x into 4 chunks of size d0, d1, d2, d3 + let [x0, x1, x2, x3] = exists(4, () => { + let xx = x.toBigInt(); + return [ + bitSlice(xx, 0, d0), + bitSlice(xx, r0, d1), + bitSlice(xx, r1, d2), + bitSlice(xx, r2, d3), + ]; + }); + + // range check each chunk + // we only need to range check to 16 bits relying on the requirement that + // the rotated values are range-checked to 32 bits later; see comments below + rangeCheck16(x0); + rangeCheck16(x1); + rangeCheck16(x2); + rangeCheck16(x3); + + // prove x decomposition + + // x === x0 + x1*2^d0 + x2*2^(d0+d1) + x3*2^(d0+d1+d2) + let x23 = x2.add(x3.mul(1 << d2)).seal(); + let x123 = x1.add(x23.mul(1 << d1)).seal(); + x0.add(x123.mul(1 << d0)).assertEquals(x); + // ^ proves that 2^(32-d3)*x3 < x < 2^32 => x3 < 2^d3 + + // reassemble chunks into rotated values + + let xRotR0: Field; + + if (!firstShifted) { + // rotr(x, r0) = x1 + x2*2^d1 + x3*2^(d1+d2) + x0*2^(d1+d2+d3) + xRotR0 = x123.add(x0.mul(1 << (d1 + d2 + d3))).seal(); + // ^ proves that 2^(32-d0)*x0 < xRotR0 => x0 < 2^d0 if we check xRotR0 < 2^32 later + } else { + // shr(x, r0) = x1 + x2*2^d1 + x3*2^(d1+d2) + xRotR0 = x123; + + // finish x0 < 2^d0 proof: + rangeCheck16(x0.mul(1 << (16 - d0)).seal()); + } + + // rotr(x, r1) = x2 + x3*2^d2 + x0*2^(d2+d3) + x1*2^(d2+d3+d0) + let x01 = x0.add(x1.mul(1 << d0)).seal(); + let xRotR1 = x23.add(x01.mul(1 << (d2 + d3))).seal(); + // ^ proves that 2^(32-d1)*x1 < xRotR1 => x1 < 2^d1 if we check xRotR1 < 2^32 later + + // rotr(x, r2) = x3 + x0*2^d3 + x1*2^(d3+d0) + x2*2^(d3+d0+d1) + let x012 = x01.add(x2.mul(1 << (d0 + d1))).seal(); + let xRotR2 = x3.add(x012.mul(1 << d3)).seal(); + // ^ proves that 2^(32-d2)*x2 < xRotR2 => x2 < 2^d2 if we check xRotR2 < 2^32 later + + return UInt32.from(xRotR0).xor(xRotR1).xor(xRotR2); +} + +function rangeCheck16(x: Field) { + x.rangeCheckHelper(16).assertEquals(x); +} diff --git a/src/lib/int.ts b/src/lib/int.ts index ac59f5c78e..f161e83187 100644 --- a/src/lib/int.ts +++ b/src/lib/int.ts @@ -735,8 +735,10 @@ class UInt32 extends CircuitValue { * c.assertEquals(0b0110); * ``` */ - xor(x: UInt32) { - return UInt32.from(Gadgets.xor(this.value, x.value, UInt32.NUM_BITS)); + xor(x: UInt32 | Field) { + return UInt32.from( + Gadgets.xor(this.value, UInt32.from(x).value, UInt32.NUM_BITS) + ); } /**